fix: multiple fixes
Some checks failed
🧪 Integration Testing / 🔧 Test Integration (2025.9.4, 3.13) (push) Failing after 26s
🛡️ Code Quality & Security Check / 🔍 Code Quality Analysis (push) Failing after 15s

Signed-off-by: Rafal Zielinski <sq4ind@gmail.com>
This commit is contained in:
2025-09-28 15:10:39 +01:00
parent 75f705d4e9
commit 13905df0ee
12 changed files with 1721 additions and 189 deletions

View File

@@ -1,5 +1,5 @@
""" """
🛡️ AdGuard Control Hub for Home Assistant. AdGuard Control Hub for Home Assistant.
Transform your AdGuard Home into a smart network management powerhouse with Transform your AdGuard Home into a smart network management powerhouse with
complete client control, service blocking, and automation capabilities. complete client control, service blocking, and automation capabilities.
@@ -7,14 +7,18 @@ complete client control, service blocking, and automation capabilities.
import asyncio import asyncio
import logging import logging
from datetime import timedelta from datetime import timedelta
from typing import Dict, Any
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .api import AdGuardHomeAPI, AdGuardConnectionError
from .const import DOMAIN, PLATFORMS, SCAN_INTERVAL, CONF_SSL, CONF_VERIFY_SSL from .const import DOMAIN, PLATFORMS, SCAN_INTERVAL, CONF_SSL, CONF_VERIFY_SSL
from .api import AdGuardHomeAPI from .services import AdGuardControlHubServices
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -23,6 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up AdGuard Control Hub from a config entry.""" """Set up AdGuard Control Hub from a config entry."""
session = async_get_clientsession(hass, entry.data.get(CONF_VERIFY_SSL, True)) session = async_get_clientsession(hass, entry.data.get(CONF_VERIFY_SSL, True))
# Create API instance
api = AdGuardHomeAPI( api = AdGuardHomeAPI(
host=entry.data[CONF_HOST], host=entry.data[CONF_HOST],
port=entry.data[CONF_PORT], port=entry.data[CONF_PORT],
@@ -34,16 +39,26 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Test the connection # Test the connection
try: try:
await api.test_connection() if not await api.test_connection():
_LOGGER.info("Successfully connected to AdGuard Home at %s:%s", raise ConfigEntryNotReady("Unable to connect to AdGuard Home")
entry.data[CONF_HOST], entry.data[CONF_PORT])
_LOGGER.info(
"Successfully connected to AdGuard Home at %s:%s",
entry.data[CONF_HOST],
entry.data[CONF_PORT]
)
except Exception as err: except Exception as err:
_LOGGER.error("Failed to connect to AdGuard Home: %s", err) _LOGGER.error("Failed to connect to AdGuard Home: %s", err)
raise ConfigEntryNotReady(f"Unable to connect: {err}") raise ConfigEntryNotReady(f"Unable to connect: {err}") from err
# Create update coordinator # Create update coordinator
coordinator = AdGuardControlHubCoordinator(hass, api) coordinator = AdGuardControlHubCoordinator(hass, api)
await coordinator.async_config_entry_first_refresh()
try:
await coordinator.async_config_entry_first_refresh()
except Exception as err:
_LOGGER.error("Failed to perform initial data refresh: %s", err)
raise ConfigEntryNotReady(f"Failed to fetch initial data: {err}") from err
# Store data # Store data
hass.data.setdefault(DOMAIN, {}) hass.data.setdefault(DOMAIN, {})
@@ -53,9 +68,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
} }
# Set up platforms # Set up platforms
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) try:
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
except Exception as err:
_LOGGER.error("Failed to set up platforms: %s", err)
# Clean up on failure
hass.data[DOMAIN].pop(entry.entry_id)
raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err
_LOGGER.info("AdGuard Control Hub setup complete") # Register services (only once, not per config entry)
if not hass.services.has_service(DOMAIN, "block_services"):
services = AdGuardControlHubServices(hass)
services.register_services()
# Store services instance for cleanup
hass.data.setdefault(f"{DOMAIN}_services", services)
_LOGGER.info("AdGuard Control Hub setup complete for %s:%s",
entry.data[CONF_HOST], entry.data[CONF_PORT])
return True return True
@@ -64,8 +94,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok: if unload_ok:
# Remove this entry's data
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
# Unregister services if this was the last entry
if not hass.data[DOMAIN]: # No more entries
services = hass.data.get(f"{DOMAIN}_services")
if services:
services.unregister_services()
hass.data.pop(f"{DOMAIN}_services", None)
# Also clean up the empty domain entry
hass.data.pop(DOMAIN, None)
return unload_ok return unload_ok
@@ -81,36 +122,54 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
update_interval=timedelta(seconds=SCAN_INTERVAL), update_interval=timedelta(seconds=SCAN_INTERVAL),
) )
self.api = api self.api = api
self._clients = {} self._clients: Dict[str, Any] = {}
self._statistics = {} self._statistics: Dict[str, Any] = {}
self._protection_status = {} self._protection_status: Dict[str, Any] = {}
async def _async_update_data(self): async def _async_update_data(self) -> Dict[str, Any]:
"""Fetch data from AdGuard Home.""" """Fetch data from AdGuard Home."""
try: try:
# Fetch all data concurrently for better performance # Fetch all data concurrently for better performance
results = await asyncio.gather( tasks = [
self.api.get_clients(), self.api.get_clients(),
self.api.get_statistics(), self.api.get_statistics(),
self.api.get_status(), self.api.get_status(),
return_exceptions=True, ]
)
results = await asyncio.gather(*tasks, return_exceptions=True)
clients, statistics, status = results clients, statistics, status = results
# Handle any exceptions # Handle any exceptions in individual requests
for i, result in enumerate(results): for i, result in enumerate(results):
if isinstance(result, Exception): if isinstance(result, Exception):
endpoint_names = ["clients", "statistics", "status"] endpoint_names = ["clients", "statistics", "status"]
_LOGGER.warning("Error fetching %s: %s", endpoint_names[i], result) _LOGGER.warning(
"Error fetching %s from %s:%s: %s",
endpoint_names[i],
self.api.host,
self.api.port,
result
)
# Update stored data (use empty dict if fetch failed) # Update stored data (use empty dict if fetch failed)
self._clients = { if not isinstance(clients, Exception):
client["name"]: client self._clients = {
for client in (clients.get("clients", []) if not isinstance(clients, Exception) else []) client["name"]: client
} for client in clients.get("clients", [])
self._statistics = statistics if not isinstance(statistics, Exception) else {} if client.get("name") # Ensure client has a name
self._protection_status = status if not isinstance(status, Exception) else {} }
else:
_LOGGER.warning("Failed to update clients data, keeping previous data")
if not isinstance(statistics, Exception):
self._statistics = statistics
else:
_LOGGER.warning("Failed to update statistics data, keeping previous data")
if not isinstance(status, Exception):
self._protection_status = status
else:
_LOGGER.warning("Failed to update status data, keeping previous data")
return { return {
"clients": self._clients, "clients": self._clients,
@@ -118,20 +177,40 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
"status": self._protection_status, "status": self._protection_status,
} }
except AdGuardConnectionError as err:
raise UpdateFailed(f"Connection error to AdGuard Home: {err}") from err
except Exception as err: except Exception as err:
raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}") raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}") from err
@property @property
def clients(self): def clients(self) -> Dict[str, Any]:
"""Return clients data.""" """Return clients data."""
return self._clients return self._clients
@property @property
def statistics(self): def statistics(self) -> Dict[str, Any]:
"""Return statistics data.""" """Return statistics data."""
return self._statistics return self._statistics
@property @property
def protection_status(self): def protection_status(self) -> Dict[str, Any]:
"""Return protection status data.""" """Return protection status data."""
return self._protection_status return self._protection_status
def get_client(self, client_name: str) -> Dict[str, Any] | None:
"""Get a specific client by name."""
return self._clients.get(client_name)
def has_client(self, client_name: str) -> bool:
"""Check if a client exists."""
return client_name in self._clients
@property
def client_count(self) -> int:
"""Return the number of clients."""
return len(self._clients)
@property
def is_protection_enabled(self) -> bool:
"""Return True if protection is enabled."""
return self._protection_status.get("protection_enabled", False)

View File

@@ -1,102 +1,207 @@
"""API wrapper for AdGuard Home.""" """API wrapper for AdGuard Home."""
import asyncio
import logging import logging
from typing import Any from typing import Any, Dict, Optional
import aiohttp import aiohttp
from aiohttp import BasicAuth from aiohttp import BasicAuth, ClientError, ClientTimeout
from .const import API_ENDPOINTS from .const import API_ENDPOINTS
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Custom exceptions
class AdGuardHomeError(Exception):
"""Base exception for AdGuard Home API."""
pass
class AdGuardConnectionError(AdGuardHomeError):
"""Exception for connection errors."""
pass
class AdGuardAuthError(AdGuardHomeError):
"""Exception for authentication errors."""
pass
class AdGuardNotFoundError(AdGuardHomeError):
"""Exception for not found errors."""
pass
class AdGuardHomeAPI: class AdGuardHomeAPI:
"""API wrapper for AdGuard Home.""" """API wrapper for AdGuard Home."""
def __init__(self, host: str, port: int = 3000, username: str = None, def __init__(
password: str = None, ssl: bool = False, session=None): self,
host: str,
port: int = 3000,
username: Optional[str] = None,
password: Optional[str] = None,
ssl: bool = False,
session: Optional[aiohttp.ClientSession] = None,
timeout: int = 10,
):
"""Initialize the API wrapper."""
self.host = host self.host = host
self.port = port self.port = port
self.username = username self.username = username
self.password = password self.password = password
self.ssl = ssl self.ssl = ssl
self.session = session self._session = session
self._timeout = ClientTimeout(total=timeout)
protocol = "https" if ssl else "http" protocol = "https" if ssl else "http"
self.base_url = f"{protocol}://{host}:{port}" self.base_url = f"{protocol}://{host}:{port}"
self._own_session = session is None
async def _request(self, method: str, endpoint: str, data: dict = None) -> dict: async def __aenter__(self):
"""Make an API request.""" """Async context manager entry."""
if self._own_session:
self._session = aiohttp.ClientSession(timeout=self._timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if self._own_session and self._session:
await self._session.close()
@property
def session(self) -> aiohttp.ClientSession:
"""Get the session, creating one if needed."""
if not self._session:
self._session = aiohttp.ClientSession(timeout=self._timeout)
return self._session
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
"""Make an API request with comprehensive error handling."""
url = f"{self.base_url}{endpoint}" url = f"{self.base_url}{endpoint}"
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
auth = None auth = None
if self.username and self.password: if self.username and self.password:
auth = BasicAuth(self.username, self.password) auth = BasicAuth(self.username, self.password)
try: try:
async with self.session.request(method, url, json=data, headers=headers, auth=auth) as response: async with self.session.request(
method, url, json=data, headers=headers, auth=auth
) as response:
# Handle different HTTP status codes
if response.status == 401:
raise AdGuardAuthError("Authentication failed - check username/password")
elif response.status == 403:
raise AdGuardAuthError("Access forbidden - insufficient permissions")
elif response.status == 404:
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
elif response.status >= 500:
raise AdGuardConnectionError(f"Server error {response.status}")
response.raise_for_status() response.raise_for_status()
# Handle empty responses
if response.status == 204 or not response.content_length: if response.status == 204 or not response.content_length:
return {} return {}
return await response.json()
try:
return await response.json()
except aiohttp.ContentTypeError:
# Handle non-JSON responses
text = await response.text()
_LOGGER.warning("Non-JSON response received: %s", text)
return {"response": text}
except asyncio.TimeoutError as err:
raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}")
except ClientError as err:
raise AdGuardConnectionError(f"Client error: {err}")
except Exception as err: except Exception as err:
_LOGGER.error("Error communicating with AdGuard Home: %s", err) _LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err)
raise raise AdGuardHomeError(f"Unexpected error: {err}")
async def test_connection(self) -> bool: async def test_connection(self) -> bool:
"""Test the connection.""" """Test the connection to AdGuard Home."""
try: try:
await self._request("GET", API_ENDPOINTS["status"]) await self._request("GET", API_ENDPOINTS["status"])
return True return True
except: except Exception as err:
_LOGGER.debug("Connection test failed: %s", err)
return False return False
async def get_status(self) -> dict: async def get_status(self) -> Dict[str, Any]:
"""Get server status.""" """Get server status information."""
return await self._request("GET", API_ENDPOINTS["status"]) return await self._request("GET", API_ENDPOINTS["status"])
async def get_clients(self) -> dict: async def get_clients(self) -> Dict[str, Any]:
"""Get all clients.""" """Get all configured clients."""
return await self._request("GET", API_ENDPOINTS["clients"]) return await self._request("GET", API_ENDPOINTS["clients"])
async def get_statistics(self) -> dict: async def get_statistics(self) -> Dict[str, Any]:
"""Get statistics.""" """Get DNS query statistics."""
return await self._request("GET", API_ENDPOINTS["stats"]) return await self._request("GET", API_ENDPOINTS["stats"])
async def set_protection(self, enabled: bool) -> dict: async def set_protection(self, enabled: bool) -> Dict[str, Any]:
"""Enable or disable protection.""" """Enable or disable AdGuard protection."""
data = {"enabled": enabled} data = {"enabled": enabled}
return await self._request("POST", API_ENDPOINTS["protection"], data) return await self._request("POST", API_ENDPOINTS["protection"], data)
async def add_client(self, client_data: dict) -> dict: async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
"""Add a new client.""" """Add a new client configuration."""
# Validate required fields
if "name" not in client_data:
raise ValueError("Client name is required")
if "ids" not in client_data or not client_data["ids"]:
raise ValueError("Client IDs are required")
return await self._request("POST", API_ENDPOINTS["clients_add"], client_data) return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
async def update_client(self, client_data: dict) -> dict: async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
"""Update an existing client.""" """Update an existing client configuration."""
if "name" not in client_data:
raise ValueError("Client name is required for update")
if "data" not in client_data:
raise ValueError("Client data is required for update")
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data) return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
async def delete_client(self, client_name: str) -> dict: async def delete_client(self, client_name: str) -> Dict[str, Any]:
"""Delete a client.""" """Delete a client configuration."""
if not client_name:
raise ValueError("Client name is required")
data = {"name": client_name} data = {"name": client_name}
return await self._request("POST", API_ENDPOINTS["clients_delete"], data) return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
async def get_client_by_name(self, client_name: str) -> dict: async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]:
"""Get a specific client by name.""" """Get a specific client by name."""
clients_data = await self.get_clients() if not client_name:
clients = clients_data.get("clients", []) return None
for client in clients: try:
if client.get("name") == client_name: clients_data = await self.get_clients()
return client clients = clients_data.get("clients", [])
return None for client in clients:
if client.get("name") == client_name:
return client
async def update_client_blocked_services(self, client_name: str, blocked_services: list, return None
schedule: dict = None) -> dict: except Exception as err:
_LOGGER.error("Failed to get client %s: %s", client_name, err)
return None
async def update_client_blocked_services(
self,
client_name: str,
blocked_services: list,
schedule: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Update blocked services for a specific client.""" """Update blocked services for a specific client."""
if not client_name:
raise ValueError("Client name is required")
client = await self.get_client_by_name(client_name) client = await self.get_client_by_name(client_name)
if not client: if not client:
raise ValueError(f"Client '{client_name}' not found") raise AdGuardNotFoundError(f"Client '{client_name}' not found")
# Prepare the blocked services data # Prepare the blocked services data with proper structure
if schedule: if schedule:
blocked_services_data = { blocked_services_data = {
"ids": blocked_services, "ids": blocked_services,
@@ -110,7 +215,7 @@ class AdGuardHomeAPI:
} }
} }
# Update the client # Update the client with new blocked services
update_data = { update_data = {
"name": client_name, "name": client_name,
"data": { "data": {
@@ -121,18 +226,23 @@ class AdGuardHomeAPI:
return await self.update_client(update_data) return await self.update_client(update_data)
async def toggle_client_service(self, client_name: str, service_id: str, enabled: bool) -> dict: async def toggle_client_service(
self, client_name: str, service_id: str, enabled: bool
) -> Dict[str, Any]:
"""Toggle a specific service for a client.""" """Toggle a specific service for a client."""
if not client_name or not service_id:
raise ValueError("Client name and service ID are required")
client = await self.get_client_by_name(client_name) client = await self.get_client_by_name(client_name)
if not client: if not client:
raise ValueError(f"Client '{client_name}' not found") raise AdGuardNotFoundError(f"Client '{client_name}' not found")
# Get current blocked services # Get current blocked services
blocked_services = client.get("blocked_services", {}) blocked_services = client.get("blocked_services", {})
if isinstance(blocked_services, dict): if isinstance(blocked_services, dict):
service_ids = blocked_services.get("ids", []) service_ids = blocked_services.get("ids", [])
else: else:
# Handle old format (list) # Handle legacy format (direct list)
service_ids = blocked_services if blocked_services else [] service_ids = blocked_services if blocked_services else []
# Update the service list # Update the service list
@@ -142,3 +252,12 @@ class AdGuardHomeAPI:
service_ids.remove(service_id) service_ids.remove(service_id)
return await self.update_client_blocked_services(client_name, service_ids) return await self.update_client_blocked_services(client_name, service_ids)
async def get_blocked_services(self) -> Dict[str, Any]:
"""Get available blocked services."""
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
async def close(self) -> None:
"""Close the API session if we own it."""
if self._own_session and self._session:
await self._session.close()

View File

@@ -0,0 +1,166 @@
"""Binary sensor platform for AdGuard Control Hub integration."""
import logging
from typing import Any
from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AdGuardControlHubCoordinator
from .api import AdGuardHomeAPI
from .const import DOMAIN, MANUFACTURER, ICON_PROTECTION, ICON_PROTECTION_OFF
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up AdGuard Control Hub binary sensor platform."""
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
entities = [
AdGuardProtectionBinarySensor(coordinator, api),
AdGuardFilteringBinarySensor(coordinator, api),
AdGuardSafeBrowsingBinarySensor(coordinator, api),
AdGuardParentalControlBinarySensor(coordinator, api),
AdGuardSafeSearchBinarySensor(coordinator, api),
]
async_add_entities(entities)
class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity):
"""Base class for AdGuard binary sensors."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator)
self.api = api
self._attr_device_info = {
"identifiers": {(DOMAIN, f"{api.host}:{api.port}")},
"name": f"AdGuard Control Hub ({api.host})",
"manufacturer": MANUFACTURER,
"model": "AdGuard Home",
}
class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show AdGuard protection status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled"
self._attr_name = "AdGuard Protection Status"
self._attr_device_class = BinarySensorDeviceClass.RUNNING
@property
def is_on(self) -> bool | None:
"""Return true if protection is enabled."""
return self.coordinator.protection_status.get("protection_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
status = self.coordinator.protection_status
return {
"dns_port": status.get("dns_port", "N/A"),
"http_port": status.get("http_port", "N/A"),
"version": status.get("version", "N/A"),
"running": status.get("running", False),
}
class AdGuardFilteringBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show DNS filtering status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_filtering_enabled"
self._attr_name = "AdGuard DNS Filtering"
self._attr_device_class = BinarySensorDeviceClass.RUNNING
@property
def is_on(self) -> bool | None:
"""Return true if DNS filtering is enabled."""
return self.coordinator.protection_status.get("filtering_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:dns" if self.is_on else "mdi:dns-off"
class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Safe Browsing status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled"
self._attr_name = "AdGuard Safe Browsing"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
@property
def is_on(self) -> bool | None:
"""Return true if Safe Browsing is enabled."""
return self.coordinator.protection_status.get("safebrowsing_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:security" if self.is_on else "mdi:security-off"
class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Parental Control status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled"
self._attr_name = "AdGuard Parental Control"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
@property
def is_on(self) -> bool | None:
"""Return true if Parental Control is enabled."""
return self.coordinator.protection_status.get("parental_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:account-child" if self.is_on else "mdi:account-child-outline"
class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Safe Search status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled"
self._attr_name = "AdGuard Safe Search"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
@property
def is_on(self) -> bool | None:
"""Return true if Safe Search is enabled."""
return self.coordinator.protection_status.get("safesearch_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:search-web" if self.is_on else "mdi:web-off"

View File

@@ -1,73 +1,128 @@
"""Config flow for AdGuard Control Hub integration.""" """Config flow for AdGuard Control Hub integration."""
import asyncio
import logging import logging
from typing import Any from typing import Any, Dict, Optional
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .api import AdGuardHomeAPI from homeassistant.data_entry_flow import FlowResult
from .const import CONF_SSL, CONF_VERIFY_SSL, DEFAULT_PORT, DEFAULT_SSL, DEFAULT_VERIFY_SSL, DOMAIN import homeassistant.helpers.config_validation as cv
from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError
from .const import (
CONF_SSL,
CONF_VERIFY_SSL,
DEFAULT_PORT,
DEFAULT_SSL,
DEFAULT_VERIFY_SSL,
DOMAIN,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
STEP_USER_DATA_SCHEMA = vol.Schema({ STEP_USER_DATA_SCHEMA = vol.Schema({
vol.Required(CONF_HOST): str, vol.Required(CONF_HOST): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): int, vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_USERNAME): str, vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PASSWORD): str, vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_SSL, default=DEFAULT_SSL): bool, vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): bool, vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
}) })
async def validate_input(hass, data: dict) -> dict:
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the user input allows us to connect.""" """Validate the user input allows us to connect."""
# Normalize host
host = data[CONF_HOST].strip()
if not host:
raise InvalidHost("Host cannot be empty")
# Remove protocol if provided
if host.startswith(("http://", "https://")):
host = host.split("://", 1)[1]
data[CONF_HOST] = host
# Validate port
port = data[CONF_PORT]
if not (1 <= port <= 65535):
raise InvalidPort("Port must be between 1 and 65535")
# Create session with appropriate SSL settings
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True)) session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
# Create API instance
api = AdGuardHomeAPI( api = AdGuardHomeAPI(
host=data[CONF_HOST], host=host,
port=data[CONF_PORT], port=port,
username=data.get(CONF_USERNAME), username=data.get(CONF_USERNAME),
password=data.get(CONF_PASSWORD), password=data.get(CONF_PASSWORD),
ssl=data.get(CONF_SSL, False), ssl=data.get(CONF_SSL, False),
session=session, session=session,
timeout=10, # 10 second timeout for setup
) )
# Test the connection # Test the connection
if not await api.test_connection():
raise CannotConnect
# Get server info
try: try:
status = await api.get_status() if not await api.test_connection():
version = status.get("version", "unknown") raise CannotConnect("Failed to connect to AdGuard Home")
return {
"title": f"AdGuard Control Hub ({data[CONF_HOST]})", # Get additional server info if possible
"version": version try:
} status = await api.get_status()
except Exception as err: version = status.get("version", "unknown")
_LOGGER.exception("Unexpected exception: %s", err) dns_port = status.get("dns_port", "N/A")
return {
"title": f"AdGuard Control Hub ({host})",
"version": version,
"dns_port": dns_port,
"host": host,
}
except Exception as err:
_LOGGER.warning("Could not get server status, but connection works: %s", err)
return {
"title": f"AdGuard Control Hub ({host})",
"version": "unknown",
"dns_port": "N/A",
"host": host,
}
except AdGuardAuthError as err:
_LOGGER.error("Authentication failed: %s", err)
raise InvalidAuth from err
except AdGuardConnectionError as err:
_LOGGER.error("Connection failed: %s", err)
if "timeout" in str(err).lower():
raise Timeout from err
raise CannotConnect from err raise CannotConnect from err
except asyncio.TimeoutError as err:
_LOGGER.error("Connection timeout: %s", err)
raise Timeout from err
except Exception as err:
_LOGGER.exception("Unexpected error during validation: %s", err)
raise CannotConnect from err
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for AdGuard Control Hub.""" """Handle a config flow for AdGuard Control Hub."""
VERSION = 1 VERSION = 1
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL
async def async_step_user(self, user_input: dict[str, Any] | None = None): async def async_step_user(
self, user_input: Optional[Dict[str, Any]] = None
) -> FlowResult:
"""Handle the initial step.""" """Handle the initial step."""
errors: dict[str, str] = {} errors: Dict[str, str] = {}
if user_input is not None: if user_input is not None:
try: try:
info = await validate_input(self.hass, user_input) info = await validate_input(self.hass, user_input)
except CannotConnect:
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
# Create unique ID based on host and port # Create unique ID based on host and port
unique_id = f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}" unique_id = f"{info['host']}:{user_input[CONF_PORT]}"
await self.async_set_unique_id(unique_id) await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
@@ -76,11 +131,83 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
data=user_input, data=user_input,
) )
except CannotConnect:
errors["base"] = "cannot_connect"
except InvalidAuth:
errors["base"] = "invalid_auth"
except InvalidHost:
errors[CONF_HOST] = "invalid_host"
except InvalidPort:
errors[CONF_PORT] = "invalid_port"
except Timeout:
errors["base"] = "timeout"
except Exception:
_LOGGER.exception("Unexpected exception during config flow")
errors["base"] = "unknown"
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
data_schema=STEP_USER_DATA_SCHEMA, data_schema=STEP_USER_DATA_SCHEMA,
errors=errors, errors=errors,
) )
async def async_step_import(self, import_info: Dict[str, Any]) -> FlowResult:
"""Handle configuration import."""
return await self.async_step_user(import_info)
@staticmethod
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return OptionsFlowHandler(config_entry)
class OptionsFlowHandler(config_entries.OptionsFlow):
"""Handle options flow for AdGuard Control Hub."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
async def async_step_init(
self, user_input: Optional[Dict[str, Any]] = None
) -> FlowResult:
"""Handle options flow."""
if user_input is not None:
return self.async_create_entry(title="", data=user_input)
options_schema = vol.Schema({
vol.Optional(
"scan_interval",
default=self.config_entry.options.get("scan_interval", 30),
): vol.All(vol.Coerce(int), vol.Range(min=10, max=300)),
vol.Optional(
"timeout",
default=self.config_entry.options.get("timeout", 10),
): vol.All(vol.Coerce(int), vol.Range(min=5, max=60)),
})
return self.async_show_form(
step_id="init",
data_schema=options_schema,
)
# Custom exceptions
class CannotConnect(Exception): class CannotConnect(Exception):
"""Error to indicate we cannot connect.""" """Error to indicate we cannot connect."""
class InvalidAuth(Exception):
"""Error to indicate there is invalid auth."""
class InvalidHost(Exception):
"""Error to indicate invalid host."""
class InvalidPort(Exception):
"""Error to indicate invalid port."""
class Timeout(Exception):
"""Error to indicate connection timeout."""

View File

@@ -1,14 +1,14 @@
{ {
"domain": "adguard_hub", "domain": "adguard_hub",
"name": "AdGuard Control Hub", "name": "AdGuard Control Hub",
"codeowners": ["@sq4ind"], "codeowners": ["@sq4ind"],
"config_flow": true, "config_flow": true,
"dependencies": [], "dependencies": [],
"documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub", "documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub",
"integration_type": "hub", "integration_type": "hub",
"iot_class": "local_polling", "iot_class": "local_polling",
"requirements": [ "requirements": [
"aiohttp>=3.8.0" "aiohttp>=3.8.0"
], ],
"version": "1.0.0" "version": "1.0.0"
} }

View File

@@ -0,0 +1,185 @@
"""Sensor platform for AdGuard Control Hub integration."""
import logging
from datetime import datetime, timezone
from typing import Any
from homeassistant.components.sensor import SensorEntity, SensorDeviceClass, SensorStateClass
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import PERCENTAGE
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AdGuardControlHubCoordinator
from .api import AdGuardHomeAPI
from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up AdGuard Control Hub sensor platform."""
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
entities = [
AdGuardQueriesCounterSensor(coordinator, api),
AdGuardBlockedCounterSensor(coordinator, api),
AdGuardBlockingPercentageSensor(coordinator, api),
AdGuardRuleCountSensor(coordinator, api),
AdGuardClientCountSensor(coordinator, api),
AdGuardUpstreamAverageTimeSensor(coordinator, api),
]
async_add_entities(entities)
class AdGuardBaseSensor(CoordinatorEntity, SensorEntity):
"""Base class for AdGuard sensors."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator)
self.api = api
self._attr_device_info = {
"identifiers": {(DOMAIN, f"{api.host}:{api.port}")},
"name": f"AdGuard Control Hub ({api.host})",
"manufacturer": MANUFACTURER,
"model": "AdGuard Home",
}
class AdGuardQueriesCounterSensor(AdGuardBaseSensor):
"""Sensor to track DNS queries count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_dns_queries"
self._attr_name = "AdGuard DNS Queries"
self._attr_icon = ICON_STATISTICS
self._attr_state_class = SensorStateClass.TOTAL_INCREASING
self._attr_native_unit_of_measurement = "queries"
@property
def native_value(self) -> int | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("num_dns_queries", 0)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
stats = self.coordinator.statistics
return {
"queries_today": stats.get("num_dns_queries_today", 0),
"queries_blocked_today": stats.get("num_blocked_filtering_today", 0),
"last_updated": datetime.now(timezone.utc).isoformat(),
}
class AdGuardBlockedCounterSensor(AdGuardBaseSensor):
"""Sensor to track blocked queries count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries"
self._attr_name = "AdGuard Blocked Queries"
self._attr_icon = "mdi:shield-check"
self._attr_state_class = SensorStateClass.TOTAL_INCREASING
self._attr_native_unit_of_measurement = "queries"
@property
def native_value(self) -> int | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("num_blocked_filtering", 0)
class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
"""Sensor to track blocking percentage."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage"
self._attr_name = "AdGuard Blocking Percentage"
self._attr_icon = "mdi:percent"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = PERCENTAGE
self._attr_device_class = SensorDeviceClass.POWER_FACTOR
@property
def native_value(self) -> float | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
total_queries = stats.get("num_dns_queries", 0)
blocked_queries = stats.get("num_blocked_filtering", 0)
if total_queries == 0:
return 0
percentage = (blocked_queries / total_queries) * 100
return round(percentage, 2)
class AdGuardRuleCountSensor(AdGuardBaseSensor):
"""Sensor to track filtering rules count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_rules_count"
self._attr_name = "AdGuard Rules Count"
self._attr_icon = "mdi:format-list-numbered"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "rules"
@property
def native_value(self) -> int | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("filtering_rules_count", 0)
class AdGuardClientCountSensor(AdGuardBaseSensor):
"""Sensor to track active clients count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_clients_count"
self._attr_name = "AdGuard Clients Count"
self._attr_icon = "mdi:account-multiple"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "clients"
@property
def native_value(self) -> int | None:
"""Return the state of the sensor."""
return len(self.coordinator.clients)
class AdGuardUpstreamAverageTimeSensor(AdGuardBaseSensor):
"""Sensor to track upstream servers average response time."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_upstream_response_time"
self._attr_name = "AdGuard Upstream Response Time"
self._attr_icon = "mdi:timer"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "ms"
self._attr_device_class = SensorDeviceClass.DURATION
@property
def native_value(self) -> float | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("avg_processing_time", 0)

View File

@@ -1,38 +1,438 @@
"""Services for AdGuard Control Hub integration.""" """Service implementations for AdGuard Control Hub integration."""
import asyncio
import logging import logging
from homeassistant.core import HomeAssistant from datetime import datetime, timedelta
from typing import Any, Dict, List
import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import config_validation as cv
from .api import AdGuardHomeAPI from .api import AdGuardHomeAPI
from .const import (
DOMAIN,
BLOCKED_SERVICES,
ATTR_CLIENT_NAME,
ATTR_SERVICES,
ATTR_DURATION,
ATTR_CLIENTS,
ATTR_CLIENT_PATTERN,
ATTR_SETTINGS,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_register_services(hass: HomeAssistant, api: AdGuardHomeAPI) -> None: # Service schemas
"""Register integration services.""" SCHEMA_BLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
async def emergency_unblock_service(call): SCHEMA_UNBLOCK_SERVICES = vol.Schema({
"""Emergency unblock service.""" vol.Required(ATTR_CLIENT_NAME): cv.string,
duration = call.data.get("duration", 300) vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
clients = call.data.get("clients", ["all"]) })
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
vol.Required(ATTR_DURATION): cv.positive_int,
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
})
SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({
vol.Required(ATTR_CLIENT_PATTERN): cv.string,
vol.Required(ATTR_SETTINGS): vol.Schema({
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
vol.Optional("filtering_enabled"): cv.boolean,
vol.Optional("safebrowsing_enabled"): cv.boolean,
vol.Optional("parental_enabled"): cv.boolean,
}),
})
SCHEMA_ADD_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("mac"): cv.string,
vol.Optional("use_global_settings"): cv.boolean,
vol.Optional("filtering_enabled"): cv.boolean,
vol.Optional("parental_enabled"): cv.boolean,
vol.Optional("safebrowsing_enabled"): cv.boolean,
vol.Optional("safesearch_enabled"): cv.boolean,
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_REMOVE_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
})
SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
vol.Required("schedule"): vol.Schema({
vol.Optional("time_zone", default="Local"): cv.string,
vol.Optional("sun"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("mon"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("tue"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("wed"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("thu"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("fri"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("sat"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
}),
})
# Service names
SERVICE_BLOCK_SERVICES = "block_services"
SERVICE_UNBLOCK_SERVICES = "unblock_services"
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
SERVICE_ADD_CLIENT = "add_client"
SERVICE_REMOVE_CLIENT = "remove_client"
SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block"
class AdGuardControlHubServices:
"""Handle services for AdGuard Control Hub."""
def __init__(self, hass: HomeAssistant):
"""Initialize the services."""
self.hass = hass
self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {}
def register_services(self) -> None:
"""Register all services."""
self.hass.services.register(
DOMAIN,
SERVICE_BLOCK_SERVICES,
self.block_services,
schema=SCHEMA_BLOCK_SERVICES,
)
self.hass.services.register(
DOMAIN,
SERVICE_UNBLOCK_SERVICES,
self.unblock_services,
schema=SCHEMA_UNBLOCK_SERVICES,
)
self.hass.services.register(
DOMAIN,
SERVICE_EMERGENCY_UNBLOCK,
self.emergency_unblock,
schema=SCHEMA_EMERGENCY_UNBLOCK,
)
self.hass.services.register(
DOMAIN,
SERVICE_BULK_UPDATE_CLIENTS,
self.bulk_update_clients,
schema=SCHEMA_BULK_UPDATE_CLIENTS,
)
self.hass.services.register(
DOMAIN,
SERVICE_ADD_CLIENT,
self.add_client,
schema=SCHEMA_ADD_CLIENT,
)
self.hass.services.register(
DOMAIN,
SERVICE_REMOVE_CLIENT,
self.remove_client,
schema=SCHEMA_REMOVE_CLIENT,
)
self.hass.services.register(
DOMAIN,
SERVICE_SCHEDULE_SERVICE_BLOCK,
self.schedule_service_block,
schema=SCHEMA_SCHEDULE_SERVICE_BLOCK,
)
def unregister_services(self) -> None:
"""Unregister all services."""
services = [
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_BULK_UPDATE_CLIENTS,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_SCHEDULE_SERVICE_BLOCK,
]
for service in services:
self.hass.services.remove(DOMAIN, service)
def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI:
"""Get API instance for a specific config entry."""
return self.hass.data[DOMAIN][entry_id]["api"]
async def block_services(self, call: ServiceCall) -> None:
"""Block services for a specific client."""
client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES]
_LOGGER.info("Blocking services %s for client %s", services, client_name)
# Get all API instances (for multiple AdGuard instances)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
# Get current client data
client = await api.get_client_by_name(client_name)
if not client:
_LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port)
continue
# Get current blocked services
current_blocked = client.get("blocked_services", {})
if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked if current_blocked else []
# Add new services to block
updated_services = list(set(current_services + services))
# Update client
await api.update_client_blocked_services(client_name, updated_services)
_LOGGER.info("Successfully blocked services for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
async def unblock_services(self, call: ServiceCall) -> None:
"""Unblock services for a specific client."""
client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES]
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
# Get all API instances
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
# Get current client data
client = await api.get_client_by_name(client_name)
if not client:
continue
# Get current blocked services
current_blocked = client.get("blocked_services", {})
if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked if current_blocked else []
# Remove services to unblock
updated_services = [s for s in current_services if s not in services]
# Update client
await api.update_client_blocked_services(client_name, updated_services)
_LOGGER.info("Successfully unblocked services for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
async def emergency_unblock(self, call: ServiceCall) -> None:
"""Emergency unblock - temporarily disable protection."""
duration = call.data[ATTR_DURATION] # seconds
clients = call.data[ATTR_CLIENTS]
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
# Cancel any existing emergency unblock tasks
for task in self._emergency_unblock_tasks.values():
task.cancel()
self._emergency_unblock_tasks.clear()
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
if "all" in clients:
# Disable global protection
await api.set_protection(False)
# Schedule re-enable
task = asyncio.create_task(
self._delayed_enable_protection(api, duration)
)
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
else:
# Disable protection for specific clients
for client_name in clients:
client = await api.get_client_by_name(client_name)
if client:
# Store original blocked services
original_blocked = client.get("blocked_services", {})
# Clear blocked services temporarily
await api.update_client_blocked_services(client_name, [])
# Schedule restore
task = asyncio.create_task(
self._delayed_restore_client(api, client_name, original_blocked, duration)
)
self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task
except Exception as err:
_LOGGER.error("Failed to execute emergency unblock: %s", err)
async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None:
"""Re-enable protection after delay."""
await asyncio.sleep(delay)
try: try:
if "all" in clients: await api.set_protection(True)
await api.set_protection(False) _LOGGER.info("Emergency unblock expired - protection re-enabled")
_LOGGER.info("Emergency unblock activated globally for %d seconds", duration)
else:
_LOGGER.info("Emergency unblock activated for clients: %s", clients)
except Exception as err: except Exception as err:
_LOGGER.error("Failed to execute emergency unblock: %s", err) _LOGGER.error("Failed to re-enable protection: %s", err)
raise
# Register emergency unblock service async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str,
hass.services.async_register( original_blocked: Dict, delay: int) -> None:
"adguard_hub", """Restore client blocked services after delay."""
"emergency_unblock", await asyncio.sleep(delay)
emergency_unblock_service try:
) if isinstance(original_blocked, dict):
services = original_blocked.get("ids", [])
else:
services = original_blocked if original_blocked else []
_LOGGER.info("AdGuard Control Hub services registered") await api.update_client_blocked_services(client_name, services)
_LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to restore client blocking: %s", err)
async def async_unregister_services(hass: HomeAssistant) -> None: async def bulk_update_clients(self, call: ServiceCall) -> None:
"""Unregister integration services.""" """Update multiple clients matching a pattern."""
hass.services.async_remove("adguard_hub", "emergency_unblock") import re
_LOGGER.info("AdGuard Control Hub services unregistered")
pattern = call.data[ATTR_CLIENT_PATTERN]
settings = call.data[ATTR_SETTINGS]
_LOGGER.info("Bulk updating clients matching pattern: %s", pattern)
# Convert pattern to regex
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
compiled_pattern = re.compile(regex_pattern, re.IGNORECASE)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
coordinator = entry_data["coordinator"]
try:
# Get all clients
clients = coordinator.clients
matching_clients = []
for client_name in clients.keys():
if compiled_pattern.match(client_name):
matching_clients.append(client_name)
_LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients)
# Update each matching client
for client_name in matching_clients:
client = clients[client_name]
# Prepare update data
update_data = {
"name": client_name,
"data": {**client} # Start with current data
}
# Apply settings
if "blocked_services" in settings:
blocked_services_data = {
"ids": settings["blocked_services"],
"schedule": {"time_zone": "Local"}
}
update_data["data"]["blocked_services"] = blocked_services_data
if "filtering_enabled" in settings:
update_data["data"]["filtering_enabled"] = settings["filtering_enabled"]
if "safebrowsing_enabled" in settings:
update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"]
if "parental_enabled" in settings:
update_data["data"]["parental_enabled"] = settings["parental_enabled"]
# Update the client
await api.update_client(update_data)
_LOGGER.info("Updated client: %s", client_name)
except Exception as err:
_LOGGER.error("Failed to bulk update clients: %s", err)
async def add_client(self, call: ServiceCall) -> None:
"""Add a new client."""
client_data = dict(call.data)
# Convert blocked_services to proper format
if "blocked_services" in client_data and client_data["blocked_services"]:
blocked_services_data = {
"ids": client_data["blocked_services"],
"schedule": {"time_zone": "Local"}
}
client_data["blocked_services"] = blocked_services_data
_LOGGER.info("Adding new client: %s", client_data["name"])
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
await api.add_client(client_data)
_LOGGER.info("Successfully added client: %s", client_data["name"])
except Exception as err:
_LOGGER.error("Failed to add client %s: %s", client_data["name"], err)
async def remove_client(self, call: ServiceCall) -> None:
"""Remove a client."""
client_name = call.data["name"]
_LOGGER.info("Removing client: %s", client_name)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
await api.delete_client(client_name)
_LOGGER.info("Successfully removed client: %s", client_name)
except Exception as err:
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
async def schedule_service_block(self, call: ServiceCall) -> None:
"""Schedule service blocking with time-based rules."""
client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES]
schedule = call.data["schedule"]
_LOGGER.info("Scheduling service blocking for client %s", client_name)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
await api.update_client_blocked_services(client_name, services, schedule)
_LOGGER.info("Successfully scheduled service blocking for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)

View File

@@ -1,27 +1,151 @@
{ {
"config": { "config": {
"step": { "step": {
"user": { "user": {
"title": "AdGuard Control Hub", "title": "AdGuard Control Hub",
"description": "Connect to your AdGuard Home instance for complete network control", "description": "Configure your AdGuard Home connection",
"data": { "data": {
"host": "AdGuard Home IP Address", "host": "Host",
"port": "Port (usually 3000)", "port": "Port",
"username": "Admin Username", "username": "Username",
"password": "Admin Password", "password": "Password",
"ssl": "Use HTTPS connection", "ssl": "Use SSL",
"verify_ssl": "Verify SSL certificate" "verify_ssl": "Verify SSL Certificate"
}
}
},
"error": {
"cannot_connect": "Failed to connect to AdGuard Home. Please check your host, port, and credentials.",
"invalid_auth": "Invalid username or password",
"timeout": "Connection timeout. Please check your network connection.",
"unknown": "An unexpected error occurred"
},
"abort": {
"already_configured": "AdGuard Control Hub is already configured for this host and port"
} }
}
}, },
"error": { "options": {
"cannot_connect": "Failed to connect to AdGuard Home. Check IP address, port, and credentials.", "step": {
"invalid_auth": "Invalid username or password. Please check your admin credentials.", "init": {
"unknown": "Unexpected error occurred. Please check logs for details." "title": "AdGuard Control Hub Options",
"description": "Configure advanced options",
"data": {
"scan_interval": "Update interval (seconds)",
"timeout": "Connection timeout (seconds)"
}
}
}
}, },
"abort": { "services": {
"already_configured": "This AdGuard Home instance is already configured", "block_services": {
"cannot_connect": "Cannot connect to AdGuard Home" "name": "Block Services",
"description": "Block specific services for a client",
"fields": {
"client_name": {
"name": "Client Name",
"description": "Name of the client to block services for"
},
"services": {
"name": "Services",
"description": "List of services to block"
}
}
},
"unblock_services": {
"name": "Unblock Services",
"description": "Unblock specific services for a client",
"fields": {
"client_name": {
"name": "Client Name",
"description": "Name of the client to unblock services for"
},
"services": {
"name": "Services",
"description": "List of services to unblock"
}
}
},
"emergency_unblock": {
"name": "Emergency Unblock",
"description": "Temporarily disable blocking for emergency access",
"fields": {
"duration": {
"name": "Duration",
"description": "Duration in seconds to keep unblocked"
},
"clients": {
"name": "Clients",
"description": "List of client names (use 'all' for global)"
}
}
},
"bulk_update_clients": {
"name": "Bulk Update Clients",
"description": "Update multiple clients matching a pattern",
"fields": {
"client_pattern": {
"name": "Client Pattern",
"description": "Pattern to match client names (supports wildcards)"
},
"settings": {
"name": "Settings",
"description": "Settings to apply to matching clients"
}
}
},
"add_client": {
"name": "Add Client",
"description": "Add a new client configuration",
"fields": {
"name": {
"name": "Name",
"description": "Client name"
},
"ids": {
"name": "IDs",
"description": "List of IP addresses or CIDR ranges"
},
"mac": {
"name": "MAC Address",
"description": "MAC address (optional)"
},
"filtering_enabled": {
"name": "Filtering Enabled",
"description": "Enable DNS filtering for this client"
},
"blocked_services": {
"name": "Blocked Services",
"description": "List of services to block"
}
}
},
"remove_client": {
"name": "Remove Client",
"description": "Remove a client configuration",
"fields": {
"name": {
"name": "Name",
"description": "Name of the client to remove"
}
}
},
"schedule_service_block": {
"name": "Schedule Service Block",
"description": "Schedule time-based service blocking",
"fields": {
"client_name": {
"name": "Client Name",
"description": "Name of the client"
},
"services": {
"name": "Services",
"description": "List of services to block"
},
"schedule": {
"name": "Schedule",
"description": "Time-based schedule configuration"
}
}
}
} }
}
} }

View File

@@ -1,10 +1,13 @@
"""Switch platform for AdGuard Control Hub integration.""" """Switch platform for AdGuard Control Hub integration."""
import logging import logging
from typing import Any
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import SwitchEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AdGuardControlHubCoordinator from . import AdGuardControlHubCoordinator
from .api import AdGuardHomeAPI from .api import AdGuardHomeAPI
from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER
@@ -12,7 +15,11 @@ from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MA
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback): async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up AdGuard Control Hub switch platform.""" """Set up AdGuard Control Hub switch platform."""
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"] coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
api = hass.data[DOMAIN][config_entry.entry_id]["api"] api = hass.data[DOMAIN][config_entry.entry_id]["api"]
@@ -32,6 +39,7 @@ class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity):
"""Base class for AdGuard switches.""" """Base class for AdGuard switches."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the switch."""
super().__init__(coordinator) super().__init__(coordinator)
self.api = api self.api = api
self._attr_device_info = { self._attr_device_info = {
@@ -46,31 +54,64 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
"""Switch to control global AdGuard protection.""" """Switch to control global AdGuard protection."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the switch."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_protection" self._attr_unique_id = f"{api.host}_{api.port}_protection"
self._attr_name = "AdGuard Protection" self._attr_name = "AdGuard Protection"
@property @property
def is_on(self) -> bool: def is_on(self) -> bool | None:
"""Return true if protection is enabled."""
return self.coordinator.protection_status.get("protection_enabled", False) return self.coordinator.protection_status.get("protection_enabled", False)
@property @property
def icon(self) -> str: def icon(self) -> str:
"""Return the icon for the switch."""
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
async def async_turn_on(self, **kwargs): @property
await self.api.set_protection(True) def extra_state_attributes(self) -> dict[str, Any]:
await self.coordinator.async_request_refresh() """Return additional state attributes."""
status = self.coordinator.protection_status
stats = self.coordinator.statistics
return {
"dns_port": status.get("dns_port", "N/A"),
"queries_today": stats.get("num_dns_queries_today", 0),
"blocked_today": stats.get("num_blocked_filtering_today", 0),
"version": status.get("version", "N/A"),
}
async def async_turn_off(self, **kwargs): async def async_turn_on(self, **kwargs: Any) -> None:
await self.api.set_protection(False) """Turn on AdGuard protection."""
await self.coordinator.async_request_refresh() try:
await self.api.set_protection(True)
await self.coordinator.async_request_refresh()
_LOGGER.info("AdGuard protection enabled")
except Exception as err:
_LOGGER.error("Failed to enable AdGuard protection: %s", err)
raise
async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off AdGuard protection."""
try:
await self.api.set_protection(False)
await self.coordinator.async_request_refresh()
_LOGGER.info("AdGuard protection disabled")
except Exception as err:
_LOGGER.error("Failed to disable AdGuard protection: %s", err)
raise
class AdGuardClientSwitch(AdGuardBaseSwitch): class AdGuardClientSwitch(AdGuardBaseSwitch):
"""Switch to control client-specific protection.""" """Switch to control client-specific protection."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI, client_name: str): def __init__(
self,
coordinator: AdGuardControlHubCoordinator,
api: AdGuardHomeAPI,
client_name: str,
):
"""Initialize the switch."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self.client_name = client_name self.client_name = client_name
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}" self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}"
@@ -78,16 +119,81 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
self._attr_icon = ICON_CLIENT self._attr_icon = ICON_CLIENT
@property @property
def is_on(self) -> bool: def is_on(self) -> bool | None:
"""Return true if client protection is enabled."""
client = self.coordinator.clients.get(self.client_name, {}) client = self.coordinator.clients.get(self.client_name, {})
return client.get("filtering_enabled", True) return client.get("filtering_enabled", True)
async def async_turn_on(self, **kwargs): @property
# This would update client settings - simplified for basic functionality def extra_state_attributes(self) -> dict[str, Any]:
_LOGGER.info("Would enable protection for %s", self.client_name) """Return additional state attributes."""
await self.coordinator.async_request_refresh() client = self.coordinator.clients.get(self.client_name, {})
blocked_services = client.get("blocked_services", {})
async def async_turn_off(self, **kwargs): if isinstance(blocked_services, dict):
# This would update client settings - simplified for basic functionality service_ids = blocked_services.get("ids", [])
_LOGGER.info("Would disable protection for %s", self.client_name) else:
await self.coordinator.async_request_refresh() service_ids = blocked_services if blocked_services else []
return {
"client_ids": client.get("ids", []),
"mac": client.get("mac", ""),
"use_global_settings": client.get("use_global_settings", True),
"safebrowsing_enabled": client.get("safebrowsing_enabled", False),
"parental_enabled": client.get("parental_enabled", False),
"safesearch_enabled": client.get("safesearch_enabled", False),
"blocked_services": service_ids,
"blocked_services_count": len(service_ids),
}
async def async_turn_on(self, **kwargs: Any) -> None:
"""Enable protection for this client."""
try:
# Get current client data
client = await self.api.get_client_by_name(self.client_name)
if not client:
_LOGGER.error("Client %s not found", self.client_name)
return
# Update client with filtering enabled
update_data = {
"name": self.client_name,
"data": {
**client,
"filtering_enabled": True,
}
}
await self.api.update_client(update_data)
await self.coordinator.async_request_refresh()
_LOGGER.info("Enabled protection for client %s", self.client_name)
except Exception as err:
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
raise
async def async_turn_off(self, **kwargs: Any) -> None:
"""Disable protection for this client."""
try:
# Get current client data
client = await self.api.get_client_by_name(self.client_name)
if not client:
_LOGGER.error("Client %s not found", self.client_name)
return
# Update client with filtering disabled
update_data = {
"name": self.client_name,
"data": {
**client,
"filtering_enabled": False,
}
}
await self.api.update_client(update_data)
await self.coordinator.async_request_refresh()
_LOGGER.info("Disabled protection for client %s", self.client_name)
except Exception as err:
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
raise

View File

@@ -3,6 +3,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from custom_components.adguard_hub.api import AdGuardHomeAPI from custom_components.adguard_hub.api import AdGuardHomeAPI
@pytest.fixture @pytest.fixture
def mock_session(): def mock_session():
"""Mock aiohttp session.""" """Mock aiohttp session."""
@@ -15,6 +16,7 @@ def mock_session():
session.request = AsyncMock(return_value=response) session.request = AsyncMock(return_value=response)
return session return session
async def test_api_connection(mock_session): async def test_api_connection(mock_session):
"""Test API connection.""" """Test API connection."""
api = AdGuardHomeAPI( api = AdGuardHomeAPI(
@@ -28,6 +30,7 @@ async def test_api_connection(mock_session):
result = await api.test_connection() result = await api.test_connection()
assert result is True assert result is True
async def test_api_get_status(mock_session): async def test_api_get_status(mock_session):
"""Test getting status.""" """Test getting status."""
api = AdGuardHomeAPI( api = AdGuardHomeAPI(

223
tests/test_integration.py Normal file
View File

@@ -0,0 +1,223 @@
"""Test the complete AdGuard Control Hub integration."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from homeassistant.core import HomeAssistant
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
from custom_components.adguard_hub import async_setup_entry, async_unload_entry
from custom_components.adguard_hub.api import AdGuardHomeAPI
from custom_components.adguard_hub.const import DOMAIN
@pytest.fixture
def mock_config_entry():
"""Mock config entry."""
return ConfigEntry(
version=1,
domain=DOMAIN,
title="Test AdGuard",
data={
CONF_HOST: "192.168.1.100",
CONF_PORT: 3000,
CONF_USERNAME: "admin",
CONF_PASSWORD: "password",
},
source="user",
entry_id="test_entry_id",
)
@pytest.fixture
def mock_api():
"""Mock API instance."""
api = MagicMock(spec=AdGuardHomeAPI)
api.host = "192.168.1.100"
api.port = 3000
api.test_connection = AsyncMock(return_value=True)
api.get_status = AsyncMock(return_value={
"protection_enabled": True,
"version": "v0.107.0",
"dns_port": 53,
"running": True,
})
api.get_clients = AsyncMock(return_value={
"clients": [
{
"name": "test_client",
"ids": ["192.168.1.50"],
"filtering_enabled": True,
"blocked_services": {"ids": ["youtube"]},
}
]
})
api.get_statistics = AsyncMock(return_value={
"num_dns_queries": 1000,
"num_blocked_filtering": 100,
"num_dns_queries_today": 500,
"num_blocked_filtering_today": 50,
"filtering_rules_count": 50000,
"avg_processing_time": 2.5,
})
return api
@pytest.mark.asyncio
async def test_setup_entry_success(hass: HomeAssistant, mock_config_entry, mock_api):
"""Test successful setup of config entry."""
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
patch("custom_components.adguard_hub.async_get_clientsession"), \
patch.object(hass.config_entries, "async_forward_entry_setups", return_value=True):
result = await async_setup_entry(hass, mock_config_entry)
assert result is True
assert DOMAIN in hass.data
assert mock_config_entry.entry_id in hass.data[DOMAIN]
assert "coordinator" in hass.data[DOMAIN][mock_config_entry.entry_id]
assert "api" in hass.data[DOMAIN][mock_config_entry.entry_id]
@pytest.mark.asyncio
async def test_setup_entry_connection_failure(hass: HomeAssistant, mock_config_entry):
"""Test setup failure due to connection error."""
mock_api = MagicMock(spec=AdGuardHomeAPI)
mock_api.test_connection = AsyncMock(return_value=False)
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
patch("custom_components.adguard_hub.async_get_clientsession"), \
pytest.raises(Exception): # Should raise ConfigEntryNotReady
await async_setup_entry(hass, mock_config_entry)
@pytest.mark.asyncio
async def test_unload_entry(hass: HomeAssistant, mock_config_entry):
"""Test unloading of config entry."""
# Set up initial data
hass.data[DOMAIN] = {
mock_config_entry.entry_id: {
"coordinator": MagicMock(),
"api": MagicMock(),
}
}
with patch.object(hass.config_entries, "async_unload_platforms", return_value=True):
result = await async_unload_entry(hass, mock_config_entry)
assert result is True
assert mock_config_entry.entry_id not in hass.data[DOMAIN]
@pytest.mark.asyncio
async def test_coordinator_data_update(hass: HomeAssistant, mock_api):
"""Test coordinator data update functionality."""
from custom_components.adguard_hub import AdGuardControlHubCoordinator
coordinator = AdGuardControlHubCoordinator(hass, mock_api)
# Test successful data update
data = await coordinator._async_update_data()
assert "clients" in data
assert "statistics" in data
assert "status" in data
assert "test_client" in data["clients"]
assert data["statistics"]["num_dns_queries"] == 1000
assert data["status"]["protection_enabled"] is True
@pytest.mark.asyncio
async def test_api_error_handling(mock_api):
"""Test API error handling."""
from custom_components.adguard_hub.api import AdGuardConnectionError, AdGuardAuthError
# Test connection error
mock_api.get_status = AsyncMock(side_effect=AdGuardConnectionError("Connection failed"))
with pytest.raises(AdGuardConnectionError):
await mock_api.get_status()
# Test auth error
mock_api.get_clients = AsyncMock(side_effect=AdGuardAuthError("Auth failed"))
with pytest.raises(AdGuardAuthError):
await mock_api.get_clients()
@pytest.mark.asyncio
async def test_services_registration(hass: HomeAssistant):
"""Test that services are properly registered."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
services = AdGuardControlHubServices(hass)
services.register_services()
# Check that services are registered
assert hass.services.has_service(DOMAIN, "block_services")
assert hass.services.has_service(DOMAIN, "unblock_services")
assert hass.services.has_service(DOMAIN, "emergency_unblock")
assert hass.services.has_service(DOMAIN, "bulk_update_clients")
assert hass.services.has_service(DOMAIN, "add_client")
assert hass.services.has_service(DOMAIN, "remove_client")
# Clean up
services.unregister_services()
def test_blocked_services_constants():
"""Test that blocked services are properly defined."""
from custom_components.adguard_hub.const import BLOCKED_SERVICES
assert "youtube" in BLOCKED_SERVICES
assert "netflix" in BLOCKED_SERVICES
assert "gaming" in BLOCKED_SERVICES
assert "facebook" in BLOCKED_SERVICES
# Check friendly names are defined
assert BLOCKED_SERVICES["youtube"] == "YouTube"
assert BLOCKED_SERVICES["netflix"] == "Netflix"
def test_api_endpoints():
"""Test that API endpoints are properly defined."""
from custom_components.adguard_hub.const import API_ENDPOINTS
required_endpoints = [
"status", "clients", "stats", "protection",
"clients_add", "clients_update", "clients_delete"
]
for endpoint in required_endpoints:
assert endpoint in API_ENDPOINTS
assert API_ENDPOINTS[endpoint].startswith("/")
@pytest.mark.asyncio
async def test_client_operations(mock_api):
"""Test client add/update/delete operations."""
# Test add client
client_data = {
"name": "new_client",
"ids": ["192.168.1.200"],
"filtering_enabled": True,
}
mock_api.add_client = AsyncMock(return_value={"success": True})
result = await mock_api.add_client(client_data)
assert result["success"] is True
# Test update client
update_data = {
"name": "new_client",
"data": {"filtering_enabled": False}
}
mock_api.update_client = AsyncMock(return_value={"success": True})
result = await mock_api.update_client(update_data)
assert result["success"] is True
# Test delete client
mock_api.delete_client = AsyncMock(return_value={"success": True})
result = await mock_api.delete_client("new_client")
assert result["success"] is True