feat: implement reverse scan (country → specific airports)
All checks were successful
Deploy / deploy (push) Successful in 30s

- DB schema: relaxed origin CHECK to >=2 chars, added scan_mode column to
  scans and scheduled_scans, added origin_airport to routes and flights,
  updated unique index to (scan_id, COALESCE(origin_airport,''), destination)
- Migrations: init_db.py recreates tables and adds columns via guarded ALTERs
- API: scan_mode field on ScanRequest/Scan; Route/Flight expose origin_airport;
  GET /scans/{id}/flights accepts origin_airport filter; CreateScheduleRequest
  and Schedule carry scan_mode; scheduler and run-now pass scan_mode through
- scan_processor: _write_route_incremental accepts origin_airport; process_scan
  branches on scan_mode=reverse (country → airports × destinations × dates)
- Frontend: new CountrySelect component (populated from GET /api/v1/countries);
  Scans page adds Direction toggle + CountrySelect for both modes; ScanDetails
  shows Origin column for reverse scans and uses composite route keys; Re-run
  preserves scan_mode

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-01 17:58:55 +01:00
parent 7ece1f9b45
commit 77d2a46264
9 changed files with 1070 additions and 279 deletions

View File

@@ -38,7 +38,10 @@ from threading import Lock
T = TypeVar('T')
# Import existing modules
from airports import download_and_build_airport_data
from airports import download_and_build_airport_data, COUNTRY_NAME_TO_ISO
# Inverted mapping: ISO code → country name (for /countries endpoint)
_ISO_TO_COUNTRY_NAME = {v: k for k, v in COUNTRY_NAME_TO_ISO.items()}
from database import get_connection
from scan_processor import start_scan_processor, start_resume_processor, pause_scan_task, stop_scan_task
@@ -294,7 +297,7 @@ def _check_and_run_due_schedules():
now_str = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
cursor.execute("""
SELECT id, origin, country, window_months, seat_class, adults,
SELECT id, origin, country, scan_mode, window_months, seat_class, adults,
frequency, hour, minute, day_of_week, day_of_month
FROM scheduled_scans
WHERE enabled = 1 AND next_run_at <= ?
@@ -302,7 +305,7 @@ def _check_and_run_due_schedules():
due = cursor.fetchall()
for row in due:
(sched_id, origin, country, window_months, seat_class, adults,
(sched_id, origin, country, scan_mode, window_months, seat_class, adults,
frequency, hour, minute, day_of_week, day_of_month) = row
# Concurrency guard: skip if a scan for this schedule is still active
@@ -323,10 +326,10 @@ def _check_and_run_due_schedules():
conn.execute("""
INSERT INTO scans (
origin, country, start_date, end_date,
origin, country, scan_mode, start_date, end_date,
status, seat_class, adults, scheduled_scan_id
) VALUES (?, ?, ?, ?, 'pending', ?, ?, ?)
""", (origin, country, start_date, end_date,
) VALUES (?, ?, ?, ?, ?, 'pending', ?, ?, ?)
""", (origin, country, scan_mode, start_date, end_date,
seat_class, adults, sched_id))
conn.commit()
scan_id = conn.execute("SELECT last_insert_rowid()").fetchone()[0]
@@ -705,11 +708,15 @@ class Country(BaseModel):
class ScanRequest(BaseModel):
"""Flight scan request model with comprehensive validation."""
scan_mode: str = Field(
'forward',
description="Scan direction: 'forward' (IATA → country) or 'reverse' (country → IATAs)"
)
origin: str = Field(
...,
min_length=3,
min_length=2,
max_length=3,
description="Origin airport IATA code (3 uppercase letters)"
description="Origin airport IATA code (forward) or ISO country code (reverse)"
)
destination_country: Optional[str] = Field(
None,
@@ -747,11 +754,22 @@ class ScanRequest(BaseModel):
description="Number of adults (1-9)"
)
@validator('scan_mode')
def validate_scan_mode(cls, v):
if v not in ('forward', 'reverse'):
raise ValueError("scan_mode must be 'forward' or 'reverse'")
return v
@validator('origin')
def validate_origin(cls, v):
v = v.upper() # Normalize to uppercase
if not re.match(r'^[A-Z]{3}$', v):
raise ValueError('Origin must be a 3-letter IATA code (e.g., BDS, MUC)')
def validate_origin(cls, v, values):
v = v.strip().upper()
mode = values.get('scan_mode', 'forward')
if mode == 'reverse':
if not re.match(r'^[A-Z]{2}$', v):
raise ValueError('For reverse scans, origin must be a 2-letter ISO country code (e.g., DE, IT)')
else:
if not re.match(r'^[A-Z]{3}$', v):
raise ValueError('Origin must be a 3-letter IATA code (e.g., BDS, MUC)')
return v
@validator('destination_country')
@@ -791,16 +809,20 @@ class ScanRequest(BaseModel):
@validator('destinations', pre=False, always=True)
def check_destination_mode(cls, v, values):
"""Ensure either country or destinations is provided, but not both."""
"""Ensure correct destination fields for the chosen scan_mode."""
country = values.get('destination_country')
mode = values.get('scan_mode', 'forward')
if country and v:
raise ValueError('Provide either country OR destinations, not both')
if not country and not v:
raise ValueError('Must provide either country or destinations')
return v
if mode == 'reverse':
if not v:
raise ValueError('Reverse scans require destinations (list of destination airport IATA codes)')
return v
else:
if country and v:
raise ValueError('Provide either country OR destinations, not both')
if not country and not v:
raise ValueError('Must provide either country or destinations')
return v
@validator('start_date')
def validate_start_date(cls, v):
@@ -895,6 +917,7 @@ class Route(BaseModel):
"""Route model - represents a discovered flight route."""
id: int = Field(..., description="Route ID")
scan_id: int = Field(..., description="Parent scan ID")
origin_airport: Optional[str] = Field(None, description="Origin airport IATA code (reverse scans only)")
destination: str = Field(..., description="Destination airport IATA code")
destination_name: str = Field(..., description="Destination airport name")
destination_city: Optional[str] = Field(None, description="Destination city")
@@ -910,6 +933,7 @@ class Flight(BaseModel):
"""Individual flight discovered by a scan."""
id: int = Field(..., description="Flight ID")
scan_id: int = Field(..., description="Parent scan ID")
origin_airport: Optional[str] = Field(None, description="Origin airport IATA code (reverse scans only)")
destination: str = Field(..., description="Destination airport IATA code")
date: str = Field(..., description="Flight date (YYYY-MM-DD)")
airline: Optional[str] = Field(None, description="Operating airline")
@@ -922,8 +946,9 @@ class Flight(BaseModel):
class Scan(BaseModel):
"""Scan model - represents a flight scan with full details."""
id: int = Field(..., description="Scan ID")
origin: str = Field(..., description="Origin airport IATA code")
country: str = Field(..., description="Destination country code")
scan_mode: str = Field('forward', description="Scan direction: forward or reverse")
origin: str = Field(..., description="Origin airport IATA code (forward) or ISO country code (reverse)")
country: str = Field(..., description="Destination country code or comma-separated destination IATAs")
start_date: str = Field(..., description="Start date (YYYY-MM-DD)")
end_date: str = Field(..., description="End date (YYYY-MM-DD)")
created_at: str = Field(..., description="ISO timestamp when scan was created")
@@ -1183,16 +1208,14 @@ async def get_countries():
country = airport['country']
country_counts[country] = country_counts.get(country, 0) + 1
# Get country names (we'll need a mapping file for this)
# For now, just return codes
countries = [
countries = sorted([
Country(
code=code,
name=code, # TODO: Add country name mapping
name=_ISO_TO_COUNTRY_NAME.get(code, code),
airport_count=count
)
for code, count in sorted(country_counts.items())
]
for code, count in country_counts.items()
], key=lambda c: c.name)
return countries
@@ -1241,12 +1264,13 @@ async def create_scan(request: ScanRequest):
cursor.execute("""
INSERT INTO scans (
origin, country, start_date, end_date,
origin, country, scan_mode, start_date, end_date,
status, seat_class, adults
) VALUES (?, ?, ?, ?, ?, ?, ?)
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
request.origin,
country_or_airports,
request.scan_mode,
start_date,
end_date,
'pending',
@@ -1259,7 +1283,7 @@ async def create_scan(request: ScanRequest):
# Fetch the created scan
cursor.execute("""
SELECT id, origin, country, start_date, end_date,
SELECT id, origin, country, scan_mode, start_date, end_date,
created_at, updated_at, status, total_routes,
routes_scanned, total_flights, error_message,
seat_class, adults, scheduled_scan_id,
@@ -1278,20 +1302,21 @@ async def create_scan(request: ScanRequest):
id=row[0],
origin=row[1],
country=row[2],
start_date=row[3],
end_date=row[4],
created_at=row[5],
updated_at=row[6],
status=row[7],
total_routes=row[8],
routes_scanned=row[9],
total_flights=row[10],
error_message=row[11],
seat_class=row[12],
adults=row[13],
scheduled_scan_id=row[14] if len(row) > 14 else None,
started_at=row[15] if len(row) > 15 else None,
completed_at=row[16] if len(row) > 16 else None,
scan_mode=row[3],
start_date=row[4],
end_date=row[5],
created_at=row[6],
updated_at=row[7],
status=row[8],
total_routes=row[9],
routes_scanned=row[10],
total_flights=row[11],
error_message=row[12],
seat_class=row[13],
adults=row[14],
scheduled_scan_id=row[15] if len(row) > 15 else None,
started_at=row[16] if len(row) > 16 else None,
completed_at=row[17] if len(row) > 17 else None,
)
logging.info(f"Scan created: ID={scan_id}, origin={scan.origin}, country={scan.country}, dates={scan.start_date} to {scan.end_date}")
@@ -1367,7 +1392,7 @@ async def list_scans(
# Get paginated results
offset = (page - 1) * limit
query = f"""
SELECT id, origin, country, start_date, end_date,
SELECT id, origin, country, scan_mode, start_date, end_date,
created_at, updated_at, status, total_routes,
routes_scanned, total_flights, error_message,
seat_class, adults, scheduled_scan_id,
@@ -1388,20 +1413,21 @@ async def list_scans(
id=row[0],
origin=row[1],
country=row[2],
start_date=row[3],
end_date=row[4],
created_at=row[5],
updated_at=row[6],
status=row[7],
total_routes=row[8],
routes_scanned=row[9],
total_flights=row[10],
error_message=row[11],
seat_class=row[12],
adults=row[13],
scheduled_scan_id=row[14] if len(row) > 14 else None,
started_at=row[15] if len(row) > 15 else None,
completed_at=row[16] if len(row) > 16 else None,
scan_mode=row[3],
start_date=row[4],
end_date=row[5],
created_at=row[6],
updated_at=row[7],
status=row[8],
total_routes=row[9],
routes_scanned=row[10],
total_flights=row[11],
error_message=row[12],
seat_class=row[13],
adults=row[14],
scheduled_scan_id=row[15] if len(row) > 15 else None,
started_at=row[16] if len(row) > 16 else None,
completed_at=row[17] if len(row) > 17 else None,
))
# Build pagination metadata
@@ -1439,7 +1465,7 @@ async def get_scan_status(scan_id: int):
cursor = conn.cursor()
cursor.execute("""
SELECT id, origin, country, start_date, end_date,
SELECT id, origin, country, scan_mode, start_date, end_date,
created_at, updated_at, status, total_routes,
routes_scanned, total_flights, error_message,
seat_class, adults, scheduled_scan_id,
@@ -1461,20 +1487,21 @@ async def get_scan_status(scan_id: int):
id=row[0],
origin=row[1],
country=row[2],
start_date=row[3],
end_date=row[4],
created_at=row[5],
updated_at=row[6],
status=row[7],
total_routes=row[8],
routes_scanned=row[9],
total_flights=row[10],
error_message=row[11],
seat_class=row[12],
adults=row[13],
scheduled_scan_id=row[14] if len(row) > 14 else None,
started_at=row[15] if len(row) > 15 else None,
completed_at=row[16] if len(row) > 16 else None,
scan_mode=row[3],
start_date=row[4],
end_date=row[5],
created_at=row[6],
updated_at=row[7],
status=row[8],
total_routes=row[9],
routes_scanned=row[10],
total_flights=row[11],
error_message=row[12],
seat_class=row[13],
adults=row[14],
scheduled_scan_id=row[15] if len(row) > 15 else None,
started_at=row[16] if len(row) > 16 else None,
completed_at=row[17] if len(row) > 17 else None,
)
except HTTPException:
@@ -1716,7 +1743,7 @@ async def get_scan_routes(
# Get paginated results
offset = (page - 1) * limit
cursor.execute("""
SELECT id, scan_id, destination, destination_name, destination_city,
SELECT id, scan_id, origin_airport, destination, destination_name, destination_city,
flight_count, airlines, min_price, max_price, avg_price, created_at
FROM routes
WHERE scan_id = ?
@@ -1736,13 +1763,13 @@ async def get_scan_routes(
for row in rows:
# Parse airlines JSON
try:
airlines = json.loads(row[6]) if row[6] else []
airlines = json.loads(row[7]) if row[7] else []
except:
airlines = []
dest = row[2]
dest_name = row[3] or dest
dest_city = row[4] or ''
dest = row[3]
dest_name = row[4] or dest
dest_city = row[5] or ''
# If name was never resolved (stored as IATA code), look it up now
if dest_name == dest:
@@ -1753,15 +1780,16 @@ async def get_scan_routes(
routes.append(Route(
id=row[0],
scan_id=row[1],
origin_airport=row[2],
destination=dest,
destination_name=dest_name,
destination_city=dest_city,
flight_count=row[5],
flight_count=row[6],
airlines=airlines,
min_price=row[7],
max_price=row[8],
avg_price=row[9],
created_at=row[10]
min_price=row[8],
max_price=row[9],
avg_price=row[10],
created_at=row[11]
))
# Build pagination metadata
@@ -1791,14 +1819,15 @@ async def get_scan_routes(
async def get_scan_flights(
scan_id: int,
destination: Optional[str] = Query(None, min_length=3, max_length=3, description="Filter by destination IATA code"),
origin_airport: Optional[str] = Query(None, min_length=3, max_length=3, description="Filter by origin airport IATA code (reverse scans)"),
page: int = Query(1, ge=1, description="Page number"),
limit: int = Query(50, ge=1, le=200, description="Items per page")
):
"""
Get individual flights discovered by a specific scan.
Optionally filter by destination airport code.
Results are ordered by price ascending.
Optionally filter by destination and/or origin airport code.
Results are ordered by date then price ascending.
"""
try:
conn = get_connection()
@@ -1809,45 +1838,41 @@ async def get_scan_flights(
conn.close()
raise HTTPException(status_code=404, detail=f"Scan not found: {scan_id}")
# Build dynamic WHERE clause
conditions = ["scan_id = ?"]
params: list = [scan_id]
if destination:
cursor.execute(
"SELECT COUNT(*) FROM flights WHERE scan_id = ? AND destination = ?",
(scan_id, destination.upper())
)
else:
cursor.execute("SELECT COUNT(*) FROM flights WHERE scan_id = ?", (scan_id,))
conditions.append("destination = ?")
params.append(destination.upper())
if origin_airport:
conditions.append("origin_airport = ?")
params.append(origin_airport.upper())
where = " AND ".join(conditions)
cursor.execute(f"SELECT COUNT(*) FROM flights WHERE {where}", params)
total = cursor.fetchone()[0]
total_pages = math.ceil(total / limit) if total > 0 else 0
offset = (page - 1) * limit
if destination:
cursor.execute("""
SELECT id, scan_id, destination, date, airline,
departure_time, arrival_time, price, stops
FROM flights
WHERE scan_id = ? AND destination = ?
ORDER BY date ASC, price ASC
LIMIT ? OFFSET ?
""", (scan_id, destination.upper(), limit, offset))
else:
cursor.execute("""
SELECT id, scan_id, destination, date, airline,
departure_time, arrival_time, price, stops
FROM flights
WHERE scan_id = ?
ORDER BY date ASC, price ASC
LIMIT ? OFFSET ?
""", (scan_id, limit, offset))
cursor.execute(f"""
SELECT id, scan_id, origin_airport, destination, date, airline,
departure_time, arrival_time, price, stops
FROM flights
WHERE {where}
ORDER BY date ASC, price ASC
LIMIT ? OFFSET ?
""", params + [limit, offset])
rows = cursor.fetchall()
conn.close()
flights = [
Flight(
id=row[0], scan_id=row[1], destination=row[2], date=row[3],
airline=row[4], departure_time=row[5], arrival_time=row[6],
price=row[7], stops=row[8]
id=row[0], scan_id=row[1], origin_airport=row[2],
destination=row[3], date=row[4], airline=row[5],
departure_time=row[6], arrival_time=row[7],
price=row[8], stops=row[9]
)
for row in rows
]
@@ -1965,7 +1990,8 @@ async def get_flights_stub(route_id: str):
class CreateScheduleRequest(BaseModel):
"""Request body for creating or updating a scheduled scan."""
origin: str = Field(..., description="Origin airport IATA code (3 letters)")
scan_mode: str = Field('forward', description="Scan direction: 'forward' or 'reverse'")
origin: str = Field(..., description="Origin airport IATA code (forward) or ISO country code (reverse)")
country: str = Field(..., description="Destination country ISO code (2 letters) or comma-separated IATA codes")
window_months: int = Field(1, ge=1, le=12, description="Months of data per scan run")
seat_class: str = Field('economy', description="Seat class")
@@ -2027,6 +2053,7 @@ class UpdateScheduleRequest(BaseModel):
class Schedule(BaseModel):
"""A recurring scheduled scan."""
id: int
scan_mode: str
origin: str
country: str
window_months: int
@@ -2049,6 +2076,7 @@ def _row_to_schedule(row, recent_scan_ids: list) -> Schedule:
"""Convert a DB row (sqlite3.Row or tuple) to a Schedule model."""
return Schedule(
id=row['id'],
scan_mode=row['scan_mode'] if 'scan_mode' in row.keys() else 'forward',
origin=row['origin'],
country=row['country'],
window_months=row['window_months'],
@@ -2126,12 +2154,12 @@ async def create_schedule(request: CreateScheduleRequest):
conn = get_connection()
conn.execute("""
INSERT INTO scheduled_scans (
origin, country, window_months, seat_class, adults,
scan_mode, origin, country, window_months, seat_class, adults,
label, frequency, hour, minute, day_of_week, day_of_month,
enabled, next_run_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?)
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?)
""", (
request.origin, request.country, request.window_months,
request.scan_mode, request.origin, request.country, request.window_months,
request.seat_class, request.adults, request.label,
request.frequency, request.hour, request.minute,
request.day_of_week, request.day_of_month, next_run_str,
@@ -2279,11 +2307,13 @@ async def run_schedule_now(schedule_id: int):
conn.execute("""
INSERT INTO scans (
origin, country, start_date, end_date,
origin, country, scan_mode, start_date, end_date,
status, seat_class, adults, scheduled_scan_id
) VALUES (?, ?, ?, ?, 'pending', ?, ?, ?)
) VALUES (?, ?, ?, ?, ?, 'pending', ?, ?, ?)
""", (
row['origin'], row['country'], start_date, end_date,
row['origin'], row['country'],
row['scan_mode'] if 'scan_mode' in row.keys() else 'forward',
start_date, end_date,
row['seat_class'], row['adults'], schedule_id,
))
conn.commit()