-
Мазур Грета Евгеньевна authored03d67af5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from transformers import BertTokenizer
import torch
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
# Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде
class MultiTaskBert(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.classifier_safety = nn.Linear(config.hidden_size, 2)
self.classifier_attack = nn.Linear(config.hidden_size, 4)
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
# Переводим тензоры на устройство
input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output)
loss = None
if labels is not None:
labels_safety, labels_attack = labels[:, 0], labels[:, 1]
loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
loss_attack = nn.CrossEntropyLoss()(logits_attack, labels_attack)
loss = loss_safety + loss_attack
return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}
# Загрузка модели с LoRA адаптерами
model = MultiTaskBert.from_pretrained('./fine-tuned-bert-lora_new').to(device)
# Восстановление модели с LoRA
lora_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["query", "value"],
modules_to_save=[] # Не сохраняем дополнительные модули (classifier и т.д.)
)
model = get_peft_model(model, lora_config)
# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('./fine-tuned-bert-lora_new')
# Функция для классификации текста с LoRA
def classify_text_with_lora(text):
# Токенизация текста
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)
# Получение предсказаний
with torch.no_grad():
outputs = model(**inputs)
# Извлечение логитов для безопасности и типа атаки
logits_safety = outputs['logits_safety']
logits_attack = outputs['logits_attack']
# Применение softmax для получения вероятностей
safety_probs = torch.softmax(logits_safety, dim=1)
attack_probs = torch.softmax(logits_attack, dim=1)
# Получение предсказанных меток для безопасности
pred_safety = torch.argmax(safety_probs, dim=1).item()
# Перевод числовых меток в текстовые метки
safety_labels = ['safe', 'unsafe']
attack_labels = ['jailbreak', 'evasion', 'generic attack', 'injection']
# Возвращаем только метку безопасности, и тип атаки только для unsafe
if pred_safety == 1: # unsafe
pred_attack = torch.argmax(attack_probs, dim=1).item()
return safety_labels[pred_safety], attack_labels[pred_attack]
else: # safe
return safety_labels[pred_safety], None # Тип атаки не выводится для safe
# Пример классификации текста
text = "привет как дела"
safety_pred, attack_pred = classify_text_with_lora(text)
print(f"Predicted safety: {safety_pred}")
if attack_pred:
print(f"Predicted attack type: {attack_pred}")