264 lines
9.4 KiB
Python
264 lines
9.4 KiB
Python
"""API wrapper for AdGuard Home."""
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Dict, Optional
|
|
|
|
import aiohttp
|
|
from aiohttp import BasicAuth, ClientError, ClientTimeout
|
|
|
|
from .const import API_ENDPOINTS
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Custom exceptions
|
|
class AdGuardHomeError(Exception):
|
|
"""Base exception for AdGuard Home API."""
|
|
pass
|
|
|
|
class AdGuardConnectionError(AdGuardHomeError):
|
|
"""Exception for connection errors."""
|
|
pass
|
|
|
|
class AdGuardAuthError(AdGuardHomeError):
|
|
"""Exception for authentication errors."""
|
|
pass
|
|
|
|
class AdGuardNotFoundError(AdGuardHomeError):
|
|
"""Exception for not found errors."""
|
|
pass
|
|
|
|
class AdGuardHomeAPI:
|
|
"""API wrapper for AdGuard Home."""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str,
|
|
port: int = 3000,
|
|
username: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
ssl: bool = False,
|
|
session: Optional[aiohttp.ClientSession] = None,
|
|
timeout: int = 10,
|
|
):
|
|
"""Initialize the API wrapper."""
|
|
self.host = host
|
|
self.port = port
|
|
self.username = username
|
|
self.password = password
|
|
self.ssl = ssl
|
|
self._session = session
|
|
self._timeout = ClientTimeout(total=timeout)
|
|
protocol = "https" if ssl else "http"
|
|
self.base_url = f"{protocol}://{host}:{port}"
|
|
self._own_session = session is None
|
|
|
|
async def __aenter__(self):
|
|
"""Async context manager entry."""
|
|
if self._own_session:
|
|
self._session = aiohttp.ClientSession(timeout=self._timeout)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
"""Async context manager exit."""
|
|
if self._own_session and self._session:
|
|
await self._session.close()
|
|
|
|
@property
|
|
def session(self) -> aiohttp.ClientSession:
|
|
"""Get the session, creating one if needed."""
|
|
if not self._session:
|
|
self._session = aiohttp.ClientSession(timeout=self._timeout)
|
|
return self._session
|
|
|
|
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
|
"""Make an API request with comprehensive error handling."""
|
|
url = f"{self.base_url}{endpoint}"
|
|
headers = {"Content-Type": "application/json"}
|
|
auth = None
|
|
|
|
if self.username and self.password:
|
|
auth = BasicAuth(self.username, self.password)
|
|
|
|
try:
|
|
async with self.session.request(
|
|
method, url, json=data, headers=headers, auth=auth
|
|
) as response:
|
|
|
|
# Handle different HTTP status codes
|
|
if response.status == 401:
|
|
raise AdGuardAuthError("Authentication failed - check username/password")
|
|
elif response.status == 403:
|
|
raise AdGuardAuthError("Access forbidden - insufficient permissions")
|
|
elif response.status == 404:
|
|
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
|
elif response.status >= 500:
|
|
raise AdGuardConnectionError(f"Server error {response.status}")
|
|
|
|
response.raise_for_status()
|
|
|
|
# Handle empty responses
|
|
if response.status == 204 or not response.content_length:
|
|
return {}
|
|
|
|
try:
|
|
return await response.json()
|
|
except aiohttp.ContentTypeError:
|
|
# Handle non-JSON responses
|
|
text = await response.text()
|
|
_LOGGER.warning("Non-JSON response received: %s", text)
|
|
return {"response": text}
|
|
|
|
except asyncio.TimeoutError as err:
|
|
raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}")
|
|
except ClientError as err:
|
|
raise AdGuardConnectionError(f"Client error: {err}")
|
|
except Exception as err:
|
|
_LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err)
|
|
raise AdGuardHomeError(f"Unexpected error: {err}")
|
|
|
|
async def test_connection(self) -> bool:
|
|
"""Test the connection to AdGuard Home."""
|
|
try:
|
|
await self._request("GET", API_ENDPOINTS["status"])
|
|
return True
|
|
except Exception as err:
|
|
_LOGGER.debug("Connection test failed: %s", err)
|
|
return False
|
|
|
|
async def get_status(self) -> Dict[str, Any]:
|
|
"""Get server status information."""
|
|
return await self._request("GET", API_ENDPOINTS["status"])
|
|
|
|
async def get_clients(self) -> Dict[str, Any]:
|
|
"""Get all configured clients."""
|
|
return await self._request("GET", API_ENDPOINTS["clients"])
|
|
|
|
async def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get DNS query statistics."""
|
|
return await self._request("GET", API_ENDPOINTS["stats"])
|
|
|
|
async def set_protection(self, enabled: bool) -> Dict[str, Any]:
|
|
"""Enable or disable AdGuard protection."""
|
|
data = {"enabled": enabled}
|
|
return await self._request("POST", API_ENDPOINTS["protection"], data)
|
|
|
|
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Add a new client configuration."""
|
|
# Validate required fields
|
|
if "name" not in client_data:
|
|
raise ValueError("Client name is required")
|
|
if "ids" not in client_data or not client_data["ids"]:
|
|
raise ValueError("Client IDs are required")
|
|
|
|
return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
|
|
|
|
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Update an existing client configuration."""
|
|
if "name" not in client_data:
|
|
raise ValueError("Client name is required for update")
|
|
if "data" not in client_data:
|
|
raise ValueError("Client data is required for update")
|
|
|
|
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
|
|
|
async def delete_client(self, client_name: str) -> Dict[str, Any]:
|
|
"""Delete a client configuration."""
|
|
if not client_name:
|
|
raise ValueError("Client name is required")
|
|
|
|
data = {"name": client_name}
|
|
return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
|
|
|
|
async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]:
|
|
"""Get a specific client by name."""
|
|
if not client_name:
|
|
return None
|
|
|
|
try:
|
|
clients_data = await self.get_clients()
|
|
clients = clients_data.get("clients", [])
|
|
|
|
for client in clients:
|
|
if client.get("name") == client_name:
|
|
return client
|
|
|
|
return None
|
|
except Exception as err:
|
|
_LOGGER.error("Failed to get client %s: %s", client_name, err)
|
|
return None
|
|
|
|
async def update_client_blocked_services(
|
|
self,
|
|
client_name: str,
|
|
blocked_services: list,
|
|
schedule: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Update blocked services for a specific client."""
|
|
if not client_name:
|
|
raise ValueError("Client name is required")
|
|
|
|
client = await self.get_client_by_name(client_name)
|
|
if not client:
|
|
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
|
|
|
# Prepare the blocked services data with proper structure
|
|
if schedule:
|
|
blocked_services_data = {
|
|
"ids": blocked_services,
|
|
"schedule": schedule
|
|
}
|
|
else:
|
|
blocked_services_data = {
|
|
"ids": blocked_services,
|
|
"schedule": {
|
|
"time_zone": "Local"
|
|
}
|
|
}
|
|
|
|
# Update the client with new blocked services
|
|
update_data = {
|
|
"name": client_name,
|
|
"data": {
|
|
**client,
|
|
"blocked_services": blocked_services_data
|
|
}
|
|
}
|
|
|
|
return await self.update_client(update_data)
|
|
|
|
async def toggle_client_service(
|
|
self, client_name: str, service_id: str, enabled: bool
|
|
) -> Dict[str, Any]:
|
|
"""Toggle a specific service for a client."""
|
|
if not client_name or not service_id:
|
|
raise ValueError("Client name and service ID are required")
|
|
|
|
client = await self.get_client_by_name(client_name)
|
|
if not client:
|
|
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
|
|
|
# Get current blocked services
|
|
blocked_services = client.get("blocked_services", {})
|
|
if isinstance(blocked_services, dict):
|
|
service_ids = blocked_services.get("ids", [])
|
|
else:
|
|
# Handle legacy format (direct list)
|
|
service_ids = blocked_services if blocked_services else []
|
|
|
|
# Update the service list
|
|
if enabled and service_id not in service_ids:
|
|
service_ids.append(service_id)
|
|
elif not enabled and service_id in service_ids:
|
|
service_ids.remove(service_id)
|
|
|
|
return await self.update_client_blocked_services(client_name, service_ids)
|
|
|
|
async def get_blocked_services(self) -> Dict[str, Any]:
|
|
"""Get available blocked services."""
|
|
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
|
|
|
|
async def close(self) -> None:
|
|
"""Close the API session if we own it."""
|
|
if self._own_session and self._session:
|
|
await self._session.close()
|