Some checks failed
Build and Deploy / build-and-deploy (push) Has been cancelled
- MutationObserver 立刻同步 editorHtml(删 @ 标签后时长/数量立即重置) - parseAssetMentionsFromDOM 从 DOM 实时读取(不用 stale state) - renderPromptWithMentions 支持音频 ♫ + 视频首帧 + assetType - rebuildMentionSpans 按 label 长度降序匹配(防子串冲突) - 删除素材后 group 缩略图优先找图片/视频(不用音频 URL) - 素材组整组删除功能(后端 DELETE + 前端按钮) - Celery poll 架构重构(一次性任务 + recover_stuck_tasks 统一驱动) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
216 lines
8.0 KiB
Python
216 lines
8.0 KiB
Python
"""Celery tasks for async video generation polling."""
|
||
|
||
import logging
|
||
|
||
from celery import shared_task
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@shared_task(ignore_result=True)
|
||
def poll_video_task(record_id):
|
||
"""Poll Volcano API once for a video generation task.
|
||
|
||
一次性任务:查一次 API,更新 DB,结束。
|
||
由 recover_stuck_tasks(beat 每10秒调度)统一驱动,不再自己 retry。
|
||
用 Redis 锁防止 _handle_completed 期间被重复 dispatch。
|
||
"""
|
||
from django.core.cache import cache
|
||
|
||
# Redis 锁:防止同一 record 被并发处理(_handle_completed 耗时较长)
|
||
lock_key = f'poll_lock:{record_id}'
|
||
if not cache.add(lock_key, '1', timeout=120):
|
||
return
|
||
|
||
try:
|
||
_do_poll(record_id)
|
||
except Exception:
|
||
logger.exception('poll_video_task: unexpected error for record=%s', record_id)
|
||
finally:
|
||
cache.delete(lock_key)
|
||
|
||
|
||
def _do_poll(record_id):
|
||
"""实际轮询逻辑,由 poll_video_task 调用。"""
|
||
from django.utils import timezone
|
||
from apps.generation.models import GenerationRecord
|
||
from utils.airdrama_client import query_task, map_status
|
||
|
||
try:
|
||
record = GenerationRecord.objects.get(pk=record_id)
|
||
except GenerationRecord.DoesNotExist:
|
||
logger.warning('poll_video_task: record %s not found', record_id)
|
||
return
|
||
|
||
if record.status not in ('queued', 'processing'):
|
||
return
|
||
|
||
ark_task_id = record.ark_task_id
|
||
if not ark_task_id:
|
||
logger.warning('poll_video_task: record %s has no ark_task_id', record_id)
|
||
return
|
||
|
||
# Poll Volcano API
|
||
try:
|
||
ark_resp = query_task(ark_task_id)
|
||
new_status = map_status(ark_resp.get('status', ''))
|
||
except Exception:
|
||
logger.exception('poll_video_task: API query failed for record=%s ark=%s', record_id, ark_task_id)
|
||
return
|
||
|
||
if new_status in ('queued', 'processing'):
|
||
record.status = new_status
|
||
record.save(update_fields=['status', 'updated_at'])
|
||
return
|
||
|
||
# Terminal state reached — process result
|
||
record.status = new_status
|
||
|
||
returned_seed = ark_resp.get('seed')
|
||
if returned_seed is not None:
|
||
record.seed = returned_seed
|
||
|
||
if new_status == 'completed':
|
||
_handle_completed(record, ark_resp)
|
||
elif new_status == 'failed':
|
||
_handle_failed(record, ark_resp)
|
||
|
||
record.completed_at = timezone.now()
|
||
record.save(update_fields=[
|
||
'status', 'result_url', 'thumbnail_url', 'error_message', 'raw_error',
|
||
'seed', 'completed_at',
|
||
])
|
||
|
||
logger.info(
|
||
'poll_video_task: record=%s ark=%s final_status=%s',
|
||
record_id, ark_task_id, new_status,
|
||
)
|
||
|
||
|
||
def _handle_completed(record, ark_resp):
|
||
"""Process a completed task: persist video to TOS, extract thumbnail, settle payment."""
|
||
import os
|
||
from utils.airdrama_client import extract_video_url
|
||
|
||
video_url = extract_video_url(ark_resp)
|
||
if video_url:
|
||
# Download once to temp file, reuse for TOS upload + thumbnail extraction
|
||
tmp_path = None
|
||
try:
|
||
from utils.media_utils import download_to_temp, extract_video_info_from_file
|
||
from utils.tos_client import upload_from_file_path, upload_file
|
||
|
||
tmp_path = download_to_temp(video_url, '.mp4')
|
||
|
||
# Upload video to TOS from file (streaming, no full memory load)
|
||
record.result_url = upload_from_file_path(tmp_path, folder='results', content_type='video/mp4')
|
||
|
||
# Extract thumbnail from the same local file (no second download)
|
||
thumb_file, _ = extract_video_info_from_file(tmp_path)
|
||
if thumb_file:
|
||
record.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
|
||
except Exception:
|
||
logger.exception('poll_video_task: failed to persist video / extract thumbnail')
|
||
if not record.result_url:
|
||
record.result_url = video_url
|
||
record.error_message = '视频保存失败,临时链接将在24小时后过期,请联系管理员'
|
||
finally:
|
||
if tmp_path and os.path.exists(tmp_path):
|
||
os.unlink(tmp_path)
|
||
|
||
# 结算:按实际 tokens 扣费
|
||
usage = ark_resp.get('usage', {})
|
||
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
|
||
if total_tokens > 0:
|
||
from apps.generation.views import _settle_payment
|
||
_settle_payment(record, total_tokens)
|
||
else:
|
||
from apps.generation.views import _release_freeze
|
||
_release_freeze(record)
|
||
|
||
|
||
@shared_task(ignore_result=True)
|
||
def recover_stuck_tasks():
|
||
"""每30秒扫一次所有进行中的任务,统一派发轮询。
|
||
|
||
poll_video_task 是一次性任务,不再自己 retry,由这里统一驱动。
|
||
"""
|
||
from apps.generation.models import GenerationRecord
|
||
|
||
active_records = GenerationRecord.objects.filter(
|
||
status__in=('queued', 'processing'),
|
||
ark_task_id__isnull=False,
|
||
).exclude(ark_task_id='').values_list('id', flat=True)
|
||
|
||
count = 0
|
||
for record_id in active_records:
|
||
try:
|
||
poll_video_task.delay(record_id)
|
||
count += 1
|
||
except Exception:
|
||
logger.error('recover_stuck_tasks: failed to dispatch record=%s', record_id)
|
||
|
||
if count:
|
||
logger.info('recover_stuck_tasks: dispatched %d active tasks', count)
|
||
|
||
|
||
def _handle_failed(record, ark_resp):
|
||
"""Process a failed task: record error and release frozen amount."""
|
||
from utils.airdrama_client import ERROR_MESSAGES
|
||
|
||
error = ark_resp.get('error', {})
|
||
code = error.get('code', '') if isinstance(error, dict) else ''
|
||
raw_msg = error.get('message', '') if isinstance(error, dict) else str(error)
|
||
record.error_message = ERROR_MESSAGES.get(code, raw_msg)
|
||
record.raw_error = f'{code}: {raw_msg}' if code else raw_msg
|
||
|
||
usage = ark_resp.get('usage', {})
|
||
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
|
||
if total_tokens > 0:
|
||
from apps.generation.views import _settle_payment
|
||
_settle_payment(record, total_tokens)
|
||
else:
|
||
from apps.generation.views import _release_freeze
|
||
_release_freeze(record)
|
||
|
||
|
||
@shared_task(ignore_result=True)
|
||
def process_asset_media(asset_id):
|
||
"""Extract thumbnail + duration for video/audio assets asynchronously."""
|
||
from apps.generation.models import Asset
|
||
try:
|
||
asset = Asset.objects.select_related('group').get(pk=asset_id)
|
||
except Asset.DoesNotExist:
|
||
logger.warning('process_asset_media: asset %s not found', asset_id)
|
||
return
|
||
|
||
from utils.media_utils import extract_video_info, get_audio_duration
|
||
from utils.tos_client import upload_file
|
||
|
||
if asset.asset_type == 'Video':
|
||
thumb_file, dur = extract_video_info(asset.url)
|
||
if thumb_file:
|
||
try:
|
||
asset.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
|
||
except Exception:
|
||
logger.exception('process_asset_media: thumbnail upload failed for asset %s', asset_id)
|
||
asset.duration = dur if dur > 0 else None # None = ffprobe failed, frontend skips duration check
|
||
asset.save(update_fields=['thumbnail_url', 'duration'])
|
||
# Atomic update: only set group thumbnail if still empty (concurrent-safe)
|
||
from apps.generation.models import AssetGroup
|
||
from django.db import transaction
|
||
try:
|
||
with transaction.atomic():
|
||
group = AssetGroup.objects.select_for_update().get(pk=asset.group_id)
|
||
if not group.thumbnail_url and asset.thumbnail_url:
|
||
group.thumbnail_url = asset.thumbnail_url
|
||
group.save(update_fields=['thumbnail_url'])
|
||
except AssetGroup.DoesNotExist:
|
||
logger.warning('process_asset_media: group %s deleted, skipping thumbnail update', asset.group_id)
|
||
elif asset.asset_type == 'Audio':
|
||
dur = get_audio_duration(asset.url)
|
||
asset.duration = dur if dur > 0 else None
|
||
asset.save(update_fields=['duration'])
|
||
|
||
logger.info('process_asset_media: asset %s done (type=%s, dur=%s)', asset_id, asset.asset_type, asset.duration)
|