diff options
| -rw-r--r-- | test/test_utils.py | 7 | ||||
| -rw-r--r-- | youtube_dl/utils.py | 15 | 
2 files changed, 19 insertions, 3 deletions
| diff --git a/test/test_utils.py b/test/test_utils.py index 0c34f0e55..5a56ad776 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -275,9 +275,16 @@ class TestUtil(unittest.TestCase):          p = xml.etree.ElementTree.SubElement(div, 'p')          p.text = 'Foo'          self.assertEqual(xpath_element(doc, 'div/p'), p) +        self.assertEqual(xpath_element(doc, ['div/p']), p) +        self.assertEqual(xpath_element(doc, ['div/bar', 'div/p']), p)          self.assertEqual(xpath_element(doc, 'div/bar', default='default'), 'default') +        self.assertEqual(xpath_element(doc, ['div/bar'], default='default'), 'default')          self.assertTrue(xpath_element(doc, 'div/bar') is None) +        self.assertTrue(xpath_element(doc, ['div/bar']) is None) +        self.assertTrue(xpath_element(doc, ['div/bar'], 'div/baz') is None)          self.assertRaises(ExtractorError, xpath_element, doc, 'div/bar', fatal=True) +        self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar'], fatal=True) +        self.assertRaises(ExtractorError, xpath_element, doc, ['div/bar', 'div/baz'], fatal=True)      def test_xpath_text(self):          testxml = '''<root> diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 558c9c7d5..89c88a4d3 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -178,10 +178,19 @@ def xpath_with_ns(path, ns_map):  def xpath_element(node, xpath, name=None, fatal=False, default=NO_DEFAULT): -    if sys.version_info < (2, 7):  # Crazy 2.6 -        xpath = xpath.encode('ascii') +    def _find_xpath(xpath): +        if sys.version_info < (2, 7):  # Crazy 2.6 +            xpath = xpath.encode('ascii') +        return node.find(xpath) + +    if isinstance(xpath, (str, compat_str)): +        n = _find_xpath(xpath) +    else: +        for xp in xpath: +            n = _find_xpath(xp) +            if n is not None: +                break -    n = node.find(xpath)      if n is None:          if default is not NO_DEFAULT:              return default | 
