Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Automodel Design #146

Merged
merged 4 commits into from
Aug 6, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
improve automodel design
1. raise warning for no available models
2. identify models based on dataset names
lolipopshock committed Aug 5, 2022
commit 1731ae8ca19f08c344380b82f57d6bd2991c35f5
70 changes: 58 additions & 12 deletions src/layoutparser/models/auto_layoutmodel.py
Original file line number Diff line number Diff line change
@@ -13,26 +13,57 @@
# limitations under the License.

from typing import Optional, Dict, Union, List
from .detectron2.layoutmodel import Detectron2LayoutModel
from .paddledetection.layoutmodel import PaddleDetectionLayoutModel
from .effdet.layoutmodel import EfficientDetLayoutModel
from collections import defaultdict

from .model_config import (
is_lp_layout_model_config_any_format,
)
from ..file_utils import (
is_effdet_available,
is_detectron2_available,
is_paddle_available,
)

ALL_AVAILABLE_BACKENDS = dict()
ALL_AVAILABLE_DATASETS = defaultdict(list)

if is_effdet_available():
from .effdet.layoutmodel import EfficientDetLayoutModel
from .effdet.catalog import MODEL_CATALOG as _effdet_model_catalog

# fmt: off
ALL_AVAILABLE_BACKENDS[EfficientDetLayoutModel.DETECTOR_NAME] = EfficientDetLayoutModel
for dataset_name in _effdet_model_catalog:
ALL_AVAILABLE_DATASETS[dataset_name].append(EfficientDetLayoutModel.DETECTOR_NAME)
# fmt: on

if is_detectron2_available():
from .detectron2.layoutmodel import Detectron2LayoutModel
from .detectron2.catalog import MODEL_CATALOG as _detectron2_model_catalog

# fmt: off
ALL_AVAILABLE_BACKENDS[Detectron2LayoutModel.DETECTOR_NAME] = Detectron2LayoutModel
for dataset_name in _detectron2_model_catalog:
ALL_AVAILABLE_DATASETS[dataset_name].append(Detectron2LayoutModel.DETECTOR_NAME)
# fmt: on

if is_paddle_available():
from .paddledetection.layoutmodel import PaddleDetectionLayoutModel
from .paddledetection.catalog import MODEL_CATALOG as _paddle_model_catalog

ALL_AVAILABLE_BACKENDS = {
Detectron2LayoutModel.DETECTOR_NAME: Detectron2LayoutModel,
PaddleDetectionLayoutModel.DETECTOR_NAME: PaddleDetectionLayoutModel,
EfficientDetLayoutModel.DETECTOR_NAME: EfficientDetLayoutModel,
}
# fmt: off
ALL_AVAILABLE_BACKENDS[PaddleDetectionLayoutModel.DETECTOR_NAME] = PaddleDetectionLayoutModel
for dataset_name in _paddle_model_catalog:
ALL_AVAILABLE_DATASETS[dataset_name].append(PaddleDetectionLayoutModel.DETECTOR_NAME)
# fmt: on


def AutoLayoutModel(
config_path: str,
model_path: Optional[str] = None,
label_map: Optional[Dict]=None,
device: Optional[str]=None,
extra_config: Optional[Union[Dict, List]]=None,
label_map: Optional[Dict] = None,
device: Optional[str] = None,
extra_config: Optional[Union[Dict, List]] = None,
) -> "BaseLayoutModel":
"""[summary]
@@ -50,7 +81,7 @@ def AutoLayoutModel(
Defaults to `None`.
device(:obj:`str`, optional):
Whether to use cuda or cpu devices. If not set, LayoutParser will
automatically determine the device to initialize the models on.
automatically determine the device to initialize the models on.
extra_config (:obj:`dict`, optional):
Extra configuration passed used for initializing the layout model.
@@ -59,6 +90,8 @@ def AutoLayoutModel(
"""
if not is_lp_layout_model_config_any_format(config_path):
raise ValueError(f"Invalid model config_path {config_path}")

# Try to search for the model keywords
for backend_name in ALL_AVAILABLE_BACKENDS:
if backend_name in config_path:
return ALL_AVAILABLE_BACKENDS[backend_name](
@@ -68,3 +101,16 @@ def AutoLayoutModel(
extra_config=extra_config,
device=device,
)

# Try to search for the dataset keywords
for dataset_name in ALL_AVAILABLE_DATASETS:
if dataset_name in config_path:
return ALL_AVAILABLE_BACKENDS[ALL_AVAILABLE_DATASETS[dataset_name][0]](
config_path,
model_path=model_path,
label_map=label_map,
extra_config=extra_config,
device=device,
)

raise ValueError(f"No available model found for {config_path}")