三、安全优化:让系统更坚固
3.1 安全威胁模型
┌─────────────────────────────────────────────────────────────────┐
│ 常见安全威胁 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Web层 │
│ ├── SQL注入:恶意SQL语句绕过安全检查 │
│ ├── XSS:注入恶意脚本到网页 │
│ ├── CSRF:伪造用户请求 │
│ └── SSRF:服务器端请求伪造 │
│ │
│ 认证层 │
│ ├── 暴力破解:尝试大量密码组合 │
│ ├── 会话劫持:窃取用户Session │
│ ├── JWT攻击:算法混淆、密钥泄露 │
│ └── OAuth漏洞:授权码拦截 │
│ │
│ 数据层 │
│ ├── 数据泄露:敏感信息暴露 │
│ ├── 权限绕过:越权访问数据 │
│ └── 注入攻击:NoSQL注入、LDAP注入 │
│ │
│ 基础设施层 │
│ ├── DDoS攻击:耗尽系统资源 │
│ ├── 容器逃逸:从容器攻击宿主机 │
│ └── 依赖投毒:恶意依赖包 │
│ │
└─────────────────────────────────────────────────────────────────┘
3.2 输入验证与防御
3.2.1 SQL注入防御
# ❌ 危险:字符串拼接
def get_user_dangerous(user_id):
query = f"SELECT * FROM users WHERE id = {user_id}" # 可注入
return db.execute(query)
# ✅ 安全:参数化查询
def get_user_safe(user_id):
query = "SELECT * FROM users WHERE id = %s"
return db.execute(query, (user_id,))
# ✅ 使用ORM更安全
from sqlalchemy import text
def get_user_orm(user_id):
return db.execute(
text("SELECT * FROM users WHERE id = :id"),
{"id": user_id}
).fetchone()
# 输入过滤
import re
def sanitize_input(input_str):
"""过滤危险字符"""
# 移除SQL注释
input_str = re.sub(r"--.*$", "", input_str)
# 移除分号(防止多语句)
input_str = input_str.replace(";", "")
# 转义单引号
input_str = input_str.replace("'", "''")
return input_str
3.2.2 XSS防御
from html import escape
import bleach
# ❌ 危险:直接输出用户输入
def render_comment_dangerous(comment):
return f"<div>{comment}</div>" # 可注入<script>alert('xss')</script>
# ✅ 安全:HTML转义
def render_comment_safe(comment):
return f"<div>{escape(comment)}</div>"
# ✅ 使用白名单过滤(允许部分HTML标签)
ALLOWED_TAGS = ['b', 'i', 'em', 'strong', 'a', 'p', 'br']
ALLOWED_ATTRIBUTES = {'a': ['href', 'title']}
def render_comment_whitelist(comment):
return bleach.clean(comment, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRIBUTES)
# Content Security Policy 响应头
def add_csp_headers(response):
response.headers['Content-Security-Policy'] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' https://trusted.cdn.com; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self'; "
"connect-src 'self' https://api.example.com; "
"frame-ancestors 'none'; "
"base-uri 'self'; "
"form-action 'self'"
)
return response
3.2.3 认证与授权
# JWT安全实践
import jwt
from datetime import datetime, timedelta
class JWTManager:
"""安全的JWT管理"""
def __init__(self, secret_key, algorithm="HS256", access_ttl=900, refresh_ttl=86400):
self.secret_key = secret_key
self.algorithm = algorithm
self.access_ttl = access_ttl # 15分钟
self.refresh_ttl = refresh_ttl # 24小时
def create_access_token(self, user_id, roles=None):
"""创建访问令牌(短时效)"""
payload = {
"sub": str(user_id),
"type": "access",
"roles": roles or [],
"iat": datetime.utcnow(),
"exp": datetime.utcnow() + timedelta(seconds=self.access_ttl),
"jti": str(uuid.uuid4()) # JWT ID,防止重放
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def create_refresh_token(self, user_id):
"""创建刷新令牌(长时效)"""
payload = {
"sub": str(user_id),
"type": "refresh",
"iat": datetime.utcnow(),
"exp": datetime.utcnow() + timedelta(seconds=self.refresh_ttl),
"jti": str(uuid.uuid4())
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def verify_token(self, token, expected_type="access"):
"""验证令牌"""
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
options={"require": ["exp", "iat", "jti", "type"]}
)
# 检查令牌类型
if payload.get("type") != expected_type:
raise InvalidTokenTypeError()
# 检查是否在黑名单中(需要Redis存储)
if self.is_token_revoked(payload["jti"]):
raise TokenRevokedError()
return payload
except jwt.ExpiredSignatureError:
raise TokenExpiredError()
except jwt.InvalidTokenError as e:
raise InvalidTokenError(str(e))
def revoke_token(self, jti):
"""撤销令牌"""
self.redis.setex(f"revoked:{jti}", self.access_ttl, "1")
# 权限检查装饰器
def require_roles(*required_roles):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 从上下文中获取用户信息
user = get_current_user()
user_roles = user.get("roles", [])
# 检查是否有任一所需角色
if not any(role in user_roles for role in required_roles):
raise PermissionDeniedError(f"Required roles: {required_roles}")
return func(*args, **kwargs)
return wrapper
return decorator
# 使用示例
@require_roles("admin", "superuser")
def delete_user(user_id):
"""删除用户 - 只有管理员可操作"""
# 业务逻辑
pass
3.3 敏感数据保护
3.3.1 数据加密
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
import base64
import hashlib
class DataEncryption:
"""数据加密工具"""
def __init__(self, master_key):
# 从主密钥派生加密密钥
kdf = PBKDF2(
algorithm=hashlib.sha256,
length=32,
salt=b'salt_should_be_random',
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(master_key.encode()))
self.cipher = Fernet(key)
def encrypt_sensitive(self, data):
"""加密敏感数据"""
if isinstance(data, str):
data = data.encode()
return self.cipher.encrypt(data).decode()
def decrypt_sensitive(self, encrypted_data):
"""解密敏感数据"""
return self.cipher.decrypt(encrypted_data.encode()).decode()
# 字段级加密
class EncryptedField:
"""自动加密的字段"""
def __init__(self, encryption_key):
self.cipher = Fernet(encryption_key)
def __get__(self, instance, owner):
if instance is None:
return self
encrypted = instance.__dict__.get(self.name)
if encrypted is None:
return None
return self.cipher.decrypt(encrypted).decode()
def __set__(self, instance, value):
if value is None:
instance.__dict__[self.name] = None
else:
encrypted = self.cipher.encrypt(value.encode())
instance.__dict__[self.name] = encrypted
def __set_name__(self, owner, name):
self.name = name
# 使用模型
class User:
ssn = EncryptedField(encryption_key) # 社会安全号自动加密
credit_card = EncryptedField(encryption_key)
def __init__(self, name, ssn, credit_card):
self.name = name
self.ssn = ssn # 自动加密存储
self.credit_card = credit_card # 自动加密存储
# 数据库中的密文存储
# INSERT INTO users (name, ssn) VALUES ('Alice', 'gAAAAAB...')
3.3.2 敏感信息脱敏
class DataMasking:
"""数据脱敏工具"""
@staticmethod
def mask_phone(phone):
"""手机号脱敏:138****1234"""
if not phone or len(phone) < 11:
return phone
return phone[:3] + "****" + phone[-4:]
@staticmethod
def mask_email(email):
"""邮箱脱敏:a***@example.com"""
if not email or "@" not in email:
return email
local, domain = email.split("@", 1)
if len(local) <= 2:
masked_local = "*" * len(local)
else:
masked_local = local[0] + "***" + local[-1]
return f"{masked_local}@{domain}"
@staticmethod
def mask_id_card(id_card):
"""身份证脱敏:3301**********1234"""
if not id_card or len(id_card) < 18:
return id_card
return id_card[:4] + "*" * 10 + id_card[-4:]
@staticmethod
def mask_bank_card(card_no):
"""银行卡脱敏:**** **** **** 1234"""
if not card_no or len(card_no) < 16:
return card_no
return "**** **** **** " + card_no[-4:]
@staticmethod
def mask_json(data, sensitive_fields):
"""递归脱敏JSON数据"""
if isinstance(data, dict):
result = {}
for key, value in data.items():
if key in sensitive_fields:
result[key] = "******"
else:
result[key] = DataMasking.mask_json(value, sensitive_fields)
return result
elif isinstance(data, list):
return [DataMasking.mask_json(item, sensitive_fields) for item in data]
else:
return data
# 日志脱敏
class SensitiveDataFilter(logging.Filter):
"""日志过滤器:自动脱敏敏感信息"""
SENSITIVE_PATTERNS = [
(r'password["\']?\s*[:=]\s*["\']?([^"\'\s]+)', r'password="***"'),
(r'token["\']?\s*[:=]\s*["\']?([^"\'\s]+)', r'token="***"'),
(r'api_key["\']?\s*[:=]\s*["\']?([^"\'\s]+)', r'api_key="***"'),
(r'credit_card["\']?\s*[:=]\s*["\']?(\d{13,19})', r'credit_card="***"'),
]
def filter(self, record):
msg = record.getMessage()
for pattern, replacement in self.SENSITIVE_PATTERNS:
msg = re.sub(pattern, replacement, msg, flags=re.IGNORECASE)
record.msg = msg
return True
# 配置日志过滤器
logging.getLogger().addFilter(SensitiveDataFilter())
3.4 API安全
# API安全最佳实践
# 1. 请求签名验证
class APISignature:
"""API签名验证"""
def __init__(self, secret_key):
self.secret_key = secret_key
def generate_signature(self, method, path, params, timestamp):
"""生成请求签名"""
# 排序参数
sorted_params = sorted(params.items()) if params else []
param_str = "&".join(f"{k}={v}" for k, v in sorted_params)
# 构造签名串
sign_str = f"{method}\n{path}\n{param_str}\n{timestamp}"
# HMAC-SHA256签名
signature = hmac.new(
self.secret_key.encode(),
sign_str.encode(),
hashlib.sha256
).hexdigest()
return signature
def verify_signature(self, request):
"""验证请求签名"""
signature = request.headers.get("X-Signature")
timestamp = request.headers.get("X-Timestamp")
if not signature or not timestamp:
return False
# 检查时间戳(防重放攻击,5分钟内有效)
if abs(time.time() - int(timestamp)) > 300:
return False
expected = self.generate_signature(
request.method,
request.path,
request.args.to_dict(),
timestamp
)
return hmac.compare_digest(signature, expected)
# 2. 速率限制
from fastapi import Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
@app.post("/api/login")
@limiter.limit("5/minute") # 每分钟最多5次登录尝试
async def login(request: Request, credentials: LoginRequest):
# 登录逻辑
pass
@app.post("/api/orders")
@limiter.limit("100/minute") # 每分钟最多100个订单
async def create_order(request: Request, order: OrderRequest):
# 创建订单
pass
# 3. 请求验证
from pydantic import BaseModel, validator, Field
class CreateUserRequest(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
email: str = Field(..., regex=r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
password: str = Field(..., min_length=8)
@validator('password')
def validate_password(cls, v):
"""密码强度验证"""
if not re.search(r'[A-Z]', v):
raise ValueError('Password must contain at least one uppercase letter')
if not re.search(r'[a-z]', v):
raise ValueError('Password must contain at least one lowercase letter')
if not re.search(r'\d', v):
raise ValueError('Password must contain at least one digit')
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', v):
raise ValueError('Password must contain at least one special character')
return v
# 4. CORS配置
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://trusted-domain.com"], # 明确指定,不使用"*"
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Authorization", "Content-Type"],
expose_headers=["X-Request-Id"],
max_age=3600,
)