diff options
| author | Andrei Lebedev <lebdron@gmail.com> | 2022-11-03 11:09:37 +0100 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-03 10:09:37 +0000 | 
| commit | 27ed77aabba8c9eb08d66f34092b1bfcc22c482e (patch) | |
| tree | 7cc41fc5e398009a5cf8e7e4156afb0246aa34d3 /youtube_dl/utils.py | |
| parent | c4b19a88169fa76c5eb665d274e7270a0fe452c4 (diff) | |
[utils] Backport traverse_obj (etc) from yt-dlp (#31156)
* Backport traverse_obj and closely related function from yt-dlp (code by pukkandan)
* Backport LazyList, variadic(), try_call (code by pukkandan)
* Recast using yt-dlp's newer traverse_obj() implementation and tests (code by grub4k)
* Add tests for Unicode case folding support matching Py3.5+ (requires f102e3d)
* Improve/add tests for variadic, try_call, join_nonempty
Co-authored-by: dirkf <fieldhouse@gmx.net>
Diffstat (limited to 'youtube_dl/utils.py')
| -rw-r--r-- | youtube_dl/utils.py | 339 | 
1 files changed, 339 insertions, 0 deletions
| diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 23a65a81c..e3c3ccff9 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -43,6 +43,7 @@ from .compat import (      compat_HTTPError,      compat_basestring,      compat_chr, +    compat_collections_abc,      compat_cookiejar,      compat_ctypes_WINFUNCTYPE,      compat_etree_fromstring, @@ -1685,6 +1686,7 @@ USER_AGENTS = {  NO_DEFAULT = object() +IDENTITY = lambda x: x  ENGLISH_MONTH_NAMES = [      'January', 'February', 'March', 'April', 'May', 'June', @@ -3867,6 +3869,105 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):          return unrecognized +class LazyList(compat_collections_abc.Sequence): +    """Lazy immutable list from an iterable +    Note that slices of a LazyList are lists and not LazyList""" + +    class IndexError(IndexError): +        def __init__(self, cause=None): +            if cause: +                # reproduce `raise from` +                self.__cause__ = cause +            super(IndexError, self).__init__() + +    def __init__(self, iterable, **kwargs): +        # kwarg-only +        reverse = kwargs.get('reverse', False) +        _cache = kwargs.get('_cache') + +        self._iterable = iter(iterable) +        self._cache = [] if _cache is None else _cache +        self._reversed = reverse + +    def __iter__(self): +        if self._reversed: +            # We need to consume the entire iterable to iterate in reverse +            for item in self.exhaust(): +                yield item +            return +        for item in self._cache: +            yield item +        for item in self._iterable: +            self._cache.append(item) +            yield item + +    def _exhaust(self): +        self._cache.extend(self._iterable) +        self._iterable = []  # Discard the emptied iterable to make it pickle-able +        return self._cache + +    def exhaust(self): +        """Evaluate the entire iterable""" +        return self._exhaust()[::-1 if self._reversed else 1] + +    @staticmethod +    def _reverse_index(x): +        return None if x is None else ~x + +    def __getitem__(self, idx): +        if isinstance(idx, slice): +            if self._reversed: +                idx = slice(self._reverse_index(idx.start), self._reverse_index(idx.stop), -(idx.step or 1)) +            start, stop, step = idx.start, idx.stop, idx.step or 1 +        elif isinstance(idx, int): +            if self._reversed: +                idx = self._reverse_index(idx) +            start, stop, step = idx, idx, 0 +        else: +            raise TypeError('indices must be integers or slices') +        if ((start or 0) < 0 or (stop or 0) < 0 +                or (start is None and step < 0) +                or (stop is None and step > 0)): +            # We need to consume the entire iterable to be able to slice from the end +            # Obviously, never use this with infinite iterables +            self._exhaust() +            try: +                return self._cache[idx] +            except IndexError as e: +                raise self.IndexError(e) +        n = max(start or 0, stop or 0) - len(self._cache) + 1 +        if n > 0: +            self._cache.extend(itertools.islice(self._iterable, n)) +        try: +            return self._cache[idx] +        except IndexError as e: +            raise self.IndexError(e) + +    def __bool__(self): +        try: +            self[-1] if self._reversed else self[0] +        except self.IndexError: +            return False +        return True + +    def __len__(self): +        self._exhaust() +        return len(self._cache) + +    def __reversed__(self): +        return type(self)(self._iterable, reverse=not self._reversed, _cache=self._cache) + +    def __copy__(self): +        return type(self)(self._iterable, reverse=self._reversed, _cache=self._cache) + +    def __repr__(self): +        # repr and str should mimic a list. So we exhaust the iterable +        return repr(self.exhaust()) + +    def __str__(self): +        return repr(self.exhaust()) + +  class PagedList(object):      def __len__(self):          # This is only useful for tests @@ -4092,6 +4193,10 @@ def multipart_encode(data, boundary=None):      return out, content_type +def variadic(x, allowed_types=(compat_str, bytes, dict)): +    return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,) + +  def dict_get(d, key_or_keys, default=None, skip_false_values=True):      if isinstance(key_or_keys, (list, tuple)):          for key in key_or_keys: @@ -4102,6 +4207,23 @@ def dict_get(d, key_or_keys, default=None, skip_false_values=True):      return d.get(key_or_keys, default) +def try_call(*funcs, **kwargs): + +    # parameter defaults +    expected_type = kwargs.get('expected_type') +    fargs = kwargs.get('args', []) +    fkwargs = kwargs.get('kwargs', {}) + +    for f in funcs: +        try: +            val = f(*fargs, **fkwargs) +        except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError): +            pass +        else: +            if expected_type is None or isinstance(val, expected_type): +                return val + +  def try_get(src, getter, expected_type=None):      if not isinstance(getter, (list, tuple)):          getter = [getter] @@ -5835,3 +5957,220 @@ def clean_podcast_url(url):                  st\.fm # https://podsights.com/docs/              )/e          )/''', '', url) + + +def traverse_obj(obj, *paths, **kwargs): +    """ +    Safely traverse nested `dict`s and `Sequence`s + +    >>> obj = [{}, {"key": "value"}] +    >>> traverse_obj(obj, (1, "key")) +    "value" + +    Each of the provided `paths` is tested and the first producing a valid result will be returned. +    The next path will also be tested if the path branched but no results could be found. +    Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. +    A value of None is treated as the absence of a value. + +    The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. + +    The keys in the path can be one of: +        - `None`:           Return the current object. +        - `str`/`int`:      Return `obj[key]`. For `re.Match, return `obj.group(key)`. +        - `slice`:          Branch out and return all values in `obj[key]`. +        - `Ellipsis`:       Branch out and return a list of all values. +        - `tuple`/`list`:   Branch out and return a list of all matching values. +                            Read as: `[traverse_obj(obj, branch) for branch in branches]`. +        - `function`:       Branch out and return values filtered by the function. +                            Read as: `[value for key, value in obj if function(key, value)]`. +                            For `Sequence`s, `key` is the index of the value. +        - `dict`            Transform the current object and return a matching dict. +                            Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + +        `tuple`, `list`, and `dict` all support nested paths and branches. + +    @params paths           Paths which to traverse by. +    Keyword arguments: +    @param default          Value to return if the paths do not match. +    @param expected_type    If a `type`, only accept final values of this type. +                            If any other callable, try to call the function on each result. +    @param get_all          If `False`, return the first matching result, otherwise all matching ones. +    @param casesense        If `False`, consider string dictionary keys as case insensitive. + +    The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + +    @param _is_user_input    Whether the keys are generated from user input. +                            If `True` strings get converted to `int`/`slice` if needed. +    @param _traverse_string  Whether to traverse into objects as strings. +                            If `True`, any non-compatible object will first be +                            converted into a string and then traversed into. + + +    @returns                The result of the object traversal. +                            If successful, `get_all=True`, and the path branches at least once, +                            then a list of results is returned instead. +                            A list is always returned if the last path branches and no `default` is given. +    """ + +    # parameter defaults +    default = kwargs.get('default', NO_DEFAULT) +    expected_type = kwargs.get('expected_type') +    get_all = kwargs.get('get_all', True) +    casesense = kwargs.get('casesense', True) +    _is_user_input = kwargs.get('_is_user_input', False) +    _traverse_string = kwargs.get('_traverse_string', False) + +    # instant compat +    str = compat_str + +    is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes)) +    # stand-in until compat_re_Match is added +    compat_re_Match = type(re.match('a', 'a')) +    # stand-in until casefold.py is added +    try: +        ''.casefold() +        compat_casefold = lambda s: s.casefold() +    except AttributeError: +        compat_casefold = lambda s: s.lower() +    casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k + +    if isinstance(expected_type, type): +        type_test = lambda val: val if isinstance(val, expected_type) else None +    else: +        type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) + +    def from_iterable(iterables): +        # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F +        for it in iterables: +            for item in it: +                yield item + +    def apply_key(key, obj): +        if obj is None: +            return + +        elif key is None: +            yield obj + +        elif isinstance(key, (list, tuple)): +            for branch in key: +                _, result = apply_path(obj, branch) +                for item in result: +                    yield item + +        elif key is Ellipsis: +            result = [] +            if isinstance(obj, compat_collections_abc.Mapping): +                result = obj.values() +            elif is_sequence(obj): +                result = obj +            elif isinstance(obj, compat_re_Match): +                result = obj.groups() +            elif _traverse_string: +                result = str(obj) +            for item in result: +                yield item + +        elif callable(key): +            if is_sequence(obj): +                iter_obj = enumerate(obj) +            elif isinstance(obj, compat_collections_abc.Mapping): +                iter_obj = obj.items() +            elif isinstance(obj, compat_re_Match): +                iter_obj = enumerate(itertools.chain([obj.group()], obj.groups())) +            elif _traverse_string: +                iter_obj = enumerate(str(obj)) +            else: +                return +            for item in (v for k, v in iter_obj if try_call(key, args=(k, v))): +                yield item + +        elif isinstance(key, dict): +            iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) +            yield dict((k, v if v is not None else default) for k, v in iter_obj +                       if v is not None or default is not NO_DEFAULT) + +        elif isinstance(obj, compat_collections_abc.Mapping): +            yield (obj.get(key) if casesense or (key in obj) +                   else next((v for k, v in obj.items() if casefold(k) == key), None)) + +        elif isinstance(obj, compat_re_Match): +            if isinstance(key, int) or casesense: +                try: +                    yield obj.group(key) +                    return +                except IndexError: +                    pass +            if not isinstance(key, str): +                return + +            yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + +        else: +            if _is_user_input: +                key = (int_or_none(key) if ':' not in key +                       else slice(*map(int_or_none, key.split(':')))) + +            if not isinstance(key, (int, slice)): +                return + +            if not is_sequence(obj): +                if not _traverse_string: +                    return +                obj = str(obj) + +            try: +                yield obj[key] +            except IndexError: +                pass + +    def apply_path(start_obj, path): +        objs = (start_obj,) +        has_branched = False + +        for key in variadic(path): +            if _is_user_input and key == ':': +                key = Ellipsis + +            if not casesense and isinstance(key, str): +                key = compat_casefold(key) + +            if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key): +                has_branched = True + +            key_func = functools.partial(apply_key, key) +            objs = from_iterable(map(key_func, objs)) + +        return has_branched, objs + +    def _traverse_obj(obj, path, use_list=True): +        has_branched, results = apply_path(obj, path) +        results = LazyList(x for x in map(type_test, results) if x is not None) + +        if get_all and has_branched: +            return results.exhaust() if results or use_list else None + +        return results[0] if results else None + +    for index, path in enumerate(paths, 1): +        use_list = default is NO_DEFAULT and index == len(paths) +        result = _traverse_obj(obj, path, use_list) +        if result is not None: +            return result + +    return None if default is NO_DEFAULT else default + + +def get_first(obj, keys, **kwargs): +    return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs) + + +def join_nonempty(*values, **kwargs): + +    # parameter defaults +    delim = kwargs.get('delim', '-') +    from_dict = kwargs.get('from_dict') + +    if from_dict is not None: +        values = (traverse_obj(from_dict, variadic(v)) for v in values) +    return delim.join(map(compat_str, filter(None, values))) | 
