From b9ab91fd2bd7eed11671afba0d2cce5663bb46b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?=
 =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?=
 =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru>
Date: Thu, 27 Mar 2025 03:31:42 +0300
Subject: [PATCH] supermega

---
 .ipynb_checkpoints/ULTRAMegaOB-checkpoint.py | 203 +++++++++++++------
 ULTRAMegaOB.py                               | 203 +++++++++++++------
 2 files changed, 274 insertions(+), 132 deletions(-)

diff --git a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
index 66ae391..fbde2fe 100644
--- a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
+++ b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py
@@ -17,6 +17,11 @@ import logging
 import nlpaug.augmenter.word as naw
 from collections import defaultdict
 from sklearn.metrics import classification_report
+import nltk
+nltk.download('punkt', quiet=True)
+nltk.download('averaged_perceptron_tagger', quiet=True)
+nltk.download('wordnet', quiet=True)
+nltk.download('omw-1.4', quiet=True)
 
 
 # Настройка логгирования
@@ -178,23 +183,76 @@ def compute_metrics(p):
 
 
 
-def augment_text(text, num_augments):
-    """Генерация аугментированных примеров с проверками"""
-    if len(text) > 1000:
-        logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов")
-        return [text]
+# def augment_text(text, num_augments):
+#     """Генерация аугментированных примеров с проверками"""
+#     if len(text) > 1000:
+#         logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов")
+#         return [text]
     
-    if not isinstance(text, str) or len(text.strip()) < 10:
-        return [text]
+#     if not isinstance(text, str) or len(text.strip()) < 10:
+#         return [text]
         
-    text = text.replace('\n', ' ').strip()
+#     text = text.replace('\n', ' ').strip()
     
-    augmented = set()
+#     augmented = set()
+#     try:
+#         # Английские синонимы
+#         eng_augs = synonym_aug.augment(text, n=num_augments)
+#         if eng_augs:
+#             augmented.update(a for a in eng_augs if isinstance(a, str))
+        
+#         # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹
+#         try:
+#             ru_augs = ru_synonym_aug.augment(text, n=num_augments)
+#             if ru_augs:
+#                 augmented.update(a for a in ru_augs if isinstance(a, str))
+#         except Exception as e:
+#             logger.warning(f"Ошибка русской аугментации: {str(e)}")
+        
+#         # Обратный перевод
+#         if len(augmented) < num_augments:
+#             try:
+#                 if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'):
+#                     tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented))
+#                 else:
+#                     tr_augs = translation_aug.augment(text, n=num_augments-len(augmented))
+                    
+#                 if tr_augs:
+#                     augmented.update(a.replace(' ##', '') for a in tr_augs 
+#                                  if isinstance(a, str) and a is not None)
+                    
+#             except Exception as e:
+#                 logger.warning(f"Ошибка перевода: {str(e)}")
+                
+#         if not augmented:
+#             logger.debug(f"Не удалось аугментировать текст: {text[:50]}...")
+#             return [text]
+            
+#         augmented = list(set(augmented))
+#         return list(augmented)[:num_augments] if augmented else [text]
+#     except Exception as e:
+#         logger.error(f"Критическая ошибка аугментации: {str(e)}")
+#         return [text]
+
+def augment_text(text, num_augments):
+    """Безопасная генерация аугментированных примеров"""
     try:
+        if len(text) > 1000:
+            return [text[:1000]]  # Обрезаем слишком длинные тексты
+        
+        if not isinstance(text, str) or len(text.strip()) < 10:
+            return [text]
+            
+        text = text.replace('\n', ' ').strip()
+        augmented = set([text])  # Начинаем с оригинала
+        
         # Английские синонимы
-        eng_augs = synonym_aug.augment(text, n=num_augments)
-        if eng_augs:
-            augmented.update(a for a in eng_augs if isinstance(a, str))
+        try:
+            eng_augs = synonym_aug.augment(text, n=num_augments)
+            if eng_augs:
+                augmented.update(a for a in eng_augs if isinstance(a, str))
+        except Exception as e:
+            logger.warning(f"Ошибка английской аугментации: {str(e)}")
         
         # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹
         try:
@@ -204,78 +262,91 @@ def augment_text(text, num_augments):
         except Exception as e:
             logger.warning(f"Ошибка русской аугментации: {str(e)}")
         
-        # Обратный перевод
-        if len(augmented) < num_augments:
-            try:
-                if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'):
-                    tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented))
-                else:
-                    tr_augs = translation_aug.augment(text, n=num_augments-len(augmented))
-                    
-                if tr_augs:
-                    augmented.update(a.replace(' ##', '') for a in tr_augs 
-                                 if isinstance(a, str) and a is not None)
-                    
-            except Exception as e:
-                logger.warning(f"Ошибка перевода: {str(e)}")
-                
-        if not augmented:
-            logger.debug(f"Не удалось аугментировать текст: {text[:50]}...")
-            return [text]
-            
-        augmented = list(set(augmented))
         return list(augmented)[:num_augments] if augmented else [text]
+        
     except Exception as e:
         logger.error(f"Критическая ошибка аугментации: {str(e)}")
         return [text]
 
 
+# def balance_attack_types(unsafe_data):
+#     """Балансировка типов атак с аугментацией"""
+#     if len(unsafe_data) == 0:
+#         logger.warning("Получен пустой DataFrame для балансировки")
+#         return pd.DataFrame()
+    
+#     # Логирование исходного распределения
+#     original_counts = unsafe_data['type'].value_counts()
+#     logger.info("\nИсходное распределение типов атак:")
+#     logger.info(original_counts.to_string())
+    
+#     attack_counts = unsafe_data['type'].value_counts()
+#     max_count = attack_counts.max()
+    
+#     balanced = []
+#     for attack_type, count in attack_counts.items():
+#         subset = unsafe_data[unsafe_data['type'] == attack_type]
+        
+#         if count < max_count:
+#             num_needed = max_count - count
+#             num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed)
+            
+#             augmented = subset.sample(n=num_augments, replace=True)
+#             augmented['prompt'] = augmented['prompt'].apply(
+#             lambda x: augment_text(x, 1)[0]  # Просто берем первый элемент возвращаемого списка
+#                 )
+            
+#             # Логирование аугментированных примеров
+#             logger.info(f"\nАугментация для {attack_type}:")
+#             logger.info(f"Исходных примеров: {len(subset)}")
+#             logger.info(f"Создано аугментированных: {len(augmented)}")
+#             if len(augmented) > 0:
+#                 logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...")
+            
+#             subset = pd.concat([subset, augmented]).sample(frac=1)
+        
+#         balanced.append(subset.sample(n=max_count, replace=False))
+    
+#     result = pd.concat(balanced).sample(frac=1)
+    
+#     # Логирование итогового распределения
+#     logger.info("\nИтоговое распределение после балансировки:")
+#     logger.info(result['type'].value_counts().to_string())
+    
+#     return result
 
 def balance_attack_types(unsafe_data):
-    """Балансировка типов атак с аугментацией"""
+    """Устойчивая балансировка классов"""
     if len(unsafe_data) == 0:
-        logger.warning("Получен пустой DataFrame для балансировки")
         return pd.DataFrame()
     
-    # Логирование исходного распределения
-    original_counts = unsafe_data['type'].value_counts()
-    logger.info("\nИсходное распределение типов атак:")
-    logger.info(original_counts.to_string())
+    # Логирование статистики
+    type_counts = unsafe_data['type'].value_counts()
+    logger.info(f"\nИсходное распределение:\n{type_counts.to_string()}")
     
-    attack_counts = unsafe_data['type'].value_counts()
-    max_count = attack_counts.max()
+    # Определяем целевое количество для балансировки
+    target_count = type_counts.max()
+    balanced_dfs = []
     
-    balanced = []
-    for attack_type, count in attack_counts.items():
-        subset = unsafe_data[unsafe_data['type'] == attack_type]
+    for attack_type, count in type_counts.items():
+        subset = unsafe_data[unsafe_data['type'] == attack_type].copy()
         
-        if count < max_count:
-            num_needed = max_count - count
-            num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed)
+        if count < target_count:
+            needed = target_count - count
+            augment_factor = min(Config.AUGMENTATION_FACTOR.get(attack_type, 1), needed)
             
-            augmented = subset.sample(n=num_augments, replace=True)
-            augmented['prompt'] = augmented['prompt'].apply(
-            lambda x: augment_text(x, 1)[0]  # Просто берем первый элемент возвращаемого списка
-                )
-            
-            # Логирование аугментированных примеров
-            logger.info(f"\nАугментация для {attack_type}:")
-            logger.info(f"Исходных примеров: {len(subset)}")
-            logger.info(f"Создано аугментированных: {len(augmented)}")
-            if len(augmented) > 0:
-                logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...")
+            # Безопасная аугментация
+            augmented_samples = subset.sample(n=augment_factor, replace=True)
+            augmented_samples['prompt'] = augmented_samples['prompt'].apply(
+                lambda x: augment_text(x, 1)[0]
+            )
             
-            subset = pd.concat([subset, augmented]).sample(frac=1)
+            subset = pd.concat([subset, augmented_samples])
         
-        balanced.append(subset.sample(n=max_count, replace=False))
-    
-    result = pd.concat(balanced).sample(frac=1)
-    
-    # Логирование итогового распределения
-    logger.info("\nИтоговое распределение после балансировки:")
-    logger.info(result['type'].value_counts().to_string())
+        # Фиксируем размер выборки
+        balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count))
     
-    return result
+    return pd.concat(balanced_dfs).sample(frac=1)
     
 
 
diff --git a/ULTRAMegaOB.py b/ULTRAMegaOB.py
index 66ae391..fbde2fe 100644
--- a/ULTRAMegaOB.py
+++ b/ULTRAMegaOB.py
@@ -17,6 +17,11 @@ import logging
 import nlpaug.augmenter.word as naw
 from collections import defaultdict
 from sklearn.metrics import classification_report
+import nltk
+nltk.download('punkt', quiet=True)
+nltk.download('averaged_perceptron_tagger', quiet=True)
+nltk.download('wordnet', quiet=True)
+nltk.download('omw-1.4', quiet=True)
 
 
 # Настройка логгирования
@@ -178,23 +183,76 @@ def compute_metrics(p):
 
 
 
-def augment_text(text, num_augments):
-    """Генерация аугментированных примеров с проверками"""
-    if len(text) > 1000:
-        logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов")
-        return [text]
+# def augment_text(text, num_augments):
+#     """Генерация аугментированных примеров с проверками"""
+#     if len(text) > 1000:
+#         logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов")
+#         return [text]
     
-    if not isinstance(text, str) or len(text.strip()) < 10:
-        return [text]
+#     if not isinstance(text, str) or len(text.strip()) < 10:
+#         return [text]
         
-    text = text.replace('\n', ' ').strip()
+#     text = text.replace('\n', ' ').strip()
     
-    augmented = set()
+#     augmented = set()
+#     try:
+#         # Английские синонимы
+#         eng_augs = synonym_aug.augment(text, n=num_augments)
+#         if eng_augs:
+#             augmented.update(a for a in eng_augs if isinstance(a, str))
+        
+#         # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹
+#         try:
+#             ru_augs = ru_synonym_aug.augment(text, n=num_augments)
+#             if ru_augs:
+#                 augmented.update(a for a in ru_augs if isinstance(a, str))
+#         except Exception as e:
+#             logger.warning(f"Ошибка русской аугментации: {str(e)}")
+        
+#         # Обратный перевод
+#         if len(augmented) < num_augments:
+#             try:
+#                 if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'):
+#                     tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented))
+#                 else:
+#                     tr_augs = translation_aug.augment(text, n=num_augments-len(augmented))
+                    
+#                 if tr_augs:
+#                     augmented.update(a.replace(' ##', '') for a in tr_augs 
+#                                  if isinstance(a, str) and a is not None)
+                    
+#             except Exception as e:
+#                 logger.warning(f"Ошибка перевода: {str(e)}")
+                
+#         if not augmented:
+#             logger.debug(f"Не удалось аугментировать текст: {text[:50]}...")
+#             return [text]
+            
+#         augmented = list(set(augmented))
+#         return list(augmented)[:num_augments] if augmented else [text]
+#     except Exception as e:
+#         logger.error(f"Критическая ошибка аугментации: {str(e)}")
+#         return [text]
+
+def augment_text(text, num_augments):
+    """Безопасная генерация аугментированных примеров"""
     try:
+        if len(text) > 1000:
+            return [text[:1000]]  # Обрезаем слишком длинные тексты
+        
+        if not isinstance(text, str) or len(text.strip()) < 10:
+            return [text]
+            
+        text = text.replace('\n', ' ').strip()
+        augmented = set([text])  # Начинаем с оригинала
+        
         # Английские синонимы
-        eng_augs = synonym_aug.augment(text, n=num_augments)
-        if eng_augs:
-            augmented.update(a for a in eng_augs if isinstance(a, str))
+        try:
+            eng_augs = synonym_aug.augment(text, n=num_augments)
+            if eng_augs:
+                augmented.update(a for a in eng_augs if isinstance(a, str))
+        except Exception as e:
+            logger.warning(f"Ошибка английской аугментации: {str(e)}")
         
         # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹
         try:
@@ -204,78 +262,91 @@ def augment_text(text, num_augments):
         except Exception as e:
             logger.warning(f"Ошибка русской аугментации: {str(e)}")
         
-        # Обратный перевод
-        if len(augmented) < num_augments:
-            try:
-                if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'):
-                    tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented))
-                else:
-                    tr_augs = translation_aug.augment(text, n=num_augments-len(augmented))
-                    
-                if tr_augs:
-                    augmented.update(a.replace(' ##', '') for a in tr_augs 
-                                 if isinstance(a, str) and a is not None)
-                    
-            except Exception as e:
-                logger.warning(f"Ошибка перевода: {str(e)}")
-                
-        if not augmented:
-            logger.debug(f"Не удалось аугментировать текст: {text[:50]}...")
-            return [text]
-            
-        augmented = list(set(augmented))
         return list(augmented)[:num_augments] if augmented else [text]
+        
     except Exception as e:
         logger.error(f"Критическая ошибка аугментации: {str(e)}")
         return [text]
 
 
+# def balance_attack_types(unsafe_data):
+#     """Балансировка типов атак с аугментацией"""
+#     if len(unsafe_data) == 0:
+#         logger.warning("Получен пустой DataFrame для балансировки")
+#         return pd.DataFrame()
+    
+#     # Логирование исходного распределения
+#     original_counts = unsafe_data['type'].value_counts()
+#     logger.info("\nИсходное распределение типов атак:")
+#     logger.info(original_counts.to_string())
+    
+#     attack_counts = unsafe_data['type'].value_counts()
+#     max_count = attack_counts.max()
+    
+#     balanced = []
+#     for attack_type, count in attack_counts.items():
+#         subset = unsafe_data[unsafe_data['type'] == attack_type]
+        
+#         if count < max_count:
+#             num_needed = max_count - count
+#             num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed)
+            
+#             augmented = subset.sample(n=num_augments, replace=True)
+#             augmented['prompt'] = augmented['prompt'].apply(
+#             lambda x: augment_text(x, 1)[0]  # Просто берем первый элемент возвращаемого списка
+#                 )
+            
+#             # Логирование аугментированных примеров
+#             logger.info(f"\nАугментация для {attack_type}:")
+#             logger.info(f"Исходных примеров: {len(subset)}")
+#             logger.info(f"Создано аугментированных: {len(augmented)}")
+#             if len(augmented) > 0:
+#                 logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...")
+            
+#             subset = pd.concat([subset, augmented]).sample(frac=1)
+        
+#         balanced.append(subset.sample(n=max_count, replace=False))
+    
+#     result = pd.concat(balanced).sample(frac=1)
+    
+#     # Логирование итогового распределения
+#     logger.info("\nИтоговое распределение после балансировки:")
+#     logger.info(result['type'].value_counts().to_string())
+    
+#     return result
 
 def balance_attack_types(unsafe_data):
-    """Балансировка типов атак с аугментацией"""
+    """Устойчивая балансировка классов"""
     if len(unsafe_data) == 0:
-        logger.warning("Получен пустой DataFrame для балансировки")
         return pd.DataFrame()
     
-    # Логирование исходного распределения
-    original_counts = unsafe_data['type'].value_counts()
-    logger.info("\nИсходное распределение типов атак:")
-    logger.info(original_counts.to_string())
+    # Логирование статистики
+    type_counts = unsafe_data['type'].value_counts()
+    logger.info(f"\nИсходное распределение:\n{type_counts.to_string()}")
     
-    attack_counts = unsafe_data['type'].value_counts()
-    max_count = attack_counts.max()
+    # Определяем целевое количество для балансировки
+    target_count = type_counts.max()
+    balanced_dfs = []
     
-    balanced = []
-    for attack_type, count in attack_counts.items():
-        subset = unsafe_data[unsafe_data['type'] == attack_type]
+    for attack_type, count in type_counts.items():
+        subset = unsafe_data[unsafe_data['type'] == attack_type].copy()
         
-        if count < max_count:
-            num_needed = max_count - count
-            num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed)
+        if count < target_count:
+            needed = target_count - count
+            augment_factor = min(Config.AUGMENTATION_FACTOR.get(attack_type, 1), needed)
             
-            augmented = subset.sample(n=num_augments, replace=True)
-            augmented['prompt'] = augmented['prompt'].apply(
-            lambda x: augment_text(x, 1)[0]  # Просто берем первый элемент возвращаемого списка
-                )
-            
-            # Логирование аугментированных примеров
-            logger.info(f"\nАугментация для {attack_type}:")
-            logger.info(f"Исходных примеров: {len(subset)}")
-            logger.info(f"Создано аугментированных: {len(augmented)}")
-            if len(augmented) > 0:
-                logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...")
+            # Безопасная аугментация
+            augmented_samples = subset.sample(n=augment_factor, replace=True)
+            augmented_samples['prompt'] = augmented_samples['prompt'].apply(
+                lambda x: augment_text(x, 1)[0]
+            )
             
-            subset = pd.concat([subset, augmented]).sample(frac=1)
+            subset = pd.concat([subset, augmented_samples])
         
-        balanced.append(subset.sample(n=max_count, replace=False))
-    
-    result = pd.concat(balanced).sample(frac=1)
-    
-    # Логирование итогового распределения
-    logger.info("\nИтоговое распределение после балансировки:")
-    logger.info(result['type'].value_counts().to_string())
+        # Фиксируем размер выборки
+        balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count))
     
-    return result
+    return pd.concat(balanced_dfs).sample(frac=1)
     
 
 
-- 
GitLab