diff options
Diffstat (limited to 'youtube_dl/utils.py')
| -rw-r--r-- | youtube_dl/utils.py | 199 | 
1 files changed, 144 insertions, 55 deletions
| diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 0d0bbe8f6..d920c65a4 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -24,6 +24,7 @@ import socket  import struct  import subprocess  import sys +import tempfile  import traceback  import xml.etree.ElementTree  import zlib @@ -191,6 +192,13 @@ try:  except ImportError:  # Python 2.6      from xml.parsers.expat import ExpatError as compat_xml_parse_error +try: +    from shlex import quote as shlex_quote +except ImportError:  # Python < 3.3 +    def shlex_quote(s): +        return "'" + s.replace("'", "'\"'\"'") + "'" + +  def compat_ord(c):      if type(c) is int: return c      else: return ord(c) @@ -228,18 +236,42 @@ else:          assert type(s) == type(u'')          print(s) -# In Python 2.x, json.dump expects a bytestream. -# In Python 3.x, it writes to a character stream -if sys.version_info < (3,0): -    def write_json_file(obj, fn): -        with open(fn, 'wb') as f: -            json.dump(obj, f) -else: -    def write_json_file(obj, fn): -        with open(fn, 'w', encoding='utf-8') as f: -            json.dump(obj, f) -if sys.version_info >= (2,7): +def write_json_file(obj, fn): +    """ Encode obj as JSON and write it to fn, atomically """ + +    args = { +        'suffix': '.tmp', +        'prefix': os.path.basename(fn) + '.', +        'dir': os.path.dirname(fn), +        'delete': False, +    } + +    # In Python 2.x, json.dump expects a bytestream. +    # In Python 3.x, it writes to a character stream +    if sys.version_info < (3, 0): +        args['mode'] = 'wb' +    else: +        args.update({ +            'mode': 'w', +            'encoding': 'utf-8', +        }) + +    tf = tempfile.NamedTemporaryFile(**args) + +    try: +        with tf: +            json.dump(obj, tf) +        os.rename(tf.name, fn) +    except: +        try: +            os.remove(tf.name) +        except OSError: +            pass +        raise + + +if sys.version_info >= (2, 7):      def find_xpath_attr(node, xpath, key, val):          """ Find the xpath xpath[@key=val] """          assert re.match(r'^[a-zA-Z-]+$', key) @@ -266,30 +298,6 @@ def xpath_with_ns(path, ns_map):              replaced.append('{%s}%s' % (ns_map[ns], tag))      return '/'.join(replaced) -def htmlentity_transform(matchobj): -    """Transforms an HTML entity to a character. - -    This function receives a match object and is intended to be used with -    the re.sub() function. -    """ -    entity = matchobj.group(1) - -    # Known non-numeric HTML entity -    if entity in compat_html_entities.name2codepoint: -        return compat_chr(compat_html_entities.name2codepoint[entity]) - -    mobj = re.match(u'(?u)#(x?\\d+)', entity) -    if mobj is not None: -        numstr = mobj.group(1) -        if numstr.startswith(u'x'): -            base = 16 -            numstr = u'0%s' % numstr -        else: -            base = 10 -        return compat_chr(int(numstr, base)) - -    # Unknown entity in name, return its literal representation -    return (u'&%s;' % entity)  compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix  class BaseHTMLParser(compat_html_parser.HTMLParser): @@ -511,13 +519,33 @@ def orderedSet(iterable):      return res +def _htmlentity_transform(entity): +    """Transforms an HTML entity to a character.""" +    # Known non-numeric HTML entity +    if entity in compat_html_entities.name2codepoint: +        return compat_chr(compat_html_entities.name2codepoint[entity]) + +    mobj = re.match(r'#(x?[0-9]+)', entity) +    if mobj is not None: +        numstr = mobj.group(1) +        if numstr.startswith(u'x'): +            base = 16 +            numstr = u'0%s' % numstr +        else: +            base = 10 +        return compat_chr(int(numstr, base)) + +    # Unknown entity in name, return its literal representation +    return (u'&%s;' % entity) + +  def unescapeHTML(s):      if s is None:          return None      assert type(s) == compat_str -    result = re.sub(r'(?u)&(.+?);', htmlentity_transform, s) -    return result +    return re.sub( +        r'&([^;]+);', lambda m: _htmlentity_transform(m.group(1)), s)  def encodeFilename(s, for_subprocess=False): @@ -589,7 +617,7 @@ def make_HTTPS_handler(opts_no_check_certificate, **kwargs):                      self.sock = sock                      self._tunnel()                  try: -                    self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_SSLv3) +                    self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_TLSv1)                  except ssl.SSLError:                      self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, ssl_version=ssl.PROTOCOL_SSLv23) @@ -597,8 +625,14 @@ def make_HTTPS_handler(opts_no_check_certificate, **kwargs):              def https_open(self, req):                  return self.do_open(HTTPSConnectionV3, req)          return HTTPSHandlerV3(**kwargs) -    else: -        context = ssl.SSLContext(ssl.PROTOCOL_SSLv3) +    elif hasattr(ssl, 'create_default_context'):  # Python >= 3.4 +        context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +        context.options &= ~ssl.OP_NO_SSLv3  # Allow older, not-as-secure SSLv3 +        if opts_no_check_certificate: +            context.verify_mode = ssl.CERT_NONE +        return compat_urllib_request.HTTPSHandler(context=context, **kwargs) +    else:  # Python < 3.4 +        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)          context.verify_mode = (ssl.CERT_NONE                                 if opts_no_check_certificate                                 else ssl.CERT_REQUIRED) @@ -734,10 +768,9 @@ class YoutubeDLHandler(compat_urllib_request.HTTPHandler):          return ret      def http_request(self, req): -        for h,v in std_headers.items(): -            if h in req.headers: -                del req.headers[h] -            req.add_header(h, v) +        for h, v in std_headers.items(): +            if h not in req.headers: +                req.add_header(h, v)          if 'Youtubedl-no-compression' in req.headers:              if 'Accept-encoding' in req.headers:                  del req.headers['Accept-encoding'] @@ -827,8 +860,10 @@ def unified_strdate(date_str):          '%b %dnd %Y %I:%M%p',          '%b %dth %Y %I:%M%p',          '%Y-%m-%d', +        '%Y/%m/%d',          '%d.%m.%Y',          '%d/%m/%Y', +        '%d/%m/%y',          '%Y/%m/%d %H:%M:%S',          '%Y-%m-%d %H:%M:%S',          '%d.%m.%Y %H:%M', @@ -852,6 +887,8 @@ def unified_strdate(date_str):      return upload_date  def determine_ext(url, default_ext=u'unknown_video'): +    if url is None: +        return default_ext      guess = url.partition(u'?')[0].rpartition(u'.')[2]      if re.match(r'^[A-Za-z0-9]+$', guess):          return guess @@ -1045,12 +1082,6 @@ def intlist_to_bytes(xs):          return bytes(xs) -def get_cachedir(params={}): -    cache_root = os.environ.get('XDG_CACHE_HOME', -                                os.path.expanduser('~/.cache')) -    return params.get('cachedir', os.path.join(cache_root, 'youtube-dl')) - -  # Cross-platform file locking  if sys.platform == 'win32':      import ctypes.wintypes @@ -1110,10 +1141,10 @@ else:      import fcntl      def _lock_file(f, exclusive): -        fcntl.lockf(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH) +        fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)      def _unlock_file(f): -        fcntl.lockf(f, fcntl.LOCK_UN) +        fcntl.flock(f, fcntl.LOCK_UN)  class locked_file(object): @@ -1257,6 +1288,12 @@ def remove_start(s, start):      return s +def remove_end(s, end): +    if s.endswith(end): +        return s[:-len(end)] +    return s + +  def url_basename(url):      path = compat_urlparse.urlparse(url).path      return path.strip(u'/').split(u'/')[-1] @@ -1271,13 +1308,20 @@ def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1):      if get_attr:          if v is not None:              v = getattr(v, get_attr, None) +    if v == '': +        v = None      return default if v is None else (int(v) * invscale // scale) +def str_or_none(v, default=None): +    return default if v is None else compat_str(v) + +  def str_to_int(int_str): +    """ A more relaxed version of int_or_none """      if int_str is None:          return None -    int_str = re.sub(r'[,\.]', u'', int_str) +    int_str = re.sub(r'[,\.\+]', u'', int_str)      return int(int_str) @@ -1289,8 +1333,10 @@ def parse_duration(s):      if s is None:          return None +    s = s.strip() +      m = re.match( -        r'(?:(?:(?P<hours>[0-9]+)[:h])?(?P<mins>[0-9]+)[:m])?(?P<secs>[0-9]+)s?(?::[0-9]+)?$', s) +        r'(?i)(?:(?:(?P<hours>[0-9]+)\s*(?:[:h]|hours?)\s*)?(?P<mins>[0-9]+)\s*(?:[:m]|mins?|minutes?)\s*)?(?P<secs>[0-9]+)(?P<ms>\.[0-9]+)?\s*(?:s|secs?|seconds?)?$', s)      if not m:          return None      res = int(m.group('secs')) @@ -1298,6 +1344,8 @@ def parse_duration(s):          res += int(m.group('mins')) * 60          if m.group('hours'):              res += int(m.group('hours')) * 60 * 60 +    if m.group('ms'): +        res += float(m.group('ms'))      return res @@ -1408,6 +1456,12 @@ def urlencode_postdata(*args, **kargs):      return compat_urllib_parse.urlencode(*args, **kargs).encode('ascii') +try: +    etree_iter = xml.etree.ElementTree.Element.iter +except AttributeError:  # Python <=2.6 +    etree_iter = lambda n: n.findall('.//*') + +  def parse_xml(s):      class TreeBuilder(xml.etree.ElementTree.TreeBuilder):          def doctype(self, name, pubid, system): @@ -1415,7 +1469,14 @@ def parse_xml(s):      parser = xml.etree.ElementTree.XMLParser(target=TreeBuilder())      kwargs = {'parser': parser} if sys.version_info >= (2, 7) else {} -    return xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs) +    tree = xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs) +    # Fix up XML parser in Python 2.x +    if sys.version_info < (3, 0): +        for n in etree_iter(tree): +            if n.text is not None: +                if not isinstance(n.text, compat_str): +                    n.text = n.text.decode('utf-8') +    return tree  if sys.version_info < (3, 0) and sys.platform == 'win32': @@ -1440,6 +1501,34 @@ def strip_jsonp(code):      return re.sub(r'(?s)^[a-zA-Z0-9_]+\s*\(\s*(.*)\);?\s*?\s*$', r'\1', code) +def js_to_json(code): +    def fix_kv(m): +        key = m.group(2) +        if key.startswith("'"): +            assert key.endswith("'") +            assert '"' not in key +            key = '"%s"' % key[1:-1] +        elif not key.startswith('"'): +            key = '"%s"' % key + +        value = m.group(4) +        if value.startswith("'"): +            assert value.endswith("'") +            assert '"' not in value +            value = '"%s"' % value[1:-1] + +        return m.group(1) + key + m.group(3) + value + +    res = re.sub(r'''(?x) +            ([{,]\s*) +            ("[^"]*"|\'[^\']*\'|[a-z0-9A-Z]+) +            (:\s*) +            ([0-9.]+|true|false|"[^"]*"|\'[^\']*\'|\[|\{) +        ''', fix_kv, code) +    res = re.sub(r',(\s*\])', lambda m: m.group(1), res) +    return res + +  def qualities(quality_ids):      """ Get a numeric quality value out of a list of possible values """      def q(qid): | 
