195 lines
6.4 KiB
Python
195 lines
6.4 KiB
Python
import os
|
|
import glob
|
|
import cv2
|
|
import yaml
|
|
import gc
|
|
import torch
|
|
import numpy as np
|
|
from ultralytics import SAM, YOLO
|
|
|
|
# Try importing the specific Predictor class
|
|
try:
|
|
from ultralytics.models.sam import SAM2Predictor
|
|
except ImportError:
|
|
from ultralytics.models.sam.predict import SAM2Predictor
|
|
|
|
# ==========================================
|
|
# 1. CONFIGURATION
|
|
# ==========================================
|
|
SOURCE_IMAGES_PATH = "Datasets/02/" # Your raw images
|
|
OUTPUT_BASE = "yolo_dataset" # Output folder
|
|
SAM_MODEL_PATH = "Models/sam2_b.pt" # SAM2 Model
|
|
YOLO_MODEL_NAME = "yolo11m-seg.pt" # YOLO Model (n=nano, m=medium)
|
|
|
|
TRAIN_EPOCHS = 10
|
|
TRAIN_IMGSZ = 640
|
|
CLASS_NAME = "my_object"
|
|
|
|
# ==========================================
|
|
# 2. SETUP & INITIALIZATION
|
|
# ==========================================
|
|
OUTPUT_IMG_DIR = os.path.join(OUTPUT_BASE, "images", "train")
|
|
OUTPUT_LBL_DIR = os.path.join(OUTPUT_BASE, "labels", "train")
|
|
|
|
os.makedirs(OUTPUT_IMG_DIR, exist_ok=True)
|
|
os.makedirs(OUTPUT_LBL_DIR, exist_ok=True)
|
|
|
|
extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
|
|
image_files = []
|
|
for ext in extensions:
|
|
image_files.extend(glob.glob(os.path.join(SOURCE_IMAGES_PATH, ext)))
|
|
image_files.sort()
|
|
|
|
if not image_files:
|
|
raise ValueError(f"No images found in {SOURCE_IMAGES_PATH}")
|
|
|
|
print(f"--- PHASE 1: AUTO-LABELING ({len(image_files)} images) ---")
|
|
print(f"Loading SAM2: {SAM_MODEL_PATH}...")
|
|
|
|
overrides = dict(conf=0.25, task='segment', mode='predict', model=SAM_MODEL_PATH)
|
|
predictor = SAM2Predictor(overrides=overrides)
|
|
|
|
# ==========================================
|
|
# 3. HELPER: MASK TO POLYGON
|
|
# ==========================================
|
|
def mask_to_yolo_polygon(mask, img_width, img_height):
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
polygons = []
|
|
for contour in contours:
|
|
contour = contour.flatten().tolist()
|
|
if len(contour) < 6: continue
|
|
normalized_poly = []
|
|
for i in range(0, len(contour), 2):
|
|
x_norm = max(0, min(1, contour[i] / img_width))
|
|
y_norm = max(0, min(1, contour[i+1] / img_height))
|
|
normalized_poly.extend([x_norm, y_norm])
|
|
polygons.append(normalized_poly)
|
|
return polygons
|
|
|
|
# ==========================================
|
|
# 4. INTERACTIVE POINT PROMPT
|
|
# ==========================================
|
|
first_image_path = image_files[0]
|
|
first_image = cv2.imread(first_image_path)
|
|
if first_image is None: raise ValueError("Could not load first image")
|
|
|
|
clone = first_image.copy()
|
|
input_points = []
|
|
input_labels = []
|
|
|
|
def mouse_callback(event, x, y, flags, param):
|
|
global input_points, input_labels, first_image
|
|
|
|
# Left Click: Add Positive Point (Foreground)
|
|
if event == cv2.EVENT_LBUTTONDOWN:
|
|
input_points.append([x, y])
|
|
input_labels.append(1) # 1 = Positive
|
|
|
|
# Visualize
|
|
cv2.circle(first_image, (x, y), 5, (0, 255, 0), -1) # Green dot
|
|
cv2.imshow("Define Object", first_image)
|
|
print(f"Added point: {x}, {y}")
|
|
|
|
# Right Click: Remove Last Point (Undo)
|
|
elif event == cv2.EVENT_RBUTTONDOWN:
|
|
if input_points:
|
|
input_points.pop()
|
|
input_labels.pop()
|
|
|
|
# Redraw image
|
|
first_image = clone.copy()
|
|
for pt in input_points:
|
|
cv2.circle(first_image, (pt[0], pt[1]), 5, (0, 255, 0), -1)
|
|
cv2.imshow("Define Object", first_image)
|
|
print("Removed last point")
|
|
|
|
cv2.namedWindow("Define Object")
|
|
cv2.setMouseCallback("Define Object", mouse_callback)
|
|
cv2.imshow("Define Object", first_image)
|
|
|
|
print("\n--- INSTRUCTIONS ---")
|
|
print("1. LEFT CLICK to add points on the object.")
|
|
print("2. RIGHT CLICK to remove the last point.")
|
|
print("3. Press SPACE or ENTER to confirm and start.")
|
|
print("4. Press 'r' to reset all points.")
|
|
|
|
while True:
|
|
key = cv2.waitKey(1) & 0xFF
|
|
if key in [32, 13]: # Space/Enter
|
|
if input_points: break
|
|
print("Please add at least one point.")
|
|
elif key == ord('r'): # Reset
|
|
first_image = clone.copy()
|
|
input_points = []
|
|
input_labels = []
|
|
cv2.imshow("Define Object", first_image)
|
|
elif key == 27: exit() # Esc
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
# ==========================================
|
|
# 5. BATCH LABEL GENERATION
|
|
# ==========================================
|
|
print(f"\nApplying {len(input_points)} points to all images...")
|
|
valid_labels_count = 0
|
|
|
|
for img_path in image_files:
|
|
filename = os.path.basename(img_path)
|
|
img = cv2.imread(img_path)
|
|
if img is None: continue
|
|
h, w = img.shape[:2]
|
|
|
|
# SAM2 Inference
|
|
predictor.set_image(img)
|
|
|
|
# Use points instead of bboxes
|
|
# points expects list of shape [N, 2], labels list of shape [N]
|
|
results = predictor(points=input_points, labels=input_labels)
|
|
|
|
if results[0].masks is not None:
|
|
masks = results[0].masks.data.cpu().numpy()
|
|
|
|
# Save Image
|
|
cv2.imwrite(os.path.join(OUTPUT_IMG_DIR, filename), img)
|
|
|
|
# Save Label
|
|
txt_path = os.path.join(OUTPUT_LBL_DIR, os.path.splitext(filename)[0] + ".txt")
|
|
with open(txt_path, 'w') as f:
|
|
for mask in masks:
|
|
polygons = mask_to_yolo_polygon(mask, w, h)
|
|
for poly in polygons:
|
|
f.write(f"0 {' '.join(map(str, poly))}\n")
|
|
|
|
valid_labels_count += 1
|
|
print(f"Labeled: {filename}")
|
|
|
|
print(f"--- Labeling Complete. {valid_labels_count} images prepared. ---")
|
|
|
|
# ==========================================
|
|
# 6. MEMORY CLEANUP
|
|
# ==========================================
|
|
del predictor
|
|
del overrides
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
print("SAM2 unloaded.")
|
|
|
|
# ==========================================
|
|
# 7. GENERATE YAML & TRAIN
|
|
# ==========================================
|
|
yaml_path = os.path.join(OUTPUT_BASE, "dataset.yaml")
|
|
yaml_data = {
|
|
'path': os.path.abspath(OUTPUT_BASE),
|
|
'train': 'images/train',
|
|
'val': 'images/train',
|
|
'names': {0: CLASS_NAME}
|
|
}
|
|
|
|
with open(yaml_path, 'w') as f:
|
|
yaml.dump(yaml_data, f)
|
|
|
|
print("--- PHASE 2: TRAINING YOLO11 ---")
|
|
model = YOLO(YOLO_MODEL_NAME)
|
|
results = model.train(data=yaml_path, epochs=TRAIN_EPOCHS, imgsz=TRAIN_IMGSZ, batch=4, device=0)
|
|
|
|
print("DONE.") |