diff --git a/custom_components/adguard_hub/__init__.py b/custom_components/adguard_hub/__init__.py index 5584aee..7c1013f 100644 --- a/custom_components/adguard_hub/__init__.py +++ b/custom_components/adguard_hub/__init__.py @@ -1,20 +1,24 @@ """ -🛡️ AdGuard Control Hub for Home Assistant. +AdGuard Control Hub for Home Assistant. -Transform your AdGuard Home into a smart network management powerhouse with +Transform your AdGuard Home into a smart network management powerhouse with complete client control, service blocking, and automation capabilities. """ import asyncio import logging from datetime import timedelta +from typing import Dict, Any + from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed + +from .api import AdGuardHomeAPI, AdGuardConnectionError from .const import DOMAIN, PLATFORMS, SCAN_INTERVAL, CONF_SSL, CONF_VERIFY_SSL -from .api import AdGuardHomeAPI +from .services import AdGuardControlHubServices _LOGGER = logging.getLogger(__name__) @@ -23,6 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up AdGuard Control Hub from a config entry.""" session = async_get_clientsession(hass, entry.data.get(CONF_VERIFY_SSL, True)) + # Create API instance api = AdGuardHomeAPI( host=entry.data[CONF_HOST], port=entry.data[CONF_PORT], @@ -34,16 +39,26 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Test the connection try: - await api.test_connection() - _LOGGER.info("Successfully connected to AdGuard Home at %s:%s", - entry.data[CONF_HOST], entry.data[CONF_PORT]) + if not await api.test_connection(): + raise ConfigEntryNotReady("Unable to connect to AdGuard Home") + + _LOGGER.info( + "Successfully connected to AdGuard Home at %s:%s", + entry.data[CONF_HOST], + entry.data[CONF_PORT] + ) except Exception as err: _LOGGER.error("Failed to connect to AdGuard Home: %s", err) - raise ConfigEntryNotReady(f"Unable to connect: {err}") + raise ConfigEntryNotReady(f"Unable to connect: {err}") from err # Create update coordinator coordinator = AdGuardControlHubCoordinator(hass, api) - await coordinator.async_config_entry_first_refresh() + + try: + await coordinator.async_config_entry_first_refresh() + except Exception as err: + _LOGGER.error("Failed to perform initial data refresh: %s", err) + raise ConfigEntryNotReady(f"Failed to fetch initial data: {err}") from err # Store data hass.data.setdefault(DOMAIN, {}) @@ -53,9 +68,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: } # Set up platforms - await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + try: + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + except Exception as err: + _LOGGER.error("Failed to set up platforms: %s", err) + # Clean up on failure + hass.data[DOMAIN].pop(entry.entry_id) + raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err - _LOGGER.info("AdGuard Control Hub setup complete") + # Register services (only once, not per config entry) + if not hass.services.has_service(DOMAIN, "block_services"): + services = AdGuardControlHubServices(hass) + services.register_services() + + # Store services instance for cleanup + hass.data.setdefault(f"{DOMAIN}_services", services) + + _LOGGER.info("AdGuard Control Hub setup complete for %s:%s", + entry.data[CONF_HOST], entry.data[CONF_PORT]) return True @@ -64,8 +94,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unload_ok: + # Remove this entry's data hass.data[DOMAIN].pop(entry.entry_id) + # Unregister services if this was the last entry + if not hass.data[DOMAIN]: # No more entries + services = hass.data.get(f"{DOMAIN}_services") + if services: + services.unregister_services() + hass.data.pop(f"{DOMAIN}_services", None) + + # Also clean up the empty domain entry + hass.data.pop(DOMAIN, None) + return unload_ok @@ -81,36 +122,54 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator): update_interval=timedelta(seconds=SCAN_INTERVAL), ) self.api = api - self._clients = {} - self._statistics = {} - self._protection_status = {} + self._clients: Dict[str, Any] = {} + self._statistics: Dict[str, Any] = {} + self._protection_status: Dict[str, Any] = {} - async def _async_update_data(self): + async def _async_update_data(self) -> Dict[str, Any]: """Fetch data from AdGuard Home.""" try: # Fetch all data concurrently for better performance - results = await asyncio.gather( + tasks = [ self.api.get_clients(), self.api.get_statistics(), self.api.get_status(), - return_exceptions=True, - ) + ] + results = await asyncio.gather(*tasks, return_exceptions=True) clients, statistics, status = results - # Handle any exceptions + # Handle any exceptions in individual requests for i, result in enumerate(results): if isinstance(result, Exception): endpoint_names = ["clients", "statistics", "status"] - _LOGGER.warning("Error fetching %s: %s", endpoint_names[i], result) + _LOGGER.warning( + "Error fetching %s from %s:%s: %s", + endpoint_names[i], + self.api.host, + self.api.port, + result + ) # Update stored data (use empty dict if fetch failed) - self._clients = { - client["name"]: client - for client in (clients.get("clients", []) if not isinstance(clients, Exception) else []) - } - self._statistics = statistics if not isinstance(statistics, Exception) else {} - self._protection_status = status if not isinstance(status, Exception) else {} + if not isinstance(clients, Exception): + self._clients = { + client["name"]: client + for client in clients.get("clients", []) + if client.get("name") # Ensure client has a name + } + else: + _LOGGER.warning("Failed to update clients data, keeping previous data") + + if not isinstance(statistics, Exception): + self._statistics = statistics + else: + _LOGGER.warning("Failed to update statistics data, keeping previous data") + + if not isinstance(status, Exception): + self._protection_status = status + else: + _LOGGER.warning("Failed to update status data, keeping previous data") return { "clients": self._clients, @@ -118,20 +177,40 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator): "status": self._protection_status, } + except AdGuardConnectionError as err: + raise UpdateFailed(f"Connection error to AdGuard Home: {err}") from err except Exception as err: - raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}") + raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}") from err @property - def clients(self): + def clients(self) -> Dict[str, Any]: """Return clients data.""" return self._clients @property - def statistics(self): + def statistics(self) -> Dict[str, Any]: """Return statistics data.""" return self._statistics @property - def protection_status(self): + def protection_status(self) -> Dict[str, Any]: """Return protection status data.""" return self._protection_status + + def get_client(self, client_name: str) -> Dict[str, Any] | None: + """Get a specific client by name.""" + return self._clients.get(client_name) + + def has_client(self, client_name: str) -> bool: + """Check if a client exists.""" + return client_name in self._clients + + @property + def client_count(self) -> int: + """Return the number of clients.""" + return len(self._clients) + + @property + def is_protection_enabled(self) -> bool: + """Return True if protection is enabled.""" + return self._protection_status.get("protection_enabled", False) diff --git a/custom_components/adguard_hub/api.py b/custom_components/adguard_hub/api.py index 34d18a2..ebd78b8 100644 --- a/custom_components/adguard_hub/api.py +++ b/custom_components/adguard_hub/api.py @@ -1,102 +1,207 @@ """API wrapper for AdGuard Home.""" +import asyncio import logging -from typing import Any +from typing import Any, Dict, Optional + import aiohttp -from aiohttp import BasicAuth +from aiohttp import BasicAuth, ClientError, ClientTimeout + from .const import API_ENDPOINTS _LOGGER = logging.getLogger(__name__) +# Custom exceptions +class AdGuardHomeError(Exception): + """Base exception for AdGuard Home API.""" + pass + +class AdGuardConnectionError(AdGuardHomeError): + """Exception for connection errors.""" + pass + +class AdGuardAuthError(AdGuardHomeError): + """Exception for authentication errors.""" + pass + +class AdGuardNotFoundError(AdGuardHomeError): + """Exception for not found errors.""" + pass class AdGuardHomeAPI: """API wrapper for AdGuard Home.""" - def __init__(self, host: str, port: int = 3000, username: str = None, - password: str = None, ssl: bool = False, session=None): + def __init__( + self, + host: str, + port: int = 3000, + username: Optional[str] = None, + password: Optional[str] = None, + ssl: bool = False, + session: Optional[aiohttp.ClientSession] = None, + timeout: int = 10, + ): + """Initialize the API wrapper.""" self.host = host self.port = port self.username = username self.password = password self.ssl = ssl - self.session = session + self._session = session + self._timeout = ClientTimeout(total=timeout) protocol = "https" if ssl else "http" self.base_url = f"{protocol}://{host}:{port}" + self._own_session = session is None - async def _request(self, method: str, endpoint: str, data: dict = None) -> dict: - """Make an API request.""" + async def __aenter__(self): + """Async context manager entry.""" + if self._own_session: + self._session = aiohttp.ClientSession(timeout=self._timeout) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._own_session and self._session: + await self._session.close() + + @property + def session(self) -> aiohttp.ClientSession: + """Get the session, creating one if needed.""" + if not self._session: + self._session = aiohttp.ClientSession(timeout=self._timeout) + return self._session + + async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]: + """Make an API request with comprehensive error handling.""" url = f"{self.base_url}{endpoint}" headers = {"Content-Type": "application/json"} auth = None + if self.username and self.password: auth = BasicAuth(self.username, self.password) try: - async with self.session.request(method, url, json=data, headers=headers, auth=auth) as response: + async with self.session.request( + method, url, json=data, headers=headers, auth=auth + ) as response: + + # Handle different HTTP status codes + if response.status == 401: + raise AdGuardAuthError("Authentication failed - check username/password") + elif response.status == 403: + raise AdGuardAuthError("Access forbidden - insufficient permissions") + elif response.status == 404: + raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}") + elif response.status >= 500: + raise AdGuardConnectionError(f"Server error {response.status}") + response.raise_for_status() + + # Handle empty responses if response.status == 204 or not response.content_length: return {} - return await response.json() + + try: + return await response.json() + except aiohttp.ContentTypeError: + # Handle non-JSON responses + text = await response.text() + _LOGGER.warning("Non-JSON response received: %s", text) + return {"response": text} + + except asyncio.TimeoutError as err: + raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}") + except ClientError as err: + raise AdGuardConnectionError(f"Client error: {err}") except Exception as err: - _LOGGER.error("Error communicating with AdGuard Home: %s", err) - raise + _LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err) + raise AdGuardHomeError(f"Unexpected error: {err}") async def test_connection(self) -> bool: - """Test the connection.""" + """Test the connection to AdGuard Home.""" try: await self._request("GET", API_ENDPOINTS["status"]) return True - except: + except Exception as err: + _LOGGER.debug("Connection test failed: %s", err) return False - async def get_status(self) -> dict: - """Get server status.""" + async def get_status(self) -> Dict[str, Any]: + """Get server status information.""" return await self._request("GET", API_ENDPOINTS["status"]) - async def get_clients(self) -> dict: - """Get all clients.""" + async def get_clients(self) -> Dict[str, Any]: + """Get all configured clients.""" return await self._request("GET", API_ENDPOINTS["clients"]) - async def get_statistics(self) -> dict: - """Get statistics.""" + async def get_statistics(self) -> Dict[str, Any]: + """Get DNS query statistics.""" return await self._request("GET", API_ENDPOINTS["stats"]) - async def set_protection(self, enabled: bool) -> dict: - """Enable or disable protection.""" + async def set_protection(self, enabled: bool) -> Dict[str, Any]: + """Enable or disable AdGuard protection.""" data = {"enabled": enabled} return await self._request("POST", API_ENDPOINTS["protection"], data) - async def add_client(self, client_data: dict) -> dict: - """Add a new client.""" + async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]: + """Add a new client configuration.""" + # Validate required fields + if "name" not in client_data: + raise ValueError("Client name is required") + if "ids" not in client_data or not client_data["ids"]: + raise ValueError("Client IDs are required") + return await self._request("POST", API_ENDPOINTS["clients_add"], client_data) - async def update_client(self, client_data: dict) -> dict: - """Update an existing client.""" + async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]: + """Update an existing client configuration.""" + if "name" not in client_data: + raise ValueError("Client name is required for update") + if "data" not in client_data: + raise ValueError("Client data is required for update") + return await self._request("POST", API_ENDPOINTS["clients_update"], client_data) - async def delete_client(self, client_name: str) -> dict: - """Delete a client.""" + async def delete_client(self, client_name: str) -> Dict[str, Any]: + """Delete a client configuration.""" + if not client_name: + raise ValueError("Client name is required") + data = {"name": client_name} return await self._request("POST", API_ENDPOINTS["clients_delete"], data) - async def get_client_by_name(self, client_name: str) -> dict: + async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]: """Get a specific client by name.""" - clients_data = await self.get_clients() - clients = clients_data.get("clients", []) + if not client_name: + return None - for client in clients: - if client.get("name") == client_name: - return client + try: + clients_data = await self.get_clients() + clients = clients_data.get("clients", []) - return None + for client in clients: + if client.get("name") == client_name: + return client - async def update_client_blocked_services(self, client_name: str, blocked_services: list, - schedule: dict = None) -> dict: + return None + except Exception as err: + _LOGGER.error("Failed to get client %s: %s", client_name, err) + return None + + async def update_client_blocked_services( + self, + client_name: str, + blocked_services: list, + schedule: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """Update blocked services for a specific client.""" + if not client_name: + raise ValueError("Client name is required") + client = await self.get_client_by_name(client_name) if not client: - raise ValueError(f"Client '{client_name}' not found") + raise AdGuardNotFoundError(f"Client '{client_name}' not found") - # Prepare the blocked services data + # Prepare the blocked services data with proper structure if schedule: blocked_services_data = { "ids": blocked_services, @@ -110,7 +215,7 @@ class AdGuardHomeAPI: } } - # Update the client + # Update the client with new blocked services update_data = { "name": client_name, "data": { @@ -121,18 +226,23 @@ class AdGuardHomeAPI: return await self.update_client(update_data) - async def toggle_client_service(self, client_name: str, service_id: str, enabled: bool) -> dict: + async def toggle_client_service( + self, client_name: str, service_id: str, enabled: bool + ) -> Dict[str, Any]: """Toggle a specific service for a client.""" + if not client_name or not service_id: + raise ValueError("Client name and service ID are required") + client = await self.get_client_by_name(client_name) if not client: - raise ValueError(f"Client '{client_name}' not found") + raise AdGuardNotFoundError(f"Client '{client_name}' not found") # Get current blocked services blocked_services = client.get("blocked_services", {}) if isinstance(blocked_services, dict): service_ids = blocked_services.get("ids", []) else: - # Handle old format (list) + # Handle legacy format (direct list) service_ids = blocked_services if blocked_services else [] # Update the service list @@ -142,3 +252,12 @@ class AdGuardHomeAPI: service_ids.remove(service_id) return await self.update_client_blocked_services(client_name, service_ids) + + async def get_blocked_services(self) -> Dict[str, Any]: + """Get available blocked services.""" + return await self._request("GET", API_ENDPOINTS["blocked_services_all"]) + + async def close(self) -> None: + """Close the API session if we own it.""" + if self._own_session and self._session: + await self._session.close() diff --git a/custom_components/adguard_hub/binary_sensor.py b/custom_components/adguard_hub/binary_sensor.py new file mode 100644 index 0000000..33d5412 --- /dev/null +++ b/custom_components/adguard_hub/binary_sensor.py @@ -0,0 +1,166 @@ +"""Binary sensor platform for AdGuard Control Hub integration.""" +import logging +from typing import Any + +from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.update_coordinator import CoordinatorEntity + +from . import AdGuardControlHubCoordinator +from .api import AdGuardHomeAPI +from .const import DOMAIN, MANUFACTURER, ICON_PROTECTION, ICON_PROTECTION_OFF + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up AdGuard Control Hub binary sensor platform.""" + coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"] + api = hass.data[DOMAIN][config_entry.entry_id]["api"] + + entities = [ + AdGuardProtectionBinarySensor(coordinator, api), + AdGuardFilteringBinarySensor(coordinator, api), + AdGuardSafeBrowsingBinarySensor(coordinator, api), + AdGuardParentalControlBinarySensor(coordinator, api), + AdGuardSafeSearchBinarySensor(coordinator, api), + ] + + async_add_entities(entities) + + +class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity): + """Base class for AdGuard binary sensors.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the binary sensor.""" + super().__init__(coordinator) + self.api = api + self._attr_device_info = { + "identifiers": {(DOMAIN, f"{api.host}:{api.port}")}, + "name": f"AdGuard Control Hub ({api.host})", + "manufacturer": MANUFACTURER, + "model": "AdGuard Home", + } + + +class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show AdGuard protection status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled" + self._attr_name = "AdGuard Protection Status" + self._attr_device_class = BinarySensorDeviceClass.RUNNING + + @property + def is_on(self) -> bool | None: + """Return true if protection is enabled.""" + return self.coordinator.protection_status.get("protection_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + status = self.coordinator.protection_status + return { + "dns_port": status.get("dns_port", "N/A"), + "http_port": status.get("http_port", "N/A"), + "version": status.get("version", "N/A"), + "running": status.get("running", False), + } + + +class AdGuardFilteringBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show DNS filtering status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_filtering_enabled" + self._attr_name = "AdGuard DNS Filtering" + self._attr_device_class = BinarySensorDeviceClass.RUNNING + + @property + def is_on(self) -> bool | None: + """Return true if DNS filtering is enabled.""" + return self.coordinator.protection_status.get("filtering_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:dns" if self.is_on else "mdi:dns-off" + + +class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show Safe Browsing status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled" + self._attr_name = "AdGuard Safe Browsing" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + + @property + def is_on(self) -> bool | None: + """Return true if Safe Browsing is enabled.""" + return self.coordinator.protection_status.get("safebrowsing_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:security" if self.is_on else "mdi:security-off" + + +class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show Parental Control status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled" + self._attr_name = "AdGuard Parental Control" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + + @property + def is_on(self) -> bool | None: + """Return true if Parental Control is enabled.""" + return self.coordinator.protection_status.get("parental_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:account-child" if self.is_on else "mdi:account-child-outline" + + +class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show Safe Search status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled" + self._attr_name = "AdGuard Safe Search" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + + @property + def is_on(self) -> bool | None: + """Return true if Safe Search is enabled.""" + return self.coordinator.protection_status.get("safesearch_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:search-web" if self.is_on else "mdi:web-off" diff --git a/custom_components/adguard_hub/config_flow.py b/custom_components/adguard_hub/config_flow.py index 7679f9f..2b76fa9 100644 --- a/custom_components/adguard_hub/config_flow.py +++ b/custom_components/adguard_hub/config_flow.py @@ -1,73 +1,128 @@ """Config flow for AdGuard Control Hub integration.""" +import asyncio import logging -from typing import Any +from typing import Any, Dict, Optional + import voluptuous as vol from homeassistant import config_entries from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.helpers.aiohttp_client import async_get_clientsession -from .api import AdGuardHomeAPI -from .const import CONF_SSL, CONF_VERIFY_SSL, DEFAULT_PORT, DEFAULT_SSL, DEFAULT_VERIFY_SSL, DOMAIN +from homeassistant.data_entry_flow import FlowResult +import homeassistant.helpers.config_validation as cv + +from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError +from .const import ( + CONF_SSL, + CONF_VERIFY_SSL, + DEFAULT_PORT, + DEFAULT_SSL, + DEFAULT_VERIFY_SSL, + DOMAIN, +) _LOGGER = logging.getLogger(__name__) STEP_USER_DATA_SCHEMA = vol.Schema({ - vol.Required(CONF_HOST): str, - vol.Optional(CONF_PORT, default=DEFAULT_PORT): int, - vol.Optional(CONF_USERNAME): str, - vol.Optional(CONF_PASSWORD): str, - vol.Optional(CONF_SSL, default=DEFAULT_SSL): bool, - vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): bool, + vol.Required(CONF_HOST): cv.string, + vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, + vol.Optional(CONF_USERNAME): cv.string, + vol.Optional(CONF_PASSWORD): cv.string, + vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, + vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, }) -async def validate_input(hass, data: dict) -> dict: + +async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: """Validate the user input allows us to connect.""" + # Normalize host + host = data[CONF_HOST].strip() + if not host: + raise InvalidHost("Host cannot be empty") + + # Remove protocol if provided + if host.startswith(("http://", "https://")): + host = host.split("://", 1)[1] + data[CONF_HOST] = host + + # Validate port + port = data[CONF_PORT] + if not (1 <= port <= 65535): + raise InvalidPort("Port must be between 1 and 65535") + + # Create session with appropriate SSL settings session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True)) + # Create API instance api = AdGuardHomeAPI( - host=data[CONF_HOST], - port=data[CONF_PORT], + host=host, + port=port, username=data.get(CONF_USERNAME), password=data.get(CONF_PASSWORD), ssl=data.get(CONF_SSL, False), session=session, + timeout=10, # 10 second timeout for setup ) # Test the connection - if not await api.test_connection(): - raise CannotConnect - - # Get server info try: - status = await api.get_status() - version = status.get("version", "unknown") - return { - "title": f"AdGuard Control Hub ({data[CONF_HOST]})", - "version": version - } - except Exception as err: - _LOGGER.exception("Unexpected exception: %s", err) + if not await api.test_connection(): + raise CannotConnect("Failed to connect to AdGuard Home") + + # Get additional server info if possible + try: + status = await api.get_status() + version = status.get("version", "unknown") + dns_port = status.get("dns_port", "N/A") + + return { + "title": f"AdGuard Control Hub ({host})", + "version": version, + "dns_port": dns_port, + "host": host, + } + except Exception as err: + _LOGGER.warning("Could not get server status, but connection works: %s", err) + return { + "title": f"AdGuard Control Hub ({host})", + "version": "unknown", + "dns_port": "N/A", + "host": host, + } + + except AdGuardAuthError as err: + _LOGGER.error("Authentication failed: %s", err) + raise InvalidAuth from err + except AdGuardConnectionError as err: + _LOGGER.error("Connection failed: %s", err) + if "timeout" in str(err).lower(): + raise Timeout from err raise CannotConnect from err + except asyncio.TimeoutError as err: + _LOGGER.error("Connection timeout: %s", err) + raise Timeout from err + except Exception as err: + _LOGGER.exception("Unexpected error during validation: %s", err) + raise CannotConnect from err + class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for AdGuard Control Hub.""" VERSION = 1 + CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL - async def async_step_user(self, user_input: dict[str, Any] | None = None): + async def async_step_user( + self, user_input: Optional[Dict[str, Any]] = None + ) -> FlowResult: """Handle the initial step.""" - errors: dict[str, str] = {} + errors: Dict[str, str] = {} if user_input is not None: try: info = await validate_input(self.hass, user_input) - except CannotConnect: - errors["base"] = "cannot_connect" - except Exception: - _LOGGER.exception("Unexpected exception") - errors["base"] = "unknown" - else: + # Create unique ID based on host and port - unique_id = f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}" + unique_id = f"{info['host']}:{user_input[CONF_PORT]}" await self.async_set_unique_id(unique_id) self._abort_if_unique_id_configured() @@ -76,11 +131,83 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): data=user_input, ) + except CannotConnect: + errors["base"] = "cannot_connect" + except InvalidAuth: + errors["base"] = "invalid_auth" + except InvalidHost: + errors[CONF_HOST] = "invalid_host" + except InvalidPort: + errors[CONF_PORT] = "invalid_port" + except Timeout: + errors["base"] = "timeout" + except Exception: + _LOGGER.exception("Unexpected exception during config flow") + errors["base"] = "unknown" + return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors, ) + async def async_step_import(self, import_info: Dict[str, Any]) -> FlowResult: + """Handle configuration import.""" + return await self.async_step_user(import_info) + + @staticmethod + def async_get_options_flow(config_entry): + """Get the options flow for this handler.""" + return OptionsFlowHandler(config_entry) + + +class OptionsFlowHandler(config_entries.OptionsFlow): + """Handle options flow for AdGuard Control Hub.""" + + def __init__(self, config_entry: config_entries.ConfigEntry) -> None: + """Initialize options flow.""" + self.config_entry = config_entry + + async def async_step_init( + self, user_input: Optional[Dict[str, Any]] = None + ) -> FlowResult: + """Handle options flow.""" + if user_input is not None: + return self.async_create_entry(title="", data=user_input) + + options_schema = vol.Schema({ + vol.Optional( + "scan_interval", + default=self.config_entry.options.get("scan_interval", 30), + ): vol.All(vol.Coerce(int), vol.Range(min=10, max=300)), + vol.Optional( + "timeout", + default=self.config_entry.options.get("timeout", 10), + ): vol.All(vol.Coerce(int), vol.Range(min=5, max=60)), + }) + + return self.async_show_form( + step_id="init", + data_schema=options_schema, + ) + + +# Custom exceptions class CannotConnect(Exception): - """Error to indicate we cannot connect.""" \ No newline at end of file + """Error to indicate we cannot connect.""" + + +class InvalidAuth(Exception): + """Error to indicate there is invalid auth.""" + + +class InvalidHost(Exception): + """Error to indicate invalid host.""" + + +class InvalidPort(Exception): + """Error to indicate invalid port.""" + + +class Timeout(Exception): + """Error to indicate connection timeout.""" diff --git a/custom_components/adguard_hub/const.py b/custom_components/adguard_hub/const.py index 881843c..4b99d31 100644 --- a/custom_components/adguard_hub/const.py +++ b/custom_components/adguard_hub/const.py @@ -17,7 +17,7 @@ SCAN_INTERVAL: Final = 30 # Platforms PLATFORMS: Final = [ "switch", - "binary_sensor", + "binary_sensor", "sensor", ] @@ -26,7 +26,7 @@ API_ENDPOINTS: Final = { "status": "/control/status", "clients": "/control/clients", "clients_add": "/control/clients/add", - "clients_update": "/control/clients/update", + "clients_update": "/control/clients/update", "clients_delete": "/control/clients/delete", "blocked_services_all": "/control/blocked_services/all", "blocked_services_get": "/control/blocked_services/get", @@ -39,7 +39,7 @@ API_ENDPOINTS: Final = { BLOCKED_SERVICES: Final = { # Social Media "youtube": "YouTube", - "facebook": "Facebook", + "facebook": "Facebook", "instagram": "Instagram", "tiktok": "TikTok", "twitter": "Twitter/X", @@ -62,7 +62,7 @@ BLOCKED_SERVICES: Final = { "amazon": "Amazon", "ebay": "eBay", - # Communication + # Communication "whatsapp": "WhatsApp", "telegram": "Telegram", "discord": "Discord", @@ -89,4 +89,4 @@ ICON_CLIENT: Final = "mdi:devices" ICON_CLIENT_OFFLINE: Final = "mdi:devices-off" ICON_BLOCKED_SERVICE: Final = "mdi:block-helper" ICON_ALLOWED_SERVICE: Final = "mdi:check-circle" -ICON_STATISTICS: Final = "mdi:chart-line" \ No newline at end of file +ICON_STATISTICS: Final = "mdi:chart-line" diff --git a/custom_components/adguard_hub/manifest.json b/custom_components/adguard_hub/manifest.json index ed709bb..b462811 100644 --- a/custom_components/adguard_hub/manifest.json +++ b/custom_components/adguard_hub/manifest.json @@ -1,14 +1,14 @@ { - "domain": "adguard_hub", - "name": "AdGuard Control Hub", - "codeowners": ["@sq4ind"], - "config_flow": true, - "dependencies": [], - "documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub", - "integration_type": "hub", - "iot_class": "local_polling", - "requirements": [ - "aiohttp>=3.8.0" - ], - "version": "1.0.0" + "domain": "adguard_hub", + "name": "AdGuard Control Hub", + "codeowners": ["@sq4ind"], + "config_flow": true, + "dependencies": [], + "documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub", + "integration_type": "hub", + "iot_class": "local_polling", + "requirements": [ + "aiohttp>=3.8.0" + ], + "version": "1.0.0" } \ No newline at end of file diff --git a/custom_components/adguard_hub/sensor.py b/custom_components/adguard_hub/sensor.py new file mode 100644 index 0000000..3a9ab5d --- /dev/null +++ b/custom_components/adguard_hub/sensor.py @@ -0,0 +1,185 @@ +"""Sensor platform for AdGuard Control Hub integration.""" +import logging +from datetime import datetime, timezone +from typing import Any + +from homeassistant.components.sensor import SensorEntity, SensorDeviceClass, SensorStateClass +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import PERCENTAGE +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.update_coordinator import CoordinatorEntity + +from . import AdGuardControlHubCoordinator +from .api import AdGuardHomeAPI +from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up AdGuard Control Hub sensor platform.""" + coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"] + api = hass.data[DOMAIN][config_entry.entry_id]["api"] + + entities = [ + AdGuardQueriesCounterSensor(coordinator, api), + AdGuardBlockedCounterSensor(coordinator, api), + AdGuardBlockingPercentageSensor(coordinator, api), + AdGuardRuleCountSensor(coordinator, api), + AdGuardClientCountSensor(coordinator, api), + AdGuardUpstreamAverageTimeSensor(coordinator, api), + ] + + async_add_entities(entities) + + +class AdGuardBaseSensor(CoordinatorEntity, SensorEntity): + """Base class for AdGuard sensors.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator) + self.api = api + self._attr_device_info = { + "identifiers": {(DOMAIN, f"{api.host}:{api.port}")}, + "name": f"AdGuard Control Hub ({api.host})", + "manufacturer": MANUFACTURER, + "model": "AdGuard Home", + } + + +class AdGuardQueriesCounterSensor(AdGuardBaseSensor): + """Sensor to track DNS queries count.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_dns_queries" + self._attr_name = "AdGuard DNS Queries" + self._attr_icon = ICON_STATISTICS + self._attr_state_class = SensorStateClass.TOTAL_INCREASING + self._attr_native_unit_of_measurement = "queries" + + @property + def native_value(self) -> int | None: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + return stats.get("num_dns_queries", 0) + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + stats = self.coordinator.statistics + return { + "queries_today": stats.get("num_dns_queries_today", 0), + "queries_blocked_today": stats.get("num_blocked_filtering_today", 0), + "last_updated": datetime.now(timezone.utc).isoformat(), + } + + +class AdGuardBlockedCounterSensor(AdGuardBaseSensor): + """Sensor to track blocked queries count.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries" + self._attr_name = "AdGuard Blocked Queries" + self._attr_icon = "mdi:shield-check" + self._attr_state_class = SensorStateClass.TOTAL_INCREASING + self._attr_native_unit_of_measurement = "queries" + + @property + def native_value(self) -> int | None: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + return stats.get("num_blocked_filtering", 0) + + +class AdGuardBlockingPercentageSensor(AdGuardBaseSensor): + """Sensor to track blocking percentage.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage" + self._attr_name = "AdGuard Blocking Percentage" + self._attr_icon = "mdi:percent" + self._attr_state_class = SensorStateClass.MEASUREMENT + self._attr_native_unit_of_measurement = PERCENTAGE + self._attr_device_class = SensorDeviceClass.POWER_FACTOR + + @property + def native_value(self) -> float | None: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + total_queries = stats.get("num_dns_queries", 0) + blocked_queries = stats.get("num_blocked_filtering", 0) + + if total_queries == 0: + return 0 + + percentage = (blocked_queries / total_queries) * 100 + return round(percentage, 2) + + +class AdGuardRuleCountSensor(AdGuardBaseSensor): + """Sensor to track filtering rules count.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_rules_count" + self._attr_name = "AdGuard Rules Count" + self._attr_icon = "mdi:format-list-numbered" + self._attr_state_class = SensorStateClass.MEASUREMENT + self._attr_native_unit_of_measurement = "rules" + + @property + def native_value(self) -> int | None: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + return stats.get("filtering_rules_count", 0) + + +class AdGuardClientCountSensor(AdGuardBaseSensor): + """Sensor to track active clients count.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_clients_count" + self._attr_name = "AdGuard Clients Count" + self._attr_icon = "mdi:account-multiple" + self._attr_state_class = SensorStateClass.MEASUREMENT + self._attr_native_unit_of_measurement = "clients" + + @property + def native_value(self) -> int | None: + """Return the state of the sensor.""" + return len(self.coordinator.clients) + + +class AdGuardUpstreamAverageTimeSensor(AdGuardBaseSensor): + """Sensor to track upstream servers average response time.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_upstream_response_time" + self._attr_name = "AdGuard Upstream Response Time" + self._attr_icon = "mdi:timer" + self._attr_state_class = SensorStateClass.MEASUREMENT + self._attr_native_unit_of_measurement = "ms" + self._attr_device_class = SensorDeviceClass.DURATION + + @property + def native_value(self) -> float | None: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + return stats.get("avg_processing_time", 0) diff --git a/custom_components/adguard_hub/services.py b/custom_components/adguard_hub/services.py index 5b52d35..72925e6 100644 --- a/custom_components/adguard_hub/services.py +++ b/custom_components/adguard_hub/services.py @@ -1,38 +1,438 @@ -"""Services for AdGuard Control Hub integration.""" +"""Service implementations for AdGuard Control Hub integration.""" +import asyncio import logging -from homeassistant.core import HomeAssistant +from datetime import datetime, timedelta +from typing import Any, Dict, List + +import voluptuous as vol +from homeassistant.core import HomeAssistant, ServiceCall +from homeassistant.helpers import config_validation as cv + from .api import AdGuardHomeAPI +from .const import ( + DOMAIN, + BLOCKED_SERVICES, + ATTR_CLIENT_NAME, + ATTR_SERVICES, + ATTR_DURATION, + ATTR_CLIENTS, + ATTR_CLIENT_PATTERN, + ATTR_SETTINGS, +) _LOGGER = logging.getLogger(__name__) -async def async_register_services(hass: HomeAssistant, api: AdGuardHomeAPI) -> None: - """Register integration services.""" +# Service schemas +SCHEMA_BLOCK_SERVICES = vol.Schema({ + vol.Required(ATTR_CLIENT_NAME): cv.string, + vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), +}) - async def emergency_unblock_service(call): - """Emergency unblock service.""" - duration = call.data.get("duration", 300) - clients = call.data.get("clients", ["all"]) +SCHEMA_UNBLOCK_SERVICES = vol.Schema({ + vol.Required(ATTR_CLIENT_NAME): cv.string, + vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), +}) +SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({ + vol.Required(ATTR_DURATION): cv.positive_int, + vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]), +}) + +SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({ + vol.Required(ATTR_CLIENT_PATTERN): cv.string, + vol.Required(ATTR_SETTINGS): vol.Schema({ + vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), + vol.Optional("filtering_enabled"): cv.boolean, + vol.Optional("safebrowsing_enabled"): cv.boolean, + vol.Optional("parental_enabled"): cv.boolean, + }), +}) + +SCHEMA_ADD_CLIENT = vol.Schema({ + vol.Required("name"): cv.string, + vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]), + vol.Optional("mac"): cv.string, + vol.Optional("use_global_settings"): cv.boolean, + vol.Optional("filtering_enabled"): cv.boolean, + vol.Optional("parental_enabled"): cv.boolean, + vol.Optional("safebrowsing_enabled"): cv.boolean, + vol.Optional("safesearch_enabled"): cv.boolean, + vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), +}) + +SCHEMA_REMOVE_CLIENT = vol.Schema({ + vol.Required("name"): cv.string, +}) + +SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({ + vol.Required(ATTR_CLIENT_NAME): cv.string, + vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), + vol.Required("schedule"): vol.Schema({ + vol.Optional("time_zone", default="Local"): cv.string, + vol.Optional("sun"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + vol.Optional("mon"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + vol.Optional("tue"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + vol.Optional("wed"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + vol.Optional("thu"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + vol.Optional("fri"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + vol.Optional("sat"): vol.Schema({ + vol.Optional("start"): cv.string, + vol.Optional("end"): cv.string, + }), + }), +}) + +# Service names +SERVICE_BLOCK_SERVICES = "block_services" +SERVICE_UNBLOCK_SERVICES = "unblock_services" +SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock" +SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients" +SERVICE_ADD_CLIENT = "add_client" +SERVICE_REMOVE_CLIENT = "remove_client" +SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block" + + +class AdGuardControlHubServices: + """Handle services for AdGuard Control Hub.""" + + def __init__(self, hass: HomeAssistant): + """Initialize the services.""" + self.hass = hass + self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {} + + def register_services(self) -> None: + """Register all services.""" + self.hass.services.register( + DOMAIN, + SERVICE_BLOCK_SERVICES, + self.block_services, + schema=SCHEMA_BLOCK_SERVICES, + ) + + self.hass.services.register( + DOMAIN, + SERVICE_UNBLOCK_SERVICES, + self.unblock_services, + schema=SCHEMA_UNBLOCK_SERVICES, + ) + + self.hass.services.register( + DOMAIN, + SERVICE_EMERGENCY_UNBLOCK, + self.emergency_unblock, + schema=SCHEMA_EMERGENCY_UNBLOCK, + ) + + self.hass.services.register( + DOMAIN, + SERVICE_BULK_UPDATE_CLIENTS, + self.bulk_update_clients, + schema=SCHEMA_BULK_UPDATE_CLIENTS, + ) + + self.hass.services.register( + DOMAIN, + SERVICE_ADD_CLIENT, + self.add_client, + schema=SCHEMA_ADD_CLIENT, + ) + + self.hass.services.register( + DOMAIN, + SERVICE_REMOVE_CLIENT, + self.remove_client, + schema=SCHEMA_REMOVE_CLIENT, + ) + + self.hass.services.register( + DOMAIN, + SERVICE_SCHEDULE_SERVICE_BLOCK, + self.schedule_service_block, + schema=SCHEMA_SCHEDULE_SERVICE_BLOCK, + ) + + def unregister_services(self) -> None: + """Unregister all services.""" + services = [ + SERVICE_BLOCK_SERVICES, + SERVICE_UNBLOCK_SERVICES, + SERVICE_EMERGENCY_UNBLOCK, + SERVICE_BULK_UPDATE_CLIENTS, + SERVICE_ADD_CLIENT, + SERVICE_REMOVE_CLIENT, + SERVICE_SCHEDULE_SERVICE_BLOCK, + ] + + for service in services: + self.hass.services.remove(DOMAIN, service) + + def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI: + """Get API instance for a specific config entry.""" + return self.hass.data[DOMAIN][entry_id]["api"] + + async def block_services(self, call: ServiceCall) -> None: + """Block services for a specific client.""" + client_name = call.data[ATTR_CLIENT_NAME] + services = call.data[ATTR_SERVICES] + + _LOGGER.info("Blocking services %s for client %s", services, client_name) + + # Get all API instances (for multiple AdGuard instances) + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + try: + # Get current client data + client = await api.get_client_by_name(client_name) + if not client: + _LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port) + continue + + # Get current blocked services + current_blocked = client.get("blocked_services", {}) + if isinstance(current_blocked, dict): + current_services = current_blocked.get("ids", []) + else: + current_services = current_blocked if current_blocked else [] + + # Add new services to block + updated_services = list(set(current_services + services)) + + # Update client + await api.update_client_blocked_services(client_name, updated_services) + _LOGGER.info("Successfully blocked services for %s", client_name) + + except Exception as err: + _LOGGER.error("Failed to block services for %s: %s", client_name, err) + + async def unblock_services(self, call: ServiceCall) -> None: + """Unblock services for a specific client.""" + client_name = call.data[ATTR_CLIENT_NAME] + services = call.data[ATTR_SERVICES] + + _LOGGER.info("Unblocking services %s for client %s", services, client_name) + + # Get all API instances + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + try: + # Get current client data + client = await api.get_client_by_name(client_name) + if not client: + continue + + # Get current blocked services + current_blocked = client.get("blocked_services", {}) + if isinstance(current_blocked, dict): + current_services = current_blocked.get("ids", []) + else: + current_services = current_blocked if current_blocked else [] + + # Remove services to unblock + updated_services = [s for s in current_services if s not in services] + + # Update client + await api.update_client_blocked_services(client_name, updated_services) + _LOGGER.info("Successfully unblocked services for %s", client_name) + + except Exception as err: + _LOGGER.error("Failed to unblock services for %s: %s", client_name, err) + + async def emergency_unblock(self, call: ServiceCall) -> None: + """Emergency unblock - temporarily disable protection.""" + duration = call.data[ATTR_DURATION] # seconds + clients = call.data[ATTR_CLIENTS] + + _LOGGER.warning("Emergency unblock activated for %s seconds", duration) + + # Cancel any existing emergency unblock tasks + for task in self._emergency_unblock_tasks.values(): + task.cancel() + self._emergency_unblock_tasks.clear() + + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + try: + if "all" in clients: + # Disable global protection + await api.set_protection(False) + + # Schedule re-enable + task = asyncio.create_task( + self._delayed_enable_protection(api, duration) + ) + self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task + else: + # Disable protection for specific clients + for client_name in clients: + client = await api.get_client_by_name(client_name) + if client: + # Store original blocked services + original_blocked = client.get("blocked_services", {}) + + # Clear blocked services temporarily + await api.update_client_blocked_services(client_name, []) + + # Schedule restore + task = asyncio.create_task( + self._delayed_restore_client(api, client_name, original_blocked, duration) + ) + self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task + + except Exception as err: + _LOGGER.error("Failed to execute emergency unblock: %s", err) + + async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None: + """Re-enable protection after delay.""" + await asyncio.sleep(delay) try: - if "all" in clients: - await api.set_protection(False) - _LOGGER.info("Emergency unblock activated globally for %d seconds", duration) - else: - _LOGGER.info("Emergency unblock activated for clients: %s", clients) + await api.set_protection(True) + _LOGGER.info("Emergency unblock expired - protection re-enabled") except Exception as err: - _LOGGER.error("Failed to execute emergency unblock: %s", err) - raise + _LOGGER.error("Failed to re-enable protection: %s", err) - # Register emergency unblock service - hass.services.async_register( - "adguard_hub", - "emergency_unblock", - emergency_unblock_service - ) + async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str, + original_blocked: Dict, delay: int) -> None: + """Restore client blocked services after delay.""" + await asyncio.sleep(delay) + try: + if isinstance(original_blocked, dict): + services = original_blocked.get("ids", []) + else: + services = original_blocked if original_blocked else [] - _LOGGER.info("AdGuard Control Hub services registered") + await api.update_client_blocked_services(client_name, services) + _LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name) + except Exception as err: + _LOGGER.error("Failed to restore client blocking: %s", err) -async def async_unregister_services(hass: HomeAssistant) -> None: - """Unregister integration services.""" - hass.services.async_remove("adguard_hub", "emergency_unblock") - _LOGGER.info("AdGuard Control Hub services unregistered") \ No newline at end of file + async def bulk_update_clients(self, call: ServiceCall) -> None: + """Update multiple clients matching a pattern.""" + import re + + pattern = call.data[ATTR_CLIENT_PATTERN] + settings = call.data[ATTR_SETTINGS] + + _LOGGER.info("Bulk updating clients matching pattern: %s", pattern) + + # Convert pattern to regex + regex_pattern = pattern.replace("*", ".*").replace("?", ".") + compiled_pattern = re.compile(regex_pattern, re.IGNORECASE) + + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + coordinator = entry_data["coordinator"] + + try: + # Get all clients + clients = coordinator.clients + + matching_clients = [] + for client_name in clients.keys(): + if compiled_pattern.match(client_name): + matching_clients.append(client_name) + + _LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients) + + # Update each matching client + for client_name in matching_clients: + client = clients[client_name] + + # Prepare update data + update_data = { + "name": client_name, + "data": {**client} # Start with current data + } + + # Apply settings + if "blocked_services" in settings: + blocked_services_data = { + "ids": settings["blocked_services"], + "schedule": {"time_zone": "Local"} + } + update_data["data"]["blocked_services"] = blocked_services_data + + if "filtering_enabled" in settings: + update_data["data"]["filtering_enabled"] = settings["filtering_enabled"] + + if "safebrowsing_enabled" in settings: + update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"] + + if "parental_enabled" in settings: + update_data["data"]["parental_enabled"] = settings["parental_enabled"] + + # Update the client + await api.update_client(update_data) + _LOGGER.info("Updated client: %s", client_name) + + except Exception as err: + _LOGGER.error("Failed to bulk update clients: %s", err) + + async def add_client(self, call: ServiceCall) -> None: + """Add a new client.""" + client_data = dict(call.data) + + # Convert blocked_services to proper format + if "blocked_services" in client_data and client_data["blocked_services"]: + blocked_services_data = { + "ids": client_data["blocked_services"], + "schedule": {"time_zone": "Local"} + } + client_data["blocked_services"] = blocked_services_data + + _LOGGER.info("Adding new client: %s", client_data["name"]) + + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + try: + await api.add_client(client_data) + _LOGGER.info("Successfully added client: %s", client_data["name"]) + except Exception as err: + _LOGGER.error("Failed to add client %s: %s", client_data["name"], err) + + async def remove_client(self, call: ServiceCall) -> None: + """Remove a client.""" + client_name = call.data["name"] + + _LOGGER.info("Removing client: %s", client_name) + + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + try: + await api.delete_client(client_name) + _LOGGER.info("Successfully removed client: %s", client_name) + except Exception as err: + _LOGGER.error("Failed to remove client %s: %s", client_name, err) + + async def schedule_service_block(self, call: ServiceCall) -> None: + """Schedule service blocking with time-based rules.""" + client_name = call.data[ATTR_CLIENT_NAME] + services = call.data[ATTR_SERVICES] + schedule = call.data["schedule"] + + _LOGGER.info("Scheduling service blocking for client %s", client_name) + + for entry_data in self.hass.data[DOMAIN].values(): + api: AdGuardHomeAPI = entry_data["api"] + try: + await api.update_client_blocked_services(client_name, services, schedule) + _LOGGER.info("Successfully scheduled service blocking for %s", client_name) + except Exception as err: + _LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err) diff --git a/custom_components/adguard_hub/strings.json b/custom_components/adguard_hub/strings.json index 0af874b..e9064e6 100644 --- a/custom_components/adguard_hub/strings.json +++ b/custom_components/adguard_hub/strings.json @@ -1,27 +1,151 @@ { - "config": { - "step": { - "user": { - "title": "AdGuard Control Hub", - "description": "Connect to your AdGuard Home instance for complete network control", - "data": { - "host": "AdGuard Home IP Address", - "port": "Port (usually 3000)", - "username": "Admin Username", - "password": "Admin Password", - "ssl": "Use HTTPS connection", - "verify_ssl": "Verify SSL certificate" + "config": { + "step": { + "user": { + "title": "AdGuard Control Hub", + "description": "Configure your AdGuard Home connection", + "data": { + "host": "Host", + "port": "Port", + "username": "Username", + "password": "Password", + "ssl": "Use SSL", + "verify_ssl": "Verify SSL Certificate" + } + } + }, + "error": { + "cannot_connect": "Failed to connect to AdGuard Home. Please check your host, port, and credentials.", + "invalid_auth": "Invalid username or password", + "timeout": "Connection timeout. Please check your network connection.", + "unknown": "An unexpected error occurred" + }, + "abort": { + "already_configured": "AdGuard Control Hub is already configured for this host and port" } - } }, - "error": { - "cannot_connect": "Failed to connect to AdGuard Home. Check IP address, port, and credentials.", - "invalid_auth": "Invalid username or password. Please check your admin credentials.", - "unknown": "Unexpected error occurred. Please check logs for details." + "options": { + "step": { + "init": { + "title": "AdGuard Control Hub Options", + "description": "Configure advanced options", + "data": { + "scan_interval": "Update interval (seconds)", + "timeout": "Connection timeout (seconds)" + } + } + } }, - "abort": { - "already_configured": "This AdGuard Home instance is already configured", - "cannot_connect": "Cannot connect to AdGuard Home" + "services": { + "block_services": { + "name": "Block Services", + "description": "Block specific services for a client", + "fields": { + "client_name": { + "name": "Client Name", + "description": "Name of the client to block services for" + }, + "services": { + "name": "Services", + "description": "List of services to block" + } + } + }, + "unblock_services": { + "name": "Unblock Services", + "description": "Unblock specific services for a client", + "fields": { + "client_name": { + "name": "Client Name", + "description": "Name of the client to unblock services for" + }, + "services": { + "name": "Services", + "description": "List of services to unblock" + } + } + }, + "emergency_unblock": { + "name": "Emergency Unblock", + "description": "Temporarily disable blocking for emergency access", + "fields": { + "duration": { + "name": "Duration", + "description": "Duration in seconds to keep unblocked" + }, + "clients": { + "name": "Clients", + "description": "List of client names (use 'all' for global)" + } + } + }, + "bulk_update_clients": { + "name": "Bulk Update Clients", + "description": "Update multiple clients matching a pattern", + "fields": { + "client_pattern": { + "name": "Client Pattern", + "description": "Pattern to match client names (supports wildcards)" + }, + "settings": { + "name": "Settings", + "description": "Settings to apply to matching clients" + } + } + }, + "add_client": { + "name": "Add Client", + "description": "Add a new client configuration", + "fields": { + "name": { + "name": "Name", + "description": "Client name" + }, + "ids": { + "name": "IDs", + "description": "List of IP addresses or CIDR ranges" + }, + "mac": { + "name": "MAC Address", + "description": "MAC address (optional)" + }, + "filtering_enabled": { + "name": "Filtering Enabled", + "description": "Enable DNS filtering for this client" + }, + "blocked_services": { + "name": "Blocked Services", + "description": "List of services to block" + } + } + }, + "remove_client": { + "name": "Remove Client", + "description": "Remove a client configuration", + "fields": { + "name": { + "name": "Name", + "description": "Name of the client to remove" + } + } + }, + "schedule_service_block": { + "name": "Schedule Service Block", + "description": "Schedule time-based service blocking", + "fields": { + "client_name": { + "name": "Client Name", + "description": "Name of the client" + }, + "services": { + "name": "Services", + "description": "List of services to block" + }, + "schedule": { + "name": "Schedule", + "description": "Time-based schedule configuration" + } + } + } } - } } \ No newline at end of file diff --git a/custom_components/adguard_hub/switch.py b/custom_components/adguard_hub/switch.py index 9c87400..67ae347 100644 --- a/custom_components/adguard_hub/switch.py +++ b/custom_components/adguard_hub/switch.py @@ -1,10 +1,13 @@ """Switch platform for AdGuard Control Hub integration.""" import logging +from typing import Any + from homeassistant.components.switch import SwitchEntity from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity + from . import AdGuardControlHubCoordinator from .api import AdGuardHomeAPI from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER @@ -12,7 +15,11 @@ from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MA _LOGGER = logging.getLogger(__name__) -async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback): +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: """Set up AdGuard Control Hub switch platform.""" coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"] api = hass.data[DOMAIN][config_entry.entry_id]["api"] @@ -32,6 +39,7 @@ class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity): """Base class for AdGuard switches.""" def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the switch.""" super().__init__(coordinator) self.api = api self._attr_device_info = { @@ -46,31 +54,64 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch): """Switch to control global AdGuard protection.""" def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + """Initialize the switch.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_protection" self._attr_name = "AdGuard Protection" @property - def is_on(self) -> bool: + def is_on(self) -> bool | None: + """Return true if protection is enabled.""" return self.coordinator.protection_status.get("protection_enabled", False) @property def icon(self) -> str: + """Return the icon for the switch.""" return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF - async def async_turn_on(self, **kwargs): - await self.api.set_protection(True) - await self.coordinator.async_request_refresh() + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + status = self.coordinator.protection_status + stats = self.coordinator.statistics + return { + "dns_port": status.get("dns_port", "N/A"), + "queries_today": stats.get("num_dns_queries_today", 0), + "blocked_today": stats.get("num_blocked_filtering_today", 0), + "version": status.get("version", "N/A"), + } - async def async_turn_off(self, **kwargs): - await self.api.set_protection(False) - await self.coordinator.async_request_refresh() + async def async_turn_on(self, **kwargs: Any) -> None: + """Turn on AdGuard protection.""" + try: + await self.api.set_protection(True) + await self.coordinator.async_request_refresh() + _LOGGER.info("AdGuard protection enabled") + except Exception as err: + _LOGGER.error("Failed to enable AdGuard protection: %s", err) + raise + + async def async_turn_off(self, **kwargs: Any) -> None: + """Turn off AdGuard protection.""" + try: + await self.api.set_protection(False) + await self.coordinator.async_request_refresh() + _LOGGER.info("AdGuard protection disabled") + except Exception as err: + _LOGGER.error("Failed to disable AdGuard protection: %s", err) + raise class AdGuardClientSwitch(AdGuardBaseSwitch): """Switch to control client-specific protection.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI, client_name: str): + def __init__( + self, + coordinator: AdGuardControlHubCoordinator, + api: AdGuardHomeAPI, + client_name: str, + ): + """Initialize the switch.""" super().__init__(coordinator, api) self.client_name = client_name self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}" @@ -78,16 +119,81 @@ class AdGuardClientSwitch(AdGuardBaseSwitch): self._attr_icon = ICON_CLIENT @property - def is_on(self) -> bool: + def is_on(self) -> bool | None: + """Return true if client protection is enabled.""" client = self.coordinator.clients.get(self.client_name, {}) return client.get("filtering_enabled", True) - async def async_turn_on(self, **kwargs): - # This would update client settings - simplified for basic functionality - _LOGGER.info("Would enable protection for %s", self.client_name) - await self.coordinator.async_request_refresh() + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + client = self.coordinator.clients.get(self.client_name, {}) + blocked_services = client.get("blocked_services", {}) - async def async_turn_off(self, **kwargs): - # This would update client settings - simplified for basic functionality - _LOGGER.info("Would disable protection for %s", self.client_name) - await self.coordinator.async_request_refresh() + if isinstance(blocked_services, dict): + service_ids = blocked_services.get("ids", []) + else: + service_ids = blocked_services if blocked_services else [] + + return { + "client_ids": client.get("ids", []), + "mac": client.get("mac", ""), + "use_global_settings": client.get("use_global_settings", True), + "safebrowsing_enabled": client.get("safebrowsing_enabled", False), + "parental_enabled": client.get("parental_enabled", False), + "safesearch_enabled": client.get("safesearch_enabled", False), + "blocked_services": service_ids, + "blocked_services_count": len(service_ids), + } + + async def async_turn_on(self, **kwargs: Any) -> None: + """Enable protection for this client.""" + try: + # Get current client data + client = await self.api.get_client_by_name(self.client_name) + if not client: + _LOGGER.error("Client %s not found", self.client_name) + return + + # Update client with filtering enabled + update_data = { + "name": self.client_name, + "data": { + **client, + "filtering_enabled": True, + } + } + + await self.api.update_client(update_data) + await self.coordinator.async_request_refresh() + _LOGGER.info("Enabled protection for client %s", self.client_name) + + except Exception as err: + _LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err) + raise + + async def async_turn_off(self, **kwargs: Any) -> None: + """Disable protection for this client.""" + try: + # Get current client data + client = await self.api.get_client_by_name(self.client_name) + if not client: + _LOGGER.error("Client %s not found", self.client_name) + return + + # Update client with filtering disabled + update_data = { + "name": self.client_name, + "data": { + **client, + "filtering_enabled": False, + } + } + + await self.api.update_client(update_data) + await self.coordinator.async_request_refresh() + _LOGGER.info("Disabled protection for client %s", self.client_name) + + except Exception as err: + _LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err) + raise diff --git a/tests/test_api.py b/tests/test_api.py index e52a55d..733eec0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -3,6 +3,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock from custom_components.adguard_hub.api import AdGuardHomeAPI + @pytest.fixture def mock_session(): """Mock aiohttp session.""" @@ -15,6 +16,7 @@ def mock_session(): session.request = AsyncMock(return_value=response) return session + async def test_api_connection(mock_session): """Test API connection.""" api = AdGuardHomeAPI( @@ -28,13 +30,14 @@ async def test_api_connection(mock_session): result = await api.test_connection() assert result is True + async def test_api_get_status(mock_session): """Test getting status.""" api = AdGuardHomeAPI( - host="test-host", + host="test-host", port=3000, session=mock_session ) status = await api.get_status() - assert status == {"status": "ok"} \ No newline at end of file + assert status == {"status": "ok"} diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..0917994 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,223 @@ +"""Test the complete AdGuard Control Hub integration.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from homeassistant.core import HomeAssistant +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME + +from custom_components.adguard_hub import async_setup_entry, async_unload_entry +from custom_components.adguard_hub.api import AdGuardHomeAPI +from custom_components.adguard_hub.const import DOMAIN + + +@pytest.fixture +def mock_config_entry(): + """Mock config entry.""" + return ConfigEntry( + version=1, + domain=DOMAIN, + title="Test AdGuard", + data={ + CONF_HOST: "192.168.1.100", + CONF_PORT: 3000, + CONF_USERNAME: "admin", + CONF_PASSWORD: "password", + }, + source="user", + entry_id="test_entry_id", + ) + + +@pytest.fixture +def mock_api(): + """Mock API instance.""" + api = MagicMock(spec=AdGuardHomeAPI) + api.host = "192.168.1.100" + api.port = 3000 + api.test_connection = AsyncMock(return_value=True) + api.get_status = AsyncMock(return_value={ + "protection_enabled": True, + "version": "v0.107.0", + "dns_port": 53, + "running": True, + }) + api.get_clients = AsyncMock(return_value={ + "clients": [ + { + "name": "test_client", + "ids": ["192.168.1.50"], + "filtering_enabled": True, + "blocked_services": {"ids": ["youtube"]}, + } + ] + }) + api.get_statistics = AsyncMock(return_value={ + "num_dns_queries": 1000, + "num_blocked_filtering": 100, + "num_dns_queries_today": 500, + "num_blocked_filtering_today": 50, + "filtering_rules_count": 50000, + "avg_processing_time": 2.5, + }) + return api + + +@pytest.mark.asyncio +async def test_setup_entry_success(hass: HomeAssistant, mock_config_entry, mock_api): + """Test successful setup of config entry.""" + with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \ + patch("custom_components.adguard_hub.async_get_clientsession"), \ + patch.object(hass.config_entries, "async_forward_entry_setups", return_value=True): + + result = await async_setup_entry(hass, mock_config_entry) + + assert result is True + assert DOMAIN in hass.data + assert mock_config_entry.entry_id in hass.data[DOMAIN] + assert "coordinator" in hass.data[DOMAIN][mock_config_entry.entry_id] + assert "api" in hass.data[DOMAIN][mock_config_entry.entry_id] + + +@pytest.mark.asyncio +async def test_setup_entry_connection_failure(hass: HomeAssistant, mock_config_entry): + """Test setup failure due to connection error.""" + mock_api = MagicMock(spec=AdGuardHomeAPI) + mock_api.test_connection = AsyncMock(return_value=False) + + with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \ + patch("custom_components.adguard_hub.async_get_clientsession"), \ + pytest.raises(Exception): # Should raise ConfigEntryNotReady + + await async_setup_entry(hass, mock_config_entry) + + +@pytest.mark.asyncio +async def test_unload_entry(hass: HomeAssistant, mock_config_entry): + """Test unloading of config entry.""" + # Set up initial data + hass.data[DOMAIN] = { + mock_config_entry.entry_id: { + "coordinator": MagicMock(), + "api": MagicMock(), + } + } + + with patch.object(hass.config_entries, "async_unload_platforms", return_value=True): + result = await async_unload_entry(hass, mock_config_entry) + + assert result is True + assert mock_config_entry.entry_id not in hass.data[DOMAIN] + + +@pytest.mark.asyncio +async def test_coordinator_data_update(hass: HomeAssistant, mock_api): + """Test coordinator data update functionality.""" + from custom_components.adguard_hub import AdGuardControlHubCoordinator + + coordinator = AdGuardControlHubCoordinator(hass, mock_api) + + # Test successful data update + data = await coordinator._async_update_data() + + assert "clients" in data + assert "statistics" in data + assert "status" in data + assert "test_client" in data["clients"] + assert data["statistics"]["num_dns_queries"] == 1000 + assert data["status"]["protection_enabled"] is True + + +@pytest.mark.asyncio +async def test_api_error_handling(mock_api): + """Test API error handling.""" + from custom_components.adguard_hub.api import AdGuardConnectionError, AdGuardAuthError + + # Test connection error + mock_api.get_status = AsyncMock(side_effect=AdGuardConnectionError("Connection failed")) + + with pytest.raises(AdGuardConnectionError): + await mock_api.get_status() + + # Test auth error + mock_api.get_clients = AsyncMock(side_effect=AdGuardAuthError("Auth failed")) + + with pytest.raises(AdGuardAuthError): + await mock_api.get_clients() + + +@pytest.mark.asyncio +async def test_services_registration(hass: HomeAssistant): + """Test that services are properly registered.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + services = AdGuardControlHubServices(hass) + services.register_services() + + # Check that services are registered + assert hass.services.has_service(DOMAIN, "block_services") + assert hass.services.has_service(DOMAIN, "unblock_services") + assert hass.services.has_service(DOMAIN, "emergency_unblock") + assert hass.services.has_service(DOMAIN, "bulk_update_clients") + assert hass.services.has_service(DOMAIN, "add_client") + assert hass.services.has_service(DOMAIN, "remove_client") + + # Clean up + services.unregister_services() + + +def test_blocked_services_constants(): + """Test that blocked services are properly defined.""" + from custom_components.adguard_hub.const import BLOCKED_SERVICES + + assert "youtube" in BLOCKED_SERVICES + assert "netflix" in BLOCKED_SERVICES + assert "gaming" in BLOCKED_SERVICES + assert "facebook" in BLOCKED_SERVICES + + # Check friendly names are defined + assert BLOCKED_SERVICES["youtube"] == "YouTube" + assert BLOCKED_SERVICES["netflix"] == "Netflix" + + +def test_api_endpoints(): + """Test that API endpoints are properly defined.""" + from custom_components.adguard_hub.const import API_ENDPOINTS + + required_endpoints = [ + "status", "clients", "stats", "protection", + "clients_add", "clients_update", "clients_delete" + ] + + for endpoint in required_endpoints: + assert endpoint in API_ENDPOINTS + assert API_ENDPOINTS[endpoint].startswith("/") + + +@pytest.mark.asyncio +async def test_client_operations(mock_api): + """Test client add/update/delete operations.""" + # Test add client + client_data = { + "name": "new_client", + "ids": ["192.168.1.200"], + "filtering_enabled": True, + } + + mock_api.add_client = AsyncMock(return_value={"success": True}) + result = await mock_api.add_client(client_data) + assert result["success"] is True + + # Test update client + update_data = { + "name": "new_client", + "data": {"filtering_enabled": False} + } + + mock_api.update_client = AsyncMock(return_value={"success": True}) + result = await mock_api.update_client(update_data) + assert result["success"] is True + + # Test delete client + mock_api.delete_client = AsyncMock(return_value={"success": True}) + result = await mock_api.delete_client("new_client") + assert result["success"] is True