"""Celery tasks for async video generation polling.""" import logging from celery import shared_task logger = logging.getLogger(__name__) @shared_task(ignore_result=True) def poll_video_task(record_id): """Poll Volcano API once for a video generation task. 一次性任务:查一次 API,更新 DB,结束。 由 recover_stuck_tasks(beat 每30秒调度)统一驱动,不再自己 retry。 """ from django.utils import timezone from apps.generation.models import GenerationRecord from utils.airdrama_client import query_task, map_status try: record = GenerationRecord.objects.get(pk=record_id) except GenerationRecord.DoesNotExist: logger.warning('poll_video_task: record %s not found', record_id) return if record.status not in ('queued', 'processing'): return ark_task_id = record.ark_task_id if not ark_task_id: logger.warning('poll_video_task: record %s has no ark_task_id', record_id) return # Poll Volcano API try: ark_resp = query_task(ark_task_id) new_status = map_status(ark_resp.get('status', '')) except Exception: logger.exception('poll_video_task: API query failed for record=%s ark=%s', record_id, ark_task_id) return if new_status in ('queued', 'processing'): record.status = new_status record.save(update_fields=['status', 'updated_at']) return # Terminal state reached — process result record.status = new_status returned_seed = ark_resp.get('seed') if returned_seed is not None: record.seed = returned_seed if new_status == 'completed': _handle_completed(record, ark_resp) elif new_status == 'failed': _handle_failed(record, ark_resp) record.completed_at = timezone.now() record.save(update_fields=[ 'status', 'result_url', 'thumbnail_url', 'error_message', 'raw_error', 'seed', 'completed_at', ]) logger.info( 'poll_video_task: record=%s ark=%s final_status=%s', record_id, ark_task_id, new_status, ) def _handle_completed(record, ark_resp): """Process a completed task: persist video to TOS, extract thumbnail, settle payment.""" import os from utils.airdrama_client import extract_video_url video_url = extract_video_url(ark_resp) if video_url: # Download once to temp file, reuse for TOS upload + thumbnail extraction tmp_path = None try: from utils.media_utils import download_to_temp, extract_video_info_from_file from utils.tos_client import upload_from_file_path, upload_file tmp_path = download_to_temp(video_url, '.mp4') # Upload video to TOS from file (streaming, no full memory load) record.result_url = upload_from_file_path(tmp_path, folder='results', content_type='video/mp4') # Extract thumbnail from the same local file (no second download) thumb_file, _ = extract_video_info_from_file(tmp_path) if thumb_file: record.thumbnail_url = upload_file(thumb_file, folder='thumbnails') except Exception: logger.exception('poll_video_task: failed to persist video / extract thumbnail') if not record.result_url: record.result_url = video_url record.error_message = '视频保存失败,临时链接将在24小时后过期,请联系管理员' finally: if tmp_path and os.path.exists(tmp_path): os.unlink(tmp_path) # 结算:按实际 tokens 扣费 usage = ark_resp.get('usage', {}) total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0 if total_tokens > 0: from apps.generation.views import _settle_payment _settle_payment(record, total_tokens) else: from apps.generation.views import _release_freeze _release_freeze(record) @shared_task(ignore_result=True) def recover_stuck_tasks(): """每30秒扫一次所有进行中的任务,统一派发轮询。 poll_video_task 是一次性任务,不再自己 retry,由这里统一驱动。 """ from apps.generation.models import GenerationRecord active_records = GenerationRecord.objects.filter( status__in=('queued', 'processing'), ark_task_id__isnull=False, ).exclude(ark_task_id='').values_list('id', flat=True) count = 0 for record_id in active_records: try: poll_video_task.delay(record_id) count += 1 except Exception: logger.error('recover_stuck_tasks: failed to dispatch record=%s', record_id) if count: logger.info('recover_stuck_tasks: dispatched %d active tasks', count) def _handle_failed(record, ark_resp): """Process a failed task: record error and release frozen amount.""" from utils.airdrama_client import ERROR_MESSAGES error = ark_resp.get('error', {}) code = error.get('code', '') if isinstance(error, dict) else '' raw_msg = error.get('message', '') if isinstance(error, dict) else str(error) record.error_message = ERROR_MESSAGES.get(code, raw_msg) record.raw_error = f'{code}: {raw_msg}' if code else raw_msg usage = ark_resp.get('usage', {}) total_tokens = usage.get('total_tokens', 0) if isinstance(usage, dict) else 0 if total_tokens > 0: from apps.generation.views import _settle_payment _settle_payment(record, total_tokens) else: from apps.generation.views import _release_freeze _release_freeze(record) @shared_task(ignore_result=True) def process_asset_media(asset_id): """Extract thumbnail + duration for video/audio assets asynchronously.""" from apps.generation.models import Asset try: asset = Asset.objects.select_related('group').get(pk=asset_id) except Asset.DoesNotExist: logger.warning('process_asset_media: asset %s not found', asset_id) return from utils.media_utils import extract_video_info, get_audio_duration from utils.tos_client import upload_file if asset.asset_type == 'Video': thumb_file, dur = extract_video_info(asset.url) if thumb_file: try: asset.thumbnail_url = upload_file(thumb_file, folder='thumbnails') except Exception: logger.exception('process_asset_media: thumbnail upload failed for asset %s', asset_id) asset.duration = dur if dur > 0 else None # None = ffprobe failed, frontend skips duration check asset.save(update_fields=['thumbnail_url', 'duration']) # Atomic update: only set group thumbnail if still empty (concurrent-safe) from apps.generation.models import AssetGroup from django.db import transaction try: with transaction.atomic(): group = AssetGroup.objects.select_for_update().get(pk=asset.group_id) if not group.thumbnail_url and asset.thumbnail_url: group.thumbnail_url = asset.thumbnail_url group.save(update_fields=['thumbnail_url']) except AssetGroup.DoesNotExist: logger.warning('process_asset_media: group %s deleted, skipping thumbnail update', asset.group_id) elif asset.asset_type == 'Audio': dur = get_audio_duration(asset.url) asset.duration = dur if dur > 0 else None asset.save(update_fields=['duration']) logger.info('process_asset_media: asset %s done (type=%s, dur=%s)', asset_id, asset.asset_type, asset.duration)