from django.db.models import Q from django.utils import timezone from rest_framework import status from rest_framework.decorators import action from rest_framework.pagination import PageNumberPagination from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet class NotificationPagination(PageNumberPagination): # 收件箱滚动加载:每批 10 条,前端可用 ?page_size= 覆盖(上限 100) page_size = 10 page_size_query_param = "page_size" max_page_size = 100 from apps.assets.models import Asset from apps.billing.models import CreditAccount from apps.common.api import TeamScopedViewSetMixin from apps.projects.models import Project from .models import Notification from .serializers import NotificationSerializer def project_stage_label(project): return { "script": "Stage 1 · 脚本", "base_assets": "Stage 2 · 基础资产", "storyboard": "Stage 3 · 故事板", "video": "Stage 4 · 视频", "export": "Stage 5 · 导出", }.get(project.current_stage, "Stage 1 · 脚本") def project_priority(project): if project.status == Project.Status.COMPLETED: return Notification.Priority.OK if project.status == Project.Status.FAILED: return Notification.Priority.ERR return Notification.Priority.INFO def ensure_team_notifications(team, user): def create_once(dedupe_key, **payload): Notification.objects.get_or_create( team=team, recipient=user, dedupe_key=dedupe_key, defaults=payload, ) create_once( "system:welcome", notification_type=Notification.Type.SYSTEM, priority=Notification.Priority.INFO, title="团队已接入 AirShelf", brief="真实消息中心已启用,状态会写入 Django 数据库。", body="消息已从演示数据切换为团队级通知表。已读、未读、归档等操作都会持久化保存。", source="Airshelf 系统", stage="系统公告", owner_label="系统", cost_label="-", related_url="settings.html#sec-notify", ) for project in Project.objects.filter(team=team).select_related("product", "created_by").order_by("-updated_at")[:5]: product_title = project.product.title if project.product_id else "未绑定商品" create_once( f"project:{project.id}:status:{project.status}:{project.current_stage}", notification_type=Notification.Type.TASK, priority=project_priority(project), title=f"项目「{project.name}」状态更新", brief=f"{product_title} · {project_stage_label(project)} · {project.get_status_display()}", body=f"项目「{project.name}」当前处于 {project_stage_label(project)}。这条消息来自 Django 项目表,刷新后状态会保持一致。", source="视频项目", project=project, stage=project_stage_label(project), owner_label=project.created_by.username if project.created_by_id else "成员", cost_label="-", related_url=f"pipeline.html?project_id={project.id}", metadata={"status": project.status, "current_stage": project.current_stage}, ) for asset in Asset.objects.filter(team=team).select_related("created_by").order_by("-updated_at")[:3]: create_once( f"asset:{asset.id}:created", notification_type=Notification.Type.TASK, priority=Notification.Priority.OK, title=f"资产「{asset.name}」已加入资产库", brief=f"{asset.get_category_display()} · {asset.get_asset_type_display()}", body="资产记录来自真实资产表。后续上传、AI 生成、导出成片都可以在这里形成团队通知。", source="资产库", stage="资产入库", owner_label=asset.created_by.username if asset.created_by_id else "成员", cost_label="-", related_url="library.html", metadata={"asset_id": str(asset.id), "category": asset.category, "asset_type": asset.asset_type}, ) account, _ = CreditAccount.objects.get_or_create(team=team) if account.balance <= 100: create_once( f"billing:low-balance:{account.id}", notification_type=Notification.Type.BILLING, priority=Notification.Priority.WARN, title="团队余额低于预警线", brief=f"当前余额 ¥{account.balance:.2f},建议及时充值。", body="余额低于 100 元时系统会生成预警通知。充值或调低成员额度后可在消费页查看最新账本。", source="计费中心", stage="余额监控", owner_label="系统", cost_label=f"¥{account.balance:.2f}", related_url="account.html", ) class NotificationViewSet(TeamScopedViewSetMixin, ModelViewSet): serializer_class = NotificationSerializer queryset = Notification.objects.select_related("team", "recipient", "project").all() pagination_class = NotificationPagination search_fields = ["title", "brief", "body", "source", "stage"] ordering_fields = ["created_at", "updated_at"] ordering = ["-created_at"] # 团队 + 收件人 + 未归档:分类计数的基准集(不含 tab/未读/搜索过滤) def _recipient_scope(self): queryset = super().get_queryset().filter(archived_at__isnull=True) user = self.request.user return queryset.filter(Q(recipient=user) | Q(recipient__isnull=True)) def get_queryset(self): queryset = self._recipient_scope() notification_type = self.request.query_params.get("type") if notification_type and notification_type not in {"all", "unread"}: queryset = queryset.filter(notification_type=notification_type) if self.request.query_params.get("unread") in {"1", "true", "yes"}: queryset = queryset.filter(is_read=False) return queryset def list(self, request, *args, **kwargs): ensure_team_notifications(self.get_team(), request.user) response = super().list(request, *args, **kwargs) data = response.data if isinstance(data, dict): # 分类 chip 计数取绝对总数(忽略当前 tab/搜索),与设计稿一致 base = self._recipient_scope() unread_count = base.filter(is_read=False).count() data["unread_count"] = unread_count data["type_counts"] = { "all": base.count(), "unread": unread_count, "task": base.filter(notification_type="task").count(), "team": base.filter(notification_type="team").count(), "billing": base.filter(notification_type="billing").count(), "system": base.filter(notification_type="system").count(), } return response def perform_create(self, serializer): serializer.save(team=self.get_team(), recipient=self.request.user) @action(detail=False, methods=["post"], url_path="mark-all-read") def mark_all_read(self, request): now = timezone.now() count = self.get_queryset().filter(is_read=False).update(is_read=True, read_at=now, updated_at=now) return Response({"updated": count, "unread_count": self.get_queryset().filter(is_read=False).count()}) @action(detail=False, methods=["post"], url_path="mark-all-unread") def mark_all_unread(self, request): now = timezone.now() count = self.get_queryset().filter(is_read=True).update(is_read=False, read_at=None, updated_at=now) return Response({"updated": count, "unread_count": self.get_queryset().filter(is_read=False).count()}) @action(detail=True, methods=["post"], url_path="mark-read") def mark_read(self, request, pk=None): notification = self.get_object() notification.mark_read() return Response(self.get_serializer(notification).data) @action(detail=True, methods=["post"], url_path="mark-unread") def mark_unread(self, request, pk=None): notification = self.get_object() notification.mark_unread() return Response(self.get_serializer(notification).data) @action(detail=True, methods=["post"], url_path="archive") def archive(self, request, pk=None): notification = self.get_object() notification.archive() return Response(status=status.HTTP_204_NO_CONTENT)