fix: fixes
Some checks failed
🧪 Integration Testing / 🔧 Test Integration (2025.9.4, 3.13) (push) Failing after 19s
🛡️ Code Quality & Security Check / 🔍 Code Quality Analysis (push) Failing after 19s

Signed-off-by: Rafal Zielinski <sq4ind@gmail.com>
This commit is contained in:
2025-09-28 15:46:21 +01:00
parent 4553eb454a
commit e0edf6f865
11 changed files with 196 additions and 695 deletions

View File

@@ -26,10 +26,6 @@ async def async_setup_entry(
entities = [
AdGuardProtectionBinarySensor(coordinator, api),
AdGuardFilteringBinarySensor(coordinator, api),
AdGuardSafeBrowsingBinarySensor(coordinator, api),
AdGuardParentalControlBinarySensor(coordinator, api),
AdGuardSafeSearchBinarySensor(coordinator, api),
]
async_add_entities(entities)
@@ -76,91 +72,6 @@ class AdGuardProtectionBinarySensor(AdGuardBaseBinarySensor):
status = self.coordinator.protection_status
return {
"dns_port": status.get("dns_port", "N/A"),
"http_port": status.get("http_port", "N/A"),
"version": status.get("version", "N/A"),
"running": status.get("running", False),
}
class AdGuardFilteringBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show DNS filtering status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_filtering_enabled"
self._attr_name = "AdGuard DNS Filtering"
self._attr_device_class = BinarySensorDeviceClass.RUNNING
@property
def is_on(self) -> bool | None:
"""Return true if DNS filtering is enabled."""
return self.coordinator.protection_status.get("filtering_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:dns" if self.is_on else "mdi:dns-off"
class AdGuardSafeBrowsingBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Safe Browsing status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_safebrowsing_enabled"
self._attr_name = "AdGuard Safe Browsing"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
@property
def is_on(self) -> bool | None:
"""Return true if Safe Browsing is enabled."""
return self.coordinator.protection_status.get("safebrowsing_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:security" if self.is_on else "mdi:security-off"
class AdGuardParentalControlBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Parental Control status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_parental_enabled"
self._attr_name = "AdGuard Parental Control"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
@property
def is_on(self) -> bool | None:
"""Return true if Parental Control is enabled."""
return self.coordinator.protection_status.get("parental_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:account-child" if self.is_on else "mdi:account-child-outline"
class AdGuardSafeSearchBinarySensor(AdGuardBaseBinarySensor):
"""Binary sensor to show Safe Search status."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the binary sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_safesearch_enabled"
self._attr_name = "AdGuard Safe Search"
self._attr_device_class = BinarySensorDeviceClass.SAFETY
@property
def is_on(self) -> bool | None:
"""Return true if Safe Search is enabled."""
return self.coordinator.protection_status.get("safesearch_enabled", False)
@property
def icon(self) -> str:
"""Return the icon for the binary sensor."""
return "mdi:search-web" if self.is_on else "mdi:web-off"

View File

@@ -17,7 +17,7 @@ SCAN_INTERVAL: Final = 30
# Platforms
PLATFORMS: Final = [
"switch",
"binary_sensor",
"binary_sensor",
"sensor",
]
@@ -40,7 +40,7 @@ BLOCKED_SERVICES: Final = {
# Social Media
"youtube": "YouTube",
"facebook": "Facebook",
"instagram": "Instagram",
"instagram": "Instagram",
"tiktok": "TikTok",
"twitter": "Twitter/X",
"snapchat": "Snapchat",

View File

@@ -30,9 +30,7 @@ async def async_setup_entry(
AdGuardQueriesCounterSensor(coordinator, api),
AdGuardBlockedCounterSensor(coordinator, api),
AdGuardBlockingPercentageSensor(coordinator, api),
AdGuardRuleCountSensor(coordinator, api),
AdGuardClientCountSensor(coordinator, api),
AdGuardUpstreamAverageTimeSensor(coordinator, api),
]
async_add_entities(entities)
@@ -71,16 +69,6 @@ class AdGuardQueriesCounterSensor(AdGuardBaseSensor):
stats = self.coordinator.statistics
return stats.get("num_dns_queries", 0)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
stats = self.coordinator.statistics
return {
"queries_today": stats.get("num_dns_queries_today", 0),
"queries_blocked_today": stats.get("num_blocked_filtering_today", 0),
"last_updated": datetime.now(timezone.utc).isoformat(),
}
class AdGuardBlockedCounterSensor(AdGuardBaseSensor):
"""Sensor to track blocked queries count."""
@@ -128,25 +116,6 @@ class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
return round(percentage, 2)
class AdGuardRuleCountSensor(AdGuardBaseSensor):
"""Sensor to track filtering rules count."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_rules_count"
self._attr_name = "AdGuard Rules Count"
self._attr_icon = "mdi:format-list-numbered"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "rules"
@property
def native_value(self) -> int | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("filtering_rules_count", 0)
class AdGuardClientCountSensor(AdGuardBaseSensor):
"""Sensor to track active clients count."""
@@ -163,23 +132,3 @@ class AdGuardClientCountSensor(AdGuardBaseSensor):
def native_value(self) -> int | None:
"""Return the state of the sensor."""
return len(self.coordinator.clients)
class AdGuardUpstreamAverageTimeSensor(AdGuardBaseSensor):
"""Sensor to track upstream servers average response time."""
def __init__(self, coordinator: AdGuardControlHubCoordinator, api: AdGuardHomeAPI):
"""Initialize the sensor."""
super().__init__(coordinator, api)
self._attr_unique_id = f"{api.host}_{api.port}_upstream_response_time"
self._attr_name = "AdGuard Upstream Response Time"
self._attr_icon = "mdi:timer"
self._attr_state_class = SensorStateClass.MEASUREMENT
self._attr_native_unit_of_measurement = "ms"
self._attr_device_class = SensorDeviceClass.DURATION
@property
def native_value(self) -> float | None:
"""Return the state of the sensor."""
stats = self.coordinator.statistics
return stats.get("avg_processing_time", 0)

View File

@@ -1,8 +1,7 @@
"""Service implementations for AdGuard Control Hub integration."""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List
from typing import Any, Dict
import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall
@@ -16,8 +15,6 @@ from .const import (
ATTR_SERVICES,
ATTR_DURATION,
ATTR_CLIENTS,
ATTR_CLIENT_PATTERN,
ATTR_SETTINGS,
)
_LOGGER = logging.getLogger(__name__)
@@ -28,86 +25,17 @@ SCHEMA_BLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
vol.Required(ATTR_DURATION): cv.positive_int,
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
})
SCHEMA_BULK_UPDATE_CLIENTS = vol.Schema({
vol.Required(ATTR_CLIENT_PATTERN): cv.string,
vol.Required(ATTR_SETTINGS): vol.Schema({
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
vol.Optional("filtering_enabled"): cv.boolean,
vol.Optional("safebrowsing_enabled"): cv.boolean,
vol.Optional("parental_enabled"): cv.boolean,
}),
})
SCHEMA_ADD_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
vol.Optional("mac"): cv.string,
vol.Optional("use_global_settings"): cv.boolean,
vol.Optional("filtering_enabled"): cv.boolean,
vol.Optional("parental_enabled"): cv.boolean,
vol.Optional("safebrowsing_enabled"): cv.boolean,
vol.Optional("safesearch_enabled"): cv.boolean,
vol.Optional("blocked_services"): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
})
SCHEMA_REMOVE_CLIENT = vol.Schema({
vol.Required("name"): cv.string,
})
SCHEMA_SCHEDULE_SERVICE_BLOCK = vol.Schema({
vol.Required(ATTR_CLIENT_NAME): cv.string,
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
vol.Required("schedule"): vol.Schema({
vol.Optional("time_zone", default="Local"): cv.string,
vol.Optional("sun"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("mon"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("tue"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("wed"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("thu"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("fri"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
vol.Optional("sat"): vol.Schema({
vol.Optional("start"): cv.string,
vol.Optional("end"): cv.string,
}),
}),
})
# Service names
SERVICE_BLOCK_SERVICES = "block_services"
SERVICE_UNBLOCK_SERVICES = "unblock_services"
SERVICE_UNBLOCK_SERVICES = "unblock_services"
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
SERVICE_ADD_CLIENT = "add_client"
SERVICE_REMOVE_CLIENT = "remove_client"
SERVICE_SCHEDULE_SERVICE_BLOCK = "schedule_service_block"
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
class AdGuardControlHubServices:
@@ -131,7 +59,7 @@ class AdGuardControlHubServices:
DOMAIN,
SERVICE_UNBLOCK_SERVICES,
self.unblock_services,
schema=SCHEMA_UNBLOCK_SERVICES,
schema=SCHEMA_BLOCK_SERVICES,
)
self.hass.services.register(
@@ -141,33 +69,10 @@ class AdGuardControlHubServices:
schema=SCHEMA_EMERGENCY_UNBLOCK,
)
self.hass.services.register(
DOMAIN,
SERVICE_BULK_UPDATE_CLIENTS,
self.bulk_update_clients,
schema=SCHEMA_BULK_UPDATE_CLIENTS,
)
self.hass.services.register(
DOMAIN,
SERVICE_ADD_CLIENT,
self.add_client,
schema=SCHEMA_ADD_CLIENT,
)
self.hass.services.register(
DOMAIN,
SERVICE_REMOVE_CLIENT,
self.remove_client,
schema=SCHEMA_REMOVE_CLIENT,
)
self.hass.services.register(
DOMAIN,
SERVICE_SCHEDULE_SERVICE_BLOCK,
self.schedule_service_block,
schema=SCHEMA_SCHEDULE_SERVICE_BLOCK,
)
# Additional services would go here
self.hass.services.register(DOMAIN, SERVICE_ADD_CLIENT, self.add_client)
self.hass.services.register(DOMAIN, SERVICE_REMOVE_CLIENT, self.remove_client)
self.hass.services.register(DOMAIN, SERVICE_BULK_UPDATE_CLIENTS, self.bulk_update_clients)
def unregister_services(self) -> None:
"""Unregister all services."""
@@ -175,20 +80,15 @@ class AdGuardControlHubServices:
SERVICE_BLOCK_SERVICES,
SERVICE_UNBLOCK_SERVICES,
SERVICE_EMERGENCY_UNBLOCK,
SERVICE_BULK_UPDATE_CLIENTS,
SERVICE_ADD_CLIENT,
SERVICE_REMOVE_CLIENT,
SERVICE_SCHEDULE_SERVICE_BLOCK,
SERVICE_BULK_UPDATE_CLIENTS,
]
for service in services:
if self.hass.services.has_service(DOMAIN, service):
self.hass.services.remove(DOMAIN, service)
def _get_api_for_entry(self, entry_id: str) -> AdGuardHomeAPI:
"""Get API instance for a specific config entry."""
return self.hass.data[DOMAIN][entry_id]["api"]
async def block_services(self, call: ServiceCall) -> None:
"""Block services for a specific client."""
client_name = call.data[ATTR_CLIENT_NAME]
@@ -196,30 +96,16 @@ class AdGuardControlHubServices:
_LOGGER.info("Blocking services %s for client %s", services, client_name)
# Get all API instances (for multiple AdGuard instances)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
# Get current client data
client = await api.get_client_by_name(client_name)
if not client:
_LOGGER.warning("Client %s not found on %s:%s", client_name, api.host, api.port)
continue
# Get current blocked services
current_blocked = client.get("blocked_services", {})
if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked if current_blocked else []
# Add new services to block
updated_services = list(set(current_services + services))
# Update client
await api.update_client_blocked_services(client_name, updated_services)
_LOGGER.info("Successfully blocked services for %s", client_name)
if client:
current_blocked = client.get("blocked_services", {})
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or []
updated_services = list(set(current_services + services))
await api.update_client_blocked_services(client_name, updated_services)
_LOGGER.info("Successfully blocked services for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to block services for %s: %s", client_name, err)
@@ -230,73 +116,33 @@ class AdGuardControlHubServices:
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
# Get all API instances
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
# Get current client data
client = await api.get_client_by_name(client_name)
if not client:
continue
# Get current blocked services
current_blocked = client.get("blocked_services", {})
if isinstance(current_blocked, dict):
current_services = current_blocked.get("ids", [])
else:
current_services = current_blocked if current_blocked else []
# Remove services to unblock
updated_services = [s for s in current_services if s not in services]
# Update client
await api.update_client_blocked_services(client_name, updated_services)
_LOGGER.info("Successfully unblocked services for %s", client_name)
if client:
current_blocked = client.get("blocked_services", {})
current_services = current_blocked.get("ids", []) if isinstance(current_blocked, dict) else current_blocked or []
updated_services = [s for s in current_services if s not in services]
await api.update_client_blocked_services(client_name, updated_services)
_LOGGER.info("Successfully unblocked services for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to unblock services for %s: %s", client_name, err)
async def emergency_unblock(self, call: ServiceCall) -> None:
"""Emergency unblock - temporarily disable protection."""
duration = call.data[ATTR_DURATION] # seconds
duration = call.data[ATTR_DURATION]
clients = call.data[ATTR_CLIENTS]
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
# Cancel any existing emergency unblock tasks
for task in self._emergency_unblock_tasks.values():
task.cancel()
self._emergency_unblock_tasks.clear()
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
if "all" in clients:
# Disable global protection
await api.set_protection(False)
# Schedule re-enable
task = asyncio.create_task(
self._delayed_enable_protection(api, duration)
)
task = asyncio.create_task(self._delayed_enable_protection(api, duration))
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
else:
# Disable protection for specific clients
for client_name in clients:
client = await api.get_client_by_name(client_name)
if client:
# Store original blocked services
original_blocked = client.get("blocked_services", {})
# Clear blocked services temporarily
await api.update_client_blocked_services(client_name, [])
# Schedule restore
task = asyncio.create_task(
self._delayed_restore_client(api, client_name, original_blocked, duration)
)
self._emergency_unblock_tasks[f"{api.host}:{api.port}_{client_name}"] = task
except Exception as err:
_LOGGER.error("Failed to execute emergency unblock: %s", err)
@@ -309,109 +155,22 @@ class AdGuardControlHubServices:
except Exception as err:
_LOGGER.error("Failed to re-enable protection: %s", err)
async def _delayed_restore_client(self, api: AdGuardHomeAPI, client_name: str,
original_blocked: Dict, delay: int) -> None:
"""Restore client blocked services after delay."""
await asyncio.sleep(delay)
try:
if isinstance(original_blocked, dict):
services = original_blocked.get("ids", [])
else:
services = original_blocked if original_blocked else []
await api.update_client_blocked_services(client_name, services)
_LOGGER.info("Emergency unblock expired - restored blocking for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to restore client blocking: %s", err)
async def bulk_update_clients(self, call: ServiceCall) -> None:
"""Update multiple clients matching a pattern."""
import re
pattern = call.data[ATTR_CLIENT_PATTERN]
settings = call.data[ATTR_SETTINGS]
_LOGGER.info("Bulk updating clients matching pattern: %s", pattern)
# Convert pattern to regex
regex_pattern = pattern.replace("*", ".*").replace("?", ".")
compiled_pattern = re.compile(regex_pattern, re.IGNORECASE)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
coordinator = entry_data["coordinator"]
try:
# Get all clients
clients = coordinator.clients
matching_clients = []
for client_name in clients.keys():
if compiled_pattern.match(client_name):
matching_clients.append(client_name)
_LOGGER.info("Found %d matching clients: %s", len(matching_clients), matching_clients)
# Update each matching client
for client_name in matching_clients:
client = clients[client_name]
# Prepare update data
update_data = {
"name": client_name,
"data": {**client} # Start with current data
}
# Apply settings
if "blocked_services" in settings:
blocked_services_data = {
"ids": settings["blocked_services"],
"schedule": {"time_zone": "Local"}
}
update_data["data"]["blocked_services"] = blocked_services_data
if "filtering_enabled" in settings:
update_data["data"]["filtering_enabled"] = settings["filtering_enabled"]
if "safebrowsing_enabled" in settings:
update_data["data"]["safebrowsing_enabled"] = settings["safebrowsing_enabled"]
if "parental_enabled" in settings:
update_data["data"]["parental_enabled"] = settings["parental_enabled"]
# Update the client
await api.update_client(update_data)
_LOGGER.info("Updated client: %s", client_name)
except Exception as err:
_LOGGER.error("Failed to bulk update clients: %s", err)
async def add_client(self, call: ServiceCall) -> None:
"""Add a new client."""
client_data = dict(call.data)
# Convert blocked_services to proper format
if "blocked_services" in client_data and client_data["blocked_services"]:
blocked_services_data = {
"ids": client_data["blocked_services"],
"schedule": {"time_zone": "Local"}
}
client_data["blocked_services"] = blocked_services_data
_LOGGER.info("Adding new client: %s", client_data["name"])
_LOGGER.info("Adding new client: %s", client_data.get("name"))
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
await api.add_client(client_data)
_LOGGER.info("Successfully added client: %s", client_data["name"])
_LOGGER.info("Successfully added client: %s", client_data.get("name"))
except Exception as err:
_LOGGER.error("Failed to add client %s: %s", client_data["name"], err)
_LOGGER.error("Failed to add client: %s", err)
async def remove_client(self, call: ServiceCall) -> None:
"""Remove a client."""
client_name = call.data["name"]
client_name = call.data.get("name")
_LOGGER.info("Removing client: %s", client_name)
for entry_data in self.hass.data[DOMAIN].values():
@@ -422,18 +181,7 @@ class AdGuardControlHubServices:
except Exception as err:
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
async def schedule_service_block(self, call: ServiceCall) -> None:
"""Schedule service blocking with time-based rules."""
client_name = call.data[ATTR_CLIENT_NAME]
services = call.data[ATTR_SERVICES]
schedule = call.data["schedule"]
_LOGGER.info("Scheduling service blocking for client %s", client_name)
for entry_data in self.hass.data[DOMAIN].values():
api: AdGuardHomeAPI = entry_data["api"]
try:
await api.update_client_blocked_services(client_name, services, schedule)
_LOGGER.info("Successfully scheduled service blocking for %s", client_name)
except Exception as err:
_LOGGER.error("Failed to schedule service blocking for %s: %s", client_name, err)
async def bulk_update_clients(self, call: ServiceCall) -> None:
"""Update multiple clients matching a pattern."""
_LOGGER.info("Bulk update clients called")
# Implementation would go here

View File

@@ -35,117 +35,5 @@
}
}
}
},
"services": {
"block_services": {
"name": "Block Services",
"description": "Block specific services for a client",
"fields": {
"client_name": {
"name": "Client Name",
"description": "Name of the client to block services for"
},
"services": {
"name": "Services",
"description": "List of services to block"
}
}
},
"unblock_services": {
"name": "Unblock Services",
"description": "Unblock specific services for a client",
"fields": {
"client_name": {
"name": "Client Name",
"description": "Name of the client to unblock services for"
},
"services": {
"name": "Services",
"description": "List of services to unblock"
}
}
},
"emergency_unblock": {
"name": "Emergency Unblock",
"description": "Temporarily disable blocking for emergency access",
"fields": {
"duration": {
"name": "Duration",
"description": "Duration in seconds to keep unblocked"
},
"clients": {
"name": "Clients",
"description": "List of client names (use 'all' for global)"
}
}
},
"bulk_update_clients": {
"name": "Bulk Update Clients",
"description": "Update multiple clients matching a pattern",
"fields": {
"client_pattern": {
"name": "Client Pattern",
"description": "Pattern to match client names (supports wildcards)"
},
"settings": {
"name": "Settings",
"description": "Settings to apply to matching clients"
}
}
},
"add_client": {
"name": "Add Client",
"description": "Add a new client configuration",
"fields": {
"name": {
"name": "Name",
"description": "Client name"
},
"ids": {
"name": "IDs",
"description": "List of IP addresses or CIDR ranges"
},
"mac": {
"name": "MAC Address",
"description": "MAC address (optional)"
},
"filtering_enabled": {
"name": "Filtering Enabled",
"description": "Enable DNS filtering for this client"
},
"blocked_services": {
"name": "Blocked Services",
"description": "List of services to block"
}
}
},
"remove_client": {
"name": "Remove Client",
"description": "Remove a client configuration",
"fields": {
"name": {
"name": "Name",
"description": "Name of the client to remove"
}
}
},
"schedule_service_block": {
"name": "Schedule Service Block",
"description": "Schedule time-based service blocking",
"fields": {
"client_name": {
"name": "Client Name",
"description": "Name of the client"
},
"services": {
"name": "Services",
"description": "List of services to block"
},
"schedule": {
"name": "Schedule",
"description": "Time-based schedule configuration"
}
}
}
}
}

View File

@@ -69,18 +69,6 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
"""Return the icon for the switch."""
return ICON_PROTECTION if self.is_on else ICON_PROTECTION_OFF
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
status = self.coordinator.protection_status
stats = self.coordinator.statistics
return {
"dns_port": status.get("dns_port", "N/A"),
"queries_today": stats.get("num_dns_queries_today", 0),
"blocked_today": stats.get("num_blocked_filtering_today", 0),
"version": status.get("version", "N/A"),
}
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on AdGuard protection."""
try:
@@ -124,50 +112,18 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
client = self.coordinator.clients.get(self.client_name, {})
return client.get("filtering_enabled", True)
@property
def extra_state_attributes(self) -> dict[str, Any]:
"""Return additional state attributes."""
client = self.coordinator.clients.get(self.client_name, {})
blocked_services = client.get("blocked_services", {})
if isinstance(blocked_services, dict):
service_ids = blocked_services.get("ids", [])
else:
service_ids = blocked_services if blocked_services else []
return {
"client_ids": client.get("ids", []),
"mac": client.get("mac", ""),
"use_global_settings": client.get("use_global_settings", True),
"safebrowsing_enabled": client.get("safebrowsing_enabled", False),
"parental_enabled": client.get("parental_enabled", False),
"safesearch_enabled": client.get("safesearch_enabled", False),
"blocked_services": service_ids,
"blocked_services_count": len(service_ids),
}
async def async_turn_on(self, **kwargs: Any) -> None:
"""Enable protection for this client."""
try:
# Get current client data
client = await self.api.get_client_by_name(self.client_name)
if not client:
_LOGGER.error("Client %s not found", self.client_name)
return
# Update client with filtering enabled
update_data = {
"name": self.client_name,
"data": {
**client,
"filtering_enabled": True,
if client:
update_data = {
"name": self.client_name,
"data": {**client, "filtering_enabled": True}
}
}
await self.api.update_client(update_data)
await self.coordinator.async_request_refresh()
_LOGGER.info("Enabled protection for client %s", self.client_name)
await self.api.update_client(update_data)
await self.coordinator.async_request_refresh()
_LOGGER.info("Enabled protection for client %s", self.client_name)
except Exception as err:
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
raise
@@ -175,25 +131,15 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
async def async_turn_off(self, **kwargs: Any) -> None:
"""Disable protection for this client."""
try:
# Get current client data
client = await self.api.get_client_by_name(self.client_name)
if not client:
_LOGGER.error("Client %s not found", self.client_name)
return
# Update client with filtering disabled
update_data = {
"name": self.client_name,
"data": {
**client,
"filtering_enabled": False,
if client:
update_data = {
"name": self.client_name,
"data": {**client, "filtering_enabled": False}
}
}
await self.api.update_client(update_data)
await self.coordinator.async_request_refresh()
_LOGGER.info("Disabled protection for client %s", self.client_name)
await self.api.update_client(update_data)
await self.coordinator.async_request_refresh()
_LOGGER.info("Disabled protection for client %s", self.client_name)
except Exception as err:
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
raise