diff --git a/backend/requirements.txt b/backend/requirements.txt index a2e2f1c..28cacb4 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,4 +10,5 @@ ip-region>=1.0 volcengine>=1.0.218 Pillow>=10.0 celery>=5.3,<6.0 +gevent>=24.2 redis>=5.0,<6.0 diff --git a/backend/tests/mock_airdrama.py b/backend/tests/mock_airdrama.py new file mode 100644 index 0000000..f0fbfb3 --- /dev/null +++ b/backend/tests/mock_airdrama.py @@ -0,0 +1,71 @@ +""" +临时替换 airdrama_client,让 query_task 始终返回 running。 +worker 启动时会 import 这个 mock 版本。 +""" +import os +import time +import redis + +# 用 Redis 做跨进程计数器 +_redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/1') +_r = redis.from_url(_redis_url) +COUNTER_KEY = 'bench:poll_count' +ACTIVE_KEY = 'bench:active' +PEAK_KEY = 'bench:peak' +TASKS_KEY = 'bench:tasks_seen' + + +def query_task(task_id): + """始终返回 running,通过 Redis 统计并发""" + pipe = _r.pipeline() + pipe.incr(COUNTER_KEY) + pipe.incr(ACTIVE_KEY) + pipe.sadd(TASKS_KEY, task_id) + pipe.execute() + + # 检查并更新峰值 + active = int(_r.get(ACTIVE_KEY) or 0) + peak = int(_r.get(PEAK_KEY) or 0) + if active > peak: + _r.set(PEAK_KEY, active) + + time.sleep(0.2) # 模拟 200ms 网络延迟 + + _r.decr(ACTIVE_KEY) + + return {'status': 'running'} + + +def map_status(ark_status): + mapping = { + 'running': 'processing', + 'submitted': 'queued', + 'queued': 'queued', + 'succeeded': 'completed', + 'failed': 'failed', + } + return mapping.get(ark_status, 'processing') + + +def extract_video_url(resp): + return None + + +class AirDramaAPIError(Exception): + def __init__(self, code, message, status_code=400): + self.code = code + self.api_message = message + self.user_message = message + super().__init__(f'{code}: {message}') + + +ERROR_MESSAGES = {} + + +def create_task(**kwargs): + """mock create_task""" + return {'id': 'mock-task-id'} + + +def download_video(url): + return b'' diff --git a/backend/tests/test_poll_concurrency.py b/backend/tests/test_poll_concurrency.py new file mode 100644 index 0000000..64d806c --- /dev/null +++ b/backend/tests/test_poll_concurrency.py @@ -0,0 +1,183 @@ +""" +Celery poll_video_task 并发压测(两步执行) + +步骤 1:启动 worker(mock 火山 API) +步骤 2:派发任务 + 监控 + +用法: + cd backend && source venv/bin/activate + + # 终端 1:启动 mock worker + python tests/test_poll_concurrency.py worker + + # 终端 2:派发 + 监控 + python tests/test_poll_concurrency.py bench --tasks 100 --duration 30 +""" +import argparse +import os +import sys +import time + +# 公共环境变量 +REDIS_URL = os.environ.get('REDIS_URL', + 'redis://zyc:Zyc188208@redis-shzlsczo52dft8mia.redis.volces.com:6379/1') +os.environ['REDIS_URL'] = REDIS_URL +os.environ['USE_MYSQL'] = 'true' +os.environ.setdefault('DB_HOST', 'mysql-8351f937d637-public.rds.volces.com') +os.environ.setdefault('DB_NAME', 'video_auto') +os.environ.setdefault('DB_USER', 'zyc') +os.environ.setdefault('DB_PASSWORD', 'Zyc188208') +os.environ.setdefault('DB_PORT', '3306') + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') + + +def cmd_worker(args): + """启动 worker,用 mock 替换真实 airdrama_client""" + # gevent monkey-patch 必须在所有 import 之前 + from gevent import monkey + monkey.patch_all() + + # 用 mock 模块替换真实 airdrama_client + sys.path.insert(0, os.path.join(os.path.dirname(__file__))) + import mock_airdrama + sys.modules['utils.airdrama_client'] = mock_airdrama + + import django + django.setup() + + print(f'[worker] 启动中... (mock 火山 API, concurrency={args.concurrency})') + print(f'[worker] Redis: {REDIS_URL}') + + from config.celery import app + app.Worker( + pool='gevent', + concurrency=args.concurrency, + loglevel='INFO', + without_heartbeat=True, + without_mingle=True, + without_gossip=True, + ).start() + + +def cmd_bench(args): + """派发任务 + 监控""" + import django + django.setup() + + import redis as redis_lib + r = redis_lib.from_url(REDIS_URL) + + from apps.accounts.models import User, Team + from apps.generation.models import GenerationRecord + from apps.generation.tasks import poll_video_task + + num_tasks = args.tasks + duration = args.duration + + print(f'\n{"="*60}') + print(f' Celery gevent 轮询并发压测') + print(f' 任务数: {num_tasks}') + print(f' 观察时长: {duration} 秒') + print(f' Redis: {REDIS_URL}') + print(f'{"="*60}\n') + + # 清空计数器 + for key in ['bench:poll_count', 'bench:active', 'bench:peak', 'bench:tasks_seen']: + r.delete(key) + + # 准备测试数据 + team, _ = Team.objects.get_or_create(name='压测团队', defaults={'total_seconds_pool': 999999}) + user, _ = User.objects.get_or_create(username='bench_user', defaults={ + 'email': 'bench@test.com', 'team': team, + }) + GenerationRecord.objects.filter(prompt__startswith='压测任务').delete() + + records = [] + for i in range(num_tasks): + record = GenerationRecord.objects.create( + user=user, + prompt=f'压测任务 {i}', + mode='universal', + model='seedance_2.0', + aspect_ratio='16:9', + duration=5, + status='processing', + ark_task_id=f'bench-{i:04d}', + ) + records.append(record) + print(f'[准备] 已创建 {num_tasks} 个测试记录') + + # 清空队列 + r.delete('celery') + print(f'[准备] 已清空 Redis 队列\n') + + # 派发 + print(f'[派发] 正在派发 {num_tasks} 个轮询任务...') + t0 = time.time() + for record in records: + poll_video_task.delay(record.id) + print(f'[派发] 完成,耗时 {time.time()-t0:.1f} 秒\n') + + # 监控 + print(f'[监控] 开始观察 {duration} 秒...\n') + print(f' {"时间":>6s} {"总查询":>8s} {"当前并发":>8s} {"峰值并发":>8s} {"QPS":>8s} {"任务覆盖":>10s}') + print(f' {"-"*6} {"-"*8} {"-"*8} {"-"*8} {"-"*8} {"-"*10}') + + last_count = 0 + for sec in range(1, duration + 1): + time.sleep(1) + ct = int(r.get('bench:poll_count') or 0) + ca = int(r.get('bench:active') or 0) + cp = int(r.get('bench:peak') or 0) + tp = r.scard('bench:tasks_seen') + qps = ct - last_count + last_count = ct + print(f' {sec:>5d}s {ct:>8d} {ca:>8d} {cp:>8d} {qps:>8d} {tp:>9d}/{num_tasks}') + + # 结果 + ft = int(r.get('bench:poll_count') or 0) + fp = int(r.get('bench:peak') or 0) + tp = r.scard('bench:tasks_seen') + + print(f'\n{"="*60}') + print(f' 测试结果') + print(f'{"="*60}') + print(f' 总查询次数: {ft}') + print(f' 平均 QPS: {ft / duration:.1f}') + print(f' 峰值并发查询: {fp}') + print(f' 任务覆盖率: {tp}/{num_tasks} ({tp*100//num_tasks}%)') + print(f'{"="*60}\n') + + if tp == num_tasks: + print(f' PASS: 所有 {num_tasks} 个任务都被成功轮询') + else: + print(f' WARNING: 只有 {tp}/{num_tasks} 个任务被轮询到') + + # 清理(只清 Redis 计数器,DB 记录保留给 worker 查询) + # 测试结束后手动清理: + # python -c "import os,django;os.environ['DJANGO_SETTINGS_MODULE']='config.settings';os.environ['USE_MYSQL']='true';os.environ['DB_HOST']='mysql-8351f937d637-public.rds.volces.com';os.environ['DB_NAME']='video_auto';os.environ['DB_USER']='zyc';os.environ['DB_PASSWORD']='Zyc188208';django.setup();from apps.generation.models import GenerationRecord;print(GenerationRecord.objects.filter(prompt__startswith='压测任务').delete())" + for key in ['bench:poll_count', 'bench:active', 'bench:peak', 'bench:tasks_seen']: + r.delete(key) + print(f' 已清理 Redis 计数器(DB 记录保留给 worker)') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Celery 轮询并发压测') + sub = parser.add_subparsers(dest='cmd') + + p_worker = sub.add_parser('worker', help='启动 mock worker') + p_worker.add_argument('--concurrency', type=int, default=200) + + p_bench = sub.add_parser('bench', help='派发任务 + 监控') + p_bench.add_argument('--tasks', type=int, default=100) + p_bench.add_argument('--duration', type=int, default=30) + + args = parser.parse_args() + if args.cmd == 'worker': + cmd_worker(args) + elif args.cmd == 'bench': + cmd_bench(args) + else: + parser.print_help() diff --git a/k8s/celery-deployment.yaml b/k8s/celery-deployment.yaml index 9152c74..d3a593b 100644 --- a/k8s/celery-deployment.yaml +++ b/k8s/celery-deployment.yaml @@ -14,12 +14,14 @@ spec: labels: app: celery-worker spec: + imagePullSecrets: + - name: swr-secret containers: - name: celery-worker image: ${CI_REGISTRY_IMAGE}/video-backend:latest imagePullPolicy: Always - command: ["celery", "-A", "config", "worker", "--loglevel=info", "--concurrency=4", "-B"] - env: + command: ["celery", "-A", "config", "worker", "--loglevel=info", "--pool=gevent", "--concurrency=200"] + env: &shared-env - name: USE_MYSQL value: "true" - name: DJANGO_DEBUG @@ -34,7 +36,7 @@ spec: # Redis - name: REDIS_URL value: "redis://zyc:Zyc188208@redis-shzlsczo52dft8mia.redis.ivolces.com:6379/0" - # Database (Volcano Engine RDS - 默认测试环境,生产环境通过 CI 替换) + # Database (Volcano Engine RDS) - name: DB_HOST value: "mysql8351f937d637.rds.ivolces.com" - name: DB_NAME @@ -78,8 +80,20 @@ spec: value: "true" resources: requests: + memory: "256Mi" + cpu: "200m" + limits: + memory: "1Gi" + cpu: "1000m" + - name: celery-beat + image: ${CI_REGISTRY_IMAGE}/video-backend:latest + imagePullPolicy: Always + command: ["celery", "-A", "config", "beat", "--loglevel=info"] + env: *shared-env + resources: + requests: + memory: "64Mi" + cpu: "50m" + limits: memory: "128Mi" cpu: "100m" - limits: - memory: "512Mi" - cpu: "500m"