aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey M․ <dstftw@gmail.com>2016-12-17 18:44:53 +0700
committerSergey M․ <dstftw@gmail.com>2016-12-17 18:49:55 +0700
commitb0c65c677f5298df8653df1e382b406bea420ba3 (patch)
tree4c3c654a572b305b8528adee10a9ce78b4e25146
parent594601f54570b8e79606002b6342dd5fcdc1f133 (diff)
downloadyoutube-dl-b0c65c677f5298df8653df1e382b406bea420ba3.tar.xz
[utils] Improve urljoin
-rw-r--r--test/test_utils.py3
-rw-r--r--youtube_dl/utils.py4
2 files changed, 5 insertions, 2 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 3f45b0bd1..1cdac82fc 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -448,11 +448,14 @@ class TestUtil(unittest.TestCase):
def test_urljoin(self):
self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
+ self.assertEqual(urljoin('//foo.de/', '/a/b/c.txt'), '//foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de/', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
+ self.assertEqual(urljoin('http://foo.de/', '//foo.de/a/b/c.txt'), '//foo.de/a/b/c.txt')
self.assertEqual(urljoin(None, 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
+ self.assertEqual(urljoin(None, '//foo.de/a/b/c.txt'), '//foo.de/a/b/c.txt')
self.assertEqual(urljoin('', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin(['foobar'], 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
self.assertEqual(urljoin('http://foo.de/', None), None)
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 694e9a600..528d87bb9 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -1703,9 +1703,9 @@ def base_url(url):
def urljoin(base, path):
if not isinstance(path, compat_str) or not path:
return None
- if re.match(r'https?://', path):
+ if re.match(r'^(?:https?:)?//', path):
return path
- if not isinstance(base, compat_str) or not re.match(r'https?://', base):
+ if not isinstance(base, compat_str) or not re.match(r'^(?:https?:)?//', base):
return None
return compat_urlparse.urljoin(base, path)