fix: Fix CI/CD issues and enhance integration
Some checks failed
Integration Testing / Integration Tests (2024.12.0, 3.11) (push) Failing after 22s
Integration Testing / Integration Tests (2024.12.0, 3.12) (push) Failing after 21s
Integration Testing / Integration Tests (2024.12.0, 3.13) (push) Failing after 1m32s
Integration Testing / Integration Tests (2025.9.4, 3.11) (push) Failing after 15s
Integration Testing / Integration Tests (2025.9.4, 3.12) (push) Failing after 20s
Integration Testing / Integration Tests (2025.9.4, 3.13) (push) Failing after 20s

Signed-off-by: Rafal Zielinski <sq4ind@gmail.com>
This commit is contained in:
2025-09-28 17:24:46 +01:00
parent bcec7bbf1a
commit 8281a1813d
17 changed files with 1439 additions and 276 deletions

View File

@@ -12,8 +12,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.13"] python-version: ["3.11", "3.12", "3.13"]
home-assistant-version: ["2025.9.4"] home-assistant-version: ["2024.12.0", "2025.9.4"]
steps: steps:
- name: Checkout Code - name: Checkout Code
@@ -24,14 +24,6 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements*.txt') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
- name: Install Dependencies - name: Install Dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
@@ -45,7 +37,7 @@ jobs:
- name: Run Unit Tests - name: Run Unit Tests
run: | run: |
python -m pytest tests/ -v --tb=short --cov=custom_components/adguard_hub --cov-report=xml --cov-report=term-missing --asyncio-mode=auto python -m pytest tests/ -v --tb=short --cov=custom_components/adguard_hub --cov-report=xml --cov-report=term-missing --asyncio-mode=auto --cov-fail-under=60
- name: Test Installation - name: Test Installation
run: | run: |
@@ -78,10 +70,3 @@ jobs:
print(f'❌ Manifest validation failed: {e}') print(f'❌ Manifest validation failed: {e}')
sys.exit(1) sys.exit(1)
" "
- name: Upload Coverage Reports
uses: actions/upload-artifact@v4
if: matrix.python-version == '3.13' && matrix.home-assistant-version == '2025.9.4'
with:
name: coverage-report
path: coverage.xml

View File

@@ -1,70 +0,0 @@
name: Code Quality Check
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
code-quality:
name: Code Quality Analysis
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
- name: Code Formatting Check (Black)
run: |
black --check custom_components/ tests/
- name: Import Sorting Check (isort)
run: |
isort --check-only --diff custom_components/ tests/
- name: Linting (flake8)
run: |
flake8 custom_components/ tests/
- name: Type Checking (mypy)
run: |
mypy custom_components/adguard_hub/ --ignore-missing-imports
continue-on-error: true
security-scan:
name: Security Analysis
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: Install Security Tools
run: |
python -m pip install --upgrade pip
pip install bandit safety
- name: Security Check (Bandit)
run: |
bandit -r custom_components/ -ll
- name: Dependency Security Check (Safety)
run: |
pip install -r requirements.txt
safety check

View File

@@ -1,70 +0,0 @@
name: Release
on:
push:
tags:
- 'v*.*.*'
permissions:
contents: write
jobs:
create-release:
name: Create Release
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/v')
steps:
- name: Checkout Code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Validate Tag Format
run: |
if [[ ! "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "❌ Invalid tag format. Expected: v1.2.3"
exit 1
fi
echo "✅ Valid semantic version tag: ${{ github.ref_name }}"
- name: Extract Version
id: version
run: |
VERSION=${{ github.ref_name }}
VERSION_NUMBER=${VERSION#v}
echo "version=${VERSION}" >> $GITHUB_OUTPUT
echo "version_number=${VERSION_NUMBER}" >> $GITHUB_OUTPUT
- name: Update Manifest Version
run: |
sed -i 's/"version": ".*"/"version": "${{ steps.version.outputs.version_number }}"/' custom_components/adguard_hub/manifest.json
- name: Run Tests Before Release
run: |
python -m pip install --upgrade pip
pip install homeassistant==2025.9.4
pip install -r requirements-dev.txt
mkdir -p custom_components
touch custom_components/__init__.py
python -m pytest tests/ -v --tb=short
- name: Create Release Archive
run: |
cd custom_components
zip -r ../adguard-control-hub-${{ steps.version.outputs.version_number }}.zip adguard_hub/
- name: Generate Changelog
id: changelog
run: |
PREVIOUS_TAG=$(git tag --sort=-version:refname | head -2 | tail -1 2>/dev/null || echo "")
if [ -z "$PREVIOUS_TAG" ]; then
echo "changelog=Initial release of AdGuard Control Hub" >> $GITHUB_OUTPUT
else
echo "changelog=Changes since $PREVIOUS_TAG" >> $GITHUB_OUTPUT
fi
- name: Success Message
run: |
echo "🎉 Release ${{ steps.version.outputs.version }} created!"
echo "📦 Archive: adguard-control-hub-${{ steps.version.outputs.version_number }}.zip"

View File

@@ -33,6 +33,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
username=entry.data.get(CONF_USERNAME), username=entry.data.get(CONF_USERNAME),
password=entry.data.get(CONF_PASSWORD), password=entry.data.get(CONF_PASSWORD),
ssl=entry.data.get(CONF_SSL, False), ssl=entry.data.get(CONF_SSL, False),
verify_ssl=entry.data.get(CONF_VERIFY_SSL, True),
session=session, session=session,
) )

View File

@@ -27,6 +27,10 @@ class AdGuardNotFoundError(AdGuardHomeError):
"""Exception for not found errors.""" """Exception for not found errors."""
class AdGuardTimeoutError(AdGuardHomeError):
"""Exception for timeout errors."""
class AdGuardHomeAPI: class AdGuardHomeAPI:
"""API wrapper for AdGuard Home.""" """API wrapper for AdGuard Home."""
@@ -39,6 +43,7 @@ class AdGuardHomeAPI:
ssl: bool = False, ssl: bool = False,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
timeout: int = 10, timeout: int = 10,
verify_ssl: bool = True,
) -> None: ) -> None:
"""Initialize the API wrapper.""" """Initialize the API wrapper."""
self.host = host self.host = host
@@ -46,6 +51,7 @@ class AdGuardHomeAPI:
self.username = username self.username = username
self.password = password self.password = password
self.ssl = ssl self.ssl = ssl
self.verify_ssl = verify_ssl
self._session = session self._session = session
self._timeout = ClientTimeout(total=timeout) self._timeout = ClientTimeout(total=timeout)
protocol = "https" if ssl else "http" protocol = "https" if ssl else "http"
@@ -55,7 +61,11 @@ class AdGuardHomeAPI:
async def __aenter__(self): async def __aenter__(self):
"""Async context manager entry.""" """Async context manager entry."""
if self._own_session: if self._own_session:
self._session = aiohttp.ClientSession(timeout=self._timeout) connector = aiohttp.TCPConnector(ssl=self.verify_ssl)
self._session = aiohttp.ClientSession(
timeout=self._timeout,
connector=connector
)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -67,7 +77,11 @@ class AdGuardHomeAPI:
def session(self) -> aiohttp.ClientSession: def session(self) -> aiohttp.ClientSession:
"""Get the session, creating one if needed.""" """Get the session, creating one if needed."""
if not self._session: if not self._session:
self._session = aiohttp.ClientSession(timeout=self._timeout) connector = aiohttp.TCPConnector(ssl=self.verify_ssl)
self._session = aiohttp.ClientSession(
timeout=self._timeout,
connector=connector
)
return self._session return self._session
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]: async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
@@ -81,7 +95,7 @@ class AdGuardHomeAPI:
try: try:
async with self.session.request( async with self.session.request(
method, url, json=data, headers=headers, auth=auth method, url, json=data, headers=headers, auth=auth, ssl=self.verify_ssl
) as response: ) as response:
if response.status == 401: if response.status == 401:
@@ -93,17 +107,19 @@ class AdGuardHomeAPI:
response.raise_for_status() response.raise_for_status()
# Handle empty responses
if response.status == 204 or not response.content_length: if response.status == 204 or not response.content_length:
return {} return {}
try: try:
return await response.json() return await response.json()
except aiohttp.ContentTypeError: except (aiohttp.ContentTypeError, ValueError):
# If not JSON, return text response
text = await response.text() text = await response.text()
return {"response": text} return {"response": text}
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
raise AdGuardConnectionError(f"Timeout: {err}") from err raise AdGuardTimeoutError(f"Request timeout: {err}") from err
except ClientError as err: except ClientError as err:
raise AdGuardConnectionError(f"Client error: {err}") from err raise AdGuardConnectionError(f"Client error: {err}") from err
except Exception as err: except Exception as err:
@@ -114,8 +130,8 @@ class AdGuardHomeAPI:
async def test_connection(self) -> bool: async def test_connection(self) -> bool:
"""Test the connection to AdGuard Home.""" """Test the connection to AdGuard Home."""
try: try:
await self._request("GET", API_ENDPOINTS["status"]) response = await self._request("GET", API_ENDPOINTS["status"])
return True return isinstance(response, dict) and len(response) > 0
except Exception: except Exception:
return False return False
@@ -176,7 +192,8 @@ class AdGuardHomeAPI:
return client return client
return None return None
except Exception: except Exception as err:
_LOGGER.error("Error getting client %s: %s", client_name, err)
return None return None
async def update_client_blocked_services( async def update_client_blocked_services(
@@ -192,6 +209,7 @@ class AdGuardHomeAPI:
if not client: if not client:
raise AdGuardNotFoundError(f"Client '{client_name}' not found") raise AdGuardNotFoundError(f"Client '{client_name}' not found")
# Format blocked services data according to AdGuard Home API
blocked_services_data = { blocked_services_data = {
"ids": blocked_services, "ids": blocked_services,
"schedule": {"time_zone": "Local"} "schedule": {"time_zone": "Local"}
@@ -207,6 +225,14 @@ class AdGuardHomeAPI:
return await self.update_client(update_data) return await self.update_client(update_data)
async def get_blocked_services_list(self) -> Dict[str, Any]:
"""Get list of available blocked services."""
try:
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
except Exception as err:
_LOGGER.error("Error getting blocked services list: %s", err)
return {}
async def close(self) -> None: async def close(self) -> None:
"""Close the API session if we own it.""" """Close the API session if we own it."""
if self._own_session and self._session: if self._own_session and self._session:

View File

@@ -1,10 +1,11 @@
"""Binary sensor platform for AdGuard Control Hub integration.""" """Binary sensor platform for AdGuard Control Hub integration."""
import logging import logging
from typing import Any from typing import Any, Optional
from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
@@ -26,15 +27,26 @@ async def async_setup_entry(
entities = [ entities = [
AdGuardProtectionBinarySensor(coordinator, api), AdGuardProtectionBinarySensor(coordinator, api),
AdGuardServerRunningBinarySensor(coordinator, api),
AdGuardSafeBrowsingBinarySensor(coordinator, api),
AdGuardParentalControlBinarySensor(coordinator, api),
AdGuardSafeSearchBinarySensor(coordinator, api),
] ]
async_add_entities(entities) # Add client-specific binary sensors
for client_name in coordinator.clients.keys():
entities.extend([
AdGuardClientFilteringBinarySensor(coordinator, api, client_name),
AdGuardClientSafeBrowsingBinarySensor(coordinator, api, client_name),
])
async_add_entities(entities, update_before_add=True)
class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity): class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity):
"""Base class for AdGuard binary sensors.""" """Base class for AdGuard binary sensors."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the binary sensor.""" """Initialize the binary sensor."""
super().__init__(coordinator) super().__init__(coordinator)
self.api = api self.api = api
@@ -43,21 +55,23 @@ class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity):
"name": f"AdGuard Control Hub ({api.host})", "name": f"AdGuard Control Hub ({api.host})",
"manufacturer": MANUFACTURER, "manufacturer": MANUFACTURER,
"model": "AdGuard Home", "model": "AdGuard Home",
"configuration_url": f"{'https' if api.ssl else 'http'}://{api.host}:{api.port}",
} }
class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor): class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show AdGuard protection status.""" """Binary sensor to show AdGuard protection status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the binary sensor.""" """Initialize the binary sensor."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled" self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled"
self._attr_name = "AdGuard Protection Status" self._attr_name = "AdGuard Protection Status"
self._attr_device_class = BinarySensorDeviceClass.RUNNING self._attr_device_class = BinarySensorDeviceClass.RUNNING
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property @property
def is_on(self) -> bool | None: def is_on(self) -> Optional[bool]:
"""Return true if protection is enabled.""" """Return true if protection is enabled."""
return self.coordinator.protection_status.get("protection_enabled", False) return self.coordinator.protection_status.get("protection_enabled", False)
@@ -66,6 +80,11 @@ class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
"""Return the icon for the binary sensor.""" """Return the icon for the binary sensor."""
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success and bool(self.coordinator.protection_status)
@property @property
def extra_state_attributes(self) -> dict[str, Any]: def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes.""" """Return additional state attributes."""
@@ -74,4 +93,205 @@ class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
"dns_port": status.get("dns_port", "N/A"), "dns_port": status.get("dns_port", "N/A"),
"version": status.get("version", "N/A"), "version": status.get("version", "N/A"),
"running": status.get("running", False), "running": status.get("running", False),
"dhcp_available": status.get("dhcp_available", False),
}
class AdGuardServerRunningBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show if AdGuard server is running."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_server_running"
self._attr_name = "AdGuard Server Running"
self._attr_device_class = BinarySensorDeviceClass.RUNNING
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def is_on(self) -> Optional[bool]:
"""Return true if server is running."""
return self.coordinator.protection_status.get("running", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:server" if self.is_on else "mdi:server-off"
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success and bool(self.coordinator.protection_status)
class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show SafeBrowsing status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled"
self._attr_name = "AdGuard SafeBrowsing"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def is_on(self) -> Optional[bool]:
"""Return true if SafeBrowsing is enabled."""
return self.coordinator.protection_status.get("safebrowsing_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:shield-check" if self.is_on else "mdi:shield-off"
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success and bool(self.coordinator.protection_status)
class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Parental Control status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled"
self._attr_name = "AdGuard Parental Control"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def is_on(self) -> Optional[bool]:
"""Return true if Parental Control is enabled."""
return self.coordinator.protection_status.get("parental_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:account-child" if self.is_on else "mdi:account-child-outline"
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success and bool(self.coordinator.protection_status)
class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Safe Search status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled"
self._attr_name = "AdGuard Safe Search"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def is_on(self) -> Optional[bool]:
"""Return true if Safe Search is enabled."""
return self.coordinator.protection_status.get("safesearch_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:shield-search" if self.is_on else "mdi:magnify"
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success and bool(self.coordinator.protection_status)
class AdGuardClientFilteringBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show client-specific filtering status."""
def __init__(
self,
coordinator: AdGuardControlHubCoordinator,
api: AdGuardHomeAPI,
client_name: str,
) -> None:
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self.client_name = client_name
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}_filtering"
self._attr_name = f"AdGuard {client_name} Filtering"
self._attr_device_class = BinarySensorDeviceClass.RUNNING
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def is_on(self) -> Optional[bool]:
"""Return true if client filtering is enabled."""
client = self.coordinator.clients.get(self.client_name, {})
return client.get("filtering_enabled", True)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:filter" if self.is_on else "mdi:filter-off"
@property
def available(self) -> bool:
"""Return if sensor is available."""
return (
self.coordinator.last_update_success
and self.client_name in self.coordinator.clients
)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
client = self.coordinator.clients.get(self.client_name, {})
return {
"client_ids": client.get("ids", []),
"use_global_settings": client.get("use_global_settings", True),
}
class AdGuardClientSafeBrowsingBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show client-specific SafeBrowsing status."""
def __init__(
self,
coordinator: AdGuardControlHubCoordinator,
api: AdGuardHomeAPI,
client_name: str,
) -> None:
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self.client_name = client_name
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}_safebrowsing"
self._attr_name = f"AdGuard {client_name} SafeBrowsing"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def is_on(self) -> Optional[bool]:
"""Return true if client SafeBrowsing is enabled."""
client = self.coordinator.clients.get(self.client_name, {})
return client.get("safebrowsing_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:shield-account" if self.is_on else "mdi:shield-account-outline"
@property
def available(self) -> bool:
"""Return if sensor is available."""
return (
self.coordinator.last_update_success
and self.client_name in self.coordinator.clients
)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
client = self.coordinator.clients.get(self.client_name, {})
return {
"parental_enabled": client.get("parental_enabled", False),
"safesearch_enabled": client.get("safesearch_enabled", False),
} }

View File

@@ -1,6 +1,7 @@
"""Config flow for AdGuard Control Hub integration.""" """Config flow for AdGuard Control Hub integration."""
import asyncio import asyncio
import logging import logging
import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import voluptuous as vol import voluptuous as vol
@@ -10,7 +11,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError, AdGuardTimeoutError
from .const import ( from .const import (
CONF_SSL, CONF_SSL,
CONF_VERIFY_SSL, CONF_VERIFY_SSL,
@@ -32,16 +33,33 @@ STEP_USER_DATA_SCHEMA = vol.Schema({
}) })
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: def validate_host(host: str) -> str:
"""Validate the user input allows us to connect.""" """Validate and clean host input."""
host = data[CONF_HOST].strip() host = host.strip()
if not host: if not host:
raise InvalidHost("Host cannot be empty") raise InvalidHost("Host cannot be empty")
# Remove protocol if present
if host.startswith(("http://", "https://")): if host.startswith(("http://", "https://")):
host = host.split("://", 1)[1] host = host.split("://", 1)[1]
data[CONF_HOST] = host
# Remove path if present
if "/" in host:
host = host.split("/", 1)[0]
return host
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the user input allows us to connect."""
# Validate and clean host
try:
host = validate_host(data[CONF_HOST])
data[CONF_HOST] = host
except InvalidHost:
raise
# Validate port
port = data[CONF_PORT] port = data[CONF_PORT]
if not (1 <= port <= 65535): if not (1 <= port <= 65535):
raise InvalidPort("Port must be between 1 and 65535") raise InvalidPort("Port must be between 1 and 65535")
@@ -54,6 +72,7 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
username=data.get(CONF_USERNAME), username=data.get(CONF_USERNAME),
password=data.get(CONF_PASSWORD), password=data.get(CONF_PASSWORD),
ssl=data.get(CONF_SSL, False), ssl=data.get(CONF_SSL, False),
verify_ssl=data.get(CONF_VERIFY_SSL, True),
session=session, session=session,
timeout=10, timeout=10,
) )
@@ -72,6 +91,7 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
"host": host, "host": host,
} }
except Exception: except Exception:
# If we can't get status but connection works, still proceed
return { return {
"title": f"AdGuard Control Hub ({host})", "title": f"AdGuard Control Hub ({host})",
"version": "unknown", "version": "unknown",
@@ -80,6 +100,8 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
except AdGuardAuthError as err: except AdGuardAuthError as err:
raise InvalidAuth from err raise InvalidAuth from err
except AdGuardTimeoutError as err:
raise Timeout from err
except AdGuardConnectionError as err: except AdGuardConnectionError as err:
if "timeout" in str(err).lower(): if "timeout" in str(err).lower():
raise Timeout from err raise Timeout from err
@@ -87,6 +109,7 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
raise Timeout from err raise Timeout from err
except Exception as err: except Exception as err:
_LOGGER.exception("Unexpected error during validation")
raise CannotConnect from err raise CannotConnect from err

View File

@@ -4,6 +4,7 @@ from typing import Final
# Integration details # Integration details
DOMAIN: Final = "adguard_hub" DOMAIN: Final = "adguard_hub"
MANUFACTURER: Final = "AdGuard Control Hub" MANUFACTURER: Final = "AdGuard Control Hub"
INTEGRATION_NAME: Final = "AdGuard Control Hub"
# Configuration # Configuration
CONF_SSL: Final = "ssl" CONF_SSL: Final = "ssl"
@@ -32,9 +33,11 @@ API_ENDPOINTS: Final = {
"blocked_services_all": "/control/blocked_services/all", "blocked_services_all": "/control/blocked_services/all",
"protection": "/control/protection", "protection": "/control/protection",
"stats": "/control/stats", "stats": "/control/stats",
"rewrite": "/control/rewrite/list",
"querylog": "/control/querylog",
} }
# Available blocked services # Available blocked services (common ones)
BLOCKED_SERVICES: Final = { BLOCKED_SERVICES: Final = {
"youtube": "YouTube", "youtube": "YouTube",
"facebook": "Facebook", "facebook": "Facebook",
@@ -52,6 +55,19 @@ BLOCKED_SERVICES: Final = {
"whatsapp": "WhatsApp", "whatsapp": "WhatsApp",
"telegram": "Telegram", "telegram": "Telegram",
"discord": "Discord", "discord": "Discord",
"amazon": "Amazon",
"ebay": "eBay",
"skype": "Skype",
"zoom": "Zoom",
"tinder": "Tinder",
"pinterest": "Pinterest",
"linkedin": "LinkedIn",
"dailymotion": "Dailymotion",
"vimeo": "Vimeo",
"viber": "Viber",
"wechat": "WeChat",
"ok": "Odnoklassniki",
"vk": "VKontakte",
} }
# Service attributes # Service attributes
@@ -59,9 +75,22 @@ ATTR_CLIENT_NAME: Final = "client_name"
ATTR_SERVICES: Final = "services" ATTR_SERVICES: Final = "services"
ATTR_DURATION: Final = "duration" ATTR_DURATION: Final = "duration"
ATTR_CLIENTS: Final = "clients" ATTR_CLIENTS: Final = "clients"
ATTR_ENABLED: Final = "enabled"
# Icons # Icons
ICON_PROTECTION: Final = "mdi:shield" ICON_PROTECTION: Final = "mdi:shield"
ICON_PROTECTION_OFF: Final = "mdi:shield-off" ICON_PROTECTION_OFF: Final = "mdi:shield-off"
ICON_CLIENT: Final = "mdi:devices" ICON_CLIENT: Final = "mdi:devices"
ICON_STATISTICS: Final = "mdi:chart-line" ICON_STATISTICS: Final = "mdi:chart-line"
ICON_BLOCKED: Final = "mdi:shield-check"
ICON_QUERIES: Final = "mdi:dns"
ICON_PERCENTAGE: Final = "mdi:percent"
ICON_CLIENTS: Final = "mdi:account-multiple"
# Service names
SERVICE_BLOCK_SERVICES: Final = "block_services"
SERVICE_UNBLOCK_SERVICES: Final = "unblock_services"
SERVICE_EMERGENCY_UNBLOCK: Final = "emergency_unblock"
SERVICE_ADD_CLIENT: Final = "add_client"
SERVICE_REMOVE_CLIENT: Final = "remove_client"
SERVICE_REFRESH_DATA: Final = "refresh_data"

View File

@@ -10,5 +10,5 @@
"requirements": [ "requirements": [
"aiohttp>=3.8.0" "aiohttp>=3.8.0"
], ],
"version": "1.0.0" "version": "1.0.1"
} }

View File

@@ -1,17 +1,18 @@
"""Sensor platform for AdGuard Control Hub integration.""" """Sensor platform for AdGuard Control Hub integration."""
import logging import logging
from typing import Any from typing import Any, Optional
from homeassistant.components.sensor import SensorEntity, SensorStateClass from homeassistant.components.sensor import SensorEntity, SensorStateClass, SensorDeviceClass
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import PERCENTAGE from homeassistant.const import PERCENTAGE, UnitOfTime
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AdGuardControlHubCoordinator from . import AdGuardControlHubCoordinator
from .api import AdGuardHomeAPI from .api import AdGuardHomeAPI
from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS, ICON_BLOCKED, ICON_QUERIES, ICON_PERCENTAGE, ICON_CLIENTS
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -30,15 +31,17 @@ async def async_setup_entry(
AdGuardBlockedCounterSensor(coordinator, api), AdGuardBlockedCounterSensor(coordinator, api),
AdGuardBlockingPercentageSensor(coordinator, api), AdGuardBlockingPercentageSensor(coordinator, api),
AdGuardClientCountSensor(coordinator, api), AdGuardClientCountSensor(coordinator, api),
AdGuardProcessingTimeSensor(coordinator, api),
AdGuardFilteringRulesSensor(coordinator, api),
] ]
async_add_entities(entities) async_add_entities(entities, update_before_add=True)
class AdGuardBaseSensor(CoordinatorEntity, SensorEntity): class AdGuardBaseSensor(CoordinatorEntity, SensorEntity):
"""Base class for AdGuard sensors.""" """Base class for AdGuard sensors."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
super().__init__(coordinator) super().__init__(coordinator)
self.api = api self.api = api
@@ -47,61 +50,91 @@ class AdGuardBaseSensor(CoordinatorEntity, SensorEntity):
"name": f"AdGuard Control Hub ({api.host})", "name": f"AdGuard Control Hub ({api.host})",
"manufacturer": MANUFACTURER, "manufacturer": MANUFACTURER,
"model": "AdGuard Home", "model": "AdGuard Home",
"configuration_url": f"{'https' if api.ssl else 'http'}://{api.host}:{api.port}",
} }
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success and bool(self.coordinator.statistics)
class AdGuardQueriesCounterSensor(AdGuardBaseSensor): class AdGuardQueriesCounterSensor(AdGuardBaseSensor):
"""Sensor to track DNS queries count.""" """Sensor to track DNS queries count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_dns_queries" self._attr_unique_id = f"{api.host}_{api.port}_dns_queries"
self._attr_name = "AdGuard DNS Queries" self._attr_name = "AdGuard DNS Queries"
self._attr_icon = ICON_STATISTICS self._attr_icon = ICON_QUERIES
self._attr_state_class = SensorStateClass.TOTAL_INCREASING self._attr_state_class = SensorStateClass.TOTAL_INCREASING
self._attr_native_unit_of_measurement = "queries" self._attr_native_unit_of_measurement = "queries"
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property @property
def native_value(self): def native_value(self) -> Optional[int]:
"""Return the state of the sensor.""" """Return the state of the sensor."""
stats = self.coordinator.statistics stats = self.coordinator.statistics
return stats.get("num_dns_queries", 0) return stats.get("num_dns_queries", 0)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
stats = self.coordinator.statistics
return {
"queries_today": stats.get("num_dns_queries_today", 0),
"replaced_safebrowsing": stats.get("num_replaced_safebrowsing", 0),
"replaced_parental": stats.get("num_replaced_parental", 0),
"replaced_safesearch": stats.get("num_replaced_safesearch", 0),
}
class AdGuardBlockedCounterSensor(AdGuardBaseSensor): class AdGuardBlockedCounterSensor(AdGuardBaseSensor):
"""Sensor to track blocked queries count.""" """Sensor to track blocked queries count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries" self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries"
self._attr_name = "AdGuard Blocked Queries" self._attr_name = "AdGuard Blocked Queries"
self._attr_icon = ICON_STATISTICS self._attr_icon = ICON_BLOCKED
self._attr_state_class = SensorStateClass.TOTAL_INCREASING self._attr_state_class = SensorStateClass.TOTAL_INCREASING
self._attr_native_unit_of_measurement = "queries" self._attr_native_unit_of_measurement = "queries"
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property @property
def native_value(self): def native_value(self) -> Optional[int]:
"""Return the state of the sensor.""" """Return the state of the sensor."""
stats = self.coordinator.statistics stats = self.coordinator.statistics
return stats.get("num_blocked_filtering", 0) return stats.get("num_blocked_filtering", 0)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
stats = self.coordinator.statistics
return {
"blocked_today": stats.get("num_blocked_filtering_today", 0),
"malware_phishing": stats.get("num_replaced_safebrowsing", 0),
"adult_websites": stats.get("num_replaced_parental", 0),
}
class AdGuardBlockingPercentageSensor(AdGuardBaseSensor): class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
"""Sensor to track blocking percentage.""" """Sensor to track blocking percentage."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage" self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage"
self._attr_name = "AdGuard Blocking Percentage" self._attr_name = "AdGuard Blocking Percentage"
self._attr_icon = ICON_STATISTICS self._attr_icon = ICON_PERCENTAGE
self._attr_state_class = SensorStateClass.MEASUREMENT self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = PERCENTAGE self._attr_native_unit_of_measurement = PERCENTAGE
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property @property
def native_value(self): def native_value(self) -> Optional[float]:
"""Return the state of the sensor.""" """Return the state of the sensor."""
stats = self.coordinator.statistics stats = self.coordinator.statistics
total_queries = stats.get("num_dns_queries", 0) total_queries = stats.get("num_dns_queries", 0)
@@ -117,16 +150,75 @@ class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
class AdGuardClientCountSensor(AdGuardBaseSensor): class AdGuardClientCountSensor(AdGuardBaseSensor):
"""Sensor to track active clients count.""" """Sensor to track active clients count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_clients_count" self._attr_unique_id = f"{api.host}_{api.port}_clients_count"
self._attr_name = "AdGuard Clients Count" self._attr_name = "AdGuard Clients Count"
self._attr_icon = ICON_STATISTICS self._attr_icon = ICON_CLIENTS
self._attr_state_class = SensorStateClass.MEASUREMENT self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "clients" self._attr_native_unit_of_measurement = "clients"
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property @property
def native_value(self): def native_value(self) -> Optional[int]:
"""Return the state of the sensor.""" """Return the state of the sensor."""
return len(self.coordinator.clients) return len(self.coordinator.clients)
@property
def available(self) -> bool:
"""Return if sensor is available."""
return self.coordinator.last_update_success
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
clients = self.coordinator.clients
protected_clients = sum(1 for c in clients.values() if c.get("filtering_enabled", True))
return {
"protected_clients": protected_clients,
"unprotected_clients": len(clients) - protected_clients,
"client_names": list(clients.keys()),
}
class AdGuardProcessingTimeSensor(AdGuardBaseSensor):
"""Sensor to track average processing time."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_avg_processing_time"
self._attr_name = "AdGuard Average Processing Time"
self._attr_icon = "mdi:speedometer"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = UnitOfTime.MILLISECONDS
self._attr_entity_category = EntityCategory.DIAGNOSTIC
self._attr_device_class = SensorDeviceClass.DURATION
@property
def native_value(self) -> Optional[float]:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
avg_time = stats.get("avg_processing_time", 0)
return round(avg_time, 2) if avg_time else 0
class AdGuardFilteringRulesSensor(AdGuardBaseSensor):
"""Sensor to track number of filtering rules."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_filtering_rules"
self._attr_name = "AdGuard Filtering Rules"
self._attr_icon = "mdi:filter"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "rules"
self._attr_entity_category = EntityCategory.DIAGNOSTIC
@property
def native_value(self) -> Optional[int]:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("filtering_rules_count", 0)

View File

@@ -7,7 +7,7 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from .api import AdGuardHomeAPI from .api import AdGuardHomeAPI, AdGuardHomeError
from .const import ( from .const import (
DOMAIN, DOMAIN,
BLOCKED_SERVICES, BLOCKED_SERVICES,
@@ -15,47 +15,101 @@ from .const import (
ATTR_SERVICES, ATTR_SERVICES,
ATTR_DURATION, ATTR_DURATION,
ATTR_CLIENTS, ATTR_CLIENTS,
ATTR_ENABLED,
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_REFRESH_DATA,
) )
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Service schemas
SCHEMA_BLOCK_SERVICES = vol.Schema({ SCHEMA_BLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string, vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
}) })
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({ SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
vol.Required(ATTR_DURATION): cv.positive_int, vol.Required(ATTR_DURATION): cv.positive_int,
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]), vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
}) })
SCHEMA_ADD_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("filtering_enabled", default=True): cv.boolean,
vol.Optional("safebrowsing_enabled", default=False): cv.boolean,
vol.Optional("parental_enabled", default=False): cv.boolean,
vol.Optional("safesearch_enabled", default=False): cv.boolean,
vol.Optional("use_global_blocked_services", default=True): cv.boolean,
vol.Optional("blocked_services", default=[]): vol.All(cv.ensure_list, [cv.string]),
})
SCHEMA_REMOVE_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
})
SCHEMA_REFRESH_DATA = vol.Schema({})
class AdGuardControlHubServices: class AdGuardControlHubServices:
"""Handle services for AdGuard Control Hub.""" """Handle services for AdGuard Control Hub."""
def __init__(self, hass: HomeAssistant): def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the services.""" """Initialize the services."""
self.hass = hass self.hass = hass
def register_services(self) -> None: def register_services(self) -> None:
"""Register all services.""" """Register all services."""
self.hass.services.register( _LOGGER.debug("Registering AdGuard Control Hub services")
DOMAIN, "block_services", self.block_services, schema=SCHEMA_BLOCK_SERVICES
) services = [
self.hass.services.register( (SERVICE_BLOCK_SERVICES, self.block_services, SCHEMA_BLOCK_SERVICES),
DOMAIN, "unblock_services", self.unblock_services, schema=SCHEMA_BLOCK_SERVICES (SERVICE_UNBLOCK_SERVICES, self.unblock_services, SCHEMA_UNBLOCK_SERVICES),
) (SERVICE_EMERGENCY_UNBLOCK, self.emergency_unblock, SCHEMA_EMERGENCY_UNBLOCK),
self.hass.services.register( (SERVICE_ADD_CLIENT, self.add_client, SCHEMA_ADD_CLIENT),
DOMAIN, "emergency_unblock", self.emergency_unblock, schema=SCHEMA_EMERGENCY_UNBLOCK (SERVICE_REMOVE_CLIENT, self.remove_client, SCHEMA_REMOVE_CLIENT),
) (SERVICE_REFRESH_DATA, self.refresh_data, SCHEMA_REFRESH_DATA),
]
for service_name, service_func, schema in services:
if not self.hass.services.has_service(DOMAIN, service_name):
self.hass.services.register(DOMAIN, service_name, service_func, schema=schema)
_LOGGER.debug("Registered service: %s", service_name)
def unregister_services(self) -> None: def unregister_services(self) -> None:
"""Unregister all services.""" """Unregister all services."""
services = ["block_services", "unblock_services", "emergency_unblock"] _LOGGER.debug("Unregistering AdGuard Control Hub services")
for service in services: services = [
if self.hass.services.has_service(DOMAIN, service): SERVICE_BLOCK_SERVICES,
self.hass.services.remove(DOMAIN, service) SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_REFRESH_DATA,
]
for service_name in services:
if self.hass.services.has_service(DOMAIN, service_name):
self.hass.services.remove(DOMAIN, service_name)
_LOGGER.debug("Unregistered service: %s", service_name)
def _get_api_instances(self) -> list[AdGuardHomeAPI]:
"""Get all API instances."""
apis = []
for entry_data in self.hass.data.get(DOMAIN, {}).values():
if isinstance(entry_data, dict) and "api" in entry_data:
apis.append(entry_data["api"])
return apis
async def block_services(self, call: ServiceCall) -> None: async def block_services(self, call: ServiceCall) -> None:
"""Block services for a specific client.""" """Block services for a specific client."""
@@ -64,36 +118,62 @@ class AdGuardControlHubServices:
_LOGGER.info("Blocking services %s for client %s", services, client_name) _LOGGER.info("Blocking services %s for client %s", services, client_name)
for entry_data in self.hass.data[DOMAIN].values(): success_count = 0
api: AdGuardHomeAPI = entry_data["api"] for api in self._get_api_instances():
try: try:
client = await api.get_client_by_name(client_name) client = await api.get_client_by_name(client_name)
if client: if client:
current_blocked = client.get("blocked_services", {}) current_blocked = client.get("blocked_services", {})
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or [] if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked or []
updated_services = list(set(current_services + services)) updated_services = list(set(current_services + services))
await api.update_client_blocked_services(client_name, updated_services) await api.update_client_blocked_services(client_name, updated_services)
success_count += 1
_LOGGER.info("Successfully blocked services for %s", client_name) _LOGGER.info("Successfully blocked services for %s", client_name)
else:
_LOGGER.warning("Client %s not found", client_name)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error blocking services for %s: %s", client_name, err)
except Exception as err: except Exception as err:
_LOGGER.error("Failed to block services for %s: %s", client_name, err) _LOGGER.exception("Unexpected error blocking services for %s: %s", client_name, err)
if success_count == 0:
_LOGGER.error("Failed to block services for %s on any instance", client_name)
async def unblock_services(self, call: ServiceCall) -> None: async def unblock_services(self, call: ServiceCall) -> None:
"""Unblock services for a specific client.""" """Unblock services for a specific client."""
client_name = call.data[ATTR_CLIENT_NAME] client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES] services = call.data[ATTR_SERVICES]
for entry_data in self.hass.data[DOMAIN].values(): _LOGGER.info("Unblocking services %s for client %s", services, client_name)
api: AdGuardHomeAPI = entry_data["api"]
success_count = 0
for api in self._get_api_instances():
try: try:
client = await api.get_client_by_name(client_name) client = await api.get_client_by_name(client_name)
if client: if client:
current_blocked = client.get("blocked_services", {}) current_blocked = client.get("blocked_services", {})
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or [] if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked or []
updated_services = [s for s in current_services if s not in services] updated_services = [s for s in current_services if s not in services]
await api.update_client_blocked_services(client_name, updated_services) await api.update_client_blocked_services(client_name, updated_services)
success_count += 1
_LOGGER.info("Successfully unblocked services for %s", client_name) _LOGGER.info("Successfully unblocked services for %s", client_name)
else:
_LOGGER.warning("Client %s not found", client_name)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error unblocking services for %s: %s", client_name, err)
except Exception as err: except Exception as err:
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err) _LOGGER.exception("Unexpected error unblocking services for %s: %s", client_name, err)
if success_count == 0:
_LOGGER.error("Failed to unblock services for %s on any instance", client_name)
async def emergency_unblock(self, call: ServiceCall) -> None: async def emergency_unblock(self, call: ServiceCall) -> None:
"""Emergency unblock - temporarily disable protection.""" """Emergency unblock - temporarily disable protection."""
@@ -102,20 +182,95 @@ class AdGuardControlHubServices:
_LOGGER.warning("Emergency unblock activated for %s seconds", duration) _LOGGER.warning("Emergency unblock activated for %s seconds", duration)
for entry_data in self.hass.data[DOMAIN].values(): for api in self._get_api_instances():
api: AdGuardHomeAPI = entry_data["api"]
try: try:
if "all" in clients: if "all" in clients:
await api.set_protection(False) await api.set_protection(False)
_LOGGER.warning("Protection disabled for %s:%s", api.host, api.port)
# Re-enable after duration # Re-enable after duration
async def delayed_enable(): async def delayed_enable(api_instance: AdGuardHomeAPI):
await asyncio.sleep(duration) await asyncio.sleep(duration)
try: try:
await api.set_protection(True) await api_instance.set_protection(True)
_LOGGER.info("Emergency unblock expired - protection re-enabled") _LOGGER.info("Emergency unblock expired - protection re-enabled for %s:%s",
api_instance.host, api_instance.port)
except Exception as err: except Exception as err:
_LOGGER.error("Failed to re-enable protection: %s", err) _LOGGER.error("Failed to re-enable protection for %s:%s: %s",
api_instance.host, api_instance.port, err)
asyncio.create_task(delayed_enable()) asyncio.create_task(delayed_enable(api))
else:
# Individual client emergency unblock
for client_name in clients:
if client_name == "all":
continue
try:
client = await api.get_client_by_name(client_name)
if client:
update_data = {
"name": client_name,
"data": {**client, "filtering_enabled": False}
}
await api.update_client(update_data)
_LOGGER.info("Emergency unblock applied to client %s", client_name)
except Exception as err:
_LOGGER.error("Failed to emergency unblock client %s: %s", client_name, err)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error during emergency unblock: %s", err)
except Exception as err: except Exception as err:
_LOGGER.error("Failed to execute emergency unblock: %s", err) _LOGGER.exception("Unexpected error during emergency unblock: %s", err)
async def add_client(self, call: ServiceCall) -> None:
"""Add a new client."""
client_data = dict(call.data)
_LOGGER.info("Adding new client: %s", client_data.get("name"))
success_count = 0
for api in self._get_api_instances():
try:
await api.add_client(client_data)
success_count += 1
_LOGGER.info("Successfully added client: %s", client_data.get("name"))
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error adding client: %s", err)
except Exception as err:
_LOGGER.exception("Unexpected error adding client: %s", err)
if success_count == 0:
_LOGGER.error("Failed to add client %s on any instance", client_data.get("name"))
async def remove_client(self, call: ServiceCall) -> None:
"""Remove a client."""
client_name = call.data.get("name")
_LOGGER.info("Removing client: %s", client_name)
success_count = 0
for api in self._get_api_instances():
try:
await api.delete_client(client_name)
success_count += 1
_LOGGER.info("Successfully removed client: %s", client_name)
except AdGuardHomeError as err:
_LOGGER.error("AdGuard error removing client: %s", err)
except Exception as err:
_LOGGER.exception("Unexpected error removing client: %s", err)
if success_count == 0:
_LOGGER.error("Failed to remove client %s on any instance", client_name)
async def refresh_data(self, call: ServiceCall) -> None:
"""Refresh data for all coordinators."""
_LOGGER.info("Manually refreshing AdGuard Control Hub data")
for entry_data in self.hass.data.get(DOMAIN, {}).values():
if isinstance(entry_data, dict) and "coordinator" in entry_data:
coordinator = entry_data["coordinator"]
try:
await coordinator.async_request_refresh()
_LOGGER.debug("Refreshed coordinator data")
except Exception as err:
_LOGGER.error("Failed to refresh coordinator: %s", err)

View File

@@ -7,21 +7,22 @@
"data": { "data": {
"host": "Host", "host": "Host",
"port": "Port", "port": "Port",
"username": "Username", "username": "Username (optional)",
"password": "Password", "password": "Password (optional)",
"ssl": "Use SSL", "ssl": "Use SSL",
"verify_ssl": "Verify SSL Certificate" "verify_ssl": "Verify SSL Certificate"
} }
} }
}, },
"error": { "error": {
"cannot_connect": "Failed to connect to AdGuard Home", "cannot_connect": "Failed to connect to AdGuard Home. Please check the host and port.",
"invalid_auth": "Invalid username or password", "invalid_auth": "Invalid username or password. Please verify your credentials.",
"timeout": "Connection timeout", "invalid_host": "Invalid host format. Please enter a valid hostname or IP address.",
"unknown": "Unexpected error occurred" "timeout": "Connection timeout. Please check your network connection and try again.",
"unknown": "Unexpected error occurred. Please check your configuration and try again."
}, },
"abort": { "abort": {
"already_configured": "AdGuard Control Hub is already configured" "already_configured": "AdGuard Control Hub is already configured for this host and port"
} }
} }
} }

View File

@@ -1,15 +1,16 @@
"""Switch platform for AdGuard Control Hub integration.""" """Switch platform for AdGuard Control Hub integration."""
import logging import logging
from typing import Any from typing import Any, Optional
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import SwitchEntity, SwitchDeviceClass
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity import EntityCategory
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AdGuardControlHubCoordinator from . import AdGuardControlHubCoordinator
from .api import AdGuardHomeAPI from .api import AdGuardHomeAPI, AdGuardHomeError
from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -30,13 +31,13 @@ async def async_setup_entry(
for client_name in coordinator.clients.keys(): for client_name in coordinator.clients.keys():
entities.append(AdGuardClientSwitch(coordinator, api, client_name)) entities.append(AdGuardClientSwitch(coordinator, api, client_name))
async_add_entities(entities) async_add_entities(entities, update_before_add=True)
class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity): class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity):
"""Base class for AdGuard switches.""" """Base class for AdGuard switches."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the switch.""" """Initialize the switch."""
super().__init__(coordinator) super().__init__(coordinator)
self.api = api self.api = api
@@ -45,20 +46,28 @@ class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity):
"name": f"AdGuard Control Hub ({api.host})", "name": f"AdGuard Control Hub ({api.host})",
"manufacturer": MANUFACTURER, "manufacturer": MANUFACTURER,
"model": "AdGuard Home", "model": "AdGuard Home",
"configuration_url": f"{'https' if api.ssl else 'http'}://{api.host}:{api.port}",
} }
@property
def available(self) -> bool:
"""Return if switch is available."""
return self.coordinator.last_update_success
class AdGuardProtectionSwitch(AdGuardBaseSwitch): class AdGuardProtectionSwitch(AdGuardBaseSwitch):
"""Switch to control global AdGuard protection.""" """Switch to control global AdGuard protection."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None:
"""Initialize the switch.""" """Initialize the switch."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_protection" self._attr_unique_id = f"{api.host}_{api.port}_protection"
self._attr_name = "AdGuard Protection" self._attr_name = "AdGuard Protection"
self._attr_device_class = SwitchDeviceClass.SWITCH
self._attr_entity_category = EntityCategory.CONFIG
@property @property
def is_on(self) -> bool | None: def is_on(self) -> Optional[bool]:
"""Return true if protection is enabled.""" """Return true if protection is enabled."""
return self.coordinator.protection_status.get("protection_enabled", False) return self.coordinator.protection_status.get("protection_enabled", False)
@@ -67,23 +76,47 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
"""Return the icon for the switch.""" """Return the icon for the switch."""
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
@property
def available(self) -> bool:
"""Return if switch is available."""
return self.coordinator.last_update_success and bool(self.coordinator.protection_status)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
status = self.coordinator.protection_status
return {
"dns_port": status.get("dns_port", "N/A"),
"version": status.get("version", "N/A"),
"running": status.get("running", False),
"dns_addresses": status.get("dns_addresses", []),
}
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on AdGuard protection.""" """Turn on AdGuard protection."""
try: try:
await self.api.set_protection(True) await self.api.set_protection(True)
await self.coordinator.async_request_refresh() await self.coordinator.async_request_refresh()
except Exception as err: _LOGGER.info("AdGuard protection enabled")
except AdGuardHomeError as err:
_LOGGER.error("Failed to enable AdGuard protection: %s", err) _LOGGER.error("Failed to enable AdGuard protection: %s", err)
raise raise
except Exception as err:
_LOGGER.exception("Unexpected error enabling AdGuard protection")
raise
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off AdGuard protection.""" """Turn off AdGuard protection."""
try: try:
await self.api.set_protection(False) await self.api.set_protection(False)
await self.coordinator.async_request_refresh() await self.coordinator.async_request_refresh()
except Exception as err: _LOGGER.warning("AdGuard protection disabled")
except AdGuardHomeError as err:
_LOGGER.error("Failed to disable AdGuard protection: %s", err) _LOGGER.error("Failed to disable AdGuard protection: %s", err)
raise raise
except Exception as err:
_LOGGER.exception("Unexpected error disabling AdGuard protection")
raise
class AdGuardClientSwitch(AdGuardBaseSwitch): class AdGuardClientSwitch(AdGuardBaseSwitch):
@@ -94,20 +127,49 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
coordinator: AdGuardControlHubCoordinator, coordinator: AdGuardControlHubCoordinator,
api: AdGuardHomeAPI, api: AdGuardHomeAPI,
client_name: str, client_name: str,
): ) -> None:
"""Initialize the switch.""" """Initialize the switch."""
super().__init__(coordinator, api) super().__init__(coordinator, api)
self.client_name = client_name self.client_name = client_name
self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}" self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}"
self._attr_name = f"AdGuard {client_name}" self._attr_name = f"AdGuard {client_name}"
self._attr_icon = ICON_CLIENT self._attr_icon = ICON_CLIENT
self._attr_device_class = SwitchDeviceClass.SWITCH
self._attr_entity_category = EntityCategory.CONFIG
@property @property
def is_on(self) -> bool | None: def is_on(self) -> Optional[bool]:
"""Return true if client protection is enabled.""" """Return true if client protection is enabled."""
client = self.coordinator.clients.get(self.client_name, {}) client = self.coordinator.clients.get(self.client_name, {})
return client.get("filtering_enabled", True) return client.get("filtering_enabled", True)
@property
def available(self) -> bool:
"""Return if switch is available."""
return (
self.coordinator.last_update_success
and self.client_name in self.coordinator.clients
)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
client = self.coordinator.clients.get(self.client_name, {})
blocked_services = client.get("blocked_services", {})
if isinstance(blocked_services, dict):
blocked_list = blocked_services.get("ids", [])
else:
blocked_list = blocked_services or []
return {
"client_ids": client.get("ids", []),
"safebrowsing_enabled": client.get("safebrowsing_enabled", False),
"parental_enabled": client.get("parental_enabled", False),
"safesearch_enabled": client.get("safesearch_enabled", False),
"blocked_services_count": len(blocked_list),
"blocked_services": blocked_list,
}
async def async_turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Enable protection for this client.""" """Enable protection for this client."""
try: try:
@@ -119,9 +181,15 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
} }
await self.api.update_client(update_data) await self.api.update_client(update_data)
await self.coordinator.async_request_refresh() await self.coordinator.async_request_refresh()
except Exception as err: _LOGGER.info("Enabled protection for client: %s", self.client_name)
else:
_LOGGER.error("Client not found: %s", self.client_name)
except AdGuardHomeError as err:
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err) _LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
raise raise
except Exception as err:
_LOGGER.exception("Unexpected error enabling protection for %s", self.client_name)
raise
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Disable protection for this client.""" """Disable protection for this client."""
@@ -134,6 +202,12 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
} }
await self.api.update_client(update_data) await self.api.update_client(update_data)
await self.coordinator.async_request_refresh() await self.coordinator.async_request_refresh()
except Exception as err: _LOGGER.info("Disabled protection for client: %s", self.client_name)
else:
_LOGGER.error("Client not found: %s", self.client_name)
except AdGuardHomeError as err:
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err) _LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
raise raise
except Exception as err:
_LOGGER.exception("Unexpected error disabling protection for %s", self.client_name)
raise

View File

@@ -16,7 +16,7 @@ addopts = [
"--cov=custom_components.adguard_hub", "--cov=custom_components.adguard_hub",
"--cov-report=term-missing", "--cov-report=term-missing",
"--cov-report=html", "--cov-report=html",
"--cov-fail-under=80", "--cov-fail-under=60",
"--asyncio-mode=auto", "--asyncio-mode=auto",
"-v" "-v"
] ]

View File

@@ -6,7 +6,7 @@ from homeassistant.config_entries import ConfigEntry, SOURCE_USER
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME
from custom_components.adguard_hub.api import AdGuardHomeAPI from custom_components.adguard_hub.api import AdGuardHomeAPI
from custom_components.adguard_hub.const import DOMAIN from custom_components.adguard_hub.const import DOMAIN, CONF_SSL, CONF_VERIFY_SSL
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -28,11 +28,15 @@ def mock_config_entry():
CONF_PORT: 3000, CONF_PORT: 3000,
CONF_USERNAME: "admin", CONF_USERNAME: "admin",
CONF_PASSWORD: "password", CONF_PASSWORD: "password",
CONF_SSL: False,
CONF_VERIFY_SSL: True,
}, },
options={}, options={},
source=SOURCE_USER, source=SOURCE_USER,
entry_id="test_entry_id", entry_id="test_entry_id",
unique_id="192.168.1.100:3000", unique_id="192.168.1.100:3000",
discovery_keys=set(), # Added required parameter
subentries_data={}, # Added required parameter
) )
@@ -42,39 +46,140 @@ def mock_api():
api = MagicMock(spec=AdGuardHomeAPI) api = MagicMock(spec=AdGuardHomeAPI)
api.host = "192.168.1.100" api.host = "192.168.1.100"
api.port = 3000 api.port = 3000
api.ssl = False
api.verify_ssl = True
# Mock successful connection
api.test_connection = AsyncMock(return_value=True) api.test_connection = AsyncMock(return_value=True)
# Mock status response
api.get_status = AsyncMock(return_value={ api.get_status = AsyncMock(return_value={
"protection_enabled": True, "protection_enabled": True,
"version": "v0.107.0", "version": "v0.107.0",
"dns_port": 53, "dns_port": 53,
"running": True, "running": True,
"dns_addresses": ["192.168.1.100:53"],
"bootstrap_dns": ["1.1.1.1", "8.8.8.8"],
"upstream_dns": ["1.1.1.1", "8.8.8.8", "1.0.0.1", "8.8.4.4"],
"safebrowsing_enabled": True,
"parental_enabled": False,
"safesearch_enabled": False,
"dhcp_available": False,
}) })
# Mock clients response
api.get_clients = AsyncMock(return_value={ api.get_clients = AsyncMock(return_value={
"clients": [ "clients": [
{ {
"name": "test_client", "name": "test_client",
"ids": ["192.168.1.50"], "ids": ["192.168.1.50"],
"filtering_enabled": True, "filtering_enabled": True,
"blocked_services": {"ids": ["youtube"]}, "safebrowsing_enabled": False,
"parental_enabled": False,
"safesearch_enabled": False,
"use_global_settings": True,
"use_global_blocked_services": True,
"blocked_services": {"ids": ["youtube", "gaming"]},
},
{
"name": "test_client_2",
"ids": ["192.168.1.51"],
"filtering_enabled": False,
"safebrowsing_enabled": True,
"parental_enabled": True,
"safesearch_enabled": False,
"use_global_settings": False,
"blocked_services": {"ids": ["netflix"]},
} }
] ]
}) })
# Mock statistics response
api.get_statistics = AsyncMock(return_value={ api.get_statistics = AsyncMock(return_value={
"num_dns_queries": 10000, "num_dns_queries": 10000,
"num_blocked_filtering": 1500, "num_blocked_filtering": 1500,
"num_dns_queries_today": 5000,
"num_blocked_filtering_today": 750,
"num_replaced_safebrowsing": 50,
"num_replaced_parental": 25,
"num_replaced_safesearch": 10,
"avg_processing_time": 2.5, "avg_processing_time": 2.5,
"filtering_rules_count": 75000, "filtering_rules_count": 75000,
}) })
# Mock client operations
api.get_client_by_name = AsyncMock(return_value={ api.get_client_by_name = AsyncMock(return_value={
"name": "test_client", "name": "test_client",
"ids": ["192.168.1.50"], "ids": ["192.168.1.50"],
"filtering_enabled": True, "filtering_enabled": True,
"blocked_services": {"ids": ["youtube"]}, "blocked_services": {"ids": ["youtube"]},
}) })
api.add_client = AsyncMock(return_value={"success": True})
api.update_client = AsyncMock(return_value={"success": True})
api.delete_client = AsyncMock(return_value={"success": True})
api.update_client_blocked_services = AsyncMock(return_value={"success": True})
api.set_protection = AsyncMock(return_value={"success": True}) api.set_protection = AsyncMock(return_value={"success": True})
api.close = AsyncMock(return_value=None)
return api return api
@pytest.fixture
def mock_coordinator(mock_api):
"""Mock coordinator with test data."""
from custom_components.adguard_hub import AdGuardControlHubCoordinator
coordinator = MagicMock(spec=AdGuardControlHubCoordinator)
coordinator.last_update_success = True
coordinator.api = mock_api
# Mock clients data
coordinator.clients = {
"test_client": {
"name": "test_client",
"ids": ["192.168.1.50"],
"filtering_enabled": True,
"blocked_services": {"ids": ["youtube"]},
},
"test_client_2": {
"name": "test_client_2",
"ids": ["192.168.1.51"],
"filtering_enabled": False,
"blocked_services": {"ids": ["netflix"]},
}
}
# Mock statistics data
coordinator.statistics = {
"num_dns_queries": 10000,
"num_blocked_filtering": 1500,
"avg_processing_time": 2.5,
"filtering_rules_count": 75000,
}
# Mock protection status
coordinator.protection_status = {
"protection_enabled": True,
"version": "v0.107.0",
"dns_port": 53,
"running": True,
"safebrowsing_enabled": True,
"parental_enabled": False,
"safesearch_enabled": False,
}
coordinator.data = {
"clients": coordinator.clients,
"statistics": coordinator.statistics,
"status": coordinator.protection_status,
}
coordinator.async_request_refresh = AsyncMock()
return coordinator
@pytest.fixture @pytest.fixture
def mock_hass(): def mock_hass():
"""Mock Home Assistant instance.""" """Mock Home Assistant instance."""
@@ -88,3 +193,25 @@ def mock_hass():
hass.config_entries.async_forward_entry_setups = AsyncMock(return_value=True) hass.config_entries.async_forward_entry_setups = AsyncMock(return_value=True)
hass.config_entries.async_unload_platforms = AsyncMock(return_value=True) hass.config_entries.async_unload_platforms = AsyncMock(return_value=True)
return hass return hass
@pytest.fixture
def mock_aiohttp_session():
"""Mock aiohttp session."""
session = MagicMock()
response = MagicMock()
response.raise_for_status = MagicMock()
response.json = AsyncMock(return_value={"status": "ok"})
response.text = AsyncMock(return_value="OK")
response.status = 200
response.content_length = 100
# Mock async context manager
context_manager = MagicMock()
context_manager.__aenter__ = AsyncMock(return_value=response)
context_manager.__aexit__ = AsyncMock(return_value=None)
session.request = MagicMock(return_value=context_manager)
session.close = AsyncMock()
return session

View File

@@ -1,7 +1,16 @@
"""Test API functionality.""" """Test API functionality."""
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock, patch
from custom_components.adguard_hub.api import AdGuardHomeAPI from aiohttp import ClientError, ClientTimeout
from custom_components.adguard_hub.api import (
AdGuardHomeAPI,
AdGuardHomeError,
AdGuardConnectionError,
AdGuardAuthError,
AdGuardNotFoundError,
AdGuardTimeoutError,
)
class TestAdGuardHomeAPI: class TestAdGuardHomeAPI:
@@ -24,6 +33,17 @@ class TestAdGuardHomeAPI:
assert api.ssl is True assert api.ssl is True
assert api.base_url == "https://192.168.1.100:3000" assert api.base_url == "https://192.168.1.100:3000"
def test_api_initialization_defaults(self):
"""Test API initialization with defaults."""
api = AdGuardHomeAPI(host="192.168.1.100")
assert api.host == "192.168.1.100"
assert api.port == 3000
assert api.username is None
assert api.password is None
assert api.ssl is False
assert api.base_url == "http://192.168.1.100:3000"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_context_manager(self): async def test_api_context_manager(self):
"""Test API as async context manager.""" """Test API as async context manager."""
@@ -33,21 +53,236 @@ class TestAdGuardHomeAPI:
assert api.port == 3000 assert api.port == 3000
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_test_connection_success(self): async def test_test_connection_success(self, mock_aiohttp_session):
"""Test successful connection test.""" """Test successful connection test."""
session = MagicMock() mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
response = MagicMock() return_value={"protection_enabled": True}
response.status = 200 )
response.json = AsyncMock(return_value={"protection_enabled": True})
response.raise_for_status = MagicMock()
response.content_length = 100
context_manager = MagicMock() api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
context_manager.__aenter__ = AsyncMock(return_value=response)
context_manager.__aexit__ = AsyncMock(return_value=None)
session.request = MagicMock(return_value=context_manager)
api = AdGuardHomeAPI(host="192.168.1.100", session=session)
result = await api.test_connection() result = await api.test_connection()
assert result is True assert result is True
mock_aiohttp_session.request.assert_called()
@pytest.mark.asyncio
async def test_test_connection_failure(self, mock_aiohttp_session):
"""Test failed connection test."""
mock_aiohttp_session.request.side_effect = ClientError("Connection failed")
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
result = await api.test_connection()
assert result is False
@pytest.mark.asyncio
async def test_get_status_success(self, mock_aiohttp_session):
"""Test successful status retrieval."""
expected_status = {
"protection_enabled": True,
"version": "v0.107.0",
"running": True,
}
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value=expected_status
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
status = await api.get_status()
assert status == expected_status
@pytest.mark.asyncio
async def test_get_clients_success(self, mock_aiohttp_session):
"""Test successful clients retrieval."""
expected_clients = {
"clients": [
{"name": "client1", "ids": ["192.168.1.50"]},
{"name": "client2", "ids": ["192.168.1.51"]},
]
}
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value=expected_clients
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
clients = await api.get_clients()
assert clients == expected_clients
@pytest.mark.asyncio
async def test_get_statistics_success(self, mock_aiohttp_session):
"""Test successful statistics retrieval."""
expected_stats = {
"num_dns_queries": 10000,
"num_blocked_filtering": 1500,
}
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value=expected_stats
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
stats = await api.get_statistics()
assert stats == expected_stats
@pytest.mark.asyncio
async def test_set_protection_enable(self, mock_aiohttp_session):
"""Test enabling protection."""
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value={"success": True}
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
result = await api.set_protection(True)
assert result == {"success": True}
@pytest.mark.asyncio
async def test_add_client_success(self, mock_aiohttp_session):
"""Test successful client addition."""
client_data = {
"name": "test_client",
"ids": ["192.168.1.100"],
}
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value={"success": True}
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
result = await api.add_client(client_data)
assert result == {"success": True}
@pytest.mark.asyncio
async def test_add_client_missing_name(self, mock_aiohttp_session):
"""Test client addition with missing name."""
client_data = {"ids": ["192.168.1.100"]}
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
with pytest.raises(ValueError, match="Client name is required"):
await api.add_client(client_data)
@pytest.mark.asyncio
async def test_add_client_missing_ids(self, mock_aiohttp_session):
"""Test client addition with missing IDs."""
client_data = {"name": "test_client"}
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
with pytest.raises(ValueError, match="Client IDs are required"):
await api.add_client(client_data)
@pytest.mark.asyncio
async def test_get_client_by_name_found(self, mock_aiohttp_session):
"""Test finding client by name."""
clients_data = {
"clients": [
{"name": "test_client", "ids": ["192.168.1.50"]},
{"name": "other_client", "ids": ["192.168.1.51"]},
]
}
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value=clients_data
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
client = await api.get_client_by_name("test_client")
assert client == {"name": "test_client", "ids": ["192.168.1.50"]}
@pytest.mark.asyncio
async def test_get_client_by_name_not_found(self, mock_aiohttp_session):
"""Test client not found by name."""
clients_data = {"clients": []}
mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock(
return_value=clients_data
)
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
client = await api.get_client_by_name("nonexistent_client")
assert client is None
@pytest.mark.asyncio
async def test_update_client_blocked_services_client_not_found(self, mock_aiohttp_session):
"""Test blocked services update with client not found."""
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
api.get_client_by_name = AsyncMock(return_value=None)
with pytest.raises(AdGuardNotFoundError, match="Client 'nonexistent' not found"):
await api.update_client_blocked_services("nonexistent", ["youtube"])
@pytest.mark.asyncio
async def test_auth_error_handling(self, mock_aiohttp_session):
"""Test 401 authentication error handling."""
mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value
mock_response.status = 401
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
with pytest.raises(AdGuardAuthError, match="Authentication failed"):
await api.get_status()
@pytest.mark.asyncio
async def test_not_found_error_handling(self, mock_aiohttp_session):
"""Test 404 not found error handling."""
mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value
mock_response.status = 404
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
with pytest.raises(AdGuardNotFoundError):
await api.get_status()
@pytest.mark.asyncio
async def test_server_error_handling(self, mock_aiohttp_session):
"""Test 500 server error handling."""
mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value
mock_response.status = 500
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
with pytest.raises(AdGuardConnectionError, match="Server error 500"):
await api.get_status()
@pytest.mark.asyncio
async def test_client_error_handling(self, mock_aiohttp_session):
"""Test client error handling."""
mock_aiohttp_session.request.side_effect = ClientError("Client error")
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
with pytest.raises(AdGuardConnectionError, match="Client error"):
await api.get_status()
@pytest.mark.asyncio
async def test_empty_response_handling(self, mock_aiohttp_session):
"""Test empty response handling."""
mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value
mock_response.status = 204
mock_response.content_length = 0
api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session)
result = await api._request("POST", "/control/protection", {"enabled": True})
assert result == {}
@pytest.mark.asyncio
async def test_close_session(self):
"""Test closing API session."""
api = AdGuardHomeAPI(host="192.168.1.100")
# Create session
async with api:
assert api._session is not None
# Close session
await api.close()

View File

@@ -1,9 +1,16 @@
"""Test the complete AdGuard Control Hub integration.""" """Test the complete AdGuard Control Hub integration."""
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers.update_coordinator import UpdateFailed
from custom_components.adguard_hub import async_setup_entry, async_unload_entry from custom_components.adguard_hub import (
async_setup_entry,
async_unload_entry,
AdGuardControlHubCoordinator,
)
from custom_components.adguard_hub.api import AdGuardConnectionError, AdGuardAuthError
from custom_components.adguard_hub.const import DOMAIN from custom_components.adguard_hub.const import DOMAIN
@@ -13,28 +20,72 @@ class TestIntegrationSetup:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_entry_success(self, mock_hass, mock_config_entry, mock_api): async def test_setup_entry_success(self, mock_hass, mock_config_entry, mock_api):
"""Test successful setup of config entry.""" """Test successful setup of config entry."""
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"): with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession") as mock_session:
result = await async_setup_entry(mock_hass, mock_config_entry) # Mock the coordinator's first refresh
with patch("custom_components.adguard_hub.AdGuardControlHubCoordinator.async_config_entry_first_refresh", new=AsyncMock()):
result = await async_setup_entry(mock_hass, mock_config_entry)
assert result is True assert result is True
assert DOMAIN in mock_hass.data assert DOMAIN in mock_hass.data
assert mock_config_entry.entry_id in mock_hass.data[DOMAIN] assert mock_config_entry.entry_id in mock_hass.data[DOMAIN]
assert "coordinator" in mock_hass.data[DOMAIN][mock_config_entry.entry_id]
assert "api" in mock_hass.data[DOMAIN][mock_config_entry.entry_id]
# Verify platforms setup
mock_hass.config_entries.async_forward_entry_setups.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_entry_connection_failure(self, mock_hass, mock_config_entry): async def test_setup_entry_connection_failure(self, mock_hass, mock_config_entry):
"""Test setup failure due to connection error.""" """Test setup failure due to connection error."""
mock_api = MagicMock() mock_api = MagicMock()
mock_api.test_connection = pytest.AsyncMock(return_value=False) mock_api.test_connection = AsyncMock(return_value=False)
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"): with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"):
with pytest.raises(ConfigEntryNotReady): with pytest.raises(ConfigEntryNotReady, match="Unable to connect to AdGuard Home"):
await async_setup_entry(mock_hass, mock_config_entry) await async_setup_entry(mock_hass, mock_config_entry)
@pytest.mark.asyncio
async def test_setup_entry_api_error(self, mock_hass, mock_config_entry):
"""Test setup failure due to API error."""
mock_api = MagicMock()
mock_api.test_connection = AsyncMock(side_effect=AdGuardAuthError("Auth failed"))
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"):
with pytest.raises(ConfigEntryNotReady, match="Unable to connect"):
await async_setup_entry(mock_hass, mock_config_entry)
@pytest.mark.asyncio
async def test_setup_entry_coordinator_failure(self, mock_hass, mock_config_entry, mock_api):
"""Test setup failure due to coordinator refresh error."""
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"), patch.object(AdGuardControlHubCoordinator, "async_config_entry_first_refresh",
side_effect=UpdateFailed("Refresh failed")):
with pytest.raises(ConfigEntryNotReady, match="Failed to fetch initial data"):
await async_setup_entry(mock_hass, mock_config_entry)
@pytest.mark.asyncio
async def test_setup_entry_platform_failure(self, mock_hass, mock_config_entry, mock_api):
"""Test setup failure due to platform setup error."""
mock_hass.config_entries.async_forward_entry_setups = AsyncMock(
side_effect=Exception("Platform setup failed")
)
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"), patch.object(AdGuardControlHubCoordinator, "async_config_entry_first_refresh",
new=AsyncMock()):
with pytest.raises(ConfigEntryNotReady, match="Failed to set up platforms"):
await async_setup_entry(mock_hass, mock_config_entry)
# Verify cleanup
assert mock_config_entry.entry_id not in mock_hass.data.get(DOMAIN, {})
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unload_entry_success(self, mock_hass, mock_config_entry): async def test_unload_entry_success(self, mock_hass, mock_config_entry):
"""Test successful unloading of config entry.""" """Test successful unloading of config entry."""
# Set up initial data
mock_hass.data[DOMAIN] = { mock_hass.data[DOMAIN] = {
mock_config_entry.entry_id: { mock_config_entry.entry_id: {
"coordinator": MagicMock(), "coordinator": MagicMock(),
@@ -46,3 +97,287 @@ class TestIntegrationSetup:
assert result is True assert result is True
assert mock_config_entry.entry_id not in mock_hass.data[DOMAIN] assert mock_config_entry.entry_id not in mock_hass.data[DOMAIN]
mock_hass.config_entries.async_unload_platforms.assert_called_once()
@pytest.mark.asyncio
async def test_unload_entry_last_instance(self, mock_hass, mock_config_entry):
"""Test unloading last config entry unregisters services."""
# Set up services
mock_services = MagicMock()
mock_services.unregister_services = MagicMock()
mock_hass.data[f"{DOMAIN}_services"] = mock_services
mock_hass.data[DOMAIN] = {
mock_config_entry.entry_id: {
"coordinator": MagicMock(),
"api": MagicMock(),
}
}
result = await async_unload_entry(mock_hass, mock_config_entry)
assert result is True
assert f"{DOMAIN}_services" not in mock_hass.data
assert DOMAIN not in mock_hass.data
mock_services.unregister_services.assert_called_once()
class TestCoordinator:
"""Test the data update coordinator."""
def test_coordinator_initialization(self, mock_hass, mock_api):
"""Test coordinator initialization."""
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
assert coordinator.api == mock_api
assert coordinator.name == f"{DOMAIN}_coordinator"
@pytest.mark.asyncio
async def test_coordinator_update_success(self, mock_hass, mock_api):
"""Test successful coordinator data update."""
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
data = await coordinator._async_update_data()
assert "clients" in data
assert "statistics" in data
assert "status" in data
assert "test_client" in data["clients"]
assert data["statistics"]["num_dns_queries"] == 10000
assert data["status"]["protection_enabled"] is True
@pytest.mark.asyncio
async def test_coordinator_update_partial_failure(self, mock_hass, mock_api):
"""Test coordinator update with partial API failures."""
# Make one API call fail
mock_api.get_clients = AsyncMock(side_effect=Exception("Client fetch failed"))
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
data = await coordinator._async_update_data()
# Should still return data from successful calls
assert "clients" in data
assert "statistics" in data
assert "status" in data
assert data["statistics"]["num_dns_queries"] == 10000
@pytest.mark.asyncio
async def test_coordinator_update_connection_error(self, mock_hass, mock_api):
"""Test coordinator update with connection error."""
mock_api.get_status = AsyncMock(side_effect=AdGuardConnectionError("Connection failed"))
mock_api.get_clients = AsyncMock(side_effect=AdGuardConnectionError("Connection failed"))
mock_api.get_statistics = AsyncMock(side_effect=AdGuardConnectionError("Connection failed"))
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
with pytest.raises(UpdateFailed, match="Connection error to AdGuard Home"):
await coordinator._async_update_data()
@pytest.mark.asyncio
async def test_coordinator_update_unexpected_error(self, mock_hass, mock_api):
"""Test coordinator update with unexpected error."""
mock_api.get_status = AsyncMock(side_effect=Exception("Unexpected error"))
mock_api.get_clients = AsyncMock(side_effect=Exception("Unexpected error"))
mock_api.get_statistics = AsyncMock(side_effect=Exception("Unexpected error"))
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
with pytest.raises(UpdateFailed, match="Error communicating with AdGuard Control Hub"):
await coordinator._async_update_data()
def test_coordinator_properties(self, mock_hass, mock_api):
"""Test coordinator properties."""
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
# Set test data
test_clients = {"client1": {"name": "client1"}}
test_stats = {"num_dns_queries": 5000}
test_status = {"protection_enabled": False}
coordinator._clients = test_clients
coordinator._statistics = test_stats
coordinator._protection_status = test_status
assert coordinator.clients == test_clients
assert coordinator.statistics == test_stats
assert coordinator.protection_status == test_status
def test_coordinator_properties_empty_data(self, mock_hass, mock_api):
"""Test coordinator properties with empty data."""
coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api)
# Properties should return empty containers, not None
assert coordinator.clients == {}
assert coordinator.statistics == {}
assert coordinator.protection_status == {}
class TestServices:
"""Test service functionality."""
def test_services_registration(self, mock_hass):
"""Test that services are properly registered."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
services = AdGuardControlHubServices(mock_hass)
services.register_services()
# Verify services registration was called
assert mock_hass.services.register.called
# Verify correct number of service registrations
expected_call_count = 6 # block_services, unblock_services, emergency_unblock, add_client, remove_client, refresh_data
assert mock_hass.services.register.call_count == expected_call_count
def test_services_unregistration(self, mock_hass):
"""Test that services are properly unregistered."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
# Mock service existence
mock_hass.services.has_service.return_value = True
services = AdGuardControlHubServices(mock_hass)
services.unregister_services()
# Verify correct number of service removals
expected_call_count = 6
assert mock_hass.services.remove.call_count == expected_call_count
@pytest.mark.asyncio
async def test_block_services_success(self, mock_hass, mock_api):
"""Test successful service blocking."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
mock_hass.data[DOMAIN] = {
"entry_id": {"api": mock_api}
}
services = AdGuardControlHubServices(mock_hass)
call = MagicMock()
call.data = {
"client_name": "test_client",
"services": ["youtube", "netflix"]
}
await services.block_services(call)
mock_api.get_client_by_name.assert_called_once_with("test_client")
mock_api.update_client_blocked_services.assert_called_once()
@pytest.mark.asyncio
async def test_unblock_services_success(self, mock_hass, mock_api):
"""Test successful service unblocking."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
mock_hass.data[DOMAIN] = {
"entry_id": {"api": mock_api}
}
services = AdGuardControlHubServices(mock_hass)
call = MagicMock()
call.data = {
"client_name": "test_client",
"services": ["youtube"]
}
await services.unblock_services(call)
mock_api.get_client_by_name.assert_called_once_with("test_client")
mock_api.update_client_blocked_services.assert_called_once()
@pytest.mark.asyncio
async def test_emergency_unblock_global(self, mock_hass, mock_api):
"""Test emergency unblock for all clients."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
mock_hass.data[DOMAIN] = {
"entry_id": {"api": mock_api}
}
services = AdGuardControlHubServices(mock_hass)
call = MagicMock()
call.data = {
"duration": 300,
"clients": ["all"]
}
await services.emergency_unblock(call)
mock_api.set_protection.assert_called_once_with(False)
@pytest.mark.asyncio
async def test_refresh_data_success(self, mock_hass, mock_coordinator):
"""Test successful data refresh."""
from custom_components.adguard_hub.services import AdGuardControlHubServices
mock_hass.data[DOMAIN] = {
"entry_id": {"coordinator": mock_coordinator}
}
services = AdGuardControlHubServices(mock_hass)
call = MagicMock()
call.data = {}
await services.refresh_data(call)
mock_coordinator.async_request_refresh.assert_called_once()
class TestConstants:
"""Test constant definitions."""
def test_blocked_services_constants(self):
"""Test that blocked services are properly defined."""
from custom_components.adguard_hub.const import BLOCKED_SERVICES
required_services = ["youtube", "netflix", "gaming", "facebook"]
for service in required_services:
assert service in BLOCKED_SERVICES
assert isinstance(BLOCKED_SERVICES[service], str)
assert len(BLOCKED_SERVICES[service]) > 0
def test_api_endpoints_constants(self):
"""Test that API endpoints are properly defined."""
from custom_components.adguard_hub.const import API_ENDPOINTS
required_endpoints = [
"status", "clients", "stats", "protection",
"clients_add", "clients_update", "clients_delete"
]
for endpoint in required_endpoints:
assert endpoint in API_ENDPOINTS
assert API_ENDPOINTS[endpoint].startswith("/")
def test_platform_constants(self):
"""Test platform constants."""
from custom_components.adguard_hub.const import PLATFORMS
expected_platforms = ["switch", "binary_sensor", "sensor"]
assert PLATFORMS == expected_platforms
def test_service_constants(self):
"""Test service name constants."""
from custom_components.adguard_hub.const import (
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_REFRESH_DATA,
)
services = [
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_REFRESH_DATA,
]
for service in services:
assert isinstance(service, str)
assert len(service) > 0
assert "_" in service # Snake case format