From 0f5e5ee80a770158c461c119f9251fec1c5cf020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?= =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?= =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru> Date: Wed, 26 Mar 2025 19:14:17 +0300 Subject: [PATCH] pereobuch --- .ipynb_checkpoints/pereobuch-checkpoint.py | 21 ++++++++++++++------- pereobuch.py | 21 ++++++++++++++------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/.ipynb_checkpoints/pereobuch-checkpoint.py b/.ipynb_checkpoints/pereobuch-checkpoint.py index f87b138..f10fc56 100644 --- a/.ipynb_checkpoints/pereobuch-checkpoint.py +++ b/.ipynb_checkpoints/pereobuch-checkpoint.py @@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module): # Функции потерь self.safety_loss = nn.CrossEntropyLoss() self.attack_loss = nn.CrossEntropyLoss() + + self.attack_weights = torch.tensor([1.0, 0.5, 10.0, 20.0]).to(DEVICE) - def forward(self, input_ids, attention_mask, labels_safety=None, labels_attack=None): + def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, + labels_safety=None, labels_attack=None): + # Поддержка обоих вариантов ввода outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, + inputs_embeds=inputs_embeds, return_dict=True ) pooled_output = outputs.last_hidden_state[:, 0, :] - - # Предсказания logits_safety = self.safety_classifier(pooled_output) logits_attack = self.attack_classifier(pooled_output) - # Расчет потерь loss = None if labels_safety is not None: - loss_safety = self.safety_loss(logits_safety, labels_safety) + loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety) - # Потери для атак только для unsafe текстов mask = (labels_safety == 1) if mask.any(): - loss_attack = self.attack_loss(logits_attack[mask], labels_attack[mask]) + loss_attack = nn.CrossEntropyLoss(weight=self.attack_weights)( + logits_attack[mask], + labels_attack[mask] + ) loss = loss_safety + 0.5 * loss_attack else: loss = loss_safety @@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module): 'loss': loss } + # 4. Метрики def compute_metrics(p): preds_safety = np.argmax(p.predictions[0], axis=1) @@ -879,6 +884,7 @@ def main(): metric_for_best_model="safety_f1", greater_is_better=True, fp16=True, + remove_unused_columns=True, # Убедитесь, что это True ) # Обучение @@ -889,6 +895,7 @@ def main(): eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] + label_names=["labels_safety", "labels_attack"] ) trainer.train() diff --git a/pereobuch.py b/pereobuch.py index f87b138..f10fc56 100644 --- a/pereobuch.py +++ b/pereobuch.py @@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module): # Функции потерь self.safety_loss = nn.CrossEntropyLoss() self.attack_loss = nn.CrossEntropyLoss() + + self.attack_weights = torch.tensor([1.0, 0.5, 10.0, 20.0]).to(DEVICE) - def forward(self, input_ids, attention_mask, labels_safety=None, labels_attack=None): + def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, + labels_safety=None, labels_attack=None): + # Поддержка обоих вариантов ввода outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, + inputs_embeds=inputs_embeds, return_dict=True ) pooled_output = outputs.last_hidden_state[:, 0, :] - - # Предсказания logits_safety = self.safety_classifier(pooled_output) logits_attack = self.attack_classifier(pooled_output) - # Расчет потерь loss = None if labels_safety is not None: - loss_safety = self.safety_loss(logits_safety, labels_safety) + loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety) - # Потери для атак только для unsafe текстов mask = (labels_safety == 1) if mask.any(): - loss_attack = self.attack_loss(logits_attack[mask], labels_attack[mask]) + loss_attack = nn.CrossEntropyLoss(weight=self.attack_weights)( + logits_attack[mask], + labels_attack[mask] + ) loss = loss_safety + 0.5 * loss_attack else: loss = loss_safety @@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module): 'loss': loss } + # 4. Метрики def compute_metrics(p): preds_safety = np.argmax(p.predictions[0], axis=1) @@ -879,6 +884,7 @@ def main(): metric_for_best_model="safety_f1", greater_is_better=True, fp16=True, + remove_unused_columns=True, # Убедитесь, что это True ) # Обучение @@ -889,6 +895,7 @@ def main(): eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] + label_names=["labels_safety", "labels_attack"] ) trainer.train() -- GitLab