diff options
| -rw-r--r-- | test/test_utils.py | 14 | ||||
| -rw-r--r-- | youtube_dl/utils.py | 10 | 
2 files changed, 24 insertions, 0 deletions
| diff --git a/test/test_utils.py b/test/test_utils.py index 2e3cd0179..3f45b0bd1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -70,6 +70,7 @@ from youtube_dl.utils import (      lowercase_escape,      url_basename,      base_url, +    urljoin,      urlencode_postdata,      urshift,      update_url_query, @@ -445,6 +446,19 @@ class TestUtil(unittest.TestCase):          self.assertEqual(base_url('http://foo.de/bar/baz'), 'http://foo.de/bar/')          self.assertEqual(base_url('http://foo.de/bar/baz?x=z/x/c'), 'http://foo.de/bar/') +    def test_urljoin(self): +        self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin('http://foo.de', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin('http://foo.de/', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin(None, 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin('', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin(['foobar'], 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') +        self.assertEqual(urljoin('http://foo.de/', None), None) +        self.assertEqual(urljoin('http://foo.de/', ''), None) +        self.assertEqual(urljoin('http://foo.de/', ['foobar']), None) +      def test_parse_age_limit(self):          self.assertEqual(parse_age_limit(None), None)          self.assertEqual(parse_age_limit(False), None) diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 3d4951ad9..694e9a600 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -1700,6 +1700,16 @@ def base_url(url):      return re.match(r'https?://[^?#&]+/', url).group() +def urljoin(base, path): +    if not isinstance(path, compat_str) or not path: +        return None +    if re.match(r'https?://', path): +        return path +    if not isinstance(base, compat_str) or not re.match(r'https?://', base): +        return None +    return compat_urlparse.urljoin(base, path) + +  class HEADRequest(compat_urllib_request.Request):      def get_method(self):          return 'HEAD' | 
