Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sub modules #78

Merged
merged 22 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor functionality to locations under new mod struct
  • Loading branch information
RNKuhns committed Dec 3, 2022
commit 38e19b97599cbfed7190e1fa4d57801127a28fd5
12 changes: 10 additions & 2 deletions skbase/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
"""Custom exceptions used in ``skbase``."""
from typing import List

__author__: List[str] = ["mloning", "rnkuhns"]
__all__: List[str] = ["NotFittedError"]
__author__: List[str] = ["fkiraly", "mloning", "rnkuhns"]
__all__: List[str] = ["FixtureGenerationError", "NotFittedError"]


class FixtureGenerationError(Exception):
"""Raised when a fixture fails to generate."""

def __init__(self, fixture_name="", err=None):
self.fixture_name = fixture_name
super().__init__(f"fixture {fixture_name} failed to generate. {err}")


class NotFittedError(ValueError, AttributeError):
Expand Down
71 changes: 1 addition & 70 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List

from skbase.base._base import BaseEstimator
from skbase.utils._nested_seq import flatten, is_flat, unflatten

__author__: List[str] = ["mloning", "fkiraly"]
__all__: List[str] = ["BaseMetaEstimator"]
Expand Down Expand Up @@ -589,76 +590,6 @@ def _tagchain_is_linked_set(
self.set_tags(**{mid_tag_name: mid_tag_val_not})


def flatten(obj):
"""Flatten nested list/tuple structure.

Parameters
----------
obj: nested list/tuple structure

Returns
-------
list or tuple, tuple if obj was tuple, list otherwise
flat iterable, containing non-list/tuple elements in obj in same order as in obj

Example
-------
>>> flatten([1, 2, [3, (4, 5)], 6])
[1, 2, 3, 4, 5, 6]
"""
if not isinstance(obj, (list, tuple)):
return [obj]
else:
return type(obj)([y for x in obj for y in flatten(x)])


def unflatten(obj, template):
"""Invert flattening, given template for nested list/tuple structure.

Parameters
----------
obj : list or tuple of elements
template : nested list/tuple structure
number of non-list/tuple elements of obj and template must be equal

Returns
-------
rest : list or tuple of elements
has element bracketing exactly as `template`
and elements in sequence exactly as `obj`

Example
-------
>>> unflatten([1, 2, 3, 4, 5, 6], [6, 3, [5, (2, 4)], 1])
[1, 2, [3, (4, 5)], 6]
"""
if not isinstance(template, (list, tuple)):
return obj[0]

list_or_tuple = type(template)
ls = [unflat_len(x) for x in template]
for i in range(1, len(ls)):
ls[i] += ls[i - 1]
ls = [0] + ls

res = [unflatten(obj[ls[i] : ls[i + 1]], template[i]) for i in range(len(ls) - 1)]

return list_or_tuple(res)


def unflat_len(obj):
"""Return number of non-list/tuple elements in obj."""
if not isinstance(obj, (list, tuple)):
return 1
else:
return sum([unflat_len(x) for x in obj])


def is_flat(obj):
"""Check whether list or tuple is flat, returns true if yes, false if nested."""
return not any(isinstance(x, (list, tuple)) for x in obj)


class _HeterogenousMetaEstimator(BaseMetaEstimator):
"""Handles parameter management for estimators composed of named estimators.

Expand Down
50 changes: 3 additions & 47 deletions skbase/testing/utils/_conditional_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,14 @@
from copy import deepcopy
from typing import Callable, Dict, List

import numpy as np
from skbase._exceptions import FixtureGenerationError
from skbase.utils._misc import _remove_single
from skbase.validate._types import _check_list_of_str

__author__: List[str] = ["fkiraly"]
__all__: List[str] = ["create_conditional_fixtures_and_names"]


class FixtureGenerationError(Exception):
"""Raised when a fixture fails to generate."""

def __init__(self, fixture_name="", err=None):
self.fixture_name = fixture_name
super().__init__(f"fixture {fixture_name} failed to generate. {err}")


def create_conditional_fixtures_and_names(
test_name: str,
fixture_vars: List[str],
Expand Down Expand Up @@ -206,41 +200,3 @@ def get_fixtures(fixture_var, **kwargs):
fixture_prod = [deepcopy(x) for x in fixture_prod]

return fixture_param_str, fixture_prod, fixture_names


def _check_list_of_str(obj, name="obj"):
"""Check whether obj is a list of str.

Parameters
----------
obj : any object, check whether is list of str
name : str, default="obj", name of obj to display in error message

Returns
-------
obj, unaltered

Raises
------
TypeError if obj is not list of str
"""
if not isinstance(obj, list) or not np.all(isinstance(x, str) for x in obj):
raise TypeError(f"{name} must be a list of str")
return obj


def _remove_single(x):
"""Remove tuple wrapping from singleton.

Parameters
----------
x : tuple

Returns
-------
x[0] if x is a singleton, otherwise x
"""
if len(x) == 1:
return x[0]
else:
return x
21 changes: 21 additions & 0 deletions skbase/validate/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@
__all__: List[str] = []


def _check_list_of_str(obj, name="obj"):
"""Check whether obj is a list of str.

Parameters
----------
obj : any object, check whether is list of str
name : str, default="obj", name of obj to display in error message

Returns
-------
obj, unaltered

Raises
------
TypeError if obj is not list of str
"""
if not isinstance(obj, list) or not all(isinstance(x, str) for x in obj):
raise TypeError(f"{name} must be a list of str")
return obj


def _check_list_of_str_or_error(arg_to_check, arg_name):
"""Check that certain arguments are str or list of str.

Expand Down