From 8e3f3f8eb972b9cae44f8b8f69375ab5d6867c86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?= =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?= =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru> Date: Sun, 23 Mar 2025 22:32:35 +0300 Subject: [PATCH] zero-shot --- .ipynb_checkpoints/goll-checkpoint.py | 33 ++++++++++++++++++++++++++- goll.py | 33 ++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/.ipynb_checkpoints/goll-checkpoint.py b/.ipynb_checkpoints/goll-checkpoint.py index ce47b77..f21f2e2 100644 --- a/.ipynb_checkpoints/goll-checkpoint.py +++ b/.ipynb_checkpoints/goll-checkpoint.py @@ -1 +1,32 @@ -print("hello") \ No newline at end of file +from transformers import pipeline +import pandas as pd +from sklearn.metrics import classification_report + +# Загрузка данных +data = pd.read_csv('all_dataset.csv') + +# Загрузка модели Qwen-0.5B для zero-shot классификации +classifier = pipeline("zero-shot-classification", model="Qwen/Qwen2.5-0.5B") + +# Категории для классификации +candidate_labels = ["safe", "unsafe"] + +# Функция для zero-shot классификации всего датасета +def zero_shot_classify_dataset(dataset, classifier, candidate_labels): + predictions = [] + for text in dataset['prompt']: + result = classifier(text, candidate_labels) + predicted_label = result['labels'][0] # Выбираем наиболее вероятную категорию + predictions.append(predicted_label) + return predictions + +# Применение zero-shot классификации к тестовому набору +test_data = data.sample(frac=0.2, random_state=42) # Примерно 20% данных для теста +test_predictions = zero_shot_classify_dataset(test_data, classifier, candidate_labels) + +# Добавление предсказаний в датасет +test_data['zero_shot_prediction'] = test_predictions + +# Оценка метрик +print("Zero-shot Classification Report:") +print(classification_report(test_data['safety'], test_data['zero_shot_prediction'], target_names=candidate_labels)) \ No newline at end of file diff --git a/goll.py b/goll.py index ce47b77..f21f2e2 100644 --- a/goll.py +++ b/goll.py @@ -1 +1,32 @@ -print("hello") \ No newline at end of file +from transformers import pipeline +import pandas as pd +from sklearn.metrics import classification_report + +# Загрузка данных +data = pd.read_csv('all_dataset.csv') + +# Загрузка модели Qwen-0.5B для zero-shot классификации +classifier = pipeline("zero-shot-classification", model="Qwen/Qwen2.5-0.5B") + +# Категории для классификации +candidate_labels = ["safe", "unsafe"] + +# Функция для zero-shot классификации всего датасета +def zero_shot_classify_dataset(dataset, classifier, candidate_labels): + predictions = [] + for text in dataset['prompt']: + result = classifier(text, candidate_labels) + predicted_label = result['labels'][0] # Выбираем наиболее вероятную категорию + predictions.append(predicted_label) + return predictions + +# Применение zero-shot классификации к тестовому набору +test_data = data.sample(frac=0.2, random_state=42) # Примерно 20% данных для теста +test_predictions = zero_shot_classify_dataset(test_data, classifier, candidate_labels) + +# Добавление предсказаний в датасет +test_data['zero_shot_prediction'] = test_predictions + +# Оценка метрик +print("Zero-shot Classification Report:") +print(classification_report(test_data['safety'], test_data['zero_shot_prediction'], target_names=candidate_labels)) \ No newline at end of file -- GitLab