Some checks failed
Deploy / deploy (push) Failing after 18s
Users running large scans can now pause (keep partial results, resume
later), cancel (stop permanently, partial results preserved), or resume
a paused scan which races through cache hits before continuing.
Backend:
- Extend scans.status CHECK to include 'paused' and 'cancelled'
- Add _migrate_add_pause_cancel_status() table-recreation migration
- scan_processor: _running_tasks/_cancel_reasons registries,
cancel_scan_task/pause_scan_task/stop_scan_task helpers,
CancelledError handler in process_scan(), start_resume_processor()
- api_server: POST /scans/{id}/pause|cancel|resume endpoints with
rate limits (30/min pause+cancel, 10/min resume); list_scans now
accepts paused/cancelled as status filter values
Frontend:
- Scan.status type extended with 'paused' | 'cancelled'
- scanApi.pause/cancel/resume added
- StatusChip: amber PauseCircle chip for paused, grey Ban for cancelled
- ScanDetails: context-aware action row with inline-confirm for
Pause and Cancel; Resume button for paused scans
Tests: 129 total (58 new) across test_scan_control.py,
test_scan_processor_control.py, and additions to existing suites
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
403 lines
14 KiB
Python
403 lines
14 KiB
Python
"""
|
|
Unit tests for API endpoints.
|
|
|
|
Tests all API endpoints with various scenarios including success cases,
|
|
error cases, validation, pagination, and edge cases.
|
|
"""
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.api
|
|
class TestHealthEndpoint:
|
|
"""Tests for the health check endpoint."""
|
|
|
|
def test_health_endpoint(self, client: TestClient):
|
|
"""Test health endpoint returns 200 OK."""
|
|
response = client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "healthy", "version": "2.0.0"}
|
|
|
|
def test_health_no_rate_limit(self, client: TestClient):
|
|
"""Test health endpoint is excluded from rate limiting."""
|
|
response = client.get("/health")
|
|
|
|
assert "x-ratelimit-limit" not in response.headers
|
|
assert "x-ratelimit-remaining" not in response.headers
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.api
|
|
class TestAirportEndpoints:
|
|
"""Tests for airport search endpoints."""
|
|
|
|
def test_search_airports_valid(self, client: TestClient):
|
|
"""Test airport search with valid query."""
|
|
response = client.get("/api/v1/airports?q=MUC")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert "data" in data
|
|
assert "pagination" in data
|
|
assert isinstance(data["data"], list)
|
|
assert len(data["data"]) > 0
|
|
|
|
# Check first result
|
|
airport = data["data"][0]
|
|
assert "iata" in airport
|
|
assert "name" in airport
|
|
assert "MUC" in airport["iata"]
|
|
|
|
def test_search_airports_pagination(self, client: TestClient):
|
|
"""Test airport search pagination."""
|
|
response = client.get("/api/v1/airports?q=airport&page=1&limit=5")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["pagination"]["page"] == 1
|
|
assert data["pagination"]["limit"] == 5
|
|
assert len(data["data"]) <= 5
|
|
|
|
def test_search_airports_invalid_query_too_short(self, client: TestClient):
|
|
"""Test airport search with query too short."""
|
|
response = client.get("/api/v1/airports?q=M")
|
|
|
|
assert response.status_code == 422
|
|
error = response.json()
|
|
assert error["error"] == "validation_error"
|
|
|
|
def test_search_airports_rate_limit_headers(self, client: TestClient):
|
|
"""Test airport search includes rate limit headers."""
|
|
response = client.get("/api/v1/airports?q=MUC")
|
|
|
|
assert response.status_code == 200
|
|
assert "x-ratelimit-limit" in response.headers
|
|
assert "x-ratelimit-remaining" in response.headers
|
|
assert "x-ratelimit-reset" in response.headers
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.api
|
|
@pytest.mark.database
|
|
class TestScanEndpoints:
|
|
"""Tests for scan management endpoints."""
|
|
|
|
def test_create_scan_valid(self, client: TestClient, sample_scan_data):
|
|
"""Test creating a scan with valid data."""
|
|
response = client.post("/api/v1/scans", json=sample_scan_data)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["status"] == "pending"
|
|
assert data["id"] > 0
|
|
assert data["scan"]["origin"] == sample_scan_data["origin"]
|
|
assert data["scan"]["country"] == sample_scan_data["country"]
|
|
|
|
def test_create_scan_with_defaults(self, client: TestClient):
|
|
"""Test creating a scan with default dates."""
|
|
data = {
|
|
"origin": "MUC",
|
|
"country": "IT",
|
|
"window_months": 3
|
|
}
|
|
|
|
response = client.post("/api/v1/scans", json=data)
|
|
|
|
assert response.status_code == 200
|
|
scan = response.json()["scan"]
|
|
|
|
assert "start_date" in scan
|
|
assert "end_date" in scan
|
|
assert scan["seat_class"] == "economy"
|
|
assert scan["adults"] == 1
|
|
|
|
def test_create_scan_invalid_origin(self, client: TestClient):
|
|
"""Test creating a scan with invalid origin."""
|
|
data = {
|
|
"origin": "INVALID", # Too long
|
|
"country": "DE"
|
|
}
|
|
|
|
response = client.post("/api/v1/scans", json=data)
|
|
|
|
assert response.status_code == 422
|
|
error = response.json()
|
|
assert error["error"] == "validation_error"
|
|
|
|
def test_create_scan_invalid_country(self, client: TestClient):
|
|
"""Test creating a scan with invalid country."""
|
|
data = {
|
|
"origin": "BDS",
|
|
"country": "DEU" # Too long
|
|
}
|
|
|
|
response = client.post("/api/v1/scans", json=data)
|
|
|
|
assert response.status_code == 422
|
|
|
|
def test_list_scans_empty(self, client: TestClient):
|
|
"""Test listing scans when database is empty."""
|
|
response = client.get("/api/v1/scans")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["data"] == []
|
|
assert data["pagination"]["total"] == 0
|
|
|
|
def test_list_scans_with_data(self, client: TestClient, create_test_scan):
|
|
"""Test listing scans with data."""
|
|
# Create test scans
|
|
create_test_scan(origin="BDS", country="DE")
|
|
create_test_scan(origin="MUC", country="IT")
|
|
|
|
response = client.get("/api/v1/scans")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 2
|
|
assert data["pagination"]["total"] == 2
|
|
|
|
def test_list_scans_pagination(self, client: TestClient, create_test_scan):
|
|
"""Test scan list pagination."""
|
|
# Create 5 scans
|
|
for i in range(5):
|
|
create_test_scan(origin="BDS", country="DE")
|
|
|
|
response = client.get("/api/v1/scans?page=1&limit=2")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 2
|
|
assert data["pagination"]["total"] == 5
|
|
assert data["pagination"]["pages"] == 3
|
|
assert data["pagination"]["has_next"] is True
|
|
|
|
def test_list_scans_filter_by_status(self, client: TestClient, create_test_scan):
|
|
"""Test filtering scans by status."""
|
|
create_test_scan(status="pending")
|
|
create_test_scan(status="completed")
|
|
create_test_scan(status="pending")
|
|
|
|
response = client.get("/api/v1/scans?status=pending")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 2
|
|
assert all(scan["status"] == "pending" for scan in data["data"])
|
|
|
|
def test_get_scan_by_id(self, client: TestClient, create_test_scan):
|
|
"""Test getting a specific scan by ID."""
|
|
scan_id = create_test_scan(origin="FRA", country="ES")
|
|
|
|
response = client.get(f"/api/v1/scans/{scan_id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["id"] == scan_id
|
|
assert data["origin"] == "FRA"
|
|
assert data["country"] == "ES"
|
|
|
|
def test_get_scan_not_found(self, client: TestClient):
|
|
"""Test getting a non-existent scan."""
|
|
response = client.get("/api/v1/scans/999")
|
|
|
|
assert response.status_code == 404
|
|
error = response.json()
|
|
assert error["error"] == "not_found"
|
|
assert "999" in error["message"]
|
|
|
|
def test_get_scan_routes_empty(self, client: TestClient, create_test_scan):
|
|
"""Test getting routes for a scan with no routes."""
|
|
scan_id = create_test_scan()
|
|
|
|
response = client.get(f"/api/v1/scans/{scan_id}/routes")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["data"] == []
|
|
assert data["pagination"]["total"] == 0
|
|
|
|
def test_get_scan_routes_with_data(self, client: TestClient, create_test_scan, create_test_route):
|
|
"""Test getting routes for a scan with data."""
|
|
scan_id = create_test_scan()
|
|
create_test_route(scan_id=scan_id, destination="MUC", min_price=100)
|
|
create_test_route(scan_id=scan_id, destination="FRA", min_price=50)
|
|
|
|
response = client.get(f"/api/v1/scans/{scan_id}/routes")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert len(data["data"]) == 2
|
|
# Routes should be ordered by price (cheapest first)
|
|
assert data["data"][0]["destination"] == "FRA"
|
|
assert data["data"][0]["min_price"] == 50
|
|
|
|
def test_get_scan_paused_status(self, client: TestClient, create_test_scan):
|
|
"""Test that GET /scans/{id} returns paused status correctly."""
|
|
scan_id = create_test_scan(status='paused')
|
|
response = client.get(f"/api/v1/scans/{scan_id}")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "paused"
|
|
|
|
def test_get_scan_cancelled_status(self, client: TestClient, create_test_scan):
|
|
"""Test that GET /scans/{id} returns cancelled status correctly."""
|
|
scan_id = create_test_scan(status='cancelled')
|
|
response = client.get(f"/api/v1/scans/{scan_id}")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "cancelled"
|
|
|
|
def test_list_scans_filter_paused(self, client: TestClient, create_test_scan):
|
|
"""Test filtering scans by paused status."""
|
|
create_test_scan(status='paused')
|
|
create_test_scan(status='completed')
|
|
create_test_scan(status='running')
|
|
|
|
response = client.get("/api/v1/scans?status=paused")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["data"]) == 1
|
|
assert data["data"][0]["status"] == "paused"
|
|
|
|
def test_list_scans_filter_cancelled(self, client: TestClient, create_test_scan):
|
|
"""Test filtering scans by cancelled status."""
|
|
create_test_scan(status='cancelled')
|
|
create_test_scan(status='pending')
|
|
|
|
response = client.get("/api/v1/scans?status=cancelled")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["data"]) == 1
|
|
assert data["data"][0]["status"] == "cancelled"
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.api
|
|
class TestLogEndpoints:
|
|
"""Tests for log viewer endpoints."""
|
|
|
|
def test_get_logs_empty(self, client: TestClient):
|
|
"""Test getting logs when buffer is empty."""
|
|
response = client.get("/api/v1/logs")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# May have some startup logs
|
|
assert "data" in data
|
|
assert "pagination" in data
|
|
|
|
def test_get_logs_with_level_filter(self, client: TestClient):
|
|
"""Test filtering logs by level."""
|
|
response = client.get("/api/v1/logs?level=INFO")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
if data["data"]:
|
|
assert all(log["level"] == "INFO" for log in data["data"])
|
|
|
|
def test_get_logs_invalid_level(self, client: TestClient):
|
|
"""Test filtering logs with invalid level."""
|
|
response = client.get("/api/v1/logs?level=INVALID")
|
|
|
|
assert response.status_code == 400
|
|
error = response.json()
|
|
assert error["error"] == "bad_request"
|
|
|
|
def test_get_logs_search(self, client: TestClient):
|
|
"""Test searching logs by text."""
|
|
response = client.get("/api/v1/logs?search=startup")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
if data["data"]:
|
|
assert all("startup" in log["message"].lower() for log in data["data"])
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.api
|
|
class TestErrorHandling:
|
|
"""Tests for error handling."""
|
|
|
|
def test_request_id_in_error(self, client: TestClient):
|
|
"""Test that errors include request ID."""
|
|
response = client.get("/api/v1/scans/999")
|
|
|
|
assert response.status_code == 404
|
|
error = response.json()
|
|
|
|
assert "request_id" in error
|
|
assert len(error["request_id"]) == 8 # UUID shortened to 8 chars
|
|
|
|
def test_request_id_in_headers(self, client: TestClient):
|
|
"""Test that request ID is in headers."""
|
|
response = client.get("/api/v1/scans")
|
|
|
|
assert "x-request-id" in response.headers
|
|
assert len(response.headers["x-request-id"]) == 8
|
|
|
|
def test_validation_error_format(self, client: TestClient):
|
|
"""Test validation error response format."""
|
|
response = client.post("/api/v1/scans", json={"origin": "TOOLONG", "country": "DE"})
|
|
|
|
assert response.status_code == 422
|
|
error = response.json()
|
|
|
|
assert error["error"] == "validation_error"
|
|
assert "errors" in error
|
|
assert isinstance(error["errors"], list)
|
|
assert len(error["errors"]) > 0
|
|
assert "field" in error["errors"][0]
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.api
|
|
class TestRateLimiting:
|
|
"""Tests for rate limiting."""
|
|
|
|
def test_rate_limit_headers_present(self, client: TestClient):
|
|
"""Test that rate limit headers are present."""
|
|
response = client.get("/api/v1/airports?q=MUC")
|
|
|
|
assert "x-ratelimit-limit" in response.headers
|
|
assert "x-ratelimit-remaining" in response.headers
|
|
assert "x-ratelimit-reset" in response.headers
|
|
|
|
def test_rate_limit_decreases(self, client: TestClient):
|
|
"""Test that rate limit remaining decreases."""
|
|
response1 = client.get("/api/v1/airports?q=MUC")
|
|
remaining1 = int(response1.headers["x-ratelimit-remaining"])
|
|
|
|
response2 = client.get("/api/v1/airports?q=MUC")
|
|
remaining2 = int(response2.headers["x-ratelimit-remaining"])
|
|
|
|
assert remaining2 < remaining1
|
|
|
|
def test_rate_limit_exceeded(self, client: TestClient):
|
|
"""Test rate limit exceeded response."""
|
|
# Make requests until limit is reached (scans endpoint has limit of 10)
|
|
for i in range(12):
|
|
response = client.post("/api/v1/scans", json={"origin": "BDS", "country": "DE"})
|
|
|
|
# Should get 429 eventually
|
|
assert response.status_code == 429
|
|
error = response.json()
|
|
assert error["error"] == "rate_limit_exceeded"
|
|
assert "retry_after" in error
|