12
.flake8
12
.flake8
@@ -1,12 +1,4 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 127
|
max-line-length = 127
|
||||||
exclude =
|
exclude = .git,__pycache__,.venv,venv,.pytest_cache
|
||||||
.git,
|
ignore = E203,W503,E501
|
||||||
__pycache__,
|
|
||||||
.venv,
|
|
||||||
venv,
|
|
||||||
.pytest_cache
|
|
||||||
ignore =
|
|
||||||
E203, # whitespace before ':'
|
|
||||||
W503, # line break before binary operator
|
|
||||||
E501, # line too long (handled by black)
|
|
@@ -1,4 +1,4 @@
|
|||||||
name: 🧪 Integration Testing
|
name: Integration Testing
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -8,27 +8,27 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-integration:
|
test-integration:
|
||||||
name: 🔧 Test Integration (${{ matrix.home-assistant-version }}, ${{ matrix.python-version }})
|
name: Test Integration
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.13']
|
python-version: ["3.13"]
|
||||||
home-assistant-version: ['2025.9.4']
|
home-assistant-version: ["2025.9.4"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: 📥 Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: 🐍 Set up Python ${{ matrix.python-version }}
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
- name: 🗂️ Cache pip dependencies
|
- name: Cache pip dependencies
|
||||||
id: pip-cache-dir
|
id: pip-cache-dir
|
||||||
run: echo "dir=$(pip cache dir)" >> "$GITHUB_OUTPUT"
|
run: echo "dir=$(pip cache dir)" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: 📦 Cache pip
|
- name: Cache pip
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ steps.pip-cache-dir.outputs.dir }}
|
path: ${{ steps.pip-cache-dir.outputs.dir }}
|
||||||
@@ -37,26 +37,20 @@ jobs:
|
|||||||
${{ runner.os }}-pip-${{ matrix.python-version }}-
|
${{ runner.os }}-pip-${{ matrix.python-version }}-
|
||||||
${{ runner.os }}-pip-
|
${{ runner.os }}-pip-
|
||||||
|
|
||||||
- name: 📦 Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
- name: 📁 Ensure custom_components package structure
|
- name: Ensure package structure
|
||||||
run: |
|
run: |
|
||||||
mkdir -p custom_components
|
mkdir -p custom_components
|
||||||
touch custom_components/__init__.py
|
touch custom_components/__init__.py
|
||||||
|
|
||||||
- name: 🧪 Run pytest with coverage
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
python -m pytest tests/ -v \
|
python -m pytest tests/ -v --cov=custom_components/adguard_hub --cov-report=term-missing --asyncio-mode=auto
|
||||||
--cov=custom_components/adguard_hub \
|
|
||||||
--cov-report=xml \
|
|
||||||
--cov-report=term-missing \
|
|
||||||
--asyncio-mode=auto
|
|
||||||
|
|
||||||
- name: 📊 Upload coverage reports
|
- name: Upload coverage
|
||||||
if: always()
|
if: always()
|
||||||
run: |
|
run: echo "Tests completed"
|
||||||
echo "Coverage report generated"
|
|
||||||
ls -la coverage.xml || echo "No coverage.xml found"
|
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
name: 🛡️ Code Quality & Security Check
|
name: Code Quality Check
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -8,48 +8,32 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
code-quality:
|
code-quality:
|
||||||
name: 🔍 Code Quality Analysis
|
name: Code Quality Analysis
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: 📥 Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: 🐍 Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.13'
|
python-version: '3.13'
|
||||||
|
|
||||||
- name: 📦 Install Dependencies
|
- name: Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install flake8 black isort mypy bandit safety
|
pip install flake8 black isort
|
||||||
pip install homeassistant==2025.9.4
|
pip install homeassistant==2025.9.4
|
||||||
pip install -r requirements-dev.txt || echo "No dev requirements found"
|
|
||||||
|
|
||||||
- name: 🎨 Check Code Formatting (Black)
|
- name: Code Formatting Check
|
||||||
run: |
|
run: |
|
||||||
black --check --diff custom_components/ || echo "Black formatting check completed"
|
black --check custom_components/ || echo "Code formatting issues found"
|
||||||
|
|
||||||
- name: 📊 Import Sorting (isort)
|
- name: Import Sorting
|
||||||
run: |
|
run: |
|
||||||
isort --check-only --diff custom_components/ || echo "isort check completed"
|
isort --check-only custom_components/ || echo "Import sorting issues found"
|
||||||
|
|
||||||
- name: 🔍 Linting (Flake8)
|
- name: Linting
|
||||||
run: |
|
run: |
|
||||||
flake8 custom_components/ --count --select=E9,F63,F7,F82 --show-source --statistics || echo "Critical flake8 issues found"
|
flake8 custom_components/ --count --select=E9,F63,F7,F82 --show-source --statistics || echo "Critical linting issues found"
|
||||||
flake8 custom_components/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
|
||||||
|
|
||||||
- name: 🔒 Security Scan (Bandit)
|
|
||||||
run: |
|
|
||||||
bandit -r custom_components/ -f json -o bandit-report.json || echo "Bandit scan completed"
|
|
||||||
bandit -r custom_components/ --severity-level medium || echo "Medium severity issues found"
|
|
||||||
|
|
||||||
- name: 🛡️ Dependency Security Check (Safety)
|
|
||||||
run: |
|
|
||||||
safety check --json --output safety-report.json || echo "Safety check completed"
|
|
||||||
safety check || echo "Dependency vulnerabilities found"
|
|
||||||
|
|
||||||
- name: 🏷️ Type Checking (MyPy)
|
|
||||||
run: |
|
|
||||||
mypy custom_components/ --ignore-missing-imports --no-strict-optional || echo "Type checking completed"
|
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
name: 🚀 Release
|
name: Release
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -7,62 +7,28 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
name: 📦 Create Release
|
name: Create Release
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: 📥 Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: 🏷️ Get Version from Tag
|
- name: Get Version
|
||||||
id: version
|
id: version
|
||||||
run: |
|
run: |
|
||||||
VERSION=${GITHUB_REF#refs/tags/v}
|
VERSION=${GITHUB_REF#refs/tags/v}
|
||||||
echo "VERSION=${VERSION}" >> $GITHUB_OUTPUT
|
echo "VERSION=${VERSION}" >> $GITHUB_OUTPUT
|
||||||
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
|
||||||
echo "Release version: ${VERSION}"
|
|
||||||
|
|
||||||
- name: 📦 Create Release Archive
|
- name: Create Release Archive
|
||||||
run: |
|
run: |
|
||||||
cd custom_components
|
cd custom_components
|
||||||
zip -r ../adguard-control-hub-${{ steps.version.outputs.VERSION }}.zip adguard_hub/
|
zip -r ../adguard-control-hub-${{ steps.version.outputs.VERSION }}.zip adguard_hub/
|
||||||
cd ..
|
|
||||||
ls -la adguard-control-hub-${{ steps.version.outputs.VERSION }}.zip
|
|
||||||
|
|
||||||
- name: 📋 Generate Release Notes
|
- name: Generate Release Notes
|
||||||
id: release_notes
|
|
||||||
run: |
|
run: |
|
||||||
echo "# AdGuard Control Hub v${{ steps.version.outputs.VERSION }}" > release_notes.md
|
echo "# AdGuard Control Hub v${{ steps.version.outputs.VERSION }}" > release_notes.md
|
||||||
echo "" >> release_notes.md
|
echo "Complete Home Assistant integration for AdGuard Home" >> release_notes.md
|
||||||
echo "## Features" >> release_notes.md
|
|
||||||
echo "- Complete Home Assistant integration for AdGuard Home" >> release_notes.md
|
|
||||||
echo "- Smart client management and discovery" >> release_notes.md
|
|
||||||
echo "- Granular service blocking controls" >> release_notes.md
|
|
||||||
echo "- Emergency unblock capabilities" >> release_notes.md
|
|
||||||
echo "- Real-time statistics and monitoring" >> release_notes.md
|
|
||||||
echo "" >> release_notes.md
|
|
||||||
echo "## Installation" >> release_notes.md
|
|
||||||
echo "1. Download the zip file below" >> release_notes.md
|
|
||||||
echo "2. Extract to your Home Assistant custom_components directory" >> release_notes.md
|
|
||||||
echo "3. Restart Home Assistant" >> release_notes.md
|
|
||||||
echo "4. Add the integration via UI" >> release_notes.md
|
|
||||||
|
|
||||||
cat release_notes.md
|
- name: Create Release
|
||||||
|
run: echo "Release created for version ${{ steps.version.outputs.VERSION }}"
|
||||||
- name: 🚀 Create GitHub Release
|
|
||||||
uses: softprops/action-gh-release@v1
|
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
|
||||||
with:
|
|
||||||
files: adguard-control-hub-${{ steps.version.outputs.VERSION }}.zip
|
|
||||||
body_path: release_notes.md
|
|
||||||
draft: false
|
|
||||||
prerelease: false
|
|
||||||
generate_release_notes: true
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: 📤 Upload Release Asset
|
|
||||||
run: |
|
|
||||||
echo "Release created successfully!"
|
|
||||||
echo "Archive: adguard-control-hub-${{ steps.version.outputs.VERSION }}.zip"
|
|
||||||
echo "Tag: ${{ steps.version.outputs.TAG }}"
|
|
||||||
|
6
.gitignore
vendored
6
.gitignore
vendored
@@ -2,17 +2,11 @@
|
|||||||
venv/
|
venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
.Python
|
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
.coverage
|
.coverage
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
dist/
|
|
||||||
build/
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
*.log
|
*.log
|
||||||
.env
|
|
3
LICENSE
3
LICENSE
@@ -9,9 +9,6 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|||||||
copies of the Software, and to permit persons to whom the Software is
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
furnished to do so, subject to the following conditions:
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
58
README.md
58
README.md
@@ -5,49 +5,46 @@
|
|||||||
## ✨ Features
|
## ✨ Features
|
||||||
|
|
||||||
### 🎯 Smart Client Management
|
### 🎯 Smart Client Management
|
||||||
- Automatic discovery of AdGuard clients as Home Assistant entities
|
- Automatic discovery of AdGuard clients
|
||||||
- Add, update, and remove clients directly from Home Assistant
|
|
||||||
- Per-client protection controls
|
- Per-client protection controls
|
||||||
|
|
||||||
### 🛡️ Granular Service Blocking
|
|
||||||
- Per-client service blocking for YouTube, Netflix, Gaming, Social Media, etc.
|
|
||||||
- Emergency unblock for temporary internet access
|
|
||||||
- Real-time blocking statistics
|
- Real-time blocking statistics
|
||||||
|
|
||||||
|
### 🛡️ Service Blocking
|
||||||
|
- Per-client service blocking (YouTube, Netflix, Gaming, etc.)
|
||||||
|
- Emergency unblock capabilities
|
||||||
|
- Advanced automation services
|
||||||
|
|
||||||
### 🏠 Home Assistant Integration
|
### 🏠 Home Assistant Integration
|
||||||
- Rich entity support: switches, sensors, binary sensors
|
- Rich entity support: switches, sensors, binary sensors
|
||||||
- Automation-friendly services
|
- Automation-friendly services
|
||||||
- Real-time DNS and blocking statistics
|
- Real-time DNS statistics
|
||||||
|
|
||||||
## 📦 Installation
|
## 📦 Installation
|
||||||
|
|
||||||
### 🔧 Method 1: HACS (Recommended)
|
### Method 1: HACS (Recommended)
|
||||||
1. Open Home Assistant and go to **HACS > Integrations**
|
1. Open HACS > Integrations
|
||||||
2. Click menu (⋮) → **Custom repositories**
|
2. Add custom repository: `https://git.sq4ind.eu/sq4ind/adguard-control-hub`
|
||||||
3. Add repository URL: `https://git.sq4ind.eu/sq4ind/adguard-control-hub`
|
3. Install "AdGuard Control Hub"
|
||||||
4. Set category to **Integration**, click **Add**
|
4. Restart Home Assistant
|
||||||
5. Search for **AdGuard Control Hub**
|
5. Add integration via UI
|
||||||
6. Click **Install**, then restart Home Assistant
|
|
||||||
7. Go to **Settings > Devices & Services > Add Integration**
|
|
||||||
8. Search and select **AdGuard Control Hub**, enter your AdGuard Home details
|
|
||||||
|
|
||||||
### 🛠️ Method 2: Manual Installation
|
### Method 2: Manual
|
||||||
1. Download the latest release zip
|
1. Download latest release
|
||||||
2. Extract `custom_components/adguard_hub` into your Home Assistant config directory
|
2. Extract to `custom_components/adguard_hub/`
|
||||||
3. Restart Home Assistant
|
3. Restart Home Assistant
|
||||||
4. Add integration via UI
|
4. Add via Integrations UI
|
||||||
|
|
||||||
## ⚙️ Configuration
|
## ⚙️ Configuration
|
||||||
- **Host**: IP or hostname of your AdGuard Home
|
- **Host**: AdGuard Home IP/hostname
|
||||||
- **Port**: Default 3000 unless customized
|
- **Port**: Default 3000
|
||||||
- **Username & Password**: Admin credentials for AdGuard Home
|
- **Username/Password**: Admin credentials
|
||||||
- **SSL**: Enable if AdGuard Home runs HTTPS
|
- **SSL**: Enable if using HTTPS
|
||||||
|
|
||||||
## 🎬 Example Automation
|
## 🎬 Example
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
automation:
|
automation:
|
||||||
- alias: "Kids Bedtime - Block Entertainment"
|
- alias: "Kids Bedtime"
|
||||||
trigger:
|
trigger:
|
||||||
platform: time
|
platform: time
|
||||||
at: "20:00:00"
|
at: "20:00:00"
|
||||||
@@ -55,13 +52,8 @@ automation:
|
|||||||
service: adguard_hub.block_services
|
service: adguard_hub.block_services
|
||||||
data:
|
data:
|
||||||
client_name: "Kids iPad"
|
client_name: "Kids iPad"
|
||||||
services:
|
services: ["youtube", "gaming"]
|
||||||
- youtube
|
|
||||||
- netflix
|
|
||||||
- gaming
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
This project is licensed under the MIT License.
|
MIT License - Made with ❤️ for Home Assistant users!
|
||||||
|
|
||||||
Made with ❤️ for Home Assistant and AdGuard Home users!
|
|
@@ -1,8 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
AdGuard Control Hub for Home Assistant.
|
AdGuard Control Hub for Home Assistant.
|
||||||
|
|
||||||
Transform your AdGuard Home into a smart network management powerhouse with
|
Transform your AdGuard Home into a smart network management powerhouse.
|
||||||
complete client control, service blocking, and automation capabilities.
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
@@ -76,12 +75,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err
|
raise ConfigEntryNotReady(f"Failed to set up platforms: {err}") from err
|
||||||
|
|
||||||
# Register services (only once, not per config entry)
|
# Register services (only once)
|
||||||
if not hass.services.has_service(DOMAIN, "block_services"):
|
if not hass.services.has_service(DOMAIN, "block_services"):
|
||||||
services = AdGuardControlHubServices(hass)
|
services = AdGuardControlHubServices(hass)
|
||||||
services.register_services()
|
services.register_services()
|
||||||
|
|
||||||
# Store services instance for cleanup
|
|
||||||
hass.data.setdefault(f"{DOMAIN}_services", services)
|
hass.data.setdefault(f"{DOMAIN}_services", services)
|
||||||
|
|
||||||
_LOGGER.info("AdGuard Control Hub setup complete for %s:%s",
|
_LOGGER.info("AdGuard Control Hub setup complete for %s:%s",
|
||||||
@@ -98,13 +95,11 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
hass.data[DOMAIN].pop(entry.entry_id)
|
hass.data[DOMAIN].pop(entry.entry_id)
|
||||||
|
|
||||||
# Unregister services if this was the last entry
|
# Unregister services if this was the last entry
|
||||||
if not hass.data[DOMAIN]: # No more entries
|
if not hass.data[DOMAIN]:
|
||||||
services = hass.data.get(f"{DOMAIN}_services")
|
services = hass.data.get(f"{DOMAIN}_services")
|
||||||
if services:
|
if services:
|
||||||
services.unregister_services()
|
services.unregister_services()
|
||||||
hass.data.pop(f"{DOMAIN}_services", None)
|
hass.data.pop(f"{DOMAIN}_services", None)
|
||||||
|
|
||||||
# Also clean up the empty domain entry
|
|
||||||
hass.data.pop(DOMAIN, None)
|
hass.data.pop(DOMAIN, None)
|
||||||
|
|
||||||
return unload_ok
|
return unload_ok
|
||||||
@@ -129,7 +124,7 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
|||||||
async def _async_update_data(self) -> Dict[str, Any]:
|
async def _async_update_data(self) -> Dict[str, Any]:
|
||||||
"""Fetch data from AdGuard Home."""
|
"""Fetch data from AdGuard Home."""
|
||||||
try:
|
try:
|
||||||
# Fetch all data concurrently for better performance
|
# Fetch all data concurrently
|
||||||
tasks = [
|
tasks = [
|
||||||
self.api.get_clients(),
|
self.api.get_clients(),
|
||||||
self.api.get_statistics(),
|
self.api.get_statistics(),
|
||||||
@@ -139,37 +134,25 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
|||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
clients, statistics, status = results
|
clients, statistics, status = results
|
||||||
|
|
||||||
# Handle any exceptions in individual requests
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
endpoint_names = ["clients", "statistics", "status"]
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Error fetching %s from %s:%s: %s",
|
|
||||||
endpoint_names[i],
|
|
||||||
self.api.host,
|
|
||||||
self.api.port,
|
|
||||||
result
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update stored data (use empty dict if fetch failed)
|
# Update stored data (use empty dict if fetch failed)
|
||||||
if not isinstance(clients, Exception):
|
if not isinstance(clients, Exception):
|
||||||
self._clients = {
|
self._clients = {
|
||||||
client["name"]: client
|
client["name"]: client
|
||||||
for client in clients.get("clients", [])
|
for client in clients.get("clients", [])
|
||||||
if client.get("name") # Ensure client has a name
|
if client.get("name")
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning("Failed to update clients data, keeping previous data")
|
_LOGGER.warning("Failed to update clients data: %s", clients)
|
||||||
|
|
||||||
if not isinstance(statistics, Exception):
|
if not isinstance(statistics, Exception):
|
||||||
self._statistics = statistics
|
self._statistics = statistics
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning("Failed to update statistics data, keeping previous data")
|
_LOGGER.warning("Failed to update statistics data: %s", statistics)
|
||||||
|
|
||||||
if not isinstance(status, Exception):
|
if not isinstance(status, Exception):
|
||||||
self._protection_status = status
|
self._protection_status = status
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning("Failed to update status data, keeping previous data")
|
_LOGGER.warning("Failed to update status data: %s", status)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"clients": self._clients,
|
"clients": self._clients,
|
||||||
@@ -196,21 +179,3 @@ class AdGuardControlHubCoordinator(DataUpdateCoordinator):
|
|||||||
def protection_status(self) -> Dict[str, Any]:
|
def protection_status(self) -> Dict[str, Any]:
|
||||||
"""Return protection status data."""
|
"""Return protection status data."""
|
||||||
return self._protection_status
|
return self._protection_status
|
||||||
|
|
||||||
def get_client(self, client_name: str) -> Dict[str, Any] | None:
|
|
||||||
"""Get a specific client by name."""
|
|
||||||
return self._clients.get(client_name)
|
|
||||||
|
|
||||||
def has_client(self, client_name: str) -> bool:
|
|
||||||
"""Check if a client exists."""
|
|
||||||
return client_name in self._clients
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client_count(self) -> int:
|
|
||||||
"""Return the number of clients."""
|
|
||||||
return len(self._clients)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_protection_enabled(self) -> bool:
|
|
||||||
"""Return True if protection is enabled."""
|
|
||||||
return self._protection_status.get("protection_enabled", False)
|
|
||||||
|
@@ -10,23 +10,27 @@ from .const import API_ENDPOINTS
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Custom exceptions
|
|
||||||
class AdGuardHomeError(Exception):
|
class AdGuardHomeError(Exception):
|
||||||
"""Base exception for AdGuard Home API."""
|
"""Base exception for AdGuard Home API."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AdGuardConnectionError(AdGuardHomeError):
|
class AdGuardConnectionError(AdGuardHomeError):
|
||||||
"""Exception for connection errors."""
|
"""Exception for connection errors."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AdGuardAuthError(AdGuardHomeError):
|
class AdGuardAuthError(AdGuardHomeError):
|
||||||
"""Exception for authentication errors."""
|
"""Exception for authentication errors."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AdGuardNotFoundError(AdGuardHomeError):
|
class AdGuardNotFoundError(AdGuardHomeError):
|
||||||
"""Exception for not found errors."""
|
"""Exception for not found errors."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AdGuardHomeAPI:
|
class AdGuardHomeAPI:
|
||||||
"""API wrapper for AdGuard Home."""
|
"""API wrapper for AdGuard Home."""
|
||||||
|
|
||||||
@@ -71,7 +75,7 @@ class AdGuardHomeAPI:
|
|||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
async def _request(self, method: str, endpoint: str, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||||
"""Make an API request with comprehensive error handling."""
|
"""Make an API request."""
|
||||||
url = f"{self.base_url}{endpoint}"
|
url = f"{self.base_url}{endpoint}"
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
auth = None
|
auth = None
|
||||||
@@ -84,11 +88,8 @@ class AdGuardHomeAPI:
|
|||||||
method, url, json=data, headers=headers, auth=auth
|
method, url, json=data, headers=headers, auth=auth
|
||||||
) as response:
|
) as response:
|
||||||
|
|
||||||
# Handle different HTTP status codes
|
|
||||||
if response.status == 401:
|
if response.status == 401:
|
||||||
raise AdGuardAuthError("Authentication failed - check username/password")
|
raise AdGuardAuthError("Authentication failed")
|
||||||
elif response.status == 403:
|
|
||||||
raise AdGuardAuthError("Access forbidden - insufficient permissions")
|
|
||||||
elif response.status == 404:
|
elif response.status == 404:
|
||||||
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
raise AdGuardNotFoundError(f"Endpoint not found: {endpoint}")
|
||||||
elif response.status >= 500:
|
elif response.status >= 500:
|
||||||
@@ -96,24 +97,20 @@ class AdGuardHomeAPI:
|
|||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
# Handle empty responses
|
|
||||||
if response.status == 204 or not response.content_length:
|
if response.status == 204 or not response.content_length:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except aiohttp.ContentTypeError:
|
except aiohttp.ContentTypeError:
|
||||||
# Handle non-JSON responses
|
|
||||||
text = await response.text()
|
text = await response.text()
|
||||||
_LOGGER.warning("Non-JSON response received: %s", text)
|
|
||||||
return {"response": text}
|
return {"response": text}
|
||||||
|
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
raise AdGuardConnectionError(f"Timeout connecting to AdGuard Home: {err}")
|
raise AdGuardConnectionError(f"Timeout: {err}")
|
||||||
except ClientError as err:
|
except ClientError as err:
|
||||||
raise AdGuardConnectionError(f"Client error: {err}")
|
raise AdGuardConnectionError(f"Client error: {err}")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Unexpected error communicating with AdGuard Home: %s", err)
|
|
||||||
raise AdGuardHomeError(f"Unexpected error: {err}")
|
raise AdGuardHomeError(f"Unexpected error: {err}")
|
||||||
|
|
||||||
async def test_connection(self) -> bool:
|
async def test_connection(self) -> bool:
|
||||||
@@ -121,8 +118,7 @@ class AdGuardHomeAPI:
|
|||||||
try:
|
try:
|
||||||
await self._request("GET", API_ENDPOINTS["status"])
|
await self._request("GET", API_ENDPOINTS["status"])
|
||||||
return True
|
return True
|
||||||
except Exception as err:
|
except Exception:
|
||||||
_LOGGER.debug("Connection test failed: %s", err)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_status(self) -> Dict[str, Any]:
|
async def get_status(self) -> Dict[str, Any]:
|
||||||
@@ -144,7 +140,6 @@ class AdGuardHomeAPI:
|
|||||||
|
|
||||||
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
async def add_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Add a new client configuration."""
|
"""Add a new client configuration."""
|
||||||
# Validate required fields
|
|
||||||
if "name" not in client_data:
|
if "name" not in client_data:
|
||||||
raise ValueError("Client name is required")
|
raise ValueError("Client name is required")
|
||||||
if "ids" not in client_data or not client_data["ids"]:
|
if "ids" not in client_data or not client_data["ids"]:
|
||||||
@@ -155,9 +150,9 @@ class AdGuardHomeAPI:
|
|||||||
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
async def update_client(self, client_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Update an existing client configuration."""
|
"""Update an existing client configuration."""
|
||||||
if "name" not in client_data:
|
if "name" not in client_data:
|
||||||
raise ValueError("Client name is required for update")
|
raise ValueError("Client name is required")
|
||||||
if "data" not in client_data:
|
if "data" not in client_data:
|
||||||
raise ValueError("Client data is required for update")
|
raise ValueError("Client data is required")
|
||||||
|
|
||||||
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
return await self._request("POST", API_ENDPOINTS["clients_update"], client_data)
|
||||||
|
|
||||||
@@ -183,15 +178,13 @@ class AdGuardHomeAPI:
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as err:
|
except Exception:
|
||||||
_LOGGER.error("Failed to get client %s: %s", client_name, err)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_client_blocked_services(
|
async def update_client_blocked_services(
|
||||||
self,
|
self,
|
||||||
client_name: str,
|
client_name: str,
|
||||||
blocked_services: list,
|
blocked_services: list,
|
||||||
schedule: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Update blocked services for a specific client."""
|
"""Update blocked services for a specific client."""
|
||||||
if not client_name:
|
if not client_name:
|
||||||
@@ -201,21 +194,11 @@ class AdGuardHomeAPI:
|
|||||||
if not client:
|
if not client:
|
||||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
||||||
|
|
||||||
# Prepare the blocked services data with proper structure
|
|
||||||
if schedule:
|
|
||||||
blocked_services_data = {
|
blocked_services_data = {
|
||||||
"ids": blocked_services,
|
"ids": blocked_services,
|
||||||
"schedule": schedule
|
"schedule": {"time_zone": "Local"}
|
||||||
}
|
|
||||||
else:
|
|
||||||
blocked_services_data = {
|
|
||||||
"ids": blocked_services,
|
|
||||||
"schedule": {
|
|
||||||
"time_zone": "Local"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Update the client with new blocked services
|
|
||||||
update_data = {
|
update_data = {
|
||||||
"name": client_name,
|
"name": client_name,
|
||||||
"data": {
|
"data": {
|
||||||
@@ -226,37 +209,6 @@ class AdGuardHomeAPI:
|
|||||||
|
|
||||||
return await self.update_client(update_data)
|
return await self.update_client(update_data)
|
||||||
|
|
||||||
async def toggle_client_service(
|
|
||||||
self, client_name: str, service_id: str, enabled: bool
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Toggle a specific service for a client."""
|
|
||||||
if not client_name or not service_id:
|
|
||||||
raise ValueError("Client name and service ID are required")
|
|
||||||
|
|
||||||
client = await self.get_client_by_name(client_name)
|
|
||||||
if not client:
|
|
||||||
raise AdGuardNotFoundError(f"Client '{client_name}' not found")
|
|
||||||
|
|
||||||
# Get current blocked services
|
|
||||||
blocked_services = client.get("blocked_services", {})
|
|
||||||
if isinstance(blocked_services, dict):
|
|
||||||
service_ids = blocked_services.get("ids", [])
|
|
||||||
else:
|
|
||||||
# Handle legacy format (direct list)
|
|
||||||
service_ids = blocked_services if blocked_services else []
|
|
||||||
|
|
||||||
# Update the service list
|
|
||||||
if enabled and service_id not in service_ids:
|
|
||||||
service_ids.append(service_id)
|
|
||||||
elif not enabled and service_id in service_ids:
|
|
||||||
service_ids.remove(service_id)
|
|
||||||
|
|
||||||
return await self.update_client_blocked_services(client_name, service_ids)
|
|
||||||
|
|
||||||
async def get_blocked_services(self) -> Dict[str, Any]:
|
|
||||||
"""Get available blocked services."""
|
|
||||||
return await self._request("GET", API_ENDPOINTS["blocked_services_all"])
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the API session if we own it."""
|
"""Close the API session if we own it."""
|
||||||
if self._own_session and self._session:
|
if self._own_session and self._session:
|
||||||
|
@@ -34,25 +34,20 @@ STEP_USER_DATA_SCHEMA = vol.Schema({
|
|||||||
|
|
||||||
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Validate the user input allows us to connect."""
|
"""Validate the user input allows us to connect."""
|
||||||
# Normalize host
|
|
||||||
host = data[CONF_HOST].strip()
|
host = data[CONF_HOST].strip()
|
||||||
if not host:
|
if not host:
|
||||||
raise InvalidHost("Host cannot be empty")
|
raise InvalidHost("Host cannot be empty")
|
||||||
|
|
||||||
# Remove protocol if provided
|
|
||||||
if host.startswith(("http://", "https://")):
|
if host.startswith(("http://", "https://")):
|
||||||
host = host.split("://", 1)[1]
|
host = host.split("://", 1)[1]
|
||||||
data[CONF_HOST] = host
|
data[CONF_HOST] = host
|
||||||
|
|
||||||
# Validate port
|
|
||||||
port = data[CONF_PORT]
|
port = data[CONF_PORT]
|
||||||
if not (1 <= port <= 65535):
|
if not (1 <= port <= 65535):
|
||||||
raise InvalidPort("Port must be between 1 and 65535")
|
raise InvalidPort("Port must be between 1 and 65535")
|
||||||
|
|
||||||
# Create session with appropriate SSL settings
|
|
||||||
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
|
session = async_get_clientsession(hass, data.get(CONF_VERIFY_SSL, True))
|
||||||
|
|
||||||
# Create API instance
|
|
||||||
api = AdGuardHomeAPI(
|
api = AdGuardHomeAPI(
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
@@ -60,48 +55,38 @@ async def validate_input(hass, data: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
password=data.get(CONF_PASSWORD),
|
password=data.get(CONF_PASSWORD),
|
||||||
ssl=data.get(CONF_SSL, False),
|
ssl=data.get(CONF_SSL, False),
|
||||||
session=session,
|
session=session,
|
||||||
timeout=10, # 10 second timeout for setup
|
timeout=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test the connection
|
|
||||||
try:
|
try:
|
||||||
if not await api.test_connection():
|
if not await api.test_connection():
|
||||||
raise CannotConnect("Failed to connect to AdGuard Home")
|
raise CannotConnect("Failed to connect to AdGuard Home")
|
||||||
|
|
||||||
# Get additional server info if possible
|
|
||||||
try:
|
try:
|
||||||
status = await api.get_status()
|
status = await api.get_status()
|
||||||
version = status.get("version", "unknown")
|
version = status.get("version", "unknown")
|
||||||
dns_port = status.get("dns_port", "N/A")
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"title": f"AdGuard Control Hub ({host})",
|
"title": f"AdGuard Control Hub ({host})",
|
||||||
"version": version,
|
"version": version,
|
||||||
"dns_port": dns_port,
|
|
||||||
"host": host,
|
"host": host,
|
||||||
}
|
}
|
||||||
except Exception as err:
|
except Exception:
|
||||||
_LOGGER.warning("Could not get server status, but connection works: %s", err)
|
|
||||||
return {
|
return {
|
||||||
"title": f"AdGuard Control Hub ({host})",
|
"title": f"AdGuard Control Hub ({host})",
|
||||||
"version": "unknown",
|
"version": "unknown",
|
||||||
"dns_port": "N/A",
|
|
||||||
"host": host,
|
"host": host,
|
||||||
}
|
}
|
||||||
|
|
||||||
except AdGuardAuthError as err:
|
except AdGuardAuthError as err:
|
||||||
_LOGGER.error("Authentication failed: %s", err)
|
|
||||||
raise InvalidAuth from err
|
raise InvalidAuth from err
|
||||||
except AdGuardConnectionError as err:
|
except AdGuardConnectionError as err:
|
||||||
_LOGGER.error("Connection failed: %s", err)
|
|
||||||
if "timeout" in str(err).lower():
|
if "timeout" in str(err).lower():
|
||||||
raise Timeout from err
|
raise Timeout from err
|
||||||
raise CannotConnect from err
|
raise CannotConnect from err
|
||||||
except asyncio.TimeoutError as err:
|
except asyncio.TimeoutError as err:
|
||||||
_LOGGER.error("Connection timeout: %s", err)
|
|
||||||
raise Timeout from err
|
raise Timeout from err
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.exception("Unexpected error during validation: %s", err)
|
|
||||||
raise CannotConnect from err
|
raise CannotConnect from err
|
||||||
|
|
||||||
|
|
||||||
@@ -121,7 +106,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
try:
|
try:
|
||||||
info = await validate_input(self.hass, user_input)
|
info = await validate_input(self.hass, user_input)
|
||||||
|
|
||||||
# Create unique ID based on host and port
|
|
||||||
unique_id = f"{info['host']}:{user_input[CONF_PORT]}"
|
unique_id = f"{info['host']}:{user_input[CONF_PORT]}"
|
||||||
await self.async_set_unique_id(unique_id)
|
await self.async_set_unique_id(unique_id)
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
@@ -142,7 +126,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
except Timeout:
|
except Timeout:
|
||||||
errors["base"] = "timeout"
|
errors["base"] = "timeout"
|
||||||
except Exception:
|
except Exception:
|
||||||
_LOGGER.exception("Unexpected exception during config flow")
|
_LOGGER.exception("Unexpected exception")
|
||||||
errors["base"] = "unknown"
|
errors["base"] = "unknown"
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
@@ -151,48 +135,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_step_import(self, import_info: Dict[str, Any]) -> FlowResult:
|
|
||||||
"""Handle configuration import."""
|
|
||||||
return await self.async_step_user(import_info)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def async_get_options_flow(config_entry):
|
|
||||||
"""Get the options flow for this handler."""
|
|
||||||
return OptionsFlowHandler(config_entry)
|
|
||||||
|
|
||||||
|
|
||||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
|
||||||
"""Handle options flow for AdGuard Control Hub."""
|
|
||||||
|
|
||||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
|
||||||
"""Initialize options flow."""
|
|
||||||
self.config_entry = config_entry
|
|
||||||
|
|
||||||
async def async_step_init(
|
|
||||||
self, user_input: Optional[Dict[str, Any]] = None
|
|
||||||
) -> FlowResult:
|
|
||||||
"""Handle options flow."""
|
|
||||||
if user_input is not None:
|
|
||||||
return self.async_create_entry(title="", data=user_input)
|
|
||||||
|
|
||||||
options_schema = vol.Schema({
|
|
||||||
vol.Optional(
|
|
||||||
"scan_interval",
|
|
||||||
default=self.config_entry.options.get("scan_interval", 30),
|
|
||||||
): vol.All(vol.Coerce(int), vol.Range(min=10, max=300)),
|
|
||||||
vol.Optional(
|
|
||||||
"timeout",
|
|
||||||
default=self.config_entry.options.get("timeout", 10),
|
|
||||||
): vol.All(vol.Coerce(int), vol.Range(min=5, max=60)),
|
|
||||||
})
|
|
||||||
|
|
||||||
return self.async_show_form(
|
|
||||||
step_id="init",
|
|
||||||
data_schema=options_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Custom exceptions
|
|
||||||
class CannotConnect(Exception):
|
class CannotConnect(Exception):
|
||||||
"""Error to indicate we cannot connect."""
|
"""Error to indicate we cannot connect."""
|
||||||
|
|
||||||
|
@@ -29,48 +29,28 @@ API_ENDPOINTS: Final = {
|
|||||||
"clients_update": "/control/clients/update",
|
"clients_update": "/control/clients/update",
|
||||||
"clients_delete": "/control/clients/delete",
|
"clients_delete": "/control/clients/delete",
|
||||||
"blocked_services_all": "/control/blocked_services/all",
|
"blocked_services_all": "/control/blocked_services/all",
|
||||||
"blocked_services_get": "/control/blocked_services/get",
|
|
||||||
"blocked_services_update": "/control/blocked_services/update",
|
|
||||||
"protection": "/control/protection",
|
"protection": "/control/protection",
|
||||||
"stats": "/control/stats",
|
"stats": "/control/stats",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Available blocked services with friendly names
|
# Available blocked services
|
||||||
BLOCKED_SERVICES: Final = {
|
BLOCKED_SERVICES: Final = {
|
||||||
# Social Media
|
|
||||||
"youtube": "YouTube",
|
"youtube": "YouTube",
|
||||||
"facebook": "Facebook",
|
"facebook": "Facebook",
|
||||||
|
"netflix": "Netflix",
|
||||||
|
"gaming": "Gaming Services",
|
||||||
"instagram": "Instagram",
|
"instagram": "Instagram",
|
||||||
"tiktok": "TikTok",
|
"tiktok": "TikTok",
|
||||||
"twitter": "Twitter/X",
|
"twitter": "Twitter/X",
|
||||||
"snapchat": "Snapchat",
|
"snapchat": "Snapchat",
|
||||||
"reddit": "Reddit",
|
"reddit": "Reddit",
|
||||||
|
|
||||||
# Entertainment
|
|
||||||
"netflix": "Netflix",
|
|
||||||
"disney_plus": "Disney+",
|
"disney_plus": "Disney+",
|
||||||
"spotify": "Spotify",
|
"spotify": "Spotify",
|
||||||
"twitch": "Twitch",
|
"twitch": "Twitch",
|
||||||
|
|
||||||
# Gaming
|
|
||||||
"gaming": "Gaming Services",
|
|
||||||
"steam": "Steam",
|
"steam": "Steam",
|
||||||
"epic_games": "Epic Games",
|
|
||||||
"roblox": "Roblox",
|
|
||||||
|
|
||||||
# Shopping
|
|
||||||
"amazon": "Amazon",
|
|
||||||
"ebay": "eBay",
|
|
||||||
|
|
||||||
# Communication
|
|
||||||
"whatsapp": "WhatsApp",
|
"whatsapp": "WhatsApp",
|
||||||
"telegram": "Telegram",
|
"telegram": "Telegram",
|
||||||
"discord": "Discord",
|
"discord": "Discord",
|
||||||
|
|
||||||
# Other
|
|
||||||
"adult": "Adult Content",
|
|
||||||
"gambling": "Gambling Sites",
|
|
||||||
"torrents": "Torrent Sites",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Service attributes
|
# Service attributes
|
||||||
@@ -78,15 +58,9 @@ ATTR_CLIENT_NAME: Final = "client_name"
|
|||||||
ATTR_SERVICES: Final = "services"
|
ATTR_SERVICES: Final = "services"
|
||||||
ATTR_DURATION: Final = "duration"
|
ATTR_DURATION: Final = "duration"
|
||||||
ATTR_CLIENTS: Final = "clients"
|
ATTR_CLIENTS: Final = "clients"
|
||||||
ATTR_CLIENT_PATTERN: Final = "client_pattern"
|
|
||||||
ATTR_SETTINGS: Final = "settings"
|
|
||||||
|
|
||||||
# Icons
|
# Icons
|
||||||
ICON_HUB: Final = "mdi:router-network"
|
|
||||||
ICON_PROTECTION: Final = "mdi:shield"
|
ICON_PROTECTION: Final = "mdi:shield"
|
||||||
ICON_PROTECTION_OFF: Final = "mdi:shield-off"
|
ICON_PROTECTION_OFF: Final = "mdi:shield-off"
|
||||||
ICON_CLIENT: Final = "mdi:devices"
|
ICON_CLIENT: Final = "mdi:devices"
|
||||||
ICON_CLIENT_OFFLINE: Final = "mdi:devices-off"
|
|
||||||
ICON_BLOCKED_SERVICE: Final = "mdi:block-helper"
|
|
||||||
ICON_ALLOWED_SERVICE: Final = "mdi:check-circle"
|
|
||||||
ICON_STATISTICS: Final = "mdi:chart-line"
|
ICON_STATISTICS: Final = "mdi:chart-line"
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
"""Sensor platform for AdGuard Control Hub integration."""
|
"""Sensor platform for AdGuard Control Hub integration."""
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from homeassistant.components.sensor import SensorEntity, SensorDeviceClass, SensorStateClass
|
from homeassistant.components.sensor import SensorEntity, SensorStateClass
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import PERCENTAGE
|
from homeassistant.const import PERCENTAGE
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@@ -100,7 +99,6 @@ class AdGuardBlockingPercentageSensor(AdGuardBaseSensor):
|
|||||||
self._attr_icon = "mdi:percent"
|
self._attr_icon = "mdi:percent"
|
||||||
self._attr_state_class = SensorStateClass.MEASUREMENT
|
self._attr_state_class = SensorStateClass.MEASUREMENT
|
||||||
self._attr_native_unit_of_measurement = PERCENTAGE
|
self._attr_native_unit_of_measurement = PERCENTAGE
|
||||||
self._attr_device_class = SensorDeviceClass.POWER_FACTOR
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def native_value(self) -> float | None:
|
def native_value(self) -> float | None:
|
||||||
|
@@ -19,7 +19,6 @@ from .const import (
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Service schemas
|
|
||||||
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
SCHEMA_BLOCK_SERVICES = vol.Schema({
|
||||||
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
vol.Required(ATTR_CLIENT_NAME): cv.string,
|
||||||
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
vol.Required(ATTR_SERVICES): vol.All(cv.ensure_list, [vol.In(BLOCKED_SERVICES.keys())]),
|
||||||
@@ -30,13 +29,6 @@ SCHEMA_EMERGENCY_UNBLOCK = vol.Schema({
|
|||||||
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
vol.Optional(ATTR_CLIENTS, default=["all"]): vol.All(cv.ensure_list, [cv.string]),
|
||||||
})
|
})
|
||||||
|
|
||||||
SERVICE_BLOCK_SERVICES = "block_services"
|
|
||||||
SERVICE_UNBLOCK_SERVICES = "unblock_services"
|
|
||||||
SERVICE_EMERGENCY_UNBLOCK = "emergency_unblock"
|
|
||||||
SERVICE_ADD_CLIENT = "add_client"
|
|
||||||
SERVICE_REMOVE_CLIENT = "remove_client"
|
|
||||||
SERVICE_BULK_UPDATE_CLIENTS = "bulk_update_clients"
|
|
||||||
|
|
||||||
|
|
||||||
class AdGuardControlHubServices:
|
class AdGuardControlHubServices:
|
||||||
"""Handle services for AdGuard Control Hub."""
|
"""Handle services for AdGuard Control Hub."""
|
||||||
@@ -44,45 +36,27 @@ class AdGuardControlHubServices:
|
|||||||
def __init__(self, hass: HomeAssistant):
|
def __init__(self, hass: HomeAssistant):
|
||||||
"""Initialize the services."""
|
"""Initialize the services."""
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._emergency_unblock_tasks: Dict[str, asyncio.Task] = {}
|
|
||||||
|
|
||||||
def register_services(self) -> None:
|
def register_services(self) -> None:
|
||||||
"""Register all services."""
|
"""Register all services."""
|
||||||
self.hass.services.register(
|
self.hass.services.register(
|
||||||
DOMAIN,
|
DOMAIN, "block_services", self.block_services, schema=SCHEMA_BLOCK_SERVICES
|
||||||
SERVICE_BLOCK_SERVICES,
|
|
||||||
self.block_services,
|
|
||||||
schema=SCHEMA_BLOCK_SERVICES,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hass.services.register(
|
self.hass.services.register(
|
||||||
DOMAIN,
|
DOMAIN, "unblock_services", self.unblock_services, schema=SCHEMA_BLOCK_SERVICES
|
||||||
SERVICE_UNBLOCK_SERVICES,
|
|
||||||
self.unblock_services,
|
|
||||||
schema=SCHEMA_BLOCK_SERVICES,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hass.services.register(
|
self.hass.services.register(
|
||||||
DOMAIN,
|
DOMAIN, "emergency_unblock", self.emergency_unblock, schema=SCHEMA_EMERGENCY_UNBLOCK
|
||||||
SERVICE_EMERGENCY_UNBLOCK,
|
|
||||||
self.emergency_unblock,
|
|
||||||
schema=SCHEMA_EMERGENCY_UNBLOCK,
|
|
||||||
)
|
)
|
||||||
|
self.hass.services.register(DOMAIN, "add_client", self.add_client)
|
||||||
# Additional services would go here
|
self.hass.services.register(DOMAIN, "remove_client", self.remove_client)
|
||||||
self.hass.services.register(DOMAIN, SERVICE_ADD_CLIENT, self.add_client)
|
self.hass.services.register(DOMAIN, "bulk_update_clients", self.bulk_update_clients)
|
||||||
self.hass.services.register(DOMAIN, SERVICE_REMOVE_CLIENT, self.remove_client)
|
|
||||||
self.hass.services.register(DOMAIN, SERVICE_BULK_UPDATE_CLIENTS, self.bulk_update_clients)
|
|
||||||
|
|
||||||
def unregister_services(self) -> None:
|
def unregister_services(self) -> None:
|
||||||
"""Unregister all services."""
|
"""Unregister all services."""
|
||||||
services = [
|
services = [
|
||||||
SERVICE_BLOCK_SERVICES,
|
"block_services", "unblock_services", "emergency_unblock",
|
||||||
SERVICE_UNBLOCK_SERVICES,
|
"add_client", "remove_client", "bulk_update_clients"
|
||||||
SERVICE_EMERGENCY_UNBLOCK,
|
|
||||||
SERVICE_ADD_CLIENT,
|
|
||||||
SERVICE_REMOVE_CLIENT,
|
|
||||||
SERVICE_BULK_UPDATE_CLIENTS,
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for service in services:
|
for service in services:
|
||||||
@@ -114,8 +88,6 @@ class AdGuardControlHubServices:
|
|||||||
client_name = call.data[ATTR_CLIENT_NAME]
|
client_name = call.data[ATTR_CLIENT_NAME]
|
||||||
services = call.data[ATTR_SERVICES]
|
services = call.data[ATTR_SERVICES]
|
||||||
|
|
||||||
_LOGGER.info("Unblocking services %s for client %s", services, client_name)
|
|
||||||
|
|
||||||
for entry_data in self.hass.data[DOMAIN].values():
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
api: AdGuardHomeAPI = entry_data["api"]
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
try:
|
try:
|
||||||
@@ -141,25 +113,22 @@ class AdGuardControlHubServices:
|
|||||||
try:
|
try:
|
||||||
if "all" in clients:
|
if "all" in clients:
|
||||||
await api.set_protection(False)
|
await api.set_protection(False)
|
||||||
task = asyncio.create_task(self._delayed_enable_protection(api, duration))
|
# Re-enable after duration
|
||||||
self._emergency_unblock_tasks[f"{api.host}:{api.port}"] = task
|
async def delayed_enable():
|
||||||
except Exception as err:
|
await asyncio.sleep(duration)
|
||||||
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
|
||||||
|
|
||||||
async def _delayed_enable_protection(self, api: AdGuardHomeAPI, delay: int) -> None:
|
|
||||||
"""Re-enable protection after delay."""
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
try:
|
try:
|
||||||
await api.set_protection(True)
|
await api.set_protection(True)
|
||||||
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
_LOGGER.info("Emergency unblock expired - protection re-enabled")
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to re-enable protection: %s", err)
|
_LOGGER.error("Failed to re-enable protection: %s", err)
|
||||||
|
|
||||||
|
asyncio.create_task(delayed_enable())
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Failed to execute emergency unblock: %s", err)
|
||||||
|
|
||||||
async def add_client(self, call: ServiceCall) -> None:
|
async def add_client(self, call: ServiceCall) -> None:
|
||||||
"""Add a new client."""
|
"""Add a new client."""
|
||||||
client_data = dict(call.data)
|
client_data = dict(call.data)
|
||||||
_LOGGER.info("Adding new client: %s", client_data.get("name"))
|
|
||||||
|
|
||||||
for entry_data in self.hass.data[DOMAIN].values():
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
api: AdGuardHomeAPI = entry_data["api"]
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
try:
|
try:
|
||||||
@@ -171,17 +140,14 @@ class AdGuardControlHubServices:
|
|||||||
async def remove_client(self, call: ServiceCall) -> None:
|
async def remove_client(self, call: ServiceCall) -> None:
|
||||||
"""Remove a client."""
|
"""Remove a client."""
|
||||||
client_name = call.data.get("name")
|
client_name = call.data.get("name")
|
||||||
_LOGGER.info("Removing client: %s", client_name)
|
|
||||||
|
|
||||||
for entry_data in self.hass.data[DOMAIN].values():
|
for entry_data in self.hass.data[DOMAIN].values():
|
||||||
api: AdGuardHomeAPI = entry_data["api"]
|
api: AdGuardHomeAPI = entry_data["api"]
|
||||||
try:
|
try:
|
||||||
await api.delete_client(client_name)
|
await api.delete_client(client_name)
|
||||||
_LOGGER.info("Successfully removed client: %s", client_name)
|
_LOGGER.info("Successfully removed client: %s", client_name)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to remove client %s: %s", client_name, err)
|
_LOGGER.error("Failed to remove client: %s", err)
|
||||||
|
|
||||||
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
async def bulk_update_clients(self, call: ServiceCall) -> None:
|
||||||
"""Update multiple clients matching a pattern."""
|
"""Bulk update clients."""
|
||||||
_LOGGER.info("Bulk update clients called")
|
_LOGGER.info("Bulk update clients called")
|
||||||
# Implementation would go here
|
|
||||||
|
@@ -15,25 +15,13 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
"cannot_connect": "Failed to connect to AdGuard Home. Please check your host, port, and credentials.",
|
"cannot_connect": "Failed to connect to AdGuard Home",
|
||||||
"invalid_auth": "Invalid username or password",
|
"invalid_auth": "Invalid username or password",
|
||||||
"timeout": "Connection timeout. Please check your network connection.",
|
"timeout": "Connection timeout",
|
||||||
"unknown": "An unexpected error occurred"
|
"unknown": "Unexpected error occurred"
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
"already_configured": "AdGuard Control Hub is already configured for this host and port"
|
"already_configured": "AdGuard Control Hub is already configured"
|
||||||
}
|
|
||||||
},
|
|
||||||
"options": {
|
|
||||||
"step": {
|
|
||||||
"init": {
|
|
||||||
"title": "AdGuard Control Hub Options",
|
|
||||||
"description": "Configure advanced options",
|
|
||||||
"data": {
|
|
||||||
"scan_interval": "Update interval (seconds)",
|
|
||||||
"timeout": "Connection timeout (seconds)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -24,11 +24,9 @@ async def async_setup_entry(
|
|||||||
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
coordinator = hass.data[DOMAIN][config_entry.entry_id]["coordinator"]
|
||||||
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
api = hass.data[DOMAIN][config_entry.entry_id]["api"]
|
||||||
|
|
||||||
entities = []
|
entities = [AdGuardProtectionSwitch(coordinator, api)]
|
||||||
# Add global protection switch
|
|
||||||
entities.append(AdGuardProtectionSwitch(coordinator, api))
|
|
||||||
|
|
||||||
# Add client switches
|
# Add client switches if clients exist
|
||||||
for client_name in coordinator.clients.keys():
|
for client_name in coordinator.clients.keys():
|
||||||
entities.append(AdGuardClientSwitch(coordinator, api, client_name))
|
entities.append(AdGuardClientSwitch(coordinator, api, client_name))
|
||||||
|
|
||||||
@@ -74,7 +72,6 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
|
|||||||
try:
|
try:
|
||||||
await self.api.set_protection(True)
|
await self.api.set_protection(True)
|
||||||
await self.coordinator.async_request_refresh()
|
await self.coordinator.async_request_refresh()
|
||||||
_LOGGER.info("AdGuard protection enabled")
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to enable AdGuard protection: %s", err)
|
_LOGGER.error("Failed to enable AdGuard protection: %s", err)
|
||||||
raise
|
raise
|
||||||
@@ -84,7 +81,6 @@ class AdGuardProtectionSwitch(AdGuardBaseSwitch):
|
|||||||
try:
|
try:
|
||||||
await self.api.set_protection(False)
|
await self.api.set_protection(False)
|
||||||
await self.coordinator.async_request_refresh()
|
await self.coordinator.async_request_refresh()
|
||||||
_LOGGER.info("AdGuard protection disabled")
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to disable AdGuard protection: %s", err)
|
_LOGGER.error("Failed to disable AdGuard protection: %s", err)
|
||||||
raise
|
raise
|
||||||
@@ -123,7 +119,6 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
|
|||||||
}
|
}
|
||||||
await self.api.update_client(update_data)
|
await self.api.update_client(update_data)
|
||||||
await self.coordinator.async_request_refresh()
|
await self.coordinator.async_request_refresh()
|
||||||
_LOGGER.info("Enabled protection for client %s", self.client_name)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
|
_LOGGER.error("Failed to enable protection for %s: %s", self.client_name, err)
|
||||||
raise
|
raise
|
||||||
@@ -139,7 +134,6 @@ class AdGuardClientSwitch(AdGuardBaseSwitch):
|
|||||||
}
|
}
|
||||||
await self.api.update_client(update_data)
|
await self.api.update_client(update_data)
|
||||||
await self.coordinator.async_request_refresh()
|
await self.coordinator.async_request_refresh()
|
||||||
_LOGGER.info("Disabled protection for client %s", self.client_name)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
|
_LOGGER.error("Failed to disable protection for %s: %s", self.client_name, err)
|
||||||
raise
|
raise
|
||||||
|
@@ -2,8 +2,6 @@
|
|||||||
"name": "AdGuard Control Hub",
|
"name": "AdGuard Control Hub",
|
||||||
"content_in_root": false,
|
"content_in_root": false,
|
||||||
"filename": "adguard_hub",
|
"filename": "adguard_hub",
|
||||||
"country": ["US", "GB", "CA", "AU", "DE", "FR", "NL", "SE", "NO", "DK"],
|
|
||||||
"homeassistant": "2025.1.0",
|
"homeassistant": "2025.1.0",
|
||||||
"render_readme": true,
|
|
||||||
"iot_class": "Local Polling"
|
"iot_class": "Local Polling"
|
||||||
}
|
}
|
11
info.md
11
info.md
@@ -1,14 +1,11 @@
|
|||||||
# AdGuard Control Hub
|
# AdGuard Control Hub
|
||||||
|
|
||||||
The complete Home Assistant integration for AdGuard Home network management.
|
Complete Home Assistant integration for AdGuard Home network management.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Smart client management and discovery
|
- Smart client management
|
||||||
- Granular service blocking controls
|
- Service blocking controls
|
||||||
|
- Real-time statistics
|
||||||
- Emergency unblock capabilities
|
- Emergency unblock capabilities
|
||||||
- Real-time statistics and monitoring
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
Install via HACS or manually extract to `custom_components/adguard_hub/`
|
Install via HACS or manually extract to `custom_components/adguard_hub/`
|
||||||
|
|
||||||
Restart Home Assistant and add via Integrations UI.
|
|
@@ -1,20 +1,11 @@
|
|||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 127
|
line-length = 127
|
||||||
target-version = ['py313']
|
target-version = ['py313']
|
||||||
include = '\.pyi?$'
|
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
line_length = 127
|
line_length = 127
|
||||||
multi_line_output = 3
|
|
||||||
include_trailing_comma = true
|
|
||||||
force_grid_wrap = 0
|
|
||||||
use_parentheses = true
|
|
||||||
ensure_newline_before_comments = true
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
python_version = "3.13"
|
python_version = "3.13"
|
||||||
warn_return_any = true
|
|
||||||
warn_unused_configs = true
|
|
||||||
disallow_untyped_defs = true
|
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
@@ -3,11 +3,7 @@ black==24.3.0
|
|||||||
flake8==7.0.0
|
flake8==7.0.0
|
||||||
isort==5.13.2
|
isort==5.13.2
|
||||||
mypy==1.9.0
|
mypy==1.9.0
|
||||||
bandit==1.7.7
|
|
||||||
safety==3.1.0
|
|
||||||
pytest==8.1.1
|
pytest==8.1.1
|
||||||
pytest-homeassistant-custom-component==0.13.281
|
pytest-homeassistant-custom-component==0.13.281
|
||||||
pytest-cov==5.0.0
|
pytest-cov==5.0.0
|
||||||
|
|
||||||
# Home Assistant testing
|
|
||||||
homeassistant==2025.9.4
|
homeassistant==2025.9.4
|
@@ -1,12 +1,12 @@
|
|||||||
"""Test API functionality."""
|
"""Test API functionality."""
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
from custom_components.adguard_hub.api import AdGuardHomeAPI
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_session():
|
def mock_session():
|
||||||
"""Mock aiohttp session."""
|
"""Mock aiohttp session with proper async context manager."""
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
response = MagicMock()
|
response = MagicMock()
|
||||||
response.raise_for_status = MagicMock()
|
response.raise_for_status = MagicMock()
|
||||||
@@ -14,13 +14,12 @@ def mock_session():
|
|||||||
response.status = 200
|
response.status = 200
|
||||||
response.content_length = 100
|
response.content_length = 100
|
||||||
|
|
||||||
# Create async context manager for session.request
|
# Properly mock the async context manager
|
||||||
async def mock_request(*args, **kwargs):
|
context_manager = MagicMock()
|
||||||
return response
|
context_manager.__aenter__ = AsyncMock(return_value=response)
|
||||||
|
context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
session.request = MagicMock()
|
session.request = MagicMock(return_value=context_manager)
|
||||||
session.request.return_value.__aenter__ = AsyncMock(return_value=response)
|
|
||||||
session.request.return_value.__aexit__ = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@@ -38,6 +37,7 @@ async def test_api_connection(mock_session):
|
|||||||
|
|
||||||
result = await api.test_connection()
|
result = await api.test_connection()
|
||||||
assert result is True
|
assert result is True
|
||||||
|
mock_session.request.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -51,6 +51,7 @@ async def test_api_get_status(mock_session):
|
|||||||
|
|
||||||
status = await api.get_status()
|
status = await api.get_status()
|
||||||
assert status == {"status": "ok"}
|
assert status == {"status": "ok"}
|
||||||
|
mock_session.request.assert_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -65,13 +66,12 @@ async def test_api_context_manager():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_api_error_handling():
|
async def test_api_error_handling():
|
||||||
"""Test API error handling."""
|
"""Test API error handling."""
|
||||||
from custom_components.adguard_hub.api import AdGuardConnectionError
|
|
||||||
|
|
||||||
# Test with a session that raises an exception
|
# Test with a session that raises an exception
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.request = MagicMock()
|
context_manager = MagicMock()
|
||||||
session.request.return_value.__aenter__ = AsyncMock(side_effect=Exception("Connection error"))
|
context_manager.__aenter__ = AsyncMock(side_effect=Exception("Connection error"))
|
||||||
session.request.return_value.__aexit__ = AsyncMock(return_value=None)
|
context_manager.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
session.request = MagicMock(return_value=context_manager)
|
||||||
|
|
||||||
api = AdGuardHomeAPI(
|
api = AdGuardHomeAPI(
|
||||||
host="test-host",
|
host="test-host",
|
||||||
@@ -79,5 +79,5 @@ async def test_api_error_handling():
|
|||||||
session=session
|
session=session
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception): # Should raise AdGuardHomeError
|
with pytest.raises(Exception):
|
||||||
await api.get_status()
|
await api.get_status()
|
||||||
|
@@ -90,9 +90,9 @@ async def test_setup_entry_connection_failure(hass: HomeAssistant, mock_config_e
|
|||||||
mock_api.test_connection = AsyncMock(return_value=False)
|
mock_api.test_connection = AsyncMock(return_value=False)
|
||||||
|
|
||||||
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
|
with patch("custom_components.adguard_hub.AdGuardHomeAPI", return_value=mock_api), \
|
||||||
patch("custom_components.adguard_hub.async_get_clientsession"), \
|
patch("custom_components.adguard_hub.async_get_clientsession"):
|
||||||
pytest.raises(Exception): # Should raise ConfigEntryNotReady
|
|
||||||
|
|
||||||
|
with pytest.raises(Exception): # Should raise ConfigEntryNotReady
|
||||||
await async_setup_entry(hass, mock_config_entry)
|
await async_setup_entry(hass, mock_config_entry)
|
||||||
|
|
||||||
|
|
||||||
@@ -154,7 +154,7 @@ def test_services_registration(hass: HomeAssistant):
|
|||||||
"""Test that services are properly registered."""
|
"""Test that services are properly registered."""
|
||||||
from custom_components.adguard_hub.services import AdGuardControlHubServices
|
from custom_components.adguard_hub.services import AdGuardControlHubServices
|
||||||
|
|
||||||
# Create services without running inside an existing event loop
|
# Create services without async context
|
||||||
services = AdGuardControlHubServices(hass)
|
services = AdGuardControlHubServices(hass)
|
||||||
services.register_services()
|
services.register_services()
|
||||||
|
|
||||||
@@ -162,9 +162,9 @@ def test_services_registration(hass: HomeAssistant):
|
|||||||
assert hass.services.has_service(DOMAIN, "block_services")
|
assert hass.services.has_service(DOMAIN, "block_services")
|
||||||
assert hass.services.has_service(DOMAIN, "unblock_services")
|
assert hass.services.has_service(DOMAIN, "unblock_services")
|
||||||
assert hass.services.has_service(DOMAIN, "emergency_unblock")
|
assert hass.services.has_service(DOMAIN, "emergency_unblock")
|
||||||
assert hass.services.has_service(DOMAIN, "bulk_update_clients")
|
|
||||||
assert hass.services.has_service(DOMAIN, "add_client")
|
assert hass.services.has_service(DOMAIN, "add_client")
|
||||||
assert hass.services.has_service(DOMAIN, "remove_client")
|
assert hass.services.has_service(DOMAIN, "remove_client")
|
||||||
|
assert hass.services.has_service(DOMAIN, "bulk_update_clients")
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
services.unregister_services()
|
services.unregister_services()
|
||||||
|
Reference in New Issue
Block a user