seaislee1209 2e72c82116
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 4m39s
Merge branch 'dev' of https://gitea.airlabs.art/zyc/video-shuoshan into dev
2026-04-04 17:36:39 +08:00

201 lines
7.4 KiB
Python

"""Celery tasks for async video generation polling."""
import logging
from celery import shared_task
logger = logging.getLogger(__name__)
# 轮询间隔(秒):每次查完后重新入队,不占 worker 进程
POLL_INTERVAL = 5
@shared_task(bind=True, max_retries=None, ignore_result=True)
def poll_video_task(self, record_id):
"""Poll Volcano API for a video generation task.
每次只执行一轮查询,查完通过 self.retry 重新入队。
这样 worker 不会被 sleep 占死,重启也不丢任务。
"""
from django.utils import timezone
from apps.generation.models import GenerationRecord
from utils.airdrama_client import query_task, map_status
# 防重复:同一 record 同一时刻只允许一个 poll 在执行
from django.core.cache import cache
lock_key = f'poll_lock:{record_id}'
if not cache.add(lock_key, '1', timeout=POLL_INTERVAL * 3):
logger.info('poll_video_task: record %s already being polled, skipping', record_id)
return
try:
record = GenerationRecord.objects.get(pk=record_id)
except GenerationRecord.DoesNotExist:
logger.warning('poll_video_task: record %s not found', record_id)
cache.delete(lock_key)
return
ark_task_id = record.ark_task_id
if not ark_task_id:
logger.warning('poll_video_task: record %s has no ark_task_id', record_id)
cache.delete(lock_key)
return
if record.status not in ('queued', 'processing'):
logger.info('poll_video_task: record %s already in terminal state: %s', record_id, record.status)
cache.delete(lock_key)
return
# Poll Volcano API
try:
ark_resp = query_task(ark_task_id)
new_status = map_status(ark_resp.get('status', ''))
except Exception:
logger.exception('poll_video_task: API query failed for %s, will retry', ark_task_id)
cache.delete(lock_key)
raise self.retry(countdown=POLL_INTERVAL)
if new_status in ('queued', 'processing'):
# Still running — update status, then re-enqueue
record.status = new_status
record.save(update_fields=['status', 'updated_at'])
cache.delete(lock_key)
raise self.retry(countdown=POLL_INTERVAL)
# Terminal state reached — process result
record.status = new_status
returned_seed = ark_resp.get('seed')
if returned_seed is not None:
record.seed = returned_seed
if new_status == 'completed':
_handle_completed(record, ark_resp)
elif new_status == 'failed':
_handle_failed(record, ark_resp)
record.completed_at = timezone.now()
record.save(update_fields=[
'status', 'result_url', 'thumbnail_url', 'error_message', 'raw_error',
'seed', 'completed_at',
])
logger.info(
'poll_video_task: record=%s ark=%s final_status=%s',
record_id, ark_task_id, new_status,
)
def _handle_completed(record, ark_resp):
"""Process a completed task: persist video to TOS and settle payment."""
from utils.airdrama_client import extract_video_url
video_url = extract_video_url(ark_resp)
if video_url:
try:
from utils.tos_client import upload_from_url
record.result_url = upload_from_url(video_url, folder='results')
except Exception:
logger.exception('poll_video_task: failed to persist video to TOS')
record.result_url = video_url
# Extract thumbnail from completed video
try:
from utils.media_utils import extract_video_info
from utils.tos_client import upload_file
thumb_file, _ = extract_video_info(record.result_url)
if thumb_file:
record.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
except Exception:
logger.exception('poll_video_task: failed to extract video thumbnail')
# 结算:按实际 tokens 扣费
usage = ark_resp.get('usage', {})
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
if total_tokens > 0:
from apps.generation.views import _settle_payment
_settle_payment(record, total_tokens)
else:
from apps.generation.views import _release_freeze
_release_freeze(record)
@shared_task(ignore_result=True)
def recover_stuck_tasks():
"""定时扫描卡在 processing/queued 超过 3 分钟的任务,重新派发轮询。"""
from datetime import timedelta
from django.utils import timezone
from apps.generation.models import GenerationRecord
cutoff = timezone.now() - timedelta(minutes=3)
stuck_records = GenerationRecord.objects.filter(
status__in=('queued', 'processing'),
ark_task_id__isnull=False,
updated_at__lt=cutoff,
).exclude(ark_task_id='')
count = 0
for record in stuck_records:
logger.warning('recover_stuck_tasks: re-dispatching record=%s ark=%s', record.id, record.ark_task_id)
try:
poll_video_task.delay(record.id)
count += 1
except Exception:
logger.error('recover_stuck_tasks: failed to dispatch record=%s', record.id)
if count:
logger.info('recover_stuck_tasks: re-dispatched %d stuck tasks', count)
def _handle_failed(record, ark_resp):
"""Process a failed task: record error and release frozen amount."""
from utils.airdrama_client import ERROR_MESSAGES
error = ark_resp.get('error', {})
code = error.get('code', '') if isinstance(error, dict) else ''
raw_msg = error.get('message', '') if isinstance(error, dict) else str(error)
record.error_message = ERROR_MESSAGES.get(code, raw_msg)
record.raw_error = f'{code}: {raw_msg}' if code else raw_msg
usage = ark_resp.get('usage', {})
total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0
if total_tokens > 0:
from apps.generation.views import _settle_payment
_settle_payment(record, total_tokens)
else:
from apps.generation.views import _release_freeze
_release_freeze(record)
@shared_task(ignore_result=True)
def process_asset_media(asset_id):
"""Extract thumbnail + duration for video/audio assets asynchronously."""
from apps.generation.models import Asset
try:
asset = Asset.objects.select_related('group').get(pk=asset_id)
except Asset.DoesNotExist:
logger.warning('process_asset_media: asset %s not found', asset_id)
return
from utils.media_utils import extract_video_info, get_audio_duration
from utils.tos_client import upload_file
if asset.asset_type == 'Video':
thumb_file, dur = extract_video_info(asset.url)
if thumb_file:
try:
asset.thumbnail_url = upload_file(thumb_file, folder='thumbnails')
except Exception:
logger.exception('process_asset_media: thumbnail upload failed for asset %s', asset_id)
asset.duration = dur
asset.save(update_fields=['thumbnail_url', 'duration'])
group = asset.group
if not group.thumbnail_url and asset.thumbnail_url:
group.thumbnail_url = asset.thumbnail_url
group.save(update_fields=['thumbnail_url'])
elif asset.asset_type == 'Audio':
asset.duration = get_audio_duration(asset.url)
asset.save(update_fields=['duration'])
logger.info('process_asset_media: asset %s done (type=%s, dur=%.1f)', asset_id, asset.asset_type, asset.duration)