aboutsummaryrefslogtreecommitdiff
path: root/yt_dlp/utils/networking.py
diff options
context:
space:
mode:
Diffstat (limited to 'yt_dlp/utils/networking.py')
-rw-r--r--yt_dlp/utils/networking.py146
1 files changed, 131 insertions, 15 deletions
diff --git a/yt_dlp/utils/networking.py b/yt_dlp/utils/networking.py
index 933b164be..542abace8 100644
--- a/yt_dlp/utils/networking.py
+++ b/yt_dlp/utils/networking.py
@@ -1,9 +1,16 @@
+from __future__ import annotations
+
import collections
+import collections.abc
import random
+import typing
import urllib.parse
import urllib.request
-from ._utils import remove_start
+if typing.TYPE_CHECKING:
+ T = typing.TypeVar('T')
+
+from ._utils import NO_DEFAULT, remove_start
def random_user_agent():
@@ -51,32 +58,141 @@ def random_user_agent():
return _USER_AGENT_TPL % random.choice(_CHROME_VERSIONS)
-class HTTPHeaderDict(collections.UserDict, dict):
+class HTTPHeaderDict(dict):
"""
Store and access keys case-insensitively.
The constructor can take multiple dicts, in which keys in the latter are prioritised.
+
+ Retains a case sensitive mapping of the headers, which can be accessed via `.sensitive()`.
"""
+ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
+ obj = dict.__new__(cls, *args, **kwargs)
+ obj.__sensitive_map = {}
+ return obj
- def __init__(self, *args, **kwargs):
+ def __init__(self, /, *args, **kwargs):
super().__init__()
- for dct in args:
- if dct is not None:
- self.update(dct)
- self.update(kwargs)
+ self.__sensitive_map = {}
+
+ for dct in filter(None, args):
+ self.update(dct)
+ if kwargs:
+ self.update(kwargs)
+
+ def sensitive(self, /) -> dict[str, str]:
+ return {
+ self.__sensitive_map[key]: value
+ for key, value in self.items()
+ }
+
+ def __contains__(self, key: str, /) -> bool:
+ return super().__contains__(key.title() if isinstance(key, str) else key)
+
+ def __delitem__(self, key: str, /) -> None:
+ key = key.title()
+ del self.__sensitive_map[key]
+ super().__delitem__(key)
- def __setitem__(self, key, value):
+ def __getitem__(self, key, /) -> str:
+ return super().__getitem__(key.title())
+
+ def __ior__(self, other, /):
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, dict):
+ self.update(other)
+ return
+ return NotImplemented
+
+ def __or__(self, other, /) -> typing.Self:
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, dict):
+ return type(self)(self.sensitive(), other)
+ return NotImplemented
+
+ def __ror__(self, other, /) -> typing.Self:
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, dict):
+ return type(self)(other, self.sensitive())
+ return NotImplemented
+
+ def __setitem__(self, key: str, value, /) -> None:
if isinstance(value, bytes):
value = value.decode('latin-1')
- super().__setitem__(key.title(), str(value).strip())
+ key_title = key.title()
+ self.__sensitive_map[key_title] = key
+ super().__setitem__(key_title, str(value).strip())
- def __getitem__(self, key):
- return super().__getitem__(key.title())
+ def clear(self, /) -> None:
+ self.__sensitive_map.clear()
+ super().clear()
- def __delitem__(self, key):
- super().__delitem__(key.title())
+ def copy(self, /) -> typing.Self:
+ return type(self)(self.sensitive())
- def __contains__(self, key):
- return super().__contains__(key.title() if isinstance(key, str) else key)
+ @typing.overload
+ def get(self, key: str, /) -> str | None: ...
+
+ @typing.overload
+ def get(self, key: str, /, default: T) -> str | T: ...
+
+ def get(self, key, /, default=NO_DEFAULT):
+ key = key.title()
+ if default is NO_DEFAULT:
+ return super().get(key)
+ return super().get(key, default)
+
+ @typing.overload
+ def pop(self, key: str, /) -> str: ...
+
+ @typing.overload
+ def pop(self, key: str, /, default: T) -> str | T: ...
+
+ def pop(self, key, /, default=NO_DEFAULT):
+ key = key.title()
+ if default is NO_DEFAULT:
+ self.__sensitive_map.pop(key)
+ return super().pop(key)
+ self.__sensitive_map.pop(key, default)
+ return super().pop(key, default)
+
+ def popitem(self) -> tuple[str, str]:
+ self.__sensitive_map.popitem()
+ return super().popitem()
+
+ @typing.overload
+ def setdefault(self, key: str, /) -> str: ...
+
+ @typing.overload
+ def setdefault(self, key: str, /, default) -> str: ...
+
+ def setdefault(self, key, /, default=None) -> str:
+ key = key.title()
+ if key in self.__sensitive_map:
+ return super().__getitem__(key)
+
+ self[key] = default or ''
+ return self[key]
+
+ def update(self, other, /, **kwargs) -> None:
+ if isinstance(other, type(self)):
+ other = other.sensitive()
+ if isinstance(other, collections.abc.Mapping):
+ for key, value in other.items():
+ self[key] = value
+
+ elif hasattr(other, 'keys'):
+ for key in other.keys(): # noqa: SIM118
+ self[key] = other[key]
+
+ else:
+ for key, value in other:
+ self[key] = value
+
+ for key, value in kwargs.items():
+ self[key] = value
std_headers = HTTPHeaderDict({