"""Celery tasks for async video generation polling.""" import logging from celery import shared_task logger = logging.getLogger(__name__) # 轮询间隔(秒):每次查完后重新入队,不占 worker 进程 POLL_INTERVAL = 5 @shared_task(bind=True, max_retries=None, ignore_result=True) def poll_video_task(self, record_id): """Poll Volcano API for a video generation task. 每次只执行一轮查询,查完通过 self.retry 重新入队。 这样 worker 不会被 sleep 占死,重启也不丢任务。 """ from django.utils import timezone from apps.generation.models import GenerationRecord from utils.airdrama_client import query_task, map_status # 防重复:同一 record 同一时刻只允许一个 poll 在执行 from django.core.cache import cache lock_key = f'poll_lock:{record_id}' if not cache.add(lock_key, '1', timeout=POLL_INTERVAL * 3): logger.info('poll_video_task: record %s already being polled, skipping', record_id) return try: record = GenerationRecord.objects.get(pk=record_id) except GenerationRecord.DoesNotExist: logger.warning('poll_video_task: record %s not found', record_id) cache.delete(lock_key) return ark_task_id = record.ark_task_id if not ark_task_id: logger.warning('poll_video_task: record %s has no ark_task_id', record_id) cache.delete(lock_key) return if record.status not in ('queued', 'processing'): logger.info('poll_video_task: record %s already in terminal state: %s', record_id, record.status) cache.delete(lock_key) return # Poll Volcano API try: ark_resp = query_task(ark_task_id) new_status = map_status(ark_resp.get('status', '')) except Exception: logger.exception('poll_video_task: API query failed for %s, will retry', ark_task_id) cache.delete(lock_key) raise self.retry(countdown=POLL_INTERVAL) if new_status in ('queued', 'processing'): # Still running — update status, then re-enqueue record.status = new_status record.save(update_fields=['status', 'updated_at']) cache.delete(lock_key) raise self.retry(countdown=POLL_INTERVAL) # Terminal state reached — process result record.status = new_status returned_seed = ark_resp.get('seed') if returned_seed is not None: record.seed = returned_seed if new_status == 'completed': _handle_completed(record, ark_resp) elif new_status == 'failed': _handle_failed(record, ark_resp) record.completed_at = timezone.now() record.save(update_fields=[ 'status', 'result_url', 'thumbnail_url', 'error_message', 'raw_error', 'seed', 'completed_at', ]) logger.info( 'poll_video_task: record=%s ark=%s final_status=%s', record_id, ark_task_id, new_status, ) def _handle_completed(record, ark_resp): """Process a completed task: persist video to TOS and settle payment.""" from utils.airdrama_client import extract_video_url video_url = extract_video_url(ark_resp) if video_url: try: from utils.tos_client import upload_from_url record.result_url = upload_from_url(video_url, folder='results') except Exception: logger.exception('poll_video_task: failed to persist video to TOS') record.result_url = video_url # Extract thumbnail from completed video try: from utils.media_utils import extract_video_info from utils.tos_client import upload_file thumb_file, _ = extract_video_info(record.result_url) if thumb_file: record.thumbnail_url = upload_file(thumb_file, folder='thumbnails') except Exception: logger.exception('poll_video_task: failed to extract video thumbnail') # 结算:按实际 tokens 扣费 usage = ark_resp.get('usage', {}) total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0 if total_tokens > 0: from apps.generation.views import _settle_payment _settle_payment(record, total_tokens) else: from apps.generation.views import _release_freeze _release_freeze(record) @shared_task(ignore_result=True) def recover_stuck_tasks(): """定时扫描卡在 processing/queued 超过 3 分钟的任务,重新派发轮询。""" from datetime import timedelta from django.utils import timezone from apps.generation.models import GenerationRecord cutoff = timezone.now() - timedelta(minutes=3) stuck_records = GenerationRecord.objects.filter( status__in=('queued', 'processing'), ark_task_id__isnull=False, updated_at__lt=cutoff, ).exclude(ark_task_id='') count = 0 for record in stuck_records: logger.warning('recover_stuck_tasks: re-dispatching record=%s ark=%s', record.id, record.ark_task_id) try: poll_video_task.delay(record.id) count += 1 except Exception: logger.error('recover_stuck_tasks: failed to dispatch record=%s', record.id) if count: logger.info('recover_stuck_tasks: re-dispatched %d stuck tasks', count) def _handle_failed(record, ark_resp): """Process a failed task: record error and release frozen amount.""" from utils.airdrama_client import ERROR_MESSAGES error = ark_resp.get('error', {}) code = error.get('code', '') if isinstance(error, dict) else '' raw_msg = error.get('message', '') if isinstance(error, dict) else str(error) record.error_message = ERROR_MESSAGES.get(code, raw_msg) record.raw_error = f'{code}: {raw_msg}' if code else raw_msg usage = ark_resp.get('usage', {}) total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0 if total_tokens > 0: from apps.generation.views import _settle_payment _settle_payment(record, total_tokens) else: from apps.generation.views import _release_freeze _release_freeze(record) @shared_task(ignore_result=True) def process_asset_media(asset_id): """Extract thumbnail + duration for video/audio assets asynchronously.""" from apps.generation.models import Asset try: asset = Asset.objects.select_related('group').get(pk=asset_id) except Asset.DoesNotExist: logger.warning('process_asset_media: asset %s not found', asset_id) return from utils.media_utils import extract_video_info, get_audio_duration from utils.tos_client import upload_file if asset.asset_type == 'Video': thumb_file, dur = extract_video_info(asset.url) if thumb_file: try: asset.thumbnail_url = upload_file(thumb_file, folder='thumbnails') except Exception: logger.exception('process_asset_media: thumbnail upload failed for asset %s', asset_id) asset.duration = dur asset.save(update_fields=['thumbnail_url', 'duration']) group = asset.group if not group.thumbnail_url and asset.thumbnail_url: group.thumbnail_url = asset.thumbnail_url group.save(update_fields=['thumbnail_url']) elif asset.asset_type == 'Audio': asset.duration = get_audio_duration(asset.url) asset.save(update_fields=['duration']) logger.info('process_asset_media: asset %s done (type=%s, dur=%.1f)', asset_id, asset.asset_type, asset.duration)