Source code for cpauger.crop_obj

from pycocotools.coco import COCO
import cv2
import json
from typing import Union, List, Dict
import os


[docs] def crop_obj_per_image(obj_names: list, imgname: Union[str, List], img_dir, coco_ann_file: str ) -> Union[Dict[str,List], None]: imgname = os.path.basename(imgname) cropped_objs_collection = {} with open(coco_ann_file, "r") as filepath: coco_data = json.load(filepath) categories = coco_data["categories"] category_id_to_name_map = {cat["id"]: cat["name"] for cat in categories} category_name_to_id_map = {cat["name"]: cat["id"] for cat in categories} coco = COCO(coco_ann_file) images = coco_data["images"] image_info = [img_info for img_info in images if os.path.basename(img_info["file_name"])==imgname][0] image_id = image_info["id"] annotations = coco_data["annotations"] img_ann = [ann_info for ann_info in annotations if ann_info["image_id"]==image_id] img_catids = set(ann_info["category_id"] for ann_info in img_ann) img_objnames = [category_id_to_name_map[catid] for catid in img_catids] img_path = os.path.join(img_dir, imgname) image = cv2.imread(img_path) objs_to_crop = set(img_objnames).intersection(set(obj_names)) if objs_to_crop: for objname in obj_names: object_masks = [] if objname in img_objnames: obj_id = category_name_to_id_map[objname] for ann in img_ann: if ann["category_id"] == obj_id: mask = coco.annToMask(ann) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: x, y, w, h = cv2.boundingRect(contour) cropped_object = image[y:y+h, x:x+w] mask_cropped = mask[y:y+h, x:x+w] cropped_object = cv2.bitwise_and(cropped_object, cropped_object, mask=mask_cropped ) # Remove the background (set to transparent) cropped_object = cv2.cvtColor(cropped_object, cv2.COLOR_BGR2RGBA) cropped_object[:, :, 3] = mask_cropped * 255 object_masks.append(cropped_object) if objname not in cropped_objs_collection.keys(): cropped_objs_collection[objname] = object_masks else: for each_mask in object_masks: cropped_objs_collection[objname].append(each_mask) return cropped_objs_collection
[docs] def collate_all_crops(object_to_cropped, imgnames_for_crop, img_dir, coco_ann_file ): all_crops = {} for img in imgnames_for_crop: img = os.path.basename(img) crop_obj = crop_obj_per_image(obj_names=object_to_cropped, imgname=img, img_dir=img_dir, coco_ann_file=coco_ann_file ) for each_object in crop_obj.keys(): if each_object not in all_crops.keys(): all_crops[each_object] = crop_obj[each_object] else: cpobjs = crop_obj[each_object] if all_crops[each_object] is None: all_crops[each_object] = cpobjs else: for idx, cpobj in enumerate(cpobjs): all_crops[each_object].append(cpobj) return all_crops