from transformers import pipeline
import pandas as pd
from sklearn.metrics import classification_report

# Загрузка данных
data = pd.read_csv('all_dataset.csv')

# Загрузка модели Qwen-0.5B для zero-shot классификации
classifier = pipeline("zero-shot-classification", model="Qwen/Qwen2.5-0.5B")

# Категории для классификации
candidate_labels = ["safe", "unsafe"]

# Функция для zero-shot классификации всего датасета
def zero_shot_classify_dataset(dataset, classifier, candidate_labels):
    predictions = []
    for text in dataset['prompt']:
        result = classifier(text, candidate_labels)
        predicted_label = result['labels'][0]  # Выбираем наиболее вероятную категорию
        predictions.append(predicted_label)
    return predictions

# Применение zero-shot классификации к тестовому набору
test_data = data.sample(frac=0.2, random_state=42)  # Примерно 20% данных для теста
test_predictions = zero_shot_classify_dataset(test_data, classifier, candidate_labels)

# Добавление предсказаний в датасет
test_data['zero_shot_prediction'] = test_predictions

# Оценка метрик
print("Zero-shot Classification Report:")
print(classification_report(test_data['safety'], test_data['zero_shot_prediction'], target_names=candidate_labels))