""" FastAPI API 测试 - 测试修复报告相关接口的新字段和过滤功能 """ import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import sessionmaker from app.main import app from app.database import get_session from app.models import ErrorLog, LogStatus, RepairTask # 使用 SQLite 内存数据库 TEST_DATABASE_URL = "sqlite+aiosqlite:///./test.db" test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) async def override_get_session(): async_session = sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: yield session app.dependency_overrides[get_session] = override_get_session @pytest_asyncio.fixture(autouse=True) async def setup_db(): """每个测试前重建数据库""" async with test_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.drop_all) await conn.run_sync(SQLModel.metadata.create_all) yield async with test_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.drop_all) @pytest_asyncio.fixture async def seed_data(): """插入测试数据""" async_session = sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: # 创建两个 ErrorLog log1 = ErrorLog( id=1, project_id="rtc_backend", environment="production", level="ERROR", error_type="ValueError", error_message="test error 1", file_path="app/views.py", line_number=42, stack_trace=["line1", "line2"], fingerprint="fp001", status=LogStatus.FIX_FAILED, ) log2 = ErrorLog( id=2, project_id="rtc_backend", environment="production", level="ERROR", error_type="TypeError", error_message="test error 2", file_path="app/models.py", line_number=10, stack_trace=["line1"], fingerprint="fp002", status=LogStatus.FIXED, ) session.add(log1) session.add(log2) await session.commit() # 创建 RepairTask 记录(含新字段) task1 = RepairTask( id=1, error_log_id=1, status=LogStatus.FIXING, project_id="rtc_backend", ai_analysis="round 1 analysis", fix_plan="plan", code_diff="diff1", modified_files=["file1.py"], test_output="FAILED test_foo", test_passed=False, repair_round=1, failure_reason="测试未通过 (第 1/3 轮)", ) task2 = RepairTask( id=2, error_log_id=1, status=LogStatus.FIX_FAILED, project_id="rtc_backend", ai_analysis="round 2 analysis", fix_plan="plan", code_diff="diff2", modified_files=["file1.py"], test_output="FAILED test_foo", test_passed=False, repair_round=2, failure_reason="测试未通过 (第 2/3 轮)", ) task3 = RepairTask( id=3, error_log_id=2, status=LogStatus.FIXED, project_id="rtc_backend", ai_analysis="fixed analysis", fix_plan="plan", code_diff="diff3", modified_files=["file2.py"], test_output="OK", test_passed=True, repair_round=1, failure_reason=None, ) session.add_all([task1, task2, task3]) await session.commit() @pytest.mark.asyncio async def test_create_repair_report_with_new_fields(): """测试创建修复报告时包含 repair_round 和 failure_reason""" async_session = sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: log = ErrorLog( id=10, project_id="rtc_backend", environment="production", level="ERROR", error_type="ValueError", error_message="test", file_path="x.py", line_number=1, stack_trace=[], fingerprint="fp_test_create", status=LogStatus.FIXING, ) session.add(log) await session.commit() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post("/api/v1/repair/reports", json={ "error_log_id": 10, "status": "FIX_FAILED", "project_id": "rtc_backend", "ai_analysis": "analysis content", "fix_plan": "fix plan", "code_diff": "some diff", "modified_files": ["a.py"], "test_output": "FAILED: test_something", "test_passed": False, "repair_round": 2, "failure_reason": "测试未通过 (第 2/3 轮)", }) assert resp.status_code == 200 data = resp.json() assert data["message"] == "Report uploaded" report_id = data["id"] resp2 = await client.get(f"/api/v1/repair/reports/{report_id}") assert resp2.status_code == 200 report = resp2.json() assert report["repair_round"] == 2 assert report["failure_reason"] == "测试未通过 (第 2/3 轮)" assert report["test_passed"] is False @pytest.mark.asyncio async def test_create_repair_report_success_no_failure_reason(): """测试成功报告的 failure_reason 为 null""" async_session = sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) async with async_session() as session: log = ErrorLog( id=11, project_id="rtc_backend", environment="production", level="ERROR", error_type="ValueError", error_message="test", file_path="x.py", line_number=1, stack_trace=[], fingerprint="fp_test_success", status=LogStatus.FIXING, ) session.add(log) await session.commit() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post("/api/v1/repair/reports", json={ "error_log_id": 11, "status": "FIXED", "project_id": "rtc_backend", "ai_analysis": "fixed it", "fix_plan": "plan", "code_diff": "diff", "modified_files": ["b.py"], "test_output": "OK all passed", "test_passed": True, "repair_round": 1, }) assert resp.status_code == 200 report_id = resp.json()["id"] resp2 = await client.get(f"/api/v1/repair/reports/{report_id}") report = resp2.json() assert report["repair_round"] == 1 assert report["failure_reason"] is None assert report["test_passed"] is True @pytest.mark.asyncio async def test_filter_repair_reports_by_error_log_id(seed_data): """测试按 error_log_id 过滤修复报告""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get("/api/v1/repair/reports", params={"error_log_id": 1}) assert resp.status_code == 200 data = resp.json() assert data["total"] == 2 assert len(data["items"]) == 2 for item in data["items"]: assert item["error_log_id"] == 1 resp2 = await client.get("/api/v1/repair/reports", params={"error_log_id": 2}) data2 = resp2.json() assert data2["total"] == 1 assert data2["items"][0]["error_log_id"] == 2 assert data2["items"][0]["repair_round"] == 1 assert data2["items"][0]["failure_reason"] is None @pytest.mark.asyncio async def test_filter_repair_reports_no_results(seed_data): """测试按不存在的 error_log_id 查询返回空""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get("/api/v1/repair/reports", params={"error_log_id": 999}) assert resp.status_code == 200 data = resp.json() assert data["total"] == 0 assert data["items"] == [] @pytest.mark.asyncio async def test_repair_report_detail_has_new_fields(seed_data): """测试修复报告详情包含新字段""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get("/api/v1/repair/reports/1") assert resp.status_code == 200 report = resp.json() assert report["repair_round"] == 1 assert report["failure_reason"] == "测试未通过 (第 1/3 轮)" resp2 = await client.get("/api/v1/repair/reports/3") report2 = resp2.json() assert report2["repair_round"] == 1 assert report2["failure_reason"] is None assert report2["test_passed"] is True