import pandas as pd
from sklearn.model_selection import train_test_split

# Загрузка данных
data = pd.read_csv('all_dataset.csv')

# Разделение данных на train, validation и test
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)



from datasets import Dataset

# Преобразование данных в формат Dataset
train_dataset = Dataset.from_pandas(train_data)
val_dataset = Dataset.from_pandas(val_data)
test_dataset = Dataset.from_pandas(test_data)


from transformers import BertTokenizer

# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Функция для токенизации
# def preprocess_function(examples):
#     return tokenizer(examples['prompt'], truncation=True, padding=True, max_length=512)


def preprocess_function(examples):
    # Токенизация текста
    tokenized = tokenizer(examples['prompt'], truncation=True, padding=True, max_length=512)

    # Подготовка меток
    labels_safety = [0 if label == "safe" else 1 for label in examples['safety']]
    labels_attack = [0 if label == "jailbreak" else 1 if label == "evasion" else 2 if label == "generic attack" else 3 for label in examples['type']]

    # Объединяем метки в один тензор
    tokenized['labels'] = list(zip(labels_safety, labels_attack))
    return tokenized

# Токенизация данных
train_dataset = train_dataset.map(preprocess_function, batched=True)
val_dataset = val_dataset.map(preprocess_function, batched=True)
test_dataset = test_dataset.map(preprocess_function, batched=True)


from transformers import BertPreTrainedModel, BertModel
from torch import nn
import torch


class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier_safety = nn.Linear(config.hidden_size, 2)  # safe/unsafe
        self.classifier_attack = nn.Linear(config.hidden_size, 4)  # jailbreak, evasion, generic attack, injection

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,  # Поглощаем все лишние аргументы
    ):
        # Если переданы inputs_embeds, используем их вместо input_ids
        if inputs_embeds is not None:
            outputs = self.bert(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        else:
            outputs = self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        pooled_output = outputs.last_hidden_state[:, 0, :]  # Используем [CLS] токен

        # Классификация safe/unsafe
        logits_safety = self.classifier_safety(pooled_output)

        # Классификация типов атак
        logits_attack = self.classifier_attack(pooled_output)

        loss = None
        if labels is not None:
            # Разделяем labels на labels_safety и labels_attack
            labels_safety, labels_attack = labels[:, 0], labels[:, 1]

            # Вычисляем потери для обеих задач
            loss_fct = nn.CrossEntropyLoss()
            loss_safety = loss_fct(logits_safety, labels_safety)
            loss_attack = loss_fct(logits_attack, labels_attack)
            loss = loss_safety + loss_attack  # Общий loss

        return {
            'logits_safety': logits_safety,
            'logits_attack': logits_attack,
            'loss': loss,
            'attentions': outputs.attentions if output_attentions else None,
            'hidden_states': outputs.hidden_states if output_hidden_states else None,
        }



# Загрузка модели
model = MultiTaskBert.from_pretrained('bert-base-uncased')


from peft import get_peft_model, LoraConfig, TaskType

# Конфигурация LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["query", "value"],
)

# Добавление LoRA к модели
model = get_peft_model(model, lora_config)

# Вывод информации о trainable параметрах
model.print_trainable_parameters()




from transformers import Trainer, TrainingArguments
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score

# Функция для вычисления метрик
def compute_metrics(p):
    preds_safety = np.argmax(p.predictions[0], axis=1)
    preds_attack = np.argmax(p.predictions[1], axis=1)
    labels_safety = p.label_ids[0]
    labels_attack = p.label_ids[1]

    # Метрики для safe/unsafe
    f1_safety = f1_score(labels_safety, preds_safety, average='weighted')
    precision_safety = precision_score(labels_safety, preds_safety, average='weighted')
    recall_safety = recall_score(labels_safety, preds_safety, average='weighted')

    # Метрики для типов атак
    f1_attack = f1_score(labels_attack, preds_attack, average='weighted')
    precision_attack = precision_score(labels_attack, preds_attack, average='weighted')
    recall_attack = recall_score(labels_attack, preds_attack, average='weighted')

    return {
        'f1_safety': f1_safety,
        'precision_safety': precision_safety,
        'recall_safety': recall_safety,
        'f1_attack': f1_attack,
        'precision_attack': precision_attack,
        'recall_attack': recall_attack,
    }

# Аргументы для обучения


training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1_safety",
    greater_is_better=True,
    label_names=["labels_safety", "labels_attack"],
    report_to="none",  # Отключаем W&B
)



# Создание Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

# Обучение модели
trainer.train()

# Оценка модели на тестовом наборе
results = trainer.evaluate(test_dataset)
print("Fine-tuned Model Evaluation Results:")
print(results)


# Сохранение модели
model.save_pretrained('./fine-tuned-bert-lora-multi-task')
tokenizer.save_pretrained('./fine-tuned-bert-lora-multi-task')