aboutsummaryrefslogtreecommitdiff
path: root/yt_dlp/networking/_helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'yt_dlp/networking/_helper.py')
-rw-r--r--yt_dlp/networking/_helper.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py
index a43c57bb4..4c9dbf25d 100644
--- a/yt_dlp/networking/_helper.py
+++ b/yt_dlp/networking/_helper.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib
import functools
+import socket
import ssl
import sys
import typing
@@ -206,3 +207,59 @@ def wrap_request_errors(func):
e.handler = self
raise
return wrapper
+
+
+def _socket_connect(ip_addr, timeout, source_address):
+ af, socktype, proto, canonname, sa = ip_addr
+ sock = socket.socket(af, socktype, proto)
+ try:
+ if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
+ sock.settimeout(timeout)
+ if source_address:
+ sock.bind(source_address)
+ sock.connect(sa)
+ return sock
+ except socket.error:
+ sock.close()
+ raise
+
+
+def create_connection(
+ address,
+ timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None,
+ *,
+ _create_socket_func=_socket_connect
+):
+ # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
+ # This filters the addresses based on the given source_address.
+ # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
+ host, port = address
+ ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
+ if not ip_addrs:
+ raise socket.error('getaddrinfo returns an empty list')
+ if source_address is not None:
+ af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
+ ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
+ if not ip_addrs:
+ raise OSError(
+ f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
+ f'Can\'t use "{source_address[0]}" as source address')
+
+ err = None
+ for ip_addr in ip_addrs:
+ try:
+ sock = _create_socket_func(ip_addr, timeout, source_address)
+ # Explicitly break __traceback__ reference cycle
+ # https://bugs.python.org/issue36820
+ err = None
+ return sock
+ except socket.error as e:
+ err = e
+
+ try:
+ raise err
+ finally:
+ # Explicitly break __traceback__ reference cycle
+ # https://bugs.python.org/issue36820
+ err = None