This commit is contained in:
zyc 2026-04-04 20:13:23 +08:00
parent 6353d2ec4f
commit ded5c4c44f
2 changed files with 22 additions and 42 deletions

View File

@ -6,44 +6,30 @@ from celery import shared_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 轮询间隔(秒):每次查完后重新入队,不占 worker 进程
POLL_INTERVAL = 5
@shared_task(ignore_result=True)
def poll_video_task(record_id):
"""Poll Volcano API once for a video generation task.
@shared_task(bind=True, max_retries=None, ignore_result=True) 一次性任务查一次 API更新 DB结束
def poll_video_task(self, record_id): recover_stuck_tasksbeat 每30秒调度统一驱动不再自己 retry
"""Poll Volcano API for a video generation task.
每次只执行一轮查询查完通过 self.retry 重新入队
这样 worker 不会被 sleep 占死重启也不丢任务
""" """
from django.utils import timezone from django.utils import timezone
from apps.generation.models import GenerationRecord from apps.generation.models import GenerationRecord
from utils.airdrama_client import query_task, map_status from utils.airdrama_client import query_task, map_status
# 防重复:同一 record 同一时刻只允许一个 poll 在执行
from django.core.cache import cache
lock_key = f'poll_lock:{record_id}'
if not cache.add(lock_key, '1', timeout=POLL_INTERVAL * 3):
logger.info('poll_video_task: record %s already being polled, skipping', record_id)
return
try: try:
record = GenerationRecord.objects.get(pk=record_id) record = GenerationRecord.objects.get(pk=record_id)
except GenerationRecord.DoesNotExist: except GenerationRecord.DoesNotExist:
logger.warning('poll_video_task: record %s not found', record_id) logger.warning('poll_video_task: record %s not found', record_id)
cache.delete(lock_key) return
if record.status not in ('queued', 'processing'):
return return
ark_task_id = record.ark_task_id ark_task_id = record.ark_task_id
if not ark_task_id: if not ark_task_id:
logger.warning('poll_video_task: record %s has no ark_task_id', record_id) logger.warning('poll_video_task: record %s has no ark_task_id', record_id)
cache.delete(lock_key)
return
if record.status not in ('queued', 'processing'):
logger.info('poll_video_task: record %s already in terminal state: %s', record_id, record.status)
cache.delete(lock_key)
return return
# Poll Volcano API # Poll Volcano API
@ -51,16 +37,13 @@ def poll_video_task(self, record_id):
ark_resp = query_task(ark_task_id) ark_resp = query_task(ark_task_id)
new_status = map_status(ark_resp.get('status', '')) new_status = map_status(ark_resp.get('status', ''))
except Exception: except Exception:
logger.exception('poll_video_task: API query failed for %s, will retry', ark_task_id) logger.exception('poll_video_task: API query failed for record=%s ark=%s', record_id, ark_task_id)
cache.delete(lock_key) return
raise self.retry(countdown=POLL_INTERVAL)
if new_status in ('queued', 'processing'): if new_status in ('queued', 'processing'):
# Still running — update status, then re-enqueue
record.status = new_status record.status = new_status
record.save(update_fields=['status', 'updated_at']) record.save(update_fields=['status', 'updated_at'])
cache.delete(lock_key) return
raise self.retry(countdown=POLL_INTERVAL)
# Terminal state reached — process result # Terminal state reached — process result
record.status = new_status record.status = new_status
@ -80,7 +63,6 @@ def poll_video_task(self, record_id):
'seed', 'completed_at', 'seed', 'completed_at',
]) ])
cache.delete(lock_key)
logger.info( logger.info(
'poll_video_task: record=%s ark=%s final_status=%s', 'poll_video_task: record=%s ark=%s final_status=%s',
record_id, ark_task_id, new_status, record_id, ark_task_id, new_status,
@ -131,29 +113,27 @@ def _handle_completed(record, ark_resp):
@shared_task(ignore_result=True) @shared_task(ignore_result=True)
def recover_stuck_tasks(): def recover_stuck_tasks():
"""定时扫描卡在 processing/queued 超过 3 分钟的任务,重新派发轮询。""" """每30秒扫一次所有进行中的任务统一派发轮询。
from datetime import timedelta
from django.utils import timezone poll_video_task 是一次性任务不再自己 retry由这里统一驱动
"""
from apps.generation.models import GenerationRecord from apps.generation.models import GenerationRecord
cutoff = timezone.now() - timedelta(minutes=3) active_records = GenerationRecord.objects.filter(
stuck_records = GenerationRecord.objects.filter(
status__in=('queued', 'processing'), status__in=('queued', 'processing'),
ark_task_id__isnull=False, ark_task_id__isnull=False,
updated_at__lt=cutoff, ).exclude(ark_task_id='').values_list('id', flat=True)
).exclude(ark_task_id='')
count = 0 count = 0
for record in stuck_records: for record_id in active_records:
logger.warning('recover_stuck_tasks: re-dispatching record=%s ark=%s', record.id, record.ark_task_id)
try: try:
poll_video_task.delay(record.id) poll_video_task.delay(record_id)
count += 1 count += 1
except Exception: except Exception:
logger.error('recover_stuck_tasks: failed to dispatch record=%s', record.id) logger.error('recover_stuck_tasks: failed to dispatch record=%s', record_id)
if count: if count:
logger.info('recover_stuck_tasks: re-dispatched %d stuck tasks', count) logger.info('recover_stuck_tasks: dispatched %d active tasks', count)
def _handle_failed(record, ark_resp): def _handle_failed(record, ark_resp):

View File

@ -182,7 +182,7 @@ CELERY_TIMEZONE = 'Asia/Shanghai'
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
'recover-stuck-tasks': { 'recover-stuck-tasks': {
'task': 'apps.generation.tasks.recover_stuck_tasks', 'task': 'apps.generation.tasks.recover_stuck_tasks',
'schedule': 180, # 每 3 分钟 'schedule': 30, # 每 30 秒
}, },
} }