cv-analysis-service/src/cv_analysis/table_inference.py

162 lines
4.6 KiB
Python

from pathlib import Path
from typing import Callable, Iterable, Optional, Tuple
from typing import Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
from numpy import ndarray as Array
from scipy.signal import argrelextrema
from scipy.stats import norm
import fitz
from pdf2img.conversion import convert_pages_to_images
def show_multiple(arrs: Tuple[Array], title: str = ""):
plt.clf()
plt.cla()
plt.close()
for a in arrs:
plt.plot(a)
plt.title(title)
plt.show()
def show(arr: Array, title: str = ""):
plt.clf()
plt.cla()
plt.close()
plt.plot(arr)
plt.title(title)
plt.show()
def save_plot(arr: Array, name: str, title: str = "") -> None:
plt.clf()
plt.cla()
plt.close()
plt.plot(arr)
plt.title(title)
plt.savefig(Path(str(name) + ".png"))
def make_gaussian_kernel(kernel_size: int, sd: float) -> Array:
kernel_size += int(not kernel_size % 2)
wing_size = int((kernel_size - 1) / 2)
xvals = np.arange(-wing_size, wing_size + 1)
kernel = norm.pdf(xvals, scale=sd)
# maxval, minval = np.max(kernel), np.min(kernel)
# diff = maxval - minval
# kernel += (diff / (1 - ratio))
kernel /= np.sum(kernel)
return kernel
def make_gaussian_nonpositive_kernel(kernel_size: int, sd: float) -> Array:
kernel_size += int(not kernel_size % 2)
wing_size = int((kernel_size - 1) / 2)
xvals = np.arange(-wing_size, wing_size + 1)
kernel = norm.pdf(xvals, scale=sd)
# maxval, minval = np.max(kernel), np.min(kernel)
# diff = maxval - minval
# kernel += (diff / (1 - ratio))
kernel /= np.sum(kernel)
return kernel
def make_quadratic_kernel(kernel_size: int, ratio: float) -> Array:
# print(bound)
# step_size = 2 * bound / (kernel_size - 1)
kernel_size += int(not kernel_size % 2)
# print(kernel_size)
wing_size = int((kernel_size - 1) / 2)
# print(step_size)
# xvals = list(map(lambda i: i * step_size, range(-wing_size, wing_size + 1)))
# print(xvals)
kernel = np.array(
list(map(lambda x: float(-(x**2)), range(-wing_size, wing_size + 1)))
)
# print(kernel)
maxval, minval = np.max(kernel), np.min(kernel)
diff = maxval - minval
kernel += diff / (1 - ratio)
# print(kernel)
kernel /= np.sum(kernel)
# print(kernel)
return kernel
def min_avg_for_interval(filtered: Array, interval: int) -> float:
n = len(filtered)
avgs = [np.mean(filtered[range(start, n, interval)]) for start in range(interval)]
best = min(avgs)
return best, avgs.index(best)
def search_intervals(filtered: Array, min_interval: int, max_interval: int):
performance = [
(interval, *min_avg_for_interval(filtered, interval))
for interval in range(min_interval, max_interval + 1)
]
best = min(performance, key=lambda x: x[1])
return best[0], best[2]
def filter_array(
array: Array,
sum_filter: Array,
padding: Optional[Array] = None,
pad_value_function: Callable[[Array], float] = np.mean,
) -> Array:
if not sum_filter:
return array
fsize = len(sum_filter)
assert fsize % 2
if padding is None: # ensures that output size matches the input size
pad = int((fsize - 1) / 2)
padding = np.full(pad, pad_value_function(array))
return np.convolve(np.concatenate((padding, array, padding)), sum_filter, "valid")
FILTERS = {
"row": {1: make_gaussian_kernel(30, 6), 2: make_gaussian_kernel(20, 4)},
"col": {1: make_gaussian_kernel(70, 10), 2: None},
}
def get_lines_either(table_array: Array, horizontal=True) -> Array:
key = "row" if horizontal else "col"
THRESHOLD = 0.3
filters = FILTERS
sums = np.mean(table_array, axis=int(horizontal))
sums = np.maximum(sums, (sums < THRESHOLD))
# save_plot(rows, name=save_path / "rows", title="raw row averages")
filtered_sums = filter_array(sums, FILTERS[key][1]) # ROW_FILTER1)
filtered_sums = filter_array(sums, FILTERS[key][2]) # ROW_FILTER2)
lines = argrelextrema(filtered_sums, np.greater)[0]
return lines
def img_bytes_to_array(img_bytes: bytes) -> Array:
img_np = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
return img_np
def infer_lines(img: Array) -> dict[str, list[dict[str, int]]]:
h, w = img.shape
row_vals = get_lines_either(img, horizontal=True)
col_vals = get_lines_either(img, horizontal=False)
lines = [{"x1": 0, "y1": r, "x2": w, "y2": r} for r in row_vals] + [
{"x1": c, "y1": 0, "x2": c, "y2": h} for c in col_vals
]
return {"tableLines": lines, "imageInfo": {"height": h, "width": w}}