Some checks failed
🧪 Integration Testing / 🔧 Test Integration (2025.9.4, 3.13) (push) Failing after 24s
Signed-off-by: Rafal Zielinski <sq4ind@gmail.com>
440 lines
17 KiB
Python
440 lines
17 KiB
Python
"""Service implementations for AdGuard Control Hub integration."""
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Dict, List
|
|
|
|
import voluptuous as vol
|
|
from homeassistant.core import HomeAssistant, ServiceCall
|
|
from homeassistant.helpers import config_validation as cv
|
|
|
|
from .api import AdGuardHomeAPI
|
|
from .const import (
|
|
DOMAIN,
|
|
BLOCKED_SERVICES,
|
|
ATTR_CLIENT_NAME,
|
|
ATTR_SERVICES,
|
|
ATTR_DURATION,
|
|
ATTR_CLIENTS,
|
|
ATTR_CLIENT_PATTERN,
|
|
ATTR_SETTINGS,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Service schemas
|
|
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
|
})
|
|
|
|
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
|
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
|
})
|
|
|
|
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
|
vol.Required(ATTR_DURATION): cv.positive_int,
|
|
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
|
})
|
|
|
|
SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({
|
|
vol.Required(ATTR_CLIENT_PATTERN): cv.string,
|
|
vol.Required(ATTR_SETTINGS): vol.Schema({
|
|
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
|
vol.Optional("filtering_enabled"): cv.boolean,
|
|
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
|
vol.Optional("parental_enabled"): cv.boolean,
|
|
}),
|
|
})
|
|
|
|
SCHEMA_ADD_CLIENT = vol.Schema({
|
|
vol.Required("name"): cv.string,
|
|
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
|
|
vol.Optional("mac"): cv.string,
|
|
vol.Optional("use_global_settings"): cv.boolean,
|
|
vol.Optional("filtering_enabled"): cv.boolean,
|
|
vol.Optional("parental_enabled"): cv.boolean,
|
|
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
|
vol.Optional("safesearch_enabled"): cv.boolean,
|
|
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
|
})
|
|
|
|
SCHEMA_REMOVE_CLIENT = vol.Schema({
|
|
vol.Required("name"): cv.string,
|
|
})
|
|
|
|
SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({
|
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
|
vol.Required("schedule"): vol.Schema({
|
|
vol.Optional("time_zone", default="Local"): cv.string,
|
|
vol.Optional("sun"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
vol.Optional("mon"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
vol.Optional("tue"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
vol.Optional("wed"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
vol.Optional("thu"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
vol.Optional("fri"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
vol.Optional("sat"): vol.Schema({
|
|
vol.Optional("start"): cv.string,
|
|
vol.Optional("end"): cv.string,
|
|
}),
|
|
}),
|
|
})
|
|
|
|
# Service names
|
|
SERVICE_BLOCK_SERVICES = "block_services"
|
|
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
|
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
|
|
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
|
SERVICE_ADD_CLIENT = "add_client"
|
|
SERVICE_REMOVE_CLIENT = "remove_client"
|
|
SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block"
|
|
|
|
|
|
class AdGuardControlHubServices:
|
|
"""Handle services for AdGuard Control Hub."""
|
|
|
|
def __init__(self, hass: HomeAssistant):
|
|
"""Initialize the services."""
|
|
self.hass = hass
|
|
self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
def register_services(self) -> None:
|
|
"""Register all services."""
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_BLOCK_SERVICES,
|
|
self.block_services,
|
|
schema=SCHEMA_BLOCK_SERVICES,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_UNBLOCK_SERVICES,
|
|
self.unblock_services,
|
|
schema=SCHEMA_UNBLOCK_SERVICES,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_EMERGENCY_UNBLOCK,
|
|
self.emergency_unblock,
|
|
schema=SCHEMA_EMERGENCY_UNBLOCK,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_BULK_UPDATE_CLIENTS,
|
|
self.bulk_update_clients,
|
|
schema=SCHEMA_BULK_UPDATE_CLIENTS,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_ADD_CLIENT,
|
|
self.add_client,
|
|
schema=SCHEMA_ADD_CLIENT,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_REMOVE_CLIENT,
|
|
self.remove_client,
|
|
schema=SCHEMA_REMOVE_CLIENT,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
|
self.schedule_service_block,
|
|
schema=SCHEMA_SCHEDULE_SERVICE_BLOCK,
|
|
)
|
|
|
|
def unregister_services(self) -> None:
|
|
"""Unregister all services."""
|
|
services = [
|
|
SERVICE_BLOCK_SERVICES,
|
|
SERVICE_UNBLOCK_SERVICES,
|
|
SERVICE_EMERGENCY_UNBLOCK,
|
|
SERVICE_BULK_UPDATE_CLIENTS,
|
|
SERVICE_ADD_CLIENT,
|
|
SERVICE_REMOVE_CLIENT,
|
|
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
|
]
|
|
|
|
for service in services:
|
|
if self.hass.services.has_service(DOMAIN, service):
|
|
self.hass.services.remove(DOMAIN, service)
|
|
|
|
def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI:
|
|
"""Get API instance for a specific config entry."""
|
|
return self.hass.data[DOMAIN][entry_id]["api"]
|
|
|
|
async def block_services(self, call: ServiceCall) -> None:
|
|
"""Block services for a specific client."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services = call.data[ATTR_SERVICES]
|
|
|
|
_LOGGER.info("Blocking services %s for client %s", services, client_name)
|
|
|
|
# Get all API instances (for multiple AdGuard instances)
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
# Get current client data
|
|
client = await api.get_client_by_name(client_name)
|
|
if not client:
|
|
_LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port)
|
|
continue
|
|
|
|
# Get current blocked services
|
|
current_blocked = client.get("blocked_services", {})
|
|
if isinstance(current_blocked, dict):
|
|
current_services = current_blocked.get("ids", [])
|
|
else:
|
|
current_services = current_blocked if current_blocked else []
|
|
|
|
# Add new services to block
|
|
updated_services = list(set(current_services + services))
|
|
|
|
# Update client
|
|
await api.update_client_blocked_services(client_name, updated_services)
|
|
_LOGGER.info("Successfully blocked services for %s", client_name)
|
|
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
|
|
|
|
async def unblock_services(self, call: ServiceCall) -> None:
|
|
"""Unblock services for a specific client."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services = call.data[ATTR_SERVICES]
|
|
|
|
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
|
|
|
# Get all API instances
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
# Get current client data
|
|
client = await api.get_client_by_name(client_name)
|
|
if not client:
|
|
continue
|
|
|
|
# Get current blocked services
|
|
current_blocked = client.get("blocked_services", {})
|
|
if isinstance(current_blocked, dict):
|
|
current_services = current_blocked.get("ids", [])
|
|
else:
|
|
current_services = current_blocked if current_blocked else []
|
|
|
|
# Remove services to unblock
|
|
updated_services = [s for s in current_services if s not in services]
|
|
|
|
# Update client
|
|
await api.update_client_blocked_services(client_name, updated_services)
|
|
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
|
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
|
|
|
|
async def emergency_unblock(self, call: ServiceCall) -> None:
|
|
"""Emergency unblock - temporarily disable protection."""
|
|
duration = call.data[ATTR_DURATION] # seconds
|
|
clients = call.data[ATTR_CLIENTS]
|
|
|
|
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
|
|
|
|
# Cancel any existing emergency unblock tasks
|
|
for task in self._emergency_unblock_tasks.values():
|
|
task.cancel()
|
|
self._emergency_unblock_tasks.clear()
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
if "all" in clients:
|
|
# Disable global protection
|
|
await api.set_protection(False)
|
|
|
|
# Schedule re-enable
|
|
task = asyncio.create_task(
|
|
self._delayed_enable_protection(api, duration)
|
|
)
|
|
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
|
|
else:
|
|
# Disable protection for specific clients
|
|
for client_name in clients:
|
|
client = await api.get_client_by_name(client_name)
|
|
if client:
|
|
# Store original blocked services
|
|
original_blocked = client.get("blocked_services", {})
|
|
|
|
# Clear blocked services temporarily
|
|
await api.update_client_blocked_services(client_name, [])
|
|
|
|
# Schedule restore
|
|
task = asyncio.create_task(
|
|
self._delayed_restore_client(api, client_name, original_blocked, duration)
|
|
)
|
|
self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task
|
|
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
|
|
|
async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None:
|
|
"""Re-enable protection after delay."""
|
|
await asyncio.sleep(delay)
|
|
try:
|
|
await api.set_protection(True)
|
|
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to re-enable protection: %s", err)
|
|
|
|
async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str,
|
|
original_blocked: Dict, delay: int) -> None:
|
|
"""Restore client blocked services after delay."""
|
|
await asyncio.sleep(delay)
|
|
try:
|
|
if isinstance(original_blocked, dict):
|
|
services = original_blocked.get("ids", [])
|
|
else:
|
|
services = original_blocked if original_blocked else []
|
|
|
|
await api.update_client_blocked_services(client_name, services)
|
|
_LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name)
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to restore client blocking: %s", err)
|
|
|
|
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
|
"""Update multiple clients matching a pattern."""
|
|
import re
|
|
|
|
pattern = call.data[ATTR_CLIENT_PATTERN]
|
|
settings = call.data[ATTR_SETTINGS]
|
|
|
|
_LOGGER.info("Bulk updating clients matching pattern: %s", pattern)
|
|
|
|
# Convert pattern to regex
|
|
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
|
|
compiled_pattern = re.compile(regex_pattern, re.IGNORECASE)
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
coordinator = entry_data["coordinator"]
|
|
|
|
try:
|
|
# Get all clients
|
|
clients = coordinator.clients
|
|
|
|
matching_clients = []
|
|
for client_name in clients.keys():
|
|
if compiled_pattern.match(client_name):
|
|
matching_clients.append(client_name)
|
|
|
|
_LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients)
|
|
|
|
# Update each matching client
|
|
for client_name in matching_clients:
|
|
client = clients[client_name]
|
|
|
|
# Prepare update data
|
|
update_data = {
|
|
"name": client_name,
|
|
"data": {**client} # Start with current data
|
|
}
|
|
|
|
# Apply settings
|
|
if "blocked_services" in settings:
|
|
blocked_services_data = {
|
|
"ids": settings["blocked_services"],
|
|
"schedule": {"time_zone": "Local"}
|
|
}
|
|
update_data["data"]["blocked_services"] = blocked_services_data
|
|
|
|
if "filtering_enabled" in settings:
|
|
update_data["data"]["filtering_enabled"] = settings["filtering_enabled"]
|
|
|
|
if "safebrowsing_enabled" in settings:
|
|
update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"]
|
|
|
|
if "parental_enabled" in settings:
|
|
update_data["data"]["parental_enabled"] = settings["parental_enabled"]
|
|
|
|
# Update the client
|
|
await api.update_client(update_data)
|
|
_LOGGER.info("Updated client: %s", client_name)
|
|
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to bulk update clients: %s", err)
|
|
|
|
async def add_client(self, call: ServiceCall) -> None:
|
|
"""Add a new client."""
|
|
client_data = dict(call.data)
|
|
|
|
# Convert blocked_services to proper format
|
|
if "blocked_services" in client_data and client_data["blocked_services"]:
|
|
blocked_services_data = {
|
|
"ids": client_data["blocked_services"],
|
|
"schedule": {"time_zone": "Local"}
|
|
}
|
|
client_data["blocked_services"] = blocked_services_data
|
|
|
|
_LOGGER.info("Adding new client: %s", client_data["name"])
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
await api.add_client(client_data)
|
|
_LOGGER.info("Successfully added client: %s", client_data["name"])
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to add client %s: %s", client_data["name"], err)
|
|
|
|
async def remove_client(self, call: ServiceCall) -> None:
|
|
"""Remove a client."""
|
|
client_name = call.data["name"]
|
|
|
|
_LOGGER.info("Removing client: %s", client_name)
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
await api.delete_client(client_name)
|
|
_LOGGER.info("Successfully removed client: %s", client_name)
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
|
|
|
|
async def schedule_service_block(self, call: ServiceCall) -> None:
|
|
"""Schedule service blocking with time-based rules."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services = call.data[ATTR_SERVICES]
|
|
schedule = call.data["schedule"]
|
|
|
|
_LOGGER.info("Scheduling service blocking for client %s", client_name)
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
await api.update_client_blocked_services(client_name, services, schedule)
|
|
_LOGGER.info("Successfully scheduled service blocking for %s", client_name)
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)
|