diff options
Diffstat (limited to 'yt_dlp/utils/_utils.py')
-rw-r--r-- | yt_dlp/utils/_utils.py | 43 |
1 files changed, 27 insertions, 16 deletions
diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index 8535d2830..e30008e93 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -212,6 +212,23 @@ def write_json_file(obj, fn): raise +def partial_application(func): + sig = inspect.signature(func) + required_args = [ + param.name for param in sig.parameters.values() + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + if param.default is inspect.Parameter.empty + ] + + @functools.wraps(func) + def wrapped(*args, **kwargs): + if set(required_args[len(args):]).difference(kwargs): + return functools.partial(func, *args, **kwargs) + return func(*args, **kwargs) + + return wrapped + + def find_xpath_attr(node, xpath, key, val=None): """ Find the xpath xpath[@key=val] """ assert re.match(r'^[a-zA-Z_-]+$', key) @@ -1192,6 +1209,7 @@ def extract_timezone(date_str, default=None): return timezone, date_str +@partial_application def parse_iso8601(date_str, delimiter='T', timezone=None): """ Return a UNIX timestamp from the given date """ @@ -1269,6 +1287,7 @@ def unified_timestamp(date_str, day_first=True): return calendar.timegm(timetuple) + pm_delta * 3600 - timezone.total_seconds() +@partial_application def determine_ext(url, default_ext='unknown_video'): if url is None or '.' not in url: return default_ext @@ -1944,7 +1963,7 @@ def remove_start(s, start): def remove_end(s, end): - return s[:-len(end)] if s is not None and s.endswith(end) else s + return s[:-len(end)] if s is not None and end and s.endswith(end) else s def remove_quotes(s): @@ -1973,6 +1992,7 @@ def base_url(url): return re.match(r'https?://[^?#]+/', url).group() +@partial_application def urljoin(base, path): if isinstance(path, bytes): path = path.decode() @@ -1988,21 +2008,6 @@ def urljoin(base, path): return urllib.parse.urljoin(base, path) -def partial_application(func): - sig = inspect.signature(func) - - @functools.wraps(func) - def wrapped(*args, **kwargs): - try: - sig.bind(*args, **kwargs) - except TypeError: - return functools.partial(func, *args, **kwargs) - else: - return func(*args, **kwargs) - - return wrapped - - @partial_application def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None): if get_attr and v is not None: @@ -2583,6 +2588,7 @@ def urlencode_postdata(*args, **kargs): return urllib.parse.urlencode(*args, **kargs).encode('ascii') +@partial_application def update_url(url, *, query_update=None, **kwargs): """Replace URL components specified by kwargs @param url str or parse url tuple @@ -2603,6 +2609,7 @@ def update_url(url, *, query_update=None, **kwargs): return urllib.parse.urlunparse(url._replace(**kwargs)) +@partial_application def update_url_query(url, query): return update_url(url, query_update=query) @@ -2924,6 +2931,7 @@ def error_to_str(err): return f'{type(err).__name__}: {err}' +@partial_application def mimetype2ext(mt, default=NO_DEFAULT): if not isinstance(mt, str): if default is not NO_DEFAULT: @@ -4664,6 +4672,7 @@ def to_high_limit_path(path): return path +@partial_application def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=IDENTITY): val = traversal.traverse_obj(obj, *variadic(field)) if not val if ignore is NO_DEFAULT else val in variadic(ignore): @@ -4828,6 +4837,7 @@ def number_of_digits(number): return len('%d' % number) +@partial_application def join_nonempty(*values, delim='-', from_dict=None): if from_dict is not None: values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values) @@ -5278,6 +5288,7 @@ class RetryManager: time.sleep(delay) +@partial_application def make_archive_id(ie, video_id): ie_key = ie if isinstance(ie, str) else ie.ie_key() return f'{ie_key.lower()} {video_id}' |