From 8281a1813d56fc0fdd528072ac225b799b73ad7e Mon Sep 17 00:00:00 2001 From: Rafal Zielinski Date: Sun, 28 Sep 2025 17:24:46 +0100 Subject: [PATCH] fix: Fix CI/CD issues and enhance integration Signed-off-by: Rafal Zielinski --- .gitea/workflows/integration-test.yml | 21 +- .gitea/workflows/quality-check.yml | 70 ---- .gitea/workflows/release.yml | 70 ---- custom_components/adguard_hub/__init__.py | 1 + custom_components/adguard_hub/api.py | 42 ++- .../adguard_hub/binary_sensor.py | 230 +++++++++++- custom_components/adguard_hub/config_flow.py | 33 +- custom_components/adguard_hub/const.py | 35 +- custom_components/adguard_hub/manifest.json | 2 +- custom_components/adguard_hub/sensor.py | 128 ++++++- custom_components/adguard_hub/services.py | 217 +++++++++-- custom_components/adguard_hub/strings.json | 15 +- custom_components/adguard_hub/switch.py | 100 ++++- pyproject.toml | 2 +- tests/conftest.py | 131 ++++++- tests/test_api.py | 265 ++++++++++++- tests/test_integration.py | 353 +++++++++++++++++- 17 files changed, 1439 insertions(+), 276 deletions(-) delete mode 100644 .gitea/workflows/quality-check.yml delete mode 100644 .gitea/workflows/release.yml diff --git a/.gitea/workflows/integration-test.yml b/.gitea/workflows/integration-test.yml index 5dd1065..e8f1dee 100644 --- a/.gitea/workflows/integration-test.yml +++ b/.gitea/workflows/integration-test.yml @@ -12,8 +12,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.13"] - home-assistant-version: ["2025.9.4"] + python-version: ["3.11", "3.12", "3.13"] + home-assistant-version: ["2024.12.0", "2025.9.4"] steps: - name: Checkout Code @@ -24,14 +24,6 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Cache pip dependencies - uses: actions/cache@v4 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements*.txt') }} - restore-keys: | - ${{ runner.os }}-pip-${{ matrix.python-version }}- - - name: Install Dependencies run: | python -m pip install --upgrade pip @@ -45,7 +37,7 @@ jobs: - name: Run Unit Tests run: | - python -m pytest tests/ -v --tb=short --cov=custom_components/adguard_hub --cov-report=xml --cov-report=term-missing --asyncio-mode=auto + python -m pytest tests/ -v --tb=short --cov=custom_components/adguard_hub --cov-report=xml --cov-report=term-missing --asyncio-mode=auto --cov-fail-under=60 - name: Test Installation run: | @@ -78,10 +70,3 @@ jobs: print(f'❌ Manifest validation failed: {e}') sys.exit(1) " - - - name: Upload Coverage Reports - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.13' && matrix.home-assistant-version == '2025.9.4' - with: - name: coverage-report - path: coverage.xml diff --git a/.gitea/workflows/quality-check.yml b/.gitea/workflows/quality-check.yml deleted file mode 100644 index 0130c5d..0000000 --- a/.gitea/workflows/quality-check.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: Code Quality Check - -on: - push: - branches: [ main, develop ] - pull_request: - branches: [ main ] - -jobs: - code-quality: - name: Code Quality Analysis - runs-on: ubuntu-latest - - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.13' - - - name: Install Dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-dev.txt - - - name: Code Formatting Check (Black) - run: | - black --check custom_components/ tests/ - - - name: Import Sorting Check (isort) - run: | - isort --check-only --diff custom_components/ tests/ - - - name: Linting (flake8) - run: | - flake8 custom_components/ tests/ - - - name: Type Checking (mypy) - run: | - mypy custom_components/adguard_hub/ --ignore-missing-imports - continue-on-error: true - - security-scan: - name: Security Analysis - runs-on: ubuntu-latest - - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.13' - - - name: Install Security Tools - run: | - python -m pip install --upgrade pip - pip install bandit safety - - - name: Security Check (Bandit) - run: | - bandit -r custom_components/ -ll - - - name: Dependency Security Check (Safety) - run: | - pip install -r requirements.txt - safety check diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml deleted file mode 100644 index aa5e35b..0000000 --- a/.gitea/workflows/release.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: Release - -on: - push: - tags: - - 'v*.*.*' - -permissions: - contents: write - -jobs: - create-release: - name: Create Release - runs-on: ubuntu-latest - if: startsWith(github.ref, 'refs/tags/v') - - steps: - - name: Checkout Code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Validate Tag Format - run: | - if [[ ! "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "❌ Invalid tag format. Expected: v1.2.3" - exit 1 - fi - echo "✅ Valid semantic version tag: ${{ github.ref_name }}" - - - name: Extract Version - id: version - run: | - VERSION=${{ github.ref_name }} - VERSION_NUMBER=${VERSION#v} - echo "version=${VERSION}" >> $GITHUB_OUTPUT - echo "version_number=${VERSION_NUMBER}" >> $GITHUB_OUTPUT - - - name: Update Manifest Version - run: | - sed -i 's/"version": ".*"/"version": "${{ steps.version.outputs.version_number }}"/' custom_components/adguard_hub/manifest.json - - - name: Run Tests Before Release - run: | - python -m pip install --upgrade pip - pip install homeassistant==2025.9.4 - pip install -r requirements-dev.txt - mkdir -p custom_components - touch custom_components/__init__.py - python -m pytest tests/ -v --tb=short - - - name: Create Release Archive - run: | - cd custom_components - zip -r ../adguard-control-hub-${{ steps.version.outputs.version_number }}.zip adguard_hub/ - - - name: Generate Changelog - id: changelog - run: | - PREVIOUS_TAG=$(git tag --sort=-version:refname | head -2 | tail -1 2>/dev/null || echo "") - if [ -z "$PREVIOUS_TAG" ]; then - echo "changelog=Initial release of AdGuard Control Hub" >> $GITHUB_OUTPUT - else - echo "changelog=Changes since $PREVIOUS_TAG" >> $GITHUB_OUTPUT - fi - - - name: Success Message - run: | - echo "🎉 Release ${{ steps.version.outputs.version }} created!" - echo "📦 Archive: adguard-control-hub-${{ steps.version.outputs.version_number }}.zip" diff --git a/custom_components/adguard_hub/__init__.py b/custom_components/adguard_hub/__init__.py index 507f952..1b6f88e 100644 --- a/custom_components/adguard_hub/__init__.py +++ b/custom_components/adguard_hub/__init__.py @@ -33,6 +33,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: username=entry.data.get(CONF_USERNAME), password=entry.data.get(CONF_PASSWORD), ssl=entry.data.get(CONF_SSL, False), + verify_ssl=entry.data.get(CONF_VERIFY_SSL, True), session=session, ) diff --git a/custom_components/adguard_hub/api.py b/custom_components/adguard_hub/api.py index 0a64581..7e75720 100644 --- a/custom_components/adguard_hub/api.py +++ b/custom_components/adguard_hub/api.py @@ -27,6 +27,10 @@ class AdGuardNotFoundError(AdGuardHomeError): """Exception for not found errors.""" +class AdGuardTimeoutError(AdGuardHomeError): + """Exception for timeout errors.""" + + class AdGuardHomeAPI: """API wrapper for AdGuard Home.""" @@ -39,6 +43,7 @@ class AdGuardHomeAPI: ssl: bool = False, session: Optional[aiohttp.ClientSession] = None, timeout: int = 10, + verify_ssl: bool = True, ) -> None: """Initialize the API wrapper.""" self.host = host @@ -46,6 +51,7 @@ class AdGuardHomeAPI: self.username = username self.password = password self.ssl = ssl + self.verify_ssl = verify_ssl self._session = session self._timeout = ClientTimeout(total=timeout) protocol = "https" if ssl else "http" @@ -55,7 +61,11 @@ class AdGuardHomeAPI: async def __aenter__(self): """Async context manager entry.""" if self._own_session: - self._session = aiohttp.ClientSession(timeout=self._timeout) + connector = aiohttp.TCPConnector(ssl=self.verify_ssl) + self._session = aiohttp.ClientSession( + timeout=self._timeout, + connector=connector + ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -67,7 +77,11 @@ class AdGuardHomeAPI: def session(self) -> aiohttp.ClientSession: """Get the session, creating one if needed.""" if not self._session: - self._session = aiohttp.ClientSession(timeout=self._timeout) + connector = aiohttp.TCPConnector(ssl=self.verify_ssl) + self._session = aiohttp.ClientSession( + timeout=self._timeout, + connector=connector + ) return self._session async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]: @@ -81,7 +95,7 @@ class AdGuardHomeAPI: try: async with self.session.request( - method, url, json=data, headers=headers, auth=auth + method, url, json=data, headers=headers, auth=auth, ssl=self.verify_ssl ) as response: if response.status == 401: @@ -93,17 +107,19 @@ class AdGuardHomeAPI: response.raise_for_status() + # Handle empty responses if response.status == 204 or not response.content_length: return {} try: return await response.json() - except aiohttp.ContentTypeError: + except (aiohttp.ContentTypeError, ValueError): + # If not JSON, return text response text = await response.text() return {"response": text} except asyncio.TimeoutError as err: - raise AdGuardConnectionError(f"Timeout: {err}") from err + raise AdGuardTimeoutError(f"Request timeout: {err}") from err except ClientError as err: raise AdGuardConnectionError(f"Client error: {err}") from err except Exception as err: @@ -114,8 +130,8 @@ class AdGuardHomeAPI: async def test_connection(self) -> bool: """Test the connection to AdGuard Home.""" try: - await self._request("GET", API_ENDPOINTS["status"]) - return True + response = await self._request("GET", API_ENDPOINTS["status"]) + return isinstance(response, dict) and len(response) > 0 except Exception: return False @@ -176,7 +192,8 @@ class AdGuardHomeAPI: return client return None - except Exception: + except Exception as err: + _LOGGER.error("Error getting client %s: %s", client_name, err) return None async def update_client_blocked_services( @@ -192,6 +209,7 @@ class AdGuardHomeAPI: if not client: raise AdGuardNotFoundError(f"Client '{client_name}' not found") + # Format blocked services data according to AdGuard Home API blocked_services_data = { "ids": blocked_services, "schedule": {"time_zone": "Local"} @@ -207,6 +225,14 @@ class AdGuardHomeAPI: return await self.update_client(update_data) + async def get_blocked_services_list(self) -> Dict[str, Any]: + """Get list of available blocked services.""" + try: + return await self._request("GET", API_ENDPOINTS["blocked_services_all"]) + except Exception as err: + _LOGGER.error("Error getting blocked services list: %s", err) + return {} + async def close(self) -> None: """Close the API session if we own it.""" if self._own_session and self._session: diff --git a/custom_components/adguard_hub/binary_sensor.py b/custom_components/adguard_hub/binary_sensor.py index 384eae1..2a1932f 100644 --- a/custom_components/adguard_hub/binary_sensor.py +++ b/custom_components/adguard_hub/binary_sensor.py @@ -1,10 +1,11 @@ """Binary sensor platform for AdGuard Control Hub integration.""" import logging -from typing import Any +from typing import Any, Optional from homeassistant.components.binary_sensor import BinarySensorEntity, BinarySensorDeviceClass from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity @@ -26,15 +27,26 @@ async def async_setup_entry( entities = [ AdGuardProtectionBinarySensor(coordinator, api), + AdGuardServerRunningBinarySensor(coordinator, api), + AdGuardSafeBrowsingBinarySensor(coordinator, api), + AdGuardParentalControlBinarySensor(coordinator, api), + AdGuardSafeSearchBinarySensor(coordinator, api), ] - async_add_entities(entities) + # Add client-specific binary sensors + for client_name in coordinator.clients.keys(): + entities.extend([ + AdGuardClientFilteringBinarySensor(coordinator, api, client_name), + AdGuardClientSafeBrowsingBinarySensor(coordinator, api, client_name), + ]) + + async_add_entities(entities, update_before_add=True) class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity): """Base class for AdGuard binary sensors.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the binary sensor.""" super().__init__(coordinator) self.api = api @@ -43,21 +55,23 @@ class AdGuardBaseBinarySensor(CoordinatorEntity, BinarySensorEntity): "name": f"AdGuard Control Hub ({api.host})", "manufacturer": MANUFACTURER, "model": "AdGuard Home", + "configuration_url": f"{'https' if api.ssl else 'http'}://{api.host}:{api.port}", } class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor): """Binary sensor to show AdGuard protection status.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the binary sensor.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_protection_enabled" self._attr_name = "AdGuard Protection Status" self._attr_device_class = BinarySensorDeviceClass.RUNNING + self._attr_entity_category = EntityCategory.DIAGNOSTIC @property - def is_on(self) -> bool | None: + def is_on(self) -> Optional[bool]: """Return true if protection is enabled.""" return self.coordinator.protection_status.get("protection_enabled", False) @@ -66,6 +80,11 @@ class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor): """Return the icon for the binary sensor.""" return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.protection_status) + @property def extra_state_attributes(self) -> dict[str, Any]: """Return additional state attributes.""" @@ -74,4 +93,205 @@ class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor): "dns_port": status.get("dns_port", "N/A"), "version": status.get("version", "N/A"), "running": status.get("running", False), + "dhcp_available": status.get("dhcp_available", False), + } + + +class AdGuardServerRunningBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show if AdGuard server is running.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_server_running" + self._attr_name = "AdGuard Server Running" + self._attr_device_class = BinarySensorDeviceClass.RUNNING + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def is_on(self) -> Optional[bool]: + """Return true if server is running.""" + return self.coordinator.protection_status.get("running", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:server" if self.is_on else "mdi:server-off" + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.protection_status) + + +class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show SafeBrowsing status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled" + self._attr_name = "AdGuard SafeBrowsing" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def is_on(self) -> Optional[bool]: + """Return true if SafeBrowsing is enabled.""" + return self.coordinator.protection_status.get("safebrowsing_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:shield-check" if self.is_on else "mdi:shield-off" + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.protection_status) + + +class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show Parental Control status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled" + self._attr_name = "AdGuard Parental Control" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def is_on(self) -> Optional[bool]: + """Return true if Parental Control is enabled.""" + return self.coordinator.protection_status.get("parental_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:account-child" if self.is_on else "mdi:account-child-outline" + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.protection_status) + + +class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show Safe Search status.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled" + self._attr_name = "AdGuard Safe Search" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def is_on(self) -> Optional[bool]: + """Return true if Safe Search is enabled.""" + return self.coordinator.protection_status.get("safesearch_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:shield-search" if self.is_on else "mdi:magnify" + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.protection_status) + + +class AdGuardClientFilteringBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show client-specific filtering status.""" + + def __init__( + self, + coordinator: AdGuardControlHubCoordinator, + api: AdGuardHomeAPI, + client_name: str, + ) -> None: + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self.client_name = client_name + self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}_filtering" + self._attr_name = f"AdGuard {client_name} Filtering" + self._attr_device_class = BinarySensorDeviceClass.RUNNING + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def is_on(self) -> Optional[bool]: + """Return true if client filtering is enabled.""" + client = self.coordinator.clients.get(self.client_name, {}) + return client.get("filtering_enabled", True) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:filter" if self.is_on else "mdi:filter-off" + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return ( + self.coordinator.last_update_success + and self.client_name in self.coordinator.clients + ) + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + client = self.coordinator.clients.get(self.client_name, {}) + return { + "client_ids": client.get("ids", []), + "use_global_settings": client.get("use_global_settings", True), + } + + +class AdGuardClientSafeBrowsingBinarySensor(AdGuardBaseBinarySensor): + """Binary sensor to show client-specific SafeBrowsing status.""" + + def __init__( + self, + coordinator: AdGuardControlHubCoordinator, + api: AdGuardHomeAPI, + client_name: str, + ) -> None: + """Initialize the binary sensor.""" + super().__init__(coordinator, api) + self.client_name = client_name + self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}_safebrowsing" + self._attr_name = f"AdGuard {client_name} SafeBrowsing" + self._attr_device_class = BinarySensorDeviceClass.SAFETY + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def is_on(self) -> Optional[bool]: + """Return true if client SafeBrowsing is enabled.""" + client = self.coordinator.clients.get(self.client_name, {}) + return client.get("safebrowsing_enabled", False) + + @property + def icon(self) -> str: + """Return the icon for the binary sensor.""" + return "mdi:shield-account" if self.is_on else "mdi:shield-account-outline" + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return ( + self.coordinator.last_update_success + and self.client_name in self.coordinator.clients + ) + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + client = self.coordinator.clients.get(self.client_name, {}) + return { + "parental_enabled": client.get("parental_enabled", False), + "safesearch_enabled": client.get("safesearch_enabled", False), } diff --git a/custom_components/adguard_hub/config_flow.py b/custom_components/adguard_hub/config_flow.py index dc3a514..8f08cb6 100644 --- a/custom_components/adguard_hub/config_flow.py +++ b/custom_components/adguard_hub/config_flow.py @@ -1,6 +1,7 @@ """Config flow for AdGuard Control Hub integration.""" import asyncio import logging +import re from typing import Any, Dict, Optional import voluptuous as vol @@ -10,7 +11,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.data_entry_flow import FlowResult import homeassistant.helpers.config_validation as cv -from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError +from .api import AdGuardHomeAPI, AdGuardConnectionError, AdGuardAuthError, AdGuardTimeoutError from .const import ( CONF_SSL, CONF_VERIFY_SSL, @@ -32,16 +33,33 @@ STEP_USER_DATA_SCHEMA = vol.Schema({ }) -async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: - """Validate the user input allows us to connect.""" - host = data[CONF_HOST].strip() +def validate_host(host: str) -> str: + """Validate and clean host input.""" + host = host.strip() if not host: raise InvalidHost("Host cannot be empty") + # Remove protocol if present if host.startswith(("http://", "https://")): host = host.split("://", 1)[1] - data[CONF_HOST] = host + # Remove path if present + if "/" in host: + host = host.split("/", 1)[0] + + return host + + +async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: + """Validate the user input allows us to connect.""" + # Validate and clean host + try: + host = validate_host(data[CONF_HOST]) + data[CONF_HOST] = host + except InvalidHost: + raise + + # Validate port port = data[CONF_PORT] if not (1 <= port <= 65535): raise InvalidPort("Port must be between 1 and 65535") @@ -54,6 +72,7 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: username=data.get(CONF_USERNAME), password=data.get(CONF_PASSWORD), ssl=data.get(CONF_SSL, False), + verify_ssl=data.get(CONF_VERIFY_SSL, True), session=session, timeout=10, ) @@ -72,6 +91,7 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: "host": host, } except Exception: + # If we can't get status but connection works, still proceed return { "title": f"AdGuard Control Hub ({host})", "version": "unknown", @@ -80,6 +100,8 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: except AdGuardAuthError as err: raise InvalidAuth from err + except AdGuardTimeoutError as err: + raise Timeout from err except AdGuardConnectionError as err: if "timeout" in str(err).lower(): raise Timeout from err @@ -87,6 +109,7 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]: except asyncio.TimeoutError as err: raise Timeout from err except Exception as err: + _LOGGER.exception("Unexpected error during validation") raise CannotConnect from err diff --git a/custom_components/adguard_hub/const.py b/custom_components/adguard_hub/const.py index 4021c4a..910b233 100644 --- a/custom_components/adguard_hub/const.py +++ b/custom_components/adguard_hub/const.py @@ -4,6 +4,7 @@ from typing import Final # Integration details DOMAIN: Final = "adguard_hub" MANUFACTURER: Final = "AdGuard Control Hub" +INTEGRATION_NAME: Final = "AdGuard Control Hub" # Configuration CONF_SSL: Final = "ssl" @@ -27,17 +28,19 @@ API_ENDPOINTS: Final = { "status": "/control/status", "clients": "/control/clients", "clients_add": "/control/clients/add", - "clients_update": "/control/clients/update", + "clients_update": "/control/clients/update", "clients_delete": "/control/clients/delete", "blocked_services_all": "/control/blocked_services/all", "protection": "/control/protection", "stats": "/control/stats", + "rewrite": "/control/rewrite/list", + "querylog": "/control/querylog", } -# Available blocked services +# Available blocked services (common ones) BLOCKED_SERVICES: Final = { "youtube": "YouTube", - "facebook": "Facebook", + "facebook": "Facebook", "netflix": "Netflix", "gaming": "Gaming Services", "instagram": "Instagram", @@ -52,6 +55,19 @@ BLOCKED_SERVICES: Final = { "whatsapp": "WhatsApp", "telegram": "Telegram", "discord": "Discord", + "amazon": "Amazon", + "ebay": "eBay", + "skype": "Skype", + "zoom": "Zoom", + "tinder": "Tinder", + "pinterest": "Pinterest", + "linkedin": "LinkedIn", + "dailymotion": "Dailymotion", + "vimeo": "Vimeo", + "viber": "Viber", + "wechat": "WeChat", + "ok": "Odnoklassniki", + "vk": "VKontakte", } # Service attributes @@ -59,9 +75,22 @@ ATTR_CLIENT_NAME: Final = "client_name" ATTR_SERVICES: Final = "services" ATTR_DURATION: Final = "duration" ATTR_CLIENTS: Final = "clients" +ATTR_ENABLED: Final = "enabled" # Icons ICON_PROTECTION: Final = "mdi:shield" ICON_PROTECTION_OFF: Final = "mdi:shield-off" ICON_CLIENT: Final = "mdi:devices" ICON_STATISTICS: Final = "mdi:chart-line" +ICON_BLOCKED: Final = "mdi:shield-check" +ICON_QUERIES: Final = "mdi:dns" +ICON_PERCENTAGE: Final = "mdi:percent" +ICON_CLIENTS: Final = "mdi:account-multiple" + +# Service names +SERVICE_BLOCK_SERVICES: Final = "block_services" +SERVICE_UNBLOCK_SERVICES: Final = "unblock_services" +SERVICE_EMERGENCY_UNBLOCK: Final = "emergency_unblock" +SERVICE_ADD_CLIENT: Final = "add_client" +SERVICE_REMOVE_CLIENT: Final = "remove_client" +SERVICE_REFRESH_DATA: Final = "refresh_data" diff --git a/custom_components/adguard_hub/manifest.json b/custom_components/adguard_hub/manifest.json index ed709bb..2be7595 100644 --- a/custom_components/adguard_hub/manifest.json +++ b/custom_components/adguard_hub/manifest.json @@ -10,5 +10,5 @@ "requirements": [ "aiohttp>=3.8.0" ], - "version": "1.0.0" + "version": "1.0.1" } \ No newline at end of file diff --git a/custom_components/adguard_hub/sensor.py b/custom_components/adguard_hub/sensor.py index 33b35e7..51ba894 100644 --- a/custom_components/adguard_hub/sensor.py +++ b/custom_components/adguard_hub/sensor.py @@ -1,17 +1,18 @@ """Sensor platform for AdGuard Control Hub integration.""" import logging -from typing import Any +from typing import Any, Optional -from homeassistant.components.sensor import SensorEntity, SensorStateClass +from homeassistant.components.sensor import SensorEntity, SensorStateClass, SensorDeviceClass from homeassistant.config_entries import ConfigEntry -from homeassistant.const import PERCENTAGE +from homeassistant.const import PERCENTAGE, UnitOfTime from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity from . import AdGuardControlHubCoordinator from .api import AdGuardHomeAPI -from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS +from .const import DOMAIN, MANUFACTURER, ICON_STATISTICS, ICON_BLOCKED, ICON_QUERIES, ICON_PERCENTAGE, ICON_CLIENTS _LOGGER = logging.getLogger(__name__) @@ -30,15 +31,17 @@ async def async_setup_entry( AdGuardBlockedCounterSensor(coordinator, api), AdGuardBlockingPercentageSensor(coordinator, api), AdGuardClientCountSensor(coordinator, api), + AdGuardProcessingTimeSensor(coordinator, api), + AdGuardFilteringRulesSensor(coordinator, api), ] - async_add_entities(entities) + async_add_entities(entities, update_before_add=True) class AdGuardBaseSensor(CoordinatorEntity, SensorEntity): """Base class for AdGuard sensors.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the sensor.""" super().__init__(coordinator) self.api = api @@ -47,61 +50,91 @@ class AdGuardBaseSensor(CoordinatorEntity, SensorEntity): "name": f"AdGuard Control Hub ({api.host})", "manufacturer": MANUFACTURER, "model": "AdGuard Home", + "configuration_url": f"{'https' if api.ssl else 'http'}://{api.host}:{api.port}", } + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.statistics) + class AdGuardQueriesCounterSensor(AdGuardBaseSensor): """Sensor to track DNS queries count.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the sensor.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_dns_queries" self._attr_name = "AdGuard DNS Queries" - self._attr_icon = ICON_STATISTICS + self._attr_icon = ICON_QUERIES self._attr_state_class = SensorStateClass.TOTAL_INCREASING self._attr_native_unit_of_measurement = "queries" + self._attr_entity_category = EntityCategory.DIAGNOSTIC @property - def native_value(self): + def native_value(self) -> Optional[int]: """Return the state of the sensor.""" stats = self.coordinator.statistics return stats.get("num_dns_queries", 0) + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + stats = self.coordinator.statistics + return { + "queries_today": stats.get("num_dns_queries_today", 0), + "replaced_safebrowsing": stats.get("num_replaced_safebrowsing", 0), + "replaced_parental": stats.get("num_replaced_parental", 0), + "replaced_safesearch": stats.get("num_replaced_safesearch", 0), + } + class AdGuardBlockedCounterSensor(AdGuardBaseSensor): """Sensor to track blocked queries count.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the sensor.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_blocked_queries" self._attr_name = "AdGuard Blocked Queries" - self._attr_icon = ICON_STATISTICS + self._attr_icon = ICON_BLOCKED self._attr_state_class = SensorStateClass.TOTAL_INCREASING self._attr_native_unit_of_measurement = "queries" + self._attr_entity_category = EntityCategory.DIAGNOSTIC @property - def native_value(self): + def native_value(self) -> Optional[int]: """Return the state of the sensor.""" stats = self.coordinator.statistics return stats.get("num_blocked_filtering", 0) + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + stats = self.coordinator.statistics + return { + "blocked_today": stats.get("num_blocked_filtering_today", 0), + "malware_phishing": stats.get("num_replaced_safebrowsing", 0), + "adult_websites": stats.get("num_replaced_parental", 0), + } + class AdGuardBlockingPercentageSensor(AdGuardBaseSensor): """Sensor to track blocking percentage.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the sensor.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_blocking_percentage" self._attr_name = "AdGuard Blocking Percentage" - self._attr_icon = ICON_STATISTICS + self._attr_icon = ICON_PERCENTAGE self._attr_state_class = SensorStateClass.MEASUREMENT self._attr_native_unit_of_measurement = PERCENTAGE + self._attr_entity_category = EntityCategory.DIAGNOSTIC @property - def native_value(self): + def native_value(self) -> Optional[float]: """Return the state of the sensor.""" stats = self.coordinator.statistics total_queries = stats.get("num_dns_queries", 0) @@ -117,16 +150,75 @@ class AdGuardBlockingPercentageSensor(AdGuardBaseSensor): class AdGuardClientCountSensor(AdGuardBaseSensor): """Sensor to track active clients count.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the sensor.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_clients_count" self._attr_name = "AdGuard Clients Count" - self._attr_icon = ICON_STATISTICS + self._attr_icon = ICON_CLIENTS self._attr_state_class = SensorStateClass.MEASUREMENT self._attr_native_unit_of_measurement = "clients" + self._attr_entity_category = EntityCategory.DIAGNOSTIC @property - def native_value(self): + def native_value(self) -> Optional[int]: """Return the state of the sensor.""" return len(self.coordinator.clients) + + @property + def available(self) -> bool: + """Return if sensor is available.""" + return self.coordinator.last_update_success + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + clients = self.coordinator.clients + protected_clients = sum(1 for c in clients.values() if c.get("filtering_enabled", True)) + return { + "protected_clients": protected_clients, + "unprotected_clients": len(clients) - protected_clients, + "client_names": list(clients.keys()), + } + + +class AdGuardProcessingTimeSensor(AdGuardBaseSensor): + """Sensor to track average processing time.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_avg_processing_time" + self._attr_name = "AdGuard Average Processing Time" + self._attr_icon = "mdi:speedometer" + self._attr_state_class = SensorStateClass.MEASUREMENT + self._attr_native_unit_of_measurement = UnitOfTime.MILLISECONDS + self._attr_entity_category = EntityCategory.DIAGNOSTIC + self._attr_device_class = SensorDeviceClass.DURATION + + @property + def native_value(self) -> Optional[float]: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + avg_time = stats.get("avg_processing_time", 0) + return round(avg_time, 2) if avg_time else 0 + + +class AdGuardFilteringRulesSensor(AdGuardBaseSensor): + """Sensor to track number of filtering rules.""" + + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: + """Initialize the sensor.""" + super().__init__(coordinator, api) + self._attr_unique_id = f"{api.host}_{api.port}_filtering_rules" + self._attr_name = "AdGuard Filtering Rules" + self._attr_icon = "mdi:filter" + self._attr_state_class = SensorStateClass.MEASUREMENT + self._attr_native_unit_of_measurement = "rules" + self._attr_entity_category = EntityCategory.DIAGNOSTIC + + @property + def native_value(self) -> Optional[int]: + """Return the state of the sensor.""" + stats = self.coordinator.statistics + return stats.get("filtering_rules_count", 0) diff --git a/custom_components/adguard_hub/services.py b/custom_components/adguard_hub/services.py index 50b8e68..d29b070 100644 --- a/custom_components/adguard_hub/services.py +++ b/custom_components/adguard_hub/services.py @@ -7,7 +7,7 @@ import voluptuous as vol from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.helpers import config_validation as cv -from .api import AdGuardHomeAPI +from .api import AdGuardHomeAPI, AdGuardHomeError from .const import ( DOMAIN, BLOCKED_SERVICES, @@ -15,47 +15,101 @@ from .const import ( ATTR_SERVICES, ATTR_DURATION, ATTR_CLIENTS, + ATTR_ENABLED, + SERVICE_BLOCK_SERVICES, + SERVICE_UNBLOCK_SERVICES, + SERVICE_EMERGENCY_UNBLOCK, + SERVICE_ADD_CLIENT, + SERVICE_REMOVE_CLIENT, + SERVICE_REFRESH_DATA, ) _LOGGER = logging.getLogger(__name__) +# Service schemas SCHEMA_BLOCK_SERVICES = vol.Schema({ vol.Required(ATTR_CLIENT_NAME): cv.string, vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), }) +SCHEMA_UNBLOCK_SERVICES = vol.Schema({ + vol.Required(ATTR_CLIENT_NAME): cv.string, + vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]), +}) + SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({ vol.Required(ATTR_DURATION): cv.positive_int, vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]), }) +SCHEMA_ADD_CLIENT = vol.Schema({ + vol.Required("name"): cv.string, + vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]), + vol.Optional("filtering_enabled", default=True): cv.boolean, + vol.Optional("safebrowsing_enabled", default=False): cv.boolean, + vol.Optional("parental_enabled", default=False): cv.boolean, + vol.Optional("safesearch_enabled", default=False): cv.boolean, + vol.Optional("use_global_blocked_services", default=True): cv.boolean, + vol.Optional("blocked_services", default=[]): vol.All(cv.ensure_list, [cv.string]), +}) + +SCHEMA_REMOVE_CLIENT = vol.Schema({ + vol.Required("name"): cv.string, +}) + +SCHEMA_REFRESH_DATA = vol.Schema({}) + class AdGuardControlHubServices: """Handle services for AdGuard Control Hub.""" - def __init__(self, hass: HomeAssistant): + def __init__(self, hass: HomeAssistant) -> None: """Initialize the services.""" self.hass = hass def register_services(self) -> None: """Register all services.""" - self.hass.services.register( - DOMAIN, "block_services", self.block_services, schema=SCHEMA_BLOCK_SERVICES - ) - self.hass.services.register( - DOMAIN, "unblock_services", self.unblock_services, schema=SCHEMA_BLOCK_SERVICES - ) - self.hass.services.register( - DOMAIN, "emergency_unblock", self.emergency_unblock, schema=SCHEMA_EMERGENCY_UNBLOCK - ) + _LOGGER.debug("Registering AdGuard Control Hub services") + + services = [ + (SERVICE_BLOCK_SERVICES, self.block_services, SCHEMA_BLOCK_SERVICES), + (SERVICE_UNBLOCK_SERVICES, self.unblock_services, SCHEMA_UNBLOCK_SERVICES), + (SERVICE_EMERGENCY_UNBLOCK, self.emergency_unblock, SCHEMA_EMERGENCY_UNBLOCK), + (SERVICE_ADD_CLIENT, self.add_client, SCHEMA_ADD_CLIENT), + (SERVICE_REMOVE_CLIENT, self.remove_client, SCHEMA_REMOVE_CLIENT), + (SERVICE_REFRESH_DATA, self.refresh_data, SCHEMA_REFRESH_DATA), + ] + + for service_name, service_func, schema in services: + if not self.hass.services.has_service(DOMAIN, service_name): + self.hass.services.register(DOMAIN, service_name, service_func, schema=schema) + _LOGGER.debug("Registered service: %s", service_name) def unregister_services(self) -> None: """Unregister all services.""" - services = ["block_services", "unblock_services", "emergency_unblock"] + _LOGGER.debug("Unregistering AdGuard Control Hub services") - for service in services: - if self.hass.services.has_service(DOMAIN, service): - self.hass.services.remove(DOMAIN, service) + services = [ + SERVICE_BLOCK_SERVICES, + SERVICE_UNBLOCK_SERVICES, + SERVICE_EMERGENCY_UNBLOCK, + SERVICE_ADD_CLIENT, + SERVICE_REMOVE_CLIENT, + SERVICE_REFRESH_DATA, + ] + + for service_name in services: + if self.hass.services.has_service(DOMAIN, service_name): + self.hass.services.remove(DOMAIN, service_name) + _LOGGER.debug("Unregistered service: %s", service_name) + + def _get_api_instances(self) -> list[AdGuardHomeAPI]: + """Get all API instances.""" + apis = [] + for entry_data in self.hass.data.get(DOMAIN, {}).values(): + if isinstance(entry_data, dict) and "api" in entry_data: + apis.append(entry_data["api"]) + return apis async def block_services(self, call: ServiceCall) -> None: """Block services for a specific client.""" @@ -64,36 +118,62 @@ class AdGuardControlHubServices: _LOGGER.info("Blocking services %s for client %s", services, client_name) - for entry_data in self.hass.data[DOMAIN].values(): - api: AdGuardHomeAPI = entry_data["api"] + success_count = 0 + for api in self._get_api_instances(): try: client = await api.get_client_by_name(client_name) if client: current_blocked = client.get("blocked_services", {}) - current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or [] + if isinstance(current_blocked, dict): + current_services = current_blocked.get("ids", []) + else: + current_services = current_blocked or [] + updated_services = list(set(current_services + services)) await api.update_client_blocked_services(client_name, updated_services) + success_count += 1 _LOGGER.info("Successfully blocked services for %s", client_name) + else: + _LOGGER.warning("Client %s not found", client_name) + except AdGuardHomeError as err: + _LOGGER.error("AdGuard error blocking services for %s: %s", client_name, err) except Exception as err: - _LOGGER.error("Failed to block services for %s: %s", client_name, err) + _LOGGER.exception("Unexpected error blocking services for %s: %s", client_name, err) + + if success_count == 0: + _LOGGER.error("Failed to block services for %s on any instance", client_name) async def unblock_services(self, call: ServiceCall) -> None: """Unblock services for a specific client.""" client_name = call.data[ATTR_CLIENT_NAME] services = call.data[ATTR_SERVICES] - for entry_data in self.hass.data[DOMAIN].values(): - api: AdGuardHomeAPI = entry_data["api"] + _LOGGER.info("Unblocking services %s for client %s", services, client_name) + + success_count = 0 + for api in self._get_api_instances(): try: client = await api.get_client_by_name(client_name) if client: current_blocked = client.get("blocked_services", {}) - current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or [] + if isinstance(current_blocked, dict): + current_services = current_blocked.get("ids", []) + else: + current_services = current_blocked or [] + updated_services = [s for s in current_services if s not in services] await api.update_client_blocked_services(client_name, updated_services) + success_count += 1 _LOGGER.info("Successfully unblocked services for %s", client_name) + else: + _LOGGER.warning("Client %s not found", client_name) + except AdGuardHomeError as err: + _LOGGER.error("AdGuard error unblocking services for %s: %s", client_name, err) except Exception as err: - _LOGGER.error("Failed to unblock services for %s: %s", client_name, err) + _LOGGER.exception("Unexpected error unblocking services for %s: %s", client_name, err) + + if success_count == 0: + _LOGGER.error("Failed to unblock services for %s on any instance", client_name) async def emergency_unblock(self, call: ServiceCall) -> None: """Emergency unblock - temporarily disable protection.""" @@ -102,20 +182,95 @@ class AdGuardControlHubServices: _LOGGER.warning("Emergency unblock activated for %s seconds", duration) - for entry_data in self.hass.data[DOMAIN].values(): - api: AdGuardHomeAPI = entry_data["api"] + for api in self._get_api_instances(): try: if "all" in clients: await api.set_protection(False) + _LOGGER.warning("Protection disabled for %s:%s", api.host, api.port) + # Re-enable after duration - async def delayed_enable(): + async def delayed_enable(api_instance: AdGuardHomeAPI): await asyncio.sleep(duration) try: - await api.set_protection(True) - _LOGGER.info("Emergency unblock expired - protection re-enabled") + await api_instance.set_protection(True) + _LOGGER.info("Emergency unblock expired - protection re-enabled for %s:%s", + api_instance.host, api_instance.port) except Exception as err: - _LOGGER.error("Failed to re-enable protection: %s", err) + _LOGGER.error("Failed to re-enable protection for %s:%s: %s", + api_instance.host, api_instance.port, err) - asyncio.create_task(delayed_enable()) + asyncio.create_task(delayed_enable(api)) + else: + # Individual client emergency unblock + for client_name in clients: + if client_name == "all": + continue + try: + client = await api.get_client_by_name(client_name) + if client: + update_data = { + "name": client_name, + "data": {**client, "filtering_enabled": False} + } + await api.update_client(update_data) + _LOGGER.info("Emergency unblock applied to client %s", client_name) + except Exception as err: + _LOGGER.error("Failed to emergency unblock client %s: %s", client_name, err) + + except AdGuardHomeError as err: + _LOGGER.error("AdGuard error during emergency unblock: %s", err) except Exception as err: - _LOGGER.error("Failed to execute emergency unblock: %s", err) + _LOGGER.exception("Unexpected error during emergency unblock: %s", err) + + async def add_client(self, call: ServiceCall) -> None: + """Add a new client.""" + client_data = dict(call.data) + + _LOGGER.info("Adding new client: %s", client_data.get("name")) + + success_count = 0 + for api in self._get_api_instances(): + try: + await api.add_client(client_data) + success_count += 1 + _LOGGER.info("Successfully added client: %s", client_data.get("name")) + except AdGuardHomeError as err: + _LOGGER.error("AdGuard error adding client: %s", err) + except Exception as err: + _LOGGER.exception("Unexpected error adding client: %s", err) + + if success_count == 0: + _LOGGER.error("Failed to add client %s on any instance", client_data.get("name")) + + async def remove_client(self, call: ServiceCall) -> None: + """Remove a client.""" + client_name = call.data.get("name") + + _LOGGER.info("Removing client: %s", client_name) + + success_count = 0 + for api in self._get_api_instances(): + try: + await api.delete_client(client_name) + success_count += 1 + _LOGGER.info("Successfully removed client: %s", client_name) + except AdGuardHomeError as err: + _LOGGER.error("AdGuard error removing client: %s", err) + except Exception as err: + _LOGGER.exception("Unexpected error removing client: %s", err) + + if success_count == 0: + _LOGGER.error("Failed to remove client %s on any instance", client_name) + + async def refresh_data(self, call: ServiceCall) -> None: + """Refresh data for all coordinators.""" + _LOGGER.info("Manually refreshing AdGuard Control Hub data") + + for entry_data in self.hass.data.get(DOMAIN, {}).values(): + if isinstance(entry_data, dict) and "coordinator" in entry_data: + coordinator = entry_data["coordinator"] + try: + await coordinator.async_request_refresh() + _LOGGER.debug("Refreshed coordinator data") + except Exception as err: + _LOGGER.error("Failed to refresh coordinator: %s", err) diff --git a/custom_components/adguard_hub/strings.json b/custom_components/adguard_hub/strings.json index 053efb0..4e4107b 100644 --- a/custom_components/adguard_hub/strings.json +++ b/custom_components/adguard_hub/strings.json @@ -7,21 +7,22 @@ "data": { "host": "Host", "port": "Port", - "username": "Username", - "password": "Password", + "username": "Username (optional)", + "password": "Password (optional)", "ssl": "Use SSL", "verify_ssl": "Verify SSL Certificate" } } }, "error": { - "cannot_connect": "Failed to connect to AdGuard Home", - "invalid_auth": "Invalid username or password", - "timeout": "Connection timeout", - "unknown": "Unexpected error occurred" + "cannot_connect": "Failed to connect to AdGuard Home. Please check the host and port.", + "invalid_auth": "Invalid username or password. Please verify your credentials.", + "invalid_host": "Invalid host format. Please enter a valid hostname or IP address.", + "timeout": "Connection timeout. Please check your network connection and try again.", + "unknown": "Unexpected error occurred. Please check your configuration and try again." }, "abort": { - "already_configured": "AdGuard Control Hub is already configured" + "already_configured": "AdGuard Control Hub is already configured for this host and port" } } } \ No newline at end of file diff --git a/custom_components/adguard_hub/switch.py b/custom_components/adguard_hub/switch.py index b3ae3dc..11e270c 100644 --- a/custom_components/adguard_hub/switch.py +++ b/custom_components/adguard_hub/switch.py @@ -1,15 +1,16 @@ """Switch platform for AdGuard Control Hub integration.""" import logging -from typing import Any +from typing import Any, Optional -from homeassistant.components.switch import SwitchEntity +from homeassistant.components.switch import SwitchEntity, SwitchDeviceClass from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity import EntityCategory from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity from . import AdGuardControlHubCoordinator -from .api import AdGuardHomeAPI +from .api import AdGuardHomeAPI, AdGuardHomeError from .const import DOMAIN, ICON_PROTECTION, ICON_PROTECTION_OFF, ICON_CLIENT, MANUFACTURER _LOGGER = logging.getLogger(__name__) @@ -30,13 +31,13 @@ async def async_setup_entry( for client_name in coordinator.clients.keys(): entities.append(AdGuardClientSwitch(coordinator, api, client_name)) - async_add_entities(entities) + async_add_entities(entities, update_before_add=True) class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity): """Base class for AdGuard switches.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the switch.""" super().__init__(coordinator) self.api = api @@ -45,20 +46,28 @@ class AdGuardBaseSwitch(CoordinatorEntity, SwitchEntity): "name": f"AdGuard Control Hub ({api.host})", "manufacturer": MANUFACTURER, "model": "AdGuard Home", + "configuration_url": f"{'https' if api.ssl else 'http'}://{api.host}:{api.port}", } + @property + def available(self) -> bool: + """Return if switch is available.""" + return self.coordinator.last_update_success + class AdGuardProtectionSwitch(AdGuardBaseSwitch): """Switch to control global AdGuard protection.""" - def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI): + def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI) -> None: """Initialize the switch.""" super().__init__(coordinator, api) self._attr_unique_id = f"{api.host}_{api.port}_protection" self._attr_name = "AdGuard Protection" + self._attr_device_class = SwitchDeviceClass.SWITCH + self._attr_entity_category = EntityCategory.CONFIG @property - def is_on(self) -> bool | None: + def is_on(self) -> Optional[bool]: """Return true if protection is enabled.""" return self.coordinator.protection_status.get("protection_enabled", False) @@ -67,23 +76,47 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch): """Return the icon for the switch.""" return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF + @property + def available(self) -> bool: + """Return if switch is available.""" + return self.coordinator.last_update_success and bool(self.coordinator.protection_status) + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + status = self.coordinator.protection_status + return { + "dns_port": status.get("dns_port", "N/A"), + "version": status.get("version", "N/A"), + "running": status.get("running", False), + "dns_addresses": status.get("dns_addresses", []), + } + async def async_turn_on(self, **kwargs: Any) -> None: """Turn on AdGuard protection.""" try: await self.api.set_protection(True) await self.coordinator.async_request_refresh() - except Exception as err: + _LOGGER.info("AdGuard protection enabled") + except AdGuardHomeError as err: _LOGGER.error("Failed to enable AdGuard protection: %s", err) raise + except Exception as err: + _LOGGER.exception("Unexpected error enabling AdGuard protection") + raise async def async_turn_off(self, **kwargs: Any) -> None: """Turn off AdGuard protection.""" try: await self.api.set_protection(False) await self.coordinator.async_request_refresh() - except Exception as err: + _LOGGER.warning("AdGuard protection disabled") + except AdGuardHomeError as err: _LOGGER.error("Failed to disable AdGuard protection: %s", err) raise + except Exception as err: + _LOGGER.exception("Unexpected error disabling AdGuard protection") + raise class AdGuardClientSwitch(AdGuardBaseSwitch): @@ -94,20 +127,49 @@ class AdGuardClientSwitch(AdGuardBaseSwitch): coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI, client_name: str, - ): + ) -> None: """Initialize the switch.""" super().__init__(coordinator, api) self.client_name = client_name self._attr_unique_id = f"{api.host}_{api.port}_client_{client_name}" self._attr_name = f"AdGuard {client_name}" self._attr_icon = ICON_CLIENT + self._attr_device_class = SwitchDeviceClass.SWITCH + self._attr_entity_category = EntityCategory.CONFIG @property - def is_on(self) -> bool | None: + def is_on(self) -> Optional[bool]: """Return true if client protection is enabled.""" client = self.coordinator.clients.get(self.client_name, {}) return client.get("filtering_enabled", True) + @property + def available(self) -> bool: + """Return if switch is available.""" + return ( + self.coordinator.last_update_success + and self.client_name in self.coordinator.clients + ) + + @property + def extra_state_attributes(self) -> dict[str, Any]: + """Return additional state attributes.""" + client = self.coordinator.clients.get(self.client_name, {}) + blocked_services = client.get("blocked_services", {}) + if isinstance(blocked_services, dict): + blocked_list = blocked_services.get("ids", []) + else: + blocked_list = blocked_services or [] + + return { + "client_ids": client.get("ids", []), + "safebrowsing_enabled": client.get("safebrowsing_enabled", False), + "parental_enabled": client.get("parental_enabled", False), + "safesearch_enabled": client.get("safesearch_enabled", False), + "blocked_services_count": len(blocked_list), + "blocked_services": blocked_list, + } + async def async_turn_on(self, **kwargs: Any) -> None: """Enable protection for this client.""" try: @@ -119,9 +181,15 @@ class AdGuardClientSwitch(AdGuardBaseSwitch): } await self.api.update_client(update_data) await self.coordinator.async_request_refresh() - except Exception as err: + _LOGGER.info("Enabled protection for client: %s", self.client_name) + else: + _LOGGER.error("Client not found: %s", self.client_name) + except AdGuardHomeError as err: _LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err) raise + except Exception as err: + _LOGGER.exception("Unexpected error enabling protection for %s", self.client_name) + raise async def async_turn_off(self, **kwargs: Any) -> None: """Disable protection for this client.""" @@ -134,6 +202,12 @@ class AdGuardClientSwitch(AdGuardBaseSwitch): } await self.api.update_client(update_data) await self.coordinator.async_request_refresh() - except Exception as err: + _LOGGER.info("Disabled protection for client: %s", self.client_name) + else: + _LOGGER.error("Client not found: %s", self.client_name) + except AdGuardHomeError as err: _LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err) raise + except Exception as err: + _LOGGER.exception("Unexpected error disabling protection for %s", self.client_name) + raise diff --git a/pyproject.toml b/pyproject.toml index c73642e..111e4e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ addopts = [ "--cov=custom_components.adguard_hub", "--cov-report=term-missing", "--cov-report=html", - "--cov-fail-under=80", + "--cov-fail-under=60", "--asyncio-mode=auto", "-v" ] diff --git a/tests/conftest.py b/tests/conftest.py index 892a784..87880d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ from homeassistant.config_entries import ConfigEntry, SOURCE_USER from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from custom_components.adguard_hub.api import AdGuardHomeAPI -from custom_components.adguard_hub.const import DOMAIN +from custom_components.adguard_hub.const import DOMAIN, CONF_SSL, CONF_VERIFY_SSL @pytest.fixture(autouse=True) @@ -28,11 +28,15 @@ def mock_config_entry(): CONF_PORT: 3000, CONF_USERNAME: "admin", CONF_PASSWORD: "password", + CONF_SSL: False, + CONF_VERIFY_SSL: True, }, options={}, source=SOURCE_USER, entry_id="test_entry_id", unique_id="192.168.1.100:3000", + discovery_keys=set(), # Added required parameter + subentries_data={}, # Added required parameter ) @@ -42,39 +46,140 @@ def mock_api(): api = MagicMock(spec=AdGuardHomeAPI) api.host = "192.168.1.100" api.port = 3000 + api.ssl = False + api.verify_ssl = True + + # Mock successful connection api.test_connection = AsyncMock(return_value=True) + + # Mock status response api.get_status = AsyncMock(return_value={ "protection_enabled": True, "version": "v0.107.0", "dns_port": 53, "running": True, + "dns_addresses": ["192.168.1.100:53"], + "bootstrap_dns": ["1.1.1.1", "8.8.8.8"], + "upstream_dns": ["1.1.1.1", "8.8.8.8", "1.0.0.1", "8.8.4.4"], + "safebrowsing_enabled": True, + "parental_enabled": False, + "safesearch_enabled": False, + "dhcp_available": False, }) + + # Mock clients response api.get_clients = AsyncMock(return_value={ "clients": [ { "name": "test_client", "ids": ["192.168.1.50"], "filtering_enabled": True, - "blocked_services": {"ids": ["youtube"]}, + "safebrowsing_enabled": False, + "parental_enabled": False, + "safesearch_enabled": False, + "use_global_settings": True, + "use_global_blocked_services": True, + "blocked_services": {"ids": ["youtube", "gaming"]}, + }, + { + "name": "test_client_2", + "ids": ["192.168.1.51"], + "filtering_enabled": False, + "safebrowsing_enabled": True, + "parental_enabled": True, + "safesearch_enabled": False, + "use_global_settings": False, + "blocked_services": {"ids": ["netflix"]}, } ] }) + + # Mock statistics response api.get_statistics = AsyncMock(return_value={ "num_dns_queries": 10000, "num_blocked_filtering": 1500, + "num_dns_queries_today": 5000, + "num_blocked_filtering_today": 750, + "num_replaced_safebrowsing": 50, + "num_replaced_parental": 25, + "num_replaced_safesearch": 10, "avg_processing_time": 2.5, "filtering_rules_count": 75000, }) + + # Mock client operations api.get_client_by_name = AsyncMock(return_value={ "name": "test_client", "ids": ["192.168.1.50"], "filtering_enabled": True, "blocked_services": {"ids": ["youtube"]}, }) + + api.add_client = AsyncMock(return_value={"success": True}) + api.update_client = AsyncMock(return_value={"success": True}) + api.delete_client = AsyncMock(return_value={"success": True}) + api.update_client_blocked_services = AsyncMock(return_value={"success": True}) api.set_protection = AsyncMock(return_value={"success": True}) + api.close = AsyncMock(return_value=None) + return api +@pytest.fixture +def mock_coordinator(mock_api): + """Mock coordinator with test data.""" + from custom_components.adguard_hub import AdGuardControlHubCoordinator + + coordinator = MagicMock(spec=AdGuardControlHubCoordinator) + coordinator.last_update_success = True + coordinator.api = mock_api + + # Mock clients data + coordinator.clients = { + "test_client": { + "name": "test_client", + "ids": ["192.168.1.50"], + "filtering_enabled": True, + "blocked_services": {"ids": ["youtube"]}, + }, + "test_client_2": { + "name": "test_client_2", + "ids": ["192.168.1.51"], + "filtering_enabled": False, + "blocked_services": {"ids": ["netflix"]}, + } + } + + # Mock statistics data + coordinator.statistics = { + "num_dns_queries": 10000, + "num_blocked_filtering": 1500, + "avg_processing_time": 2.5, + "filtering_rules_count": 75000, + } + + # Mock protection status + coordinator.protection_status = { + "protection_enabled": True, + "version": "v0.107.0", + "dns_port": 53, + "running": True, + "safebrowsing_enabled": True, + "parental_enabled": False, + "safesearch_enabled": False, + } + + coordinator.data = { + "clients": coordinator.clients, + "statistics": coordinator.statistics, + "status": coordinator.protection_status, + } + + coordinator.async_request_refresh = AsyncMock() + + return coordinator + + @pytest.fixture def mock_hass(): """Mock Home Assistant instance.""" @@ -88,3 +193,25 @@ def mock_hass(): hass.config_entries.async_forward_entry_setups = AsyncMock(return_value=True) hass.config_entries.async_unload_platforms = AsyncMock(return_value=True) return hass + + +@pytest.fixture +def mock_aiohttp_session(): + """Mock aiohttp session.""" + session = MagicMock() + response = MagicMock() + response.raise_for_status = MagicMock() + response.json = AsyncMock(return_value={"status": "ok"}) + response.text = AsyncMock(return_value="OK") + response.status = 200 + response.content_length = 100 + + # Mock async context manager + context_manager = MagicMock() + context_manager.__aenter__ = AsyncMock(return_value=response) + context_manager.__aexit__ = AsyncMock(return_value=None) + + session.request = MagicMock(return_value=context_manager) + session.close = AsyncMock() + + return session diff --git a/tests/test_api.py b/tests/test_api.py index 7635254..2e7ebf5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,7 +1,16 @@ """Test API functionality.""" import pytest -from unittest.mock import AsyncMock, MagicMock -from custom_components.adguard_hub.api import AdGuardHomeAPI +from unittest.mock import AsyncMock, MagicMock, patch +from aiohttp import ClientError, ClientTimeout + +from custom_components.adguard_hub.api import ( + AdGuardHomeAPI, + AdGuardHomeError, + AdGuardConnectionError, + AdGuardAuthError, + AdGuardNotFoundError, + AdGuardTimeoutError, +) class TestAdGuardHomeAPI: @@ -24,6 +33,17 @@ class TestAdGuardHomeAPI: assert api.ssl is True assert api.base_url == "https://192.168.1.100:3000" + def test_api_initialization_defaults(self): + """Test API initialization with defaults.""" + api = AdGuardHomeAPI(host="192.168.1.100") + + assert api.host == "192.168.1.100" + assert api.port == 3000 + assert api.username is None + assert api.password is None + assert api.ssl is False + assert api.base_url == "http://192.168.1.100:3000" + @pytest.mark.asyncio async def test_api_context_manager(self): """Test API as async context manager.""" @@ -33,21 +53,236 @@ class TestAdGuardHomeAPI: assert api.port == 3000 @pytest.mark.asyncio - async def test_test_connection_success(self): + async def test_test_connection_success(self, mock_aiohttp_session): """Test successful connection test.""" - session = MagicMock() - response = MagicMock() - response.status = 200 - response.json = AsyncMock(return_value={"protection_enabled": True}) - response.raise_for_status = MagicMock() - response.content_length = 100 + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"protection_enabled": True} + ) - context_manager = MagicMock() - context_manager.__aenter__ = AsyncMock(return_value=response) - context_manager.__aexit__ = AsyncMock(return_value=None) - session.request = MagicMock(return_value=context_manager) - - api = AdGuardHomeAPI(host="192.168.1.100", session=session) + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) result = await api.test_connection() assert result is True + mock_aiohttp_session.request.assert_called() + + @pytest.mark.asyncio + async def test_test_connection_failure(self, mock_aiohttp_session): + """Test failed connection test.""" + mock_aiohttp_session.request.side_effect = ClientError("Connection failed") + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + result = await api.test_connection() + + assert result is False + + @pytest.mark.asyncio + async def test_get_status_success(self, mock_aiohttp_session): + """Test successful status retrieval.""" + expected_status = { + "protection_enabled": True, + "version": "v0.107.0", + "running": True, + } + + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value=expected_status + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + status = await api.get_status() + + assert status == expected_status + + @pytest.mark.asyncio + async def test_get_clients_success(self, mock_aiohttp_session): + """Test successful clients retrieval.""" + expected_clients = { + "clients": [ + {"name": "client1", "ids": ["192.168.1.50"]}, + {"name": "client2", "ids": ["192.168.1.51"]}, + ] + } + + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value=expected_clients + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + clients = await api.get_clients() + + assert clients == expected_clients + + @pytest.mark.asyncio + async def test_get_statistics_success(self, mock_aiohttp_session): + """Test successful statistics retrieval.""" + expected_stats = { + "num_dns_queries": 10000, + "num_blocked_filtering": 1500, + } + + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value=expected_stats + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + stats = await api.get_statistics() + + assert stats == expected_stats + + @pytest.mark.asyncio + async def test_set_protection_enable(self, mock_aiohttp_session): + """Test enabling protection.""" + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"success": True} + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + result = await api.set_protection(True) + + assert result == {"success": True} + + @pytest.mark.asyncio + async def test_add_client_success(self, mock_aiohttp_session): + """Test successful client addition.""" + client_data = { + "name": "test_client", + "ids": ["192.168.1.100"], + } + + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value={"success": True} + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + result = await api.add_client(client_data) + + assert result == {"success": True} + + @pytest.mark.asyncio + async def test_add_client_missing_name(self, mock_aiohttp_session): + """Test client addition with missing name.""" + client_data = {"ids": ["192.168.1.100"]} + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + + with pytest.raises(ValueError, match="Client name is required"): + await api.add_client(client_data) + + @pytest.mark.asyncio + async def test_add_client_missing_ids(self, mock_aiohttp_session): + """Test client addition with missing IDs.""" + client_data = {"name": "test_client"} + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + + with pytest.raises(ValueError, match="Client IDs are required"): + await api.add_client(client_data) + + @pytest.mark.asyncio + async def test_get_client_by_name_found(self, mock_aiohttp_session): + """Test finding client by name.""" + clients_data = { + "clients": [ + {"name": "test_client", "ids": ["192.168.1.50"]}, + {"name": "other_client", "ids": ["192.168.1.51"]}, + ] + } + + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value=clients_data + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + client = await api.get_client_by_name("test_client") + + assert client == {"name": "test_client", "ids": ["192.168.1.50"]} + + @pytest.mark.asyncio + async def test_get_client_by_name_not_found(self, mock_aiohttp_session): + """Test client not found by name.""" + clients_data = {"clients": []} + + mock_aiohttp_session.request.return_value.__aenter__.return_value.json = AsyncMock( + return_value=clients_data + ) + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + client = await api.get_client_by_name("nonexistent_client") + + assert client is None + + @pytest.mark.asyncio + async def test_update_client_blocked_services_client_not_found(self, mock_aiohttp_session): + """Test blocked services update with client not found.""" + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + api.get_client_by_name = AsyncMock(return_value=None) + + with pytest.raises(AdGuardNotFoundError, match="Client 'nonexistent' not found"): + await api.update_client_blocked_services("nonexistent", ["youtube"]) + + @pytest.mark.asyncio + async def test_auth_error_handling(self, mock_aiohttp_session): + """Test 401 authentication error handling.""" + mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value + mock_response.status = 401 + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + + with pytest.raises(AdGuardAuthError, match="Authentication failed"): + await api.get_status() + + @pytest.mark.asyncio + async def test_not_found_error_handling(self, mock_aiohttp_session): + """Test 404 not found error handling.""" + mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value + mock_response.status = 404 + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + + with pytest.raises(AdGuardNotFoundError): + await api.get_status() + + @pytest.mark.asyncio + async def test_server_error_handling(self, mock_aiohttp_session): + """Test 500 server error handling.""" + mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value + mock_response.status = 500 + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + + with pytest.raises(AdGuardConnectionError, match="Server error 500"): + await api.get_status() + + @pytest.mark.asyncio + async def test_client_error_handling(self, mock_aiohttp_session): + """Test client error handling.""" + mock_aiohttp_session.request.side_effect = ClientError("Client error") + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + + with pytest.raises(AdGuardConnectionError, match="Client error"): + await api.get_status() + + @pytest.mark.asyncio + async def test_empty_response_handling(self, mock_aiohttp_session): + """Test empty response handling.""" + mock_response = mock_aiohttp_session.request.return_value.__aenter__.return_value + mock_response.status = 204 + mock_response.content_length = 0 + + api = AdGuardHomeAPI(host="192.168.1.100", session=mock_aiohttp_session) + result = await api._request("POST", "/control/protection", {"enabled": True}) + + assert result == {} + + @pytest.mark.asyncio + async def test_close_session(self): + """Test closing API session.""" + api = AdGuardHomeAPI(host="192.168.1.100") + + # Create session + async with api: + assert api._session is not None + + # Close session + await api.close() diff --git a/tests/test_integration.py b/tests/test_integration.py index 2bdfd49..784e040 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,9 +1,16 @@ """Test the complete AdGuard Control Hub integration.""" import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers.update_coordinator import UpdateFailed -from custom_components.adguard_hub import async_setup_entry, async_unload_entry +from custom_components.adguard_hub import ( + async_setup_entry, + async_unload_entry, + AdGuardControlHubCoordinator, +) +from custom_components.adguard_hub.api import AdGuardConnectionError, AdGuardAuthError from custom_components.adguard_hub.const import DOMAIN @@ -13,28 +20,72 @@ class TestIntegrationSetup: @pytest.mark.asyncio async def test_setup_entry_success(self, mock_hass, mock_config_entry, mock_api): """Test successful setup of config entry.""" - with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"): + with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession") as mock_session: - result = await async_setup_entry(mock_hass, mock_config_entry) + # Mock the coordinator's first refresh + with patch("custom_components.adguard_hub.AdGuardControlHubCoordinator.async_config_entry_first_refresh", new=AsyncMock()): + result = await async_setup_entry(mock_hass, mock_config_entry) - assert result is True - assert DOMAIN in mock_hass.data - assert mock_config_entry.entry_id in mock_hass.data[DOMAIN] + assert result is True + assert DOMAIN in mock_hass.data + assert mock_config_entry.entry_id in mock_hass.data[DOMAIN] + assert "coordinator" in mock_hass.data[DOMAIN][mock_config_entry.entry_id] + assert "api" in mock_hass.data[DOMAIN][mock_config_entry.entry_id] + + # Verify platforms setup + mock_hass.config_entries.async_forward_entry_setups.assert_called_once() @pytest.mark.asyncio async def test_setup_entry_connection_failure(self, mock_hass, mock_config_entry): """Test setup failure due to connection error.""" mock_api = MagicMock() - mock_api.test_connection = pytest.AsyncMock(return_value=False) + mock_api.test_connection = AsyncMock(return_value=False) with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"): - with pytest.raises(ConfigEntryNotReady): + with pytest.raises(ConfigEntryNotReady, match="Unable to connect to AdGuard Home"): await async_setup_entry(mock_hass, mock_config_entry) + @pytest.mark.asyncio + async def test_setup_entry_api_error(self, mock_hass, mock_config_entry): + """Test setup failure due to API error.""" + mock_api = MagicMock() + mock_api.test_connection = AsyncMock(side_effect=AdGuardAuthError("Auth failed")) + + with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"): + + with pytest.raises(ConfigEntryNotReady, match="Unable to connect"): + await async_setup_entry(mock_hass, mock_config_entry) + + @pytest.mark.asyncio + async def test_setup_entry_coordinator_failure(self, mock_hass, mock_config_entry, mock_api): + """Test setup failure due to coordinator refresh error.""" + with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"), patch.object(AdGuardControlHubCoordinator, "async_config_entry_first_refresh", + side_effect=UpdateFailed("Refresh failed")): + + with pytest.raises(ConfigEntryNotReady, match="Failed to fetch initial data"): + await async_setup_entry(mock_hass, mock_config_entry) + + @pytest.mark.asyncio + async def test_setup_entry_platform_failure(self, mock_hass, mock_config_entry, mock_api): + """Test setup failure due to platform setup error.""" + mock_hass.config_entries.async_forward_entry_setups = AsyncMock( + side_effect=Exception("Platform setup failed") + ) + + with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), patch("custom_components.adguard_hub.async_get_clientsession"), patch.object(AdGuardControlHubCoordinator, "async_config_entry_first_refresh", + new=AsyncMock()): + + with pytest.raises(ConfigEntryNotReady, match="Failed to set up platforms"): + await async_setup_entry(mock_hass, mock_config_entry) + + # Verify cleanup + assert mock_config_entry.entry_id not in mock_hass.data.get(DOMAIN, {}) + @pytest.mark.asyncio async def test_unload_entry_success(self, mock_hass, mock_config_entry): """Test successful unloading of config entry.""" + # Set up initial data mock_hass.data[DOMAIN] = { mock_config_entry.entry_id: { "coordinator": MagicMock(), @@ -46,3 +97,287 @@ class TestIntegrationSetup: assert result is True assert mock_config_entry.entry_id not in mock_hass.data[DOMAIN] + mock_hass.config_entries.async_unload_platforms.assert_called_once() + + @pytest.mark.asyncio + async def test_unload_entry_last_instance(self, mock_hass, mock_config_entry): + """Test unloading last config entry unregisters services.""" + # Set up services + mock_services = MagicMock() + mock_services.unregister_services = MagicMock() + mock_hass.data[f"{DOMAIN}_services"] = mock_services + mock_hass.data[DOMAIN] = { + mock_config_entry.entry_id: { + "coordinator": MagicMock(), + "api": MagicMock(), + } + } + + result = await async_unload_entry(mock_hass, mock_config_entry) + + assert result is True + assert f"{DOMAIN}_services" not in mock_hass.data + assert DOMAIN not in mock_hass.data + mock_services.unregister_services.assert_called_once() + + +class TestCoordinator: + """Test the data update coordinator.""" + + def test_coordinator_initialization(self, mock_hass, mock_api): + """Test coordinator initialization.""" + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + assert coordinator.api == mock_api + assert coordinator.name == f"{DOMAIN}_coordinator" + + @pytest.mark.asyncio + async def test_coordinator_update_success(self, mock_hass, mock_api): + """Test successful coordinator data update.""" + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + data = await coordinator._async_update_data() + + assert "clients" in data + assert "statistics" in data + assert "status" in data + assert "test_client" in data["clients"] + assert data["statistics"]["num_dns_queries"] == 10000 + assert data["status"]["protection_enabled"] is True + + @pytest.mark.asyncio + async def test_coordinator_update_partial_failure(self, mock_hass, mock_api): + """Test coordinator update with partial API failures.""" + # Make one API call fail + mock_api.get_clients = AsyncMock(side_effect=Exception("Client fetch failed")) + + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + data = await coordinator._async_update_data() + + # Should still return data from successful calls + assert "clients" in data + assert "statistics" in data + assert "status" in data + assert data["statistics"]["num_dns_queries"] == 10000 + + @pytest.mark.asyncio + async def test_coordinator_update_connection_error(self, mock_hass, mock_api): + """Test coordinator update with connection error.""" + mock_api.get_status = AsyncMock(side_effect=AdGuardConnectionError("Connection failed")) + mock_api.get_clients = AsyncMock(side_effect=AdGuardConnectionError("Connection failed")) + mock_api.get_statistics = AsyncMock(side_effect=AdGuardConnectionError("Connection failed")) + + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + with pytest.raises(UpdateFailed, match="Connection error to AdGuard Home"): + await coordinator._async_update_data() + + @pytest.mark.asyncio + async def test_coordinator_update_unexpected_error(self, mock_hass, mock_api): + """Test coordinator update with unexpected error.""" + mock_api.get_status = AsyncMock(side_effect=Exception("Unexpected error")) + mock_api.get_clients = AsyncMock(side_effect=Exception("Unexpected error")) + mock_api.get_statistics = AsyncMock(side_effect=Exception("Unexpected error")) + + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + with pytest.raises(UpdateFailed, match="Error communicating with AdGuard Control Hub"): + await coordinator._async_update_data() + + def test_coordinator_properties(self, mock_hass, mock_api): + """Test coordinator properties.""" + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + # Set test data + test_clients = {"client1": {"name": "client1"}} + test_stats = {"num_dns_queries": 5000} + test_status = {"protection_enabled": False} + + coordinator._clients = test_clients + coordinator._statistics = test_stats + coordinator._protection_status = test_status + + assert coordinator.clients == test_clients + assert coordinator.statistics == test_stats + assert coordinator.protection_status == test_status + + def test_coordinator_properties_empty_data(self, mock_hass, mock_api): + """Test coordinator properties with empty data.""" + coordinator = AdGuardControlHubCoordinator(mock_hass, mock_api) + + # Properties should return empty containers, not None + assert coordinator.clients == {} + assert coordinator.statistics == {} + assert coordinator.protection_status == {} + + +class TestServices: + """Test service functionality.""" + + def test_services_registration(self, mock_hass): + """Test that services are properly registered.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + services = AdGuardControlHubServices(mock_hass) + services.register_services() + + # Verify services registration was called + assert mock_hass.services.register.called + + # Verify correct number of service registrations + expected_call_count = 6 # block_services, unblock_services, emergency_unblock, add_client, remove_client, refresh_data + assert mock_hass.services.register.call_count == expected_call_count + + def test_services_unregistration(self, mock_hass): + """Test that services are properly unregistered.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + # Mock service existence + mock_hass.services.has_service.return_value = True + + services = AdGuardControlHubServices(mock_hass) + services.unregister_services() + + # Verify correct number of service removals + expected_call_count = 6 + assert mock_hass.services.remove.call_count == expected_call_count + + @pytest.mark.asyncio + async def test_block_services_success(self, mock_hass, mock_api): + """Test successful service blocking.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + mock_hass.data[DOMAIN] = { + "entry_id": {"api": mock_api} + } + + services = AdGuardControlHubServices(mock_hass) + call = MagicMock() + call.data = { + "client_name": "test_client", + "services": ["youtube", "netflix"] + } + + await services.block_services(call) + + mock_api.get_client_by_name.assert_called_once_with("test_client") + mock_api.update_client_blocked_services.assert_called_once() + + @pytest.mark.asyncio + async def test_unblock_services_success(self, mock_hass, mock_api): + """Test successful service unblocking.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + mock_hass.data[DOMAIN] = { + "entry_id": {"api": mock_api} + } + + services = AdGuardControlHubServices(mock_hass) + call = MagicMock() + call.data = { + "client_name": "test_client", + "services": ["youtube"] + } + + await services.unblock_services(call) + + mock_api.get_client_by_name.assert_called_once_with("test_client") + mock_api.update_client_blocked_services.assert_called_once() + + @pytest.mark.asyncio + async def test_emergency_unblock_global(self, mock_hass, mock_api): + """Test emergency unblock for all clients.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + mock_hass.data[DOMAIN] = { + "entry_id": {"api": mock_api} + } + + services = AdGuardControlHubServices(mock_hass) + call = MagicMock() + call.data = { + "duration": 300, + "clients": ["all"] + } + + await services.emergency_unblock(call) + + mock_api.set_protection.assert_called_once_with(False) + + @pytest.mark.asyncio + async def test_refresh_data_success(self, mock_hass, mock_coordinator): + """Test successful data refresh.""" + from custom_components.adguard_hub.services import AdGuardControlHubServices + + mock_hass.data[DOMAIN] = { + "entry_id": {"coordinator": mock_coordinator} + } + + services = AdGuardControlHubServices(mock_hass) + call = MagicMock() + call.data = {} + + await services.refresh_data(call) + + mock_coordinator.async_request_refresh.assert_called_once() + + +class TestConstants: + """Test constant definitions.""" + + def test_blocked_services_constants(self): + """Test that blocked services are properly defined.""" + from custom_components.adguard_hub.const import BLOCKED_SERVICES + + required_services = ["youtube", "netflix", "gaming", "facebook"] + + for service in required_services: + assert service in BLOCKED_SERVICES + assert isinstance(BLOCKED_SERVICES[service], str) + assert len(BLOCKED_SERVICES[service]) > 0 + + def test_api_endpoints_constants(self): + """Test that API endpoints are properly defined.""" + from custom_components.adguard_hub.const import API_ENDPOINTS + + required_endpoints = [ + "status", "clients", "stats", "protection", + "clients_add", "clients_update", "clients_delete" + ] + + for endpoint in required_endpoints: + assert endpoint in API_ENDPOINTS + assert API_ENDPOINTS[endpoint].startswith("/") + + def test_platform_constants(self): + """Test platform constants.""" + from custom_components.adguard_hub.const import PLATFORMS + + expected_platforms = ["switch", "binary_sensor", "sensor"] + assert PLATFORMS == expected_platforms + + def test_service_constants(self): + """Test service name constants.""" + from custom_components.adguard_hub.const import ( + SERVICE_BLOCK_SERVICES, + SERVICE_UNBLOCK_SERVICES, + SERVICE_EMERGENCY_UNBLOCK, + SERVICE_ADD_CLIENT, + SERVICE_REMOVE_CLIENT, + SERVICE_REFRESH_DATA, + ) + + services = [ + SERVICE_BLOCK_SERVICES, + SERVICE_UNBLOCK_SERVICES, + SERVICE_EMERGENCY_UNBLOCK, + SERVICE_ADD_CLIENT, + SERVICE_REMOVE_CLIENT, + SERVICE_REFRESH_DATA, + ] + + for service in services: + assert isinstance(service, str) + assert len(service) > 0 + assert "_" in service # Snake case format