from transformers import BertTokenizer import torch from peft import get_peft_model, LoraConfig, TaskType import torch.nn as nn from transformers import BertModel, BertPreTrainedModel # Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде class MultiTaskBert(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.bert = BertModel(config) self.classifier_safety = nn.Linear(config.hidden_size, 2) self.classifier_attack = nn.Linear(config.hidden_size, 4) def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): # Переводим тензоры на устройство input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels]) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) pooled_output = outputs.last_hidden_state[:, 0, :] logits_safety = self.classifier_safety(pooled_output) logits_attack = self.classifier_attack(pooled_output) loss = None if labels is not None: labels_safety, labels_attack = labels[:, 0], labels[:, 1] loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety) loss_attack = nn.CrossEntropyLoss()(logits_attack, labels_attack) loss = loss_safety + loss_attack return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss} # Загрузка модели с LoRA адаптерами model = MultiTaskBert.from_pretrained('./fine-tuned-bert-lora_new').to(device) # Восстановление модели с LoRA lora_config = LoraConfig( task_type=TaskType.SEQ_CLS, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["query", "value"], modules_to_save=[] # Не сохраняем дополнительные модули (classifier и т.д.) ) model = get_peft_model(model, lora_config) # Загрузка токенизатора tokenizer = BertTokenizer.from_pretrained('./fine-tuned-bert-lora_new') # Функция для классификации текста с LoRA def classify_text_with_lora(text): # Токенизация текста inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device) # Получение предсказаний with torch.no_grad(): outputs = model(**inputs) # Извлечение логитов для безопасности и типа атаки logits_safety = outputs['logits_safety'] logits_attack = outputs['logits_attack'] # Применение softmax для получения вероятностей safety_probs = torch.softmax(logits_safety, dim=1) attack_probs = torch.softmax(logits_attack, dim=1) # Получение предсказанных меток для безопасности pred_safety = torch.argmax(safety_probs, dim=1).item() # Перевод числовых меток в текстовые метки safety_labels = ['safe', 'unsafe'] attack_labels = ['jailbreak', 'evasion', 'generic attack', 'injection'] # Возвращаем только метку безопасности, и тип атаки только для unsafe if pred_safety == 1: # unsafe pred_attack = torch.argmax(attack_probs, dim=1).item() return safety_labels[pred_safety], attack_labels[pred_attack] else: # safe return safety_labels[pred_safety], None # Тип атаки не выводится для safe # Пример классификации текста text = "привет как дела" safety_pred, attack_pred = classify_text_with_lora(text) print(f"Predicted safety: {safety_pred}") if attack_pred: print(f"Predicted attack type: {attack_pred}")