Skip to content
GitLab
Explore
Projects
Groups
Topics
Snippets
Projects
Groups
Topics
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
proekt
obuch
Commits
0f5e5ee8
Commit
0f5e5ee8
authored
1 week ago
by
Мазур Грета Евгеньевна
Browse files
Options
Download
Patches
Plain Diff
pereobuch
parent
f6e09856
master
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
.ipynb_checkpoints/pereobuch-checkpoint.py
+14
-7
.ipynb_checkpoints/pereobuch-checkpoint.py
pereobuch.py
+14
-7
pereobuch.py
with
28 additions
and
14 deletions
+28
-14
.ipynb_checkpoints/pereobuch-checkpoint.py
+
14
−
7
View file @
0f5e5ee8
...
@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
...
@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
# Функции потерь
# Функции потерь
self
.
safety_loss
=
nn
.
CrossEntropyLoss
()
self
.
safety_loss
=
nn
.
CrossEntropyLoss
()
self
.
attack_loss
=
nn
.
CrossEntropyLoss
()
self
.
attack_loss
=
nn
.
CrossEntropyLoss
()
self
.
attack_weights
=
torch
.
tensor
([
1.0
,
0.5
,
10.0
,
20.0
]).
to
(
DEVICE
)
def
forward
(
self
,
input_ids
,
attention_mask
,
labels_safety
=
None
,
labels_attack
=
None
):
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
labels_safety
=
None
,
labels_attack
=
None
):
# Поддержка обоих вариантов ввода
outputs
=
self
.
bert
(
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inputs_embeds
=
inputs_embeds
,
return_dict
=
True
return_dict
=
True
)
)
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
# Предсказания
logits_safety
=
self
.
safety_classifier
(
pooled_output
)
logits_safety
=
self
.
safety_classifier
(
pooled_output
)
logits_attack
=
self
.
attack_classifier
(
pooled_output
)
logits_attack
=
self
.
attack_classifier
(
pooled_output
)
# Расчет потерь
loss
=
None
loss
=
None
if
labels_safety
is
not
None
:
if
labels_safety
is
not
None
:
loss_safety
=
self
.
safety_l
oss
(
logits_safety
,
labels_safety
)
loss_safety
=
nn
.
CrossEntropyL
oss
(
)(
logits_safety
,
labels_safety
)
# Потери для атак только для unsafe текстов
mask
=
(
labels_safety
==
1
)
mask
=
(
labels_safety
==
1
)
if
mask
.
any
():
if
mask
.
any
():
loss_attack
=
self
.
attack_loss
(
logits_attack
[
mask
],
labels_attack
[
mask
])
loss_attack
=
nn
.
CrossEntropyLoss
(
weight
=
self
.
attack_weights
)(
logits_attack
[
mask
],
labels_attack
[
mask
]
)
loss
=
loss_safety
+
0.5
*
loss_attack
loss
=
loss_safety
+
0.5
*
loss_attack
else
:
else
:
loss
=
loss_safety
loss
=
loss_safety
...
@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
...
@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
'loss'
:
loss
'loss'
:
loss
}
}
# 4. Метрики
# 4. Метрики
def
compute_metrics
(
p
):
def
compute_metrics
(
p
):
preds_safety
=
np
.
argmax
(
p
.
predictions
[
0
],
axis
=
1
)
preds_safety
=
np
.
argmax
(
p
.
predictions
[
0
],
axis
=
1
)
...
@@ -879,6 +884,7 @@ def main():
...
@@ -879,6 +884,7 @@ def main():
metric_for_best_model
=
"safety_f1"
,
metric_for_best_model
=
"safety_f1"
,
greater_is_better
=
True
,
greater_is_better
=
True
,
fp16
=
True
,
fp16
=
True
,
remove_unused_columns
=
True
,
# Убедитесь, что это True
)
)
# Обучение
# Обучение
...
@@ -889,6 +895,7 @@ def main():
...
@@ -889,6 +895,7 @@ def main():
eval_dataset
=
val_dataset
,
eval_dataset
=
val_dataset
,
compute_metrics
=
compute_metrics
,
compute_metrics
=
compute_metrics
,
callbacks
=
[
EarlyStoppingCallback
(
early_stopping_patience
=
2
)]
callbacks
=
[
EarlyStoppingCallback
(
early_stopping_patience
=
2
)]
label_names
=
[
"labels_safety"
,
"labels_attack"
]
)
)
trainer
.
train
()
trainer
.
train
()
...
...
This diff is collapsed.
Click to expand it.
pereobuch.py
+
14
−
7
View file @
0f5e5ee8
...
@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
...
@@ -782,29 +782,33 @@ class SafetyAndAttackModel(nn.Module):
# Функции потерь
# Функции потерь
self
.
safety_loss
=
nn
.
CrossEntropyLoss
()
self
.
safety_loss
=
nn
.
CrossEntropyLoss
()
self
.
attack_loss
=
nn
.
CrossEntropyLoss
()
self
.
attack_loss
=
nn
.
CrossEntropyLoss
()
self
.
attack_weights
=
torch
.
tensor
([
1.0
,
0.5
,
10.0
,
20.0
]).
to
(
DEVICE
)
def
forward
(
self
,
input_ids
,
attention_mask
,
labels_safety
=
None
,
labels_attack
=
None
):
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
labels_safety
=
None
,
labels_attack
=
None
):
# Поддержка обоих вариантов ввода
outputs
=
self
.
bert
(
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inputs_embeds
=
inputs_embeds
,
return_dict
=
True
return_dict
=
True
)
)
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
# Предсказания
logits_safety
=
self
.
safety_classifier
(
pooled_output
)
logits_safety
=
self
.
safety_classifier
(
pooled_output
)
logits_attack
=
self
.
attack_classifier
(
pooled_output
)
logits_attack
=
self
.
attack_classifier
(
pooled_output
)
# Расчет потерь
loss
=
None
loss
=
None
if
labels_safety
is
not
None
:
if
labels_safety
is
not
None
:
loss_safety
=
self
.
safety_l
oss
(
logits_safety
,
labels_safety
)
loss_safety
=
nn
.
CrossEntropyL
oss
(
)(
logits_safety
,
labels_safety
)
# Потери для атак только для unsafe текстов
mask
=
(
labels_safety
==
1
)
mask
=
(
labels_safety
==
1
)
if
mask
.
any
():
if
mask
.
any
():
loss_attack
=
self
.
attack_loss
(
logits_attack
[
mask
],
labels_attack
[
mask
])
loss_attack
=
nn
.
CrossEntropyLoss
(
weight
=
self
.
attack_weights
)(
logits_attack
[
mask
],
labels_attack
[
mask
]
)
loss
=
loss_safety
+
0.5
*
loss_attack
loss
=
loss_safety
+
0.5
*
loss_attack
else
:
else
:
loss
=
loss_safety
loss
=
loss_safety
...
@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
...
@@ -815,6 +819,7 @@ class SafetyAndAttackModel(nn.Module):
'loss'
:
loss
'loss'
:
loss
}
}
# 4. Метрики
# 4. Метрики
def
compute_metrics
(
p
):
def
compute_metrics
(
p
):
preds_safety
=
np
.
argmax
(
p
.
predictions
[
0
],
axis
=
1
)
preds_safety
=
np
.
argmax
(
p
.
predictions
[
0
],
axis
=
1
)
...
@@ -879,6 +884,7 @@ def main():
...
@@ -879,6 +884,7 @@ def main():
metric_for_best_model
=
"safety_f1"
,
metric_for_best_model
=
"safety_f1"
,
greater_is_better
=
True
,
greater_is_better
=
True
,
fp16
=
True
,
fp16
=
True
,
remove_unused_columns
=
True
,
# Убедитесь, что это True
)
)
# Обучение
# Обучение
...
@@ -889,6 +895,7 @@ def main():
...
@@ -889,6 +895,7 @@ def main():
eval_dataset
=
val_dataset
,
eval_dataset
=
val_dataset
,
compute_metrics
=
compute_metrics
,
compute_metrics
=
compute_metrics
,
callbacks
=
[
EarlyStoppingCallback
(
early_stopping_patience
=
2
)]
callbacks
=
[
EarlyStoppingCallback
(
early_stopping_patience
=
2
)]
label_names
=
[
"labels_safety"
,
"labels_attack"
]
)
)
trainer
.
train
()
trainer
.
train
()
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Topics
Snippets