@@ -1,8 +1,7 @@
|
||||
"""
|
||||
AdGuard Control Hub for Home Assistant.
|
||||
|
||||
Transform your AdGuard Home into a smart network management powerhouse with
|
||||
complete client control, service blocking, and automation capabilities.
|
||||
Transform your AdGuard Home into a smart network management powerhouse.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
@@ -76,12 +75,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err
|
||||
|
||||
# Register services (only once, not per config entry)
|
||||
# Register services (only once)
|
||||
if not hass.services.has_service(DOMAIN, "block_services"):
|
||||
services = AdGuardControlHubServices(hass)
|
||||
services.register_services()
|
||||
|
||||
# Store services instance for cleanup
|
||||
hass.data.setdefault(f"{DOMAIN}_services", services)
|
||||
|
||||
_LOGGER.info("AdGuard Control Hub setup complete for %s:%s",
|
||||
@@ -98,13 +95,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
|
||||
# Unregister services if this was the last entry
|
||||
if not hass.data[DOMAIN]: # No more entries
|
||||
if not hass.data[DOMAIN]:
|
||||
services = hass.data.get(f"{DOMAIN}_services")
|
||||
if services:
|
||||
services.unregister_services()
|
||||
hass.data.pop(f"{DOMAIN}_services", None)
|
||||
|
||||
# Also clean up the empty domain entry
|
||||
hass.data.pop(DOMAIN, None)
|
||||
|
||||
return unload_ok
|
||||
@@ -129,7 +124,7 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
||||
async def _async_update_data(self) -> Dict[str, Any]:
|
||||
"""Fetch data from AdGuard Home."""
|
||||
try:
|
||||
# Fetch all data concurrently for better performance
|
||||
# Fetch all data concurrently
|
||||
tasks = [
|
||||
self.api.get_clients(),
|
||||
self.api.get_statistics(),
|
||||
@@ -139,37 +134,25 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
clients, statistics, status = results
|
||||
|
||||
# Handle any exceptions in individual requests
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
endpoint_names = ["clients", "statistics", "status"]
|
||||
_LOGGER.warning(
|
||||
"Error fetching %s from %s:%s: %s",
|
||||
endpoint_names[i],
|
||||
self.api.host,
|
||||
self.api.port,
|
||||
result
|
||||
)
|
||||
|
||||
# Update stored data (use empty dict if fetch failed)
|
||||
if not isinstance(clients, Exception):
|
||||
self._clients = {
|
||||
client["name"]: client
|
||||
for client in clients.get("clients", [])
|
||||
if client.get("name") # Ensure client has a name
|
||||
if client.get("name")
|
||||
}
|
||||
else:
|
||||
_LOGGER.warning("Failed to update clients data, keeping previous data")
|
||||
_LOGGER.warning("Failed to update clients data: %s", clients)
|
||||
|
||||
if not isinstance(statistics, Exception):
|
||||
self._statistics = statistics
|
||||
else:
|
||||
_LOGGER.warning("Failed to update statistics data, keeping previous data")
|
||||
_LOGGER.warning("Failed to update statistics data: %s", statistics)
|
||||
|
||||
if not isinstance(status, Exception):
|
||||
self._protection_status = status
|
||||
else:
|
||||
_LOGGER.warning("Failed to update status data, keeping previous data")
|
||||
_LOGGER.warning("Failed to update status data: %s", status)
|
||||
|
||||
return {
|
||||
"clients": self._clients,
|
||||
@@ -196,21 +179,3 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
||||
def protection_status(self) -> Dict[str, Any]:
|
||||
"""Return protection status data."""
|
||||
return self._protection_status
|
||||
|
||||
def get_client(self, client_name: str) -> Dict[str, Any] | None:
|
||||
"""Get a specific client by name."""
|
||||
return self._clients.get(client_name)
|
||||
|
||||
def has_client(self, client_name: str) -> bool:
|
||||
"""Check if a client exists."""
|
||||
return client_name in self._clients
|
||||
|
||||
@property
|
||||
def client_count(self) -> int:
|
||||
"""Return the number of clients."""
|
||||
return len(self._clients)
|
||||
|
||||
@property
|
||||
def is_protection_enabled(self) -> bool:
|
||||
"""Return True if protection is enabled."""
|
||||
return self._protection_status.get("protection_enabled", False)
|
||||
|
@@ -10,23 +10,27 @@ from .const import API_ENDPOINTS
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Custom exceptions
|
||||
|
||||
class AdGuardHomeError(Exception):
|
||||
"""Base exception for AdGuard Home API."""
|
||||
pass
|
||||
|
||||
|
||||
class AdGuardConnectionError(AdGuardHomeError):
|
||||
"""Exception for connection errors."""
|
||||
pass
|
||||
|
||||
|
||||
class AdGuardAuthError(AdGuardHomeError):
|
||||
"""Exception for authentication errors."""
|
||||
pass
|
||||
|
||||
|
||||
class AdGuardNotFoundError(AdGuardHomeError):
|
||||
"""Exception for not found errors."""
|
||||
pass
|
||||
|
||||
|
||||
class AdGuardHomeAPI:
|
||||
"""API wrapper for AdGuard Home."""
|
||||
|
||||
@@ -71,7 +75,7 @@ class AdGuardHomeAPI:
|
||||
return self._session
|
||||
|
||||
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Make an API request with comprehensive error handling."""
|
||||
"""Make an API request."""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
auth = None
|
||||
@@ -84,11 +88,8 @@ class AdGuardHomeAPI:
|
||||
method, url, json=data, headers=headers, auth=auth
|
||||
) as response:
|
||||
|
||||
# Handle different HTTP status codes
|
||||
if response.status == 401:
|
||||
raise AdGuardAuthError("Authentication failed - check username/password")
|
||||
elif response.status == 403:
|
||||
raise AdGuardAuthError("Access forbidden - insufficient permissions")
|
||||
raise AdGuardAuthError("Authentication failed")
|
||||
elif response.status == 404:
|
||||
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
||||
elif response.status >= 500:
|
||||
@@ -96,24 +97,20 @@ class AdGuardHomeAPI:
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Handle empty responses
|
||||
if response.status == 204 or not response.content_length:
|
||||
return {}
|
||||
|
||||
try:
|
||||
return await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
# Handle non-JSON responses
|
||||
text = await response.text()
|
||||
_LOGGER.warning("Non-JSON response received: %s", text)
|
||||
return {"response": text}
|
||||
|
||||
except asyncio.TimeoutError as err:
|
||||
raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}")
|
||||
raise AdGuardConnectionError(f"Timeout: {err}")
|
||||
except ClientError as err:
|
||||
raise AdGuardConnectionError(f"Client error: {err}")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err)
|
||||
raise AdGuardHomeError(f"Unexpected error: {err}")
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
@@ -121,8 +118,7 @@ class AdGuardHomeAPI:
|
||||
try:
|
||||
await self._request("GET", API_ENDPOINTS["status"])
|
||||
return True
|
||||
except Exception as err:
|
||||
_LOGGER.debug("Connection test failed: %s", err)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
@@ -144,7 +140,6 @@ class AdGuardHomeAPI:
|
||||
|
||||
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Add a new client configuration."""
|
||||
# Validate required fields
|
||||
if "name" not in client_data:
|
||||
raise ValueError("Client name is required")
|
||||
if "ids" not in client_data or not client_data["ids"]:
|
||||
@@ -155,9 +150,9 @@ class AdGuardHomeAPI:
|
||||
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update an existing client configuration."""
|
||||
if "name" not in client_data:
|
||||
raise ValueError("Client name is required for update")
|
||||
raise ValueError("Client name is required")
|
||||
if "data" not in client_data:
|
||||
raise ValueError("Client data is required for update")
|
||||
raise ValueError("Client data is required")
|
||||
|
||||
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
||||
|
||||
@@ -183,15 +178,13 @@ class AdGuardHomeAPI:
|
||||
return client
|
||||
|
||||
return None
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to get client %s: %s", client_name, err)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_client_blocked_services(
|
||||
self,
|
||||
client_name: str,
|
||||
blocked_services: list,
|
||||
schedule: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Update blocked services for a specific client."""
|
||||
if not client_name:
|
||||
@@ -201,21 +194,11 @@ class AdGuardHomeAPI:
|
||||
if not client:
|
||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||
|
||||
# Prepare the blocked services data with proper structure
|
||||
if schedule:
|
||||
blocked_services_data = {
|
||||
"ids": blocked_services,
|
||||
"schedule": schedule
|
||||
}
|
||||
else:
|
||||
blocked_services_data = {
|
||||
"ids": blocked_services,
|
||||
"schedule": {
|
||||
"time_zone": "Local"
|
||||
}
|
||||
}
|
||||
blocked_services_data = {
|
||||
"ids": blocked_services,
|
||||
"schedule": {"time_zone": "Local"}
|
||||
}
|
||||
|
||||
# Update the client with new blocked services
|
||||
update_data = {
|
||||
"name": client_name,
|
||||
"data": {
|
||||
@@ -226,37 +209,6 @@ class AdGuardHomeAPI:
|
||||
|
||||
return await self.update_client(update_data)
|
||||
|
||||
async def toggle_client_service(
|
||||
self, client_name: str, service_id: str, enabled: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""Toggle a specific service for a client."""
|
||||
if not client_name or not service_id:
|
||||
raise ValueError("Client name and service ID are required")
|
||||
|
||||
client = await self.get_client_by_name(client_name)
|
||||
if not client:
|
||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||
|
||||
# Get current blocked services
|
||||
blocked_services = client.get("blocked_services", {})
|
||||
if isinstance(blocked_services, dict):
|
||||
service_ids = blocked_services.get("ids", [])
|
||||
else:
|
||||
# Handle legacy format (direct list)
|
||||
service_ids = blocked_services if blocked_services else []
|
||||
|
||||
# Update the service list
|
||||
if enabled and service_id not in service_ids:
|
||||
service_ids.append(service_id)
|
||||
elif not enabled and service_id in service_ids:
|
||||
service_ids.remove(service_id)
|
||||
|
||||
return await self.update_client_blocked_services(client_name, service_ids)
|
||||
|
||||
async def get_blocked_services(self) -> Dict[str, Any]:
|
||||
"""Get available blocked services."""
|
||||
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the API session if we own it."""
|
||||
if self._own_session and self._session:
|
||||
|
@@ -34,25 +34,20 @@ STEP_USER_DATA_SCHEMA = vol.Schema({
|
||||
|
||||
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate the user input allows us to connect."""
|
||||
# Normalize host
|
||||
host = data[CONF_HOST].strip()
|
||||
if not host:
|
||||
raise InvalidHost("Host cannot be empty")
|
||||
|
||||
# Remove protocol if provided
|
||||
if host.startswith(("http://", "https://")):
|
||||
host = host.split("://", 1)[1]
|
||||
data[CONF_HOST] = host
|
||||
|
||||
# Validate port
|
||||
port = data[CONF_PORT]
|
||||
if not (1 <= port <= 65535):
|
||||
raise InvalidPort("Port must be between 1 and 65535")
|
||||
|
||||
# Create session with appropriate SSL settings
|
||||
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
|
||||
|
||||
# Create API instance
|
||||
api = AdGuardHomeAPI(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -60,48 +55,38 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
password=data.get(CONF_PASSWORD),
|
||||
ssl=data.get(CONF_SSL, False),
|
||||
session=session,
|
||||
timeout=10, # 10 second timeout for setup
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Test the connection
|
||||
try:
|
||||
if not await api.test_connection():
|
||||
raise CannotConnect("Failed to connect to AdGuard Home")
|
||||
|
||||
# Get additional server info if possible
|
||||
try:
|
||||
status = await api.get_status()
|
||||
version = status.get("version", "unknown")
|
||||
dns_port = status.get("dns_port", "N/A")
|
||||
|
||||
return {
|
||||
"title": f"AdGuard Control Hub ({host})",
|
||||
"version": version,
|
||||
"dns_port": dns_port,
|
||||
"host": host,
|
||||
}
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Could not get server status, but connection works: %s", err)
|
||||
except Exception:
|
||||
return {
|
||||
"title": f"AdGuard Control Hub ({host})",
|
||||
"version": "unknown",
|
||||
"dns_port": "N/A",
|
||||
"host": host,
|
||||
}
|
||||
|
||||
except AdGuardAuthError as err:
|
||||
_LOGGER.error("Authentication failed: %s", err)
|
||||
raise InvalidAuth from err
|
||||
except AdGuardConnectionError as err:
|
||||
_LOGGER.error("Connection failed: %s", err)
|
||||
if "timeout" in str(err).lower():
|
||||
raise Timeout from err
|
||||
raise CannotConnect from err
|
||||
except asyncio.TimeoutError as err:
|
||||
_LOGGER.error("Connection timeout: %s", err)
|
||||
raise Timeout from err
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error during validation: %s", err)
|
||||
raise CannotConnect from err
|
||||
|
||||
|
||||
@@ -121,7 +106,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
try:
|
||||
info = await validate_input(self.hass, user_input)
|
||||
|
||||
# Create unique ID based on host and port
|
||||
unique_id = f"{info['host']}:{user_input[CONF_PORT]}"
|
||||
await self.async_set_unique_id(unique_id)
|
||||
self._abort_if_unique_id_configured()
|
||||
@@ -142,7 +126,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
except Timeout:
|
||||
errors["base"] = "timeout"
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception during config flow")
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return self.async_show_form(
|
||||
@@ -151,48 +135,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_import(self, import_info: Dict[str, Any]) -> FlowResult:
|
||||
"""Handle configuration import."""
|
||||
return await self.async_step_user(import_info)
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get the options flow for this handler."""
|
||||
return OptionsFlowHandler(config_entry)
|
||||
|
||||
|
||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle options flow for AdGuard Control Hub."""
|
||||
|
||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, Any]] = None
|
||||
) -> FlowResult:
|
||||
"""Handle options flow."""
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
|
||||
options_schema = vol.Schema({
|
||||
vol.Optional(
|
||||
"scan_interval",
|
||||
default=self.config_entry.options.get("scan_interval", 30),
|
||||
): vol.All(vol.Coerce(int), vol.Range(min=10, max=300)),
|
||||
vol.Optional(
|
||||
"timeout",
|
||||
default=self.config_entry.options.get("timeout", 10),
|
||||
): vol.All(vol.Coerce(int), vol.Range(min=5, max=60)),
|
||||
})
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=options_schema,
|
||||
)
|
||||
|
||||
|
||||
# Custom exceptions
|
||||
class CannotConnect(Exception):
|
||||
"""Error to indicate we cannot connect."""
|
||||
|
||||
|
@@ -29,48 +29,28 @@ API_ENDPOINTS: Final = {
|
||||
"clients_update": "/control/clients/update",
|
||||
"clients_delete": "/control/clients/delete",
|
||||
"blocked_services_all": "/control/blocked_services/all",
|
||||
"blocked_services_get": "/control/blocked_services/get",
|
||||
"blocked_services_update": "/control/blocked_services/update",
|
||||
"protection": "/control/protection",
|
||||
"stats": "/control/stats",
|
||||
}
|
||||
|
||||
# Available blocked services with friendly names
|
||||
# Available blocked services
|
||||
BLOCKED_SERVICES: Final = {
|
||||
# Social Media
|
||||
"youtube": "YouTube",
|
||||
"facebook": "Facebook",
|
||||
"instagram": "Instagram",
|
||||
"netflix": "Netflix",
|
||||
"gaming": "Gaming Services",
|
||||
"instagram": "Instagram",
|
||||
"tiktok": "TikTok",
|
||||
"twitter": "Twitter/X",
|
||||
"snapchat": "Snapchat",
|
||||
"reddit": "Reddit",
|
||||
|
||||
# Entertainment
|
||||
"netflix": "Netflix",
|
||||
"disney_plus": "Disney+",
|
||||
"spotify": "Spotify",
|
||||
"twitch": "Twitch",
|
||||
|
||||
# Gaming
|
||||
"gaming": "Gaming Services",
|
||||
"steam": "Steam",
|
||||
"epic_games": "Epic Games",
|
||||
"roblox": "Roblox",
|
||||
|
||||
# Shopping
|
||||
"amazon": "Amazon",
|
||||
"ebay": "eBay",
|
||||
|
||||
# Communication
|
||||
"whatsapp": "WhatsApp",
|
||||
"telegram": "Telegram",
|
||||
"discord": "Discord",
|
||||
|
||||
# Other
|
||||
"adult": "Adult Content",
|
||||
"gambling": "Gambling Sites",
|
||||
"torrents": "Torrent Sites",
|
||||
}
|
||||
|
||||
# Service attributes
|
||||
@@ -78,15 +58,9 @@ ATTR_CLIENT_NAME: Final = "client_name"
|
||||
ATTR_SERVICES: Final = "services"
|
||||
ATTR_DURATION: Final = "duration"
|
||||
ATTR_CLIENTS: Final = "clients"
|
||||
ATTR_CLIENT_PATTERN: Final = "client_pattern"
|
||||
ATTR_SETTINGS: Final = "settings"
|
||||
|
||||
# Icons
|
||||
ICON_HUB: Final = "mdi:router-network"
|
||||
ICON_PROTECTION: Final = "mdi:shield"
|
||||
ICON_PROTECTION_OFF: Final = "mdi:shield-off"
|
||||
ICON_CLIENT: Final = "mdi:devices"
|
||||
ICON_CLIENT_OFFLINE: Final = "mdi:devices-off"
|
||||
ICON_BLOCKED_SERVICE: Final = "mdi:block-helper"
|
||||
ICON_ALLOWED_SERVICE: Final = "mdi:check-circle"
|
||||
ICON_STATISTICS: Final = "mdi:chart-line"
|
||||
|
@@ -1,9 +1,8 @@
|
||||
"""Sensor platform for AdGuard Control Hub integration."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.sensor import SensorEntity, SensorDeviceClass, SensorStateClass
|
||||
from homeassistant.components.sensor import SensorEntity, SensorStateClass
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import PERCENTAGE
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -100,7 +99,6 @@ class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
|
||||
self._attr_icon = "mdi:percent"
|
||||
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||
self._attr_native_unit_of_measurement = PERCENTAGE
|
||||
self._attr_device_class = SensorDeviceClass.POWER_FACTOR
|
||||
|
||||
@property
|
||||
def native_value(self) -> float | None:
|
||||
|
@@ -19,7 +19,6 @@ from .const import (
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Service schemas
|
||||
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
@@ -30,13 +29,6 @@ SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
||||
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
||||
})
|
||||
|
||||
SERVICE_BLOCK_SERVICES = "block_services"
|
||||
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
||||
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
|
||||
SERVICE_ADD_CLIENT = "add_client"
|
||||
SERVICE_REMOVE_CLIENT = "remove_client"
|
||||
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
||||
|
||||
|
||||
class AdGuardControlHubServices:
|
||||
"""Handle services for AdGuard Control Hub."""
|
||||
@@ -44,45 +36,27 @@ class AdGuardControlHubServices:
|
||||
def __init__(self, hass: HomeAssistant):
|
||||
"""Initialize the services."""
|
||||
self.hass = hass
|
||||
self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def register_services(self) -> None:
|
||||
"""Register all services."""
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
self.block_services,
|
||||
schema=SCHEMA_BLOCK_SERVICES,
|
||||
DOMAIN, "block_services", self.block_services, schema=SCHEMA_BLOCK_SERVICES
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
self.unblock_services,
|
||||
schema=SCHEMA_BLOCK_SERVICES,
|
||||
DOMAIN, "unblock_services", self.unblock_services, schema=SCHEMA_BLOCK_SERVICES
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
self.emergency_unblock,
|
||||
schema=SCHEMA_EMERGENCY_UNBLOCK,
|
||||
DOMAIN, "emergency_unblock", self.emergency_unblock, schema=SCHEMA_EMERGENCY_UNBLOCK
|
||||
)
|
||||
|
||||
# Additional services would go here
|
||||
self.hass.services.register(DOMAIN, SERVICE_ADD_CLIENT, self.add_client)
|
||||
self.hass.services.register(DOMAIN, SERVICE_REMOVE_CLIENT, self.remove_client)
|
||||
self.hass.services.register(DOMAIN, SERVICE_BULK_UPDATE_CLIENTS, self.bulk_update_clients)
|
||||
self.hass.services.register(DOMAIN, "add_client", self.add_client)
|
||||
self.hass.services.register(DOMAIN, "remove_client", self.remove_client)
|
||||
self.hass.services.register(DOMAIN, "bulk_update_clients", self.bulk_update_clients)
|
||||
|
||||
def unregister_services(self) -> None:
|
||||
"""Unregister all services."""
|
||||
services = [
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
SERVICE_ADD_CLIENT,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
SERVICE_BULK_UPDATE_CLIENTS,
|
||||
"block_services", "unblock_services", "emergency_unblock",
|
||||
"add_client", "remove_client", "bulk_update_clients"
|
||||
]
|
||||
|
||||
for service in services:
|
||||
@@ -114,8 +88,6 @@ class AdGuardControlHubServices:
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
|
||||
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
@@ -141,25 +113,22 @@ class AdGuardControlHubServices:
|
||||
try:
|
||||
if "all" in clients:
|
||||
await api.set_protection(False)
|
||||
task = asyncio.create_task(self._delayed_enable_protection(api, duration))
|
||||
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
|
||||
# Re-enable after duration
|
||||
async def delayed_enable():
|
||||
await asyncio.sleep(duration)
|
||||
try:
|
||||
await api.set_protection(True)
|
||||
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to re-enable protection: %s", err)
|
||||
|
||||
asyncio.create_task(delayed_enable())
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
||||
|
||||
async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None:
|
||||
"""Re-enable protection after delay."""
|
||||
await asyncio.sleep(delay)
|
||||
try:
|
||||
await api.set_protection(True)
|
||||
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to re-enable protection: %s", err)
|
||||
|
||||
async def add_client(self, call: ServiceCall) -> None:
|
||||
"""Add a new client."""
|
||||
client_data = dict(call.data)
|
||||
_LOGGER.info("Adding new client: %s", client_data.get("name"))
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
@@ -171,17 +140,14 @@ class AdGuardControlHubServices:
|
||||
async def remove_client(self, call: ServiceCall) -> None:
|
||||
"""Remove a client."""
|
||||
client_name = call.data.get("name")
|
||||
_LOGGER.info("Removing client: %s", client_name)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
await api.delete_client(client_name)
|
||||
_LOGGER.info("Successfully removed client: %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
|
||||
_LOGGER.error("Failed to remove client: %s", err)
|
||||
|
||||
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
||||
"""Update multiple clients matching a pattern."""
|
||||
"""Bulk update clients."""
|
||||
_LOGGER.info("Bulk update clients called")
|
||||
# Implementation would go here
|
||||
|
@@ -6,7 +6,7 @@
|
||||
"description": "Configure your AdGuard Home connection",
|
||||
"data": {
|
||||
"host": "Host",
|
||||
"port": "Port",
|
||||
"port": "Port",
|
||||
"username": "Username",
|
||||
"password": "Password",
|
||||
"ssl": "Use SSL",
|
||||
@@ -15,25 +15,13 @@
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "Failed to connect to AdGuard Home. Please check your host, port, and credentials.",
|
||||
"cannot_connect": "Failed to connect to AdGuard Home",
|
||||
"invalid_auth": "Invalid username or password",
|
||||
"timeout": "Connection timeout. Please check your network connection.",
|
||||
"unknown": "An unexpected error occurred"
|
||||
"timeout": "Connection timeout",
|
||||
"unknown": "Unexpected error occurred"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "AdGuard Control Hub is already configured for this host and port"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "AdGuard Control Hub Options",
|
||||
"description": "Configure advanced options",
|
||||
"data": {
|
||||
"scan_interval": "Update interval (seconds)",
|
||||
"timeout": "Connection timeout (seconds)"
|
||||
}
|
||||
}
|
||||
"already_configured": "AdGuard Control Hub is already configured"
|
||||
}
|
||||
}
|
||||
}
|
@@ -24,11 +24,9 @@ async def async_setup_entry(
|
||||
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||
|
||||
entities = []
|
||||
# Add global protection switch
|
||||
entities.append(AdGuardProtectionSwitch(coordinator, api))
|
||||
entities = [AdGuardProtectionSwitch(coordinator, api)]
|
||||
|
||||
# Add client switches
|
||||
# Add client switches if clients exist
|
||||
for client_name in coordinator.clients.keys():
|
||||
entities.append(AdGuardClientSwitch(coordinator, api, client_name))
|
||||
|
||||
@@ -74,7 +72,6 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
|
||||
try:
|
||||
await self.api.set_protection(True)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("AdGuard protection enabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to enable AdGuard protection: %s", err)
|
||||
raise
|
||||
@@ -84,7 +81,6 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
|
||||
try:
|
||||
await self.api.set_protection(False)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("AdGuard protection disabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to disable AdGuard protection: %s", err)
|
||||
raise
|
||||
@@ -123,7 +119,6 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
|
||||
}
|
||||
await self.api.update_client(update_data)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("Enabled protection for client %s", self.client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
|
||||
raise
|
||||
@@ -139,7 +134,6 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
|
||||
}
|
||||
await self.api.update_client(update_data)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("Disabled protection for client %s", self.client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
|
||||
raise
|
||||
|
Reference in New Issue
Block a user