from collections import namedtuple from functools import partial import cv2 from matplotlib import pyplot as plt def show_mpl(image): fig, ax = plt.subplots(1, 1) fig.set_size_inches(20, 20) ax.imshow(image) plt.show() def show_cv2(image): cv2.imshow("", image) cv2.waitKey(0) def copy_and_normalize_channels(image): image = image.copy() try: image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) except cv2.error: pass return image def draw_contours(image, contours): image = copy_and_normalize_channels(image) for cont in contours: cv2.drawContours(image, cont, -1, (0, 255, 0), 4) return image def draw_rectangles(image, rectangles, color=None): image = copy_and_normalize_channels(image) if not color: color = (0, 255, 0) for rect in rectangles: x, y, w, h = rect cv2.rectangle(image, (x, y), (x + w, y + h), color, 2) return image def draw_stats(image, stats, annotate=False): image = copy_and_normalize_channels(image) keys = ["x", "y", "w", "h"] def annotate_stat(x, y, w, h): for i, (s, v) in enumerate(zip(keys, [x, y, w, h])): anno = f"{s} = {v}" xann = int(x + 5) yann = int(y + h - (20 * (i + 1))) cv2.putText(image, anno, (xann, yann), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) def draw_stat(stat): x, y, w, h, area = stat cv2.rectangle(image, (x, y), (x + w, y + h), (0, 255, 0), 2) if annotate: annotate_stat(x, y, w, h) for stat in stats[2:]: draw_stat(stat) return image def remove_overlapping(rectangles): def overlap(a, b): return compute_intersection(a, b) > 0 def does_not_overlap(rect, rectangles): return not any(overlap(rect, r2) for r2 in rectangles if not rect == r2) rectangles = list(map(xywh_to_vec_rect, rectangles)) rectangles = filter(partial(does_not_overlap, rectangles=rectangles), rectangles) rectangles = map(vec_rect_to_xywh, rectangles) return rectangles def remove_included(rectangles): def included(a, b): return b.xmin >= a.xmin and b.ymin >= a.ymin and b.xmax <= a.xmax and b.ymax <= a.ymax def is_not_included(rect, rectangles): return not any(included(r2, rect) for r2 in rectangles if not rect == r2) rectangles = list(map(xywh_to_vec_rect, rectangles)) rectangles = filter(partial(is_not_included, rectangles=rectangles), rectangles) rectangles = map(vec_rect_to_xywh, rectangles) return rectangles Rectangle = namedtuple("Rectangle", "xmin ymin xmax ymax") def make_box(x1, y1, x2, y2): keys = "x1", "y1", "x2", "y2" return dict(zip(keys, [x1, y1, x2, y2])) def compute_intersection(a, b): dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin) dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin) return dx * dy if (dx >= 0) and (dy >= 0) else 0 def has_no_parent(hierarchy): return hierarchy[-1] <= 0 def xywh_to_vec_rect(rect): x1, y1, w, h = rect x2 = x1 + w y2 = y1 + h return Rectangle(x1, y1, x2, y2) def vec_rect_to_xywh(rect): x, y, x2, y2 = rect w = x2 - x h = y2 - y return x, y, w, h