"""Config flow for AdGuard Control Hub integration.""" import asyncio import logging import re from typing import Any, Dict, Optional import voluptuous as vol from homeassistant import config_entries from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.data_entry_flow import FlowResult import homeassistant.helpers.config_validation as cv from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError, AdGuardTimeoutError from .const import ( CONF_SSL, CONF_VERIFY_SSL, DEFAULT_PORT, DEFAULT_SSL, DEFAULT_VERIFY_SSL, DOMAIN, ) _LOGGER = logging.getLogger(__name__) STEP_USER_DATA_SCHEMA = vol.Schema({ vol.Required(CONF_HOST): cv.string, vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, vol.Optional(CONF_USERNAME): cv.string, vol.Optional(CONF_PASSWORD): cv.string, vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, }) def validate_host(host: str) -> str: """Validate and clean host input.""" host = host.strip() if not host: raise InvalidHost("Host cannot be empty") # Remove protocol if present if host.startswith(("http://", "https://")): host = host.split("://", 1)[1] # Remove path if present if "/" in host: host = host.split("/", 1)[0] return host async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: """Validate the user input allows us to connect.""" # Validate and clean host try: host = validate_host(data[CONF_HOST]) data[CONF_HOST] = host except InvalidHost: raise # Validate port port = data[CONF_PORT] if not (1 <= port <= 65535): raise InvalidPort("Port must be between 1 and 65535") session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True)) api = AdGuardHomeAPI( host=host, port=port, username=data.get(CONF_USERNAME), password=data.get(CONF_PASSWORD), ssl=data.get(CONF_SSL, False), verify_ssl=data.get(CONF_VERIFY_SSL, True), session=session, timeout=10, ) try: if not await api.test_connection(): raise CannotConnect("Failed to connect to AdGuard Home") try: status = await api.get_status() version = status.get("version", "unknown") return { "title": f"AdGuard Control Hub ({host})", "version": version, "host": host, } except Exception: # If we can't get status but connection works, still proceed return { "title": f"AdGuard Control Hub ({host})", "version": "unknown", "host": host, } except AdGuardAuthError as err: raise InvalidAuth from err except AdGuardTimeoutError as err: raise Timeout from err except AdGuardConnectionError as err: if "timeout" in str(err).lower(): raise Timeout from err raise CannotConnect from err except asyncio.TimeoutError as err: raise Timeout from err except Exception as err: _LOGGER.exception("Unexpected error during validation") raise CannotConnect from err class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for AdGuard Control Hub.""" VERSION = 1 MINOR_VERSION = 1 async def async_step_user( self, user_input: Optional[Dict[str, Any]] = None ) -> FlowResult: """Handle the initial step.""" errors: Dict[str, str] = {} if user_input is not None: try: info = await validate_input(self.hass, user_input) unique_id = f"{info['host']}:{user_input[CONF_PORT]}" await self.async_set_unique_id(unique_id) self._abort_if_unique_id_configured() return self.async_create_entry( title=info["title"], data=user_input, ) except CannotConnect: errors["base"] = "cannot_connect" except InvalidAuth: errors["base"] = "invalid_auth" except InvalidHost: errors[CONF_HOST] = "invalid_host" except InvalidPort: errors[CONF_PORT] = "invalid_port" except Timeout: errors["base"] = "timeout" except Exception: _LOGGER.exception("Unexpected exception") errors["base"] = "unknown" return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors, ) class CannotConnect(Exception): """Error to indicate we cannot connect.""" class InvalidAuth(Exception): """Error to indicate there is invalid auth.""" class InvalidHost(Exception): """Error to indicate invalid host.""" class InvalidPort(Exception): """Error to indicate invalid port.""" class Timeout(Exception): """Error to indicate connection timeout."""