diff options
Diffstat (limited to 'youtube_dl/utils.py')
| -rw-r--r-- | youtube_dl/utils.py | 440 |
1 files changed, 370 insertions, 70 deletions
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 113c913df..02a49ff49 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -53,6 +53,8 @@ from .compat import ( compat_etree_fromstring, compat_etree_iterfind, compat_expanduser, + compat_filter as filter, + compat_filter_fns, compat_html_entities, compat_html_entities_html5, compat_http_client, @@ -1717,21 +1719,6 @@ TIMEZONE_NAMES = { 'PST': -8, 'PDT': -7 # Pacific } -KNOWN_EXTENSIONS = ( - 'mp4', 'm4a', 'm4p', 'm4b', 'm4r', 'm4v', 'aac', - 'flv', 'f4v', 'f4a', 'f4b', - 'webm', 'ogg', 'ogv', 'oga', 'ogx', 'spx', 'opus', - 'mkv', 'mka', 'mk3d', - 'avi', 'divx', - 'mov', - 'asf', 'wmv', 'wma', - '3gp', '3g2', - 'mp3', - 'flac', - 'ape', - 'wav', - 'f4f', 'f4m', 'm3u8', 'smil') - # needed for sanitizing filenames in restricted mode ACCENT_CHARS = dict(zip('ÂÃÄÀÁÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖŐØŒÙÚÛÜŰÝÞßàáâãäåæçèéêëìíîïðñòóôõöőøœùúûüűýþÿ', itertools.chain('AAAAAA', ['AE'], 'CEEEEIIIIDNOOOOOOO', ['OE'], 'UUUUUY', ['TH', 'ss'], @@ -1874,6 +1861,39 @@ def write_json_file(obj, fn): raise +class partial_application(object): + """Allow a function to use pre-set argument values""" + + # see _try_bind_args() + try: + inspect.signature + + @staticmethod + def required_args(fn): + return [ + param.name for param in inspect.signature(fn).parameters.values() + if (param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + and param.default is inspect.Parameter.empty)] + + except AttributeError: + + # Py < 3.3 + @staticmethod + def required_args(fn): + fn_args = inspect.getargspec(fn) + n_defaults = len(fn_args.defaults or []) + return (fn_args.args or [])[:-n_defaults if n_defaults > 0 else None] + + def __new__(cls, func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if set(cls.required_args(func)[len(args):]).difference(kwargs): + return functools.partial(func, *args, **kwargs) + return func(*args, **kwargs) + + return wrapped + + if sys.version_info >= (2, 7): def find_xpath_attr(node, xpath, key, val=None): """ Find the xpath xpath[@key=val] """ @@ -3167,6 +3187,7 @@ def extract_timezone(date_str): return timezone, date_str +@partial_application def parse_iso8601(date_str, delimiter='T', timezone=None): """ Return a UNIX timestamp from the given date """ @@ -3244,6 +3265,7 @@ def unified_timestamp(date_str, day_first=True): return calendar.timegm(timetuple) + pm_delta * 3600 - compat_datetime_timedelta_total_seconds(timezone) +@partial_application def determine_ext(url, default_ext='unknown_video'): if url is None or '.' not in url: return default_ext @@ -3822,6 +3844,7 @@ def base_url(url): return re.match(r'https?://[^?#&]+/', url).group() +@partial_application def urljoin(base, path): path = _decode_compat_str(path, encoding='utf-8', or_none=True) if not path: @@ -3846,6 +3869,7 @@ class PUTRequest(compat_urllib_request.Request): return 'PUT' +@partial_application def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None): if get_attr: if v is not None: @@ -3872,6 +3896,7 @@ def str_to_int(int_str): return int_or_none(int_str) +@partial_application def float_or_none(v, scale=1, invscale=1, default=None): if v is None: return default @@ -3906,38 +3931,46 @@ def parse_duration(s): return None s = s.strip() + if not s: + return None days, hours, mins, secs, ms = [None] * 5 - m = re.match(r'(?:(?:(?:(?P<days>[0-9]+):)?(?P<hours>[0-9]+):)?(?P<mins>[0-9]+):)?(?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?Z?$', s) + m = re.match(r'''(?x) + (?P<before_secs> + (?:(?:(?P<days>[0-9]+):)?(?P<hours>[0-9]+):)? + (?P<mins>[0-9]+):)? + (?P<secs>(?(before_secs)[0-9]{1,2}|[0-9]+)) + (?:[.:](?P<ms>[0-9]+))?Z?$ + ''', s) if m: - days, hours, mins, secs, ms = m.groups() + days, hours, mins, secs, ms = m.group('days', 'hours', 'mins', 'secs', 'ms') else: m = re.match( r'''(?ix)(?:P? (?: - [0-9]+\s*y(?:ears?)?\s* + [0-9]+\s*y(?:ears?)?,?\s* )? (?: - [0-9]+\s*m(?:onths?)?\s* + [0-9]+\s*m(?:onths?)?,?\s* )? (?: - [0-9]+\s*w(?:eeks?)?\s* + [0-9]+\s*w(?:eeks?)?,?\s* )? (?: - (?P<days>[0-9]+)\s*d(?:ays?)?\s* + (?P<days>[0-9]+)\s*d(?:ays?)?,?\s* )? T)? (?: - (?P<hours>[0-9]+)\s*h(?:ours?)?\s* + (?P<hours>[0-9]+)\s*h(?:(?:ou)?rs?)?,?\s* )? (?: - (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?\s* + (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?,?\s* )? (?: - (?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?\s*s(?:ec(?:ond)?s?)?\s* + (?P<secs>[0-9]+)(?:\.(?P<ms>[0-9]+))?\s*s(?:ec(?:ond)?s?)?\s* )?Z?$''', s) if m: - days, hours, mins, secs, ms = m.groups() + days, hours, mins, secs, ms = m.group('days', 'hours', 'mins', 'secs', 'ms') else: m = re.match(r'(?i)(?:(?P<hours>[0-9.]+)\s*(?:hours?)|(?P<mins>[0-9.]+)\s*(?:mins?\.?|minutes?)\s*)Z?$', s) if m: @@ -3945,33 +3978,32 @@ def parse_duration(s): else: return None - duration = 0 - if secs: - duration += float(secs) - if mins: - duration += float(mins) * 60 - if hours: - duration += float(hours) * 60 * 60 - if days: - duration += float(days) * 24 * 60 * 60 - if ms: - duration += float(ms) + duration = ( + ((((float(days) * 24) if days else 0) + + (float(hours) if hours else 0)) * 60 + + (float(mins) if mins else 0)) * 60 + + (float(secs) if secs else 0) + + (float(ms) / 10 ** len(ms) if ms else 0)) + return duration -def prepend_extension(filename, ext, expected_real_ext=None): +def _change_extension(prepend, filename, ext, expected_real_ext=None): name, real_ext = os.path.splitext(filename) - return ( - '{0}.{1}{2}'.format(name, ext, real_ext) - if not expected_real_ext or real_ext[1:] == expected_real_ext - else '{0}.{1}'.format(filename, ext)) + sanitize_extension = _UnsafeExtensionError.sanitize_extension + if not expected_real_ext or real_ext.partition('.')[0::2] == ('', expected_real_ext): + filename = name + if prepend and real_ext: + sanitize_extension(ext, prepend=prepend) + return ''.join((filename, '.', ext, real_ext)) -def replace_extension(filename, ext, expected_real_ext=None): - name, real_ext = os.path.splitext(filename) - return '{0}.{1}'.format( - name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename, - ext) + # Mitigate path traversal and file impersonation attacks + return '.'.join((filename, sanitize_extension(ext))) + + +prepend_extension = functools.partial(_change_extension, True) +replace_extension = functools.partial(_change_extension, False) def check_executable(exe, args=[]): @@ -4216,12 +4248,16 @@ def lowercase_escape(s): s) -def escape_rfc3986(s): +def escape_rfc3986(s, safe=None): """Escape non-ASCII characters as suggested by RFC 3986""" if sys.version_info < (3, 0): s = _encode_compat_str(s, 'utf-8') + if safe is not None: + safe = _encode_compat_str(safe, 'utf-8') + if safe is None: + safe = b"%/;:@&=+$,!~*'()?#[]" # ensure unicode: after quoting, it can always be converted - return compat_str(compat_urllib_parse.quote(s, b"%/;:@&=+$,!~*'()?#[]")) + return compat_str(compat_urllib_parse.quote(s, safe)) def escape_url(url): @@ -4259,6 +4295,7 @@ def urlencode_postdata(*args, **kargs): return compat_urllib_parse_urlencode(*args, **kargs).encode('ascii') +@partial_application def update_url(url, **kwargs): """Replace URL components specified by kwargs url: compat_str or parsed URL tuple @@ -4280,6 +4317,7 @@ def update_url(url, **kwargs): return compat_urllib_parse.urlunparse(url._replace(**kwargs)) +@partial_application def update_url_query(url, query): return update_url(url, query_update=query) @@ -4706,30 +4744,45 @@ def parse_codecs(codecs_str): if not codecs_str: return {} split_codecs = list(filter(None, map( - lambda str: str.strip(), codecs_str.strip().strip(',').split(',')))) - vcodec, acodec = None, None + lambda s: s.strip(), codecs_str.strip().split(',')))) + vcodec, acodec, hdr = None, None, None for full_codec in split_codecs: - codec = full_codec.split('.')[0] - if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v', 'hvc1', 'av01', 'theora'): - if not vcodec: - vcodec = full_codec - elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'): + codec, rest = full_codec.partition('.')[::2] + codec = codec.lower() + full_codec = '.'.join((codec, rest)) if rest else codec + codec = re.sub(r'0+(?=\d)', '', codec) + if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', + 'h263', 'h264', 'mp4v', 'hvc1', 'av1', 'theora', 'dvh1', 'dvhe'): + if vcodec: + continue + vcodec = full_codec + if codec in ('dvh1', 'dvhe'): + hdr = 'DV' + elif codec in ('av1', 'vp9'): + n, m = { + 'av1': (2, '10'), + 'vp9': (0, '2'), + }[codec] + if (rest.split('.', n + 1)[n:] or [''])[0].lstrip('0') == m: + hdr = 'HDR10' + elif codec in ('flac', 'mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-4', + 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'): if not acodec: acodec = full_codec else: - write_string('WARNING: Unknown codec %s\n' % full_codec, sys.stderr) - if not vcodec and not acodec: - if len(split_codecs) == 2: - return { - 'vcodec': split_codecs[0], - 'acodec': split_codecs[1], - } - else: - return { + write_string('WARNING: Unknown codec %s\n' % (full_codec,), sys.stderr) + + return ( + filter_dict({ 'vcodec': vcodec or 'none', 'acodec': acodec or 'none', - } - return {} + 'dynamic_range': hdr, + }) if vcodec or acodec + else { + 'vcodec': split_codecs[0], + 'acodec': split_codecs[1], + } if len(split_codecs) == 2 + else {}) def urlhandle_detect_ext(url_handle): @@ -6291,6 +6344,7 @@ def traverse_obj(obj, *paths, **kwargs): Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. - `any`-builtin: Take the first matching object and return it, resetting branching. - `all`-builtin: Take all matching objects and return them as a list, resetting branching. + - `filter`-builtin: Return the value if it is truthy, `None` otherwise. `tuple`, `list`, and `dict` all support nested paths and branches. @@ -6332,6 +6386,11 @@ def traverse_obj(obj, *paths, **kwargs): # instant compat str = compat_str + from .compat import ( + compat_builtins_dict as dict_, # the basic dict type + compat_dict as dict, # dict preserving imsertion order + ) + casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k if isinstance(expected_type, type): @@ -6414,7 +6473,7 @@ def traverse_obj(obj, *paths, **kwargs): if not branching: # string traversal result = ''.join(result) - elif isinstance(key, dict): + elif isinstance(key, dict_): iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items()) result = dict((k, v if v is not None else default) for k, v in iter_obj if v is not None or default is not NO_DEFAULT) or None @@ -6492,7 +6551,7 @@ def traverse_obj(obj, *paths, **kwargs): has_branched = False key = None - for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): + for last, key in lazy_last(variadic(path, (str, bytes, dict_, set))): if not casesense and isinstance(key, str): key = compat_casefold(key) @@ -6505,6 +6564,11 @@ def traverse_obj(obj, *paths, **kwargs): objs = (list(filtered_objs),) continue + # filter might be from __builtin__, future_builtins, or itertools.ifilter + if key in compat_filter_fns: + objs = filter(None, objs) + continue + if __debug__ and callable(key): # Verify function signature _try_bind_args(key, None, None) @@ -6517,10 +6581,10 @@ def traverse_obj(obj, *paths, **kwargs): objs = from_iterable(new_objs) - if test_type and not isinstance(key, (dict, list, tuple)): + if test_type and not isinstance(key, (dict_, list, tuple)): objs = map(type_test, objs) - return objs, has_branched, isinstance(key, dict) + return objs, has_branched, isinstance(key, dict_) def _traverse_obj(obj, path, allow_empty, test_type): results, has_branched, is_dict = apply_path(obj, path, test_type) @@ -6543,6 +6607,76 @@ def traverse_obj(obj, *paths, **kwargs): return None if default is NO_DEFAULT else default +def value(value): + return lambda _: value + + +class require(ExtractorError): + def __init__(self, name, expected=False): + super(require, self).__init__( + 'Unable to extract {0}'.format(name), expected=expected) + + def __call__(self, value): + if value is None: + raise self + + return value + + +@partial_application +# typing: (subs: list[dict], /, *, lang='und', ext=None) -> dict[str, list[dict] +def subs_list_to_dict(subs, lang='und', ext=None): + """ + Convert subtitles from a traversal into a subtitle dict. + The path should have an `all` immediately before this function. + + Arguments: + `lang` The default language tag for subtitle dicts with no + `lang` (`und`: undefined) + `ext` The default value for `ext` in the subtitle dicts + + In the dict you can set the following additional items: + `id` The language tag to which the subtitle dict should be added + `quality` The sort order for each subtitle dict + """ + + result = collections.defaultdict(list) + + for sub in subs: + tn_url = url_or_none(sub.pop('url', None)) + if tn_url: + sub['url'] = tn_url + elif not sub.get('data'): + continue + sub_lang = sub.pop('id', None) + if not isinstance(sub_lang, compat_str): + if not lang: + continue + sub_lang = lang + sub_ext = sub.get('ext') + if not isinstance(sub_ext, compat_str): + if not ext: + sub.pop('ext', None) + else: + sub['ext'] = ext + result[sub_lang].append(sub) + result = dict(result) + + for subs in result.values(): + subs.sort(key=lambda x: x.pop('quality', 0) or 0) + + return result + + +def unpack(func, **kwargs): + """Make a function that applies `partial(func, **kwargs)` to its argument as *args""" + @functools.wraps(func) + def inner(items): + return func(*items, **kwargs) + + return inner + + def T(*x): """ For use in yt-dl instead of {type, ...} or set((type, ...)) """ return set(x) @@ -6561,3 +6695,169 @@ def join_nonempty(*values, **kwargs): if from_dict is not None: values = (traverse_obj(from_dict, variadic(v)) for v in values) return delim.join(map(compat_str, filter(None, values))) + + +class Namespace(object): + """Immutable namespace""" + + def __init__(self, **kw_attr): + self.__dict__.update(kw_attr) + + def __iter__(self): + return iter(self.__dict__.values()) + + @property + def items_(self): + return self.__dict__.items() + + +MEDIA_EXTENSIONS = Namespace( + common_video=('avi', 'flv', 'mkv', 'mov', 'mp4', 'webm'), + video=('3g2', '3gp', 'f4v', 'mk3d', 'divx', 'mpg', 'ogv', 'm4v', 'wmv'), + common_audio=('aiff', 'alac', 'flac', 'm4a', 'mka', 'mp3', 'ogg', 'opus', 'wav'), + audio=('aac', 'ape', 'asf', 'f4a', 'f4b', 'm4b', 'm4p', 'm4r', 'oga', 'ogx', 'spx', 'vorbis', 'wma', 'weba'), + thumbnails=('jpg', 'png', 'webp'), + # storyboards=('mhtml', ), + subtitles=('srt', 'vtt', 'ass', 'lrc', 'ttml'), + manifests=('f4f', 'f4m', 'm3u8', 'smil', 'mpd'), +) +MEDIA_EXTENSIONS.video = MEDIA_EXTENSIONS.common_video + MEDIA_EXTENSIONS.video +MEDIA_EXTENSIONS.audio = MEDIA_EXTENSIONS.common_audio + MEDIA_EXTENSIONS.audio + +KNOWN_EXTENSIONS = ( + MEDIA_EXTENSIONS.video + MEDIA_EXTENSIONS.audio + + MEDIA_EXTENSIONS.manifests +) + + +class _UnsafeExtensionError(Exception): + """ + Mitigation exception for unwanted file overwrite/path traversal + + Ref: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4j + """ + _ALLOWED_EXTENSIONS = frozenset(itertools.chain( + ( # internal + 'description', + 'json', + 'meta', + 'orig', + 'part', + 'temp', + 'uncut', + 'unknown_video', + 'ytdl', + ), + # video + MEDIA_EXTENSIONS.video, ( + 'asx', + 'ismv', + 'm2t', + 'm2ts', + 'm2v', + 'm4s', + 'mng', + 'mp2v', + 'mp4v', + 'mpe', + 'mpeg', + 'mpeg1', + 'mpeg2', + 'mpeg4', + 'mxf', + 'ogm', + 'qt', + 'rm', + 'swf', + 'ts', + 'vob', + 'vp9', + ), + # audio + MEDIA_EXTENSIONS.audio, ( + '3ga', + 'ac3', + 'adts', + 'aif', + 'au', + 'dts', + 'isma', + 'it', + 'mid', + 'mod', + 'mpga', + 'mp1', + 'mp2', + 'mp4a', + 'mpa', + 'ra', + 'shn', + 'xm', + ), + # image + MEDIA_EXTENSIONS.thumbnails, ( + 'avif', + 'bmp', + 'gif', + 'ico', + 'heic', + 'jng', + 'jpeg', + 'jxl', + 'svg', + 'tif', + 'tiff', + 'wbmp', + ), + # subtitle + MEDIA_EXTENSIONS.subtitles, ( + 'dfxp', + 'fs', + 'ismt', + 'json3', + 'sami', + 'scc', + 'srv1', + 'srv2', + 'srv3', + 'ssa', + 'tt', + 'xml', + ), + # others + MEDIA_EXTENSIONS.manifests, + ( + # not used in yt-dl + # *MEDIA_EXTENSIONS.storyboards, + # 'desktop', + # 'ism', + # 'm3u', + # 'sbv', + # 'swp', + # 'url', + # 'webloc', + ))) + + def __init__(self, extension): + super(_UnsafeExtensionError, self).__init__('unsafe file extension: {0!r}'.format(extension)) + self.extension = extension + + # support --no-check-extensions + lenient = False + + @classmethod + def sanitize_extension(cls, extension, **kwargs): + # ... /, *, prepend=False + prepend = kwargs.get('prepend', False) + + if '/' in extension or '\\' in extension: + raise cls(extension) + + if not prepend: + last = extension.rpartition('.')[-1] + if last == 'bin': + extension = last = 'unknown_video' + if not (cls.lenient or last.lower() in cls._ALLOWED_EXTENSIONS): + raise cls(extension) + + return extension |
