199 lines
9.8 KiB
Python
199 lines
9.8 KiB
Python
from django.db import transaction
|
|
from rest_framework import status
|
|
from rest_framework.decorators import action
|
|
from rest_framework.response import Response
|
|
from rest_framework.viewsets import ModelViewSet
|
|
|
|
from apps.ai.services import (
|
|
create_export_job,
|
|
generate_base_asset,
|
|
generate_project_script,
|
|
generate_storyboard,
|
|
poll_video_segment,
|
|
submit_video_segment,
|
|
)
|
|
from apps.common.api import TeamScopedViewSetMixin
|
|
|
|
from .models import BaseAssetGroup, Project, ProjectStage, ScriptVersion, Timeline, TimelineClip, VideoSegment
|
|
from .serializers import (
|
|
BaseAssetGroupSerializer,
|
|
ExportJobSerializer,
|
|
ProjectSerializer,
|
|
ScriptVersionSerializer,
|
|
StoryboardVersionSerializer,
|
|
VideoSegmentVersionSerializer,
|
|
)
|
|
from .services.pipeline import STAGE_ORDER
|
|
from .tasks import poll_video_segment_task, run_export_job_task
|
|
|
|
|
|
def promote_base_asset_stage_if_ready(project: Project) -> bool:
|
|
adopted_kind_count = (
|
|
project.base_asset_groups.filter(adopted_asset__isnull=False).values("kind").distinct().count()
|
|
)
|
|
if adopted_kind_count < 3:
|
|
return False
|
|
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS)
|
|
stage.status = ProjectStage.Status.SUCCEEDED
|
|
stage.save(update_fields=["status", "updated_at"])
|
|
project.current_stage = ProjectStage.Stage.STORYBOARD
|
|
project.status = Project.Status.STORYBOARDING
|
|
project.save(update_fields=["current_stage", "status", "updated_at"])
|
|
return True
|
|
|
|
|
|
class ProjectViewSet(TeamScopedViewSetMixin, ModelViewSet):
|
|
queryset = Project.objects.select_related("product", "timeline").prefetch_related(
|
|
"stages",
|
|
"video_segments",
|
|
"script_versions",
|
|
"script_versions__segments",
|
|
"base_asset_groups",
|
|
"base_asset_groups__candidate_assets",
|
|
"storyboard_versions",
|
|
"storyboard_versions__frames",
|
|
"timeline__clips",
|
|
).all()
|
|
serializer_class = ProjectSerializer
|
|
search_fields = ["name", "product__title"]
|
|
ordering_fields = ["created_at", "updated_at", "name"]
|
|
|
|
@transaction.atomic
|
|
def perform_create(self, serializer):
|
|
project = serializer.save(team=self.get_team(), created_by=self.request.user)
|
|
for stage in STAGE_ORDER:
|
|
ProjectStage.objects.create(project=project, stage=stage)
|
|
for index in range(4):
|
|
VideoSegment.objects.create(project=project, sort_order=index, target_duration_seconds=15)
|
|
|
|
@action(detail=True, methods=["post"], url_path="generate-script")
|
|
def generate_script(self, request, pk=None):
|
|
project = self.get_object()
|
|
script = generate_project_script(
|
|
project=project,
|
|
user=request.user,
|
|
user_prompt=request.data.get("prompt", ""),
|
|
selling_point_ids=request.data.get("selling_point_ids") or [],
|
|
)
|
|
return Response(ScriptVersionSerializer(script).data, status=status.HTTP_201_CREATED)
|
|
|
|
@action(detail=True, methods=["post"], url_path="adopt-script")
|
|
@transaction.atomic
|
|
def adopt_script(self, request, pk=None):
|
|
project = self.get_object()
|
|
script_id = request.data.get("script_version_id")
|
|
script = ScriptVersion.objects.select_for_update().get(project=project, id=script_id)
|
|
ScriptVersion.objects.filter(project=project).update(is_adopted=False)
|
|
script.is_adopted = True
|
|
script.save(update_fields=["is_adopted", "updated_at"])
|
|
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.SCRIPT)
|
|
stage.status = ProjectStage.Status.SUCCEEDED
|
|
stage.save(update_fields=["status", "updated_at"])
|
|
project.current_stage = ProjectStage.Stage.BASE_ASSETS
|
|
project.status = Project.Status.ASSETING
|
|
project.save(update_fields=["current_stage", "status", "updated_at"])
|
|
return Response(ScriptVersionSerializer(script).data)
|
|
|
|
@action(detail=True, methods=["post"], url_path="generate-base-asset")
|
|
def generate_base_asset_action(self, request, pk=None):
|
|
project = self.get_object()
|
|
kind = request.data.get("kind")
|
|
if kind not in BaseAssetGroup.Kind.values:
|
|
return Response({"detail": "invalid base asset kind"}, status=status.HTTP_400_BAD_REQUEST)
|
|
group = generate_base_asset(project=project, user=request.user, kind=kind, prompt=request.data.get("prompt", ""))
|
|
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.BASE_ASSETS)
|
|
stage.status = ProjectStage.Status.NEEDS_REVIEW
|
|
stage.save(update_fields=["status", "updated_at"])
|
|
promote_base_asset_stage_if_ready(project)
|
|
return Response(BaseAssetGroupSerializer(group).data, status=status.HTTP_201_CREATED)
|
|
|
|
@action(detail=True, methods=["post"], url_path="adopt-base-asset")
|
|
@transaction.atomic
|
|
def adopt_base_asset(self, request, pk=None):
|
|
project = self.get_object()
|
|
group = BaseAssetGroup.objects.select_for_update().get(project=project, id=request.data.get("group_id"))
|
|
asset_id = request.data.get("asset_id")
|
|
if not group.candidate_assets.filter(id=asset_id).exists():
|
|
return Response({"detail": "asset is not a candidate"}, status=status.HTTP_400_BAD_REQUEST)
|
|
group.adopted_asset_id = asset_id
|
|
group.save(update_fields=["adopted_asset", "updated_at"])
|
|
promote_base_asset_stage_if_ready(project)
|
|
return Response(BaseAssetGroupSerializer(group).data)
|
|
|
|
@action(detail=True, methods=["post"], url_path="generate-storyboard")
|
|
def generate_storyboard_action(self, request, pk=None):
|
|
project = self.get_object()
|
|
storyboard = generate_storyboard(project=project, user=request.user, prompt=request.data.get("prompt", ""))
|
|
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD)
|
|
stage.status = ProjectStage.Status.SUCCEEDED
|
|
stage.save(update_fields=["status", "updated_at"])
|
|
project.current_stage = ProjectStage.Stage.VIDEO
|
|
project.status = Project.Status.VIDEOING
|
|
project.save(update_fields=["current_stage", "status", "updated_at"])
|
|
return Response(StoryboardVersionSerializer(storyboard).data, status=status.HTTP_201_CREATED)
|
|
|
|
@action(detail=True, methods=["post"], url_path="skip-storyboard")
|
|
@transaction.atomic
|
|
def skip_storyboard(self, request, pk=None):
|
|
project = self.get_object()
|
|
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.STORYBOARD)
|
|
stage.status = ProjectStage.Status.SKIPPED
|
|
stage.save(update_fields=["status", "updated_at"])
|
|
project.current_stage = ProjectStage.Stage.VIDEO
|
|
project.status = Project.Status.VIDEOING
|
|
project.save(update_fields=["current_stage", "status", "updated_at"])
|
|
return Response(ProjectSerializer(project).data)
|
|
|
|
@action(detail=True, methods=["post"], url_path="submit-video-segment")
|
|
def submit_video_segment_action(self, request, pk=None):
|
|
project = self.get_object()
|
|
segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id"))
|
|
submit_video_segment(video_segment=segment, user=request.user, prompt=request.data.get("prompt", ""))
|
|
poll_video_segment_task.apply_async(args=[str(segment.id)], countdown=30)
|
|
return Response(ProjectSerializer(project).data, status=status.HTTP_202_ACCEPTED)
|
|
|
|
@action(detail=True, methods=["post"], url_path="poll-video-segment")
|
|
def poll_video_segment_action(self, request, pk=None):
|
|
project = self.get_object()
|
|
segment = VideoSegment.objects.get(project=project, id=request.data.get("video_segment_id"))
|
|
version = poll_video_segment(video_segment=segment, user=request.user)
|
|
if version is None:
|
|
return Response({"status": segment.status}, status=status.HTTP_202_ACCEPTED)
|
|
return Response(VideoSegmentVersionSerializer(version).data)
|
|
|
|
@action(detail=True, methods=["post"], url_path="submit-export")
|
|
@transaction.atomic
|
|
def submit_export(self, request, pk=None):
|
|
project = self.get_object()
|
|
missing_segments = project.video_segments.filter(adopted_version__isnull=True).count()
|
|
if missing_segments:
|
|
return Response(
|
|
{"detail": f"{missing_segments} video segments are not ready for export"},
|
|
status=status.HTTP_400_BAD_REQUEST,
|
|
)
|
|
timeline, _ = Timeline.objects.get_or_create(
|
|
project=project,
|
|
defaults={"name": f"{project.name} Timeline", "duration_seconds": 60},
|
|
)
|
|
if not timeline.clips.exists():
|
|
start_ms = 0
|
|
for segment in project.video_segments.select_related("adopted_version__asset").order_by("sort_order"):
|
|
if segment.adopted_version_id:
|
|
TimelineClip.objects.create(
|
|
timeline=timeline,
|
|
asset=segment.adopted_version.asset,
|
|
sort_order=segment.sort_order,
|
|
start_ms=start_ms,
|
|
duration_ms=segment.target_duration_seconds * 1000,
|
|
)
|
|
start_ms += segment.target_duration_seconds * 1000
|
|
export_job = create_export_job(timeline=timeline, user=request.user)
|
|
run_export_job_task.delay(str(export_job.id))
|
|
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.EXPORT)
|
|
stage.status = ProjectStage.Status.QUEUED
|
|
stage.save(update_fields=["status", "updated_at"])
|
|
project.current_stage = ProjectStage.Stage.EXPORT
|
|
project.status = Project.Status.EXPORTING
|
|
project.save(update_fields=["current_stage", "status", "updated_at"])
|
|
return Response(ExportJobSerializer(export_job).data, status=status.HTTP_202_ACCEPTED)
|