1. Python FastAPI概述

FastAPI是一个现代、快速的Python Web框架,基于标准Python类型提示构建,具有高性能、自动API文档生成、数据验证等特点。本文将详细介绍FastAPI框架、异步编程、API设计、数据验证和性能优化的完整解决方案。

1.1 核心功能

  1. 高性能API: 基于Starlette和Pydantic的高性能API框架
  2. 异步编程: 原生支持异步和并发编程
  3. 数据验证: 自动数据验证和序列化
  4. API文档: 自动生成交互式API文档
  5. 类型提示: 基于Python类型提示的API设计

1.2 技术架构

1
2
3
4
5
客户端 → FastAPI应用 → 路由处理 → 数据验证
↓ ↓ ↓ ↓
HTTP请求 → 中间件 → 业务逻辑 → 响应序列化
↓ ↓ ↓ ↓
异步处理 → 依赖注入 → 数据库 → JSON响应

2. FastAPI配置

2.1 FastAPI应用配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
FastAPI应用配置
"""
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseSettings
import uvicorn
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Settings(BaseSettings):
"""应用配置类"""
app_name: str = "FastAPI应用"
app_version: str = "1.0.0"
debug: bool = False
host: str = "0.0.0.0"
port: int = 8000

# 数据库配置
database_url: str = "sqlite:///./app.db"

# Redis配置
redis_url: str = "redis://localhost:6379"

# JWT配置
secret_key: str = "your-secret-key"
algorithm: str = "HS256"
access_token_expire_minutes: int = 30

class Config:
env_file = ".env"

# 创建配置实例
settings = Settings()

# 创建FastAPI应用
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description="高性能异步API开发框架",
debug=settings.debug
)

# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# 添加受信任主机中间件
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["*"]
)

@app.get("/")
async def root():
"""根路径"""
return {"message": "欢迎使用FastAPI", "version": settings.app_version}

@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "timestamp": "2024-01-01T00:00:00Z"}

if __name__ == "__main__":
uvicorn.run(
"main:app",
host=settings.host,
port=settings.port,
reload=settings.debug
)

2.2 数据库配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
数据库配置
"""
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from datetime import datetime
from typing import Generator

# 数据库配置
DATABASE_URL = "sqlite:///./app.db"

# 创建数据库引擎
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False} # SQLite特定配置
)

# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 创建基础模型类
Base = declarative_base()

class User(Base):
"""用户模型"""
__tablename__ = "users"

id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True, nullable=False)
email = Column(String(100), unique=True, index=True, nullable=False)
hashed_password = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

class Product(Base):
"""产品模型"""
__tablename__ = "products"

id = Column(Integer, primary_key=True, index=True)
name = Column(String(100), nullable=False)
description = Column(String(500))
price = Column(Integer, nullable=False) # 价格以分为单位
stock = Column(Integer, default=0)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

# 创建数据库表
Base.metadata.create_all(bind=engine)

def get_db() -> Generator[Session, None, None]:
"""获取数据库会话"""
db = SessionLocal()
try:
yield db
finally:
db.close()

3. 数据模型和验证

3.1 Pydantic模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Pydantic数据模型
"""
from pydantic import BaseModel, EmailStr, validator
from typing import Optional, List
from datetime import datetime
from enum import Enum

class UserRole(str, Enum):
"""用户角色枚举"""
ADMIN = "admin"
USER = "user"
GUEST = "guest"

class UserBase(BaseModel):
"""用户基础模型"""
username: str
email: EmailStr
is_active: bool = True

class UserCreate(UserBase):
"""用户创建模型"""
password: str

@validator('password')
def validate_password(cls, v):
"""密码验证"""
if len(v) < 6:
raise ValueError('密码长度不能少于6位')
return v

@validator('username')
def validate_username(cls, v):
"""用户名验证"""
if len(v) < 3:
raise ValueError('用户名长度不能少于3位')
return v

class UserUpdate(BaseModel):
"""用户更新模型"""
username: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool] = None

class UserResponse(UserBase):
"""用户响应模型"""
id: int
created_at: datetime
updated_at: datetime

class Config:
from_attributes = True

class ProductBase(BaseModel):
"""产品基础模型"""
name: str
description: Optional[str] = None
price: int
stock: int = 0

class ProductCreate(ProductBase):
"""产品创建模型"""

@validator('price')
def validate_price(cls, v):
"""价格验证"""
if v <= 0:
raise ValueError('价格必须大于0')
return v

@validator('name')
def validate_name(cls, v):
"""产品名称验证"""
if len(v.strip()) < 2:
raise ValueError('产品名称不能少于2个字符')
return v.strip()

class ProductUpdate(BaseModel):
"""产品更新模型"""
name: Optional[str] = None
description: Optional[str] = None
price: Optional[int] = None
stock: Optional[int] = None

class ProductResponse(ProductBase):
"""产品响应模型"""
id: int
is_active: bool
created_at: datetime
updated_at: datetime

class Config:
from_attributes = True

class Token(BaseModel):
"""令牌模型"""
access_token: str
token_type: str

class TokenData(BaseModel):
"""令牌数据模型"""
username: Optional[str] = None

class PaginationParams(BaseModel):
"""分页参数模型"""
page: int = 1
size: int = 10

@validator('page')
def validate_page(cls, v):
"""页码验证"""
if v < 1:
raise ValueError('页码必须大于0')
return v

@validator('size')
def validate_size(cls, v):
"""页面大小验证"""
if v < 1 or v > 100:
raise ValueError('页面大小必须在1-100之间')
return v

class PaginatedResponse(BaseModel):
"""分页响应模型"""
items: List[dict]
total: int
page: int
size: int
pages: int

4. 用户管理API

4.1 用户认证和授权

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
用户认证和授权
"""
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from typing import Optional
import logging

logger = logging.getLogger(__name__)

# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# OAuth2密码承载令牌
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# JWT配置
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password: str) -> str:
"""获取密码哈希"""
return pwd_context.hash(password)

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

def authenticate_user(db: Session, username: str, password: str):
"""认证用户"""
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not verify_password(password, user.hashed_password):
return False
return user

def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception

user = db.query(User).filter(User.username == token_data.username).first()
if user is None:
raise credentials_exception
return user

def get_current_active_user(current_user: User = Depends(get_current_user)):
"""获取当前活跃用户"""
if not current_user.is_active:
raise HTTPException(status_code=400, detail="用户已被禁用")
return current_user

4.2 用户管理API

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
用户管理API
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
import logging

logger = logging.getLogger(__name__)

# 创建路由器
router = APIRouter(prefix="/users", tags=["用户管理"])

@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(user: UserCreate, db: Session = Depends(get_db)):
"""创建用户"""
try:
# 检查用户名是否已存在
db_user = db.query(User).filter(User.username == user.username).first()
if db_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)

# 检查邮箱是否已存在
db_user = db.query(User).filter(User.email == user.email).first()
if db_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已存在"
)

# 创建用户
hashed_password = get_password_hash(user.password)
db_user = User(
username=user.username,
email=user.email,
hashed_password=hashed_password,
is_active=user.is_active
)

db.add(db_user)
db.commit()
db.refresh(db_user)

logger.info(f"用户创建成功: {user.username}")

return db_user

except Exception as e:
logger.error(f"创建用户失败: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="创建用户失败"
)

@router.get("/", response_model=List[UserResponse])
async def get_users(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""获取用户列表"""
try:
users = db.query(User).offset(skip).limit(limit).all()
return users

except Exception as e:
logger.error(f"获取用户列表失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取用户列表失败"
)

@router.get("/{user_id}", response_model=UserResponse)
async def get_user(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""获取用户详情"""
try:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)

return user

except HTTPException:
raise
except Exception as e:
logger.error(f"获取用户详情失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取用户详情失败"
)

@router.put("/{user_id}", response_model=UserResponse)
async def update_user(
user_id: int,
user_update: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""更新用户"""
try:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)

# 更新用户信息
update_data = user_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(user, field, value)

db.commit()
db.refresh(user)

logger.info(f"用户更新成功: {user.username}")

return user

except HTTPException:
raise
except Exception as e:
logger.error(f"更新用户失败: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="更新用户失败"
)

@router.delete("/{user_id}")
async def delete_user(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""删除用户"""
try:
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)

db.delete(user)
db.commit()

logger.info(f"用户删除成功: {user.username}")

return {"message": "用户删除成功"}

except HTTPException:
raise
except Exception as e:
logger.error(f"删除用户失败: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="删除用户失败"
)

@router.post("/token", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
):
"""用户登录获取令牌"""
try:
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)

access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)

logger.info(f"用户登录成功: {user.username}")

return {"access_token": access_token, "token_type": "bearer"}

except HTTPException:
raise
except Exception as e:
logger.error(f"用户登录失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="登录失败"
)

5. 产品管理API

5.1 产品管理API

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
产品管理API
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
import logging

logger = logging.getLogger(__name__)

# 创建路由器
product_router = APIRouter(prefix="/products", tags=["产品管理"])

@product_router.post("/", response_model=ProductResponse, status_code=status.HTTP_201_CREATED)
async def create_product(
product: ProductCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""创建产品"""
try:
db_product = Product(
name=product.name,
description=product.description,
price=product.price,
stock=product.stock
)

db.add(db_product)
db.commit()
db.refresh(db_product)

logger.info(f"产品创建成功: {product.name}")

return db_product

except Exception as e:
logger.error(f"创建产品失败: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="创建产品失败"
)

@product_router.get("/", response_model=List[ProductResponse])
async def get_products(
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(10, ge=1, le=100, description="返回的记录数"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db)
):
"""获取产品列表"""
try:
query = db.query(Product).filter(Product.is_active == True)

# 添加搜索条件
if search:
query = query.filter(Product.name.contains(search))

products = query.offset(skip).limit(limit).all()

return products

except Exception as e:
logger.error(f"获取产品列表失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取产品列表失败"
)

@product_router.get("/{product_id}", response_model=ProductResponse)
async def get_product(
product_id: int,
db: Session = Depends(get_db)
):
"""获取产品详情"""
try:
product = db.query(Product).filter(
Product.id == product_id,
Product.is_active == True
).first()

if not product:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="产品不存在"
)

return product

except HTTPException:
raise
except Exception as e:
logger.error(f"获取产品详情失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取产品详情失败"
)

@product_router.put("/{product_id}", response_model=ProductResponse)
async def update_product(
product_id: int,
product_update: ProductUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""更新产品"""
try:
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="产品不存在"
)

# 更新产品信息
update_data = product_update.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(product, field, value)

db.commit()
db.refresh(product)

logger.info(f"产品更新成功: {product.name}")

return product

except HTTPException:
raise
except Exception as e:
logger.error(f"更新产品失败: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="更新产品失败"
)

@product_router.delete("/{product_id}")
async def delete_product(
product_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""删除产品(软删除)"""
try:
product = db.query(Product).filter(Product.id == product_id).first()
if not product:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="产品不存在"
)

# 软删除:设置is_active为False
product.is_active = False
db.commit()

logger.info(f"产品删除成功: {product.name}")

return {"message": "产品删除成功"}

except HTTPException:
raise
except Exception as e:
logger.error(f"删除产品失败: {str(e)}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="删除产品失败"
)

@product_router.get("/search/", response_model=List[ProductResponse])
async def search_products(
q: str = Query(..., min_length=2, description="搜索关键词"),
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db)
):
"""搜索产品"""
try:
products = db.query(Product).filter(
Product.is_active == True,
Product.name.contains(q)
).offset(skip).limit(limit).all()

return products

except Exception as e:
logger.error(f"搜索产品失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="搜索产品失败"
)

6. 中间件和异常处理

6.1 自定义中间件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
自定义中间件
"""
from fastapi import Request, Response
from fastapi.responses import JSONResponse
import time
import logging
import uuid

logger = logging.getLogger(__name__)

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
"""添加处理时间中间件"""
start_time = time.time()

# 生成请求ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id

# 记录请求开始
logger.info(f"请求开始: {request.method} {request.url} - ID: {request_id}")

# 处理请求
response = await call_next(request)

# 计算处理时间
process_time = time.time() - start_time

# 添加响应头
response.headers["X-Process-Time"] = str(process_time)
response.headers["X-Request-ID"] = request_id

# 记录请求结束
logger.info(f"请求结束: {request.method} {request.url} - ID: {request_id} - 耗时: {process_time:.3f}s")

return response

@app.middleware("http")
async def add_cors_header(request: Request, call_next):
"""添加CORS头中间件"""
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*"
return response

# 全局异常处理器
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""HTTP异常处理器"""
request_id = getattr(request.state, 'request_id', 'unknown')

logger.error(f"HTTP异常: {exc.status_code} - {exc.detail} - ID: {request_id}")

return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.status_code,
"message": exc.detail,
"request_id": request_id
}
}
)

@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""通用异常处理器"""
request_id = getattr(request.state, 'request_id', 'unknown')

logger.error(f"未处理异常: {str(exc)} - ID: {request_id}", exc_info=True)

return JSONResponse(
status_code=500,
content={
"error": {
"code": 500,
"message": "内部服务器错误",
"request_id": request_id
}
}
)

7. 性能优化

7.1 缓存和异步优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
性能优化
"""
import redis
import asyncio
from functools import wraps
from typing import Any, Callable
import json
import logging

logger = logging.getLogger(__name__)

# Redis连接
redis_client = redis.Redis(host='localhost', port=6379, db=0)

def cache_result(expire_time: int = 300):
"""缓存结果装饰器"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{func.__name__}:{hash(str(args) + str(kwargs))}"

try:
# 尝试从缓存获取
cached_result = redis_client.get(cache_key)
if cached_result:
logger.info(f"缓存命中: {cache_key}")
return json.loads(cached_result)

# 执行函数
result = await func(*args, **kwargs)

# 存储到缓存
redis_client.setex(
cache_key,
expire_time,
json.dumps(result, default=str)
)

logger.info(f"缓存存储: {cache_key}")
return result

except Exception as e:
logger.error(f"缓存操作失败: {str(e)}")
# 缓存失败时直接执行函数
return await func(*args, **kwargs)

return wrapper
return decorator

@cache_result(expire_time=600)
async def get_cached_products(skip: int, limit: int, search: str = None):
"""获取缓存的产品列表"""
# 模拟数据库查询
await asyncio.sleep(0.1) # 模拟数据库延迟

products = [
{
"id": i,
"name": f"产品{i}",
"description": f"产品{i}的描述",
"price": 1000 + i * 100,
"stock": 50 - i
}
for i in range(skip, skip + limit)
]

if search:
products = [p for p in products if search in p["name"]]

return products

# 异步任务队列
class TaskQueue:
"""异步任务队列"""

def __init__(self):
self.tasks = []

async def add_task(self, func: Callable, *args, **kwargs):
"""添加任务"""
task = asyncio.create_task(func(*args, **kwargs))
self.tasks.append(task)
return task

async def wait_all(self):
"""等待所有任务完成"""
if self.tasks:
await asyncio.gather(*self.tasks)
self.tasks.clear()

# 全局任务队列
task_queue = TaskQueue()

@app.post("/products/batch")
async def create_products_batch(
products: List[ProductCreate],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
"""批量创建产品"""
try:
created_products = []

# 使用异步任务队列
for product_data in products:
task = await task_queue.add_task(
create_single_product,
product_data,
db
)
created_products.append(task)

# 等待所有任务完成
await task_queue.wait_all()

logger.info(f"批量创建产品成功: {len(products)}个")

return {"message": f"成功创建{len(products)}个产品"}

except Exception as e:
logger.error(f"批量创建产品失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="批量创建产品失败"
)

async def create_single_product(product_data: ProductCreate, db: Session):
"""创建单个产品"""
db_product = Product(
name=product_data.name,
description=product_data.description,
price=product_data.price,
stock=product_data.stock
)

db.add(db_product)
db.commit()
db.refresh(db_product)

return db_product

8. 测试和部署

8.1 单元测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
单元测试
"""
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from main import app, get_db
from database import Base

# 测试数据库
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

def override_get_db():
"""重写数据库依赖"""
try:
db = TestingSessionLocal()
yield db
finally:
db.close()

app.dependency_overrides[get_db] = override_get_db

# 创建测试客户端
client = TestClient(app)

def test_create_user():
"""测试创建用户"""
user_data = {
"username": "testuser",
"email": "test@example.com",
"password": "testpassword"
}

response = client.post("/users/", json=user_data)
assert response.status_code == 201

data = response.json()
assert data["username"] == user_data["username"]
assert data["email"] == user_data["email"]
assert "id" in data

def test_get_users():
"""测试获取用户列表"""
response = client.get("/users/")
assert response.status_code == 200

data = response.json()
assert isinstance(data, list)

def test_create_product():
"""测试创建产品"""
product_data = {
"name": "测试产品",
"description": "测试产品描述",
"price": 1000,
"stock": 10
}

response = client.post("/products/", json=product_data)
assert response.status_code == 201

data = response.json()
assert data["name"] == product_data["name"]
assert data["price"] == product_data["price"]
assert "id" in data

def test_get_products():
"""测试获取产品列表"""
response = client.get("/products/")
assert response.status_code == 200

data = response.json()
assert isinstance(data, list)

if __name__ == "__main__":
pytest.main([__file__])

9. 总结

通过Python FastAPI的实现,我们成功构建了一个高性能的异步API框架。关键特性包括:

9.1 核心优势

  1. 高性能API: 基于Starlette和Pydantic的高性能API框架
  2. 异步编程: 原生支持异步和并发编程
  3. 数据验证: 自动数据验证和序列化
  4. API文档: 自动生成交互式API文档
  5. 类型提示: 基于Python类型提示的API设计

9.2 最佳实践

  1. API设计: RESTful API设计原则
  2. 数据验证: 使用Pydantic进行数据验证
  3. 异步编程: 充分利用异步编程优势
  4. 错误处理: 完善的异常处理机制
  5. 性能优化: 缓存和异步优化策略

这套Python FastAPI方案不仅能够提供高性能的API开发能力,还包含了数据验证、异步编程、性能优化等核心功能,是现代Python Web开发的重要框架。