pupilometer/Training/run.py

195 lines
6.4 KiB
Python

import os
import glob
import cv2
import yaml
import gc
import torch
import numpy as np
from ultralytics import SAM, YOLO
# Try importing the specific Predictor class
try:
from ultralytics.models.sam import SAM2Predictor
except ImportError:
from ultralytics.models.sam.predict import SAM2Predictor
# ==========================================
# 1. CONFIGURATION
# ==========================================
SOURCE_IMAGES_PATH = "Datasets/02/" # Your raw images
OUTPUT_BASE = "yolo_dataset" # Output folder
SAM_MODEL_PATH = "Models/sam2_b.pt" # SAM2 Model
YOLO_MODEL_NAME = "yolo11m-seg.pt" # YOLO Model (n=nano, m=medium)
TRAIN_EPOCHS = 10
TRAIN_IMGSZ = 640
CLASS_NAME = "my_object"
# ==========================================
# 2. SETUP & INITIALIZATION
# ==========================================
OUTPUT_IMG_DIR = os.path.join(OUTPUT_BASE, "images", "train")
OUTPUT_LBL_DIR = os.path.join(OUTPUT_BASE, "labels", "train")
os.makedirs(OUTPUT_IMG_DIR, exist_ok=True)
os.makedirs(OUTPUT_LBL_DIR, exist_ok=True)
extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
image_files = []
for ext in extensions:
image_files.extend(glob.glob(os.path.join(SOURCE_IMAGES_PATH, ext)))
image_files.sort()
if not image_files:
raise ValueError(f"No images found in {SOURCE_IMAGES_PATH}")
print(f"--- PHASE 1: AUTO-LABELING ({len(image_files)} images) ---")
print(f"Loading SAM2: {SAM_MODEL_PATH}...")
overrides = dict(conf=0.25, task='segment', mode='predict', model=SAM_MODEL_PATH)
predictor = SAM2Predictor(overrides=overrides)
# ==========================================
# 3. HELPER: MASK TO POLYGON
# ==========================================
def mask_to_yolo_polygon(mask, img_width, img_height):
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
polygons = []
for contour in contours:
contour = contour.flatten().tolist()
if len(contour) < 6: continue
normalized_poly = []
for i in range(0, len(contour), 2):
x_norm = max(0, min(1, contour[i] / img_width))
y_norm = max(0, min(1, contour[i+1] / img_height))
normalized_poly.extend([x_norm, y_norm])
polygons.append(normalized_poly)
return polygons
# ==========================================
# 4. INTERACTIVE POINT PROMPT
# ==========================================
first_image_path = image_files[0]
first_image = cv2.imread(first_image_path)
if first_image is None: raise ValueError("Could not load first image")
clone = first_image.copy()
input_points = []
input_labels = []
def mouse_callback(event, x, y, flags, param):
global input_points, input_labels, first_image
# Left Click: Add Positive Point (Foreground)
if event == cv2.EVENT_LBUTTONDOWN:
input_points.append([x, y])
input_labels.append(1) # 1 = Positive
# Visualize
cv2.circle(first_image, (x, y), 5, (0, 255, 0), -1) # Green dot
cv2.imshow("Define Object", first_image)
print(f"Added point: {x}, {y}")
# Right Click: Remove Last Point (Undo)
elif event == cv2.EVENT_RBUTTONDOWN:
if input_points:
input_points.pop()
input_labels.pop()
# Redraw image
first_image = clone.copy()
for pt in input_points:
cv2.circle(first_image, (pt[0], pt[1]), 5, (0, 255, 0), -1)
cv2.imshow("Define Object", first_image)
print("Removed last point")
cv2.namedWindow("Define Object")
cv2.setMouseCallback("Define Object", mouse_callback)
cv2.imshow("Define Object", first_image)
print("\n--- INSTRUCTIONS ---")
print("1. LEFT CLICK to add points on the object.")
print("2. RIGHT CLICK to remove the last point.")
print("3. Press SPACE or ENTER to confirm and start.")
print("4. Press 'r' to reset all points.")
while True:
key = cv2.waitKey(1) & 0xFF
if key in [32, 13]: # Space/Enter
if input_points: break
print("Please add at least one point.")
elif key == ord('r'): # Reset
first_image = clone.copy()
input_points = []
input_labels = []
cv2.imshow("Define Object", first_image)
elif key == 27: exit() # Esc
cv2.destroyAllWindows()
# ==========================================
# 5. BATCH LABEL GENERATION
# ==========================================
print(f"\nApplying {len(input_points)} points to all images...")
valid_labels_count = 0
for img_path in image_files:
filename = os.path.basename(img_path)
img = cv2.imread(img_path)
if img is None: continue
h, w = img.shape[:2]
# SAM2 Inference
predictor.set_image(img)
# Use points instead of bboxes
# points expects list of shape [N, 2], labels list of shape [N]
results = predictor(points=input_points, labels=input_labels)
if results[0].masks is not None:
masks = results[0].masks.data.cpu().numpy()
# Save Image
cv2.imwrite(os.path.join(OUTPUT_IMG_DIR, filename), img)
# Save Label
txt_path = os.path.join(OUTPUT_LBL_DIR, os.path.splitext(filename)[0] + ".txt")
with open(txt_path, 'w') as f:
for mask in masks:
polygons = mask_to_yolo_polygon(mask, w, h)
for poly in polygons:
f.write(f"0 {' '.join(map(str, poly))}\n")
valid_labels_count += 1
print(f"Labeled: {filename}")
print(f"--- Labeling Complete. {valid_labels_count} images prepared. ---")
# ==========================================
# 6. MEMORY CLEANUP
# ==========================================
del predictor
del overrides
gc.collect()
torch.cuda.empty_cache()
print("SAM2 unloaded.")
# ==========================================
# 7. GENERATE YAML & TRAIN
# ==========================================
yaml_path = os.path.join(OUTPUT_BASE, "dataset.yaml")
yaml_data = {
'path': os.path.abspath(OUTPUT_BASE),
'train': 'images/train',
'val': 'images/train',
'names': {0: CLASS_NAME}
}
with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f)
print("--- PHASE 2: TRAINING YOLO11 ---")
model = YOLO(YOLO_MODEL_NAME)
results = model.train(data=yaml_path, epochs=TRAIN_EPOCHS, imgsz=TRAIN_IMGSZ, batch=4, device=0)
print("DONE.")