Start using ParamSpec for decorator functions (#63148)

This commit is contained in:
Marc Mueller 2022-01-04 18:37:46 +01:00 committed by GitHub
parent 3a32fe9a34
commit 53496c019c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 82 additions and 40 deletions

View file

@ -2,17 +2,18 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Sequence from collections.abc import Awaitable, Callable, Sequence
import contextlib import contextlib
from datetime import datetime, timedelta from datetime import datetime, timedelta
import functools import functools
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar
from async_upnp_client import UpnpService, UpnpStateVariable from async_upnp_client import UpnpService, UpnpStateVariable
from async_upnp_client.const import NotificationSubType from async_upnp_client.const import NotificationSubType
from async_upnp_client.exceptions import UpnpError, UpnpResponseError from async_upnp_client.exceptions import UpnpError, UpnpResponseError
from async_upnp_client.profiles.dlna import DmrDevice, PlayMode, TransportState from async_upnp_client.profiles.dlna import DmrDevice, PlayMode, TransportState
from async_upnp_client.utils import async_get_local_ip from async_upnp_client.utils import async_get_local_ip
from typing_extensions import Concatenate, ParamSpec
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import ssdp from homeassistant.components import ssdp
@ -65,27 +66,32 @@ from .data import EventListenAddr, get_domain_data
PARALLEL_UPDATES = 0 PARALLEL_UPDATES = 0
Func = TypeVar("Func", bound=Callable[..., Any]) _T = TypeVar("_T", bound="DlnaDmrEntity")
_R = TypeVar("_R")
_P = ParamSpec("_P")
def catch_request_errors(func: Func) -> Func: def catch_request_errors(
func: Callable[Concatenate[_T, _P], Awaitable[_R]] # type: ignore[misc]
) -> Callable[Concatenate[_T, _P], Awaitable[_R | None]]: # type: ignore[misc]
"""Catch UpnpError errors.""" """Catch UpnpError errors."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(self: "DlnaDmrEntity", *args: Any, **kwargs: Any) -> Any: async def wrapper(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R | None:
"""Catch UpnpError errors and check availability before and after request.""" """Catch UpnpError errors and check availability before and after request."""
if not self.available: if not self.available:
_LOGGER.warning( _LOGGER.warning(
"Device disappeared when trying to call service %s", func.__name__ "Device disappeared when trying to call service %s", func.__name__
) )
return return None
try: try:
return await func(self, *args, **kwargs) return await func(self, *args, **kwargs) # type: ignore[no-any-return] # mypy can't yet infer 'func'
except UpnpError as err: except UpnpError as err:
self.check_available = True self.check_available = True
_LOGGER.error("Error during call %s: %r", func.__name__, err) _LOGGER.error("Error during call %s: %r", func.__name__, err)
return None
return cast(Func, wrapper) return wrapper
async def async_setup_entry( async def async_setup_entry(

View file

@ -1,21 +1,29 @@
"""Utilities for Evil Genius Labs.""" """Utilities for Evil Genius Labs."""
from collections.abc import Callable from __future__ import annotations
from collections.abc import Awaitable, Callable
from functools import wraps from functools import wraps
from typing import Any, TypeVar, cast from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from . import EvilGeniusEntity from . import EvilGeniusEntity
CallableT = TypeVar("CallableT", bound=Callable) _T = TypeVar("_T", bound=EvilGeniusEntity)
_R = TypeVar("_R")
_P = ParamSpec("_P")
def update_when_done(func: CallableT) -> CallableT: def update_when_done(
func: Callable[Concatenate[_T, _P], Awaitable[_R]] # type: ignore[misc]
) -> Callable[Concatenate[_T, _P], Awaitable[_R]]: # type: ignore[misc]
"""Decorate function to trigger update when function is done.""" """Decorate function to trigger update when function is done."""
@wraps(func) @wraps(func)
async def wrapper(self: EvilGeniusEntity, *args: Any, **kwargs: Any) -> Any: async def wrapper(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""Wrap function.""" """Wrap function."""
result = await func(self, *args, **kwargs) result = await func(self, *args, **kwargs)
await self.coordinator.async_request_refresh() await self.coordinator.async_request_refresh()
return result return result # type: ignore[no-any-return] # mypy can't yet infer 'func'
return cast(CallableT, wrapper) return wrapper

View file

@ -1,10 +1,12 @@
"""Helper methods for common tasks.""" """Helper methods for common tasks."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from typing import TYPE_CHECKING, TypeVar
from soco.exceptions import SoCoException, SoCoUPnPException from soco.exceptions import SoCoException, SoCoUPnPException
from typing_extensions import Concatenate, ParamSpec
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import dispatcher_send from homeassistant.helpers.dispatcher import dispatcher_send
@ -19,20 +21,26 @@ if TYPE_CHECKING:
UID_PREFIX = "RINCON_" UID_PREFIX = "RINCON_"
UID_POSTFIX = "01400" UID_POSTFIX = "01400"
WrapFuncType = TypeVar("WrapFuncType", bound=Callable[..., Any])
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_T = TypeVar("_T", "SonosSpeaker", "SonosEntity")
_R = TypeVar("_R")
_P = ParamSpec("_P")
def soco_error( def soco_error(
errorcodes: list[str] | None = None, raise_on_err: bool = True errorcodes: list[str] | None = None, raise_on_err: bool = True
) -> Callable: ) -> Callable[ # type: ignore[misc]
[Callable[Concatenate[_T, _P], _R]], Callable[Concatenate[_T, _P], _R | None]
]:
"""Filter out specified UPnP errors and raise exceptions for service calls.""" """Filter out specified UPnP errors and raise exceptions for service calls."""
def decorator(funct: WrapFuncType) -> WrapFuncType: def decorator(
funct: Callable[Concatenate[_T, _P], _R] # type: ignore[misc]
) -> Callable[Concatenate[_T, _P], _R | None]: # type: ignore[misc]
"""Decorate functions.""" """Decorate functions."""
def wrapper(self: SonosSpeaker | SonosEntity, *args: Any, **kwargs: Any) -> Any: def wrapper(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R | None:
"""Wrap for all soco UPnP exception.""" """Wrap for all soco UPnP exception."""
try: try:
result = funct(self, *args, **kwargs) result = funct(self, *args, **kwargs)
@ -65,7 +73,7 @@ def soco_error(
) )
return result return result
return cast(WrapFuncType, wrapper) return wrapper
return decorator return decorator

View file

@ -1,9 +1,11 @@
"""Common code for tplink.""" """Common code for tplink."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, TypeVar, cast from collections.abc import Awaitable, Callable
from typing import TypeVar
from kasa import SmartDevice from kasa import SmartDevice
from typing_extensions import Concatenate, ParamSpec
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity import DeviceInfo
@ -12,19 +14,20 @@ from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import DOMAIN from .const import DOMAIN
from .coordinator import TPLinkDataUpdateCoordinator from .coordinator import TPLinkDataUpdateCoordinator
WrapFuncType = TypeVar("WrapFuncType", bound=Callable[..., Any]) _T = TypeVar("_T", bound="CoordinatedTPLinkEntity")
_P = ParamSpec("_P")
def async_refresh_after(func: WrapFuncType) -> WrapFuncType: def async_refresh_after(
func: Callable[Concatenate[_T, _P], Awaitable[None]] # type: ignore[misc]
) -> Callable[Concatenate[_T, _P], Awaitable[None]]: # type: ignore[misc]
"""Define a wrapper to refresh after.""" """Define a wrapper to refresh after."""
async def _async_wrap( async def _async_wrap(self: _T, *args: _P.args, **kwargs: _P.kwargs) -> None:
self: CoordinatedTPLinkEntity, *args: Any, **kwargs: Any
) -> None:
await func(self, *args, **kwargs) await func(self, *args, **kwargs)
await self.coordinator.async_request_refresh_without_children() await self.coordinator.async_request_refresh_without_children()
return cast(WrapFuncType, _async_wrap) return _async_wrap
class CoordinatedTPLinkEntity(CoordinatorEntity): class CoordinatedTPLinkEntity(CoordinatorEntity):

View file

@ -1,12 +1,14 @@
"""Provide functionality to interact with the vlc telnet interface.""" """Provide functionality to interact with the vlc telnet interface."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Callable, TypeVar, cast from typing import Any, TypeVar
from aiovlc.client import Client from aiovlc.client import Client
from aiovlc.exceptions import AuthError, CommandError, ConnectError from aiovlc.exceptions import AuthError, CommandError, ConnectError
from typing_extensions import Concatenate, ParamSpec
from homeassistant.components.media_player import MediaPlayerEntity from homeassistant.components.media_player import MediaPlayerEntity
from homeassistant.components.media_player.const import ( from homeassistant.components.media_player.const import (
@ -49,7 +51,9 @@ SUPPORT_VLC = (
| SUPPORT_VOLUME_SET | SUPPORT_VOLUME_SET
) )
Func = TypeVar("Func", bound=Callable[..., Any]) _T = TypeVar("_T", bound="VlcDevice")
_R = TypeVar("_R")
_P = ParamSpec("_P")
async def async_setup_entry( async def async_setup_entry(
@ -64,11 +68,13 @@ async def async_setup_entry(
async_add_entities([VlcDevice(entry, vlc, name, available)], True) async_add_entities([VlcDevice(entry, vlc, name, available)], True)
def catch_vlc_errors(func: Func) -> Func: def catch_vlc_errors(
func: Callable[Concatenate[_T, _P], Awaitable[None]] # type: ignore[misc]
) -> Callable[Concatenate[_T, _P], Awaitable[None]]: # type: ignore[misc]
"""Catch VLC errors.""" """Catch VLC errors."""
@wraps(func) @wraps(func)
async def wrapper(self: VlcDevice, *args: Any, **kwargs: Any) -> Any: async def wrapper(self: VlcDevice, *args: _P.args, **kwargs: _P.kwargs) -> None:
"""Catch VLC errors and modify availability.""" """Catch VLC errors and modify availability."""
try: try:
await func(self, *args, **kwargs) await func(self, *args, **kwargs)
@ -80,7 +86,7 @@ def catch_vlc_errors(func: Func) -> Func:
LOGGER.error("Connection error: %s", err) LOGGER.error("Connection error: %s", err)
self._available = False self._available = False
return cast(Func, wrapper) return wrapper
class VlcDevice(MediaPlayerEntity): class VlcDevice(MediaPlayerEntity):

View file

@ -2,10 +2,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, Callable, TypeVar, cast from typing import Any, TypeVar
from typing_extensions import ParamSpec
from homeassistant.components.hassio import ( from homeassistant.components.hassio import (
async_create_backup, async_create_backup,
@ -35,7 +38,8 @@ from .const import (
LOGGER, LOGGER,
) )
F = TypeVar("F", bound=Callable[..., Any]) # pylint: disable=invalid-name _R = TypeVar("_R")
_P = ParamSpec("_P")
DATA_ADDON_MANAGER = f"{DOMAIN}_addon_manager" DATA_ADDON_MANAGER = f"{DOMAIN}_addon_manager"
@ -47,13 +51,17 @@ def get_addon_manager(hass: HomeAssistant) -> AddonManager:
return AddonManager(hass) return AddonManager(hass)
def api_error(error_message: str) -> Callable[[F], F]: def api_error(
error_message: str,
) -> Callable[[Callable[_P, Awaitable[_R]]], Callable[_P, Awaitable[_R]]]:
"""Handle HassioAPIError and raise a specific AddonError.""" """Handle HassioAPIError and raise a specific AddonError."""
def handle_hassio_api_error(func: F) -> F: def handle_hassio_api_error(
func: Callable[_P, Awaitable[_R]]
) -> Callable[_P, Awaitable[_R]]:
"""Handle a HassioAPIError.""" """Handle a HassioAPIError."""
async def wrapper(*args, **kwargs): # type: ignore async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
"""Wrap an add-on manager method.""" """Wrap an add-on manager method."""
try: try:
return_value = await func(*args, **kwargs) return_value = await func(*args, **kwargs)
@ -62,7 +70,7 @@ def api_error(error_message: str) -> Callable[[F], F]:
return return_value return return_value
return cast(F, wrapper) return wrapper
return handle_hassio_api_error return handle_hassio_api_error

View file

@ -29,6 +29,7 @@ pyyaml==6.0
requests==2.26.0 requests==2.26.0
scapy==2.4.5 scapy==2.4.5
sqlalchemy==1.4.27 sqlalchemy==1.4.27
typing-extensions>=3.10.0.2,<5.0
voluptuous-serialize==2.5.0 voluptuous-serialize==2.5.0
voluptuous==0.12.2 voluptuous==0.12.2
yarl==1.6.3 yarl==1.6.3

View file

@ -20,6 +20,7 @@ pip>=8.0.3,<20.3
python-slugify==4.0.1 python-slugify==4.0.1
pyyaml==6.0 pyyaml==6.0
requests==2.26.0 requests==2.26.0
typing-extensions>=3.10.0.2,<5.0
voluptuous==0.12.2 voluptuous==0.12.2
voluptuous-serialize==2.5.0 voluptuous-serialize==2.5.0
yarl==1.6.3 yarl==1.6.3

View file

@ -52,6 +52,7 @@ REQUIRES = [
"python-slugify==4.0.1", "python-slugify==4.0.1",
"pyyaml==6.0", "pyyaml==6.0",
"requests==2.26.0", "requests==2.26.0",
"typing-extensions>=3.10.0.2,<5.0",
"voluptuous==0.12.2", "voluptuous==0.12.2",
"voluptuous-serialize==2.5.0", "voluptuous-serialize==2.5.0",
"yarl==1.6.3", "yarl==1.6.3",