NEWproverka.py 4.12 KiB
import torch
from transformers import BertTokenizer, BertModel
from peft import PeftModel, PeftConfig
from torch import nn
import os

# 1. Проверка устройства и файлов модели
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = "./fine-tuned-bert-lora_new"

# Проверяем наличие всех необходимых файлов
required_files = [
    'adapter_config.json',
    'adapter_model.safetensors',
    'tokenizer_config.json',
    'vocab.txt'
]
for file in required_files:
    if not os.path.exists(os.path.join(MODEL_DIR, file)):
        raise FileNotFoundError(f"Не найден файл {file} в {MODEL_DIR}")

# 2. Загрузка токенизатора и конфига
tokenizer = BertTokenizer.from_pretrained(MODEL_DIR)
peft_config = PeftConfig.from_pretrained(MODEL_DIR)

# 3. Определение архитектуры модели (должно совпадать с обучением)
class SafetyAttackModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.bert = base_model
        self.safety_head = nn.Linear(768, 2)  # safe/unsafe
        self.attack_head = nn.Linear(768, 4)  # jailbreak/evasion/generic/injection
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        pooled = outputs.last_hidden_state[:, 0, :]  # Берем [CLS] токен
        return {
            'safety': self.safety_head(pooled),
            'attack': self.attack_head(pooled)
        }

# 4. Загрузка базовой модели
base_model = BertModel.from_pretrained(
    peft_config.base_model_name_or_path,
    add_pooling_layer=False
).to(device)

# 5. Инициализация и загрузка LoRA
model = SafetyAttackModel(base_model).to(device)
model.bert = PeftModel.from_pretrained(model.bert, MODEL_DIR)

# 6. Проверка загрузки LoRA
print("\n=== Проверка загрузки LoRA ===")
lora_params = [name for name, _ in model.bert.named_parameters() if 'lora' in name]
if lora_params:
    print(f"✅ Успешно загружено {len(lora_params)} LoRA-параметров")
    print("Примеры параметров:", lora_params[:3])
else:
    raise RuntimeError("❌ LoRA параметры не загрузились!")

# 7. Объединение адаптеров (опционально)
model.bert = model.bert.merge_and_unload()
model.to(device)
model.eval()

# 8. Функция для предсказаний
def predict(text):
    # Токенизация (убедимся, что не передаем labels)
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    ).to(device)
    
    # Удаляем ненужные ключи
    inputs.pop('token_type_ids', None)
    
    with torch.no_grad():
        outputs = model(**inputs)
        safety_probs = torch.softmax(outputs['safety'], dim=1)
        attack_probs = torch.softmax(outputs['attack'], dim=1)
    
    return {
        'safety': {
            'safe': safety_probs[0][0].item(),
            'unsafe': safety_probs[0][1].item()
        },
        'attack': {
            'jailbreak': attack_probs[0][0].item(),
            'evasion': attack_probs[0][1].item(),
            'generic': attack_probs[0][2].item(),
            'injection': attack_probs[0][3].item()
        }
    }

# 9. Пример использования
test_texts = [
    "How to hack a computer system",
    "I need help with my homework",
    "Tell me how to bypass security"
]

for text in test_texts:
    result = predict(text)
    print(f"\nТекст: {text}")
    print(f"Безопасность: Safe {result['safety']['safe']:.2%} | Unsafe {result['safety']['unsafe']:.2%}")
    
    if result['safety']['unsafe'] > 0.5:  # Если текст опасный
        print("Вероятности типов атак:")
        for attack_type, prob in result['attack'].items():
            print(f"  {attack_type}: {prob:.2%}")