36 lines
1.5 KiB
Python
36 lines
1.5 KiB
Python
"""认证路由"""
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.orm import Session
|
|
from database import get_db
|
|
from models import User
|
|
from schemas import LoginRequest, Token, UserOut
|
|
from auth import verify_password, create_access_token, get_current_user
|
|
|
|
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
def login(req: LoginRequest, db: Session = Depends(get_db)):
|
|
user = db.query(User).filter(User.username == req.username).first()
|
|
if not user or not verify_password(req.password, user.password_hash):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误")
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="账号已停用")
|
|
token = create_access_token(data={"sub": str(user.id)})
|
|
return {"access_token": token, "token_type": "bearer"}
|
|
|
|
|
|
@router.get("/me", response_model=UserOut)
|
|
def get_me(current_user: User = Depends(get_current_user)):
|
|
return UserOut(
|
|
id=current_user.id,
|
|
username=current_user.username,
|
|
name=current_user.name,
|
|
phase_group=current_user.phase_group.value if hasattr(current_user.phase_group, 'value') else current_user.phase_group,
|
|
role=current_user.role.value if hasattr(current_user.role, 'value') else current_user.role,
|
|
monthly_salary=current_user.monthly_salary,
|
|
daily_cost=current_user.daily_cost,
|
|
is_active=current_user.is_active,
|
|
created_at=current_user.created_at,
|
|
)
|