From 2553252d69e39f75600f0cf44d6e41c9ac7e73bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?= =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?= =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru> Date: Fri, 28 Mar 2025 01:35:01 +0300 Subject: [PATCH] supermega --- .ipynb_checkpoints/ULTRAMegaOB-checkpoint.py | 42 ++++++++++++-------- ULTRAMegaOB.py | 42 ++++++++++++-------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py index b6d4986..0edac7d 100644 --- a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py +++ b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py @@ -46,7 +46,7 @@ class Config: SAVE_DIR = './safety_model' MAX_LENGTH = 192 BATCH_SIZE = 16 - EPOCHS = 2 + EPOCHS = 10 SAFETY_THRESHOLD = 0.5 TEST_SIZE = 0.2 VAL_SIZE = 0.1 @@ -175,24 +175,34 @@ def compute_metrics(p): # Метрики для безопасности preds_safety = np.argmax(safety_preds, axis=1) - safety_report = classification_report( - labels_safety, - preds_safety, - labels=[0, 1], - target_names=["safe", "unsafe"], - output_dict=True, - zero_division=0 - ) - if 'unsafe' not in safety_report: - logger.error("Отсутствуют примеры 'unsafe' для вычисления метрик!") - return {'eval_unsafe_recall': 0.0} # Возвращаем значение по умолчанию + try: + safety_report = classification_report( + labels_safety, + preds_safety, + labels=[0, 1], + target_names=["safe", "unsafe"], + output_dict=True, + zero_division=0 + ) + except Exception as e: + logger.error(f"Ошибка при создании отчета: {str(e)}") + safety_report = { + "safe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "unsafe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "accuracy": 0.0 + } + + # Гарантированное получение значений с проверкой вложенных ключей + unsafe_recall = safety_report.get("unsafe", {}).get("recall", 0.0) + safe_precision = safety_report.get("safe", {}).get("precision", 0.0) + accuracy = safety_report.get("accuracy", 0.0) - # Формируем метрики ТОЛЬКО с префиксом eval_ + # Формируем метрики с префиксом eval_ metrics = { - 'eval_accuracy': safety_report.get("accuracy", 0), - 'eval_unsafe_recall': safety_report.get("unsafe", {}).get("recall", 0), - 'eval_safe_precision': safety_report.get("safe", {}).get("precision", 0), + 'eval_accuracy': accuracy, + 'eval_unsafe_recall': unsafe_recall, + 'eval_safe_precision': safe_precision, } logger.info(f"Метрики для ранней остановки: {metrics}") diff --git a/ULTRAMegaOB.py b/ULTRAMegaOB.py index b6d4986..0edac7d 100644 --- a/ULTRAMegaOB.py +++ b/ULTRAMegaOB.py @@ -46,7 +46,7 @@ class Config: SAVE_DIR = './safety_model' MAX_LENGTH = 192 BATCH_SIZE = 16 - EPOCHS = 2 + EPOCHS = 10 SAFETY_THRESHOLD = 0.5 TEST_SIZE = 0.2 VAL_SIZE = 0.1 @@ -175,24 +175,34 @@ def compute_metrics(p): # Метрики для безопасности preds_safety = np.argmax(safety_preds, axis=1) - safety_report = classification_report( - labels_safety, - preds_safety, - labels=[0, 1], - target_names=["safe", "unsafe"], - output_dict=True, - zero_division=0 - ) - if 'unsafe' not in safety_report: - logger.error("Отсутствуют примеры 'unsafe' для вычисления метрик!") - return {'eval_unsafe_recall': 0.0} # Возвращаем значение по умолчанию + try: + safety_report = classification_report( + labels_safety, + preds_safety, + labels=[0, 1], + target_names=["safe", "unsafe"], + output_dict=True, + zero_division=0 + ) + except Exception as e: + logger.error(f"Ошибка при создании отчета: {str(e)}") + safety_report = { + "safe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "unsafe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0}, + "accuracy": 0.0 + } + + # Гарантированное получение значений с проверкой вложенных ключей + unsafe_recall = safety_report.get("unsafe", {}).get("recall", 0.0) + safe_precision = safety_report.get("safe", {}).get("precision", 0.0) + accuracy = safety_report.get("accuracy", 0.0) - # Формируем метрики ТОЛЬКО с префиксом eval_ + # Формируем метрики с префиксом eval_ metrics = { - 'eval_accuracy': safety_report.get("accuracy", 0), - 'eval_unsafe_recall': safety_report.get("unsafe", {}).get("recall", 0), - 'eval_safe_precision': safety_report.get("safe", {}).get("precision", 0), + 'eval_accuracy': accuracy, + 'eval_unsafe_recall': unsafe_recall, + 'eval_safe_precision': safe_precision, } logger.info(f"Метрики для ранней остановки: {metrics}") -- GitLab