aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_traversal.py39
-rw-r--r--youtube_dl/traversal.py3
-rw-r--r--youtube_dl/utils.py25
3 files changed, 63 insertions, 4 deletions
diff --git a/test/test_traversal.py b/test/test_traversal.py
index 21f81136f..5d08b8dbb 100644
--- a/test/test_traversal.py
+++ b/test/test_traversal.py
@@ -15,8 +15,11 @@ import re
from youtube_dl.traversal import (
dict_get,
get_first,
+ require,
T,
traverse_obj,
+ unpack,
+ value,
)
from youtube_dl.compat import (
compat_chr as chr,
@@ -27,7 +30,9 @@ from youtube_dl.compat import (
compat_zip as zip,
)
from youtube_dl.utils import (
+ ExtractorError,
int_or_none,
+ join_nonempty,
str_or_none,
)
@@ -462,8 +467,8 @@ class TestTraversal(_TestCase):
}),
values = dict((str(k), v) for k, v in values.items())
- for key, value in values.items():
- self.assertEqual(traverse_obj(morsel, key), value,
+ for key, val in values.items():
+ self.assertEqual(traverse_obj(morsel, key), val,
msg='Morsel should provide access to all values')
values = list(values.values())
self.assertMaybeCountEqual(traverse_obj(morsel, Ellipsis), values,
@@ -481,8 +486,31 @@ class TestTraversal(_TestCase):
[True, 1, 1.1, 'str', {0: 0}, [1]],
'`filter` should filter falsy values')
- def test_get_first(self):
- self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
+
+class TestTraversalHelpers(_TestCase):
+ def test_traversal_require(self):
+ with self.assertRaises(ExtractorError, msg='Missing `value` should raise'):
+ traverse_obj(_TEST_DATA, ('None', T(require('value'))))
+ self.assertEqual(
+ traverse_obj(_TEST_DATA, ('str', T(require('value')))), 'str',
+ '`require` should pass through non-`None` values')
+
+ def test_unpack(self):
+ self.assertEqual(
+ unpack(lambda *x: ''.join(map(compat_str, x)))([1, 2, 3]), '123')
+ self.assertEqual(
+ unpack(join_nonempty)([1, 2, 3]), '1-2-3')
+ self.assertEqual(
+ unpack(join_nonempty, delim=' ')([1, 2, 3]), '1 2 3')
+ with self.assertRaises(TypeError):
+ unpack(join_nonempty)()
+ with self.assertRaises(TypeError):
+ unpack()
+
+ def test_value(self):
+ self.assertEqual(
+ traverse_obj(_TEST_DATA, ('str', T(value('other')))), 'other',
+ '`value` should substitute specified value')
class TestDictGet(_TestCase):
@@ -508,6 +536,9 @@ class TestDictGet(_TestCase):
self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
+ def test_get_first(self):
+ self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
+
if __name__ == '__main__':
unittest.main()
diff --git a/youtube_dl/traversal.py b/youtube_dl/traversal.py
index 834cfef7f..e4e8758c6 100644
--- a/youtube_dl/traversal.py
+++ b/youtube_dl/traversal.py
@@ -5,6 +5,9 @@
from .utils import (
dict_get,
get_first,
+ require,
T,
traverse_obj,
+ unpack,
+ value,
)
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 29d62130a..437257f5b 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -6543,6 +6543,31 @@ def traverse_obj(obj, *paths, **kwargs):
return None if default is NO_DEFAULT else default
+def value(value):
+ return lambda _: value
+
+
+class require(ExtractorError):
+ def __init__(self, name, expected=False):
+ super(require, self).__init__(
+ 'Unable to extract {0}'.format(name), expected=expected)
+
+ def __call__(self, value):
+ if value is None:
+ raise self
+
+ return value
+
+
+def unpack(func, **kwargs):
+ """Make a function that applies `partial(func, **kwargs)` to its argument as *args"""
+ @functools.wraps(func)
+ def inner(items):
+ return func(*items, **kwargs)
+
+ return inner
+
+
def T(*x):
""" For use in yt-dl instead of {type, ...} or set((type, ...)) """
return set(x)