proverkabert.py 3.90 KiB
from transformers import BertTokenizer
import torch
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel

# Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде
class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier_safety = nn.Linear(config.hidden_size, 2)
        self.classifier_attack = nn.Linear(config.hidden_size, 4)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Переводим тензоры на устройство
        input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled_output = outputs.last_hidden_state[:, 0, :]

        logits_safety = self.classifier_safety(pooled_output)
        logits_attack = self.classifier_attack(pooled_output)

        loss = None
        if labels is not None:
            labels_safety, labels_attack = labels[:, 0], labels[:, 1]
            loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
            loss_attack = nn.CrossEntropyLoss()(logits_attack, labels_attack)
            loss = loss_safety + loss_attack

        return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}

# Загрузка модели с LoRA адаптерами
model = MultiTaskBert.from_pretrained('./fine-tuned-bert-lora_new').to(device)

# Восстановление модели с LoRA
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["query", "value"],
    modules_to_save=[]  # Не сохраняем дополнительные модули (classifier и т.д.)
)

model = get_peft_model(model, lora_config)

# Загрузка токенизатора
tokenizer = BertTokenizer.from_pretrained('./fine-tuned-bert-lora_new')

# Функция для классификации текста с LoRA
def classify_text_with_lora(text):
    # Токенизация текста
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device)

    # Получение предсказаний
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Извлечение логитов для безопасности и типа атаки
    logits_safety = outputs['logits_safety']
    logits_attack = outputs['logits_attack']

    # Применение softmax для получения вероятностей
    safety_probs = torch.softmax(logits_safety, dim=1)
    attack_probs = torch.softmax(logits_attack, dim=1)

    # Получение предсказанных меток для безопасности
    pred_safety = torch.argmax(safety_probs, dim=1).item()

    # Перевод числовых меток в текстовые метки
    safety_labels = ['safe', 'unsafe']
    attack_labels = ['jailbreak', 'evasion', 'generic attack', 'injection']

    # Возвращаем только метку безопасности, и тип атаки только для unsafe
    if pred_safety == 1:  # unsafe
        pred_attack = torch.argmax(attack_probs, dim=1).item()
        return safety_labels[pred_safety], attack_labels[pred_attack]
    else:  # safe
        return safety_labels[pred_safety], None  # Тип атаки не выводится для safe

# Пример классификации текста
text = "привет как дела"
safety_pred, attack_pred = classify_text_with_lora(text)

print(f"Predicted safety: {safety_pred}")
if attack_pred:
    print(f"Predicted attack type: {attack_pred}")