Skip to content

Commit c6ab388

Browse files
authored
✨ add basic support for line items (#150)
1 parent 93ec83a commit c6ab388

File tree

4 files changed

+172
-8
lines changed

4 files changed

+172
-8
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Dict, List, Sequence
2+
3+
from mindee.documents.custom.custom_v1_fields import ListField, ListFieldValue
4+
from mindee.geometry import (
5+
Quadrilateral,
6+
get_bounding_box,
7+
get_min_max_y,
8+
is_point_in_y,
9+
merge_polygons,
10+
)
11+
12+
13+
def _array_product(array: Sequence[float]) -> float:
14+
"""
15+
Get the product of a sequence of floats.
16+
17+
:array: List of floats
18+
"""
19+
product = 1.0
20+
for k in array:
21+
product = product * k
22+
return product
23+
24+
25+
def _find_best_anchor(anchors: Sequence[str], fields: Dict[str, ListField]) -> str:
26+
"""
27+
Find the anchor with the most rows, in the order specified by `anchors`.
28+
29+
Anchor will be the name of the field.
30+
"""
31+
anchor = ""
32+
anchor_rows = 0
33+
for field in anchors:
34+
values = fields[field].values
35+
if len(values) > anchor_rows:
36+
anchor_rows = len(values)
37+
anchor = field
38+
return anchor
39+
40+
41+
def _get_empty_field() -> ListFieldValue:
42+
"""Return sample field with empty values."""
43+
return ListFieldValue({"content": "", "polygon": [], "confidence": 0.0})
44+
45+
46+
class Line:
47+
"""Represent a single line."""
48+
49+
row_number: int
50+
fields: Dict[str, ListFieldValue]
51+
bounding_box: Quadrilateral
52+
53+
54+
def get_line_items(
55+
anchors: Sequence[str], columns: Sequence[str], fields: Dict[str, ListField]
56+
) -> List[Line]:
57+
"""
58+
Reconstruct line items from fields.
59+
60+
:anchors: Possible fields to use as an anchor
61+
:columns: All fields which are columns
62+
:fields: List of field names to reconstruct table with
63+
"""
64+
line_items: List[Line] = []
65+
anchor = _find_best_anchor(anchors, fields)
66+
if not anchor:
67+
print(Warning("Could not find an anchor!"))
68+
return line_items
69+
70+
# Loop on anchor items and create an item for each anchor item.
71+
# This will create all rows with just the anchor column value.
72+
for item in fields[anchor].values:
73+
line_item = Line()
74+
line_item.fields = {f: _get_empty_field() for f in columns}
75+
line_item.fields[anchor] = item
76+
line_items.append(line_item)
77+
78+
# Loop on all created rows
79+
for idx, line in enumerate(line_items):
80+
# Compute sliding window between anchor item and the next
81+
min_y, _ = get_min_max_y(line.fields[anchor].polygon)
82+
if idx != len(line_items) - 1:
83+
max_y, _ = get_min_max_y(line_items[idx + 1].fields[anchor].polygon)
84+
else:
85+
max_y = 1.0 # bottom of page
86+
# Get candidates of each field included in sliding window and add it in line item
87+
for field in columns:
88+
field_words = [
89+
word
90+
for word in fields[field].values
91+
if is_point_in_y(word.polygon.centroid, min_y, max_y)
92+
]
93+
line.fields[field].content = " ".join([v.content for v in field_words])
94+
try:
95+
line.fields[field].polygon = merge_polygons(
96+
[v.polygon for v in field_words]
97+
)
98+
except ValueError:
99+
pass
100+
line.fields[field].confidence = _array_product(
101+
[v.confidence for v in field_words]
102+
)
103+
all_polygons = [line.fields[anchor].polygon]
104+
for field in columns:
105+
try:
106+
all_polygons.append(line.fields[field].polygon)
107+
except IndexError:
108+
pass
109+
line.bounding_box = get_bounding_box(merge_polygons(all_polygons))
110+
line.row_number = idx
111+
return line_items

mindee/geometry.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ class Quadrilateral(NamedTuple):
2828
bottom_left: Point
2929
"""Bottom left Point"""
3030

31+
@property
32+
def centroid(self) -> Point:
33+
"""The central point (centroid) of the quadrilateral."""
34+
return get_centroid(self)
35+
3136

3237
class BBox(NamedTuple):
3338
"""Contains exactly 4 coordinates."""
@@ -73,6 +78,11 @@ class Polygon(list):
7378
Inherits from base class ``list`` so is compatible with type ``Points``.
7479
"""
7580

81+
@property
82+
def centroid(self) -> Point:
83+
"""The central point (centroid) of the polygon."""
84+
return get_centroid(self)
85+
7686

7787
Points = Sequence[Point]
7888

@@ -132,9 +142,9 @@ def get_bbox(points: Points) -> BBox:
132142
return BBox(x_min, y_min, x_max, y_max)
133143

134144

135-
def get_bounding_box_for_polygons(vertices: Sequence[Polygon]) -> Quadrilateral:
145+
def merge_polygons(vertices: Sequence[Polygon]) -> Polygon:
136146
"""
137-
Given a sequence of polygons, calculate a bounding box that encompasses all polygons.
147+
Given a sequence of polygons, calculate a polygon box that encompasses all polygons.
138148
139149
:param vertices: List of polygons
140150
:return: A bounding box that encompasses all polygons
@@ -143,11 +153,13 @@ def get_bounding_box_for_polygons(vertices: Sequence[Polygon]) -> Quadrilateral:
143153
y_max = max(y for v in vertices for _, y in v)
144154
x_min = min(x for v in vertices for x, _ in v)
145155
x_max = max(x for v in vertices for x, _ in v)
146-
return Quadrilateral(
147-
Point(x_min, y_min),
148-
Point(x_max, y_min),
149-
Point(x_max, y_max),
150-
Point(x_min, y_max),
156+
return Polygon(
157+
[
158+
Point(x_min, y_min),
159+
Point(x_max, y_min),
160+
Point(x_max, y_max),
161+
Point(x_min, y_max),
162+
]
151163
)
152164

153165

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import json
2+
3+
from mindee.documents import CustomV1
4+
from mindee.documents.custom.line_items import get_line_items
5+
from tests import CUSTOM_DATA_DIR
6+
7+
8+
def test_single_table_01():
9+
json_data_path = f"{CUSTOM_DATA_DIR}/response_v1/line_items/single_table_01.json"
10+
json_data = json.load(open(json_data_path, "r"))
11+
doc = CustomV1(
12+
"field_test", api_prediction=json_data["document"]["inference"], page_n=None
13+
)
14+
anchors = ["beneficiary_birth_date"]
15+
columns = [
16+
"beneficiary_name",
17+
"beneficiary_birth_date",
18+
"beneficiary_rank",
19+
"beneficiary_number",
20+
]
21+
line_items = get_line_items(anchors, columns, doc.fields)
22+
assert len(line_items) == 3
23+
assert line_items[0].fields["beneficiary_name"].content == "JAMES BOND 007"
24+
assert line_items[0].fields["beneficiary_birth_date"].content == "1970-11-11"
25+
assert line_items[0].row_number == 0
26+
assert line_items[1].fields["beneficiary_name"].content == "HARRY POTTER"
27+
assert line_items[1].fields["beneficiary_birth_date"].content == "2010-07-18"
28+
assert line_items[1].row_number == 1
29+
assert line_items[2].fields["beneficiary_name"].content == "DRAGO MALFOY"
30+
assert line_items[2].fields["beneficiary_birth_date"].content == "2015-07-05"
31+
assert line_items[2].row_number == 2

tests/test_geometry.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,19 @@ def test_get_centroid(rectangle_a):
8787

8888

8989
def test_bounding_box_several_polygons(rectangle_b, quadrangle_a):
90-
assert geometry.get_bounding_box_for_polygons((rectangle_b, quadrangle_a)) == (
90+
merged = geometry.merge_polygons((rectangle_b, quadrangle_a))
91+
assert geometry.get_bounding_box(merged) == (
9192
(0.124, 0.407),
9293
(0.381, 0.407),
9394
(0.381, 0.546),
9495
(0.124, 0.546),
9596
)
97+
98+
99+
def test_polygon_merge(rectangle_b, quadrangle_a):
100+
assert geometry.merge_polygons((rectangle_b, quadrangle_a)) == [
101+
(0.124, 0.407),
102+
(0.381, 0.407),
103+
(0.381, 0.546),
104+
(0.124, 0.546),
105+
]

0 commit comments

Comments
 (0)