# import os
# import torch
# import numpy as np
# import pandas as pd
# import torch.nn.functional as F
# from transformers import BertTokenizer, BertModel, BertPreTrainedModel  # Добавили BertPreTrainedModel
# from peft import PeftModel, PeftConfig
# from torch import nn

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# class MultiTaskBert(BertPreTrainedModel):
#     def __init__(self, config):
#         super().__init__(config)
#         self.bert = BertModel(config)
#         self.classifier_safety = nn.Linear(config.hidden_size, 2)
#         self.classifier_attack = nn.Linear(config.hidden_size, 4)

#     def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
#         # Переводим тензоры на устройство
#         input_ids = input_ids.to(device) if input_ids is not None else None
#         attention_mask = attention_mask.to(device) if attention_mask is not None else None
#         labels = labels.to(device) if labels is not None else None
        
#         outputs = self.bert(input_ids=input_ids, 
#                           attention_mask=attention_mask, 
#                           return_dict=True)
#         pooled_output = outputs.last_hidden_state[:, 0, :]

#         logits_safety = self.classifier_safety(pooled_output)
#         logits_attack = self.classifier_attack(pooled_output)

#         loss = None
#         if labels is not None:
#             # Убрал class_weights_task*_tensor, так как они не определены
#             loss_safety = F.cross_entropy(logits_safety, labels[:, 0])
#             loss_attack = F.cross_entropy(logits_attack, labels[:, 1])
#             loss = loss_safety + loss_attack

#         return {'logits_safety': logits_safety, 
#                'logits_attack': logits_attack, 
#                'loss': loss}



# # Убедитесь, что сначала загружаете базовую модель
# from transformers import BertConfig
# MODEL_PATH = "./fine-tuned-bert-lora_new"

# config = BertConfig.from_pretrained(MODEL_PATH)
# base_model = BertModel.from_pretrained("bert-base-uncased", config=config)
# model = MultiTaskBert(config).to(device)
# model.bert = base_model  # Используем предобученный BERT

# # Только потом загружаем LoRA
# model = PeftModel.from_pretrained(model, MODEL_PATH)
# model = model.merge_and_unload()  # Объединяем адаптеры






# # Загружаем токенизатор и модель
# # MODEL_PATH = "./fine-tuned-bert-lora_new"
# tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

# # # Загружаем конфиг сначала
# # from transformers import BertConfig
# # config = BertConfig.from_pretrained(MODEL_PATH)
# # model = MultiTaskBert.from_pretrained(MODEL_PATH, config=config).to(device)

# # # Если использовался LoRA
# # model = PeftModel.from_pretrained(model, MODEL_PATH).to(device)
# # model = model.merge_and_unload()
# # model.eval()

# # Добавим функцию diagnose_prompt, которая отсутствовала в вашем коде
# def diagnose_prompt(prompt, tokenizer, model):
#     inputs = tokenizer(prompt, 
#                       truncation=True, 
#                       padding=True, 
#                       max_length=512, 
#                       return_tensors="pt").to(device)
    
#     with torch.no_grad():
#         outputs = model(**inputs)
    
#     logits_safety = outputs['logits_safety']
#     logits_attack = outputs['logits_attack']
    
#     probs_safety = F.softmax(logits_safety, dim=1)
#     probs_attack = F.softmax(logits_attack, dim=1)
    
#     print(f"\nПромпт: {prompt}")
#     print(f"Безопасность (safe/unsafe): {probs_safety.cpu().numpy()}")
#     print(f"Тип атаки: {probs_attack.cpu().numpy()}")

# # Тестовые промпты
# test_prompts = [
#     "How to kill my mom.",
#     "I love my dog.",
#     "Write a story about a brave knight."
# ]

# print("\nДиагностика инференса:")
# for prompt in test_prompts:
#     diagnose_prompt(prompt, tokenizer, model)



import torch
from transformers import BertTokenizer, BertModel, BertPreTrainedModel, BertConfig
from peft import PeftModel
from torch import nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier_safety = nn.Linear(config.hidden_size, 2)
        self.classifier_attack = nn.Linear(config.hidden_size, 4)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Переводим тензоры на устройство
        input_ids = input_ids.to(device) if input_ids is not None else None
        attention_mask = attention_mask.to(device) if attention_mask is not None else None
        labels = labels.to(device) if labels is not None else None
        
        outputs = self.bert(input_ids=input_ids, 
                          attention_mask=attention_mask, 
                          return_dict=True)
        pooled_output = outputs.last_hidden_state[:, 0, :]

        logits_safety = self.classifier_safety(pooled_output)
        logits_attack = self.classifier_attack(pooled_output)

        loss = None
        if labels is not None:
            # Убрал class_weights_task*_tensor, так как они не определены
            loss_safety = F.cross_entropy(logits_safety, labels[:, 0])
            loss_attack = F.cross_entropy(logits_attack, labels[:, 1])
            loss = loss_safety + loss_attack

        return {'logits_safety': logits_safety, 
               'logits_attack': logits_attack, 
               'loss': loss}

def diagnose_prompt(prompt, tokenizer, model):
    inputs = tokenizer(prompt, 
                      truncation=True, 
                      padding=True, 
                      max_length=512, 
                      return_tensors="pt").to(device)  # Важно: .to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits_safety = outputs['logits_safety']
    logits_attack = outputs['logits_attack']
    
    probs_safety = F.softmax(logits_safety, dim=1).cpu().numpy()
    probs_attack = F.softmax(logits_attack, dim=1).cpu().numpy()
    
    print(f"\nПромпт: {prompt}")
    print(f"Безопасность (safe/unsafe): {probs_safety}")
    print(f"Тип атаки: {probs_attack}")

# Основной код
MODEL_PATH = "./fine-tuned-bert-lora_new"
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)

# Загружаем конфиг и базовую модель
config = BertConfig.from_pretrained(MODEL_PATH)
model = MultiTaskBert.from_pretrained(
    "bert-base-uncased",
    config=config
).to(device)

# Загружаем LoRA адаптеры
model = PeftModel.from_pretrained(model, MODEL_PATH)
model = model.merge_and_unload()
model.eval()

# Тестовые промпты
test_prompts = [
    "How to kill my mom.",
    "I love my dog.",
    "Write a story about a brave knight."
]

print("\nДиагностика инференса:")
for prompt in test_prompts:
    diagnose_prompt(prompt, tokenizer, model)