Some checks failed
Code Quality Check / Code Formatting (push) Failing after 23s
Code Quality Check / Security Analysis (push) Failing after 25s
Integration Testing / Integration Tests (2024.12.0, 3.13) (push) Failing after 1m38s
Integration Testing / Integration Tests (2025.9.4, 3.13) (push) Failing after 24s
Signed-off-by: Rafal Zielinski <sq4ind@gmail.com>
247 lines
7.9 KiB
Python
247 lines
7.9 KiB
Python
"""AdGuard Control Hub services."""
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict, List
|
|
|
|
from homeassistant.core import HomeAssistant, ServiceCall
|
|
from homeassistant.helpers import config_validation as cv
|
|
import voluptuous as vol
|
|
|
|
from .api import AdGuardConnectionError, AdGuardHomeError
|
|
from .const import (
|
|
ATTR_CLIENT_NAME,
|
|
ATTR_CLIENTS,
|
|
ATTR_DURATION,
|
|
ATTR_SERVICES,
|
|
BLOCKED_SERVICES,
|
|
DOMAIN,
|
|
SERVICE_ADD_CLIENT,
|
|
SERVICE_BLOCK_SERVICES,
|
|
SERVICE_EMERGENCY_UNBLOCK,
|
|
SERVICE_REFRESH_DATA,
|
|
SERVICE_REMOVE_CLIENT,
|
|
SERVICE_UNBLOCK_SERVICES,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class AdGuardControlHubServices:
|
|
"""AdGuard Control Hub services."""
|
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
"""Initialize services."""
|
|
self.hass = hass
|
|
|
|
def register_services(self) -> None:
|
|
"""Register services."""
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_BLOCK_SERVICES,
|
|
self.block_services,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_UNBLOCK_SERVICES,
|
|
self.unblock_services,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_EMERGENCY_UNBLOCK,
|
|
self.emergency_unblock,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_ADD_CLIENT,
|
|
self.add_client,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_REMOVE_CLIENT,
|
|
self.remove_client,
|
|
)
|
|
|
|
self.hass.services.register(
|
|
DOMAIN,
|
|
SERVICE_REFRESH_DATA,
|
|
self.refresh_data,
|
|
)
|
|
|
|
_LOGGER.info("AdGuard Control Hub services registered")
|
|
|
|
def unregister_services(self) -> None:
|
|
"""Unregister services."""
|
|
services = [
|
|
SERVICE_BLOCK_SERVICES,
|
|
SERVICE_UNBLOCK_SERVICES,
|
|
SERVICE_EMERGENCY_UNBLOCK,
|
|
SERVICE_ADD_CLIENT,
|
|
SERVICE_REMOVE_CLIENT,
|
|
SERVICE_REFRESH_DATA,
|
|
]
|
|
|
|
for service in services:
|
|
if self.hass.services.has_service(DOMAIN, service):
|
|
self.hass.services.remove(DOMAIN, service)
|
|
|
|
_LOGGER.info("AdGuard Control Hub services unregistered")
|
|
|
|
def _get_api(self):
|
|
"""Get API instance from first available entry."""
|
|
for entry_id, entry_data in self.hass.data[DOMAIN].items():
|
|
if isinstance(entry_data, dict) and "api" in entry_data:
|
|
return entry_data["api"]
|
|
raise AdGuardConnectionError("No AdGuard Control Hub API available")
|
|
|
|
def _get_coordinator(self):
|
|
"""Get coordinator instance from first available entry."""
|
|
for entry_id, entry_data in self.hass.data[DOMAIN].items():
|
|
if isinstance(entry_data, dict) and "coordinator" in entry_data:
|
|
return entry_data["coordinator"]
|
|
raise AdGuardConnectionError("No AdGuard Control Hub coordinator available")
|
|
|
|
async def block_services(self, call: ServiceCall) -> None:
|
|
"""Block services for a client."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services_to_block = call.data[ATTR_SERVICES]
|
|
|
|
try:
|
|
api = self._get_api()
|
|
client = await api.get_client_by_name(client_name)
|
|
|
|
if not client:
|
|
_LOGGER.error("Client '%s' not found", client_name)
|
|
return
|
|
|
|
# Get current blocked services and add new ones
|
|
current_blocked = set(client.get("blocked_services", []))
|
|
current_blocked.update(services_to_block)
|
|
|
|
await api.update_client_blocked_services(
|
|
client_name, list(current_blocked)
|
|
)
|
|
|
|
coordinator = self._get_coordinator()
|
|
await coordinator.async_request_refresh()
|
|
|
|
_LOGGER.info(
|
|
"Blocked services %s for client '%s'", services_to_block, client_name
|
|
)
|
|
|
|
except AdGuardHomeError as err:
|
|
_LOGGER.error("Failed to block services for '%s': %s", client_name, err)
|
|
|
|
async def unblock_services(self, call: ServiceCall) -> None:
|
|
"""Unblock services for a client."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services_to_unblock = call.data[ATTR_SERVICES]
|
|
|
|
try:
|
|
api = self._get_api()
|
|
client = await api.get_client_by_name(client_name)
|
|
|
|
if not client:
|
|
_LOGGER.error("Client '%s' not found", client_name)
|
|
return
|
|
|
|
# Get current blocked services and remove specified ones
|
|
current_blocked = set(client.get("blocked_services", []))
|
|
current_blocked.difference_update(services_to_unblock)
|
|
|
|
await api.update_client_blocked_services(
|
|
client_name, list(current_blocked)
|
|
)
|
|
|
|
coordinator = self._get_coordinator()
|
|
await coordinator.async_request_refresh()
|
|
|
|
_LOGGER.info(
|
|
"Unblocked services %s for client '%s'", services_to_unblock, client_name
|
|
)
|
|
|
|
except AdGuardHomeError as err:
|
|
_LOGGER.error("Failed to unblock services for '%s': %s", client_name, err)
|
|
|
|
async def emergency_unblock(self, call: ServiceCall) -> None:
|
|
"""Emergency unblock - disable protection temporarily."""
|
|
duration = call.data.get(ATTR_DURATION, 300)
|
|
clients = call.data.get(ATTR_CLIENTS, ["all"])
|
|
|
|
try:
|
|
api = self._get_api()
|
|
|
|
if "all" in clients:
|
|
# Global protection disable
|
|
await api.set_protection(False)
|
|
_LOGGER.warning(
|
|
"Emergency unblock activated globally for %d seconds", duration
|
|
)
|
|
|
|
coordinator = self._get_coordinator()
|
|
await coordinator.async_request_refresh()
|
|
|
|
# Schedule re-enabling protection
|
|
async def restore_protection():
|
|
await asyncio.sleep(duration)
|
|
try:
|
|
if "all" in clients:
|
|
await api.set_protection(True)
|
|
|
|
await coordinator.async_request_refresh()
|
|
_LOGGER.info("Emergency unblock period ended, protection restored")
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to restore protection after emergency unblock: %s", err)
|
|
|
|
# Schedule restoration
|
|
self.hass.async_create_task(restore_protection())
|
|
|
|
except AdGuardHomeError as err:
|
|
_LOGGER.error("Failed to activate emergency unblock: %s", err)
|
|
|
|
async def add_client(self, call: ServiceCall) -> None:
|
|
"""Add a new client."""
|
|
client_data = dict(call.data)
|
|
|
|
try:
|
|
api = self._get_api()
|
|
await api.add_client(client_data)
|
|
|
|
coordinator = self._get_coordinator()
|
|
await coordinator.async_request_refresh()
|
|
|
|
_LOGGER.info("Added new client: %s", client_data["name"])
|
|
|
|
except AdGuardHomeError as err:
|
|
_LOGGER.error("Failed to add client '%s': %s", client_data["name"], err)
|
|
|
|
async def remove_client(self, call: ServiceCall) -> None:
|
|
"""Remove a client."""
|
|
client_name = call.data["name"]
|
|
|
|
try:
|
|
api = self._get_api()
|
|
await api.delete_client(client_name)
|
|
|
|
coordinator = self._get_coordinator()
|
|
await coordinator.async_request_refresh()
|
|
|
|
_LOGGER.info("Removed client: %s", client_name)
|
|
|
|
except AdGuardHomeError as err:
|
|
_LOGGER.error("Failed to remove client '%s': %s", client_name, err)
|
|
|
|
async def refresh_data(self, call: ServiceCall) -> None:
|
|
"""Refresh data from AdGuard Home."""
|
|
try:
|
|
coordinator = self._get_coordinator()
|
|
await coordinator.async_request_refresh()
|
|
|
|
_LOGGER.info("Data refresh requested")
|
|
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to refresh data: %s", err)
|