From b9ab91fd2bd7eed11671afba0d2cce5663bb46b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?= =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?= =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru> Date: Thu, 27 Mar 2025 03:31:42 +0300 Subject: [PATCH] supermega --- .ipynb_checkpoints/ULTRAMegaOB-checkpoint.py | 203 +++++++++++++------ ULTRAMegaOB.py | 203 +++++++++++++------ 2 files changed, 274 insertions(+), 132 deletions(-) diff --git a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py index 66ae391..fbde2fe 100644 --- a/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py +++ b/.ipynb_checkpoints/ULTRAMegaOB-checkpoint.py @@ -17,6 +17,11 @@ import logging import nlpaug.augmenter.word as naw from collections import defaultdict from sklearn.metrics import classification_report +import nltk +nltk.download('punkt', quiet=True) +nltk.download('averaged_perceptron_tagger', quiet=True) +nltk.download('wordnet', quiet=True) +nltk.download('omw-1.4', quiet=True) # Настройка логгирования @@ -178,23 +183,76 @@ def compute_metrics(p): -def augment_text(text, num_augments): - """Генерация аугментированных примеров СЃ проверками""" - if len(text) > 1000: - logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов") - return [text] +# def augment_text(text, num_augments): +# """Генерация аугментированных примеров СЃ проверками""" +# if len(text) > 1000: +# logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов") +# return [text] - if not isinstance(text, str) or len(text.strip()) < 10: - return [text] +# if not isinstance(text, str) or len(text.strip()) < 10: +# return [text] - text = text.replace('\n', ' ').strip() +# text = text.replace('\n', ' ').strip() - augmented = set() +# augmented = set() +# try: +# # Английские СЃРёРЅРѕРЅРёРјС‹ +# eng_augs = synonym_aug.augment(text, n=num_augments) +# if eng_augs: +# augmented.update(a for a in eng_augs if isinstance(a, str)) + +# # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹ +# try: +# ru_augs = ru_synonym_aug.augment(text, n=num_augments) +# if ru_augs: +# augmented.update(a for a in ru_augs if isinstance(a, str)) +# except Exception as e: +# logger.warning(f"Ошибка СЂСѓСЃСЃРєРѕР№ аугментации: {str(e)}") + +# # Обратный перевод +# if len(augmented) < num_augments: +# try: +# if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'): +# tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented)) +# else: +# tr_augs = translation_aug.augment(text, n=num_augments-len(augmented)) + +# if tr_augs: +# augmented.update(a.replace(' ##', '') for a in tr_augs +# if isinstance(a, str) and a is not None) + +# except Exception as e: +# logger.warning(f"Ошибка перевода: {str(e)}") + +# if not augmented: +# logger.debug(f"РќРµ удалось аугментировать текст: {text[:50]}...") +# return [text] + +# augmented = list(set(augmented)) +# return list(augmented)[:num_augments] if augmented else [text] +# except Exception as e: +# logger.error(f"Критическая ошибка аугментации: {str(e)}") +# return [text] + +def augment_text(text, num_augments): + """Безопасная генерация аугментированных примеров""" try: + if len(text) > 1000: + return [text[:1000]] # Обрезаем слишком длинные тексты + + if not isinstance(text, str) or len(text.strip()) < 10: + return [text] + + text = text.replace('\n', ' ').strip() + augmented = set([text]) # Начинаем СЃ оригинала + # Английские СЃРёРЅРѕРЅРёРјС‹ - eng_augs = synonym_aug.augment(text, n=num_augments) - if eng_augs: - augmented.update(a for a in eng_augs if isinstance(a, str)) + try: + eng_augs = synonym_aug.augment(text, n=num_augments) + if eng_augs: + augmented.update(a for a in eng_augs if isinstance(a, str)) + except Exception as e: + logger.warning(f"Ошибка английской аугментации: {str(e)}") # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹ try: @@ -204,78 +262,91 @@ def augment_text(text, num_augments): except Exception as e: logger.warning(f"Ошибка СЂСѓСЃСЃРєРѕР№ аугментации: {str(e)}") - # Обратный перевод - if len(augmented) < num_augments: - try: - if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'): - tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented)) - else: - tr_augs = translation_aug.augment(text, n=num_augments-len(augmented)) - - if tr_augs: - augmented.update(a.replace(' ##', '') for a in tr_augs - if isinstance(a, str) and a is not None) - - except Exception as e: - logger.warning(f"Ошибка перевода: {str(e)}") - - if not augmented: - logger.debug(f"РќРµ удалось аугментировать текст: {text[:50]}...") - return [text] - - augmented = list(set(augmented)) return list(augmented)[:num_augments] if augmented else [text] + except Exception as e: logger.error(f"Критическая ошибка аугментации: {str(e)}") return [text] +# def balance_attack_types(unsafe_data): +# """Балансировка типов атак СЃ аугментацией""" +# if len(unsafe_data) == 0: +# logger.warning("Получен пустой DataFrame для балансировки") +# return pd.DataFrame() + +# # Логирование РёСЃС…РѕРґРЅРѕРіРѕ распределения +# original_counts = unsafe_data['type'].value_counts() +# logger.info("\nРСЃС…РѕРґРЅРѕРµ распределение типов атак:") +# logger.info(original_counts.to_string()) + +# attack_counts = unsafe_data['type'].value_counts() +# max_count = attack_counts.max() + +# balanced = [] +# for attack_type, count in attack_counts.items(): +# subset = unsafe_data[unsafe_data['type'] == attack_type] + +# if count < max_count: +# num_needed = max_count - count +# num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed) + +# augmented = subset.sample(n=num_augments, replace=True) +# augmented['prompt'] = augmented['prompt'].apply( +# lambda x: augment_text(x, 1)[0] # Просто берем первый элемент возвращаемого СЃРїРёСЃРєР° +# ) + +# # Логирование аугментированных примеров +# logger.info(f"\nАугментация для {attack_type}:") +# logger.info(f"Рсходных примеров: {len(subset)}") +# logger.info(f"Создано аугментированных: {len(augmented)}") +# if len(augmented) > 0: +# logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...") + +# subset = pd.concat([subset, augmented]).sample(frac=1) + +# balanced.append(subset.sample(n=max_count, replace=False)) + +# result = pd.concat(balanced).sample(frac=1) + +# # Логирование итогового распределения +# logger.info("\nРтоговое распределение после балансировки:") +# logger.info(result['type'].value_counts().to_string()) + +# return result def balance_attack_types(unsafe_data): - """Балансировка типов атак СЃ аугментацией""" + """Устойчивая балансировка классов""" if len(unsafe_data) == 0: - logger.warning("Получен пустой DataFrame для балансировки") return pd.DataFrame() - # Логирование РёСЃС…РѕРґРЅРѕРіРѕ распределения - original_counts = unsafe_data['type'].value_counts() - logger.info("\nРСЃС…РѕРґРЅРѕРµ распределение типов атак:") - logger.info(original_counts.to_string()) + # Логирование статистики + type_counts = unsafe_data['type'].value_counts() + logger.info(f"\nРСЃС…РѕРґРЅРѕРµ распределение:\n{type_counts.to_string()}") - attack_counts = unsafe_data['type'].value_counts() - max_count = attack_counts.max() + # Определяем целевое количество для балансировки + target_count = type_counts.max() + balanced_dfs = [] - balanced = [] - for attack_type, count in attack_counts.items(): - subset = unsafe_data[unsafe_data['type'] == attack_type] + for attack_type, count in type_counts.items(): + subset = unsafe_data[unsafe_data['type'] == attack_type].copy() - if count < max_count: - num_needed = max_count - count - num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed) + if count < target_count: + needed = target_count - count + augment_factor = min(Config.AUGMENTATION_FACTOR.get(attack_type, 1), needed) - augmented = subset.sample(n=num_augments, replace=True) - augmented['prompt'] = augmented['prompt'].apply( - lambda x: augment_text(x, 1)[0] # Просто берем первый элемент возвращаемого СЃРїРёСЃРєР° - ) - - # Логирование аугментированных примеров - logger.info(f"\nАугментация для {attack_type}:") - logger.info(f"Рсходных примеров: {len(subset)}") - logger.info(f"Создано аугментированных: {len(augmented)}") - if len(augmented) > 0: - logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...") + # Безопасная аугментация + augmented_samples = subset.sample(n=augment_factor, replace=True) + augmented_samples['prompt'] = augmented_samples['prompt'].apply( + lambda x: augment_text(x, 1)[0] + ) - subset = pd.concat([subset, augmented]).sample(frac=1) + subset = pd.concat([subset, augmented_samples]) - balanced.append(subset.sample(n=max_count, replace=False)) - - result = pd.concat(balanced).sample(frac=1) - - # Логирование итогового распределения - logger.info("\nРтоговое распределение после балансировки:") - logger.info(result['type'].value_counts().to_string()) + # Фиксируем размер выборки + balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count)) - return result + return pd.concat(balanced_dfs).sample(frac=1) diff --git a/ULTRAMegaOB.py b/ULTRAMegaOB.py index 66ae391..fbde2fe 100644 --- a/ULTRAMegaOB.py +++ b/ULTRAMegaOB.py @@ -17,6 +17,11 @@ import logging import nlpaug.augmenter.word as naw from collections import defaultdict from sklearn.metrics import classification_report +import nltk +nltk.download('punkt', quiet=True) +nltk.download('averaged_perceptron_tagger', quiet=True) +nltk.download('wordnet', quiet=True) +nltk.download('omw-1.4', quiet=True) # Настройка логгирования @@ -178,23 +183,76 @@ def compute_metrics(p): -def augment_text(text, num_augments): - """Генерация аугментированных примеров СЃ проверками""" - if len(text) > 1000: - logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов") - return [text] +# def augment_text(text, num_augments): +# """Генерация аугментированных примеров СЃ проверками""" +# if len(text) > 1000: +# logger.debug(f"Текст слишком длинный для аугментации: {len(text)} символов") +# return [text] - if not isinstance(text, str) or len(text.strip()) < 10: - return [text] +# if not isinstance(text, str) or len(text.strip()) < 10: +# return [text] - text = text.replace('\n', ' ').strip() +# text = text.replace('\n', ' ').strip() - augmented = set() +# augmented = set() +# try: +# # Английские СЃРёРЅРѕРЅРёРјС‹ +# eng_augs = synonym_aug.augment(text, n=num_augments) +# if eng_augs: +# augmented.update(a for a in eng_augs if isinstance(a, str)) + +# # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹ +# try: +# ru_augs = ru_synonym_aug.augment(text, n=num_augments) +# if ru_augs: +# augmented.update(a for a in ru_augs if isinstance(a, str)) +# except Exception as e: +# logger.warning(f"Ошибка СЂСѓСЃСЃРєРѕР№ аугментации: {str(e)}") + +# # Обратный перевод +# if len(augmented) < num_augments: +# try: +# if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'): +# tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented)) +# else: +# tr_augs = translation_aug.augment(text, n=num_augments-len(augmented)) + +# if tr_augs: +# augmented.update(a.replace(' ##', '') for a in tr_augs +# if isinstance(a, str) and a is not None) + +# except Exception as e: +# logger.warning(f"Ошибка перевода: {str(e)}") + +# if not augmented: +# logger.debug(f"РќРµ удалось аугментировать текст: {text[:50]}...") +# return [text] + +# augmented = list(set(augmented)) +# return list(augmented)[:num_augments] if augmented else [text] +# except Exception as e: +# logger.error(f"Критическая ошибка аугментации: {str(e)}") +# return [text] + +def augment_text(text, num_augments): + """Безопасная генерация аугментированных примеров""" try: + if len(text) > 1000: + return [text[:1000]] # Обрезаем слишком длинные тексты + + if not isinstance(text, str) or len(text.strip()) < 10: + return [text] + + text = text.replace('\n', ' ').strip() + augmented = set([text]) # Начинаем СЃ оригинала + # Английские СЃРёРЅРѕРЅРёРјС‹ - eng_augs = synonym_aug.augment(text, n=num_augments) - if eng_augs: - augmented.update(a for a in eng_augs if isinstance(a, str)) + try: + eng_augs = synonym_aug.augment(text, n=num_augments) + if eng_augs: + augmented.update(a for a in eng_augs if isinstance(a, str)) + except Exception as e: + logger.warning(f"Ошибка английской аугментации: {str(e)}") # Р СѓСЃСЃРєРёРµ СЃРёРЅРѕРЅРёРјС‹ try: @@ -204,78 +262,91 @@ def augment_text(text, num_augments): except Exception as e: logger.warning(f"Ошибка СЂСѓСЃСЃРєРѕР№ аугментации: {str(e)}") - # Обратный перевод - if len(augmented) < num_augments: - try: - if any(cyr_char in text for cyr_char in 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'): - tr_augs = translation_aug_ru.augment(text, n=num_augments-len(augmented)) - else: - tr_augs = translation_aug.augment(text, n=num_augments-len(augmented)) - - if tr_augs: - augmented.update(a.replace(' ##', '') for a in tr_augs - if isinstance(a, str) and a is not None) - - except Exception as e: - logger.warning(f"Ошибка перевода: {str(e)}") - - if not augmented: - logger.debug(f"РќРµ удалось аугментировать текст: {text[:50]}...") - return [text] - - augmented = list(set(augmented)) return list(augmented)[:num_augments] if augmented else [text] + except Exception as e: logger.error(f"Критическая ошибка аугментации: {str(e)}") return [text] +# def balance_attack_types(unsafe_data): +# """Балансировка типов атак СЃ аугментацией""" +# if len(unsafe_data) == 0: +# logger.warning("Получен пустой DataFrame для балансировки") +# return pd.DataFrame() + +# # Логирование РёСЃС…РѕРґРЅРѕРіРѕ распределения +# original_counts = unsafe_data['type'].value_counts() +# logger.info("\nРСЃС…РѕРґРЅРѕРµ распределение типов атак:") +# logger.info(original_counts.to_string()) + +# attack_counts = unsafe_data['type'].value_counts() +# max_count = attack_counts.max() + +# balanced = [] +# for attack_type, count in attack_counts.items(): +# subset = unsafe_data[unsafe_data['type'] == attack_type] + +# if count < max_count: +# num_needed = max_count - count +# num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed) + +# augmented = subset.sample(n=num_augments, replace=True) +# augmented['prompt'] = augmented['prompt'].apply( +# lambda x: augment_text(x, 1)[0] # Просто берем первый элемент возвращаемого СЃРїРёСЃРєР° +# ) + +# # Логирование аугментированных примеров +# logger.info(f"\nАугментация для {attack_type}:") +# logger.info(f"Рсходных примеров: {len(subset)}") +# logger.info(f"Создано аугментированных: {len(augmented)}") +# if len(augmented) > 0: +# logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...") + +# subset = pd.concat([subset, augmented]).sample(frac=1) + +# balanced.append(subset.sample(n=max_count, replace=False)) + +# result = pd.concat(balanced).sample(frac=1) + +# # Логирование итогового распределения +# logger.info("\nРтоговое распределение после балансировки:") +# logger.info(result['type'].value_counts().to_string()) + +# return result def balance_attack_types(unsafe_data): - """Балансировка типов атак СЃ аугментацией""" + """Устойчивая балансировка классов""" if len(unsafe_data) == 0: - logger.warning("Получен пустой DataFrame для балансировки") return pd.DataFrame() - # Логирование РёСЃС…РѕРґРЅРѕРіРѕ распределения - original_counts = unsafe_data['type'].value_counts() - logger.info("\nРСЃС…РѕРґРЅРѕРµ распределение типов атак:") - logger.info(original_counts.to_string()) + # Логирование статистики + type_counts = unsafe_data['type'].value_counts() + logger.info(f"\nРСЃС…РѕРґРЅРѕРµ распределение:\n{type_counts.to_string()}") - attack_counts = unsafe_data['type'].value_counts() - max_count = attack_counts.max() + # Определяем целевое количество для балансировки + target_count = type_counts.max() + balanced_dfs = [] - balanced = [] - for attack_type, count in attack_counts.items(): - subset = unsafe_data[unsafe_data['type'] == attack_type] + for attack_type, count in type_counts.items(): + subset = unsafe_data[unsafe_data['type'] == attack_type].copy() - if count < max_count: - num_needed = max_count - count - num_augments = min(Config.AUGMENTATION_FACTOR[attack_type], num_needed) + if count < target_count: + needed = target_count - count + augment_factor = min(Config.AUGMENTATION_FACTOR.get(attack_type, 1), needed) - augmented = subset.sample(n=num_augments, replace=True) - augmented['prompt'] = augmented['prompt'].apply( - lambda x: augment_text(x, 1)[0] # Просто берем первый элемент возвращаемого СЃРїРёСЃРєР° - ) - - # Логирование аугментированных примеров - logger.info(f"\nАугментация для {attack_type}:") - logger.info(f"Рсходных примеров: {len(subset)}") - logger.info(f"Создано аугментированных: {len(augmented)}") - if len(augmented) > 0: - logger.info(f"Пример аугментированного текста:\n{augmented.iloc[0]['prompt'][:200]}...") + # Безопасная аугментация + augmented_samples = subset.sample(n=augment_factor, replace=True) + augmented_samples['prompt'] = augmented_samples['prompt'].apply( + lambda x: augment_text(x, 1)[0] + ) - subset = pd.concat([subset, augmented]).sample(frac=1) + subset = pd.concat([subset, augmented_samples]) - balanced.append(subset.sample(n=max_count, replace=False)) - - result = pd.concat(balanced).sample(frac=1) - - # Логирование итогового распределения - logger.info("\nРтоговое распределение после балансировки:") - logger.info(result['type'].value_counts().to_string()) + # Фиксируем размер выборки + balanced_dfs.append(subset.sample(n=target_count, replace=len(subset) < target_count)) - return result + return pd.concat(balanced_dfs).sample(frac=1) -- GitLab