From 0f5e5ee80a770158c461c119f9251fec1c5cf020 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?=
 =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?=
 =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru>
Date: Wed, 26 Mar 2025 19:14:17 +0300
Subject: [PATCH] pereobuch

---
 .ipynb_checkpoints/pereobuch-checkpoint.py | 21 ++++++++++++++-------
 pereobuch.py                               | 21 ++++++++++++++-------
 2 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/.ipynb_checkpoints/pereobuch-checkpoint.py b/.ipynb_checkpoints/pereobuch-checkpoint.py
index f87b138..f10fc56 100644
--- a/.ipynb_checkpoints/pereobuch-checkpoint.py
+++ b/.ipynb_checkpoints/pereobuch-checkpoint.py
@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
         # Функции потерь
         self.safety_loss = nn.CrossEntropyLoss()
         self.attack_loss = nn.CrossEntropyLoss()
+        
+        self.attack_weights = torch.tensor([1.0, 0.5, 10.0, 20.0]).to(DEVICE)
     
-    def forward(self, input_ids, attention_mask, labels_safety=None, labels_attack=None):
+    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, 
+                labels_safety=None, labels_attack=None):
+        # Поддержка обоих вариантов ввода
         outputs = self.bert(
             input_ids=input_ids,
             attention_mask=attention_mask,
+            inputs_embeds=inputs_embeds,
             return_dict=True
         )
         
         pooled_output = outputs.last_hidden_state[:, 0, :]
-        
-        # Предсказания
         logits_safety = self.safety_classifier(pooled_output)
         logits_attack = self.attack_classifier(pooled_output)
         
-        # Расчет потерь
         loss = None
         if labels_safety is not None:
-            loss_safety = self.safety_loss(logits_safety, labels_safety)
+            loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
             
-            # Потери для атак только для unsafe текстов
             mask = (labels_safety == 1)
             if mask.any():
-                loss_attack = self.attack_loss(logits_attack[mask], labels_attack[mask])
+                loss_attack = nn.CrossEntropyLoss(weight=self.attack_weights)(
+                    logits_attack[mask],
+                    labels_attack[mask]
+                )
                 loss = loss_safety + 0.5 * loss_attack
             else:
                 loss = loss_safety
@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
             'loss': loss
         }
 
+
 # 4. Метрики
 def compute_metrics(p):
     preds_safety = np.argmax(p.predictions[0], axis=1)
@@ -879,6 +884,7 @@ def main():
         metric_for_best_model="safety_f1",
         greater_is_better=True,
         fp16=True,
+        remove_unused_columns=True,  # Убедитесь, что это True
     )
     
     # Обучение
@@ -889,6 +895,7 @@ def main():
         eval_dataset=val_dataset,
         compute_metrics=compute_metrics,
         callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
+        label_names=["labels_safety", "labels_attack"] 
     )
     
     trainer.train()
diff --git a/pereobuch.py b/pereobuch.py
index f87b138..f10fc56 100644
--- a/pereobuch.py
+++ b/pereobuch.py
@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
         # Функции потерь
         self.safety_loss = nn.CrossEntropyLoss()
         self.attack_loss = nn.CrossEntropyLoss()
+        
+        self.attack_weights = torch.tensor([1.0, 0.5, 10.0, 20.0]).to(DEVICE)
     
-    def forward(self, input_ids, attention_mask, labels_safety=None, labels_attack=None):
+    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, 
+                labels_safety=None, labels_attack=None):
+        # Поддержка обоих вариантов ввода
         outputs = self.bert(
             input_ids=input_ids,
             attention_mask=attention_mask,
+            inputs_embeds=inputs_embeds,
             return_dict=True
         )
         
         pooled_output = outputs.last_hidden_state[:, 0, :]
-        
-        # Предсказания
         logits_safety = self.safety_classifier(pooled_output)
         logits_attack = self.attack_classifier(pooled_output)
         
-        # Расчет потерь
         loss = None
         if labels_safety is not None:
-            loss_safety = self.safety_loss(logits_safety, labels_safety)
+            loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
             
-            # Потери для атак только для unsafe текстов
             mask = (labels_safety == 1)
             if mask.any():
-                loss_attack = self.attack_loss(logits_attack[mask], labels_attack[mask])
+                loss_attack = nn.CrossEntropyLoss(weight=self.attack_weights)(
+                    logits_attack[mask],
+                    labels_attack[mask]
+                )
                 loss = loss_safety + 0.5 * loss_attack
             else:
                 loss = loss_safety
@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
             'loss': loss
         }
 
+
 # 4. Метрики
 def compute_metrics(p):
     preds_safety = np.argmax(p.predictions[0], axis=1)
@@ -879,6 +884,7 @@ def main():
         metric_for_best_model="safety_f1",
         greater_is_better=True,
         fp16=True,
+        remove_unused_columns=True,  # Убедитесь, что это True
     )
     
     # Обучение
@@ -889,6 +895,7 @@ def main():
         eval_dataset=val_dataset,
         compute_metrics=compute_metrics,
         callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
+        label_names=["labels_safety", "labels_attack"] 
     )
     
     trainer.train()
-- 
GitLab