aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey M․ <dstftw@gmail.com>2015-10-31 22:39:44 +0600
committerSergey M․ <dstftw@gmail.com>2015-10-31 22:39:44 +0600
commit578c074575f45ffdfd032d7b84f6fe449614f511 (patch)
tree009d49c675a8abc9b2b43ad6bb4a7081ad11b3e1
parent8cdb5c845336ad3dc48c85a0558a38bd42972b00 (diff)
downloadyoutube-dl-578c074575f45ffdfd032d7b84f6fe449614f511.tar.xz
[utils] Support list of xpath in xpath_element
-rw-r--r--test/test_utils.py7
-rw-r--r--youtube_dl/utils.py15
2 files changed, 19 insertions, 3 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 0c34f0e55..5a56ad776 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -275,9 +275,16 @@ class TestUtil(unittest.TestCase):
p = xml.etree.ElementTree.SubElement(div, 'p')
p.text = 'Foo'
self.assertEqual(xpath_element(doc, 'div/p'), p)
+ self.assertEqual(xpath_element(doc, ['div/p']), p)
+ self.assertEqual(xpath_element(doc, ['div/bar', 'div/p']), p)
self.assertEqual(xpath_element(doc, 'div/bar', default='default'), 'default')
+ self.assertEqual(xpath_element(doc, ['div/bar'], default='default'), 'default')
self.assertTrue(xpath_element(doc, 'div/bar') is None)
+ self.assertTrue(xpath_element(doc, ['div/bar']) is None)
+ self.assertTrue(xpath_element(doc, ['div/bar'], 'div/baz') is None)
self.assertRaises(ExtractorError, xpath_element, doc, 'div/bar', fatal=True)
+ self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar'], fatal=True)
+ self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar', 'div/baz'], fatal=True)
def test_xpath_text(self):
testxml = '''<root>
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 558c9c7d5..89c88a4d3 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -178,10 +178,19 @@ def xpath_with_ns(path, ns_map):
def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT):
- if sys.version_info < (2, 7): # Crazy 2.6
- xpath = xpath.encode('ascii')
+ def _find_xpath(xpath):
+ if sys.version_info < (2, 7): # Crazy 2.6
+ xpath = xpath.encode('ascii')
+ return node.find(xpath)
+
+ if isinstance(xpath, (str, compat_str)):
+ n = _find_xpath(xpath)
+ else:
+ for xp in xpath:
+ n = _find_xpath(xp)
+ if n is not None:
+ break
- n = node.find(xpath)
if n is None:
if default is not NO_DEFAULT:
return default