bertwithgoodgrad.py 2.98 KiB
import torch
from transformers import BertTokenizer, BertModel
from peft import PeftModel, PeftConfig
from torch import nn

# 1. Инициализация устройства
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = "./fine-tuned-bert-lora_new"

# 2. Правильная загрузка токенизатора (без конфликтных параметров)
tokenizer = BertTokenizer.from_pretrained(
    MODEL_DIR,
    use_fast=True  # Используем быстрый токенизатор
)

# 3. Определение архитектуры модели
class MultiTaskBert(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.bert = base_model
        self.safety_head = nn.Linear(768, 2).to(device)  # safe/unsafe
        self.attack_head = nn.Linear(768, 4).to(device)  # 4 типа атак
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        pooled = outputs.last_hidden_state[:, 0, :]  # Берем [CLS] токен
        return {
            'safety': self.safety_head(pooled),
            'attack': self.attack_head(pooled)
        }

# 4. Загрузка модели
peft_config = PeftConfig.from_pretrained(MODEL_DIR)
base_model = BertModel.from_pretrained(
    peft_config.base_model_name_or_path,
    add_pooling_layer=False
).to(device)

model = MultiTaskBert(base_model).to(device)
model.bert = PeftModel.from_pretrained(model.bert, MODEL_DIR)
model.bert = model.bert.merge_and_unload()
model.eval()

# 5. Функция предсказания
def predict(text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding='max_length',
        max_length=512
    ).to(device)
    
    # Удаляем token_type_ids если есть
    if 'token_type_ids' in inputs:
        inputs.pop('token_type_ids')
    
    with torch.no_grad():
        outputs = model(**inputs)
        safety_probs = torch.softmax(outputs['safety'], dim=1)[0]
        attack_probs = torch.softmax(outputs['attack'], dim=1)[0]
    
    return {
        'safety': {
            'safe': safety_probs[0].item(),
            'unsafe': safety_probs[1].item()
        },
        'attack': {
            'jailbreak': attack_probs[0].item(),
            'evasion': attack_probs[1].item(),
            'generic': attack_probs[2].item(),
            'injection': attack_probs[3].item()
        }
    }

# 6. Пример использования
text = "How to bypass security"
result = predict(text)

print(f"\nРезультат для текста: '{text}'")
print(f"Безопасность: Safe {result['safety']['safe']:.1%} | Unsafe {result['safety']['unsafe']:.1%}")
if result['safety']['unsafe'] > 0.5:
    print("\nВероятности типов атак:")
    for name, prob in result['attack'].items():
        print(f"{name}: {prob:.1%}")