from sqlmodel import SQLModel, create_engine from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import sessionmaker from sqlalchemy import text from sqlalchemy.exc import DBAPIError, OperationalError import os import logging from dotenv import load_dotenv logger = logging.getLogger(__name__) load_dotenv() DB_USER = os.getenv("DB_USER") DB_PASSWORD = os.getenv("DB_PASSWORD") DB_HOST = os.getenv("DB_HOST") DB_PORT = os.getenv("DB_PORT", "5432") DB_NAME = os.getenv("DB_NAME") DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" engine = create_async_engine( DATABASE_URL, echo=True, future=True, pool_pre_ping=True, pool_recycle=180, pool_size=5, max_overflow=10, pool_timeout=30, ) async def init_db(): async with engine.begin() as conn: # await conn.run_sync(SQLModel.metadata.drop_all) await conn.run_sync(SQLModel.metadata.create_all) # Migrate: add new columns to existing repairtask table migrations = [ "ALTER TABLE repairtask ADD COLUMN IF NOT EXISTS repair_round INTEGER DEFAULT 1", "ALTER TABLE repairtask ADD COLUMN IF NOT EXISTS failure_reason TEXT", # Log source support "ALTER TABLE errorlog ADD COLUMN IF NOT EXISTS source VARCHAR(20) DEFAULT 'runtime'", "ALTER TABLE errorlog ALTER COLUMN file_path DROP NOT NULL", "ALTER TABLE errorlog ALTER COLUMN line_number DROP NOT NULL", "CREATE INDEX IF NOT EXISTS ix_errorlog_source ON errorlog (source)", # ErrorLog failure_reason "ALTER TABLE errorlog ADD COLUMN IF NOT EXISTS failure_reason TEXT", # Bug severity (1-10 AI评估等级) "ALTER TABLE errorlog ADD COLUMN IF NOT EXISTS severity INTEGER", "ALTER TABLE errorlog ADD COLUMN IF NOT EXISTS severity_reason TEXT", # Seed Project table from existing ErrorLog data """INSERT INTO project (project_id, created_at, updated_at) SELECT DISTINCT e.project_id, NOW(), NOW() FROM errorlog e WHERE NOT EXISTS (SELECT 1 FROM project p WHERE p.project_id = e.project_id)""", ] for sql in migrations: try: await conn.execute(text(sql)) except Exception: pass # Already applied _async_session_factory = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False ) _MAX_RETRIES = 2 async def get_session() -> AsyncSession: for attempt in range(_MAX_RETRIES): try: async with _async_session_factory() as session: yield session return except (DBAPIError, OperationalError, ConnectionResetError, OSError) as e: if attempt < _MAX_RETRIES - 1: logger.warning("Database connection error (attempt %d/%d): %s", attempt + 1, _MAX_RETRIES, e) await engine.dispose() continue raise