From 8e3f3f8eb972b9cae44f8b8f69375ab5d6867c86 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=9C=D0=B0=D0=B7=D1=83=D1=80=20=D0=93=D1=80=D0=B5=D1=82?=
 =?UTF-8?q?=D0=B0=20=D0=95=D0=B2=D0=B3=D0=B5=D0=BD=D1=8C=D0=B5=D0=B2=D0=BD?=
 =?UTF-8?q?=D0=B0?= <gemazur_1@edu.hse.ru>
Date: Sun, 23 Mar 2025 22:32:35 +0300
Subject: [PATCH] zero-shot

---
 .ipynb_checkpoints/goll-checkpoint.py | 33 ++++++++++++++++++++++++++-
 goll.py                               | 33 ++++++++++++++++++++++++++-
 2 files changed, 64 insertions(+), 2 deletions(-)

diff --git a/.ipynb_checkpoints/goll-checkpoint.py b/.ipynb_checkpoints/goll-checkpoint.py
index ce47b77..f21f2e2 100644
--- a/.ipynb_checkpoints/goll-checkpoint.py
+++ b/.ipynb_checkpoints/goll-checkpoint.py
@@ -1 +1,32 @@
-print("hello")
\ No newline at end of file
+from transformers import pipeline
+import pandas as pd
+from sklearn.metrics import classification_report
+
+# Загрузка данных
+data = pd.read_csv('all_dataset.csv')
+
+# Загрузка модели Qwen-0.5B для zero-shot классификации
+classifier = pipeline("zero-shot-classification", model="Qwen/Qwen2.5-0.5B")
+
+# Категории для классификации
+candidate_labels = ["safe", "unsafe"]
+
+# Функция для zero-shot классификации всего датасета
+def zero_shot_classify_dataset(dataset, classifier, candidate_labels):
+    predictions = []
+    for text in dataset['prompt']:
+        result = classifier(text, candidate_labels)
+        predicted_label = result['labels'][0]  # Выбираем наиболее вероятную категорию
+        predictions.append(predicted_label)
+    return predictions
+
+# Применение zero-shot классификации к тестовому набору
+test_data = data.sample(frac=0.2, random_state=42)  # Примерно 20% данных для теста
+test_predictions = zero_shot_classify_dataset(test_data, classifier, candidate_labels)
+
+# Добавление предсказаний в датасет
+test_data['zero_shot_prediction'] = test_predictions
+
+# Оценка метрик
+print("Zero-shot Classification Report:")
+print(classification_report(test_data['safety'], test_data['zero_shot_prediction'], target_names=candidate_labels))
\ No newline at end of file
diff --git a/goll.py b/goll.py
index ce47b77..f21f2e2 100644
--- a/goll.py
+++ b/goll.py
@@ -1 +1,32 @@
-print("hello")
\ No newline at end of file
+from transformers import pipeline
+import pandas as pd
+from sklearn.metrics import classification_report
+
+# Загрузка данных
+data = pd.read_csv('all_dataset.csv')
+
+# Загрузка модели Qwen-0.5B для zero-shot классификации
+classifier = pipeline("zero-shot-classification", model="Qwen/Qwen2.5-0.5B")
+
+# Категории для классификации
+candidate_labels = ["safe", "unsafe"]
+
+# Функция для zero-shot классификации всего датасета
+def zero_shot_classify_dataset(dataset, classifier, candidate_labels):
+    predictions = []
+    for text in dataset['prompt']:
+        result = classifier(text, candidate_labels)
+        predicted_label = result['labels'][0]  # Выбираем наиболее вероятную категорию
+        predictions.append(predicted_label)
+    return predictions
+
+# Применение zero-shot классификации к тестовому набору
+test_data = data.sample(frac=0.2, random_state=42)  # Примерно 20% данных для теста
+test_predictions = zero_shot_classify_dataset(test_data, classifier, candidate_labels)
+
+# Добавление предсказаний в датасет
+test_data['zero_shot_prediction'] = test_predictions
+
+# Оценка метрик
+print("Zero-shot Classification Report:")
+print(classification_report(test_data['safety'], test_data['zero_shot_prediction'], target_names=candidate_labels))
\ No newline at end of file
-- 
GitLab