493 lines
16 KiB
Python
493 lines
16 KiB
Python
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File, Form, BackgroundTasks, Request
|
||
import admin
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, ForeignKey, DateTime, Text
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
from sqlalchemy.sql import func
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional, Dict, Any
|
||
import random
|
||
import datetime
|
||
import os
|
||
import uuid
|
||
from pathlib import Path
|
||
|
||
# Create the FastAPI app
|
||
app = FastAPI(
|
||
title="Lottery System API",
|
||
description="Backend API for lottery system with SQLite database",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# Include admin router
|
||
app.include_router(admin.router)
|
||
|
||
# Add CORS middleware to allow frontend requests
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # In production, change this to your frontend domain
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Database setup
|
||
DATABASE_URL = "sqlite:///./lottery.db"
|
||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
Base = declarative_base()
|
||
|
||
# Database dependency
|
||
def get_db():
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
# Models
|
||
class User(Base):
|
||
__tablename__ = "users"
|
||
|
||
id = Column(Integer, primary_key=True, index=True)
|
||
name = Column(String(50), nullable=True)
|
||
phone = Column(String(20), nullable=True)
|
||
address = Column(Text, nullable=True)
|
||
created_at = Column(DateTime, default=func.now())
|
||
|
||
class Prize(Base):
|
||
__tablename__ = "prizes"
|
||
|
||
id = Column(Integer, primary_key=True, index=True)
|
||
name = Column(String(100))
|
||
description = Column(Text, nullable=True)
|
||
probability = Column(Float, default=0)
|
||
available_quantity = Column(Integer, default=0)
|
||
is_active = Column(Boolean, default=True)
|
||
|
||
class Card(Base):
|
||
__tablename__ = "cards"
|
||
|
||
id = Column(Integer, primary_key=True, index=True)
|
||
card_id = Column(String(50), unique=True)
|
||
card_type = Column(String(20))
|
||
is_collected = Column(Boolean, default=False)
|
||
collected_by = Column(Integer, ForeignKey("users.id"), nullable=True)
|
||
collected_at = Column(DateTime, nullable=True)
|
||
|
||
class UserPrize(Base):
|
||
__tablename__ = "user_prizes"
|
||
|
||
id = Column(Integer, primary_key=True, index=True)
|
||
user_id = Column(Integer, ForeignKey("users.id"))
|
||
prize_id = Column(Integer, ForeignKey("prizes.id"))
|
||
awarded_at = Column(DateTime, default=func.now())
|
||
is_shipped = Column(Boolean, default=False)
|
||
shipped_at = Column(DateTime, nullable=True)
|
||
|
||
# Create tables
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
# Pydantic Models for API
|
||
class UserCreate(BaseModel):
|
||
name: Optional[str] = None
|
||
phone: Optional[str] = None
|
||
address: Optional[str] = None
|
||
|
||
class UserUpdate(BaseModel):
|
||
name: Optional[str] = None
|
||
phone: Optional[str] = None
|
||
address: Optional[str] = None
|
||
|
||
class PrizeCreate(BaseModel):
|
||
name: str
|
||
description: Optional[str] = None
|
||
probability: float
|
||
available_quantity: int = 0
|
||
is_active: bool = True
|
||
|
||
class CardCreate(BaseModel):
|
||
card_type: str
|
||
|
||
class CardCollect(BaseModel):
|
||
card_id: str
|
||
user_id: int
|
||
|
||
class DrawResult(BaseModel):
|
||
success: bool
|
||
prize: Optional[Dict[str, Any]] = None
|
||
message: Optional[str] = None
|
||
|
||
class ShippingUpdate(BaseModel):
|
||
user_id: int
|
||
address: str
|
||
|
||
# Initialize default prizes if none exist
|
||
def initialize_prizes(db: Session):
|
||
# Check if prizes already exist
|
||
existing_prizes = db.query(Prize).count()
|
||
if existing_prizes == 0:
|
||
default_prizes = [
|
||
{"name": "一等奖", "description": "豪华大礼包", "probability": 0.01, "available_quantity": 5},
|
||
{"name": "二等奖", "description": "精美礼品", "probability": 0.05, "available_quantity": 20},
|
||
{"name": "三等奖", "description": "纪念品", "probability": 0.2, "available_quantity": 50},
|
||
{"name": "鼓励奖", "description": "小礼品", "probability": 0.3, "available_quantity": 100},
|
||
{"name": "谢谢参与", "description": "下次再来", "probability": 0.44, "available_quantity": 999},
|
||
]
|
||
|
||
for prize_data in default_prizes:
|
||
db_prize = Prize(**prize_data)
|
||
db.add(db_prize)
|
||
|
||
db.commit()
|
||
|
||
# API Routes
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
db = SessionLocal()
|
||
initialize_prizes(db)
|
||
db.close()
|
||
|
||
@app.get("/")
|
||
def read_root():
|
||
return {"message": "Welcome to the Lottery System API"}
|
||
|
||
# User routes
|
||
@app.post("/users/", response_model=dict)
|
||
def create_user(user: UserCreate, db: Session = Depends(get_db)):
|
||
db_user = User(**user.dict())
|
||
db.add(db_user)
|
||
db.commit()
|
||
db.refresh(db_user)
|
||
return {"success": True, "user_id": db_user.id}
|
||
|
||
@app.put("/users/{user_id}", response_model=dict)
|
||
def update_user(user_id: int, user: UserUpdate, db: Session = Depends(get_db)):
|
||
db_user = db.query(User).filter(User.id == user_id).first()
|
||
if not db_user:
|
||
raise HTTPException(status_code=404, detail="User not found")
|
||
|
||
for key, value in user.dict(exclude_unset=True).items():
|
||
setattr(db_user, key, value)
|
||
|
||
db.commit()
|
||
db.refresh(db_user)
|
||
return {"success": True, "message": "User updated successfully"}
|
||
|
||
# Draw lottery route
|
||
@app.post("/draw/{user_id}", response_model=DrawResult)
|
||
def draw_lottery(user_id: int, db: Session = Depends(get_db)):
|
||
# Verify user exists
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="User not found")
|
||
|
||
# Get active prizes
|
||
prizes = db.query(Prize).filter(Prize.is_active == True, Prize.available_quantity > 0).all()
|
||
|
||
# No prizes available
|
||
if not prizes:
|
||
return DrawResult(success=False, message="No prizes available")
|
||
|
||
# Calculate draw result based on probability
|
||
random_num = random.random()
|
||
cumulative_prob = 0
|
||
selected_prize = None
|
||
|
||
for prize in prizes:
|
||
cumulative_prob += prize.probability
|
||
if random_num <= cumulative_prob:
|
||
selected_prize = prize
|
||
break
|
||
|
||
if not selected_prize:
|
||
selected_prize = prizes[-1] # Default to last prize if no match (should be "谢谢参与")
|
||
|
||
# 记录所有的抽奖结果到user_prizes表,包括"谢谢参与"
|
||
# 这样可以记录每个用户的抽奖历史
|
||
user_prize = UserPrize(
|
||
user_id=user_id,
|
||
prize_id=selected_prize.id,
|
||
is_shipped=False
|
||
)
|
||
db.add(user_prize)
|
||
|
||
# 只有实际奖品才减少库存
|
||
if selected_prize.name != "谢谢参与":
|
||
# 减少可用数量
|
||
selected_prize.available_quantity -= 1
|
||
|
||
# 提交所有更改到数据库
|
||
db.commit()
|
||
|
||
# 记录日志
|
||
print(f"用户 {user_id} 抽中了 {selected_prize.name},信息已存入数据库")
|
||
|
||
# Return prize info
|
||
return DrawResult(
|
||
success=True,
|
||
prize={
|
||
"id": selected_prize.id,
|
||
"name": selected_prize.name,
|
||
"description": selected_prize.description
|
||
}
|
||
)
|
||
|
||
# Card routes
|
||
@app.post("/cards/generate", response_model=dict)
|
||
def generate_card(card: CardCreate, db: Session = Depends(get_db)):
|
||
card_id = f"CARD-{uuid.uuid4().hex[:8].upper()}"
|
||
db_card = Card(card_id=card_id, card_type=card.card_type)
|
||
db.add(db_card)
|
||
db.commit()
|
||
db.refresh(db_card)
|
||
return {"success": True, "card_id": card_id}
|
||
|
||
@app.post("/cards/collect", response_model=dict)
|
||
def collect_card(card_data: CardCollect, db: Session = Depends(get_db)):
|
||
# Check if card exists
|
||
card = db.query(Card).filter(Card.card_id == card_data.card_id).first()
|
||
if not card:
|
||
raise HTTPException(status_code=404, detail="Card not found")
|
||
|
||
# Check if card is already collected
|
||
if card.is_collected:
|
||
return {"success": False, "message": "Card is already collected"}
|
||
|
||
# Collect the card
|
||
card.is_collected = True
|
||
card.collected_by = card_data.user_id
|
||
card.collected_at = datetime.datetime.now()
|
||
db.commit()
|
||
|
||
# Check if user has collected 5 cards
|
||
collected_cards = db.query(Card).filter(
|
||
Card.collected_by == card_data.user_id,
|
||
Card.is_collected == True
|
||
).count()
|
||
|
||
return {
|
||
"success": True,
|
||
"card_id": card.card_id,
|
||
"collected_count": collected_cards,
|
||
"has_complete_set": collected_cards >= 5
|
||
}
|
||
|
||
@app.get("/users/{user_id}/cards", response_model=dict)
|
||
def get_user_cards(user_id: int, db: Session = Depends(get_db)):
|
||
cards = db.query(Card).filter(Card.collected_by == user_id, Card.is_collected == True).all()
|
||
return {
|
||
"success": True,
|
||
"cards": [{"card_id": card.card_id, "card_type": card.card_type} for card in cards],
|
||
"count": len(cards),
|
||
"has_complete_set": len(cards) >= 5
|
||
}
|
||
|
||
# Prize claiming for card collection
|
||
@app.post("/cards/claim-prize/{user_id}", response_model=DrawResult)
|
||
def claim_prize_for_cards(user_id: int, db: Session = Depends(get_db)):
|
||
# Check if user has 5 or more cards
|
||
card_count = db.query(Card).filter(Card.collected_by == user_id, Card.is_collected == True).count()
|
||
if card_count < 5:
|
||
return DrawResult(success=False, message="Not enough cards collected. Need at least 5 cards.")
|
||
|
||
# Get a random prize (excluding "谢谢参与")
|
||
prizes = db.query(Prize).filter(
|
||
Prize.is_active == True,
|
||
Prize.available_quantity > 0,
|
||
Prize.name != "谢谢参与"
|
||
).all()
|
||
|
||
if not prizes:
|
||
return DrawResult(success=False, message="No prizes available")
|
||
|
||
# Select a random prize from available ones
|
||
selected_prize = random.choice(prizes)
|
||
|
||
# Decrement available quantity
|
||
selected_prize.available_quantity -= 1
|
||
|
||
# Record user win
|
||
user_prize = UserPrize(user_id=user_id, prize_id=selected_prize.id)
|
||
db.add(user_prize)
|
||
|
||
db.commit()
|
||
|
||
# Return prize info
|
||
return DrawResult(
|
||
success=True,
|
||
prize={
|
||
"id": selected_prize.id,
|
||
"name": selected_prize.name,
|
||
"description": selected_prize.description
|
||
}
|
||
)
|
||
|
||
# Update shipping address (保留旧接口兼容性)
|
||
@app.post("/shipping/update", response_model=dict)
|
||
async def update_shipping(request: Request, db: Session = Depends(get_db)):
|
||
# 导入json
|
||
import json
|
||
|
||
# 直接从请求体中读取原始数据
|
||
body_bytes = await request.body()
|
||
body_str = body_bytes.decode()
|
||
print(f"Raw request body: {body_str}")
|
||
|
||
# 解析JSON数据
|
||
try:
|
||
body_data = json.loads(body_str) if body_str else {}
|
||
except json.JSONDecodeError:
|
||
print("Failed to parse JSON")
|
||
body_data = {}
|
||
|
||
print(f"Parsed body data: {body_data}")
|
||
|
||
# 从请求体中提取用户ID和地址
|
||
try:
|
||
# 尝试多种可能的键名获取user_id
|
||
user_id = None
|
||
if 'user_id' in body_data:
|
||
user_id = int(body_data['user_id'])
|
||
elif 'userId' in body_data:
|
||
user_id = int(body_data['userId'])
|
||
|
||
if user_id is None:
|
||
raise ValueError("user_id is required but not found in request")
|
||
|
||
# 获取地址
|
||
address = ""
|
||
if 'address' in body_data:
|
||
address = str(body_data['address']).strip()
|
||
|
||
if not address:
|
||
raise HTTPException(status_code=400, detail="Address is required")
|
||
|
||
print(f"Extracted user_id: {user_id}, address: {address}")
|
||
except (ValueError, TypeError) as e:
|
||
print(f"Error processing data: {str(e)}")
|
||
raise HTTPException(status_code=400, detail=f"Invalid data format: {str(e)}")
|
||
|
||
# 查找用户
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="User not found")
|
||
|
||
# 更新地址
|
||
user.address = address
|
||
db.commit()
|
||
|
||
# 标记奖品为已发货
|
||
prizes = db.query(UserPrize).filter(
|
||
UserPrize.user_id == user_id,
|
||
UserPrize.is_shipped == False
|
||
).all()
|
||
|
||
for prize in prizes:
|
||
prize.is_shipped = True
|
||
prize.shipped_at = datetime.datetime.now()
|
||
|
||
db.commit()
|
||
|
||
return {"success": True, "message": "Shipping information updated"}
|
||
|
||
# 专门用于提交收货地址的接口
|
||
@app.post("/users/address", response_model=dict)
|
||
async def submit_address(request: Request, db: Session = Depends(get_db)):
|
||
# 导入json
|
||
import json
|
||
|
||
# 直接从请求体中读取原始数据
|
||
body_bytes = await request.body()
|
||
body_str = body_bytes.decode()
|
||
print(f"Raw request body: {body_str}")
|
||
|
||
# 解析JSON数据
|
||
try:
|
||
body_data = json.loads(body_str) if body_str else {}
|
||
except json.JSONDecodeError:
|
||
print("Failed to parse JSON")
|
||
body_data = {}
|
||
|
||
print(f"Parsed body data: {body_data}")
|
||
|
||
# 从请求体中提取用户ID和地址
|
||
try:
|
||
# 尝试多种可能的键名获取user_id
|
||
user_id = None
|
||
if 'user_id' in body_data:
|
||
user_id = int(body_data['user_id'])
|
||
elif 'userId' in body_data:
|
||
user_id = int(body_data['userId'])
|
||
|
||
if user_id is None:
|
||
raise ValueError("user_id is required but not found in request")
|
||
|
||
# 获取地址
|
||
address = ""
|
||
if 'address' in body_data:
|
||
address = str(body_data['address']).strip()
|
||
|
||
if not address:
|
||
raise HTTPException(status_code=400, detail="Address is required")
|
||
|
||
print(f"Extracted user_id: {user_id}, address: {address}")
|
||
except (ValueError, TypeError) as e:
|
||
print(f"Error processing data: {str(e)}")
|
||
raise HTTPException(status_code=400, detail=f"Invalid data format: {str(e)}")
|
||
|
||
# 查找用户
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="User not found")
|
||
|
||
# 更新地址
|
||
user.address = address
|
||
db.commit()
|
||
|
||
# 标记奖品为已发货
|
||
prizes = db.query(UserPrize).filter(
|
||
UserPrize.user_id == user_id,
|
||
UserPrize.is_shipped == False
|
||
).all()
|
||
|
||
for prize in prizes:
|
||
prize.is_shipped = True
|
||
prize.shipped_at = datetime.datetime.now()
|
||
|
||
db.commit()
|
||
|
||
return {"success": True, "message": "Shipping address submitted successfully"}
|
||
|
||
# Stats routes for dashboard
|
||
@app.get("/stats", response_model=dict)
|
||
def get_stats(db: Session = Depends(get_db)):
|
||
user_count = db.query(User).count()
|
||
prizes_awarded = db.query(UserPrize).count()
|
||
prizes_shipped = db.query(UserPrize).filter(UserPrize.is_shipped == True).count()
|
||
cards_collected = db.query(Card).filter(Card.is_collected == True).count()
|
||
|
||
# Prize distribution
|
||
prize_distribution = db.query(
|
||
Prize.name, func.count(UserPrize.id)
|
||
).join(
|
||
UserPrize, UserPrize.prize_id == Prize.id, isouter=True
|
||
).group_by(
|
||
Prize.id
|
||
).all()
|
||
|
||
return {
|
||
"user_count": user_count,
|
||
"prizes_awarded": prizes_awarded,
|
||
"prizes_shipped": prizes_shipped,
|
||
"cards_collected": cards_collected,
|
||
"prize_distribution": [{"name": name, "count": count} for name, count in prize_distribution]
|
||
}
|
||
|
||
# Run with: uvicorn main:app --reload
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|