from django.db import transaction from rest_framework import status from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from apps.ai.services import ( create_export_job, generate_base_asset, generate_project_script, generate_storyboard, poll_video_segment, submit_video_segment, ) from apps.common.api import TeamScopedViewSetMixin from .models import BaseAssetGroup, Project, ProjectStage, ScriptVersion, Timeline, TimelineClip, VideoSegment from .serializers import ( BaseAssetGroupSerializer, ExportJobSerializer, ProjectSerializer, ScriptVersionSerializer, StoryboardVersionSerializer, VideoSegmentVersionSerializer, ) from .services.pipeline import STAGE_ORDER from .tasks import poll_video_segment_task, run_export_job_task def promote_base_asset_stage_if_ready(project: Project) -> bool: adopted_kind_count = ( project.base_asset_groups.filter(adopted_asset__isnull=False).values("kind").distinct().count() ) if adopted_kind_count < 3: return False stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS) stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.STORYBOARD project.status = Project.Status.STORYBOARDING project.save(update_fields=["current_stage", "status", "updated_at"]) return True class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet): queryset = Project.objects.select_related("product", "timeline").prefetch_related( "stages", "video_segments", "script_versions", "script_versions__segments", "base_asset_groups", "base_asset_groups__candidate_assets", "storyboard_versions", "storyboard_versions__frames", "timeline__clips", ).all() serializer_class = ProjectSerializer search_fields = ["name", "product__title"] ordering_fields = ["created_at", "updated_at", "name"] @transaction.atomic def perform_create(self, serializer): project = serializer.save(team=self.get_team(), created_by=self.request.user) for stage in STAGE_ORDER: ProjectStage.objects.create(project=project, stage=stage) for index in range(4): VideoSegment.objects.create(project=project, sort_order=index, target_duration_seconds=15) @action(detail=True, methods=["post"], url_path="generate-script") def generate_script(self, request, pk=None): project = self.get_object() script = generate_project_script( project=project, user=request.user, user_prompt=request.data.get("prompt", ""), selling_point_ids=request.data.get("selling_point_ids") or [], ) return Response(ScriptVersionSerializer(script).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post"], url_path="adopt-script") @transaction.atomic def adopt_script(self, request, pk=None): project = self.get_object() script_id = request.data.get("script_version_id") script = ScriptVersion.objects.select_for_update().get(project=project, id=script_id) ScriptVersion.objects.filter(project=project).update(is_adopted=False) script.is_adopted = True script.save(update_fields=["is_adopted", "updated_at"]) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.SCRIPT) stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.BASE_ASSETS project.status = Project.Status.ASSETING project.save(update_fields=["current_stage", "status", "updated_at"]) return Response(ScriptVersionSerializer(script).data) @action(detail=True, methods=["post"], url_path="generate-base-asset") def generate_base_asset_action(self, request, pk=None): project = self.get_object() kind = request.data.get("kind") if kind not in BaseAssetGroup.Kind.values: return Response({"detail": "invalid base asset kind"}, status=status.HTTP_400_BAD_REQUEST) group = generate_base_asset(project=project, user=request.user, kind=kind, prompt=request.data.get("prompt", "")) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS) stage.status = ProjectStage.Status.NEEDS_REVIEW stage.save(update_fields=["status", "updated_at"]) promote_base_asset_stage_if_ready(project) return Response(BaseAssetGroupSerializer(group).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post"], url_path="adopt-base-asset") @transaction.atomic def adopt_base_asset(self, request, pk=None): project = self.get_object() group = BaseAssetGroup.objects.select_for_update().get(project=project, id=request.data.get("group_id")) asset_id = request.data.get("asset_id") if not group.candidate_assets.filter(id=asset_id).exists(): return Response({"detail": "asset is not a candidate"}, status=status.HTTP_400_BAD_REQUEST) group.adopted_asset_id = asset_id group.save(update_fields=["adopted_asset", "updated_at"]) promote_base_asset_stage_if_ready(project) return Response(BaseAssetGroupSerializer(group).data) @action(detail=True, methods=["post"], url_path="generate-storyboard") def generate_storyboard_action(self, request, pk=None): project = self.get_object() storyboard = generate_storyboard(project=project, user=request.user, prompt=request.data.get("prompt", "")) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) stage.status = ProjectStage.Status.SUCCEEDED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.VIDEO project.status = Project.Status.VIDEOING project.save(update_fields=["current_stage", "status", "updated_at"]) return Response(StoryboardVersionSerializer(storyboard).data, status=status.HTTP_201_CREATED) @action(detail=True, methods=["post"], url_path="skip-storyboard") @transaction.atomic def skip_storyboard(self, request, pk=None): project = self.get_object() stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD) stage.status = ProjectStage.Status.SKIPPED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.VIDEO project.status = Project.Status.VIDEOING project.save(update_fields=["current_stage", "status", "updated_at"]) return Response(ProjectSerializer(project).data) @action(detail=True, methods=["post"], url_path="submit-video-segment") def submit_video_segment_action(self, request, pk=None): project = self.get_object() segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id")) submit_video_segment(video_segment=segment, user=request.user, prompt=request.data.get("prompt", "")) poll_video_segment_task.apply_async(args=[str(segment.id)], countdown=30) return Response(ProjectSerializer(project).data, status=status.HTTP_202_ACCEPTED) @action(detail=True, methods=["post"], url_path="poll-video-segment") def poll_video_segment_action(self, request, pk=None): project = self.get_object() segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id")) version = poll_video_segment(video_segment=segment, user=request.user) if version is None: return Response({"status": segment.status}, status=status.HTTP_202_ACCEPTED) return Response(VideoSegmentVersionSerializer(version).data) @action(detail=True, methods=["post"], url_path="submit-export") @transaction.atomic def submit_export(self, request, pk=None): project = self.get_object() missing_segments = project.video_segments.filter(adopted_version__isnull=True).count() if missing_segments: return Response( {"detail": f"{missing_segments} video segments are not ready for export"}, status=status.HTTP_400_BAD_REQUEST, ) timeline, _ = Timeline.objects.get_or_create( project=project, defaults={"name": f"{project.name} Timeline", "duration_seconds": 60}, ) if not timeline.clips.exists(): start_ms = 0 for segment in project.video_segments.select_related("adopted_version__asset").order_by("sort_order"): if segment.adopted_version_id: TimelineClip.objects.create( timeline=timeline, asset=segment.adopted_version.asset, sort_order=segment.sort_order, start_ms=start_ms, duration_ms=segment.target_duration_seconds * 1000, ) start_ms += segment.target_duration_seconds * 1000 export_job = create_export_job(timeline=timeline, user=request.user) run_export_job_task.delay(str(export_job.id)) stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT) stage.status = ProjectStage.Status.QUEUED stage.save(update_fields=["status", "updated_at"]) project.current_stage = ProjectStage.Stage.EXPORT project.status = Project.Status.EXPORTING project.save(update_fields=["current_stage", "status", "updated_at"]) return Response(ExportJobSerializer(export_job).data, status=status.HTTP_202_ACCEPTED)