Revert "fix: Complete fixes: tests, workflows, coverage"
This reverts commit ed94d40e96
.
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
"""AdGuard Home API client."""
|
||||
"""API wrapper for AdGuard Home."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from aiohttp import BasicAuth, ClientError, ClientTimeout
|
||||
|
||||
from .const import API_ENDPOINTS
|
||||
|
||||
@@ -12,141 +12,228 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdGuardHomeError(Exception):
|
||||
"""Base exception for AdGuard Home errors."""
|
||||
"""Base exception for AdGuard Home API."""
|
||||
|
||||
|
||||
class AdGuardConnectionError(AdGuardHomeError):
|
||||
"""Connection error."""
|
||||
"""Exception for connection errors."""
|
||||
|
||||
|
||||
class AdGuardAuthError(AdGuardHomeError):
|
||||
"""Authentication error."""
|
||||
"""Exception for authentication errors."""
|
||||
|
||||
|
||||
class AdGuardNotFoundError(AdGuardHomeError):
|
||||
"""Exception for not found errors."""
|
||||
|
||||
|
||||
class AdGuardTimeoutError(AdGuardHomeError):
|
||||
"""Timeout error."""
|
||||
"""Exception for timeout errors."""
|
||||
|
||||
|
||||
class AdGuardHomeAPI:
|
||||
"""AdGuard Home API client."""
|
||||
"""API wrapper for AdGuard Home."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
port: int = 3000,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
verify_ssl: bool = True,
|
||||
session: Optional[aiohttp.ClientSession] = None,
|
||||
timeout: int = 30,
|
||||
timeout: int = 10,
|
||||
verify_ssl: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the API client."""
|
||||
"""Initialize the API wrapper."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.ssl = ssl
|
||||
self.verify_ssl = verify_ssl
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self._session = session
|
||||
self._auth = None
|
||||
self._timeout = ClientTimeout(total=timeout)
|
||||
protocol = "https" if ssl else "http"
|
||||
self.base_url = f"{protocol}://{host}:{port}"
|
||||
self._own_session = session is None
|
||||
|
||||
if username and password:
|
||||
self._auth = aiohttp.BasicAuth(username, password)
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
if self._own_session:
|
||||
connector = aiohttp.TCPConnector(ssl=self.verify_ssl)
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=self._timeout,
|
||||
connector=connector
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
if self._own_session and self._session:
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
"""Return the base URL."""
|
||||
protocol = "https" if self.ssl else "http"
|
||||
return f"{protocol}://{self.host}:{self.port}"
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
"""Get the session, creating one if needed."""
|
||||
if not self._session:
|
||||
connector = aiohttp.TCPConnector(ssl=self.verify_ssl)
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=self._timeout,
|
||||
connector=connector
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def _request(
|
||||
self, method: str, endpoint: str, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""Make a request to the API."""
|
||||
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Make an API request."""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
auth = None
|
||||
|
||||
if self.username and self.password:
|
||||
auth = BasicAuth(self.username, self.password)
|
||||
|
||||
try:
|
||||
async with self._session.request(
|
||||
method,
|
||||
url,
|
||||
auth=self._auth,
|
||||
timeout=self.timeout,
|
||||
ssl=self.verify_ssl if self.ssl else None,
|
||||
**kwargs
|
||||
async with self.session.request(
|
||||
method, url, json=data, headers=headers, auth=auth, ssl=self.verify_ssl
|
||||
) as response:
|
||||
|
||||
if response.status == 401:
|
||||
raise AdGuardAuthError("Authentication failed")
|
||||
elif response.status == 404:
|
||||
raise AdGuardConnectionError(f"Endpoint not found: {endpoint}")
|
||||
elif response.status >= 400:
|
||||
raise AdGuardConnectionError(f"HTTP {response.status}: {response.reason}")
|
||||
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
||||
elif response.status >= 500:
|
||||
raise AdGuardConnectionError(f"Server error {response.status}")
|
||||
|
||||
return await response.json()
|
||||
response.raise_for_status()
|
||||
|
||||
# Handle empty responses
|
||||
if response.status == 204 or not response.content_length:
|
||||
return {}
|
||||
|
||||
try:
|
||||
return await response.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
# If not JSON, return text response
|
||||
text = await response.text()
|
||||
return {"response": text}
|
||||
|
||||
except asyncio.TimeoutError as err:
|
||||
raise AdGuardTimeoutError(f"Request timeout for {url}") from err
|
||||
except aiohttp.ClientConnectorError as err:
|
||||
raise AdGuardConnectionError(f"Connection failed to {url}: {err}") from err
|
||||
except aiohttp.ClientError as err:
|
||||
raise AdGuardConnectionError(f"Client error for {url}: {err}") from err
|
||||
raise AdGuardTimeoutError(f"Request timeout: {err}") from err
|
||||
except ClientError as err:
|
||||
raise AdGuardConnectionError(f"Client error: {err}") from err
|
||||
except Exception as err:
|
||||
raise AdGuardHomeError(f"Unexpected error for {url}: {err}") from err
|
||||
if isinstance(err, AdGuardHomeError):
|
||||
raise
|
||||
raise AdGuardHomeError(f"Unexpected error: {err}") from err
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test the connection to AdGuard Home."""
|
||||
try:
|
||||
await self.get_status()
|
||||
return True
|
||||
except Exception as err:
|
||||
_LOGGER.error("Connection test failed: %s", err)
|
||||
response = await self._request("GET", API_ENDPOINTS["status"])
|
||||
return isinstance(response, dict) and len(response) > 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_status(self) -> Dict[str, Any]:
|
||||
"""Get AdGuard Home status."""
|
||||
"""Get server status information."""
|
||||
return await self._request("GET", API_ENDPOINTS["status"])
|
||||
|
||||
async def get_clients(self) -> Dict[str, Any]:
|
||||
"""Get clients list."""
|
||||
"""Get all configured clients."""
|
||||
return await self._request("GET", API_ENDPOINTS["clients"])
|
||||
|
||||
async def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get DNS query statistics."""
|
||||
return await self._request("GET", API_ENDPOINTS["stats"])
|
||||
|
||||
async def set_protection(self, enabled: bool) -> None:
|
||||
"""Enable or disable protection."""
|
||||
async def set_protection(self, enabled: bool) -> Dict[str, Any]:
|
||||
"""Enable or disable AdGuard protection."""
|
||||
data = {"enabled": enabled}
|
||||
await self._request("POST", API_ENDPOINTS["protection"], json=data)
|
||||
return await self._request("POST", API_ENDPOINTS["protection"], data)
|
||||
|
||||
async def get_client_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get client by name."""
|
||||
clients_data = await self.get_clients()
|
||||
for client in clients_data.get("clients", []):
|
||||
if client.get("name") == name:
|
||||
return client
|
||||
return None
|
||||
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Add a new client configuration."""
|
||||
if "name" not in client_data:
|
||||
raise ValueError("Client name is required")
|
||||
if "ids" not in client_data or not client_data["ids"]:
|
||||
raise ValueError("Client IDs are required")
|
||||
|
||||
return await self._request("POST", API_ENDPOINTS["clients_add"], client_data)
|
||||
|
||||
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update an existing client configuration."""
|
||||
if "name" not in client_data:
|
||||
raise ValueError("Client name is required")
|
||||
if "data" not in client_data:
|
||||
raise ValueError("Client data is required")
|
||||
|
||||
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
||||
|
||||
async def delete_client(self, client_name: str) -> Dict[str, Any]:
|
||||
"""Delete a client configuration."""
|
||||
if not client_name:
|
||||
raise ValueError("Client name is required")
|
||||
|
||||
data = {"name": client_name}
|
||||
return await self._request("POST", API_ENDPOINTS["clients_delete"], data)
|
||||
|
||||
async def get_client_by_name(self, client_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific client by name."""
|
||||
if not client_name:
|
||||
return None
|
||||
|
||||
try:
|
||||
clients_data = await self.get_clients()
|
||||
clients = clients_data.get("clients", [])
|
||||
|
||||
for client in clients:
|
||||
if client.get("name") == client_name:
|
||||
return client
|
||||
|
||||
return None
|
||||
except Exception as err:
|
||||
_LOGGER.error("Error getting client %s: %s", client_name, err)
|
||||
return None
|
||||
|
||||
async def update_client_blocked_services(
|
||||
self, client_name: str, blocked_services: List[str]
|
||||
) -> None:
|
||||
"""Update blocked services for a client."""
|
||||
self,
|
||||
client_name: str,
|
||||
blocked_services: list,
|
||||
) -> Dict[str, Any]:
|
||||
"""Update blocked services for a specific client."""
|
||||
if not client_name:
|
||||
raise ValueError("Client name is required")
|
||||
|
||||
client = await self.get_client_by_name(client_name)
|
||||
if not client:
|
||||
raise AdGuardConnectionError(f"Client '{client_name}' not found")
|
||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||
|
||||
# Update client with new blocked services
|
||||
client_data = client.copy()
|
||||
client_data["blocked_services"] = blocked_services
|
||||
# Format blocked services data according to AdGuard Home API
|
||||
blocked_services_data = {
|
||||
"ids": blocked_services,
|
||||
"schedule": {"time_zone": "Local"}
|
||||
}
|
||||
|
||||
await self._request("POST", API_ENDPOINTS["clients_update"], json=client_data)
|
||||
update_data = {
|
||||
"name": client_name,
|
||||
"data": {
|
||||
**client,
|
||||
"blocked_services": blocked_services_data
|
||||
}
|
||||
}
|
||||
|
||||
async def add_client(self, client_data: Dict[str, Any]) -> None:
|
||||
"""Add a new client."""
|
||||
await self._request("POST", API_ENDPOINTS["clients_add"], json=client_data)
|
||||
return await self.update_client(update_data)
|
||||
|
||||
async def delete_client(self, client_name: str) -> None:
|
||||
"""Delete a client."""
|
||||
data = {"name": client_name}
|
||||
await self._request("POST", API_ENDPOINTS["clients_delete"], json=data)
|
||||
async def get_blocked_services_list(self) -> Dict[str, Any]:
|
||||
"""Get list of available blocked services."""
|
||||
try:
|
||||
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
|
||||
except Exception as err:
|
||||
_LOGGER.error("Error getting blocked services list: %s", err)
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the API session if we own it."""
|
||||
if self._own_session and self._session:
|
||||
await self._session.close()
|
||||
|
Reference in New Issue
Block a user