Commit 0f5e5ee8 authored by Мазур Грета Евгеньевна's avatar Мазур Грета Евгеньевна
Browse files

pereobuch

parent f6e09856
No related merge requests found
Showing with 28 additions and 14 deletions
+28 -14
......@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
# Функции потерь
self.safety_loss = nn.CrossEntropyLoss()
self.attack_loss = nn.CrossEntropyLoss()
self.attack_weights = torch.tensor([1.0, 0.5, 10.0, 20.0]).to(DEVICE)
def forward(self, input_ids, attention_mask, labels_safety=None, labels_attack=None):
def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None,
labels_safety=None, labels_attack=None):
# Поддержка обоих вариантов ввода
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True
)
pooled_output = outputs.last_hidden_state[:, 0, :]
# Предсказания
logits_safety = self.safety_classifier(pooled_output)
logits_attack = self.attack_classifier(pooled_output)
# Расчет потерь
loss = None
if labels_safety is not None:
loss_safety = self.safety_loss(logits_safety, labels_safety)
loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
# Потери для атак только для unsafe текстов
mask = (labels_safety == 1)
if mask.any():
loss_attack = self.attack_loss(logits_attack[mask], labels_attack[mask])
loss_attack = nn.CrossEntropyLoss(weight=self.attack_weights)(
logits_attack[mask],
labels_attack[mask]
)
loss = loss_safety + 0.5 * loss_attack
else:
loss = loss_safety
......@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
'loss': loss
}
# 4. Метрики
def compute_metrics(p):
preds_safety = np.argmax(p.predictions[0], axis=1)
......@@ -879,6 +884,7 @@ def main():
metric_for_best_model="safety_f1",
greater_is_better=True,
fp16=True,
remove_unused_columns=True, # Убедитесь, что это True
)
# Обучение
......@@ -889,6 +895,7 @@ def main():
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
label_names=["labels_safety", "labels_attack"]
)
trainer.train()
......
......@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
# Функции потерь
self.safety_loss = nn.CrossEntropyLoss()
self.attack_loss = nn.CrossEntropyLoss()
self.attack_weights = torch.tensor([1.0, 0.5, 10.0, 20.0]).to(DEVICE)
def forward(self, input_ids, attention_mask, labels_safety=None, labels_attack=None):
def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None,
labels_safety=None, labels_attack=None):
# Поддержка обоих вариантов ввода
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=True
)
pooled_output = outputs.last_hidden_state[:, 0, :]
# Предсказания
logits_safety = self.safety_classifier(pooled_output)
logits_attack = self.attack_classifier(pooled_output)
# Расчет потерь
loss = None
if labels_safety is not None:
loss_safety = self.safety_loss(logits_safety, labels_safety)
loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
# Потери для атак только для unsafe текстов
mask = (labels_safety == 1)
if mask.any():
loss_attack = self.attack_loss(logits_attack[mask], labels_attack[mask])
loss_attack = nn.CrossEntropyLoss(weight=self.attack_weights)(
logits_attack[mask],
labels_attack[mask]
)
loss = loss_safety + 0.5 * loss_attack
else:
loss = loss_safety
......@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
'loss': loss
}
# 4. Метрики
def compute_metrics(p):
preds_safety = np.argmax(p.predictions[0], axis=1)
......@@ -879,6 +884,7 @@ def main():
metric_for_best_model="safety_f1",
greater_is_better=True,
fp16=True,
remove_unused_columns=True, # Убедитесь, что это True
)
# Обучение
......@@ -889,6 +895,7 @@ def main():
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
label_names=["labels_safety", "labels_attack"]
)
trainer.train()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment