"""API wrapper for AdGuard Home.""" import asyncio import logging from typing import Any, Dict, Optional import aiohttp from aiohttp import BasicAuth, ClientError, ClientTimeout from .const import API_ENDPOINTS _LOGGER = logging.getLogger(__name__) class AdGuardHomeError(Exception): """Base exception for AdGuard Home API.""" class AdGuardConnectionError(AdGuardHomeError): """Exception for connection errors.""" class AdGuardAuthError(AdGuardHomeError): """Exception for authentication errors.""" class AdGuardNotFoundError(AdGuardHomeError): """Exception for not found errors.""" class AdGuardTimeoutError(AdGuardHomeError): """Exception for timeout errors.""" class AdGuardHomeAPI: """API wrapper for AdGuard Home.""" def __init__( self, host: str, port: int = 3000, username: Optional[str] = None, password: Optional[str] = None, ssl: bool = False, session: Optional[aiohttp.ClientSession] = None, timeout: int = 10, verify_ssl: bool = True, ) -> None: """Initialize the API wrapper.""" self.host = host self.port = port self.username = username self.password = password self.ssl = ssl self.verify_ssl = verify_ssl self._session = session self._timeout = ClientTimeout(total=timeout) protocol = "https" if ssl else "http" self.base_url = f"{protocol}://{host}:{port}" self._own_session = session is None async def __aenter__(self): """Async context manager entry.""" if self._own_session: connector = aiohttp.TCPConnector(ssl=self.verify_ssl) self._session = aiohttp.ClientSession( timeout=self._timeout, connector=connector ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" if self._own_session and self._session: await self._session.close() @property def session(self) -> aiohttp.ClientSession: """Get the session, creating one if needed.""" if not self._session: connector = aiohttp.TCPConnector(ssl=self.verify_ssl) self._session = aiohttp.ClientSession( timeout=self._timeout, connector=connector ) return self._session async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]: """Make an API request.""" url = f"{self.base_url}{endpoint}" headers = {"Content-Type": "application/json"} auth = None if self.username and self.password: auth = BasicAuth(self.username, self.password) try: async with self.session.request( method, url, json=data, headers=headers, auth=auth, ssl=self.verify_ssl ) as response: if response.status == 401: raise AdGuardAuthError("Authentication failed") elif response.status == 404: raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}") elif response.status >= 500: raise AdGuardConnectionError(f"Server error {response.status}") response.raise_for_status() # Handle empty responses if response.status == 204 or not response.content_length: return {} try: return await response.json() except (aiohttp.ContentTypeError, ValueError): # If not JSON, return text response text = await response.text() return {"response": text} except asyncio.TimeoutError as err: raise AdGuardTimeoutError(f"Request timeout: {err}") from err except ClientError as err: raise AdGuardConnectionError(f"Client error: {err}") from err except Exception as err: if isinstance(err, AdGuardHomeError): raise raise AdGuardHomeError(f"Unexpected error: {err}") from err async def test_connection(self) -> bool: """Test the connection to AdGuard Home.""" try: response = await self._request("GET", API_ENDPOINTS["status"]) return isinstance(response, dict) and len(response) > 0 except Exception: return False async def get_status(self) -> Dict[str, Any]: """Get server status information.""" return await self._request("GET", API_ENDPOINTS["status"]) async def get_clients(self) -> Dict[str, Any]: """Get all configured clients.""" return await self._request("GET", API_ENDPOINTS["clients"]) async def get_statistics(self) -> Dict[str, Any]: """Get DNS query statistics.""" return await self._request("GET", API_ENDPOINTS["stats"]) async def set_protection(self, enabled: bool) -> Dict[str, Any]: """Enable or disable AdGuard protection.""" data = {"enabled": enabled} return await self._request("POST", API_ENDPOINTS["protection"], data) async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]: """Add a new client configuration.""" if "name" not in client_data: raise ValueError("Client name is required") if "ids" not in client_data or not client_data["ids"]: raise ValueError("Client IDs are required") return await self._request("POST", API_ENDPOINTS["clients_add"], client_data) async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]: """Update an existing client configuration.""" if "name" not in client_data: raise ValueError("Client name is required") if "data" not in client_data: raise ValueError("Client data is required") return await self._request("POST", API_ENDPOINTS["clients_update"], client_data) async def delete_client(self, client_name: str) -> Dict[str, Any]: """Delete a client configuration.""" if not client_name: raise ValueError("Client name is required") data = {"name": client_name} return await self._request("POST", API_ENDPOINTS["clients_delete"], data) async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]: """Get a specific client by name.""" if not client_name: return None try: clients_data = await self.get_clients() clients = clients_data.get("clients", []) for client in clients: if client.get("name") == client_name: return client return None except Exception as err: _LOGGER.error("Error getting client %s: %s", client_name, err) return None async def update_client_blocked_services( self, client_name: str, blocked_services: list, ) -> Dict[str, Any]: """Update blocked services for a specific client.""" if not client_name: raise ValueError("Client name is required") client = await self.get_client_by_name(client_name) if not client: raise AdGuardNotFoundError(f"Client '{client_name}' not found") # Format blocked services data according to AdGuard Home API blocked_services_data = { "ids": blocked_services, "schedule": {"time_zone": "Local"} } update_data = { "name": client_name, "data": { **client, "blocked_services": blocked_services_data } } return await self.update_client(update_data) async def get_blocked_services_list(self) -> Dict[str, Any]: """Get list of available blocked services.""" try: return await self._request("GET", API_ENDPOINTS["blocked_services_all"]) except Exception as err: _LOGGER.error("Error getting blocked services list: %s", err) return {} async def close(self) -> None: """Close the API session if we own it.""" if self._own_session and self._session: await self._session.close()