AirShelf/core/backend/apps/ai/services.py

422 lines
18 KiB
Python

import uuid
from decimal import Decimal
from django.db import transaction
from django.utils import timezone
from apps.ai.models import AITask, ModelConfig
from apps.ai.providers import VolcanoArkProvider
from apps.assets.models import Asset, AssetFile
from apps.assets.storage import TosStorage
from apps.billing.services.ledger import charge_reserved_credit, release_credit, reserve_credit
from apps.projects.models import (
BaseAssetGroup,
ExportJob,
ProjectStage,
ScriptSegment,
ScriptVersion,
StoryboardFrame,
StoryboardVersion,
VideoSegment,
VideoSegmentVersion,
)
def get_default_model(capability: str) -> ModelConfig:
return (
ModelConfig.objects.select_related("provider")
.filter(capability=capability, status=ModelConfig.Status.ACTIVE, provider__status="active")
.order_by("created_at")
.first()
)
def estimate_cost(model_config: ModelConfig) -> Decimal:
return model_config.unit_price if model_config.unit_price > 0 else Decimal("1.0000")
def build_script_prompt(*, project, user_prompt: str, selling_point_ids: list[str] | None = None) -> list[dict[str, str]]:
product = project.product
selling_points = product.selling_points.all()
if selling_point_ids:
selling_points = selling_points.filter(id__in=selling_point_ids)
selling_text = "\n".join(f"- {item.title}: {item.detail}" for item in selling_points)
system = (
"你是电商短视频脚本导演。请为 9:16 竖屏带货短视频生成 60 秒脚本,"
"拆成 4 个 15 秒段落。每段包含旁白、画面描述、商品露出方式和转场建议。"
)
user = f"""
商品标题:{product.title}
品牌:{product.brand or "未填写"}
类目:{product.category or "未填写"}
目标人群:{product.target_audience or "未填写"}
商品描述:{product.description or "未填写"}
卖点:
{selling_text or "未选择卖点,请根据商品信息自行提炼。"}
用户补充需求:
{user_prompt or "生成一条结构完整、节奏清晰、适合投放的带货短视频脚本。"}
""".strip()
return [{"role": "system", "content": system}, {"role": "user", "content": user}]
def split_script_into_segments(content: str) -> list[str]:
blocks = [line.strip() for line in content.splitlines() if line.strip()]
if len(blocks) >= 4:
return blocks[:4]
if not content.strip():
return [""] * 4
return [content.strip()] + [""] * (4 - len(blocks or [content]))
@transaction.atomic
def create_ai_task(*, project, user, task_type: str, model_config: ModelConfig, request_payload: dict) -> AITask:
cost = estimate_cost(model_config)
task = AITask.objects.create(
team=project.team,
created_by=user,
project=project,
task_type=task_type,
status=AITask.Status.CREATED,
model_config=model_config,
idempotency_key=f"{task_type}:{project.id}:{uuid.uuid4()}",
request_payload=request_payload,
estimated_cost=cost,
)
reserve_credit(team=project.team, user=user, task=task, amount=cost)
task.status = AITask.Status.RESERVED
task.save(update_fields=["status", "updated_at"])
return task
def generate_project_script(*, project, user, user_prompt: str, selling_point_ids: list[str] | None = None) -> ScriptVersion:
model_config = get_default_model(ModelConfig.Capability.TEXT)
if model_config is None:
raise ValueError("no active text model configured")
messages = build_script_prompt(project=project, user_prompt=user_prompt, selling_point_ids=selling_point_ids)
payload = {"model": model_config.name, "endpoint": model_config.endpoint, "messages": messages}
task = create_ai_task(
project=project,
user=user,
task_type=AITask.Type.SCRIPT_GENERATION,
model_config=model_config,
request_payload=payload,
)
reservation = task.credit_reservation
try:
task.status = AITask.Status.SUBMITTED
task.submitted_at = timezone.now()
task.save(update_fields=["status", "submitted_at", "updated_at"])
provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None)
response = provider.chat_completion(model=model_config.name, endpoint=model_config.endpoint, messages=messages)
content = provider.extract_text(response)
with transaction.atomic():
task.status = AITask.Status.SUCCEEDED
task.response_payload = response
task.actual_cost = task.estimated_cost
task.completed_at = timezone.now()
task.save(update_fields=["status", "response_payload", "actual_cost", "completed_at", "updated_at"])
charge_reserved_credit(reservation=reservation, actual_amount=task.actual_cost)
script = ScriptVersion.objects.create(
project=project,
task=task,
title="AI 脚本",
content=content,
source="ai",
is_adopted=False,
)
for index, segment_text in enumerate(split_script_into_segments(content)):
ScriptSegment.objects.create(
script_version=script,
sort_order=index,
duration_seconds=15,
narration=segment_text,
visual_prompt=segment_text,
)
stage, _ = ProjectStage.objects.get_or_create(project=project, stage=ProjectStage.Stage.SCRIPT)
stage.status = ProjectStage.Status.NEEDS_REVIEW
stage.save(update_fields=["status", "updated_at"])
return script
except Exception as exc:
with transaction.atomic():
task.status = AITask.Status.FAILED
task.error_message = str(exc)
task.completed_at = timezone.now()
task.save(update_fields=["status", "error_message", "completed_at", "updated_at"])
release_credit(reservation=reservation, reason=str(exc))
raise
def _store_generated_media(*, team, user, project, task, media: str, name: str, category: str, asset_type: str) -> Asset:
fileobj, content_type = VolcanoArkProvider.media_to_bytes(media)
suffix = ".png"
if "video" in content_type:
suffix = ".mp4"
elif "jpeg" in content_type:
suffix = ".jpg"
elif "webp" in content_type:
suffix = ".webp"
asset_id = uuid.uuid4()
object_key = f"teams/{team.id}/projects/{project.id}/generated/{asset_id}{suffix}"
stored = TosStorage().upload_fileobj(fileobj=fileobj, object_key=object_key, content_type=content_type)
asset = Asset.objects.create(
id=asset_id,
team=team,
created_by=user,
name=name,
asset_type=asset_type,
source=Asset.Source.AI_GENERATED,
category=category,
origin_task=task,
)
AssetFile.objects.create(
asset=asset,
object_key=stored.object_key,
bucket=stored.bucket,
content_type=stored.content_type,
size_bytes=stored.size_bytes,
is_primary=True,
)
return asset
def generate_base_asset(*, project, user, kind: str, prompt: str) -> BaseAssetGroup:
model_config = get_default_model(ModelConfig.Capability.IMAGE)
if model_config is None:
raise ValueError("no active image model configured")
payload = {"model": model_config.name, "endpoint": model_config.endpoint, "prompt": prompt, "kind": kind}
task = create_ai_task(
project=project,
user=user,
task_type={
BaseAssetGroup.Kind.PRODUCT: AITask.Type.PRODUCT_IMAGE,
BaseAssetGroup.Kind.PERSON: AITask.Type.PERSON_IMAGE,
BaseAssetGroup.Kind.SCENE: AITask.Type.SCENE_IMAGE,
}[kind],
model_config=model_config,
request_payload=payload,
)
reservation = task.credit_reservation
try:
provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None)
response = provider.image_generation(model=model_config.name, endpoint=model_config.endpoint, prompt=prompt)
media = provider.extract_first_media_url(response)
with transaction.atomic():
task.status = AITask.Status.SUCCEEDED
task.response_payload = response
task.actual_cost = task.estimated_cost
task.completed_at = timezone.now()
task.save(update_fields=["status", "response_payload", "actual_cost", "completed_at", "updated_at"])
charge_reserved_credit(reservation=reservation, actual_amount=task.actual_cost)
category = {
BaseAssetGroup.Kind.PRODUCT: Asset.Category.PRODUCT_IMAGE,
BaseAssetGroup.Kind.PERSON: Asset.Category.PERSON,
BaseAssetGroup.Kind.SCENE: Asset.Category.SCENE,
}[kind]
asset = _store_generated_media(
team=project.team,
user=user,
project=project,
task=task,
media=media,
name=f"{project.name}-{kind}",
category=category,
asset_type=Asset.Type.IMAGE,
)
group = BaseAssetGroup.objects.create(project=project, kind=kind, task=task, prompt=prompt)
group.candidate_assets.add(asset)
group.adopted_asset = asset
group.save(update_fields=["adopted_asset", "updated_at"])
return group
except Exception as exc:
task.status = AITask.Status.FAILED
task.error_message = str(exc)
task.completed_at = timezone.now()
task.save(update_fields=["status", "error_message", "completed_at", "updated_at"])
release_credit(reservation=reservation, reason=str(exc))
raise
def generate_storyboard(*, project, user, prompt: str = "") -> StoryboardVersion:
adopted_script = project.script_versions.filter(is_adopted=True).prefetch_related("segments").first()
if adopted_script is None:
raise ValueError("script must be adopted before generating storyboard")
model_config = get_default_model(ModelConfig.Capability.IMAGE)
if model_config is None:
raise ValueError("no active image model configured")
storyboard = StoryboardVersion.objects.create(project=project, prompt=prompt)
provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None)
for segment in adopted_script.segments.all():
task = create_ai_task(
project=project,
user=user,
task_type=AITask.Type.STORYBOARD,
model_config=model_config,
request_payload={"model": model_config.name, "endpoint": model_config.endpoint, "prompt": segment.visual_prompt},
)
reservation = task.credit_reservation
try:
response = provider.image_generation(
model=model_config.name,
endpoint=model_config.endpoint,
prompt=f"{prompt}\n{segment.visual_prompt}".strip(),
)
media = provider.extract_first_media_url(response)
task.status = AITask.Status.SUCCEEDED
task.response_payload = response
task.actual_cost = task.estimated_cost
task.completed_at = timezone.now()
task.save(update_fields=["status", "response_payload", "actual_cost", "completed_at", "updated_at"])
charge_reserved_credit(reservation=reservation, actual_amount=task.actual_cost)
asset = _store_generated_media(
team=project.team,
user=user,
project=project,
task=task,
media=media,
name=f"{project.name}-storyboard-{segment.sort_order + 1}",
category=Asset.Category.SCENE,
asset_type=Asset.Type.IMAGE,
)
StoryboardFrame.objects.create(
storyboard=storyboard,
script_segment=segment,
asset=asset,
sort_order=segment.sort_order,
prompt=segment.visual_prompt,
)
except Exception as exc:
task.status = AITask.Status.FAILED
task.error_message = str(exc)
task.completed_at = timezone.now()
task.save(update_fields=["status", "error_message", "completed_at", "updated_at"])
release_credit(reservation=reservation, reason=str(exc))
raise
storyboard.is_adopted = True
storyboard.save(update_fields=["is_adopted", "updated_at"])
return storyboard
def submit_video_segment(*, video_segment: VideoSegment, user, prompt: str) -> VideoSegmentVersion | None:
model_config = get_default_model(ModelConfig.Capability.VIDEO)
if model_config is None:
raise ValueError("no active video model configured")
project = video_segment.project
task = create_ai_task(
project=project,
user=user,
task_type=AITask.Type.VIDEO_SEGMENT,
model_config=model_config,
request_payload={
"model": model_config.name,
"endpoint": model_config.endpoint,
"prompt": prompt,
"duration": video_segment.target_duration_seconds,
"ratio": "9:16",
"video_segment_id": str(video_segment.id),
},
)
try:
provider = VolcanoArkProvider(base_url=model_config.provider.base_url or None)
response = provider.create_video_task(
model=model_config.name,
endpoint=model_config.endpoint,
prompt=prompt,
duration=video_segment.target_duration_seconds,
ratio="9:16",
resolution="720p",
)
task.provider_task_id = str(response.get("id") or response.get("task_id") or "")
task.response_payload = response
task.status = AITask.Status.SUBMITTED
task.submitted_at = timezone.now()
task.save(update_fields=["provider_task_id", "response_payload", "status", "submitted_at", "updated_at"])
video_segment.status = VideoSegment.Status.RUNNING
video_segment.save(update_fields=["status", "updated_at"])
return None
except Exception as exc:
task.status = AITask.Status.FAILED
task.error_message = str(exc)
task.completed_at = timezone.now()
task.save(update_fields=["status", "error_message", "completed_at", "updated_at"])
release_credit(reservation=task.credit_reservation, reason=str(exc))
video_segment.status = VideoSegment.Status.FAILED
video_segment.error_message = str(exc)
video_segment.save(update_fields=["status", "error_message", "updated_at"])
raise
def poll_video_segment(*, video_segment: VideoSegment, user) -> VideoSegmentVersion | None:
task = video_segment.versions.order_by("-created_at").first()
ai_task = None
if task:
ai_task = task.task
if ai_task is None:
ai_task = video_segment.project.ai_tasks.filter(
task_type=AITask.Type.VIDEO_SEGMENT,
request_payload__video_segment_id=str(video_segment.id),
status__in=[AITask.Status.SUBMITTED, AITask.Status.POLLING],
).order_by("-created_at").first()
if ai_task is None:
raise ValueError("no active video generation task")
provider = VolcanoArkProvider(base_url=ai_task.model_config.provider.base_url or None)
response = provider.poll_video_task(endpoint=ai_task.model_config.endpoint, provider_task_id=ai_task.provider_task_id)
remote_status = response.get("status")
if remote_status in {"queued", "running", "processing"}:
ai_task.status = AITask.Status.POLLING
ai_task.response_payload = response
ai_task.save(update_fields=["status", "response_payload", "updated_at"])
return None
if remote_status in {"failed", "expired", "cancelled"}:
ai_task.status = AITask.Status.FAILED
ai_task.response_payload = response
ai_task.error_message = response.get("error", {}).get("message", "video generation failed")
ai_task.completed_at = timezone.now()
ai_task.save(update_fields=["status", "response_payload", "error_message", "completed_at", "updated_at"])
release_credit(reservation=ai_task.credit_reservation, reason=ai_task.error_message)
video_segment.status = VideoSegment.Status.FAILED
video_segment.error_message = ai_task.error_message
video_segment.save(update_fields=["status", "error_message", "updated_at"])
return None
media = provider.extract_first_media_url(response)
asset = _store_generated_media(
team=video_segment.project.team,
user=user,
project=video_segment.project,
task=ai_task,
media=media,
name=f"{video_segment.project.name}-segment-{video_segment.sort_order + 1}",
category=Asset.Category.VIDEO_CLIP,
asset_type=Asset.Type.VIDEO,
)
ai_task.status = AITask.Status.SUCCEEDED
ai_task.response_payload = response
ai_task.actual_cost = ai_task.estimated_cost
ai_task.completed_at = timezone.now()
ai_task.save(update_fields=["status", "response_payload", "actual_cost", "completed_at", "updated_at"])
charge_reserved_credit(reservation=ai_task.credit_reservation, actual_amount=ai_task.actual_cost)
version = VideoSegmentVersion.objects.create(
video_segment=video_segment,
task=ai_task,
asset=asset,
prompt=ai_task.request_payload.get("prompt", ""),
is_adopted=True,
)
video_segment.adopted_version = version
video_segment.status = VideoSegment.Status.SUCCEEDED
video_segment.error_message = ""
video_segment.save(update_fields=["adopted_version", "status", "error_message", "updated_at"])
return version
def create_export_job(*, timeline, user) -> ExportJob:
return ExportJob.objects.create(timeline=timeline, status=ExportJob.Status.QUEUED)