#!/usr/bin/env python3 """ Flight Radar Web API Server v2.0 Provides REST API for the web frontend to: - Search airports - Configure scans - Retrieve flight data - View application logs API Version: v1 (all endpoints under /api/v1/) Run with: uvicorn api_server:app --reload """ from fastapi import FastAPI, APIRouter, HTTPException, Query, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from typing import Optional, List, Generic, TypeVar from pydantic import BaseModel, Field, validator, ValidationError from contextlib import asynccontextmanager from functools import lru_cache from datetime import datetime, date, timedelta import json import os import re import uuid import traceback import math import logging import time from collections import deque, defaultdict from threading import Lock # Generic type for pagination T = TypeVar('T') # Import existing modules from airports import download_and_build_airport_data from database import get_connection from scan_processor import start_scan_processor # ============================================================================= # In-Memory Log Buffer # ============================================================================= class LogBuffer: """Thread-safe circular buffer for storing application logs in memory.""" def __init__(self, maxlen=1000): self.buffer = deque(maxlen=maxlen) self.lock = Lock() def add(self, log_entry: dict): """Add a log entry to the buffer.""" with self.lock: self.buffer.append(log_entry) def get_all(self) -> List[dict]: """Get all log entries (newest first).""" with self.lock: return list(reversed(self.buffer)) def clear(self): """Clear all log entries.""" with self.lock: self.buffer.clear() class BufferedLogHandler(logging.Handler): """Custom logging handler that stores logs in memory buffer.""" def __init__(self, log_buffer: LogBuffer): super().__init__() self.log_buffer = log_buffer def emit(self, record: logging.LogRecord): """Emit a log record to the buffer.""" try: log_entry = { 'timestamp': datetime.fromtimestamp(record.created).isoformat() + 'Z', 'level': record.levelname, 'message': record.getMessage(), 'module': record.module, 'function': record.funcName, 'line': record.lineno, } # Add exception info if present if record.exc_info: log_entry['exception'] = self.formatter.formatException(record.exc_info) if self.formatter else str(record.exc_info) self.log_buffer.add(log_entry) except Exception: self.handleError(record) # Initialize log buffer log_buffer = LogBuffer(maxlen=1000) # Configure logging to use buffer logger = logging.getLogger() logger.setLevel(logging.INFO) # Add buffered handler buffered_handler = BufferedLogHandler(log_buffer) buffered_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') buffered_handler.setFormatter(formatter) logger.addHandler(buffered_handler) # ============================================================================= # Rate Limiting # ============================================================================= class RateLimiter: """ Sliding window rate limiter with per-IP and per-endpoint tracking. Uses a sliding window algorithm to track requests per IP address and endpoint. Each endpoint has independent rate limiting per IP. Automatically cleans up old entries to prevent memory leaks. """ def __init__(self): self.requests = defaultdict(lambda: defaultdict(deque)) # IP -> endpoint -> deque of timestamps self.lock = Lock() self.last_cleanup = time.time() self.cleanup_interval = 60 # Clean up every 60 seconds def is_allowed(self, client_ip: str, endpoint: str, limit: int, window: int) -> tuple[bool, dict]: """ Check if a request is allowed based on rate limit. Args: client_ip: Client IP address endpoint: Endpoint identifier (e.g., 'scans', 'airports') limit: Maximum number of requests allowed window: Time window in seconds Returns: tuple: (is_allowed, rate_limit_info) rate_limit_info contains: limit, remaining, reset_time """ with self.lock: now = time.time() cutoff = now - window # Get request history for this IP and endpoint request_times = self.requests[client_ip][endpoint] # Remove requests outside the current window while request_times and request_times[0] < cutoff: request_times.popleft() # Calculate remaining requests current_count = len(request_times) remaining = max(0, limit - current_count) # Calculate reset time (when oldest request expires) if request_times: reset_time = int(request_times[0] + window) else: reset_time = int(now + window) # Check if limit exceeded if current_count >= limit: return False, { 'limit': limit, 'remaining': 0, 'reset': reset_time, 'retry_after': int(request_times[0] + window - now) } # Allow request and record it request_times.append(now) # Periodic cleanup if now - self.last_cleanup > self.cleanup_interval: self._cleanup(cutoff) self.last_cleanup = now return True, { 'limit': limit, 'remaining': remaining - 1, # -1 because we just added this request 'reset': reset_time } def _cleanup(self, cutoff: float): """Remove old entries to prevent memory leaks.""" ips_to_remove = [] for ip, endpoints in self.requests.items(): endpoints_to_remove = [] for endpoint, request_times in endpoints.items(): # Remove old requests while request_times and request_times[0] < cutoff: request_times.popleft() # If no requests left, mark endpoint for removal if not request_times: endpoints_to_remove.append(endpoint) # Remove endpoints with no requests for endpoint in endpoints_to_remove: del endpoints[endpoint] # If no endpoints left, mark IP for removal if not endpoints: ips_to_remove.append(ip) # Remove IPs with no recent requests for ip in ips_to_remove: del self.requests[ip] # Initialize rate limiter rate_limiter = RateLimiter() # Rate limit configurations (requests per minute) RATE_LIMITS = { 'default': (200, 60), # 200 requests per 60 seconds (~3 req/sec) 'scans': (50, 60), # 50 scan creations per minute 'logs': (100, 60), # 100 log requests per minute 'airports': (500, 60), # 500 airport searches per minute } def get_rate_limit_for_path(path: str) -> tuple[str, int, int]: """ Get rate limit configuration for a given path. Returns: tuple: (endpoint_name, limit, window) """ if '/scans' in path and path.count('/') == 3: # POST /api/v1/scans return 'scans', *RATE_LIMITS['scans'] elif '/logs' in path: return 'logs', *RATE_LIMITS['logs'] elif '/airports' in path: return 'airports', *RATE_LIMITS['airports'] else: return 'default', *RATE_LIMITS['default'] @asynccontextmanager async def lifespan(app: FastAPI): """Initialize airport data and database on server start.""" logging.info("Flight Radar API v2.0 starting up...") # Initialize airport data try: download_and_build_airport_data() print("✅ Airport database initialized") logging.info("Airport database initialized successfully") except Exception as e: print(f"❌ Failed to initialize airport database: {e}") logging.error(f"Failed to initialize airport database: {e}") # Initialize web app database try: from database import initialize_database initialize_database(verbose=False) print("✅ Web app database initialized") logging.info("Web app database initialized successfully") except Exception as e: print(f"⚠️ Database initialization: {e}") logging.warning(f"Database initialization issue: {e}") # Cleanup stuck scans from previous server session try: conn = get_connection() cursor = conn.cursor() # Find scans stuck in 'running' state cursor.execute(""" SELECT id, origin, country, created_at FROM scans WHERE status = 'running' """) stuck_scans = cursor.fetchall() if stuck_scans: logging.warning(f"Found {len(stuck_scans)} scan(s) stuck in 'running' state from previous session") print(f"⚠️ Found {len(stuck_scans)} stuck scan(s), cleaning up...") # Update stuck scans to 'failed' status for scan_id, origin, country, created_at in stuck_scans: cursor.execute(""" UPDATE scans SET status = 'failed', error_message = 'Server restarted while scan was running', updated_at = CURRENT_TIMESTAMP WHERE id = ? """, (scan_id,)) logging.info(f"Cleaned up stuck scan: ID={scan_id}, origin={origin}, country={country}, created={created_at}") conn.commit() print(f"✅ Cleaned up {len(stuck_scans)} stuck scan(s)") logging.info(f"Successfully cleaned up {len(stuck_scans)} stuck scan(s)") else: logging.info("No stuck scans found - database is clean") conn.close() except Exception as e: logging.error(f"Failed to cleanup stuck scans: {e}", exc_info=True) print(f"⚠️ Scan cleanup warning: {e}") logging.info("Flight Radar API v2.0 startup complete") yield logging.info("Flight Radar API v2.0 shutting down") app = FastAPI( title="Flight Radar API", description="API for discovering and tracking direct flights", version="2.0.0", lifespan=lifespan ) # Configure CORS based on environment # Development: localhost origins # Production: specific frontend URL from environment variable ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "").split(",") if os.getenv("ALLOWED_ORIGINS") else [ "http://localhost:5173", # Vite dev server "http://localhost:3000", # Alternative dev port "http://127.0.0.1:5173", "http://127.0.0.1:3000", "http://localhost", # Docker "http://localhost:80", ] app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request tracking middleware @app.middleware("http") async def add_request_id(request: Request, call_next): """ Add unique request ID to each request for tracking and debugging. Request ID is included in error responses and can be used for log correlation. """ request_id = str(uuid.uuid4())[:8] request.state.request_id = request_id response = await call_next(request) response.headers["X-Request-ID"] = request_id return response # Rate limiting middleware @app.middleware("http") async def rate_limit_middleware(request: Request, call_next): """ Rate limiting middleware using sliding window algorithm. Limits requests per IP address based on endpoint type. Returns 429 Too Many Requests when limit is exceeded. """ # Skip rate limiting for health check if request.url.path == "/health": return await call_next(request) # Get client IP (handle proxy headers) client_ip = request.client.host if forwarded_for := request.headers.get("X-Forwarded-For"): client_ip = forwarded_for.split(",")[0].strip() # Get rate limit for this path endpoint, limit, window = get_rate_limit_for_path(request.url.path) # Check rate limit is_allowed, rate_info = rate_limiter.is_allowed(client_ip, endpoint, limit, window) if not is_allowed: # Log rate limit exceeded logging.warning(f"Rate limit exceeded for IP {client_ip} on {request.url.path}") # Return 429 Too Many Requests return JSONResponse( status_code=429, content={ 'error': 'rate_limit_exceeded', 'message': f'Rate limit exceeded. Maximum {limit} requests per {window} seconds.', 'limit': rate_info['limit'], 'retry_after': rate_info['retry_after'], 'timestamp': datetime.utcnow().isoformat() + 'Z', 'path': request.url.path, 'request_id': getattr(request.state, 'request_id', 'unknown') }, headers={ 'X-RateLimit-Limit': str(rate_info['limit']), 'X-RateLimit-Remaining': '0', 'X-RateLimit-Reset': str(rate_info['reset']), 'Retry-After': str(rate_info['retry_after']) } ) # Process request and add rate limit headers response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(rate_info['limit']) response.headers["X-RateLimit-Remaining"] = str(rate_info['remaining']) response.headers["X-RateLimit-Reset"] = str(rate_info['reset']) return response # Create API v1 router router_v1 = APIRouter(prefix="/api/v1", tags=["v1"]) # ============================================================================= # Error Handling Middleware & Exception Handlers # ============================================================================= @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): """ Handle Pydantic validation errors with user-friendly messages. Converts technical validation errors into readable format. """ errors = [] for error in exc.errors(): # Extract field name from location tuple field = error['loc'][-1] if error['loc'] else 'unknown' # Get error message msg = error.get('msg', 'Validation error') # For value_error type, extract custom message from ctx if error['type'] == 'value_error' and 'ctx' in error and 'error' in error['ctx']: # This is our custom validator error msg = error['msg'].replace('Value error, ', '') errors.append({ 'field': field, 'message': msg, 'type': error['type'] }) request_id = getattr(request.state, 'request_id', 'unknown') return JSONResponse( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={ 'error': 'validation_error', 'message': 'Invalid input data', 'errors': errors, 'timestamp': datetime.utcnow().isoformat() + 'Z', 'path': request.url.path, 'request_id': request_id } ) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """ Handle HTTP exceptions with consistent format. """ request_id = getattr(request.state, 'request_id', 'unknown') return JSONResponse( status_code=exc.status_code, content={ 'error': get_error_code(exc.status_code), 'message': exc.detail, 'timestamp': datetime.utcnow().isoformat() + 'Z', 'path': request.url.path, 'request_id': request_id } ) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): """ Catch-all handler for unexpected errors. Logs full traceback but returns safe error to user. """ request_id = getattr(request.state, 'request_id', 'unknown') # Log full error details (in production, send to logging service) print(f"\n{'='*60}") print(f"REQUEST ID: {request_id}") print(f"Path: {request.method} {request.url.path}") print(f"Error: {type(exc).__name__}: {str(exc)}") print(f"{'='*60}") traceback.print_exc() print(f"{'='*60}\n") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={ 'error': 'internal_server_error', 'message': 'An unexpected error occurred. Please try again later.', 'timestamp': datetime.utcnow().isoformat() + 'Z', 'path': request.url.path, 'request_id': request_id } ) def get_error_code(status_code: int) -> str: """Map HTTP status code to error code string.""" codes = { 400: 'bad_request', 401: 'unauthorized', 403: 'forbidden', 404: 'not_found', 422: 'validation_error', 429: 'rate_limit_exceeded', 500: 'internal_server_error', 503: 'service_unavailable', } return codes.get(status_code, 'unknown_error') # ============================================================================= # Data Models with Validation # ============================================================================= class Airport(BaseModel): """Airport information model.""" iata: str = Field(..., min_length=3, max_length=3, description="3-letter IATA code") name: str = Field(..., min_length=1, max_length=200, description="Airport name") city: str = Field(..., max_length=100, description="City name") country: str = Field(..., min_length=2, max_length=2, description="2-letter country code") latitude: float = Field(..., ge=-90, le=90, description="Latitude (-90 to 90)") longitude: float = Field(..., ge=-180, le=180, description="Longitude (-180 to 180)") @validator('iata') def validate_iata(cls, v): if not re.match(r'^[A-Z]{3}$', v): raise ValueError('IATA code must be 3 uppercase letters (e.g., MUC, BDS)') return v @validator('country') def validate_country(cls, v): if not re.match(r'^[A-Z]{2}$', v): raise ValueError('Country code must be 2 uppercase letters (e.g., DE, IT)') return v class Country(BaseModel): """Country information model.""" code: str = Field(..., min_length=2, max_length=2, description="2-letter ISO country code") name: str = Field(..., min_length=1, max_length=100, description="Country name") airport_count: int = Field(..., ge=0, description="Number of airports") @validator('code') def validate_code(cls, v): if not re.match(r'^[A-Z]{2}$', v): raise ValueError('Country code must be 2 uppercase letters (e.g., DE, IT)') return v class ScanRequest(BaseModel): """Flight scan request model with comprehensive validation.""" origin: str = Field( ..., min_length=3, max_length=3, description="Origin airport IATA code (3 uppercase letters)" ) destination_country: Optional[str] = Field( None, min_length=2, max_length=2, description="Destination country code (2 uppercase letters)", alias="country" # Allow both 'country' and 'destination_country' ) destinations: Optional[List[str]] = Field( None, description="List of destination airport IATA codes (alternative to country)" ) start_date: Optional[str] = Field( None, description="Start date in ISO format (YYYY-MM-DD). Default: tomorrow" ) end_date: Optional[str] = Field( None, description="End date in ISO format (YYYY-MM-DD). Default: start + window_months" ) window_months: int = Field( default=3, ge=1, le=12, description="Time window in months (1-12)" ) seat_class: str = Field( default="economy", description="Seat class: economy, premium, business, or first" ) adults: int = Field( default=1, ge=1, le=9, description="Number of adults (1-9)" ) @validator('origin') def validate_origin(cls, v): v = v.upper() # Normalize to uppercase if not re.match(r'^[A-Z]{3}$', v): raise ValueError('Origin must be a 3-letter IATA code (e.g., BDS, MUC)') return v @validator('destination_country') def validate_destination_country(cls, v): if v is None: return v v = v.upper() # Normalize to uppercase if not re.match(r'^[A-Z]{2}$', v): raise ValueError('Country must be a 2-letter ISO code (e.g., DE, IT)') return v @validator('destinations') def validate_destinations(cls, v, values): if v is None: return v # Normalize to uppercase and validate each code normalized = [] for code in v: code = code.strip().upper() if not re.match(r'^[A-Z]{3}$', code): raise ValueError(f'Invalid destination airport code: {code}. Must be 3-letter IATA code.') normalized.append(code) # Check for duplicates if len(normalized) != len(set(normalized)): raise ValueError('Destination list contains duplicate airport codes') # Limit to reasonable number if len(normalized) > 50: raise ValueError('Maximum 50 destination airports allowed') if len(normalized) == 0: raise ValueError('At least one destination airport required') return normalized @validator('destinations', pre=False, always=True) def check_destination_mode(cls, v, values): """Ensure either country or destinations is provided, but not both.""" country = values.get('destination_country') if country and v: raise ValueError('Provide either country OR destinations, not both') if not country and not v: raise ValueError('Must provide either country or destinations') return v @validator('start_date') def validate_start_date(cls, v): if v is None: return v try: parsed_date = datetime.strptime(v, '%Y-%m-%d').date() # Allow past dates for historical scans # if parsed_date < date.today(): # raise ValueError('Start date must be today or in the future') return v except ValueError as e: if 'does not match format' in str(e): raise ValueError('Start date must be in ISO format (YYYY-MM-DD), e.g., 2026-04-01') raise @validator('end_date') def validate_end_date(cls, v, values): if v is None: return v try: end = datetime.strptime(v, '%Y-%m-%d').date() if 'start_date' in values and values['start_date']: start = datetime.strptime(values['start_date'], '%Y-%m-%d').date() if end < start: raise ValueError('End date must be on or after start date') return v except ValueError as e: if 'does not match format' in str(e): raise ValueError('End date must be in ISO format (YYYY-MM-DD), e.g., 2026-06-30') raise @validator('seat_class') def validate_seat_class(cls, v): allowed = ['economy', 'premium', 'business', 'first'] v = v.lower() if v not in allowed: raise ValueError(f'Seat class must be one of: {", ".join(allowed)}') return v class Config: allow_population_by_field_name = True # Allow both 'country' and 'destination_country' class ScanStatus(BaseModel): """Scan status model.""" scan_id: str = Field(..., min_length=1, description="Unique scan identifier") status: str = Field(..., description="Scan status: pending, running, completed, or failed") progress: int = Field(..., ge=0, le=100, description="Progress percentage (0-100)") routes_scanned: int = Field(..., ge=0, description="Number of routes scanned") routes_total: int = Field(..., ge=0, description="Total number of routes") flights_found: int = Field(..., ge=0, description="Total flights found") started_at: str = Field(..., description="ISO timestamp when scan started") completed_at: Optional[str] = Field(None, description="ISO timestamp when scan completed") @validator('status') def validate_status(cls, v): allowed = ['pending', 'running', 'completed', 'failed'] if v not in allowed: raise ValueError(f'Status must be one of: {", ".join(allowed)}') return v @validator('routes_scanned') def validate_routes_scanned(cls, v, values): if 'routes_total' in values and values['routes_total'] > 0: if v > values['routes_total']: raise ValueError('routes_scanned cannot exceed routes_total') return v class PaginationMetadata(BaseModel): """Pagination metadata for paginated responses.""" page: int = Field(..., ge=1, description="Current page number") limit: int = Field(..., ge=1, le=500, description="Items per page") total: int = Field(..., ge=0, description="Total number of items") pages: int = Field(..., ge=0, description="Total number of pages") has_next: bool = Field(..., description="Whether there is a next page") has_prev: bool = Field(..., description="Whether there is a previous page") class PaginatedResponse(BaseModel, Generic[T]): """Generic paginated response wrapper.""" data: List[T] = Field(..., description="List of items for current page") pagination: PaginationMetadata = Field(..., description="Pagination metadata") class Config: # Pydantic v2: Enable arbitrary types for Generic support arbitrary_types_allowed = True class Route(BaseModel): """Route model - represents a discovered flight route.""" id: int = Field(..., description="Route ID") scan_id: int = Field(..., description="Parent scan ID") destination: str = Field(..., description="Destination airport IATA code") destination_name: str = Field(..., description="Destination airport name") destination_city: Optional[str] = Field(None, description="Destination city") flight_count: int = Field(..., ge=0, description="Number of flights found") airlines: List[str] = Field(..., description="List of airlines operating this route") min_price: Optional[float] = Field(None, ge=0, description="Minimum price found") max_price: Optional[float] = Field(None, ge=0, description="Maximum price found") avg_price: Optional[float] = Field(None, ge=0, description="Average price") created_at: str = Field(..., description="ISO timestamp when route was discovered") class Flight(BaseModel): """Individual flight discovered by a scan.""" id: int = Field(..., description="Flight ID") scan_id: int = Field(..., description="Parent scan ID") destination: str = Field(..., description="Destination airport IATA code") date: str = Field(..., description="Flight date (YYYY-MM-DD)") airline: Optional[str] = Field(None, description="Operating airline") departure_time: Optional[str] = Field(None, description="Departure time (HH:MM)") arrival_time: Optional[str] = Field(None, description="Arrival time (HH:MM)") price: Optional[float] = Field(None, ge=0, description="Price in EUR") stops: int = Field(0, ge=0, description="Number of stops (0 = direct)") class Scan(BaseModel): """Scan model - represents a flight scan with full details.""" id: int = Field(..., description="Scan ID") origin: str = Field(..., description="Origin airport IATA code") country: str = Field(..., description="Destination country code") start_date: str = Field(..., description="Start date (YYYY-MM-DD)") end_date: str = Field(..., description="End date (YYYY-MM-DD)") created_at: str = Field(..., description="ISO timestamp when scan was created") updated_at: str = Field(..., description="ISO timestamp when scan was last updated") status: str = Field(..., description="Scan status: pending, running, completed, or failed") total_routes: int = Field(..., ge=0, description="Total number of routes to scan") routes_scanned: int = Field(..., ge=0, description="Number of routes scanned so far") total_flights: int = Field(..., ge=0, description="Total number of flights found") error_message: Optional[str] = Field(None, description="Error message if scan failed") seat_class: str = Field(..., description="Seat class") adults: int = Field(..., ge=1, le=9, description="Number of adults") class ScanCreateResponse(BaseModel): """Response after creating a new scan.""" id: int = Field(..., description="Scan ID") status: str = Field(..., description="Scan status") message: str = Field(..., description="Status message") scan: Scan = Field(..., description="Full scan details") class LogEntry(BaseModel): """Log entry model.""" timestamp: str = Field(..., description="ISO timestamp when log was created") level: str = Field(..., description="Log level: DEBUG, INFO, WARNING, ERROR, CRITICAL") message: str = Field(..., description="Log message") module: str = Field(..., description="Module name where log originated") function: str = Field(..., description="Function name where log originated") line: int = Field(..., description="Line number where log originated") exception: Optional[str] = Field(None, description="Exception traceback if present") # ============================================================================= # Root Endpoints (not versioned) # ============================================================================= @app.get("/") async def root(): """API root endpoint.""" return { "name": "Flight Radar API", "version": "2.0.0", "api_version": "v1", "docs": "/docs", "endpoints": { "airports": "/api/v1/airports", "scans": "/api/v1/scans", "logs": "/api/v1/logs" }, "status": "online" } @app.get("/health") async def health_check(): """Health check endpoint for monitoring.""" return {"status": "healthy", "version": "2.0.0"} # ============================================================================= # API v1 - Airport Search & Discovery # ============================================================================= @router_v1.get("/airports", response_model=PaginatedResponse[Airport]) async def search_airports( q: str = Query(..., min_length=2, max_length=100, description="Search query (IATA, city, or country name)"), page: int = Query(1, ge=1, description="Page number (starts at 1)"), limit: int = Query(20, ge=1, le=100, description="Items per page (max 100)") ): """ Search airports by IATA code, name, city, or country with pagination. Examples: /api/v1/airports?q=mun&page=1&limit=20 → First page of Munich results /api/v1/airports?q=FRA → Frankfurt (default page=1, limit=20) /api/v1/airports?q=germany&limit=50 → All German airports, 50 per page /api/v1/airports?q=BDS → Brindisi Returns: Paginated response with airport data and pagination metadata """ try: airports_data = get_airport_data() except FileNotFoundError: raise HTTPException( status_code=503, detail="Airport database unavailable. Please try again later." ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load airport data: {e}") query = q.lower().strip() # Priority buckets — higher bucket = shown first p0_exact_iata: list[Airport] = [] # IATA == query exactly (e.g. "BER") p1_iata_prefix: list[Airport] = [] # IATA starts with query (e.g. "BE" → BER) p2_city_prefix: list[Airport] = [] # city starts with query (e.g. "ber" → Berlin) p3_city_contains: list[Airport] = [] # city contains query p4_name_prefix: list[Airport] = [] # name starts with query p5_name_contains: list[Airport] = [] # name contains query p6_country: list[Airport] = [] # country code contains query seen: set[str] = set() for airport in airports_data: try: iata_l = airport['iata'].lower() city_l = airport.get('city', '').lower() name_l = airport['name'].lower() country_l = airport.get('country', '').lower() if iata_l in seen: continue obj = Airport(**airport) if iata_l == query: p0_exact_iata.append(obj) elif iata_l.startswith(query): p1_iata_prefix.append(obj) elif city_l.startswith(query): p2_city_prefix.append(obj) elif query in city_l: p3_city_contains.append(obj) elif name_l.startswith(query): p4_name_prefix.append(obj) elif query in name_l: p5_name_contains.append(obj) elif query in country_l: p6_country.append(obj) else: continue seen.add(iata_l) except Exception: # Skip airports with invalid data (e.g., invalid IATA codes like 'DU9') continue results = ( p0_exact_iata + p1_iata_prefix + p2_city_prefix + p3_city_contains + p4_name_prefix + p5_name_contains + p6_country ) # Calculate pagination total = len(results) total_pages = math.ceil(total / limit) if total > 0 else 0 # Validate page number if page > total_pages and total_pages > 0: raise HTTPException( status_code=404, detail=f"Page {page} does not exist. Total pages: {total_pages}" ) # Paginate results start_idx = (page - 1) * limit end_idx = start_idx + limit page_results = results[start_idx:end_idx] # Build pagination metadata pagination = PaginationMetadata( page=page, limit=limit, total=total, pages=total_pages, has_next=page < total_pages, has_prev=page > 1 ) return PaginatedResponse(data=page_results, pagination=pagination) @router_v1.get("/airports/country/{country_code}", response_model=List[Airport]) async def get_airports_by_country(country_code: str): """ Get all airports in a specific country. Examples: /api/airports/country/DE → All 95 German airports /api/airports/country/IT → All Italian airports /api/airports/country/US → All US airports """ try: airports_data = get_airport_data() except FileNotFoundError: raise HTTPException( status_code=503, detail="Airport database unavailable. Please try again later." ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load airport data: {e}") country_airports = [ Airport(**airport) for airport in airports_data if airport['country'] == country_code.upper() ] if not country_airports: raise HTTPException( status_code=404, detail=f"No airports found for country code: {country_code}" ) return country_airports @router_v1.get("/airports/{iata}", response_model=Airport) async def get_airport(iata: str): """ Get details for a specific airport by IATA code. Example: /api/airports/BDS → Brindisi Airport details """ try: airports_data = get_airport_data() except FileNotFoundError: raise HTTPException( status_code=503, detail="Airport database unavailable. Please try again later." ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load airport data: {e}") iata = iata.upper() airport = next((ap for ap in airports_data if ap['iata'] == iata), None) if not airport: raise HTTPException(status_code=404, detail=f"Airport not found: {iata}") return Airport(**airport) @router_v1.get("/countries", response_model=List[Country]) async def get_countries(): """ Get list of all countries with airports. Returns country codes, names, and airport counts. """ try: airports_data = get_airport_data() except FileNotFoundError: raise HTTPException( status_code=503, detail="Airport database unavailable. Please try again later." ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to load airport data: {e}") # Count airports per country country_counts = {} for airport in airports_data: country = airport['country'] country_counts[country] = country_counts.get(country, 0) + 1 # Get country names (we'll need a mapping file for this) # For now, just return codes countries = [ Country( code=code, name=code, # TODO: Add country name mapping airport_count=count ) for code, count in sorted(country_counts.items()) ] return countries # ============================================================================= # Scan Management (TODO: Implement async scanning) # ============================================================================= @router_v1.post("/scans", response_model=ScanCreateResponse) async def create_scan(request: ScanRequest): """ Create a new flight scan. This creates a scan record in the database with 'pending' status. The actual scanning will be performed by a background worker. Returns the created scan details. """ try: # Parse and validate dates if request.start_date: start_date = request.start_date else: # Default to tomorrow start_date = (date.today() + timedelta(days=1)).isoformat() if request.end_date: end_date = request.end_date else: # Default to start_date + window_months start = datetime.strptime(start_date, '%Y-%m-%d').date() end = start + timedelta(days=30 * request.window_months) end_date = end.isoformat() # Determine destination mode and prepare country field # Store either country code (2 letters) or comma-separated airport codes if request.destination_country: country_or_airports = request.destination_country else: # Store comma-separated list of destination airports country_or_airports = ','.join(request.destinations) # Insert scan into database conn = get_connection() cursor = conn.cursor() cursor.execute(""" INSERT INTO scans ( origin, country, start_date, end_date, status, seat_class, adults ) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( request.origin, country_or_airports, start_date, end_date, 'pending', request.seat_class, request.adults )) scan_id = cursor.lastrowid conn.commit() # Fetch the created scan cursor.execute(""" SELECT id, origin, country, start_date, end_date, created_at, updated_at, status, total_routes, routes_scanned, total_flights, error_message, seat_class, adults FROM scans WHERE id = ? """, (scan_id,)) row = cursor.fetchone() conn.close() if not row: raise HTTPException(status_code=500, detail="Failed to create scan") scan = Scan( id=row[0], origin=row[1], country=row[2], start_date=row[3], end_date=row[4], created_at=row[5], updated_at=row[6], status=row[7], total_routes=row[8], routes_scanned=row[9], total_flights=row[10], error_message=row[11], seat_class=row[12], adults=row[13] ) logging.info(f"Scan created: ID={scan_id}, origin={scan.origin}, country={scan.country}, dates={scan.start_date} to {scan.end_date}") # Start background processing try: start_scan_processor(scan_id) logging.info(f"Background scan processor started for scan {scan_id}") except Exception as bg_error: logging.error(f"Failed to start background processor for scan {scan_id}: {str(bg_error)}") # Don't fail the request - scan is created, just not processed yet return ScanCreateResponse( id=scan_id, status='pending', message=f'Scan created successfully. Processing started. Scan ID: {scan_id}', scan=scan ) except Exception as e: import traceback traceback.print_exc() logging.error(f"Failed to create scan: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to create scan: {str(e)}" ) @router_v1.get("/scans", response_model=PaginatedResponse[Scan]) async def list_scans( page: int = Query(1, ge=1, description="Page number"), limit: int = Query(20, ge=1, le=100, description="Items per page"), status: Optional[str] = Query(None, description="Filter by status: pending, running, completed, or failed") ): """ List all scans with pagination. Optionally filter by status. Results are ordered by creation date (most recent first). """ try: conn = get_connection() cursor = conn.cursor() # Build WHERE clause for status filter where_clause = "" params = [] if status: if status not in ['pending', 'running', 'completed', 'failed']: raise HTTPException( status_code=400, detail=f"Invalid status: {status}. Must be one of: pending, running, completed, failed" ) where_clause = "WHERE status = ?" params.append(status) # Get total count count_query = f"SELECT COUNT(*) FROM scans {where_clause}" cursor.execute(count_query, params) total = cursor.fetchone()[0] # Calculate pagination total_pages = math.ceil(total / limit) if total > 0 else 0 # Validate page number if page > total_pages and total_pages > 0: raise HTTPException( status_code=404, detail=f"Page {page} does not exist. Total pages: {total_pages}" ) # Get paginated results offset = (page - 1) * limit query = f""" SELECT id, origin, country, start_date, end_date, created_at, updated_at, status, total_routes, routes_scanned, total_flights, error_message, seat_class, adults FROM scans {where_clause} ORDER BY created_at DESC LIMIT ? OFFSET ? """ cursor.execute(query, params + [limit, offset]) rows = cursor.fetchall() conn.close() # Convert to Scan models scans = [] for row in rows: scans.append(Scan( id=row[0], origin=row[1], country=row[2], start_date=row[3], end_date=row[4], created_at=row[5], updated_at=row[6], status=row[7], total_routes=row[8], routes_scanned=row[9], total_flights=row[10], error_message=row[11], seat_class=row[12], adults=row[13] )) # Build pagination metadata pagination = PaginationMetadata( page=page, limit=limit, total=total, pages=total_pages, has_next=page < total_pages, has_prev=page > 1 ) return PaginatedResponse(data=scans, pagination=pagination) except HTTPException: raise except Exception as e: import traceback traceback.print_exc() raise HTTPException( status_code=500, detail=f"Failed to list scans: {str(e)}" ) @router_v1.get("/scans/{scan_id}", response_model=Scan) async def get_scan_status(scan_id: int): """ Get details of a specific scan. Returns full scan information including progress, status, and statistics. """ try: conn = get_connection() cursor = conn.cursor() cursor.execute(""" SELECT id, origin, country, start_date, end_date, created_at, updated_at, status, total_routes, routes_scanned, total_flights, error_message, seat_class, adults FROM scans WHERE id = ? """, (scan_id,)) row = cursor.fetchone() conn.close() if not row: raise HTTPException( status_code=404, detail=f"Scan not found: {scan_id}" ) return Scan( id=row[0], origin=row[1], country=row[2], start_date=row[3], end_date=row[4], created_at=row[5], updated_at=row[6], status=row[7], total_routes=row[8], routes_scanned=row[9], total_flights=row[10], error_message=row[11], seat_class=row[12], adults=row[13] ) except HTTPException: raise except Exception as e: import traceback traceback.print_exc() raise HTTPException( status_code=500, detail=f"Failed to get scan: {str(e)}" ) @router_v1.delete("/scans/{scan_id}", status_code=204) async def delete_scan(scan_id: int): """ Delete a scan and all its associated routes and flights (CASCADE). Returns 409 if the scan is currently running or pending. """ try: conn = get_connection() cursor = conn.cursor() cursor.execute("SELECT status FROM scans WHERE id = ?", (scan_id,)) row = cursor.fetchone() if not row: conn.close() raise HTTPException(status_code=404, detail=f"Scan not found: {scan_id}") if row[0] in ('pending', 'running'): conn.close() raise HTTPException( status_code=409, detail="Cannot delete a scan that is currently pending or running." ) cursor.execute("DELETE FROM scans WHERE id = ?", (scan_id,)) conn.commit() conn.close() logging.info(f"Scan {scan_id} deleted") except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to delete scan: {str(e)}") @router_v1.get("/scans/{scan_id}/routes", response_model=PaginatedResponse[Route]) async def get_scan_routes( scan_id: int, page: int = Query(1, ge=1, description="Page number"), limit: int = Query(20, ge=1, le=100, description="Items per page") ): """ Get all routes discovered by a specific scan. Returns paginated list of routes with flight counts, airlines, and price statistics. Results are ordered by minimum price (cheapest first). """ try: conn = get_connection() cursor = conn.cursor() # Verify scan exists cursor.execute("SELECT id FROM scans WHERE id = ?", (scan_id,)) if not cursor.fetchone(): conn.close() raise HTTPException( status_code=404, detail=f"Scan not found: {scan_id}" ) # Get total count of routes for this scan cursor.execute("SELECT COUNT(*) FROM routes WHERE scan_id = ?", (scan_id,)) total = cursor.fetchone()[0] # Calculate pagination total_pages = math.ceil(total / limit) if total > 0 else 0 # Validate page number if page > total_pages and total_pages > 0: conn.close() raise HTTPException( status_code=404, detail=f"Page {page} does not exist. Total pages: {total_pages}" ) # Get paginated results offset = (page - 1) * limit cursor.execute(""" SELECT id, scan_id, destination, destination_name, destination_city, flight_count, airlines, min_price, max_price, avg_price, created_at FROM routes WHERE scan_id = ? ORDER BY CASE WHEN min_price IS NULL THEN 1 ELSE 0 END, min_price ASC, flight_count DESC LIMIT ? OFFSET ? """, (scan_id, limit, offset)) rows = cursor.fetchall() conn.close() # Convert to Route models, enriching name/city from airport DB when missing lookup = _iata_lookup() routes = [] for row in rows: # Parse airlines JSON try: airlines = json.loads(row[6]) if row[6] else [] except: airlines = [] dest = row[2] dest_name = row[3] or dest dest_city = row[4] or '' # If name was never resolved (stored as IATA code), look it up now if dest_name == dest: airport = lookup.get(dest, {}) dest_name = airport.get('name', dest) dest_city = airport.get('city', dest_city) routes.append(Route( id=row[0], scan_id=row[1], destination=dest, destination_name=dest_name, destination_city=dest_city, flight_count=row[5], airlines=airlines, min_price=row[7], max_price=row[8], avg_price=row[9], created_at=row[10] )) # Build pagination metadata pagination = PaginationMetadata( page=page, limit=limit, total=total, pages=total_pages, has_next=page < total_pages, has_prev=page > 1 ) return PaginatedResponse(data=routes, pagination=pagination) except HTTPException: raise except Exception as e: import traceback traceback.print_exc() raise HTTPException( status_code=500, detail=f"Failed to get routes: {str(e)}" ) @router_v1.get("/scans/{scan_id}/flights", response_model=PaginatedResponse[Flight]) async def get_scan_flights( scan_id: int, destination: Optional[str] = Query(None, min_length=3, max_length=3, description="Filter by destination IATA code"), page: int = Query(1, ge=1, description="Page number"), limit: int = Query(50, ge=1, le=200, description="Items per page") ): """ Get individual flights discovered by a specific scan. Optionally filter by destination airport code. Results are ordered by price ascending. """ try: conn = get_connection() cursor = conn.cursor() cursor.execute("SELECT id FROM scans WHERE id = ?", (scan_id,)) if not cursor.fetchone(): conn.close() raise HTTPException(status_code=404, detail=f"Scan not found: {scan_id}") if destination: cursor.execute( "SELECT COUNT(*) FROM flights WHERE scan_id = ? AND destination = ?", (scan_id, destination.upper()) ) else: cursor.execute("SELECT COUNT(*) FROM flights WHERE scan_id = ?", (scan_id,)) total = cursor.fetchone()[0] total_pages = math.ceil(total / limit) if total > 0 else 0 offset = (page - 1) * limit if destination: cursor.execute(""" SELECT id, scan_id, destination, date, airline, departure_time, arrival_time, price, stops FROM flights WHERE scan_id = ? AND destination = ? ORDER BY price ASC, date ASC LIMIT ? OFFSET ? """, (scan_id, destination.upper(), limit, offset)) else: cursor.execute(""" SELECT id, scan_id, destination, date, airline, departure_time, arrival_time, price, stops FROM flights WHERE scan_id = ? ORDER BY price ASC, date ASC LIMIT ? OFFSET ? """, (scan_id, limit, offset)) rows = cursor.fetchall() conn.close() flights = [ Flight( id=row[0], scan_id=row[1], destination=row[2], date=row[3], airline=row[4], departure_time=row[5], arrival_time=row[6], price=row[7], stops=row[8] ) for row in rows ] pagination = PaginationMetadata( page=page, limit=limit, total=total, pages=total_pages, has_next=page < total_pages, has_prev=page > 1 ) return PaginatedResponse(data=flights, pagination=pagination) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get flights: {str(e)}") @router_v1.get("/logs", response_model=PaginatedResponse[LogEntry]) async def get_logs( page: int = Query(1, ge=1, description="Page number"), limit: int = Query(50, ge=1, le=500, description="Items per page"), level: Optional[str] = Query(None, description="Filter by log level: DEBUG, INFO, WARNING, ERROR, CRITICAL"), search: Optional[str] = Query(None, min_length=1, description="Search in log messages") ): """ Get application logs with pagination and filtering. Logs are stored in memory (circular buffer, max 1000 entries). Results are ordered by timestamp (newest first). Query Parameters: - page: Page number (default: 1) - limit: Items per page (default: 50, max: 500) - level: Filter by log level (optional) - search: Search text in messages (case-insensitive, optional) """ try: # Get all logs from buffer all_logs = log_buffer.get_all() # Apply level filter if level: level_upper = level.upper() valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] if level_upper not in valid_levels: raise HTTPException( status_code=400, detail=f"Invalid log level: {level}. Must be one of: {', '.join(valid_levels)}" ) all_logs = [log for log in all_logs if log['level'] == level_upper] # Apply search filter if search: search_lower = search.lower() all_logs = [ log for log in all_logs if search_lower in log['message'].lower() ] # Calculate pagination total = len(all_logs) total_pages = math.ceil(total / limit) if total > 0 else 0 # Validate page number if page > total_pages and total_pages > 0: raise HTTPException( status_code=404, detail=f"Page {page} does not exist. Total pages: {total_pages}" ) # Paginate results start_idx = (page - 1) * limit end_idx = start_idx + limit page_logs = all_logs[start_idx:end_idx] # Convert to LogEntry models log_entries = [LogEntry(**log) for log in page_logs] # Build pagination metadata pagination = PaginationMetadata( page=page, limit=limit, total=total, pages=total_pages, has_next=page < total_pages, has_prev=page > 1 ) return PaginatedResponse(data=log_entries, pagination=pagination) except HTTPException: raise except Exception as e: import traceback traceback.print_exc() raise HTTPException( status_code=500, detail=f"Failed to get logs: {str(e)}" ) @router_v1.get("/flights/{route_id}") async def get_flights(route_id: str): """ Get all flights for a specific route. Returns daily flight data for calendar view. """ # TODO: Implement flight data retrieval raise HTTPException(status_code=501, detail="Flights endpoint not yet implemented") # ============================================================================= # Include Router (IMPORTANT!) # ============================================================================= app.include_router(router_v1) # ============================================================================= # Helper Functions # ============================================================================= # Airports missing from the OpenFlights dataset (opened/renamed after dataset freeze) _MISSING_AIRPORTS = [ {'iata': 'BER', 'name': 'Berlin Brandenburg Airport', 'city': 'Berlin', 'country': 'DE'}, {'iata': 'IST', 'name': 'Istanbul Airport', 'city': 'Istanbul', 'country': 'TR'}, ] @lru_cache(maxsize=1) def get_airport_data(): """ Load airport data from the existing airports.py module. Returns list of airport dictionaries. """ from pathlib import Path json_path = Path(__file__).parent / "data" / "airports_by_country.json" if not json_path.exists(): raise FileNotFoundError("Airport database not found. Run airports.py first.") with open(json_path, 'r', encoding='utf-8') as f: airports_by_country = json.load(f) # Flatten the data structure airports = [] for country_code, country_airports in airports_by_country.items(): for airport in country_airports: airports.append({ 'iata': airport['iata'], 'name': airport['name'], 'city': airport.get('city', ''), 'country': country_code, 'latitude': airport.get('lat', 0.0), 'longitude': airport.get('lon', 0.0), }) # Patch in modern airports missing from the OpenFlights dataset existing_iatas = {a['iata'] for a in airports} for extra in _MISSING_AIRPORTS: if extra['iata'] not in existing_iatas: airports.append({ 'iata': extra['iata'], 'name': extra['name'], 'city': extra['city'], 'country': extra['country'], 'latitude': 0.0, 'longitude': 0.0, }) return airports @lru_cache(maxsize=1) def _iata_lookup() -> dict: """Return {iata: airport_dict} built from get_airport_data(). Cached.""" return {a['iata']: a for a in get_airport_data()} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)