diff options
| author | dirkf <fieldhouse@gmx.net> | 2023-05-03 12:40:09 +0100 | 
|---|---|---|
| committer | dirkf <fieldhouse@gmx.net> | 2023-07-19 22:14:50 +0100 | 
| commit | 825a40744bf9aeb743452db24e43d3eb61feb6c2 (patch) | |
| tree | 3cf1be097f4352ec0272c0d3c31043f06957ce4a | |
| parent | 47214e46d852e9d7ddf81d69a8e70806e2396c6c (diff) | |
[utils] Align traverse_obj() with yt-dlp
Thanks Grub4k for these:
* traverse `Iterable`s, from https://github.com/yt-dlp/yt-dlp/pull/6902, etc
* traverse `set` key for transformations/filters, `re.Match` group names, from
  https://github.com/yt-dlp/yt-dlp/commit/776995bc109c5cd1aa56b684fada2ce718a386ec, etc
* traverse `re.Match`es, from https://github.com/yt-dlp/yt-dlp/pull/5174
* always return list when branching, from https://github.com/yt-dlp/yt-dlp/pull/5170
| -rw-r--r-- | test/test_utils.py | 37 | ||||
| -rw-r--r-- | youtube_dl/utils.py | 9 | 
2 files changed, 23 insertions, 23 deletions
| diff --git a/test/test_utils.py b/test/test_utils.py index 2ee727caf..1b5d170fe 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,7 @@ import xml.etree.ElementTree  from youtube_dl.utils import (      age_restricted,      args_to_str, -    encode_base_n, +    base_url,      caesar,      clean_html,      clean_podcast_url, @@ -29,10 +29,12 @@ from youtube_dl.utils import (      detect_exe_version,      determine_ext,      dict_get, +    encode_base_n,      encode_compat_str,      encodeFilename,      escape_rfc3986,      escape_url, +    expand_path,      extract_attributes,      ExtractorError,      find_xpath_attr, @@ -51,6 +53,7 @@ from youtube_dl.utils import (      js_to_json,      LazyList,      limit_length, +    lowercase_escape,      merge_dicts,      mimetype2ext,      month_by_name, @@ -66,17 +69,16 @@ from youtube_dl.utils import (      parse_resolution,      parse_bitrate,      pkcs1pad, -    read_batch_urls, -    sanitize_filename, -    sanitize_path, -    sanitize_url, -    expand_path,      prepend_extension, -    replace_extension, +    read_batch_urls,      remove_start,      remove_end,      remove_quotes, +    replace_extension,      rot47, +    sanitize_filename, +    sanitize_path, +    sanitize_url,      shell_quote,      smuggle_url,      str_or_none, @@ -93,10 +95,8 @@ from youtube_dl.utils import (      unified_timestamp,      unsmuggle_url,      uppercase_escape, -    lowercase_escape,      url_basename,      url_or_none, -    base_url,      urljoin,      urlencode_postdata,      urshift, @@ -1586,6 +1586,11 @@ Line 1              'dict': {},          } +        # define a pukka Iterable +        def iter_range(stop): +            for from_ in range(stop): +                yield from_ +          # Test base functionality          self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',                           msg='allow tuple path') @@ -1602,13 +1607,13 @@ Line 1          # Test Ellipsis behavior          self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),                                (item for item in _TEST_DATA.values() if item not in (None, {})), -                              msg='`...` should give all non discarded values') +                              msg='`...` should give all non-discarded values')          self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),                                msg='`...` selection for dicts should select all values')          self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),                           ['https://www.example.com/0', 'https://www.example.com/1'],                           msg='nested `...` queries should work') -        self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), +        self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), iter_range(4),                                msg='`...` query result should be flattened')          self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),                           msg='`...` should accept iterables') @@ -1618,7 +1623,7 @@ Line 1                           [_TEST_DATA['urls']],                           msg='function as query key should perform a filter based on (key, value)')          self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)), -                              msg='exceptions in the query function should be catched') +                              msg='exceptions in the query function should be caught')          self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],                           msg='function key should accept iterables')          if __debug__: @@ -1706,7 +1711,7 @@ Line 1          self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},                           msg='remove empty values when dict key')          self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis}, -                         msg='use `default` when dict key and `default`') +                         msg='use `default` when dict key and a default')          self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},                           msg='remove empty values when nested dict key fails')          self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, @@ -1768,7 +1773,7 @@ Line 1          self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),                           'str', msg='accept matching `expected_type` type')          self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), -                         None, msg='reject non matching `expected_type` type') +                         None, msg='reject non-matching `expected_type` type')          self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),                           '0', msg='transform type using type function')          self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), @@ -1780,7 +1785,7 @@ Line 1          self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),                           {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')          self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int), -                         1, msg='expected_type should not filter non final dict values') +                         1, msg='expected_type should not filter non-final dict values')          self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),                           {0: {0: 100}}, msg='expected_type should transform deep dict values')          self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)), @@ -1838,7 +1843,7 @@ Line 1          self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),                                        _traverse_string=True), 'sr',                           msg='`slice` should result in string if `traverse_string`') -        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"), +        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'),                                        _traverse_string=True), 'str',                           msg='function should result in string if `traverse_string`')          self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 494f8341b..b77a7fb0e 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -4268,13 +4268,8 @@ def variadic(x, allowed_types=NO_DEFAULT):  def dict_get(d, key_or_keys, default=None, skip_false_values=True): -    if isinstance(key_or_keys, (list, tuple)): -        for key in key_or_keys: -            if key not in d or d[key] is None or skip_false_values and not d[key]: -                continue -            return d[key] -        return default -    return d.get(key_or_keys, default) +    exp = (lambda x: x or None) if skip_false_values else IDENTITY +    return traverse_obj(d, *variadic(key_or_keys), expected_type=exp, default=default)  def try_call(*funcs, **kwargs): | 
