diff options
| author | Andrei Lebedev <lebdron@gmail.com> | 2022-11-03 11:09:37 +0100 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-03 10:09:37 +0000 | 
| commit | 27ed77aabba8c9eb08d66f34092b1bfcc22c482e (patch) | |
| tree | 7cc41fc5e398009a5cf8e7e4156afb0246aa34d3 /test/test_utils.py | |
| parent | c4b19a88169fa76c5eb665d274e7270a0fe452c4 (diff) | |
[utils] Backport traverse_obj (etc) from yt-dlp (#31156)
* Backport traverse_obj and closely related function from yt-dlp (code by pukkandan)
* Backport LazyList, variadic(), try_call (code by pukkandan)
* Recast using yt-dlp's newer traverse_obj() implementation and tests (code by grub4k)
* Add tests for Unicode case folding support matching Py3.5+ (requires f102e3d)
* Improve/add tests for variadic, try_call, join_nonempty
Co-authored-by: dirkf <fieldhouse@gmx.net>
Diffstat (limited to 'test/test_utils.py')
| -rw-r--r-- | test/test_utils.py | 323 | 
1 files changed, 323 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index f1a748dde..9d364c863 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,9 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  # Various small unit tests  import io +import itertools  import json +import re  import xml.etree.ElementTree  from youtube_dl.utils import ( @@ -40,11 +42,14 @@ from youtube_dl.utils import (      get_element_by_attribute,      get_elements_by_class,      get_elements_by_attribute, +    get_first,      InAdvancePagedList,      int_or_none,      intlist_to_bytes,      is_html, +    join_nonempty,      js_to_json, +    LazyList,      limit_length,      merge_dicts,      mimetype2ext, @@ -79,6 +84,8 @@ from youtube_dl.utils import (      strip_or_none,      subtitles_filename,      timeconvert, +    traverse_obj, +    try_call,      unescapeHTML,      unified_strdate,      unified_timestamp, @@ -92,6 +99,7 @@ from youtube_dl.utils import (      urlencode_postdata,      urshift,      update_url_query, +    variadic,      version_tuple,      xpath_with_ns,      xpath_element, @@ -112,12 +120,18 @@ from youtube_dl.compat import (      compat_getenv,      compat_os_name,      compat_setenv, +    compat_str,      compat_urlparse,      compat_parse_qs,  )  class TestUtil(unittest.TestCase): + +    # yt-dlp shim +    def assertCountEqual(self, expected, got, msg='count should be the same'): +        return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg) +      def test_timeconvert(self):          self.assertTrue(timeconvert('') is None)          self.assertTrue(timeconvert('bougrg') is None) @@ -1478,6 +1492,315 @@ Line 1          self.assertEqual(clean_podcast_url('https://www.podtrac.com/pts/redirect.mp3/chtbl.com/track/5899E/traffic.megaphone.fm/HSW7835899191.mp3'), 'https://traffic.megaphone.fm/HSW7835899191.mp3')          self.assertEqual(clean_podcast_url('https://play.podtrac.com/npr-344098539/edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3'), 'https://edge1.pod.npr.org/anon.npr-podcasts/podcast/npr/waitwait/2020/10/20201003_waitwait_wwdtmpodcast201003-015621a5-f035-4eca-a9a1-7c118d90bc3c.mp3') +    def test_LazyList(self): +        it = list(range(10)) + +        self.assertEqual(list(LazyList(it)), it) +        self.assertEqual(LazyList(it).exhaust(), it) +        self.assertEqual(LazyList(it)[5], it[5]) + +        self.assertEqual(LazyList(it)[5:], it[5:]) +        self.assertEqual(LazyList(it)[:5], it[:5]) +        self.assertEqual(LazyList(it)[::2], it[::2]) +        self.assertEqual(LazyList(it)[1::2], it[1::2]) +        self.assertEqual(LazyList(it)[5::-1], it[5::-1]) +        self.assertEqual(LazyList(it)[6:2:-2], it[6:2:-2]) +        self.assertEqual(LazyList(it)[::-1], it[::-1]) + +        self.assertTrue(LazyList(it)) +        self.assertFalse(LazyList(range(0))) +        self.assertEqual(len(LazyList(it)), len(it)) +        self.assertEqual(repr(LazyList(it)), repr(it)) +        self.assertEqual(compat_str(LazyList(it)), compat_str(it)) + +        self.assertEqual(list(LazyList(it, reverse=True)), it[::-1]) +        self.assertEqual(list(reversed(LazyList(it))[::-1]), it) +        self.assertEqual(list(reversed(LazyList(it))[1:3:7]), it[::-1][1:3:7]) + +    def test_LazyList_laziness(self): + +        def test(ll, idx, val, cache): +            self.assertEqual(ll[idx], val) +            self.assertEqual(ll._cache, list(cache)) + +        ll = LazyList(range(10)) +        test(ll, 0, 0, range(1)) +        test(ll, 5, 5, range(6)) +        test(ll, -3, 7, range(10)) + +        ll = LazyList(range(10), reverse=True) +        test(ll, -1, 0, range(1)) +        test(ll, 3, 6, range(10)) + +        ll = LazyList(itertools.count()) +        test(ll, 10, 10, range(11)) +        ll = reversed(ll) +        test(ll, -15, 14, range(15)) + +    def test_try_call(self): +        def total(*x, **kwargs): +            return sum(x) + sum(kwargs.values()) + +        self.assertEqual(try_call(None), None, +                         msg='not a fn should give None') +        self.assertEqual(try_call(lambda: 1), 1, +                         msg='int fn with no expected_type should give int') +        self.assertEqual(try_call(lambda: 1, expected_type=int), 1, +                         msg='int fn with expected_type int should give int') +        self.assertEqual(try_call(lambda: 1, expected_type=dict), None, +                         msg='int fn with wrong expected_type should give None') +        self.assertEqual(try_call(total, args=(0, 1, 0, ), expected_type=int), 1, +                         msg='fn should accept arglist') +        self.assertEqual(try_call(total, kwargs={'a': 0, 'b': 1, 'c': 0}, expected_type=int), 1, +                         msg='fn should accept kwargs') +        self.assertEqual(try_call(lambda: 1, expected_type=dict), None, +                         msg='int fn with no expected_type should give None') +        self.assertEqual(try_call(lambda x: {}, total, args=(42, ), expected_type=int), 42, +                         msg='expect first int result with expected_type int') + +    def test_variadic(self): +        self.assertEqual(variadic(None), (None, )) +        self.assertEqual(variadic('spam'), ('spam', )) +        self.assertEqual(variadic('spam', allowed_types=dict), 'spam') + +    def test_traverse_obj(self): +        _TEST_DATA = { +            100: 100, +            1.2: 1.2, +            'str': 'str', +            'None': None, +            '...': Ellipsis, +            'urls': [ +                {'index': 0, 'url': 'https://www.example.com/0'}, +                {'index': 1, 'url': 'https://www.example.com/1'}, +            ], +            'data': ( +                {'index': 2}, +                {'index': 3}, +            ), +            'dict': {}, +        } + +        # Test base functionality +        self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', +                         msg='allow tuple path') +        self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str', +                         msg='allow list path') +        self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str', +                         msg='allow iterable path') +        self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str', +                         msg='single items should be treated as a path') +        self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA) +        self.assertEqual(traverse_obj(_TEST_DATA, 100), 100) +        self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2) + +        # Test Ellipsis behavior +        self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), +                              (item for item in _TEST_DATA.values() if item is not None), +                              msg='`...` should give all values except `None`') +        self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), +                              msg='`...` selection for dicts should select all values') +        self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), +                         ['https://www.example.com/0', 'https://www.example.com/1'], +                         msg='nested `...` queries should work') +        self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), +                              msg='`...` query result should be flattened') + +        # Test function as key +        self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), +                         [_TEST_DATA['urls']], +                         msg='function as query key should perform a filter based on (key, value)') +        self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), {'str'}, +                              msg='exceptions in the query function should be caught') + +        # Test alternative paths +        self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', +                         msg='multiple `paths` should be treated as alternative paths') +        self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str', +                         msg='alternatives should exit early') +        self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None, +                         msg='alternatives should return `default` if exhausted') +        self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, 'fail'), 100), 100, +                         msg='alternatives should track their own branching return') +        self.assertEqual(traverse_obj(_TEST_DATA, ('dict', Ellipsis), ('data', Ellipsis)), list(_TEST_DATA['data']), +                         msg='alternatives on empty objects should search further') + +        # Test branch and path nesting +        self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'], +                         msg='tuple as key should be treated as branches') +        self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'], +                         msg='list as key should be treated as branches') +        self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'], +                         msg='double nesting in path should be treated as paths') +        self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1], +                         msg='do not fail early on branching') +        self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))), +                              ['https://www.example.com/0', 'https://www.example.com/1'], +                              msg='triple nesting in path should be treated as branches') +        self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (Ellipsis, 'url')))), +                         ['https://www.example.com/0', 'https://www.example.com/1'], +                         msg='ellipsis as branch path start gets flattened') + +        # Test dictionary as key +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2}, +                         msg='dict key should result in a dict with the same keys') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}), +                         {0: 'https://www.example.com/0'}, +                         msg='dict key should allow paths') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}), +                         {0: ['https://www.example.com/0']}, +                         msg='tuple in dict path should be treated as branches') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}), +                         {0: ['https://www.example.com/0']}, +                         msg='double nesting in dict path should be treated as paths') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}), +                         {0: ['https://www.example.com/1', 'https://www.example.com/0']}, +                         msg='triple nesting in dict path should be treated as branches') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, +                         msg='remove `None` values when dict key') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis}, +                         msg='do not remove `None` values if `default`') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, +                         msg='do not remove empty values when dict key') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}}, +                         msg='do not remove empty values when dict key and a default') +        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []}, +                         msg='if branch in dict key not successful, return `[]`') + +        # Testing default parameter behavior +        _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} +        self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None, +                         msg='default value should be `None`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=Ellipsis), Ellipsis, +                         msg='chained fails should result in default') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0, +                         msg='should not short cirquit on `None`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1, +                         msg='invalid dict key should result in `default`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1, +                         msg='`None` is a deliberate sentinel and should become `default`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None, +                         msg='`IndexError` should result in `default`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail'), default=1), 1, +                         msg='if branched but not successful return `default` if defined, not `[]`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail'), default=None), None, +                         msg='if branched but not successful return `default` even if `default` is `None`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, (Ellipsis, 'fail')), [], +                         msg='if branched but not successful return `[]`, not `default`') +        self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [], +                         msg='if branched but object is empty return `[]`, not `default`') + +        # Testing expected_type behavior +        _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} +        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str', +                         msg='accept matching `expected_type` type') +        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, +                         msg='reject non matching `expected_type` type') +        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0', +                         msg='transform type using type function') +        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', +                                      expected_type=lambda _: 1 / 0), None, +                         msg='wrap expected_type function in try_call') +        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'], +                         msg='eliminate items that expected_type fails on') + +        # Test get_all behavior +        _GET_ALL_DATA = {'key': [0, 1, 2]} +        self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', Ellipsis), get_all=False), 0, +                         msg='if not `get_all`, return only first matching value') +        self.assertEqual(traverse_obj(_GET_ALL_DATA, Ellipsis, get_all=False), [0, 1, 2], +                         msg='do not overflatten if not `get_all`') + +        # Test casesense behavior +        _CASESENSE_DATA = { +            'KeY': 'value0', +            0: { +                'KeY': 'value1', +                0: {'KeY': 'value2'}, +            }, +            # FULLWIDTH LATIN CAPITAL LETTER K +            '\uff2bey': 'value3', +        } +        self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None, +                         msg='dict keys should be case sensitive unless `casesense`') +        self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY', +                                      casesense=False), 'value0', +                         msg='allow non matching key case if `casesense`') +        self.assertEqual(traverse_obj(_CASESENSE_DATA, '\uff4bey',  # FULLWIDTH LATIN SMALL LETTER K +                                      casesense=False), 'value3', +                         msg='allow non matching Unicode key case if `casesense`') +        self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)), +                                      casesense=False), ['value1'], +                         msg='allow non matching key case in branch if `casesense`') +        self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)), +                                      casesense=False), ['value2'], +                         msg='allow non matching key case in branch path if `casesense`') + +        # Test traverse_string behavior +        _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2} +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None, +                         msg='do not traverse into string if not `traverse_string`') +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0), +                                      _traverse_string=True), 's', +                         msg='traverse into string if `traverse_string`') +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1), +                                      _traverse_string=True), '.', +                         msg='traverse into converted data if `traverse_string`') +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis), +                                      _traverse_string=True), list('str'), +                         msg='`...` branching into string should result in list') +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), +                                      _traverse_string=True), ['s', 'r'], +                         msg='branching into string should result in list') +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), +                                      _traverse_string=True), list('str'), +                         msg='function branching into string should result in list') + +        # Test is_user_input behavior +        _IS_USER_INPUT_DATA = {'range8': list(range(8))} +        self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), +                                      _is_user_input=True), 3, +                         msg='allow for string indexing if `is_user_input`') +        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), +                                           _is_user_input=True), tuple(range(8))[3:], +                              msg='allow for string slice if `is_user_input`') +        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), +                                           _is_user_input=True), tuple(range(8))[:4:2], +                              msg='allow step in string slice if `is_user_input`') +        self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), +                                           _is_user_input=True), range(8), +                              msg='`:` should be treated as `...` if `is_user_input`') +        with self.assertRaises(TypeError, msg='too many params should result in error'): +            traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), _is_user_input=True) + +        # Test re.Match as input obj +        mobj = re.match(r'^0(12)(?P<group>3)(4)?$', '0123') +        self.assertEqual(traverse_obj(mobj, Ellipsis), [x for x in mobj.groups() if x is not None], +                         msg='`...` on a `re.Match` should give its `groups()`') +        self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 2)), ['0123', '3'], +                         msg='function on a `re.Match` should give groupno, value starting at 0') +        self.assertEqual(traverse_obj(mobj, 'group'), '3', +                         msg='str key on a `re.Match` should give group with that name') +        self.assertEqual(traverse_obj(mobj, 2), '3', +                         msg='int key on a `re.Match` should give group with that name') +        self.assertEqual(traverse_obj(mobj, 'gRoUp', casesense=False), '3', +                         msg='str key on a `re.Match` should respect casesense') +        self.assertEqual(traverse_obj(mobj, 'fail'), None, +                         msg='failing str key on a `re.Match` should return `default`') +        self.assertEqual(traverse_obj(mobj, 'gRoUpS', casesense=False), None, +                         msg='failing str key on a `re.Match` should return `default`') +        self.assertEqual(traverse_obj(mobj, 8), None, +                         msg='failing int key on a `re.Match` should return `default`') + +    def test_get_first(self): +        self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') + +    def test_join_nonempty(self): +        self.assertEqual(join_nonempty('a', 'b'), 'a-b') +        self.assertEqual(join_nonempty( +            'a', 'b', 'c', 'd', +            from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') +  if __name__ == '__main__':      unittest.main()  | 
