2026-04-04 20:13:23 +08:00

198 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Celery tasks for async video generation polling."""
import logging
from celery import shared_task
logger = logging.getLogger(__name__)
@shared_task(ignore_result=True)
def poll_video_task(record_id):
"""Poll Volcano API once for a video generation task.
一次性任务:查一次 API更新 DB结束。
由 recover_stuck_tasksbeat 每30秒调度统一驱动不再自己 retry。
"""
from django.utils import timezone
from apps.generation.models import GenerationRecord
from utils.airdrama_client import query_task, map_status
try:
record = GenerationRecord.objects.get(pk=record_id)
except GenerationRecord.DoesNotExist:
logger.warning('poll_video_task: record %s not found', record_id)
return
if record.status not in ('queued', 'processing'):
return
ark_task_id = record.ark_task_id
if not ark_task_id:
logger.warning('poll_video_task: record %s has no ark_task_id', record_id)
return
# Poll Volcano API
try:
ark_resp = query_task(ark_task_id)
new_status = map_status(ark_resp.get('status', ''))
except Exception:
logger.exception('poll_video_task: API query failed for record=%s ark=%s', record_id, ark_task_id)
return
if new_status in ('queued', 'processing'):
record.status = new_status
record.save(update_fields=['status', 'updated_at'])
return
# Terminal state reached — process result
record.status = new_status
returned_seed = ark_resp.get('seed')
if returned_seed is not None:
record.seed = returned_seed
if new_status == 'completed':
_handle_completed(record, ark_resp)
elif new_status == 'failed':
_handle_failed(record, ark_resp)
record.completed_at = timezone.now()
record.save(update_fields=[
'status', 'result_url', 'thumbnail_url', 'error_message', 'raw_error',
'seed', 'completed_at',
])
logger.info(
'poll_video_task: record=%s ark=%s final_status=%s',
record_id, ark_task_id, new_status,
)
def _handle_completed(record, ark_resp):
"""Process a completed task: persist video to TOS, extract thumbnail, settle payment."""
import os
from utils.airdrama_client import extract_video_url
video_url = extract_video_url(ark_resp)
if video_url:
# Download once to temp file, reuse for TOS upload + thumbnail extraction
tmp_path = None
try:
from utils.media_utils import download_to_temp, extract_video_info_from_file
from utils.tos_client import upload_from_file_path, upload_file
tmp_path = download_to_temp(video_url, '.mp4')
# Upload video to TOS from file (streaming, no full memory load)
record.result_url = upload_from_file_path(tmp_path, folder='results', content_type='video/mp4')
# Extract thumbnail from the same local file (no second download)
thumb_file, _ = extract_video_info_from_file(tmp_path)
if thumb_file:
record.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
except Exception:
logger.exception('poll_video_task: failed to persist video / extract thumbnail')
if not record.result_url:
record.result_url = video_url
record.error_message = '视频保存失败临时链接将在24小时后过期请联系管理员'
finally:
if tmp_path and os.path.exists(tmp_path):
os.unlink(tmp_path)
# 结算:按实际 tokens 扣费
usage = ark_resp.get('usage', {})
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
if total_tokens > 0:
from apps.generation.views import _settle_payment
_settle_payment(record, total_tokens)
else:
from apps.generation.views import _release_freeze
_release_freeze(record)
@shared_task(ignore_result=True)
def recover_stuck_tasks():
"""每30秒扫一次所有进行中的任务统一派发轮询。
poll_video_task 是一次性任务,不再自己 retry由这里统一驱动。
"""
from apps.generation.models import GenerationRecord
active_records = GenerationRecord.objects.filter(
status__in=('queued', 'processing'),
ark_task_id__isnull=False,
).exclude(ark_task_id='').values_list('id', flat=True)
count = 0
for record_id in active_records:
try:
poll_video_task.delay(record_id)
count += 1
except Exception:
logger.error('recover_stuck_tasks: failed to dispatch record=%s', record_id)
if count:
logger.info('recover_stuck_tasks: dispatched %d active tasks', count)
def _handle_failed(record, ark_resp):
"""Process a failed task: record error and release frozen amount."""
from utils.airdrama_client import ERROR_MESSAGES
error = ark_resp.get('error', {})
code = error.get('code', '') if isinstance(error, dict) else ''
raw_msg = error.get('message', '') if isinstance(error, dict) else str(error)
record.error_message = ERROR_MESSAGES.get(code, raw_msg)
record.raw_error = f'{code}: {raw_msg}' if code else raw_msg
usage = ark_resp.get('usage', {})
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
if total_tokens > 0:
from apps.generation.views import _settle_payment
_settle_payment(record, total_tokens)
else:
from apps.generation.views import _release_freeze
_release_freeze(record)
@shared_task(ignore_result=True)
def process_asset_media(asset_id):
"""Extract thumbnail + duration for video/audio assets asynchronously."""
from apps.generation.models import Asset
try:
asset = Asset.objects.select_related('group').get(pk=asset_id)
except Asset.DoesNotExist:
logger.warning('process_asset_media: asset %s not found', asset_id)
return
from utils.media_utils import extract_video_info, get_audio_duration
from utils.tos_client import upload_file
if asset.asset_type == 'Video':
thumb_file, dur = extract_video_info(asset.url)
if thumb_file:
try:
asset.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
except Exception:
logger.exception('process_asset_media: thumbnail upload failed for asset %s', asset_id)
asset.duration = dur if dur > 0 else None # None = ffprobe failed, frontend skips duration check
asset.save(update_fields=['thumbnail_url', 'duration'])
# Atomic update: only set group thumbnail if still empty (concurrent-safe)
from apps.generation.models import AssetGroup
from django.db import transaction
try:
with transaction.atomic():
group = AssetGroup.objects.select_for_update().get(pk=asset.group_id)
if not group.thumbnail_url and asset.thumbnail_url:
group.thumbnail_url = asset.thumbnail_url
group.save(update_fields=['thumbnail_url'])
except AssetGroup.DoesNotExist:
logger.warning('process_asset_media: group %s deleted, skipping thumbnail update', asset.group_id)
elif asset.asset_type == 'Audio':
dur = get_audio_duration(asset.url)
asset.duration = dur if dur > 0 else None
asset.save(update_fields=['duration'])
logger.info('process_asset_media: asset %s done (type=%s, dur=%s)', asset_id, asset.asset_type, asset.duration)