aboutsummaryrefslogtreecommitdiff
path: root/youtube_dl/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'youtube_dl/utils.py')
-rw-r--r--youtube_dl/utils.py440
1 files changed, 370 insertions, 70 deletions
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 113c913df..02a49ff49 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -53,6 +53,8 @@ from .compat import (
compat_etree_fromstring,
compat_etree_iterfind,
compat_expanduser,
+ compat_filter as filter,
+ compat_filter_fns,
compat_html_entities,
compat_html_entities_html5,
compat_http_client,
@@ -1717,21 +1719,6 @@ TIMEZONE_NAMES = {
'PST': -8, 'PDT': -7 # Pacific
}
-KNOWN_EXTENSIONS = (
- 'mp4', 'm4a', 'm4p', 'm4b', 'm4r', 'm4v', 'aac',
- 'flv', 'f4v', 'f4a', 'f4b',
- 'webm', 'ogg', 'ogv', 'oga', 'ogx', 'spx', 'opus',
- 'mkv', 'mka', 'mk3d',
- 'avi', 'divx',
- 'mov',
- 'asf', 'wmv', 'wma',
- '3gp', '3g2',
- 'mp3',
- 'flac',
- 'ape',
- 'wav',
- 'f4f', 'f4m', 'm3u8', 'smil')
-
# needed for sanitizing filenames in restricted mode
ACCENT_CHARS = dict(zip('ÂÃÄÀÁÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖŐØŒÙÚÛÜŰÝÞßàáâãäåæçèéêëìíîïðñòóôõöőøœùúûüűýþÿ',
itertools.chain('AAAAAA', ['AE'], 'CEEEEIIIIDNOOOOOOO', ['OE'], 'UUUUUY', ['TH', 'ss'],
@@ -1874,6 +1861,39 @@ def write_json_file(obj, fn):
raise
+class partial_application(object):
+ """Allow a function to use pre-set argument values"""
+
+ # see _try_bind_args()
+ try:
+ inspect.signature
+
+ @staticmethod
+ def required_args(fn):
+ return [
+ param.name for param in inspect.signature(fn).parameters.values()
+ if (param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ and param.default is inspect.Parameter.empty)]
+
+ except AttributeError:
+
+ # Py < 3.3
+ @staticmethod
+ def required_args(fn):
+ fn_args = inspect.getargspec(fn)
+ n_defaults = len(fn_args.defaults or [])
+ return (fn_args.args or [])[:-n_defaults if n_defaults > 0 else None]
+
+ def __new__(cls, func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if set(cls.required_args(func)[len(args):]).difference(kwargs):
+ return functools.partial(func, *args, **kwargs)
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
if sys.version_info >= (2, 7):
def find_xpath_attr(node, xpath, key, val=None):
""" Find the xpath xpath[@key=val] """
@@ -3167,6 +3187,7 @@ def extract_timezone(date_str):
return timezone, date_str
+@partial_application
def parse_iso8601(date_str, delimiter='T', timezone=None):
""" Return a UNIX timestamp from the given date """
@@ -3244,6 +3265,7 @@ def unified_timestamp(date_str, day_first=True):
return calendar.timegm(timetuple) + pm_delta * 3600 - compat_datetime_timedelta_total_seconds(timezone)
+@partial_application
def determine_ext(url, default_ext='unknown_video'):
if url is None or '.' not in url:
return default_ext
@@ -3822,6 +3844,7 @@ def base_url(url):
return re.match(r'https?://[^?#&]+/', url).group()
+@partial_application
def urljoin(base, path):
path = _decode_compat_str(path, encoding='utf-8', or_none=True)
if not path:
@@ -3846,6 +3869,7 @@ class PUTRequest(compat_urllib_request.Request):
return 'PUT'
+@partial_application
def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None):
if get_attr:
if v is not None:
@@ -3872,6 +3896,7 @@ def str_to_int(int_str):
return int_or_none(int_str)
+@partial_application
def float_or_none(v, scale=1, invscale=1, default=None):
if v is None:
return default
@@ -3906,38 +3931,46 @@ def parse_duration(s):
return None
s = s.strip()
+ if not s:
+ return None
days, hours, mins, secs, ms = [None] * 5
- m = re.match(r'(?:(?:(?:(?P<days>[0-9]+):)?(?P<hours>[0-9]+):)?(?P<mins>[0-9]+):)?(?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?Z?$', s)
+ m = re.match(r'''(?x)
+ (?P<before_secs>
+ (?:(?:(?P<days>[0-9]+):)?(?P<hours>[0-9]+):)?
+ (?P<mins>[0-9]+):)?
+ (?P<secs>(?(before_secs)[0-9]{1,2}|[0-9]+))
+ (?:[.:](?P<ms>[0-9]+))?Z?$
+ ''', s)
if m:
- days, hours, mins, secs, ms = m.groups()
+ days, hours, mins, secs, ms = m.group('days', 'hours', 'mins', 'secs', 'ms')
else:
m = re.match(
r'''(?ix)(?:P?
(?:
- [0-9]+\s*y(?:ears?)?\s*
+ [0-9]+\s*y(?:ears?)?,?\s*
)?
(?:
- [0-9]+\s*m(?:onths?)?\s*
+ [0-9]+\s*m(?:onths?)?,?\s*
)?
(?:
- [0-9]+\s*w(?:eeks?)?\s*
+ [0-9]+\s*w(?:eeks?)?,?\s*
)?
(?:
- (?P<days>[0-9]+)\s*d(?:ays?)?\s*
+ (?P<days>[0-9]+)\s*d(?:ays?)?,?\s*
)?
T)?
(?:
- (?P<hours>[0-9]+)\s*h(?:ours?)?\s*
+ (?P<hours>[0-9]+)\s*h(?:(?:ou)?rs?)?,?\s*
)?
(?:
- (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?\s*
+ (?P<mins>[0-9]+)\s*m(?:in(?:ute)?s?)?,?\s*
)?
(?:
- (?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?\s*s(?:ec(?:ond)?s?)?\s*
+ (?P<secs>[0-9]+)(?:\.(?P<ms>[0-9]+))?\s*s(?:ec(?:ond)?s?)?\s*
)?Z?$''', s)
if m:
- days, hours, mins, secs, ms = m.groups()
+ days, hours, mins, secs, ms = m.group('days', 'hours', 'mins', 'secs', 'ms')
else:
m = re.match(r'(?i)(?:(?P<hours>[0-9.]+)\s*(?:hours?)|(?P<mins>[0-9.]+)\s*(?:mins?\.?|minutes?)\s*)Z?$', s)
if m:
@@ -3945,33 +3978,32 @@ def parse_duration(s):
else:
return None
- duration = 0
- if secs:
- duration += float(secs)
- if mins:
- duration += float(mins) * 60
- if hours:
- duration += float(hours) * 60 * 60
- if days:
- duration += float(days) * 24 * 60 * 60
- if ms:
- duration += float(ms)
+ duration = (
+ ((((float(days) * 24) if days else 0)
+ + (float(hours) if hours else 0)) * 60
+ + (float(mins) if mins else 0)) * 60
+ + (float(secs) if secs else 0)
+ + (float(ms) / 10 ** len(ms) if ms else 0))
+
return duration
-def prepend_extension(filename, ext, expected_real_ext=None):
+def _change_extension(prepend, filename, ext, expected_real_ext=None):
name, real_ext = os.path.splitext(filename)
- return (
- '{0}.{1}{2}'.format(name, ext, real_ext)
- if not expected_real_ext or real_ext[1:] == expected_real_ext
- else '{0}.{1}'.format(filename, ext))
+ sanitize_extension = _UnsafeExtensionError.sanitize_extension
+ if not expected_real_ext or real_ext.partition('.')[0::2] == ('', expected_real_ext):
+ filename = name
+ if prepend and real_ext:
+ sanitize_extension(ext, prepend=prepend)
+ return ''.join((filename, '.', ext, real_ext))
-def replace_extension(filename, ext, expected_real_ext=None):
- name, real_ext = os.path.splitext(filename)
- return '{0}.{1}'.format(
- name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename,
- ext)
+ # Mitigate path traversal and file impersonation attacks
+ return '.'.join((filename, sanitize_extension(ext)))
+
+
+prepend_extension = functools.partial(_change_extension, True)
+replace_extension = functools.partial(_change_extension, False)
def check_executable(exe, args=[]):
@@ -4216,12 +4248,16 @@ def lowercase_escape(s):
s)
-def escape_rfc3986(s):
+def escape_rfc3986(s, safe=None):
"""Escape non-ASCII characters as suggested by RFC 3986"""
if sys.version_info < (3, 0):
s = _encode_compat_str(s, 'utf-8')
+ if safe is not None:
+ safe = _encode_compat_str(safe, 'utf-8')
+ if safe is None:
+ safe = b"%/;:@&=+$,!~*'()?#[]"
# ensure unicode: after quoting, it can always be converted
- return compat_str(compat_urllib_parse.quote(s, b"%/;:@&=+$,!~*'()?#[]"))
+ return compat_str(compat_urllib_parse.quote(s, safe))
def escape_url(url):
@@ -4259,6 +4295,7 @@ def urlencode_postdata(*args, **kargs):
return compat_urllib_parse_urlencode(*args, **kargs).encode('ascii')
+@partial_application
def update_url(url, **kwargs):
"""Replace URL components specified by kwargs
url: compat_str or parsed URL tuple
@@ -4280,6 +4317,7 @@ def update_url(url, **kwargs):
return compat_urllib_parse.urlunparse(url._replace(**kwargs))
+@partial_application
def update_url_query(url, query):
return update_url(url, query_update=query)
@@ -4706,30 +4744,45 @@ def parse_codecs(codecs_str):
if not codecs_str:
return {}
split_codecs = list(filter(None, map(
- lambda str: str.strip(), codecs_str.strip().strip(',').split(','))))
- vcodec, acodec = None, None
+ lambda s: s.strip(), codecs_str.strip().split(','))))
+ vcodec, acodec, hdr = None, None, None
for full_codec in split_codecs:
- codec = full_codec.split('.')[0]
- if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2', 'h263', 'h264', 'mp4v', 'hvc1', 'av01', 'theora'):
- if not vcodec:
- vcodec = full_codec
- elif codec in ('mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'):
+ codec, rest = full_codec.partition('.')[::2]
+ codec = codec.lower()
+ full_codec = '.'.join((codec, rest)) if rest else codec
+ codec = re.sub(r'0+(?=\d)', '', codec)
+ if codec in ('avc1', 'avc2', 'avc3', 'avc4', 'vp9', 'vp8', 'hev1', 'hev2',
+ 'h263', 'h264', 'mp4v', 'hvc1', 'av1', 'theora', 'dvh1', 'dvhe'):
+ if vcodec:
+ continue
+ vcodec = full_codec
+ if codec in ('dvh1', 'dvhe'):
+ hdr = 'DV'
+ elif codec in ('av1', 'vp9'):
+ n, m = {
+ 'av1': (2, '10'),
+ 'vp9': (0, '2'),
+ }[codec]
+ if (rest.split('.', n + 1)[n:] or [''])[0].lstrip('0') == m:
+ hdr = 'HDR10'
+ elif codec in ('flac', 'mp4a', 'opus', 'vorbis', 'mp3', 'aac', 'ac-4',
+ 'ac-3', 'ec-3', 'eac3', 'dtsc', 'dtse', 'dtsh', 'dtsl'):
if not acodec:
acodec = full_codec
else:
- write_string('WARNING: Unknown codec %s\n' % full_codec, sys.stderr)
- if not vcodec and not acodec:
- if len(split_codecs) == 2:
- return {
- 'vcodec': split_codecs[0],
- 'acodec': split_codecs[1],
- }
- else:
- return {
+ write_string('WARNING: Unknown codec %s\n' % (full_codec,), sys.stderr)
+
+ return (
+ filter_dict({
'vcodec': vcodec or 'none',
'acodec': acodec or 'none',
- }
- return {}
+ 'dynamic_range': hdr,
+ }) if vcodec or acodec
+ else {
+ 'vcodec': split_codecs[0],
+ 'acodec': split_codecs[1],
+ } if len(split_codecs) == 2
+ else {})
def urlhandle_detect_ext(url_handle):
@@ -6291,6 +6344,7 @@ def traverse_obj(obj, *paths, **kwargs):
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
- `any`-builtin: Take the first matching object and return it, resetting branching.
- `all`-builtin: Take all matching objects and return them as a list, resetting branching.
+ - `filter`-builtin: Return the value if it is truthy, `None` otherwise.
`tuple`, `list`, and `dict` all support nested paths and branches.
@@ -6332,6 +6386,11 @@ def traverse_obj(obj, *paths, **kwargs):
# instant compat
str = compat_str
+ from .compat import (
+ compat_builtins_dict as dict_, # the basic dict type
+ compat_dict as dict, # dict preserving imsertion order
+ )
+
casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
if isinstance(expected_type, type):
@@ -6414,7 +6473,7 @@ def traverse_obj(obj, *paths, **kwargs):
if not branching: # string traversal
result = ''.join(result)
- elif isinstance(key, dict):
+ elif isinstance(key, dict_):
iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
result = dict((k, v if v is not None else default) for k, v in iter_obj
if v is not None or default is not NO_DEFAULT) or None
@@ -6492,7 +6551,7 @@ def traverse_obj(obj, *paths, **kwargs):
has_branched = False
key = None
- for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
+ for last, key in lazy_last(variadic(path, (str, bytes, dict_, set))):
if not casesense and isinstance(key, str):
key = compat_casefold(key)
@@ -6505,6 +6564,11 @@ def traverse_obj(obj, *paths, **kwargs):
objs = (list(filtered_objs),)
continue
+ # filter might be from __builtin__, future_builtins, or itertools.ifilter
+ if key in compat_filter_fns:
+ objs = filter(None, objs)
+ continue
+
if __debug__ and callable(key):
# Verify function signature
_try_bind_args(key, None, None)
@@ -6517,10 +6581,10 @@ def traverse_obj(obj, *paths, **kwargs):
objs = from_iterable(new_objs)
- if test_type and not isinstance(key, (dict, list, tuple)):
+ if test_type and not isinstance(key, (dict_, list, tuple)):
objs = map(type_test, objs)
- return objs, has_branched, isinstance(key, dict)
+ return objs, has_branched, isinstance(key, dict_)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
@@ -6543,6 +6607,76 @@ def traverse_obj(obj, *paths, **kwargs):
return None if default is NO_DEFAULT else default
+def value(value):
+ return lambda _: value
+
+
+class require(ExtractorError):
+ def __init__(self, name, expected=False):
+ super(require, self).__init__(
+ 'Unable to extract {0}'.format(name), expected=expected)
+
+ def __call__(self, value):
+ if value is None:
+ raise self
+
+ return value
+
+
+@partial_application
+# typing: (subs: list[dict], /, *, lang='und', ext=None) -> dict[str, list[dict]
+def subs_list_to_dict(subs, lang='und', ext=None):
+ """
+ Convert subtitles from a traversal into a subtitle dict.
+ The path should have an `all` immediately before this function.
+
+ Arguments:
+ `lang` The default language tag for subtitle dicts with no
+ `lang` (`und`: undefined)
+ `ext` The default value for `ext` in the subtitle dicts
+
+ In the dict you can set the following additional items:
+ `id` The language tag to which the subtitle dict should be added
+ `quality` The sort order for each subtitle dict
+ """
+
+ result = collections.defaultdict(list)
+
+ for sub in subs:
+ tn_url = url_or_none(sub.pop('url', None))
+ if tn_url:
+ sub['url'] = tn_url
+ elif not sub.get('data'):
+ continue
+ sub_lang = sub.pop('id', None)
+ if not isinstance(sub_lang, compat_str):
+ if not lang:
+ continue
+ sub_lang = lang
+ sub_ext = sub.get('ext')
+ if not isinstance(sub_ext, compat_str):
+ if not ext:
+ sub.pop('ext', None)
+ else:
+ sub['ext'] = ext
+ result[sub_lang].append(sub)
+ result = dict(result)
+
+ for subs in result.values():
+ subs.sort(key=lambda x: x.pop('quality', 0) or 0)
+
+ return result
+
+
+def unpack(func, **kwargs):
+ """Make a function that applies `partial(func, **kwargs)` to its argument as *args"""
+ @functools.wraps(func)
+ def inner(items):
+ return func(*items, **kwargs)
+
+ return inner
+
+
def T(*x):
""" For use in yt-dl instead of {type, ...} or set((type, ...)) """
return set(x)
@@ -6561,3 +6695,169 @@ def join_nonempty(*values, **kwargs):
if from_dict is not None:
values = (traverse_obj(from_dict, variadic(v)) for v in values)
return delim.join(map(compat_str, filter(None, values)))
+
+
+class Namespace(object):
+ """Immutable namespace"""
+
+ def __init__(self, **kw_attr):
+ self.__dict__.update(kw_attr)
+
+ def __iter__(self):
+ return iter(self.__dict__.values())
+
+ @property
+ def items_(self):
+ return self.__dict__.items()
+
+
+MEDIA_EXTENSIONS = Namespace(
+ common_video=('avi', 'flv', 'mkv', 'mov', 'mp4', 'webm'),
+ video=('3g2', '3gp', 'f4v', 'mk3d', 'divx', 'mpg', 'ogv', 'm4v', 'wmv'),
+ common_audio=('aiff', 'alac', 'flac', 'm4a', 'mka', 'mp3', 'ogg', 'opus', 'wav'),
+ audio=('aac', 'ape', 'asf', 'f4a', 'f4b', 'm4b', 'm4p', 'm4r', 'oga', 'ogx', 'spx', 'vorbis', 'wma', 'weba'),
+ thumbnails=('jpg', 'png', 'webp'),
+ # storyboards=('mhtml', ),
+ subtitles=('srt', 'vtt', 'ass', 'lrc', 'ttml'),
+ manifests=('f4f', 'f4m', 'm3u8', 'smil', 'mpd'),
+)
+MEDIA_EXTENSIONS.video = MEDIA_EXTENSIONS.common_video + MEDIA_EXTENSIONS.video
+MEDIA_EXTENSIONS.audio = MEDIA_EXTENSIONS.common_audio + MEDIA_EXTENSIONS.audio
+
+KNOWN_EXTENSIONS = (
+ MEDIA_EXTENSIONS.video + MEDIA_EXTENSIONS.audio
+ + MEDIA_EXTENSIONS.manifests
+)
+
+
+class _UnsafeExtensionError(Exception):
+ """
+ Mitigation exception for unwanted file overwrite/path traversal
+
+ Ref: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4j
+ """
+ _ALLOWED_EXTENSIONS = frozenset(itertools.chain(
+ ( # internal
+ 'description',
+ 'json',
+ 'meta',
+ 'orig',
+ 'part',
+ 'temp',
+ 'uncut',
+ 'unknown_video',
+ 'ytdl',
+ ),
+ # video
+ MEDIA_EXTENSIONS.video, (
+ 'asx',
+ 'ismv',
+ 'm2t',
+ 'm2ts',
+ 'm2v',
+ 'm4s',
+ 'mng',
+ 'mp2v',
+ 'mp4v',
+ 'mpe',
+ 'mpeg',
+ 'mpeg1',
+ 'mpeg2',
+ 'mpeg4',
+ 'mxf',
+ 'ogm',
+ 'qt',
+ 'rm',
+ 'swf',
+ 'ts',
+ 'vob',
+ 'vp9',
+ ),
+ # audio
+ MEDIA_EXTENSIONS.audio, (
+ '3ga',
+ 'ac3',
+ 'adts',
+ 'aif',
+ 'au',
+ 'dts',
+ 'isma',
+ 'it',
+ 'mid',
+ 'mod',
+ 'mpga',
+ 'mp1',
+ 'mp2',
+ 'mp4a',
+ 'mpa',
+ 'ra',
+ 'shn',
+ 'xm',
+ ),
+ # image
+ MEDIA_EXTENSIONS.thumbnails, (
+ 'avif',
+ 'bmp',
+ 'gif',
+ 'ico',
+ 'heic',
+ 'jng',
+ 'jpeg',
+ 'jxl',
+ 'svg',
+ 'tif',
+ 'tiff',
+ 'wbmp',
+ ),
+ # subtitle
+ MEDIA_EXTENSIONS.subtitles, (
+ 'dfxp',
+ 'fs',
+ 'ismt',
+ 'json3',
+ 'sami',
+ 'scc',
+ 'srv1',
+ 'srv2',
+ 'srv3',
+ 'ssa',
+ 'tt',
+ 'xml',
+ ),
+ # others
+ MEDIA_EXTENSIONS.manifests,
+ (
+ # not used in yt-dl
+ # *MEDIA_EXTENSIONS.storyboards,
+ # 'desktop',
+ # 'ism',
+ # 'm3u',
+ # 'sbv',
+ # 'swp',
+ # 'url',
+ # 'webloc',
+ )))
+
+ def __init__(self, extension):
+ super(_UnsafeExtensionError, self).__init__('unsafe file extension: {0!r}'.format(extension))
+ self.extension = extension
+
+ # support --no-check-extensions
+ lenient = False
+
+ @classmethod
+ def sanitize_extension(cls, extension, **kwargs):
+ # ... /, *, prepend=False
+ prepend = kwargs.get('prepend', False)
+
+ if '/' in extension or '\\' in extension:
+ raise cls(extension)
+
+ if not prepend:
+ last = extension.rpartition('.')[-1]
+ if last == 'bin':
+ extension = last = 'unknown_video'
+ if not (cls.lenient or last.lower() in cls._ALLOWED_EXTENSIONS):
+ raise cls(extension)
+
+ return extension