2025-04-08 08:07:41 +08:00

493 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, Depends, HTTPException, UploadFile, File, Form, BackgroundTasks, Request
import admin
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, ForeignKey, DateTime, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import random
import datetime
import os
import uuid
from pathlib import Path
# Create the FastAPI app
app = FastAPI(
title="Lottery System API",
description="Backend API for lottery system with SQLite database",
version="1.0.0"
)
# Include admin router
app.include_router(admin.router)
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, change this to your frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Database setup
DATABASE_URL = "sqlite:///./lottery.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Database dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Models
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50), nullable=True)
phone = Column(String(20), nullable=True)
address = Column(Text, nullable=True)
created_at = Column(DateTime, default=func.now())
class Prize(Base):
__tablename__ = "prizes"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(100))
description = Column(Text, nullable=True)
probability = Column(Float, default=0)
available_quantity = Column(Integer, default=0)
is_active = Column(Boolean, default=True)
class Card(Base):
__tablename__ = "cards"
id = Column(Integer, primary_key=True, index=True)
card_id = Column(String(50), unique=True)
card_type = Column(String(20))
is_collected = Column(Boolean, default=False)
collected_by = Column(Integer, ForeignKey("users.id"), nullable=True)
collected_at = Column(DateTime, nullable=True)
class UserPrize(Base):
__tablename__ = "user_prizes"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"))
prize_id = Column(Integer, ForeignKey("prizes.id"))
awarded_at = Column(DateTime, default=func.now())
is_shipped = Column(Boolean, default=False)
shipped_at = Column(DateTime, nullable=True)
# Create tables
Base.metadata.create_all(bind=engine)
# Pydantic Models for API
class UserCreate(BaseModel):
name: Optional[str] = None
phone: Optional[str] = None
address: Optional[str] = None
class UserUpdate(BaseModel):
name: Optional[str] = None
phone: Optional[str] = None
address: Optional[str] = None
class PrizeCreate(BaseModel):
name: str
description: Optional[str] = None
probability: float
available_quantity: int = 0
is_active: bool = True
class CardCreate(BaseModel):
card_type: str
class CardCollect(BaseModel):
card_id: str
user_id: int
class DrawResult(BaseModel):
success: bool
prize: Optional[Dict[str, Any]] = None
message: Optional[str] = None
class ShippingUpdate(BaseModel):
user_id: int
address: str
# Initialize default prizes if none exist
def initialize_prizes(db: Session):
# Check if prizes already exist
existing_prizes = db.query(Prize).count()
if existing_prizes == 0:
default_prizes = [
{"name": "一等奖", "description": "豪华大礼包", "probability": 0.01, "available_quantity": 5},
{"name": "二等奖", "description": "精美礼品", "probability": 0.05, "available_quantity": 20},
{"name": "三等奖", "description": "纪念品", "probability": 0.2, "available_quantity": 50},
{"name": "鼓励奖", "description": "小礼品", "probability": 0.3, "available_quantity": 100},
{"name": "谢谢参与", "description": "下次再来", "probability": 0.44, "available_quantity": 999},
]
for prize_data in default_prizes:
db_prize = Prize(**prize_data)
db.add(db_prize)
db.commit()
# API Routes
@app.on_event("startup")
async def startup_event():
db = SessionLocal()
initialize_prizes(db)
db.close()
@app.get("/")
def read_root():
return {"message": "Welcome to the Lottery System API"}
# User routes
@app.post("/users/", response_model=dict)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = User(**user.dict())
db.add(db_user)
db.commit()
db.refresh(db_user)
return {"success": True, "user_id": db_user.id}
@app.put("/users/{user_id}", response_model=dict)
def update_user(user_id: int, user: UserUpdate, db: Session = Depends(get_db)):
db_user = db.query(User).filter(User.id == user_id).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
for key, value in user.dict(exclude_unset=True).items():
setattr(db_user, key, value)
db.commit()
db.refresh(db_user)
return {"success": True, "message": "User updated successfully"}
# Draw lottery route
@app.post("/draw/{user_id}", response_model=DrawResult)
def draw_lottery(user_id: int, db: Session = Depends(get_db)):
# Verify user exists
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Get active prizes
prizes = db.query(Prize).filter(Prize.is_active == True, Prize.available_quantity > 0).all()
# No prizes available
if not prizes:
return DrawResult(success=False, message="No prizes available")
# Calculate draw result based on probability
random_num = random.random()
cumulative_prob = 0
selected_prize = None
for prize in prizes:
cumulative_prob += prize.probability
if random_num <= cumulative_prob:
selected_prize = prize
break
if not selected_prize:
selected_prize = prizes[-1] # Default to last prize if no match (should be "谢谢参与")
# 记录所有的抽奖结果到user_prizes表包括"谢谢参与"
# 这样可以记录每个用户的抽奖历史
user_prize = UserPrize(
user_id=user_id,
prize_id=selected_prize.id,
is_shipped=False
)
db.add(user_prize)
# 只有实际奖品才减少库存
if selected_prize.name != "谢谢参与":
# 减少可用数量
selected_prize.available_quantity -= 1
# 提交所有更改到数据库
db.commit()
# 记录日志
print(f"用户 {user_id} 抽中了 {selected_prize.name},信息已存入数据库")
# Return prize info
return DrawResult(
success=True,
prize={
"id": selected_prize.id,
"name": selected_prize.name,
"description": selected_prize.description
}
)
# Card routes
@app.post("/cards/generate", response_model=dict)
def generate_card(card: CardCreate, db: Session = Depends(get_db)):
card_id = f"CARD-{uuid.uuid4().hex[:8].upper()}"
db_card = Card(card_id=card_id, card_type=card.card_type)
db.add(db_card)
db.commit()
db.refresh(db_card)
return {"success": True, "card_id": card_id}
@app.post("/cards/collect", response_model=dict)
def collect_card(card_data: CardCollect, db: Session = Depends(get_db)):
# Check if card exists
card = db.query(Card).filter(Card.card_id == card_data.card_id).first()
if not card:
raise HTTPException(status_code=404, detail="Card not found")
# Check if card is already collected
if card.is_collected:
return {"success": False, "message": "Card is already collected"}
# Collect the card
card.is_collected = True
card.collected_by = card_data.user_id
card.collected_at = datetime.datetime.now()
db.commit()
# Check if user has collected 5 cards
collected_cards = db.query(Card).filter(
Card.collected_by == card_data.user_id,
Card.is_collected == True
).count()
return {
"success": True,
"card_id": card.card_id,
"collected_count": collected_cards,
"has_complete_set": collected_cards >= 5
}
@app.get("/users/{user_id}/cards", response_model=dict)
def get_user_cards(user_id: int, db: Session = Depends(get_db)):
cards = db.query(Card).filter(Card.collected_by == user_id, Card.is_collected == True).all()
return {
"success": True,
"cards": [{"card_id": card.card_id, "card_type": card.card_type} for card in cards],
"count": len(cards),
"has_complete_set": len(cards) >= 5
}
# Prize claiming for card collection
@app.post("/cards/claim-prize/{user_id}", response_model=DrawResult)
def claim_prize_for_cards(user_id: int, db: Session = Depends(get_db)):
# Check if user has 5 or more cards
card_count = db.query(Card).filter(Card.collected_by == user_id, Card.is_collected == True).count()
if card_count < 5:
return DrawResult(success=False, message="Not enough cards collected. Need at least 5 cards.")
# Get a random prize (excluding "谢谢参与")
prizes = db.query(Prize).filter(
Prize.is_active == True,
Prize.available_quantity > 0,
Prize.name != "谢谢参与"
).all()
if not prizes:
return DrawResult(success=False, message="No prizes available")
# Select a random prize from available ones
selected_prize = random.choice(prizes)
# Decrement available quantity
selected_prize.available_quantity -= 1
# Record user win
user_prize = UserPrize(user_id=user_id, prize_id=selected_prize.id)
db.add(user_prize)
db.commit()
# Return prize info
return DrawResult(
success=True,
prize={
"id": selected_prize.id,
"name": selected_prize.name,
"description": selected_prize.description
}
)
# Update shipping address (保留旧接口兼容性)
@app.post("/shipping/update", response_model=dict)
async def update_shipping(request: Request, db: Session = Depends(get_db)):
# 导入json
import json
# 直接从请求体中读取原始数据
body_bytes = await request.body()
body_str = body_bytes.decode()
print(f"Raw request body: {body_str}")
# 解析JSON数据
try:
body_data = json.loads(body_str) if body_str else {}
except json.JSONDecodeError:
print("Failed to parse JSON")
body_data = {}
print(f"Parsed body data: {body_data}")
# 从请求体中提取用户ID和地址
try:
# 尝试多种可能的键名获取user_id
user_id = None
if 'user_id' in body_data:
user_id = int(body_data['user_id'])
elif 'userId' in body_data:
user_id = int(body_data['userId'])
if user_id is None:
raise ValueError("user_id is required but not found in request")
# 获取地址
address = ""
if 'address' in body_data:
address = str(body_data['address']).strip()
if not address:
raise HTTPException(status_code=400, detail="Address is required")
print(f"Extracted user_id: {user_id}, address: {address}")
except (ValueError, TypeError) as e:
print(f"Error processing data: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid data format: {str(e)}")
# 查找用户
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# 更新地址
user.address = address
db.commit()
# 标记奖品为已发货
prizes = db.query(UserPrize).filter(
UserPrize.user_id == user_id,
UserPrize.is_shipped == False
).all()
for prize in prizes:
prize.is_shipped = True
prize.shipped_at = datetime.datetime.now()
db.commit()
return {"success": True, "message": "Shipping information updated"}
# 专门用于提交收货地址的接口
@app.post("/users/address", response_model=dict)
async def submit_address(request: Request, db: Session = Depends(get_db)):
# 导入json
import json
# 直接从请求体中读取原始数据
body_bytes = await request.body()
body_str = body_bytes.decode()
print(f"Raw request body: {body_str}")
# 解析JSON数据
try:
body_data = json.loads(body_str) if body_str else {}
except json.JSONDecodeError:
print("Failed to parse JSON")
body_data = {}
print(f"Parsed body data: {body_data}")
# 从请求体中提取用户ID和地址
try:
# 尝试多种可能的键名获取user_id
user_id = None
if 'user_id' in body_data:
user_id = int(body_data['user_id'])
elif 'userId' in body_data:
user_id = int(body_data['userId'])
if user_id is None:
raise ValueError("user_id is required but not found in request")
# 获取地址
address = ""
if 'address' in body_data:
address = str(body_data['address']).strip()
if not address:
raise HTTPException(status_code=400, detail="Address is required")
print(f"Extracted user_id: {user_id}, address: {address}")
except (ValueError, TypeError) as e:
print(f"Error processing data: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid data format: {str(e)}")
# 查找用户
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# 更新地址
user.address = address
db.commit()
# 标记奖品为已发货
prizes = db.query(UserPrize).filter(
UserPrize.user_id == user_id,
UserPrize.is_shipped == False
).all()
for prize in prizes:
prize.is_shipped = True
prize.shipped_at = datetime.datetime.now()
db.commit()
return {"success": True, "message": "Shipping address submitted successfully"}
# Stats routes for dashboard
@app.get("/stats", response_model=dict)
def get_stats(db: Session = Depends(get_db)):
user_count = db.query(User).count()
prizes_awarded = db.query(UserPrize).count()
prizes_shipped = db.query(UserPrize).filter(UserPrize.is_shipped == True).count()
cards_collected = db.query(Card).filter(Card.is_collected == True).count()
# Prize distribution
prize_distribution = db.query(
Prize.name, func.count(UserPrize.id)
).join(
UserPrize, UserPrize.prize_id == Prize.id, isouter=True
).group_by(
Prize.id
).all()
return {
"user_count": user_count,
"prizes_awarded": prizes_awarded,
"prizes_shipped": prizes_shipped,
"cards_collected": cards_collected,
"prize_distribution": [{"name": name, "count": count} for name, count in prize_distribution]
}
# Run with: uvicorn main:app --reload
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)