from rest_framework import status from rest_framework.decorators import api_view, permission_classes, throttle_classes from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from rest_framework.throttling import ScopedRateThrottle from django.contrib.auth import authenticate, get_user_model from django.utils import timezone from django.db.models import Sum, Count from .serializers import UserSerializer from .models import ActiveSession, LoginRecord, get_client_ip, parse_device_type from .tokens import SessionRefreshToken from django.contrib.auth.hashers import check_password User = get_user_model() class LoginRateThrottle(ScopedRateThrottle): scope = 'login' @api_view(['POST']) @permission_classes([AllowAny]) def register_view(request): """POST /api/v1/auth/register — disabled, all accounts created by admins.""" return Response( {'error': 'registration_disabled', 'message': '公开注册已关闭,请联系管理员'}, status=status.HTTP_403_FORBIDDEN, ) def _enforce_session_limit(user, device_type): """Enforce concurrent session limits: remove oldest sessions if over limit.""" from apps.generation.models import QuotaConfig config = QuotaConfig.objects.filter(pk=1).first() if device_type == 'desktop': max_sessions = config.max_desktop_sessions if config else 1 elif device_type == 'mobile': max_sessions = config.max_mobile_sessions if config else 0 else: max_sessions = 1 if max_sessions <= 0: # 0 means no sessions allowed for this device type — but still allow login # (treat as unlimited for unknown device types) if device_type == 'unknown': return # For mobile with limit 0, still allow (no mobile enforcement yet) return existing = ActiveSession.objects.filter( user=user, device_type=device_type ).order_by('created_at') # If at or over limit, delete oldest sessions to make room for the new one over_count = existing.count() - max_sessions + 1 if over_count > 0: ids_to_remove = list(existing.values_list('id', flat=True)[:over_count]) ActiveSession.objects.filter(id__in=ids_to_remove).delete() @api_view(['POST']) @permission_classes([AllowAny]) @throttle_classes([LoginRateThrottle]) def login_view(request): """POST /api/v1/auth/login""" username = request.data.get('username', '').strip() password = request.data.get('password', '') # Try authenticate with username first, then email user = authenticate(username=username, password=password) if user is None: # Try email login try: user_by_email = User.objects.get(email=username) user = authenticate(username=user_by_email.username, password=password) except User.DoesNotExist: pass if user is None: return Response( {'error': 'invalid_credentials', 'message': '用户名或密码错误'}, status=status.HTTP_401_UNAUTHORIZED ) # Check if user or team is disabled if not user.is_active: code = 'user_disabled' return Response( {'code': code, 'message': '您的账号已被禁用,请联系团队管理员'}, status=status.HTTP_401_UNAUTHORIZED ) if user.team and not user.team.is_active: code = 'team_disabled' return Response( {'code': code, 'message': '您所在的团队已被禁用,请联系平台管理员'}, status=status.HTTP_403_FORBIDDEN ) # Record login IP and User-Agent ip = get_client_ip(request) user_agent = request.META.get('HTTP_USER_AGENT', '') login_record = LoginRecord.objects.create( user=user, team=user.team, ip_address=ip, user_agent=user_agent, geo_country='', geo_province='', geo_city='', geo_source='', ) # IP 归属地解析 + 异常检测(不阻塞登录) try: from utils.geo_client import resolve_ip_location country, province, city, source = resolve_ip_location(ip) login_record.geo_country = country login_record.geo_province = province login_record.geo_city = city login_record.geo_source = source login_record.save(update_fields=['geo_country', 'geo_province', 'geo_city', 'geo_source']) from utils.anomaly_detector import check_login_anomaly, process_anomalies anomalies = check_login_anomaly(login_record) if anomalies: process_anomalies(login_record, anomalies) # 封禁后重新检查(anomaly_detector 可能刚封禁了用户/团队) user.refresh_from_db() if not user.is_active: return Response( {'code': 'user_disabled', 'message': '您的账号已被禁用,请联系团队管理员'}, status=status.HTTP_401_UNAUTHORIZED ) if user.team: user.team.refresh_from_db() if not user.team.is_active: return Response( {'code': 'team_disabled', 'message': '您所在的团队已被禁用,请联系平台管理员'}, status=status.HTTP_403_FORBIDDEN ) except Exception: import logging logging.getLogger(__name__).exception('Anomaly detection failed for login %s', login_record.pk) # Concurrent session management device_type = parse_device_type(user_agent) _enforce_session_limit(user, device_type) session = ActiveSession.objects.create(user=user, device_type=device_type, user_agent=user_agent) refresh = SessionRefreshToken.for_user_session(user, session.session_id) return Response({ 'user': UserSerializer(user).data, 'tokens': { 'access': str(refresh.access_token), 'refresh': str(refresh), } }) @api_view(['POST']) @permission_classes([IsAuthenticated]) def logout_view(request): """POST /api/v1/auth/logout — 清除当前会话,标记用户离线。""" session_id = getattr(request, 'session_id', None) if session_id: ActiveSession.objects.filter(user=request.user, session_id=session_id).delete() else: # fallback: 清除该用户所有会话 ActiveSession.objects.filter(user=request.user).delete() return Response({'detail': 'ok'}) @api_view(['GET']) @permission_classes([IsAuthenticated]) def me_view(request): """GET /api/v1/auth/me — returns role, team info, and quota.""" user = request.user today = timezone.now().date() first_of_month = today.replace(day=1) daily_seconds_used = user.generation_records.filter( created_at__date=today ).aggregate(total=Sum('seconds_consumed'))['total'] or 0 monthly_seconds_used = user.generation_records.filter( created_at__date__gte=first_of_month ).aggregate(total=Sum('seconds_consumed'))['total'] or 0 # Count-based usage daily_generation_used = user.generation_records.filter( created_at__date=today ).count() monthly_generation_used = user.generation_records.filter( created_at__date__gte=first_of_month ).count() data = UserSerializer(user).data data['quota'] = { 'daily_seconds_limit': user.daily_seconds_limit, 'daily_seconds_used': daily_seconds_used, 'monthly_seconds_limit': user.monthly_seconds_limit, 'monthly_seconds_used': monthly_seconds_used, 'daily_generation_limit': user.daily_generation_limit, 'daily_generation_used': daily_generation_used, 'monthly_generation_limit': user.monthly_generation_limit, 'monthly_generation_used': monthly_generation_used, } # Team info team = user.team if team: # Team monthly consumption from apps.generation.models import GenerationRecord, QuotaConfig team_monthly_used = GenerationRecord.objects.filter( user__team=team, created_at__date__gte=first_of_month, ).aggregate(total=Sum('seconds_consumed'))['total'] or 0 team_monthly_spent = GenerationRecord.objects.filter( user__team=team, created_at__date__gte=first_of_month, ).aggregate(total=Sum('cost_amount'))['total'] or 0 config = QuotaConfig.objects.get_or_create(pk=1)[0] markup_mult = 1 + float(team.markup_percentage) / 100 token_price = float(config.base_token_price) * markup_mult data['team'] = { 'id': team.id, 'name': team.name, 'total_seconds_pool': team.total_seconds_pool, 'total_seconds_used': team.total_seconds_used, 'remaining_seconds': team.remaining_seconds, 'monthly_seconds_limit': team.monthly_seconds_limit, 'monthly_seconds_used': team_monthly_used, 'balance': float(team.balance), 'total_spent': float(team.total_spent), 'available_balance': float(team.available_balance), 'monthly_spending_limit': float(team.monthly_spending_limit), 'monthly_spent': float(team_monthly_spent), 'frozen_amount': float(team.frozen_amount), 'token_price': token_price, 'token_price_video': float(config.base_token_price_video) * markup_mult, 'token_price_fast': float(config.base_token_price_fast) * markup_mult, 'token_price_fast_video': float(config.base_token_price_fast_video) * markup_mult, 'is_active': team.is_active, } data['team_disabled'] = not team.is_active else: data['team'] = None data['team_disabled'] = False return Response(data) @api_view(['POST']) @permission_classes([IsAuthenticated]) def change_password_view(request): """POST /api/v1/auth/change-password — user changes own password.""" old_password = request.data.get('old_password', '') new_password = request.data.get('new_password', '') if not old_password or not new_password: return Response( {'error': 'missing_fields', 'message': '请填写旧密码和新密码'}, status=status.HTTP_400_BAD_REQUEST, ) if len(new_password) < 8: return Response( {'error': 'password_too_short', 'message': '新密码至少8位'}, status=status.HTTP_400_BAD_REQUEST, ) if not check_password(old_password, request.user.password): return Response( {'error': 'wrong_password', 'message': '旧密码错误'}, status=status.HTTP_400_BAD_REQUEST, ) request.user.set_password(new_password) request.user.must_change_password = False request.user.save(update_fields=['password', 'must_change_password']) return Response({ 'message': '密码修改成功', 'user': UserSerializer(request.user).data, })