Commit ebdddc50 authored by Мазур Грета Евгеньевна's avatar Мазур Грета Евгеньевна
Browse files

supermega

parent 271d1c9e
No related merge requests found
Showing with 106 additions and 20 deletions
+106 -20
...@@ -141,21 +141,46 @@ def compute_metrics(p): ...@@ -141,21 +141,46 @@ def compute_metrics(p):
"support": attack_report[attack_type]["support"] "support": attack_report[attack_type]["support"]
} }
metrics = { metrics = {
'eval_accuracy': safety_report["accuracy"], 'eval_accuracy': safety_report["accuracy"],
'eval_f1': safety_report["weighted avg"]["f1-score"], 'eval_f1': safety_report["weighted avg"]["f1-score"],
'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс 'eval_unsafe_recall': safety_report["unsafe"]["recall"],
'eval_safe_precision': safety_report["safe"]["precision"], 'eval_safe_precision': safety_report["safe"]["precision"],
} }
# Добавляем метрики для атак (только если есть unsafe примеры)
if attack_details: if attack_details:
metrics.update({ metrics.update({
'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0), 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0),
'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0) 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0)
}) })
logger.info(f"Возвращаемые метрики: {metrics}") # Добавляем проверку на наличие unsafe примеров
if np.sum(unsafe_mask) == 0:
metrics['eval_unsafe_recall'] = 0.0
logger.warning("В валидационной выборке отсутствуют unsafe примеры!")
return metrics return metrics
# metrics = {
# 'eval_accuracy': safety_report["accuracy"],
# 'eval_f1': safety_report["weighted avg"]["f1-score"],
# 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс
# 'eval_safe_precision': safety_report["safe"]["precision"],
# }
# # Добавляем метрики для атак (только если есть unsafe примеры)
# if attack_details:
# metrics.update({
# 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0),
# 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0)
# })
# logger.info(f"Возвращаемые метрики: {metrics}")
# return metrics
# # Формирование полного лога метрик # # Формирование полного лога метрик
# full_metrics = { # full_metrics = {
# "safety": { # "safety": {
...@@ -278,8 +303,9 @@ def augment_text(text, num_augments): ...@@ -278,8 +303,9 @@ def augment_text(text, num_augments):
tr_augs = translation_aug.augment(text, n=num_augments) tr_augs = translation_aug.augment(text, n=num_augments)
if tr_augs: if tr_augs:
augmented.update(a.replace(' ##', '') for a in tr_augs
if isinstance(a, str) and a is not None) augmented.update(a.replace(' ##', '') for a in tr_augs if isinstance(a, str) and a is not None)
except Exception as e: except Exception as e:
logger.debug(f"Обратный перевод пропущен: {str(e)}") logger.debug(f"Обратный перевод пропущен: {str(e)}")
...@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data): ...@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data):
# Фиксируем размер выборки # Фиксируем размер выборки
balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count)) balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count))
return pd.concat(balanced_dfs).sample(frac=1) # Объединяем все сбалансированные данные
result = pd.concat(balanced_dfs).sample(frac=1)
# Логирование итогового распределения
logger.info("\nИтоговое распределение после балансировки:")
logger.info(result['type'].value_counts().to_string())
# Проверка минимального количества примеров
if result['type'].value_counts().min() == 0:
raise ValueError("Нулевое количество примеров для одного из классов атак")
return result
def load_and_balance_data(): def load_and_balance_data():
...@@ -411,6 +447,9 @@ def load_and_balance_data(): ...@@ -411,6 +447,9 @@ def load_and_balance_data():
logger.info(f"Безопасные/Небезопасные: {balanced_data['safety'].value_counts().to_dict()}") logger.info(f"Безопасные/Небезопасные: {balanced_data['safety'].value_counts().to_dict()}")
logger.info(f"Типы атак:\n{balanced_data[balanced_data['safety']=='unsafe']['type'].value_counts()}") logger.info(f"Типы атак:\n{balanced_data[balanced_data['safety']=='unsafe']['type'].value_counts()}")
if (balanced_data['safety'] == 'unsafe').sum() == 0:
raise ValueError("No unsafe examples after balancing!")
return balanced_data return balanced_data
except Exception as e: except Exception as e:
...@@ -516,6 +555,10 @@ def train_model(): ...@@ -516,6 +555,10 @@ def train_model():
stratify=train_data['safety'], stratify=train_data['safety'],
random_state=Config.SEED random_state=Config.SEED
) )
logger.info("\nРаспределение классов в train:")
logger.info(train_data['safety'].value_counts())
logger.info("\nРаспределение классов в validation:")
logger.info(val_data['safety'].value_counts())
# 2. Токенизация # 2. Токенизация
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME) tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
...@@ -561,7 +604,7 @@ def train_model(): ...@@ -561,7 +604,7 @@ def train_model():
report_to="none", report_to="none",
seed=Config.SEED, seed=Config.SEED,
max_grad_norm=1.0, max_grad_norm=1.0,
metric_for_best_model="eval_unsafe_recall", metric_for_best_model="eval_unsafe_recall",
greater_is_better=True, greater_is_better=True,
load_best_model_at_end=True, load_best_model_at_end=True,
) )
......
...@@ -141,21 +141,46 @@ def compute_metrics(p): ...@@ -141,21 +141,46 @@ def compute_metrics(p):
"support": attack_report[attack_type]["support"] "support": attack_report[attack_type]["support"]
} }
metrics = { metrics = {
'eval_accuracy': safety_report["accuracy"], 'eval_accuracy': safety_report["accuracy"],
'eval_f1': safety_report["weighted avg"]["f1-score"], 'eval_f1': safety_report["weighted avg"]["f1-score"],
'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс 'eval_unsafe_recall': safety_report["unsafe"]["recall"],
'eval_safe_precision': safety_report["safe"]["precision"], 'eval_safe_precision': safety_report["safe"]["precision"],
} }
# Добавляем метрики для атак (только если есть unsafe примеры)
if attack_details: if attack_details:
metrics.update({ metrics.update({
'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0), 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0),
'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0) 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0)
}) })
logger.info(f"Возвращаемые метрики: {metrics}") # Добавляем проверку на наличие unsafe примеров
if np.sum(unsafe_mask) == 0:
metrics['eval_unsafe_recall'] = 0.0
logger.warning("В валидационной выборке отсутствуют unsafe примеры!")
return metrics return metrics
# metrics = {
# 'eval_accuracy': safety_report["accuracy"],
# 'eval_f1': safety_report["weighted avg"]["f1-score"],
# 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс
# 'eval_safe_precision': safety_report["safe"]["precision"],
# }
# # Добавляем метрики для атак (только если есть unsafe примеры)
# if attack_details:
# metrics.update({
# 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0),
# 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0)
# })
# logger.info(f"Возвращаемые метрики: {metrics}")
# return metrics
# # Формирование полного лога метрик # # Формирование полного лога метрик
# full_metrics = { # full_metrics = {
# "safety": { # "safety": {
...@@ -278,8 +303,9 @@ def augment_text(text, num_augments): ...@@ -278,8 +303,9 @@ def augment_text(text, num_augments):
tr_augs = translation_aug.augment(text, n=num_augments) tr_augs = translation_aug.augment(text, n=num_augments)
if tr_augs: if tr_augs:
augmented.update(a.replace(' ##', '') for a in tr_augs
if isinstance(a, str) and a is not None) augmented.update(a.replace(' ##', '') for a in tr_augs if isinstance(a, str) and a is not None)
except Exception as e: except Exception as e:
logger.debug(f"Обратный перевод пропущен: {str(e)}") logger.debug(f"Обратный перевод пропущен: {str(e)}")
...@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data): ...@@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data):
# Фиксируем размер выборки # Фиксируем размер выборки
balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count)) balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count))
return pd.concat(balanced_dfs).sample(frac=1) # Объединяем все сбалансированные данные
result = pd.concat(balanced_dfs).sample(frac=1)
# Логирование итогового распределения
logger.info("\nИтоговое распределение после балансировки:")
logger.info(result['type'].value_counts().to_string())
# Проверка минимального количества примеров
if result['type'].value_counts().min() == 0:
raise ValueError("Нулевое количество примеров для одного из классов атак")
return result
def load_and_balance_data(): def load_and_balance_data():
...@@ -411,6 +447,9 @@ def load_and_balance_data(): ...@@ -411,6 +447,9 @@ def load_and_balance_data():
logger.info(f"Безопасные/Небезопасные: {balanced_data['safety'].value_counts().to_dict()}") logger.info(f"Безопасные/Небезопасные: {balanced_data['safety'].value_counts().to_dict()}")
logger.info(f"Типы атак:\n{balanced_data[balanced_data['safety']=='unsafe']['type'].value_counts()}") logger.info(f"Типы атак:\n{balanced_data[balanced_data['safety']=='unsafe']['type'].value_counts()}")
if (balanced_data['safety'] == 'unsafe').sum() == 0:
raise ValueError("No unsafe examples after balancing!")
return balanced_data return balanced_data
except Exception as e: except Exception as e:
...@@ -516,6 +555,10 @@ def train_model(): ...@@ -516,6 +555,10 @@ def train_model():
stratify=train_data['safety'], stratify=train_data['safety'],
random_state=Config.SEED random_state=Config.SEED
) )
logger.info("\nРаспределение классов в train:")
logger.info(train_data['safety'].value_counts())
logger.info("\nРаспределение классов в validation:")
logger.info(val_data['safety'].value_counts())
# 2. Токенизация # 2. Токенизация
tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME) tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME)
...@@ -561,7 +604,7 @@ def train_model(): ...@@ -561,7 +604,7 @@ def train_model():
report_to="none", report_to="none",
seed=Config.SEED, seed=Config.SEED,
max_grad_norm=1.0, max_grad_norm=1.0,
metric_for_best_model="eval_unsafe_recall", metric_for_best_model="eval_unsafe_recall",
greater_is_better=True, greater_is_better=True,
load_best_model_at_end=True, load_best_model_at_end=True,
) )
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment