Revert "fix: Complete fixes: tests, workflows, coverage"
This reverts commit ed94d40e96
.
This commit is contained in:
@@ -1,81 +1,94 @@
|
||||
"""AdGuard Control Hub services."""
|
||||
"""Service implementations for AdGuard Control Hub integration."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
import voluptuous as vol
|
||||
from homeassistant.core import HomeAssistant, ServiceCall
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
import voluptuous as vol
|
||||
|
||||
from .api import AdGuardConnectionError, AdGuardHomeError
|
||||
from .api import AdGuardHomeAPI, AdGuardHomeError
|
||||
from .const import (
|
||||
ATTR_CLIENT_NAME,
|
||||
ATTR_CLIENTS,
|
||||
ATTR_DURATION,
|
||||
ATTR_SERVICES,
|
||||
BLOCKED_SERVICES,
|
||||
DOMAIN,
|
||||
SERVICE_ADD_CLIENT,
|
||||
BLOCKED_SERVICES,
|
||||
ATTR_CLIENT_NAME,
|
||||
ATTR_SERVICES,
|
||||
ATTR_DURATION,
|
||||
ATTR_CLIENTS,
|
||||
ATTR_ENABLED,
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
SERVICE_REFRESH_DATA,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
SERVICE_ADD_CLIENT,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
SERVICE_REFRESH_DATA,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Service schemas
|
||||
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_UNBLOCK_SERVICES = vol.Schema({
|
||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||
})
|
||||
|
||||
SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
||||
vol.Required(ATTR_DURATION): cv.positive_int,
|
||||
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
||||
})
|
||||
|
||||
SCHEMA_ADD_CLIENT = vol.Schema({
|
||||
vol.Required("name"): cv.string,
|
||||
vol.Required("ids"): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional("filtering_enabled", default=True): cv.boolean,
|
||||
vol.Optional("safebrowsing_enabled", default=False): cv.boolean,
|
||||
vol.Optional("parental_enabled", default=False): cv.boolean,
|
||||
vol.Optional("safesearch_enabled", default=False): cv.boolean,
|
||||
vol.Optional("use_global_blocked_services", default=True): cv.boolean,
|
||||
vol.Optional("blocked_services", default=[]): vol.All(cv.ensure_list, [cv.string]),
|
||||
})
|
||||
|
||||
SCHEMA_REMOVE_CLIENT = vol.Schema({
|
||||
vol.Required("name"): cv.string,
|
||||
})
|
||||
|
||||
SCHEMA_REFRESH_DATA = vol.Schema({})
|
||||
|
||||
|
||||
class AdGuardControlHubServices:
|
||||
"""AdGuard Control Hub services."""
|
||||
"""Handle services for AdGuard Control Hub."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize services."""
|
||||
"""Initialize the services."""
|
||||
self.hass = hass
|
||||
|
||||
def register_services(self) -> None:
|
||||
"""Register services."""
|
||||
# FIXED: All service constants are now properly defined
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
self.block_services,
|
||||
)
|
||||
"""Register all services."""
|
||||
_LOGGER.debug("Registering AdGuard Control Hub services")
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
self.unblock_services,
|
||||
)
|
||||
services = [
|
||||
(SERVICE_BLOCK_SERVICES, self.block_services, SCHEMA_BLOCK_SERVICES),
|
||||
(SERVICE_UNBLOCK_SERVICES, self.unblock_services, SCHEMA_UNBLOCK_SERVICES),
|
||||
(SERVICE_EMERGENCY_UNBLOCK, self.emergency_unblock, SCHEMA_EMERGENCY_UNBLOCK),
|
||||
(SERVICE_ADD_CLIENT, self.add_client, SCHEMA_ADD_CLIENT),
|
||||
(SERVICE_REMOVE_CLIENT, self.remove_client, SCHEMA_REMOVE_CLIENT),
|
||||
(SERVICE_REFRESH_DATA, self.refresh_data, SCHEMA_REFRESH_DATA),
|
||||
]
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_EMERGENCY_UNBLOCK,
|
||||
self.emergency_unblock,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_ADD_CLIENT,
|
||||
self.add_client,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_REMOVE_CLIENT,
|
||||
self.remove_client,
|
||||
)
|
||||
|
||||
self.hass.services.register(
|
||||
DOMAIN,
|
||||
SERVICE_REFRESH_DATA,
|
||||
self.refresh_data,
|
||||
)
|
||||
|
||||
_LOGGER.info("AdGuard Control Hub services registered")
|
||||
for service_name, service_func, schema in services:
|
||||
if not self.hass.services.has_service(DOMAIN, service_name):
|
||||
self.hass.services.register(DOMAIN, service_name, service_func, schema=schema)
|
||||
_LOGGER.debug("Registered service: %s", service_name)
|
||||
|
||||
def unregister_services(self) -> None:
|
||||
"""Unregister services."""
|
||||
"""Unregister all services."""
|
||||
_LOGGER.debug("Unregistering AdGuard Control Hub services")
|
||||
|
||||
services = [
|
||||
SERVICE_BLOCK_SERVICES,
|
||||
SERVICE_UNBLOCK_SERVICES,
|
||||
@@ -85,163 +98,179 @@ class AdGuardControlHubServices:
|
||||
SERVICE_REFRESH_DATA,
|
||||
]
|
||||
|
||||
for service in services:
|
||||
if self.hass.services.has_service(DOMAIN, service):
|
||||
self.hass.services.remove(DOMAIN, service)
|
||||
for service_name in services:
|
||||
if self.hass.services.has_service(DOMAIN, service_name):
|
||||
self.hass.services.remove(DOMAIN, service_name)
|
||||
_LOGGER.debug("Unregistered service: %s", service_name)
|
||||
|
||||
_LOGGER.info("AdGuard Control Hub services unregistered")
|
||||
|
||||
def _get_api(self):
|
||||
"""Get API instance from first available entry."""
|
||||
for entry_id, entry_data in self.hass.data[DOMAIN].items():
|
||||
def _get_api_instances(self) -> list[AdGuardHomeAPI]:
|
||||
"""Get all API instances."""
|
||||
apis = []
|
||||
for entry_data in self.hass.data.get(DOMAIN, {}).values():
|
||||
if isinstance(entry_data, dict) and "api" in entry_data:
|
||||
return entry_data["api"]
|
||||
raise AdGuardConnectionError("No AdGuard Control Hub API available")
|
||||
|
||||
def _get_coordinator(self):
|
||||
"""Get coordinator instance from first available entry."""
|
||||
for entry_id, entry_data in self.hass.data[DOMAIN].items():
|
||||
if isinstance(entry_data, dict) and "coordinator" in entry_data:
|
||||
return entry_data["coordinator"]
|
||||
raise AdGuardConnectionError("No AdGuard Control Hub coordinator available")
|
||||
apis.append(entry_data["api"])
|
||||
return apis
|
||||
|
||||
async def block_services(self, call: ServiceCall) -> None:
|
||||
"""Block services for a client."""
|
||||
"""Block services for a specific client."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services_to_block = call.data[ATTR_SERVICES]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
|
||||
try:
|
||||
api = self._get_api()
|
||||
client = await api.get_client_by_name(client_name)
|
||||
_LOGGER.info("Blocking services %s for client %s", services, client_name)
|
||||
|
||||
if not client:
|
||||
_LOGGER.error("Client '%s' not found", client_name)
|
||||
return
|
||||
success_count = 0
|
||||
for api in self._get_api_instances():
|
||||
try:
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if client:
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
if isinstance(current_blocked, dict):
|
||||
current_services = current_blocked.get("ids", [])
|
||||
else:
|
||||
current_services = current_blocked or []
|
||||
|
||||
# Get current blocked services and add new ones
|
||||
current_blocked = set(client.get("blocked_services", []))
|
||||
current_blocked.update(services_to_block)
|
||||
updated_services = list(set(current_services + services))
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
success_count += 1
|
||||
_LOGGER.info("Successfully blocked services for %s", client_name)
|
||||
else:
|
||||
_LOGGER.warning("Client %s not found", client_name)
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("AdGuard error blocking services for %s: %s", client_name, err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error blocking services for %s: %s", client_name, err)
|
||||
|
||||
await api.update_client_blocked_services(
|
||||
client_name, list(current_blocked)
|
||||
)
|
||||
|
||||
coordinator = self._get_coordinator()
|
||||
await coordinator.async_request_refresh()
|
||||
|
||||
_LOGGER.info(
|
||||
"Blocked services %s for client '%s'", services_to_block, client_name
|
||||
)
|
||||
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("Failed to block services for '%s': %s", client_name, err)
|
||||
if success_count == 0:
|
||||
_LOGGER.error("Failed to block services for %s on any instance", client_name)
|
||||
|
||||
async def unblock_services(self, call: ServiceCall) -> None:
|
||||
"""Unblock services for a client."""
|
||||
"""Unblock services for a specific client."""
|
||||
client_name = call.data[ATTR_CLIENT_NAME]
|
||||
services_to_unblock = call.data[ATTR_SERVICES]
|
||||
services = call.data[ATTR_SERVICES]
|
||||
|
||||
try:
|
||||
api = self._get_api()
|
||||
client = await api.get_client_by_name(client_name)
|
||||
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
||||
|
||||
if not client:
|
||||
_LOGGER.error("Client '%s' not found", client_name)
|
||||
return
|
||||
success_count = 0
|
||||
for api in self._get_api_instances():
|
||||
try:
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if client:
|
||||
current_blocked = client.get("blocked_services", {})
|
||||
if isinstance(current_blocked, dict):
|
||||
current_services = current_blocked.get("ids", [])
|
||||
else:
|
||||
current_services = current_blocked or []
|
||||
|
||||
# Get current blocked services and remove specified ones
|
||||
current_blocked = set(client.get("blocked_services", []))
|
||||
current_blocked.difference_update(services_to_unblock)
|
||||
updated_services = [s for s in current_services if s not in services]
|
||||
await api.update_client_blocked_services(client_name, updated_services)
|
||||
success_count += 1
|
||||
_LOGGER.info("Successfully unblocked services for %s", client_name)
|
||||
else:
|
||||
_LOGGER.warning("Client %s not found", client_name)
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("AdGuard error unblocking services for %s: %s", client_name, err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error unblocking services for %s: %s", client_name, err)
|
||||
|
||||
await api.update_client_blocked_services(
|
||||
client_name, list(current_blocked)
|
||||
)
|
||||
|
||||
coordinator = self._get_coordinator()
|
||||
await coordinator.async_request_refresh()
|
||||
|
||||
_LOGGER.info(
|
||||
"Unblocked services %s for client '%s'", services_to_unblock, client_name
|
||||
)
|
||||
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("Failed to unblock services for '%s': %s", client_name, err)
|
||||
if success_count == 0:
|
||||
_LOGGER.error("Failed to unblock services for %s on any instance", client_name)
|
||||
|
||||
async def emergency_unblock(self, call: ServiceCall) -> None:
|
||||
"""Emergency unblock - disable protection temporarily."""
|
||||
duration = call.data.get(ATTR_DURATION, 300)
|
||||
clients = call.data.get(ATTR_CLIENTS, ["all"])
|
||||
"""Emergency unblock - temporarily disable protection."""
|
||||
duration = call.data[ATTR_DURATION]
|
||||
clients = call.data[ATTR_CLIENTS]
|
||||
|
||||
try:
|
||||
api = self._get_api()
|
||||
_LOGGER.warning("Emergency unblock activated for %s seconds", duration)
|
||||
|
||||
if "all" in clients:
|
||||
# Global protection disable
|
||||
await api.set_protection(False)
|
||||
_LOGGER.warning(
|
||||
"Emergency unblock activated globally for %d seconds", duration
|
||||
)
|
||||
for api in self._get_api_instances():
|
||||
try:
|
||||
if "all" in clients:
|
||||
await api.set_protection(False)
|
||||
_LOGGER.warning("Protection disabled for %s:%s", api.host, api.port)
|
||||
|
||||
coordinator = self._get_coordinator()
|
||||
await coordinator.async_request_refresh()
|
||||
# Re-enable after duration
|
||||
async def delayed_enable(api_instance: AdGuardHomeAPI):
|
||||
await asyncio.sleep(duration)
|
||||
try:
|
||||
await api_instance.set_protection(True)
|
||||
_LOGGER.info("Emergency unblock expired - protection re-enabled for %s:%s",
|
||||
api_instance.host, api_instance.port)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to re-enable protection for %s:%s: %s",
|
||||
api_instance.host, api_instance.port, err)
|
||||
|
||||
# Schedule re-enabling protection
|
||||
async def restore_protection():
|
||||
await asyncio.sleep(duration)
|
||||
try:
|
||||
if "all" in clients:
|
||||
await api.set_protection(True)
|
||||
asyncio.create_task(delayed_enable(api))
|
||||
else:
|
||||
# Individual client emergency unblock
|
||||
for client_name in clients:
|
||||
if client_name == "all":
|
||||
continue
|
||||
try:
|
||||
client = await api.get_client_by_name(client_name)
|
||||
if client:
|
||||
update_data = {
|
||||
"name": client_name,
|
||||
"data": {**client, "filtering_enabled": False}
|
||||
}
|
||||
await api.update_client(update_data)
|
||||
_LOGGER.info("Emergency unblock applied to client %s", client_name)
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to emergency unblock client %s: %s", client_name, err)
|
||||
|
||||
await coordinator.async_request_refresh()
|
||||
_LOGGER.info("Emergency unblock period ended, protection restored")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to restore protection after emergency unblock: %s", err)
|
||||
|
||||
# Schedule restoration
|
||||
self.hass.async_create_task(restore_protection())
|
||||
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("Failed to activate emergency unblock: %s", err)
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("AdGuard error during emergency unblock: %s", err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error during emergency unblock: %s", err)
|
||||
|
||||
async def add_client(self, call: ServiceCall) -> None:
|
||||
"""Add a new client."""
|
||||
client_data = dict(call.data)
|
||||
|
||||
try:
|
||||
api = self._get_api()
|
||||
await api.add_client(client_data)
|
||||
_LOGGER.info("Adding new client: %s", client_data.get("name"))
|
||||
|
||||
coordinator = self._get_coordinator()
|
||||
await coordinator.async_request_refresh()
|
||||
success_count = 0
|
||||
for api in self._get_api_instances():
|
||||
try:
|
||||
await api.add_client(client_data)
|
||||
success_count += 1
|
||||
_LOGGER.info("Successfully added client: %s", client_data.get("name"))
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("AdGuard error adding client: %s", err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error adding client: %s", err)
|
||||
|
||||
_LOGGER.info("Added new client: %s", client_data["name"])
|
||||
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("Failed to add client '%s': %s", client_data["name"], err)
|
||||
if success_count == 0:
|
||||
_LOGGER.error("Failed to add client %s on any instance", client_data.get("name"))
|
||||
|
||||
async def remove_client(self, call: ServiceCall) -> None:
|
||||
"""Remove a client."""
|
||||
client_name = call.data["name"]
|
||||
client_name = call.data.get("name")
|
||||
|
||||
try:
|
||||
api = self._get_api()
|
||||
await api.delete_client(client_name)
|
||||
_LOGGER.info("Removing client: %s", client_name)
|
||||
|
||||
coordinator = self._get_coordinator()
|
||||
await coordinator.async_request_refresh()
|
||||
success_count = 0
|
||||
for api in self._get_api_instances():
|
||||
try:
|
||||
await api.delete_client(client_name)
|
||||
success_count += 1
|
||||
_LOGGER.info("Successfully removed client: %s", client_name)
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("AdGuard error removing client: %s", err)
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Unexpected error removing client: %s", err)
|
||||
|
||||
_LOGGER.info("Removed client: %s", client_name)
|
||||
|
||||
except AdGuardHomeError as err:
|
||||
_LOGGER.error("Failed to remove client '%s': %s", client_name, err)
|
||||
if success_count == 0:
|
||||
_LOGGER.error("Failed to remove client %s on any instance", client_name)
|
||||
|
||||
async def refresh_data(self, call: ServiceCall) -> None:
|
||||
"""Refresh data from AdGuard Home."""
|
||||
try:
|
||||
coordinator = self._get_coordinator()
|
||||
await coordinator.async_request_refresh()
|
||||
"""Refresh data for all coordinators."""
|
||||
_LOGGER.info("Manually refreshing AdGuard Control Hub data")
|
||||
|
||||
_LOGGER.info("Data refresh requested")
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to refresh data: %s", err)
|
||||
for entry_data in self.hass.data.get(DOMAIN, {}).values():
|
||||
if isinstance(entry_data, dict) and "coordinator" in entry_data:
|
||||
coordinator = entry_data["coordinator"]
|
||||
try:
|
||||
await coordinator.async_request_refresh()
|
||||
_LOGGER.debug("Refreshed coordinator data")
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to refresh coordinator: %s", err)
|
||||
|
Reference in New Issue
Block a user