diff options
author | dirkf <fieldhouse@gmx.net> | 2024-04-21 23:42:08 +0100 |
---|---|---|
committer | dirkf <fieldhouse@gmx.net> | 2024-05-30 15:46:36 +0100 |
commit | 06da64ee51cd405b9392ba484cf7d3d31a88ee30 (patch) | |
tree | 9e52705c22131dea945bf015447d080faef778ba /test | |
parent | a08f2b7e4567cdc50c0614ee0a4ffdff49b8b6e6 (diff) |
[utils] Update traverse_obj() from yt-dlp
* remove `is_user_input` option per https://github.com/yt-dlp/yt-dlp/pull/8673
* support traversal of compat_xml_etree_ElementTree_Element per https://github.com/yt-dlp/yt-dlp/pull/8911
* allow un/branching using all and any per https://github.com/yt-dlp/yt-dlp/pull/9571
* support traversal of compat_cookies.Morsel and multiple types in `set()` keys per https://github.com/yt-dlp/yt-dlp/pull/9577
thx Grub4k for these
* also, move traversal tests to a separate class
* allow for unordered dicts in tests for Py<3.7
Diffstat (limited to 'test')
-rw-r--r-- | test/test_utils.py | 257 |
1 files changed, 189 insertions, 68 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index ca36909a8..179d21cf5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -123,6 +123,7 @@ from youtube_dl.compat import ( compat_chr, compat_etree_fromstring, compat_getenv, + compat_http_cookies, compat_os_name, compat_setenv, compat_str, @@ -132,10 +133,6 @@ from youtube_dl.compat import ( class TestUtil(unittest.TestCase): - # yt-dlp shim - def assertCountEqual(self, expected, got, msg='count should be the same'): - return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg) - def test_timeconvert(self): self.assertTrue(timeconvert('') is None) self.assertTrue(timeconvert('bougrg') is None) @@ -740,28 +737,6 @@ class TestUtil(unittest.TestCase): self.assertRaises( ValueError, multipart_encode, {b'field': b'value'}, boundary='value') - def test_dict_get(self): - FALSE_VALUES = { - 'none': None, - 'false': False, - 'zero': 0, - 'empty_string': '', - 'empty_list': [], - } - d = FALSE_VALUES.copy() - d['a'] = 42 - self.assertEqual(dict_get(d, 'a'), 42) - self.assertEqual(dict_get(d, 'b'), None) - self.assertEqual(dict_get(d, 'b', 42), 42) - self.assertEqual(dict_get(d, ('a', )), 42) - self.assertEqual(dict_get(d, ('b', 'a', )), 42) - self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42) - self.assertEqual(dict_get(d, ('b', 'c', )), None) - self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42) - for key, false_value in FALSE_VALUES.items(): - self.assertEqual(dict_get(d, ('b', 'c', key, )), None) - self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value) - def test_merge_dicts(self): self.assertEqual(merge_dicts({'a': 1}, {'b': 2}), {'a': 1, 'b': 2}) self.assertEqual(merge_dicts({'a': 1}, {'a': 2}), {'a': 1}) @@ -1703,24 +1678,46 @@ Line 1 self.assertEqual(variadic('spam', allowed_types=dict), 'spam') self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam') + def test_join_nonempty(self): + self.assertEqual(join_nonempty('a', 'b'), 'a-b') + self.assertEqual(join_nonempty( + 'a', 'b', 'c', 'd', + from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') + + +class TestTraversal(unittest.TestCase): + str = compat_str + _TEST_DATA = { + 100: 100, + 1.2: 1.2, + 'str': 'str', + 'None': None, + '...': Ellipsis, + 'urls': [ + {'index': 0, 'url': 'https://www.example.com/0'}, + {'index': 1, 'url': 'https://www.example.com/1'}, + ], + 'data': ( + {'index': 2}, + {'index': 3}, + ), + 'dict': {}, + } + + # yt-dlp shim + def assertCountEqual(self, expected, got, msg='count should be the same'): + return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg) + + def assertMaybeCountEqual(self, *args, **kwargs): + if sys.version_info < (3, 7): + # random dict order + return self.assertCountEqual(*args, **kwargs) + else: + return self.assertEqual(*args, **kwargs) + def test_traverse_obj(self): - str = compat_str - _TEST_DATA = { - 100: 100, - 1.2: 1.2, - 'str': 'str', - 'None': None, - '...': Ellipsis, - 'urls': [ - {'index': 0, 'url': 'https://www.example.com/0'}, - {'index': 1, 'url': 'https://www.example.com/1'}, - ], - 'data': ( - {'index': 2}, - {'index': 3}, - ), - 'dict': {}, - } + str = self.str + _TEST_DATA = self._TEST_DATA # define a pukka Iterable def iter_range(stop): @@ -1771,15 +1768,19 @@ Line 1 # Test set as key (transformation/type, like `expected_type`) self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'], msg='Function in set should be a transformation') + self.assertEqual(traverse_obj(_TEST_DATA, ('fail', T(lambda _: 'const'))), 'const', + msg='Function in set should always be called') self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'], msg='Type in set should be a type filter') + self.assertMaybeCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str, int))), [100, 'str'], + msg='Multiple types in set should be a type filter') self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA, msg='A single set should be wrapped into a path') self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'], msg='Transformation function should not raise') - self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))), - [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], - msg='Function in set should be a transformation') + self.assertMaybeCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))), + [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], + msg='Function in set should be a transformation') if __debug__: with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): traverse_obj(_TEST_DATA, set()) @@ -1992,23 +1993,6 @@ Line 1 self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [], msg='branching should result in list if `traverse_string`') - # Test is_user_input behavior - _IS_USER_INPUT_DATA = {'range8': list(range(8))} - self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), - _is_user_input=True), 3, - msg='allow for string indexing if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), - _is_user_input=True), tuple(range(8))[3:], - msg='allow for string slice if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), - _is_user_input=True), tuple(range(8))[:4:2], - msg='allow step in string slice if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), - _is_user_input=True), range(8), - msg='`:` should be treated as `...` if `is_user_input`') - with self.assertRaises(TypeError, msg='too many params should result in error'): - traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), _is_user_input=True) - # Test re.Match as input obj mobj = re.match(r'^0(12)(?P<group>3)(4)?$', '0123') self.assertEqual(traverse_obj(mobj, Ellipsis), [x for x in mobj.groups() if x is not None], @@ -2030,14 +2014,151 @@ Line 1 self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], msg='function on a `re.Match` should give group name as well') + # Test xml.etree.ElementTree.Element as input obj + etree = compat_etree_fromstring('''<?xml version="1.0"?> + <data> + <country name="Liechtenstein"> + <rank>1</rank> + <year>2008</year> + <gdppc>141100</gdppc> + <neighbor name="Austria" direction="E"/> + <neighbor name="Switzerland" direction="W"/> + </country> + <country name="Singapore"> + <rank>4</rank> + <year>2011</year> + <gdppc>59900</gdppc> + <neighbor name="Malaysia" direction="N"/> + </country> + <country name="Panama"> + <rank>68</rank> + <year>2011</year> + <gdppc>13600</gdppc> + <neighbor name="Costa Rica" direction="W"/> + <neighbor name="Colombia" direction="E"/> + </country> + </data>''') + self.assertEqual(traverse_obj(etree, ''), etree, + msg='empty str key should return the element itself') + self.assertEqual(traverse_obj(etree, 'country'), list(etree), + msg='str key should return all children with that tag name') + self.assertEqual(traverse_obj(etree, Ellipsis), list(etree), + msg='`...` as key should return all children') + self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]], + msg='function as key should get element as value') + self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]], + msg='function as key should get index as key') + self.assertEqual(traverse_obj(etree, 0), etree[0], + msg='int key should return the nth child') + self.assertEqual(traverse_obj(etree, './/neighbor/@name'), + ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'], + msg='`@<attribute>` at end of path should give that attribute') + self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None], + msg='`@<nonexistent>` at end of path should give `None`') + self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'}, + msg='`@` should give the full attribute dict') + self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'], + msg='`text()` at end of path should give the inner text') + self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'], + msg='full python xpath features should be supported') + self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein', + msg='special transformations should act on current element') + self.assertEqual(traverse_obj(etree, ('country', 0, Ellipsis, 'text()', T(int_or_none))), [1, 2008, 141100], + msg='special transformations should act on current element') + + def test_traversal_unbranching(self): + # str = self.str + _TEST_DATA = self._TEST_DATA + + self.assertEqual(traverse_obj(_TEST_DATA, [(100, 1.2), all]), [100, 1.2], + msg='`all` should give all results as list') + self.assertEqual(traverse_obj(_TEST_DATA, [(100, 1.2), any]), 100, + msg='`any` should give the first result') + self.assertEqual(traverse_obj(_TEST_DATA, [100, all]), [100], + msg='`all` should give list if non branching') + self.assertEqual(traverse_obj(_TEST_DATA, [100, any]), 100, + msg='`any` should give single item if non branching') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]), [100], + msg='`all` should filter `None` and empty dict') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]), 100, + msg='`any` should filter `None` and empty dict') + self.assertEqual(traverse_obj(_TEST_DATA, [{ + 'all': [('dict', 'None', 100, 1.2), all], + 'any': [('dict', 'None', 100, 1.2), any], + }]), {'all': [100, 1.2], 'any': 100}, + msg='`all`/`any` should apply to each dict path separately') + self.assertEqual(traverse_obj(_TEST_DATA, [{ + 'all': [('dict', 'None', 100, 1.2), all], + 'any': [('dict', 'None', 100, 1.2), any], + }], get_all=False), {'all': [100, 1.2], 'any': 100}, + msg='`all`/`any` should apply to dict regardless of `get_all`') + self.assertIs(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, T(float)]), None, + msg='`all` should reset branching status') + self.assertIs(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, T(float)]), None, + msg='`any` should reset branching status') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, Ellipsis, T(float)]), [1.2], + msg='`all` should allow further branching') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, Ellipsis, 'index']), [0, 1], + msg='`any` should allow further branching') + + def test_traversal_morsel(self): + values = { + 'expires': 'a', + 'path': 'b', + 'comment': 'c', + 'domain': 'd', + 'max-age': 'e', + 'secure': 'f', + 'httponly': 'g', + 'version': 'h', + 'samesite': 'i', + } + # SameSite added in Py3.8, breaks .update for 3.5-3.7 + if sys.version_info < (3, 8): + del values['samesite'] + morsel = compat_http_cookies.Morsel() + morsel.set(str('item_key'), 'item_value', 'coded_value') + morsel.update(values) + values['key'] = str('item_key') + values['value'] = 'item_value' + values = dict((str(k), v) for k, v in values.items()) + # make test pass even without ordered dict + value_set = set(values.values()) + + for key, value in values.items(): + self.assertEqual(traverse_obj(morsel, key), value, + msg='Morsel should provide access to all values') + self.assertEqual(set(traverse_obj(morsel, Ellipsis)), value_set, + msg='`...` should yield all values') + self.assertEqual(set(traverse_obj(morsel, lambda k, v: True)), value_set, + msg='function key should yield all values') + self.assertIs(traverse_obj(morsel, [(None,), any]), morsel, + msg='Morsel should not be implicitly changed to dict on usage') + def test_get_first(self): self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') - def test_join_nonempty(self): - self.assertEqual(join_nonempty('a', 'b'), 'a-b') - self.assertEqual(join_nonempty( - 'a', 'b', 'c', 'd', - from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') + def test_dict_get(self): + FALSE_VALUES = { + 'none': None, + 'false': False, + 'zero': 0, + 'empty_string': '', + 'empty_list': [], + } + d = FALSE_VALUES.copy() + d['a'] = 42 + self.assertEqual(dict_get(d, 'a'), 42) + self.assertEqual(dict_get(d, 'b'), None) + self.assertEqual(dict_get(d, 'b', 42), 42) + self.assertEqual(dict_get(d, ('a', )), 42) + self.assertEqual(dict_get(d, ('b', 'a', )), 42) + self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42) + self.assertEqual(dict_get(d, ('b', 'c', )), None) + self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42) + for key, false_value in FALSE_VALUES.items(): + self.assertEqual(dict_get(d, ('b', 'c', key, )), None) + self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value) if __name__ == '__main__': |