Skip to content

Commit d973f69

Browse files
committed
feat: add __iter__ support to LayoutElements and Elements for native iteration
1 parent 18c73ca commit d973f69

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

test_unstructured_inference/inference/test_layout_element.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion
2-
1+
from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements, TextRegion
2+
import numpy as np
33

44
def test_layout_element_do_dict(mock_layout_element):
55
expected = {
@@ -18,3 +18,33 @@ def test_layout_element_from_region(mock_rectangle):
1818
region = TextRegion(bbox=mock_rectangle)
1919

2020
assert LayoutElement.from_region(region) == expected
21+
22+
23+
def test_layout_elements_iter_support():
24+
coords = np.array([[0, 0, 100, 100]])
25+
texts = np.array(["sample"])
26+
probs = np.array([0.9])
27+
class_ids = np.array([0])
28+
class_id_map = {0: "Text"}
29+
sources = np.array(["test_source"])
30+
text_as_html = np.array(["<p>sample</p>"])
31+
table_as_cells = np.array([None])
32+
33+
layout_elements = LayoutElements(
34+
element_coords=coords,
35+
texts=texts,
36+
element_probs=probs,
37+
element_class_ids=class_ids,
38+
element_class_id_map=class_id_map,
39+
sources=sources,
40+
text_as_html=text_as_html,
41+
table_as_cells=table_as_cells,
42+
)
43+
44+
# New feature test: __iter__() works
45+
elements = list(layout_elements)
46+
assert len(elements) == 1
47+
assert isinstance(elements[0], LayoutElement)
48+
assert elements[0].text == "sample"
49+
assert elements[0].type == "Text"
50+

unstructured_inference/inference/elements.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def __post_init__(self):
228228

229229
def __getitem__(self, indices) -> TextRegions:
230230
return self.slice(indices)
231+
232+
def __iter__(self):
233+
return self.iter_elements()
231234

232235
def slice(self, indices) -> TextRegions:
233236
"""slice text regions based on indices"""

unstructured_inference/inference/layoutelement.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __eq__(self, other: object) -> bool:
7777

7878
def __getitem__(self, indices):
7979
return self.slice(indices)
80+
81+
def __iter__(self):
82+
return self.iter_elements()
8083

8184
def slice(self, indices) -> LayoutElements:
8285
"""slice and return only selected indices"""

0 commit comments

Comments
 (0)