277 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			277 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Service implementations for AdGuard Control Hub integration."""
 | |
| import asyncio
 | |
| import logging
 | |
| from typing import Any, Dict
 | |
| 
 | |
| import voluptuous as vol
 | |
| from homeassistant.core import HomeAssistant, ServiceCall
 | |
| from homeassistant.helpers import config_validation as cv
 | |
| 
 | |
| from .api import AdGuardHomeAPI, AdGuardHomeError
 | |
| from .const import (
 | |
|     DOMAIN,
 | |
|     BLOCKED_SERVICES,
 | |
|     ATTR_CLIENT_NAME,
 | |
|     ATTR_SERVICES,
 | |
|     ATTR_DURATION,
 | |
|     ATTR_CLIENTS,
 | |
|     ATTR_ENABLED,
 | |
|     SERVICE_BLOCK_SERVICES,
 | |
|     SERVICE_UNBLOCK_SERVICES,
 | |
|     SERVICE_EMERGENCY_UNBLOCK,
 | |
|     SERVICE_ADD_CLIENT,
 | |
|     SERVICE_REMOVE_CLIENT,
 | |
|     SERVICE_REFRESH_DATA,
 | |
| )
 | |
| 
 | |
| _LOGGER = logging.getLogger(__name__)
 | |
| 
 | |
| # Service schemas
 | |
| SCHEMA_BLOCK_SERVICES = vol.Schema({
 | |
|     vol.Required(ATTR_CLIENT_NAME): cv.string,
 | |
|     vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
 | |
| })
 | |
| 
 | |
| SCHEMA_UNBLOCK_SERVICES = vol.Schema({
 | |
|     vol.Required(ATTR_CLIENT_NAME): cv.string,
 | |
|     vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
 | |
| })
 | |
| 
 | |
| SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
 | |
|     vol.Required(ATTR_DURATION): cv.positive_int,
 | |
|     vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
 | |
| })
 | |
| 
 | |
| SCHEMA_ADD_CLIENT = vol.Schema({
 | |
|     vol.Required("name"): cv.string,
 | |
|     vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
 | |
|     vol.Optional("filtering_enabled", default=True): cv.boolean,
 | |
|     vol.Optional("safebrowsing_enabled", default=False): cv.boolean,
 | |
|     vol.Optional("parental_enabled", default=False): cv.boolean,
 | |
|     vol.Optional("safesearch_enabled", default=False): cv.boolean,
 | |
|     vol.Optional("use_global_blocked_services", default=True): cv.boolean,
 | |
|     vol.Optional("blocked_services", default=[]): vol.All(cv.ensure_list, [cv.string]),
 | |
| })
 | |
| 
 | |
| SCHEMA_REMOVE_CLIENT = vol.Schema({
 | |
|     vol.Required("name"): cv.string,
 | |
| })
 | |
| 
 | |
| SCHEMA_REFRESH_DATA = vol.Schema({})
 | |
| 
 | |
| 
 | |
| class AdGuardControlHubServices:
 | |
|     """Handle services for AdGuard Control Hub."""
 | |
| 
 | |
|     def __init__(self, hass: HomeAssistant) -> None:
 | |
|         """Initialize the services."""
 | |
|         self.hass = hass
 | |
| 
 | |
|     def register_services(self) -> None:
 | |
|         """Register all services."""
 | |
|         _LOGGER.debug("Registering AdGuard Control Hub services")
 | |
| 
 | |
|         services = [
 | |
|             (SERVICE_BLOCK_SERVICES, self.block_services, SCHEMA_BLOCK_SERVICES),
 | |
|             (SERVICE_UNBLOCK_SERVICES, self.unblock_services, SCHEMA_UNBLOCK_SERVICES),
 | |
|             (SERVICE_EMERGENCY_UNBLOCK, self.emergency_unblock, SCHEMA_EMERGENCY_UNBLOCK),
 | |
|             (SERVICE_ADD_CLIENT, self.add_client, SCHEMA_ADD_CLIENT),
 | |
|             (SERVICE_REMOVE_CLIENT, self.remove_client, SCHEMA_REMOVE_CLIENT),
 | |
|             (SERVICE_REFRESH_DATA, self.refresh_data, SCHEMA_REFRESH_DATA),
 | |
|         ]
 | |
| 
 | |
|         for service_name, service_func, schema in services:
 | |
|             if not self.hass.services.has_service(DOMAIN, service_name):
 | |
|                 self.hass.services.register(DOMAIN, service_name, service_func, schema=schema)
 | |
|                 _LOGGER.debug("Registered service: %s", service_name)
 | |
| 
 | |
|     def unregister_services(self) -> None:
 | |
|         """Unregister all services."""
 | |
|         _LOGGER.debug("Unregistering AdGuard Control Hub services")
 | |
| 
 | |
|         services = [
 | |
|             SERVICE_BLOCK_SERVICES,
 | |
|             SERVICE_UNBLOCK_SERVICES,
 | |
|             SERVICE_EMERGENCY_UNBLOCK,
 | |
|             SERVICE_ADD_CLIENT,
 | |
|             SERVICE_REMOVE_CLIENT,
 | |
|             SERVICE_REFRESH_DATA,
 | |
|         ]
 | |
| 
 | |
|         for service_name in services:
 | |
|             if self.hass.services.has_service(DOMAIN, service_name):
 | |
|                 self.hass.services.remove(DOMAIN, service_name)
 | |
|                 _LOGGER.debug("Unregistered service: %s", service_name)
 | |
| 
 | |
|     def _get_api_instances(self) -> list[AdGuardHomeAPI]:
 | |
|         """Get all API instances."""
 | |
|         apis = []
 | |
|         for entry_data in self.hass.data.get(DOMAIN, {}).values():
 | |
|             if isinstance(entry_data, dict) and "api" in entry_data:
 | |
|                 apis.append(entry_data["api"])
 | |
|         return apis
 | |
| 
 | |
|     async def block_services(self, call: ServiceCall) -> None:
 | |
|         """Block services for a specific client."""
 | |
|         client_name = call.data[ATTR_CLIENT_NAME]
 | |
|         services = call.data[ATTR_SERVICES]
 | |
| 
 | |
|         _LOGGER.info("Blocking services %s for client %s", services, client_name)
 | |
| 
 | |
|         success_count = 0
 | |
|         for api in self._get_api_instances():
 | |
|             try:
 | |
|                 client = await api.get_client_by_name(client_name)
 | |
|                 if client:
 | |
|                     current_blocked = client.get("blocked_services", {})
 | |
|                     if isinstance(current_blocked, dict):
 | |
|                         current_services = current_blocked.get("ids", [])
 | |
|                     else:
 | |
|                         current_services = current_blocked or []
 | |
| 
 | |
|                     updated_services = list(set(current_services + services))
 | |
|                     await api.update_client_blocked_services(client_name, updated_services)
 | |
|                     success_count += 1
 | |
|                     _LOGGER.info("Successfully blocked services for %s", client_name)
 | |
|                 else:
 | |
|                     _LOGGER.warning("Client %s not found", client_name)
 | |
|             except AdGuardHomeError as err:
 | |
|                 _LOGGER.error("AdGuard error blocking services for %s: %s", client_name, err)
 | |
|             except Exception as err:
 | |
|                 _LOGGER.exception("Unexpected error blocking services for %s: %s", client_name, err)
 | |
| 
 | |
|         if success_count == 0:
 | |
|             _LOGGER.error("Failed to block services for %s on any instance", client_name)
 | |
| 
 | |
|     async def unblock_services(self, call: ServiceCall) -> None:
 | |
|         """Unblock services for a specific client."""
 | |
|         client_name = call.data[ATTR_CLIENT_NAME]
 | |
|         services = call.data[ATTR_SERVICES]
 | |
| 
 | |
|         _LOGGER.info("Unblocking services %s for client %s", services, client_name)
 | |
| 
 | |
|         success_count = 0
 | |
|         for api in self._get_api_instances():
 | |
|             try:
 | |
|                 client = await api.get_client_by_name(client_name)
 | |
|                 if client:
 | |
|                     current_blocked = client.get("blocked_services", {})
 | |
|                     if isinstance(current_blocked, dict):
 | |
|                         current_services = current_blocked.get("ids", [])
 | |
|                     else:
 | |
|                         current_services = current_blocked or []
 | |
| 
 | |
|                     updated_services = [s for s in current_services if s not in services]
 | |
|                     await api.update_client_blocked_services(client_name, updated_services)
 | |
|                     success_count += 1
 | |
|                     _LOGGER.info("Successfully unblocked services for %s", client_name)
 | |
|                 else:
 | |
|                     _LOGGER.warning("Client %s not found", client_name)
 | |
|             except AdGuardHomeError as err:
 | |
|                 _LOGGER.error("AdGuard error unblocking services for %s: %s", client_name, err)
 | |
|             except Exception as err:
 | |
|                 _LOGGER.exception("Unexpected error unblocking services for %s: %s", client_name, err)
 | |
| 
 | |
|         if success_count == 0:
 | |
|             _LOGGER.error("Failed to unblock services for %s on any instance", client_name)
 | |
| 
 | |
|     async def emergency_unblock(self, call: ServiceCall) -> None:
 | |
|         """Emergency unblock - temporarily disable protection."""
 | |
|         duration = call.data[ATTR_DURATION]
 | |
|         clients = call.data[ATTR_CLIENTS]
 | |
| 
 | |
|         _LOGGER.warning("Emergency unblock activated for %s seconds", duration)
 | |
| 
 | |
|         for api in self._get_api_instances():
 | |
|             try:
 | |
|                 if "all" in clients:
 | |
|                     await api.set_protection(False)
 | |
|                     _LOGGER.warning("Protection disabled for %s:%s", api.host, api.port)
 | |
| 
 | |
|                     # Re-enable after duration
 | |
|                     async def delayed_enable(api_instance: AdGuardHomeAPI):
 | |
|                         await asyncio.sleep(duration)
 | |
|                         try:
 | |
|                             await api_instance.set_protection(True)
 | |
|                             _LOGGER.info("Emergency unblock expired - protection re-enabled for %s:%s", 
 | |
|                                        api_instance.host, api_instance.port)
 | |
|                         except Exception as err:
 | |
|                             _LOGGER.error("Failed to re-enable protection for %s:%s: %s", 
 | |
|                                         api_instance.host, api_instance.port, err)
 | |
| 
 | |
|                     asyncio.create_task(delayed_enable(api))
 | |
|                 else:
 | |
|                     # Individual client emergency unblock
 | |
|                     for client_name in clients:
 | |
|                         if client_name == "all":
 | |
|                             continue
 | |
|                         try:
 | |
|                             client = await api.get_client_by_name(client_name)
 | |
|                             if client:
 | |
|                                 update_data = {
 | |
|                                     "name": client_name,
 | |
|                                     "data": {**client, "filtering_enabled": False}
 | |
|                                 }
 | |
|                                 await api.update_client(update_data)
 | |
|                                 _LOGGER.info("Emergency unblock applied to client %s", client_name)
 | |
|                         except Exception as err:
 | |
|                             _LOGGER.error("Failed to emergency unblock client %s: %s", client_name, err)
 | |
| 
 | |
|             except AdGuardHomeError as err:
 | |
|                 _LOGGER.error("AdGuard error during emergency unblock: %s", err)
 | |
|             except Exception as err:
 | |
|                 _LOGGER.exception("Unexpected error during emergency unblock: %s", err)
 | |
| 
 | |
|     async def add_client(self, call: ServiceCall) -> None:
 | |
|         """Add a new client."""
 | |
|         client_data = dict(call.data)
 | |
| 
 | |
|         _LOGGER.info("Adding new client: %s", client_data.get("name"))
 | |
| 
 | |
|         success_count = 0
 | |
|         for api in self._get_api_instances():
 | |
|             try:
 | |
|                 await api.add_client(client_data)
 | |
|                 success_count += 1
 | |
|                 _LOGGER.info("Successfully added client: %s", client_data.get("name"))
 | |
|             except AdGuardHomeError as err:
 | |
|                 _LOGGER.error("AdGuard error adding client: %s", err)
 | |
|             except Exception as err:
 | |
|                 _LOGGER.exception("Unexpected error adding client: %s", err)
 | |
| 
 | |
|         if success_count == 0:
 | |
|             _LOGGER.error("Failed to add client %s on any instance", client_data.get("name"))
 | |
| 
 | |
|     async def remove_client(self, call: ServiceCall) -> None:
 | |
|         """Remove a client."""
 | |
|         client_name = call.data.get("name")
 | |
| 
 | |
|         _LOGGER.info("Removing client: %s", client_name)
 | |
| 
 | |
|         success_count = 0
 | |
|         for api in self._get_api_instances():
 | |
|             try:
 | |
|                 await api.delete_client(client_name)
 | |
|                 success_count += 1
 | |
|                 _LOGGER.info("Successfully removed client: %s", client_name)
 | |
|             except AdGuardHomeError as err:
 | |
|                 _LOGGER.error("AdGuard error removing client: %s", err)
 | |
|             except Exception as err:
 | |
|                 _LOGGER.exception("Unexpected error removing client: %s", err)
 | |
| 
 | |
|         if success_count == 0:
 | |
|             _LOGGER.error("Failed to remove client %s on any instance", client_name)
 | |
| 
 | |
|     async def refresh_data(self, call: ServiceCall) -> None:
 | |
|         """Refresh data for all coordinators."""
 | |
|         _LOGGER.info("Manually refreshing AdGuard Control Hub data")
 | |
| 
 | |
|         for entry_data in self.hass.data.get(DOMAIN, {}).values():
 | |
|             if isinstance(entry_data, dict) and "coordinator" in entry_data:
 | |
|                 coordinator = entry_data["coordinator"]
 | |
|                 try:
 | |
|                     await coordinator.async_request_refresh()
 | |
|                     _LOGGER.debug("Refreshed coordinator data")
 | |
|                 except Exception as err:
 | |
|                     _LOGGER.error("Failed to refresh coordinator: %s", err)
 |