dify_admin/api/app.py

444 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from datetime import datetime, timedelta, timezone
from fastapi import FastAPI, Depends, HTTPException, status, Request, Body, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from models import AccountCreate, AccountResponse, PasswordChange, TenantCreate, TenantResponse
from account_manager import AccountManager as DifyAccountManager
from backend_account_manager import BackendAccountManager
backend_account_manager = BackendAccountManager()
from tenant_manager import TenantManager
from api_user_manager import APIAuthManager
from operation_logger import OperationLogger
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta
from typing import Optional
import logging
from account_manager import AccountManager
from database import get_db_cursor
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# JWT配置
SECRET_KEY = "your-secret-key-here" # 生产环境应该从环境变量获取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 延长至24小时
# 密码哈希
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2方案
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
# 初始化数据库表(使用IF NOT EXISTS语法不会覆盖已有表)
try:
from database import init_sqlite_db
init_sqlite_db()
logger.info("数据库表检查完成,缺失的表已创建")
except Exception as e:
logger.error(f"数据库表初始化失败: {e}")
raise
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
api_auth = APIAuthManager()
op_logger = OperationLogger()
def verify_password(plain_password: str, hashed_password: str):
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str):
"""生成密码哈希"""
return pwd_context.hash(password)
def authenticate_user(username: str, password: str):
"""认证用户"""
try:
user = AccountManager.get_user_by_username(username)
if not user:
return False
# 兼容测试用户
if isinstance(user, tuple):
user_dict = {
"id": user[0],
"username": user[1],
"email": user[2],
"password": user[3],
"password_salt": user[4]
}
if not AccountManager.verify_password(password, user_dict["password"], user_dict["password_salt"]):
return False
return user_dict
else:
if not AccountManager.verify_password(password, user["password"], user["password_salt"]):
return False
return user
except Exception as e:
logger.error(f"认证失败: {e}")
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
# 验证用户(兼容前后端用户)
user = backend_account_manager.get_user_by_username(username) or \
AccountManager.get_user_by_username(username)
if not user:
raise credentials_exception
return {
"id": str(user.get("id")),
"username": user.get("username"),
"email": user.get("email", "")
}
except JWTError:
raise credentials_exception
# 认证路由组
@app.options("/api/auth/register", include_in_schema=False)
async def auth_register_options():
"""处理OPTIONS预检请求"""
response = Response(status_code=204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
return response
@app.post("/api/auth/register")
async def auth_register(request: Request):
"""注册后台管理账号"""
form_data = await request.form()
username = form_data.get("username", "").strip()
password = form_data.get("password", "").strip()
email = form_data.get("email", "").strip()
if not all([username, password, email]):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="缺少必要参数"
)
try:
# 检查用户名是否已存在
if backend_account_manager.get_user_by_username(username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 创建后台管理账号
user = backend_account_manager.create_account(username, email, password)
return {
"user_id": str(user["id"]),
"username": user["username"],
"email": user["email"]
}
except Exception as e:
logger.error(f"注册后台账号失败: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="注册失败"
)
@app.post("/api/auth/login")
async def auth_login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
"""用户登录(auth)"""
client_ip = request.client.host if request.client else "unknown"
user = backend_account_manager.get_user_by_username(form_data.username)
if not user or not backend_account_manager.verify_password(form_data.password, user["password"], user["password_salt"]):
op_logger.log_operation(
user_id=0,
operation_type="LOGIN_ATTEMPT",
endpoint="/api/auth/login",
parameters=f"username={form_data.username}, ip={client_ip}",
status="FAILED"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=access_token_expires
)
op_logger.log_operation(
user_id=user["id"],
operation_type="LOGIN",
endpoint="/api/auth/login",
parameters=f"ip={client_ip}",
status="SUCCESS"
)
return {"access_token": access_token, "token_type": "bearer"}
@app.post("/api/user/login")
async def user_login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()):
"""用户登录(user)"""
return await auth_login(request, form_data)
@app.post("/api/auth/refresh")
async def refresh_token(current_user: dict = Depends(get_current_user)):
"""刷新Token"""
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": current_user["username"]},
expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
# 账户管理路由组
@app.post("/api/dify_accounts/")
async def create_dify_account(account: AccountCreate):
"""创建Dify账户"""
try:
user = AccountManager.create_account(account.username, account.email, account.password)
return {
"user_id": str(user["id"]),
"username": user["username"],
"email": user["email"],
"created_at": user["created_at"]
}
except Exception as e:
logger.error(f"创建Dify账户失败: {e}")
raise HTTPException(status_code=400, detail="创建Dify账户失败")
@app.get("/api/accounts/search")
async def search_accounts(
search: str = None,
page: int = 1,
page_size: int = 10,
current_user: dict = Depends(get_current_user)
):
"""搜索账户"""
try:
accounts = AccountManager.search_accounts(search, page, page_size)
return {
"accounts": [{
"id": str(a["id"]),
"username": a["username"],
"email": a["email"],
"status": a.get("status", "active"),
"created_at": a.get("created_at", datetime.now(timezone.utc))
} for a in accounts["data"]],
"total": accounts["total"]
}
except Exception as e:
logger.error(f"搜索账户失败: {e}")
raise HTTPException(status_code=400, detail="搜索账户失败")
@app.get("/api/dify_accounts/{username}")
async def get_dify_account(username: str, current_user: dict = Depends(get_current_user)):
"""查询Dify账户信息及关联租户"""
try:
account = AccountManager.get_user_by_username(username)
if not account:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Dify账户不存在")
# 获取关联租户信息
tenant_info = AccountManager.get_account_tenants(account["id"])
return {
"user_id": str(account["id"]),
"username": account["username"],
"email": account["email"],
"created_at": account["created_at"],
"tenants": tenant_info # 直接使用已格式化的租户信息
}
except ValueError as e:
logger.error(f"参数格式错误: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="参数格式错误"
)
except Exception as e:
logger.error(f"查询Dify账户失败: {e}", exc_info=True)
if "404" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dify账户不存在"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="服务器内部错误"
)
@app.put("/api/accounts/password")
async def change_password(
password_change: PasswordChange,
current_user: dict = Depends(get_current_user)
):
"""修改当前用户密码"""
try:
user = AccountManager.get_user_by_username(current_user["username"])
if not AccountManager.verify_password(
password_change.current_password,
user["password"],
user["password_salt"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="当前密码不正确"
)
AccountManager.update_password(
current_user["username"],
current_user["email"],
password_change.new_password
)
return {"message": "密码修改成功"}
except Exception as e:
logger.error(f"修改密码失败: {e}")
raise HTTPException(status_code=400, detail="修改密码失败")
@app.post("/api/accounts/{account_id}/reset-password")
async def reset_password(
account_id: str,
current_user: dict = Depends(get_current_user)
):
"""管理员重置用户密码"""
try:
# 检查当前用户是否是admin
admin_user = backend_account_manager.get_user_by_username(current_user["username"])
if not admin_user or admin_user.get("username") != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限"
)
# 增强account_id格式验证
logger.info(f"完整请求参数: account_id={account_id}")
# 去除可能的空格和引号
clean_account_id = account_id.strip().strip('"').strip("'")
logger.info(f"清理后的account_id: {clean_account_id}")
try:
# 严格验证UUID格式
if len(clean_account_id) != 36 or clean_account_id.count("-") != 4:
raise ValueError("UUID格式不正确")
parsed_uuid = uuid.UUID(clean_account_id)
logger.info(f"成功解析为UUID: {parsed_uuid}")
except ValueError as e:
logger.error(f"无效的account_id格式: {clean_account_id}, 原始值: {account_id}, 错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"账号ID格式无效: {clean_account_id} (必须是标准的UUID格式)"
)
# 重置密码
success = AccountManager.reset_password(account_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="账号不存在"
)
return {"message": "密码重置成功"}
except Exception as e:
logger.error(f"重置密码失败: {e}")
raise HTTPException(status_code=400, detail="重置密码失败")
# 租户管理路由组
@app.post("/api/tenants/")
async def create_tenant(tenant: TenantCreate, current_user: dict = Depends(get_current_user)):
"""创建租户"""
try:
tenant_id = TenantManager.create_tenant(tenant.name)
return {
"message": "租户创建成功",
"tenant": {
"id": str(tenant_id),
"name": tenant.name,
"description": tenant.description,
"created_at": datetime.now(timezone.utc)
}
}
except Exception as e:
logger.error(f"创建租户失败: {e}")
raise HTTPException(status_code=400, detail="创建租户失败")
@app.get("/api/tenants/")
async def list_tenants(current_user: dict = Depends(get_current_user)):
"""查询租户列表"""
try:
tenants = TenantManager.get_all_tenants()
return {
"tenants": [{
"id": str(t["id"]),
"name": t["name"],
"description": t.get("description", ""),
"created_at": t.get("created_at", datetime.now(timezone.utc))
} for t in tenants]
}
except Exception as e:
logger.error(f"查询租户列表失败: {e}")
raise HTTPException(status_code=400, detail="查询租户列表失败")
@app.get("/api/tenants/{name}")
async def get_tenant(name: str, current_user: dict = Depends(get_current_user)):
"""查询特定租户"""
try:
tenant = TenantManager.get_tenant_by_name(name)
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在"
)
return {
"tenant": {
"id": str(tenant["id"]),
"name": tenant["name"],
"description": tenant.get("description", ""),
"created_at": tenant.get("created_at", datetime.now(timezone.utc))
}
}
except Exception as e:
logger.error(f"查询租户失败: {e}")
if "404" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在"
)
raise HTTPException(status_code=400, detail="查询租户失败")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)