aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_traversal.py32
-rw-r--r--yt_dlp/utils/traversal.py9
2 files changed, 41 insertions, 0 deletions
diff --git a/test/test_traversal.py b/test/test_traversal.py
index 3b247d059..0b2f3fb5d 100644
--- a/test/test_traversal.py
+++ b/test/test_traversal.py
@@ -377,3 +377,35 @@ class TestTraversal:
'special transformations should act on current element'
assert traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})) == [1, 2008, 141100], \
'special transformations should act on current element'
+
+ def test_traversal_unbranching(self):
+ assert traverse_obj(_TEST_DATA, [(100, 1.2), all]) == [100, 1.2], \
+ '`all` should give all results as list'
+ assert traverse_obj(_TEST_DATA, [(100, 1.2), any]) == 100, \
+ '`any` should give the first result'
+ assert traverse_obj(_TEST_DATA, [100, all]) == [100], \
+ '`all` should give list if non branching'
+ assert traverse_obj(_TEST_DATA, [100, any]) == 100, \
+ '`any` should give single item if non branching'
+ assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]) == [100], \
+ '`all` should filter `None` and empty dict'
+ assert traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]) == 100, \
+ '`any` should filter `None` and empty dict'
+ assert traverse_obj(_TEST_DATA, [{
+ 'all': [('dict', 'None', 100, 1.2), all],
+ 'any': [('dict', 'None', 100, 1.2), any],
+ }]) == {'all': [100, 1.2], 'any': 100}, \
+ '`all`/`any` should apply to each dict path separately'
+ assert traverse_obj(_TEST_DATA, [{
+ 'all': [('dict', 'None', 100, 1.2), all],
+ 'any': [('dict', 'None', 100, 1.2), any],
+ }], get_all=False) == {'all': [100, 1.2], 'any': 100}, \
+ '`all`/`any` should apply to dict regardless of `get_all`'
+ assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, {float}]) is None, \
+ '`all` should reset branching status'
+ assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, {float}]) is None, \
+ '`any` should reset branching status'
+ assert traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, ..., {float}]) == [1.2], \
+ '`all` should allow further branching'
+ assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \
+ '`any` should allow further branching'
diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py
index 8938f4c78..926a3d0a1 100644
--- a/yt_dlp/utils/traversal.py
+++ b/yt_dlp/utils/traversal.py
@@ -228,6 +228,15 @@ def traverse_obj(
if not casesense and isinstance(key, str):
key = key.casefold()
+ if key in (any, all):
+ has_branched = False
+ filtered_objs = (obj for obj in objs if obj not in (None, {}))
+ if key is any:
+ objs = (next(filtered_objs, None),)
+ else:
+ objs = (list(filtered_objs),)
+ continue
+
if __debug__ and callable(key):
# Verify function signature
inspect.signature(key).bind(None, None)