Skip to content
GitLab
Explore
Projects
Groups
Topics
Snippets
Projects
Groups
Topics
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
proekt
obuch
Commits
ebdddc50
Commit
ebdddc50
authored
3 weeks ago
by
Мазур Грета Евгеньевна
Browse files
Options
Download
Patches
Plain Diff
supermega
parent
271d1c9e
master
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
+53
-10
.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
ULTRAMegaOB.py
+53
-10
ULTRAMegaOB.py
with
106 additions
and
20 deletions
+106
-20
.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
+
53
−
10
View file @
ebdddc50
...
@@ -141,21 +141,46 @@ def compute_metrics(p):
...
@@ -141,21 +141,46 @@ def compute_metrics(p):
"support"
:
attack_report
[
attack_type
][
"support"
]
"support"
:
attack_report
[
attack_type
][
"support"
]
}
}
metrics
=
{
metrics
=
{
'eval_accuracy'
:
safety_report
[
"accuracy"
],
'eval_accuracy'
:
safety_report
[
"accuracy"
],
'eval_f1'
:
safety_report
[
"weighted avg"
][
"f1-score"
],
'eval_f1'
:
safety_report
[
"weighted avg"
][
"f1-score"
],
'eval_unsafe_recall'
:
safety_report
[
"unsafe"
][
"recall"
],
# Добавляем eval_ префикс
'eval_unsafe_recall'
:
safety_report
[
"unsafe"
][
"recall"
],
'eval_safe_precision'
:
safety_report
[
"safe"
][
"precision"
],
'eval_safe_precision'
:
safety_report
[
"safe"
][
"precision"
],
}
}
# Добавляем метрики для атак (только если есть unsafe примеры)
if
attack_details
:
if
attack_details
:
metrics
.
update
({
metrics
.
update
({
'eval_evasion_precision'
:
attack_details
.
get
(
"evasion"
,
{}).
get
(
"precision"
,
0
),
'eval_evasion_precision'
:
attack_details
.
get
(
"evasion"
,
{}).
get
(
"precision"
,
0
),
'eval_generic_attack_recall'
:
attack_details
.
get
(
"generic attack"
,
{}).
get
(
"recall"
,
0
)
'eval_generic_attack_recall'
:
attack_details
.
get
(
"generic attack"
,
{}).
get
(
"recall"
,
0
)
})
})
logger
.
info
(
f
"Возвращаемые метрики:
{
metrics
}
"
)
# Добавляем проверку на наличие unsafe примеров
if
np
.
sum
(
unsafe_mask
)
==
0
:
metrics
[
'eval_unsafe_recall'
]
=
0.0
logger
.
warning
(
"В валидационной выборке отсутствуют unsafe примеры!"
)
return
metrics
return
metrics
# metrics = {
# 'eval_accuracy': safety_report["accuracy"],
# 'eval_f1': safety_report["weighted avg"]["f1-score"],
# 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс
# 'eval_safe_precision': safety_report["safe"]["precision"],
# }
# # Добавляем метрики для атак (только если есть unsafe примеры)
# if attack_details:
# metrics.update({
# 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0),
# 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0)
# })
# logger.info(f"Возвращаемые метрики: {metrics}")
# return metrics
# # Формирование полного лога метрик
# # Формирование полного лога метрик
# full_metrics = {
# full_metrics = {
# "safety": {
# "safety": {
...
@@ -278,8 +303,9 @@ def augment_text(text, num_augments):
...
@@ -278,8 +303,9 @@ def augment_text(text, num_augments):
tr_augs
=
translation_aug
.
augment
(
text
,
n
=
num_augments
)
tr_augs
=
translation_aug
.
augment
(
text
,
n
=
num_augments
)
if
tr_augs
:
if
tr_augs
:
augmented
.
update
(
a
.
replace
(
' ##'
,
''
)
for
a
in
tr_augs
if
isinstance
(
a
,
str
)
and
a
is
not
None
)
augmented
.
update
(
a
.
replace
(
' ##'
,
''
)
for
a
in
tr_augs
if
isinstance
(
a
,
str
)
and
a
is
not
None
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
debug
(
f
"Обратный перевод пропущен:
{
str
(
e
)
}
"
)
logger
.
debug
(
f
"Обратный перевод пропущен:
{
str
(
e
)
}
"
)
...
@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data):
...
@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data):
# Фиксируем размер выборки
# Фиксируем размер выборки
balanced_dfs
.
append
(
subset
.
sample
(
n
=
target_count
,
replace
=
len
(
subset
)
<
target_count
))
balanced_dfs
.
append
(
subset
.
sample
(
n
=
target_count
,
replace
=
len
(
subset
)
<
target_count
))
return
pd
.
concat
(
balanced_dfs
).
sample
(
frac
=
1
)
# Объединяем все сбалансированные данные
result
=
pd
.
concat
(
balanced_dfs
).
sample
(
frac
=
1
)
# Логирование итогового распределения
logger
.
info
(
"
\n
Итоговое распределение после балансировки:"
)
logger
.
info
(
result
[
'type'
].
value_counts
().
to_string
())
# Проверка минимального количества примеров
if
result
[
'type'
].
value_counts
().
min
()
==
0
:
raise
ValueError
(
"Нулевое количество примеров для одного из классов атак"
)
return
result
def
load_and_balance_data
():
def
load_and_balance_data
():
...
@@ -411,6 +447,9 @@ def load_and_balance_data():
...
@@ -411,6 +447,9 @@ def load_and_balance_data():
logger
.
info
(
f
"Безопасные/Небезопасные:
{
balanced_data
[
'safety'
].
value_counts
().
to_dict
()
}
"
)
logger
.
info
(
f
"Безопасные/Небезопасные:
{
balanced_data
[
'safety'
].
value_counts
().
to_dict
()
}
"
)
logger
.
info
(
f
"Типы атак:
\n
{
balanced_data
[
balanced_data
[
'safety'
]
==
'unsafe'
][
'type'
].
value_counts
()
}
"
)
logger
.
info
(
f
"Типы атак:
\n
{
balanced_data
[
balanced_data
[
'safety'
]
==
'unsafe'
][
'type'
].
value_counts
()
}
"
)
if
(
balanced_data
[
'safety'
]
==
'unsafe'
).
sum
()
==
0
:
raise
ValueError
(
"No unsafe examples after balancing!"
)
return
balanced_data
return
balanced_data
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -516,6 +555,10 @@ def train_model():
...
@@ -516,6 +555,10 @@ def train_model():
stratify
=
train_data
[
'safety'
],
stratify
=
train_data
[
'safety'
],
random_state
=
Config
.
SEED
random_state
=
Config
.
SEED
)
)
logger
.
info
(
"
\n
Распределение классов в train:"
)
logger
.
info
(
train_data
[
'safety'
].
value_counts
())
logger
.
info
(
"
\n
Распределение классов в validation:"
)
logger
.
info
(
val_data
[
'safety'
].
value_counts
())
# 2. Токенизация
# 2. Токенизация
tokenizer
=
BertTokenizer
.
from_pretrained
(
Config
.
MODEL_NAME
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
Config
.
MODEL_NAME
)
...
@@ -561,7 +604,7 @@ def train_model():
...
@@ -561,7 +604,7 @@ def train_model():
report_to
=
"none"
,
report_to
=
"none"
,
seed
=
Config
.
SEED
,
seed
=
Config
.
SEED
,
max_grad_norm
=
1.0
,
max_grad_norm
=
1.0
,
metric_for_best_model
=
"eval_unsafe_recall"
,
metric_for_best_model
=
"eval_unsafe_recall"
,
greater_is_better
=
True
,
greater_is_better
=
True
,
load_best_model_at_end
=
True
,
load_best_model_at_end
=
True
,
)
)
...
...
This diff is collapsed.
Click to expand it.
ULTRAMegaOB.py
+
53
−
10
View file @
ebdddc50
...
@@ -141,21 +141,46 @@ def compute_metrics(p):
...
@@ -141,21 +141,46 @@ def compute_metrics(p):
"support"
:
attack_report
[
attack_type
][
"support"
]
"support"
:
attack_report
[
attack_type
][
"support"
]
}
}
metrics
=
{
metrics
=
{
'eval_accuracy'
:
safety_report
[
"accuracy"
],
'eval_accuracy'
:
safety_report
[
"accuracy"
],
'eval_f1'
:
safety_report
[
"weighted avg"
][
"f1-score"
],
'eval_f1'
:
safety_report
[
"weighted avg"
][
"f1-score"
],
'eval_unsafe_recall'
:
safety_report
[
"unsafe"
][
"recall"
],
# Добавляем eval_ префикс
'eval_unsafe_recall'
:
safety_report
[
"unsafe"
][
"recall"
],
'eval_safe_precision'
:
safety_report
[
"safe"
][
"precision"
],
'eval_safe_precision'
:
safety_report
[
"safe"
][
"precision"
],
}
}
# Добавляем метрики для атак (только если есть unsafe примеры)
if
attack_details
:
if
attack_details
:
metrics
.
update
({
metrics
.
update
({
'eval_evasion_precision'
:
attack_details
.
get
(
"evasion"
,
{}).
get
(
"precision"
,
0
),
'eval_evasion_precision'
:
attack_details
.
get
(
"evasion"
,
{}).
get
(
"precision"
,
0
),
'eval_generic_attack_recall'
:
attack_details
.
get
(
"generic attack"
,
{}).
get
(
"recall"
,
0
)
'eval_generic_attack_recall'
:
attack_details
.
get
(
"generic attack"
,
{}).
get
(
"recall"
,
0
)
})
})
logger
.
info
(
f
"Возвращаемые метрики:
{
metrics
}
"
)
# Добавляем проверку на наличие unsafe примеров
if
np
.
sum
(
unsafe_mask
)
==
0
:
metrics
[
'eval_unsafe_recall'
]
=
0.0
logger
.
warning
(
"В валидационной выборке отсутствуют unsafe примеры!"
)
return
metrics
return
metrics
# metrics = {
# 'eval_accuracy': safety_report["accuracy"],
# 'eval_f1': safety_report["weighted avg"]["f1-score"],
# 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс
# 'eval_safe_precision': safety_report["safe"]["precision"],
# }
# # Добавляем метрики для атак (только если есть unsafe примеры)
# if attack_details:
# metrics.update({
# 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0),
# 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0)
# })
# logger.info(f"Возвращаемые метрики: {metrics}")
# return metrics
# # Формирование полного лога метрик
# # Формирование полного лога метрик
# full_metrics = {
# full_metrics = {
# "safety": {
# "safety": {
...
@@ -278,8 +303,9 @@ def augment_text(text, num_augments):
...
@@ -278,8 +303,9 @@ def augment_text(text, num_augments):
tr_augs
=
translation_aug
.
augment
(
text
,
n
=
num_augments
)
tr_augs
=
translation_aug
.
augment
(
text
,
n
=
num_augments
)
if
tr_augs
:
if
tr_augs
:
augmented
.
update
(
a
.
replace
(
' ##'
,
''
)
for
a
in
tr_augs
if
isinstance
(
a
,
str
)
and
a
is
not
None
)
augmented
.
update
(
a
.
replace
(
' ##'
,
''
)
for
a
in
tr_augs
if
isinstance
(
a
,
str
)
and
a
is
not
None
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
debug
(
f
"Обратный перевод пропущен:
{
str
(
e
)
}
"
)
logger
.
debug
(
f
"Обратный перевод пропущен:
{
str
(
e
)
}
"
)
...
@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data):
...
@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data):
# Фиксируем размер выборки
# Фиксируем размер выборки
balanced_dfs
.
append
(
subset
.
sample
(
n
=
target_count
,
replace
=
len
(
subset
)
<
target_count
))
balanced_dfs
.
append
(
subset
.
sample
(
n
=
target_count
,
replace
=
len
(
subset
)
<
target_count
))
return
pd
.
concat
(
balanced_dfs
).
sample
(
frac
=
1
)
# Объединяем все сбалансированные данные
result
=
pd
.
concat
(
balanced_dfs
).
sample
(
frac
=
1
)
# Логирование итогового распределения
logger
.
info
(
"
\n
Итоговое распределение после балансировки:"
)
logger
.
info
(
result
[
'type'
].
value_counts
().
to_string
())
# Проверка минимального количества примеров
if
result
[
'type'
].
value_counts
().
min
()
==
0
:
raise
ValueError
(
"Нулевое количество примеров для одного из классов атак"
)
return
result
def
load_and_balance_data
():
def
load_and_balance_data
():
...
@@ -411,6 +447,9 @@ def load_and_balance_data():
...
@@ -411,6 +447,9 @@ def load_and_balance_data():
logger
.
info
(
f
"Безопасные/Небезопасные:
{
balanced_data
[
'safety'
].
value_counts
().
to_dict
()
}
"
)
logger
.
info
(
f
"Безопасные/Небезопасные:
{
balanced_data
[
'safety'
].
value_counts
().
to_dict
()
}
"
)
logger
.
info
(
f
"Типы атак:
\n
{
balanced_data
[
balanced_data
[
'safety'
]
==
'unsafe'
][
'type'
].
value_counts
()
}
"
)
logger
.
info
(
f
"Типы атак:
\n
{
balanced_data
[
balanced_data
[
'safety'
]
==
'unsafe'
][
'type'
].
value_counts
()
}
"
)
if
(
balanced_data
[
'safety'
]
==
'unsafe'
).
sum
()
==
0
:
raise
ValueError
(
"No unsafe examples after balancing!"
)
return
balanced_data
return
balanced_data
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -516,6 +555,10 @@ def train_model():
...
@@ -516,6 +555,10 @@ def train_model():
stratify
=
train_data
[
'safety'
],
stratify
=
train_data
[
'safety'
],
random_state
=
Config
.
SEED
random_state
=
Config
.
SEED
)
)
logger
.
info
(
"
\n
Распределение классов в train:"
)
logger
.
info
(
train_data
[
'safety'
].
value_counts
())
logger
.
info
(
"
\n
Распределение классов в validation:"
)
logger
.
info
(
val_data
[
'safety'
].
value_counts
())
# 2. Токенизация
# 2. Токенизация
tokenizer
=
BertTokenizer
.
from_pretrained
(
Config
.
MODEL_NAME
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
Config
.
MODEL_NAME
)
...
@@ -561,7 +604,7 @@ def train_model():
...
@@ -561,7 +604,7 @@ def train_model():
report_to
=
"none"
,
report_to
=
"none"
,
seed
=
Config
.
SEED
,
seed
=
Config
.
SEED
,
max_grad_norm
=
1.0
,
max_grad_norm
=
1.0
,
metric_for_best_model
=
"eval_unsafe_recall"
,
metric_for_best_model
=
"eval_unsafe_recall"
,
greater_is_better
=
True
,
greater_is_better
=
True
,
load_best_model_at_end
=
True
,
load_best_model_at_end
=
True
,
)
)
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Topics
Snippets