aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordirkf <fieldhouse@gmx.net>2025-11-01 20:24:43 +0000
committerdirkf <fieldhouse@gmx.net>2025-11-21 01:52:11 +0000
commit23a848c3141ad2ba1e7bb62708f5ed72ef81c98a (patch)
tree652f25823c6181222c72e940ac45b9dafd5f39d1
parenta96a77875023407233bae4111a36d113b756a4e3 (diff)
[utils] Add `partial_application` decorator function
Thx: yt-dlp/yt-dlp#10653
-rw-r--r--test/test_utils.py16
-rw-r--r--youtube_dl/utils.py33
2 files changed, 49 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py
index 2947cce7e..9aca4df63 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -69,6 +69,7 @@ from youtube_dl.utils import (
parse_iso8601,
parse_resolution,
parse_qs,
+ partial_application,
pkcs1pad,
prepend_extension,
read_batch_urls,
@@ -1723,6 +1724,21 @@ Line 1
'a', 'b', 'c', 'd',
from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
+ def test_partial_application(self):
+ test_fn = partial_application(lambda x, kwarg=None: '{0}, kwarg={1!r}'.format(x, kwarg))
+ self.assertTrue(
+ callable(test_fn(kwarg=10)),
+ 'missing positional parameter should apply partially')
+ self.assertEqual(
+ test_fn(10, kwarg=0.1), '10, kwarg=0.1',
+ 'positionally passed argument should call function')
+ self.assertEqual(
+ test_fn(x=10), '10, kwarg=None',
+ 'keyword passed positional should call function')
+ self.assertEqual(
+ test_fn(kwarg=0.1)(10), '10, kwarg=0.1',
+ 'call after partial application should call the function')
+
if __name__ == '__main__':
unittest.main()
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 437257f5b..f2c02829b 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -1861,6 +1861,39 @@ def write_json_file(obj, fn):
raise
+class partial_application(object):
+ """Allow a function to use pre-set argument values"""
+
+ # see _try_bind_args()
+ try:
+ inspect.signature
+
+ @staticmethod
+ def required_args(fn):
+ return [
+ param.name for param in inspect.signature(fn).parameters.values()
+ if (param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ and param.default is inspect.Parameter.empty)]
+
+ except AttributeError:
+
+ # Py < 3.3
+ @staticmethod
+ def required_args(fn):
+ fn_args = inspect.getargspec(fn)
+ n_defaults = len(fn_args.defaults or [])
+ return (fn_args.args or [])[:-n_defaults if n_defaults > 0 else None]
+
+ def __new__(cls, func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if set(cls.required_args(func)[len(args):]).difference(kwargs):
+ return functools.partial(func, *args, **kwargs)
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
if sys.version_info >= (2, 7):
def find_xpath_attr(node, xpath, key, val=None):
""" Find the xpath xpath[@key=val] """