diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 38be2cd437f200..d1161c0b4a2b44 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -247,6 +247,13 @@ def check_element(element): self.assertRegex(repr(element), r"^$") element = ET.Element("tag", key="value") + # Verify type checking for ElementTree constructor + + with self.assertRaises(TypeError): + tree = ET.ElementTree("") + with self.assertRaises(TypeError): + tree = ET.ElementTree(ET.ElementTree()) + # Make sure all standard element methods exist. def check_method(method): diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 44ab5d18624e73..4dabb211d0b863 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -527,7 +527,9 @@ class ElementTree: """ def __init__(self, element=None, file=None): - # assert element is None or iselement(element) + if element is not None and not iselement(element): + raise TypeError(f"element must be etree.Element, " + f"not {type(element).__name__}") self._root = element # first node if file: self.parse(file) @@ -543,7 +545,9 @@ def _setroot(self, element): with the given element. Use with care! """ - # assert iselement(element) + if not iselement(element): + raise TypeError(f"element must be etree.Element, " + f"not {type(element).__name__}") self._root = element def parse(self, source, parser=None):