199 lines
9.8 KiB
Python

from django.db import transaction
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from apps.ai.services import (
create_export_job,
generate_base_asset,
generate_project_script,
generate_storyboard,
poll_video_segment,
submit_video_segment,
)
from apps.common.api import TeamScopedViewSetMixin
from .models import BaseAssetGroup, Project, ProjectStage, ScriptVersion, Timeline, TimelineClip, VideoSegment
from .serializers import (
BaseAssetGroupSerializer,
ExportJobSerializer,
ProjectSerializer,
ScriptVersionSerializer,
StoryboardVersionSerializer,
VideoSegmentVersionSerializer,
)
from .services.pipeline import STAGE_ORDER
from .tasks import poll_video_segment_task, run_export_job_task
def promote_base_asset_stage_if_ready(project: Project) -> bool:
adopted_kind_count = (
project.base_asset_groups.filter(adopted_asset__isnull=False).values("kind").distinct().count()
)
if adopted_kind_count < 3:
return False
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS)
stage.status = ProjectStage.Status.SUCCEEDED
stage.save(update_fields=["status", "updated_at"])
project.current_stage = ProjectStage.Stage.STORYBOARD
project.status = Project.Status.STORYBOARDING
project.save(update_fields=["current_stage", "status", "updated_at"])
return True
class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet):
queryset = Project.objects.select_related("product", "timeline").prefetch_related(
"stages",
"video_segments",
"script_versions",
"script_versions__segments",
"base_asset_groups",
"base_asset_groups__candidate_assets",
"storyboard_versions",
"storyboard_versions__frames",
"timeline__clips",
).all()
serializer_class = ProjectSerializer
search_fields = ["name", "product__title"]
ordering_fields = ["created_at", "updated_at", "name"]
@transaction.atomic
def perform_create(self, serializer):
project = serializer.save(team=self.get_team(), created_by=self.request.user)
for stage in STAGE_ORDER:
ProjectStage.objects.create(project=project, stage=stage)
for index in range(4):
VideoSegment.objects.create(project=project, sort_order=index, target_duration_seconds=15)
@action(detail=True, methods=["post"], url_path="generate-script")
def generate_script(self, request, pk=None):
project = self.get_object()
script = generate_project_script(
project=project,
user=request.user,
user_prompt=request.data.get("prompt", ""),
selling_point_ids=request.data.get("selling_point_ids") or [],
)
return Response(ScriptVersionSerializer(script).data, status=status.HTTP_201_CREATED)
@action(detail=True, methods=["post"], url_path="adopt-script")
@transaction.atomic
def adopt_script(self, request, pk=None):
project = self.get_object()
script_id = request.data.get("script_version_id")
script = ScriptVersion.objects.select_for_update().get(project=project, id=script_id)
ScriptVersion.objects.filter(project=project).update(is_adopted=False)
script.is_adopted = True
script.save(update_fields=["is_adopted", "updated_at"])
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.SCRIPT)
stage.status = ProjectStage.Status.SUCCEEDED
stage.save(update_fields=["status", "updated_at"])
project.current_stage = ProjectStage.Stage.BASE_ASSETS
project.status = Project.Status.ASSETING
project.save(update_fields=["current_stage", "status", "updated_at"])
return Response(ScriptVersionSerializer(script).data)
@action(detail=True, methods=["post"], url_path="generate-base-asset")
def generate_base_asset_action(self, request, pk=None):
project = self.get_object()
kind = request.data.get("kind")
if kind not in BaseAssetGroup.Kind.values:
return Response({"detail": "invalid base asset kind"}, status=status.HTTP_400_BAD_REQUEST)
group = generate_base_asset(project=project, user=request.user, kind=kind, prompt=request.data.get("prompt", ""))
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS)
stage.status = ProjectStage.Status.NEEDS_REVIEW
stage.save(update_fields=["status", "updated_at"])
promote_base_asset_stage_if_ready(project)
return Response(BaseAssetGroupSerializer(group).data, status=status.HTTP_201_CREATED)
@action(detail=True, methods=["post"], url_path="adopt-base-asset")
@transaction.atomic
def adopt_base_asset(self, request, pk=None):
project = self.get_object()
group = BaseAssetGroup.objects.select_for_update().get(project=project, id=request.data.get("group_id"))
asset_id = request.data.get("asset_id")
if not group.candidate_assets.filter(id=asset_id).exists():
return Response({"detail": "asset is not a candidate"}, status=status.HTTP_400_BAD_REQUEST)
group.adopted_asset_id = asset_id
group.save(update_fields=["adopted_asset", "updated_at"])
promote_base_asset_stage_if_ready(project)
return Response(BaseAssetGroupSerializer(group).data)
@action(detail=True, methods=["post"], url_path="generate-storyboard")
def generate_storyboard_action(self, request, pk=None):
project = self.get_object()
storyboard = generate_storyboard(project=project, user=request.user, prompt=request.data.get("prompt", ""))
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD)
stage.status = ProjectStage.Status.SUCCEEDED
stage.save(update_fields=["status", "updated_at"])
project.current_stage = ProjectStage.Stage.VIDEO
project.status = Project.Status.VIDEOING
project.save(update_fields=["current_stage", "status", "updated_at"])
return Response(StoryboardVersionSerializer(storyboard).data, status=status.HTTP_201_CREATED)
@action(detail=True, methods=["post"], url_path="skip-storyboard")
@transaction.atomic
def skip_storyboard(self, request, pk=None):
project = self.get_object()
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD)
stage.status = ProjectStage.Status.SKIPPED
stage.save(update_fields=["status", "updated_at"])
project.current_stage = ProjectStage.Stage.VIDEO
project.status = Project.Status.VIDEOING
project.save(update_fields=["current_stage", "status", "updated_at"])
return Response(ProjectSerializer(project).data)
@action(detail=True, methods=["post"], url_path="submit-video-segment")
def submit_video_segment_action(self, request, pk=None):
project = self.get_object()
segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id"))
submit_video_segment(video_segment=segment, user=request.user, prompt=request.data.get("prompt", ""))
poll_video_segment_task.apply_async(args=[str(segment.id)], countdown=30)
return Response(ProjectSerializer(project).data, status=status.HTTP_202_ACCEPTED)
@action(detail=True, methods=["post"], url_path="poll-video-segment")
def poll_video_segment_action(self, request, pk=None):
project = self.get_object()
segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id"))
version = poll_video_segment(video_segment=segment, user=request.user)
if version is None:
return Response({"status": segment.status}, status=status.HTTP_202_ACCEPTED)
return Response(VideoSegmentVersionSerializer(version).data)
@action(detail=True, methods=["post"], url_path="submit-export")
@transaction.atomic
def submit_export(self, request, pk=None):
project = self.get_object()
missing_segments = project.video_segments.filter(adopted_version__isnull=True).count()
if missing_segments:
return Response(
{"detail": f"{missing_segments} video segments are not ready for export"},
status=status.HTTP_400_BAD_REQUEST,
)
timeline, _ = Timeline.objects.get_or_create(
project=project,
defaults={"name": f"{project.name} Timeline", "duration_seconds": 60},
)
if not timeline.clips.exists():
start_ms = 0
for segment in project.video_segments.select_related("adopted_version__asset").order_by("sort_order"):
if segment.adopted_version_id:
TimelineClip.objects.create(
timeline=timeline,
asset=segment.adopted_version.asset,
sort_order=segment.sort_order,
start_ms=start_ms,
duration_ms=segment.target_duration_seconds * 1000,
)
start_ms += segment.target_duration_seconds * 1000
export_job = create_export_job(timeline=timeline, user=request.user)
run_export_job_task.delay(str(export_job.id))
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT)
stage.status = ProjectStage.Status.QUEUED
stage.save(update_fields=["status", "updated_at"])
project.current_stage = ProjectStage.Stage.EXPORT
project.status = Project.Status.EXPORTING
project.save(update_fields=["current_stage", "status", "updated_at"])
return Response(ExportJobSerializer(export_job).data, status=status.HTTP_202_ACCEPTED)