diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/conftest.py | 5 | ||||
-rw-r--r-- | test/test_networking.py | 79 | ||||
-rw-r--r-- | test/test_socks.py | 62 | ||||
-rw-r--r-- | test/test_websockets.py | 380 |
4 files changed, 483 insertions, 43 deletions
diff --git a/test/conftest.py b/test/conftest.py index 15549d30b..2fbc269e1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -19,3 +19,8 @@ def handler(request): pytest.skip(f'{RH_KEY} request handler is not available') return functools.partial(handler, logger=FakeLogger) + + +def validate_and_send(rh, req): + rh.validate(req) + return rh.send(req) diff --git a/test/test_networking.py b/test/test_networking.py index 4466fc048..64af6e459 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -52,6 +52,8 @@ from yt_dlp.networking.exceptions import ( from yt_dlp.utils._utils import _YDLLogger as FakeLogger from yt_dlp.utils.networking import HTTPHeaderDict +from test.conftest import validate_and_send + TEST_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -275,11 +277,6 @@ class HTTPTestRequestHandler(http.server.BaseHTTPRequestHandler): self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode()) -def validate_and_send(rh, req): - rh.validate(req) - return rh.send(req) - - class TestRequestHandlerBase: @classmethod def setup_class(cls): @@ -872,8 +869,9 @@ class TestRequestsRequestHandler(TestRequestHandlerBase): ]) @pytest.mark.parametrize('handler', ['Requests'], indirect=True) def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match): - from urllib3.response import HTTPResponse as Urllib3Response from requests.models import Response as RequestsResponse + from urllib3.response import HTTPResponse as Urllib3Response + from yt_dlp.networking._requests import RequestsResponseAdapter requests_res = RequestsResponse() requests_res.raw = Urllib3Response(body=b'', status=200) @@ -929,13 +927,17 @@ class TestRequestHandlerValidation: ('http', False, {}), ('https', False, {}), ]), + ('Websockets', [ + ('ws', False, {}), + ('wss', False, {}), + ]), (NoCheckRH, [('http', False, {})]), (ValidationRH, [('http', UnsupportedRequest, {})]) ] PROXY_SCHEME_TESTS = [ # scheme, expected to fail - ('Urllib', [ + ('Urllib', 'http', [ ('http', False), ('https', UnsupportedRequest), ('socks4', False), @@ -944,7 +946,7 @@ class TestRequestHandlerValidation: ('socks5h', False), ('socks', UnsupportedRequest), ]), - ('Requests', [ + ('Requests', 'http', [ ('http', False), ('https', False), ('socks4', False), @@ -952,8 +954,11 @@ class TestRequestHandlerValidation: ('socks5', False), ('socks5h', False), ]), - (NoCheckRH, [('http', False)]), - (HTTPSupportedRH, [('http', UnsupportedRequest)]), + (NoCheckRH, 'http', [('http', False)]), + (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]), + ('Websockets', 'ws', [('http', UnsupportedRequest)]), + (NoCheckRH, 'http', [('http', False)]), + (HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]), ] PROXY_KEY_TESTS = [ @@ -972,7 +977,7 @@ class TestRequestHandlerValidation: ] EXTENSION_TESTS = [ - ('Urllib', [ + ('Urllib', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': YoutubeDLCookieJar()}, False), ({'cookiejar': CookieJar()}, AssertionError), @@ -980,17 +985,21 @@ class TestRequestHandlerValidation: ({'timeout': 'notatimeout'}, AssertionError), ({'unsupported': 'value'}, UnsupportedRequest), ]), - ('Requests', [ + ('Requests', 'http', [ ({'cookiejar': 'notacookiejar'}, AssertionError), ({'cookiejar': YoutubeDLCookieJar()}, False), ({'timeout': 1}, False), ({'timeout': 'notatimeout'}, AssertionError), ({'unsupported': 'value'}, UnsupportedRequest), ]), - (NoCheckRH, [ + (NoCheckRH, 'http', [ ({'cookiejar': 'notacookiejar'}, False), ({'somerandom': 'test'}, False), # but any extension is allowed through ]), + ('Websockets', 'ws', [ + ({'cookiejar': YoutubeDLCookieJar()}, False), + ({'timeout': 2}, False), + ]), ] @pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [ @@ -1016,14 +1025,14 @@ class TestRequestHandlerValidation: run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'})) run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'}) - @pytest.mark.parametrize('handler,scheme,fail', [ - (handler_tests[0], scheme, fail) + @pytest.mark.parametrize('handler,req_scheme,scheme,fail', [ + (handler_tests[0], handler_tests[1], scheme, fail) for handler_tests in PROXY_SCHEME_TESTS - for scheme, fail in handler_tests[1] + for scheme, fail in handler_tests[2] ], indirect=['handler']) - def test_proxy_scheme(self, handler, scheme, fail): - run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'})) - run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'}) + def test_proxy_scheme(self, handler, req_scheme, scheme, fail): + run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'})) + run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'}) @pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True) def test_empty_proxy(self, handler): @@ -1035,14 +1044,14 @@ class TestRequestHandlerValidation: def test_invalid_proxy_url(self, handler, proxy_url): run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url})) - @pytest.mark.parametrize('handler,extensions,fail', [ - (handler_tests[0], extensions, fail) + @pytest.mark.parametrize('handler,scheme,extensions,fail', [ + (handler_tests[0], handler_tests[1], extensions, fail) for handler_tests in EXTENSION_TESTS - for extensions, fail in handler_tests[1] + for extensions, fail in handler_tests[2] ], indirect=['handler']) - def test_extension(self, handler, extensions, fail): + def test_extension(self, handler, scheme, extensions, fail): run_validation( - handler, fail, Request('http://', extensions=extensions)) + handler, fail, Request(f'{scheme}://', extensions=extensions)) def test_invalid_request_type(self): rh = self.ValidationRH(logger=FakeLogger()) @@ -1075,6 +1084,22 @@ class FakeRHYDL(FakeYDL): self._request_director = self.build_request_director([FakeRH]) +class AllUnsupportedRHYDL(FakeYDL): + + def __init__(self, *args, **kwargs): + + class UnsupportedRH(RequestHandler): + def _send(self, request: Request): + pass + + _SUPPORTED_FEATURES = () + _SUPPORTED_PROXY_SCHEMES = () + _SUPPORTED_URL_SCHEMES = () + + super().__init__(*args, **kwargs) + self._request_director = self.build_request_director([UnsupportedRH]) + + class TestRequestDirector: def test_handler_operations(self): @@ -1234,6 +1259,12 @@ class TestYoutubeDLNetworking: with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'): ydl.urlopen('file://') + @pytest.mark.parametrize('scheme', (['ws', 'wss'])) + def test_websocket_unavailable_error(self, scheme): + with AllUnsupportedRHYDL() as ydl: + with pytest.raises(RequestError, match=r'This request requires WebSocket support'): + ydl.urlopen(f'{scheme}://') + def test_legacy_server_connect_error(self): with FakeRHYDL() as ydl: for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'): diff --git a/test/test_socks.py b/test/test_socks.py index d8ac88dad..71f783e13 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -210,6 +210,16 @@ class SocksHTTPTestRequestHandler(http.server.BaseHTTPRequestHandler, SocksTestR self.wfile.write(payload.encode()) +class SocksWebSocketTestRequestHandler(SocksTestRequestHandler): + def handle(self): + import websockets.sync.server + protocol = websockets.ServerProtocol() + connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) + connection.handshake() + connection.send(json.dumps(self.socks_info)) + connection.close() + + @contextlib.contextmanager def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs): server = server_thread = None @@ -252,8 +262,22 @@ class HTTPSocksTestProxyContext(SocksProxyTestContext): return json.loads(handler.send(request).read().decode()) +class WebSocketSocksTestProxyContext(SocksProxyTestContext): + REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler + + def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs): + request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs) + handler.validate(request) + ws = handler.send(request) + ws.send('socks_info') + socks_info = ws.recv() + ws.close() + return json.loads(socks_info) + + CTX_MAP = { 'http': HTTPSocksTestProxyContext, + 'ws': WebSocketSocksTestProxyContext, } @@ -263,7 +287,7 @@ def ctx(request): class TestSocks4Proxy: - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4_no_auth(self, handler, ctx): with handler() as rh: with ctx.socks_server(Socks4ProxyHandler) as server_address: @@ -271,7 +295,7 @@ class TestSocks4Proxy: rh, proxies={'all': f'socks4://{server_address}'}) assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4_auth(self, handler, ctx): with handler() as rh: with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address: @@ -281,7 +305,7 @@ class TestSocks4Proxy: rh, proxies={'all': f'socks4://user:@{server_address}'}) assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4a_ipv4_target(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: @@ -289,7 +313,7 @@ class TestSocks4Proxy: assert response['version'] == 4 assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1') - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks4a_domain_target(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: with handler(proxies={'all': f'socks4a://{server_address}'}) as rh: @@ -298,7 +322,7 @@ class TestSocks4Proxy: assert response['ipv4_address'] is None assert response['domain_address'] == 'localhost' - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv4_client_source_address(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' @@ -308,7 +332,7 @@ class TestSocks4Proxy: assert response['client_address'][0] == source_address assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) @pytest.mark.parametrize('reply_code', [ Socks4CD.REQUEST_REJECTED_OR_FAILED, Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD, @@ -320,7 +344,7 @@ class TestSocks4Proxy: with pytest.raises(ProxyError): ctx.socks_info_request(rh) - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv6_socks4_proxy(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address: with handler(proxies={'all': f'socks4://{server_address}'}) as rh: @@ -329,7 +353,7 @@ class TestSocks4Proxy: assert response['ipv4_address'] == '127.0.0.1' assert response['version'] == 4 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_timeout(self, handler, ctx): with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address: with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh: @@ -339,7 +363,7 @@ class TestSocks4Proxy: class TestSocks5Proxy: - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_no_auth(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -347,7 +371,7 @@ class TestSocks5Proxy: assert response['auth_methods'] == [0x0] assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_user_pass(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address: with handler() as rh: @@ -360,7 +384,7 @@ class TestSocks5Proxy: assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS] assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_ipv4_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -368,7 +392,7 @@ class TestSocks5Proxy: assert response['ipv4_address'] == '127.0.0.1' assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_domain_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -376,7 +400,7 @@ class TestSocks5Proxy: assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1') assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5h_domain_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: @@ -385,7 +409,7 @@ class TestSocks5Proxy: assert response['domain_address'] == 'localhost' assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5h_ip_target(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5h://{server_address}'}) as rh: @@ -394,7 +418,7 @@ class TestSocks5Proxy: assert response['domain_address'] is None assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_socks5_ipv6_destination(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -402,7 +426,7 @@ class TestSocks5Proxy: assert response['ipv6_address'] == '::1' assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv6_socks5_proxy(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address: with handler(proxies={'all': f'socks5://{server_address}'}) as rh: @@ -413,7 +437,7 @@ class TestSocks5Proxy: # XXX: is there any feasible way of testing IPv6 source addresses? # Same would go for non-proxy source_address test... - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) def test_ipv4_client_source_address(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' @@ -422,7 +446,7 @@ class TestSocks5Proxy: assert response['client_address'][0] == source_address assert response['version'] == 5 - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True) @pytest.mark.parametrize('reply_code', [ Socks5Reply.GENERAL_FAILURE, Socks5Reply.CONNECTION_NOT_ALLOWED, @@ -439,7 +463,7 @@ class TestSocks5Proxy: with pytest.raises(ProxyError): ctx.socks_info_request(rh) - @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True) + @pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True) def test_timeout(self, handler, ctx): with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address: with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh: diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 000000000..39d3c7d72 --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 + +# Allow direct execution +import os +import sys + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import http.client +import http.cookiejar +import http.server +import json +import random +import ssl +import threading + +from yt_dlp import socks +from yt_dlp.cookies import YoutubeDLCookieJar +from yt_dlp.dependencies import websockets +from yt_dlp.networking import Request +from yt_dlp.networking.exceptions import ( + CertificateVerifyError, + HTTPError, + ProxyError, + RequestError, + SSLError, + TransportError, +) +from yt_dlp.utils.networking import HTTPHeaderDict + +from test.conftest import validate_and_send + +TEST_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def websocket_handler(websocket): + for message in websocket: + if isinstance(message, bytes): + if message == b'bytes': + return websocket.send('2') + elif isinstance(message, str): + if message == 'headers': + return websocket.send(json.dumps(dict(websocket.request.headers))) + elif message == 'path': + return websocket.send(websocket.request.path) + elif message == 'source_address': + return websocket.send(websocket.remote_address[0]) + elif message == 'str': + return websocket.send('1') + return websocket.send(message) + + +def process_request(self, request): + if request.path.startswith('/gen_'): + status = http.HTTPStatus(int(request.path[5:])) + if 300 <= status.value <= 300: + return websockets.http11.Response( + status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'') + return self.protocol.reject(status.value, status.phrase) + return self.protocol.accept(request) + + +def create_websocket_server(**ws_kwargs): + import websockets.sync.server + wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) + ws_port = wsd.socket.getsockname()[1] + ws_server_thread = threading.Thread(target=wsd.serve_forever) + ws_server_thread.daemon = True + ws_server_thread.start() + return ws_server_thread, ws_port + + +def create_ws_websocket_server(): + return create_websocket_server() + + +def create_wss_websocket_server(): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(certfn, None) + return create_websocket_server(ssl_context=sslctx) + + +MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate') + + +def create_mtls_wss_websocket_server(): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt') + + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.verify_mode = ssl.CERT_REQUIRED + sslctx.load_verify_locations(cafile=cacertfn) + sslctx.load_cert_chain(certfn, None) + + return create_websocket_server(ssl_context=sslctx) + + +@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') +class TestWebsSocketRequestHandlerConformance: + @classmethod + def setup_class(cls): + cls.ws_thread, cls.ws_port = create_ws_websocket_server() + cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}' + + cls.wss_thread, cls.wss_port = create_wss_websocket_server() + cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}' + + cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)) + cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}' + + cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server() + cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}' + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_basic_websockets(self, handler): + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + assert 'upgrade' in ws.headers + assert ws.status == 101 + ws.send('foo') + assert ws.recv() == 'foo' + ws.close() + + # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6 + @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)]) + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_send_types(self, handler, msg, opcode): + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send(msg) + assert int(ws.recv()) == opcode + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_verify_cert(self, handler): + with handler() as rh: + with pytest.raises(CertificateVerifyError): + validate_and_send(rh, Request(self.wss_base_url)) + + with handler(verify=False) as rh: + ws = validate_and_send(rh, Request(self.wss_base_url)) + assert ws.status == 101 + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_ssl_error(self, handler): + with handler(verify=False) as rh: + with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info: + validate_and_send(rh, Request(self.bad_wss_host)) + assert not issubclass(exc_info.type, CertificateVerifyError) + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('path,expected', [ + # Unicode characters should be encoded with uppercase percent-encoding + ('/中文', '/%E4%B8%AD%E6%96%87'), + # don't normalize existing percent encodings + ('/%c7%9f', '/%c7%9f'), + ]) + def test_percent_encode(self, handler, path, expected): + with handler() as rh: + ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) + ws.send('path') + assert ws.recv() == expected + assert ws.status == 101 + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_remove_dot_segments(self, handler): + with handler() as rh: + # This isn't a comprehensive test, + # but it should be enough to check whether the handler is removing dot segments + ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) + assert ws.status == 101 + ws.send('path') + assert ws.recv() == '/test' + ws.close() + + # We are restricted to known HTTP status codes in http.HTTPStatus + # Redirects are not supported for websockets + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511)) + def test_raise_http_error(self, handler, status): + with handler() as rh: + with pytest.raises(HTTPError) as exc_info: + validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) + assert exc_info.value.status == status + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + @pytest.mark.parametrize('params,extensions', [ + ({'timeout': 0.00001}, {}), + ({}, {'timeout': 0.00001}), + ]) + def test_timeout(self, handler, params, extensions): + with handler(**params) as rh: + with pytest.raises(TransportError): + validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_cookies(self, handler): + cookiejar = YoutubeDLCookieJar() + cookiejar.set_cookie(http.cookiejar.Cookie( + version=0, name='test', value='ytdlp', port=None, port_specified=False, + domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/', + path_specified=True, secure=False, expires=None, discard=False, comment=None, + comment_url=None, rest={})) + + with handler(cookiejar=cookiejar) as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + ws.close() + + with handler() as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + assert 'cookie' not in json.loads(ws.recv()) + ws.close() + + ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) + ws.send('headers') + assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_source_address(self, handler): + source_address = f'127.0.0.{random.randint(5, 255)}' + with handler(source_address=source_address) as rh: + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('source_address') + assert source_address == ws.recv() + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_response_url(self, handler): + with handler() as rh: + url = f'{self.ws_base_url}/something' + ws = validate_and_send(rh, Request(url)) + assert ws.url == url + ws.close() + + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_request_headers(self, handler): + with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: + # Global Headers + ws = validate_and_send(rh, Request(self.ws_base_url)) + ws.send('headers') + headers = HTTPHeaderDict(json.loads(ws.recv())) + assert headers['test1'] == 'test' + ws.close() + + # Per request headers, merged with global + ws = validate_and_send(rh, Request( + self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) + ws.send('headers') + headers = HTTPHeaderDict(json.loads(ws.recv())) + assert headers['test1'] == 'test' + assert headers['test2'] == 'changed' + assert headers['test3'] == 'test3' + ws.close() + + @pytest.mark.parametrize('client_cert', ( + {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')}, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), + 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'), + }, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'), + 'client_certificate_password': 'foobar', + }, + { + 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'), + 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'), + 'client_certificate_password': 'foobar', + } + )) + @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) + def test_mtls(self, handler, client_cert): + with handler( + # Disable client-side validation of unacceptable self-signed testcert.pem + # The test is of a check on the server side, so unaffected + verify=False, + client_cert=client_cert + ) as rh: + validate_and_send(rh, Request(self.mtls_wss_base_url)).close() + + +def create_fake_ws_connection(raised): + import websockets.sync.client + + class FakeWsConnection(websockets.sync.client.ClientConnection): + def __init__(self, *args, **kwargs): + class FakeResponse: + body = b'' + headers = {} + status_code = 101 + reason_phrase = 'test' + + self.response = FakeResponse() + + def send(self, *args, **kwargs): + raise raised() + + def recv(self, *args, **kwargs): + raise raised() + + def close(self, *args, **kwargs): + return + + return FakeWsConnection() + + +@pytest.mark.parametrize('handler', ['Websockets'], indirect=True) +class TestWebsocketsRequestHandler: + @pytest.mark.parametrize('raised,expected', [ + # https://websockets.readthedocs.io/en/stable/reference/exceptions.html + (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError), + # Requires a response object. Should be covered by HTTP error tests. + # (lambda: websockets.exceptions.InvalidStatus(), TransportError), + (lambda: websockets.exceptions.InvalidHandshake(), TransportError), + # These are subclasses of InvalidHandshake + (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError), + (lambda: websockets.exceptions.NegotiationError(), TransportError), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError), + (lambda: TimeoutError(), TransportError), + # These may be raised by our create_connection implementation, which should also be caught + (lambda: OSError(), TransportError), + (lambda: ssl.SSLError(), SSLError), + (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError), + (lambda: socks.ProxyError(), ProxyError), + ]) + def test_request_error_mapping(self, handler, monkeypatch, raised, expected): + import websockets.sync.client + + import yt_dlp.networking._websockets + with handler() as rh: + def fake_connect(*args, **kwargs): + raise raised() + monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None) + monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect) + with pytest.raises(expected) as exc_info: + rh.send(Request('ws://fake-url')) + assert exc_info.type is expected + + @pytest.mark.parametrize('raised,expected,match', [ + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send + (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), + (lambda: RuntimeError(), TransportError, None), + (lambda: TimeoutError(), TransportError, None), + (lambda: TypeError(), RequestError, None), + (lambda: socks.ProxyError(), ProxyError, None), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError, None), + ]) + def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match): + from yt_dlp.networking._websockets import WebsocketsResponseAdapter + ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') + with pytest.raises(expected, match=match) as exc_info: + ws.send('test') + assert exc_info.type is expected + + @pytest.mark.parametrize('raised,expected,match', [ + # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv + (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None), + (lambda: RuntimeError(), TransportError, None), + (lambda: TimeoutError(), TransportError, None), + (lambda: socks.ProxyError(), ProxyError, None), + # Catch-all + (lambda: websockets.exceptions.WebSocketException(), TransportError, None), + ]) + def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match): + from yt_dlp.networking._websockets import WebsocketsResponseAdapter + ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url') + with pytest.raises(expected, match=match) as exc_info: + ws.recv() + assert exc_info.type is expected |