Source

Target

Commits (2)
Showing with 85 additions and 0 deletions
+85 -0
GroundingDINO/
\ No newline at end of file
%% Cell type:code id: tags:
 
``` python
!python -m pip install opencv-python
!python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
```
 
%% Cell type:markdown id: tags:
 
# Сбор и предобработка датасета
Датасет составлен на основе поисковой выдачи Яндекс.Картинок, Pinterest, Unsplash и нескольких небольших датасетов с Roboflow ([раз](https://universe.roboflow.com/darinas-workspace/cup-kettle), [два](https://universe.roboflow.com/badredslim/electric_kettles_detection), [три](https://universe.roboflow.com/pruebas-de-200/teapot-f8mdf)). Все картинки проверены на сходство при помощи редактора `geeqie`, что помогло избавиться от большинства дубликатов. Кроме того, все картинки обработаны ImageMagick для получения одинакового для всех размера 640*640 пикселей.
 
%% Cell type:markdown id: tags:
 
# Автоматическая разметка датасета
 
%% Cell type:markdown id: tags:
 
Попробуем разметить датасет при помощи модели `SAM-2`. К сожалению, эта модель не подошла для нашей задачи. Как выяснилось, эта модель умеет только определять границы объектов, никак не связывая их с семантиеским значением. Вот пример детекции чайника по одной точке, относящейся к нему и выбранной вручную.
 
%% Cell type:code id: tags:
 
``` python
!python -m pip install sam2
!python -m pip install huggingface_hub
```
 
%% Cell type:code id: tags:
 
``` python
import os
 
HOME = os.getcwd()
TRAIN_DIR = f"{HOME}/dataset/train"
VAL_DIR = f"{HOME}/dataset/val"
```
 
%% Cell type:code id: tags:
 
``` python
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
```
 
%% Cell type:code id: tags:
 
``` python
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-base-plus", device="cpu"
)
 
image = Image.open(os.path.join(TRAIN_DIR, "0001.jpg"))
image
```
 
%% Output
 
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>
 
%% Cell type:code id: tags:
 
``` python
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(image)
masks, _, _ = predictor.predict([(300, 320)], [1])
```
 
%% Output
 
/Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
warnings.warn(
 
%% Cell type:code id: tags:
 
``` python
image_np = np.array(image)
color_map = plt.get_cmap("jet")
for mask in masks:
colored_mask = color_map(mask.astype(np.float32))[:, :, :3]
colored_mask = (colored_mask * 255).astype(np.uint8)
image_np = np.where(mask[..., np.newaxis], colored_mask, image_np)
plt.imshow(image_np)
plt.axis("off")
plt.show()
```
 
%% Output
 
 
%% Cell type:markdown id: tags:
 
Изучив `CLIP`, я понял, что он тоже не подходит, так как выполняет обратное действие: определить все объекты на картинке, не уточняя их позиции.
 
Наконец, я выбрал `GroundingDINO`, который нам подходит: на вход можно указать список промптов и, после обработки изображения, получить список прямоугольников с этим типом объектов. Воспользуемся же ею!
 
%% Cell type:code id: tags:
 
``` python
# Установка GroundingDINO
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
%pip install -q -r requirements.txt
%pip install -e .
!mkdir weights
%cd weights
!curl -L -O https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
%cd {HOME}/GroundingDINO
from groundingdino.util.inference import load_model, load_image, predict, annotate
%cd {HOME}
```
 
%% Output
 
/Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
 
/Users/ischknv/Documents/GitHub/miem/aimm/lab-2/GroundingDINO
/Users/ischknv/Documents/GitHub/miem/aimm/lab-2
 
UserWarning: Failed to load custom C++ ops. Running on CPU mode Only!
UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
 
%% Cell type:code id: tags:
 
``` python
# Получение путей к конфигурации и весам
CONFIG_NAME = "GroundingDINO_SwinT_OGC.py"
CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/", CONFIG_NAME)
print(f"CONFIG: {CONFIG_PATH}; exist: {os.path.isfile(CONFIG_PATH)}")
WEIGHTS_NAME = "groundingdino_swint_ogc.pth"
WEIGHTS_PATH = os.path.join(HOME, "GroundingDINO", "weights", WEIGHTS_NAME)
print(f"WEIGHTS: {WEIGHTS_PATH}; exist: {os.path.isfile(WEIGHTS_PATH)}")
```
 
%% Output
 
CONFIG: /Users/ischknv/Documents/GitHub/miem/aimm/lab-2/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py; exist: True
WEIGHTS: /Users/ischknv/Documents/GitHub/miem/aimm/lab-2/GroundingDINO/weights/groundingdino_swint_ogc.pth; exist: True
 
%% Cell type:code id: tags:
 
``` python
import torch
import supervision as sv
 
model = load_model(CONFIG_PATH, WEIGHTS_PATH)
TEXT_PROMPT = "kettle . teapot ."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
IMAGE_SIZE = 640
device = "cuda" if torch.cuda.is_available() else "cpu"
```
 
%% Output
 
UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3596.)
 
final text_encoder_type: bert-base-uncased
 
%% Cell type:code id: tags:
 
``` python
images = sorted(m for m in os.listdir(TRAIN_DIR) if m.lower().endswith(('.png', '.jpg', '.jpeg')))
image_source, image = load_image(os.path.join(TRAIN_DIR, images[1]))
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD,
device=device
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
%matplotlib inline
sv.plot_image(annotated_frame, (8, 8))
```
 
%% Output
 
UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
UserWarning: None of the inputs have requires_grad=True. Gradients will be None
 
 
%% Cell type:markdown id: tags:
 
Обработаем весь датасет выбранной моделью:
 
%% Cell type:code id: tags:
 
``` python
import os
from IPython.display import ProgressBar
 
 
def prepare_dataset(directory, store_only_best=False):
images = sorted(
m
for m in os.listdir(directory)
if m.lower().endswith((".png", ".jpg", ".jpeg"))
)
pbar = ProgressBar(total=len(images))
pbar.display()
 
for image_name in images:
txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + ".txt")
if os.path.exists(txt_path):
pbar.progress += 1
continue
image_path = os.path.join(directory, image_name)
image_source, image = load_image(image_path)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD,
device=device,
)
yolo_annotations = []
if store_only_best:
best_logit_index = logits.argmax()
boxes = [boxes[best_logit_index]]
phrases = [phrases[best_logit_index]]
for box, phrase in zip(boxes, phrases):
class_id = 0 if "kettle" in phrase.lower() else 1
yolo_box = box.tolist()
yolo_annotations.append(f"{class_id} {' '.join(map(str, yolo_box))}")
 
txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + ".txt")
with open(txt_path, "w") as f:
f.write("\n".join(yolo_annotations))
f.write("\n")
 
pbar.progress += 1
 
 
prepare_dataset(VAL_DIR, store_only_best=True)
```
 
%% Output
 
 
%% Cell type:markdown id: tags:
 
Посмотрим на результаты обработки:
 
%% Cell type:code id: tags:
 
``` python
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
 
 
def display_image_with_annotations(directory, image_name):
image_path = os.path.join(directory, image_name)
img = Image.open(image_path)
fig, ax = plt.subplots(1)
ax.imshow(img)
txt_path = os.path.join(directory, os.path.splitext(image_name)[0] + ".txt")
with open(txt_path, "r") as f:
annotations = f.read().strip().split("\n")
for annotation in annotations:
class_id, x_center, y_center, width, height = map(float, annotation.split())
img_width, img_height = img.size[0], img.size[1]
x = (x_center - width / 2) * img_width
y = (y_center - height / 2) * img_height
box_width = width * img_width
box_height = height * img_height
rect = patches.Rectangle(
(x, y), box_width, box_height, linewidth=2, edgecolor="r", facecolor="none"
)
ax.add_patch(rect)
label = "Kettle" if class_id == 0 else "Teapot"
plt.text(x, y - 10, label, color="r", fontweight="bold")
plt.title(f"Image: {directory}/{image_name}")
plt.axis("off")
plt.show()
 
 
display_image_with_annotations(TRAIN_DIR, "0002.jpg")
```
 
%% Output
 
 
%% Cell type:markdown id: tags:
 
# Тренировка YOLO
 
%% Cell type:code id: tags:
 
``` python
# Установим Ultralytics
!python -m pip install ultralytics
```
 
%% Output
 
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
 
Requirement already satisfied: ultralytics in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (8.3.19)
Requirement already satisfied: numpy>=1.23.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.1.1)
Requirement already satisfied: matplotlib>=3.3.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (3.9.2)
Requirement already satisfied: opencv-python>=4.6.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (4.10.0.84)
Requirement already satisfied: pillow>=7.1.2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (10.4.0)
Requirement already satisfied: pyyaml>=5.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (6.0.2)
Requirement already satisfied: requests>=2.23.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.32.3)
Requirement already satisfied: scipy>=1.4.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (1.14.1)
Requirement already satisfied: torch>=1.8.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.5.0)
Requirement already satisfied: torchvision>=0.9.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (0.20.0)
Requirement already satisfied: tqdm>=4.64.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (4.66.5)
Requirement already satisfied: psutil in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (6.0.0)
Requirement already satisfied: py-cpuinfo in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (9.0.0)
Requirement already satisfied: pandas>=1.1.4 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.2.3)
Requirement already satisfied: seaborn>=0.11.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (0.13.2)
Requirement already satisfied: ultralytics-thop>=2.0.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from ultralytics) (2.0.9)
Requirement already satisfied: contourpy>=1.0.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (4.54.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (24.1)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (3.1.4)
Requirement already satisfied: python-dateutil>=2.7 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from matplotlib>=3.3.0->ultralytics) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from pandas>=1.1.4->ultralytics) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from pandas>=1.1.4->ultralytics) (2024.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from requests>=2.23.0->ultralytics) (2024.8.30)
Requirement already satisfied: filelock in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (4.12.2)
Requirement already satisfied: networkx in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.4.2)
Requirement already satisfied: jinja2 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (3.1.4)
Requirement already satisfied: fsspec in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (2024.10.0)
Requirement already satisfied: setuptools in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (75.1.0)
Requirement already satisfied: sympy==1.13.1 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from torch>=1.8.0->ultralytics) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from sympy==1.13.1->torch>=1.8.0->ultralytics) (1.3.0)
Requirement already satisfied: six>=1.5 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib>=3.3.0->ultralytics) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/ischknv/Documents/GitHub/miem/aimm/.venv/lib/python3.12/site-packages (from jinja2->torch>=1.8.0->ultralytics) (2.1.5)
 
%% Cell type:code id: tags:
 
``` python
# Загрузим модель
from ultralytics import YOLO
 
model = YOLO("yolov8s.pt")
```
 
%% Cell type:code id: tags:
 
``` python
WANDB_API_KEY = "f2c04ec6cbd9f5e8f4c01a46514f80400716f594"
```
%% Cell type:code id: tags:
``` python
# Тренируем модель
model.train(data=f"{HOME}/dataset/data.yaml", device="cuda", batch=16, epochs=50)
# Для тренировки используется режим `cuda`, так как тренировку делал на другой машине.
```
 
%% Cell type:markdown id: tags:
 
Результат тренировки стёрся случайно, но он точно был здесь :)
 
%% Cell type:code id: tags:
 
``` python
from ultralytics import YOLO
import matplotlib.pyplot as plt
import cv2
 
model_path = f"{HOME}/runs/detect/train/weights/best.pt"
model = YOLO(model_path)
results = model.predict(os.path.join(TRAIN_DIR, "0002.jpg"))
result = results[0]
res_plotted = result.plot()
fig, ax = plt.subplots(1)
ax.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.show()
```
 
%% Output
 
image 1/1 /Users/ischknv/Documents/GitHub/miem/aimm/lab-2/dataset/train/0002.jpg: 640x640 1 teapot, 159.1ms
Speed: 1.1ms preprocess, 159.1ms inference, 0.5ms postprocess per image at shape (1, 3, 640, 640)
 
 
%% Cell type:markdown id: tags:
 
# Результаты
Сделан файн-тюн YOLO, который умеет распознавать чайники: газовые, электрические и заварочные. Между собой классы отличаются не всегда (например, на картинке выше должен быть `kettle`, а распознан `teapot`), но это скорее проблема качества датасета (я и сам не всегда могу чётко сказать разницу). Результаты тренировки можно найти в [папке runs](lab-2/runs/detect/train/). Лучшая метрика `mAP50` составила `0.85204`: около 0.7 для kettle и 0.9 для teapot — есть смысл объединить эти классы, чтобы не "путать" модель и повысить точность.
import cv2
import numpy as np
import time
import sys
from ultralytics import YOLO
CONFIDENCE = 0.5
font_scale = 1
thickness = 1
labels = ['kettle', 'teapot']
colors = np.random.randint(0, 255, size=(len(labels), 3), dtype="uint8")
model = YOLO("/Users/ischknv/Documents/GitHub/miem/aimm/lab-2/runs/detect/train/weights/best.pt")
cap = cv2.VideoCapture(0)
_, image = cap.read()
h, w = image.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*"XVID")
out = cv2.VideoWriter("output.avi", fourcc, 20.0, (w, h))
while True:
_, image = cap.read()
start = time.perf_counter()
# run inference on the image
# see: https://docs.ultralytics.com/modes/predict/#arguments for full list of arguments
results = model.predict(image, conf=CONFIDENCE)[0]
time_took = time.perf_counter() - start
print("Time took:", time_took)
# loop over the detections
for data in results.boxes.data.tolist():
# get the bounding box coordinates, confidence, and class id
xmin, ymin, xmax, ymax, confidence, class_id = data
# converting the coordinates and the class id to integers
xmin = int(xmin)
ymin = int(ymin)
xmax = int(xmax)
ymax = int(ymax)
class_id = int(class_id)
# draw a bounding box rectangle and label on the image
color = [int(c) for c in colors[class_id]]
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=color, thickness=thickness)
text = f"{labels[class_id]}: {confidence:.2f}"
# calculate text width & height to draw the transparent boxes as background of the text
(text_width, text_height) = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=thickness)[0]
text_offset_x = xmin
text_offset_y = ymin - 5
box_coords = ((text_offset_x, text_offset_y), (text_offset_x + text_width + 2, text_offset_y - text_height))
overlay = image.copy()
cv2.rectangle(overlay, box_coords[0], box_coords[1], color=color, thickness=cv2.FILLED)
# add opacity (transparency to the box)
image = cv2.addWeighted(overlay, 0.6, image, 0.4, 0)
# now put the text (label: confidence %)
cv2.putText(image, text, (xmin, ymin - 5), cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale, color=(0, 0, 0), thickness=thickness)
# end time to compute the fps
end = time.perf_counter()
# calculate the frame per second and draw it on the frame
fps = f"FPS: {1 / (end - start):.2f}"
cv2.putText(image, fps, (50, 50),
cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 6)
out.write(image)
cv2.imshow("image", image)
if ord("q") == cv2.waitKey(1):
break
cap.release()
cv2.destroyAllWindows()
\ No newline at end of file
Venice-2/
GroundingDINO/
yolov8m.pt