diff options
author | Sergey M․ <dstftw@gmail.com> | 2016-02-07 06:12:53 +0600 |
---|---|---|
committer | Sergey M․ <dstftw@gmail.com> | 2016-02-07 06:12:53 +0600 |
commit | cbecc9b9039d5166185a41ca4d9d6c4d11595c52 (patch) | |
tree | bc0fc97d71af4a36b5deb236f7e957240ea9efa0 | |
parent | b8b465af3e83fb19c1818c2fa83f0c5f753dd917 (diff) |
[utils] Add dict_get convenience method
-rw-r--r-- | test/test_utils.py | 14 | ||||
-rw-r--r-- | youtube_dl/utils.py | 9 |
2 files changed, 23 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 1c3290d9b..e3dd019af 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -22,6 +22,7 @@ from youtube_dl.utils import ( DateRange, detect_exe_version, determine_ext, + dict_get, encode_compat_str, encodeFilename, escape_rfc3986, @@ -450,6 +451,19 @@ class TestUtil(unittest.TestCase): data = urlencode_postdata({'username': 'foo@bar.com', 'password': '1234'}) self.assertTrue(isinstance(data, bytes)) + def test_dict_get(self): + d = { + 'a': 42, + } + self.assertEqual(dict_get(d, 'a'), 42) + self.assertEqual(dict_get(d, 'b'), None) + self.assertEqual(dict_get(d, 'b', 42), 42) + self.assertEqual(dict_get(d, ('a', )), 42) + self.assertEqual(dict_get(d, ('b', 'a', )), 42) + self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42) + self.assertEqual(dict_get(d, ('b', 'c', )), None) + self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42) + def test_encode_compat_str(self): self.assertEqual(encode_compat_str(b'\xd1\x82\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'), 'тест') self.assertEqual(encode_compat_str('тест', 'utf-8'), 'тест') diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 4262ad6ac..652dba59d 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -1717,6 +1717,15 @@ def encode_dict(d, encoding='utf-8'): return dict((encode(k), encode(v)) for k, v in d.items()) +def dict_get(d, key_or_keys, default=None): + if isinstance(key_or_keys, (list, tuple)): + for key in key_or_keys: + if d.get(key): + return d[key] + return default + return d.get(key_or_keys, default) + + def encode_compat_str(string, encoding=preferredencoding(), errors='strict'): return string if isinstance(string, compat_str) else compat_str(string, encoding, errors) |