From 3575d1d6dee52184c04c4447b8603b7c31a5d258 Mon Sep 17 00:00:00 2001
From: Ivan Chekanov <78589729+ichekanov@users.noreply.github.com>
Date: Wed, 23 Oct 2024 00:06:18 +0300
Subject: [PATCH] Update descriptions

---
 lab-2/result.ipynb | 176 +++++++++++++++++++++++++++++++++------------
 1 file changed, 131 insertions(+), 45 deletions(-)

diff --git a/lab-2/result.ipynb b/lab-2/result.ipynb
index 75f32a9..fc2eeda 100644
--- a/lab-2/result.ipynb
+++ b/lab-2/result.ipynb
@@ -55,6 +55,7 @@
       "outputs": [],
       "source": [
         "import os\n",
+        "\n",
         "HOME = os.getcwd()\n",
         "TRAIN_DIR = f\"{HOME}/dataset/train\"\n",
         "VAL_DIR = f\"{HOME}/dataset/val\""
@@ -101,7 +102,9 @@
         }
       ],
       "source": [
-        "predictor = SAM2ImagePredictor.from_pretrained(\"facebook/sam2-hiera-base-plus\", device=\"cpu\")\n",
+        "predictor = SAM2ImagePredictor.from_pretrained(\n",
+        "    \"facebook/sam2-hiera-base-plus\", device=\"cpu\"\n",
+        ")\n",
         "\n",
         "image = Image.open(os.path.join(TRAIN_DIR, \"0001.jpg\"))\n",
         "image"
@@ -126,7 +129,7 @@
       "source": [
         "with torch.inference_mode(), torch.autocast(\"cuda\", dtype=torch.bfloat16):\n",
         "    predictor.set_image(image)\n",
-        "    masks, _, _ = predictor.predict([(300,320)], [1])"
+        "    masks, _, _ = predictor.predict([(300, 320)], [1])"
       ]
     },
     {
@@ -160,7 +163,7 @@
         "    colored_mask = (colored_mask * 255).astype(np.uint8)\n",
         "    image_np = np.where(mask[..., np.newaxis], colored_mask, image_np)\n",
         "plt.imshow(image_np)\n",
-        "plt.axis('off')\n",
+        "plt.axis(\"off\")\n",
         "plt.show()"
       ]
     },
@@ -399,6 +402,13 @@
         "sv.plot_image(annotated_frame, (8, 8))"
       ]
     },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "Обработаем весь датасет выбранной моделью:"
+      ]
+    },
     {
       "cell_type": "code",
       "execution_count": 14,
@@ -421,13 +431,18 @@
         "import os\n",
         "from IPython.display import ProgressBar\n",
         "\n",
-        "def convert_dataset(directory, store_only_best=False):\n",
-        "    images = sorted(m for m in os.listdir(directory) if m.lower().endswith(('.png', '.jpg', '.jpeg')))\n",
+        "\n",
+        "def prepare_dataset(directory, store_only_best=False):\n",
+        "    images = sorted(\n",
+        "        m\n",
+        "        for m in os.listdir(directory)\n",
+        "        if m.lower().endswith((\".png\", \".jpg\", \".jpeg\"))\n",
+        "    )\n",
         "    pbar = ProgressBar(total=len(images))\n",
         "    pbar.display()\n",
         "\n",
         "    for image_name in images:\n",
-        "        txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + '.txt')\n",
+        "        txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + \".txt\")\n",
         "        if os.path.exists(txt_path):\n",
         "            pbar.progress += 1\n",
         "            continue\n",
@@ -439,7 +454,7 @@
         "            caption=TEXT_PROMPT,\n",
         "            box_threshold=BOX_TRESHOLD,\n",
         "            text_threshold=TEXT_TRESHOLD,\n",
-        "            device=device\n",
+        "            device=device,\n",
         "        )\n",
         "        yolo_annotations = []\n",
         "        if store_only_best:\n",
@@ -447,18 +462,26 @@
         "            boxes = [boxes[best_logit_index]]\n",
         "            phrases = [phrases[best_logit_index]]\n",
         "        for box, phrase in zip(boxes, phrases):\n",
-        "            class_id = 0 if 'kettle' in phrase.lower() else 1\n",
+        "            class_id = 0 if \"kettle\" in phrase.lower() else 1\n",
         "            yolo_box = box.tolist()\n",
         "            yolo_annotations.append(f\"{class_id} {' '.join(map(str, yolo_box))}\")\n",
-        "    \n",
-        "        txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + '.txt')\n",
-        "        with open(txt_path, 'w') as f:\n",
-        "            f.write('\\n'.join(yolo_annotations))\n",
-        "            f.write('\\n')\n",
+        "\n",
+        "        txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + \".txt\")\n",
+        "        with open(txt_path, \"w\") as f:\n",
+        "            f.write(\"\\n\".join(yolo_annotations))\n",
+        "            f.write(\"\\n\")\n",
         "\n",
         "        pbar.progress += 1\n",
         "\n",
-        "convert_dataset(VAL_DIR, store_only_best=True)\n"
+        "\n",
+        "prepare_dataset(VAL_DIR, store_only_best=True)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "Посмотрим на результаты обработки: "
       ]
     },
     {
@@ -482,56 +505,106 @@
         "import matplotlib.patches as patches\n",
         "from PIL import Image\n",
         "\n",
+        "\n",
         "def display_image_with_annotations(directory, image_name):\n",
-        "    # Load the image\n",
         "    image_path = os.path.join(directory, image_name)\n",
         "    img = Image.open(image_path)\n",
-        "    \n",
-        "    # Create figure and axes\n",
         "    fig, ax = plt.subplots(1)\n",
-        "    \n",
-        "    # Display the image\n",
         "    ax.imshow(img)\n",
-        "    \n",
-        "    # Load annotations\n",
-        "    txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + '.txt')\n",
-        "    with open(txt_path, 'r') as f:\n",
-        "        annotations = f.read().strip().split('\\n')\n",
-        "    \n",
-        "    # Draw bounding boxes\n",
+        "    txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + \".txt\")\n",
+        "    with open(txt_path, \"r\") as f:\n",
+        "        annotations = f.read().strip().split(\"\\n\")\n",
         "    for annotation in annotations:\n",
         "        class_id, x_center, y_center, width, height = map(float, annotation.split())\n",
-        "\n",
-        "        # Convert YOLO format to pixel coordinates\n",
         "        img_width, img_height = img.size[0], img.size[1]\n",
-        "        x = (x_center - width/2) * img_width\n",
-        "        y = (y_center - height/2) * img_height\n",
+        "        x = (x_center - width / 2) * img_width\n",
+        "        y = (y_center - height / 2) * img_height\n",
         "        box_width = width * img_width\n",
         "        box_height = height * img_height\n",
-        "        \n",
-        "        # Create a Rectangle patch\n",
-        "        rect = patches.Rectangle((x, y), box_width, box_height, linewidth=2, edgecolor='r', facecolor='none')\n",
-        "        \n",
-        "        # Add the patch to the Axes\n",
+        "        rect = patches.Rectangle(\n",
+        "            (x, y), box_width, box_height, linewidth=2, edgecolor=\"r\", facecolor=\"none\"\n",
+        "        )\n",
         "        ax.add_patch(rect)\n",
-        "        \n",
-        "        # # Add label\n",
         "        label = \"Kettle\" if class_id == 0 else \"Teapot\"\n",
-        "        plt.text(x, y-10, label, color='r', fontweight='bold')\n",
-        "    \n",
+        "        plt.text(x, y - 10, label, color=\"r\", fontweight=\"bold\")\n",
         "    plt.title(f\"Image: {directory}/{image_name}\")\n",
-        "    plt.axis('off')\n",
+        "    plt.axis(\"off\")\n",
         "    plt.show()\n",
         "\n",
-        "display_image_with_annotations(TRAIN_DIR, \"0002.jpg\")\n"
+        "\n",
+        "display_image_with_annotations(TRAIN_DIR, \"0002.jpg\")"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "# Тренировка YOLO"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 22,
       "metadata": {},
-      "outputs": [],
+      "outputs": [
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
+            "To disable this warning, you can either:\n",
+            "\t- Avoid using `tokenizers` before the fork if possible\n",
+            "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Requirement already satisfied: ultralytics in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (8.3.19)\n",
+            "Requirement already satisfied: numpy>=1.23.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.1.1)\n",
+            "Requirement already satisfied: matplotlib>=3.3.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (3.9.2)\n",
+            "Requirement already satisfied: opencv-python>=4.6.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (4.10.0.84)\n",
+            "Requirement already satisfied: pillow>=7.1.2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (10.4.0)\n",
+            "Requirement already satisfied: pyyaml>=5.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (6.0.2)\n",
+            "Requirement already satisfied: requests>=2.23.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.32.3)\n",
+            "Requirement already satisfied: scipy>=1.4.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (1.14.1)\n",
+            "Requirement already satisfied: torch>=1.8.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.5.0)\n",
+            "Requirement already satisfied: torchvision>=0.9.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (0.20.0)\n",
+            "Requirement already satisfied: tqdm>=4.64.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (4.66.5)\n",
+            "Requirement already satisfied: psutil in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (6.0.0)\n",
+            "Requirement already satisfied: py-cpuinfo in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (9.0.0)\n",
+            "Requirement already satisfied: pandas>=1.1.4 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.2.3)\n",
+            "Requirement already satisfied: seaborn>=0.11.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (0.13.2)\n",
+            "Requirement already satisfied: ultralytics-thop>=2.0.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.0.9)\n",
+            "Requirement already satisfied: contourpy>=1.0.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (1.3.0)\n",
+            "Requirement already satisfied: cycler>=0.10 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (0.12.1)\n",
+            "Requirement already satisfied: fonttools>=4.22.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (4.54.1)\n",
+            "Requirement already satisfied: kiwisolver>=1.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (1.4.7)\n",
+            "Requirement already satisfied: packaging>=20.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (24.1)\n",
+            "Requirement already satisfied: pyparsing>=2.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (3.1.4)\n",
+            "Requirement already satisfied: python-dateutil>=2.7 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (2.9.0.post0)\n",
+            "Requirement already satisfied: pytz>=2020.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from pandas>=1.1.4->ultralytics) (2024.2)\n",
+            "Requirement already satisfied: tzdata>=2022.7 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from pandas>=1.1.4->ultralytics) (2024.2)\n",
+            "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (3.3.2)\n",
+            "Requirement already satisfied: idna<4,>=2.5 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (3.10)\n",
+            "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (2.2.3)\n",
+            "Requirement already satisfied: certifi>=2017.4.17 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (2024.8.30)\n",
+            "Requirement already satisfied: filelock in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.16.1)\n",
+            "Requirement already satisfied: typing-extensions>=4.8.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (4.12.2)\n",
+            "Requirement already satisfied: networkx in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.4.2)\n",
+            "Requirement already satisfied: jinja2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.1.4)\n",
+            "Requirement already satisfied: fsspec in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (2024.10.0)\n",
+            "Requirement already satisfied: setuptools in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (75.1.0)\n",
+            "Requirement already satisfied: sympy==1.13.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (1.13.1)\n",
+            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from sympy==1.13.1->torch>=1.8.0->ultralytics) (1.3.0)\n",
+            "Requirement already satisfied: six>=1.5 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib>=3.3.0->ultralytics) (1.16.0)\n",
+            "Requirement already satisfied: MarkupSafe>=2.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from jinja2->torch>=1.8.0->ultralytics) (2.1.5)\n"
+          ]
+        }
+      ],
       "source": [
+        "# Установим Ultralytics\n",
         "!python -m pip install ultralytics"
       ]
     },
@@ -541,7 +614,9 @@
       "metadata": {},
       "outputs": [],
       "source": [
+        "# Загрузим модель\n",
         "from ultralytics import YOLO\n",
+        "\n",
         "model = YOLO(\"yolov8s.pt\")"
       ]
     },
@@ -551,7 +626,9 @@
       "metadata": {},
       "outputs": [],
       "source": [
-        "model.train(data=f\"{HOME}/dataset/data.yaml\", device=\"cuda\", batch=16, epochs=50)"
+        "# Тренируем модель\n",
+        "model.train(data=f\"{HOME}/dataset/data.yaml\", device=\"cuda\", batch=16, epochs=50)\n",
+        "# Для тренировки используется режим `cuda`, так как тренировку делал на другой машине."
       ]
     },
     {
@@ -590,6 +667,7 @@
         "from ultralytics import YOLO\n",
         "import matplotlib.pyplot as plt\n",
         "import cv2\n",
+        "\n",
         "model_path = f\"{HOME}/runs/detect/train/weights/best.pt\"\n",
         "model = YOLO(model_path)\n",
         "results = model.predict(os.path.join(TRAIN_DIR, \"0002.jpg\"))\n",
@@ -597,9 +675,17 @@
         "res_plotted = result.plot()\n",
         "fig, ax = plt.subplots(1)\n",
         "ax.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))\n",
-        "plt.axis('off')\n",
+        "plt.axis(\"off\")\n",
         "plt.show()"
       ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {},
+      "source": [
+        "# Результаты\n",
+        "Сделан файн-тюн YOLO, который умеет распознавать чайники: газовые, электрические и заварочные. Между собой классы отличаются не всегда (например, на картинке выше должен быть `kettle`, а распознан `teapot`), но это скорее проблема качества датасета (я и сам не всегда могу чётко сказать разницу). Результаты тренировки можно найти в [папке runs](lab-2/runs/detect/train/). Лучшая метрика `mAP50` составила `0.85204`: около 0.7 для kettle и 0.9 для teapot — есть смысл объединить эти классы, чтобы не \"путать\" модель и повысить точность."
+      ]
     }
   ],
   "metadata": {
-- 
GitLab