From 3575d1d6dee52184c04c4447b8603b7c31a5d258 Mon Sep 17 00:00:00 2001 From: Ivan Chekanov <78589729+ichekanov@users.noreply.github.com> Date: Wed, 23 Oct 2024 00:06:18 +0300 Subject: [PATCH] Update descriptions --- lab-2/result.ipynb | 176 +++++++++++++++++++++++++++++++++------------ 1 file changed, 131 insertions(+), 45 deletions(-) diff --git a/lab-2/result.ipynb b/lab-2/result.ipynb index 75f32a9..fc2eeda 100644 --- a/lab-2/result.ipynb +++ b/lab-2/result.ipynb @@ -55,6 +55,7 @@ "outputs": [], "source": [ "import os\n", + "\n", "HOME = os.getcwd()\n", "TRAIN_DIR = f\"{HOME}/dataset/train\"\n", "VAL_DIR = f\"{HOME}/dataset/val\"" @@ -101,7 +102,9 @@ } ], "source": [ - "predictor = SAM2ImagePredictor.from_pretrained(\"facebook/sam2-hiera-base-plus\", device=\"cpu\")\n", + "predictor = SAM2ImagePredictor.from_pretrained(\n", + " \"facebook/sam2-hiera-base-plus\", device=\"cpu\"\n", + ")\n", "\n", "image = Image.open(os.path.join(TRAIN_DIR, \"0001.jpg\"))\n", "image" @@ -126,7 +129,7 @@ "source": [ "with torch.inference_mode(), torch.autocast(\"cuda\", dtype=torch.bfloat16):\n", " predictor.set_image(image)\n", - " masks, _, _ = predictor.predict([(300,320)], [1])" + " masks, _, _ = predictor.predict([(300, 320)], [1])" ] }, { @@ -160,7 +163,7 @@ " colored_mask = (colored_mask * 255).astype(np.uint8)\n", " image_np = np.where(mask[..., np.newaxis], colored_mask, image_np)\n", "plt.imshow(image_np)\n", - "plt.axis('off')\n", + "plt.axis(\"off\")\n", "plt.show()" ] }, @@ -399,6 +402,13 @@ "sv.plot_image(annotated_frame, (8, 8))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Обработаем весь датасет выбранной моделью:" + ] + }, { "cell_type": "code", "execution_count": 14, @@ -421,13 +431,18 @@ "import os\n", "from IPython.display import ProgressBar\n", "\n", - "def convert_dataset(directory, store_only_best=False):\n", - " images = sorted(m for m in os.listdir(directory) if m.lower().endswith(('.png', '.jpg', '.jpeg')))\n", + "\n", + "def prepare_dataset(directory, store_only_best=False):\n", + " images = sorted(\n", + " m\n", + " for m in os.listdir(directory)\n", + " if m.lower().endswith((\".png\", \".jpg\", \".jpeg\"))\n", + " )\n", " pbar = ProgressBar(total=len(images))\n", " pbar.display()\n", "\n", " for image_name in images:\n", - " txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + '.txt')\n", + " txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + \".txt\")\n", " if os.path.exists(txt_path):\n", " pbar.progress += 1\n", " continue\n", @@ -439,7 +454,7 @@ " caption=TEXT_PROMPT,\n", " box_threshold=BOX_TRESHOLD,\n", " text_threshold=TEXT_TRESHOLD,\n", - " device=device\n", + " device=device,\n", " )\n", " yolo_annotations = []\n", " if store_only_best:\n", @@ -447,18 +462,26 @@ " boxes = [boxes[best_logit_index]]\n", " phrases = [phrases[best_logit_index]]\n", " for box, phrase in zip(boxes, phrases):\n", - " class_id = 0 if 'kettle' in phrase.lower() else 1\n", + " class_id = 0 if \"kettle\" in phrase.lower() else 1\n", " yolo_box = box.tolist()\n", " yolo_annotations.append(f\"{class_id} {' '.join(map(str, yolo_box))}\")\n", - " \n", - " txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + '.txt')\n", - " with open(txt_path, 'w') as f:\n", - " f.write('\\n'.join(yolo_annotations))\n", - " f.write('\\n')\n", + "\n", + " txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + \".txt\")\n", + " with open(txt_path, \"w\") as f:\n", + " f.write(\"\\n\".join(yolo_annotations))\n", + " f.write(\"\\n\")\n", "\n", " pbar.progress += 1\n", "\n", - "convert_dataset(VAL_DIR, store_only_best=True)\n" + "\n", + "prepare_dataset(VAL_DIR, store_only_best=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Посмотрим на результаты обработки: " ] }, { @@ -482,56 +505,106 @@ "import matplotlib.patches as patches\n", "from PIL import Image\n", "\n", + "\n", "def display_image_with_annotations(directory, image_name):\n", - " # Load the image\n", " image_path = os.path.join(directory, image_name)\n", " img = Image.open(image_path)\n", - " \n", - " # Create figure and axes\n", " fig, ax = plt.subplots(1)\n", - " \n", - " # Display the image\n", " ax.imshow(img)\n", - " \n", - " # Load annotations\n", - " txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + '.txt')\n", - " with open(txt_path, 'r') as f:\n", - " annotations = f.read().strip().split('\\n')\n", - " \n", - " # Draw bounding boxes\n", + " txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + \".txt\")\n", + " with open(txt_path, \"r\") as f:\n", + " annotations = f.read().strip().split(\"\\n\")\n", " for annotation in annotations:\n", " class_id, x_center, y_center, width, height = map(float, annotation.split())\n", - "\n", - " # Convert YOLO format to pixel coordinates\n", " img_width, img_height = img.size[0], img.size[1]\n", - " x = (x_center - width/2) * img_width\n", - " y = (y_center - height/2) * img_height\n", + " x = (x_center - width / 2) * img_width\n", + " y = (y_center - height / 2) * img_height\n", " box_width = width * img_width\n", " box_height = height * img_height\n", - " \n", - " # Create a Rectangle patch\n", - " rect = patches.Rectangle((x, y), box_width, box_height, linewidth=2, edgecolor='r', facecolor='none')\n", - " \n", - " # Add the patch to the Axes\n", + " rect = patches.Rectangle(\n", + " (x, y), box_width, box_height, linewidth=2, edgecolor=\"r\", facecolor=\"none\"\n", + " )\n", " ax.add_patch(rect)\n", - " \n", - " # # Add label\n", " label = \"Kettle\" if class_id == 0 else \"Teapot\"\n", - " plt.text(x, y-10, label, color='r', fontweight='bold')\n", - " \n", + " plt.text(x, y - 10, label, color=\"r\", fontweight=\"bold\")\n", " plt.title(f\"Image: {directory}/{image_name}\")\n", - " plt.axis('off')\n", + " plt.axis(\"off\")\n", " plt.show()\n", "\n", - "display_image_with_annotations(TRAIN_DIR, \"0002.jpg\")\n" + "\n", + "display_image_with_annotations(TRAIN_DIR, \"0002.jpg\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Тренировка YOLO" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: ultralytics in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (8.3.19)\n", + "Requirement already satisfied: numpy>=1.23.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.1.1)\n", + "Requirement already satisfied: matplotlib>=3.3.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (3.9.2)\n", + "Requirement already satisfied: opencv-python>=4.6.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (4.10.0.84)\n", + "Requirement already satisfied: pillow>=7.1.2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (10.4.0)\n", + "Requirement already satisfied: pyyaml>=5.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (6.0.2)\n", + "Requirement already satisfied: requests>=2.23.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.32.3)\n", + "Requirement already satisfied: scipy>=1.4.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (1.14.1)\n", + "Requirement already satisfied: torch>=1.8.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.5.0)\n", + "Requirement already satisfied: torchvision>=0.9.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (0.20.0)\n", + "Requirement already satisfied: tqdm>=4.64.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (4.66.5)\n", + "Requirement already satisfied: psutil in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (6.0.0)\n", + "Requirement already satisfied: py-cpuinfo in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (9.0.0)\n", + "Requirement already satisfied: pandas>=1.1.4 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.2.3)\n", + "Requirement already satisfied: seaborn>=0.11.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (0.13.2)\n", + "Requirement already satisfied: ultralytics-thop>=2.0.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.0.9)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (1.3.0)\n", + "Requirement already satisfied: cycler>=0.10 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (4.54.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (1.4.7)\n", + "Requirement already satisfied: packaging>=20.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (24.1)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (3.1.4)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from pandas>=1.1.4->ultralytics) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from pandas>=1.1.4->ultralytics) (2024.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (2024.8.30)\n", + "Requirement already satisfied: filelock in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (4.12.2)\n", + "Requirement already satisfied: networkx in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.1.4)\n", + "Requirement already satisfied: fsspec in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (2024.10.0)\n", + "Requirement already satisfied: setuptools in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (75.1.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from sympy==1.13.1->torch>=1.8.0->ultralytics) (1.3.0)\n", + "Requirement already satisfied: six>=1.5 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib>=3.3.0->ultralytics) (1.16.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from jinja2->torch>=1.8.0->ultralytics) (2.1.5)\n" + ] + } + ], "source": [ + "# Установим Ultralytics\n", "!python -m pip install ultralytics" ] }, @@ -541,7 +614,9 @@ "metadata": {}, "outputs": [], "source": [ + "# Загрузим модель\n", "from ultralytics import YOLO\n", + "\n", "model = YOLO(\"yolov8s.pt\")" ] }, @@ -551,7 +626,9 @@ "metadata": {}, "outputs": [], "source": [ - "model.train(data=f\"{HOME}/dataset/data.yaml\", device=\"cuda\", batch=16, epochs=50)" + "# Тренируем модель\n", + "model.train(data=f\"{HOME}/dataset/data.yaml\", device=\"cuda\", batch=16, epochs=50)\n", + "# Для тренировки используется режим `cuda`, так как тренировку делал на другой машине." ] }, { @@ -590,6 +667,7 @@ "from ultralytics import YOLO\n", "import matplotlib.pyplot as plt\n", "import cv2\n", + "\n", "model_path = f\"{HOME}/runs/detect/train/weights/best.pt\"\n", "model = YOLO(model_path)\n", "results = model.predict(os.path.join(TRAIN_DIR, \"0002.jpg\"))\n", @@ -597,9 +675,17 @@ "res_plotted = result.plot()\n", "fig, ax = plt.subplots(1)\n", "ax.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))\n", - "plt.axis('off')\n", + "plt.axis(\"off\")\n", "plt.show()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Результаты\n", + "Сделан файн-тюн YOLO, который умеет распознавать чайники: газовые, электрические и заварочные. Между собой классы отличаются не всегда (например, на картинке выше должен быть `kettle`, а распознан `teapot`), но это скорее проблема качества датасета (я и сам не всегда могу чётко сказать разницу). Результаты тренировки можно найти в [папке runs](lab-2/runs/detect/train/). Лучшая метрика `mAP50` составила `0.85204`: около 0.7 для kettle и 0.9 для teapot — есть смысл объединить эти классы, чтобы не \"путать\" модель и повысить точность." + ] } ], "metadata": { -- GitLab