Files
adguard-control-hub/custom_components/adguard_hub/services.py
Rafal Zielinski 8281a1813d
Some checks failed
Integration Testing / Integration Tests (2024.12.0, 3.11) (push) Failing after 22s
Integration Testing / Integration Tests (2024.12.0, 3.12) (push) Failing after 21s
Integration Testing / Integration Tests (2024.12.0, 3.13) (push) Failing after 1m32s
Integration Testing / Integration Tests (2025.9.4, 3.11) (push) Failing after 15s
Integration Testing / Integration Tests (2025.9.4, 3.12) (push) Failing after 20s
Integration Testing / Integration Tests (2025.9.4, 3.13) (push) Failing after 20s
fix: Fix CI/CD issues and enhance integration
Signed-off-by: Rafal Zielinski <sq4ind@gmail.com>
2025-09-28 17:24:46 +01:00

277 lines
12 KiB
Python

"""Service implementations for AdGuard Control Hub integration."""
import asyncio
import logging
from typing import Any, Dict
import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import config_validation as cv
from .api import AdGuardHomeAPI, AdGuardHomeError
from .const import (
DOMAIN,
BLOCKED_SERVICES,
ATTR_CLIENT_NAME,
ATTR_SERVICES,
ATTR_DURATION,
ATTR_CLIENTS,
ATTR_ENABLED,
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_REFRESH_DATA,
)
_LOGGER = logging.getLogger(__name__)
# Service schemas
SCHEMA_BLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
vol.Required(ATTR_DURATION): cv.positive_int,
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
})
SCHEMA_ADD_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("filtering_enabled", default=True): cv.boolean,
vol.Optional("safebrowsing_enabled", default=False): cv.boolean,
vol.Optional("parental_enabled", default=False): cv.boolean,
vol.Optional("safesearch_enabled", default=False): cv.boolean,
vol.Optional("use_global_blocked_services", default=True): cv.boolean,
vol.Optional("blocked_services", default=[]): vol.All(cv.ensure_list, [cv.string]),
})
SCHEMA_REMOVE_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
})
SCHEMA_REFRESH_DATA = vol.Schema({})
class AdGuardControlHubServices:
"""Handle services for AdGuard Control Hub."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the services."""
self.hass = hass
def register_services(self) -> None:
"""Register all services."""
_LOGGER.debug("Registering AdGuard Control Hub services")
services = [
(SERVICE_BLOCK_SERVICES, self.block_services, SCHEMA_BLOCK_SERVICES),
(SERVICE_UNBLOCK_SERVICES, self.unblock_services, SCHEMA_UNBLOCK_SERVICES),
(SERVICE_EMERGENCY_UNBLOCK, self.emergency_unblock, SCHEMA_EMERGENCY_UNBLOCK),
(SERVICE_ADD_CLIENT, self.add_client, SCHEMA_ADD_CLIENT),
(SERVICE_REMOVE_CLIENT, self.remove_client, SCHEMA_REMOVE_CLIENT),
(SERVICE_REFRESH_DATA, self.refresh_data, SCHEMA_REFRESH_DATA),
]
for service_name, service_func, schema in services:
if not self.hass.services.has_service(DOMAIN, service_name):
self.hass.services.register(DOMAIN, service_name, service_func, schema=schema)
_LOGGER.debug("Registered service: %s", service_name)
def unregister_services(self) -> None:
"""Unregister all services."""
_LOGGER.debug("Unregistering AdGuard Control Hub services")
services = [
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_REFRESH_DATA,
]
for service_name in services:
if self.hass.services.has_service(DOMAIN, service_name):
self.hass.services.remove(DOMAIN, service_name)
_LOGGER.debug("Unregistered service: %s", service_name)
def _get_api_instances(self) -> list[AdGuardHomeAPI]:
"""Get all API instances."""
apis = []
for entry_data in self.hass.data.get(DOMAIN, {}).values():
if isinstance(entry_data, dict) and "api" in entry_data:
apis.append(entry_data["api"])
return apis
async def block_services(self, call: ServiceCall) -> None:
"""Block services for a specific client."""
client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES]
_LOGGER.info("Blocking services %s for client %s", services, client_name)
success_count = 0
for api in self._get_api_instances():
try:
client = await api.get_client_by_name(client_name)
if client:
current_blocked = client.get("blocked_services", {})
if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked or []
updated_services = list(set(current_services + services))
await api.update_client_blocked_services(client_name, updated_services)
success_count += 1
_LOGGER.info("Successfully blocked services for %s", client_name)
else:
_LOGGER.warning("Client %s not found", client_name)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error blocking services for %s: %s", client_name, err)
except Exception as err:
_LOGGER.exception("Unexpected error blocking services for %s: %s", client_name, err)
if success_count == 0:
_LOGGER.error("Failed to block services for %s on any instance", client_name)
async def unblock_services(self, call: ServiceCall) -> None:
"""Unblock services for a specific client."""
client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES]
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
success_count = 0
for api in self._get_api_instances():
try:
client = await api.get_client_by_name(client_name)
if client:
current_blocked = client.get("blocked_services", {})
if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked or []
updated_services = [s for s in current_services if s not in services]
await api.update_client_blocked_services(client_name, updated_services)
success_count += 1
_LOGGER.info("Successfully unblocked services for %s", client_name)
else:
_LOGGER.warning("Client %s not found", client_name)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error unblocking services for %s: %s", client_name, err)
except Exception as err:
_LOGGER.exception("Unexpected error unblocking services for %s: %s", client_name, err)
if success_count == 0:
_LOGGER.error("Failed to unblock services for %s on any instance", client_name)
async def emergency_unblock(self, call: ServiceCall) -> None:
"""Emergency unblock - temporarily disable protection."""
duration = call.data[ATTR_DURATION]
clients = call.data[ATTR_CLIENTS]
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
for api in self._get_api_instances():
try:
if "all" in clients:
await api.set_protection(False)
_LOGGER.warning("Protection disabled for %s:%s", api.host, api.port)
# Re-enable after duration
async def delayed_enable(api_instance: AdGuardHomeAPI):
await asyncio.sleep(duration)
try:
await api_instance.set_protection(True)
_LOGGER.info("Emergency unblock expired - protection re-enabled for %s:%s",
api_instance.host, api_instance.port)
except Exception as err:
_LOGGER.error("Failed to re-enable protection for %s:%s: %s",
api_instance.host, api_instance.port, err)
asyncio.create_task(delayed_enable(api))
else:
# Individual client emergency unblock
for client_name in clients:
if client_name == "all":
continue
try:
client = await api.get_client_by_name(client_name)
if client:
update_data = {
"name": client_name,
"data": {**client, "filtering_enabled": False}
}
await api.update_client(update_data)
_LOGGER.info("Emergency unblock applied to client %s", client_name)
except Exception as err:
_LOGGER.error("Failed to emergency unblock client %s: %s", client_name, err)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error during emergency unblock: %s", err)
except Exception as err:
_LOGGER.exception("Unexpected error during emergency unblock: %s", err)
async def add_client(self, call: ServiceCall) -> None:
"""Add a new client."""
client_data = dict(call.data)
_LOGGER.info("Adding new client: %s", client_data.get("name"))
success_count = 0
for api in self._get_api_instances():
try:
await api.add_client(client_data)
success_count += 1
_LOGGER.info("Successfully added client: %s", client_data.get("name"))
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error adding client: %s", err)
except Exception as err:
_LOGGER.exception("Unexpected error adding client: %s", err)
if success_count == 0:
_LOGGER.error("Failed to add client %s on any instance", client_data.get("name"))
async def remove_client(self, call: ServiceCall) -> None:
"""Remove a client."""
client_name = call.data.get("name")
_LOGGER.info("Removing client: %s", client_name)
success_count = 0
for api in self._get_api_instances():
try:
await api.delete_client(client_name)
success_count += 1
_LOGGER.info("Successfully removed client: %s", client_name)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error removing client: %s", err)
except Exception as err:
_LOGGER.exception("Unexpected error removing client: %s", err)
if success_count == 0:
_LOGGER.error("Failed to remove client %s on any instance", client_name)
async def refresh_data(self, call: ServiceCall) -> None:
"""Refresh data for all coordinators."""
_LOGGER.info("Manually refreshing AdGuard Control Hub data")
for entry_data in self.hass.data.get(DOMAIN, {}).values():
if isinstance(entry_data, dict) and "coordinator" in entry_data:
coordinator = entry_data["coordinator"]
try:
await coordinator.async_request_refresh()
_LOGGER.debug("Refreshed coordinator data")
except Exception as err:
_LOGGER.error("Failed to refresh coordinator: %s", err)