aboutsummaryrefslogtreecommitdiff
path: root/test/test_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_utils.py')
-rw-r--r--test/test_utils.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 251739686..3ff1f8b55 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -130,6 +130,7 @@ from yt_dlp.utils import (
xpath_text,
xpath_with_ns,
)
+from yt_dlp.utils._utils import _UnsafeExtensionError
from yt_dlp.utils.networking import (
HTTPHeaderDict,
escape_rfc3986,
@@ -281,6 +282,13 @@ class TestUtil(unittest.TestCase):
finally:
os.environ['HOME'] = old_home or ''
+ _uncommon_extensions = [
+ ('exe', 'abc.exe.ext'),
+ ('de', 'abc.de.ext'),
+ ('../.mp4', None),
+ ('..\\.mp4', None),
+ ]
+
def test_prepend_extension(self):
self.assertEqual(prepend_extension('abc.ext', 'temp'), 'abc.temp.ext')
self.assertEqual(prepend_extension('abc.ext', 'temp', 'ext'), 'abc.temp.ext')
@@ -289,6 +297,19 @@ class TestUtil(unittest.TestCase):
self.assertEqual(prepend_extension('.abc', 'temp'), '.abc.temp')
self.assertEqual(prepend_extension('.abc.ext', 'temp'), '.abc.temp.ext')
+ # Test uncommon extensions
+ self.assertEqual(prepend_extension('abc.ext', 'bin'), 'abc.bin.ext')
+ for ext, result in self._uncommon_extensions:
+ with self.assertRaises(_UnsafeExtensionError):
+ prepend_extension('abc', ext)
+ if result:
+ self.assertEqual(prepend_extension('abc.ext', ext, 'ext'), result)
+ else:
+ with self.assertRaises(_UnsafeExtensionError):
+ prepend_extension('abc.ext', ext, 'ext')
+ with self.assertRaises(_UnsafeExtensionError):
+ prepend_extension('abc.unexpected_ext', ext, 'ext')
+
def test_replace_extension(self):
self.assertEqual(replace_extension('abc.ext', 'temp'), 'abc.temp')
self.assertEqual(replace_extension('abc.ext', 'temp', 'ext'), 'abc.temp')
@@ -297,6 +318,16 @@ class TestUtil(unittest.TestCase):
self.assertEqual(replace_extension('.abc', 'temp'), '.abc.temp')
self.assertEqual(replace_extension('.abc.ext', 'temp'), '.abc.temp')
+ # Test uncommon extensions
+ self.assertEqual(replace_extension('abc.ext', 'bin'), 'abc.unknown_video')
+ for ext, _ in self._uncommon_extensions:
+ with self.assertRaises(_UnsafeExtensionError):
+ replace_extension('abc', ext)
+ with self.assertRaises(_UnsafeExtensionError):
+ replace_extension('abc.ext', ext, 'ext')
+ with self.assertRaises(_UnsafeExtensionError):
+ replace_extension('abc.unexpected_ext', ext, 'ext')
+
def test_subtitles_filename(self):
self.assertEqual(subtitles_filename('abc.ext', 'en', 'vtt'), 'abc.en.vtt')
self.assertEqual(subtitles_filename('abc.ext', 'en', 'vtt', 'ext'), 'abc.en.vtt')