Skip to content

Commit b4b4fea

Browse files
authoredSep 9, 2021
Add efficientdet models (#67)
* add effdet check * Add effdet models and catalogs * clean-up * register effdet models * Add generalized image type support for layout models * Add effdet tests * Update reqs
1 parent 9b73ff1 commit b4b4fea

File tree

12 files changed

+402
-24
lines changed

12 files changed

+402
-24
lines changed
 

‎dev-requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytest
2+
torch
23
numpy
34
opencv-python
45
pandas
@@ -11,4 +12,5 @@ google-cloud-vision==1
1112
pytesseract
1213
pycocotools
1314
git+https://github.com/facebookresearch/detectron2.git@v0.4#egg=detectron2
14-
paddlepaddle
15+
paddlepaddle
16+
effdet

‎setup.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
"pandas",
2525
"pillow",
2626
"pyyaml>=5.1",
27-
"torch",
2827
"torchvision",
2928
"iopath",
3029
],
@@ -33,6 +32,10 @@
3332
'google-cloud-vision==1',
3433
'pytesseract'
3534
],
35+
"effdet": [
36+
"torch",
37+
"effdet"
38+
]
3639
},
3740
include_package_data=True
3841
)

‎src/layoutparser/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
_LazyModule,
77
is_detectron2_available,
88
is_paddle_available,
9+
is_effdet_available,
910
is_pytesseract_available,
1011
is_gcv_available,
1112
)
@@ -45,6 +46,9 @@
4546
if is_paddle_available():
4647
_import_structure["models.paddledetection"] = ["PaddleDetectionLayoutModel"]
4748

49+
if is_effdet_available():
50+
_import_structure["models.effdet"] = ["EfficientDetLayoutModel"]
51+
4852
if is_pytesseract_available():
4953
_import_structure["ocr.tesseract_agent"] = [
5054
"TesseractAgent",

‎src/layoutparser/file_utils.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,18 @@
3939
# The name of the paddlepaddle library:
4040
# Install name: pip install paddlepaddle
4141
# Import name: import paddle
42-
_paddle_version = importlib_metadata.version("paddlepaddle")
42+
_paddle_version = importlib_metadata.version("paddlepaddle")
4343
logger.debug(f"Paddle version {_paddle_version} available.")
4444
except importlib_metadata.PackageNotFoundError:
4545
_paddle_available = False
4646

47+
_effdet_available = importlib.util.find_spec("effdet") is not None
48+
try:
49+
_effdet_version = importlib_metadata.version("effdet")
50+
logger.debug(f"Effdet version {_effdet_version} available.")
51+
except importlib_metadata.PackageNotFoundError:
52+
_effdet_version = False
53+
4754
###########################################
4855
############## OCR Tool Deps ##############
4956
###########################################
@@ -78,12 +85,16 @@ def is_torch_cuda_available():
7885
return False
7986

8087

88+
def is_detectron2_available():
89+
return _detectron2_available
90+
91+
8192
def is_paddle_available():
8293
return _paddle_available
8394

8495

85-
def is_detectron2_available():
86-
return _detectron2_available
96+
def is_effdet_available():
97+
return _effdet_available
8798

8899

89100
def is_pytesseract_available():
@@ -111,6 +122,11 @@ def is_gcv_available():
111122
installation page: https://github.com/PaddlePaddle/Paddle and follow the ones that match your environment.
112123
"""
113124

125+
EFFDET_IMPORT_ERROR = """
126+
{0} requires the effdet library but it was not found in your environment. You can install it with pip:
127+
`pip install effdet`
128+
"""
129+
114130
PYTESSERACT_IMPORT_ERROR = """
115131
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
116132
`pip install pytesseract`
@@ -126,6 +142,7 @@ def is_gcv_available():
126142
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
127143
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
128144
("paddle", (is_paddle_available, PADDLE_IMPORT_ERROR)),
145+
("effdet", (is_effdet_available, )),
129146
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
130147
("google-cloud-vision", (is_gcv_available, GCV_IMPORT_ERROR)),
131148
]
@@ -172,7 +189,7 @@ def __init__(
172189
self._import_structure = import_structure
173190

174191
# Following [PEP 366](https://www.python.org/dev/peps/pep-0366/)
175-
# The __package__ variable should be set
192+
# The __package__ variable should be set
176193
# https://docs.python.org/3/reference/import.html#__package__
177194
self.__package__ = self.__name__
178195

@@ -198,4 +215,4 @@ def _get_module(self, module_name: str):
198215
return importlib.import_module("." + module_name, self.__name__)
199216

200217
def __reduce__(self):
201-
return (self.__class__, (self._name, self.__file__, self._import_structure))
218+
return (self.__class__, (self._name, self.__file__, self._import_structure))

‎src/layoutparser/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .detectron2.layoutmodel import Detectron2LayoutModel
2-
from .paddledetection.layoutmodel import PaddleDetectionLayoutModel
2+
from .paddledetection.layoutmodel import PaddleDetectionLayoutModel
3+
from .effdet.layoutmodel import EfficientDetLayoutModel

‎src/layoutparser/models/base_layoutmodel.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
from typing import Union
12
from abc import ABC, abstractmethod
23

34
from ..file_utils import requires_backends
45

56

67
class BaseLayoutModel(ABC):
7-
88
@property
99
@abstractmethod
1010
def DETECTOR_NAME(self):
1111
pass
12-
12+
13+
@abstractmethod
14+
def detect(self, image):
15+
pass
16+
1317
@abstractmethod
14-
def detect(self):
18+
def image_loader(self, image: Union["ndarray", "Image"]):
19+
"""It will process the input images appropriately to the target format.
20+
"""
1521
pass
1622

1723
# Add lazy loading mechanisms for layout models, refer to

‎src/layoutparser/models/detectron2/layoutmodel.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Union
12
from PIL import Image
23
import numpy as np
34

@@ -41,7 +42,7 @@ class Detectron2LayoutModel(BaseLayoutModel):
4142
4243
Examples::
4344
>>> import layoutparser as lp
44-
>>> model = lp.models.Detectron2LayoutModel('lp://HJDataset/faster_rcnn_R_50_FPN_3x/config')
45+
>>> model = lp.Detectron2LayoutModel('lp://HJDataset/faster_rcnn_R_50_FPN_3x/config')
4546
>>> model.detect(image)
4647
4748
"""
@@ -108,7 +109,7 @@ def _reconstruct_path_with_detector_name(self, path: str) -> str:
108109
model_name_segments = model_name.split("/")
109110
if (
110111
len(model_name_segments) == 3
111-
and "detectron2" not in model_name_segments
112+
and self.DETECTOR_NAME not in model_name_segments
112113
):
113114
return "lp://" + self.DETECTOR_NAME + "/" + path[len("lp://") :]
114115
return path
@@ -148,12 +149,16 @@ def detect(self, image):
148149
:obj:`~layoutparser.Layout`: The detected layout of the input image
149150
"""
150151

152+
image = self.image_loader(image)
153+
outputs = self.model(image)
154+
layout = self.gather_output(outputs)
155+
return layout
156+
157+
def image_loader(self, image: Union["np.ndarray", "Image.Image"]):
151158
# Convert PIL Image Input
152159
if isinstance(image, Image.Image):
153160
if image.mode != "RGB":
154161
image = image.convert("RGB")
155162
image = np.array(image)
156163

157-
outputs = self.model(image)
158-
layout = self.gather_output(outputs)
159-
return layout
164+
return image
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import catalog as _UNUSED
2+
from .layoutmodel import EfficientDetLayoutModel
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from iopath.common.file_io import PathHandler
2+
3+
from ..base_catalog import PathManager
4+
5+
MODEL_CATALOG = {
6+
"PubLayNet": {
7+
"tf_efficientdet_d0": "https://www.dropbox.com/s/ukbw5s673633hsw/publaynet-tf_efficientdet_d0.pth.tar?dl=1",
8+
"tf_efficientdet_d1": "https://www.dropbox.com/s/gxy11xkkiwnpgog/publaynet-tf_efficientdet_d1.pth.tar?dl=1"
9+
},
10+
"MFD": {
11+
"tf_efficientdet_d0": "https://www.dropbox.com/s/dkr22iux7thlhel/mfd-tf_efficientdet_d0.pth.tar?dl=1",
12+
"tf_efficientdet_d1": "https://www.dropbox.com/s/icmbiaqr5s9bz1x/mfd-tf_efficientdet_d1.pth.tar?dl=1"
13+
}
14+
}
15+
16+
# In effdet training scripts, it requires the label_map starting
17+
# from 1 instead of 0
18+
LABEL_MAP_CATALOG = {
19+
"PubLayNet": {
20+
1: "Text",
21+
2: "Title",
22+
3: "List",
23+
4: "Table",
24+
5: "Figure"
25+
}
26+
}
27+
28+
class LayoutParserEfficientDetModelHandler(PathHandler):
29+
"""
30+
Resolve anything that's in LayoutParser model zoo.
31+
"""
32+
33+
PREFIX = "lp://efficientdet/"
34+
35+
def _get_supported_prefixes(self):
36+
return [self.PREFIX]
37+
38+
def _get_local_path(self, path, **kwargs):
39+
model_name = path[len(self.PREFIX) :]
40+
41+
dataset_name, *model_name, data_type = model_name.split("/")
42+
43+
if data_type == "weight":
44+
model_url = MODEL_CATALOG[dataset_name]["/".join(model_name)]
45+
else:
46+
raise ValueError(f"Unknown data_type {data_type}")
47+
return PathManager.get_local_path(model_url, **kwargs)
48+
49+
def _open(self, path, mode="r", **kwargs):
50+
return PathManager.open(self._get_local_path(path), mode, **kwargs)
51+
52+
53+
PathManager.register_handler(LayoutParserEfficientDetModelHandler())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from typing import List, Optional, Union, Dict, Any, Tuple
2+
3+
from PIL import Image
4+
import numpy as np
5+
6+
from .catalog import PathManager, LABEL_MAP_CATALOG
7+
from ..base_layoutmodel import BaseLayoutModel
8+
from ...elements import Rectangle, TextBlock, Layout
9+
10+
from ...file_utils import is_effdet_available, is_torch_cuda_available
11+
12+
if is_effdet_available():
13+
import torch
14+
from effdet import create_model
15+
from effdet.data.transforms import (
16+
IMAGENET_DEFAULT_MEAN,
17+
IMAGENET_DEFAULT_STD,
18+
transforms_coco_eval,
19+
)
20+
21+
22+
class InputTransform:
23+
def __init__(
24+
self,
25+
image_size,
26+
mean=IMAGENET_DEFAULT_MEAN,
27+
std=IMAGENET_DEFAULT_STD,
28+
):
29+
30+
self.mean = mean
31+
self.std = std
32+
33+
self.transform = transforms_coco_eval(
34+
image_size,
35+
interpolation="bilinear",
36+
use_prefetcher=True,
37+
fill_color="mean",
38+
mean=self.mean,
39+
std=self.std,
40+
)
41+
42+
self.mean_tensor = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
43+
self.std_tensor = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
44+
45+
def preprocess(self, image: Image) -> Tuple[torch.Tensor, Dict]:
46+
47+
image = image.convert("RGB")
48+
image_info = {"img_size": image.size}
49+
50+
input, image_info = self.transform(image, image_info)
51+
image_info = {
52+
key: torch.tensor(val).unsqueeze(0) for key, val in image_info.items()
53+
}
54+
55+
input = torch.tensor(input).unsqueeze(0)
56+
input = input.float().sub_(self.mean_tensor).div_(self.std_tensor)
57+
58+
return input, image_info
59+
60+
61+
class EfficientDetLayoutModel(BaseLayoutModel):
62+
"""Create a EfficientDet-based Layout Detection Model
63+
64+
Args:
65+
config_path (:obj:`str`):
66+
The path to the configuration file.
67+
model_path (:obj:`str`, None):
68+
The path to the saved weights of the model.
69+
If set, overwrite the weights in the configuration file.
70+
Defaults to `None`.
71+
label_map (:obj:`dict`, optional):
72+
The map from the model prediction (ids) to real
73+
word labels (strings). If the config is from one of the supported
74+
datasets, Layout Parser will automatically initialize the label_map.
75+
Defaults to `None`.
76+
enforce_cpu(:obj:`bool`, optional):
77+
When set to `True`, it will enforce using cpu even if it is on a CUDA
78+
available device.
79+
extra_config (:obj:`dict`, optional):
80+
Extra configuration passed to the EfficientDet model
81+
configuration. Currently supported arguments:
82+
num_classes: specifying the number of classes for the models
83+
output_confidence_threshold: minmum object prediction confidence to retain
84+
85+
Examples::
86+
>>> import layoutparser as lp
87+
>>> model = lp.EfficientDetLayoutModel("lp://PubLayNet/tf_efficientdet_d0/config")
88+
>>> model.detect(image)
89+
90+
"""
91+
92+
DEPENDENCIES = ["effdet"]
93+
DETECTOR_NAME = "efficientdet"
94+
95+
DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.25
96+
97+
def __init__(
98+
self,
99+
config_path: str,
100+
model_path: str = None,
101+
label_map: Optional[Dict] = None,
102+
extra_config: Optional[Dict] = None,
103+
enforce_cpu: bool = False,
104+
device: str = None,
105+
):
106+
107+
if is_torch_cuda_available():
108+
if device is None:
109+
device = "cuda"
110+
else:
111+
device = "cpu"
112+
self.device = device
113+
114+
extra_config = extra_config if extra_config is not None else {}
115+
116+
self._initialize_model(config_path, model_path, label_map, extra_config)
117+
118+
self.output_confidence_threshold = extra_config.get(
119+
"output_confidence_threshold", self.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
120+
)
121+
122+
self.preprocessor = InputTransform(self.config.image_size)
123+
124+
def _initialize_model(
125+
self,
126+
config_path: str,
127+
model_path: Optional[str],
128+
label_map: Optional[Dict],
129+
extra_config: Optional[Dict],
130+
):
131+
132+
if config_path.startswith("lp://"):
133+
# If it's officially supported by layoutparser
134+
dataset_name, model_name = config_path.lstrip("lp://").split("/")[0:2]
135+
136+
if label_map is None:
137+
label_map = LABEL_MAP_CATALOG[dataset_name]
138+
num_classes = len(label_map)
139+
140+
if model_path is None:
141+
# Download the models when it model_path is not specified
142+
model_path = PathManager.get_local_path(
143+
self._reconstruct_path_with_detector_name(
144+
config_path.replace("config", "weight")
145+
)
146+
)
147+
148+
self.model = create_model(
149+
model_name,
150+
num_classes=num_classes,
151+
bench_task="predict",
152+
pretrained=True,
153+
checkpoint_path=model_path,
154+
)
155+
else:
156+
assert (
157+
model_path is not None
158+
), f"When the specified model is not layoutparser-based, you need to specify the model_path"
159+
160+
assert (
161+
label_map is not None or "num_classes" in extra_config
162+
), "When the specified model is not layoutparser-based, you need to specify the label_map or add num_classes in the extra_config"
163+
164+
model_name = config_path
165+
model_path = PathManager.get_local_path(
166+
model_path
167+
) # It might be an https URL
168+
169+
num_classes = len(label_map) if label_map else extra_config["num_classes"]
170+
171+
self.model = create_model(
172+
model_name,
173+
num_classes=num_classes,
174+
bench_task="predict",
175+
pretrained=True,
176+
checkpoint_path=model_path,
177+
)
178+
179+
self.model.to(self.device)
180+
self.model.eval()
181+
self.config = self.model.config
182+
self.label_map = label_map if label_map is not None else {}
183+
184+
def _reconstruct_path_with_detector_name(self, path: str) -> str:
185+
"""This function will add the detector name (efficientdet) into the
186+
lp model config path to get the "canonical" model name.
187+
188+
Args:
189+
path (str): The given input path that might or might not contain the detector name.
190+
191+
Returns:
192+
str: a modified path that contains the detector name.
193+
"""
194+
if path.startswith("lp://"): # TODO: Move "lp://" to a constant
195+
model_name = path[len("lp://") :]
196+
model_name_segments = model_name.split("/")
197+
if (
198+
len(model_name_segments) == 3
199+
and self.DETECTOR_NAME not in model_name_segments
200+
):
201+
return "lp://" + self.DETECTOR_NAME + "/" + path[len("lp://") :]
202+
return path
203+
204+
def detect(self, image: Union["np.ndarray", "Image.Image"]):
205+
206+
image = self.image_loader(image)
207+
208+
model_inputs, image_info = self.preprocessor.preprocess(image)
209+
210+
model_outputs = self.model(
211+
model_inputs.to(self.device),
212+
{key: val.to(self.device) for key, val in image_info.items()},
213+
)
214+
215+
layout = self.gather_output(model_outputs)
216+
return layout
217+
218+
def gather_output(self, model_outputs: torch.Tensor) -> Layout:
219+
220+
model_outputs = model_outputs.cpu().detach()
221+
box_predictions = Layout()
222+
223+
for index, sample in enumerate(model_outputs):
224+
sample[:, 2] -= sample[:, 0]
225+
sample[:, 3] -= sample[:, 1]
226+
227+
for det in sample:
228+
229+
score = float(det[4])
230+
pred_cat = int(det[5])
231+
x, y, w, h = det[0:4].tolist()
232+
233+
if (
234+
score < self.output_confidence_threshold
235+
): # stop when below this threshold, scores in descending order
236+
break
237+
238+
box_predictions.append(
239+
TextBlock(
240+
block=Rectangle(x, y, w + x, h + y),
241+
score=score,
242+
id=index,
243+
type=self.label_map.get(pred_cat, pred_cat),
244+
)
245+
)
246+
247+
return box_predictions
248+
249+
def image_loader(self, image: Union["np.ndarray", "Image.Image"]):
250+
251+
# Convert cv2 Image Input
252+
if isinstance(image, np.ndarray):
253+
# In this case, we assume the image is loaded by cv2
254+
# and the channel order is BGR
255+
image = image[..., ::-1]
256+
image = Image.fromarray(image, mode="RGB")
257+
258+
return image

‎src/layoutparser/models/paddledetection/layoutmodel.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class PaddleDetectionLayoutModel(BaseLayoutModel):
8181
Examples::
8282
>>> import layoutparser as lp
8383
>>> model = lp.models.PaddleDetectionLayoutModel('
84-
lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config')
84+
lp://PubLayNet/ppyolov2_r50vd_dcn_365e/config')
8585
>>> model.detect(image)
8686
8787
"""
@@ -155,7 +155,7 @@ def _reconstruct_path_with_detector_name(self, path: str) -> str:
155155
model_name_segments = model_name.split("/")
156156
if (
157157
len(model_name_segments) == 3
158-
and "paddledetection" not in model_name_segments
158+
and self.DETECTOR_NAME not in model_name_segments
159159
):
160160
return "lp://" + self.DETECTOR_NAME + "/" + path[len("lp://") :]
161161
return path
@@ -276,10 +276,7 @@ def detect(self, image):
276276
"""
277277

278278
# Convert PIL Image Input
279-
if isinstance(image, Image.Image):
280-
if image.mode != "RGB":
281-
image = image.convert("RGB")
282-
image = np.array(image)
279+
image = self.image_loader(image)
283280

284281
inputs = self.preprocess(image)
285282

@@ -295,4 +292,13 @@ def detect(self, image):
295292
np_boxes = boxes_tensor.copy_to_cpu()
296293

297294
layout = self.gather_output(np_boxes)
298-
return layout
295+
return layout
296+
297+
def image_loader(self, image: Union["np.ndarray", "Image.Image"]):
298+
299+
if isinstance(image, Image.Image):
300+
if image.mode != "RGB":
301+
image = image.convert("RGB")
302+
image = np.array(image)
303+
304+
return image

‎tests/test_model.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
"lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config",
2323
]
2424

25+
ALL_EFFDET_MODEL_CONFIGS = [
26+
"lp://PubLayNet/tf_efficientdet_d0/config",
27+
"lp://PubLayNet/tf_efficientdet_d1/config",
28+
"lp://MFD/tf_efficientdet_d0/config",
29+
"lp://MFD/tf_efficientdet_d1/config",
30+
]
31+
2532
def test_Detectron2Model(is_large_scale=False):
2633

2734
if is_large_scale:
@@ -72,4 +79,18 @@ def test_PaddleDetectionModel(is_large_scale=False):
7279
# Test in enforce CPU mode
7380
model = PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", enforce_cpu=True)
7481
image = cv2.imread("tests/fixtures/model/test_model_image.jpg")
75-
layout = model.detect(image)
82+
layout = model.detect(image)
83+
84+
def test_EffDetModel(is_large_scale=False):
85+
86+
if is_large_scale:
87+
88+
for config in ALL_EFFDET_MODEL_CONFIGS:
89+
model = EfficientDetLayoutModel(config)
90+
91+
image = cv2.imread("tests/fixtures/model/test_model_image.jpg")
92+
layout = model.detect(image)
93+
else:
94+
model = EfficientDetLayoutModel("lp://PubLayNet/tf_efficientdet_d0/config")
95+
image = cv2.imread("tests/fixtures/model/test_model_image.jpg")
96+
layout = model.detect(image)

0 commit comments

Comments
 (0)
Please sign in to comment.