diff options
author | Sergey M․ <dstftw@gmail.com> | 2015-10-31 22:39:44 +0600 |
---|---|---|
committer | Sergey M․ <dstftw@gmail.com> | 2015-10-31 22:39:44 +0600 |
commit | 578c074575f45ffdfd032d7b84f6fe449614f511 (patch) | |
tree | 009d49c675a8abc9b2b43ad6bb4a7081ad11b3e1 | |
parent | 8cdb5c845336ad3dc48c85a0558a38bd42972b00 (diff) |
[utils] Support list of xpath in xpath_element
-rw-r--r-- | test/test_utils.py | 7 | ||||
-rw-r--r-- | youtube_dl/utils.py | 15 |
2 files changed, 19 insertions, 3 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 0c34f0e55..5a56ad776 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -275,9 +275,16 @@ class TestUtil(unittest.TestCase): p = xml.etree.ElementTree.SubElement(div, 'p') p.text = 'Foo' self.assertEqual(xpath_element(doc, 'div/p'), p) + self.assertEqual(xpath_element(doc, ['div/p']), p) + self.assertEqual(xpath_element(doc, ['div/bar', 'div/p']), p) self.assertEqual(xpath_element(doc, 'div/bar', default='default'), 'default') + self.assertEqual(xpath_element(doc, ['div/bar'], default='default'), 'default') self.assertTrue(xpath_element(doc, 'div/bar') is None) + self.assertTrue(xpath_element(doc, ['div/bar']) is None) + self.assertTrue(xpath_element(doc, ['div/bar'], 'div/baz') is None) self.assertRaises(ExtractorError, xpath_element, doc, 'div/bar', fatal=True) + self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar'], fatal=True) + self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar', 'div/baz'], fatal=True) def test_xpath_text(self): testxml = '''<root> diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 558c9c7d5..89c88a4d3 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -178,10 +178,19 @@ def xpath_with_ns(path, ns_map): def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT): - if sys.version_info < (2, 7): # Crazy 2.6 - xpath = xpath.encode('ascii') + def _find_xpath(xpath): + if sys.version_info < (2, 7): # Crazy 2.6 + xpath = xpath.encode('ascii') + return node.find(xpath) + + if isinstance(xpath, (str, compat_str)): + n = _find_xpath(xpath) + else: + for xp in xpath: + n = _find_xpath(xp) + if n is not None: + break - n = node.find(xpath) if n is None: if default is not NO_DEFAULT: return default |