"""API wrapper for AdGuard Home.""" import logging from typing import Any import aiohttp from aiohttp import BasicAuth from .const import API_ENDPOINTS _LOGGER = logging.getLogger(__name__) class AdGuardHomeAPI: """API wrapper for AdGuard Home.""" def __init__(self, host: str, port: int = 3000, username: str = None, password: str = None, ssl: bool = False, session = None): self.host = host self.port = port self.username = username self.password = password self.ssl = ssl self.session = session protocol = "https" if ssl else "http" self.base_url = f"{protocol}://{host}:{port}" async def _request(self, method: str, endpoint: str, data: dict = None) -> dict: """Make an API request.""" url = f"{self.base_url}{endpoint}" headers = {"Content-Type": "application/json"} auth = None if self.username and self.password: auth = BasicAuth(self.username, self.password) try: async with self.session.request(method, url, json=data, headers=headers, auth=auth) as response: response.raise_for_status() if response.status == 204 or not response.content_length: return {} return await response.json() except Exception as err: _LOGGER.error("Error communicating with AdGuard Home: %s", err) raise async def test_connection(self) -> bool: """Test the connection.""" try: await self._request("GET", API_ENDPOINTS["status"]) return True except: return False async def get_status(self) -> dict: """Get server status.""" return await self._request("GET", API_ENDPOINTS["status"]) async def get_clients(self) -> dict: """Get all clients.""" return await self._request("GET", API_ENDPOINTS["clients"]) async def get_statistics(self) -> dict: """Get statistics.""" return await self._request("GET", API_ENDPOINTS["stats"]) async def set_protection(self, enabled: bool) -> dict: """Enable or disable protection.""" data = {"enabled": enabled} return await self._request("POST", API_ENDPOINTS["protection"], data) async def add_client(self, client_data: dict) -> dict: """Add a new client.""" return await self._request("POST", API_ENDPOINTS["clients_add"], client_data) async def update_client(self, client_data: dict) -> dict: """Update an existing client.""" return await self._request("POST", API_ENDPOINTS["clients_update"], client_data) async def delete_client(self, client_name: str) -> dict: """Delete a client.""" data = {"name": client_name} return await self._request("POST", API_ENDPOINTS["clients_delete"], data) async def get_client_by_name(self, client_name: str) -> dict: """Get a specific client by name.""" clients_data = await self.get_clients() clients = clients_data.get("clients", []) for client in clients: if client.get("name") == client_name: return client return None async def update_client_blocked_services(self, client_name: str, blocked_services: list, schedule: dict = None) -> dict: """Update blocked services for a specific client.""" client = await self.get_client_by_name(client_name) if not client: raise ValueError(f"Client '{client_name}' not found") # Prepare the blocked services data if schedule: blocked_services_data = { "ids": blocked_services, "schedule": schedule } else: blocked_services_data = { "ids": blocked_services, "schedule": { "time_zone": "Local" } } # Update the client update_data = { "name": client_name, "data": { **client, "blocked_services": blocked_services_data } } return await self.update_client(update_data) async def toggle_client_service(self, client_name: str, service_id: str, enabled: bool) -> dict: """Toggle a specific service for a client.""" client = await self.get_client_by_name(client_name) if not client: raise ValueError(f"Client '{client_name}' not found") # Get current blocked services blocked_services = client.get("blocked_services", {}) if isinstance(blocked_services, dict): service_ids = blocked_services.get("ids", []) else: # Handle old format (list) service_ids = blocked_services if blocked_services else [] # Update the service list if enabled and service_id not in service_ids: service_ids.append(service_id) elif not enabled and service_id in service_ids: service_ids.remove(service_id) return await self.update_client_blocked_services(client_name, service_ids)