1. 添加README说明项目结构 2. 配置Python和Node.js的.gitignore 3. 包含认证模块和账号管理的前后端基础代码 4. 开发计划文档记录当前阶段任务
365 lines
13 KiB
Python
365 lines
13 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from fastapi import FastAPI, Depends, HTTPException, status, Request, Body, Response
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
from models import AccountCreate, AccountResponse, PasswordChange, TenantCreate, TenantResponse
|
|
from account_manager import AccountManager as DifyAccountManager
|
|
from backend_account_manager import BackendAccountManager
|
|
|
|
backend_account_manager = BackendAccountManager()
|
|
from tenant_manager import TenantManager
|
|
from api_user_manager import APIAuthManager
|
|
from operation_logger import OperationLogger
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
import logging
|
|
from account_manager import AccountManager
|
|
from database import get_db_cursor
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# JWT配置
|
|
SECRET_KEY = "your-secret-key-here" # 生产环境应该从环境变量获取
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
# 密码哈希
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# OAuth2方案
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
|
|
app = FastAPI()
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
expose_headers=["*"]
|
|
)
|
|
api_auth = APIAuthManager()
|
|
op_logger = OperationLogger()
|
|
|
|
def verify_password(plain_password: str, hashed_password: str):
|
|
"""验证密码"""
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
|
|
def get_password_hash(password: str):
|
|
"""生成密码哈希"""
|
|
return pwd_context.hash(password)
|
|
|
|
def authenticate_user(username: str, password: str):
|
|
"""认证用户"""
|
|
try:
|
|
user = AccountManager.get_user_by_username(username)
|
|
if not user:
|
|
return False
|
|
|
|
# 兼容测试用户
|
|
if isinstance(user, tuple):
|
|
user_dict = {
|
|
"id": user[0],
|
|
"username": user[1],
|
|
"email": user[2],
|
|
"password": user[3],
|
|
"password_salt": user[4]
|
|
}
|
|
if not AccountManager.verify_password(password, user_dict["password"], user_dict["password_salt"]):
|
|
return False
|
|
return user_dict
|
|
else:
|
|
if not AccountManager.verify_password(password, user["password"], user["password_salt"]):
|
|
return False
|
|
return user
|
|
except Exception as e:
|
|
logger.error(f"认证失败: {e}")
|
|
return False
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|
"""创建访问令牌"""
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.now(timezone.utc) + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
return encoded_jwt
|
|
|
|
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
"""获取当前用户"""
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="无法验证凭据",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
try:
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
username: str = payload.get("sub")
|
|
if username is None:
|
|
raise credentials_exception
|
|
except JWTError:
|
|
raise credentials_exception
|
|
|
|
user = AccountManager.get_user_by_username(username)
|
|
if user is None:
|
|
raise credentials_exception
|
|
return {
|
|
"id": user["id"],
|
|
"username": user["username"],
|
|
"email": user["email"]
|
|
}
|
|
|
|
# 认证路由组
|
|
@app.options("/api/auth/register", include_in_schema=False)
|
|
async def auth_register_options():
|
|
"""处理OPTIONS预检请求"""
|
|
response = Response(status_code=204)
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS"
|
|
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
|
|
return response
|
|
|
|
@app.post("/api/auth/register")
|
|
async def auth_register(request: Request):
|
|
"""注册后台管理账号"""
|
|
form_data = await request.form()
|
|
username = form_data.get("username", "").strip()
|
|
password = form_data.get("password", "").strip()
|
|
email = form_data.get("email", "").strip()
|
|
|
|
if not all([username, password, email]):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail="缺少必要参数"
|
|
)
|
|
try:
|
|
# 检查用户名是否已存在
|
|
if backend_account_manager.get_user_by_username(username):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="用户名已存在"
|
|
)
|
|
|
|
# 创建后台管理账号
|
|
user = backend_account_manager.create_account(username, email, password)
|
|
return {
|
|
"user_id": str(user["id"]),
|
|
"username": user["username"],
|
|
"email": user["email"]
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"注册后台账号失败: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="注册失败"
|
|
)
|
|
|
|
@app.post("/api/auth/login")
|
|
async def auth_login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
"""用户登录(auth)"""
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user = backend_account_manager.get_user_by_username(form_data.username)
|
|
if not user or not backend_account_manager.verify_password(form_data.password, user["password"], user["password_salt"]):
|
|
op_logger.log_operation(
|
|
user_id=0,
|
|
operation_type="LOGIN_ATTEMPT",
|
|
endpoint="/api/auth/login",
|
|
parameters=f"username={form_data.username}, ip={client_ip}",
|
|
status="FAILED"
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="用户名或密码错误",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": user["username"]},
|
|
expires_delta=access_token_expires
|
|
)
|
|
op_logger.log_operation(
|
|
user_id=user["id"],
|
|
operation_type="LOGIN",
|
|
endpoint="/api/auth/login",
|
|
parameters=f"ip={client_ip}",
|
|
status="SUCCESS"
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
@app.post("/api/user/login")
|
|
async def user_login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
"""用户登录(user)"""
|
|
return await auth_login(request, form_data)
|
|
|
|
@app.post("/api/auth/refresh")
|
|
async def refresh_token(current_user: dict = Depends(get_current_user)):
|
|
"""刷新Token"""
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
access_token = create_access_token(
|
|
data={"sub": current_user["username"]},
|
|
expires_delta=access_token_expires
|
|
)
|
|
return {"access_token": access_token, "token_type": "bearer"}
|
|
|
|
# 账户管理路由组
|
|
@app.post("/api/dify_accounts/")
|
|
async def create_dify_account(account: AccountCreate):
|
|
"""创建Dify账户"""
|
|
try:
|
|
user = AccountManager.create_account(account.username, account.email, account.password)
|
|
return {
|
|
"user_id": str(user["id"]),
|
|
"username": user["username"],
|
|
"email": user["email"],
|
|
"created_at": user["created_at"]
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"创建Dify账户失败: {e}")
|
|
raise HTTPException(status_code=400, detail="创建Dify账户失败")
|
|
|
|
@app.get("/api/accounts/search")
|
|
async def search_accounts(
|
|
search: str = None,
|
|
page: int = 1,
|
|
page_size: int = 10,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""搜索账户"""
|
|
try:
|
|
accounts = AccountManager.search_accounts(search, page, page_size)
|
|
return {
|
|
"accounts": [{
|
|
"id": str(a["id"]),
|
|
"username": a["username"],
|
|
"email": a["email"],
|
|
"status": a.get("status", "active"),
|
|
"created_at": a.get("created_at", datetime.now(timezone.utc))
|
|
} for a in accounts["data"]],
|
|
"total": accounts["total"]
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"搜索账户失败: {e}")
|
|
raise HTTPException(status_code=400, detail="搜索账户失败")
|
|
|
|
@app.get("/api/dify_accounts/{username}")
|
|
async def get_dify_account(username: str, current_user: dict = Depends(get_current_user)):
|
|
"""查询Dify账户信息"""
|
|
try:
|
|
account = AccountManager.get_user_by_username(username)
|
|
if not account:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Dify账户不存在")
|
|
return {
|
|
"user_id": str(account["id"]),
|
|
"username": account["username"],
|
|
"email": account["email"],
|
|
"created_at": account["created_at"]
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"查询Dify账户失败: {e}")
|
|
if "404" in str(e):
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Dify账户不存在")
|
|
raise HTTPException(status_code=400, detail="查询Dify账户失败")
|
|
|
|
@app.put("/api/accounts/password")
|
|
async def change_password(
|
|
password_change: PasswordChange,
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
"""修改密码"""
|
|
try:
|
|
user = AccountManager.get_user_by_username(current_user["username"])
|
|
if not AccountManager.verify_password(
|
|
password_change.current_password,
|
|
user["password"],
|
|
user["password_salt"]
|
|
):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="当前密码不正确"
|
|
)
|
|
|
|
AccountManager.update_password(
|
|
current_user["username"],
|
|
current_user["email"],
|
|
password_change.new_password
|
|
)
|
|
return {"message": "密码修改成功"}
|
|
except Exception as e:
|
|
logger.error(f"修改密码失败: {e}")
|
|
raise HTTPException(status_code=400, detail="修改密码失败")
|
|
|
|
# 租户管理路由组
|
|
@app.post("/api/tenants/")
|
|
async def create_tenant(tenant: TenantCreate, current_user: dict = Depends(get_current_user)):
|
|
"""创建租户"""
|
|
try:
|
|
tenant_id = TenantManager.create_tenant(tenant.name)
|
|
return {
|
|
"message": "租户创建成功",
|
|
"tenant": {
|
|
"id": str(tenant_id),
|
|
"name": tenant.name,
|
|
"description": tenant.description,
|
|
"created_at": datetime.now(timezone.utc)
|
|
}
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"创建租户失败: {e}")
|
|
raise HTTPException(status_code=400, detail="创建租户失败")
|
|
|
|
@app.get("/api/tenants/")
|
|
async def list_tenants(current_user: dict = Depends(get_current_user)):
|
|
"""查询租户列表"""
|
|
try:
|
|
tenants = TenantManager.get_all_tenants()
|
|
return {
|
|
"tenants": [{
|
|
"id": str(t["id"]),
|
|
"name": t["name"],
|
|
"description": t.get("description", ""),
|
|
"created_at": t.get("created_at", datetime.now(timezone.utc))
|
|
} for t in tenants]
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"查询租户列表失败: {e}")
|
|
raise HTTPException(status_code=400, detail="查询租户列表失败")
|
|
|
|
@app.get("/api/tenants/{name}")
|
|
async def get_tenant(name: str, current_user: dict = Depends(get_current_user)):
|
|
"""查询特定租户"""
|
|
try:
|
|
tenant = TenantManager.get_tenant_by_name(name)
|
|
if not tenant:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="租户不存在"
|
|
)
|
|
return {
|
|
"tenant": {
|
|
"id": str(tenant["id"]),
|
|
"name": tenant["name"],
|
|
"description": tenant.get("description", ""),
|
|
"created_at": tenant.get("created_at", datetime.now(timezone.utc))
|
|
}
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"查询租户失败: {e}")
|
|
if "404" in str(e):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="租户不存在"
|
|
)
|
|
raise HTTPException(status_code=400, detail="查询租户失败")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|