import os
import pandas as pd
import torch
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
from datasets import Dataset
from transformers import (
    BertTokenizer, BertPreTrainedModel, BertModel,
    Trainer, TrainingArguments
)
from torch import nn
from peft import get_peft_model, LoraConfig, TaskType
import shutil

# Конфигурация
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_FOLDS = 5
BATCH_SIZE = 16
EPOCHS = 3
MODEL_NAME = 'bert-base-uncased'
OUTPUT_DIR = './results'

# Загрузка данных
data = pd.read_csv('all_dataset.csv')

# Токенизация с раздельными метками
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
def preprocess_function(examples):
    tokenized = tokenizer(examples['prompt'], truncation=True, padding='max_length', max_length=512)
    tokenized['labels_safety'] = [0 if label == "safe" else 1 for label in examples['safety']]
    tokenized['labels_attack'] = [
        0 if label == "jailbreak" else 
        1 if label == "evasion" else 
        2 if label == "generic attack" else 3 
        for label in examples['type']
    ]
    return tokenized

# Модель
class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier_safety = nn.Linear(config.hidden_size, 2)
        self.classifier_attack = nn.Linear(config.hidden_size, 4)

    def forward(self, input_ids=None, attention_mask=None, labels_safety=None, labels_attack=None, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        logits_safety = self.classifier_safety(pooled_output)
        logits_attack = self.classifier_attack(pooled_output)

        loss = None
        if labels_safety is not None and labels_attack is not None:
            loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
            loss_attack = nn.CrossEntropyLoss()(logits_attack, labels_attack)
            loss = loss_safety + loss_attack

        return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}

# Кросс-валидация
kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
all_metrics = []
best_fold_metrics = {'eval_f1_safety': -1}

for fold, (train_idx, val_idx) in enumerate(kf.split(data)):
    print(f"\n=== Fold {fold + 1}/{NUM_FOLDS} ===")
    
    # Подготовка данных
    train_fold = data.iloc[train_idx]
    val_fold = data.iloc[val_idx]
    
    train_dataset = Dataset.from_pandas(train_fold).map(preprocess_function, batched=True)
    val_dataset = Dataset.from_pandas(val_fold).map(preprocess_function, batched=True)

    # Инициализация модели
    model = MultiTaskBert.from_pretrained(MODEL_NAME).to(device)
    peft_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["query", "value"],
        modules_to_save=["classifier_safety", "classifier_attack"]
    )
    model = get_peft_model(model, peft_config)

    # Метрики
    def compute_metrics(p):
        preds_safety = np.argmax(p.predictions[0], axis=1)
        preds_attack = np.argmax(p.predictions[1], axis=1)
        labels_safety, labels_attack = p.label_ids[:, 0], p.label_ids[:, 1]
        return {
            'eval_f1_safety': f1_score(labels_safety, preds_safety, average='weighted'),
            'eval_f1_attack': f1_score(labels_attack, preds_attack, average='weighted')
        }

    # Обучение
    training_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_DIR, f'fold_{fold}'),
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=3e-5,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=EPOCHS,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1_safety",
        greater_is_better=True,
        fp16=True,
        report_to="none"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        label_names=["labels_safety", "labels_attack"]
    )

    trainer.train()
    fold_metrics = trainer.evaluate()
    all_metrics.append(fold_metrics)

    # Сохранение лучшей модели
    if fold_metrics['eval_f1_safety'] > best_fold_metrics['eval_f1_safety']:
        best_fold = fold
        best_fold_metrics = fold_metrics
        
        # Очистка предыдущей лучшей модели
        if os.path.exists(os.path.join(OUTPUT_DIR, 'best_model')):
            shutil.rmtree(os.path.join(OUTPUT_DIR, 'best_model'))
        
        # Сохранение новой лучшей модели
        model.save_pretrained(os.path.join(OUTPUT_DIR, 'best_model'))
        model.save_adapter(os.path.join(OUTPUT_DIR, 'best_model'), "lora_adapters")
        tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, 'best_model'))

# Итоговые результаты
print("\n=== Результаты кросс-валидации ===")
for i, metrics in enumerate(all_metrics):
    print(f"Fold {i + 1}: F1 Safety = {metrics['eval_f1_safety']:.4f}, F1 Attack = {metrics['eval_f1_attack']:.4f}")

print(f"\nЛучшая модель: Fold {best_fold + 1}")
print(f"F1 Safety: {best_fold_metrics['eval_f1_safety']:.4f}")
print(f"F1 Attack: {best_fold_metrics['eval_f1_attack']:.4f}")
print(f"\nМодель сохранена в: {os.path.join(OUTPUT_DIR, 'best_model')}")