diff --git a/app/database.py b/app/database.py index 441149f..79b26c9 100644 --- a/app/database.py +++ b/app/database.py @@ -16,16 +16,38 @@ DB_NAME = os.getenv("DB_NAME") DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" -engine = create_async_engine( - DATABASE_URL, - echo=True, - future=True, - pool_pre_ping=True, - pool_recycle=300, -) +_engine = None + + +def get_engine(): + """Lazy engine creation to ensure pool is bound to the current event loop.""" + global _engine + if _engine is None: + _engine = create_async_engine( + DATABASE_URL, + echo=True, + future=True, + pool_pre_ping=True, + pool_recycle=300, + ) + return _engine + + +async def dispose_engine(): + """Dispose engine and reset so next call creates a fresh one.""" + global _engine + if _engine is not None: + await _engine.dispose() + _engine = None + + +# Module-level alias for backward compatibility +engine = None # Use get_engine() instead + async def init_db(): - async with engine.begin() as conn: + eng = get_engine() + async with eng.begin() as conn: # await conn.run_sync(SQLModel.metadata.drop_all) await conn.run_sync(SQLModel.metadata.create_all) @@ -59,7 +81,7 @@ async def init_db(): async def get_session() -> AsyncSession: async_session = sessionmaker( - engine, class_=AsyncSession, expire_on_commit=False + get_engine(), class_=AsyncSession, expire_on_commit=False ) async with async_session() as session: yield session diff --git a/app/main.py b/app/main.py index 9bfbfb3..3e0ce13 100644 --- a/app/main.py +++ b/app/main.py @@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel import select, func, text -from .database import init_db, get_session, engine +from .database import init_db, get_session, get_engine from .models import ErrorLog, ErrorLogCreate, LogStatus, TaskStatusUpdate, RepairTask, RepairTaskCreate, Project, ProjectUpdate from .gitea_client import GiteaClient from .self_report import self_report_error @@ -49,7 +49,7 @@ async def _register_self_projects(): "description": "日志中台 React 管理端", }, ] - async_session = sa_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async_session = sa_sessionmaker(get_engine(), class_=AsyncSession, expire_on_commit=False) async with async_session() as session: for proj_data in projects: stmt = select(Project).where(Project.project_id == proj_data["project_id"]) diff --git a/app/self_report.py b/app/self_report.py index 29cbabf..4c7706f 100644 --- a/app/self_report.py +++ b/app/self_report.py @@ -9,7 +9,7 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.orm import sessionmaker -from .database import engine +from .database import get_engine from .models import ErrorLog, LogStatus, Project PROJECT_ID = "log_center_api" @@ -35,7 +35,7 @@ async def self_report_error(exc: Exception, context: dict = None): raw = f"{PROJECT_ID}|{error_type}|{file_path}|{line_number}" fingerprint = hashlib.md5(raw.encode()).hexdigest() - async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async_session = sessionmaker(get_engine(), class_=AsyncSession, expire_on_commit=False) async with async_session() as session: # 去重检查 stmt = select(ErrorLog).where(ErrorLog.fingerprint == fingerprint)