240 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			240 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """API wrapper for AdGuard Home."""
 | |
| import asyncio
 | |
| import logging
 | |
| from typing import Any, Dict, Optional
 | |
| 
 | |
| import aiohttp
 | |
| from aiohttp import BasicAuth, ClientError, ClientTimeout
 | |
| 
 | |
| from .const import API_ENDPOINTS
 | |
| 
 | |
| _LOGGER = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class AdGuardHomeError(Exception):
 | |
|     """Base exception for AdGuard Home API."""
 | |
| 
 | |
| 
 | |
| class AdGuardConnectionError(AdGuardHomeError):
 | |
|     """Exception for connection errors."""
 | |
| 
 | |
| 
 | |
| class AdGuardAuthError(AdGuardHomeError):
 | |
|     """Exception for authentication errors."""
 | |
| 
 | |
| 
 | |
| class AdGuardNotFoundError(AdGuardHomeError):
 | |
|     """Exception for not found errors."""
 | |
| 
 | |
| 
 | |
| class AdGuardTimeoutError(AdGuardHomeError):
 | |
|     """Exception for timeout errors."""
 | |
| 
 | |
| 
 | |
| class AdGuardHomeAPI:
 | |
|     """API wrapper for AdGuard Home."""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         host: str,
 | |
|         port: int = 3000,
 | |
|         username: Optional[str] = None,
 | |
|         password: Optional[str] = None,
 | |
|         ssl: bool = False,
 | |
|         session: Optional[aiohttp.ClientSession] = None,
 | |
|         timeout: int = 10,
 | |
|         verify_ssl: bool = True,
 | |
|     ) -> None:
 | |
|         """Initialize the API wrapper."""
 | |
|         self.host = host
 | |
|         self.port = port
 | |
|         self.username = username
 | |
|         self.password = password
 | |
|         self.ssl = ssl
 | |
|         self.verify_ssl = verify_ssl
 | |
|         self._session = session
 | |
|         self._timeout = ClientTimeout(total=timeout)
 | |
|         protocol = "https" if ssl else "http"
 | |
|         self.base_url = f"{protocol}://{host}:{port}"
 | |
|         self._own_session = session is None
 | |
| 
 | |
|     async def __aenter__(self):
 | |
|         """Async context manager entry."""
 | |
|         if self._own_session:
 | |
|             connector = aiohttp.TCPConnector(ssl=self.verify_ssl)
 | |
|             self._session = aiohttp.ClientSession(
 | |
|                 timeout=self._timeout,
 | |
|                 connector=connector
 | |
|             )
 | |
|         return self
 | |
| 
 | |
|     async def __aexit__(self, exc_type, exc_val, exc_tb):
 | |
|         """Async context manager exit."""
 | |
|         if self._own_session and self._session:
 | |
|             await self._session.close()
 | |
| 
 | |
|     @property
 | |
|     def session(self) -> aiohttp.ClientSession:
 | |
|         """Get the session, creating one if needed."""
 | |
|         if not self._session:
 | |
|             connector = aiohttp.TCPConnector(ssl=self.verify_ssl)
 | |
|             self._session = aiohttp.ClientSession(
 | |
|                 timeout=self._timeout,
 | |
|                 connector=connector
 | |
|             )
 | |
|         return self._session
 | |
| 
 | |
|     async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
 | |
|         """Make an API request."""
 | |
|         url = f"{self.base_url}{endpoint}"
 | |
|         headers = {"Content-Type": "application/json"}
 | |
|         auth = None
 | |
| 
 | |
|         if self.username and self.password:
 | |
|             auth = BasicAuth(self.username, self.password)
 | |
| 
 | |
|         try:
 | |
|             async with self.session.request(
 | |
|                 method, url, json=data, headers=headers, auth=auth, ssl=self.verify_ssl
 | |
|             ) as response:
 | |
| 
 | |
|                 if response.status == 401:
 | |
|                     raise AdGuardAuthError("Authentication failed")
 | |
|                 elif response.status == 404:
 | |
|                     raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
 | |
|                 elif response.status >= 500:
 | |
|                     raise AdGuardConnectionError(f"Server error {response.status}")
 | |
| 
 | |
|                 response.raise_for_status()
 | |
| 
 | |
|                 # Handle empty responses
 | |
|                 if response.status == 204 or not response.content_length:
 | |
|                     return {}
 | |
| 
 | |
|                 try:
 | |
|                     return await response.json()
 | |
|                 except (aiohttp.ContentTypeError, ValueError):
 | |
|                     # If not JSON, return text response
 | |
|                     text = await response.text()
 | |
|                     return {"response": text}
 | |
| 
 | |
|         except asyncio.TimeoutError as err:
 | |
|             raise AdGuardTimeoutError(f"Request timeout: {err}") from err
 | |
|         except ClientError as err:
 | |
|             raise AdGuardConnectionError(f"Client error: {err}") from err
 | |
|         except Exception as err:
 | |
|             if isinstance(err, AdGuardHomeError):
 | |
|                 raise
 | |
|             raise AdGuardHomeError(f"Unexpected error: {err}") from err
 | |
| 
 | |
|     async def test_connection(self) -> bool:
 | |
|         """Test the connection to AdGuard Home."""
 | |
|         try:
 | |
|             response = await self._request("GET", API_ENDPOINTS["status"])
 | |
|             return isinstance(response, dict) and len(response) > 0
 | |
|         except Exception:
 | |
|             return False
 | |
| 
 | |
|     async def get_status(self) -> Dict[str, Any]:
 | |
|         """Get server status information."""
 | |
|         return await self._request("GET", API_ENDPOINTS["status"])
 | |
| 
 | |
|     async def get_clients(self) -> Dict[str, Any]:
 | |
|         """Get all configured clients."""
 | |
|         return await self._request("GET", API_ENDPOINTS["clients"])
 | |
| 
 | |
|     async def get_statistics(self) -> Dict[str, Any]:
 | |
|         """Get DNS query statistics."""
 | |
|         return await self._request("GET", API_ENDPOINTS["stats"])
 | |
| 
 | |
|     async def set_protection(self, enabled: bool) -> Dict[str, Any]:
 | |
|         """Enable or disable AdGuard protection."""
 | |
|         data = {"enabled": enabled}
 | |
|         return await self._request("POST", API_ENDPOINTS["protection"], data)
 | |
| 
 | |
|     async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
 | |
|         """Add a new client configuration."""
 | |
|         if "name" not in client_data:
 | |
|             raise ValueError("Client name is required")
 | |
|         if "ids" not in client_data or not client_data["ids"]:
 | |
|             raise ValueError("Client IDs are required")
 | |
| 
 | |
|         return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
 | |
| 
 | |
|     async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
 | |
|         """Update an existing client configuration."""
 | |
|         if "name" not in client_data:
 | |
|             raise ValueError("Client name is required")
 | |
|         if "data" not in client_data:
 | |
|             raise ValueError("Client data is required")
 | |
| 
 | |
|         return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
 | |
| 
 | |
|     async def delete_client(self, client_name: str) -> Dict[str, Any]:
 | |
|         """Delete a client configuration."""
 | |
|         if not client_name:
 | |
|             raise ValueError("Client name is required")
 | |
| 
 | |
|         data = {"name": client_name}
 | |
|         return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
 | |
| 
 | |
|     async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]:
 | |
|         """Get a specific client by name."""
 | |
|         if not client_name:
 | |
|             return None
 | |
| 
 | |
|         try:
 | |
|             clients_data = await self.get_clients()
 | |
|             clients = clients_data.get("clients", [])
 | |
| 
 | |
|             for client in clients:
 | |
|                 if client.get("name") == client_name:
 | |
|                     return client
 | |
| 
 | |
|             return None
 | |
|         except Exception as err:
 | |
|             _LOGGER.error("Error getting client %s: %s", client_name, err)
 | |
|             return None
 | |
| 
 | |
|     async def update_client_blocked_services(
 | |
|         self,
 | |
|         client_name: str,
 | |
|         blocked_services: list,
 | |
|     ) -> Dict[str, Any]:
 | |
|         """Update blocked services for a specific client."""
 | |
|         if not client_name:
 | |
|             raise ValueError("Client name is required")
 | |
| 
 | |
|         client = await self.get_client_by_name(client_name)
 | |
|         if not client:
 | |
|             raise AdGuardNotFoundError(f"Client '{client_name}' not found")
 | |
| 
 | |
|         # Format blocked services data according to AdGuard Home API
 | |
|         blocked_services_data = {
 | |
|             "ids": blocked_services,
 | |
|             "schedule": {"time_zone": "Local"}
 | |
|         }
 | |
| 
 | |
|         update_data = {
 | |
|             "name": client_name,
 | |
|             "data": {
 | |
|                 **client,
 | |
|                 "blocked_services": blocked_services_data
 | |
|             }
 | |
|         }
 | |
| 
 | |
|         return await self.update_client(update_data)
 | |
| 
 | |
|     async def get_blocked_services_list(self) -> Dict[str, Any]:
 | |
|         """Get list of available blocked services."""
 | |
|         try:
 | |
|             return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
 | |
|         except Exception as err:
 | |
|             _LOGGER.error("Error getting blocked services list: %s", err)
 | |
|             return {}
 | |
| 
 | |
|     async def close(self) -> None:
 | |
|         """Close the API session if we own it."""
 | |
|         if self._own_session and self._session:
 | |
|             await self._session.close()
 |