diff options
| author | Sergey M․ <dstftw@gmail.com> | 2016-02-07 08:13:04 +0600 | 
|---|---|---|
| committer | Sergey M․ <dstftw@gmail.com> | 2016-02-07 08:13:04 +0600 | 
| commit | 86296ad2cd5702a66b05aae79ec9196b62e41e9a (patch) | |
| tree | 55e6e2d21294d298d164881b274e69bc62e77dde | |
| parent | 52f5889f77e04420634d70145d689660ea6bfe24 (diff) | |
[utils] Add ability to control skipping false values in dict_get
| -rw-r--r-- | test/test_utils.py | 13 | ||||
| -rw-r--r-- | youtube_dl/utils.py | 7 | 
2 files changed, 15 insertions, 5 deletions
| diff --git a/test/test_utils.py b/test/test_utils.py index e3dd019af..909d0e51d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -452,9 +452,15 @@ class TestUtil(unittest.TestCase):          self.assertTrue(isinstance(data, bytes))      def test_dict_get(self): -        d = { -            'a': 42, +        FALSE_VALUES = { +            'none': None, +            'false': False, +            'zero': 0, +            'empty_string': '', +            'empty_list': [],          } +        d = FALSE_VALUES.copy() +        d['a'] = 42          self.assertEqual(dict_get(d, 'a'), 42)          self.assertEqual(dict_get(d, 'b'), None)          self.assertEqual(dict_get(d, 'b', 42), 42) @@ -463,6 +469,9 @@ class TestUtil(unittest.TestCase):          self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42)          self.assertEqual(dict_get(d, ('b', 'c', )), None)          self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42) +        for key, false_value in FALSE_VALUES.items(): +            self.assertEqual(dict_get(d, ('b', 'c', key, )), None) +            self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)      def test_encode_compat_str(self):          self.assertEqual(encode_compat_str(b'\xd1\x82\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'), 'тест') diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 652dba59d..f3b0180ab 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -1717,11 +1717,12 @@ def encode_dict(d, encoding='utf-8'):      return dict((encode(k), encode(v)) for k, v in d.items()) -def dict_get(d, key_or_keys, default=None): +def dict_get(d, key_or_keys, default=None, skip_false_values=True):      if isinstance(key_or_keys, (list, tuple)):          for key in key_or_keys: -            if d.get(key): -                return d[key] +            if key not in d or d[key] is None or skip_false_values and not d[key]: +                continue +            return d[key]          return default      return d.get(key_or_keys, default) | 
