import os
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.utils.class_weight import compute_class_weight
from datasets import Dataset, load_from_disk
from transformers import BertTokenizer, BertPreTrainedModel, BertModel, Trainer, TrainingArguments
from torch import nn
from peft import get_peft_model, LoraConfig, TaskType

# Очистка кеша
torch.cuda.empty_cache()

# Определяем устройство (GPU или CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Пути для сохранения токенизированных данных
TOKENIZED_DATA_DIR = "./tokenized_data"
TRAIN_TOKENIZED_PATH = os.path.join(TOKENIZED_DATA_DIR, "train")
VAL_TOKENIZED_PATH = os.path.join(TOKENIZED_DATA_DIR, "val")
TEST_TOKENIZED_PATH = os.path.join(TOKENIZED_DATA_DIR, "test")

# Загрузка данных
data = pd.read_csv('all_dataset.csv')

# Разделение данных на train, validation и test
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)

# Преобразование данных в Dataset
train_dataset = Dataset.from_pandas(train_data)
val_dataset = Dataset.from_pandas(val_data)
test_dataset = Dataset.from_pandas(test_data)

# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Функция токенизации
def preprocess_function(examples):
    tokenized = tokenizer(examples['prompt'], truncation=True, padding=True, max_length=512)

    labels_safety = [0 if label == "safe" else 1 for label in examples['safety']]
    labels_attack = [0 if label == "jailbreak" else 1 if label == "evasion" else 2 if label == "generic attack" else 3 for label in examples['type']]

    tokenized['labels'] = list(zip(labels_safety, labels_attack))
    return tokenized

# Токенизация данных (если не сохранены, то создаем)
if os.path.exists(TRAIN_TOKENIZED_PATH) and os.path.exists(VAL_TOKENIZED_PATH) and os.path.exists(TEST_TOKENIZED_PATH):
    train_dataset = load_from_disk(TRAIN_TOKENIZED_PATH)
    val_dataset = load_from_disk(VAL_TOKENIZED_PATH)
    test_dataset = load_from_disk(TEST_TOKENIZED_PATH)
else:
    train_dataset = train_dataset.map(preprocess_function, batched=True)
    val_dataset = val_dataset.map(preprocess_function, batched=True)
    test_dataset = test_dataset.map(preprocess_function, batched=True)

    os.makedirs(TOKENIZED_DATA_DIR, exist_ok=True)
    train_dataset.save_to_disk(TRAIN_TOKENIZED_PATH)
    val_dataset.save_to_disk(VAL_TOKENIZED_PATH)
    test_dataset.save_to_disk(TEST_TOKENIZED_PATH)

# Вычисление весов классов
class_weights_task1 = compute_class_weight('balanced', classes=np.unique(train_data['safety']), y=train_data['safety'])
class_weights_task2 = compute_class_weight('balanced', classes=np.unique(train_data[train_data['safety'] == 'unsafe']['type']),
                                            y=train_data[train_data['safety'] == 'unsafe']['type'])

# Перевод весов в тензоры
class_weights_task1_tensor = torch.tensor(class_weights_task1, dtype=torch.float32).to(device)
class_weights_task2_tensor = torch.tensor(class_weights_task2, dtype=torch.float32).to(device)

# Определение модели
class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier_safety = nn.Linear(config.hidden_size, 2)
        self.classifier_attack = nn.Linear(config.hidden_size, 4)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Переводим тензоры на устройство
        input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled_output = outputs.last_hidden_state[:, 0, :]

        logits_safety = self.classifier_safety(pooled_output)
        logits_attack = self.classifier_attack(pooled_output)

        loss = None
        if labels is not None:
            labels_safety, labels_attack = labels[:, 0], labels[:, 1]
            loss_safety = nn.CrossEntropyLoss(weight=class_weights_task1_tensor)(logits_safety, labels_safety)
            loss_attack = nn.CrossEntropyLoss(weight=class_weights_task2_tensor)(logits_attack, labels_attack)
            loss = loss_safety + loss_attack

        return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}

# Создание модели
model = MultiTaskBert.from_pretrained('bert-base-uncased').to(device)

# Настройка LoRA.
# Явно исключаем сохранение модулей, не адаптированных LoRA (например, классификаторов),
# чтобы не возникало KeyError при загрузке.
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["query", "value"],
    modules_to_save=[]  # Не сохраняем дополнительные модули (classifier и т.д.)
)
model = get_peft_model(model, lora_config)

# Функция вычисления метрик
def compute_metrics(p):
    preds_safety = np.argmax(p.predictions[0], axis=1)
    preds_attack = np.argmax(p.predictions[1], axis=1)
    labels_safety, labels_attack = p.label_ids[:, 0], p.label_ids[:, 1]

    return {
        'f1_safety': f1_score(labels_safety, preds_safety, average='weighted'),
        'precision_safety': precision_score(labels_safety, preds_safety, average='weighted'),
        'recall_safety': recall_score(labels_safety, preds_safety, average='weighted'),
        'f1_attack': f1_score(labels_attack, preds_attack, average='weighted'),
        'precision_attack': precision_score(labels_attack, preds_attack, average='weighted'),
        'recall_attack': recall_score(labels_attack, preds_attack, average='weighted'),
    }

# Аргументы обучения
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1_safety",
    greater_is_better=True,
    fp16=True,
    max_grad_norm=1.0,
    warmup_steps=100,
    report_to="none",
)

# Обучение
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics)
trainer.train()

# Оценка
val_results = trainer.evaluate(val_dataset)
test_results = trainer.evaluate(test_dataset)

print("Validation Results:", val_results)
print("Test Results:", test_results)

# График потерь
logs = trainer.state.log_history
train_loss = [log["loss"] for log in logs if "loss" in log]
val_loss = [log["eval_loss"] for log in logs if "eval_loss" in log]

plt.plot(train_loss, label="Train Loss")
plt.plot(val_loss, label="Validation Loss")
plt.legend()
plt.show()

# # Сохранение модели вместе с адаптерами LoRA
# trainer.save_model('./fine-tuned-bert-lora_new')
# tokenizer.save_pretrained('./fine-tuned-bert-lora_new')
# Сохранение модели, адаптеров LoRA и токенизатора
model.save_pretrained('./fine-tuned-bert-lora_new2')  # Сохраняет модель и её веса
model.save_adapter('./fine-tuned-bert-lora_new2')  # Сохраняет адаптеры LoRA
tokenizer.save_pretrained('./fine-tuned-bert-lora_new2')  # Сохраняет токенизатор
print("Все сохранено")