"""Core sync engine: runs the full sync cycle across all eligible replicas.""" import asyncio import hashlib import sqlite3 from dataclasses import dataclass from datetime import datetime, timezone from typing import Optional from sqlmodel import Session, select from ..config import get_config from ..crypto import decrypt from ..database import get_engine from ..logger import emit_log from ..models import Replica, SyncMap, SyncRun from .. import metrics from .paperless import PaperlessClient, PaperlessError _sync_lock = asyncio.Lock() @dataclass class SyncProgress: running: bool = False phase: str = "" docs_done: int = 0 docs_total: int = 0 _progress = SyncProgress() def get_progress() -> SyncProgress: return _progress async def run_sync_cycle( triggered_by: str = "scheduler", replica_id: Optional[int] = None, ) -> bool: """Trigger a sync cycle in the background. Returns False if already running.""" if _sync_lock.locked(): return False asyncio.create_task(_do_sync(triggered_by, replica_id)) return True async def _get_settings() -> dict: from ..scheduler import SETTINGS_DEFAULTS from ..models import Setting with Session(get_engine()) as s: rows = s.exec(select(Setting)).all() result = dict(SETTINGS_DEFAULTS) for row in rows: if row.value is not None: result[row.key] = row.value return result async def _ensure_schema_parity( master: PaperlessClient, replica: PaperlessClient, ) -> dict: """Create missing tags/correspondents/document_types/custom_fields on replica. Returns maps: master_id → replica_id for each entity type.""" master_tags = {t["name"]: t for t in await master.get_tags()} replica_tags = {t["name"]: t for t in await replica.get_tags()} tag_map: dict[int, int] = {} for name, mt in master_tags.items(): rt = replica_tags.get(name) or await replica.create_tag( name, color=mt.get("color", ""), is_inbox_tag=mt.get("is_inbox_tag", False), ) tag_map[mt["id"]] = rt["id"] master_corrs = {c["name"]: c for c in await master.get_correspondents()} replica_corrs = {c["name"]: c for c in await replica.get_correspondents()} corr_map: dict[int, int] = {} for name, mc in master_corrs.items(): rc = replica_corrs.get(name) or await replica.create_correspondent(name) corr_map[mc["id"]] = rc["id"] master_dts = {d["name"]: d for d in await master.get_document_types()} replica_dts = {d["name"]: d for d in await replica.get_document_types()} dt_map: dict[int, int] = {} for name, mdt in master_dts.items(): rdt = replica_dts.get(name) or await replica.create_document_type(name) dt_map[mdt["id"]] = rdt["id"] master_cfs = {cf["name"]: cf for cf in await master.get_custom_fields()} replica_cfs = {cf["name"]: cf for cf in await replica.get_custom_fields()} cf_map: dict[int, int] = {} for name, mcf in master_cfs.items(): rcf = replica_cfs.get(name) or await replica.create_custom_field( name, mcf.get("data_type", "string") ) cf_map[mcf["id"]] = rcf["id"] return { "tags": tag_map, "correspondents": corr_map, "document_types": dt_map, "custom_fields": cf_map, } def _translate_metadata(meta: dict, maps: dict) -> dict: """Translate master entity IDs to replica entity IDs.""" result: dict = { "title": meta.get("title", ""), "created": meta.get("created") or meta.get("created_date"), "archive_serial_number": meta.get("archive_serial_number"), } if meta.get("correspondent") is not None: result["correspondent"] = maps["correspondents"].get(meta["correspondent"]) if meta.get("document_type") is not None: result["document_type"] = maps["document_types"].get(meta["document_type"]) result["tags"] = [ maps["tags"][t] for t in meta.get("tags", []) if t in maps["tags"] ] cf_list = [] for cf_entry in meta.get("custom_fields", []): master_cf_id = cf_entry.get("field") if master_cf_id in maps["custom_fields"]: cf_list.append( { "field": maps["custom_fields"][master_cf_id], "value": cf_entry.get("value"), } ) result["custom_fields"] = cf_list return result def _sha256(data: bytes) -> str: return hashlib.sha256(data).hexdigest() async def _resolve_pending_tasks( replica: PaperlessClient, replica_obj: Replica, task_poll_timeout: int, run_id: int, session: Session, ) -> tuple[int, int]: """Resolve pending sync_map entries. Returns (resolved, failed).""" pending = session.exec( select(SyncMap).where( SyncMap.replica_id == replica_obj.id, SyncMap.status == "pending", SyncMap.task_id.is_not(None), # type: ignore[union-attr] ) ).all() resolved = failed = 0 now = datetime.now(timezone.utc) for entry in pending: try: task = await replica.get_task(entry.task_id) # type: ignore[arg-type] status = task.get("status", "") age_seconds = 0 if entry.last_synced: last = entry.last_synced if last.tzinfo is None: last = last.replace(tzinfo=timezone.utc) age_seconds = (now - last).total_seconds() if not task or age_seconds > task_poll_timeout: entry.status = "error" entry.error_msg = "task timed out" entry.retry_count += 1 session.add(entry) emit_log( "warning", f"Task timed out for doc {entry.master_doc_id}", replica=replica_obj.name, replica_id=replica_obj.id, doc_id=entry.master_doc_id, run_id=run_id, session=session, ) failed += 1 elif status == "SUCCESS": # Extract replica_doc_id from task result related = task.get("related_document") if related is not None: entry.replica_doc_id = int(str(related)) entry.task_id = None entry.status = "ok" entry.last_synced = now session.add(entry) resolved += 1 elif status in ("FAILURE", "REVOKED"): entry.status = "error" entry.error_msg = task.get("result", "task failed")[:500] entry.retry_count += 1 session.add(entry) emit_log( "warning", f"Task failed for doc {entry.master_doc_id}: {entry.error_msg}", replica=replica_obj.name, replica_id=replica_obj.id, doc_id=entry.master_doc_id, run_id=run_id, session=session, ) failed += 1 # else: still PENDING/STARTED — leave it except Exception as e: emit_log( "warning", f"Could not check task for doc {entry.master_doc_id}: {e}", replica=replica_obj.name, replica_id=replica_obj.id, run_id=run_id, session=session, ) if pending: session.commit() return resolved, failed async def _sync_replica( replica_obj: Replica, master: PaperlessClient, changed_docs: list[dict], settings: dict, run_id: int, engine, ) -> tuple[int, int]: """Sync one replica. Returns (docs_synced, docs_failed).""" config = get_config() max_concurrent = int(settings.get("max_concurrent_requests", "4")) task_poll_timeout = int(settings.get("task_poll_timeout_seconds", "600")) replica_token = decrypt(replica_obj.api_token, config.secret_key) replica_semaphore = asyncio.Semaphore(max_concurrent) docs_synced = docs_failed = 0 async with PaperlessClient( replica_obj.url, replica_token, replica_semaphore ) as replica: with Session(engine) as session: # Step 5a: ensure schema parity _progress.phase = f"schema parity — {replica_obj.name}" try: maps = await _ensure_schema_parity(master, replica) except Exception as e: emit_log( "error", f"Schema parity failed: {e}", replica=replica_obj.name, replica_id=replica_obj.id, run_id=run_id, session=session, ) raise # Step 5b: resolve pending tasks _progress.phase = f"resolving tasks — {replica_obj.name}" await _resolve_pending_tasks( replica, replica_obj, task_poll_timeout, run_id, session ) # Step 5c: collect docs to process last_ts = replica_obj.last_sync_ts if last_ts and last_ts.tzinfo is None: last_ts = last_ts.replace(tzinfo=timezone.utc) docs_for_replica = [ d for d in changed_docs if last_ts is None or _parse_dt(d.get("modified", "")) is None or _parse_dt(d.get("modified", "")) >= last_ts ] # Include error-status docs (capped at 50) error_entries = session.exec( select(SyncMap).where( SyncMap.replica_id == replica_obj.id, SyncMap.status == "error", ) ).all()[:50] error_doc_ids = {e.master_doc_id for e in error_entries} existing_ids = {d["id"] for d in docs_for_replica} for e in error_entries: if e.master_doc_id not in existing_ids: docs_for_replica.append({"id": e.master_doc_id, "_retry": True}) _progress.docs_total = len(docs_for_replica) _progress.docs_done = 0 _progress.phase = f"syncing {replica_obj.name}" # Step 5d: process each document for doc_stub in docs_for_replica: doc_id = doc_stub["id"] try: # Fetch full metadata from master meta = await master.get_document(doc_id) file_bytes = await master.download_document(doc_id, original=True) checksum = _sha256(file_bytes) filename = meta.get("original_file_name") or f"document-{doc_id}.pdf" translated = _translate_metadata(meta, maps) existing = session.exec( select(SyncMap).where( SyncMap.replica_id == replica_obj.id, SyncMap.master_doc_id == doc_id, ) ).first() if existing and existing.replica_doc_id is not None and existing.status == "ok": # Update metadata on replica await replica.patch_document(existing.replica_doc_id, translated) existing.last_synced = datetime.now(timezone.utc) existing.file_checksum = checksum session.add(existing) session.commit() docs_synced += 1 emit_log( "info", f"Updated doc {doc_id} → replica {existing.replica_doc_id}", replica=replica_obj.name, replica_id=replica_obj.id, doc_id=doc_id, run_id=run_id, session=session, ) else: # Upload new document task_id = await master_post_to_replica( replica, file_bytes, filename, translated ) now = datetime.now(timezone.utc) if existing: existing.task_id = task_id existing.status = "pending" existing.replica_doc_id = None existing.file_checksum = checksum existing.last_synced = now existing.retry_count = existing.retry_count + 1 session.add(existing) else: entry = SyncMap( replica_id=replica_obj.id, master_doc_id=doc_id, task_id=task_id, status="pending", file_checksum=checksum, last_synced=now, ) session.add(entry) session.commit() emit_log( "info", f"Uploaded doc {doc_id}, task {task_id}", replica=replica_obj.name, replica_id=replica_obj.id, doc_id=doc_id, run_id=run_id, session=session, ) except Exception as e: docs_failed += 1 emit_log( "error", f"Failed to sync doc {doc_id}: {e}", replica=replica_obj.name, replica_id=replica_obj.id, doc_id=doc_id, run_id=run_id, session=session, ) # Mark as error in sync_map existing = session.exec( select(SyncMap).where( SyncMap.replica_id == replica_obj.id, SyncMap.master_doc_id == doc_id, ) ).first() if existing: existing.status = "error" existing.error_msg = str(e)[:500] session.add(existing) session.commit() _progress.docs_done += 1 metrics.docs_total.labels( replica=replica_obj.name, status="ok" if docs_failed == 0 else "error", ).inc() return docs_synced, docs_failed async def master_post_to_replica( replica: PaperlessClient, file_bytes: bytes, filename: str, metadata: dict, ) -> str: return await replica.post_document(file_bytes, filename, metadata) def _parse_dt(s: str) -> datetime | None: if not s: return None try: dt = datetime.fromisoformat(s.replace("Z", "+00:00")) if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt except Exception: return None async def _do_sync(triggered_by: str, target_replica_id: Optional[int]) -> None: global _progress async with _sync_lock: _progress = SyncProgress(running=True, phase="starting") metrics.sync_running.set(1) config = get_config() engine = get_engine() start_time = datetime.now(timezone.utc) run_id: Optional[int] = None try: settings = await _get_settings() master_url = settings.get("master_url", "") master_token_enc = settings.get("master_token", "") if not master_url or not master_token_enc: emit_log("error", "Master URL or token not configured") return master_token = decrypt(master_token_enc, config.secret_key) max_concurrent = int(settings.get("max_concurrent_requests", "4")) sync_cycle_timeout = int(settings.get("sync_cycle_timeout_seconds", "1800")) suspend_threshold = int(settings.get("replica_suspend_threshold", "5")) # Create sync_run record with Session(engine) as session: sync_run = SyncRun( replica_id=target_replica_id, started_at=start_time, triggered_by=triggered_by, ) session.add(sync_run) session.commit() session.refresh(sync_run) run_id = sync_run.id # Determine eligible replicas with Session(engine) as session: stmt = select(Replica).where(Replica.enabled == True) # noqa: E712 if target_replica_id: stmt = stmt.where(Replica.id == target_replica_id) all_replicas = session.exec(stmt).all() now = datetime.now(timezone.utc) eligible: list[Replica] = [] for r in all_replicas: if r.suspended_at is not None: continue if r.sync_interval_seconds is not None and r.last_sync_ts: last = r.last_sync_ts if last.tzinfo is None: last = last.replace(tzinfo=timezone.utc) if (now - last).total_seconds() < r.sync_interval_seconds: continue eligible.append(r) if not eligible: emit_log("info", "No eligible replicas for this cycle") _close_run(engine, run_id, 0, 0, False) return # Find min last_sync_ts for master query. # If ANY eligible replica has never synced, fetch ALL master docs. any_never_synced = any(r.last_sync_ts is None for r in eligible) if any_never_synced: modified_gte = None else: last_sync_times = [r.last_sync_ts for r in eligible] # type: ignore[misc] min_ts = min( (t if t.tzinfo else t.replace(tzinfo=timezone.utc)) for t in last_sync_times ) modified_gte = min_ts.isoformat() master_semaphore = asyncio.Semaphore(max_concurrent) result_container = [0, 0] try: await asyncio.wait_for( _run_all_replicas( eligible=eligible, master_url=master_url, master_token=master_token, master_semaphore=master_semaphore, modified_gte=modified_gte, settings=settings, run_id=run_id, suspend_threshold=suspend_threshold, engine=engine, start_time=start_time, result_container=result_container, ), timeout=sync_cycle_timeout, ) except asyncio.TimeoutError: emit_log( "warning", f"Sync cycle timed out after {sync_cycle_timeout}s", ) _close_run(engine, run_id, 0, 0, True) return _close_run(engine, run_id, result_container[0], result_container[1], False) _do_backup(config.db_path) except Exception as e: emit_log("error", f"Sync cycle crashed: {e}") if run_id: _close_run(engine, run_id, 0, 0, False) finally: elapsed = (datetime.now(timezone.utc) - start_time).total_seconds() metrics.sync_duration.labels(triggered_by=triggered_by).observe(elapsed) metrics.sync_running.set(0) _progress = SyncProgress(running=False) async def _run_all_replicas( *, eligible: list[Replica], master_url: str, master_token: str, master_semaphore: asyncio.Semaphore, modified_gte: str | None, settings: dict, run_id: int, suspend_threshold: int, engine, start_time: datetime, result_container: list, ) -> None: """Fetch changed docs once, then sync each replica.""" _progress.phase = "fetching master documents" async with PaperlessClient(master_url, master_token, master_semaphore) as master: changed_docs = await master.get_all_documents(modified_gte=modified_gte) total_synced = total_failed = 0 for replica_obj in eligible: _progress.phase = f"syncing {replica_obj.name}" try: async with PaperlessClient( master_url, master_token, master_semaphore ) as master: synced, failed = await _sync_replica( replica_obj=replica_obj, master=master, changed_docs=changed_docs, settings=settings, run_id=run_id, engine=engine, ) total_synced += synced total_failed += failed # Update replica success state with Session(engine) as session: r = session.get(Replica, replica_obj.id) if r: r.last_sync_ts = start_time r.consecutive_failures = 0 session.add(r) session.commit() metrics.replica_consecutive_failures.labels(replica=replica_obj.name).set(0) # Check alert threshold alert_threshold = int(settings.get("alert_error_threshold", "5")) if failed >= alert_threshold: await _send_alert( replica_obj, "sync_failures_threshold", {"docs_synced": synced, "docs_failed": failed}, settings, engine, ) except Exception as e: emit_log( "error", f"Replica sync failed: {e}", replica=replica_obj.name, replica_id=replica_obj.id, run_id=run_id, ) total_failed += 1 with Session(engine) as session: r = session.get(Replica, replica_obj.id) if r: r.consecutive_failures += 1 if r.consecutive_failures >= suspend_threshold: r.suspended_at = datetime.now(timezone.utc) emit_log( "error", f"Replica {r.name} suspended after {r.consecutive_failures} consecutive failures", replica=r.name, replica_id=r.id, ) await _send_alert( r, "replica_suspended", {"docs_synced": 0, "docs_failed": 1}, settings, engine, ) session.add(r) session.commit() metrics.replica_consecutive_failures.labels( replica=replica_obj.name ).set(replica_obj.consecutive_failures + 1) # Update Prometheus lag with Session(engine) as session: r = session.get(Replica, replica_obj.id) if r and r.last_sync_ts: ts = r.last_sync_ts if ts.tzinfo is None: ts = ts.replace(tzinfo=timezone.utc) lag = (datetime.now(timezone.utc) - ts).total_seconds() metrics.replica_lag.labels(replica=replica_obj.name).set(lag) result_container[0] = total_synced result_container[1] = total_failed async def _send_alert( replica: Replica, event: str, run_stats: dict, settings: dict, engine, ) -> None: import httpx target_type = settings.get("alert_target_type", "") target_url = settings.get("alert_target_url", "") cooldown = int(settings.get("alert_cooldown_seconds", "3600")) if not target_type or not target_url: return now = datetime.now(timezone.utc) if replica.last_alert_at: last = replica.last_alert_at if last.tzinfo is None: last = last.replace(tzinfo=timezone.utc) if (now - last).total_seconds() < cooldown: return payload = { "event": event, "replica": replica.name, "replica_url": replica.url, "consecutive_failures": replica.consecutive_failures, "docs_failed": run_stats.get("docs_failed", 0), "docs_synced": run_stats.get("docs_synced", 0), "timestamp": now.isoformat(), } config = get_config() token_enc = settings.get("alert_target_token", "") token = decrypt(token_enc, config.secret_key) if token_enc else "" try: async with httpx.AsyncClient(timeout=10.0) as client: if target_type == "gotify": await client.post( f"{target_url}/message", json={ "title": "pngx-controller alert", "message": str(payload), "priority": 7, }, headers={"X-Gotify-Key": token}, ) elif target_type == "webhook": headers = {} if token: headers["Authorization"] = token await client.post(target_url, json=payload, headers=headers) with Session(engine) as session: r = session.get(Replica, replica.id) if r: r.last_alert_at = now session.add(r) session.commit() except Exception as e: emit_log("warning", f"Alert send failed: {e}") def _close_run( engine, run_id: int, synced: int, failed: int, timed_out: bool ) -> None: with Session(engine) as session: sr = session.get(SyncRun, run_id) if sr: sr.finished_at = datetime.now(timezone.utc) sr.docs_synced = synced sr.docs_failed = failed sr.timed_out = timed_out session.add(sr) session.commit() def _do_backup(db_path: str) -> None: """Copy DB to .bak file after a successful sync run.""" import os bak_path = db_path + ".bak" try: import sqlite3 as _sqlite3 src = _sqlite3.connect(db_path) dst = _sqlite3.connect(bak_path) src.backup(dst) dst.close() src.close() except Exception as e: emit_log("warning", f"DB backup failed: {e}")