seaislee1209 9a6d95a69d
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 3m13s
fix: v0.18.0 商业级加固 — 并发安全、流式上传、错误反馈、类型修复
- TOS 流式上传 upload_from_file_path(避免大文件 OOM)
- 视频生成完成后下载一次复用(TOS 上传 + 首帧提取)
- 并发安全:group thumbnail 用 select_for_update 原子更新
- 跨团队校验:_resolve_asset_group_all 加 group__team 过滤
- 异常信息脱敏:文件上传失败不再泄露内部异常
- SSRF 防护:download_to_temp 校验 URL scheme
- poll lock 终态释放:cache.delete 在 record.save 后调用
- duration=null 语义区分:ffprobe 失败存 None 非 0
- 前端 duration 未知 toast 警告:素材时长未确定时提示用户
- 搜索 API 失败 toast:素材搜索失败时反馈用户
- 视频保存降级标记:临时 URL 降级时设 error_message
- TypeScript 类型修复:AssetItem/AssetSearchResult.duration 改为 number|null
- rebuildMentionSpans 补完 assetId/assetType/assetName/duration 属性
- paste DOMPurify 白名单补完新 data attributes
- resolved_url NameError 修复:非素材库视频/音频引用用 url
- process_asset_media group 删除保护
- download_to_temp 改为 public API
- 清理前端死代码

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-04 18:49:08 +08:00

218 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Celery tasks for async video generation polling."""
import logging
from celery import shared_task
logger = logging.getLogger(__name__)
# 轮询间隔(秒):每次查完后重新入队,不占 worker 进程
POLL_INTERVAL = 5
@shared_task(bind=True, max_retries=None, ignore_result=True)
def poll_video_task(self, record_id):
"""Poll Volcano API for a video generation task.
每次只执行一轮查询,查完通过 self.retry 重新入队。
这样 worker 不会被 sleep 占死,重启也不丢任务。
"""
from django.utils import timezone
from apps.generation.models import GenerationRecord
from utils.airdrama_client import query_task, map_status
# 防重复:同一 record 同一时刻只允许一个 poll 在执行
from django.core.cache import cache
lock_key = f'poll_lock:{record_id}'
if not cache.add(lock_key, '1', timeout=POLL_INTERVAL * 3):
logger.info('poll_video_task: record %s already being polled, skipping', record_id)
return
try:
record = GenerationRecord.objects.get(pk=record_id)
except GenerationRecord.DoesNotExist:
logger.warning('poll_video_task: record %s not found', record_id)
cache.delete(lock_key)
return
ark_task_id = record.ark_task_id
if not ark_task_id:
logger.warning('poll_video_task: record %s has no ark_task_id', record_id)
cache.delete(lock_key)
return
if record.status not in ('queued', 'processing'):
logger.info('poll_video_task: record %s already in terminal state: %s', record_id, record.status)
cache.delete(lock_key)
return
# Poll Volcano API
try:
ark_resp = query_task(ark_task_id)
new_status = map_status(ark_resp.get('status', ''))
except Exception:
logger.exception('poll_video_task: API query failed for %s, will retry', ark_task_id)
cache.delete(lock_key)
raise self.retry(countdown=POLL_INTERVAL)
if new_status in ('queued', 'processing'):
# Still running — update status, then re-enqueue
record.status = new_status
record.save(update_fields=['status', 'updated_at'])
cache.delete(lock_key)
raise self.retry(countdown=POLL_INTERVAL)
# Terminal state reached — process result
record.status = new_status
returned_seed = ark_resp.get('seed')
if returned_seed is not None:
record.seed = returned_seed
if new_status == 'completed':
_handle_completed(record, ark_resp)
elif new_status == 'failed':
_handle_failed(record, ark_resp)
record.completed_at = timezone.now()
record.save(update_fields=[
'status', 'result_url', 'thumbnail_url', 'error_message', 'raw_error',
'seed', 'completed_at',
])
cache.delete(lock_key)
logger.info(
'poll_video_task: record=%s ark=%s final_status=%s',
record_id, ark_task_id, new_status,
)
def _handle_completed(record, ark_resp):
"""Process a completed task: persist video to TOS, extract thumbnail, settle payment."""
import os
from utils.airdrama_client import extract_video_url
video_url = extract_video_url(ark_resp)
if video_url:
# Download once to temp file, reuse for TOS upload + thumbnail extraction
tmp_path = None
try:
from utils.media_utils import download_to_temp, extract_video_info_from_file
from utils.tos_client import upload_from_file_path, upload_file
tmp_path = download_to_temp(video_url, '.mp4')
# Upload video to TOS from file (streaming, no full memory load)
record.result_url = upload_from_file_path(tmp_path, folder='results', content_type='video/mp4')
# Extract thumbnail from the same local file (no second download)
thumb_file, _ = extract_video_info_from_file(tmp_path)
if thumb_file:
record.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
except Exception:
logger.exception('poll_video_task: failed to persist video / extract thumbnail')
if not record.result_url:
record.result_url = video_url
record.error_message = '视频保存失败临时链接将在24小时后过期请联系管理员'
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)
# 结算:按实际 tokens 扣费
usage = ark_resp.get('usage', {})
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
if total_tokens > 0:
from apps.generation.views import _settle_payment
_settle_payment(record, total_tokens)
else:
from apps.generation.views import _release_freeze
_release_freeze(record)
@shared_task(ignore_result=True)
def recover_stuck_tasks():
"""定时扫描卡在 processing/queued 超过 3 分钟的任务,重新派发轮询。"""
from datetime import timedelta
from django.utils import timezone
from apps.generation.models import GenerationRecord
cutoff = timezone.now() - timedelta(minutes=3)
stuck_records = GenerationRecord.objects.filter(
status__in=('queued', 'processing'),
ark_task_id__isnull=False,
updated_at__lt=cutoff,
).exclude(ark_task_id='')
count = 0
for record in stuck_records:
logger.warning('recover_stuck_tasks: re-dispatching record=%s ark=%s', record.id, record.ark_task_id)
try:
poll_video_task.delay(record.id)
count += 1
except Exception:
logger.error('recover_stuck_tasks: failed to dispatch record=%s', record.id)
if count:
logger.info('recover_stuck_tasks: re-dispatched %d stuck tasks', count)
def _handle_failed(record, ark_resp):
"""Process a failed task: record error and release frozen amount."""
from utils.airdrama_client import ERROR_MESSAGES
error = ark_resp.get('error', {})
code = error.get('code', '') if isinstance(error, dict) else ''
raw_msg = error.get('message', '') if isinstance(error, dict) else str(error)
record.error_message = ERROR_MESSAGES.get(code, raw_msg)
record.raw_error = f'{code}: {raw_msg}' if code else raw_msg
usage = ark_resp.get('usage', {})
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
if total_tokens > 0:
from apps.generation.views import _settle_payment
_settle_payment(record, total_tokens)
else:
from apps.generation.views import _release_freeze
_release_freeze(record)
@shared_task(ignore_result=True)
def process_asset_media(asset_id):
"""Extract thumbnail + duration for video/audio assets asynchronously."""
from apps.generation.models import Asset
try:
asset = Asset.objects.select_related('group').get(pk=asset_id)
except Asset.DoesNotExist:
logger.warning('process_asset_media: asset %s not found', asset_id)
return
from utils.media_utils import extract_video_info, get_audio_duration
from utils.tos_client import upload_file
if asset.asset_type == 'Video':
thumb_file, dur = extract_video_info(asset.url)
if thumb_file:
try:
asset.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
except Exception:
logger.exception('process_asset_media: thumbnail upload failed for asset %s', asset_id)
asset.duration = dur if dur > 0 else None # None = ffprobe failed, frontend skips duration check
asset.save(update_fields=['thumbnail_url', 'duration'])
# Atomic update: only set group thumbnail if still empty (concurrent-safe)
from apps.generation.models import AssetGroup
from django.db import transaction
try:
with transaction.atomic():
group = AssetGroup.objects.select_for_update().get(pk=asset.group_id)
if not group.thumbnail_url and asset.thumbnail_url:
group.thumbnail_url = asset.thumbnail_url
group.save(update_fields=['thumbnail_url'])
except AssetGroup.DoesNotExist:
logger.warning('process_asset_media: group %s deleted, skipping thumbnail update', asset.group_id)
elif asset.asset_type == 'Audio':
dur = get_audio_duration(asset.url)
asset.duration = dur if dur > 0 else None
asset.save(update_fields=['duration'])
logger.info('process_asset_media: asset %s done (type=%s, dur=%s)', asset_id, asset.asset_type, asset.duration)