From 92826dec1409821d2cd56e66a55ea09c0fc73a5c Mon Sep 17 00:00:00 2001 From: zyc <1439655764@qq.com> Date: Tue, 9 Jun 2026 14:46:16 +0800 Subject: [PATCH] feat(core/backend): pipeline continuity + threaded ffmpeg burn-in export + upload/save-timeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Video pipeline (script→assets→storyboard→video→stitch): - robust split_script_into_segments (4 non-empty scenes), scene-aware storyboard/video prompts - link VideoSegment→ScriptSegment + storyboard-frame reference image (graceful text fallback) - idempotent poll_video_segment (no double-charge on repeated polling) - threaded export (no Celery worker needed) + poll-export endpoint - run_export_job rewritten to filter_complex: per-clip trim, xfade transitions, subtitle burn-in (Pillow PNG overlay; this ffmpeg lacks libass), BGM mix - upload-video-segment / upload-bgm / save-timeline endpoints - serializers embed asset preview URLs (beat assets pagination); Pillow added to requirements Also includes prior uncommitted backend work: account preferences/sessions, billing trend, product/asset endpoints, accounts 0002 migration. Co-Authored-By: Claude Opus 4.8 --- .../0002_loginsession_userpreference.py | 94 +++++ core/backend/apps/accounts/models.py | 38 ++ core/backend/apps/accounts/serializers.py | 22 +- core/backend/apps/accounts/urls.py | 8 + core/backend/apps/accounts/views.py | 91 ++++- core/backend/apps/ai/services.py | 333 ++++++++++++++++-- core/backend/apps/billing/urls.py | 3 +- core/backend/apps/billing/views.py | 94 +++++ core/backend/apps/products/views.py | 68 +++- core/backend/apps/projects/serializers.py | 75 +++- core/backend/apps/projects/services/export.py | 260 ++++++++++++-- core/backend/apps/projects/views.py | 239 ++++++++++++- core/backend/requirements.txt | 1 + 13 files changed, 1229 insertions(+), 97 deletions(-) create mode 100644 core/backend/apps/accounts/migrations/0002_loginsession_userpreference.py diff --git a/core/backend/apps/accounts/migrations/0002_loginsession_userpreference.py b/core/backend/apps/accounts/migrations/0002_loginsession_userpreference.py new file mode 100644 index 0000000..7afa01d --- /dev/null +++ b/core/backend/apps/accounts/migrations/0002_loginsession_userpreference.py @@ -0,0 +1,94 @@ +# Generated by Django 5.1.15 on 2026-06-08 09:48 + +import apps.accounts.models +import django.db.models.deletion +import uuid +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("accounts", "0001_initial"), + ] + + operations = [ + migrations.CreateModel( + name="LoginSession", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("user_agent", models.CharField(blank=True, max_length=400)), + ("ip_address", models.GenericIPAddressField(blank=True, null=True)), + ("last_seen_at", models.DateTimeField(auto_now=True)), + ("revoked_at", models.DateTimeField(blank=True, null=True)), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="login_sessions", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-last_seen_at"], + }, + ), + migrations.CreateModel( + name="UserPreference", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "notify", + models.JSONField( + blank=True, default=apps.accounts.models._default_notify + ), + ), + ("two_factor_enabled", models.BooleanField(default=False)), + ( + "creation_defaults", + models.JSONField( + blank=True, default=apps.accounts.models._default_creation + ), + ), + ( + "display", + models.JSONField( + blank=True, default=apps.accounts.models._default_display + ), + ), + ( + "user", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="preference", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/core/backend/apps/accounts/models.py b/core/backend/apps/accounts/models.py index af64bff..3c8fc35 100644 --- a/core/backend/apps/accounts/models.py +++ b/core/backend/apps/accounts/models.py @@ -34,6 +34,44 @@ class Team(TimeStampedModel): return self.name +def _default_notify() -> dict: + return {"n-export": True, "n-fail": True, "n-quota": True, "n-login": True} + + +def _default_creation() -> dict: + return {"template": "pain", "duration": "60", "subtitle": "big-variety", "bgm": "kapian", "transition": "fade"} + + +def _default_display() -> dict: + return {"appearance": "system", "language": "zh", "density": "standard"} + + +class UserPreference(TimeStampedModel): + """用户设置:通知策略 / 两步验证 / 创作默认 / 显示偏好。服务端持久化(替代前端 localStorage)。""" + + user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="preference") + notify = models.JSONField(default=_default_notify, blank=True) + two_factor_enabled = models.BooleanField(default=False) + creation_defaults = models.JSONField(default=_default_creation, blank=True) + display = models.JSONField(default=_default_display, blank=True) + + def __str__(self) -> str: + return f"prefs/{self.user}" + + +class LoginSession(TimeStampedModel): + """登录会话记录:每次登录写一条(设备 UA / IP / 时间),供设置页「在用设备」展示与下线。""" + + user = models.ForeignKey(User, on_delete=models.CASCADE, related_name="login_sessions") + user_agent = models.CharField(max_length=400, blank=True) + ip_address = models.GenericIPAddressField(null=True, blank=True) + last_seen_at = models.DateTimeField(auto_now=True) + revoked_at = models.DateTimeField(null=True, blank=True) + + class Meta: + ordering = ["-last_seen_at"] + + class TeamMember(TimeStampedModel): class Role(models.TextChoices): OWNER = "owner", "Owner" diff --git a/core/backend/apps/accounts/serializers.py b/core/backend/apps/accounts/serializers.py index 5ec237d..ae4908d 100644 --- a/core/backend/apps/accounts/serializers.py +++ b/core/backend/apps/accounts/serializers.py @@ -2,7 +2,7 @@ from rest_framework import serializers from apps.billing.models import CreditAccount -from .models import Team, TeamMember, User +from .models import LoginSession, Team, TeamMember, User, UserPreference class UserSerializer(serializers.ModelSerializer): @@ -12,6 +12,26 @@ class UserSerializer(serializers.ModelSerializer): read_only_fields = ["id", "status"] +class UserPreferenceSerializer(serializers.ModelSerializer): + class Meta: + model = UserPreference + fields = ["notify", "two_factor_enabled", "creation_defaults", "display", "updated_at"] + read_only_fields = ["updated_at"] + + +class LoginSessionSerializer(serializers.ModelSerializer): + is_current = serializers.SerializerMethodField() + + class Meta: + model = LoginSession + fields = ["id", "user_agent", "ip_address", "last_seen_at", "created_at", "is_current"] + read_only_fields = fields + + def get_is_current(self, obj) -> bool: + ctx = self.context or {} + return bool(obj.ip_address and obj.ip_address == ctx.get("current_ip") and obj.user_agent == ctx.get("current_ua")) + + class TeamSerializer(serializers.ModelSerializer): class Meta: model = Team diff --git a/core/backend/apps/accounts/urls.py b/core/backend/apps/accounts/urls.py index f0473fa..4b2fef0 100644 --- a/core/backend/apps/accounts/urls.py +++ b/core/backend/apps/accounts/urls.py @@ -3,9 +3,13 @@ from django.urls import path from .views import ( change_password, login, + login_sessions, logout, me, + preferences, register, + revoke_login_session, + revoke_other_sessions, team_member_detail, team_member_password, team_members, @@ -20,6 +24,10 @@ urlpatterns = [ path("me/", me, name="auth-me"), path("me/password/", change_password, name="auth-change-password"), path("me/avatar/", update_avatar, name="auth-avatar"), + path("me/preferences/", preferences, name="auth-preferences"), + path("me/sessions/", login_sessions, name="auth-sessions"), + path("me/sessions/revoke-others/", revoke_other_sessions, name="auth-sessions-revoke-others"), + path("me/sessions//revoke/", revoke_login_session, name="auth-session-revoke"), path("team/members/", team_members, name="team-members"), path("team/members//", team_member_detail, name="team-member-detail"), path("team/members//password/", team_member_password, name="team-member-password"), diff --git a/core/backend/apps/accounts/views.py b/core/backend/apps/accounts/views.py index f343559..2282694 100644 --- a/core/backend/apps/accounts/views.py +++ b/core/backend/apps/accounts/views.py @@ -12,8 +12,16 @@ from rest_framework.response import Response from apps.common.api import get_current_team -from .models import TeamMember, User -from .serializers import LoginSerializer, RegisterSerializer, TeamMemberSerializer, TeamSerializer, UserSerializer +from .models import LoginSession, TeamMember, User, UserPreference +from .serializers import ( + LoginSerializer, + LoginSessionSerializer, + RegisterSerializer, + TeamMemberSerializer, + TeamSerializer, + UserPreferenceSerializer, + UserSerializer, +) def auth_payload(user, team, token): @@ -24,6 +32,25 @@ def auth_payload(user, team, token): } +def _client_ip(request): + forwarded = request.META.get("HTTP_X_FORWARDED_FOR", "") + if forwarded: + return forwarded.split(",")[0].strip() + return request.META.get("REMOTE_ADDR") or None + + +def record_login_session(request, user): + """登录成功后记录一条会话(设备 UA / IP),供设置页「在用设备」展示。""" + try: + LoginSession.objects.create( + user=user, + user_agent=(request.META.get("HTTP_USER_AGENT") or "")[:400], + ip_address=_client_ip(request), + ) + except Exception: # noqa: BLE001 — 会话记录失败不应阻断登录 + pass + + @api_view(["POST"]) @permission_classes([]) def register(request): @@ -31,6 +58,7 @@ def register(request): serializer.is_valid(raise_exception=True) data = serializer.save() token, _ = Token.objects.get_or_create(user=data["user"]) + record_login_session(request, data["user"]) return Response(auth_payload(data["user"], data["team"], token), status=status.HTTP_201_CREATED) @@ -48,6 +76,7 @@ def login(request): return Response({"detail": "invalid credentials"}, status=status.HTTP_400_BAD_REQUEST) team = get_current_team(user) token, _ = Token.objects.get_or_create(user=user) + record_login_session(request, user) return Response(auth_payload(user, team, token)) @@ -97,12 +126,19 @@ def change_password(request): return Response({"token": token.key}) -@api_view(["POST"]) +@api_view(["POST", "DELETE"]) @parser_classes([MultiPartParser, FormParser]) @permission_classes([IsAuthenticated]) def update_avatar(request): from apps.assets.storage import TosStorage + # DELETE = 恢复默认头像(清空 avatar_url,前端回退到首字母占位) + if request.method == "DELETE": + user = request.user + user.avatar_url = "" + user.save(update_fields=["avatar_url"]) + return Response(UserSerializer(user).data) + upload = request.FILES.get("file") if upload is None: return Response({"detail": "no file"}, status=status.HTTP_400_BAD_REQUEST) @@ -223,3 +259,52 @@ def team_member_password(request, member_id): member.user.save(update_fields=["password"]) Token.objects.filter(user=member.user).delete() return Response(status=status.HTTP_204_NO_CONTENT) + + +@api_view(["GET", "PUT", "PATCH"]) +@permission_classes([IsAuthenticated]) +def preferences(request): + """用户设置:通知策略 / 两步验证 / 创作默认 / 显示偏好。服务端持久化。""" + pref, _ = UserPreference.objects.get_or_create(user=request.user) + if request.method in ("PUT", "PATCH"): + serializer = UserPreferenceSerializer(pref, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) + serializer.save() + pref.refresh_from_db() + return Response(UserPreferenceSerializer(pref).data) + + +@api_view(["GET"]) +@permission_classes([IsAuthenticated]) +def login_sessions(request): + """在用设备:返回未下线的登录会话(最近 20 条)。""" + sessions = LoginSession.objects.filter(user=request.user, revoked_at__isnull=True)[:20] + current_ip = _client_ip(request) + current_ua = (request.META.get("HTTP_USER_AGENT") or "")[:400] + data = LoginSessionSerializer(sessions, many=True, context={"current_ip": current_ip, "current_ua": current_ua}).data + return Response(data) + + +@api_view(["POST"]) +@permission_classes([IsAuthenticated]) +def revoke_login_session(request, session_id): + """下线单个设备会话。""" + from django.utils import timezone + + updated = LoginSession.objects.filter(user=request.user, id=session_id, revoked_at__isnull=True).update( + revoked_at=timezone.now() + ) + return Response({"revoked": updated}) + + +@api_view(["POST"]) +@permission_classes([IsAuthenticated]) +def revoke_other_sessions(request): + """下线除当前外的所有其他设备:旋转 token(令其他端 token 失效)+ 标记会话已下线。""" + from django.utils import timezone + + LoginSession.objects.filter(user=request.user, revoked_at__isnull=True).update(revoked_at=timezone.now()) + Token.objects.filter(user=request.user).delete() + token, _ = Token.objects.get_or_create(user=request.user) + record_login_session(request, request.user) + return Response({"token": token.key}) diff --git a/core/backend/apps/ai/services.py b/core/backend/apps/ai/services.py index b437b17..669528b 100644 --- a/core/backend/apps/ai/services.py +++ b/core/backend/apps/ai/services.py @@ -1,4 +1,6 @@ +import re import uuid +from datetime import timedelta from decimal import Decimal from django.db import transaction from django.utils import timezone @@ -59,13 +61,43 @@ def build_script_prompt(*, project, user_prompt: str, selling_point_ids: list[st return [{"role": "system", "content": system}, {"role": "user", "content": user}] -def split_script_into_segments(content: str) -> list[str]: - blocks = [line.strip() for line in content.splitlines() if line.strip()] - if len(blocks) >= 4: - return blocks[:4] - if not content.strip(): - return [""] * 4 - return [content.strip()] + [""] * (4 - len(blocks or [content])) +def split_script_into_segments(content: str, count: int = 4) -> list[str]: + """把一段脚本稳健地拆成 `count` 个分镜文本,保证每镜都非空、且所有内容都被分配到某一镜。 + + 原实现按行 `[:4]`,ARK 返回整段散文时常变成「第1镜有词、2/3/4镜全空」, + 导致后续故事板帧 / 视频段拿到空提示词,前后内容断裂。这里改为: + 优先按空行/标号块切,块数够就把全部块均匀分桶;块不够再按句子切;仍不够则补齐。 + """ + + def _bucketize(items: list[str], joiner: str) -> list[str]: + buckets: list[list[str]] = [[] for _ in range(count)] + per = len(items) / count + for index, item in enumerate(items): + buckets[min(count - 1, int(index / per))].append(item) + return [joiner.join(bucket).strip() for bucket in buckets] + + text = (content or "").strip() + if not text: + return [""] * count + + # 1) 优先按空行分段;只有一段时退回按行分 + blocks = [block.strip() for block in re.split(r"\n\s*\n", text) if block.strip()] + if len(blocks) < 2: + blocks = [line.strip() for line in text.splitlines() if line.strip()] + if len(blocks) >= count: + return _bucketize(blocks, "\n") + + # 2) 段落不足:按中英文句末标点切句,再均匀分桶 + sentences = [s.strip() for s in re.split(r"(?<=[。!?!?.;;\n])", text) if s.strip()] + if len(sentences) >= count: + return _bucketize(sentences, " ") + + # 3) 仍不足:用已有块/句补齐到 count,绝不留空镜 + base = blocks or sentences or [text] + filled = list(base) + while len(filled) < count: + filled.append(base[-1]) + return filled[:count] @transaction.atomic @@ -242,30 +274,98 @@ def generate_base_asset(*, project, user, kind: str, prompt: str) -> BaseAssetGr raise -def generate_storyboard(*, project, user, prompt: str = "") -> StoryboardVersion: +def _scene_context(project) -> str: + """从商品 + 已采用基础资产提炼一句「风格锚点」,贯穿故事板 / 视频,保证各镜内容一致。""" + product = project.product + parts = [f"商品:{product.title}"] + if product.brand: + parts.append(f"品牌:{product.brand}") + if product.category: + parts.append(f"类目:{product.category}") + if getattr(product, "target_audience", ""): + parts.append(f"人群:{product.target_audience}") + adopted_kinds = set( + project.base_asset_groups.filter(adopted_asset__isnull=False).values_list("kind", flat=True) + ) + if BaseAssetGroup.Kind.PERSON in adopted_kinds: + parts.append("真人出镜,保持人物一致") + if BaseAssetGroup.Kind.SCENE in adopted_kinds: + parts.append("统一场景与色调") + return " · ".join(parts) + + +def build_storyboard_frame_prompt(project, version, segment) -> str: + """单帧故事板提示词:风格锚点 + 本镜画面(回退旁白)+ 版本统一指令。""" + visual = (segment.visual_prompt or segment.narration or "").strip() + lines = [ + _scene_context(project), + f"第 {segment.sort_order + 1} 镜画面:{visual}" if visual else f"第 {segment.sort_order + 1} 镜", + ] + if version.prompt: + lines.append(version.prompt.strip()) + lines.append("电商竖屏分镜图,构图清晰,可直接指导视频生成") + return "\n".join(line for line in lines if line) + + +def build_video_segment_prompt(project, video_segment, scene, user_prompt: str) -> str: + """单段视频提示词:把本镜旁白 + 画面 + 风格锚点织进去,让每个视频片段跟住对应脚本/故事板。""" + lines = [_scene_context(project)] + if scene is not None: + if scene.narration: + lines.append(f"旁白:{scene.narration.strip()}") + visual = (scene.visual_prompt or scene.narration or "").strip() + if visual: + lines.append(f"画面:{visual}") + if user_prompt: + lines.append(user_prompt.strip()) + lines.append( + f"第 {video_segment.sort_order + 1} 段 · {video_segment.target_duration_seconds}s · " + "9:16 竖屏电商带货短视频,镜头稳定,商品露出清晰,节奏有转化感" + ) + return "\n".join(line for line in lines if line) + + +def submit_storyboard(*, project, user, prompt: str = "") -> StoryboardVersion: + """异步故事板·提交:快速创建(或复用)一个未采用的版本,不在此处生图。逐帧生成交给 generate_storyboard_frame(轮询)。""" adopted_script = project.script_versions.filter(is_adopted=True).prefetch_related("segments").first() if adopted_script is None: raise ValueError("script must be adopted before generating storyboard") - model_config = get_default_model(ModelConfig.Capability.IMAGE) - if model_config is None: + if get_default_model(ModelConfig.Capability.IMAGE) is None: raise ValueError("no active image model configured") + # 复用尚未完成(未采用)的版本,避免重复提交产生多版本;否则新建 + version = project.storyboard_versions.filter(is_adopted=False).order_by("-created_at").first() + if version is None: + version = StoryboardVersion.objects.create(project=project, prompt=prompt) + elif prompt and version.prompt != prompt: + version.prompt = prompt + version.save(update_fields=["prompt", "updated_at"]) + return version - storyboard = StoryboardVersion.objects.create(project=project, prompt=prompt) - provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None) - for segment in adopted_script.segments.all(): - task = create_ai_task( - project=project, - user=user, - task_type=AITask.Type.STORYBOARD, - model_config=model_config, - request_payload={"model": model_config.name, "endpoint": model_config.endpoint, "prompt": segment.visual_prompt}, - ) + +def _storyboard_frame_worker(task_id, version_id, segment_id, user_id) -> None: + """后台线程:真正调 ARK 生成一帧故事板图并落库。每次 poll 不阻塞在此——HTTP 永远秒回。""" + import threading # noqa: F401 — 仅标注此函数运行在独立线程 + from django.db import connections + + from apps.accounts.models import User + + try: + task = AITask.objects.select_related("model_config__provider").get(id=task_id) + version = StoryboardVersion.objects.select_related("project__team").get(id=version_id) + segment = ScriptSegment.objects.get(id=segment_id) + user = User.objects.get(id=user_id) + project = version.project + model_config = task.model_config reservation = task.credit_reservation + task.status = AITask.Status.SUBMITTED + task.save(update_fields=["status", "updated_at"]) try: + provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None) + frame_prompt = task.request_payload.get("prompt") or build_storyboard_frame_prompt(project, version, segment) response = provider.image_generation( model=model_config.name, endpoint=model_config.endpoint, - prompt=f"{prompt}\n{segment.visual_prompt}".strip(), + prompt=frame_prompt, ) media = provider.extract_first_media_url(response) task.status = AITask.Status.SUCCEEDED @@ -285,22 +385,141 @@ def generate_storyboard(*, project, user, prompt: str = "") -> StoryboardVersion asset_type=Asset.Type.IMAGE, ) StoryboardFrame.objects.create( - storyboard=storyboard, + storyboard=version, script_segment=segment, asset=asset, sort_order=segment.sort_order, prompt=segment.visual_prompt, ) - except Exception as exc: + except Exception as exc: # noqa: BLE001 — 失败回滚额度,标记任务失败供 poll 上报 task.status = AITask.Status.FAILED task.error_message = str(exc) task.completed_at = timezone.now() task.save(update_fields=["status", "error_message", "completed_at", "updated_at"]) release_credit(reservation=reservation, reason=str(exc)) - raise - storyboard.is_adopted = True - storyboard.save(update_fields=["is_adopted", "updated_at"]) - return storyboard + finally: + connections.close_all() # 释放该线程的 DB 连接 + + +def generate_storyboard_frame(*, project, user) -> dict: + """异步故事板·轮询(秒回):读取进度;若无帧在生成则后台起线程生成下一帧。永不阻塞在 ARK 调用上。 + 返回 {status: generating|succeeded|failed, done, total, version_id}。全部完成→采用版本。""" + import threading + + version = project.storyboard_versions.filter(is_adopted=False).order_by("-created_at").first() + adopted_script = project.script_versions.filter(is_adopted=True).prefetch_related("segments").first() + if version is None or adopted_script is None: + latest = project.storyboard_versions.order_by("-created_at").first() + n = latest.frames.count() if latest else 0 + return {"status": "succeeded", "done": n, "total": n, "version_id": str(latest.id) if latest else ""} + + segments = list(adopted_script.segments.all().order_by("sort_order")) + total = len(segments) + done_segment_ids = set(version.frames.values_list("script_segment_id", flat=True)) + done = len(done_segment_ids) + + if done >= total: + _finalize_storyboard(project, version) + return {"status": "succeeded", "done": total, "total": total, "version_id": str(version.id)} + + # 该版本内是否已有帧在后台生成中(RESERVED/SUBMITTED 的故事板任务即为「占位锁」)。 + # 仅算「近 3 分钟内」的任务:若进程/线程意外中断留下僵尸任务,超时后不再视为在生成,允许重新发起。 + stale_cutoff = timezone.now() - timedelta(minutes=3) + inflight = AITask.objects.filter( + project=project, + task_type=AITask.Type.STORYBOARD, + status__in=[AITask.Status.CREATED, AITask.Status.RESERVED, AITask.Status.SUBMITTED], + request_payload__storyboard_version=str(version.id), + created_at__gte=stale_cutoff, + ).exists() + if inflight: + return {"status": "generating", "done": done, "total": total, "version_id": str(version.id)} + + pending = [s for s in segments if s.id not in done_segment_ids] + segment = pending[0] + # 单帧失败次数上限,避免持续失败时无限重试 + failed_for_segment = AITask.objects.filter( + project=project, + task_type=AITask.Type.STORYBOARD, + status=AITask.Status.FAILED, + request_payload__storyboard_segment=str(segment.id), + ).count() + if failed_for_segment >= 2: + last = AITask.objects.filter(project=project, task_type=AITask.Type.STORYBOARD, status=AITask.Status.FAILED, + request_payload__storyboard_segment=str(segment.id)).order_by("-created_at").first() + return {"status": "failed", "done": done, "total": total, "version_id": str(version.id), + "error": last.error_message if last else "storyboard frame failed"} + + model_config = get_default_model(ModelConfig.Capability.IMAGE) + task = create_ai_task( + project=project, + user=user, + task_type=AITask.Type.STORYBOARD, + model_config=model_config, + request_payload={ + "model": model_config.name, + "endpoint": model_config.endpoint, + "prompt": build_storyboard_frame_prompt(project, version, segment), + "storyboard_version": str(version.id), + "storyboard_segment": str(segment.id), + }, + ) + threading.Thread( + target=_storyboard_frame_worker, + args=(str(task.id), str(version.id), str(segment.id), str(user.id)), + daemon=True, + ).start() + return {"status": "generating", "done": done, "total": total, "version_id": str(version.id)} + + +def _finalize_storyboard(project, version) -> None: + """全部帧就绪:采用该版本(反采用其余版本)。项目阶段推进由视图负责(与原同步实现一致)。""" + project.storyboard_versions.exclude(id=version.id).update(is_adopted=False) + if not version.is_adopted: + version.is_adopted = True + version.save(update_fields=["is_adopted", "updated_at"]) + + +def _asset_preview_url(asset) -> str: + """资产主文件的可公开访问 URL(已写绝对 URL 优先,否则实时签 TOS GET)。""" + if asset is None: + return "" + primary = asset.files.filter(is_primary=True).first() or asset.files.first() + if primary is None: + return "" + if primary.preview_url: + return primary.preview_url + try: + return TosStorage().presigned_get_url(object_key=primary.object_key) + except Exception: + return "" + + +def _video_reference_images(project, video_segment) -> list[str]: + """为本视频段挑一张视觉参考图:优先本镜故事板帧,兜底已采用商品基础资产。""" + version = ( + project.storyboard_versions.filter(is_adopted=True).order_by("-created_at").first() + or project.storyboard_versions.order_by("-created_at").first() + ) + if version is not None: + frame = ( + version.frames.filter(sort_order=video_segment.sort_order).first() + or version.frames.order_by("sort_order").first() + ) + if frame is not None: + url = _asset_preview_url(frame.asset) + if url: + return [url] + product_group = ( + project.base_asset_groups.filter(kind=BaseAssetGroup.Kind.PRODUCT, adopted_asset__isnull=False) + .order_by("-created_at") + .first() + ) + if product_group is not None: + url = _asset_preview_url(product_group.adopted_asset) + if url: + return [url] + return [] def submit_video_segment(*, video_segment: VideoSegment, user, prompt: str) -> VideoSegmentVersion | None: @@ -308,6 +527,20 @@ def submit_video_segment(*, video_segment: VideoSegment, user, prompt: str) -> V if model_config is None: raise ValueError("no active video model configured") project = video_segment.project + + # 衔接:按 sort_order 把视频段绑到对应脚本镜,并织出跟住该镜的提示词。 + scene = None + adopted_script = project.script_versions.filter(is_adopted=True).prefetch_related("segments").first() + if adopted_script is not None: + scene = adopted_script.segments.filter(sort_order=video_segment.sort_order).first() + if scene is not None and video_segment.script_segment_id != scene.id: + video_segment.script_segment = scene + video_segment.save(update_fields=["script_segment", "updated_at"]) + final_prompt = build_video_segment_prompt(project, video_segment, scene, prompt) + + # 参考图:优先用本镜故事板帧,其次商品/人物基础资产,给视频做视觉锚点(衔接故事板→视频)。 + reference_images = _video_reference_images(project, video_segment) + task = create_ai_task( project=project, user=user, @@ -316,22 +549,38 @@ def submit_video_segment(*, video_segment: VideoSegment, user, prompt: str) -> V request_payload={ "model": model_config.name, "endpoint": model_config.endpoint, - "prompt": prompt, + "prompt": final_prompt, "duration": video_segment.target_duration_seconds, "ratio": "9:16", "video_segment_id": str(video_segment.id), + "reference_images": reference_images, }, ) try: provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None) - response = provider.create_video_task( - model=model_config.name, - endpoint=model_config.endpoint, - prompt=prompt, - duration=video_segment.target_duration_seconds, - ratio="9:16", - resolution="720p", - ) + try: + response = provider.create_video_task( + model=model_config.name, + endpoint=model_config.endpoint, + prompt=final_prompt, + duration=video_segment.target_duration_seconds, + ratio="9:16", + resolution="720p", + reference_images=reference_images or None, + ) + except Exception: + # 降级:带参考图被拒时退回纯文生视频(文本里已含本镜旁白/画面,衔接不丢) + if not reference_images: + raise + response = provider.create_video_task( + model=model_config.name, + endpoint=model_config.endpoint, + prompt=final_prompt, + duration=video_segment.target_duration_seconds, + ratio="9:16", + resolution="720p", + reference_images=None, + ) task.provider_task_id = str(response.get("id") or response.get("task_id") or "") task.response_payload = response task.status = AITask.Status.SUBMITTED @@ -353,6 +602,12 @@ def submit_video_segment(*, video_segment: VideoSegment, user, prompt: str) -> V def poll_video_segment(*, video_segment: VideoSegment, user) -> VideoSegmentVersion | None: + # 幂等:已完成的段直接回采用版;已失败的段不再 poll。避免对已成功 task 再 poll → 二次建版 / 二次扣费。 + if video_segment.status == VideoSegment.Status.SUCCEEDED: + return video_segment.adopted_version or video_segment.versions.order_by("-created_at").first() + if video_segment.status == VideoSegment.Status.FAILED: + return None + task = video_segment.versions.order_by("-created_at").first() ai_task = None if task: @@ -366,6 +621,12 @@ def poll_video_segment(*, video_segment: VideoSegment, user) -> VideoSegmentVers if ai_task is None: raise ValueError("no active video generation task") + # task 已终态(可能被并发的 worker / 另一次 poll 处理过):直接回已有版,不再调 ARK。 + if ai_task.status == AITask.Status.SUCCEEDED: + return video_segment.versions.filter(task=ai_task).order_by("-created_at").first() + if ai_task.status in (AITask.Status.FAILED, AITask.Status.CANCELLED): + return None + provider = VolcanoArkProvider(base_url=ai_task.model_config.provider.base_url or None) response = provider.poll_video_task(endpoint=ai_task.model_config.endpoint, provider_task_id=ai_task.provider_task_id) remote_status = response.get("status") diff --git a/core/backend/apps/billing/urls.py b/core/backend/apps/billing/urls.py index 235b052..1c0595a 100644 --- a/core/backend/apps/billing/urls.py +++ b/core/backend/apps/billing/urls.py @@ -1,9 +1,10 @@ from django.urls import path -from .views import ledgers, recharge, summary +from .views import ledgers, recharge, summary, trend urlpatterns = [ path("summary/", summary, name="billing-summary"), path("ledgers/", ledgers, name="billing-ledgers"), path("recharge/", recharge, name="billing-recharge"), + path("trend/", trend, name="billing-trend"), ] diff --git a/core/backend/apps/billing/views.py b/core/backend/apps/billing/views.py index 3c3c014..191a2cd 100644 --- a/core/backend/apps/billing/views.py +++ b/core/backend/apps/billing/views.py @@ -1,17 +1,33 @@ +from datetime import timedelta from decimal import Decimal, InvalidOperation from django.db import transaction from django.db.models import Sum +from django.db.models.functions import TruncDate +from django.utils import timezone from rest_framework import status from rest_framework.decorators import api_view, permission_classes from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response +from apps.ai.models import AITask from apps.common.api import get_current_team from .models import CreditAccount, CreditLedger from .serializers import CreditAccountSerializer, CreditLedgerSerializer +# AITask.task_type → 账户页「按阶段分布」的 4 个聚合桶 +_STAGE_BUCKET = { + AITask.Type.SCRIPT_GENERATION: "script", + AITask.Type.SCRIPT_OPTIMIZATION: "script", + AITask.Type.PRODUCT_IMAGE: "base", + AITask.Type.PERSON_IMAGE: "base", + AITask.Type.SCENE_IMAGE: "base", + AITask.Type.STORYBOARD: "storyboard", + AITask.Type.VIDEO_SEGMENT: "video", + AITask.Type.EXPORT: "video", +} + @api_view(["GET"]) @permission_classes([IsAuthenticated]) @@ -78,3 +94,81 @@ def recharge(request): }, status=status.HTTP_201_CREATED, ) + + +@api_view(["GET"]) +@permission_classes([IsAuthenticated]) +def trend(request): + """账户页消费分析:消费趋势(日/周/月可切)+ 本月按阶段/按项目分布。全部来自真实 CHARGE 流水。""" + team = get_current_team(request.user) + today = timezone.localdate() + rng = request.query_params.get("range", "day") + charges = CreditLedger.objects.filter(team=team, ledger_type=CreditLedger.Type.CHARGE) + + def _daily_amounts(win_start): + rows = ( + charges.filter(created_at__date__gte=win_start) + .annotate(day=TruncDate("created_at")) + .values("day") + .annotate(amount=Sum("amount")) + ) + return {row["day"]: row["amount"] or Decimal("0") for row in rows} + + # 按 range 选窗口与分桶:日=近 14 天 / 周=近 8 周 / 月=近 6 个自然月(缺口补 0) + series = [] + if rng == "week": + monday = today - timedelta(days=today.weekday()) + starts = [monday - timedelta(weeks=(7 - i)) for i in range(8)] + amt_by_day = _daily_amounts(starts[0]) + for s in starts: + total = sum((amt_by_day.get(s + timedelta(days=k), Decimal("0")) for k in range(7)), Decimal("0")) + series.append({"date": s.isoformat(), "label": s.strftime("%m/%d"), "amount": str(total)}) + elif rng == "month": + seq = [] + y, m = today.year, today.month + for _ in range(6): + seq.append((y, m)) + m -= 1 + if m == 0: + m, y = 12, y - 1 + seq.reverse() + amt_by_day = _daily_amounts(today.replace(year=seq[0][0], month=seq[0][1], day=1)) + for yy, mm in seq: + total = sum((v for d, v in amt_by_day.items() if d.year == yy and d.month == mm), Decimal("0")) + series.append({"date": f"{yy}-{mm:02d}-01", "label": f"{mm}月", "amount": str(total)}) + else: + start = today - timedelta(days=13) + amt_by_day = _daily_amounts(start) + for i in range(14): + d = start + timedelta(days=i) + series.append({"date": d.isoformat(), "label": d.strftime("%m/%d"), "amount": str(amt_by_day.get(d, Decimal("0")))}) + + daily = series + total_14d = sum((Decimal(s["amount"]) for s in series), Decimal("0")) + peak = max((Decimal(s["amount"]) for s in series), default=Decimal("0")) + avg = (total_14d / len(series)).quantize(Decimal("0.0001")) if series else Decimal("0") + + # 本月按阶段分布(task.task_type → 4 桶) + month_start = today.replace(day=1) + month_charges = charges.filter(created_at__date__gte=month_start).select_related("task") + by_stage = {"script": Decimal("0"), "base": Decimal("0"), "storyboard": Decimal("0"), "video": Decimal("0")} + project_amounts: dict[str, Decimal] = {} + for row in month_charges: + task = row.task + bucket = _STAGE_BUCKET.get(task.task_type) if task else None + if bucket: + by_stage[bucket] += row.amount + pid = str(row.project_id) if row.project_id else None + if pid: + project_amounts[pid] = project_amounts.get(pid, Decimal("0")) + row.amount + + return Response( + { + "daily": daily, + "total_14d": str(total_14d), + "avg": str(avg), + "peak": str(peak), + "by_stage": {k: str(v) for k, v in by_stage.items()}, + "by_project": {k: str(v) for k, v in project_amounts.items()}, + } + ) diff --git a/core/backend/apps/products/views.py b/core/backend/apps/products/views.py index 5f73ef1..9878752 100644 --- a/core/backend/apps/products/views.py +++ b/core/backend/apps/products/views.py @@ -1,8 +1,18 @@ +from pathlib import Path +import uuid + +from django.db import transaction +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.parsers import FormParser, MultiPartParser +from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -from apps.common.api import TeamScopedViewSetMixin +from apps.assets.models import Asset, AssetFile +from apps.assets.storage import TosStorage +from apps.common.api import TeamScopedViewSetMixin, get_current_team -from .models import Product +from .models import Product, ProductImage from .serializers import ProductSerializer @@ -12,3 +22,57 @@ class ProductViewSet(TeamScopedViewSetMixin, ModelViewSet): search_fields = ["title", "brand", "category"] ordering_fields = ["created_at", "updated_at", "title"] + @action(detail=True, methods=["post"], url_path="images", parser_classes=[MultiPartParser, FormParser]) + @transaction.atomic + def upload_image(self, request, pk=None): + """上传商品图:file → TOS → Asset(category=product_image) → ProductImage 关联。""" + product = self.get_object() + team = get_current_team(request.user) + upload = request.FILES.get("file") + if upload is None: + return Response({"detail": "no file"}, status=status.HTTP_400_BAD_REQUEST) + + suffix = Path(upload.name).suffix.lower() or ".png" + asset_id = uuid.uuid4() + object_key = f"teams/{team.id}/products/{product.id}/{asset_id}{suffix}" + stored = TosStorage().upload_fileobj( + fileobj=upload.file, + object_key=object_key, + content_type=upload.content_type or "image/png", + ) + asset = Asset.objects.create( + id=asset_id, + team=team, + created_by=request.user, + name=request.data.get("name") or upload.name, + asset_type=Asset.Type.IMAGE, + source=Asset.Source.UPLOAD, + category=Asset.Category.PRODUCT_IMAGE, + ) + AssetFile.objects.create( + asset=asset, + object_key=stored.object_key, + bucket=stored.bucket, + content_type=stored.content_type, + size_bytes=stored.size_bytes, + is_primary=True, + ) + next_order = product.images.count() + ProductImage.objects.create( + product=product, + asset=asset, + sort_order=next_order, + is_primary=next_order == 0, + ) + product.refresh_from_db() + return Response(ProductSerializer(product).data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["delete"], url_path=r"images/(?P[^/.]+)") + def delete_image(self, request, pk=None, image_id=None): + """移除商品图(删 ProductImage 关联,保留底层 Asset)。""" + product = self.get_object() + deleted, _ = ProductImage.objects.filter(product=product, id=image_id).delete() + if not deleted: + return Response({"detail": "image not found"}, status=status.HTTP_404_NOT_FOUND) + product.refresh_from_db() + return Response(ProductSerializer(product).data, status=status.HTTP_200_OK) diff --git a/core/backend/apps/projects/serializers.py b/core/backend/apps/projects/serializers.py index 0164bd4..242f622 100644 --- a/core/backend/apps/projects/serializers.py +++ b/core/backend/apps/projects/serializers.py @@ -1,5 +1,7 @@ from rest_framework import serializers +from apps.assets.serializers import AssetFileSerializer + from .models import ( BaseAssetGroup, BgmTrack, @@ -18,6 +20,16 @@ from .models import ( ) +def _asset_preview_url(asset) -> str: + """资产主文件的可播放/可显示 URL(主图优先,其次首张),内嵌进各阶段序列化, + 让前端缩略图不再依赖(分页 20 条的)团队 assets 列表解析——团队资产 >20 时新生成的图本会丢。""" + if asset is None: + return "" + files = list(asset.files.all()) + primary = next((f for f in files if f.is_primary), files[0] if files else None) + return AssetFileSerializer().get_preview_url(primary) if primary else "" + + class ProjectStageSerializer(serializers.ModelSerializer): class Meta: model = ProjectStage @@ -27,10 +39,11 @@ class ProjectStageSerializer(serializers.ModelSerializer): class VideoSegmentSerializer(serializers.ModelSerializer): adopted_asset = serializers.SerializerMethodField() + adopted_asset_url = serializers.SerializerMethodField() class Meta: model = VideoSegment - fields = ["id", "sort_order", "target_duration_seconds", "status", "error_message", "adopted_version", "adopted_asset"] + fields = ["id", "sort_order", "target_duration_seconds", "status", "error_message", "adopted_version", "adopted_asset", "adopted_asset_url"] read_only_fields = ["id", "sort_order", "target_duration_seconds", "status", "error_message", "adopted_version"] def get_adopted_asset(self, obj): @@ -38,22 +51,39 @@ class VideoSegmentSerializer(serializers.ModelSerializer): version = obj.adopted_version return str(version.asset_id) if version and version.asset_id else None + def get_adopted_asset_url(self, obj) -> str: + version = obj.adopted_version + return _asset_preview_url(version.asset) if version is not None else "" + class BaseAssetGroupSerializer(serializers.ModelSerializer): candidate_assets = serializers.PrimaryKeyRelatedField(many=True, read_only=True) + adopted_asset_url = serializers.SerializerMethodField() + candidate_asset_urls = serializers.SerializerMethodField() class Meta: model = BaseAssetGroup - fields = ["id", "kind", "prompt", "adopted_asset", "candidate_assets", "version", "metadata", "created_at"] + fields = ["id", "kind", "prompt", "adopted_asset", "adopted_asset_url", "candidate_assets", "candidate_asset_urls", "version", "metadata", "created_at"] read_only_fields = fields + def get_adopted_asset_url(self, obj) -> str: + return _asset_preview_url(obj.adopted_asset) + + def get_candidate_asset_urls(self, obj) -> dict: + return {str(asset.id): _asset_preview_url(asset) for asset in obj.candidate_assets.all()} + class StoryboardFrameSerializer(serializers.ModelSerializer): + asset_url = serializers.SerializerMethodField() + class Meta: model = StoryboardFrame - fields = ["id", "script_segment", "asset", "sort_order", "prompt"] + fields = ["id", "script_segment", "asset", "asset_url", "sort_order", "prompt"] read_only_fields = fields + def get_asset_url(self, obj) -> str: + return _asset_preview_url(obj.asset) + class StoryboardVersionSerializer(serializers.ModelSerializer): frames = StoryboardFrameSerializer(many=True, read_only=True) @@ -72,10 +102,34 @@ class VideoSegmentVersionSerializer(serializers.ModelSerializer): class TimelineClipSerializer(serializers.ModelSerializer): + # 直接内嵌片段资产的可播放 URL + 是否视频,前端播放器无需再依赖(分页的)团队 assets 列表解析 + asset_url = serializers.SerializerMethodField() + asset_is_video = serializers.SerializerMethodField() + class Meta: model = TimelineClip - fields = ["id", "asset", "sort_order", "start_ms", "duration_ms", "trim_start_ms", "trim_end_ms"] - read_only_fields = ["id"] + fields = ["id", "asset", "asset_url", "asset_is_video", "sort_order", "start_ms", "duration_ms", "trim_start_ms", "trim_end_ms"] + read_only_fields = ["id", "asset_url", "asset_is_video"] + + def _primary_file(self, obj): + asset = obj.asset + if asset is None: + return None + files = list(asset.files.all()) + return next((f for f in files if f.is_primary), files[0] if files else None) + + def get_asset_url(self, obj) -> str: + f = self._primary_file(obj) + return AssetFileSerializer().get_preview_url(f) if f else "" + + def get_asset_is_video(self, obj) -> bool: + asset = obj.asset + if asset is None: + return False + if asset.asset_type == "video": + return True + f = self._primary_file(obj) + return bool(f and "video/" in (f.content_type or "")) class TimelineExportJobSerializer(serializers.ModelSerializer): @@ -93,11 +147,20 @@ class SubtitleTrackSerializer(serializers.ModelSerializer): class BgmTrackSerializer(serializers.ModelSerializer): + asset_url = serializers.SerializerMethodField() + asset_name = serializers.SerializerMethodField() + class Meta: model = BgmTrack - fields = ["id", "asset", "volume", "start_ms"] + fields = ["id", "asset", "asset_url", "asset_name", "volume", "start_ms"] read_only_fields = fields + def get_asset_url(self, obj) -> str: + return _asset_preview_url(obj.asset) + + def get_asset_name(self, obj) -> str: + return obj.asset.name if obj.asset_id else "" + class TimelineSerializer(serializers.ModelSerializer): clips = TimelineClipSerializer(many=True, read_only=True) diff --git a/core/backend/apps/projects/services/export.py b/core/backend/apps/projects/services/export.py index 1dd7db5..890371f 100644 --- a/core/backend/apps/projects/services/export.py +++ b/core/backend/apps/projects/services/export.py @@ -1,15 +1,39 @@ from pathlib import Path import subprocess import tempfile +import threading import requests -from django.db import transaction +from django.db import connections, transaction from apps.assets.models import Asset, AssetFile from apps.assets.storage import TosStorage from apps.projects.models import ExportJob +# 字幕样式(对齐 stage5 四个 swatch)。本机 ffmpeg 无 libass/drawtext,改用 Pillow 渲染 PNG 再 overlay 烧入。 +# RGBA 颜色;box 为半透明黑底(影视),stroke 为描边色。 +SUBTITLE_STYLES: dict[str, dict] = { + "plain": {"size": 58, "fill": (255, 255, 255, 255), "stroke": (0, 0, 0, 255), "stroke_w": 4, "box": None}, # 朴素白底 + "cinema": {"size": 56, "fill": (255, 255, 255, 255), "stroke": (0, 0, 0, 0), "stroke_w": 0, "box": (0, 0, 0, 165)}, # 影视黑底 + "handwrite": {"size": 60, "fill": (255, 255, 255, 255), "stroke": (250, 93, 25, 255), "stroke_w": 7, "box": None}, # 手写描边(主橙 #fa5d19) + "variety": {"size": 60, "fill": (255, 220, 60, 255), "stroke": (0, 0, 0, 255), "stroke_w": 6, "box": None}, # 综艺暖黄 +} +# 候选 CJK 字体(mac 优先,Linux 兜底) +_FONT_CANDIDATES = [ + "/System/Library/Fonts/STHeiti Medium.ttc", + "/System/Library/Fonts/Hiragino Sans GB.ttc", + "/System/Library/Fonts/PingFang.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", +] +# 转场(UI 选项)→ ffmpeg xfade transition 名。"none" 表示纯拼接。 +XFADE_MAP: dict[str, str] = { + "fade": "fade", "dissolve": "dissolve", "slide": "slideleft", "slideleft": "slideleft", + "slideright": "slideright", "wipe": "wiperight", "wiperight": "wiperight", "circle": "circleopen", "smooth": "smoothleft", +} + + def _download_asset_primary_file(asset, target_path: Path) -> None: primary = asset.files.filter(is_primary=True).first() or asset.files.first() if primary is None: @@ -20,6 +44,174 @@ def _download_asset_primary_file(asset, target_path: Path) -> None: target_path.write_bytes(response.content) +def _load_font(size: int): + from PIL import ImageFont + + for path in _FONT_CANDIDATES: + try: + return ImageFont.truetype(path, size, index=0) + except Exception: # noqa: BLE001 + continue + return ImageFont.load_default() + + +def _wrap_cjk(draw, text: str, font, max_width: int) -> list[str]: + """按像素宽折行(中文逐字、英文整体不强拆)。""" + lines: list[str] = [] + line = "" + for ch in text: + trial = line + ch + if draw.textlength(trial, font=font) <= max_width or not line: + line = trial + else: + lines.append(line) + line = ch + if line: + lines.append(line) + return lines[:3] # 最多 3 行,够长截断 + + +def _render_subtitle_png(text: str, style_key: str, path: Path) -> tuple[int, int]: + """把一条字幕渲染成 1080 宽的透明 PNG(居中,带描边/底框),返回 (w,h)。""" + from PIL import Image, ImageDraw + + st = SUBTITLE_STYLES.get(style_key) or SUBTITLE_STYLES["plain"] + canvas_w = 1080 + margin_x = 90 + font = _load_font(st["size"]) + probe = ImageDraw.Draw(Image.new("RGBA", (10, 10))) + lines = _wrap_cjk(probe, (text or "").strip().replace("\n", " "), font, canvas_w - 2 * margin_x) + line_h = st["size"] + 16 + pad = 22 + text_h = line_h * len(lines) + canvas_h = text_h + 2 * pad + img = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0)) + draw = ImageDraw.Draw(img) + if st["box"]: + widest = max((draw.textlength(ln, font=font) for ln in lines), default=0) + box_w = int(widest) + 2 * pad + 24 + x0 = (canvas_w - box_w) // 2 + draw.rounded_rectangle([x0, 0, x0 + box_w, canvas_h], radius=16, fill=st["box"]) + y = pad + for ln in lines: + w = draw.textlength(ln, font=font) + x = (canvas_w - w) / 2 + draw.text((x, y), ln, font=font, fill=st["fill"], + stroke_width=st["stroke_w"], stroke_fill=st["stroke"]) + y += line_h + img.save(path) + return canvas_w, canvas_h + + +def _clip_specs(clips) -> list[dict]: + """每个 clip 的入点/出点/时长(秒),考虑 trim。""" + specs = [] + for clip in clips: + ts = (clip.trim_start_ms or 0) / 1000.0 + te = (clip.trim_end_ms / 1000.0) if clip.trim_end_ms else ts + (clip.duration_ms or 15000) / 1000.0 + specs.append({"ts": ts, "te": te, "dur": max(0.1, te - ts)}) + return specs + + +def _output_starts(specs: list[dict], xfade: float) -> tuple[list[float], float]: + """每个 clip 在输出时间轴上的起点 + 输出总时长(xfade 会压缩总长)。""" + starts, cum = [], 0.0 + for i, s in enumerate(specs): + starts.append(0.0 if i == 0 else max(0.0, cum - i * xfade)) + cum += s["dur"] + total = sum(s["dur"] for s in specs) - (len(specs) - 1) * xfade if xfade > 0 else sum(s["dur"] for s in specs) + return starts, max(0.1, total) + + +def _build_export_command(*, n: int, specs: list[dict], starts: list[float], total: float, + transition: str, sub_overlays: list[tuple[str, float, float]], + bgm_name: str | None, bgm_volume: float) -> list[str]: + parts: list[str] = [] + for i, s in enumerate(specs): + parts.append( + f"[{i}:v]trim=start={s['ts']:.3f}:end={s['te']:.3f},setpts=PTS-STARTPTS," + "scale=1080:1920:force_original_aspect_ratio=decrease," + "pad=1080:1920:(ow-iw)/2:(oh-ih)/2,setsar=1,fps=30,format=yuv420p[v" + str(i) + "]" + ) + xname = XFADE_MAP.get(transition or "none") + if xname and n > 1: + prev = "v0" + for i in range(1, n): + out = "vbase" if i == n - 1 else f"x{i}" + parts.append(f"[{prev}][v{i}]xfade=transition={xname}:duration=0.5:offset={starts[i]:.3f}[{out}]") + prev = out + else: + parts.append("".join(f"[v{i}]" for i in range(n)) + f"concat=n={n}:v=1:a=0[vbase]") + + # 字幕:每条一张 PNG,按时间窗 overlay 到底部居中(本机 ffmpeg 无 libass,用图片烧入) + sub_base = n + (1 if bgm_name else 0) + vlabel = "vbase" + for j, (_png, start, end) in enumerate(sub_overlays): + idx = sub_base + j + out = "vout" if j == len(sub_overlays) - 1 else f"ov{j}" + parts.append( + f"[{vlabel}][{idx}:v]overlay=x=(W-w)/2:y=H-h-150:enable='between(t,{start:.3f},{end:.3f})'[{out}]" + ) + vlabel = out + if bgm_name: + parts.append(f"[{n}:a]volume={bgm_volume:.3f},atrim=0:{total:.3f},asetpts=PTS-STARTPTS[aout]") + + cmd = ["ffmpeg", "-y"] + for i in range(n): + cmd += ["-i", f"clip{i}.mp4"] + if bgm_name: + cmd += ["-stream_loop", "-1", "-i", bgm_name] + for png, _s, _e in sub_overlays: + cmd += ["-loop", "1", "-i", png] + cmd += ["-filter_complex", ";".join(parts), "-map", f"[{vlabel}]"] + if bgm_name: + cmd += ["-map", "[aout]"] + cmd += ["-c:v", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "-preset", "veryfast"] + if bgm_name: + cmd += ["-c:a", "aac", "-b:a", "192k"] + cmd += ["-t", f"{total:.3f}", "-movflags", "+faststart", "output.mp4"] + return cmd + + +def run_export_job_in_thread(export_job_id: str) -> None: + """后台线程跑拼接导出。本机无 Celery worker(dev),故事板/视频已用线程模式,导出沿用同一打法: + HTTP 秒回,真实 ffmpeg 拼接在线程里跑,前端轮询 poll-export 看进度 / 取成片。失败落库供轮询上报。""" + + def _worker() -> None: + try: + run_export_job(export_job_id) + except Exception as exc: # noqa: BLE001 — 失败落库,poll-export 据此上报 + job = ExportJob.objects.filter(id=export_job_id).first() + if job is not None: + job.status = ExportJob.Status.FAILED + job.error_message = str(exc) + job.save(update_fields=["status", "error_message", "updated_at"]) + finally: + connections.close_all() + + threading.Thread(target=_worker, daemon=True).start() + + +def _subtitle_cues(timeline, project, specs, starts, total) -> list[tuple[float, float, str]]: + """字幕条目:文本取 SubtitleTrack.content,空则回退脚本旁白;时间按输出布局(对 xfade 也对齐)。""" + track = timeline.subtitle_tracks.filter(enabled=True).first() or timeline.subtitle_tracks.first() + if track is None or track.enabled is False: + return [] + texts: list[str] = [str((c or {}).get("text", "")) for c in (track.content or [])] + if not any(t.strip() for t in texts): + script = project.script_versions.filter(is_adopted=True).prefetch_related("segments").first() + if script is not None: + texts = [seg.narration for seg in script.segments.all().order_by("sort_order")] + cues: list[tuple[float, float, str]] = [] + for i in range(len(specs)): + text = texts[i] if i < len(texts) else "" + start = starts[i] + end = starts[i + 1] if i + 1 < len(starts) else total + if text and text.strip(): + cues.append((start, max(start + 0.5, end), text)) + return cues + + def run_export_job(export_job_id: str) -> ExportJob: export_job = ExportJob.objects.select_related("timeline", "timeline__project").get(id=export_job_id) timeline = export_job.timeline @@ -32,43 +224,45 @@ def run_export_job(export_job_id: str) -> ExportJob: export_job.progress = 10 export_job.save(update_fields=["status", "progress", "updated_at"]) + transition = str((timeline.metadata or {}).get("transition", {}).get("type", "none")) + bgm_track = timeline.bgm_tracks.select_related("asset").first() + subtitle_track = timeline.subtitle_tracks.filter(enabled=True).first() + style_key = str((subtitle_track.style or {}).get("key", "plain")) if subtitle_track else "plain" + + specs = _clip_specs(clips) + xfade = 0.5 if XFADE_MAP.get(transition) and len(clips) > 1 else 0.0 + starts, total = _output_starts(specs, xfade) + with tempfile.TemporaryDirectory(prefix="airshelf-export-") as tmp_dir: tmp = Path(tmp_dir) - concat_file = tmp / "concat.txt" - downloaded_files: list[Path] = [] for index, clip in enumerate(clips): - clip_path = tmp / f"clip-{index}.mp4" - _download_asset_primary_file(clip.asset, clip_path) - downloaded_files.append(clip_path) - concat_file.write_text( - "\n".join(f"file '{path.as_posix()}'" for path in downloaded_files), - encoding="utf-8", + _download_asset_primary_file(clip.asset, tmp / f"clip{index}.mp4") + + bgm_name = None + if bgm_track is not None and bgm_track.asset_id: + primary = bgm_track.asset.files.filter(is_primary=True).first() or bgm_track.asset.files.first() + suffix = Path(primary.object_key).suffix or ".mp3" if primary else ".mp3" + bgm_name = f"bgm{suffix}" + _download_asset_primary_file(bgm_track.asset, tmp / bgm_name) + + cues = _subtitle_cues(timeline, project, specs, starts, total) + sub_overlays: list[tuple[str, float, float]] = [] + for i, (start, end, text) in enumerate(cues): + png = f"sub{i}.png" + _render_subtitle_png(text, style_key, tmp / png) + sub_overlays.append((png, start, end)) + + export_job.progress = 35 + export_job.save(update_fields=["progress", "updated_at"]) + + command = _build_export_command( + n=len(clips), specs=specs, starts=starts, total=total, transition=transition, + sub_overlays=sub_overlays, bgm_name=bgm_name, bgm_volume=(bgm_track.volume / 100.0) if bgm_track else 1.0, ) + proc = subprocess.run(command, cwd=str(tmp), capture_output=True) + if proc.returncode != 0: + raise RuntimeError(f"ffmpeg export failed: {proc.stderr.decode('utf-8', 'ignore')[-1200:]}") output_path = tmp / "output.mp4" - command = [ - "ffmpeg", - "-y", - "-f", - "concat", - "-safe", - "0", - "-i", - str(concat_file), - "-vf", - "scale=1080:1920:force_original_aspect_ratio=decrease,pad=1080:1920:(ow-iw)/2:(oh-ih)/2", - "-r", - "30", - "-c:v", - "libx264", - "-pix_fmt", - "yuv420p", - "-c:a", - "aac", - "-movflags", - "+faststart", - str(output_path), - ] - subprocess.run(command, check=True, capture_output=True) export_job.progress = 85 export_job.save(update_fields=["progress", "updated_at"]) diff --git a/core/backend/apps/projects/views.py b/core/backend/apps/projects/views.py index 7e775e8..71cdf0f 100644 --- a/core/backend/apps/projects/views.py +++ b/core/backend/apps/projects/views.py @@ -1,6 +1,11 @@ +import logging +from pathlib import Path +import uuid + from django.db import transaction from rest_framework import status from rest_framework.decorators import action +from rest_framework.parsers import FormParser, MultiPartParser from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet @@ -8,13 +13,29 @@ from apps.ai.services import ( create_export_job, generate_base_asset, generate_project_script, - generate_storyboard, + generate_storyboard_frame, poll_video_segment, + submit_storyboard, submit_video_segment, ) +from apps.assets.models import Asset, AssetFile +from apps.assets.serializers import AssetFileSerializer +from apps.assets.storage import TosStorage from apps.common.api import TeamScopedViewSetMixin -from .models import BaseAssetGroup, Project, ProjectStage, ScriptVersion, Timeline, TimelineClip, VideoSegment +from .models import ( + BaseAssetGroup, + BgmTrack, + ExportJob, + Project, + ProjectStage, + ScriptVersion, + SubtitleTrack, + Timeline, + TimelineClip, + VideoSegment, + VideoSegmentVersion, +) from .serializers import ( BaseAssetGroupSerializer, ExportJobSerializer, @@ -23,8 +44,32 @@ from .serializers import ( StoryboardVersionSerializer, VideoSegmentVersionSerializer, ) +from .services.export import run_export_job_in_thread from .services.pipeline import STAGE_ORDER -from .tasks import poll_video_segment_task, run_export_job_task +from .tasks import poll_video_segment_task + +logger = logging.getLogger(__name__) + + +def _store_uploaded_asset(*, team, user, upload, asset_type: str, category: str, name: str) -> Asset: + """把上传的文件落到 TOS,建 Asset+AssetFile(主文件)。供上传视频段 / 上传 BGM 复用。""" + suffix = Path(upload.name).suffix.lower() or (".mp4" if asset_type == Asset.Type.VIDEO else ".mp3") + asset_id = uuid.uuid4() + object_key = f"teams/{team.id}/uploads/{asset_id}{suffix}" + stored = TosStorage().upload_fileobj( + fileobj=upload.file, + object_key=object_key, + content_type=upload.content_type or "application/octet-stream", + ) + asset = Asset.objects.create( + id=asset_id, team=team, created_by=user, name=name, + asset_type=asset_type, source=Asset.Source.UPLOAD, category=category, + ) + AssetFile.objects.create( + asset=asset, object_key=stored.object_key, bucket=stored.bucket, + content_type=stored.content_type, size_bytes=stored.size_bytes, is_primary=True, + ) + return asset def promote_base_asset_stage_if_ready(project: Project) -> bool: @@ -46,13 +91,15 @@ class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet): queryset = Project.objects.select_related("product", "timeline").prefetch_related( "stages", "video_segments", + "video_segments__adopted_version__asset__files", "script_versions", "script_versions__segments", "base_asset_groups", - "base_asset_groups__candidate_assets", + "base_asset_groups__adopted_asset__files", + "base_asset_groups__candidate_assets__files", "storyboard_versions", - "storyboard_versions__frames", - "timeline__clips", + "storyboard_versions__frames__asset__files", + "timeline__clips__asset__files", ).all() serializer_class = ProjectSerializer search_fields = ["name", "product__title"] @@ -122,15 +169,28 @@ class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet): @action(detail=True, methods=["post"], url_path="generate-storyboard") def generate_storyboard_action(self, request, pk=None): + """异步故事板·提交:快速创建版本(不在此生图、不推进阶段)。前端随后轮询 poll-storyboard 逐帧生成。""" project = self.get_object() - storyboard = generate_storyboard(project=project, user=request.user, prompt=request.data.get("prompt", "")) + storyboard = submit_storyboard(project=project, user=request.user, prompt=request.data.get("prompt", "")) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) - stage.status = ProjectStage.Status.SUCCEEDED + stage.status = ProjectStage.Status.RUNNING stage.save(update_fields=["status", "updated_at"]) - project.current_stage = ProjectStage.Stage.VIDEO - project.status = Project.Status.VIDEOING - project.save(update_fields=["current_stage", "status", "updated_at"]) - return Response(StoryboardVersionSerializer(storyboard).data, status=status.HTTP_201_CREATED) + return Response(StoryboardVersionSerializer(storyboard).data, status=status.HTTP_202_ACCEPTED) + + @action(detail=True, methods=["post"], url_path="poll-storyboard") + def poll_storyboard_action(self, request, pk=None): + """异步故事板·轮询:每次生成下一帧(单次 ARK 调用 ~20s)。全部完成 → 推进到 VIDEO 阶段。""" + project = self.get_object() + result = generate_storyboard_frame(project=project, user=request.user) + if result.get("status") == "succeeded": + stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) + stage.status = ProjectStage.Status.SUCCEEDED + stage.save(update_fields=["status", "updated_at"]) + project.current_stage = ProjectStage.Stage.VIDEO + project.status = Project.Status.VIDEOING + project.save(update_fields=["current_stage", "status", "updated_at"]) + http_status = status.HTTP_200_OK if result.get("status") == "succeeded" else status.HTTP_202_ACCEPTED + return Response(result, status=http_status) @action(detail=True, methods=["post"], url_path="skip-storyboard") @transaction.atomic @@ -149,7 +209,12 @@ class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet): project = self.get_object() segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id")) submit_video_segment(video_segment=segment, user=request.user, prompt=request.data.get("prompt", "")) - poll_video_segment_task.apply_async(args=[str(segment.id)], countdown=30) + # 有 Celery worker 时由它自动轮询;无 worker(本机 dev)则前端驱动 poll-video-segment。 + # 队列不可用不应让提交 500——已提交到 ARK,轮询是次要路径。 + try: + poll_video_segment_task.apply_async(args=[str(segment.id)], countdown=30) + except Exception: # noqa: BLE001 + logger.warning("poll_video_segment_task enqueue failed; relying on client polling", exc_info=True) return Response(ProjectSerializer(project).data, status=status.HTTP_202_ACCEPTED) @action(detail=True, methods=["post"], url_path="poll-video-segment") @@ -188,11 +253,155 @@ class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet): ) start_ms += segment.target_duration_seconds * 1000 export_job = create_export_job(timeline=timeline, user=request.user) - run_export_job_task.delay(str(export_job.id)) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) - stage.status = ProjectStage.Status.QUEUED + stage.status = ProjectStage.Status.RUNNING stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.EXPORT project.status = Project.Status.EXPORTING project.save(update_fields=["current_stage", "status", "updated_at"]) + # 后台线程跑真实 ffmpeg 拼接(无需 Celery worker);前端轮询 poll-export 取进度/成片。 + transaction.on_commit(lambda: run_export_job_in_thread(str(export_job.id))) return Response(ExportJobSerializer(export_job).data, status=status.HTTP_202_ACCEPTED) + + @action(detail=True, methods=["post", "get"], url_path="poll-export") + def poll_export_action(self, request, pk=None): + """拼接导出·轮询:回最新导出任务的状态/进度/成片 URL。成片就绪时把 EXPORT 阶段标记成功。""" + project = self.get_object() + timeline = getattr(project, "timeline", None) + export_job = timeline.export_jobs.order_by("-created_at").first() if timeline is not None else None + if export_job is None: + return Response({"status": "not_started", "progress": 0, "output_url": ""}) + + output_url = "" + output = export_job.output_asset + if output is not None: + primary = output.files.filter(is_primary=True).first() or output.files.first() + if primary is not None: + output_url = AssetFileSerializer().get_preview_url(primary) + + if export_job.status == ExportJob.Status.SUCCEEDED: + stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) + if stage.status != ProjectStage.Status.SUCCEEDED: + stage.status = ProjectStage.Status.SUCCEEDED + stage.save(update_fields=["status", "updated_at"]) + elif export_job.status == ExportJob.Status.FAILED: + stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) + if stage.status != ProjectStage.Status.FAILED: + stage.status = ProjectStage.Status.FAILED + stage.error_message = export_job.error_message + stage.save(update_fields=["status", "error_message", "updated_at"]) + + return Response({ + "status": export_job.status, + "progress": export_job.progress, + "output_asset": str(output.id) if output is not None else None, + "output_url": output_url, + "error_message": export_job.error_message, + }) + + @action(detail=True, methods=["post"], url_path="upload-video-segment", parser_classes=[MultiPartParser, FormParser]) + @transaction.atomic + def upload_video_segment_action(self, request, pk=None): + """上传自带视频替换某段:落 TOS → Asset(video) → VideoSegmentVersion → 采用并标记完成。""" + project = self.get_object() + upload = request.data.get("file") + if upload is None: + return Response({"detail": "no file uploaded"}, status=status.HTTP_400_BAD_REQUEST) + segment = VideoSegment.objects.filter(project=project, id=request.data.get("video_segment_id")).first() + if segment is None: + return Response({"detail": "video segment not found"}, status=status.HTTP_404_NOT_FOUND) + asset = _store_uploaded_asset( + team=project.team, user=request.user, upload=upload, + asset_type=Asset.Type.VIDEO, category=Asset.Category.VIDEO_CLIP, + name=f"{project.name}-上传-{segment.sort_order + 1}", + ) + version = VideoSegmentVersion.objects.create( + video_segment=segment, asset=asset, prompt="用户上传", is_adopted=True, + metadata={"source": "upload"}, + ) + segment.adopted_version = version + segment.status = VideoSegment.Status.SUCCEEDED + segment.error_message = "" + segment.save(update_fields=["adopted_version", "status", "error_message", "updated_at"]) + return Response(ProjectSerializer(project).data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["post"], url_path="upload-bgm", parser_classes=[MultiPartParser, FormParser]) + @transaction.atomic + def upload_bgm_action(self, request, pk=None): + """上传 BGM 音频:落 TOS → Asset(audio) → 设为 timeline 的(唯一)BGM 轨。""" + project = self.get_object() + upload = request.data.get("file") + if upload is None: + return Response({"detail": "no file uploaded"}, status=status.HTTP_400_BAD_REQUEST) + timeline, _ = Timeline.objects.get_or_create( + project=project, defaults={"name": f"{project.name} Timeline", "duration_seconds": 60} + ) + asset = _store_uploaded_asset( + team=project.team, user=request.user, upload=upload, + asset_type=Asset.Type.AUDIO, category=Asset.Category.UPLOAD, name=f"{project.name}-BGM", + ) + volume = int(request.data.get("volume") or 60) + timeline.bgm_tracks.all().delete() + BgmTrack.objects.create(timeline=timeline, asset=asset, volume=max(0, min(100, volume)), start_ms=0) + return Response(ProjectSerializer(self.get_object()).data, status=status.HTTP_201_CREATED) + + @action(detail=True, methods=["post", "put"], url_path="save-timeline") + @transaction.atomic + def save_timeline_action(self, request, pk=None): + """保存草稿:整体持久化时间轴编辑态(片段顺序/裁剪、字幕样式与内容、BGM 音量、转场、草稿元数据)。""" + project = self.get_object() + timeline, _ = Timeline.objects.get_or_create( + project=project, defaults={"name": f"{project.name} Timeline", "duration_seconds": 60} + ) + data = request.data + + clips = data.get("clips") + if isinstance(clips, list): + valid_asset_ids = set( + Asset.objects.filter(team=project.team, id__in=[c.get("asset") for c in clips if c.get("asset")]) + .values_list("id", flat=True) + ) + timeline.clips.all().delete() + start_ms = 0 + for index, clip in enumerate(clips): + asset_id = clip.get("asset") + if not asset_id or str(asset_id) not in {str(a) for a in valid_asset_ids}: + continue + duration = int(clip.get("duration_ms") or 15000) + TimelineClip.objects.create( + timeline=timeline, asset_id=asset_id, sort_order=index, start_ms=start_ms, + duration_ms=duration, trim_start_ms=int(clip.get("trim_start_ms") or 0), + trim_end_ms=clip.get("trim_end_ms"), + ) + start_ms += duration + timeline.duration_seconds = max(1, round(start_ms / 1000)) + + subtitle = data.get("subtitle") + if isinstance(subtitle, dict): + track = timeline.subtitle_tracks.first() or SubtitleTrack(timeline=timeline) + track.enabled = bool(subtitle.get("enabled", True)) + style = dict(track.style or {}) + if subtitle.get("style_key"): + style["key"] = subtitle["style_key"] + track.style = style + if isinstance(subtitle.get("content"), list): + track.content = subtitle["content"] + track.save() + + bgm = data.get("bgm") + if isinstance(bgm, dict): + track = timeline.bgm_tracks.first() + if bgm.get("clear"): + timeline.bgm_tracks.all().delete() + elif track is not None and bgm.get("volume") is not None: + track.volume = max(0, min(100, int(bgm["volume"]))) + track.save(update_fields=["volume", "updated_at"]) + + metadata = dict(timeline.metadata or {}) + if isinstance(data.get("transition"), dict): + metadata["transition"] = {"type": str(data["transition"].get("type", "none"))} + if isinstance(data.get("draft"), dict): + metadata["draft"] = data["draft"] + timeline.metadata = metadata + timeline.save(update_fields=["metadata", "duration_seconds", "updated_at"]) + return Response(ProjectSerializer(self.get_object()).data) diff --git a/core/backend/requirements.txt b/core/backend/requirements.txt index 4f1994a..9313837 100644 --- a/core/backend/requirements.txt +++ b/core/backend/requirements.txt @@ -10,3 +10,4 @@ requests>=2.31,<3.0 gunicorn>=21.2,<23.0 whitenoise>=6.6,<7.0 +Pillow>=10.0