diff options
Diffstat (limited to 'yt_dlp/utils/traversal.py')
-rw-r--r-- | yt_dlp/utils/traversal.py | 35 |
1 files changed, 32 insertions, 3 deletions
diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 5a2f69fcc..8938f4c78 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -3,6 +3,7 @@ import contextlib import inspect import itertools import re +import xml.etree.ElementTree from ._utils import ( IDENTITY, @@ -118,7 +119,7 @@ def traverse_obj( branching = True if isinstance(obj, collections.abc.Mapping): result = obj.values() - elif is_iterable_like(obj): + elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): result = obj elif isinstance(obj, re.Match): result = obj.groups() @@ -132,7 +133,7 @@ def traverse_obj( branching = True if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() - elif is_iterable_like(obj): + elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element): iter_obj = enumerate(obj) elif isinstance(obj, re.Match): iter_obj = itertools.chain( @@ -168,7 +169,7 @@ def traverse_obj( result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, (int, slice)): - if is_iterable_like(obj, collections.abc.Sequence): + if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)): branching = isinstance(key, slice) with contextlib.suppress(IndexError): result = obj[key] @@ -176,6 +177,34 @@ def traverse_obj( with contextlib.suppress(IndexError): result = str(obj)[key] + elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str): + xpath, _, special = key.rpartition('/') + if not special.startswith('@') and special != 'text()': + xpath = key + special = None + + # Allow abbreviations of relative paths, absolute paths error + if xpath.startswith('/'): + xpath = f'.{xpath}' + elif xpath and not xpath.startswith('./'): + xpath = f'./{xpath}' + + def apply_specials(element): + if special is None: + return element + if special == '@': + return element.attrib + if special.startswith('@'): + return try_call(element.attrib.get, args=(special[1:],)) + if special == 'text()': + return element.text + assert False, f'apply_specials is missing case for {special!r}' + + if xpath: + result = list(map(apply_specials, obj.iterfind(xpath))) + else: + result = apply_specials(obj) + return branching, result if branching else (result,) def lazy_last(iterable): |