"""Service implementations for AdGuard Control Hub integration.""" import asyncio import logging from datetime import datetime, timedelta from typing import Any, Dict, List import voluptuous as vol from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.helpers import config_validation as cv from .api import AdGuardHomeAPI from .const import ( DOMAIN, BLOCKED_SERVICES, ATTR_CLIENT_NAME, ATTR_SERVICES, ATTR_DURATION, ATTR_CLIENTS, ATTR_CLIENT_PATTERN, ATTR_SETTINGS, ) _LOGGER = logging.getLogger(__name__) # Service schemas SCHEMA_BLOCK_SERVICES = vol.Schema({ vol.Required(ATTR_CLIENT_NAME): cv.string, vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), }) SCHEMA_UNBLOCK_SERVICES = vol.Schema({ vol.Required(ATTR_CLIENT_NAME): cv.string, vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), }) SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({ vol.Required(ATTR_DURATION): cv.positive_int, vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]), }) SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({ vol.Required(ATTR_CLIENT_PATTERN): cv.string, vol.Required(ATTR_SETTINGS): vol.Schema({ vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), vol.Optional("filtering_enabled"): cv.boolean, vol.Optional("safebrowsing_enabled"): cv.boolean, vol.Optional("parental_enabled"): cv.boolean, }), }) SCHEMA_ADD_CLIENT = vol.Schema({ vol.Required("name"): cv.string, vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]), vol.Optional("mac"): cv.string, vol.Optional("use_global_settings"): cv.boolean, vol.Optional("filtering_enabled"): cv.boolean, vol.Optional("parental_enabled"): cv.boolean, vol.Optional("safebrowsing_enabled"): cv.boolean, vol.Optional("safesearch_enabled"): cv.boolean, vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), }) SCHEMA_REMOVE_CLIENT = vol.Schema({ vol.Required("name"): cv.string, }) SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({ vol.Required(ATTR_CLIENT_NAME): cv.string, vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), vol.Required("schedule"): vol.Schema({ vol.Optional("time_zone", default="Local"): cv.string, vol.Optional("sun"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), vol.Optional("mon"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), vol.Optional("tue"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), vol.Optional("wed"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), vol.Optional("thu"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), vol.Optional("fri"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), vol.Optional("sat"): vol.Schema({ vol.Optional("start"): cv.string, vol.Optional("end"): cv.string, }), }), }) # Service names SERVICE_BLOCK_SERVICES = "block_services" SERVICE_UNBLOCK_SERVICES = "unblock_services" SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock" SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients" SERVICE_ADD_CLIENT = "add_client" SERVICE_REMOVE_CLIENT = "remove_client" SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block" class AdGuardControlHubServices: """Handle services for AdGuard Control Hub.""" def __init__(self, hass: HomeAssistant): """Initialize the services.""" self.hass = hass self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {} def register_services(self) -> None: """Register all services.""" self.hass.services.register( DOMAIN, SERVICE_BLOCK_SERVICES, self.block_services, schema=SCHEMA_BLOCK_SERVICES, ) self.hass.services.register( DOMAIN, SERVICE_UNBLOCK_SERVICES, self.unblock_services, schema=SCHEMA_UNBLOCK_SERVICES, ) self.hass.services.register( DOMAIN, SERVICE_EMERGENCY_UNBLOCK, self.emergency_unblock, schema=SCHEMA_EMERGENCY_UNBLOCK, ) self.hass.services.register( DOMAIN, SERVICE_BULK_UPDATE_CLIENTS, self.bulk_update_clients, schema=SCHEMA_BULK_UPDATE_CLIENTS, ) self.hass.services.register( DOMAIN, SERVICE_ADD_CLIENT, self.add_client, schema=SCHEMA_ADD_CLIENT, ) self.hass.services.register( DOMAIN, SERVICE_REMOVE_CLIENT, self.remove_client, schema=SCHEMA_REMOVE_CLIENT, ) self.hass.services.register( DOMAIN, SERVICE_SCHEDULE_SERVICE_BLOCK, self.schedule_service_block, schema=SCHEMA_SCHEDULE_SERVICE_BLOCK, ) def unregister_services(self) -> None: """Unregister all services.""" services = [ SERVICE_BLOCK_SERVICES, SERVICE_UNBLOCK_SERVICES, SERVICE_EMERGENCY_UNBLOCK, SERVICE_BULK_UPDATE_CLIENTS, SERVICE_ADD_CLIENT, SERVICE_REMOVE_CLIENT, SERVICE_SCHEDULE_SERVICE_BLOCK, ] for service in services: self.hass.services.remove(DOMAIN, service) def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI: """Get API instance for a specific config entry.""" return self.hass.data[DOMAIN][entry_id]["api"] async def block_services(self, call: ServiceCall) -> None: """Block services for a specific client.""" client_name = call.data[ATTR_CLIENT_NAME] services = call.data[ATTR_SERVICES] _LOGGER.info("Blocking services %s for client %s", services, client_name) # Get all API instances (for multiple AdGuard instances) for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] try: # Get current client data client = await api.get_client_by_name(client_name) if not client: _LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port) continue # Get current blocked services current_blocked = client.get("blocked_services", {}) if isinstance(current_blocked, dict): current_services = current_blocked.get("ids", []) else: current_services = current_blocked if current_blocked else [] # Add new services to block updated_services = list(set(current_services + services)) # Update client await api.update_client_blocked_services(client_name, updated_services) _LOGGER.info("Successfully blocked services for %s", client_name) except Exception as err: _LOGGER.error("Failed to block services for %s: %s", client_name, err) async def unblock_services(self, call: ServiceCall) -> None: """Unblock services for a specific client.""" client_name = call.data[ATTR_CLIENT_NAME] services = call.data[ATTR_SERVICES] _LOGGER.info("Unblocking services %s for client %s", services, client_name) # Get all API instances for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] try: # Get current client data client = await api.get_client_by_name(client_name) if not client: continue # Get current blocked services current_blocked = client.get("blocked_services", {}) if isinstance(current_blocked, dict): current_services = current_blocked.get("ids", []) else: current_services = current_blocked if current_blocked else [] # Remove services to unblock updated_services = [s for s in current_services if s not in services] # Update client await api.update_client_blocked_services(client_name, updated_services) _LOGGER.info("Successfully unblocked services for %s", client_name) except Exception as err: _LOGGER.error("Failed to unblock services for %s: %s", client_name, err) async def emergency_unblock(self, call: ServiceCall) -> None: """Emergency unblock - temporarily disable protection.""" duration = call.data[ATTR_DURATION] # seconds clients = call.data[ATTR_CLIENTS] _LOGGER.warning("Emergency unblock activated for %s seconds", duration) # Cancel any existing emergency unblock tasks for task in self._emergency_unblock_tasks.values(): task.cancel() self._emergency_unblock_tasks.clear() for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] try: if "all" in clients: # Disable global protection await api.set_protection(False) # Schedule re-enable task = asyncio.create_task( self._delayed_enable_protection(api, duration) ) self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task else: # Disable protection for specific clients for client_name in clients: client = await api.get_client_by_name(client_name) if client: # Store original blocked services original_blocked = client.get("blocked_services", {}) # Clear blocked services temporarily await api.update_client_blocked_services(client_name, []) # Schedule restore task = asyncio.create_task( self._delayed_restore_client(api, client_name, original_blocked, duration) ) self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task except Exception as err: _LOGGER.error("Failed to execute emergency unblock: %s", err) async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None: """Re-enable protection after delay.""" await asyncio.sleep(delay) try: await api.set_protection(True) _LOGGER.info("Emergency unblock expired - protection re-enabled") except Exception as err: _LOGGER.error("Failed to re-enable protection: %s", err) async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str, original_blocked: Dict, delay: int) -> None: """Restore client blocked services after delay.""" await asyncio.sleep(delay) try: if isinstance(original_blocked, dict): services = original_blocked.get("ids", []) else: services = original_blocked if original_blocked else [] await api.update_client_blocked_services(client_name, services) _LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name) except Exception as err: _LOGGER.error("Failed to restore client blocking: %s", err) async def bulk_update_clients(self, call: ServiceCall) -> None: """Update multiple clients matching a pattern.""" import re pattern = call.data[ATTR_CLIENT_PATTERN] settings = call.data[ATTR_SETTINGS] _LOGGER.info("Bulk updating clients matching pattern: %s", pattern) # Convert pattern to regex regex_pattern = pattern.replace("*", ".*").replace("?", ".") compiled_pattern = re.compile(regex_pattern, re.IGNORECASE) for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] coordinator = entry_data["coordinator"] try: # Get all clients clients = coordinator.clients matching_clients = [] for client_name in clients.keys(): if compiled_pattern.match(client_name): matching_clients.append(client_name) _LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients) # Update each matching client for client_name in matching_clients: client = clients[client_name] # Prepare update data update_data = { "name": client_name, "data": {**client} # Start with current data } # Apply settings if "blocked_services" in settings: blocked_services_data = { "ids": settings["blocked_services"], "schedule": {"time_zone": "Local"} } update_data["data"]["blocked_services"] = blocked_services_data if "filtering_enabled" in settings: update_data["data"]["filtering_enabled"] = settings["filtering_enabled"] if "safebrowsing_enabled" in settings: update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"] if "parental_enabled" in settings: update_data["data"]["parental_enabled"] = settings["parental_enabled"] # Update the client await api.update_client(update_data) _LOGGER.info("Updated client: %s", client_name) except Exception as err: _LOGGER.error("Failed to bulk update clients: %s", err) async def add_client(self, call: ServiceCall) -> None: """Add a new client.""" client_data = dict(call.data) # Convert blocked_services to proper format if "blocked_services" in client_data and client_data["blocked_services"]: blocked_services_data = { "ids": client_data["blocked_services"], "schedule": {"time_zone": "Local"} } client_data["blocked_services"] = blocked_services_data _LOGGER.info("Adding new client: %s", client_data["name"]) for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] try: await api.add_client(client_data) _LOGGER.info("Successfully added client: %s", client_data["name"]) except Exception as err: _LOGGER.error("Failed to add client %s: %s", client_data["name"], err) async def remove_client(self, call: ServiceCall) -> None: """Remove a client.""" client_name = call.data["name"] _LOGGER.info("Removing client: %s", client_name) for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] try: await api.delete_client(client_name) _LOGGER.info("Successfully removed client: %s", client_name) except Exception as err: _LOGGER.error("Failed to remove client %s: %s", client_name, err) async def schedule_service_block(self, call: ServiceCall) -> None: """Schedule service blocking with time-based rules.""" client_name = call.data[ATTR_CLIENT_NAME] services = call.data[ATTR_SERVICES] schedule = call.data["schedule"] _LOGGER.info("Scheduling service blocking for client %s", client_name) for entry_data in self.hass.data[DOMAIN].values(): api: AdGuardHomeAPI = entry_data["api"] try: await api.update_client_blocked_services(client_name, services, schedule) _LOGGER.info("Successfully scheduled service blocking for %s", client_name) except Exception as err: _LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)