import logging from pathlib import Path import uuid from django.db import transaction from rest_framework import status from rest_framework.decorators import action from rest_framework.parsers import FormParser, MultiPartParser from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from apps.ai.services import ( create_export_job, generate_base_asset, generate_project_script, generate_storyboard_frame, poll_video_segment, submit_storyboard, submit_video_segment, ) from apps.assets.models import Asset, AssetFile from apps.assets.serializers import AssetFileSerializer from apps.assets.storage import TosStorage from apps.common.api import TeamScopedViewSetMixin from .models import ( BaseAssetGroup, BgmTrack, ExportJob, Project, ProjectStage, ScriptVersion, SubtitleTrack, Timeline, TimelineClip, VideoSegment, VideoSegmentVersion, ) from .serializers import ( BaseAssetGroupSerializer, ExportJobSerializer, ProjectSerializer, ScriptVersionSerializer, StoryboardVersionSerializer, VideoSegmentVersionSerializer, ) from .services.export import run_export_job_in_thread from .services.pipeline import STAGE_ORDER from .tasks import poll_video_segment_task logger = logging.getLogger(__name__) def _store_uploaded_asset(*, team, user, upload, asset_type: str, category: str, name: str) -> Asset: """把上传的文件落到 TOS,建 Asset+AssetFile(主文件)。供上传视频段 / 上传 BGM 复用。""" suffix = Path(upload.name).suffix.lower() or (".mp4" if asset_type == Asset.Type.VIDEO else ".mp3") asset_id = uuid.uuid4() object_key = f"teams/{team.id}/uploads/{asset_id}{suffix}" stored = TosStorage().upload_fileobj( fileobj=upload.file, object_key=object_key, content_type=upload.content_type or "application/octet-stream", ) asset = Asset.objects.create( id=asset_id, team=team, created_by=user, name=name, asset_type=asset_type, source=Asset.Source.UPLOAD, category=category, ) AssetFile.objects.create( asset=asset, object_key=stored.object_key, bucket=stored.bucket, content_type=stored.content_type, size_bytes=stored.size_bytes, is_primary=True, ) return asset def promote_base_asset_stage_if_ready(project: Project) -> bool: adopted_kind_count = ( project.base_asset_groups.filter(adopted_asset__isnull=False).values("kind").distinct().count() ) if adopted_kind_count < 3: return False stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS) stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.STORYBOARD project.status = Project.Status.STORYBOARDING project.save(update_fields=["current_stage", "status", "updated_at"]) return True class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet): queryset = Project.objects.select_related("product", "timeline").prefetch_related( "stages", "video_segments", "video_segments__adopted_version__asset__files", "script_versions", "script_versions__segments", "base_asset_groups", "base_asset_groups__adopted_asset__files", "base_asset_groups__candidate_assets__files", "storyboard_versions", "storyboard_versions__frames__asset__files", "timeline__clips__asset__files", ).all() serializer_class = ProjectSerializer search_fields = ["name", "product__title"] ordering_fields = ["created_at", "updated_at", "name"] @transaction.atomic def perform_create(self, serializer): project = serializer.save(team=self.get_team(), created_by=self.request.user) for stage in STAGE_ORDER: ProjectStage.objects.create(project=project, stage=stage) for index in range(4): VideoSegment.objects.create(project=project, sort_order=index, target_duration_seconds=15) @action(detail=True, methods=["post"], url_path="generate-script") def generate_script(self, request, pk=None): project = self.get_object() script = generate_project_script( project=project, user=request.user, user_prompt=request.data.get("prompt", ""), selling_point_ids=request.data.get("selling_point_ids") or [], ) return Response(ScriptVersionSerializer(script).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post"], url_path="adopt-script") @transaction.atomic def adopt_script(self, request, pk=None): project = self.get_object() script_id = request.data.get("script_version_id") script = ScriptVersion.objects.select_for_update().get(project=project, id=script_id) ScriptVersion.objects.filter(project=project).update(is_adopted=False) script.is_adopted = True script.save(update_fields=["is_adopted", "updated_at"]) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.SCRIPT) stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.BASE_ASSETS project.status = Project.Status.ASSETING project.save(update_fields=["current_stage", "status", "updated_at"]) return Response(ScriptVersionSerializer(script).data) @action(detail=True, methods=["post"], url_path="generate-base-asset") def generate_base_asset_action(self, request, pk=None): project = self.get_object() kind = request.data.get("kind") if kind not in BaseAssetGroup.Kind.values: return Response({"detail": "invalid base asset kind"}, status=status.HTTP_400_BAD_REQUEST) group = generate_base_asset(project=project, user=request.user, kind=kind, prompt=request.data.get("prompt", "")) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS) stage.status = ProjectStage.Status.NEEDS_REVIEW stage.save(update_fields=["status", "updated_at"]) promote_base_asset_stage_if_ready(project) return Response(BaseAssetGroupSerializer(group).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post"], url_path="adopt-base-asset") @transaction.atomic def adopt_base_asset(self, request, pk=None): project = self.get_object() group = BaseAssetGroup.objects.select_for_update().get(project=project, id=request.data.get("group_id")) asset_id = request.data.get("asset_id") if not group.candidate_assets.filter(id=asset_id).exists(): return Response({"detail": "asset is not a candidate"}, status=status.HTTP_400_BAD_REQUEST) group.adopted_asset_id = asset_id group.save(update_fields=["adopted_asset", "updated_at"]) promote_base_asset_stage_if_ready(project) return Response(BaseAssetGroupSerializer(group).data) @action(detail=True, methods=["post"], url_path="generate-storyboard") def generate_storyboard_action(self, request, pk=None): """异步故事板·提交:快速创建版本(不在此生图、不推进阶段)。前端随后轮询 poll-storyboard 逐帧生成。""" project = self.get_object() storyboard = submit_storyboard(project=project, user=request.user, prompt=request.data.get("prompt", "")) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) stage.status = ProjectStage.Status.RUNNING stage.save(update_fields=["status", "updated_at"]) return Response(StoryboardVersionSerializer(storyboard).data, status=status.HTTP_202_ACCEPTED) @action(detail=True, methods=["post"], url_path="poll-storyboard") def poll_storyboard_action(self, request, pk=None): """异步故事板·轮询:每次生成下一帧(单次 ARK 调用 ~20s)。全部完成 → 推进到 VIDEO 阶段。""" project = self.get_object() result = generate_storyboard_frame(project=project, user=request.user) if result.get("status") == "succeeded": stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.VIDEO project.status = Project.Status.VIDEOING project.save(update_fields=["current_stage", "status", "updated_at"]) http_status = status.HTTP_200_OK if result.get("status") == "succeeded" else status.HTTP_202_ACCEPTED return Response(result, status=http_status) @action(detail=True, methods=["post"], url_path="skip-storyboard") @transaction.atomic def skip_storyboard(self, request, pk=None): project = self.get_object() stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) stage.status = ProjectStage.Status.SKIPPED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.VIDEO project.status = Project.Status.VIDEOING project.save(update_fields=["current_stage", "status", "updated_at"]) return Response(ProjectSerializer(project).data) @action(detail=True, methods=["post"], url_path="submit-video-segment") def submit_video_segment_action(self, request, pk=None): project = self.get_object() segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id")) submit_video_segment(video_segment=segment, user=request.user, prompt=request.data.get("prompt", "")) # 有 Celery worker 时由它自动轮询;无 worker(本机 dev)则前端驱动 poll-video-segment。 # 队列不可用不应让提交 500——已提交到 ARK,轮询是次要路径。 try: poll_video_segment_task.apply_async(args=[str(segment.id)], countdown=30) except Exception: # noqa: BLE001 logger.warning("poll_video_segment_task enqueue failed; relying on client polling", exc_info=True) return Response(ProjectSerializer(project).data, status=status.HTTP_202_ACCEPTED) @action(detail=True, methods=["post"], url_path="poll-video-segment") def poll_video_segment_action(self, request, pk=None): project = self.get_object() segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id")) version = poll_video_segment(video_segment=segment, user=request.user) if version is None: return Response({"status": segment.status}, status=status.HTTP_202_ACCEPTED) return Response(VideoSegmentVersionSerializer(version).data) @action(detail=True, methods=["post"], url_path="submit-export") @transaction.atomic def submit_export(self, request, pk=None): project = self.get_object() missing_segments = project.video_segments.filter(adopted_version__isnull=True).count() if missing_segments: return Response( {"detail": f"{missing_segments} video segments are not ready for export"}, status=status.HTTP_400_BAD_REQUEST, ) timeline, _ = Timeline.objects.get_or_create( project=project, defaults={"name": f"{project.name} Timeline", "duration_seconds": 60}, ) if not timeline.clips.exists(): start_ms = 0 for segment in project.video_segments.select_related("adopted_version__asset").order_by("sort_order"): if segment.adopted_version_id: TimelineClip.objects.create( timeline=timeline, asset=segment.adopted_version.asset, sort_order=segment.sort_order, start_ms=start_ms, duration_ms=segment.target_duration_seconds * 1000, ) start_ms += segment.target_duration_seconds * 1000 export_job = create_export_job(timeline=timeline, user=request.user) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) stage.status = ProjectStage.Status.RUNNING stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.EXPORT project.status = Project.Status.EXPORTING project.save(update_fields=["current_stage", "status", "updated_at"]) # 后台线程跑真实 ffmpeg 拼接(无需 Celery worker);前端轮询 poll-export 取进度/成片。 transaction.on_commit(lambda: run_export_job_in_thread(str(export_job.id))) return Response(ExportJobSerializer(export_job).data, status=status.HTTP_202_ACCEPTED) @action(detail=True, methods=["post", "get"], url_path="poll-export") def poll_export_action(self, request, pk=None): """拼接导出·轮询:回最新导出任务的状态/进度/成片 URL。成片就绪时把 EXPORT 阶段标记成功。""" project = self.get_object() timeline = getattr(project, "timeline", None) export_job = timeline.export_jobs.order_by("-created_at").first() if timeline is not None else None if export_job is None: return Response({"status": "not_started", "progress": 0, "output_url": ""}) output_url = "" output = export_job.output_asset if output is not None: primary = output.files.filter(is_primary=True).first() or output.files.first() if primary is not None: output_url = AssetFileSerializer().get_preview_url(primary) if export_job.status == ExportJob.Status.SUCCEEDED: stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) if stage.status != ProjectStage.Status.SUCCEEDED: stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) elif export_job.status == ExportJob.Status.FAILED: stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) if stage.status != ProjectStage.Status.FAILED: stage.status = ProjectStage.Status.FAILED stage.error_message = export_job.error_message stage.save(update_fields=["status", "error_message", "updated_at"]) return Response({ "status": export_job.status, "progress": export_job.progress, "output_asset": str(output.id) if output is not None else None, "output_url": output_url, "error_message": export_job.error_message, }) @action(detail=True, methods=["post"], url_path="upload-video-segment", parser_classes=[MultiPartParser, FormParser]) @transaction.atomic def upload_video_segment_action(self, request, pk=None): """上传自带视频替换某段:落 TOS → Asset(video) → VideoSegmentVersion → 采用并标记完成。""" project = self.get_object() upload = request.data.get("file") if upload is None: return Response({"detail": "no file uploaded"}, status=status.HTTP_400_BAD_REQUEST) segment = VideoSegment.objects.filter(project=project, id=request.data.get("video_segment_id")).first() if segment is None: return Response({"detail": "video segment not found"}, status=status.HTTP_404_NOT_FOUND) asset = _store_uploaded_asset( team=project.team, user=request.user, upload=upload, asset_type=Asset.Type.VIDEO, category=Asset.Category.VIDEO_CLIP, name=f"{project.name}-上传-{segment.sort_order + 1}", ) version = VideoSegmentVersion.objects.create( video_segment=segment, asset=asset, prompt="用户上传", is_adopted=True, metadata={"source": "upload"}, ) segment.adopted_version = version segment.status = VideoSegment.Status.SUCCEEDED segment.error_message = "" segment.save(update_fields=["adopted_version", "status", "error_message", "updated_at"]) return Response(ProjectSerializer(project).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post"], url_path="upload-bgm", parser_classes=[MultiPartParser, FormParser]) @transaction.atomic def upload_bgm_action(self, request, pk=None): """上传 BGM 音频:落 TOS → Asset(audio) → 设为 timeline 的(唯一)BGM 轨。""" project = self.get_object() upload = request.data.get("file") if upload is None: return Response({"detail": "no file uploaded"}, status=status.HTTP_400_BAD_REQUEST) timeline, _ = Timeline.objects.get_or_create( project=project, defaults={"name": f"{project.name} Timeline", "duration_seconds": 60} ) asset = _store_uploaded_asset( team=project.team, user=request.user, upload=upload, asset_type=Asset.Type.AUDIO, category=Asset.Category.UPLOAD, name=f"{project.name}-BGM", ) volume = int(request.data.get("volume") or 60) timeline.bgm_tracks.all().delete() BgmTrack.objects.create(timeline=timeline, asset=asset, volume=max(0, min(100, volume)), start_ms=0) return Response(ProjectSerializer(self.get_object()).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post", "put"], url_path="save-timeline") @transaction.atomic def save_timeline_action(self, request, pk=None): """保存草稿:整体持久化时间轴编辑态(片段顺序/裁剪、字幕样式与内容、BGM 音量、转场、草稿元数据)。""" project = self.get_object() timeline, _ = Timeline.objects.get_or_create( project=project, defaults={"name": f"{project.name} Timeline", "duration_seconds": 60} ) data = request.data clips = data.get("clips") if isinstance(clips, list): valid_asset_ids = set( Asset.objects.filter(team=project.team, id__in=[c.get("asset") for c in clips if c.get("asset")]) .values_list("id", flat=True) ) timeline.clips.all().delete() start_ms = 0 for index, clip in enumerate(clips): asset_id = clip.get("asset") if not asset_id or str(asset_id) not in {str(a) for a in valid_asset_ids}: continue duration = int(clip.get("duration_ms") or 15000) TimelineClip.objects.create( timeline=timeline, asset_id=asset_id, sort_order=index, start_ms=start_ms, duration_ms=duration, trim_start_ms=int(clip.get("trim_start_ms") or 0), trim_end_ms=clip.get("trim_end_ms"), ) start_ms += duration timeline.duration_seconds = max(1, round(start_ms / 1000)) subtitle = data.get("subtitle") if isinstance(subtitle, dict): track = timeline.subtitle_tracks.first() or SubtitleTrack(timeline=timeline) track.enabled = bool(subtitle.get("enabled", True)) style = dict(track.style or {}) if subtitle.get("style_key"): style["key"] = subtitle["style_key"] track.style = style if isinstance(subtitle.get("content"), list): track.content = subtitle["content"] track.save() bgm = data.get("bgm") if isinstance(bgm, dict): track = timeline.bgm_tracks.first() if bgm.get("clear"): timeline.bgm_tracks.all().delete() elif track is not None and bgm.get("volume") is not None: track.volume = max(0, min(100, int(bgm["volume"]))) track.save(update_fields=["volume", "updated_at"]) metadata = dict(timeline.metadata or {}) if isinstance(data.get("transition"), dict): metadata["transition"] = {"type": str(data["transition"].get("type", "none"))} if isinstance(data.get("draft"), dict): metadata["draft"] = data["draft"] timeline.metadata = metadata timeline.save(update_fields=["metadata", "duration_seconds", "updated_at"]) return Response(ProjectSerializer(self.get_object()).data)