@@ -1,20 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
🛡️ AdGuard Control Hub for Home Assistant.
|
AdGuard Control Hub for Home Assistant.
|
||||||
|
|
||||||
Transform your AdGuard Home into a smart network management powerhouse with
|
Transform your AdGuard Home into a smart network management powerhouse with
|
||||||
complete client control, service blocking, and automation capabilities.
|
complete client control, service blocking, and automation capabilities.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryNotReady
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||||
|
|
||||||
|
from .api import AdGuardHomeAPI, AdGuardConnectionError
|
||||||
from .const import DOMAIN, PLATFORMS, SCAN_INTERVAL, CONF_SSL, CONF_VERIFY_SSL
|
from .const import DOMAIN, PLATFORMS, SCAN_INTERVAL, CONF_SSL, CONF_VERIFY_SSL
|
||||||
from .api import AdGuardHomeAPI
|
from .services import AdGuardControlHubServices
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,6 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
"""Set up AdGuard Control Hub from a config entry."""
|
"""Set up AdGuard Control Hub from a config entry."""
|
||||||
session = async_get_clientsession(hass, entry.data.get(CONF_VERIFY_SSL, True))
|
session = async_get_clientsession(hass, entry.data.get(CONF_VERIFY_SSL, True))
|
||||||
|
|
||||||
|
# Create API instance
|
||||||
api = AdGuardHomeAPI(
|
api = AdGuardHomeAPI(
|
||||||
host=entry.data[CONF_HOST],
|
host=entry.data[CONF_HOST],
|
||||||
port=entry.data[CONF_PORT],
|
port=entry.data[CONF_PORT],
|
||||||
@@ -34,16 +39,26 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
# Test the connection
|
# Test the connection
|
||||||
try:
|
try:
|
||||||
await api.test_connection()
|
if not await api.test_connection():
|
||||||
_LOGGER.info("Successfully connected to AdGuard Home at %s:%s",
|
raise ConfigEntryNotReady("Unable to connect to AdGuard Home")
|
||||||
entry.data[CONF_HOST], entry.data[CONF_PORT])
|
|
||||||
|
_LOGGER.info(
|
||||||
|
"Successfully connected to AdGuard Home at %s:%s",
|
||||||
|
entry.data[CONF_HOST],
|
||||||
|
entry.data[CONF_PORT]
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to connect to AdGuard Home: %s", err)
|
_LOGGER.error("Failed to connect to AdGuard Home: %s", err)
|
||||||
raise ConfigEntryNotReady(f"Unable to connect: {err}")
|
raise ConfigEntryNotReady(f"Unable to connect: {err}") from err
|
||||||
|
|
||||||
# Create update coordinator
|
# Create update coordinator
|
||||||
coordinator = AdGuardControlHubCoordinator(hass, api)
|
coordinator = AdGuardControlHubCoordinator(hass, api)
|
||||||
await coordinator.async_config_entry_first_refresh()
|
|
||||||
|
try:
|
||||||
|
await coordinator.async_config_entry_first_refresh()
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to perform initial data refresh: %s", err)
|
||||||
|
raise ConfigEntryNotReady(f"Failed to fetch initial data: {err}") from err
|
||||||
|
|
||||||
# Store data
|
# Store data
|
||||||
hass.data.setdefault(DOMAIN, {})
|
hass.data.setdefault(DOMAIN, {})
|
||||||
@@ -53,9 +68,24 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Set up platforms
|
# Set up platforms
|
||||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
try:
|
||||||
|
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to set up platforms: %s", err)
|
||||||
|
# Clean up on failure
|
||||||
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
|
raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err
|
||||||
|
|
||||||
_LOGGER.info("AdGuard Control Hub setup complete")
|
# Register services (only once, not per config entry)
|
||||||
|
if not hass.services.has_service(DOMAIN, "block_services"):
|
||||||
|
services = AdGuardControlHubServices(hass)
|
||||||
|
services.register_services()
|
||||||
|
|
||||||
|
# Store services instance for cleanup
|
||||||
|
hass.data.setdefault(f"{DOMAIN}_services", services)
|
||||||
|
|
||||||
|
_LOGGER.info("AdGuard Control Hub setup complete for %s:%s",
|
||||||
|
entry.data[CONF_HOST], entry.data[CONF_PORT])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -64,8 +94,19 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||||
|
|
||||||
if unload_ok:
|
if unload_ok:
|
||||||
|
# Remove this entry's data
|
||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
|
|
||||||
|
# Unregister services if this was the last entry
|
||||||
|
if not hass.data[DOMAIN]: # No more entries
|
||||||
|
services = hass.data.get(f"{DOMAIN}_services")
|
||||||
|
if services:
|
||||||
|
services.unregister_services()
|
||||||
|
hass.data.pop(f"{DOMAIN}_services", None)
|
||||||
|
|
||||||
|
# Also clean up the empty domain entry
|
||||||
|
hass.data.pop(DOMAIN, None)
|
||||||
|
|
||||||
return unload_ok
|
return unload_ok
|
||||||
|
|
||||||
|
|
||||||
@@ -81,36 +122,54 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
|||||||
update_interval=timedelta(seconds=SCAN_INTERVAL),
|
update_interval=timedelta(seconds=SCAN_INTERVAL),
|
||||||
)
|
)
|
||||||
self.api = api
|
self.api = api
|
||||||
self._clients = {}
|
self._clients: Dict[str, Any] = {}
|
||||||
self._statistics = {}
|
self._statistics: Dict[str, Any] = {}
|
||||||
self._protection_status = {}
|
self._protection_status: Dict[str, Any] = {}
|
||||||
|
|
||||||
async def _async_update_data(self):
|
async def _async_update_data(self) -> Dict[str, Any]:
|
||||||
"""Fetch data from AdGuard Home."""
|
"""Fetch data from AdGuard Home."""
|
||||||
try:
|
try:
|
||||||
# Fetch all data concurrently for better performance
|
# Fetch all data concurrently for better performance
|
||||||
results = await asyncio.gather(
|
tasks = [
|
||||||
self.api.get_clients(),
|
self.api.get_clients(),
|
||||||
self.api.get_statistics(),
|
self.api.get_statistics(),
|
||||||
self.api.get_status(),
|
self.api.get_status(),
|
||||||
return_exceptions=True,
|
]
|
||||||
)
|
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
clients, statistics, status = results
|
clients, statistics, status = results
|
||||||
|
|
||||||
# Handle any exceptions
|
# Handle any exceptions in individual requests
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
endpoint_names = ["clients", "statistics", "status"]
|
endpoint_names = ["clients", "statistics", "status"]
|
||||||
_LOGGER.warning("Error fetching %s: %s", endpoint_names[i], result)
|
_LOGGER.warning(
|
||||||
|
"Error fetching %s from %s:%s: %s",
|
||||||
|
endpoint_names[i],
|
||||||
|
self.api.host,
|
||||||
|
self.api.port,
|
||||||
|
result
|
||||||
|
)
|
||||||
|
|
||||||
# Update stored data (use empty dict if fetch failed)
|
# Update stored data (use empty dict if fetch failed)
|
||||||
self._clients = {
|
if not isinstance(clients, Exception):
|
||||||
client["name"]: client
|
self._clients = {
|
||||||
for client in (clients.get("clients", []) if not isinstance(clients, Exception) else [])
|
client["name"]: client
|
||||||
}
|
for client in clients.get("clients", [])
|
||||||
self._statistics = statistics if not isinstance(statistics, Exception) else {}
|
if client.get("name") # Ensure client has a name
|
||||||
self._protection_status = status if not isinstance(status, Exception) else {}
|
}
|
||||||
|
else:
|
||||||
|
_LOGGER.warning("Failed to update clients data, keeping previous data")
|
||||||
|
|
||||||
|
if not isinstance(statistics, Exception):
|
||||||
|
self._statistics = statistics
|
||||||
|
else:
|
||||||
|
_LOGGER.warning("Failed to update statistics data, keeping previous data")
|
||||||
|
|
||||||
|
if not isinstance(status, Exception):
|
||||||
|
self._protection_status = status
|
||||||
|
else:
|
||||||
|
_LOGGER.warning("Failed to update status data, keeping previous data")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"clients": self._clients,
|
"clients": self._clients,
|
||||||
@@ -118,20 +177,40 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
|||||||
"status": self._protection_status,
|
"status": self._protection_status,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except AdGuardConnectionError as err:
|
||||||
|
raise UpdateFailed(f"Connection error to AdGuard Home: {err}") from err
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}")
|
raise UpdateFailed(f"Error communicating with AdGuard Control Hub: {err}") from err
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clients(self):
|
def clients(self) -> Dict[str, Any]:
|
||||||
"""Return clients data."""
|
"""Return clients data."""
|
||||||
return self._clients
|
return self._clients
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def statistics(self):
|
def statistics(self) -> Dict[str, Any]:
|
||||||
"""Return statistics data."""
|
"""Return statistics data."""
|
||||||
return self._statistics
|
return self._statistics
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def protection_status(self):
|
def protection_status(self) -> Dict[str, Any]:
|
||||||
"""Return protection status data."""
|
"""Return protection status data."""
|
||||||
return self._protection_status
|
return self._protection_status
|
||||||
|
|
||||||
|
def get_client(self, client_name: str) -> Dict[str, Any] | None:
|
||||||
|
"""Get a specific client by name."""
|
||||||
|
return self._clients.get(client_name)
|
||||||
|
|
||||||
|
def has_client(self, client_name: str) -> bool:
|
||||||
|
"""Check if a client exists."""
|
||||||
|
return client_name in self._clients
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_count(self) -> int:
|
||||||
|
"""Return the number of clients."""
|
||||||
|
return len(self._clients)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_protection_enabled(self) -> bool:
|
||||||
|
"""Return True if protection is enabled."""
|
||||||
|
return self._protection_status.get("protection_enabled", False)
|
||||||
|
@@ -1,102 +1,207 @@
|
|||||||
"""API wrapper for AdGuard Home."""
|
"""API wrapper for AdGuard Home."""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import BasicAuth
|
from aiohttp import BasicAuth, ClientError, ClientTimeout
|
||||||
|
|
||||||
from .const import API_ENDPOINTS
|
from .const import API_ENDPOINTS
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Custom exceptions
|
||||||
|
class AdGuardHomeError(Exception):
|
||||||
|
"""Base exception for AdGuard Home API."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AdGuardConnectionError(AdGuardHomeError):
|
||||||
|
"""Exception for connection errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AdGuardAuthError(AdGuardHomeError):
|
||||||
|
"""Exception for authentication errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AdGuardNotFoundError(AdGuardHomeError):
|
||||||
|
"""Exception for not found errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
class AdGuardHomeAPI:
|
class AdGuardHomeAPI:
|
||||||
"""API wrapper for AdGuard Home."""
|
"""API wrapper for AdGuard Home."""
|
||||||
|
|
||||||
def __init__(self, host: str, port: int = 3000, username: str = None,
|
def __init__(
|
||||||
password: str = None, ssl: bool = False, session=None):
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int = 3000,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
ssl: bool = False,
|
||||||
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
):
|
||||||
|
"""Initialize the API wrapper."""
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.username = username
|
self.username = username
|
||||||
self.password = password
|
self.password = password
|
||||||
self.ssl = ssl
|
self.ssl = ssl
|
||||||
self.session = session
|
self._session = session
|
||||||
|
self._timeout = ClientTimeout(total=timeout)
|
||||||
protocol = "https" if ssl else "http"
|
protocol = "https" if ssl else "http"
|
||||||
self.base_url = f"{protocol}://{host}:{port}"
|
self.base_url = f"{protocol}://{host}:{port}"
|
||||||
|
self._own_session = session is None
|
||||||
|
|
||||||
async def _request(self, method: str, endpoint: str, data: dict = None) -> dict:
|
async def __aenter__(self):
|
||||||
"""Make an API request."""
|
"""Async context manager entry."""
|
||||||
|
if self._own_session:
|
||||||
|
self._session = aiohttp.ClientSession(timeout=self._timeout)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit."""
|
||||||
|
if self._own_session and self._session:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Get the session, creating one if needed."""
|
||||||
|
if not self._session:
|
||||||
|
self._session = aiohttp.ClientSession(timeout=self._timeout)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
|
"""Make an API request with comprehensive error handling."""
|
||||||
url = f"{self.base_url}{endpoint}"
|
url = f"{self.base_url}{endpoint}"
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
auth = None
|
auth = None
|
||||||
|
|
||||||
if self.username and self.password:
|
if self.username and self.password:
|
||||||
auth = BasicAuth(self.username, self.password)
|
auth = BasicAuth(self.username, self.password)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.session.request(method, url, json=data, headers=headers, auth=auth) as response:
|
async with self.session.request(
|
||||||
|
method, url, json=data, headers=headers, auth=auth
|
||||||
|
) as response:
|
||||||
|
|
||||||
|
# Handle different HTTP status codes
|
||||||
|
if response.status == 401:
|
||||||
|
raise AdGuardAuthError("Authentication failed - check username/password")
|
||||||
|
elif response.status == 403:
|
||||||
|
raise AdGuardAuthError("Access forbidden - insufficient permissions")
|
||||||
|
elif response.status == 404:
|
||||||
|
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
||||||
|
elif response.status >= 500:
|
||||||
|
raise AdGuardConnectionError(f"Server error {response.status}")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Handle empty responses
|
||||||
if response.status == 204 or not response.content_length:
|
if response.status == 204 or not response.content_length:
|
||||||
return {}
|
return {}
|
||||||
return await response.json()
|
|
||||||
|
try:
|
||||||
|
return await response.json()
|
||||||
|
except aiohttp.ContentTypeError:
|
||||||
|
# Handle non-JSON responses
|
||||||
|
text = await response.text()
|
||||||
|
_LOGGER.warning("Non-JSON response received: %s", text)
|
||||||
|
return {"response": text}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError as err:
|
||||||
|
raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}")
|
||||||
|
except ClientError as err:
|
||||||
|
raise AdGuardConnectionError(f"Client error: {err}")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Error communicating with AdGuard Home: %s", err)
|
_LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err)
|
||||||
raise
|
raise AdGuardHomeError(f"Unexpected error: {err}")
|
||||||
|
|
||||||
async def test_connection(self) -> bool:
|
async def test_connection(self) -> bool:
|
||||||
"""Test the connection."""
|
"""Test the connection to AdGuard Home."""
|
||||||
try:
|
try:
|
||||||
await self._request("GET", API_ENDPOINTS["status"])
|
await self._request("GET", API_ENDPOINTS["status"])
|
||||||
return True
|
return True
|
||||||
except:
|
except Exception as err:
|
||||||
|
_LOGGER.debug("Connection test failed: %s", err)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_status(self) -> dict:
|
async def get_status(self) -> Dict[str, Any]:
|
||||||
"""Get server status."""
|
"""Get server status information."""
|
||||||
return await self._request("GET", API_ENDPOINTS["status"])
|
return await self._request("GET", API_ENDPOINTS["status"])
|
||||||
|
|
||||||
async def get_clients(self) -> dict:
|
async def get_clients(self) -> Dict[str, Any]:
|
||||||
"""Get all clients."""
|
"""Get all configured clients."""
|
||||||
return await self._request("GET", API_ENDPOINTS["clients"])
|
return await self._request("GET", API_ENDPOINTS["clients"])
|
||||||
|
|
||||||
async def get_statistics(self) -> dict:
|
async def get_statistics(self) -> Dict[str, Any]:
|
||||||
"""Get statistics."""
|
"""Get DNS query statistics."""
|
||||||
return await self._request("GET", API_ENDPOINTS["stats"])
|
return await self._request("GET", API_ENDPOINTS["stats"])
|
||||||
|
|
||||||
async def set_protection(self, enabled: bool) -> dict:
|
async def set_protection(self, enabled: bool) -> Dict[str, Any]:
|
||||||
"""Enable or disable protection."""
|
"""Enable or disable AdGuard protection."""
|
||||||
data = {"enabled": enabled}
|
data = {"enabled": enabled}
|
||||||
return await self._request("POST", API_ENDPOINTS["protection"], data)
|
return await self._request("POST", API_ENDPOINTS["protection"], data)
|
||||||
|
|
||||||
async def add_client(self, client_data: dict) -> dict:
|
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Add a new client."""
|
"""Add a new client configuration."""
|
||||||
|
# Validate required fields
|
||||||
|
if "name" not in client_data:
|
||||||
|
raise ValueError("Client name is required")
|
||||||
|
if "ids" not in client_data or not client_data["ids"]:
|
||||||
|
raise ValueError("Client IDs are required")
|
||||||
|
|
||||||
return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
|
return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
|
||||||
|
|
||||||
async def update_client(self, client_data: dict) -> dict:
|
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Update an existing client."""
|
"""Update an existing client configuration."""
|
||||||
|
if "name" not in client_data:
|
||||||
|
raise ValueError("Client name is required for update")
|
||||||
|
if "data" not in client_data:
|
||||||
|
raise ValueError("Client data is required for update")
|
||||||
|
|
||||||
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
||||||
|
|
||||||
async def delete_client(self, client_name: str) -> dict:
|
async def delete_client(self, client_name: str) -> Dict[str, Any]:
|
||||||
"""Delete a client."""
|
"""Delete a client configuration."""
|
||||||
|
if not client_name:
|
||||||
|
raise ValueError("Client name is required")
|
||||||
|
|
||||||
data = {"name": client_name}
|
data = {"name": client_name}
|
||||||
return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
|
return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
|
||||||
|
|
||||||
async def get_client_by_name(self, client_name: str) -> dict:
|
async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Get a specific client by name."""
|
"""Get a specific client by name."""
|
||||||
clients_data = await self.get_clients()
|
if not client_name:
|
||||||
clients = clients_data.get("clients", [])
|
return None
|
||||||
|
|
||||||
for client in clients:
|
try:
|
||||||
if client.get("name") == client_name:
|
clients_data = await self.get_clients()
|
||||||
return client
|
clients = clients_data.get("clients", [])
|
||||||
|
|
||||||
return None
|
for client in clients:
|
||||||
|
if client.get("name") == client_name:
|
||||||
|
return client
|
||||||
|
|
||||||
async def update_client_blocked_services(self, client_name: str, blocked_services: list,
|
return None
|
||||||
schedule: dict = None) -> dict:
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to get client %s: %s", client_name, err)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update_client_blocked_services(
|
||||||
|
self,
|
||||||
|
client_name: str,
|
||||||
|
blocked_services: list,
|
||||||
|
schedule: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Update blocked services for a specific client."""
|
"""Update blocked services for a specific client."""
|
||||||
|
if not client_name:
|
||||||
|
raise ValueError("Client name is required")
|
||||||
|
|
||||||
client = await self.get_client_by_name(client_name)
|
client = await self.get_client_by_name(client_name)
|
||||||
if not client:
|
if not client:
|
||||||
raise ValueError(f"Client '{client_name}' not found")
|
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||||
|
|
||||||
# Prepare the blocked services data
|
# Prepare the blocked services data with proper structure
|
||||||
if schedule:
|
if schedule:
|
||||||
blocked_services_data = {
|
blocked_services_data = {
|
||||||
"ids": blocked_services,
|
"ids": blocked_services,
|
||||||
@@ -110,7 +215,7 @@ class AdGuardHomeAPI:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update the client
|
# Update the client with new blocked services
|
||||||
update_data = {
|
update_data = {
|
||||||
"name": client_name,
|
"name": client_name,
|
||||||
"data": {
|
"data": {
|
||||||
@@ -121,18 +226,23 @@ class AdGuardHomeAPI:
|
|||||||
|
|
||||||
return await self.update_client(update_data)
|
return await self.update_client(update_data)
|
||||||
|
|
||||||
async def toggle_client_service(self, client_name: str, service_id: str, enabled: bool) -> dict:
|
async def toggle_client_service(
|
||||||
|
self, client_name: str, service_id: str, enabled: bool
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Toggle a specific service for a client."""
|
"""Toggle a specific service for a client."""
|
||||||
|
if not client_name or not service_id:
|
||||||
|
raise ValueError("Client name and service ID are required")
|
||||||
|
|
||||||
client = await self.get_client_by_name(client_name)
|
client = await self.get_client_by_name(client_name)
|
||||||
if not client:
|
if not client:
|
||||||
raise ValueError(f"Client '{client_name}' not found")
|
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||||
|
|
||||||
# Get current blocked services
|
# Get current blocked services
|
||||||
blocked_services = client.get("blocked_services", {})
|
blocked_services = client.get("blocked_services", {})
|
||||||
if isinstance(blocked_services, dict):
|
if isinstance(blocked_services, dict):
|
||||||
service_ids = blocked_services.get("ids", [])
|
service_ids = blocked_services.get("ids", [])
|
||||||
else:
|
else:
|
||||||
# Handle old format (list)
|
# Handle legacy format (direct list)
|
||||||
service_ids = blocked_services if blocked_services else []
|
service_ids = blocked_services if blocked_services else []
|
||||||
|
|
||||||
# Update the service list
|
# Update the service list
|
||||||
@@ -142,3 +252,12 @@ class AdGuardHomeAPI:
|
|||||||
service_ids.remove(service_id)
|
service_ids.remove(service_id)
|
||||||
|
|
||||||
return await self.update_client_blocked_services(client_name, service_ids)
|
return await self.update_client_blocked_services(client_name, service_ids)
|
||||||
|
|
||||||
|
async def get_blocked_services(self) -> Dict[str, Any]:
|
||||||
|
"""Get available blocked services."""
|
||||||
|
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the API session if we own it."""
|
||||||
|
if self._own_session and self._session:
|
||||||
|
await self._session.close()
|
||||||
|
166
custom_components/adguard_hub/binary_sensor.py
Normal file
166
custom_components/adguard_hub/binary_sensor.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Binary sensor platform for AdGuard Control Hub integration."""
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
|
from . import AdGuardControlHubCoordinator
|
||||||
|
from .api import AdGuardHomeAPI
|
||||||
|
from .const import DOMAIN, MANUFACTURER, ICON_PROTECTION, ICON_PROTECTION_OFF
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up AdGuard Control Hub binary sensor platform."""
|
||||||
|
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||||
|
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
AdGuardProtectionBinarySensor(coordinator, api),
|
||||||
|
AdGuardFilteringBinarySensor(coordinator, api),
|
||||||
|
AdGuardSafeBrowsingBinarySensor(coordinator, api),
|
||||||
|
AdGuardParentalControlBinarySensor(coordinator, api),
|
||||||
|
AdGuardSafeSearchBinarySensor(coordinator, api),
|
||||||
|
]
|
||||||
|
|
||||||
|
async_add_entities(entities)
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity):
|
||||||
|
"""Base class for AdGuard binary sensors."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the binary sensor."""
|
||||||
|
super().__init__(coordinator)
|
||||||
|
self.api = api
|
||||||
|
self._attr_device_info = {
|
||||||
|
"identifiers": {(DOMAIN, f"{api.host}:{api.port}")},
|
||||||
|
"name": f"AdGuard Control Hub ({api.host})",
|
||||||
|
"manufacturer": MANUFACTURER,
|
||||||
|
"model": "AdGuard Home",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
|
||||||
|
"""Binary sensor to show AdGuard protection status."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the binary sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled"
|
||||||
|
self._attr_name = "AdGuard Protection Status"
|
||||||
|
self._attr_device_class = BinarySensorDeviceClass.RUNNING
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if protection is enabled."""
|
||||||
|
return self.coordinator.protection_status.get("protection_enabled", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> str:
|
||||||
|
"""Return the icon for the binary sensor."""
|
||||||
|
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_state_attributes(self) -> dict[str, Any]:
|
||||||
|
"""Return additional state attributes."""
|
||||||
|
status = self.coordinator.protection_status
|
||||||
|
return {
|
||||||
|
"dns_port": status.get("dns_port", "N/A"),
|
||||||
|
"http_port": status.get("http_port", "N/A"),
|
||||||
|
"version": status.get("version", "N/A"),
|
||||||
|
"running": status.get("running", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardFilteringBinarySensor(AdGuardBaseBinarySensor):
|
||||||
|
"""Binary sensor to show DNS filtering status."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the binary sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_filtering_enabled"
|
||||||
|
self._attr_name = "AdGuard DNS Filtering"
|
||||||
|
self._attr_device_class = BinarySensorDeviceClass.RUNNING
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if DNS filtering is enabled."""
|
||||||
|
return self.coordinator.protection_status.get("filtering_enabled", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> str:
|
||||||
|
"""Return the icon for the binary sensor."""
|
||||||
|
return "mdi:dns" if self.is_on else "mdi:dns-off"
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor):
|
||||||
|
"""Binary sensor to show Safe Browsing status."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the binary sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled"
|
||||||
|
self._attr_name = "AdGuard Safe Browsing"
|
||||||
|
self._attr_device_class = BinarySensorDeviceClass.SAFETY
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if Safe Browsing is enabled."""
|
||||||
|
return self.coordinator.protection_status.get("safebrowsing_enabled", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> str:
|
||||||
|
"""Return the icon for the binary sensor."""
|
||||||
|
return "mdi:security" if self.is_on else "mdi:security-off"
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor):
|
||||||
|
"""Binary sensor to show Parental Control status."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the binary sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled"
|
||||||
|
self._attr_name = "AdGuard Parental Control"
|
||||||
|
self._attr_device_class = BinarySensorDeviceClass.SAFETY
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if Parental Control is enabled."""
|
||||||
|
return self.coordinator.protection_status.get("parental_enabled", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> str:
|
||||||
|
"""Return the icon for the binary sensor."""
|
||||||
|
return "mdi:account-child" if self.is_on else "mdi:account-child-outline"
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor):
|
||||||
|
"""Binary sensor to show Safe Search status."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the binary sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled"
|
||||||
|
self._attr_name = "AdGuard Safe Search"
|
||||||
|
self._attr_device_class = BinarySensorDeviceClass.SAFETY
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if Safe Search is enabled."""
|
||||||
|
return self.coordinator.protection_status.get("safesearch_enabled", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def icon(self) -> str:
|
||||||
|
"""Return the icon for the binary sensor."""
|
||||||
|
return "mdi:search-web" if self.is_on else "mdi:web-off"
|
@@ -1,73 +1,128 @@
|
|||||||
"""Config flow for AdGuard Control Hub integration."""
|
"""Config flow for AdGuard Control Hub integration."""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from .api import AdGuardHomeAPI
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from .const import CONF_SSL, CONF_VERIFY_SSL, DEFAULT_PORT, DEFAULT_SSL, DEFAULT_VERIFY_SSL, DOMAIN
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
|
||||||
|
from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError
|
||||||
|
from .const import (
|
||||||
|
CONF_SSL,
|
||||||
|
CONF_VERIFY_SSL,
|
||||||
|
DEFAULT_PORT,
|
||||||
|
DEFAULT_SSL,
|
||||||
|
DEFAULT_VERIFY_SSL,
|
||||||
|
DOMAIN,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
STEP_USER_DATA_SCHEMA = vol.Schema({
|
STEP_USER_DATA_SCHEMA = vol.Schema({
|
||||||
vol.Required(CONF_HOST): str,
|
vol.Required(CONF_HOST): cv.string,
|
||||||
vol.Optional(CONF_PORT, default=DEFAULT_PORT): int,
|
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
|
||||||
vol.Optional(CONF_USERNAME): str,
|
vol.Optional(CONF_USERNAME): cv.string,
|
||||||
vol.Optional(CONF_PASSWORD): str,
|
vol.Optional(CONF_PASSWORD): cv.string,
|
||||||
vol.Optional(CONF_SSL, default=DEFAULT_SSL): bool,
|
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
|
||||||
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): bool,
|
vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean,
|
||||||
})
|
})
|
||||||
|
|
||||||
async def validate_input(hass, data: dict) -> dict:
|
|
||||||
|
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate the user input allows us to connect."""
|
"""Validate the user input allows us to connect."""
|
||||||
|
# Normalize host
|
||||||
|
host = data[CONF_HOST].strip()
|
||||||
|
if not host:
|
||||||
|
raise InvalidHost("Host cannot be empty")
|
||||||
|
|
||||||
|
# Remove protocol if provided
|
||||||
|
if host.startswith(("http://", "https://")):
|
||||||
|
host = host.split("://", 1)[1]
|
||||||
|
data[CONF_HOST] = host
|
||||||
|
|
||||||
|
# Validate port
|
||||||
|
port = data[CONF_PORT]
|
||||||
|
if not (1 <= port <= 65535):
|
||||||
|
raise InvalidPort("Port must be between 1 and 65535")
|
||||||
|
|
||||||
|
# Create session with appropriate SSL settings
|
||||||
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
|
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
|
||||||
|
|
||||||
|
# Create API instance
|
||||||
api = AdGuardHomeAPI(
|
api = AdGuardHomeAPI(
|
||||||
host=data[CONF_HOST],
|
host=host,
|
||||||
port=data[CONF_PORT],
|
port=port,
|
||||||
username=data.get(CONF_USERNAME),
|
username=data.get(CONF_USERNAME),
|
||||||
password=data.get(CONF_PASSWORD),
|
password=data.get(CONF_PASSWORD),
|
||||||
ssl=data.get(CONF_SSL, False),
|
ssl=data.get(CONF_SSL, False),
|
||||||
session=session,
|
session=session,
|
||||||
|
timeout=10, # 10 second timeout for setup
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test the connection
|
# Test the connection
|
||||||
if not await api.test_connection():
|
|
||||||
raise CannotConnect
|
|
||||||
|
|
||||||
# Get server info
|
|
||||||
try:
|
try:
|
||||||
status = await api.get_status()
|
if not await api.test_connection():
|
||||||
version = status.get("version", "unknown")
|
raise CannotConnect("Failed to connect to AdGuard Home")
|
||||||
return {
|
|
||||||
"title": f"AdGuard Control Hub ({data[CONF_HOST]})",
|
# Get additional server info if possible
|
||||||
"version": version
|
try:
|
||||||
}
|
status = await api.get_status()
|
||||||
except Exception as err:
|
version = status.get("version", "unknown")
|
||||||
_LOGGER.exception("Unexpected exception: %s", err)
|
dns_port = status.get("dns_port", "N/A")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"title": f"AdGuard Control Hub ({host})",
|
||||||
|
"version": version,
|
||||||
|
"dns_port": dns_port,
|
||||||
|
"host": host,
|
||||||
|
}
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.warning("Could not get server status, but connection works: %s", err)
|
||||||
|
return {
|
||||||
|
"title": f"AdGuard Control Hub ({host})",
|
||||||
|
"version": "unknown",
|
||||||
|
"dns_port": "N/A",
|
||||||
|
"host": host,
|
||||||
|
}
|
||||||
|
|
||||||
|
except AdGuardAuthError as err:
|
||||||
|
_LOGGER.error("Authentication failed: %s", err)
|
||||||
|
raise InvalidAuth from err
|
||||||
|
except AdGuardConnectionError as err:
|
||||||
|
_LOGGER.error("Connection failed: %s", err)
|
||||||
|
if "timeout" in str(err).lower():
|
||||||
|
raise Timeout from err
|
||||||
raise CannotConnect from err
|
raise CannotConnect from err
|
||||||
|
except asyncio.TimeoutError as err:
|
||||||
|
_LOGGER.error("Connection timeout: %s", err)
|
||||||
|
raise Timeout from err
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.exception("Unexpected error during validation: %s", err)
|
||||||
|
raise CannotConnect from err
|
||||||
|
|
||||||
|
|
||||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
"""Handle a config flow for AdGuard Control Hub."""
|
"""Handle a config flow for AdGuard Control Hub."""
|
||||||
|
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
|
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_POLL
|
||||||
|
|
||||||
async def async_step_user(self, user_input: dict[str, Any] | None = None):
|
async def async_step_user(
|
||||||
|
self, user_input: Optional[Dict[str, Any]] = None
|
||||||
|
) -> FlowResult:
|
||||||
"""Handle the initial step."""
|
"""Handle the initial step."""
|
||||||
errors: dict[str, str] = {}
|
errors: Dict[str, str] = {}
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
try:
|
||||||
info = await validate_input(self.hass, user_input)
|
info = await validate_input(self.hass, user_input)
|
||||||
except CannotConnect:
|
|
||||||
errors["base"] = "cannot_connect"
|
|
||||||
except Exception:
|
|
||||||
_LOGGER.exception("Unexpected exception")
|
|
||||||
errors["base"] = "unknown"
|
|
||||||
else:
|
|
||||||
# Create unique ID based on host and port
|
# Create unique ID based on host and port
|
||||||
unique_id = f"{user_input[CONF_HOST]}:{user_input[CONF_PORT]}"
|
unique_id = f"{info['host']}:{user_input[CONF_PORT]}"
|
||||||
await self.async_set_unique_id(unique_id)
|
await self.async_set_unique_id(unique_id)
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
|
|
||||||
@@ -76,11 +131,83 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
data=user_input,
|
data=user_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except CannotConnect:
|
||||||
|
errors["base"] = "cannot_connect"
|
||||||
|
except InvalidAuth:
|
||||||
|
errors["base"] = "invalid_auth"
|
||||||
|
except InvalidHost:
|
||||||
|
errors[CONF_HOST] = "invalid_host"
|
||||||
|
except InvalidPort:
|
||||||
|
errors[CONF_PORT] = "invalid_port"
|
||||||
|
except Timeout:
|
||||||
|
errors["base"] = "timeout"
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception("Unexpected exception during config flow")
|
||||||
|
errors["base"] = "unknown"
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user",
|
step_id="user",
|
||||||
data_schema=STEP_USER_DATA_SCHEMA,
|
data_schema=STEP_USER_DATA_SCHEMA,
|
||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_step_import(self, import_info: Dict[str, Any]) -> FlowResult:
|
||||||
|
"""Handle configuration import."""
|
||||||
|
return await self.async_step_user(import_info)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def async_get_options_flow(config_entry):
|
||||||
|
"""Get the options flow for this handler."""
|
||||||
|
return OptionsFlowHandler(config_entry)
|
||||||
|
|
||||||
|
|
||||||
|
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
|
"""Handle options flow for AdGuard Control Hub."""
|
||||||
|
|
||||||
|
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||||
|
"""Initialize options flow."""
|
||||||
|
self.config_entry = config_entry
|
||||||
|
|
||||||
|
async def async_step_init(
|
||||||
|
self, user_input: Optional[Dict[str, Any]] = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle options flow."""
|
||||||
|
if user_input is not None:
|
||||||
|
return self.async_create_entry(title="", data=user_input)
|
||||||
|
|
||||||
|
options_schema = vol.Schema({
|
||||||
|
vol.Optional(
|
||||||
|
"scan_interval",
|
||||||
|
default=self.config_entry.options.get("scan_interval", 30),
|
||||||
|
): vol.All(vol.Coerce(int), vol.Range(min=10, max=300)),
|
||||||
|
vol.Optional(
|
||||||
|
"timeout",
|
||||||
|
default=self.config_entry.options.get("timeout", 10),
|
||||||
|
): vol.All(vol.Coerce(int), vol.Range(min=5, max=60)),
|
||||||
|
})
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="init",
|
||||||
|
data_schema=options_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom exceptions
|
||||||
class CannotConnect(Exception):
|
class CannotConnect(Exception):
|
||||||
"""Error to indicate we cannot connect."""
|
"""Error to indicate we cannot connect."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuth(Exception):
|
||||||
|
"""Error to indicate there is invalid auth."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidHost(Exception):
|
||||||
|
"""Error to indicate invalid host."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPort(Exception):
|
||||||
|
"""Error to indicate invalid port."""
|
||||||
|
|
||||||
|
|
||||||
|
class Timeout(Exception):
|
||||||
|
"""Error to indicate connection timeout."""
|
||||||
|
@@ -17,7 +17,7 @@ SCAN_INTERVAL: Final = 30
|
|||||||
# Platforms
|
# Platforms
|
||||||
PLATFORMS: Final = [
|
PLATFORMS: Final = [
|
||||||
"switch",
|
"switch",
|
||||||
"binary_sensor",
|
"binary_sensor",
|
||||||
"sensor",
|
"sensor",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ API_ENDPOINTS: Final = {
|
|||||||
"status": "/control/status",
|
"status": "/control/status",
|
||||||
"clients": "/control/clients",
|
"clients": "/control/clients",
|
||||||
"clients_add": "/control/clients/add",
|
"clients_add": "/control/clients/add",
|
||||||
"clients_update": "/control/clients/update",
|
"clients_update": "/control/clients/update",
|
||||||
"clients_delete": "/control/clients/delete",
|
"clients_delete": "/control/clients/delete",
|
||||||
"blocked_services_all": "/control/blocked_services/all",
|
"blocked_services_all": "/control/blocked_services/all",
|
||||||
"blocked_services_get": "/control/blocked_services/get",
|
"blocked_services_get": "/control/blocked_services/get",
|
||||||
@@ -39,7 +39,7 @@ API_ENDPOINTS: Final = {
|
|||||||
BLOCKED_SERVICES: Final = {
|
BLOCKED_SERVICES: Final = {
|
||||||
# Social Media
|
# Social Media
|
||||||
"youtube": "YouTube",
|
"youtube": "YouTube",
|
||||||
"facebook": "Facebook",
|
"facebook": "Facebook",
|
||||||
"instagram": "Instagram",
|
"instagram": "Instagram",
|
||||||
"tiktok": "TikTok",
|
"tiktok": "TikTok",
|
||||||
"twitter": "Twitter/X",
|
"twitter": "Twitter/X",
|
||||||
@@ -62,7 +62,7 @@ BLOCKED_SERVICES: Final = {
|
|||||||
"amazon": "Amazon",
|
"amazon": "Amazon",
|
||||||
"ebay": "eBay",
|
"ebay": "eBay",
|
||||||
|
|
||||||
# Communication
|
# Communication
|
||||||
"whatsapp": "WhatsApp",
|
"whatsapp": "WhatsApp",
|
||||||
"telegram": "Telegram",
|
"telegram": "Telegram",
|
||||||
"discord": "Discord",
|
"discord": "Discord",
|
||||||
@@ -89,4 +89,4 @@ ICON_CLIENT: Final = "mdi:devices"
|
|||||||
ICON_CLIENT_OFFLINE: Final = "mdi:devices-off"
|
ICON_CLIENT_OFFLINE: Final = "mdi:devices-off"
|
||||||
ICON_BLOCKED_SERVICE: Final = "mdi:block-helper"
|
ICON_BLOCKED_SERVICE: Final = "mdi:block-helper"
|
||||||
ICON_ALLOWED_SERVICE: Final = "mdi:check-circle"
|
ICON_ALLOWED_SERVICE: Final = "mdi:check-circle"
|
||||||
ICON_STATISTICS: Final = "mdi:chart-line"
|
ICON_STATISTICS: Final = "mdi:chart-line"
|
||||||
|
@@ -1,14 +1,14 @@
|
|||||||
{
|
{
|
||||||
"domain": "adguard_hub",
|
"domain": "adguard_hub",
|
||||||
"name": "AdGuard Control Hub",
|
"name": "AdGuard Control Hub",
|
||||||
"codeowners": ["@sq4ind"],
|
"codeowners": ["@sq4ind"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
"dependencies": [],
|
"dependencies": [],
|
||||||
"documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub",
|
"documentation": "https://git.sq4ind.eu/sq4ind/adguard-control-hub",
|
||||||
"integration_type": "hub",
|
"integration_type": "hub",
|
||||||
"iot_class": "local_polling",
|
"iot_class": "local_polling",
|
||||||
"requirements": [
|
"requirements": [
|
||||||
"aiohttp>=3.8.0"
|
"aiohttp>=3.8.0"
|
||||||
],
|
],
|
||||||
"version": "1.0.0"
|
"version": "1.0.0"
|
||||||
}
|
}
|
185
custom_components/adguard_hub/sensor.py
Normal file
185
custom_components/adguard_hub/sensor.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""Sensor platform for AdGuard Control Hub integration."""
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components.sensor import SensorEntity, SensorDeviceClass, SensorStateClass
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import PERCENTAGE
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
|
from . import AdGuardControlHubCoordinator
|
||||||
|
from .api import AdGuardHomeAPI
|
||||||
|
from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up AdGuard Control Hub sensor platform."""
|
||||||
|
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||||
|
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||||
|
|
||||||
|
entities = [
|
||||||
|
AdGuardQueriesCounterSensor(coordinator, api),
|
||||||
|
AdGuardBlockedCounterSensor(coordinator, api),
|
||||||
|
AdGuardBlockingPercentageSensor(coordinator, api),
|
||||||
|
AdGuardRuleCountSensor(coordinator, api),
|
||||||
|
AdGuardClientCountSensor(coordinator, api),
|
||||||
|
AdGuardUpstreamAverageTimeSensor(coordinator, api),
|
||||||
|
]
|
||||||
|
|
||||||
|
async_add_entities(entities)
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardBaseSensor(CoordinatorEntity, SensorEntity):
|
||||||
|
"""Base class for AdGuard sensors."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator)
|
||||||
|
self.api = api
|
||||||
|
self._attr_device_info = {
|
||||||
|
"identifiers": {(DOMAIN, f"{api.host}:{api.port}")},
|
||||||
|
"name": f"AdGuard Control Hub ({api.host})",
|
||||||
|
"manufacturer": MANUFACTURER,
|
||||||
|
"model": "AdGuard Home",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardQueriesCounterSensor(AdGuardBaseSensor):
|
||||||
|
"""Sensor to track DNS queries count."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_dns_queries"
|
||||||
|
self._attr_name = "AdGuard DNS Queries"
|
||||||
|
self._attr_icon = ICON_STATISTICS
|
||||||
|
self._attr_state_class = SensorStateClass.TOTAL_INCREASING
|
||||||
|
self._attr_native_unit_of_measurement = "queries"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def native_value(self) -> int | None:
|
||||||
|
"""Return the state of the sensor."""
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
return stats.get("num_dns_queries", 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_state_attributes(self) -> dict[str, Any]:
|
||||||
|
"""Return additional state attributes."""
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
return {
|
||||||
|
"queries_today": stats.get("num_dns_queries_today", 0),
|
||||||
|
"queries_blocked_today": stats.get("num_blocked_filtering_today", 0),
|
||||||
|
"last_updated": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardBlockedCounterSensor(AdGuardBaseSensor):
|
||||||
|
"""Sensor to track blocked queries count."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries"
|
||||||
|
self._attr_name = "AdGuard Blocked Queries"
|
||||||
|
self._attr_icon = "mdi:shield-check"
|
||||||
|
self._attr_state_class = SensorStateClass.TOTAL_INCREASING
|
||||||
|
self._attr_native_unit_of_measurement = "queries"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def native_value(self) -> int | None:
|
||||||
|
"""Return the state of the sensor."""
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
return stats.get("num_blocked_filtering", 0)
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
|
||||||
|
"""Sensor to track blocking percentage."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage"
|
||||||
|
self._attr_name = "AdGuard Blocking Percentage"
|
||||||
|
self._attr_icon = "mdi:percent"
|
||||||
|
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||||
|
self._attr_native_unit_of_measurement = PERCENTAGE
|
||||||
|
self._attr_device_class = SensorDeviceClass.POWER_FACTOR
|
||||||
|
|
||||||
|
@property
|
||||||
|
def native_value(self) -> float | None:
|
||||||
|
"""Return the state of the sensor."""
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
total_queries = stats.get("num_dns_queries", 0)
|
||||||
|
blocked_queries = stats.get("num_blocked_filtering", 0)
|
||||||
|
|
||||||
|
if total_queries == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
percentage = (blocked_queries / total_queries) * 100
|
||||||
|
return round(percentage, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardRuleCountSensor(AdGuardBaseSensor):
|
||||||
|
"""Sensor to track filtering rules count."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_rules_count"
|
||||||
|
self._attr_name = "AdGuard Rules Count"
|
||||||
|
self._attr_icon = "mdi:format-list-numbered"
|
||||||
|
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||||
|
self._attr_native_unit_of_measurement = "rules"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def native_value(self) -> int | None:
|
||||||
|
"""Return the state of the sensor."""
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
return stats.get("filtering_rules_count", 0)
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardClientCountSensor(AdGuardBaseSensor):
|
||||||
|
"""Sensor to track active clients count."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_clients_count"
|
||||||
|
self._attr_name = "AdGuard Clients Count"
|
||||||
|
self._attr_icon = "mdi:account-multiple"
|
||||||
|
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||||
|
self._attr_native_unit_of_measurement = "clients"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def native_value(self) -> int | None:
|
||||||
|
"""Return the state of the sensor."""
|
||||||
|
return len(self.coordinator.clients)
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardUpstreamAverageTimeSensor(AdGuardBaseSensor):
|
||||||
|
"""Sensor to track upstream servers average response time."""
|
||||||
|
|
||||||
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the sensor."""
|
||||||
|
super().__init__(coordinator, api)
|
||||||
|
self._attr_unique_id = f"{api.host}_{api.port}_upstream_response_time"
|
||||||
|
self._attr_name = "AdGuard Upstream Response Time"
|
||||||
|
self._attr_icon = "mdi:timer"
|
||||||
|
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||||
|
self._attr_native_unit_of_measurement = "ms"
|
||||||
|
self._attr_device_class = SensorDeviceClass.DURATION
|
||||||
|
|
||||||
|
@property
|
||||||
|
def native_value(self) -> float | None:
|
||||||
|
"""Return the state of the sensor."""
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
return stats.get("avg_processing_time", 0)
|
@@ -1,38 +1,438 @@
|
|||||||
"""Services for AdGuard Control Hub integration."""
|
"""Service implementations for AdGuard Control Hub integration."""
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from homeassistant.core import HomeAssistant
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
from homeassistant.core import HomeAssistant, ServiceCall
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
|
||||||
from .api import AdGuardHomeAPI
|
from .api import AdGuardHomeAPI
|
||||||
|
from .const import (
|
||||||
|
DOMAIN,
|
||||||
|
BLOCKED_SERVICES,
|
||||||
|
ATTR_CLIENT_NAME,
|
||||||
|
ATTR_SERVICES,
|
||||||
|
ATTR_DURATION,
|
||||||
|
ATTR_CLIENTS,
|
||||||
|
ATTR_CLIENT_PATTERN,
|
||||||
|
ATTR_SETTINGS,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def async_register_services(hass: HomeAssistant, api: AdGuardHomeAPI) -> None:
|
# Service schemas
|
||||||
"""Register integration services."""
|
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
||||||
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||||
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||||
|
})
|
||||||
|
|
||||||
async def emergency_unblock_service(call):
|
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
|
||||||
"""Emergency unblock service."""
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||||
duration = call.data.get("duration", 300)
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||||
clients = call.data.get("clients", ["all"])
|
})
|
||||||
|
|
||||||
|
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
||||||
|
vol.Required(ATTR_DURATION): cv.positive_int,
|
||||||
|
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
||||||
|
})
|
||||||
|
|
||||||
|
SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({
|
||||||
|
vol.Required(ATTR_CLIENT_PATTERN): cv.string,
|
||||||
|
vol.Required(ATTR_SETTINGS): vol.Schema({
|
||||||
|
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||||
|
vol.Optional("filtering_enabled"): cv.boolean,
|
||||||
|
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
||||||
|
vol.Optional("parental_enabled"): cv.boolean,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
SCHEMA_ADD_CLIENT = vol.Schema({
|
||||||
|
vol.Required("name"): cv.string,
|
||||||
|
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
|
||||||
|
vol.Optional("mac"): cv.string,
|
||||||
|
vol.Optional("use_global_settings"): cv.boolean,
|
||||||
|
vol.Optional("filtering_enabled"): cv.boolean,
|
||||||
|
vol.Optional("parental_enabled"): cv.boolean,
|
||||||
|
vol.Optional("safebrowsing_enabled"): cv.boolean,
|
||||||
|
vol.Optional("safesearch_enabled"): cv.boolean,
|
||||||
|
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||||
|
})
|
||||||
|
|
||||||
|
SCHEMA_REMOVE_CLIENT = vol.Schema({
|
||||||
|
vol.Required("name"): cv.string,
|
||||||
|
})
|
||||||
|
|
||||||
|
SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({
|
||||||
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||||
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||||
|
vol.Required("schedule"): vol.Schema({
|
||||||
|
vol.Optional("time_zone", default="Local"): cv.string,
|
||||||
|
vol.Optional("sun"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
vol.Optional("mon"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
vol.Optional("tue"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
vol.Optional("wed"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
vol.Optional("thu"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
vol.Optional("fri"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
vol.Optional("sat"): vol.Schema({
|
||||||
|
vol.Optional("start"): cv.string,
|
||||||
|
vol.Optional("end"): cv.string,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Service names
|
||||||
|
SERVICE_BLOCK_SERVICES = "block_services"
|
||||||
|
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
||||||
|
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
|
||||||
|
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
||||||
|
SERVICE_ADD_CLIENT = "add_client"
|
||||||
|
SERVICE_REMOVE_CLIENT = "remove_client"
|
||||||
|
SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block"
|
||||||
|
|
||||||
|
|
||||||
|
class AdGuardControlHubServices:
|
||||||
|
"""Handle services for AdGuard Control Hub."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant):
|
||||||
|
"""Initialize the services."""
|
||||||
|
self.hass = hass
|
||||||
|
self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
def register_services(self) -> None:
|
||||||
|
"""Register all services."""
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_BLOCK_SERVICES,
|
||||||
|
self.block_services,
|
||||||
|
schema=SCHEMA_BLOCK_SERVICES,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_UNBLOCK_SERVICES,
|
||||||
|
self.unblock_services,
|
||||||
|
schema=SCHEMA_UNBLOCK_SERVICES,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_EMERGENCY_UNBLOCK,
|
||||||
|
self.emergency_unblock,
|
||||||
|
schema=SCHEMA_EMERGENCY_UNBLOCK,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_BULK_UPDATE_CLIENTS,
|
||||||
|
self.bulk_update_clients,
|
||||||
|
schema=SCHEMA_BULK_UPDATE_CLIENTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_ADD_CLIENT,
|
||||||
|
self.add_client,
|
||||||
|
schema=SCHEMA_ADD_CLIENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_REMOVE_CLIENT,
|
||||||
|
self.remove_client,
|
||||||
|
schema=SCHEMA_REMOVE_CLIENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.services.register(
|
||||||
|
DOMAIN,
|
||||||
|
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
||||||
|
self.schedule_service_block,
|
||||||
|
schema=SCHEMA_SCHEDULE_SERVICE_BLOCK,
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister_services(self) -> None:
|
||||||
|
"""Unregister all services."""
|
||||||
|
services = [
|
||||||
|
SERVICE_BLOCK_SERVICES,
|
||||||
|
SERVICE_UNBLOCK_SERVICES,
|
||||||
|
SERVICE_EMERGENCY_UNBLOCK,
|
||||||
|
SERVICE_BULK_UPDATE_CLIENTS,
|
||||||
|
SERVICE_ADD_CLIENT,
|
||||||
|
SERVICE_REMOVE_CLIENT,
|
||||||
|
SERVICE_SCHEDULE_SERVICE_BLOCK,
|
||||||
|
]
|
||||||
|
|
||||||
|
for service in services:
|
||||||
|
self.hass.services.remove(DOMAIN, service)
|
||||||
|
|
||||||
|
def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI:
|
||||||
|
"""Get API instance for a specific config entry."""
|
||||||
|
return self.hass.data[DOMAIN][entry_id]["api"]
|
||||||
|
|
||||||
|
async def block_services(self, call: ServiceCall) -> None:
|
||||||
|
"""Block services for a specific client."""
|
||||||
|
client_name = call.data[ATTR_CLIENT_NAME]
|
||||||
|
services = call.data[ATTR_SERVICES]
|
||||||
|
|
||||||
|
_LOGGER.info("Blocking services %s for client %s", services, client_name)
|
||||||
|
|
||||||
|
# Get all API instances (for multiple AdGuard instances)
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
try:
|
||||||
|
# Get current client data
|
||||||
|
client = await api.get_client_by_name(client_name)
|
||||||
|
if not client:
|
||||||
|
_LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get current blocked services
|
||||||
|
current_blocked = client.get("blocked_services", {})
|
||||||
|
if isinstance(current_blocked, dict):
|
||||||
|
current_services = current_blocked.get("ids", [])
|
||||||
|
else:
|
||||||
|
current_services = current_blocked if current_blocked else []
|
||||||
|
|
||||||
|
# Add new services to block
|
||||||
|
updated_services = list(set(current_services + services))
|
||||||
|
|
||||||
|
# Update client
|
||||||
|
await api.update_client_blocked_services(client_name, updated_services)
|
||||||
|
_LOGGER.info("Successfully blocked services for %s", client_name)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
|
||||||
|
|
||||||
|
async def unblock_services(self, call: ServiceCall) -> None:
|
||||||
|
"""Unblock services for a specific client."""
|
||||||
|
client_name = call.data[ATTR_CLIENT_NAME]
|
||||||
|
services = call.data[ATTR_SERVICES]
|
||||||
|
|
||||||
|
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
||||||
|
|
||||||
|
# Get all API instances
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
try:
|
||||||
|
# Get current client data
|
||||||
|
client = await api.get_client_by_name(client_name)
|
||||||
|
if not client:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get current blocked services
|
||||||
|
current_blocked = client.get("blocked_services", {})
|
||||||
|
if isinstance(current_blocked, dict):
|
||||||
|
current_services = current_blocked.get("ids", [])
|
||||||
|
else:
|
||||||
|
current_services = current_blocked if current_blocked else []
|
||||||
|
|
||||||
|
# Remove services to unblock
|
||||||
|
updated_services = [s for s in current_services if s not in services]
|
||||||
|
|
||||||
|
# Update client
|
||||||
|
await api.update_client_blocked_services(client_name, updated_services)
|
||||||
|
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
|
||||||
|
|
||||||
|
async def emergency_unblock(self, call: ServiceCall) -> None:
|
||||||
|
"""Emergency unblock - temporarily disable protection."""
|
||||||
|
duration = call.data[ATTR_DURATION] # seconds
|
||||||
|
clients = call.data[ATTR_CLIENTS]
|
||||||
|
|
||||||
|
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
|
||||||
|
|
||||||
|
# Cancel any existing emergency unblock tasks
|
||||||
|
for task in self._emergency_unblock_tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
self._emergency_unblock_tasks.clear()
|
||||||
|
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
try:
|
||||||
|
if "all" in clients:
|
||||||
|
# Disable global protection
|
||||||
|
await api.set_protection(False)
|
||||||
|
|
||||||
|
# Schedule re-enable
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._delayed_enable_protection(api, duration)
|
||||||
|
)
|
||||||
|
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
|
||||||
|
else:
|
||||||
|
# Disable protection for specific clients
|
||||||
|
for client_name in clients:
|
||||||
|
client = await api.get_client_by_name(client_name)
|
||||||
|
if client:
|
||||||
|
# Store original blocked services
|
||||||
|
original_blocked = client.get("blocked_services", {})
|
||||||
|
|
||||||
|
# Clear blocked services temporarily
|
||||||
|
await api.update_client_blocked_services(client_name, [])
|
||||||
|
|
||||||
|
# Schedule restore
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._delayed_restore_client(api, client_name, original_blocked, duration)
|
||||||
|
)
|
||||||
|
self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
||||||
|
|
||||||
|
async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None:
|
||||||
|
"""Re-enable protection after delay."""
|
||||||
|
await asyncio.sleep(delay)
|
||||||
try:
|
try:
|
||||||
if "all" in clients:
|
await api.set_protection(True)
|
||||||
await api.set_protection(False)
|
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
||||||
_LOGGER.info("Emergency unblock activated globally for %d seconds", duration)
|
|
||||||
else:
|
|
||||||
_LOGGER.info("Emergency unblock activated for clients: %s", clients)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
_LOGGER.error("Failed to re-enable protection: %s", err)
|
||||||
raise
|
|
||||||
|
|
||||||
# Register emergency unblock service
|
async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str,
|
||||||
hass.services.async_register(
|
original_blocked: Dict, delay: int) -> None:
|
||||||
"adguard_hub",
|
"""Restore client blocked services after delay."""
|
||||||
"emergency_unblock",
|
await asyncio.sleep(delay)
|
||||||
emergency_unblock_service
|
try:
|
||||||
)
|
if isinstance(original_blocked, dict):
|
||||||
|
services = original_blocked.get("ids", [])
|
||||||
|
else:
|
||||||
|
services = original_blocked if original_blocked else []
|
||||||
|
|
||||||
_LOGGER.info("AdGuard Control Hub services registered")
|
await api.update_client_blocked_services(client_name, services)
|
||||||
|
_LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name)
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to restore client blocking: %s", err)
|
||||||
|
|
||||||
async def async_unregister_services(hass: HomeAssistant) -> None:
|
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
||||||
"""Unregister integration services."""
|
"""Update multiple clients matching a pattern."""
|
||||||
hass.services.async_remove("adguard_hub", "emergency_unblock")
|
import re
|
||||||
_LOGGER.info("AdGuard Control Hub services unregistered")
|
|
||||||
|
pattern = call.data[ATTR_CLIENT_PATTERN]
|
||||||
|
settings = call.data[ATTR_SETTINGS]
|
||||||
|
|
||||||
|
_LOGGER.info("Bulk updating clients matching pattern: %s", pattern)
|
||||||
|
|
||||||
|
# Convert pattern to regex
|
||||||
|
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
|
||||||
|
compiled_pattern = re.compile(regex_pattern, re.IGNORECASE)
|
||||||
|
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
coordinator = entry_data["coordinator"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all clients
|
||||||
|
clients = coordinator.clients
|
||||||
|
|
||||||
|
matching_clients = []
|
||||||
|
for client_name in clients.keys():
|
||||||
|
if compiled_pattern.match(client_name):
|
||||||
|
matching_clients.append(client_name)
|
||||||
|
|
||||||
|
_LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients)
|
||||||
|
|
||||||
|
# Update each matching client
|
||||||
|
for client_name in matching_clients:
|
||||||
|
client = clients[client_name]
|
||||||
|
|
||||||
|
# Prepare update data
|
||||||
|
update_data = {
|
||||||
|
"name": client_name,
|
||||||
|
"data": {**client} # Start with current data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply settings
|
||||||
|
if "blocked_services" in settings:
|
||||||
|
blocked_services_data = {
|
||||||
|
"ids": settings["blocked_services"],
|
||||||
|
"schedule": {"time_zone": "Local"}
|
||||||
|
}
|
||||||
|
update_data["data"]["blocked_services"] = blocked_services_data
|
||||||
|
|
||||||
|
if "filtering_enabled" in settings:
|
||||||
|
update_data["data"]["filtering_enabled"] = settings["filtering_enabled"]
|
||||||
|
|
||||||
|
if "safebrowsing_enabled" in settings:
|
||||||
|
update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"]
|
||||||
|
|
||||||
|
if "parental_enabled" in settings:
|
||||||
|
update_data["data"]["parental_enabled"] = settings["parental_enabled"]
|
||||||
|
|
||||||
|
# Update the client
|
||||||
|
await api.update_client(update_data)
|
||||||
|
_LOGGER.info("Updated client: %s", client_name)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to bulk update clients: %s", err)
|
||||||
|
|
||||||
|
async def add_client(self, call: ServiceCall) -> None:
|
||||||
|
"""Add a new client."""
|
||||||
|
client_data = dict(call.data)
|
||||||
|
|
||||||
|
# Convert blocked_services to proper format
|
||||||
|
if "blocked_services" in client_data and client_data["blocked_services"]:
|
||||||
|
blocked_services_data = {
|
||||||
|
"ids": client_data["blocked_services"],
|
||||||
|
"schedule": {"time_zone": "Local"}
|
||||||
|
}
|
||||||
|
client_data["blocked_services"] = blocked_services_data
|
||||||
|
|
||||||
|
_LOGGER.info("Adding new client: %s", client_data["name"])
|
||||||
|
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
try:
|
||||||
|
await api.add_client(client_data)
|
||||||
|
_LOGGER.info("Successfully added client: %s", client_data["name"])
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to add client %s: %s", client_data["name"], err)
|
||||||
|
|
||||||
|
async def remove_client(self, call: ServiceCall) -> None:
|
||||||
|
"""Remove a client."""
|
||||||
|
client_name = call.data["name"]
|
||||||
|
|
||||||
|
_LOGGER.info("Removing client: %s", client_name)
|
||||||
|
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
try:
|
||||||
|
await api.delete_client(client_name)
|
||||||
|
_LOGGER.info("Successfully removed client: %s", client_name)
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
|
||||||
|
|
||||||
|
async def schedule_service_block(self, call: ServiceCall) -> None:
|
||||||
|
"""Schedule service blocking with time-based rules."""
|
||||||
|
client_name = call.data[ATTR_CLIENT_NAME]
|
||||||
|
services = call.data[ATTR_SERVICES]
|
||||||
|
schedule = call.data["schedule"]
|
||||||
|
|
||||||
|
_LOGGER.info("Scheduling service blocking for client %s", client_name)
|
||||||
|
|
||||||
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
|
try:
|
||||||
|
await api.update_client_blocked_services(client_name, services, schedule)
|
||||||
|
_LOGGER.info("Successfully scheduled service blocking for %s", client_name)
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)
|
||||||
|
@@ -1,27 +1,151 @@
|
|||||||
{
|
{
|
||||||
"config": {
|
"config": {
|
||||||
"step": {
|
"step": {
|
||||||
"user": {
|
"user": {
|
||||||
"title": "AdGuard Control Hub",
|
"title": "AdGuard Control Hub",
|
||||||
"description": "Connect to your AdGuard Home instance for complete network control",
|
"description": "Configure your AdGuard Home connection",
|
||||||
"data": {
|
"data": {
|
||||||
"host": "AdGuard Home IP Address",
|
"host": "Host",
|
||||||
"port": "Port (usually 3000)",
|
"port": "Port",
|
||||||
"username": "Admin Username",
|
"username": "Username",
|
||||||
"password": "Admin Password",
|
"password": "Password",
|
||||||
"ssl": "Use HTTPS connection",
|
"ssl": "Use SSL",
|
||||||
"verify_ssl": "Verify SSL certificate"
|
"verify_ssl": "Verify SSL Certificate"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"cannot_connect": "Failed to connect to AdGuard Home. Please check your host, port, and credentials.",
|
||||||
|
"invalid_auth": "Invalid username or password",
|
||||||
|
"timeout": "Connection timeout. Please check your network connection.",
|
||||||
|
"unknown": "An unexpected error occurred"
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"already_configured": "AdGuard Control Hub is already configured for this host and port"
|
||||||
}
|
}
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"error": {
|
"options": {
|
||||||
"cannot_connect": "Failed to connect to AdGuard Home. Check IP address, port, and credentials.",
|
"step": {
|
||||||
"invalid_auth": "Invalid username or password. Please check your admin credentials.",
|
"init": {
|
||||||
"unknown": "Unexpected error occurred. Please check logs for details."
|
"title": "AdGuard Control Hub Options",
|
||||||
|
"description": "Configure advanced options",
|
||||||
|
"data": {
|
||||||
|
"scan_interval": "Update interval (seconds)",
|
||||||
|
"timeout": "Connection timeout (seconds)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"abort": {
|
"services": {
|
||||||
"already_configured": "This AdGuard Home instance is already configured",
|
"block_services": {
|
||||||
"cannot_connect": "Cannot connect to AdGuard Home"
|
"name": "Block Services",
|
||||||
|
"description": "Block specific services for a client",
|
||||||
|
"fields": {
|
||||||
|
"client_name": {
|
||||||
|
"name": "Client Name",
|
||||||
|
"description": "Name of the client to block services for"
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"name": "Services",
|
||||||
|
"description": "List of services to block"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"unblock_services": {
|
||||||
|
"name": "Unblock Services",
|
||||||
|
"description": "Unblock specific services for a client",
|
||||||
|
"fields": {
|
||||||
|
"client_name": {
|
||||||
|
"name": "Client Name",
|
||||||
|
"description": "Name of the client to unblock services for"
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"name": "Services",
|
||||||
|
"description": "List of services to unblock"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"emergency_unblock": {
|
||||||
|
"name": "Emergency Unblock",
|
||||||
|
"description": "Temporarily disable blocking for emergency access",
|
||||||
|
"fields": {
|
||||||
|
"duration": {
|
||||||
|
"name": "Duration",
|
||||||
|
"description": "Duration in seconds to keep unblocked"
|
||||||
|
},
|
||||||
|
"clients": {
|
||||||
|
"name": "Clients",
|
||||||
|
"description": "List of client names (use 'all' for global)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bulk_update_clients": {
|
||||||
|
"name": "Bulk Update Clients",
|
||||||
|
"description": "Update multiple clients matching a pattern",
|
||||||
|
"fields": {
|
||||||
|
"client_pattern": {
|
||||||
|
"name": "Client Pattern",
|
||||||
|
"description": "Pattern to match client names (supports wildcards)"
|
||||||
|
},
|
||||||
|
"settings": {
|
||||||
|
"name": "Settings",
|
||||||
|
"description": "Settings to apply to matching clients"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"add_client": {
|
||||||
|
"name": "Add Client",
|
||||||
|
"description": "Add a new client configuration",
|
||||||
|
"fields": {
|
||||||
|
"name": {
|
||||||
|
"name": "Name",
|
||||||
|
"description": "Client name"
|
||||||
|
},
|
||||||
|
"ids": {
|
||||||
|
"name": "IDs",
|
||||||
|
"description": "List of IP addresses or CIDR ranges"
|
||||||
|
},
|
||||||
|
"mac": {
|
||||||
|
"name": "MAC Address",
|
||||||
|
"description": "MAC address (optional)"
|
||||||
|
},
|
||||||
|
"filtering_enabled": {
|
||||||
|
"name": "Filtering Enabled",
|
||||||
|
"description": "Enable DNS filtering for this client"
|
||||||
|
},
|
||||||
|
"blocked_services": {
|
||||||
|
"name": "Blocked Services",
|
||||||
|
"description": "List of services to block"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"remove_client": {
|
||||||
|
"name": "Remove Client",
|
||||||
|
"description": "Remove a client configuration",
|
||||||
|
"fields": {
|
||||||
|
"name": {
|
||||||
|
"name": "Name",
|
||||||
|
"description": "Name of the client to remove"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"schedule_service_block": {
|
||||||
|
"name": "Schedule Service Block",
|
||||||
|
"description": "Schedule time-based service blocking",
|
||||||
|
"fields": {
|
||||||
|
"client_name": {
|
||||||
|
"name": "Client Name",
|
||||||
|
"description": "Name of the client"
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"name": "Services",
|
||||||
|
"description": "List of services to block"
|
||||||
|
},
|
||||||
|
"schedule": {
|
||||||
|
"name": "Schedule",
|
||||||
|
"description": "Time-based schedule configuration"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
@@ -1,10 +1,13 @@
|
|||||||
"""Switch platform for AdGuard Control Hub integration."""
|
"""Switch platform for AdGuard Control Hub integration."""
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.switch import SwitchEntity
|
from homeassistant.components.switch import SwitchEntity
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||||
|
|
||||||
from . import AdGuardControlHubCoordinator
|
from . import AdGuardControlHubCoordinator
|
||||||
from .api import AdGuardHomeAPI
|
from .api import AdGuardHomeAPI
|
||||||
from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER
|
from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER
|
||||||
@@ -12,7 +15,11 @@ from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MA
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback):
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
"""Set up AdGuard Control Hub switch platform."""
|
"""Set up AdGuard Control Hub switch platform."""
|
||||||
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||||
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||||
@@ -32,6 +39,7 @@ class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity):
|
|||||||
"""Base class for AdGuard switches."""
|
"""Base class for AdGuard switches."""
|
||||||
|
|
||||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the switch."""
|
||||||
super().__init__(coordinator)
|
super().__init__(coordinator)
|
||||||
self.api = api
|
self.api = api
|
||||||
self._attr_device_info = {
|
self._attr_device_info = {
|
||||||
@@ -46,31 +54,64 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
|
|||||||
"""Switch to control global AdGuard protection."""
|
"""Switch to control global AdGuard protection."""
|
||||||
|
|
||||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
|
||||||
|
"""Initialize the switch."""
|
||||||
super().__init__(coordinator, api)
|
super().__init__(coordinator, api)
|
||||||
self._attr_unique_id = f"{api.host}_{api.port}_protection"
|
self._attr_unique_id = f"{api.host}_{api.port}_protection"
|
||||||
self._attr_name = "AdGuard Protection"
|
self._attr_name = "AdGuard Protection"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_on(self) -> bool:
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if protection is enabled."""
|
||||||
return self.coordinator.protection_status.get("protection_enabled", False)
|
return self.coordinator.protection_status.get("protection_enabled", False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def icon(self) -> str:
|
def icon(self) -> str:
|
||||||
|
"""Return the icon for the switch."""
|
||||||
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
|
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
|
||||||
|
|
||||||
async def async_turn_on(self, **kwargs):
|
@property
|
||||||
await self.api.set_protection(True)
|
def extra_state_attributes(self) -> dict[str, Any]:
|
||||||
await self.coordinator.async_request_refresh()
|
"""Return additional state attributes."""
|
||||||
|
status = self.coordinator.protection_status
|
||||||
|
stats = self.coordinator.statistics
|
||||||
|
return {
|
||||||
|
"dns_port": status.get("dns_port", "N/A"),
|
||||||
|
"queries_today": stats.get("num_dns_queries_today", 0),
|
||||||
|
"blocked_today": stats.get("num_blocked_filtering_today", 0),
|
||||||
|
"version": status.get("version", "N/A"),
|
||||||
|
}
|
||||||
|
|
||||||
async def async_turn_off(self, **kwargs):
|
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||||
await self.api.set_protection(False)
|
"""Turn on AdGuard protection."""
|
||||||
await self.coordinator.async_request_refresh()
|
try:
|
||||||
|
await self.api.set_protection(True)
|
||||||
|
await self.coordinator.async_request_refresh()
|
||||||
|
_LOGGER.info("AdGuard protection enabled")
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to enable AdGuard protection: %s", err)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||||
|
"""Turn off AdGuard protection."""
|
||||||
|
try:
|
||||||
|
await self.api.set_protection(False)
|
||||||
|
await self.coordinator.async_request_refresh()
|
||||||
|
_LOGGER.info("AdGuard protection disabled")
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to disable AdGuard protection: %s", err)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class AdGuardClientSwitch(AdGuardBaseSwitch):
|
class AdGuardClientSwitch(AdGuardBaseSwitch):
|
||||||
"""Switch to control client-specific protection."""
|
"""Switch to control client-specific protection."""
|
||||||
|
|
||||||
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI, client_name: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
coordinator: AdGuardControlHubCoordinator,
|
||||||
|
api: AdGuardHomeAPI,
|
||||||
|
client_name: str,
|
||||||
|
):
|
||||||
|
"""Initialize the switch."""
|
||||||
super().__init__(coordinator, api)
|
super().__init__(coordinator, api)
|
||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}"
|
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}"
|
||||||
@@ -78,16 +119,81 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
|
|||||||
self._attr_icon = ICON_CLIENT
|
self._attr_icon = ICON_CLIENT
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_on(self) -> bool:
|
def is_on(self) -> bool | None:
|
||||||
|
"""Return true if client protection is enabled."""
|
||||||
client = self.coordinator.clients.get(self.client_name, {})
|
client = self.coordinator.clients.get(self.client_name, {})
|
||||||
return client.get("filtering_enabled", True)
|
return client.get("filtering_enabled", True)
|
||||||
|
|
||||||
async def async_turn_on(self, **kwargs):
|
@property
|
||||||
# This would update client settings - simplified for basic functionality
|
def extra_state_attributes(self) -> dict[str, Any]:
|
||||||
_LOGGER.info("Would enable protection for %s", self.client_name)
|
"""Return additional state attributes."""
|
||||||
await self.coordinator.async_request_refresh()
|
client = self.coordinator.clients.get(self.client_name, {})
|
||||||
|
blocked_services = client.get("blocked_services", {})
|
||||||
|
|
||||||
async def async_turn_off(self, **kwargs):
|
if isinstance(blocked_services, dict):
|
||||||
# This would update client settings - simplified for basic functionality
|
service_ids = blocked_services.get("ids", [])
|
||||||
_LOGGER.info("Would disable protection for %s", self.client_name)
|
else:
|
||||||
await self.coordinator.async_request_refresh()
|
service_ids = blocked_services if blocked_services else []
|
||||||
|
|
||||||
|
return {
|
||||||
|
"client_ids": client.get("ids", []),
|
||||||
|
"mac": client.get("mac", ""),
|
||||||
|
"use_global_settings": client.get("use_global_settings", True),
|
||||||
|
"safebrowsing_enabled": client.get("safebrowsing_enabled", False),
|
||||||
|
"parental_enabled": client.get("parental_enabled", False),
|
||||||
|
"safesearch_enabled": client.get("safesearch_enabled", False),
|
||||||
|
"blocked_services": service_ids,
|
||||||
|
"blocked_services_count": len(service_ids),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||||
|
"""Enable protection for this client."""
|
||||||
|
try:
|
||||||
|
# Get current client data
|
||||||
|
client = await self.api.get_client_by_name(self.client_name)
|
||||||
|
if not client:
|
||||||
|
_LOGGER.error("Client %s not found", self.client_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update client with filtering enabled
|
||||||
|
update_data = {
|
||||||
|
"name": self.client_name,
|
||||||
|
"data": {
|
||||||
|
**client,
|
||||||
|
"filtering_enabled": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.api.update_client(update_data)
|
||||||
|
await self.coordinator.async_request_refresh()
|
||||||
|
_LOGGER.info("Enabled protection for client %s", self.client_name)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||||
|
"""Disable protection for this client."""
|
||||||
|
try:
|
||||||
|
# Get current client data
|
||||||
|
client = await self.api.get_client_by_name(self.client_name)
|
||||||
|
if not client:
|
||||||
|
_LOGGER.error("Client %s not found", self.client_name)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Update client with filtering disabled
|
||||||
|
update_data = {
|
||||||
|
"name": self.client_name,
|
||||||
|
"data": {
|
||||||
|
**client,
|
||||||
|
"filtering_enabled": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.api.update_client(update_data)
|
||||||
|
await self.coordinator.async_request_refresh()
|
||||||
|
_LOGGER.info("Disabled protection for client %s", self.client_name)
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
|
||||||
|
raise
|
||||||
|
@@ -3,6 +3,7 @@ import pytest
|
|||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_session():
|
def mock_session():
|
||||||
"""Mock aiohttp session."""
|
"""Mock aiohttp session."""
|
||||||
@@ -15,6 +16,7 @@ def mock_session():
|
|||||||
session.request = AsyncMock(return_value=response)
|
session.request = AsyncMock(return_value=response)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def test_api_connection(mock_session):
|
async def test_api_connection(mock_session):
|
||||||
"""Test API connection."""
|
"""Test API connection."""
|
||||||
api = AdGuardHomeAPI(
|
api = AdGuardHomeAPI(
|
||||||
@@ -28,13 +30,14 @@ async def test_api_connection(mock_session):
|
|||||||
result = await api.test_connection()
|
result = await api.test_connection()
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
async def test_api_get_status(mock_session):
|
async def test_api_get_status(mock_session):
|
||||||
"""Test getting status."""
|
"""Test getting status."""
|
||||||
api = AdGuardHomeAPI(
|
api = AdGuardHomeAPI(
|
||||||
host="test-host",
|
host="test-host",
|
||||||
port=3000,
|
port=3000,
|
||||||
session=mock_session
|
session=mock_session
|
||||||
)
|
)
|
||||||
|
|
||||||
status = await api.get_status()
|
status = await api.get_status()
|
||||||
assert status == {"status": "ok"}
|
assert status == {"status": "ok"}
|
||||||
|
223
tests/test_integration.py
Normal file
223
tests/test_integration.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""Test the complete AdGuard Control Hub integration."""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
|
||||||
|
|
||||||
|
from custom_components.adguard_hub import async_setup_entry, async_unload_entry
|
||||||
|
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
||||||
|
from custom_components.adguard_hub.const import DOMAIN
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_config_entry():
|
||||||
|
"""Mock config entry."""
|
||||||
|
return ConfigEntry(
|
||||||
|
version=1,
|
||||||
|
domain=DOMAIN,
|
||||||
|
title="Test AdGuard",
|
||||||
|
data={
|
||||||
|
CONF_HOST: "192.168.1.100",
|
||||||
|
CONF_PORT: 3000,
|
||||||
|
CONF_USERNAME: "admin",
|
||||||
|
CONF_PASSWORD: "password",
|
||||||
|
},
|
||||||
|
source="user",
|
||||||
|
entry_id="test_entry_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_api():
|
||||||
|
"""Mock API instance."""
|
||||||
|
api = MagicMock(spec=AdGuardHomeAPI)
|
||||||
|
api.host = "192.168.1.100"
|
||||||
|
api.port = 3000
|
||||||
|
api.test_connection = AsyncMock(return_value=True)
|
||||||
|
api.get_status = AsyncMock(return_value={
|
||||||
|
"protection_enabled": True,
|
||||||
|
"version": "v0.107.0",
|
||||||
|
"dns_port": 53,
|
||||||
|
"running": True,
|
||||||
|
})
|
||||||
|
api.get_clients = AsyncMock(return_value={
|
||||||
|
"clients": [
|
||||||
|
{
|
||||||
|
"name": "test_client",
|
||||||
|
"ids": ["192.168.1.50"],
|
||||||
|
"filtering_enabled": True,
|
||||||
|
"blocked_services": {"ids": ["youtube"]},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
api.get_statistics = AsyncMock(return_value={
|
||||||
|
"num_dns_queries": 1000,
|
||||||
|
"num_blocked_filtering": 100,
|
||||||
|
"num_dns_queries_today": 500,
|
||||||
|
"num_blocked_filtering_today": 50,
|
||||||
|
"filtering_rules_count": 50000,
|
||||||
|
"avg_processing_time": 2.5,
|
||||||
|
})
|
||||||
|
return api
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_entry_success(hass: HomeAssistant, mock_config_entry, mock_api):
|
||||||
|
"""Test successful setup of config entry."""
|
||||||
|
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
|
||||||
|
patch("custom_components.adguard_hub.async_get_clientsession"), \
|
||||||
|
patch.object(hass.config_entries, "async_forward_entry_setups", return_value=True):
|
||||||
|
|
||||||
|
result = await async_setup_entry(hass, mock_config_entry)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert DOMAIN in hass.data
|
||||||
|
assert mock_config_entry.entry_id in hass.data[DOMAIN]
|
||||||
|
assert "coordinator" in hass.data[DOMAIN][mock_config_entry.entry_id]
|
||||||
|
assert "api" in hass.data[DOMAIN][mock_config_entry.entry_id]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_entry_connection_failure(hass: HomeAssistant, mock_config_entry):
|
||||||
|
"""Test setup failure due to connection error."""
|
||||||
|
mock_api = MagicMock(spec=AdGuardHomeAPI)
|
||||||
|
mock_api.test_connection = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
|
||||||
|
patch("custom_components.adguard_hub.async_get_clientsession"), \
|
||||||
|
pytest.raises(Exception): # Should raise ConfigEntryNotReady
|
||||||
|
|
||||||
|
await async_setup_entry(hass, mock_config_entry)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unload_entry(hass: HomeAssistant, mock_config_entry):
|
||||||
|
"""Test unloading of config entry."""
|
||||||
|
# Set up initial data
|
||||||
|
hass.data[DOMAIN] = {
|
||||||
|
mock_config_entry.entry_id: {
|
||||||
|
"coordinator": MagicMock(),
|
||||||
|
"api": MagicMock(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(hass.config_entries, "async_unload_platforms", return_value=True):
|
||||||
|
result = await async_unload_entry(hass, mock_config_entry)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert mock_config_entry.entry_id not in hass.data[DOMAIN]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_coordinator_data_update(hass: HomeAssistant, mock_api):
|
||||||
|
"""Test coordinator data update functionality."""
|
||||||
|
from custom_components.adguard_hub import AdGuardControlHubCoordinator
|
||||||
|
|
||||||
|
coordinator = AdGuardControlHubCoordinator(hass, mock_api)
|
||||||
|
|
||||||
|
# Test successful data update
|
||||||
|
data = await coordinator._async_update_data()
|
||||||
|
|
||||||
|
assert "clients" in data
|
||||||
|
assert "statistics" in data
|
||||||
|
assert "status" in data
|
||||||
|
assert "test_client" in data["clients"]
|
||||||
|
assert data["statistics"]["num_dns_queries"] == 1000
|
||||||
|
assert data["status"]["protection_enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error_handling(mock_api):
|
||||||
|
"""Test API error handling."""
|
||||||
|
from custom_components.adguard_hub.api import AdGuardConnectionError, AdGuardAuthError
|
||||||
|
|
||||||
|
# Test connection error
|
||||||
|
mock_api.get_status = AsyncMock(side_effect=AdGuardConnectionError("Connection failed"))
|
||||||
|
|
||||||
|
with pytest.raises(AdGuardConnectionError):
|
||||||
|
await mock_api.get_status()
|
||||||
|
|
||||||
|
# Test auth error
|
||||||
|
mock_api.get_clients = AsyncMock(side_effect=AdGuardAuthError("Auth failed"))
|
||||||
|
|
||||||
|
with pytest.raises(AdGuardAuthError):
|
||||||
|
await mock_api.get_clients()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_services_registration(hass: HomeAssistant):
|
||||||
|
"""Test that services are properly registered."""
|
||||||
|
from custom_components.adguard_hub.services import AdGuardControlHubServices
|
||||||
|
|
||||||
|
services = AdGuardControlHubServices(hass)
|
||||||
|
services.register_services()
|
||||||
|
|
||||||
|
# Check that services are registered
|
||||||
|
assert hass.services.has_service(DOMAIN, "block_services")
|
||||||
|
assert hass.services.has_service(DOMAIN, "unblock_services")
|
||||||
|
assert hass.services.has_service(DOMAIN, "emergency_unblock")
|
||||||
|
assert hass.services.has_service(DOMAIN, "bulk_update_clients")
|
||||||
|
assert hass.services.has_service(DOMAIN, "add_client")
|
||||||
|
assert hass.services.has_service(DOMAIN, "remove_client")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
services.unregister_services()
|
||||||
|
|
||||||
|
|
||||||
|
def test_blocked_services_constants():
|
||||||
|
"""Test that blocked services are properly defined."""
|
||||||
|
from custom_components.adguard_hub.const import BLOCKED_SERVICES
|
||||||
|
|
||||||
|
assert "youtube" in BLOCKED_SERVICES
|
||||||
|
assert "netflix" in BLOCKED_SERVICES
|
||||||
|
assert "gaming" in BLOCKED_SERVICES
|
||||||
|
assert "facebook" in BLOCKED_SERVICES
|
||||||
|
|
||||||
|
# Check friendly names are defined
|
||||||
|
assert BLOCKED_SERVICES["youtube"] == "YouTube"
|
||||||
|
assert BLOCKED_SERVICES["netflix"] == "Netflix"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_endpoints():
|
||||||
|
"""Test that API endpoints are properly defined."""
|
||||||
|
from custom_components.adguard_hub.const import API_ENDPOINTS
|
||||||
|
|
||||||
|
required_endpoints = [
|
||||||
|
"status", "clients", "stats", "protection",
|
||||||
|
"clients_add", "clients_update", "clients_delete"
|
||||||
|
]
|
||||||
|
|
||||||
|
for endpoint in required_endpoints:
|
||||||
|
assert endpoint in API_ENDPOINTS
|
||||||
|
assert API_ENDPOINTS[endpoint].startswith("/")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_operations(mock_api):
|
||||||
|
"""Test client add/update/delete operations."""
|
||||||
|
# Test add client
|
||||||
|
client_data = {
|
||||||
|
"name": "new_client",
|
||||||
|
"ids": ["192.168.1.200"],
|
||||||
|
"filtering_enabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_api.add_client = AsyncMock(return_value={"success": True})
|
||||||
|
result = await mock_api.add_client(client_data)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
# Test update client
|
||||||
|
update_data = {
|
||||||
|
"name": "new_client",
|
||||||
|
"data": {"filtering_enabled": False}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_api.update_client = AsyncMock(return_value={"success": True})
|
||||||
|
result = await mock_api.update_client(update_data)
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
# Test delete client
|
||||||
|
mock_api.delete_client = AsyncMock(return_value={"success": True})
|
||||||
|
result = await mock_api.delete_client("new_client")
|
||||||
|
assert result["success"] is True
|
Reference in New Issue
Block a user