diff --git a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py index 1d00d46a5e65e6ac73228a87a8ea2a6c1b733549..61cc4caf3c08a82cac2070280bf0858640070de8 100644 --- a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py +++ b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py @@ -141,21 +141,46 @@ def compute_metrics(p): "support": attack_report[attack_type]["support"] } metrics = { - 'eval_accuracy': safety_report["accuracy"], - 'eval_f1': safety_report["weighted avg"]["f1-score"], - 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс - 'eval_safe_precision': safety_report["safe"]["precision"], + 'eval_accuracy': safety_report["accuracy"], + 'eval_f1': safety_report["weighted avg"]["f1-score"], + 'eval_unsafe_recall': safety_report["unsafe"]["recall"], + 'eval_safe_precision': safety_report["safe"]["precision"], } - # Добавляем метрики для атак (только если есть unsafe примеры) if attack_details: metrics.update({ 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0), 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0) }) - logger.info(f"Возвращаемые метрики: {metrics}") + # Добавляем проверку РЅР° наличие unsafe примеров + + if np.sum(unsafe_mask) == 0: + metrics['eval_unsafe_recall'] = 0.0 + logger.warning("Р’ валидационной выборке отсутствуют unsafe примеры!") + + return metrics + # metrics = { + # 'eval_accuracy': safety_report["accuracy"], + # 'eval_f1': safety_report["weighted avg"]["f1-score"], + # 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс + # 'eval_safe_precision': safety_report["safe"]["precision"], + # } + + # # Добавляем метрики для атак (только если есть unsafe примеры) + # if attack_details: + # metrics.update({ + # 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0), + # 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0) + # }) + + # logger.info(f"Возвращаемые метрики: {metrics}") + # return metrics + + + + # # Формирование полного лога метрик # full_metrics = { # "safety": { @@ -278,8 +303,9 @@ def augment_text(text, num_augments): tr_augs = translation_aug.augment(text, n=num_augments) if tr_augs: - augmented.update(a.replace(' ##', '') for a in tr_augs - if isinstance(a, str) and a is not None) + + augmented.update(a.replace(' ##', '') for a in tr_augs if isinstance(a, str) and a is not None) + except Exception as e: logger.debug(f"Обратный перевод пропущен: {str(e)}") @@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data): # Фиксируем размер выборки balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count)) - return pd.concat(balanced_dfs).sample(frac=1) + # Объединяем РІСЃРµ сбалансированные данные + result = pd.concat(balanced_dfs).sample(frac=1) + # Логирование итогового распределения + logger.info("\nРтоговое распределение после балансировки:") + logger.info(result['type'].value_counts().to_string()) + + # Проверка минимального количества примеров + if result['type'].value_counts().min() == 0: + raise ValueError("Нулевое количество примеров для РѕРґРЅРѕРіРѕ РёР· классов атак") + + return result def load_and_balance_data(): @@ -411,6 +447,9 @@ def load_and_balance_data(): logger.info(f"Безопасные/Небезопасные: {balanced_data['safety'].value_counts().to_dict()}") logger.info(f"РўРёРїС‹ атак:\n{balanced_data[balanced_data['safety']=='unsafe']['type'].value_counts()}") + if (balanced_data['safety'] == 'unsafe').sum() == 0: + raise ValueError("No unsafe examples after balancing!") + return balanced_data except Exception as e: @@ -516,6 +555,10 @@ def train_model(): stratify=train_data['safety'], random_state=Config.SEED ) + logger.info("\nРаспределение классов РІ train:") + logger.info(train_data['safety'].value_counts()) + logger.info("\nРаспределение классов РІ validation:") + logger.info(val_data['safety'].value_counts()) # 2. Токенизация tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME) @@ -561,7 +604,7 @@ def train_model(): report_to="none", seed=Config.SEED, max_grad_norm=1.0, - metric_for_best_model="eval_unsafe_recall", + metric_for_best_model="eval_unsafe_recall", greater_is_better=True, load_best_model_at_end=True, ) diff --git a/ULTRAMegaOB.py b/ULTRAMegaOB.py index 1d00d46a5e65e6ac73228a87a8ea2a6c1b733549..61cc4caf3c08a82cac2070280bf0858640070de8 100644 --- a/ULTRAMegaOB.py +++ b/ULTRAMegaOB.py @@ -141,21 +141,46 @@ def compute_metrics(p): "support": attack_report[attack_type]["support"] } metrics = { - 'eval_accuracy': safety_report["accuracy"], - 'eval_f1': safety_report["weighted avg"]["f1-score"], - 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс - 'eval_safe_precision': safety_report["safe"]["precision"], + 'eval_accuracy': safety_report["accuracy"], + 'eval_f1': safety_report["weighted avg"]["f1-score"], + 'eval_unsafe_recall': safety_report["unsafe"]["recall"], + 'eval_safe_precision': safety_report["safe"]["precision"], } - # Добавляем метрики для атак (только если есть unsafe примеры) if attack_details: metrics.update({ 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0), 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0) }) - logger.info(f"Возвращаемые метрики: {metrics}") + # Добавляем проверку РЅР° наличие unsafe примеров + + if np.sum(unsafe_mask) == 0: + metrics['eval_unsafe_recall'] = 0.0 + logger.warning("Р’ валидационной выборке отсутствуют unsafe примеры!") + + return metrics + # metrics = { + # 'eval_accuracy': safety_report["accuracy"], + # 'eval_f1': safety_report["weighted avg"]["f1-score"], + # 'eval_unsafe_recall': safety_report["unsafe"]["recall"], # Добавляем eval_ префикс + # 'eval_safe_precision': safety_report["safe"]["precision"], + # } + + # # Добавляем метрики для атак (только если есть unsafe примеры) + # if attack_details: + # metrics.update({ + # 'eval_evasion_precision': attack_details.get("evasion", {}).get("precision", 0), + # 'eval_generic_attack_recall': attack_details.get("generic attack", {}).get("recall", 0) + # }) + + # logger.info(f"Возвращаемые метрики: {metrics}") + # return metrics + + + + # # Формирование полного лога метрик # full_metrics = { # "safety": { @@ -278,8 +303,9 @@ def augment_text(text, num_augments): tr_augs = translation_aug.augment(text, n=num_augments) if tr_augs: - augmented.update(a.replace(' ##', '') for a in tr_augs - if isinstance(a, str) and a is not None) + + augmented.update(a.replace(' ##', '') for a in tr_augs if isinstance(a, str) and a is not None) + except Exception as e: logger.debug(f"Обратный перевод пропущен: {str(e)}") @@ -367,8 +393,18 @@ def balance_attack_types(unsafe_data): # Фиксируем размер выборки balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count)) - return pd.concat(balanced_dfs).sample(frac=1) + # Объединяем РІСЃРµ сбалансированные данные + result = pd.concat(balanced_dfs).sample(frac=1) + # Логирование итогового распределения + logger.info("\nРтоговое распределение после балансировки:") + logger.info(result['type'].value_counts().to_string()) + + # Проверка минимального количества примеров + if result['type'].value_counts().min() == 0: + raise ValueError("Нулевое количество примеров для РѕРґРЅРѕРіРѕ РёР· классов атак") + + return result def load_and_balance_data(): @@ -411,6 +447,9 @@ def load_and_balance_data(): logger.info(f"Безопасные/Небезопасные: {balanced_data['safety'].value_counts().to_dict()}") logger.info(f"РўРёРїС‹ атак:\n{balanced_data[balanced_data['safety']=='unsafe']['type'].value_counts()}") + if (balanced_data['safety'] == 'unsafe').sum() == 0: + raise ValueError("No unsafe examples after balancing!") + return balanced_data except Exception as e: @@ -516,6 +555,10 @@ def train_model(): stratify=train_data['safety'], random_state=Config.SEED ) + logger.info("\nРаспределение классов РІ train:") + logger.info(train_data['safety'].value_counts()) + logger.info("\nРаспределение классов РІ validation:") + logger.info(val_data['safety'].value_counts()) # 2. Токенизация tokenizer = BertTokenizer.from_pretrained(Config.MODEL_NAME) @@ -561,7 +604,7 @@ def train_model(): report_to="none", seed=Config.SEED, max_grad_norm=1.0, - metric_for_best_model="eval_unsafe_recall", + metric_for_best_model="eval_unsafe_recall", greater_is_better=True, load_best_model_at_end=True, )