feat: add cancel, pause, and resume flow control for scans
Some checks failed
Deploy / deploy (push) Failing after 18s
Some checks failed
Deploy / deploy (push) Failing after 18s
Users running large scans can now pause (keep partial results, resume
later), cancel (stop permanently, partial results preserved), or resume
a paused scan which races through cache hits before continuing.
Backend:
- Extend scans.status CHECK to include 'paused' and 'cancelled'
- Add _migrate_add_pause_cancel_status() table-recreation migration
- scan_processor: _running_tasks/_cancel_reasons registries,
cancel_scan_task/pause_scan_task/stop_scan_task helpers,
CancelledError handler in process_scan(), start_resume_processor()
- api_server: POST /scans/{id}/pause|cancel|resume endpoints with
rate limits (30/min pause+cancel, 10/min resume); list_scans now
accepts paused/cancelled as status filter values
Frontend:
- Scan.status type extended with 'paused' | 'cancelled'
- scanApi.pause/cancel/resume added
- StatusChip: amber PauseCircle chip for paused, grey Ban for cancelled
- ScanDetails: context-aware action row with inline-confirm for
Pause and Cancel; Resume button for paused scans
Tests: 129 total (58 new) across test_scan_control.py,
test_scan_processor_control.py, and additions to existing suites
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -245,6 +245,45 @@ class TestScanEndpoints:
|
||||
assert data["data"][0]["destination"] == "FRA"
|
||||
assert data["data"][0]["min_price"] == 50
|
||||
|
||||
def test_get_scan_paused_status(self, client: TestClient, create_test_scan):
|
||||
"""Test that GET /scans/{id} returns paused status correctly."""
|
||||
scan_id = create_test_scan(status='paused')
|
||||
response = client.get(f"/api/v1/scans/{scan_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "paused"
|
||||
|
||||
def test_get_scan_cancelled_status(self, client: TestClient, create_test_scan):
|
||||
"""Test that GET /scans/{id} returns cancelled status correctly."""
|
||||
scan_id = create_test_scan(status='cancelled')
|
||||
response = client.get(f"/api/v1/scans/{scan_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "cancelled"
|
||||
|
||||
def test_list_scans_filter_paused(self, client: TestClient, create_test_scan):
|
||||
"""Test filtering scans by paused status."""
|
||||
create_test_scan(status='paused')
|
||||
create_test_scan(status='completed')
|
||||
create_test_scan(status='running')
|
||||
|
||||
response = client.get("/api/v1/scans?status=paused")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["status"] == "paused"
|
||||
|
||||
def test_list_scans_filter_cancelled(self, client: TestClient, create_test_scan):
|
||||
"""Test filtering scans by cancelled status."""
|
||||
create_test_scan(status='cancelled')
|
||||
create_test_scan(status='pending')
|
||||
|
||||
response = client.get("/api/v1/scans?status=cancelled")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["status"] == "cancelled"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.api
|
||||
|
||||
@@ -86,6 +86,25 @@ class TestScanWorkflow:
|
||||
prices = [r["min_price"] for r in routes]
|
||||
assert prices == sorted(prices)
|
||||
|
||||
def test_pause_and_resume_preserves_scan_id(self, client: TestClient, create_test_scan):
|
||||
"""Resume returns the same scan id, not a new one (unlike Re-run)."""
|
||||
scan_id = create_test_scan(status='running')
|
||||
|
||||
# Pause
|
||||
pause_resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert pause_resp.status_code == 200
|
||||
assert pause_resp.json()["id"] == scan_id
|
||||
|
||||
# Resume
|
||||
resume_resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert resume_resp.status_code == 200
|
||||
assert resume_resp.json()["id"] == scan_id
|
||||
|
||||
# Confirm scan still exists with same id
|
||||
get_resp = client.get(f"/api/v1/scans/{scan_id}")
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["id"] == scan_id
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.database
|
||||
|
||||
370
flight-comparator/tests/test_scan_control.py
Normal file
370
flight-comparator/tests/test_scan_control.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Tests for scan control endpoints: pause, cancel, resume.
|
||||
|
||||
Covers API behaviour, DB state, status transitions, rate limit headers,
|
||||
and schema-level acceptance of the new 'paused' and 'cancelled' values.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sqlite3
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestScanControlEndpoints — API unit tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.api
|
||||
class TestScanControlEndpoints:
|
||||
"""Tests for pause, cancel, and resume endpoints in isolation."""
|
||||
|
||||
# ── Pause ──────────────────────────────────────────────────────────────
|
||||
|
||||
def test_pause_running_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "paused"
|
||||
assert body["id"] == scan_id
|
||||
|
||||
def test_pause_pending_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='pending')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "paused"
|
||||
|
||||
def test_pause_nonexistent_scan(self, client: TestClient):
|
||||
resp = client.post("/api/v1/scans/99999/pause")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_pause_completed_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='completed')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_pause_already_paused_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_pause_cancelled_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='cancelled')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert resp.status_code == 409
|
||||
|
||||
# ── Cancel ─────────────────────────────────────────────────────────────
|
||||
|
||||
def test_cancel_running_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "cancelled"
|
||||
|
||||
def test_cancel_pending_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='pending')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "cancelled"
|
||||
|
||||
def test_cancel_nonexistent_scan(self, client: TestClient):
|
||||
resp = client.post("/api/v1/scans/99999/cancel")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_cancel_completed_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='completed')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_cancel_already_cancelled_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='cancelled')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
assert resp.status_code == 409
|
||||
|
||||
# ── Resume ─────────────────────────────────────────────────────────────
|
||||
|
||||
def test_resume_paused_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "pending"
|
||||
assert body["id"] == scan_id
|
||||
|
||||
def test_resume_nonexistent_scan(self, client: TestClient):
|
||||
resp = client.post("/api/v1/scans/99999/resume")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_resume_running_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_resume_cancelled_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='cancelled')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_resume_completed_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='completed')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert resp.status_code == 409
|
||||
|
||||
# ── Response shape ──────────────────────────────────────────────────────
|
||||
|
||||
def test_pause_response_shape(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
body = client.post(f"/api/v1/scans/{scan_id}/pause").json()
|
||||
assert "id" in body
|
||||
assert "status" in body
|
||||
|
||||
def test_cancel_response_shape(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
body = client.post(f"/api/v1/scans/{scan_id}/cancel").json()
|
||||
assert "id" in body
|
||||
assert "status" in body
|
||||
|
||||
def test_resume_response_shape(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
body = client.post(f"/api/v1/scans/{scan_id}/resume").json()
|
||||
assert "id" in body
|
||||
assert "status" in body
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestScanControlDatabaseState — verify DB state after operations
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.database
|
||||
class TestScanControlDatabaseState:
|
||||
"""Tests that verify SQLite state after pause/cancel/resume operations."""
|
||||
|
||||
def test_pause_sets_completed_at(self, client: TestClient, create_test_scan, clean_database):
|
||||
scan_id = create_test_scan(status='running')
|
||||
client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
conn = sqlite3.connect(clean_database)
|
||||
row = conn.execute("SELECT completed_at FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] is not None
|
||||
|
||||
def test_cancel_sets_completed_at(self, client: TestClient, create_test_scan, clean_database):
|
||||
scan_id = create_test_scan(status='running')
|
||||
client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
conn = sqlite3.connect(clean_database)
|
||||
row = conn.execute("SELECT completed_at FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] is not None
|
||||
|
||||
def test_resume_clears_completed_at(self, client: TestClient, create_test_scan, clean_database):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
conn = sqlite3.connect(clean_database)
|
||||
row = conn.execute("SELECT completed_at FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] is None
|
||||
|
||||
def test_resume_resets_started_at_from_old_value(self, client: TestClient, create_test_scan, clean_database):
|
||||
"""After resume, started_at is no longer the old seeded timestamp.
|
||||
|
||||
The endpoint clears started_at; the background processor may then
|
||||
set a new timestamp immediately. Either way, the old value is gone.
|
||||
"""
|
||||
old_timestamp = '2026-01-01 10:00:00'
|
||||
scan_id = create_test_scan(status='paused')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
conn.execute("UPDATE scans SET started_at = ? WHERE id = ?", (old_timestamp, scan_id))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
|
||||
conn = sqlite3.connect(clean_database)
|
||||
row = conn.execute("SELECT started_at FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
# The endpoint cleared the old timestamp; the processor may have set a new one
|
||||
assert row[0] != old_timestamp
|
||||
|
||||
def test_resume_resets_routes_scanned(self, client: TestClient, create_test_scan, clean_database):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
conn.execute("UPDATE scans SET routes_scanned = 50, total_routes = 100 WHERE id = ?", (scan_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
conn = sqlite3.connect(clean_database)
|
||||
row = conn.execute("SELECT routes_scanned FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] == 0
|
||||
|
||||
def test_pause_preserves_routes(
|
||||
self, client: TestClient, create_test_scan, create_test_route, clean_database
|
||||
):
|
||||
scan_id = create_test_scan(status='running')
|
||||
create_test_route(scan_id=scan_id, destination='MUC')
|
||||
client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
conn = sqlite3.connect(clean_database)
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM routes WHERE scan_id = ?", (scan_id,)
|
||||
).fetchone()[0]
|
||||
conn.close()
|
||||
assert count == 1
|
||||
|
||||
def test_cancel_preserves_routes(
|
||||
self, client: TestClient, create_test_scan, create_test_route, clean_database
|
||||
):
|
||||
scan_id = create_test_scan(status='running')
|
||||
create_test_route(scan_id=scan_id, destination='MUC')
|
||||
client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
conn = sqlite3.connect(clean_database)
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM routes WHERE scan_id = ?", (scan_id,)
|
||||
).fetchone()[0]
|
||||
conn.close()
|
||||
assert count == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestScanControlStatusTransitions — full workflow integration tests
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.database
|
||||
class TestScanControlStatusTransitions:
|
||||
"""Full workflow tests across multiple API calls."""
|
||||
|
||||
def test_running_to_paused_to_pending(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
# Pause it
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert resp.json()["status"] == "paused"
|
||||
# Verify persisted
|
||||
assert client.get(f"/api/v1/scans/{scan_id}").json()["status"] == "paused"
|
||||
# Resume → pending (background processor moves to running)
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert resp.json()["status"] == "pending"
|
||||
|
||||
def test_running_to_cancelled(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
assert resp.json()["status"] == "cancelled"
|
||||
assert client.get(f"/api/v1/scans/{scan_id}").json()["status"] == "cancelled"
|
||||
|
||||
def test_pause_then_delete(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
resp = client.delete(f"/api/v1/scans/{scan_id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
def test_cancel_then_delete(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='cancelled')
|
||||
resp = client.delete(f"/api/v1/scans/{scan_id}")
|
||||
assert resp.status_code == 204
|
||||
|
||||
def test_cannot_delete_running_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.delete(f"/api/v1/scans/{scan_id}")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_cannot_delete_pending_scan(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='pending')
|
||||
resp = client.delete(f"/api/v1/scans/{scan_id}")
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_list_scans_filter_paused(self, client: TestClient, create_test_scan):
|
||||
paused_id = create_test_scan(status='paused')
|
||||
create_test_scan(status='running')
|
||||
create_test_scan(status='completed')
|
||||
resp = client.get("/api/v1/scans?status=paused")
|
||||
assert resp.status_code == 200
|
||||
scans = resp.json()["data"]
|
||||
assert len(scans) >= 1
|
||||
assert all(s["status"] == "paused" for s in scans)
|
||||
assert any(s["id"] == paused_id for s in scans)
|
||||
|
||||
def test_list_scans_filter_cancelled(self, client: TestClient, create_test_scan):
|
||||
cancelled_id = create_test_scan(status='cancelled')
|
||||
create_test_scan(status='running')
|
||||
resp = client.get("/api/v1/scans?status=cancelled")
|
||||
assert resp.status_code == 200
|
||||
scans = resp.json()["data"]
|
||||
assert len(scans) >= 1
|
||||
assert all(s["status"] == "cancelled" for s in scans)
|
||||
assert any(s["id"] == cancelled_id for s in scans)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestScanControlRateLimits — rate limit headers on control endpoints
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.api
|
||||
class TestScanControlRateLimits:
|
||||
"""Verify that rate limit response headers are present on control endpoints."""
|
||||
|
||||
def test_pause_rate_limit_headers(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/pause")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
def test_cancel_rate_limit_headers(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='running')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/cancel")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
def test_resume_rate_limit_headers(self, client: TestClient, create_test_scan):
|
||||
scan_id = create_test_scan(status='paused')
|
||||
resp = client.post(f"/api/v1/scans/{scan_id}/resume")
|
||||
assert "x-ratelimit-limit" in resp.headers
|
||||
assert "x-ratelimit-remaining" in resp.headers
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestScanControlNewStatuses — schema-level acceptance of new status values
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.database
|
||||
class TestScanControlNewStatuses:
|
||||
"""Verify the new status values are accepted/rejected at the SQLite level."""
|
||||
|
||||
def test_paused_status_accepted_by_schema(self, clean_database, create_test_scan):
|
||||
scan_id = create_test_scan(status='pending')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
conn.execute("UPDATE scans SET status='paused' WHERE id = ?", (scan_id,))
|
||||
conn.commit()
|
||||
row = conn.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] == 'paused'
|
||||
|
||||
def test_cancelled_status_accepted_by_schema(self, clean_database, create_test_scan):
|
||||
scan_id = create_test_scan(status='pending')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
conn.execute("UPDATE scans SET status='cancelled' WHERE id = ?", (scan_id,))
|
||||
conn.commit()
|
||||
row = conn.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)).fetchone()
|
||||
conn.close()
|
||||
assert row[0] == 'cancelled'
|
||||
|
||||
def test_invalid_status_rejected_by_schema(self, clean_database, create_test_scan):
|
||||
scan_id = create_test_scan(status='pending')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
with pytest.raises(sqlite3.IntegrityError):
|
||||
conn.execute("UPDATE scans SET status='stopped' WHERE id = ?", (scan_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def test_filter_active_scans_excludes_paused(self, clean_database, create_test_scan):
|
||||
paused_id = create_test_scan(status='paused')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
rows = conn.execute("SELECT id FROM active_scans").fetchall()
|
||||
conn.close()
|
||||
ids = [r[0] for r in rows]
|
||||
assert paused_id not in ids
|
||||
|
||||
def test_filter_active_scans_excludes_cancelled(self, clean_database, create_test_scan):
|
||||
cancelled_id = create_test_scan(status='cancelled')
|
||||
conn = sqlite3.connect(clean_database)
|
||||
rows = conn.execute("SELECT id FROM active_scans").fetchall()
|
||||
conn.close()
|
||||
ids = [r[0] for r in rows]
|
||||
assert cancelled_id not in ids
|
||||
127
flight-comparator/tests/test_scan_processor_control.py
Normal file
127
flight-comparator/tests/test_scan_processor_control.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Tests for scan_processor task registry and control functions.
|
||||
|
||||
Tests cancel_scan_task, pause_scan_task, stop_scan_task, and the
|
||||
done-callback that removes tasks from the registry on completion.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from scan_processor import (
|
||||
_running_tasks,
|
||||
_cancel_reasons,
|
||||
cancel_scan_task,
|
||||
pause_scan_task,
|
||||
stop_scan_task,
|
||||
)
|
||||
|
||||
|
||||
class TestScanProcessorControl:
|
||||
"""Tests for task registry and cancel/pause/stop functions."""
|
||||
|
||||
def teardown_method(self, _method):
|
||||
"""Clean up any test state from _running_tasks and _cancel_reasons."""
|
||||
for key in [9001, 8001, 8002, 7001]:
|
||||
_running_tasks.pop(key, None)
|
||||
_cancel_reasons.pop(key, None)
|
||||
|
||||
# ── cancel_scan_task ───────────────────────────────────────────────────
|
||||
|
||||
def test_cancel_scan_task_returns_false_when_no_task(self):
|
||||
"""Returns False when no task is registered for the given scan id."""
|
||||
result = cancel_scan_task(99999)
|
||||
assert result is False
|
||||
|
||||
def test_cancel_scan_task_returns_true_when_task_exists(self):
|
||||
"""Returns True and calls task.cancel() when a live task is registered."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = False
|
||||
_running_tasks[9001] = mock_task
|
||||
|
||||
result = cancel_scan_task(9001)
|
||||
|
||||
assert result is True
|
||||
mock_task.cancel.assert_called_once()
|
||||
|
||||
def test_cancel_scan_task_returns_false_for_completed_task(self):
|
||||
"""Returns False when the registered task is already done."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = True
|
||||
_running_tasks[9001] = mock_task
|
||||
|
||||
result = cancel_scan_task(9001)
|
||||
|
||||
assert result is False
|
||||
mock_task.cancel.assert_not_called()
|
||||
|
||||
# ── pause_scan_task ────────────────────────────────────────────────────
|
||||
|
||||
def test_pause_sets_cancel_reason_paused(self):
|
||||
"""pause_scan_task sets _cancel_reasons[id] = 'paused'."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = False
|
||||
_running_tasks[8001] = mock_task
|
||||
|
||||
pause_scan_task(8001)
|
||||
|
||||
assert _cancel_reasons.get(8001) == 'paused'
|
||||
|
||||
def test_pause_calls_cancel_on_task(self):
|
||||
"""pause_scan_task triggers cancellation of the underlying task."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = False
|
||||
_running_tasks[8001] = mock_task
|
||||
|
||||
result = pause_scan_task(8001)
|
||||
|
||||
assert result is True
|
||||
mock_task.cancel.assert_called_once()
|
||||
|
||||
# ── stop_scan_task ─────────────────────────────────────────────────────
|
||||
|
||||
def test_stop_sets_cancel_reason_cancelled(self):
|
||||
"""stop_scan_task sets _cancel_reasons[id] = 'cancelled'."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = False
|
||||
_running_tasks[8002] = mock_task
|
||||
|
||||
stop_scan_task(8002)
|
||||
|
||||
assert _cancel_reasons.get(8002) == 'cancelled'
|
||||
|
||||
def test_stop_calls_cancel_on_task(self):
|
||||
"""stop_scan_task triggers cancellation of the underlying task."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.done.return_value = False
|
||||
_running_tasks[8002] = mock_task
|
||||
|
||||
result = stop_scan_task(8002)
|
||||
|
||||
assert result is True
|
||||
mock_task.cancel.assert_called_once()
|
||||
|
||||
# ── done callback ──────────────────────────────────────────────────────
|
||||
|
||||
def test_task_removed_from_registry_on_completion(self):
|
||||
"""The done-callback registered by start_scan_processor removes the task."""
|
||||
|
||||
async def run():
|
||||
async def quick():
|
||||
return
|
||||
|
||||
task = asyncio.create_task(quick())
|
||||
_running_tasks[7001] = task
|
||||
task.add_done_callback(lambda _: _running_tasks.pop(7001, None))
|
||||
await task
|
||||
# Yield to let done callbacks fire
|
||||
await asyncio.sleep(0)
|
||||
return 7001 not in _running_tasks
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result is True
|
||||
Reference in New Issue
Block a user