2025-04-07 16:56:13 +08:00

267 lines
8.2 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from passlib.context import CryptContext
from typing import List, Optional
from datetime import datetime, timedelta
from pydantic import BaseModel
import jwt
from jose import JWTError, jwt
from models import User, Prize, Card, UserPrize
from sqlalchemy import desc
# Database dependency
def get_db():
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
DATABASE_URL = "sqlite:///./lottery.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
try:
yield db
finally:
db.close()
# Admin router
router = APIRouter(prefix="/admin", tags=["admin"])
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT settings
SECRET_KEY = "YOUR_SECRET_KEY_HERE" # In production, use a secure secret key
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# Token model
class Token(BaseModel):
access_token: str
token_type: str
# TokenData model
class TokenData(BaseModel):
username: Optional[str] = None
# Admin model
class AdminUser(BaseModel):
username: str
email: Optional[str] = None
is_active: bool = True
# Admin in DB
class AdminUserInDB(AdminUser):
hashed_password: str
# Admin list (this would typically come from a database)
fake_admin_db = {
"admin": {
"username": "admin",
"email": "admin@example.com",
"hashed_password": pwd_context.hash("adminpassword"), # In production, use a secure password
"is_active": True,
}
}
# Functions for authentication
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_admin_user(db, username: str):
if username in fake_admin_db:
user_dict = fake_admin_db[username]
return AdminUserInDB(**user_dict)
return None
def authenticate_admin(username: str, password: str):
user = get_admin_user(None, username)
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/admin/token")
# Dependency to get current admin
async def get_current_admin(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_admin_user(None, username=token_data.username)
if user is None:
raise credentials_exception
return user
# Login route
@router.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_admin(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
# Protected route to test authentication
@router.get("/me", response_model=AdminUser)
async def read_admin_me(current_user: AdminUserInDB = Depends(get_current_admin)):
return current_user
# Order model for responses
class OrderResponse(BaseModel):
id: int
user_id: int
user_name: Optional[str] = None
user_phone: Optional[str] = None
user_address: Optional[str] = None
prize_id: int
prize_name: str
prize_description: Optional[str] = None
awarded_at: datetime
is_shipped: bool
shipped_at: Optional[datetime] = None
# Get all orders
@router.get("/orders", response_model=List[OrderResponse])
async def get_orders(
current_user: AdminUserInDB = Depends(get_current_admin),
db: Session = Depends(get_db),
skip: int = 0,
limit: int = 100,
):
# Join query to get orders with user and prize information
orders = (
db.query(
UserPrize.id,
UserPrize.user_id,
User.name.label("user_name"),
User.phone.label("user_phone"),
User.address.label("user_address"),
UserPrize.prize_id,
Prize.name.label("prize_name"),
Prize.description.label("prize_description"),
UserPrize.awarded_at,
UserPrize.is_shipped,
UserPrize.shipped_at,
)
.join(User, UserPrize.user_id == User.id)
.join(Prize, UserPrize.prize_id == Prize.id)
.order_by(desc(UserPrize.awarded_at))
.offset(skip)
.limit(limit)
.all()
)
return [
OrderResponse(
id=order.id,
user_id=order.user_id,
user_name=order.user_name,
user_phone=order.user_phone,
user_address=order.user_address,
prize_id=order.prize_id,
prize_name=order.prize_name,
prize_description=order.prize_description,
awarded_at=order.awarded_at,
is_shipped=order.is_shipped,
shipped_at=order.shipped_at,
)
for order in orders
]
# Update order shipping status
@router.put("/orders/{order_id}/ship", response_model=dict)
async def update_order_status(
order_id: int,
current_user: AdminUserInDB = Depends(get_current_admin),
db: Session = Depends(get_db),
):
order = db.query(UserPrize).filter(UserPrize.id == order_id).first()
if not order:
raise HTTPException(status_code=404, detail="Order not found")
order.is_shipped = True
order.shipped_at = datetime.now()
db.commit()
return {"success": True, "message": "Order marked as shipped"}
# Get order details
@router.get("/orders/{order_id}", response_model=OrderResponse)
async def get_order_details(
order_id: int,
current_user: AdminUserInDB = Depends(get_current_admin),
db: Session = Depends(get_db),
):
order = (
db.query(
UserPrize.id,
UserPrize.user_id,
User.name.label("user_name"),
User.phone.label("user_phone"),
User.address.label("user_address"),
UserPrize.prize_id,
Prize.name.label("prize_name"),
Prize.description.label("prize_description"),
UserPrize.awarded_at,
UserPrize.is_shipped,
UserPrize.shipped_at,
)
.join(User, UserPrize.user_id == User.id)
.join(Prize, UserPrize.prize_id == Prize.id)
.filter(UserPrize.id == order_id)
.first()
)
if not order:
raise HTTPException(status_code=404, detail="Order not found")
return OrderResponse(
id=order.id,
user_id=order.user_id,
user_name=order.user_name,
user_phone=order.user_phone,
user_address=order.user_address,
prize_id=order.prize_id,
prize_name=order.prize_name,
prize_description=order.prize_description,
awarded_at=order.awarded_at,
is_shipped=order.is_shipped,
shipped_at=order.shipped_at,
)