Skip to content

Commit 2b54450

Browse files
authoredApr 3, 2022
Better layout comp (#128)
* Fix layout comp * more robust layout objects * improving loading dict
1 parent 7d5f97a commit 2b54450

File tree

5 files changed

+70
-25
lines changed

5 files changed

+70
-25
lines changed
 

‎src/layoutparser/elements/layout.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from typing import List, Union, Dict, Dict, Any, Optional
16-
from collections.abc import MutableSequence
16+
from collections.abc import MutableSequence, Iterable
1717
from copy import copy
1818

1919
import pandas as pd
@@ -47,6 +47,21 @@ class Layout(MutableSequence):
4747
"""
4848

4949
def __init__(self, blocks: Optional[List] = None, *, page_data: Dict = None):
50+
51+
if not (
52+
(blocks is None)
53+
or (isinstance(blocks, Iterable) and blocks.__class__.__name__ != "Layout")
54+
):
55+
56+
if blocks.__class__.__name__ == "Layout":
57+
error_msg = f"Please check the input: it should be lp.Layout([layout]) instead of lp.Layout(layout)"
58+
else:
59+
error_msg = f"Blocks should be a list of layout elements or empty (None), instead got {blocks}.\n"
60+
raise ValueError(error_msg)
61+
62+
if isinstance(blocks, tuple):
63+
blocks = list(blocks) # <- more robust handling for tuple-like inputs
64+
5065
self._blocks = blocks if blocks is not None else []
5166
self.page_data = page_data or {}
5267

@@ -76,10 +91,7 @@ def __repr__(self):
7691

7792
def __eq__(self, other):
7893
if isinstance(other, Layout):
79-
return (
80-
all((a, b) for a, b in zip(self, other))
81-
and self.page_data == other.page_data
82-
)
94+
return self._blocks == other._blocks and self.page_data == other.page_data
8395
else:
8496
return False
8597

‎src/layoutparser/io/basic.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def load_dict(data: Union[Dict, List[Dict]]) -> Union[BaseLayoutElement, Layout]
6969
if isinstance(data, dict):
7070
if "page_data" in data:
7171
# It is a layout instance
72-
return Layout(load_dict(data["blocks"]), page_data=data["page_data"])
72+
return Layout(load_dict(data["blocks"])._blocks, page_data=data["page_data"])
7373
else:
7474

7575
if data["block_type"] not in BASECOORD_ELEMENT_NAMEMAP:
@@ -140,7 +140,10 @@ def load_dataframe(df: pd.DataFrame, block_type: str = None) -> Layout:
140140
else:
141141
df["block_type"] = block_type
142142

143-
if "id" not in df.columns:
144-
df["id"] = df.index
145-
143+
print((df.columns), TextBlock._features, any(col in TextBlock._features for col in df.columns))
144+
if any(col in TextBlock._features for col in df.columns):
145+
# Automatically setting index for textblock
146+
if "id" not in df.columns:
147+
df["id"] = df.index
148+
146149
return load_dict(df.apply(lambda x: x.dropna().to_dict(), axis=1).to_list())

‎src/layoutparser/io/pdf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ def extract_words_for_page(
6565
)
6666

6767
page_tokens = load_dataframe(
68-
df.rename(
68+
df.reset_index().rename(
6969
columns={
7070
"x0": "x_1",
7171
"x1": "x_2",
7272
"top": "y_1",
7373
"bottom": "y_2",
74+
"index": "id",
7475
"fontname": "type", # also loading fontname as "type"
7576
}
7677
),

‎tests/test_elements.py

+43-14
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@
1616
import numpy as np
1717
import pandas as pd
1818

19-
from layoutparser.elements import (
20-
Interval,
21-
Rectangle,
22-
Quadrilateral,
23-
TextBlock,
24-
Layout
25-
)
19+
from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout
2620
from layoutparser.elements.errors import InvalidShapeError, NotSupportedShapeError
2721

2822

@@ -242,12 +236,25 @@ def test_layout():
242236
r = Rectangle(3, 3, 5, 6)
243237
t = TextBlock(i, id=1, type=2, text="12")
244238

239+
# Test Initializations
240+
l = Layout([i, q, r])
241+
l = Layout((i,q))
242+
Layout([l])
243+
with pytest.raises(ValueError):
244+
Layout(l)
245+
246+
# Test tuple-like inputs
247+
l = Layout((i, q, r))
248+
assert l._blocks == [i, q, r]
249+
l.append(i)
250+
251+
# Test apply functions
245252
l = Layout([i, q, r])
246253
l.get_texts()
247-
l.condition_on(i)
248-
l.relative_to(q)
249-
l.filter_by(t)
250-
l.is_in(r)
254+
assert l.filter_by(t) == Layout([i])
255+
assert l.condition_on(i) == Layout([block.condition_on(i) for block in [i, q, r]])
256+
assert l.relative_to(q) == Layout([block.relative_to(q) for block in [i, q, r]])
257+
assert l.is_in(r) == Layout([block.is_in(r) for block in [i, q, r]])
251258
assert l.get_homogeneous_blocks() == [i.to_quadrilateral(), q, r.to_quadrilateral()]
252259

253260
i2 = TextBlock(i, id=1, type=2, text="12")
@@ -286,17 +293,39 @@ def test_layout():
286293
l + l2
287294

288295
# Test sort
296+
## When sorting inplace, it should return None
297+
l = Layout([i])
298+
assert l.sort(key=lambda x: x.coordinates[1], inplace=True) is None
299+
300+
## Make sure only sorting inplace works
289301
l = Layout([i, i.shift(2)])
290302
l.sort(key=lambda x: x.coordinates[1], reverse=True)
303+
assert l != Layout([i.shift(2), i])
304+
l.sort(key=lambda x: x.coordinates[1], reverse=True, inplace=True)
291305
assert l == Layout([i.shift(2), i])
292306

293307
l = Layout([q, r, i], page_data={"width": 200, "height": 400})
294-
assert l.sort(key=lambda x: x.coordinates[0], inplace=False) == Layout(
308+
assert l.sort(key=lambda x: x.coordinates[0]) == Layout(
295309
[i, q, r], page_data={"width": 200, "height": 400}
296310
)
297311

298312
l = Layout([q, t])
299-
assert l.sort(key=lambda x: x.coordinates[0], inplace=False) == Layout([q, t])
313+
assert l.sort(key=lambda x: x.coordinates[0]) == Layout([t, q])
314+
315+
316+
def test_layout_comp():
317+
a = Layout([Rectangle(1, 2, 3, 4)])
318+
b = Layout([Rectangle(1, 2, 3, 4)])
319+
320+
assert a == b
321+
322+
a.append(Rectangle(1, 2, 3, 5))
323+
assert a != b
324+
b.append(Rectangle(1, 2, 3, 5))
325+
assert a == b
326+
327+
a = Layout([TextBlock(Rectangle(1, 2, 3, 4))])
328+
assert a != b
300329

301330

302331
def test_shape_operations():
@@ -428,4 +457,4 @@ def test_dict():
428457

429458
l2 = Layout([i2, r2, q2])
430459
l2_dict = {"page_data": {}, "blocks": [i_dict, r_dict, q_dict]}
431-
assert l2.to_dict() == l2_dict
460+
assert l2.to_dict() == l2_dict

‎tests/test_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_csv():
6060
_l.page_data = {"width": 200, "height": 200}
6161
assert _l == l
6262

63-
i2 = TextBlock(i, "")
63+
i2 = i # <- Allow mixmode loading
6464
r2 = TextBlock(r, id=24)
6565
q2 = TextBlock(q, text="test", parent=45)
6666
l2 = Layout([i2, r2, q2])

0 commit comments

Comments
 (0)
Please sign in to comment.