From 2553252d69e39f75600f0cf44d6e41c9ac7e73bd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?=
 =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?=
 =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru>
Date: Fri, 28 Mar 2025 01:35:01 +0300
Subject: [PATCH] supermega

---
 .ipynb_checkpoints/ULTRAMegaOB-checkpoint.py | 42 ++++++++++++--------
 ULTRAMegaOB.py                               | 42 ++++++++++++--------
 2 files changed, 52 insertions(+), 32 deletions(-)

diff --git a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
index b6d4986..0edac7d 100644
--- a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
+++ b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
@@ -46,7 +46,7 @@ class Config:
     SAVE_DIR = './safety_model'
     MAX_LENGTH = 192
     BATCH_SIZE = 16
-    EPOCHS = 2
+    EPOCHS = 10
     SAFETY_THRESHOLD = 0.5
     TEST_SIZE = 0.2
     VAL_SIZE = 0.1
@@ -175,24 +175,34 @@ def compute_metrics(p):
 
     # Метрики для безопасности
     preds_safety = np.argmax(safety_preds, axis=1)
-    safety_report = classification_report(
-        labels_safety, 
-        preds_safety,
-        labels=[0, 1],
-        target_names=["safe", "unsafe"],
-        output_dict=True,
-        zero_division=0
-    )
     
-    if 'unsafe' not in safety_report:
-        logger.error("Отсутствуют примеры 'unsafe' для вычисления метрик!")
-        return {'eval_unsafe_recall': 0.0}  # Возвращаем значение по умолчанию
+    try:
+        safety_report = classification_report(
+            labels_safety, 
+            preds_safety,
+            labels=[0, 1],
+            target_names=["safe", "unsafe"],
+            output_dict=True,
+            zero_division=0
+        )
+    except Exception as e:
+        logger.error(f"Ошибка при создании отчета: {str(e)}")
+        safety_report = {
+            "safe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0},
+            "unsafe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0},
+            "accuracy": 0.0
+        }
+
+    # Гарантированное получение значений с проверкой вложенных ключей
+    unsafe_recall = safety_report.get("unsafe", {}).get("recall", 0.0)
+    safe_precision = safety_report.get("safe", {}).get("precision", 0.0)
+    accuracy = safety_report.get("accuracy", 0.0)
 
-    # Формируем метрики ТОЛЬКО с префиксом eval_
+    # Формируем метрики с префиксом eval_
     metrics = {
-        'eval_accuracy': safety_report.get("accuracy", 0),
-        'eval_unsafe_recall': safety_report.get("unsafe", {}).get("recall", 0),
-        'eval_safe_precision': safety_report.get("safe", {}).get("precision", 0),
+        'eval_accuracy': accuracy,
+        'eval_unsafe_recall': unsafe_recall,
+        'eval_safe_precision': safe_precision,
     }
 
     logger.info(f"Метрики для ранней остановки: {metrics}")
diff --git a/ULTRAMegaOB.py b/ULTRAMegaOB.py
index b6d4986..0edac7d 100644
--- a/ULTRAMegaOB.py
+++ b/ULTRAMegaOB.py
@@ -46,7 +46,7 @@ class Config:
     SAVE_DIR = './safety_model'
     MAX_LENGTH = 192
     BATCH_SIZE = 16
-    EPOCHS = 2
+    EPOCHS = 10
     SAFETY_THRESHOLD = 0.5
     TEST_SIZE = 0.2
     VAL_SIZE = 0.1
@@ -175,24 +175,34 @@ def compute_metrics(p):
 
     # Метрики для безопасности
     preds_safety = np.argmax(safety_preds, axis=1)
-    safety_report = classification_report(
-        labels_safety, 
-        preds_safety,
-        labels=[0, 1],
-        target_names=["safe", "unsafe"],
-        output_dict=True,
-        zero_division=0
-    )
     
-    if 'unsafe' not in safety_report:
-        logger.error("Отсутствуют примеры 'unsafe' для вычисления метрик!")
-        return {'eval_unsafe_recall': 0.0}  # Возвращаем значение по умолчанию
+    try:
+        safety_report = classification_report(
+            labels_safety, 
+            preds_safety,
+            labels=[0, 1],
+            target_names=["safe", "unsafe"],
+            output_dict=True,
+            zero_division=0
+        )
+    except Exception as e:
+        logger.error(f"Ошибка при создании отчета: {str(e)}")
+        safety_report = {
+            "safe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0},
+            "unsafe": {"precision": 0.0, "recall": 0.0, "f1-score": 0.0, "support": 0},
+            "accuracy": 0.0
+        }
+
+    # Гарантированное получение значений с проверкой вложенных ключей
+    unsafe_recall = safety_report.get("unsafe", {}).get("recall", 0.0)
+    safe_precision = safety_report.get("safe", {}).get("precision", 0.0)
+    accuracy = safety_report.get("accuracy", 0.0)
 
-    # Формируем метрики ТОЛЬКО с префиксом eval_
+    # Формируем метрики с префиксом eval_
     metrics = {
-        'eval_accuracy': safety_report.get("accuracy", 0),
-        'eval_unsafe_recall': safety_report.get("unsafe", {}).get("recall", 0),
-        'eval_safe_precision': safety_report.get("safe", {}).get("precision", 0),
+        'eval_accuracy': accuracy,
+        'eval_unsafe_recall': unsafe_recall,
+        'eval_safe_precision': safe_precision,
     }
 
     logger.info(f"Метрики для ранней остановки: {metrics}")
-- 
GitLab