diff --git a/flight-comparator/api_server.py b/flight-comparator/api_server.py index 545f5ab..d6ec9f1 100644 --- a/flight-comparator/api_server.py +++ b/flight-comparator/api_server.py @@ -40,7 +40,7 @@ T = TypeVar('T') # Import existing modules from airports import download_and_build_airport_data from database import get_connection -from scan_processor import start_scan_processor +from scan_processor import start_scan_processor, start_resume_processor, pause_scan_task, stop_scan_task # ============================================================================= @@ -221,11 +221,13 @@ rate_limiter = RateLimiter() # Rate limit configurations (requests per minute) RATE_LIMITS = { - 'default': (200, 60), # 200 requests per 60 seconds (~3 req/sec) - 'scans': (50, 60), # 50 scan creations per minute - 'logs': (100, 60), # 100 log requests per minute - 'airports': (500, 60), # 500 airport searches per minute - 'schedules': (30, 60), # 30 schedule requests per minute + 'default': (200, 60), # 200 requests per 60 seconds (~3 req/sec) + 'scans': (50, 60), # 50 scan creations per minute + 'logs': (100, 60), # 100 log requests per minute + 'airports': (500, 60), # 500 airport searches per minute + 'schedules': (30, 60), # 30 schedule requests per minute + 'scan_control': (30, 60), # 30 pause/cancel requests per minute + 'scan_resume': (10, 60), # 10 resume requests per minute } @@ -236,7 +238,11 @@ def get_rate_limit_for_path(path: str) -> tuple[str, int, int]: Returns: tuple: (endpoint_name, limit, window) """ - if '/scans' in path and path.count('/') == 3: # POST /api/v1/scans + if '/scans' in path and (path.endswith('/pause') or path.endswith('/cancel')): + return 'scan_control', *RATE_LIMITS['scan_control'] + elif '/scans' in path and path.endswith('/resume'): + return 'scan_resume', *RATE_LIMITS['scan_resume'] + elif '/scans' in path and path.count('/') == 3: # POST /api/v1/scans return 'scans', *RATE_LIMITS['scans'] elif '/logs' in path: return 'logs', *RATE_LIMITS['logs'] @@ -930,6 +936,8 @@ class Scan(BaseModel): seat_class: str = Field(..., description="Seat class") adults: int = Field(..., ge=1, le=9, description="Number of adults") scheduled_scan_id: Optional[int] = Field(None, description="ID of the schedule that created this scan") + started_at: Optional[str] = Field(None, description="ISO timestamp when scan processing started") + completed_at: Optional[str] = Field(None, description="ISO timestamp when scan completed or failed") class ScanCreateResponse(BaseModel): @@ -1254,7 +1262,8 @@ async def create_scan(request: ScanRequest): SELECT id, origin, country, start_date, end_date, created_at, updated_at, status, total_routes, routes_scanned, total_flights, error_message, - seat_class, adults, scheduled_scan_id + seat_class, adults, scheduled_scan_id, + started_at, completed_at FROM scans WHERE id = ? """, (scan_id,)) @@ -1280,7 +1289,9 @@ async def create_scan(request: ScanRequest): error_message=row[11], seat_class=row[12], adults=row[13], - scheduled_scan_id=row[14] if len(row) > 14 else None + scheduled_scan_id=row[14] if len(row) > 14 else None, + started_at=row[15] if len(row) > 15 else None, + completed_at=row[16] if len(row) > 16 else None, ) logging.info(f"Scan created: ID={scan_id}, origin={scan.origin}, country={scan.country}, dates={scan.start_date} to {scan.end_date}") @@ -1330,10 +1341,10 @@ async def list_scans( where_clause = "" params = [] if status: - if status not in ['pending', 'running', 'completed', 'failed']: + if status not in ['pending', 'running', 'completed', 'failed', 'paused', 'cancelled']: raise HTTPException( status_code=400, - detail=f"Invalid status: {status}. Must be one of: pending, running, completed, failed" + detail=f"Invalid status: {status}. Must be one of: pending, running, completed, failed, paused, cancelled" ) where_clause = "WHERE status = ?" params.append(status) @@ -1359,7 +1370,8 @@ async def list_scans( SELECT id, origin, country, start_date, end_date, created_at, updated_at, status, total_routes, routes_scanned, total_flights, error_message, - seat_class, adults, scheduled_scan_id + seat_class, adults, scheduled_scan_id, + started_at, completed_at FROM scans {where_clause} ORDER BY created_at DESC @@ -1387,7 +1399,9 @@ async def list_scans( error_message=row[11], seat_class=row[12], adults=row[13], - scheduled_scan_id=row[14] if len(row) > 14 else None + scheduled_scan_id=row[14] if len(row) > 14 else None, + started_at=row[15] if len(row) > 15 else None, + completed_at=row[16] if len(row) > 16 else None, )) # Build pagination metadata @@ -1428,7 +1442,8 @@ async def get_scan_status(scan_id: int): SELECT id, origin, country, start_date, end_date, created_at, updated_at, status, total_routes, routes_scanned, total_flights, error_message, - seat_class, adults, scheduled_scan_id + seat_class, adults, scheduled_scan_id, + started_at, completed_at FROM scans WHERE id = ? """, (scan_id,)) @@ -1457,7 +1472,9 @@ async def get_scan_status(scan_id: int): error_message=row[11], seat_class=row[12], adults=row[13], - scheduled_scan_id=row[14] if len(row) > 14 else None + scheduled_scan_id=row[14] if len(row) > 14 else None, + started_at=row[15] if len(row) > 15 else None, + completed_at=row[16] if len(row) > 16 else None, ) except HTTPException: @@ -1507,6 +1524,155 @@ async def delete_scan(scan_id: int): raise HTTPException(status_code=500, detail=f"Failed to delete scan: {str(e)}") +@router_v1.post("/scans/{scan_id}/pause") +async def pause_scan(scan_id: int): + """ + Pause a running or pending scan. + + Stops the background task and marks the scan as 'paused'. + The scan can be resumed later via POST /scans/{id}/resume. + Returns 409 if the scan is not in a pauseable state (not pending/running). + """ + try: + conn = get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)) + row = cursor.fetchone() + + if not row: + conn.close() + raise HTTPException(status_code=404, detail=f"Scan not found: {scan_id}") + + if row[0] not in ('pending', 'running'): + conn.close() + raise HTTPException( + status_code=409, + detail=f"Cannot pause a scan with status '{row[0]}'. Only pending or running scans can be paused." + ) + + cursor.execute(""" + UPDATE scans + SET status = 'paused', + completed_at = CURRENT_TIMESTAMP, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (scan_id,)) + conn.commit() + conn.close() + + pause_scan_task(scan_id) + logging.info(f"Scan {scan_id} paused") + + return {"id": scan_id, "status": "paused"} + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to pause scan: {str(e)}") + + +@router_v1.post("/scans/{scan_id}/cancel") +async def cancel_scan(scan_id: int): + """ + Cancel a running or pending scan permanently. + + Stops the background task and marks the scan as 'cancelled'. + Partial results are preserved. Use Re-run to start a new scan. + Returns 409 if the scan is not in a cancellable state. + """ + try: + conn = get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)) + row = cursor.fetchone() + + if not row: + conn.close() + raise HTTPException(status_code=404, detail=f"Scan not found: {scan_id}") + + if row[0] not in ('pending', 'running'): + conn.close() + raise HTTPException( + status_code=409, + detail=f"Cannot cancel a scan with status '{row[0]}'. Only pending or running scans can be cancelled." + ) + + cursor.execute(""" + UPDATE scans + SET status = 'cancelled', + completed_at = CURRENT_TIMESTAMP, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (scan_id,)) + conn.commit() + conn.close() + + stop_scan_task(scan_id) + logging.info(f"Scan {scan_id} cancelled") + + return {"id": scan_id, "status": "cancelled"} + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to cancel scan: {str(e)}") + + +@router_v1.post("/scans/{scan_id}/resume") +async def resume_scan(scan_id: int): + """ + Resume a paused scan. + + Resets progress counters and restarts the background worker. + Already-queried routes are instant cache hits so progress races quickly + through them before settling on uncompleted routes. + Returns 409 if the scan is not paused. + """ + try: + conn = get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)) + row = cursor.fetchone() + + if not row: + conn.close() + raise HTTPException(status_code=404, detail=f"Scan not found: {scan_id}") + + if row[0] != 'paused': + conn.close() + raise HTTPException( + status_code=409, + detail=f"Cannot resume a scan with status '{row[0]}'. Only paused scans can be resumed." + ) + + # Reset counters so the progress bar starts fresh; the processor will race + # through cache hits before slowing on uncompleted routes. + cursor.execute(""" + UPDATE scans + SET status = 'pending', + routes_scanned = 0, + started_at = NULL, + completed_at = NULL, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (scan_id,)) + conn.commit() + conn.close() + + start_resume_processor(scan_id) + logging.info(f"Scan {scan_id} resumed") + + return {"id": scan_id, "status": "pending"} + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to resume scan: {str(e)}") + + @router_v1.get("/scans/{scan_id}/routes", response_model=PaginatedResponse[Route]) async def get_scan_routes( scan_id: int, diff --git a/flight-comparator/database/init_db.py b/flight-comparator/database/init_db.py index d992820..67df59a 100644 --- a/flight-comparator/database/init_db.py +++ b/flight-comparator/database/init_db.py @@ -199,6 +199,108 @@ def _migrate_add_scheduled_scan_id_to_scans(conn, verbose=True): print(" βœ… Migration complete: scheduled_scan_id column added to scans") +def _migrate_add_timing_columns_to_scans(conn, verbose=True): + """ + Migration: add started_at and completed_at columns to the scans table. + + started_at β€” set when status transitions to 'running' + completed_at β€” set when status transitions to 'completed' or 'failed' + Both are nullable so existing rows are unaffected. + """ + cursor = conn.execute("PRAGMA table_info(scans)") + columns = [row[1] for row in cursor.fetchall()] + if not columns: + return # Fresh install: scans table doesn't exist yet β€” schema will create the columns + if 'started_at' in columns and 'completed_at' in columns: + return # Already migrated + + if verbose: + print(" πŸ”„ Migrating scans table: adding started_at and completed_at columns...") + + if 'started_at' not in columns: + conn.execute("ALTER TABLE scans ADD COLUMN started_at TIMESTAMP") + if 'completed_at' not in columns: + conn.execute("ALTER TABLE scans ADD COLUMN completed_at TIMESTAMP") + conn.commit() + + if verbose: + print(" βœ… Migration complete: started_at and completed_at columns added to scans") + + +def _migrate_add_pause_cancel_status(conn, verbose=True): + """ + Migration: Extend status CHECK constraint to include 'paused' and 'cancelled'. + + Needed for cancel/pause/resume scan flow control feature. + Uses the same table-recreation pattern as _migrate_relax_country_constraint + because SQLite doesn't support modifying CHECK constraints in-place. + """ + cursor = conn.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='scans'" + ) + row = cursor.fetchone() + if not row or 'paused' in row[0]: + return # Table doesn't exist yet (fresh install) or already migrated + + if verbose: + print(" πŸ”„ Migrating scans table: adding 'paused' and 'cancelled' status values...") + + # SQLite doesn't support ALTER TABLE MODIFY COLUMN, so recreate the table. + # Use PRAGMA foreign_keys = OFF to avoid FK errors during the swap. + conn.execute("PRAGMA foreign_keys = OFF") + # Drop triggers that reference scans (they are recreated by executescript below). + conn.execute("DROP TRIGGER IF EXISTS update_scans_timestamp") + conn.execute("DROP TRIGGER IF EXISTS update_scan_flight_count_insert") + conn.execute("DROP TRIGGER IF EXISTS update_scan_flight_count_update") + conn.execute("DROP TRIGGER IF EXISTS update_scan_flight_count_delete") + conn.execute(""" + CREATE TABLE scans_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + origin TEXT NOT NULL CHECK(length(origin) = 3), + country TEXT NOT NULL CHECK(length(country) >= 2), + start_date TEXT NOT NULL, + end_date TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP, + completed_at TIMESTAMP, + status TEXT NOT NULL DEFAULT 'pending' + CHECK(status IN ('pending', 'running', 'completed', 'failed', 'cancelled', 'paused')), + total_routes INTEGER NOT NULL DEFAULT 0 CHECK(total_routes >= 0), + routes_scanned INTEGER NOT NULL DEFAULT 0 CHECK(routes_scanned >= 0), + total_flights INTEGER NOT NULL DEFAULT 0 CHECK(total_flights >= 0), + error_message TEXT, + seat_class TEXT DEFAULT 'economy', + adults INTEGER DEFAULT 1 CHECK(adults > 0 AND adults <= 9), + scheduled_scan_id INTEGER, + CHECK(end_date >= start_date), + CHECK(routes_scanned <= total_routes OR total_routes = 0) + ) + """) + # Use named columns to handle different column orderings (ALTER TABLE vs fresh schema). + conn.execute(""" + INSERT INTO scans_new ( + id, origin, country, start_date, end_date, + created_at, updated_at, started_at, completed_at, + status, total_routes, routes_scanned, total_flights, + error_message, seat_class, adults, scheduled_scan_id + ) + SELECT + id, origin, country, start_date, end_date, + created_at, updated_at, started_at, completed_at, + status, total_routes, routes_scanned, total_flights, + error_message, seat_class, adults, scheduled_scan_id + FROM scans + """) + conn.execute("DROP TABLE scans") + conn.execute("ALTER TABLE scans_new RENAME TO scans") + conn.execute("PRAGMA foreign_keys = ON") + conn.commit() + + if verbose: + print(" βœ… Migration complete: status now accepts 'paused' and 'cancelled'") + + def initialize_database(db_path=None, verbose=True): """ Initialize or migrate the database. @@ -245,6 +347,8 @@ def initialize_database(db_path=None, verbose=True): _migrate_relax_country_constraint(conn, verbose) _migrate_add_routes_unique_index(conn, verbose) _migrate_add_scheduled_scan_id_to_scans(conn, verbose) + _migrate_add_timing_columns_to_scans(conn, verbose) + _migrate_add_pause_cancel_status(conn, verbose) # Load and execute schema schema_sql = load_schema() diff --git a/flight-comparator/database/schema.sql b/flight-comparator/database/schema.sql index c39a2a3..124120a 100644 --- a/flight-comparator/database/schema.sql +++ b/flight-comparator/database/schema.sql @@ -28,10 +28,12 @@ CREATE TABLE IF NOT EXISTS scans ( -- Timestamps (auto-managed) created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + started_at TIMESTAMP, -- Set when status transitions to 'running' + completed_at TIMESTAMP, -- Set when status transitions to 'completed' or 'failed' -- Scan status (enforced enum via CHECK) status TEXT NOT NULL DEFAULT 'pending' - CHECK(status IN ('pending', 'running', 'completed', 'failed')), + CHECK(status IN ('pending', 'running', 'completed', 'failed', 'cancelled', 'paused')), -- Progress tracking total_routes INTEGER NOT NULL DEFAULT 0 CHECK(total_routes >= 0), diff --git a/flight-comparator/frontend/src/api.ts b/flight-comparator/frontend/src/api.ts index cbb8c5f..6993a18 100644 --- a/flight-comparator/frontend/src/api.ts +++ b/flight-comparator/frontend/src/api.ts @@ -14,7 +14,7 @@ export interface Scan { country: string; start_date: string; end_date: string; - status: 'pending' | 'running' | 'completed' | 'failed'; + status: 'pending' | 'running' | 'completed' | 'failed' | 'paused' | 'cancelled'; created_at: string; updated_at: string; total_routes: number; @@ -24,6 +24,8 @@ export interface Scan { seat_class: string; adults: number; scheduled_scan_id?: number; + started_at?: string; // ISO-8601 UTC β€” set when status transitions to 'running' + completed_at?: string; // ISO-8601 UTC β€” set when status transitions to 'completed' or 'failed' } export interface Schedule { @@ -160,6 +162,10 @@ export const scanApi = { }, delete: (id: number) => api.delete(`/scans/${id}`), + + pause: (id: number) => api.post(`/scans/${id}/pause`), + cancel: (id: number) => api.post(`/scans/${id}/cancel`), + resume: (id: number) => api.post(`/scans/${id}/resume`), }; export const airportApi = { diff --git a/flight-comparator/frontend/src/components/StatusChip.tsx b/flight-comparator/frontend/src/components/StatusChip.tsx index d1c6f8f..8a514b1 100644 --- a/flight-comparator/frontend/src/components/StatusChip.tsx +++ b/flight-comparator/frontend/src/components/StatusChip.tsx @@ -1,8 +1,8 @@ -import { CheckCircle2, Loader2, Clock, XCircle } from 'lucide-react'; +import { CheckCircle2, Loader2, Clock, XCircle, PauseCircle, Ban } from 'lucide-react'; import type { LucideIcon } from 'lucide-react'; import { cn } from '../lib/utils'; -export type ScanStatus = 'completed' | 'running' | 'pending' | 'failed'; +export type ScanStatus = 'completed' | 'running' | 'pending' | 'failed' | 'paused' | 'cancelled'; interface StatusConfig { icon: LucideIcon; @@ -38,6 +38,18 @@ const CONFIGS: Record = { chipClass: 'bg-[#FDECEA] text-[#A50E0E] border border-[#F5C6C6]', iconClass: 'text-[#A50E0E]', }, + paused: { + icon: PauseCircle, + label: 'paused', + chipClass: 'bg-[#FEF7E0] text-[#7A5200] border border-[#F9D659]', + iconClass: 'text-[#7A5200]', + }, + cancelled: { + icon: Ban, + label: 'cancelled', + chipClass: 'bg-[#F3F3F3] text-[#5F6368] border border-[#DADCE0]', + iconClass: 'text-[#5F6368]', + }, }; interface StatusChipProps { diff --git a/flight-comparator/frontend/src/pages/ScanDetails.tsx b/flight-comparator/frontend/src/pages/ScanDetails.tsx index c2f26b0..e2927f9 100644 --- a/flight-comparator/frontend/src/pages/ScanDetails.tsx +++ b/flight-comparator/frontend/src/pages/ScanDetails.tsx @@ -8,6 +8,7 @@ import { Users, Armchair, Clock, + Timer, ChevronRight, ChevronUp, ChevronDown, @@ -17,6 +18,9 @@ import { RotateCcw, Trash2, Info, + Pause, + Play, + X, } from 'lucide-react'; import { scanApi } from '../api'; import type { Scan, Route, Flight } from '../api'; @@ -25,6 +29,8 @@ import type { ScanStatus } from '../components/StatusChip'; import StatCard from '../components/StatCard'; import EmptyState from '../components/EmptyState'; import { SkeletonStatCard, SkeletonTableRow } from '../components/SkeletonCard'; +import ScanTimer, { formatDuration } from '../components/ScanTimer'; +import { useScanTimer } from '../hooks/useScanTimer'; import { cn } from '../lib/utils'; const formatPrice = (price?: number) => @@ -52,6 +58,13 @@ export default function ScanDetails() { const [rerunning, setRerunning] = useState(false); const [confirmDelete, setConfirmDelete] = useState(false); const [deleting, setDeleting] = useState(false); + const [confirmPause, setConfirmPause] = useState(false); + const [confirmCancel, setConfirmCancel] = useState(false); + const [stopping, setStopping] = useState(false); + const [resuming, setResuming] = useState(false); + + // Must be called unconditionally before any early returns (Rules of Hooks) + const timer = useScanTimer(scan); useEffect(() => { if (id) loadScanDetails(); @@ -156,6 +169,47 @@ export default function ScanDetails() { } }; + const handlePause = async () => { + if (!scan) return; + setStopping(true); + try { + await scanApi.pause(scan.id); + await loadScanDetails(); + } catch { + // fall through + } finally { + setStopping(false); + setConfirmPause(false); + } + }; + + const handleCancel = async () => { + if (!scan) return; + setStopping(true); + try { + await scanApi.cancel(scan.id); + await loadScanDetails(); + } catch { + // fall through + } finally { + setStopping(false); + setConfirmCancel(false); + } + }; + + const handleResume = async () => { + if (!scan) return; + setResuming(true); + try { + await scanApi.resume(scan.id); + await loadScanDetails(); + } catch { + // fall through + } finally { + setResuming(false); + } + }; + const SortIcon = ({ field }: { field: typeof sortField }) => { if (sortField !== field) return ; return sortDirection === 'asc' @@ -261,51 +315,168 @@ export default function ScanDetails() { )} {/* Row 4: actions */} -
- {/* Re-run */} - +
- {/* Delete β€” inline confirm */} - {confirmDelete ? ( -
- Delete this scan? + {/* ── Active (pending / running): Pause + Cancel ── */} + {isActive && ( + <> + {/* Pause β€” inline confirm */} + {confirmPause ? ( +
+ Pause this scan? + + +
+ ) : ( + + )} + + {/* Cancel β€” inline confirm */} + {confirmCancel ? ( +
+ Cancel this scan? + + +
+ ) : ( + + )} + + )} + + {/* ── Paused: Resume + Re-run + Delete ── */} + {scan.status === 'paused' && ( + <> + -
- ) : ( - + + {confirmDelete ? ( +
+ Delete this scan? + + +
+ ) : ( + + )} + + )} + + {/* ── Completed / Failed / Cancelled: Re-run + Delete ── */} + {!isActive && scan.status !== 'paused' && ( + <> + + + {confirmDelete ? ( +
+ Delete this scan? + + +
+ ) : ( + + )} + )}
{/* ── Stat cards ────────────────────────────────────────────── */} -
+
{loading ? ( [0, 1, 2].map(i => ) ) : ( @@ -313,6 +484,14 @@ export default function ScanDetails() { + {!isActive && scan.started_at && scan.completed_at && ( + + )} )}
@@ -340,6 +519,9 @@ export default function ScanDetails() {

{scan.routes_scanned} of {scan.total_routes > 0 ? scan.total_routes : '?'} routes Β· auto-refreshing every 3 s

+ {scan.status === 'running' && scan.started_at && ( + + )}
)} diff --git a/flight-comparator/scan_processor.py b/flight-comparator/scan_processor.py index f7e77c5..65092d0 100644 --- a/flight-comparator/scan_processor.py +++ b/flight-comparator/scan_processor.py @@ -21,6 +21,34 @@ from searcher_v3 import search_multiple_routes logger = logging.getLogger(__name__) +# ───────────────────────────────────────────────────────────────────────────── +# Task registry β€” tracks running asyncio tasks so they can be cancelled. +# ───────────────────────────────────────────────────────────────────────────── + +_running_tasks: dict[int, asyncio.Task] = {} +_cancel_reasons: dict[int, str] = {} + + +def cancel_scan_task(scan_id: int) -> bool: + """Cancel the background task for a scan. Returns True if a task was found and cancelled.""" + task = _running_tasks.get(scan_id) + if task and not task.done(): + task.cancel() + return True + return False + + +def pause_scan_task(scan_id: int) -> bool: + """Signal the running task to stop with status='paused'. Returns True if task was found.""" + _cancel_reasons[scan_id] = 'paused' + return cancel_scan_task(scan_id) + + +def stop_scan_task(scan_id: int) -> bool: + """Signal the running task to stop with status='cancelled'. Returns True if task was found.""" + _cancel_reasons[scan_id] = 'cancelled' + return cancel_scan_task(scan_id) + def _write_route_incremental(scan_id: int, destination: str, dest_name: str, dest_city: str, @@ -156,10 +184,10 @@ async def process_scan(scan_id: int): logger.info(f"[Scan {scan_id}] Scan details: {origin} -> {country_or_airports}, {start_date_str} to {end_date_str}") - # Update status to 'running' + # Update status to 'running' and record when processing started cursor.execute(""" UPDATE scans - SET status = 'running', updated_at = CURRENT_TIMESTAMP + SET status = 'running', started_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ? """, (scan_id,)) conn.commit() @@ -192,6 +220,7 @@ async def process_scan(scan_id: int): UPDATE scans SET status = 'failed', error_message = ?, + completed_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ? """, (f"Failed to resolve airports: {str(e)}", scan_id)) @@ -294,11 +323,12 @@ async def process_scan(scan_id: int): "SELECT COALESCE(SUM(flight_count), 0) FROM routes WHERE scan_id = ?", (scan_id,) ).fetchone()[0] - # Update scan to completed + # Update scan to completed and record finish time cursor.execute(""" UPDATE scans SET status = 'completed', total_flights = ?, + completed_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ? """, (total_flights_saved, scan_id)) @@ -306,6 +336,24 @@ async def process_scan(scan_id: int): logger.info(f"[Scan {scan_id}] βœ… Scan completed successfully! {routes_saved} routes saved with {total_flights_saved} flights") + except asyncio.CancelledError: + reason = _cancel_reasons.pop(scan_id, 'cancelled') + logger.info(f"[Scan {scan_id}] Scan {reason} by user request") + try: + if conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE scans + SET status = ?, + completed_at = CURRENT_TIMESTAMP, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (reason, scan_id)) + conn.commit() + except Exception as update_error: + logger.error(f"[Scan {scan_id}] Failed to update {reason} status: {str(update_error)}") + raise # must re-raise so asyncio marks the task as cancelled + except Exception as e: logger.error(f"[Scan {scan_id}] ❌ Scan failed with error: {str(e)}", exc_info=True) @@ -317,6 +365,7 @@ async def process_scan(scan_id: int): UPDATE scans SET status = 'failed', error_message = ?, + completed_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ? """, (str(e), scan_id)) @@ -340,5 +389,28 @@ def start_scan_processor(scan_id: int): asyncio.Task: The background task """ task = asyncio.create_task(process_scan(scan_id)) + _running_tasks[scan_id] = task + task.add_done_callback(lambda _: _running_tasks.pop(scan_id, None)) logger.info(f"[Scan {scan_id}] Background task created") return task + + +def start_resume_processor(scan_id: int): + """ + Resume processing a paused scan as a background task. + + The API endpoint has already reset status to 'pending' and cleared counters. + process_scan() will transition the status to 'running' and re-run all routes, + getting instant cache hits for already-queried routes. + + Args: + scan_id: The ID of the paused scan to resume + + Returns: + asyncio.Task: The background task + """ + task = asyncio.create_task(process_scan(scan_id)) + _running_tasks[scan_id] = task + task.add_done_callback(lambda _: _running_tasks.pop(scan_id, None)) + logger.info(f"[Scan {scan_id}] Resume task created") + return task diff --git a/flight-comparator/tests/test_api_endpoints.py b/flight-comparator/tests/test_api_endpoints.py index 3568d2f..9bb8f3f 100644 --- a/flight-comparator/tests/test_api_endpoints.py +++ b/flight-comparator/tests/test_api_endpoints.py @@ -245,6 +245,45 @@ class TestScanEndpoints: assert data["data"][0]["destination"] == "FRA" assert data["data"][0]["min_price"] == 50 + def test_get_scan_paused_status(self, client: TestClient, create_test_scan): + """Test that GET /scans/{id} returns paused status correctly.""" + scan_id = create_test_scan(status='paused') + response = client.get(f"/api/v1/scans/{scan_id}") + assert response.status_code == 200 + assert response.json()["status"] == "paused" + + def test_get_scan_cancelled_status(self, client: TestClient, create_test_scan): + """Test that GET /scans/{id} returns cancelled status correctly.""" + scan_id = create_test_scan(status='cancelled') + response = client.get(f"/api/v1/scans/{scan_id}") + assert response.status_code == 200 + assert response.json()["status"] == "cancelled" + + def test_list_scans_filter_paused(self, client: TestClient, create_test_scan): + """Test filtering scans by paused status.""" + create_test_scan(status='paused') + create_test_scan(status='completed') + create_test_scan(status='running') + + response = client.get("/api/v1/scans?status=paused") + + assert response.status_code == 200 + data = response.json() + assert len(data["data"]) == 1 + assert data["data"][0]["status"] == "paused" + + def test_list_scans_filter_cancelled(self, client: TestClient, create_test_scan): + """Test filtering scans by cancelled status.""" + create_test_scan(status='cancelled') + create_test_scan(status='pending') + + response = client.get("/api/v1/scans?status=cancelled") + + assert response.status_code == 200 + data = response.json() + assert len(data["data"]) == 1 + assert data["data"][0]["status"] == "cancelled" + @pytest.mark.unit @pytest.mark.api diff --git a/flight-comparator/tests/test_integration.py b/flight-comparator/tests/test_integration.py index cdced73..aef2104 100644 --- a/flight-comparator/tests/test_integration.py +++ b/flight-comparator/tests/test_integration.py @@ -86,6 +86,25 @@ class TestScanWorkflow: prices = [r["min_price"] for r in routes] assert prices == sorted(prices) + def test_pause_and_resume_preserves_scan_id(self, client: TestClient, create_test_scan): + """Resume returns the same scan id, not a new one (unlike Re-run).""" + scan_id = create_test_scan(status='running') + + # Pause + pause_resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert pause_resp.status_code == 200 + assert pause_resp.json()["id"] == scan_id + + # Resume + resume_resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert resume_resp.status_code == 200 + assert resume_resp.json()["id"] == scan_id + + # Confirm scan still exists with same id + get_resp = client.get(f"/api/v1/scans/{scan_id}") + assert get_resp.status_code == 200 + assert get_resp.json()["id"] == scan_id + @pytest.mark.integration @pytest.mark.database diff --git a/flight-comparator/tests/test_scan_control.py b/flight-comparator/tests/test_scan_control.py new file mode 100644 index 0000000..7a17517 --- /dev/null +++ b/flight-comparator/tests/test_scan_control.py @@ -0,0 +1,370 @@ +""" +Tests for scan control endpoints: pause, cancel, resume. + +Covers API behaviour, DB state, status transitions, rate limit headers, +and schema-level acceptance of the new 'paused' and 'cancelled' values. +""" + +import pytest +import sqlite3 +from fastapi.testclient import TestClient + + +# ============================================================================= +# TestScanControlEndpoints β€” API unit tests +# ============================================================================= + +@pytest.mark.unit +@pytest.mark.api +class TestScanControlEndpoints: + """Tests for pause, cancel, and resume endpoints in isolation.""" + + # ── Pause ────────────────────────────────────────────────────────────── + + def test_pause_running_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "paused" + assert body["id"] == scan_id + + def test_pause_pending_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='pending') + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert resp.status_code == 200 + assert resp.json()["status"] == "paused" + + def test_pause_nonexistent_scan(self, client: TestClient): + resp = client.post("/api/v1/scans/99999/pause") + assert resp.status_code == 404 + + def test_pause_completed_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='completed') + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert resp.status_code == 409 + + def test_pause_already_paused_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='paused') + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert resp.status_code == 409 + + def test_pause_cancelled_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='cancelled') + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert resp.status_code == 409 + + # ── Cancel ───────────────────────────────────────────────────────────── + + def test_cancel_running_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.post(f"/api/v1/scans/{scan_id}/cancel") + assert resp.status_code == 200 + assert resp.json()["status"] == "cancelled" + + def test_cancel_pending_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='pending') + resp = client.post(f"/api/v1/scans/{scan_id}/cancel") + assert resp.status_code == 200 + assert resp.json()["status"] == "cancelled" + + def test_cancel_nonexistent_scan(self, client: TestClient): + resp = client.post("/api/v1/scans/99999/cancel") + assert resp.status_code == 404 + + def test_cancel_completed_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='completed') + resp = client.post(f"/api/v1/scans/{scan_id}/cancel") + assert resp.status_code == 409 + + def test_cancel_already_cancelled_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='cancelled') + resp = client.post(f"/api/v1/scans/{scan_id}/cancel") + assert resp.status_code == 409 + + # ── Resume ───────────────────────────────────────────────────────────── + + def test_resume_paused_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='paused') + resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "pending" + assert body["id"] == scan_id + + def test_resume_nonexistent_scan(self, client: TestClient): + resp = client.post("/api/v1/scans/99999/resume") + assert resp.status_code == 404 + + def test_resume_running_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert resp.status_code == 409 + + def test_resume_cancelled_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='cancelled') + resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert resp.status_code == 409 + + def test_resume_completed_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='completed') + resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert resp.status_code == 409 + + # ── Response shape ────────────────────────────────────────────────────── + + def test_pause_response_shape(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + body = client.post(f"/api/v1/scans/{scan_id}/pause").json() + assert "id" in body + assert "status" in body + + def test_cancel_response_shape(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + body = client.post(f"/api/v1/scans/{scan_id}/cancel").json() + assert "id" in body + assert "status" in body + + def test_resume_response_shape(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='paused') + body = client.post(f"/api/v1/scans/{scan_id}/resume").json() + assert "id" in body + assert "status" in body + + +# ============================================================================= +# TestScanControlDatabaseState β€” verify DB state after operations +# ============================================================================= + +@pytest.mark.database +class TestScanControlDatabaseState: + """Tests that verify SQLite state after pause/cancel/resume operations.""" + + def test_pause_sets_completed_at(self, client: TestClient, create_test_scan, clean_database): + scan_id = create_test_scan(status='running') + client.post(f"/api/v1/scans/{scan_id}/pause") + conn = sqlite3.connect(clean_database) + row = conn.execute("SELECT completed_at FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + assert row[0] is not None + + def test_cancel_sets_completed_at(self, client: TestClient, create_test_scan, clean_database): + scan_id = create_test_scan(status='running') + client.post(f"/api/v1/scans/{scan_id}/cancel") + conn = sqlite3.connect(clean_database) + row = conn.execute("SELECT completed_at FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + assert row[0] is not None + + def test_resume_clears_completed_at(self, client: TestClient, create_test_scan, clean_database): + scan_id = create_test_scan(status='paused') + client.post(f"/api/v1/scans/{scan_id}/resume") + conn = sqlite3.connect(clean_database) + row = conn.execute("SELECT completed_at FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + assert row[0] is None + + def test_resume_resets_started_at_from_old_value(self, client: TestClient, create_test_scan, clean_database): + """After resume, started_at is no longer the old seeded timestamp. + + The endpoint clears started_at; the background processor may then + set a new timestamp immediately. Either way, the old value is gone. + """ + old_timestamp = '2026-01-01 10:00:00' + scan_id = create_test_scan(status='paused') + conn = sqlite3.connect(clean_database) + conn.execute("UPDATE scans SET started_at = ? WHERE id = ?", (old_timestamp, scan_id)) + conn.commit() + conn.close() + + client.post(f"/api/v1/scans/{scan_id}/resume") + + conn = sqlite3.connect(clean_database) + row = conn.execute("SELECT started_at FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + # The endpoint cleared the old timestamp; the processor may have set a new one + assert row[0] != old_timestamp + + def test_resume_resets_routes_scanned(self, client: TestClient, create_test_scan, clean_database): + scan_id = create_test_scan(status='paused') + conn = sqlite3.connect(clean_database) + conn.execute("UPDATE scans SET routes_scanned = 50, total_routes = 100 WHERE id = ?", (scan_id,)) + conn.commit() + conn.close() + client.post(f"/api/v1/scans/{scan_id}/resume") + conn = sqlite3.connect(clean_database) + row = conn.execute("SELECT routes_scanned FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + assert row[0] == 0 + + def test_pause_preserves_routes( + self, client: TestClient, create_test_scan, create_test_route, clean_database + ): + scan_id = create_test_scan(status='running') + create_test_route(scan_id=scan_id, destination='MUC') + client.post(f"/api/v1/scans/{scan_id}/pause") + conn = sqlite3.connect(clean_database) + count = conn.execute( + "SELECT COUNT(*) FROM routes WHERE scan_id = ?", (scan_id,) + ).fetchone()[0] + conn.close() + assert count == 1 + + def test_cancel_preserves_routes( + self, client: TestClient, create_test_scan, create_test_route, clean_database + ): + scan_id = create_test_scan(status='running') + create_test_route(scan_id=scan_id, destination='MUC') + client.post(f"/api/v1/scans/{scan_id}/cancel") + conn = sqlite3.connect(clean_database) + count = conn.execute( + "SELECT COUNT(*) FROM routes WHERE scan_id = ?", (scan_id,) + ).fetchone()[0] + conn.close() + assert count == 1 + + +# ============================================================================= +# TestScanControlStatusTransitions β€” full workflow integration tests +# ============================================================================= + +@pytest.mark.integration +@pytest.mark.database +class TestScanControlStatusTransitions: + """Full workflow tests across multiple API calls.""" + + def test_running_to_paused_to_pending(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + # Pause it + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert resp.json()["status"] == "paused" + # Verify persisted + assert client.get(f"/api/v1/scans/{scan_id}").json()["status"] == "paused" + # Resume β†’ pending (background processor moves to running) + resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert resp.json()["status"] == "pending" + + def test_running_to_cancelled(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.post(f"/api/v1/scans/{scan_id}/cancel") + assert resp.json()["status"] == "cancelled" + assert client.get(f"/api/v1/scans/{scan_id}").json()["status"] == "cancelled" + + def test_pause_then_delete(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='paused') + resp = client.delete(f"/api/v1/scans/{scan_id}") + assert resp.status_code == 204 + + def test_cancel_then_delete(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='cancelled') + resp = client.delete(f"/api/v1/scans/{scan_id}") + assert resp.status_code == 204 + + def test_cannot_delete_running_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.delete(f"/api/v1/scans/{scan_id}") + assert resp.status_code == 409 + + def test_cannot_delete_pending_scan(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='pending') + resp = client.delete(f"/api/v1/scans/{scan_id}") + assert resp.status_code == 409 + + def test_list_scans_filter_paused(self, client: TestClient, create_test_scan): + paused_id = create_test_scan(status='paused') + create_test_scan(status='running') + create_test_scan(status='completed') + resp = client.get("/api/v1/scans?status=paused") + assert resp.status_code == 200 + scans = resp.json()["data"] + assert len(scans) >= 1 + assert all(s["status"] == "paused" for s in scans) + assert any(s["id"] == paused_id for s in scans) + + def test_list_scans_filter_cancelled(self, client: TestClient, create_test_scan): + cancelled_id = create_test_scan(status='cancelled') + create_test_scan(status='running') + resp = client.get("/api/v1/scans?status=cancelled") + assert resp.status_code == 200 + scans = resp.json()["data"] + assert len(scans) >= 1 + assert all(s["status"] == "cancelled" for s in scans) + assert any(s["id"] == cancelled_id for s in scans) + + +# ============================================================================= +# TestScanControlRateLimits β€” rate limit headers on control endpoints +# ============================================================================= + +@pytest.mark.api +class TestScanControlRateLimits: + """Verify that rate limit response headers are present on control endpoints.""" + + def test_pause_rate_limit_headers(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.post(f"/api/v1/scans/{scan_id}/pause") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers + + def test_cancel_rate_limit_headers(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='running') + resp = client.post(f"/api/v1/scans/{scan_id}/cancel") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers + + def test_resume_rate_limit_headers(self, client: TestClient, create_test_scan): + scan_id = create_test_scan(status='paused') + resp = client.post(f"/api/v1/scans/{scan_id}/resume") + assert "x-ratelimit-limit" in resp.headers + assert "x-ratelimit-remaining" in resp.headers + + +# ============================================================================= +# TestScanControlNewStatuses β€” schema-level acceptance of new status values +# ============================================================================= + +@pytest.mark.database +class TestScanControlNewStatuses: + """Verify the new status values are accepted/rejected at the SQLite level.""" + + def test_paused_status_accepted_by_schema(self, clean_database, create_test_scan): + scan_id = create_test_scan(status='pending') + conn = sqlite3.connect(clean_database) + conn.execute("UPDATE scans SET status='paused' WHERE id = ?", (scan_id,)) + conn.commit() + row = conn.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + assert row[0] == 'paused' + + def test_cancelled_status_accepted_by_schema(self, clean_database, create_test_scan): + scan_id = create_test_scan(status='pending') + conn = sqlite3.connect(clean_database) + conn.execute("UPDATE scans SET status='cancelled' WHERE id = ?", (scan_id,)) + conn.commit() + row = conn.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)).fetchone() + conn.close() + assert row[0] == 'cancelled' + + def test_invalid_status_rejected_by_schema(self, clean_database, create_test_scan): + scan_id = create_test_scan(status='pending') + conn = sqlite3.connect(clean_database) + with pytest.raises(sqlite3.IntegrityError): + conn.execute("UPDATE scans SET status='stopped' WHERE id = ?", (scan_id,)) + conn.commit() + conn.close() + + def test_filter_active_scans_excludes_paused(self, clean_database, create_test_scan): + paused_id = create_test_scan(status='paused') + conn = sqlite3.connect(clean_database) + rows = conn.execute("SELECT id FROM active_scans").fetchall() + conn.close() + ids = [r[0] for r in rows] + assert paused_id not in ids + + def test_filter_active_scans_excludes_cancelled(self, clean_database, create_test_scan): + cancelled_id = create_test_scan(status='cancelled') + conn = sqlite3.connect(clean_database) + rows = conn.execute("SELECT id FROM active_scans").fetchall() + conn.close() + ids = [r[0] for r in rows] + assert cancelled_id not in ids diff --git a/flight-comparator/tests/test_scan_processor_control.py b/flight-comparator/tests/test_scan_processor_control.py new file mode 100644 index 0000000..982a1ca --- /dev/null +++ b/flight-comparator/tests/test_scan_processor_control.py @@ -0,0 +1,127 @@ +""" +Tests for scan_processor task registry and control functions. + +Tests cancel_scan_task, pause_scan_task, stop_scan_task, and the +done-callback that removes tasks from the registry on completion. +""" + +import asyncio +import pytest +import sys +import os +from unittest.mock import MagicMock + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from scan_processor import ( + _running_tasks, + _cancel_reasons, + cancel_scan_task, + pause_scan_task, + stop_scan_task, +) + + +class TestScanProcessorControl: + """Tests for task registry and cancel/pause/stop functions.""" + + def teardown_method(self, _method): + """Clean up any test state from _running_tasks and _cancel_reasons.""" + for key in [9001, 8001, 8002, 7001]: + _running_tasks.pop(key, None) + _cancel_reasons.pop(key, None) + + # ── cancel_scan_task ─────────────────────────────────────────────────── + + def test_cancel_scan_task_returns_false_when_no_task(self): + """Returns False when no task is registered for the given scan id.""" + result = cancel_scan_task(99999) + assert result is False + + def test_cancel_scan_task_returns_true_when_task_exists(self): + """Returns True and calls task.cancel() when a live task is registered.""" + mock_task = MagicMock() + mock_task.done.return_value = False + _running_tasks[9001] = mock_task + + result = cancel_scan_task(9001) + + assert result is True + mock_task.cancel.assert_called_once() + + def test_cancel_scan_task_returns_false_for_completed_task(self): + """Returns False when the registered task is already done.""" + mock_task = MagicMock() + mock_task.done.return_value = True + _running_tasks[9001] = mock_task + + result = cancel_scan_task(9001) + + assert result is False + mock_task.cancel.assert_not_called() + + # ── pause_scan_task ──────────────────────────────────────────────────── + + def test_pause_sets_cancel_reason_paused(self): + """pause_scan_task sets _cancel_reasons[id] = 'paused'.""" + mock_task = MagicMock() + mock_task.done.return_value = False + _running_tasks[8001] = mock_task + + pause_scan_task(8001) + + assert _cancel_reasons.get(8001) == 'paused' + + def test_pause_calls_cancel_on_task(self): + """pause_scan_task triggers cancellation of the underlying task.""" + mock_task = MagicMock() + mock_task.done.return_value = False + _running_tasks[8001] = mock_task + + result = pause_scan_task(8001) + + assert result is True + mock_task.cancel.assert_called_once() + + # ── stop_scan_task ───────────────────────────────────────────────────── + + def test_stop_sets_cancel_reason_cancelled(self): + """stop_scan_task sets _cancel_reasons[id] = 'cancelled'.""" + mock_task = MagicMock() + mock_task.done.return_value = False + _running_tasks[8002] = mock_task + + stop_scan_task(8002) + + assert _cancel_reasons.get(8002) == 'cancelled' + + def test_stop_calls_cancel_on_task(self): + """stop_scan_task triggers cancellation of the underlying task.""" + mock_task = MagicMock() + mock_task.done.return_value = False + _running_tasks[8002] = mock_task + + result = stop_scan_task(8002) + + assert result is True + mock_task.cancel.assert_called_once() + + # ── done callback ────────────────────────────────────────────────────── + + def test_task_removed_from_registry_on_completion(self): + """The done-callback registered by start_scan_processor removes the task.""" + + async def run(): + async def quick(): + return + + task = asyncio.create_task(quick()) + _running_tasks[7001] = task + task.add_done_callback(lambda _: _running_tasks.pop(7001, None)) + await task + # Yield to let done callbacks fire + await asyncio.sleep(0) + return 7001 not in _running_tasks + + result = asyncio.run(run()) + assert result is True