diff options
| -rw-r--r-- | test/test_traversal.py | 39 | ||||
| -rw-r--r-- | youtube_dl/traversal.py | 3 | ||||
| -rw-r--r-- | youtube_dl/utils.py | 25 |
3 files changed, 63 insertions, 4 deletions
diff --git a/test/test_traversal.py b/test/test_traversal.py index 21f81136f..5d08b8dbb 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -15,8 +15,11 @@ import re from youtube_dl.traversal import ( dict_get, get_first, + require, T, traverse_obj, + unpack, + value, ) from youtube_dl.compat import ( compat_chr as chr, @@ -27,7 +30,9 @@ from youtube_dl.compat import ( compat_zip as zip, ) from youtube_dl.utils import ( + ExtractorError, int_or_none, + join_nonempty, str_or_none, ) @@ -462,8 +467,8 @@ class TestTraversal(_TestCase): }), values = dict((str(k), v) for k, v in values.items()) - for key, value in values.items(): - self.assertEqual(traverse_obj(morsel, key), value, + for key, val in values.items(): + self.assertEqual(traverse_obj(morsel, key), val, msg='Morsel should provide access to all values') values = list(values.values()) self.assertMaybeCountEqual(traverse_obj(morsel, Ellipsis), values, @@ -481,8 +486,31 @@ class TestTraversal(_TestCase): [True, 1, 1.1, 'str', {0: 0}, [1]], '`filter` should filter falsy values') - def test_get_first(self): - self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') + +class TestTraversalHelpers(_TestCase): + def test_traversal_require(self): + with self.assertRaises(ExtractorError, msg='Missing `value` should raise'): + traverse_obj(_TEST_DATA, ('None', T(require('value')))) + self.assertEqual( + traverse_obj(_TEST_DATA, ('str', T(require('value')))), 'str', + '`require` should pass through non-`None` values') + + def test_unpack(self): + self.assertEqual( + unpack(lambda *x: ''.join(map(compat_str, x)))([1, 2, 3]), '123') + self.assertEqual( + unpack(join_nonempty)([1, 2, 3]), '1-2-3') + self.assertEqual( + unpack(join_nonempty, delim=' ')([1, 2, 3]), '1 2 3') + with self.assertRaises(TypeError): + unpack(join_nonempty)() + with self.assertRaises(TypeError): + unpack() + + def test_value(self): + self.assertEqual( + traverse_obj(_TEST_DATA, ('str', T(value('other')))), 'other', + '`value` should substitute specified value') class TestDictGet(_TestCase): @@ -508,6 +536,9 @@ class TestDictGet(_TestCase): self.assertEqual(dict_get(d, ('b', 'c', key, )), None) self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value) + def test_get_first(self): + self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') + if __name__ == '__main__': unittest.main() diff --git a/youtube_dl/traversal.py b/youtube_dl/traversal.py index 834cfef7f..e4e8758c6 100644 --- a/youtube_dl/traversal.py +++ b/youtube_dl/traversal.py @@ -5,6 +5,9 @@ from .utils import ( dict_get, get_first, + require, T, traverse_obj, + unpack, + value, ) diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 29d62130a..437257f5b 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -6543,6 +6543,31 @@ def traverse_obj(obj, *paths, **kwargs): return None if default is NO_DEFAULT else default +def value(value): + return lambda _: value + + +class require(ExtractorError): + def __init__(self, name, expected=False): + super(require, self).__init__( + 'Unable to extract {0}'.format(name), expected=expected) + + def __call__(self, value): + if value is None: + raise self + + return value + + +def unpack(func, **kwargs): + """Make a function that applies `partial(func, **kwargs)` to its argument as *args""" + @functools.wraps(func) + def inner(items): + return func(*items, **kwargs) + + return inner + + def T(*x): """ For use in yt-dl instead of {type, ...} or set((type, ...)) """ return set(x) |
