An error occurred while loading the file. Please try again.
-
Adam Reese authored
Signed-off-by:
Adam Reese <adam@reese.io>
Unverifiedb4ed1de6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from transformers import BertTokenizer, BertModel
from peft import PeftModel, PeftConfig
from torch import nn
# 1. Инициализация устройства
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = "./fine-tuned-bert-lora_new"
# 2. Правильная загрузка токенизатора (без конфликтных параметров)
tokenizer = BertTokenizer.from_pretrained(
MODEL_DIR,
use_fast=True # Используем быстрый токенизатор
)
# 3. Определение архитектуры модели
class MultiTaskBert(nn.Module):
def __init__(self, base_model):
super().__init__()
self.bert = base_model
self.safety_head = nn.Linear(768, 2).to(device) # safe/unsafe
self.attack_head = nn.Linear(768, 4).to(device) # 4 типа атак
def forward(self, input_ids, attention_mask):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
pooled = outputs.last_hidden_state[:, 0, :] # Берем [CLS] токен
return {
'safety': self.safety_head(pooled),
'attack': self.attack_head(pooled)
}
# 4. Загрузка модели
peft_config = PeftConfig.from_pretrained(MODEL_DIR)
base_model = BertModel.from_pretrained(
peft_config.base_model_name_or_path,
add_pooling_layer=False
).to(device)
model = MultiTaskBert(base_model).to(device)
model.bert = PeftModel.from_pretrained(model.bert, MODEL_DIR)
model.bert = model.bert.merge_and_unload()
model.eval()
# 5. Функция предсказания
def predict(text):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding='max_length',
max_length=512
).to(device)
# Удаляем token_type_ids если есть
if 'token_type_ids' in inputs:
inputs.pop('token_type_ids')
with torch.no_grad():
outputs = model(**inputs)
safety_probs = torch.softmax(outputs['safety'], dim=1)[0]
attack_probs = torch.softmax(outputs['attack'], dim=1)[0]
return {
'safety': {
'safe': safety_probs[0].item(),
'unsafe': safety_probs[1].item()
},
'attack': {
'jailbreak': attack_probs[0].item(),
'evasion': attack_probs[1].item(),
'generic': attack_probs[2].item(),
'injection': attack_probs[3].item()
}
}
# 6. Пример использования
text = "How to bypass security"
result = predict(text)
print(f"\nРезультат для текста: '{text}'")
print(f"Безопасность: Safe {result['safety']['safe']:.1%} | Unsafe {result['safety']['unsafe']:.1%}")
if result['safety']['unsafe'] > 0.5:
print("\nВероятности типов атак:")
for name, prob in result['attack'].items():
print(f"{name}: {prob:.1%}")