aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Sawicki <contact@grub4k.xyz>2025-03-03 00:10:01 +0100
committerGitHub <noreply@github.com>2025-03-03 00:10:01 +0100
commit7d18fed8f1983fe6de4ddc810dfb2761ba5744ac (patch)
treeb090448b1098b43359a0ce87c531bb54178482ca
parent79ec2fdff75c8c1bb89b550266849ad4dec48dd3 (diff)
[networking] Add `keep_header_casing` extension (#11652)
Authored by: coletdjnz, Grub4K Co-authored-by: coletdjnz <coletdjnz@protonmail.com>
-rw-r--r--test/test_networking.py13
-rw-r--r--test/test_utils.py23
-rw-r--r--test/test_websockets.py22
-rw-r--r--yt_dlp/networking/_requests.py8
-rw-r--r--yt_dlp/networking/_urllib.py8
-rw-r--r--yt_dlp/networking/_websockets.py8
-rw-r--r--yt_dlp/networking/common.py19
-rw-r--r--yt_dlp/networking/impersonate.py22
-rw-r--r--yt_dlp/utils/networking.py146
9 files changed, 229 insertions, 40 deletions
diff --git a/test/test_networking.py b/test/test_networking.py
index d96624af1..63914bc4b 100644
--- a/test/test_networking.py
+++ b/test/test_networking.py
@@ -720,6 +720,15 @@ class TestHTTPRequestHandler(TestRequestHandlerBase):
rh, Request(
f'http://127.0.0.1:{self.http_port}/headers', proxies={'all': 'http://10.255.255.255'})).close()
+ @pytest.mark.skip_handlers_if(lambda _, handler: handler not in ['Urllib', 'CurlCFFI'], 'handler does not support keep_header_casing')
+ def test_keep_header_casing(self, handler):
+ with handler() as rh:
+ res = validate_and_send(
+ rh, Request(
+ f'http://127.0.0.1:{self.http_port}/headers', headers={'X-test-heaDer': 'test'}, extensions={'keep_header_casing': True})).read().decode()
+
+ assert 'X-test-heaDer: test' in res
+
@pytest.mark.parametrize('handler', ['Urllib', 'Requests', 'CurlCFFI'], indirect=True)
class TestClientCertificate:
@@ -1289,6 +1298,7 @@ class TestRequestHandlerValidation:
({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError),
+ ({'keep_header_casing': True}, UnsupportedRequest),
]),
('Requests', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError),
@@ -1299,6 +1309,9 @@ class TestRequestHandlerValidation:
({'legacy_ssl': False}, False),
({'legacy_ssl': True}, False),
({'legacy_ssl': 'notabool'}, AssertionError),
+ ({'keep_header_casing': False}, False),
+ ({'keep_header_casing': True}, False),
+ ({'keep_header_casing': 'notabool'}, AssertionError),
]),
('CurlCFFI', 'http', [
({'cookiejar': 'notacookiejar'}, AssertionError),
diff --git a/test/test_utils.py b/test/test_utils.py
index 8f81d0b1b..65f28db36 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -3,19 +3,20 @@
# Allow direct execution
import os
import sys
-import unittest
-import unittest.mock
-import warnings
-import datetime as dt
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import contextlib
+import datetime as dt
import io
import itertools
import json
+import pickle
import subprocess
+import unittest
+import unittest.mock
+import warnings
import xml.etree.ElementTree
from yt_dlp.compat import (
@@ -2087,21 +2088,26 @@ Line 1
headers = HTTPHeaderDict()
headers['ytdl-test'] = b'0'
self.assertEqual(list(headers.items()), [('Ytdl-Test', '0')])
+ self.assertEqual(list(headers.sensitive().items()), [('ytdl-test', '0')])
headers['ytdl-test'] = 1
self.assertEqual(list(headers.items()), [('Ytdl-Test', '1')])
+ self.assertEqual(list(headers.sensitive().items()), [('ytdl-test', '1')])
headers['Ytdl-test'] = '2'
self.assertEqual(list(headers.items()), [('Ytdl-Test', '2')])
+ self.assertEqual(list(headers.sensitive().items()), [('Ytdl-test', '2')])
self.assertTrue('ytDl-Test' in headers)
self.assertEqual(str(headers), str(dict(headers)))
self.assertEqual(repr(headers), str(dict(headers)))
headers.update({'X-dlp': 'data'})
self.assertEqual(set(headers.items()), {('Ytdl-Test', '2'), ('X-Dlp', 'data')})
+ self.assertEqual(set(headers.sensitive().items()), {('Ytdl-test', '2'), ('X-dlp', 'data')})
self.assertEqual(dict(headers), {'Ytdl-Test': '2', 'X-Dlp': 'data'})
self.assertEqual(len(headers), 2)
self.assertEqual(headers.copy(), headers)
- headers2 = HTTPHeaderDict({'X-dlp': 'data3'}, **headers, **{'X-dlp': 'data2'})
+ headers2 = HTTPHeaderDict({'X-dlp': 'data3'}, headers, **{'X-dlP': 'data2'})
self.assertEqual(set(headers2.items()), {('Ytdl-Test', '2'), ('X-Dlp', 'data2')})
+ self.assertEqual(set(headers2.sensitive().items()), {('Ytdl-test', '2'), ('X-dlP', 'data2')})
self.assertEqual(len(headers2), 2)
headers2.clear()
self.assertEqual(len(headers2), 0)
@@ -2109,16 +2115,23 @@ Line 1
# ensure we prefer latter headers
headers3 = HTTPHeaderDict({'Ytdl-TeSt': 1}, {'Ytdl-test': 2})
self.assertEqual(set(headers3.items()), {('Ytdl-Test', '2')})
+ self.assertEqual(set(headers3.sensitive().items()), {('Ytdl-test', '2')})
del headers3['ytdl-tesT']
self.assertEqual(dict(headers3), {})
headers4 = HTTPHeaderDict({'ytdl-test': 'data;'})
self.assertEqual(set(headers4.items()), {('Ytdl-Test', 'data;')})
+ self.assertEqual(set(headers4.sensitive().items()), {('ytdl-test', 'data;')})
# common mistake: strip whitespace from values
# https://github.com/yt-dlp/yt-dlp/issues/8729
headers5 = HTTPHeaderDict({'ytdl-test': ' data; '})
self.assertEqual(set(headers5.items()), {('Ytdl-Test', 'data;')})
+ self.assertEqual(set(headers5.sensitive().items()), {('ytdl-test', 'data;')})
+
+ # test if picklable
+ headers6 = HTTPHeaderDict(a=1, b=2)
+ self.assertEqual(pickle.loads(pickle.dumps(headers6)), headers6)
def test_extract_basic_auth(self):
assert extract_basic_auth('http://:foo.bar') == ('http://:foo.bar', None)
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 06112cc0b..dead5fe5c 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -44,7 +44,7 @@ def websocket_handler(websocket):
return websocket.send('2')
elif isinstance(message, str):
if message == 'headers':
- return websocket.send(json.dumps(dict(websocket.request.headers)))
+ return websocket.send(json.dumps(dict(websocket.request.headers.raw_items())))
elif message == 'path':
return websocket.send(websocket.request.path)
elif message == 'source_address':
@@ -266,18 +266,18 @@ class TestWebsSocketRequestHandlerConformance:
with handler(cookiejar=cookiejar) as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
- assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
ws.close()
with handler() as rh:
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
- assert 'cookie' not in json.loads(ws.recv())
+ assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
ws.close()
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
ws.send('headers')
- assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
ws.close()
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
@@ -287,7 +287,7 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie', extensions={'cookiejar': YoutubeDLCookieJar()}))
ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': YoutubeDLCookieJar()}))
ws.send('headers')
- assert 'cookie' not in json.loads(ws.recv())
+ assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
ws.close()
@pytest.mark.skip_handler('Websockets', 'Set-Cookie not supported by websockets')
@@ -298,12 +298,12 @@ class TestWebsSocketRequestHandlerConformance:
ws_validate_and_send(rh, Request(f'{self.ws_base_url}/get_cookie'))
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
- assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
+ assert HTTPHeaderDict(json.loads(ws.recv()))['cookie'] == 'test=ytdlp'
ws.close()
cookiejar.clear_session_cookies()
ws = ws_validate_and_send(rh, Request(self.ws_base_url))
ws.send('headers')
- assert 'cookie' not in json.loads(ws.recv())
+ assert 'cookie' not in HTTPHeaderDict(json.loads(ws.recv()))
ws.close()
def test_source_address(self, handler):
@@ -341,6 +341,14 @@ class TestWebsSocketRequestHandlerConformance:
assert headers['test3'] == 'test3'
ws.close()
+ def test_keep_header_casing(self, handler):
+ with handler(headers=HTTPHeaderDict({'x-TeSt1': 'test'})) as rh:
+ ws = ws_validate_and_send(rh, Request(self.ws_base_url, headers={'x-TeSt2': 'test'}, extensions={'keep_header_casing': True}))
+ ws.send('headers')
+ headers = json.loads(ws.recv())
+ assert 'x-TeSt1' in headers
+ assert 'x-TeSt2' in headers
+
@pytest.mark.parametrize('client_cert', (
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
{
diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py
index 7de95ab3b..23775845d 100644
--- a/yt_dlp/networking/_requests.py
+++ b/yt_dlp/networking/_requests.py
@@ -296,6 +296,7 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
extensions.pop('cookiejar', None)
extensions.pop('timeout', None)
extensions.pop('legacy_ssl', None)
+ extensions.pop('keep_header_casing', None)
def _create_instance(self, cookiejar, legacy_ssl_support=None):
session = RequestsSession()
@@ -312,11 +313,12 @@ class RequestsRH(RequestHandler, InstanceStoreMixin):
session.trust_env = False # no need, we already load proxies from env
return session
- def _send(self, request):
-
- headers = self._merge_headers(request.headers)
+ def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
+ def _send(self, request):
+
+ headers = self._get_headers(request)
max_redirects_exceeded = False
session = self._get_instance(
diff --git a/yt_dlp/networking/_urllib.py b/yt_dlp/networking/_urllib.py
index 510bb2a69..a188b35f5 100644
--- a/yt_dlp/networking/_urllib.py
+++ b/yt_dlp/networking/_urllib.py
@@ -379,13 +379,15 @@ class UrllibRH(RequestHandler, InstanceStoreMixin):
opener.addheaders = []
return opener
- def _send(self, request):
- headers = self._merge_headers(request.headers)
+ def _prepare_headers(self, _, headers):
add_accept_encoding_header(headers, SUPPORTED_ENCODINGS)
+
+ def _send(self, request):
+ headers = self._get_headers(request)
urllib_req = urllib.request.Request(
url=request.url,
data=request.data,
- headers=dict(headers),
+ headers=headers,
method=request.method,
)
diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py
index ec55567da..7e5ab4600 100644
--- a/yt_dlp/networking/_websockets.py
+++ b/yt_dlp/networking/_websockets.py
@@ -116,6 +116,7 @@ class WebsocketsRH(WebSocketRequestHandler):
extensions.pop('timeout', None)
extensions.pop('cookiejar', None)
extensions.pop('legacy_ssl', None)
+ extensions.pop('keep_header_casing', None)
def close(self):
# Remove the logging handler that contains a reference to our logger
@@ -123,15 +124,16 @@ class WebsocketsRH(WebSocketRequestHandler):
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
- def _send(self, request):
- timeout = self._calculate_timeout(request)
- headers = self._merge_headers(request.headers)
+ def _prepare_headers(self, request, headers):
if 'cookie' not in headers:
cookiejar = self._get_cookiejar(request)
cookie_header = cookiejar.get_cookie_header(request.url)
if cookie_header:
headers['cookie'] = cookie_header
+ def _send(self, request):
+ timeout = self._calculate_timeout(request)
+ headers = self._get_headers(request)
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
diff --git a/yt_dlp/networking/common.py b/yt_dlp/networking/common.py
index e8951c7e7..ddceaa9a9 100644
--- a/yt_dlp/networking/common.py
+++ b/yt_dlp/networking/common.py
@@ -206,6 +206,7 @@ class RequestHandler(abc.ABC):
- `cookiejar`: Cookiejar to use for this request.
- `timeout`: socket timeout to use for this request.
- `legacy_ssl`: Enable legacy SSL options for this request. See legacy_ssl_support.
+ - `keep_header_casing`: Keep the casing of headers when sending the request.
To enable these, add extensions.pop('<extension>', None) to _check_extensions
Apart from the url protocol, proxies dict may contain the following keys:
@@ -259,6 +260,23 @@ class RequestHandler(abc.ABC):
def _merge_headers(self, request_headers):
return HTTPHeaderDict(self.headers, request_headers)
+ def _prepare_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
+ """Additional operations to prepare headers before building. To be extended by subclasses.
+ @param request: Request object
+ @param headers: Merged headers to prepare
+ """
+
+ def _get_headers(self, request: Request) -> dict[str, str]:
+ """
+ Get headers for external use.
+ Subclasses may define a _prepare_headers method to modify headers after merge but before building.
+ """
+ headers = self._merge_headers(request.headers)
+ self._prepare_headers(request, headers)
+ if request.extensions.get('keep_header_casing'):
+ return headers.sensitive()
+ return dict(headers)
+
def _calculate_timeout(self, request):
return float(request.extensions.get('timeout') or self.timeout)
@@ -317,6 +335,7 @@ class RequestHandler(abc.ABC):
assert isinstance(extensions.get('cookiejar'), (YoutubeDLCookieJar, NoneType))
assert isinstance(extensions.get('timeout'), (float, int, NoneType))
assert isinstance(extensions.get('legacy_ssl'), (bool, NoneType))
+ assert isinstance(extensions.get('keep_header_casing'), (bool, NoneType))
def _validate(self, request):
self._check_url_scheme(request)
diff --git a/yt_dlp/networking/impersonate.py b/yt_dlp/networking/impersonate.py
index 0626b3b49..b90d10b76 100644
--- a/yt_dlp/networking/impersonate.py
+++ b/yt_dlp/networking/impersonate.py
@@ -5,11 +5,11 @@ from abc import ABC
from dataclasses import dataclass
from typing import Any
-from .common import RequestHandler, register_preference
+from .common import RequestHandler, register_preference, Request
from .exceptions import UnsupportedRequest
from ..compat.types import NoneType
from ..utils import classproperty, join_nonempty
-from ..utils.networking import std_headers
+from ..utils.networking import std_headers, HTTPHeaderDict
@dataclass(order=True, frozen=True)
@@ -123,7 +123,17 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
"""Get the requested target for the request"""
return self._resolve_target(request.extensions.get('impersonate') or self.impersonate)
- def _get_impersonate_headers(self, request):
+ def _prepare_impersonate_headers(self, request: Request, headers: HTTPHeaderDict) -> None: # noqa: B027
+ """Additional operations to prepare headers before building. To be extended by subclasses.
+ @param request: Request object
+ @param headers: Merged headers to prepare
+ """
+
+ def _get_impersonate_headers(self, request: Request) -> dict[str, str]:
+ """
+ Get headers for external impersonation use.
+ Subclasses may define a _prepare_impersonate_headers method to modify headers after merge but before building.
+ """
headers = self._merge_headers(request.headers)
if self._get_request_target(request) is not None:
# remove all headers present in std_headers
@@ -131,7 +141,11 @@ class ImpersonateRequestHandler(RequestHandler, ABC):
for k, v in std_headers.items():
if headers.get(k) == v:
headers.pop(k)
- return headers
+
+ self._prepare_impersonate_headers(request, headers)
+ if request.extensions.get('keep_header_casing'):
+ return headers.sensitive()
+ return dict(headers)
@register_preference(ImpersonateRequestHandler)
diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py
index 933b164be..542abace8 100644
--- a/yt_dlp/utils/networking.py
+++ b/yt_dlp/utils/networking.py
@@ -1,9 +1,16 @@
+from __future__ import annotations
+
import collections
+import collections.abc
import random
+import typing
import urllib.parse
import urllib.request
-from ._utils import remove_start
+if typing.TYPE_CHECKING:
+ T = typing.TypeVar('T')
+
+from ._utils import NO_DEFAULT, remove_start
def random_user_agent():
@@ -51,32 +58,141 @@ def random_user_agent():
return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS)
-class HTTPHeaderDict(collections.UserDict, dict):
+class HTTPHeaderDict(dict):
"""
Store and access keys case-insensitively.
The constructor can take multiple dicts, in which keys in the latter are prioritised.
+
+ Retains a case sensitive mapping of the headers, which can be accessed via `.sensitive()`.
"""
+ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
+ obj = dict.__new__(cls, *args, **kwargs)
+ obj.__sensitive_map = {}
+ return obj
- def __init__(self, *args, **kwargs):
+ def __init__(self, /, *args, **kwargs):
super().__init__()
- for dct in args:
- if dct is not None:
- self.update(dct)
- self.update(kwargs)
+ self.__sensitive_map = {}
+
+ for dct in filter(None, args):
+ self.update(dct)
+ if kwargs:
+ self.update(kwargs)
+
+ def sensitive(self, /) -> dict[str, str]:
+ return {
+ self.__sensitive_map[key]: value
+ for key, value in self.items()
+ }
+
+ def __contains__(self, key: str, /) -> bool:
+ return super().__contains__(key.title() if isinstance(key, str) else key)
+
+ def __delitem__(self, key: str, /) -> None:
+ key = key.title()
+ del self.__sensitive_map[key]
+ super().__delitem__(key)
- def __setitem__(self, key, value):
+ def __getitem__(self, key, /) -> str:
+ return super().__getitem__(key.title())
+
+ def __ior__(self, other, /):
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, dict):
+ self.update(other)
+ return
+ return NotImplemented
+
+ def __or__(self, other, /) -> typing.Self:
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, dict):
+ return type(self)(self.sensitive(), other)
+ return NotImplemented
+
+ def __ror__(self, other, /) -> typing.Self:
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, dict):
+ return type(self)(other, self.sensitive())
+ return NotImplemented
+
+ def __setitem__(self, key: str, value, /) -> None:
if isinstance(value, bytes):
value = value.decode('latin-1')
- super().__setitem__(key.title(), str(value).strip())
+ key_title = key.title()
+ self.__sensitive_map[key_title] = key
+ super().__setitem__(key_title, str(value).strip())
- def __getitem__(self, key):
- return super().__getitem__(key.title())
+ def clear(self, /) -> None:
+ self.__sensitive_map.clear()
+ super().clear()
- def __delitem__(self, key):
- super().__delitem__(key.title())
+ def copy(self, /) -> typing.Self:
+ return type(self)(self.sensitive())
- def __contains__(self, key):
- return super().__contains__(key.title() if isinstance(key, str) else key)
+ @typing.overload
+ def get(self, key: str, /) -> str | None: ...
+
+ @typing.overload
+ def get(self, key: str, /, default: T) -> str | T: ...
+
+ def get(self, key, /, default=NO_DEFAULT):
+ key = key.title()
+ if default is NO_DEFAULT:
+ return super().get(key)
+ return super().get(key, default)
+
+ @typing.overload
+ def pop(self, key: str, /) -> str: ...
+
+ @typing.overload
+ def pop(self, key: str, /, default: T) -> str | T: ...
+
+ def pop(self, key, /, default=NO_DEFAULT):
+ key = key.title()
+ if default is NO_DEFAULT:
+ self.__sensitive_map.pop(key)
+ return super().pop(key)
+ self.__sensitive_map.pop(key, default)
+ return super().pop(key, default)
+
+ def popitem(self) -> tuple[str, str]:
+ self.__sensitive_map.popitem()
+ return super().popitem()
+
+ @typing.overload
+ def setdefault(self, key: str, /) -> str: ...
+
+ @typing.overload
+ def setdefault(self, key: str, /, default) -> str: ...
+
+ def setdefault(self, key, /, default=None) -> str:
+ key = key.title()
+ if key in self.__sensitive_map:
+ return super().__getitem__(key)
+
+ self[key] = default or ''
+ return self[key]
+
+ def update(self, other, /, **kwargs) -> None:
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, collections.abc.Mapping):
+ for key, value in other.items():
+ self[key] = value
+
+ elif hasattr(other, 'keys'):
+ for key in other.keys(): # noqa: SIM118
+ self[key] = other[key]
+
+ else:
+ for key, value in other:
+ self[key] = value
+
+ for key, value in kwargs.items():
+ self[key] = value
std_headers = HTTPHeaderDict({