Improve schema typing (3) (#120521)

This commit is contained in:
Marc Mueller 2024-06-26 11:30:07 +02:00 committed by GitHub
parent afbd24adfe
commit d527113d59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 44 additions and 35 deletions

View file

@ -58,9 +58,9 @@ class InputButtonStorageCollection(collection.DictStorageCollection):
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
async def _process_create_data(self, data: dict) -> vol.Schema:
async def _process_create_data(self, data: dict) -> dict[str, str]:
"""Validate the config is valid."""
return self.CREATE_UPDATE_SCHEMA(data)
return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]
@callback
def _get_suggested_id(self, info: dict) -> str:

View file

@ -163,9 +163,9 @@ class InputTextStorageCollection(collection.DictStorageCollection):
CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))
async def _process_create_data(self, data: dict[str, Any]) -> vol.Schema:
async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the config is valid."""
return self.CREATE_UPDATE_SCHEMA(data)
return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]
@callback
def _get_suggested_id(self, info: dict[str, Any]) -> str:

View file

@ -302,7 +302,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:
def preprocess_turn_on_alternatives(
hass: HomeAssistant, params: dict[str, Any]
hass: HomeAssistant, params: dict[str, Any] | dict[str | vol.Optional, Any]
) -> None:
"""Process extra data for turn light on request.
@ -406,7 +406,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
# of the light base platform.
hass.async_create_task(profiles.async_initialize(), eager_start=True)
def preprocess_data(data: dict[str, Any]) -> dict[str | vol.Optional, Any]:
def preprocess_data(
data: dict[str | vol.Optional, Any],
) -> dict[str | vol.Optional, Any]:
"""Preprocess the service data."""
base: dict[str | vol.Optional, Any] = {
entity_field: data.pop(entity_field)

View file

@ -226,14 +226,16 @@ class MotionEyeOptionsFlow(OptionsFlow):
if self.show_advanced_options:
# The input URL is not validated as being a URL, to allow for the possibility
# the template input won't be a valid URL until after it's rendered
stream_kwargs = {}
description: dict[str, str] | None = None
if CONF_STREAM_URL_TEMPLATE in self._config_entry.options:
stream_kwargs["description"] = {
description = {
"suggested_value": self._config_entry.options[
CONF_STREAM_URL_TEMPLATE
]
}
schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, **stream_kwargs)] = str
schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, description=description)] = (
str
)
return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))

View file

@ -167,8 +167,9 @@ async def async_get_action_capabilities(
hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]:
"""List action capabilities."""
return {"extra_fields": DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE], {})}
if (fields := DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE])) is None:
return {}
return {"extra_fields": fields}
async def _execute_service_based_action(

View file

@ -80,10 +80,8 @@ def validate_event_data(obj: dict) -> dict:
except ValidationError as exc:
# Filter out required field errors if keys can be missing, and if there are
# still errors, raise an exception
if errors := [
error for error in exc.errors() if error["type"] != "value_error.missing"
]:
raise vol.MultipleInvalid(errors) from exc
if [error for error in exc.errors() if error["type"] != "value_error.missing"]:
raise vol.MultipleInvalid from exc
return obj

View file

@ -5,7 +5,7 @@ from __future__ import annotations
import abc
import asyncio
from collections import defaultdict
from collections.abc import Callable, Container, Iterable, Mapping
from collections.abc import Callable, Container, Hashable, Iterable, Mapping
from contextlib import suppress
import copy
from dataclasses import dataclass
@ -13,7 +13,7 @@ from enum import StrEnum
from functools import partial
import logging
from types import MappingProxyType
from typing import Any, Generic, Required, TypedDict
from typing import Any, Generic, Required, TypedDict, cast
from typing_extensions import TypeVar
import voluptuous as vol
@ -120,7 +120,7 @@ class InvalidData(vol.Invalid): # type: ignore[misc]
def __init__(
self,
message: str,
path: list[str | vol.Marker] | None,
path: list[Hashable] | None,
error_message: str | None,
schema_errors: dict[str, Any],
**kwargs: Any,
@ -384,6 +384,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
if (
data_schema := cur_step.get("data_schema")
) is not None and user_input is not None:
data_schema = cast(vol.Schema, data_schema)
try:
user_input = data_schema(user_input) # type: ignore[operator]
except vol.Invalid as ex:
@ -694,7 +695,7 @@ class FlowHandler(Generic[_FlowResultT, _HandlerT]):
):
# Copy the marker to not modify the flow schema
new_key = copy.copy(key)
new_key.description = {"suggested_value": suggested_values[key]}
new_key.description = {"suggested_value": suggested_values[key.schema]}
schema[new_key] = val
return vol.Schema(schema)

View file

@ -981,7 +981,7 @@ def removed(
def key_value_schemas(
key: str,
value_schemas: dict[Hashable, VolSchemaType],
value_schemas: dict[Hashable, VolSchemaType | Callable[[Any], dict[str, Any]]],
default_schema: VolSchemaType | None = None,
default_description: str | None = None,
) -> Callable[[Any], dict[Hashable, Any]]:
@ -1016,12 +1016,12 @@ def key_value_schemas(
# Validator helpers
def key_dependency(
def key_dependency[_KT: Hashable, _VT](
key: Hashable, dependency: Hashable
) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]:
) -> Callable[[dict[_KT, _VT]], dict[_KT, _VT]]:
"""Validate that all dependencies exist for key."""
def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]:
def validator(value: dict[_KT, _VT]) -> dict[_KT, _VT]:
"""Test dependencies."""
if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict")
@ -1405,13 +1405,13 @@ STATE_CONDITION_ATTRIBUTE_SCHEMA = vol.Schema(
)
def STATE_CONDITION_SCHEMA(value: Any) -> dict:
def STATE_CONDITION_SCHEMA(value: Any) -> dict[str, Any]:
"""Validate a state condition."""
if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary")
if CONF_ATTRIBUTE in value:
validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
validated: dict[str, Any] = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
else:
validated = STATE_CONDITION_STATE_SCHEMA(value)

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Collection, Coroutine, Iterable
from collections.abc import Callable, Collection, Coroutine, Iterable
import dataclasses
from dataclasses import dataclass, field
from enum import Enum, auto
@ -37,6 +37,9 @@ from .typing import VolSchemaType
_LOGGER = logging.getLogger(__name__)
type _SlotsType = dict[str, Any]
type _IntentSlotsType = dict[
str | tuple[str, str], VolSchemaType | Callable[[Any], Any]
]
INTENT_TURN_OFF = "HassTurnOff"
INTENT_TURN_ON = "HassTurnOn"
@ -808,8 +811,8 @@ class DynamicServiceIntentHandler(IntentHandler):
self,
intent_type: str,
speech: str | None = None,
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
required_slots: _IntentSlotsType | None = None,
optional_slots: _IntentSlotsType | None = None,
required_domains: set[str] | None = None,
required_features: int | None = None,
required_states: set[str] | None = None,
@ -825,7 +828,7 @@ class DynamicServiceIntentHandler(IntentHandler):
self.description = description
self.platforms = platforms
self.required_slots: dict[tuple[str, str], VolSchemaType] = {}
self.required_slots: _IntentSlotsType = {}
if required_slots:
for key, value_schema in required_slots.items():
if isinstance(key, str):
@ -834,7 +837,7 @@ class DynamicServiceIntentHandler(IntentHandler):
self.required_slots[key] = value_schema
self.optional_slots: dict[tuple[str, str], VolSchemaType] = {}
self.optional_slots: _IntentSlotsType = {}
if optional_slots:
for key, value_schema in optional_slots.items():
if isinstance(key, str):
@ -1108,8 +1111,8 @@ class ServiceIntentHandler(DynamicServiceIntentHandler):
domain: str,
service: str,
speech: str | None = None,
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
required_slots: _IntentSlotsType | None = None,
optional_slots: _IntentSlotsType | None = None,
required_domains: set[str] | None = None,
required_features: int | None = None,
required_states: set[str] | None = None,

View file

@ -175,7 +175,9 @@ class SchemaCommonFlowHandler:
and key.default is not vol.UNDEFINED
and key not in self._options
):
user_input[str(key.schema)] = key.default()
user_input[str(key.schema)] = cast(
Callable[[], Any], key.default
)()
if user_input is not None and form_step.validate_user_input is not None:
# Do extra validation of user input
@ -215,7 +217,7 @@ class SchemaCommonFlowHandler:
)
):
# Key not present, delete keys old value (if present) too
values.pop(key, None)
values.pop(key.schema, None)
async def _show_next_step_or_create_entry(
self, form_step: SchemaFlowFormStep
@ -491,7 +493,7 @@ def wrapped_entity_config_entry_title(
def entity_selector_without_own_entities(
handler: SchemaOptionsFlowHandler,
entity_selector_config: selector.EntitySelectorConfig,
) -> vol.Schema:
) -> selector.EntitySelector:
"""Return an entity selector which excludes own entities."""
entity_registry = er.async_get(handler.hass)
entities = er.async_entries_for_config_entry(