aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_socks.py38
-rw-r--r--yt_dlp/networking/_helper.py57
-rw-r--r--yt_dlp/networking/_urllib.py68
-rw-r--r--yt_dlp/socks.py31
4 files changed, 110 insertions, 84 deletions
diff --git a/test/test_socks.py b/test/test_socks.py
index 95ffce275..211ee814d 100644
--- a/test/test_socks.py
+++ b/test/test_socks.py
@@ -281,17 +281,13 @@ class TestSocks4Proxy:
rh, proxies={'all': f'socks4://user:@{server_address}'})
assert response['version'] == 4
- @pytest.mark.parametrize('handler,ctx', [
- pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
- reason='socks4a implementation currently broken when destination is not a domain name'))
- ], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_socks4a_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
assert response['version'] == 4
- assert response['ipv4_address'] == '127.0.0.1'
- assert response['domain_address'] is None
+ assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_socks4a_domain_target(self, handler, ctx):
@@ -302,10 +298,7 @@ class TestSocks4Proxy:
assert response['ipv4_address'] is None
assert response['domain_address'] == 'localhost'
- @pytest.mark.parametrize('handler,ctx', [
- pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
- reason='source_address is not yet supported for socks4 proxies'))
- ], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
@@ -327,10 +320,7 @@ class TestSocks4Proxy:
with pytest.raises(ProxyError):
ctx.socks_info_request(rh)
- @pytest.mark.parametrize('handler,ctx', [
- pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
- reason='IPv6 socks4 proxies are not yet supported'))
- ], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv6_socks4_proxy(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
@@ -342,7 +332,7 @@ class TestSocks4Proxy:
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
- with handler(proxies={'all': f'socks4://{server_address}'}, timeout=1) as rh:
+ with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
with pytest.raises(TransportError):
ctx.socks_info_request(rh)
@@ -383,7 +373,7 @@ class TestSocks5Proxy:
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
response = ctx.socks_info_request(rh, target_domain='localhost')
- assert response['ipv4_address'] == '127.0.0.1'
+ assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
assert response['version'] == 5
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
@@ -404,22 +394,15 @@ class TestSocks5Proxy:
assert response['domain_address'] is None
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [
- pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
- reason='IPv6 destination addresses are not yet supported'))
- ], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_socks5_ipv6_destination(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
response = ctx.socks_info_request(rh, target_domain='[::1]')
assert response['ipv6_address'] == '::1'
- assert response['port'] == 80
assert response['version'] == 5
- @pytest.mark.parametrize('handler,ctx', [
- pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
- reason='IPv6 socks5 proxies are not yet supported'))
- ], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv6_socks5_proxy(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
@@ -430,10 +413,7 @@ class TestSocks5Proxy:
# XXX: is there any feasible way of testing IPv6 source addresses?
# Same would go for non-proxy source_address test...
- @pytest.mark.parametrize('handler,ctx', [
- pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
- reason='source_address is not yet supported for socks5 proxies'))
- ], indirect=True)
+ @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py
index a43c57bb4..4c9dbf25d 100644
--- a/yt_dlp/networking/_helper.py
+++ b/yt_dlp/networking/_helper.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import contextlib
import functools
+import socket
import ssl
import sys
import typing
@@ -206,3 +207,59 @@ def wrap_request_errors(func):
e.handler = self
raise
return wrapper
+
+
+def _socket_connect(ip_addr, timeout, source_address):
+ af, socktype, proto, canonname, sa = ip_addr
+ sock = socket.socket(af, socktype, proto)
+ try:
+ if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
+ sock.settimeout(timeout)
+ if source_address:
+ sock.bind(source_address)
+ sock.connect(sa)
+ return sock
+ except socket.error:
+ sock.close()
+ raise
+
+
+def create_connection(
+ address,
+ timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None,
+ *,
+ _create_socket_func=_socket_connect
+):
+ # Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
+ # This filters the addresses based on the given source_address.
+ # Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
+ host, port = address
+ ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
+ if not ip_addrs:
+ raise socket.error('getaddrinfo returns an empty list')
+ if source_address is not None:
+ af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
+ ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
+ if not ip_addrs:
+ raise OSError(
+ f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
+ f'Can\'t use "{source_address[0]}" as source address')
+
+ err = None
+ for ip_addr in ip_addrs:
+ try:
+ sock = _create_socket_func(ip_addr, timeout, source_address)
+ # Explicitly break __traceback__ reference cycle
+ # https://bugs.python.org/issue36820
+ err = None
+ return sock
+ except socket.error as e:
+ err = e
+
+ try:
+ raise err
+ finally:
+ # Explicitly break __traceback__ reference cycle
+ # https://bugs.python.org/issue36820
+ err = None
diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py
index 3c0647ecf..c327f7744 100644
--- a/yt_dlp/networking/_urllib.py
+++ b/yt_dlp/networking/_urllib.py
@@ -23,6 +23,7 @@ from urllib.request import (
from ._helper import (
InstanceStoreMixin,
add_accept_encoding_header,
+ create_connection,
get_redirect_method,
make_socks_proxy_opts,
select_proxy,
@@ -54,44 +55,10 @@ if brotli:
def _create_http_connection(http_class, source_address, *args, **kwargs):
hc = http_class(*args, **kwargs)
+ if hasattr(hc, '_create_connection'):
+ hc._create_connection = create_connection
+
if source_address is not None:
- # This is to workaround _create_connection() from socket where it will try all
- # address data from getaddrinfo() including IPv6. This filters the result from
- # getaddrinfo() based on the source_address value.
- # This is based on the cpython socket.create_connection() function.
- # https://github.com/python/cpython/blob/master/Lib/socket.py#L691
- def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
- host, port = address
- err = None
- addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
- af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
- ip_addrs = [addr for addr in addrs if addr[0] == af]
- if addrs and not ip_addrs:
- ip_version = 'v4' if af == socket.AF_INET else 'v6'
- raise OSError(
- "No remote IP%s addresses available for connect, can't use '%s' as source address"
- % (ip_version, source_address[0]))
- for res in ip_addrs:
- af, socktype, proto, canonname, sa = res
- sock = None
- try:
- sock = socket.socket(af, socktype, proto)
- if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
- sock.settimeout(timeout)
- sock.bind(source_address)
- sock.connect(sa)
- err = None # Explicitly break reference cycle
- return sock
- except OSError as _:
- err = _
- if sock is not None:
- sock.close()
- if err is not None:
- raise err
- else:
- raise OSError('getaddrinfo returns an empty list')
- if hasattr(hc, '_create_connection'):
- hc._create_connection = _create_connection
hc.source_address = (source_address, 0)
return hc
@@ -220,13 +187,28 @@ def make_socks_conn_class(base_class, socks_proxy):
proxy_args = make_socks_proxy_opts(socks_proxy)
class SocksConnection(base_class):
- def connect(self):
- self.sock = sockssocket()
- self.sock.setproxy(**proxy_args)
- if type(self.timeout) in (int, float): # noqa: E721
- self.sock.settimeout(self.timeout)
- self.sock.connect((self.host, self.port))
+ _create_connection = create_connection
+ def connect(self):
+ def sock_socket_connect(ip_addr, timeout, source_address):
+ af, socktype, proto, canonname, sa = ip_addr
+ sock = sockssocket(af, socktype, proto)
+ try:
+ connect_proxy_args = proxy_args.copy()
+ connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
+ sock.setproxy(**connect_proxy_args)
+ if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721
+ sock.settimeout(timeout)
+ if source_address:
+ sock.bind(source_address)
+ sock.connect((self.host, self.port))
+ return sock
+ except socket.error:
+ sock.close()
+ raise
+ self.sock = create_connection(
+ (proxy_args['addr'], proxy_args['port']), timeout=self.timeout,
+ source_address=self.source_address, _create_socket_func=sock_socket_connect)
if isinstance(self, http.client.HTTPSConnection):
self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)
diff --git a/yt_dlp/socks.py b/yt_dlp/socks.py
index f93328f63..e7f41d7e2 100644
--- a/yt_dlp/socks.py
+++ b/yt_dlp/socks.py
@@ -134,26 +134,31 @@ class sockssocket(socket.socket):
self.close()
raise InvalidVersionError(expected_version, got_version)
- def _resolve_address(self, destaddr, default, use_remote_dns):
- try:
- return socket.inet_aton(destaddr)
- except OSError:
- if use_remote_dns and self._proxy.remote_dns:
- return default
- else:
- return socket.inet_aton(socket.gethostbyname(destaddr))
+ def _resolve_address(self, destaddr, default, use_remote_dns, family=None):
+ for f in (family,) if family else (socket.AF_INET, socket.AF_INET6):
+ try:
+ return f, socket.inet_pton(f, destaddr)
+ except OSError:
+ continue
+
+ if use_remote_dns and self._proxy.remote_dns:
+ return 0, default
+ else:
+ res = socket.getaddrinfo(destaddr, None, family=family or 0)
+ f, _, _, _, ipaddr = res[0]
+ return f, socket.inet_pton(f, ipaddr[0])
def _setup_socks4(self, address, is_4a=False):
destaddr, port = address
- ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a)
+ _, ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a, family=socket.AF_INET)
packet = struct.pack('!BBH', SOCKS4_VERSION, Socks4Command.CMD_CONNECT, port) + ipaddr
username = (self._proxy.username or '').encode()
packet += username + b'\x00'
- if is_4a and self._proxy.remote_dns:
+ if is_4a and self._proxy.remote_dns and ipaddr == SOCKS4_DEFAULT_DSTIP:
packet += destaddr.encode() + b'\x00'
self.sendall(packet)
@@ -210,7 +215,7 @@ class sockssocket(socket.socket):
def _setup_socks5(self, address):
destaddr, port = address
- ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True)
+ family, ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True)
self._socks5_auth()
@@ -220,8 +225,10 @@ class sockssocket(socket.socket):
destaddr = destaddr.encode()
packet += struct.pack('!B', Socks5AddressType.ATYP_DOMAINNAME)
packet += self._len_and_data(destaddr)
- else:
+ elif family == socket.AF_INET:
packet += struct.pack('!B', Socks5AddressType.ATYP_IPV4) + ipaddr
+ elif family == socket.AF_INET6:
+ packet += struct.pack('!B', Socks5AddressType.ATYP_IPV6) + ipaddr
packet += struct.pack('!H', port)
self.sendall(packet)