154 lines
6.5 KiB
Python
154 lines
6.5 KiB
Python
"""Service implementations for AdGuard Control Hub integration."""
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict
|
|
|
|
import voluptuous as vol
|
|
from homeassistant.core import HomeAssistant, ServiceCall
|
|
from homeassistant.helpers import config_validation as cv
|
|
|
|
from .api import AdGuardHomeAPI
|
|
from .const import (
|
|
DOMAIN,
|
|
BLOCKED_SERVICES,
|
|
ATTR_CLIENT_NAME,
|
|
ATTR_SERVICES,
|
|
ATTR_DURATION,
|
|
ATTR_CLIENTS,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
|
})
|
|
|
|
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
|
vol.Required(ATTR_DURATION): cv.positive_int,
|
|
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
|
})
|
|
|
|
|
|
class AdGuardControlHubServices:
|
|
"""Handle services for AdGuard Control Hub."""
|
|
|
|
def __init__(self, hass: HomeAssistant):
|
|
"""Initialize the services."""
|
|
self.hass = hass
|
|
|
|
def register_services(self) -> None:
|
|
"""Register all services."""
|
|
self.hass.services.register(
|
|
DOMAIN, "block_services", self.block_services, schema=SCHEMA_BLOCK_SERVICES
|
|
)
|
|
self.hass.services.register(
|
|
DOMAIN, "unblock_services", self.unblock_services, schema=SCHEMA_BLOCK_SERVICES
|
|
)
|
|
self.hass.services.register(
|
|
DOMAIN, "emergency_unblock", self.emergency_unblock, schema=SCHEMA_EMERGENCY_UNBLOCK
|
|
)
|
|
self.hass.services.register(DOMAIN, "add_client", self.add_client)
|
|
self.hass.services.register(DOMAIN, "remove_client", self.remove_client)
|
|
self.hass.services.register(DOMAIN, "bulk_update_clients", self.bulk_update_clients)
|
|
|
|
def unregister_services(self) -> None:
|
|
"""Unregister all services."""
|
|
services = [
|
|
"block_services", "unblock_services", "emergency_unblock",
|
|
"add_client", "remove_client", "bulk_update_clients"
|
|
]
|
|
|
|
for service in services:
|
|
if self.hass.services.has_service(DOMAIN, service):
|
|
self.hass.services.remove(DOMAIN, service)
|
|
|
|
async def block_services(self, call: ServiceCall) -> None:
|
|
"""Block services for a specific client."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services = call.data[ATTR_SERVICES]
|
|
|
|
_LOGGER.info("Blocking services %s for client %s", services, client_name)
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
client = await api.get_client_by_name(client_name)
|
|
if client:
|
|
current_blocked = client.get("blocked_services", {})
|
|
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or []
|
|
updated_services = list(set(current_services + services))
|
|
await api.update_client_blocked_services(client_name, updated_services)
|
|
_LOGGER.info("Successfully blocked services for %s", client_name)
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
|
|
|
|
async def unblock_services(self, call: ServiceCall) -> None:
|
|
"""Unblock services for a specific client."""
|
|
client_name = call.data[ATTR_CLIENT_NAME]
|
|
services = call.data[ATTR_SERVICES]
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
client = await api.get_client_by_name(client_name)
|
|
if client:
|
|
current_blocked = client.get("blocked_services", {})
|
|
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or []
|
|
updated_services = [s for s in current_services if s not in services]
|
|
await api.update_client_blocked_services(client_name, updated_services)
|
|
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
|
|
|
|
async def emergency_unblock(self, call: ServiceCall) -> None:
|
|
"""Emergency unblock - temporarily disable protection."""
|
|
duration = call.data[ATTR_DURATION]
|
|
clients = call.data[ATTR_CLIENTS]
|
|
|
|
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
|
|
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
if "all" in clients:
|
|
await api.set_protection(False)
|
|
# Re-enable after duration
|
|
async def delayed_enable():
|
|
await asyncio.sleep(duration)
|
|
try:
|
|
await api.set_protection(True)
|
|
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to re-enable protection: %s", err)
|
|
|
|
asyncio.create_task(delayed_enable())
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
|
|
|
async def add_client(self, call: ServiceCall) -> None:
|
|
"""Add a new client."""
|
|
client_data = dict(call.data)
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
await api.add_client(client_data)
|
|
_LOGGER.info("Successfully added client: %s", client_data.get("name"))
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to add client: %s", err)
|
|
|
|
async def remove_client(self, call: ServiceCall) -> None:
|
|
"""Remove a client."""
|
|
client_name = call.data.get("name")
|
|
for entry_data in self.hass.data[DOMAIN].values():
|
|
api: AdGuardHomeAPI = entry_data["api"]
|
|
try:
|
|
await api.delete_client(client_name)
|
|
_LOGGER.info("Successfully removed client: %s", client_name)
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to remove client: %s", err)
|
|
|
|
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
|
"""Bulk update clients."""
|
|
_LOGGER.info("Bulk update clients called")
|