@@ -1,20 +1,24 @@
|
||||
"""
|
||||
🛡️ AdGuard Control Hub for Home Assistant.
|
||||
AdGuard Control Hub for Home Assistant.
|
||||
|
||||
Transform your AdGuard Home into a smart network management powerhouse with
|
||||
Transform your AdGuard Home into a smart network management powerhouse with
|
||||
complete client control, service blocking, and automation capabilities.
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Any
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
|
||||
from .api import AdGuardHomeAPI, AdGuardConnectionError
|
||||
from .const import DOMAIN, PLATFORMS, SCAN_INTERVAL, CONF_SSL, CONF_VERIFY_SSL
|
||||
from .api import AdGuardHomeAPI
|
||||
from .services import AdGuardControlHubServices
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,6 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up AdGuard Control Hub from a config entry."""
|
||||
session = async_get_clientsession(hass, entry.data.get(CONF_VERIFY_SSL, True))
|
||||
|
||||
# Create API instance
|
||||
api = AdGuardHomeAPI(
|
||||
host=entry.data[CONF_HOST],
|
||||
port=entry.data[CONF_PORT],
|
||||
@@ -34,16 +39,26 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
# Test the connection
|
||||
try:
|
||||
await api.test_connection()
|
||||
_LOGGER.info("Successfully connected to AdGuard Home at %s:%s",
|
||||
entry.data[CONF_HOST], entry.data[CONF_PORT])
|
||||
if not await api.test_connection():
|
||||
raise ConfigEntryNotReady("Unable to connect to AdGuard Home")
|
||||
|
||||
_LOGGER.info(
|
||||
"Successfully connected to AdGuard Home at %s:%s",
|
||||
entry.data[CONF_HOST],
|
||||
entry.data[CONF_PORT]
|
||||
)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to connect to AdGuard Home: %s", err)
|
||||
raise ConfigEntryNotReady(f"Unable to connect: {err}")
|
||||
raise ConfigEntryNotReady(f"Unable to connect: {err}") from err
|
||||
|
||||
# Create update coordinator
|
||||
coordinator = AdGuardControlHubCoordinator(hass, api)
|
||||
await coordinator.async_config_entry_first_refresh()
|
||||
|
||||
try:
|
||||
await coordinator.async_config_entry_first_refresh()
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to perform initial data refresh: %s", err)
|
||||
raise ConfigEntryNotReady(f"Failed to fetch initial data: {err}") from err
|
||||
|
||||
# Store data
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
@@ -53,9 +68,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
}
|
||||
|
||||
# Set up platforms
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
try:
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to set up platforms: %s", err)
|
||||
# Clean up on failure
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err
|
||||
|
||||
_LOGGER.info("AdGuard Control Hub setup complete")
|
||||
# Register services (only once, not per config entry)
|
||||
if not hass.services.has_service(DOMAIN, "block_services"):
|
||||
services = AdGuardControlHubServices(hass)
|
||||
services.register_services()
|
||||
|
||||
# Store services instance for cleanup
|
||||
hass.data.setdefault(f"{DOMAIN}_services", services)
|
||||
|
||||
_LOGGER.info("AdGuard Control Hub setup complete for %s:%s",
|
||||
entry.data[CONF_HOST], entry.data[CONF_PORT])
|
||||
return True
|
||||
|
||||
|
||||
@@ -64,8 +94,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
if unload_ok:
|
||||
# Remove this entry's data
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
|
||||
# Unregister services if this was the last entry
|
||||
if not hass.data[DOMAIN]: # No more entries
|
||||
services = hass.data.get(f"{DOMAIN}_services")
|
||||
if services:
|
||||
services.unregister_services()
|
||||
hass.data.pop(f"{DOMAIN}_services", None)
|
||||
|
||||
# Also clean up the empty domain entry
|
||||
hass.data.pop(DOMAIN, None)
|
||||
|
||||
return unload_ok
|
||||
|
||||
|
||||
@@ -81,36 +122,54 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
||||
update_interval=timedelta(seconds=SCAN_INTERVAL),
|
||||
)
|
||||
self.api = api
|
||||
self._clients = {}
|
||||
self._statistics = {}
|
||||
self._protection_status = {}
|
||||
self._clients: Dict[str, Any] = {}
|
||||
self._statistics: Dict[str, Any] = {}
|
||||
self._protection_status: Dict[str, Any] = {}
|
||||
|
||||
async def _async_update_data(self):
|
||||
async def _async_update_data(self) -> Dict[str, Any]:
|
||||
"""Fetch data from AdGuard Home."""
|
||||
try:
|
||||
# Fetch all data concurrently for better performance
|
||||
results = await asyncio.gather(
|
||||
tasks = [
|
||||
self.api.get_clients(),
|
||||
self.api.get_statistics(),
|
||||
self.api.get_status(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
clients, statistics, status = results
|
||||
|
||||
# Handle any exceptions
|
||||
# Handle any exceptions in individual requests
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
endpoint_names = ["clients", "statistics", "status"]
|
||||
_LOGGER.warning("Error fetching %s: %s", endpoint_names[i], result)
|
||||
_LOGGER.warning(
|
||||
"Error fetching %s from %s:%s: %s",
|
||||
endpoint_names[i],
|
||||
self.api.host,
|
||||
self.api.port,
|
||||
result
|
||||
)
|
||||
|
||||
# Update stored data (use empty dict if fetch failed)
|
||||
self._clients = {
|
||||
client["name"]: client
|
||||
for client in (clients.get("clients", []) if not isinstance(clients, Exception) else [])
|
||||
}
|
||||
self._statistics = statistics if not isinstance(statistics, Exception) else {}
|
||||
self._protection_status = status if not isinstance(status, Exception) else {}
|
||||
if not isinstance(clients, Exception):
|
||||
self._clients = {
|
||||
client["name"]: client
|
||||
for client in clients.get("clients", [])
|
||||
if client.get("name") # Ensure client has a name
|
||||
}
|
||||
else:
|
||||
_LOGGER.warning("Failed to update clients data, keeping previous data")
|
||||
|
||||
if not isinstance(statistics, Exception):
|
||||
self._statistics = statistics
|
||||
else:
|
||||
_LOGGER.warning("Failed to update statistics data, keeping previous data")
|
||||
|
||||
if not isinstance(status, Exception):
|
||||
self._protection_status = status
|
||||
else:
|
||||
_LOGGER.warning("Failed to update status data, keeping previous data")
|
||||
|
||||
return {
|
||||
"clients": self._clients,
|
||||
@@ -118,20 +177,40 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
||||
"status": self._protection_status,
|
||||
}
|
||||
|
||||
except AdGuardConnectionError as err:
|
||||
raise UpdateFailed(f"Connection error to AdGuard Home: {err}") from err
|
||||
except Exception as err:
|
||||
raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}")
|
||||
raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}") from err
|
||||
|
||||
@property
|
||||
def clients(self):
|
||||
def clients(self) -> Dict[str, Any]:
|
||||
"""Return clients data."""
|
||||
return self._clients
|
||||
|
||||
@property
|
||||
def statistics(self):
|
||||
def statistics(self) -> Dict[str, Any]:
|
||||
"""Return statistics data."""
|
||||
return self._statistics
|
||||
|
||||
@property
|
||||
def protection_status(self):
|
||||
def protection_status(self) -> Dict[str, Any]:
|
||||
"""Return protection status data."""
|
||||
return self._protection_status
|
||||
|
||||
def get_client(self, client_name: str) -> Dict[str, Any] | None:
|
||||
"""Get a specific client by name."""
|
||||
return self._clients.get(client_name)
|
||||
|
||||
def has_client(self, client_name: str) -> bool:
|
||||
"""Check if a client exists."""
|
||||
return client_name in self._clients
|
||||
|
||||
@property
|
||||
def client_count(self) -> int:
|
||||
"""Return the number of clients."""
|
||||
return len(self._clients)
|
||||
|
||||
@property
|
||||
def is_protection_enabled(self) -> bool:
|
||||
"""Return True if protection is enabled."""
|
||||
return self._protection_status.get("protection_enabled", False)
|
||||
|
@@ -1,102 +1,207 @@
|
||||
"""API wrapper for AdGuard Home."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import BasicAuth
|
||||
from aiohttp import BasicAuth, ClientError, ClientTimeout
|
||||
|
||||
from .const import API_ENDPOINTS
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Custom exceptions
|
||||
class AdGuardHomeError(Exception):
|
||||
"""Base exception for AdGuard Home API."""
|
||||
pass
|
||||
|
||||
class AdGuardConnectionError(AdGuardHomeError):
|
||||
"""Exception for connection errors."""
|
||||
pass
|
||||
|
||||
class AdGuardAuthError(AdGuardHomeError):
|
||||
"""Exception for authentication errors."""
|
||||
pass
|
||||
|
||||
class AdGuardNotFoundError(AdGuardHomeError):
|
||||
"""Exception for not found errors."""
|
||||
pass
|
||||
|
||||
class AdGuardHomeAPI:
|
||||
"""API wrapper for AdGuard Home."""
|
||||
|
||||
def __init__(self, host: str, port: int = 3000, username: str = None,
|
||||
password: str = None, ssl: bool = False, session=None):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 3000,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
timeout: int = 10,
|
||||
):
|
||||
"""Initialize the API wrapper."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.ssl = ssl
|
||||
self.session = session
|
||||
self._session = session
|
||||
self._timeout = ClientTimeout(total=timeout)
|
||||
protocol = "https" if ssl else "http"
|
||||
self.base_url = f"{protocol}://{host}:{port}"
|
||||
self._own_session = session is None
|
||||
|
||||
async def _request(self, method: str, endpoint: str, data: dict = None) -> dict:
|
||||
"""Make an API request."""
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
if self._own_session:
|
||||
self._session = aiohttp.ClientSession(timeout=self._timeout)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self._own_session and self._session:
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
"""Get the session, creating one if needed."""
|
||||
if not self._session:
|
||||
self._session = aiohttp.ClientSession(timeout=self._timeout)
|
||||
return self._session
|
||||
|
||||
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Make an API request with comprehensive error handling."""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
auth = None
|
||||
|
||||
if self.username and self.password:
|
||||
auth = BasicAuth(self.username, self.password)
|
||||
|
||||
try:
|
||||
async with self.session.request(method, url, json=data, headers=headers, auth=auth) as response:
|
||||
async with self.session.request(
|
||||
method, url, json=data, headers=headers, auth=auth
|
||||
) as response:
|
||||
|
||||
# Handle different HTTP status codes
|
||||
if response.status == 401:
|
||||
raise AdGuardAuthError("Authentication failed - check username/password")
|
||||
elif response.status == 403:
|
||||
raise AdGuardAuthError("Access forbidden - insufficient permissions")
|
||||
elif response.status == 404:
|
||||
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
||||
elif response.status >= 500:
|
||||
raise AdGuardConnectionError(f"Server error {response.status}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Handle empty responses
|
||||
if response.status == 204 or not response.content_length:
|
||||
return {}
|
||||
return await response.json()
|
||||
|
||||
try:
|
||||
return await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
# Handle non-JSON responses
|
||||
text = await response.text()
|
||||
_LOGGER.warning("Non-JSON response received: %s", text)
|
||||
return {"response": text}
|
||||
|
||||
except asyncio.TimeoutError as err:
|
||||
raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}")
|
||||
except ClientError as err:
|
||||
raise AdGuardConnectionError(f"Client error: {err}")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Error communicating with AdGuard Home: %s", err)
|
||||
raise
|
||||
_LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err)
|
||||
raise AdGuardHomeError(f"Unexpected error: {err}")
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test the connection."""
|
||||
"""Test the connection to AdGuard Home."""
|
||||
try:
|
||||
await self._request("GET", API_ENDPOINTS["status"])
|
||||
return True
|
||||
except:
|
||||
except Exception as err:
|
||||
_LOGGER.debug("Connection test failed: %s", err)
|
||||
return False
|
||||
|
||||
async def get_status(self) -> dict:
|
||||
"""Get server status."""
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get server status information."""
|
||||
return await self._request("GET", API_ENDPOINTS["status"])
|
||||
|
||||
async def get_clients(self) -> dict:
|
||||
"""Get all clients."""
|
||||
async def get_clients(self) -> Dict[str, Any]:
|
||||
"""Get all configured clients."""
|
||||
return await self._request("GET", API_ENDPOINTS["clients"])
|
||||
|
||||
async def get_statistics(self) -> dict:
|
||||
"""Get statistics."""
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get DNS query statistics."""
|
||||
return await self._request("GET", API_ENDPOINTS["stats"])
|
||||
|
||||
async def set_protection(self, enabled: bool) -> dict:
|
||||
"""Enable or disable protection."""
|
||||
async def set_protection(self, enabled: bool) -> Dict[str, Any]:
|
||||
"""Enable or disable AdGuard protection."""
|
||||
data = {"enabled": enabled}
|
||||
return await self._request("POST", API_ENDPOINTS["protection"], data)
|
||||
|
||||
async def add_client(self, client_data: dict) -> dict:
|
||||
"""Add a new client."""
|
||||
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Add a new client configuration."""
|
||||
# Validate required fields
|
||||
if "name" not in client_data:
|
||||
raise ValueError("Client name is required")
|
||||
if "ids" not in client_data or not client_data["ids"]:
|
||||
raise ValueError("Client IDs are required")
|
||||
|
||||
return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
|
||||
|
||||
async def update_client(self, client_data: dict) -> dict:
|
||||
"""Update an existing client."""
|
||||
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update an existing client configuration."""
|
||||
if "name" not in client_data:
|
||||
raise ValueError("Client name is required for update")
|
||||
if "data" not in client_data:
|
||||
raise ValueError("Client data is required for update")
|
||||
|
||||
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
||||
|
||||
async def delete_client(self, client_name: str) -> dict:
|
||||
"""Delete a client."""
|
||||
async def delete_client(self, client_name: str) -> Dict[str, Any]:
|
||||
"""Delete a client configuration."""
|
||||
if not client_name:
|
||||
raise ValueError("Client name is required")
|
||||
|
||||
data = {"name": client_name}
|
||||
return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
|
||||
|
||||
async def get_client_by_name(self, client_name: str) -> dict:
|
||||
async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific client by name."""
|
||||
clients_data = await self.get_clients()
|
||||
clients = clients_data.get("clients", [])
|
||||
if not client_name:
|
||||
return None
|
||||
|
||||
for client in clients:
|
||||
if client.get("name") == client_name:
|
||||
return client
|
||||
try:
|
||||
clients_data = await self.get_clients()
|
||||
clients = clients_data.get("clients", [])
|
||||
|
||||
return None
|
||||
for client in clients:
|
||||
if client.get("name") == client_name:
|
||||
return client
|
||||
|
||||
async def update_client_blocked_services(self, client_name: str, blocked_services: list,
|
||||
schedule: dict = None) -> dict:
|
||||
return None
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to get client %s: %s", client_name, err)
|
||||
return None
|
||||
|
||||
async def update_client_blocked_services(
|
||||
self,
|
||||
client_name: str,
|
||||
blocked_services: list,
|
||||
schedule: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Update blocked services for a specific client."""
|
||||
if not client_name:
|
||||
raise ValueError("Client name is required")
|
||||
|
||||
client = await self.get_client_by_name(client_name)
|
||||
if not client:
|
||||
raise ValueError(f"Client '{client_name}' not found")
|
||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||
|
||||
# Prepare the blocked services data
|
||||
# Prepare the blocked services data with proper structure
|
||||
if schedule:
|
||||
blocked_services_data = {
|
||||
"ids": blocked_services,
|
||||
@@ -110,7 +215,7 @@ class AdGuardHomeAPI:
|
||||
}
|
||||
}
|
||||
|
||||
# Update the client
|
||||
# Update the client with new blocked services
|
||||
update_data = {
|
||||
"name": client_name,
|
||||
"data": {
|
||||
@@ -121,18 +226,23 @@ class AdGuardHomeAPI:
|
||||
|
||||
return await self.update_client(update_data)
|
||||
|
||||
async def toggle_client_service(self, client_name: str, service_id: str, enabled: bool) -> dict:
|
||||
async def toggle_client_service(
|
||||
self, client_name: str, service_id: str, enabled: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""Toggle a specific service for a client."""
|
||||
if not client_name or not service_id:
|
||||
raise ValueError("Client name and service ID are required")
|
||||
|
||||
client = await self.get_client_by_name(client_name)
|
||||
if not client:
|
||||
raise ValueError(f"Client '{client_name}' not found")
|
||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||
|
||||
# Get current blocked services
|
||||
blocked_services = client.get("blocked_services", {})
|
||||
if isinstance(blocked_services, dict):
|
||||
service_ids = blocked_services.get("ids", [])
|
||||
else:
|
||||
# Handle old format (list)
|
||||
# Handle legacy format (direct list)
|
||||
service_ids = blocked_services if blocked_services else []
|
||||
|
||||
# Update the service list
|
||||
@@ -142,3 +252,12 @@ class AdGuardHomeAPI:
|
||||
service_ids.remove(service_id)
|
||||
|
||||
return await self.update_client_blocked_services(client_name, service_ids)
|
||||
|
||||
async def get_blocked_services(self) -> Dict[str, Any]:
|
||||
"""Get available blocked services."""
|
||||
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the API session if we own it."""
|
||||
if self._own_session and self._session:
|
||||
await self._session.close()
|
||||
|
166
custom_components/adguard_hub/binary_sensor.py
Normal file
166
custom_components/adguard_hub/binary_sensor.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Binary sensor platform for AdGuard Control Hub integration."""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
|
||||
from . import AdGuardControlHubCoordinator
|
||||
from .api import AdGuardHomeAPI
|
||||
from .const import DOMAIN, MANUFACTURER, ICON_PROTECTION, ICON_PROTECTION_OFF
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AdGuard Control Hub binary sensor platform."""
|
||||
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||
|
||||
entities = [
|
||||
AdGuardProtectionBinarySensor(coordinator, api),
|
||||
AdGuardFilteringBinarySensor(coordinator, api),
|
||||
AdGuardSafeBrowsingBinarySensor(coordinator, api),
|
||||
AdGuardParentalControlBinarySensor(coordinator, api),
|
||||
AdGuardSafeSearchBinarySensor(coordinator, api),
|
||||
]
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity):
|
||||
"""Base class for AdGuard binary sensors."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the binary sensor."""
|
||||
super().__init__(coordinator)
|
||||
self.api = api
|
||||
self._attr_device_info = {
|
||||
"identifiers": {(DOMAIN, f"{api.host}:{api.port}")},
|
||||
"name": f"AdGuard Control Hub ({api.host})",
|
||||
"manufacturer": MANUFACTURER,
|
||||
"model": "AdGuard Home",
|
||||
}
|
||||
|
||||
|
||||
class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
|
||||
"""Binary sensor to show AdGuard protection status."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the binary sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled"
|
||||
self._attr_name = "AdGuard Protection Status"
|
||||
self._attr_device_class = BinarySensorDeviceClass.RUNNING
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if protection is enabled."""
|
||||
return self.coordinator.protection_status.get("protection_enabled", False)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the icon for the binary sensor."""
|
||||
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return additional state attributes."""
|
||||
status = self.coordinator.protection_status
|
||||
return {
|
||||
"dns_port": status.get("dns_port", "N/A"),
|
||||
"http_port": status.get("http_port", "N/A"),
|
||||
"version": status.get("version", "N/A"),
|
||||
"running": status.get("running", False),
|
||||
}
|
||||
|
||||
|
||||
class AdGuardFilteringBinarySensor(AdGuardBaseBinarySensor):
|
||||
"""Binary sensor to show DNS filtering status."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the binary sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_filtering_enabled"
|
||||
self._attr_name = "AdGuard DNS Filtering"
|
||||
self._attr_device_class = BinarySensorDeviceClass.RUNNING
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if DNS filtering is enabled."""
|
||||
return self.coordinator.protection_status.get("filtering_enabled", False)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the icon for the binary sensor."""
|
||||
return "mdi:dns" if self.is_on else "mdi:dns-off"
|
||||
|
||||
|
||||
class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor):
|
||||
"""Binary sensor to show Safe Browsing status."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the binary sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled"
|
||||
self._attr_name = "AdGuard Safe Browsing"
|
||||
self._attr_device_class = BinarySensorDeviceClass.SAFETY
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if Safe Browsing is enabled."""
|
||||
return self.coordinator.protection_status.get("safebrowsing_enabled", False)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the icon for the binary sensor."""
|
||||
return "mdi:security" if self.is_on else "mdi:security-off"
|
||||
|
||||
|
||||
class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor):
|
||||
"""Binary sensor to show Parental Control status."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the binary sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled"
|
||||
self._attr_name = "AdGuard Parental Control"
|
||||
self._attr_device_class = BinarySensorDeviceClass.SAFETY
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if Parental Control is enabled."""
|
||||
return self.coordinator.protection_status.get("parental_enabled", False)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the icon for the binary sensor."""
|
||||
return "mdi:account-child" if self.is_on else "mdi:account-child-outline"
|
||||
|
||||
|
||||
class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor):
|
||||
"""Binary sensor to show Safe Search status."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the binary sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled"
|
||||
self._attr_name = "AdGuard Safe Search"
|
||||
self._attr_device_class = BinarySensorDeviceClass.SAFETY
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if Safe Search is enabled."""
|
||||
return self.coordinator.protection_status.get("safesearch_enabled", False)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the icon for the binary sensor."""
|
||||
return "mdi:search-web" if self.is_on else "mdi:web-off"
|
@@ -1,73 +1,128 @@
|
||||
"""Config flow for AdGuard Control Hub integration."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import voluptuous as vol
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from .api import AdGuardHomeAPI
|
||||
from .const import CONF_SSL, CONF_VERIFY_SSL, DEFAULT_PORT, DEFAULT_SSL, DEFAULT_VERIFY_SSL, DOMAIN
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError
|
||||
from .const import (
|
||||
CONF_SSL,
|
||||
CONF_VERIFY_SSL,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_SSL,
|
||||
DEFAULT_VERIFY_SSL,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema({
|
||||
vol.Required(CONF_HOST): str,
|
||||
vol.Optional(CONF_PORT, default=DEFAULT_PORT): int,
|
||||
vol.Optional(CONF_USERNAME): str,
|
||||
vol.Optional(CONF_PASSWORD): str,
|
||||
vol.Optional(CONF_SSL, default=DEFAULT_SSL): bool,
|
||||
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): bool,
|
||||
vol.Required(CONF_HOST): cv.string,
|
||||
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
|
||||
vol.Optional(CONF_USERNAME): cv.string,
|
||||
vol.Optional(CONF_PASSWORD): cv.string,
|
||||
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
|
||||
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
|
||||
})
|
||||
|
||||
async def validate_input(hass, data: dict) -> dict:
|
||||
|
||||
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate the user input allows us to connect."""
|
||||
# Normalize host
|
||||
host = data[CONF_HOST].strip()
|
||||
if not host:
|
||||
raise InvalidHost("Host cannot be empty")
|
||||
|
||||
# Remove protocol if provided
|
||||
if host.startswith(("http://", "https://")):
|
||||
host = host.split("://", 1)[1]
|
||||
data[CONF_HOST] = host
|
||||
|
||||
# Validate port
|
||||
port = data[CONF_PORT]
|
||||
if not (1 <= port <= 65535):
|
||||
raise InvalidPort("Port must be between 1 and 65535")
|
||||
|
||||
# Create session with appropriate SSL settings
|
||||
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
|
||||
|
||||
# Create API instance
|
||||
api = AdGuardHomeAPI(
|
||||
host=data[CONF_HOST],
|
||||
port=data[CONF_PORT],
|
||||
host=host,
|
||||
port=port,
|
||||
username=data.get(CONF_USERNAME),
|
||||
password=data.get(CONF_PASSWORD),
|
||||
ssl=data.get(CONF_SSL, False),
|
||||
session=session,
|
||||
timeout=10, # 10 second timeout for setup
|
||||
)
|
||||
|
||||
# Test the connection
|
||||
if not await api.test_connection():
|
||||
raise CannotConnect
|
||||
|
||||
# Get server info
|
||||
try:
|
||||
status = await api.get_status()
|
||||
version = status.get("version", "unknown")
|
||||
return {
|
||||
"title": f"AdGuard Control Hub ({data[CONF_HOST]})",
|
||||
"version": version
|
||||
}
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected exception: %s", err)
|
||||
if not await api.test_connection():
|
||||
raise CannotConnect("Failed to connect to AdGuard Home")
|
||||
|
||||
# Get additional server info if possible
|
||||
try:
|
||||
status = await api.get_status()
|
||||
version = status.get("version", "unknown")
|
||||
dns_port = status.get("dns_port", "N/A")
|
||||
|
||||
return {
|
||||
"title": f"AdGuard Control Hub ({host})",
|
||||
"version": version,
|
||||
"dns_port": dns_port,
|
||||
"host": host,
|
||||
}
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Could not get server status, but connection works: %s", err)
|
||||
return {
|
||||
"title": f"AdGuard Control Hub ({host})",
|
||||
"version": "unknown",
|
||||
"dns_port": "N/A",
|
||||
"host": host,
|
||||
}
|
||||
|
||||
except AdGuardAuthError as err:
|
||||
_LOGGER.error("Authentication failed: %s", err)
|
||||
raise InvalidAuth from err
|
||||
except AdGuardConnectionError as err:
|
||||
_LOGGER.error("Connection failed: %s", err)
|
||||
if "timeout" in str(err).lower():
|
||||
raise Timeout from err
|
||||
raise CannotConnect from err
|
||||
except asyncio.TimeoutError as err:
|
||||
_LOGGER.error("Connection timeout: %s", err)
|
||||
raise Timeout from err
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error during validation: %s", err)
|
||||
raise CannotConnect from err
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for AdGuard Control Hub."""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL
|
||||
|
||||
async def async_step_user(self, user_input: dict[str, Any] | None = None):
|
||||
async def async_step_user(
|
||||
self, user_input: Optional[Dict[str, Any]] = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step."""
|
||||
errors: dict[str, str] = {}
|
||||
errors: Dict[str, str] = {}
|
||||
|
||||
if user_input is not None:
|
||||
try:
|
||||
info = await validate_input(self.hass, user_input)
|
||||
except CannotConnect:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
|
||||
# Create unique ID based on host and port
|
||||
unique_id = f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}"
|
||||
unique_id = f"{info['host']}:{user_input[CONF_PORT]}"
|
||||
await self.async_set_unique_id(unique_id)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
@@ -76,11 +131,83 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
except CannotConnect:
|
||||
errors["base"] = "cannot_connect"
|
||||
except InvalidAuth:
|
||||
errors["base"] = "invalid_auth"
|
||||
except InvalidHost:
|
||||
errors[CONF_HOST] = "invalid_host"
|
||||
except InvalidPort:
|
||||
errors[CONF_PORT] = "invalid_port"
|
||||
except Timeout:
|
||||
errors["base"] = "timeout"
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception during config flow")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=STEP_USER_DATA_SCHEMA,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_import(self, import_info: Dict[str, Any]) -> FlowResult:
|
||||
"""Handle configuration import."""
|
||||
return await self.async_step_user(import_info)
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get the options flow for this handler."""
|
||||
return OptionsFlowHandler(config_entry)
|
||||
|
||||
|
||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle options flow for AdGuard Control Hub."""
|
||||
|
||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, Any]] = None
|
||||
) -> FlowResult:
|
||||
"""Handle options flow."""
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(title="", data=user_input)
|
||||
|
||||
options_schema = vol.Schema({
|
||||
vol.Optional(
|
||||
"scan_interval",
|
||||
default=self.config_entry.options.get("scan_interval", 30),
|
||||
): vol.All(vol.Coerce(int), vol.Range(min=10, max=300)),
|
||||
vol.Optional(
|
||||
"timeout",
|
||||
default=self.config_entry.options.get("timeout", 10),
|
||||
): vol.All(vol.Coerce(int), vol.Range(min=5, max=60)),
|
||||
})
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=options_schema,
|
||||
)
|
||||
|
||||
|
||||
# Custom exceptions
|
||||
class CannotConnect(Exception):
|
||||
"""Error to indicate we cannot connect."""
|
||||
"""Error to indicate we cannot connect."""
|
||||
|
||||
|
||||
class InvalidAuth(Exception):
|
||||
"""Error to indicate there is invalid auth."""
|
||||
|
||||
|
||||
class InvalidHost(Exception):
|
||||
"""Error to indicate invalid host."""
|
||||
|
||||
|
||||
class InvalidPort(Exception):
|
||||
"""Error to indicate invalid port."""
|
||||
|
||||
|
||||
class Timeout(Exception):
|
||||
"""Error to indicate connection timeout."""
|
||||
|
@@ -17,7 +17,7 @@ SCAN_INTERVAL: Final = 30
|
||||
# Platforms
|
||||
PLATFORMS: Final = [
|
||||
"switch",
|
||||
"binary_sensor",
|
||||
"binary_sensor",
|
||||
"sensor",
|
||||
]
|
||||
|
||||
@@ -26,7 +26,7 @@ API_ENDPOINTS: Final = {
|
||||
"status": "/control/status",
|
||||
"clients": "/control/clients",
|
||||
"clients_add": "/control/clients/add",
|
||||
"clients_update": "/control/clients/update",
|
||||
"clients_update": "/control/clients/update",
|
||||
"clients_delete": "/control/clients/delete",
|
||||
"blocked_services_all": "/control/blocked_services/all",
|
||||
"blocked_services_get": "/control/blocked_services/get",
|
||||
@@ -39,7 +39,7 @@ API_ENDPOINTS: Final = {
|
||||
BLOCKED_SERVICES: Final = {
|
||||
# Social Media
|
||||
"youtube": "YouTube",
|
||||
"facebook": "Facebook",
|
||||
"facebook": "Facebook",
|
||||
"instagram": "Instagram",
|
||||
"tiktok": "TikTok",
|
||||
"twitter": "Twitter/X",
|
||||
@@ -62,7 +62,7 @@ BLOCKED_SERVICES: Final = {
|
||||
"amazon": "Amazon",
|
||||
"ebay": "eBay",
|
||||
|
||||
# Communication
|
||||
# Communication
|
||||
"whatsapp": "WhatsApp",
|
||||
"telegram": "Telegram",
|
||||
"discord": "Discord",
|
||||
@@ -89,4 +89,4 @@ ICON_CLIENT: Final = "mdi:devices"
|
||||
ICON_CLIENT_OFFLINE: Final = "mdi:devices-off"
|
||||
ICON_BLOCKED_SERVICE: Final = "mdi:block-helper"
|
||||
ICON_ALLOWED_SERVICE: Final = "mdi:check-circle"
|
||||
ICON_STATISTICS: Final = "mdi:chart-line"
|
||||
ICON_STATISTICS: Final = "mdi:chart-line"
|
||||
|
@@ -1,14 +1,14 @@
|
||||
{
|
||||
"domain": "adguard_hub",
|
||||
"name": "AdGuard Control Hub",
|
||||
"codeowners": ["@sq4ind"],
|
||||
"config_flow": true,
|
||||
"dependencies": [],
|
||||
"documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub",
|
||||
"integration_type": "hub",
|
||||
"iot_class": "local_polling",
|
||||
"requirements": [
|
||||
"aiohttp>=3.8.0"
|
||||
],
|
||||
"version": "1.0.0"
|
||||
"domain": "adguard_hub",
|
||||
"name": "AdGuard Control Hub",
|
||||
"codeowners": ["@sq4ind"],
|
||||
"config_flow": true,
|
||||
"dependencies": [],
|
||||
"documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub",
|
||||
"integration_type": "hub",
|
||||
"iot_class": "local_polling",
|
||||
"requirements": [
|
||||
"aiohttp>=3.8.0"
|
||||
],
|
||||
"version": "1.0.0"
|
||||
}
|
185
custom_components/adguard_hub/sensor.py
Normal file
185
custom_components/adguard_hub/sensor.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Sensor platform for AdGuard Control Hub integration."""
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.sensor import SensorEntity, SensorDeviceClass, SensorStateClass
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import PERCENTAGE
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
|
||||
from . import AdGuardControlHubCoordinator
|
||||
from .api import AdGuardHomeAPI
|
||||
from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AdGuard Control Hub sensor platform."""
|
||||
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||
|
||||
entities = [
|
||||
AdGuardQueriesCounterSensor(coordinator, api),
|
||||
AdGuardBlockedCounterSensor(coordinator, api),
|
||||
AdGuardBlockingPercentageSensor(coordinator, api),
|
||||
AdGuardRuleCountSensor(coordinator, api),
|
||||
AdGuardClientCountSensor(coordinator, api),
|
||||
AdGuardUpstreamAverageTimeSensor(coordinator, api),
|
||||
]
|
||||
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class AdGuardBaseSensor(CoordinatorEntity, SensorEntity):
|
||||
"""Base class for AdGuard sensors."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator)
|
||||
self.api = api
|
||||
self._attr_device_info = {
|
||||
"identifiers": {(DOMAIN, f"{api.host}:{api.port}")},
|
||||
"name": f"AdGuard Control Hub ({api.host})",
|
||||
"manufacturer": MANUFACTURER,
|
||||
"model": "AdGuard Home",
|
||||
}
|
||||
|
||||
|
||||
class AdGuardQueriesCounterSensor(AdGuardBaseSensor):
|
||||
"""Sensor to track DNS queries count."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_dns_queries"
|
||||
self._attr_name = "AdGuard DNS Queries"
|
||||
self._attr_icon = ICON_STATISTICS
|
||||
self._attr_state_class = SensorStateClass.TOTAL_INCREASING
|
||||
self._attr_native_unit_of_measurement = "queries"
|
||||
|
||||
@property
|
||||
def native_value(self) -> int | None:
|
||||
"""Return the state of the sensor."""
|
||||
stats = self.coordinator.statistics
|
||||
return stats.get("num_dns_queries", 0)
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return additional state attributes."""
|
||||
stats = self.coordinator.statistics
|
||||
return {
|
||||
"queries_today": stats.get("num_dns_queries_today", 0),
|
||||
"queries_blocked_today": stats.get("num_blocked_filtering_today", 0),
|
||||
"last_updated": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdGuardBlockedCounterSensor(AdGuardBaseSensor):
|
||||
"""Sensor to track blocked queries count."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries"
|
||||
self._attr_name = "AdGuard Blocked Queries"
|
||||
self._attr_icon = "mdi:shield-check"
|
||||
self._attr_state_class = SensorStateClass.TOTAL_INCREASING
|
||||
self._attr_native_unit_of_measurement = "queries"
|
||||
|
||||
@property
|
||||
def native_value(self) -> int | None:
|
||||
"""Return the state of the sensor."""
|
||||
stats = self.coordinator.statistics
|
||||
return stats.get("num_blocked_filtering", 0)
|
||||
|
||||
|
||||
class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
|
||||
"""Sensor to track blocking percentage."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage"
|
||||
self._attr_name = "AdGuard Blocking Percentage"
|
||||
self._attr_icon = "mdi:percent"
|
||||
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||
self._attr_native_unit_of_measurement = PERCENTAGE
|
||||
self._attr_device_class = SensorDeviceClass.POWER_FACTOR
|
||||
|
||||
@property
|
||||
def native_value(self) -> float | None:
|
||||
"""Return the state of the sensor."""
|
||||
stats = self.coordinator.statistics
|
||||
total_queries = stats.get("num_dns_queries", 0)
|
||||
blocked_queries = stats.get("num_blocked_filtering", 0)
|
||||
|
||||
if total_queries == 0:
|
||||
return 0
|
||||
|
||||
percentage = (blocked_queries / total_queries) * 100
|
||||
return round(percentage, 2)
|
||||
|
||||
|
||||
class AdGuardRuleCountSensor(AdGuardBaseSensor):
|
||||
"""Sensor to track filtering rules count."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_rules_count"
|
||||
self._attr_name = "AdGuard Rules Count"
|
||||
self._attr_icon = "mdi:format-list-numbered"
|
||||
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||
self._attr_native_unit_of_measurement = "rules"
|
||||
|
||||
@property
|
||||
def native_value(self) -> int | None:
|
||||
"""Return the state of the sensor."""
|
||||
stats = self.coordinator.statistics
|
||||
return stats.get("filtering_rules_count", 0)
|
||||
|
||||
|
||||
class AdGuardClientCountSensor(AdGuardBaseSensor):
|
||||
"""Sensor to track active clients count."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_clients_count"
|
||||
self._attr_name = "AdGuard Clients Count"
|
||||
self._attr_icon = "mdi:account-multiple"
|
||||
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||
self._attr_native_unit_of_measurement = "clients"
|
||||
|
||||
@property
|
||||
def native_value(self) -> int | None:
|
||||
"""Return the state of the sensor."""
|
||||
return len(self.coordinator.clients)
|
||||
|
||||
|
||||
class AdGuardUpstreamAverageTimeSensor(AdGuardBaseSensor):
|
||||
"""Sensor to track upstream servers average response time."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the sensor."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_upstream_response_time"
|
||||
self._attr_name = "AdGuard Upstream Response Time"
|
||||
self._attr_icon = "mdi:timer"
|
||||
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||
self._attr_native_unit_of_measurement = "ms"
|
||||
self._attr_device_class = SensorDeviceClass.DURATION
|
||||
|
||||
@property
|
||||
def native_value(self) -> float | None:
|
||||
"""Return the state of the sensor."""
|
||||
stats = self.coordinator.statistics
|
||||
return stats.get("avg_processing_time", 0)
|
@@ -1,38 +1,438 @@
|
||||
"""Services for AdGuard Control Hub integration."""
|
||||
"""Service implementations for AdGuard Control Hub integration."""
|
||||
import asyncio
|
||||
import logging
|
||||
from homeassistant.core import HomeAssistant
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import voluptuous as vol
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from .api import AdGuardHomeAPI
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
BLOCKED_SERVICES,
|
||||
ATTR_CLIENT_NAME,
|
||||
ATTR_SERVICES,
|
||||
ATTR_DURATION,
|
||||
ATTR_CLIENTS,
|
||||
ATTR_CLIENT_PATTERN,
|
||||
ATTR_SETTINGS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
async def async_register_services(hass: HomeAssistant, api: AdGuardHomeAPI) -> None:
|
||||
"""Register integration services."""
|
||||
# Service schemas
|
||||
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
async def emergency_unblock_service(call):
|
||||
"""Emergency unblock service."""
|
||||
duration = call.data.get("duration", 300)
|
||||
clients = call.data.get("clients", ["all"])
|
||||
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
||||
vol.Required(ATTR_DURATION): cv.positive_int,
|
||||
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
||||
})
|
||||
|
||||
SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_PATTERN): cv.string,
|
||||
vol.Required(ATTR_SETTINGS): vol.Schema({
|
||||
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
vol.Optional("filtering_enabled"): cv.boolean,
|
||||
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
||||
vol.Optional("parental_enabled"): cv.boolean,
|
||||
}),
|
||||
})
|
||||
|
||||
SCHEMA_ADD_CLIENT = vol.Schema({
|
||||
vol.Required("name"): cv.string,
|
||||
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional("mac"): cv.string,
|
||||
vol.Optional("use_global_settings"): cv.boolean,
|
||||
vol.Optional("filtering_enabled"): cv.boolean,
|
||||
vol.Optional("parental_enabled"): cv.boolean,
|
||||
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
||||
vol.Optional("safesearch_enabled"): cv.boolean,
|
||||
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_REMOVE_CLIENT = vol.Schema({
|
||||
vol.Required("name"): cv.string,
|
||||
})
|
||||
|
||||
SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
vol.Required("schedule"): vol.Schema({
|
||||
vol.Optional("time_zone", default="Local"): cv.string,
|
||||
vol.Optional("sun"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("mon"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("tue"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("wed"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("thu"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("fri"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
vol.Optional("sat"): vol.Schema({
|
||||
vol.Optional("start"): cv.string,
|
||||
vol.Optional("end"): cv.string,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
|
||||
# Service names
|
||||
SERVICE_BLOCK_SERVICES = "block_services"
|
||||
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
||||
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
|
||||
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
||||
SERVICE_ADD_CLIENT = "add_client"
|
||||
SERVICE_REMOVE_CLIENT = "remove_client"
|
||||
SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block"
|
||||
|
||||
|
||||
class AdGuardControlHubServices:
|
||||
"""Handle services for AdGuard Control Hub."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant):
|
||||
"""Initialize the services."""
|
||||
self.hass = hass
|
||||
self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
def register_services(self) -> None:
|
||||
"""Register all services."""
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
self.block_services,
|
||||
schema=SCHEMA_BLOCK_SERVICES,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
self.unblock_services,
|
||||
schema=SCHEMA_UNBLOCK_SERVICES,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
self.emergency_unblock,
|
||||
schema=SCHEMA_EMERGENCY_UNBLOCK,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_BULK_UPDATE_CLIENTS,
|
||||
self.bulk_update_clients,
|
||||
schema=SCHEMA_BULK_UPDATE_CLIENTS,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_ADD_CLIENT,
|
||||
self.add_client,
|
||||
schema=SCHEMA_ADD_CLIENT,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
self.remove_client,
|
||||
schema=SCHEMA_REMOVE_CLIENT,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
||||
self.schedule_service_block,
|
||||
schema=SCHEMA_SCHEDULE_SERVICE_BLOCK,
|
||||
)
|
||||
|
||||
def unregister_services(self) -> None:
|
||||
"""Unregister all services."""
|
||||
services = [
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
SERVICE_BULK_UPDATE_CLIENTS,
|
||||
SERVICE_ADD_CLIENT,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
||||
]
|
||||
|
||||
for service in services:
|
||||
self.hass.services.remove(DOMAIN, service)
|
||||
|
||||
def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI:
|
||||
"""Get API instance for a specific config entry."""
|
||||
return self.hass.data[DOMAIN][entry_id]["api"]
|
||||
|
||||
async def block_services(self, call: ServiceCall) -> None:
|
||||
"""Block services for a specific client."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
|
||||
_LOGGER.info("Blocking services %s for client %s", services, client_name)
|
||||
|
||||
# Get all API instances (for multiple AdGuard instances)
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
# Get current client data
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if not client:
|
||||
_LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port)
|
||||
continue
|
||||
|
||||
# Get current blocked services
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
if isinstance(current_blocked, dict):
|
||||
current_services = current_blocked.get("ids", [])
|
||||
else:
|
||||
current_services = current_blocked if current_blocked else []
|
||||
|
||||
# Add new services to block
|
||||
updated_services = list(set(current_services + services))
|
||||
|
||||
# Update client
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
_LOGGER.info("Successfully blocked services for %s", client_name)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
|
||||
|
||||
async def unblock_services(self, call: ServiceCall) -> None:
|
||||
"""Unblock services for a specific client."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
|
||||
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
||||
|
||||
# Get all API instances
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
# Get current client data
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if not client:
|
||||
continue
|
||||
|
||||
# Get current blocked services
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
if isinstance(current_blocked, dict):
|
||||
current_services = current_blocked.get("ids", [])
|
||||
else:
|
||||
current_services = current_blocked if current_blocked else []
|
||||
|
||||
# Remove services to unblock
|
||||
updated_services = [s for s in current_services if s not in services]
|
||||
|
||||
# Update client
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
|
||||
|
||||
async def emergency_unblock(self, call: ServiceCall) -> None:
|
||||
"""Emergency unblock - temporarily disable protection."""
|
||||
duration = call.data[ATTR_DURATION] # seconds
|
||||
clients = call.data[ATTR_CLIENTS]
|
||||
|
||||
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
|
||||
|
||||
# Cancel any existing emergency unblock tasks
|
||||
for task in self._emergency_unblock_tasks.values():
|
||||
task.cancel()
|
||||
self._emergency_unblock_tasks.clear()
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
if "all" in clients:
|
||||
# Disable global protection
|
||||
await api.set_protection(False)
|
||||
|
||||
# Schedule re-enable
|
||||
task = asyncio.create_task(
|
||||
self._delayed_enable_protection(api, duration)
|
||||
)
|
||||
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
|
||||
else:
|
||||
# Disable protection for specific clients
|
||||
for client_name in clients:
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if client:
|
||||
# Store original blocked services
|
||||
original_blocked = client.get("blocked_services", {})
|
||||
|
||||
# Clear blocked services temporarily
|
||||
await api.update_client_blocked_services(client_name, [])
|
||||
|
||||
# Schedule restore
|
||||
task = asyncio.create_task(
|
||||
self._delayed_restore_client(api, client_name, original_blocked, duration)
|
||||
)
|
||||
self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
||||
|
||||
async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None:
|
||||
"""Re-enable protection after delay."""
|
||||
await asyncio.sleep(delay)
|
||||
try:
|
||||
if "all" in clients:
|
||||
await api.set_protection(False)
|
||||
_LOGGER.info("Emergency unblock activated globally for %d seconds", duration)
|
||||
else:
|
||||
_LOGGER.info("Emergency unblock activated for clients: %s", clients)
|
||||
await api.set_protection(True)
|
||||
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
||||
raise
|
||||
_LOGGER.error("Failed to re-enable protection: %s", err)
|
||||
|
||||
# Register emergency unblock service
|
||||
hass.services.async_register(
|
||||
"adguard_hub",
|
||||
"emergency_unblock",
|
||||
emergency_unblock_service
|
||||
)
|
||||
async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str,
|
||||
original_blocked: Dict, delay: int) -> None:
|
||||
"""Restore client blocked services after delay."""
|
||||
await asyncio.sleep(delay)
|
||||
try:
|
||||
if isinstance(original_blocked, dict):
|
||||
services = original_blocked.get("ids", [])
|
||||
else:
|
||||
services = original_blocked if original_blocked else []
|
||||
|
||||
_LOGGER.info("AdGuard Control Hub services registered")
|
||||
await api.update_client_blocked_services(client_name, services)
|
||||
_LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to restore client blocking: %s", err)
|
||||
|
||||
async def async_unregister_services(hass: HomeAssistant) -> None:
|
||||
"""Unregister integration services."""
|
||||
hass.services.async_remove("adguard_hub", "emergency_unblock")
|
||||
_LOGGER.info("AdGuard Control Hub services unregistered")
|
||||
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
||||
"""Update multiple clients matching a pattern."""
|
||||
import re
|
||||
|
||||
pattern = call.data[ATTR_CLIENT_PATTERN]
|
||||
settings = call.data[ATTR_SETTINGS]
|
||||
|
||||
_LOGGER.info("Bulk updating clients matching pattern: %s", pattern)
|
||||
|
||||
# Convert pattern to regex
|
||||
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
|
||||
compiled_pattern = re.compile(regex_pattern, re.IGNORECASE)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
coordinator = entry_data["coordinator"]
|
||||
|
||||
try:
|
||||
# Get all clients
|
||||
clients = coordinator.clients
|
||||
|
||||
matching_clients = []
|
||||
for client_name in clients.keys():
|
||||
if compiled_pattern.match(client_name):
|
||||
matching_clients.append(client_name)
|
||||
|
||||
_LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients)
|
||||
|
||||
# Update each matching client
|
||||
for client_name in matching_clients:
|
||||
client = clients[client_name]
|
||||
|
||||
# Prepare update data
|
||||
update_data = {
|
||||
"name": client_name,
|
||||
"data": {**client} # Start with current data
|
||||
}
|
||||
|
||||
# Apply settings
|
||||
if "blocked_services" in settings:
|
||||
blocked_services_data = {
|
||||
"ids": settings["blocked_services"],
|
||||
"schedule": {"time_zone": "Local"}
|
||||
}
|
||||
update_data["data"]["blocked_services"] = blocked_services_data
|
||||
|
||||
if "filtering_enabled" in settings:
|
||||
update_data["data"]["filtering_enabled"] = settings["filtering_enabled"]
|
||||
|
||||
if "safebrowsing_enabled" in settings:
|
||||
update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"]
|
||||
|
||||
if "parental_enabled" in settings:
|
||||
update_data["data"]["parental_enabled"] = settings["parental_enabled"]
|
||||
|
||||
# Update the client
|
||||
await api.update_client(update_data)
|
||||
_LOGGER.info("Updated client: %s", client_name)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to bulk update clients: %s", err)
|
||||
|
||||
async def add_client(self, call: ServiceCall) -> None:
|
||||
"""Add a new client."""
|
||||
client_data = dict(call.data)
|
||||
|
||||
# Convert blocked_services to proper format
|
||||
if "blocked_services" in client_data and client_data["blocked_services"]:
|
||||
blocked_services_data = {
|
||||
"ids": client_data["blocked_services"],
|
||||
"schedule": {"time_zone": "Local"}
|
||||
}
|
||||
client_data["blocked_services"] = blocked_services_data
|
||||
|
||||
_LOGGER.info("Adding new client: %s", client_data["name"])
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
await api.add_client(client_data)
|
||||
_LOGGER.info("Successfully added client: %s", client_data["name"])
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to add client %s: %s", client_data["name"], err)
|
||||
|
||||
async def remove_client(self, call: ServiceCall) -> None:
|
||||
"""Remove a client."""
|
||||
client_name = call.data["name"]
|
||||
|
||||
_LOGGER.info("Removing client: %s", client_name)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
await api.delete_client(client_name)
|
||||
_LOGGER.info("Successfully removed client: %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
|
||||
|
||||
async def schedule_service_block(self, call: ServiceCall) -> None:
|
||||
"""Schedule service blocking with time-based rules."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
schedule = call.data["schedule"]
|
||||
|
||||
_LOGGER.info("Scheduling service blocking for client %s", client_name)
|
||||
|
||||
for entry_data in self.hass.data[DOMAIN].values():
|
||||
api: AdGuardHomeAPI = entry_data["api"]
|
||||
try:
|
||||
await api.update_client_blocked_services(client_name, services, schedule)
|
||||
_LOGGER.info("Successfully scheduled service blocking for %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)
|
||||
|
@@ -1,27 +1,151 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "AdGuard Control Hub",
|
||||
"description": "Connect to your AdGuard Home instance for complete network control",
|
||||
"data": {
|
||||
"host": "AdGuard Home IP Address",
|
||||
"port": "Port (usually 3000)",
|
||||
"username": "Admin Username",
|
||||
"password": "Admin Password",
|
||||
"ssl": "Use HTTPS connection",
|
||||
"verify_ssl": "Verify SSL certificate"
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "AdGuard Control Hub",
|
||||
"description": "Configure your AdGuard Home connection",
|
||||
"data": {
|
||||
"host": "Host",
|
||||
"port": "Port",
|
||||
"username": "Username",
|
||||
"password": "Password",
|
||||
"ssl": "Use SSL",
|
||||
"verify_ssl": "Verify SSL Certificate"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "Failed to connect to AdGuard Home. Please check your host, port, and credentials.",
|
||||
"invalid_auth": "Invalid username or password",
|
||||
"timeout": "Connection timeout. Please check your network connection.",
|
||||
"unknown": "An unexpected error occurred"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "AdGuard Control Hub is already configured for this host and port"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "Failed to connect to AdGuard Home. Check IP address, port, and credentials.",
|
||||
"invalid_auth": "Invalid username or password. Please check your admin credentials.",
|
||||
"unknown": "Unexpected error occurred. Please check logs for details."
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "AdGuard Control Hub Options",
|
||||
"description": "Configure advanced options",
|
||||
"data": {
|
||||
"scan_interval": "Update interval (seconds)",
|
||||
"timeout": "Connection timeout (seconds)"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "This AdGuard Home instance is already configured",
|
||||
"cannot_connect": "Cannot connect to AdGuard Home"
|
||||
"services": {
|
||||
"block_services": {
|
||||
"name": "Block Services",
|
||||
"description": "Block specific services for a client",
|
||||
"fields": {
|
||||
"client_name": {
|
||||
"name": "Client Name",
|
||||
"description": "Name of the client to block services for"
|
||||
},
|
||||
"services": {
|
||||
"name": "Services",
|
||||
"description": "List of services to block"
|
||||
}
|
||||
}
|
||||
},
|
||||
"unblock_services": {
|
||||
"name": "Unblock Services",
|
||||
"description": "Unblock specific services for a client",
|
||||
"fields": {
|
||||
"client_name": {
|
||||
"name": "Client Name",
|
||||
"description": "Name of the client to unblock services for"
|
||||
},
|
||||
"services": {
|
||||
"name": "Services",
|
||||
"description": "List of services to unblock"
|
||||
}
|
||||
}
|
||||
},
|
||||
"emergency_unblock": {
|
||||
"name": "Emergency Unblock",
|
||||
"description": "Temporarily disable blocking for emergency access",
|
||||
"fields": {
|
||||
"duration": {
|
||||
"name": "Duration",
|
||||
"description": "Duration in seconds to keep unblocked"
|
||||
},
|
||||
"clients": {
|
||||
"name": "Clients",
|
||||
"description": "List of client names (use 'all' for global)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"bulk_update_clients": {
|
||||
"name": "Bulk Update Clients",
|
||||
"description": "Update multiple clients matching a pattern",
|
||||
"fields": {
|
||||
"client_pattern": {
|
||||
"name": "Client Pattern",
|
||||
"description": "Pattern to match client names (supports wildcards)"
|
||||
},
|
||||
"settings": {
|
||||
"name": "Settings",
|
||||
"description": "Settings to apply to matching clients"
|
||||
}
|
||||
}
|
||||
},
|
||||
"add_client": {
|
||||
"name": "Add Client",
|
||||
"description": "Add a new client configuration",
|
||||
"fields": {
|
||||
"name": {
|
||||
"name": "Name",
|
||||
"description": "Client name"
|
||||
},
|
||||
"ids": {
|
||||
"name": "IDs",
|
||||
"description": "List of IP addresses or CIDR ranges"
|
||||
},
|
||||
"mac": {
|
||||
"name": "MAC Address",
|
||||
"description": "MAC address (optional)"
|
||||
},
|
||||
"filtering_enabled": {
|
||||
"name": "Filtering Enabled",
|
||||
"description": "Enable DNS filtering for this client"
|
||||
},
|
||||
"blocked_services": {
|
||||
"name": "Blocked Services",
|
||||
"description": "List of services to block"
|
||||
}
|
||||
}
|
||||
},
|
||||
"remove_client": {
|
||||
"name": "Remove Client",
|
||||
"description": "Remove a client configuration",
|
||||
"fields": {
|
||||
"name": {
|
||||
"name": "Name",
|
||||
"description": "Name of the client to remove"
|
||||
}
|
||||
}
|
||||
},
|
||||
"schedule_service_block": {
|
||||
"name": "Schedule Service Block",
|
||||
"description": "Schedule time-based service blocking",
|
||||
"fields": {
|
||||
"client_name": {
|
||||
"name": "Client Name",
|
||||
"description": "Name of the client"
|
||||
},
|
||||
"services": {
|
||||
"name": "Services",
|
||||
"description": "List of services to block"
|
||||
},
|
||||
"schedule": {
|
||||
"name": "Schedule",
|
||||
"description": "Time-based schedule configuration"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,10 +1,13 @@
|
||||
"""Switch platform for AdGuard Control Hub integration."""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.switch import SwitchEntity
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
|
||||
from . import AdGuardControlHubCoordinator
|
||||
from .api import AdGuardHomeAPI
|
||||
from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER
|
||||
@@ -12,7 +15,11 @@ from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MA
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback):
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AdGuard Control Hub switch platform."""
|
||||
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||
@@ -32,6 +39,7 @@ class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity):
|
||||
"""Base class for AdGuard switches."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the switch."""
|
||||
super().__init__(coordinator)
|
||||
self.api = api
|
||||
self._attr_device_info = {
|
||||
@@ -46,31 +54,64 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
|
||||
"""Switch to control global AdGuard protection."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||
"""Initialize the switch."""
|
||||
super().__init__(coordinator, api)
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_protection"
|
||||
self._attr_name = "AdGuard Protection"
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool:
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if protection is enabled."""
|
||||
return self.coordinator.protection_status.get("protection_enabled", False)
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the icon for the switch."""
|
||||
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
await self.api.set_protection(True)
|
||||
await self.coordinator.async_request_refresh()
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return additional state attributes."""
|
||||
status = self.coordinator.protection_status
|
||||
stats = self.coordinator.statistics
|
||||
return {
|
||||
"dns_port": status.get("dns_port", "N/A"),
|
||||
"queries_today": stats.get("num_dns_queries_today", 0),
|
||||
"blocked_today": stats.get("num_blocked_filtering_today", 0),
|
||||
"version": status.get("version", "N/A"),
|
||||
}
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
await self.api.set_protection(False)
|
||||
await self.coordinator.async_request_refresh()
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn on AdGuard protection."""
|
||||
try:
|
||||
await self.api.set_protection(True)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("AdGuard protection enabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to enable AdGuard protection: %s", err)
|
||||
raise
|
||||
|
||||
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||
"""Turn off AdGuard protection."""
|
||||
try:
|
||||
await self.api.set_protection(False)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("AdGuard protection disabled")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to disable AdGuard protection: %s", err)
|
||||
raise
|
||||
|
||||
|
||||
class AdGuardClientSwitch(AdGuardBaseSwitch):
|
||||
"""Switch to control client-specific protection."""
|
||||
|
||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI, client_name: str):
|
||||
def __init__(
|
||||
self,
|
||||
coordinator: AdGuardControlHubCoordinator,
|
||||
api: AdGuardHomeAPI,
|
||||
client_name: str,
|
||||
):
|
||||
"""Initialize the switch."""
|
||||
super().__init__(coordinator, api)
|
||||
self.client_name = client_name
|
||||
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}"
|
||||
@@ -78,16 +119,81 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
|
||||
self._attr_icon = ICON_CLIENT
|
||||
|
||||
@property
|
||||
def is_on(self) -> bool:
|
||||
def is_on(self) -> bool | None:
|
||||
"""Return true if client protection is enabled."""
|
||||
client = self.coordinator.clients.get(self.client_name, {})
|
||||
return client.get("filtering_enabled", True)
|
||||
|
||||
async def async_turn_on(self, **kwargs):
|
||||
# This would update client settings - simplified for basic functionality
|
||||
_LOGGER.info("Would enable protection for %s", self.client_name)
|
||||
await self.coordinator.async_request_refresh()
|
||||
@property
|
||||
def extra_state_attributes(self) -> dict[str, Any]:
|
||||
"""Return additional state attributes."""
|
||||
client = self.coordinator.clients.get(self.client_name, {})
|
||||
blocked_services = client.get("blocked_services", {})
|
||||
|
||||
async def async_turn_off(self, **kwargs):
|
||||
# This would update client settings - simplified for basic functionality
|
||||
_LOGGER.info("Would disable protection for %s", self.client_name)
|
||||
await self.coordinator.async_request_refresh()
|
||||
if isinstance(blocked_services, dict):
|
||||
service_ids = blocked_services.get("ids", [])
|
||||
else:
|
||||
service_ids = blocked_services if blocked_services else []
|
||||
|
||||
return {
|
||||
"client_ids": client.get("ids", []),
|
||||
"mac": client.get("mac", ""),
|
||||
"use_global_settings": client.get("use_global_settings", True),
|
||||
"safebrowsing_enabled": client.get("safebrowsing_enabled", False),
|
||||
"parental_enabled": client.get("parental_enabled", False),
|
||||
"safesearch_enabled": client.get("safesearch_enabled", False),
|
||||
"blocked_services": service_ids,
|
||||
"blocked_services_count": len(service_ids),
|
||||
}
|
||||
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Enable protection for this client."""
|
||||
try:
|
||||
# Get current client data
|
||||
client = await self.api.get_client_by_name(self.client_name)
|
||||
if not client:
|
||||
_LOGGER.error("Client %s not found", self.client_name)
|
||||
return
|
||||
|
||||
# Update client with filtering enabled
|
||||
update_data = {
|
||||
"name": self.client_name,
|
||||
"data": {
|
||||
**client,
|
||||
"filtering_enabled": True,
|
||||
}
|
||||
}
|
||||
|
||||
await self.api.update_client(update_data)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("Enabled protection for client %s", self.client_name)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
|
||||
raise
|
||||
|
||||
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||
"""Disable protection for this client."""
|
||||
try:
|
||||
# Get current client data
|
||||
client = await self.api.get_client_by_name(self.client_name)
|
||||
if not client:
|
||||
_LOGGER.error("Client %s not found", self.client_name)
|
||||
return
|
||||
|
||||
# Update client with filtering disabled
|
||||
update_data = {
|
||||
"name": self.client_name,
|
||||
"data": {
|
||||
**client,
|
||||
"filtering_enabled": False,
|
||||
}
|
||||
}
|
||||
|
||||
await self.api.update_client(update_data)
|
||||
await self.coordinator.async_request_refresh()
|
||||
_LOGGER.info("Disabled protection for client %s", self.client_name)
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
|
||||
raise
|
||||
|
@@ -3,6 +3,7 @@ import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock aiohttp session."""
|
||||
@@ -15,6 +16,7 @@ def mock_session():
|
||||
session.request = AsyncMock(return_value=response)
|
||||
return session
|
||||
|
||||
|
||||
async def test_api_connection(mock_session):
|
||||
"""Test API connection."""
|
||||
api = AdGuardHomeAPI(
|
||||
@@ -28,13 +30,14 @@ async def test_api_connection(mock_session):
|
||||
result = await api.test_connection()
|
||||
assert result is True
|
||||
|
||||
|
||||
async def test_api_get_status(mock_session):
|
||||
"""Test getting status."""
|
||||
api = AdGuardHomeAPI(
|
||||
host="test-host",
|
||||
host="test-host",
|
||||
port=3000,
|
||||
session=mock_session
|
||||
)
|
||||
|
||||
status = await api.get_status()
|
||||
assert status == {"status": "ok"}
|
||||
assert status == {"status": "ok"}
|
||||
|
223
tests/test_integration.py
Normal file
223
tests/test_integration.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Test the complete AdGuard Control Hub integration."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
||||
|
||||
from custom_components.adguard_hub import async_setup_entry, async_unload_entry
|
||||
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
||||
from custom_components.adguard_hub.const import DOMAIN
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry():
|
||||
"""Mock config entry."""
|
||||
return ConfigEntry(
|
||||
version=1,
|
||||
domain=DOMAIN,
|
||||
title="Test AdGuard",
|
||||
data={
|
||||
CONF_HOST: "192.168.1.100",
|
||||
CONF_PORT: 3000,
|
||||
CONF_USERNAME: "admin",
|
||||
CONF_PASSWORD: "password",
|
||||
},
|
||||
source="user",
|
||||
entry_id="test_entry_id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api():
|
||||
"""Mock API instance."""
|
||||
api = MagicMock(spec=AdGuardHomeAPI)
|
||||
api.host = "192.168.1.100"
|
||||
api.port = 3000
|
||||
api.test_connection = AsyncMock(return_value=True)
|
||||
api.get_status = AsyncMock(return_value={
|
||||
"protection_enabled": True,
|
||||
"version": "v0.107.0",
|
||||
"dns_port": 53,
|
||||
"running": True,
|
||||
})
|
||||
api.get_clients = AsyncMock(return_value={
|
||||
"clients": [
|
||||
{
|
||||
"name": "test_client",
|
||||
"ids": ["192.168.1.50"],
|
||||
"filtering_enabled": True,
|
||||
"blocked_services": {"ids": ["youtube"]},
|
||||
}
|
||||
]
|
||||
})
|
||||
api.get_statistics = AsyncMock(return_value={
|
||||
"num_dns_queries": 1000,
|
||||
"num_blocked_filtering": 100,
|
||||
"num_dns_queries_today": 500,
|
||||
"num_blocked_filtering_today": 50,
|
||||
"filtering_rules_count": 50000,
|
||||
"avg_processing_time": 2.5,
|
||||
})
|
||||
return api
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_entry_success(hass: HomeAssistant, mock_config_entry, mock_api):
|
||||
"""Test successful setup of config entry."""
|
||||
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
|
||||
patch("custom_components.adguard_hub.async_get_clientsession"), \
|
||||
patch.object(hass.config_entries, "async_forward_entry_setups", return_value=True):
|
||||
|
||||
result = await async_setup_entry(hass, mock_config_entry)
|
||||
|
||||
assert result is True
|
||||
assert DOMAIN in hass.data
|
||||
assert mock_config_entry.entry_id in hass.data[DOMAIN]
|
||||
assert "coordinator" in hass.data[DOMAIN][mock_config_entry.entry_id]
|
||||
assert "api" in hass.data[DOMAIN][mock_config_entry.entry_id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_entry_connection_failure(hass: HomeAssistant, mock_config_entry):
|
||||
"""Test setup failure due to connection error."""
|
||||
mock_api = MagicMock(spec=AdGuardHomeAPI)
|
||||
mock_api.test_connection = AsyncMock(return_value=False)
|
||||
|
||||
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
|
||||
patch("custom_components.adguard_hub.async_get_clientsession"), \
|
||||
pytest.raises(Exception): # Should raise ConfigEntryNotReady
|
||||
|
||||
await async_setup_entry(hass, mock_config_entry)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_entry(hass: HomeAssistant, mock_config_entry):
|
||||
"""Test unloading of config entry."""
|
||||
# Set up initial data
|
||||
hass.data[DOMAIN] = {
|
||||
mock_config_entry.entry_id: {
|
||||
"coordinator": MagicMock(),
|
||||
"api": MagicMock(),
|
||||
}
|
||||
}
|
||||
|
||||
with patch.object(hass.config_entries, "async_unload_platforms", return_value=True):
|
||||
result = await async_unload_entry(hass, mock_config_entry)
|
||||
|
||||
assert result is True
|
||||
assert mock_config_entry.entry_id not in hass.data[DOMAIN]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coordinator_data_update(hass: HomeAssistant, mock_api):
|
||||
"""Test coordinator data update functionality."""
|
||||
from custom_components.adguard_hub import AdGuardControlHubCoordinator
|
||||
|
||||
coordinator = AdGuardControlHubCoordinator(hass, mock_api)
|
||||
|
||||
# Test successful data update
|
||||
data = await coordinator._async_update_data()
|
||||
|
||||
assert "clients" in data
|
||||
assert "statistics" in data
|
||||
assert "status" in data
|
||||
assert "test_client" in data["clients"]
|
||||
assert data["statistics"]["num_dns_queries"] == 1000
|
||||
assert data["status"]["protection_enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(mock_api):
|
||||
"""Test API error handling."""
|
||||
from custom_components.adguard_hub.api import AdGuardConnectionError, AdGuardAuthError
|
||||
|
||||
# Test connection error
|
||||
mock_api.get_status = AsyncMock(side_effect=AdGuardConnectionError("Connection failed"))
|
||||
|
||||
with pytest.raises(AdGuardConnectionError):
|
||||
await mock_api.get_status()
|
||||
|
||||
# Test auth error
|
||||
mock_api.get_clients = AsyncMock(side_effect=AdGuardAuthError("Auth failed"))
|
||||
|
||||
with pytest.raises(AdGuardAuthError):
|
||||
await mock_api.get_clients()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_services_registration(hass: HomeAssistant):
|
||||
"""Test that services are properly registered."""
|
||||
from custom_components.adguard_hub.services import AdGuardControlHubServices
|
||||
|
||||
services = AdGuardControlHubServices(hass)
|
||||
services.register_services()
|
||||
|
||||
# Check that services are registered
|
||||
assert hass.services.has_service(DOMAIN, "block_services")
|
||||
assert hass.services.has_service(DOMAIN, "unblock_services")
|
||||
assert hass.services.has_service(DOMAIN, "emergency_unblock")
|
||||
assert hass.services.has_service(DOMAIN, "bulk_update_clients")
|
||||
assert hass.services.has_service(DOMAIN, "add_client")
|
||||
assert hass.services.has_service(DOMAIN, "remove_client")
|
||||
|
||||
# Clean up
|
||||
services.unregister_services()
|
||||
|
||||
|
||||
def test_blocked_services_constants():
|
||||
"""Test that blocked services are properly defined."""
|
||||
from custom_components.adguard_hub.const import BLOCKED_SERVICES
|
||||
|
||||
assert "youtube" in BLOCKED_SERVICES
|
||||
assert "netflix" in BLOCKED_SERVICES
|
||||
assert "gaming" in BLOCKED_SERVICES
|
||||
assert "facebook" in BLOCKED_SERVICES
|
||||
|
||||
# Check friendly names are defined
|
||||
assert BLOCKED_SERVICES["youtube"] == "YouTube"
|
||||
assert BLOCKED_SERVICES["netflix"] == "Netflix"
|
||||
|
||||
|
||||
def test_api_endpoints():
|
||||
"""Test that API endpoints are properly defined."""
|
||||
from custom_components.adguard_hub.const import API_ENDPOINTS
|
||||
|
||||
required_endpoints = [
|
||||
"status", "clients", "stats", "protection",
|
||||
"clients_add", "clients_update", "clients_delete"
|
||||
]
|
||||
|
||||
for endpoint in required_endpoints:
|
||||
assert endpoint in API_ENDPOINTS
|
||||
assert API_ENDPOINTS[endpoint].startswith("/")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_operations(mock_api):
|
||||
"""Test client add/update/delete operations."""
|
||||
# Test add client
|
||||
client_data = {
|
||||
"name": "new_client",
|
||||
"ids": ["192.168.1.200"],
|
||||
"filtering_enabled": True,
|
||||
}
|
||||
|
||||
mock_api.add_client = AsyncMock(return_value={"success": True})
|
||||
result = await mock_api.add_client(client_data)
|
||||
assert result["success"] is True
|
||||
|
||||
# Test update client
|
||||
update_data = {
|
||||
"name": "new_client",
|
||||
"data": {"filtering_enabled": False}
|
||||
}
|
||||
|
||||
mock_api.update_client = AsyncMock(return_value={"success": True})
|
||||
result = await mock_api.update_client(update_data)
|
||||
assert result["success"] is True
|
||||
|
||||
# Test delete client
|
||||
mock_api.delete_client = AsyncMock(return_value={"success": True})
|
||||
result = await mock_api.delete_client("new_client")
|
||||
assert result["success"] is True
|
Reference in New Issue
Block a user