aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_traversal.py17
-rw-r--r--test/test_utils.py14
-rw-r--r--yt_dlp/utils/_utils.py43
-rw-r--r--yt_dlp/utils/traversal.py14
4 files changed, 71 insertions, 17 deletions
diff --git a/test/test_traversal.py b/test/test_traversal.py
index 9179dadda..f1d123bd6 100644
--- a/test/test_traversal.py
+++ b/test/test_traversal.py
@@ -12,9 +12,10 @@ from yt_dlp.utils import (
str_or_none,
)
from yt_dlp.utils.traversal import (
- traverse_obj,
require,
subs_list_to_dict,
+ traverse_obj,
+ trim_str,
)
_TEST_DATA = {
@@ -495,6 +496,20 @@ class TestTraversalHelpers:
{'url': 'https://example.com/subs/en2', 'ext': 'ext'},
]}, '`quality` key should sort subtitle list accordingly'
+ def test_trim_str(self):
+ with pytest.raises(TypeError):
+ trim_str('positional')
+
+ assert callable(trim_str(start='a'))
+ assert trim_str(start='ab')('abc') == 'c'
+ assert trim_str(end='bc')('abc') == 'a'
+ assert trim_str(start='a', end='c')('abc') == 'b'
+ assert trim_str(start='ab', end='c')('abc') == ''
+ assert trim_str(start='a', end='bc')('abc') == ''
+ assert trim_str(start='ab', end='bc')('abc') == ''
+ assert trim_str(start='abc', end='abc')('abc') == ''
+ assert trim_str(start='', end='')('abc') == 'abc'
+
class TestDictGet:
def test_dict_get(self):
diff --git a/test/test_utils.py b/test/test_utils.py
index d4b846f56..04f91547a 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -4,6 +4,7 @@
import os
import sys
import unittest
+import unittest.mock
import warnings
import datetime as dt
@@ -71,6 +72,7 @@ from yt_dlp.utils import (
intlist_to_bytes,
iri_to_uri,
is_html,
+ join_nonempty,
js_to_json,
limit_length,
locked_file,
@@ -343,11 +345,13 @@ class TestUtil(unittest.TestCase):
self.assertEqual(remove_start(None, 'A - '), None)
self.assertEqual(remove_start('A - B', 'A - '), 'B')
self.assertEqual(remove_start('B - A', 'A - '), 'B - A')
+ self.assertEqual(remove_start('non-empty', ''), 'non-empty')
def test_remove_end(self):
self.assertEqual(remove_end(None, ' - B'), None)
self.assertEqual(remove_end('A - B', ' - B'), 'A')
self.assertEqual(remove_end('B - A', ' - B'), 'B - A')
+ self.assertEqual(remove_end('non-empty', ''), 'non-empty')
def test_remove_quotes(self):
self.assertEqual(remove_quotes(None), None)
@@ -2148,6 +2152,16 @@ Line 1
assert run_shell(args) == expected
assert run_shell(shell_quote(args, shell=True)) == expected
+ def test_partial_application(self):
+ assert callable(int_or_none(scale=10)), 'missing positional parameter should apply partially'
+ assert int_or_none(10, scale=0.1) == 100, 'positionally passed argument should call function'
+ assert int_or_none(v=10) == 10, 'keyword passed positional should call function'
+ assert int_or_none(scale=0.1)(10) == 100, 'call after partial applicatino should call the function'
+
+ assert callable(join_nonempty(delim=', ')), 'varargs positional should apply partially'
+ assert callable(join_nonempty()), 'varargs positional should apply partially'
+ assert join_nonempty(None, delim=', ') == '', 'passed varargs should call the function'
+
if __name__ == '__main__':
unittest.main()
diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py
index 8535d2830..e30008e93 100644
--- a/yt_dlp/utils/_utils.py
+++ b/yt_dlp/utils/_utils.py
@@ -212,6 +212,23 @@ def write_json_file(obj, fn):
raise
+def partial_application(func):
+ sig = inspect.signature(func)
+ required_args = [
+ param.name for param in sig.parameters.values()
+ if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
+ if param.default is inspect.Parameter.empty
+ ]
+
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if set(required_args[len(args):]).difference(kwargs):
+ return functools.partial(func, *args, **kwargs)
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
def find_xpath_attr(node, xpath, key, val=None):
""" Find the xpath xpath[@key=val] """
assert re.match(r'^[a-zA-Z_-]+$', key)
@@ -1192,6 +1209,7 @@ def extract_timezone(date_str, default=None):
return timezone, date_str
+@partial_application
def parse_iso8601(date_str, delimiter='T', timezone=None):
""" Return a UNIX timestamp from the given date """
@@ -1269,6 +1287,7 @@ def unified_timestamp(date_str, day_first=True):
return calendar.timegm(timetuple) + pm_delta * 3600 - timezone.total_seconds()
+@partial_application
def determine_ext(url, default_ext='unknown_video'):
if url is None or '.' not in url:
return default_ext
@@ -1944,7 +1963,7 @@ def remove_start(s, start):
def remove_end(s, end):
- return s[:-len(end)] if s is not None and s.endswith(end) else s
+ return s[:-len(end)] if s is not None and end and s.endswith(end) else s
def remove_quotes(s):
@@ -1973,6 +1992,7 @@ def base_url(url):
return re.match(r'https?://[^?#]+/', url).group()
+@partial_application
def urljoin(base, path):
if isinstance(path, bytes):
path = path.decode()
@@ -1988,21 +2008,6 @@ def urljoin(base, path):
return urllib.parse.urljoin(base, path)
-def partial_application(func):
- sig = inspect.signature(func)
-
- @functools.wraps(func)
- def wrapped(*args, **kwargs):
- try:
- sig.bind(*args, **kwargs)
- except TypeError:
- return functools.partial(func, *args, **kwargs)
- else:
- return func(*args, **kwargs)
-
- return wrapped
-
-
@partial_application
def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None):
if get_attr and v is not None:
@@ -2583,6 +2588,7 @@ def urlencode_postdata(*args, **kargs):
return urllib.parse.urlencode(*args, **kargs).encode('ascii')
+@partial_application
def update_url(url, *, query_update=None, **kwargs):
"""Replace URL components specified by kwargs
@param url str or parse url tuple
@@ -2603,6 +2609,7 @@ def update_url(url, *, query_update=None, **kwargs):
return urllib.parse.urlunparse(url._replace(**kwargs))
+@partial_application
def update_url_query(url, query):
return update_url(url, query_update=query)
@@ -2924,6 +2931,7 @@ def error_to_str(err):
return f'{type(err).__name__}: {err}'
+@partial_application
def mimetype2ext(mt, default=NO_DEFAULT):
if not isinstance(mt, str):
if default is not NO_DEFAULT:
@@ -4664,6 +4672,7 @@ def to_high_limit_path(path):
return path
+@partial_application
def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=IDENTITY):
val = traversal.traverse_obj(obj, *variadic(field))
if not val if ignore is NO_DEFAULT else val in variadic(ignore):
@@ -4828,6 +4837,7 @@ def number_of_digits(number):
return len('%d' % number)
+@partial_application
def join_nonempty(*values, delim='-', from_dict=None):
if from_dict is not None:
values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values)
@@ -5278,6 +5288,7 @@ class RetryManager:
time.sleep(delay)
+@partial_application
def make_archive_id(ie, video_id):
ie_key = ie if isinstance(ie, str) else ie.ie_key()
return f'{ie_key.lower()} {video_id}'
diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py
index 0eef817ea..dd9b4690b 100644
--- a/yt_dlp/utils/traversal.py
+++ b/yt_dlp/utils/traversal.py
@@ -435,6 +435,20 @@ def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False):
return functools.partial(func, cls)
+def trim_str(*, start=None, end=None):
+ def trim(s):
+ if s is None:
+ return None
+ start_idx = 0
+ if start and s.startswith(start):
+ start_idx = len(start)
+ if end and s.endswith(end):
+ return s[start_idx:-len(end)]
+ return s[start_idx:]
+
+ return trim
+
+
def get_first(obj, *paths, **kwargs):
return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)