Skip to content

Commit

Permalink
Fix bug that made List and Dict unusable types on python 3.7 and 3.8 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pereman2 authored Nov 16, 2020
1 parent 69b8a37 commit 800e8a7
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 34 deletions.
4 changes: 2 additions & 2 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def get_dict_key_value_types(ref_type: Any) -> Tuple[Any, Any]:

key_type: Any
element_type: Any
if ref_type is None:
if ref_type is None or ref_type == Dict:
key_type = Any
element_type = Any
else:
Expand All @@ -491,7 +491,7 @@ def valid_value_annotation_type(type_: Any) -> bool:


def _valid_dict_key_annotation_type(type_: Any) -> bool:
return type_ is Any or issubclass(type_, (str, Enum))
return type_ is None or type_ is Any or issubclass(type_, (str, Enum))


def is_primitive_type(type_: Any) -> bool:
Expand Down
1 change: 0 additions & 1 deletion omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def __setstate__(self, d: Dict[str, Any]) -> None:
d["_metadata"].ref_type = List[element_type] # type: ignore
else:
assert False

self.__dict__.update(d)

@abstractmethod
Expand Down
1 change: 1 addition & 0 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def _create_impl( # noqa F811
element_type = get_list_element_type(ref_type)
return ListConfig(
element_type=element_type,
ref_type=ref_type,
content=obj,
parent=parent,
flags=flags,
Expand Down
12 changes: 12 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,23 @@ class Package:
modules: List[Module] = MISSING


@dataclass
class UntypedList:
list: List = field(default_factory=lambda: [1, 2]) # type: ignore
opt_list: Optional[List] = None # type: ignore


@dataclass
class SubscriptedList:
list: List[int] = field(default_factory=lambda: [1, 2])


@dataclass
class UntypedDict:
dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore
opt_dict: Optional[Dict] = None # type: ignore


@dataclass
class SubscriptedDict:
dict: Dict[str, int] = field(default_factory=lambda: {"foo": 4})
Expand Down
17 changes: 12 additions & 5 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class AnyTypeConfig:

@attr.s(auto_attribs=True)
class BoolConfig:

# with default value
with_default: bool = True

Expand All @@ -82,7 +81,6 @@ class BoolConfig:

@attr.s(auto_attribs=True)
class IntegersConfig:

# with default value
with_default: int = 10

Expand All @@ -98,7 +96,6 @@ class IntegersConfig:

@attr.s(auto_attribs=True)
class StringConfig:

# with default value
with_default: str = "foo"

Expand All @@ -114,7 +111,6 @@ class StringConfig:

@attr.s(auto_attribs=True)
class FloatConfig:

# with default value
with_default: float = 0.10

Expand All @@ -130,7 +126,6 @@ class FloatConfig:

@attr.s(auto_attribs=True)
class EnumConfig:

# with default value
with_default: Color = Color.BLUE

Expand Down Expand Up @@ -485,3 +480,15 @@ class MissingStructuredConfigField:
class ListClass:
list: List[int] = []
tuple: Tuple[int, int] = (1, 2)


@attr.s(auto_attribs=True)
class UntypedList:
list: List = [1, 2] # type: ignore
opt_list: Optional[List] = None # type: ignore


@attr.s(auto_attribs=True)
class UntypedDict:
dict: Dict = {"foo": "var"} # type: ignore
opt_dict: Optional[Dict] = None # type: ignore
12 changes: 12 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,15 @@ class MissingStructuredConfigField:
class ListClass:
list: List[int] = field(default_factory=lambda: [])
tuple: Tuple[int, int] = field(default_factory=lambda: (1, 2))


@dataclass
class UntypedList:
list: List = field(default_factory=lambda: [1, 2]) # type: ignore
opt_list: Optional[List] = None # type: ignore


@dataclass
class UntypedDict:
dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore
opt_dict: Optional[Dict] = None # type: ignore
20 changes: 19 additions & 1 deletion tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from importlib import import_module
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import pytest

Expand Down Expand Up @@ -811,6 +812,23 @@ def test_recursive_list(self, class_type: str) -> None:
cfg = OmegaConf.structured(o)
assert cfg == {"d": [{"d": "???"}, {"d": "???"}]}

def test_create_untyped_dict(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg = OmegaConf.structured(module.UntypedDict)
dt = Dict[Union[str, Enum], Any]
assert _utils.get_ref_type(cfg, "dict") == dt
assert _utils.get_ref_type(cfg, "opt_dict") == Optional[dt]
assert cfg.dict == {"foo": "var"}
assert cfg.opt_dict is None

def test_create_untyped_list(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg = OmegaConf.structured(module.UntypedList)
assert _utils.get_ref_type(cfg, "list") == List[Any]
assert _utils.get_ref_type(cfg, "opt_list") == Optional[List[Any]]
assert cfg.list == [1, 2]
assert cfg.opt_list is None


def validate_frozen_impl(conf: DictConfig) -> None:
with pytest.raises(ReadonlyConfigError):
Expand Down
17 changes: 16 additions & 1 deletion tests/test_create.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Testing for OmegaConf"""
import re
import sys
from typing import Any, Dict, List
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import pytest
import yaml
Expand Down Expand Up @@ -198,3 +199,17 @@ def test_create_unmodified_loader() -> None:
yaml_cfg = yaml.load("gitrev: 100e100", Loader=yaml.loader.SafeLoader)
assert cfg.gitrev == 1e102
assert yaml_cfg["gitrev"] == "100e100"


def test_create_untyped_list() -> None:
from omegaconf._utils import get_ref_type

cfg = ListConfig(ref_type=List, content=[])
assert get_ref_type(cfg) == Optional[List[Any]]


def test_create_untyped_dict() -> None:
from omegaconf._utils import get_ref_type

cfg = DictConfig(ref_type=Dict, content={})
assert get_ref_type(cfg) == Optional[Dict[Union[str, Enum], Any]]
89 changes: 65 additions & 24 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
import pathlib
import pickle
import tempfile
from enum import Enum
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Optional, Type, Union

import pytest

from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf._utils import get_ref_type

from . import PersonA, PersonD, SubscriptedDict, SubscriptedList
from . import (
PersonA,
PersonD,
SubscriptedDict,
SubscriptedList,
UntypedDict,
UntypedList,
)


def save_load_from_file(conf: Any, resolve: bool, expected: Any) -> None:
Expand Down Expand Up @@ -118,24 +126,26 @@ def test_save_illegal_type() -> None:
OmegaConf.save(OmegaConf.create(), 1000) # type: ignore


def test_pickle_dict() -> None:
with tempfile.TemporaryFile() as fp:
c = OmegaConf.create({"a": "b"})
pickle.dump(c, fp)
fp.flush()
fp.seek(0)
c1 = pickle.load(fp)
assert c == c1


def test_pickle_list() -> None:
@pytest.mark.parametrize( # type: ignore
"obj,ref_type",
[
({"a": "b"}, Dict[Union[str, Enum], Any]),
([1, 2, 3], List[Any]),
],
)
def test_pickle(obj: Any, ref_type: Any) -> None:
with tempfile.TemporaryFile() as fp:
c = OmegaConf.create([1, 2, 3])
c = OmegaConf.create(obj)
pickle.dump(c, fp)
fp.flush()
fp.seek(0)
c1 = pickle.load(fp)
assert c == c1
assert get_ref_type(c1) == Optional[ref_type]
assert c1._metadata.element_type is Any
assert c1._metadata.optional is True
if isinstance(c, DictConfig):
assert c1._metadata.key_type is Any


def test_load_duplicate_keys_top() -> None:
Expand Down Expand Up @@ -189,15 +199,47 @@ def test_load_empty_file(tmpdir: str) -> None:


@pytest.mark.parametrize( # type: ignore
"input_,key,element_type,key_type,optional,ref_type",
"input_,node,element_type,key_type,optional,ref_type",
[
(UntypedList, "list", Any, Any, False, List[Any]),
(UntypedList, "opt_list", Any, Any, True, Optional[List[Any]]),
(UntypedDict, "dict", Any, Any, False, Dict[Union[str, Enum], Any]),
(
UntypedDict,
"opt_dict",
Any,
Any,
True,
Optional[Dict[Union[str, Enum], Any]],
),
(SubscriptedDict, "dict", int, str, False, Dict[str, int]),
(SubscriptedList, "list", int, None, False, List[int]),
(SubscriptedList, "list", int, Any, False, List[int]),
(
DictConfig(
content={"a": "foo"},
ref_type=Dict[str, str],
element_type=str,
key_type=str,
),
None,
str,
str,
True,
Optional[Dict[str, str]],
),
(
ListConfig(content=[1, 2], ref_type=List[int], element_type=int),
None,
int,
Any,
True,
Optional[List[int]],
),
],
)
def test_pickle_generic(
def test_pickle_untyped(
input_: Any,
key: str,
node: str,
optional: bool,
element_type: Any,
key_type: Any,
Expand All @@ -218,10 +260,9 @@ def get_node(cfg: Any, key: str) -> Any:
else:
return cfg._get_node(key)

node = get_node(cfg2, key)
assert cfg == cfg2
assert get_ref_type(node) == ref_type
assert node._metadata.element_type == element_type
assert node._metadata.optional == optional
assert get_ref_type(get_node(cfg2, node)) == ref_type
assert get_node(cfg2, node)._metadata.element_type == element_type
assert get_node(cfg2, node)._metadata.optional == optional
if isinstance(input_, DictConfig):
assert node._metadata.key_type == key_type
assert get_node(cfg2, node)._metadata.key_type == key_type

0 comments on commit 800e8a7

Please sign in to comment.