@@ -1,8 +1,7 @@
|
||||
"""Service implementations for AdGuard Control Hub integration."""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
import voluptuous as vol
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
@@ -16,8 +15,6 @@ from .const import (
|
||||
ATTR_SERVICES,
|
||||
ATTR_DURATION,
|
||||
ATTR_CLIENTS,
|
||||
ATTR_CLIENT_PATTERN,
|
||||
ATTR_SETTINGS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -28,86 +25,17 @@ SCHEMA_BLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
||||
vol.Required(ATTR_DURATION): cv.positive_int,
|
||||
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
||||
})
|
||||
|
||||
SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_PATTERN): cv.string,
|
||||
vol.Required(ATTR_SETTINGS): vol.Schema({
|
||||
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
vol.Optional("filtering_enabled"): cv.boolean,
|
||||
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
||||
vol.Optional("parental_enabled"): cv.boolean,
|
||||
}),
|
||||
})
|
||||
|
||||
SCHEMA_ADD_CLIENT = vol.Schema({
|
||||
vol.Required("name"): cv.string,
|
||||
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional("mac"): cv.string,
|
||||
vol.Optional("use_global_settings"): cv.boolean,
|
||||
vol.Optional("filtering_enabled"): cv.boolean,
|
||||
vol.Optional("parental_enabled"): cv.boolean,
|
||||
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
||||
vol.Optional("safesearch_enabled"): cv.boolean,
|
||||
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_REMOVE_CLIENT = vol.Schema({
|
||||
vol.Required("name"): cv.string,
|
||||
})
|
||||
|
||||
SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
vol.Required("schedule"): vol.Schema({
|
||||
vol.Optional("time_zone", default="Local"): cv.string,
|
||||
vol.Optional("sun"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("mon"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("tue"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("wed"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("thu"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("fri"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("sat"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
# Service names
|
||||
SERVICE_BLOCK_SERVICES = "block_services"
|
||||
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
||||
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
||||
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
|
||||
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
||||
SERVICE_ADD_CLIENT = "add_client"
|
||||
SERVICE_REMOVE_CLIENT = "remove_client"
|
||||
SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block"
|
||||
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
||||
|
||||
|
||||
class AdGuardControlHubServices:
|
||||
@@ -131,7 +59,7 @@ class AdGuardControlHubServices:
|
||||
DOMAIN,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
self.unblock_services,
|
||||
schema=SCHEMA_UNBLOCK_SERVICES,
|
||||
schema=SCHEMA_BLOCK_SERVICES,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
@@ -141,33 +69,10 @@ class AdGuardControlHubServices:
|
||||
schema=SCHEMA_EMERGENCY_UNBLOCK,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_BULK_UPDATE_CLIENTS,
|
||||
self.bulk_update_clients,
|
||||
schema=SCHEMA_BULK_UPDATE_CLIENTS,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_ADD_CLIENT,
|
||||
self.add_client,
|
||||
schema=SCHEMA_ADD_CLIENT,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
self.remove_client,
|
||||
schema=SCHEMA_REMOVE_CLIENT,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
||||
self.schedule_service_block,
|
||||
schema=SCHEMA_SCHEDULE_SERVICE_BLOCK,
|
||||
)
|
||||
# Additional services would go here
|
||||
self.hass.services.register(DOMAIN, SERVICE_ADD_CLIENT, self.add_client)
|
||||
self.hass.services.register(DOMAIN, SERVICE_REMOVE_CLIENT, self.remove_client)
|
||||
self.hass.services.register(DOMAIN, SERVICE_BULK_UPDATE_CLIENTS, self.bulk_update_clients)
|
||||
|
||||
def unregister_services(self) -> None:
|
||||
"""Unregister all services."""
|
||||
@@ -175,20 +80,15 @@ class AdGuardControlHubServices:
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
SERVICE_BULK_UPDATE_CLIENTS,
|
||||
SERVICE_ADD_CLIENT,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
||||
SERVICE_BULK_UPDATE_CLIENTS,
|
||||
]
|
||||
|
||||
for service in services:
|
||||
if self.hass.services.has_service(DOMAIN, service):
|
||||
self.hass.services.remove(DOMAIN, service)
|
||||
|
||||
def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI:
|
||||
"""Get API instance for a specific config entry."""
|
||||
return self.hass.data[DOMAIN][entry_id]["api"]
|
||||
|
||||
async def block_services(self, call: ServiceCall) -> None:
|
||||
"""Block services for a specific client."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
@@ -196,30 +96,16 @@ class AdGuardControlHubServices:
|
||||
|
||||
_LOGGER.info("Blocking services %s for client %s", services, client_name)
|
||||
|
||||
# Get all API instances (for multiple AdGuard instances)
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
# Get current client data
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if not client:
|
||||
_LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port)
|
||||
continue
|
||||
|
||||
# Get current blocked services
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
if isinstance(current_blocked, dict):
|
||||
current_services = current_blocked.get("ids", [])
|
||||
else:
|
||||
current_services = current_blocked if current_blocked else []
|
||||
|
||||
# Add new services to block
|
||||
updated_services = list(set(current_services + services))
|
||||
|
||||
# Update client
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
_LOGGER.info("Successfully blocked services for %s", client_name)
|
||||
|
||||
if client:
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or []
|
||||
updated_services = list(set(current_services + services))
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
_LOGGER.info("Successfully blocked services for %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
|
||||
|
||||
@@ -230,73 +116,33 @@ class AdGuardControlHubServices:
|
||||
|
||||
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
||||
|
||||
# Get all API instances
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
# Get current client data
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if not client:
|
||||
continue
|
||||
|
||||
# Get current blocked services
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
if isinstance(current_blocked, dict):
|
||||
current_services = current_blocked.get("ids", [])
|
||||
else:
|
||||
current_services = current_blocked if current_blocked else []
|
||||
|
||||
# Remove services to unblock
|
||||
updated_services = [s for s in current_services if s not in services]
|
||||
|
||||
# Update client
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
||||
|
||||
if client:
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or []
|
||||
updated_services = [s for s in current_services if s not in services]
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
|
||||
|
||||
async def emergency_unblock(self, call: ServiceCall) -> None:
|
||||
"""Emergency unblock - temporarily disable protection."""
|
||||
duration = call.data[ATTR_DURATION] # seconds
|
||||
duration = call.data[ATTR_DURATION]
|
||||
clients = call.data[ATTR_CLIENTS]
|
||||
|
||||
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
|
||||
|
||||
# Cancel any existing emergency unblock tasks
|
||||
for task in self._emergency_unblock_tasks.values():
|
||||
task.cancel()
|
||||
self._emergency_unblock_tasks.clear()
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
if "all" in clients:
|
||||
# Disable global protection
|
||||
await api.set_protection(False)
|
||||
|
||||
# Schedule re-enable
|
||||
task = asyncio.create_task(
|
||||
self._delayed_enable_protection(api, duration)
|
||||
)
|
||||
task = asyncio.create_task(self._delayed_enable_protection(api, duration))
|
||||
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
|
||||
else:
|
||||
# Disable protection for specific clients
|
||||
for client_name in clients:
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if client:
|
||||
# Store original blocked services
|
||||
original_blocked = client.get("blocked_services", {})
|
||||
|
||||
# Clear blocked services temporarily
|
||||
await api.update_client_blocked_services(client_name, [])
|
||||
|
||||
# Schedule restore
|
||||
task = asyncio.create_task(
|
||||
self._delayed_restore_client(api, client_name, original_blocked, duration)
|
||||
)
|
||||
self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
||||
|
||||
@@ -309,109 +155,22 @@ class AdGuardControlHubServices:
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to re-enable protection: %s", err)
|
||||
|
||||
async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str,
|
||||
original_blocked: Dict, delay: int) -> None:
|
||||
"""Restore client blocked services after delay."""
|
||||
await asyncio.sleep(delay)
|
||||
try:
|
||||
if isinstance(original_blocked, dict):
|
||||
services = original_blocked.get("ids", [])
|
||||
else:
|
||||
services = original_blocked if original_blocked else []
|
||||
|
||||
await api.update_client_blocked_services(client_name, services)
|
||||
_LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to restore client blocking: %s", err)
|
||||
|
||||
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
||||
"""Update multiple clients matching a pattern."""
|
||||
import re
|
||||
|
||||
pattern = call.data[ATTR_CLIENT_PATTERN]
|
||||
settings = call.data[ATTR_SETTINGS]
|
||||
|
||||
_LOGGER.info("Bulk updating clients matching pattern: %s", pattern)
|
||||
|
||||
# Convert pattern to regex
|
||||
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
|
||||
compiled_pattern = re.compile(regex_pattern, re.IGNORECASE)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
coordinator = entry_data["coordinator"]
|
||||
|
||||
try:
|
||||
# Get all clients
|
||||
clients = coordinator.clients
|
||||
|
||||
matching_clients = []
|
||||
for client_name in clients.keys():
|
||||
if compiled_pattern.match(client_name):
|
||||
matching_clients.append(client_name)
|
||||
|
||||
_LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients)
|
||||
|
||||
# Update each matching client
|
||||
for client_name in matching_clients:
|
||||
client = clients[client_name]
|
||||
|
||||
# Prepare update data
|
||||
update_data = {
|
||||
"name": client_name,
|
||||
"data": {**client} # Start with current data
|
||||
}
|
||||
|
||||
# Apply settings
|
||||
if "blocked_services" in settings:
|
||||
blocked_services_data = {
|
||||
"ids": settings["blocked_services"],
|
||||
"schedule": {"time_zone": "Local"}
|
||||
}
|
||||
update_data["data"]["blocked_services"] = blocked_services_data
|
||||
|
||||
if "filtering_enabled" in settings:
|
||||
update_data["data"]["filtering_enabled"] = settings["filtering_enabled"]
|
||||
|
||||
if "safebrowsing_enabled" in settings:
|
||||
update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"]
|
||||
|
||||
if "parental_enabled" in settings:
|
||||
update_data["data"]["parental_enabled"] = settings["parental_enabled"]
|
||||
|
||||
# Update the client
|
||||
await api.update_client(update_data)
|
||||
_LOGGER.info("Updated client: %s", client_name)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to bulk update clients: %s", err)
|
||||
|
||||
async def add_client(self, call: ServiceCall) -> None:
|
||||
"""Add a new client."""
|
||||
client_data = dict(call.data)
|
||||
|
||||
# Convert blocked_services to proper format
|
||||
if "blocked_services" in client_data and client_data["blocked_services"]:
|
||||
blocked_services_data = {
|
||||
"ids": client_data["blocked_services"],
|
||||
"schedule": {"time_zone": "Local"}
|
||||
}
|
||||
client_data["blocked_services"] = blocked_services_data
|
||||
|
||||
_LOGGER.info("Adding new client: %s", client_data["name"])
|
||||
_LOGGER.info("Adding new client: %s", client_data.get("name"))
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
await api.add_client(client_data)
|
||||
_LOGGER.info("Successfully added client: %s", client_data["name"])
|
||||
_LOGGER.info("Successfully added client: %s", client_data.get("name"))
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to add client %s: %s", client_data["name"], err)
|
||||
_LOGGER.error("Failed to add client: %s", err)
|
||||
|
||||
async def remove_client(self, call: ServiceCall) -> None:
|
||||
"""Remove a client."""
|
||||
client_name = call.data["name"]
|
||||
|
||||
client_name = call.data.get("name")
|
||||
_LOGGER.info("Removing client: %s", client_name)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
@@ -422,18 +181,7 @@ class AdGuardControlHubServices:
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
|
||||
|
||||
async def schedule_service_block(self, call: ServiceCall) -> None:
|
||||
"""Schedule service blocking with time-based rules."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
schedule = call.data["schedule"]
|
||||
|
||||
_LOGGER.info("Scheduling service blocking for client %s", client_name)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
await api.update_client_blocked_services(client_name, services, schedule)
|
||||
_LOGGER.info("Successfully scheduled service blocking for %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)
|
||||
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
||||
"""Update multiple clients matching a pattern."""
|
||||
_LOGGER.info("Bulk update clients called")
|
||||
# Implementation would go here
|
||||
|
Reference in New Issue
Block a user