Source code for cpauger.paste_obj

from pycocotools.coco import COCO
import cv2
import numpy as np
import json
from typing import List, Dict, Tuple
import os
import random

__all__ = ["paste_object", "paste_crops_on_bkgs"]
# Function to adjust segmentation based on bbox
def adjust_segmentation(bbox, segmentation):
    x_offset, y_offset = bbox[0], bbox[1]
    adjusted_segmentation = []
    for polygon in segmentation:
        adjusted_polygon = []
        for i in range(0, len(polygon), 2):
            adjusted_polygon.append(polygon[i] + x_offset)
            adjusted_polygon.append(polygon[i + 1] + y_offset)
        adjusted_segmentation.append(adjusted_polygon)
    return adjusted_segmentation


def get_polygon_coordinates(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    segmentation = []
    for contour in contours:
        contour = contour.flatten().tolist()
        segmentation.append(contour)
    return segmentation

def get_adjusted_width_height(new_w, new_h, obj_w, obj_h):
    if (int(new_w) <= 0) or (int(new_h) <= 0):
        new_w = obj_w + 20
        new_h = obj_h + 20
    return new_w, new_h

def get_adjusted_object(x, y, resized_object, dest_image):
    obj_bottom_loc = y+resized_object.shape[0]
    obj_right_loc = x+resized_object.shape[1]
    #print(f"obj_bottom_loc: {obj_bottom_loc}")
    #print(f"obj_right_loc: {obj_right_loc}")
    while obj_bottom_loc > dest_image.shape[0]:
        obj_bottom_loc = dest_image.shape[0] - y  
        #print(f"reset obj_bottom_loc: to {obj_bottom_loc}")
        resized_object = cv2.resize(resized_object, (resized_object.shape[1], 
                                                    obj_bottom_loc
                                                    ), 
                                    interpolation=cv2.INTER_AREA
                                    )
        #print(f"In obj_bottom_loc while loop")
    while obj_right_loc > dest_image.shape[1]:
        obj_right_loc = dest_image.shape[1] - x 
        #print(f"reset obj_right_loc: to {obj_right_loc}")
        resized_object = cv2.resize(resized_object, (obj_right_loc, 
                                                    resized_object.shape[0]
                                                    ), 
                                    interpolation=cv2.INTER_AREA
                                    )
        #print(f"In obj_right_loc while loop")
    return resized_object

def get_paste_location_coordinate(sample_location_randomly, dest_w,
                                dest_h
                                ):
    if sample_location_randomly:
        min_x = random.random()
        max_x = random.uniform(min_x, 1)
        min_y = random.random()
        max_y = random.uniform(min_y, 1)
        x = int(min_x * dest_w)
        y = int(min_y * dest_h)
        max_x = int(max_x * dest_w)
        max_y = int(max_y * dest_h)
    else:
        x = int(min_x * dest_w)
        y = int(min_y * dest_h)
        max_x = int(max_x * dest_w)
        max_y = int(max_y * dest_h)
    return x, y, max_x, max_y

def get_resized_object(resize_w, resize_h, cropped_object,):
    if resize_w and resize_h:
        obj_h, obj_w = cropped_object.shape[:2]
        aspect_ratio = obj_w / obj_h
        if resize_w / resize_h > aspect_ratio:
            resize_w = int(resize_h * aspect_ratio)
        else:
            resize_h = int(resize_w / aspect_ratio)
        resized_object = cv2.resize(cropped_object, (resize_w, resize_h), 
                                    interpolation=cv2.INTER_AREA
                                    )
    else:
        resized_object = cropped_object
    return resized_object

def get_scaled_object_width_height(x, y, max_x, max_y, obj_w, obj_h):
    scale_x = (max_x - x) / obj_w
    if scale_x <= 0:
        scale_x = 0.1
    scale_y = (max_y - y) / obj_h
    if scale_y <= 0:
        scale_y = 0.1
        
    scale = min(scale_x, scale_y)
    new_w = int(obj_w * scale)
    new_h = int(obj_h * scale)
    return new_w, new_h
                            
[docs] def paste_object(dest_img_path, cropped_objects: Dict[str, List[np.ndarray]], min_x=None, min_y=None, max_x=None, max_y=None, resize_w=None, resize_h=None, sample_location_randomly: bool = True, )->Tuple[np.ndarray, List, List, List]: # Load the destination image dest_image = cv2.imread(dest_img_path, cv2.IMREAD_UNCHANGED) dest_image = cv2.cvtColor(dest_image, cv2.COLOR_BGR2RGB) dest_h, dest_w = dest_image.shape[:2] if not isinstance(cropped_objects, dict): raise ValueError(f"""cropped_objects is expected to be a dictionary of key being the category_id and value being a list of cropped object image (np.ndarray) """) bboxes, segmentations, category_ids = [], [], [] # Resize the cropped object if resize dimensions are provided for cat_id in cropped_objects: cat_cropped_objects = cropped_objects[cat_id] if not isinstance(cat_cropped_objects, list): cat_cropped_objects = [cat_cropped_objects] for cropped_object in cat_cropped_objects: x, y, max_x, max_y = get_paste_location_coordinate(sample_location_randomly=sample_location_randomly, dest_w=dest_w, dest_h=dest_h ) resized_object = get_resized_object(resize_w=resize_w, resize_h=resize_h, cropped_object=cropped_object ) # Ensure the resized object fits within the specified area obj_h, obj_w = resized_object.shape[:2] if obj_w > (max_x - x) or obj_h > (max_y - y): new_w, new_h = get_scaled_object_width_height(x, y, max_x, max_y, obj_w=obj_w, obj_h=obj_h ) new_w, new_h = get_adjusted_width_height(new_w=new_w, new_h=new_h, obj_w=obj_w, obj_h=obj_h ) resized_object = cv2.resize(resized_object, (new_w, new_h), interpolation=cv2.INTER_AREA ) resized_object = get_adjusted_object(x=x, y=y, resized_object=resized_object, dest_image=dest_image ) if resized_object.shape[2] == 3: resized_object = cv2.cvtColor(resized_object, cv2.COLOR_RGB2RGBA) mask = resized_object[:, :, 3] mask_inv = cv2.bitwise_not(mask) resized_object = resized_object[:, :, :3] # Define the region of interest (ROI) in the destination image roi = dest_image[y:y+resized_object.shape[0], x:x+resized_object.shape[1]] roi = cv2.resize(roi, (mask_inv.shape[1], mask_inv.shape[0]), interpolation=cv2.INTER_AREA ) #mask_inv = cv2.resize(mask_inv, (roi.shape[1], roi.shape[0])) # Black-out the area of the object in the ROI #print(f"roi: {roi.shape} ==== mask_inv: {mask_inv.shape}") img_bg = cv2.bitwise_and(roi, roi, mask=mask_inv) # Take only region of the object from the object image obj_fg = cv2.bitwise_and(resized_object, resized_object, mask=mask) # sometimes, we are unable to get a segmask large enough for visualozation or # not enough polygon coordinate. # we want to check if that is the case and then just skip it without pasting it segmentation = get_polygon_coordinates(mask=mask) bbox = [x, y, resized_object.shape[1], resized_object.shape[0]] adjusted_segmentation = adjust_segmentation(bbox=bbox, segmentation=segmentation) if len(adjusted_segmentation[0]) < 8: print(f"adjusted_segmentation: {adjusted_segmentation}") print("segemntation has less than 8 polygon coordinates hence not used and object not pasted") else: # Put the object in the ROI and modify the destination image dst = cv2.add(img_bg, obj_fg) dest_image[y:y+resized_object.shape[0], x:x+resized_object.shape[1]] = dst segmentations.append(adjusted_segmentation) category_ids.append(int(cat_id)) bboxes.append(bbox) return dest_image, bboxes, segmentations, category_ids
[docs] def paste_crops_on_bkgs(bkgs, all_crops, objs_paste_num: Dict, output_img_dir, save_coco_ann_as, min_x=None, min_y=None, max_x=None, max_y=None, resize_width=None, resize_height=None, sample_location_randomly=True ): os.makedirs(output_img_dir, exist_ok=True) coco_ann = {"categories": [{"id": obj_idx+1, "name": obj} for obj_idx, obj in enumerate(sorted(objs_paste_num)) ], "images": [], "annotations": [] } ann_ids = [] for bkg_idx, bkg in enumerate(bkgs): sampled_obj = {obj_idx+1: random.sample(all_crops[obj], int(objs_paste_num[obj])) for obj_idx, obj in enumerate(sorted(objs_paste_num)) } dest_img, bboxes, segmasks, category_ids = paste_object(dest_img_path=bkg, ## showed also return the obj_idx as category_id cropped_objects=sampled_obj, min_x=min_x, min_y=min_y, max_x=max_x, max_y=max_y, resize_h=resize_height, resize_w=resize_width, sample_location_randomly=sample_location_randomly ) file_name = os.path.basename(bkg) img_path = os.path.join(output_img_dir, file_name) dest_img = cv2.cvtColor(dest_img, cv2.COLOR_BGR2RGB) cv2.imwrite(img_path, dest_img) assert(len(bboxes) == len(segmasks) == len(category_ids)), f"""bboxes: {len(bboxes)}, segmasks: {len(segmasks)} and category_ids: {len(category_ids)} are not equal length""" img_height, img_width = dest_img.shape[0], dest_img.shape[1] img_id = bkg_idx+1 image_info = {"file_name": file_name, "height": img_height, "width": img_width, "id": img_id } coco_ann["images"].append(image_info) for ann_ins in range(0, len(bboxes)): bbox = bboxes[ann_ins] segmask = segmasks[ann_ins] ann_id = len(ann_ids) + 1 ann_ids.append(ann_id) category_id = category_ids[ann_ins] annotation = {"id": ann_id, "image_id": img_id, "category_id": category_id, "bbox": bbox, "segmentation": segmask } coco_ann["annotations"].append(annotation) with open(save_coco_ann_as, "w") as filepath: json.dump(coco_ann, filepath)