diff options
author | Simon Sawicki <contact@grub4k.xyz> | 2024-01-05 21:26:17 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-05 21:26:17 +0100 |
commit | ffbd4f2a02fee387ea5e0a267ce32df5259111ac (patch) | |
tree | 66e0792e0566835aa4589e245cbfdd06ff798aff | |
parent | 292d60b1ed3b9fe5bcb2775a894cca99b0f9473e (diff) |
[utils] `traverse_obj`: Support `xml.etree.ElementTree.Element` (#8911)
Authored by: Grub4K
-rw-r--r-- | test/test_utils.py | 52 | ||||
-rw-r--r-- | yt_dlp/utils/traversal.py | 35 |
2 files changed, 84 insertions, 3 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index c3e387cd0..09c648cf8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2340,6 +2340,58 @@ Line 1 self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], msg='function on a `re.Match` should give group name as well') + # Test xml.etree.ElementTree.Element as input obj + etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?> + <data> + <country name="Liechtenstein"> + <rank>1</rank> + <year>2008</year> + <gdppc>141100</gdppc> + <neighbor name="Austria" direction="E"/> + <neighbor name="Switzerland" direction="W"/> + </country> + <country name="Singapore"> + <rank>4</rank> + <year>2011</year> + <gdppc>59900</gdppc> + <neighbor name="Malaysia" direction="N"/> + </country> + <country name="Panama"> + <rank>68</rank> + <year>2011</year> + <gdppc>13600</gdppc> + <neighbor name="Costa Rica" direction="W"/> + <neighbor name="Colombia" direction="E"/> + </country> + </data>''') + self.assertEqual(traverse_obj(etree, ''), etree, + msg='empty str key should return the element itself') + self.assertEqual(traverse_obj(etree, 'country'), list(etree), + msg='str key should lead all children with that tag name') + self.assertEqual(traverse_obj(etree, ...), list(etree), + msg='`...` as key should return all children') + self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]], + msg='function as key should get element as value') + self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]], + msg='function as key should get index as key') + self.assertEqual(traverse_obj(etree, 0), etree[0], + msg='int key should return the nth child') + self.assertEqual(traverse_obj(etree, './/neighbor/@name'), + ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'], + msg='`@<attribute>` at end of path should give that attribute') + self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None], + msg='`@<nonexistant>` at end of path should give `None`') + self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'}, + msg='`@` should give the full attribute dict') + self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'], + msg='`text()` at end of path should give the inner text') + self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'], + msg='full python xpath features should be supported') + self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein', + msg='special transformations should act on current element') + self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100], + msg='special transformations should act on current element') + def test_http_header_dict(self): headers = HTTPHeaderDict() headers['ytdl-test'] = b'0' diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 5a2f69fcc..8938f4c78 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -3,6 +3,7 @@ import contextlib import inspect import itertools import re +import xml.etree.ElementTree from ._utils import ( IDENTITY, @@ -118,7 +119,7 @@ def traverse_obj( branching = True if isinstance(obj, collections.abc.Mapping): result = obj.values() - elif is_iterable_like(obj): + elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): result = obj elif isinstance(obj, re.Match): result = obj.groups() @@ -132,7 +133,7 @@ def traverse_obj( branching = True if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() - elif is_iterable_like(obj): + elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): iter_obj = enumerate(obj) elif isinstance(obj, re.Match): iter_obj = itertools.chain( @@ -168,7 +169,7 @@ def traverse_obj( result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, (int, slice)): - if is_iterable_like(obj, collections.abc.Sequence): + if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)): branching = isinstance(key, slice) with contextlib.suppress(IndexError): result = obj[key] @@ -176,6 +177,34 @@ def traverse_obj( with contextlib.suppress(IndexError): result = str(obj)[key] + elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str): + xpath, _, special = key.rpartition('/') + if not special.startswith('@') and special != 'text()': + xpath = key + special = None + + # Allow abbreviations of relative paths, absolute paths error + if xpath.startswith('/'): + xpath = f'.{xpath}' + elif xpath and not xpath.startswith('./'): + xpath = f'./{xpath}' + + def apply_specials(element): + if special is None: + return element + if special == '@': + return element.attrib + if special.startswith('@'): + return try_call(element.attrib.get, args=(special[1:],)) + if special == 'text()': + return element.text + assert False, f'apply_specials is missing case for {special!r}' + + if xpath: + result = list(map(apply_specials, obj.iterfind(xpath))) + else: + result = apply_specials(obj) + return branching, result if branching else (result,) def lazy_last(iterable): |