168 lines
7.4 KiB
Python

from django.db.models import Q
from django.utils import timezone
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from apps.assets.models import Asset
from apps.billing.models import CreditAccount
from apps.common.api import TeamScopedViewSetMixin
from apps.projects.models import Project
from .models import Notification
from .serializers import NotificationSerializer
def project_stage_label(project):
return {
"script": "Stage 1 · 脚本",
"base_assets": "Stage 2 · 基础资产",
"storyboard": "Stage 3 · 故事板",
"video": "Stage 4 · 视频",
"export": "Stage 5 · 导出",
}.get(project.current_stage, "Stage 1 · 脚本")
def project_priority(project):
if project.status == Project.Status.COMPLETED:
return Notification.Priority.OK
if project.status == Project.Status.FAILED:
return Notification.Priority.ERR
return Notification.Priority.INFO
def ensure_team_notifications(team, user):
def create_once(dedupe_key, **payload):
Notification.objects.get_or_create(
team=team,
recipient=user,
dedupe_key=dedupe_key,
defaults=payload,
)
create_once(
"system:welcome",
notification_type=Notification.Type.SYSTEM,
priority=Notification.Priority.INFO,
title="团队已接入 AirShelf",
brief="真实消息中心已启用,状态会写入 Django 数据库。",
body="消息已从演示数据切换为团队级通知表。已读、未读、归档等操作都会持久化保存。",
source="Airshelf 系统",
stage="系统公告",
owner_label="系统",
cost_label="-",
related_url="settings.html#sec-notify",
)
for project in Project.objects.filter(team=team).select_related("product", "created_by").order_by("-updated_at")[:5]:
product_title = project.product.title if project.product_id else "未绑定商品"
create_once(
f"project:{project.id}:status:{project.status}:{project.current_stage}",
notification_type=Notification.Type.TASK,
priority=project_priority(project),
title=f"项目「{project.name}」状态更新",
brief=f"{product_title} · {project_stage_label(project)} · {project.get_status_display()}",
body=f"项目「{project.name}」当前处于 {project_stage_label(project)}。这条消息来自 Django 项目表,刷新后状态会保持一致。",
source="视频项目",
project=project,
stage=project_stage_label(project),
owner_label=project.created_by.username if project.created_by_id else "成员",
cost_label="-",
related_url=f"pipeline.html?project_id={project.id}",
metadata={"status": project.status, "current_stage": project.current_stage},
)
for asset in Asset.objects.filter(team=team).select_related("created_by").order_by("-updated_at")[:3]:
create_once(
f"asset:{asset.id}:created",
notification_type=Notification.Type.TASK,
priority=Notification.Priority.OK,
title=f"资产「{asset.name}」已加入资产库",
brief=f"{asset.get_category_display()} · {asset.get_asset_type_display()}",
body="资产记录来自真实资产表。后续上传、AI 生成、导出成片都可以在这里形成团队通知。",
source="资产库",
stage="资产入库",
owner_label=asset.created_by.username if asset.created_by_id else "成员",
cost_label="-",
related_url="library.html",
metadata={"asset_id": str(asset.id), "category": asset.category, "asset_type": asset.asset_type},
)
account, _ = CreditAccount.objects.get_or_create(team=team)
if account.balance <= 100:
create_once(
f"billing:low-balance:{account.id}",
notification_type=Notification.Type.BILLING,
priority=Notification.Priority.WARN,
title="团队余额低于预警线",
brief=f"当前余额 ¥{account.balance:.2f},建议及时充值。",
body="余额低于 100 元时系统会生成预警通知。充值或调低成员额度后可在消费页查看最新账本。",
source="计费中心",
stage="余额监控",
owner_label="系统",
cost_label=f"¥{account.balance:.2f}",
related_url="account.html",
)
class NotificationViewSet(TeamScopedViewSetMixin, ModelViewSet):
serializer_class = NotificationSerializer
queryset = Notification.objects.select_related("team", "recipient", "project").all()
search_fields = ["title", "brief", "body", "source", "stage"]
ordering_fields = ["created_at", "updated_at"]
ordering = ["-created_at"]
def get_queryset(self):
queryset = super().get_queryset().filter(archived_at__isnull=True)
user = self.request.user
queryset = queryset.filter(Q(recipient=user) | Q(recipient__isnull=True))
notification_type = self.request.query_params.get("type")
if notification_type and notification_type not in {"all", "unread"}:
queryset = queryset.filter(notification_type=notification_type)
if self.request.query_params.get("unread") in {"1", "true", "yes"}:
queryset = queryset.filter(is_read=False)
return queryset
def list(self, request, *args, **kwargs):
ensure_team_notifications(self.get_team(), request.user)
response = super().list(request, *args, **kwargs)
data = response.data
unread_count = self.get_queryset().filter(is_read=False).count()
if isinstance(data, dict):
data["unread_count"] = unread_count
return response
def perform_create(self, serializer):
serializer.save(team=self.get_team(), recipient=self.request.user)
@action(detail=False, methods=["post"], url_path="mark-all-read")
def mark_all_read(self, request):
now = timezone.now()
count = self.get_queryset().filter(is_read=False).update(is_read=True, read_at=now, updated_at=now)
return Response({"updated": count, "unread_count": self.get_queryset().filter(is_read=False).count()})
@action(detail=False, methods=["post"], url_path="mark-all-unread")
def mark_all_unread(self, request):
now = timezone.now()
count = self.get_queryset().filter(is_read=True).update(is_read=False, read_at=None, updated_at=now)
return Response({"updated": count, "unread_count": self.get_queryset().filter(is_read=False).count()})
@action(detail=True, methods=["post"], url_path="mark-read")
def mark_read(self, request, pk=None):
notification = self.get_object()
notification.mark_read()
return Response(self.get_serializer(notification).data)
@action(detail=True, methods=["post"], url_path="mark-unread")
def mark_unread(self, request, pk=None):
notification = self.get_object()
notification.mark_unread()
return Response(self.get_serializer(notification).data)
@action(detail=True, methods=["post"], url_path="archive")
def archive(self, request, pk=None):
notification = self.get_object()
notification.archive()
return Response(status=status.HTTP_204_NO_CONTENT)