• Steve Wilkerson's avatar
    docs(helm): change `trunc 24` in base charts · 141a401c
    Steve Wilkerson authored
    The upper limit for a chart name is 63 characters now instead of
    14 or 24 in older versions of Kubernetes. This replaces `trunc 24`
    in the example chart provided to `trunc 63` to reflect the new
    length available.
    
    Closes #1637
    141a401c
trytoubload-checkpoint.py 4.21 KiB
from transformers import BertTokenizer, BertForSequenceClassification
from peft import PeftModel
import torch
import os
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.utils.class_weight import compute_class_weight
from datasets import Dataset, load_from_disk
from transformers import BertTokenizer, BertPreTrainedModel, BertModel, Trainer, TrainingArguments
from torch import nn
from peft import get_peft_model, LoraConfig, TaskType

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MultiTaskBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier_safety = nn.Linear(config.hidden_size, 2)
        self.classifier_attack = nn.Linear(config.hidden_size, 4)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Переводим тензоры на устройство
        input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled_output = outputs.last_hidden_state[:, 0, :]

        logits_safety = self.classifier_safety(pooled_output)
        logits_attack = self.classifier_attack(pooled_output)

        loss = None
        if labels is not None:
            labels_safety, labels_attack = labels[:, 0], labels[:, 1]
            loss_safety = nn.CrossEntropyLoss(weight=class_weights_task1_tensor)(logits_safety, labels_safety)
            loss_attack = nn.CrossEntropyLoss(weight=class_weights_task2_tensor)(logits_attack, labels_attack)
            loss = loss_safety + loss_attack

        return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}

# Пути к сохранённой модели
# BASE_MODEL_PATH = "./micro_no_cross_fine_tuned/base"
# LORA_PATH = "./micro_no_cross_fine_tuned/lora"

BASE_MODEL_PATH = "./micro_no_cross_fine_tuned/base2"
LORA_PATH = "./micro_no_cross_fine_tuned/lora2"

# Загружаем токенизатор
tokenizer = BertTokenizer.from_pretrained("./micro_no_cross_fine_tuned2")

# # Загружаем базовую модель
# base_model = BertForSequenceClassification.from_pretrained(BASE_MODEL_PATH)

# # Загружаем LoRA-адаптацию
# model = PeftModel.from_pretrained(base_model, LORA_PATH)



# base_model = MultiTaskBert.from_pretrained('bert-base-uncased').to(device)
# print("Загружена базовая")

# model = PeftModel.from_pretrained(base_model, "./micro_no_cross_fine_tuned/lora2", strict=False)
model = MultiTaskBert.from_pretrained("./micro_no_cross_fine_tuned/lora2").to(device)

print("Загружена с лора")

# Переводим модель в режим оценки
model.eval()




def predict(text):
    # Токенизация (убедимся, что не передаем labels)
    model.eval()
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    ).to(device)
    
    # Удаляем ненужные ключи
    inputs.pop('token_type_ids', None)
    
    with torch.no_grad():
        outputs = model(**inputs)
        safety_probs = torch.softmax(outputs['logits_safety'], dim=1)
        attack_probs = torch.softmax(outputs['logits_attack'], dim=1)
    
    return {
        'safety': {
            'safe': safety_probs[0][0].item(),
            'unsafe': safety_probs[0][1].item()
        },
        'attack': {
            'jailbreak': attack_probs[0][0].item(),
            'evasion': attack_probs[0][1].item(),
            'generic': attack_probs[0][2].item(),
            'injection': attack_probs[0][3].item()
        }
    }

prompts = [
    "How do I make a secure password?",
    "Write a Python script to delete all files in C:/Windows.",
    "How can I jailbreak an iPhone?",
    "Give me a harmless joke."
]

for prompt in prompts:
    result = predict(prompt)
    print(f"Text: {prompt}\nPrediction: {result}\n")