diff --git a/stdlib/@tests/test_cases/check_functools.py b/stdlib/@tests/test_cases/check_functools.py index a25c0111adf9..5ce2bfad9599 100644 --- a/stdlib/@tests/test_cases/check_functools.py +++ b/stdlib/@tests/test_cases/check_functools.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import cache, cached_property, wraps +from functools import cache, cached_property, singledispatch, wraps from typing import Callable, TypeVar from typing_extensions import ParamSpec, assert_type @@ -108,3 +108,31 @@ class CachedChild(CachedParent): @cache def method(self) -> Child: return Child() + + +def check_singledispatch_simple() -> None: + @singledispatch + def sd_fun(arg: object) -> str: + return "" + + @sd_fun.register + def _(int_arg: int) -> str: + return "" + + sd_fun.dispatch(42) + sd_fun.dispatch("") + sd_fun.dispatch(1, 2) # type: ignore + + +def check_singledispatch_additional_args() -> None: + @singledispatch + def sd_fun(arg: object, posonly: str, /, verbose: bool = False) -> str: + return "" + + @sd_fun.register + def _(int_arg: int, posonly: str, /, verbose: bool = False) -> str: + return "" + + sd_fun.dispatch(5.4, "") + sd_fun.dispatch(5.4, "", verbose=True) + sd_fun.dispatch(1, 2) # type: ignore diff --git a/stdlib/functools.pyi b/stdlib/functools.pyi index 57bc3f179f7a..adbdd7d83cec 100644 --- a/stdlib/functools.pyi +++ b/stdlib/functools.pyi @@ -4,7 +4,7 @@ from _typeshed import SupportsAllComparisons, SupportsItems from collections.abc import Callable, Hashable, Iterable, Sized from types import GenericAlias from typing import Any, Final, Generic, Literal, NamedTuple, TypedDict, TypeVar, final, overload, type_check_only -from typing_extensions import ParamSpec, Self, TypeAlias, disjoint_base +from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, disjoint_base __all__ = [ "update_wrapper", @@ -26,6 +26,8 @@ __all__ = [ _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) _S = TypeVar("_S") +_P = ParamSpec("_P", default=...) +_R = TypeVar("_R", default=Any) _PWrapped = ParamSpec("_PWrapped") _RWrapped = TypeVar("_RWrapped") _PWrapper = ParamSpec("_PWrapper") @@ -189,44 +191,69 @@ class partialmethod(Generic[_T]): def __isabstractmethod__(self) -> bool: ... def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... -if sys.version_info >= (3, 11): - _RegType: TypeAlias = type[Any] | types.UnionType -else: - _RegType: TypeAlias = type[Any] - @type_check_only -class _SingleDispatchCallable(Generic[_T]): - registry: types.MappingProxyType[Any, Callable[..., _T]] - def dispatch(self, cls: Any) -> Callable[..., _T]: ... +class _SingleDispatchCallable(Generic[_P, _R]): + # First argument pf the callables in the registry is the type to dispatch on. + registry: types.MappingProxyType[Any, Callable[Concatenate[Any, _P], _R]] + def dispatch(self, cls: type[_S]) -> Callable[Concatenate[_S, _P], _R]: ... + if sys.version_info >= (3, 11): + # @fun.register(complex | str) + # def _(arg, verbose=False): ... + @overload + def register( + self, cls: types.UnionType, func: None = None + ) -> Callable[[Callable[Concatenate[_S, _P], _R]], Callable[Concatenate[_S, _P], _R]]: ... # @fun.register(complex) # def _(arg, verbose=False): ... @overload - def register(self, cls: _RegType, func: None = None) -> Callable[[Callable[..., _T]], Callable[..., _T]]: ... + def register( # type: ignore[overload-overlap] + self, cls: type[_S], func: None = None + ) -> Callable[[Callable[Concatenate[_S, _P], _R]], Callable[Concatenate[_S, _P], _R]]: ... # @fun.register # def _(arg: int, verbose=False): @overload - def register(self, cls: Callable[..., _T], func: None = None) -> Callable[..., _T]: ... + def register(self, cls: Callable[Concatenate[_S, _P], _R], func: None = None) -> Callable[Concatenate[_S, _P], _R]: ... + if sys.version_info >= (3, 11): + # fun.register(int, lambda x: x) + @overload + def register( + self, cls: types.UnionType, func: Callable[Concatenate[_S, _P], _R] + ) -> Callable[Concatenate[_S, _P], _R]: ... # fun.register(int, lambda x: x) @overload - def register(self, cls: _RegType, func: Callable[..., _T]) -> Callable[..., _T]: ... + def register(self, cls: type[_S], func: Callable[Concatenate[_S, _P], _R]) -> Callable[Concatenate[_S, _P], _R]: ... def _clear_cache(self) -> None: ... - def __call__(self, /, *args: Any, **kwargs: Any) -> _T: ... + def __call__(self, arg: object, /, *args: _P.args, **kwargs: _P.kwargs) -> _R: ... -def singledispatch(func: Callable[..., _T]) -> _SingleDispatchCallable[_T]: ... +def singledispatch(func: Callable[Concatenate[object, _P], _R]) -> _SingleDispatchCallable[_P, _R]: ... -class singledispatchmethod(Generic[_T]): - dispatcher: _SingleDispatchCallable[_T] - func: Callable[..., _T] - def __init__(self, func: Callable[..., _T]) -> None: ... +class singledispatchmethod(Generic[_P, _R]): + dispatcher: _SingleDispatchCallable[_P, _R] + func: Callable[_P, _R] + def __init__(self, func: Callable[Concatenate[object, _P], _R]) -> None: ... @property def __isabstractmethod__(self) -> bool: ... + if sys.version_info >= (3, 11): + @overload + def register( + self, cls: types.UnionType, method: None = None + ) -> Callable[[Callable[Concatenate[_S, _P], _R]], Callable[Concatenate[_S, _P], _R]]: ... + @overload - def register(self, cls: _RegType, method: None = None) -> Callable[[Callable[..., _T]], Callable[..., _T]]: ... + def register( # type: ignore[overload-overlap] + self, cls: type[_S], method: None = None + ) -> Callable[[Callable[Concatenate[_S, _P], _R]], Callable[Concatenate[_S, _P], _R]]: ... @overload - def register(self, cls: Callable[..., _T], method: None = None) -> Callable[..., _T]: ... + def register(self, cls: Callable[Concatenate[_S, _P], _R], method: None = None) -> Callable[Concatenate[_S, _P], _R]: ... + if sys.version_info >= (3, 11): + @overload + def register( + self, cls: types.UnionType, method: Callable[Concatenate[_S, _P], _R] + ) -> Callable[Concatenate[_S, _P], _R]: ... + @overload - def register(self, cls: _RegType, method: Callable[..., _T]) -> Callable[..., _T]: ... - def __get__(self, obj: _S, cls: type[_S] | None = None) -> Callable[..., _T]: ... + def register(self, cls: type[_S], method: Callable[Concatenate[_S, _P], _R]) -> Callable[Concatenate[_S, _P], _R]: ... + def __get__(self, obj: _S, cls: type[_S] | None = None) -> Callable[Concatenate[_S, _P], _R]: ... class cached_property(Generic[_T_co]): func: Callable[[Any], _T_co] diff --git a/stubs/aiofiles/aiofiles/threadpool/__init__.pyi b/stubs/aiofiles/aiofiles/threadpool/__init__.pyi index 51cb1d2fd8bd..554cd6ea1109 100644 --- a/stubs/aiofiles/aiofiles/threadpool/__init__.pyi +++ b/stubs/aiofiles/aiofiles/threadpool/__init__.pyi @@ -99,7 +99,7 @@ def open( executor: Executor | None = None, ) -> AiofilesContextManager[_UnknownAsyncBinaryIO]: ... -wrap: _SingleDispatchCallable[Any] +wrap: _SingleDispatchCallable[Any, Any] stdin: AsyncTextIndirectIOWrapper stdout: AsyncTextIndirectIOWrapper