import os import glob import cv2 import yaml import gc import torch import numpy as np from ultralytics import SAM, YOLO # Try importing the specific Predictor class try: from ultralytics.models.sam import SAM2Predictor except ImportError: from ultralytics.models.sam.predict import SAM2Predictor # ========================================== # 1. CONFIGURATION # ========================================== SOURCE_IMAGES_PATH = "Datasets/02/" # Your raw images OUTPUT_BASE = "yolo_dataset" # Output folder SAM_MODEL_PATH = "Models/sam2_b.pt" # SAM2 Model YOLO_MODEL_NAME = "yolo11m-seg.pt" # YOLO Model (n=nano, m=medium) TRAIN_EPOCHS = 10 TRAIN_IMGSZ = 640 CLASS_NAME = "my_object" # ========================================== # 2. SETUP & INITIALIZATION # ========================================== OUTPUT_IMG_DIR = os.path.join(OUTPUT_BASE, "images", "train") OUTPUT_LBL_DIR = os.path.join(OUTPUT_BASE, "labels", "train") os.makedirs(OUTPUT_IMG_DIR, exist_ok=True) os.makedirs(OUTPUT_LBL_DIR, exist_ok=True) extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp'] image_files = [] for ext in extensions: image_files.extend(glob.glob(os.path.join(SOURCE_IMAGES_PATH, ext))) image_files.sort() if not image_files: raise ValueError(f"No images found in {SOURCE_IMAGES_PATH}") print(f"--- PHASE 1: AUTO-LABELING ({len(image_files)} images) ---") print(f"Loading SAM2: {SAM_MODEL_PATH}...") overrides = dict(conf=0.25, task='segment', mode='predict', model=SAM_MODEL_PATH) predictor = SAM2Predictor(overrides=overrides) # ========================================== # 3. HELPER: MASK TO POLYGON # ========================================== def mask_to_yolo_polygon(mask, img_width, img_height): contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) polygons = [] for contour in contours: contour = contour.flatten().tolist() if len(contour) < 6: continue normalized_poly = [] for i in range(0, len(contour), 2): x_norm = max(0, min(1, contour[i] / img_width)) y_norm = max(0, min(1, contour[i+1] / img_height)) normalized_poly.extend([x_norm, y_norm]) polygons.append(normalized_poly) return polygons # ========================================== # 4. INTERACTIVE POINT PROMPT # ========================================== first_image_path = image_files[0] first_image = cv2.imread(first_image_path) if first_image is None: raise ValueError("Could not load first image") clone = first_image.copy() input_points = [] input_labels = [] def mouse_callback(event, x, y, flags, param): global input_points, input_labels, first_image # Left Click: Add Positive Point (Foreground) if event == cv2.EVENT_LBUTTONDOWN: input_points.append([x, y]) input_labels.append(1) # 1 = Positive # Visualize cv2.circle(first_image, (x, y), 5, (0, 255, 0), -1) # Green dot cv2.imshow("Define Object", first_image) print(f"Added point: {x}, {y}") # Right Click: Remove Last Point (Undo) elif event == cv2.EVENT_RBUTTONDOWN: if input_points: input_points.pop() input_labels.pop() # Redraw image first_image = clone.copy() for pt in input_points: cv2.circle(first_image, (pt[0], pt[1]), 5, (0, 255, 0), -1) cv2.imshow("Define Object", first_image) print("Removed last point") cv2.namedWindow("Define Object") cv2.setMouseCallback("Define Object", mouse_callback) cv2.imshow("Define Object", first_image) print("\n--- INSTRUCTIONS ---") print("1. LEFT CLICK to add points on the object.") print("2. RIGHT CLICK to remove the last point.") print("3. Press SPACE or ENTER to confirm and start.") print("4. Press 'r' to reset all points.") while True: key = cv2.waitKey(1) & 0xFF if key in [32, 13]: # Space/Enter if input_points: break print("Please add at least one point.") elif key == ord('r'): # Reset first_image = clone.copy() input_points = [] input_labels = [] cv2.imshow("Define Object", first_image) elif key == 27: exit() # Esc cv2.destroyAllWindows() # ========================================== # 5. BATCH LABEL GENERATION # ========================================== print(f"\nApplying {len(input_points)} points to all images...") valid_labels_count = 0 for img_path in image_files: filename = os.path.basename(img_path) img = cv2.imread(img_path) if img is None: continue h, w = img.shape[:2] # SAM2 Inference predictor.set_image(img) # Use points instead of bboxes # points expects list of shape [N, 2], labels list of shape [N] results = predictor(points=input_points, labels=input_labels) if results[0].masks is not None: masks = results[0].masks.data.cpu().numpy() # Save Image cv2.imwrite(os.path.join(OUTPUT_IMG_DIR, filename), img) # Save Label txt_path = os.path.join(OUTPUT_LBL_DIR, os.path.splitext(filename)[0] + ".txt") with open(txt_path, 'w') as f: for mask in masks: polygons = mask_to_yolo_polygon(mask, w, h) for poly in polygons: f.write(f"0 {' '.join(map(str, poly))}\n") valid_labels_count += 1 print(f"Labeled: {filename}") print(f"--- Labeling Complete. {valid_labels_count} images prepared. ---") # ========================================== # 6. MEMORY CLEANUP # ========================================== del predictor del overrides gc.collect() torch.cuda.empty_cache() print("SAM2 unloaded.") # ========================================== # 7. GENERATE YAML & TRAIN # ========================================== yaml_path = os.path.join(OUTPUT_BASE, "dataset.yaml") yaml_data = { 'path': os.path.abspath(OUTPUT_BASE), 'train': 'images/train', 'val': 'images/train', 'names': {0: CLASS_NAME} } with open(yaml_path, 'w') as f: yaml.dump(yaml_data, f) print("--- PHASE 2: TRAINING YOLO11 ---") model = YOLO(YOLO_MODEL_NAME) results = model.train(data=yaml_path, epochs=TRAIN_EPOCHS, imgsz=TRAIN_IMGSZ, batch=4, device=0) print("DONE.")