diff options
| author | dirkf <fieldhouse@gmx.net> | 2025-11-01 20:24:43 +0000 |
|---|---|---|
| committer | dirkf <fieldhouse@gmx.net> | 2025-11-21 01:52:11 +0000 |
| commit | 23a848c3141ad2ba1e7bb62708f5ed72ef81c98a (patch) | |
| tree | 652f25823c6181222c72e940ac45b9dafd5f39d1 | |
| parent | a96a77875023407233bae4111a36d113b756a4e3 (diff) | |
[utils] Add `partial_application` decorator function
Thx: yt-dlp/yt-dlp#10653
| -rw-r--r-- | test/test_utils.py | 16 | ||||
| -rw-r--r-- | youtube_dl/utils.py | 33 |
2 files changed, 49 insertions, 0 deletions
diff --git a/test/test_utils.py b/test/test_utils.py index 2947cce7e..9aca4df63 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -69,6 +69,7 @@ from youtube_dl.utils import ( parse_iso8601, parse_resolution, parse_qs, + partial_application, pkcs1pad, prepend_extension, read_batch_urls, @@ -1723,6 +1724,21 @@ Line 1 'a', 'b', 'c', 'd', from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') + def test_partial_application(self): + test_fn = partial_application(lambda x, kwarg=None: '{0}, kwarg={1!r}'.format(x, kwarg)) + self.assertTrue( + callable(test_fn(kwarg=10)), + 'missing positional parameter should apply partially') + self.assertEqual( + test_fn(10, kwarg=0.1), '10, kwarg=0.1', + 'positionally passed argument should call function') + self.assertEqual( + test_fn(x=10), '10, kwarg=None', + 'keyword passed positional should call function') + self.assertEqual( + test_fn(kwarg=0.1)(10), '10, kwarg=0.1', + 'call after partial application should call the function') + if __name__ == '__main__': unittest.main() diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 437257f5b..f2c02829b 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -1861,6 +1861,39 @@ def write_json_file(obj, fn): raise +class partial_application(object): + """Allow a function to use pre-set argument values""" + + # see _try_bind_args() + try: + inspect.signature + + @staticmethod + def required_args(fn): + return [ + param.name for param in inspect.signature(fn).parameters.values() + if (param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + and param.default is inspect.Parameter.empty)] + + except AttributeError: + + # Py < 3.3 + @staticmethod + def required_args(fn): + fn_args = inspect.getargspec(fn) + n_defaults = len(fn_args.defaults or []) + return (fn_args.args or [])[:-n_defaults if n_defaults > 0 else None] + + def __new__(cls, func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if set(cls.required_args(func)[len(args):]).difference(kwargs): + return functools.partial(func, *args, **kwargs) + return func(*args, **kwargs) + + return wrapped + + if sys.version_info >= (2, 7): def find_xpath_attr(node, xpath, key, val=None): """ Find the xpath xpath[@key=val] """ |
