from datetime import timedelta from decimal import Decimal, InvalidOperation from django.db import transaction from django.db.models import Sum from django.db.models.functions import TruncDate from django.utils import timezone from rest_framework import status from rest_framework.decorators import api_view, permission_classes from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from apps.ai.models import AITask from apps.common.api import get_current_team from .models import CreditAccount, CreditLedger from .serializers import CreditAccountSerializer, CreditLedgerSerializer # AITask.task_type → 账户页「按阶段分布」的 4 个聚合桶 _STAGE_BUCKET = { AITask.Type.SCRIPT_GENERATION: "script", AITask.Type.SCRIPT_OPTIMIZATION: "script", AITask.Type.PRODUCT_IMAGE: "base", AITask.Type.PERSON_IMAGE: "base", AITask.Type.SCENE_IMAGE: "base", AITask.Type.STORYBOARD: "storyboard", AITask.Type.VIDEO_SEGMENT: "video", AITask.Type.EXPORT: "video", } @api_view(["GET"]) @permission_classes([IsAuthenticated]) def summary(request): team = get_current_team(request.user) account, _ = CreditAccount.objects.get_or_create(team=team) charged = CreditLedger.objects.filter(team=team, ledger_type=CreditLedger.Type.CHARGE).aggregate( total=Sum("amount") )["total"] or 0 return Response( { "account": CreditAccountSerializer(account).data, "charged_total": charged, } ) @api_view(["GET"]) @permission_classes([IsAuthenticated]) def ledgers(request): team = get_current_team(request.user) queryset = CreditLedger.objects.filter(team=team).select_related("user", "project", "task").order_by("-created_at") project_id = request.query_params.get("project") user_id = request.query_params.get("user") if project_id: queryset = queryset.filter(project_id=project_id) if user_id: queryset = queryset.filter(user_id=user_id) # 服务端分页:总数随流水增长(原先写死 [:100] 导致永远 100 条) try: page = max(1, int(request.query_params.get("page", 1))) except (TypeError, ValueError): page = 1 try: page_size = int(request.query_params.get("page_size", 10)) except (TypeError, ValueError): page_size = 10 page_size = max(1, min(page_size, 100)) total = queryset.count() start = (page - 1) * page_size rows = queryset[start:start + page_size] return Response( { "count": total, "page": page, "page_size": page_size, "results": CreditLedgerSerializer(rows, many=True).data, } ) @api_view(["POST"]) @permission_classes([IsAuthenticated]) def recharge(request): team = get_current_team(request.user) try: amount = Decimal(str(request.data.get("amount", "0"))) bonus = Decimal(str(request.data.get("bonus", "0"))) except (InvalidOperation, TypeError): return Response({"detail": "invalid amount"}, status=status.HTTP_400_BAD_REQUEST) if amount <= 0: return Response({"detail": "amount must be positive"}, status=status.HTTP_400_BAD_REQUEST) if bonus < 0: return Response({"detail": "bonus cannot be negative"}, status=status.HTTP_400_BAD_REQUEST) channel = str(request.data.get("channel") or "manual")[:32] credited = amount + bonus with transaction.atomic(): account, _ = CreditAccount.objects.select_for_update().get_or_create(team=team) account.balance += credited account.save(update_fields=["balance", "updated_at"]) ledger = CreditLedger.objects.create( team=team, user=request.user, ledger_type=CreditLedger.Type.RECHARGE, amount=credited, balance_after=account.balance, reason="团队充值", metadata={"channel": channel, "paid_amount": str(amount), "bonus": str(bonus)}, ) return Response( { "account": CreditAccountSerializer(account).data, "ledger": CreditLedgerSerializer(ledger).data, }, status=status.HTTP_201_CREATED, ) @api_view(["GET"]) @permission_classes([IsAuthenticated]) def trend(request): """账户页消费分析:消费趋势(日/周/月可切)+ 本月按阶段/按项目分布。全部来自真实 CHARGE 流水。""" team = get_current_team(request.user) today = timezone.localdate() rng = request.query_params.get("range", "day") charges = CreditLedger.objects.filter(team=team, ledger_type=CreditLedger.Type.CHARGE) def _daily_amounts(win_start): rows = ( charges.filter(created_at__date__gte=win_start) .annotate(day=TruncDate("created_at")) .values("day") .annotate(amount=Sum("amount")) ) return {row["day"]: row["amount"] or Decimal("0") for row in rows} # 按 range 选窗口与分桶:日=近 14 天 / 周=近 8 周 / 月=近 6 个自然月(缺口补 0) series = [] if rng == "week": monday = today - timedelta(days=today.weekday()) starts = [monday - timedelta(weeks=(7 - i)) for i in range(8)] amt_by_day = _daily_amounts(starts[0]) for s in starts: total = sum((amt_by_day.get(s + timedelta(days=k), Decimal("0")) for k in range(7)), Decimal("0")) series.append({"date": s.isoformat(), "label": s.strftime("%m/%d"), "amount": str(total)}) elif rng == "month": seq = [] y, m = today.year, today.month for _ in range(6): seq.append((y, m)) m -= 1 if m == 0: m, y = 12, y - 1 seq.reverse() amt_by_day = _daily_amounts(today.replace(year=seq[0][0], month=seq[0][1], day=1)) for yy, mm in seq: total = sum((v for d, v in amt_by_day.items() if d.year == yy and d.month == mm), Decimal("0")) series.append({"date": f"{yy}-{mm:02d}-01", "label": f"{mm}月", "amount": str(total)}) else: start = today - timedelta(days=13) amt_by_day = _daily_amounts(start) for i in range(14): d = start + timedelta(days=i) series.append({"date": d.isoformat(), "label": d.strftime("%m/%d"), "amount": str(amt_by_day.get(d, Decimal("0")))}) daily = series total_14d = sum((Decimal(s["amount"]) for s in series), Decimal("0")) peak = max((Decimal(s["amount"]) for s in series), default=Decimal("0")) avg = (total_14d / len(series)).quantize(Decimal("0.0001")) if series else Decimal("0") # 本月按阶段分布(task.task_type → 4 桶) month_start = today.replace(day=1) month_charges = charges.filter(created_at__date__gte=month_start).select_related("task") by_stage = {"script": Decimal("0"), "base": Decimal("0"), "storyboard": Decimal("0"), "video": Decimal("0")} project_amounts: dict[str, Decimal] = {} for row in month_charges: task = row.task bucket = _STAGE_BUCKET.get(task.task_type) if task else None if bucket: by_stage[bucket] += row.amount pid = str(row.project_id) if row.project_id else None if pid: project_amounts[pid] = project_amounts.get(pid, Decimal("0")) + row.amount return Response( { "daily": daily, "total_14d": str(total_14d), "avg": str(avg), "peak": str(peak), "by_stage": {k: str(v) for k, v in by_stage.items()}, "by_project": {k: str(v) for k, v in project_amounts.items()}, } )