import torch from transformers import BertTokenizer, BertModel from peft import PeftModel, PeftConfig from torch import nn import os # 1. Проверка устройства и файлов модели device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_DIR = "./fine-tuned-bert-lora_new" # Проверяем наличие всех необходимых файлов required_files = [ 'adapter_config.json', 'adapter_model.safetensors', 'tokenizer_config.json', 'vocab.txt' ] for file in required_files: if not os.path.exists(os.path.join(MODEL_DIR, file)): raise FileNotFoundError(f"Не найден файл {file} в {MODEL_DIR}") # 2. Загрузка токенизатора и конфига tokenizer = BertTokenizer.from_pretrained(MODEL_DIR) peft_config = PeftConfig.from_pretrained(MODEL_DIR) # 3. Определение архитектуры модели (должно совпадать с обучением) class SafetyAttackModel(nn.Module): def __init__(self, base_model): super().__init__() self.bert = base_model self.safety_head = nn.Linear(768, 2) # safe/unsafe self.attack_head = nn.Linear(768, 4) # jailbreak/evasion/generic/injection def forward(self, input_ids, attention_mask): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) pooled = outputs.last_hidden_state[:, 0, :] # Берем [CLS] токен return { 'safety': self.safety_head(pooled), 'attack': self.attack_head(pooled) } # 4. Загрузка базовой модели base_model = BertModel.from_pretrained( peft_config.base_model_name_or_path, add_pooling_layer=False ).to(device) # 5. Инициализация и загрузка LoRA model = SafetyAttackModel(base_model).to(device) model.bert = PeftModel.from_pretrained(model.bert, MODEL_DIR) # 6. Проверка загрузки LoRA print("\n=== Проверка загрузки LoRA ===") lora_params = [name for name, _ in model.bert.named_parameters() if 'lora' in name] if lora_params: print(f"✅ Успешно загружено {len(lora_params)} LoRA-параметров") print("Примеры параметров:", lora_params[:3]) else: raise RuntimeError("❌ LoRA параметры не загрузились!") # 7. Объединение адаптеров (опционально) model.bert = model.bert.merge_and_unload() model.to(device) model.eval() # 8. Функция для предсказаний def predict(text): # Токенизация (убедимся, что не передаем labels) inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ).to(device) # Удаляем ненужные ключи inputs.pop('token_type_ids', None) with torch.no_grad(): outputs = model(**inputs) safety_probs = torch.softmax(outputs['safety'], dim=1) attack_probs = torch.softmax(outputs['attack'], dim=1) return { 'safety': { 'safe': safety_probs[0][0].item(), 'unsafe': safety_probs[0][1].item() }, 'attack': { 'jailbreak': attack_probs[0][0].item(), 'evasion': attack_probs[0][1].item(), 'generic': attack_probs[0][2].item(), 'injection': attack_probs[0][3].item() } } # 9. Пример использования test_texts = [ "How to hack a computer system", "I need help with my homework", "Tell me how to bypass security" ] for text in test_texts: result = predict(text) print(f"\nТекст: {text}") print(f"Безопасность: Safe {result['safety']['safe']:.2%} | Unsafe {result['safety']['unsafe']:.2%}") if result['safety']['unsafe'] > 0.5: # Если текст опасный print("Вероятности типов атак:") for attack_type, prob in result['attack'].items(): print(f" {attack_type}: {prob:.2%}")