Refactoring

This commit is contained in:
Matthias Bisping 2023-01-04 13:32:57 +01:00
parent 77f565c652
commit ac84494613
2 changed files with 18 additions and 4 deletions

View File

@ -4,7 +4,9 @@ from __future__ import annotations
from operator import attrgetter
from typing import TYPE_CHECKING, Iterable
from funcy import juxt, rpartial
from funcy import juxt, rpartial, compose, lflatten
from cv_analysis.utils import lift
if TYPE_CHECKING:
from cv_analysis.utils.structures import Rectangle
@ -130,8 +132,11 @@ def intersection_along_axis(alpha, beta, axis):
def get_component_accessor(component):
return attrgetter(f"{axis}{component}")
c1 = get_component_accessor(1)
c2 = get_component_accessor(2)
def make_access_components_and_sort_fn(component):
return compose(sorted, lift(get_component_accessor(component)))
c1 = make_access_components_and_sort_fn(1)
c2 = make_access_components_and_sort_fn(2)
# Cases:
# a b
@ -147,6 +152,6 @@ def intersection_along_axis(alpha, beta, axis):
# b a
# (----[--)----] => [b1, a1, b2, a2] => max(0, (b2 - a1)) = (b2 - a1)
coords = [*sorted([c1(alpha), c1(beta)]), *sorted([c2(alpha), c2(beta)])]
coords = lflatten(juxt(c1, c2)((alpha, beta)))
intersection = max(0, coords[2] - coords[1])
return intersection

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from numpy import generic
import cv2
@ -17,3 +19,10 @@ def npconvert(ob):
if isinstance(ob, generic):
return ob.item()
raise TypeError
def lift(fn):
def lifted(coll):
yield from map(fn, coll)
return lifted