196 lines
6.0 KiB
Python
196 lines
6.0 KiB
Python
from operator import itemgetter
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from kn_utils.logging import logger # type: ignore
|
|
from numpy import ndarray as Array
|
|
from scipy.stats import norm # type: ignore
|
|
|
|
|
|
def show_multiple(arrs: Tuple[Array], title: str = ""):
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
for a in arrs:
|
|
plt.plot(a)
|
|
plt.title(title)
|
|
plt.show()
|
|
|
|
|
|
def show(arr: Array, title: str = ""):
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
plt.plot(arr)
|
|
plt.title(title)
|
|
plt.show()
|
|
|
|
|
|
def save_plot(arr: Array, name: str, title: str = "") -> None:
|
|
plt.clf()
|
|
plt.cla()
|
|
plt.close()
|
|
plt.plot(arr)
|
|
plt.title(title)
|
|
plt.savefig(Path(str(name) + ".png"))
|
|
|
|
|
|
def save_lines(img: Array, lines: list[dict[str, int]]) -> None:
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
getter = itemgetter("x1", "y1", "x2", "y2")
|
|
for line in lines:
|
|
x1, y1, x2, y2 = getter(line)
|
|
img = cv2.line(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=3)
|
|
cv2.imwrite("/tmp/lines.png", img)
|
|
|
|
|
|
def make_gaussian_kernel(kernel_size: int, sd: float) -> Array:
|
|
kernel_size += int(not kernel_size % 2)
|
|
wing_size = int((kernel_size - 1) / 2)
|
|
xvals = np.arange(-wing_size, wing_size + 1)
|
|
kernel = norm.pdf(xvals, scale=sd)
|
|
kernel /= np.sum(kernel)
|
|
|
|
return kernel
|
|
|
|
|
|
def make_gaussian_nonpositive_kernel(kernel_size: int, sd: float) -> Array:
|
|
kernel_size += int(not kernel_size % 2)
|
|
wing_size = int((kernel_size - 1) / 2)
|
|
xvals = np.arange(-wing_size, wing_size + 1)
|
|
kernel = norm.pdf(xvals, scale=sd)
|
|
kernel /= np.sum(kernel)
|
|
|
|
return kernel
|
|
|
|
|
|
def make_quadratic_kernel(kernel_size: int, ratio: float) -> Array:
|
|
kernel_size += int(not kernel_size % 2)
|
|
wing_size = int((kernel_size - 1) / 2)
|
|
kernel = np.array(list(map(lambda x: float(-(x**2)), range(-wing_size, wing_size + 1))))
|
|
maxval, minval = np.max(kernel), np.min(kernel)
|
|
diff = maxval - minval
|
|
kernel += diff / (1 - ratio)
|
|
kernel /= np.sum(kernel)
|
|
return kernel
|
|
|
|
|
|
def min_avg_for_interval(filtered: Array, interval: int) -> tuple[float, int]:
|
|
n = len(filtered)
|
|
avgs: list[float] = [np.mean(filtered[range(start, n, interval)]) for start in range(interval)]
|
|
best: float = min(avgs)
|
|
return best, avgs.index(best)
|
|
|
|
|
|
def search_intervals(filtered: Array, min_interval: int, max_interval: int):
|
|
performance = [
|
|
(interval, *min_avg_for_interval(filtered, interval)) for interval in range(min_interval, max_interval + 1)
|
|
]
|
|
best = min(performance, key=lambda x: x[1])
|
|
return best[0], best[2]
|
|
|
|
|
|
def filter_array(
|
|
array: Array,
|
|
sum_filter: Array | None,
|
|
padding: Optional[Array] = None,
|
|
pad_value_function: Callable[[Array], float] = lambda x: 255.0, # np.mean,
|
|
) -> Array:
|
|
if sum_filter is None:
|
|
return array
|
|
fsize = len(sum_filter)
|
|
assert fsize % 2
|
|
if padding is None: # ensures that output size matches the input size
|
|
pad = int((fsize - 1) / 2)
|
|
padding = np.full(pad, pad_value_function(array))
|
|
|
|
return np.convolve(np.concatenate((padding, array, padding)), sum_filter, "valid")
|
|
|
|
|
|
ROW_FILTER1_WIDTH = 30
|
|
ROW_FILTER1_SD = 6
|
|
ROW_FILTER2_WIDTH = 20
|
|
ROW_FILTER2_SD = 4
|
|
COL_FILTER1_WIDTH = 90
|
|
COL_FILTER1_SD = 15
|
|
COL_FILTER2_WIDTH = 70
|
|
COL_FILTER2_SD = 12
|
|
COL_FILTER3_WIDTH = 200
|
|
COL_FILTER3_SD = 20
|
|
FILTERS: dict[str, dict[int, np.ndarray | None]] = {
|
|
"row": {
|
|
1: make_gaussian_kernel(ROW_FILTER1_WIDTH, ROW_FILTER1_SD),
|
|
2: make_gaussian_kernel(ROW_FILTER2_WIDTH, ROW_FILTER2_SD),
|
|
3: None,
|
|
},
|
|
"col": {
|
|
1: make_gaussian_kernel(COL_FILTER1_WIDTH, COL_FILTER1_SD),
|
|
2: make_gaussian_kernel(COL_FILTER2_WIDTH, COL_FILTER2_SD),
|
|
3: make_gaussian_kernel(COL_FILTER3_WIDTH, COL_FILTER3_SD),
|
|
},
|
|
}
|
|
|
|
|
|
def filter_fp_col_lines(line_list: list[int], filt_sums: Array) -> list[int]:
|
|
if not list(line_list):
|
|
return []
|
|
centers = list(np.where((filt_sums[1:-1] < filt_sums[:-2]) * (filt_sums[1:-1] < filt_sums[2:]))[0] + 1)
|
|
if not centers:
|
|
return []
|
|
|
|
if line_list[0] > centers[0]:
|
|
centers = centers[1:] + [len(filt_sums) - 1]
|
|
mindiff = np.std(filt_sums)
|
|
line_list = [
|
|
maxidx for maxidx, minidx in zip(line_list, centers) if (filt_sums[maxidx] - filt_sums[minidx]) > mindiff
|
|
]
|
|
return line_list
|
|
|
|
|
|
def get_lines_either(table_array: Array, horizontal=True) -> list[int]:
|
|
key = "row" if horizontal else "col"
|
|
|
|
sums = np.mean(table_array, axis=int(horizontal))
|
|
threshold = 0.3 * 255 # np.mean(sums) - (1 + 2 * horizontal) * np.std(sums)
|
|
predicate = 1000.0 * (sums < threshold)
|
|
sums = np.maximum(
|
|
np.maximum(sums[1:-1], predicate[1:-1]),
|
|
np.maximum(predicate[:-2], predicate[2:]),
|
|
)
|
|
filtered_sums = filter_array(sums, FILTERS[key][1])
|
|
filtered_sums = filter_array(filtered_sums, FILTERS[key][2])
|
|
filtered_sums = filter_array(filtered_sums, FILTERS[key][3])
|
|
|
|
lines = list(
|
|
np.where((filtered_sums[1:-1] > filtered_sums[:-2]) * (filtered_sums[1:-1] > filtered_sums[2:]))[0] + 1
|
|
)
|
|
if not horizontal:
|
|
lines = filter_fp_col_lines(lines, filtered_sums)
|
|
|
|
return lines
|
|
|
|
|
|
def img_bytes_to_array(img_bytes: bytes) -> Array:
|
|
img_np = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
|
|
return img_np
|
|
|
|
|
|
def infer_lines(img: Array) -> dict[str, dict[str, int] | list[dict[str, int]]]:
|
|
cv2.imwrite("/tmp/table.png", img)
|
|
_, img = cv2.threshold(img, 220, 255, cv2.THRESH_BINARY)
|
|
cv2.imwrite("/tmp/table_bin.png", img)
|
|
h, w = map(int, img.shape)
|
|
row_vals = map(int, get_lines_either(img, horizontal=True))
|
|
col_vals = map(int, get_lines_either(img, horizontal=False))
|
|
|
|
lines = [{"x1": 0, "y1": r, "x2": w, "y2": r} for r in row_vals] + [
|
|
{"x1": c, "y1": 0, "x2": c, "y2": h} for c in col_vals
|
|
]
|
|
|
|
save_lines(img, lines)
|
|
|
|
return {"tableLines": lines, "imageInfo": {"height": h, "width": w}}
|