cursorObuch-checkpoint.py 6.36 KiB
import os
import gc
import pandas as pd
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import BertTokenizer, BertModel, Trainer, TrainingArguments, EarlyStoppingCallback
from torch import nn
from peft import get_peft_model, LoraConfig, TaskType
import logging
from collections import defaultdict
from sklearn.metrics import classification_report, f1_score
import nltk
from typing import List, Dict, Union
from pathlib import Path
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

# Настройка NLTK один раз в начале
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

# Настройка логгирования
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler('model_training.log'), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

class ModelConfig:
    """Упрощенная конфигурация модели"""
    def __init__(self):
        self.model_name = 'distilbert-base-multilingual-cased'  # Более легкая модель
        self.max_length = 128  # Уменьшенная длина последовательности
        self.batch_size = 8
        self.epochs = 5  # Меньше эпох
        self.safety_threshold = 0.5
        self.test_size = 0.2
        self.val_size = 0.1
        self.early_stopping_patience = 2
        self.learning_rate = 2e-5
        self.seed = 42
        self.fp16 = True
        self.gradient_accumulation_steps = 4  # Уменьшено
        self.max_grad_norm = 1.0
        self.lora_r = 4  # Уменьшено
        self.lora_alpha = 8  # Уменьшено
        self.lora_dropout = 0.1

class SafetyModel(nn.Module):
    """Упрощенная модель для экономии памяти"""
    def __init__(self, model_name: str):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.safety_head = nn.Linear(self.bert.config.hidden_size, 2)
        self.attack_head = nn.Linear(self.bert.config.hidden_size, 4)
        
    def forward(self, input_ids=None, attention_mask=None, labels_safety=None, labels_attack=None, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0, :]
        return {
            'logits_safety': self.safety_head(pooled),
            'logits_attack': self.attack_head(pooled)
        }

def load_data() -> pd.DataFrame:
    """Загрузка данных без балансировки"""
    try:
        data = pd.read_csv('all_dataset.csv')
        data = data.dropna(subset=['prompt'])
        data['prompt'] = data['prompt'].str.strip()
        data = data[data['prompt'].str.len() > 0]
        return data
    except Exception as e:
        logger.error(f"Ошибка загрузки данных: {str(e)}")
        raise

def tokenize_data(tokenizer, df: pd.DataFrame) -> Dataset:
    """Упрощенная токенизация"""
    df = df.copy()
    df['labels_safety'] = df['safety'].apply(lambda x: 0 if x == "safe" else 1)
    df['labels_attack'] = df['type'].map({'jailbreak':0, 'injection':1, 'evasion':2, 'generic attack':3, 'generic_attack':3}).fillna(-1)
    df.loc[df['safety'] == 'safe', 'labels_attack'] = -1
    
    dataset = Dataset.from_pandas(df)
    
    def preprocess(examples):
        return tokenizer(
            examples['prompt'],
            truncation=True,
            padding='max_length',
            max_length=ModelConfig().max_length,
            return_tensors="pt"
        )
    
    return dataset.map(preprocess, batched=True, batch_size=1000, remove_columns=dataset.column_names)

def train():
    """Основная функция обучения"""
    try:
        config = ModelConfig()
        set_seed(config.seed)
        
        # Загрузка данных
        logger.info("Загрузка данных...")
        data = load_data()
        
        # Разделение данных
        train_data, test_data = train_test_split(
            data, test_size=config.test_size, random_state=config.seed
        )
        train_data, val_data = train_test_split(
            train_data, test_size=config.val_size, random_state=config.seed
        )
        
        # Токенизация
        logger.info("Токенизация...")
        tokenizer = BertTokenizer.from_pretrained(config.model_name)
        train_dataset = tokenize_data(tokenizer, train_data)
        val_dataset = tokenize_data(tokenizer, val_data)
        
        # Модель
        logger.info("Инициализация модели...")
        model = SafetyModel(config.model_name)
        peft_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=["query", "value"]
        )
        model = get_peft_model(model, peft_config)
        
        # Обучение
        training_args = TrainingArguments(
            output_dir='./output',
            evaluation_strategy="epoch",
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=config.batch_size*2,
            num_train_epochs=config.epochs,
            fp16=config.fp16,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            load_best_model_at_end=True,
            metric_for_best_model='eval_loss',
            greater_is_better=False
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)]
        )
        
        logger.info("Старт обучения...")
        trainer.train()
        
        # Сохранение
        model.save_pretrained('./model')
        tokenizer.save_pretrained('./model')
        
        logger.info("Обучение завершено!")
        
    except Exception as e:
        logger.error(f"Ошибка: {str(e)}")

if __name__ == "__main__":
    train()