diff --git a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py index b6d4986a290f4f19ebb373c0e8944e2b3a764ab0..0edac7d92cfa83422695bfa5a0ec02a2d66c8673 100644 --- a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py +++ b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py @@ -46,7 +46,7 @@ class Config: SAVE_DIR = './safety_model' MAX_LENGTH = 192 BATCH_SIZE = 16 - EPOCHS = 2 + EPOCHS = 10 SAFETY_THRESHOLD = 0.5 TEST_SIZE = 0.2 VAL_SIZE = 0.1 @@ -175,24 +175,34 @@ def compute_metrics(p): # Метрики для безопасности preds_safety = np.argmax(safety_preds, axis=1) - safety_report = classification_report( - labels_safety, - preds_safety, - labels=[0, 1], - target_names=["safe", "unsafe"], - output_dict=True, - zero_division=0 - ) - if 'unsafe' not in safety_report: - logger.error("Отсутствуют примеры 'unsafe' для вычисления метрик!") - return {'eval_unsafe_recall': 0.0} # Возвращаем значение по умолчанию + try: + safety_report = classification_report( + labels_safety, + preds_safety, + labels=[0, 1], + target_names=["safe", "unsafe"], + output_dict=True, + zero_division=0 + ) + except Exception as e: + logger.error(f"Ошибка при создании отчета: {str(e)}") + safety_report = { + "safe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "unsafe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "accuracy": 0.0 + } + + # Гарантированное получение значений с проверкой вложенных ключей + unsafe_recall = safety_report.get("unsafe", {}).get("recall", 0.0) + safe_precision = safety_report.get("safe", {}).get("precision", 0.0) + accuracy = safety_report.get("accuracy", 0.0) - # Формируем метрики ТОЛЬКО с префиксом eval_ + # Формируем метрики с префиксом eval_ metrics = { - 'eval_accuracy': safety_report.get("accuracy", 0), - 'eval_unsafe_recall': safety_report.get("unsafe", {}).get("recall", 0), - 'eval_safe_precision': safety_report.get("safe", {}).get("precision", 0), + 'eval_accuracy': accuracy, + 'eval_unsafe_recall': unsafe_recall, + 'eval_safe_precision': safe_precision, } logger.info(f"Метрики для ранней остановки: {metrics}") diff --git a/ULTRAMegaOB.py b/ULTRAMegaOB.py index b6d4986a290f4f19ebb373c0e8944e2b3a764ab0..0edac7d92cfa83422695bfa5a0ec02a2d66c8673 100644 --- a/ULTRAMegaOB.py +++ b/ULTRAMegaOB.py @@ -46,7 +46,7 @@ class Config: SAVE_DIR = './safety_model' MAX_LENGTH = 192 BATCH_SIZE = 16 - EPOCHS = 2 + EPOCHS = 10 SAFETY_THRESHOLD = 0.5 TEST_SIZE = 0.2 VAL_SIZE = 0.1 @@ -175,24 +175,34 @@ def compute_metrics(p): # Метрики для безопасности preds_safety = np.argmax(safety_preds, axis=1) - safety_report = classification_report( - labels_safety, - preds_safety, - labels=[0, 1], - target_names=["safe", "unsafe"], - output_dict=True, - zero_division=0 - ) - if 'unsafe' not in safety_report: - logger.error("Отсутствуют примеры 'unsafe' для вычисления метрик!") - return {'eval_unsafe_recall': 0.0} # Возвращаем значение по умолчанию + try: + safety_report = classification_report( + labels_safety, + preds_safety, + labels=[0, 1], + target_names=["safe", "unsafe"], + output_dict=True, + zero_division=0 + ) + except Exception as e: + logger.error(f"Ошибка при создании отчета: {str(e)}") + safety_report = { + "safe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "unsafe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "accuracy": 0.0 + } + + # Гарантированное получение значений с проверкой вложенных ключей + unsafe_recall = safety_report.get("unsafe", {}).get("recall", 0.0) + safe_precision = safety_report.get("safe", {}).get("precision", 0.0) + accuracy = safety_report.get("accuracy", 0.0) - # Формируем метрики ТОЛЬКО с префиксом eval_ + # Формируем метрики с префиксом eval_ metrics = { - 'eval_accuracy': safety_report.get("accuracy", 0), - 'eval_unsafe_recall': safety_report.get("unsafe", {}).get("recall", 0), - 'eval_safe_precision': safety_report.get("safe", {}).get("precision", 0), + 'eval_accuracy': accuracy, + 'eval_unsafe_recall': unsafe_recall, + 'eval_safe_precision': safe_precision, } logger.info(f"Метрики для ранней остановки: {metrics}")