dify_admin/api/app.py

643 lines
23 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
from datetime import datetime, timedelta, timezone
from typing import List
from fastapi import FastAPI, Depends, HTTPException, status, Request, Body, Response, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from models import AccountCreate, AccountResponse, PasswordChange, TenantCreate, TenantResponse, ModelConfig
from model_manager import ModelManager
from account_manager import AccountManager as DifyAccountManager
from backend_account_manager import BackendAccountManager
backend_account_manager = BackendAccountManager()
from tenant_manager import TenantManager
from api_user_manager import APIAuthManager
from operation_logger import OperationLogger
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta
from typing import Optional
import logging
from account_manager import AccountManager
from database import get_db_cursor
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# JWT配置
SECRET_KEY = "your-secret-key-here" # 生产环境应该从环境变量获取
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 延长至24小时
# 密码哈希
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2方案
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
# 初始化数据库表(使用IF NOT EXISTS语法不会覆盖已有表)
try:
from database import init_sqlite_db
init_sqlite_db()
logger.info("数据库表检查完成,缺失的表已创建")
except Exception as e:
logger.error(f"数据库表初始化失败: {e}")
raise
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:3001", "http://127.0.0.1:3001"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
max_age=600
)
api_auth = APIAuthManager()
op_logger = OperationLogger()
def verify_password(plain_password: str, hashed_password: str):
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str):
"""生成密码哈希"""
return pwd_context.hash(password)
def authenticate_user(username: str, password: str):
"""认证用户"""
try:
user = AccountManager.get_user_by_username(username)
if not user:
return False
# 兼容测试用户
if isinstance(user, tuple):
user_dict = {
"id": user[0],
"username": user[1],
"email": user[2],
"password": user[3],
"password_salt": user[4]
}
if not AccountManager.verify_password(password, user_dict["password"], user_dict["password_salt"]):
return False
return user_dict
else:
if not AccountManager.verify_password(password, user["password"], user["password_salt"]):
return False
return user
except Exception as e:
logger.error(f"认证失败: {e}")
return False
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# 打印接收到的token用于调试
logger.info(f"Received token: {token[:10]}...{token[-10:]}")
print(f"Full token: {token}") # 调试日志
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
print(f"Decoded payload: {payload}") # 调试日志
username: str = payload.get("sub")
if username is None:
raise credentials_exception
# 验证用户(兼容前后端用户)
user = backend_account_manager.get_user_by_username(username) or \
AccountManager.get_user_by_username(username)
if not user:
raise credentials_exception
logger.info(f"Authenticated user: {username}")
return {
"id": str(user.get("id")),
"username": user.get("username"),
"email": user.get("email", "")
}
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token已过期",
headers={"WWW-Authenticate": "Bearer"},
)
except JWTError:
raise credentials_exception
# 认证路由组
@app.options("/api/auth/register", include_in_schema=False)
async def auth_register_options():
"""处理OPTIONS预检请求"""
response = Response(status_code=204)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
return response
@app.post("/api/auth/register")
async def auth_register(request: Request):
"""注册后台管理账号"""
form_data = await request.form()
username = form_data.get("username", "").strip()
password = form_data.get("password", "").strip()
email = form_data.get("email", "").strip()
if not all([username, password, email]):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="缺少必要参数"
)
try:
# 检查用户名是否已存在
if backend_account_manager.get_user_by_username(username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在"
)
# 创建后台管理账号
user = backend_account_manager.create_account(username, email, password)
return {
"user_id": str(user["id"]),
"username": user["username"],
"email": user["email"]
}
except Exception as e:
logger.error(f"注册后台账号失败: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="注册失败"
)
@app.post("/api/auth/login")
async def auth_login(request: Request):
"""用户登录(auth)"""
try:
data = await request.json()
username = data.get("username")
password = data.get("password")
if not username or not password:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="缺少用户名或密码"
)
client_ip = request.client.host if request.client else "unknown"
user = backend_account_manager.get_user_by_username(username)
if not user or not backend_account_manager.verify_password(password, user["password"], user["password_salt"]):
op_logger.log_operation(
user_id=0,
operation_type="LOGIN_ATTEMPT",
endpoint="/api/auth/login",
parameters=f"username={username}, ip={client_ip}",
status="FAILED"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]},
expires_delta=access_token_expires
)
op_logger.log_operation(
user_id=user["id"],
operation_type="LOGIN",
endpoint="/api/auth/login",
parameters=f"ip={client_ip}",
status="SUCCESS"
)
return {"access_token": access_token, "token_type": "bearer"}
except Exception as e:
logger.error(f"登录失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="登录失败"
)
@app.post("/api/user/login")
async def user_login(request: Request):
"""用户登录(user)"""
return await auth_login(request)
@app.post("/api/auth/refresh")
async def refresh_token(current_user: dict = Depends(get_current_user)):
"""刷新Token"""
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": current_user["username"]},
expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
# 账户管理路由组
@app.post("/api/accounts/")
async def create_account(account: AccountCreate):
"""创建后台管理账户"""
try:
user = backend_account_manager.create_account(account.username, account.email, account.password)
return {
"user_id": str(user["id"]),
"username": user["username"],
"email": user["email"],
"created_at": user["created_at"]
}
except Exception as e:
logger.error(f"创建后台账户失败: {e}")
raise HTTPException(status_code=400, detail="创建后台账户失败")
@app.post("/api/dify_accounts/")
async def create_dify_account(account: AccountCreate):
"""创建Dify账户"""
try:
user = AccountManager.create_account(account.username, account.email, account.password)
return {
"user_id": str(user["id"]),
"username": user["username"],
"email": user["email"],
"created_at": user["created_at"]
}
except Exception as e:
logger.error(f"创建Dify账户失败: {e}")
raise HTTPException(status_code=400, detail="创建Dify账户失败")
@app.get("/api/accounts/search")
async def search_accounts(
search: str = None,
page: int = 1,
page_size: int = 10,
current_user: dict = Depends(get_current_user)
):
"""搜索账户"""
try:
accounts = AccountManager.search_accounts(search, page, page_size)
return {
"accounts": [{
"id": str(a["id"]),
"username": a["username"],
"email": a["email"],
"status": a.get("status", "active"),
"created_at": a.get("created_at", datetime.now(timezone.utc))
} for a in accounts["data"]],
"total": accounts["total"]
}
except Exception as e:
logger.error(f"搜索账户失败: {e}")
raise HTTPException(status_code=400, detail="搜索账户失败")
@app.get("/api/dify_accounts/{username}")
async def get_dify_account(username: str, current_user: dict = Depends(get_current_user)):
"""查询Dify账户信息及关联租户"""
try:
account = AccountManager.get_user_by_username(username)
if not account:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Dify账户不存在")
# 获取关联租户信息
tenant_info = AccountManager.get_account_tenants(account["id"])
return {
"user_id": str(account["id"]),
"username": account["username"],
"email": account["email"],
"created_at": account["created_at"],
"tenants": tenant_info # 直接使用已格式化的租户信息
}
except ValueError as e:
logger.error(f"参数格式错误: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="参数格式错误"
)
except Exception as e:
logger.error(f"查询Dify账户失败: {e}", exc_info=True)
if "404" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Dify账户不存在"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="服务器内部错误"
)
@app.put("/api/accounts/password")
async def change_password(
password_change: PasswordChange,
current_user: dict = Depends(get_current_user)
):
"""修改当前用户密码"""
try:
user = AccountManager.get_user_by_username(current_user["username"])
if not AccountManager.verify_password(
password_change.current_password,
user["password"],
user["password_salt"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="当前密码不正确"
)
AccountManager.update_password(
current_user["username"],
current_user["email"],
password_change.new_password
)
return {"message": "密码修改成功"}
except Exception as e:
logger.error(f"修改密码失败: {e}")
raise HTTPException(status_code=400, detail="修改密码失败")
@app.post("/api/accounts/{account_id}/reset-password")
async def reset_password(
account_id: str,
current_user: dict = Depends(get_current_user)
):
"""管理员重置用户密码"""
try:
# 检查当前用户是否是admin
admin_user = backend_account_manager.get_user_by_username(current_user["username"])
if not admin_user or admin_user.get("username") != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限"
)
# 增强account_id格式验证
logger.info(f"完整请求参数: account_id={account_id}")
# 去除可能的空格和引号
clean_account_id = account_id.strip().strip('"').strip("'")
logger.info(f"清理后的account_id: {clean_account_id}")
try:
# 严格验证UUID格式
if len(clean_account_id) != 36 or clean_account_id.count("-") != 4:
raise ValueError("UUID格式不正确")
parsed_uuid = uuid.UUID(clean_account_id)
logger.info(f"成功解析为UUID: {parsed_uuid}")
except ValueError as e:
logger.error(f"无效的account_id格式: {clean_account_id}, 原始值: {account_id}, 错误: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"账号ID格式无效: {clean_account_id} (必须是标准的UUID格式)"
)
# 重置密码
success = AccountManager.reset_password(account_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="账号不存在"
)
return {"message": "密码重置成功"}
except Exception as e:
logger.error(f"重置密码失败: {e}")
raise HTTPException(status_code=400, detail="重置密码失败")
# 租户管理路由组
@app.post("/api/tenants/")
async def create_tenant(tenant: TenantCreate, current_user: dict = Depends(get_current_user)):
"""创建租户"""
try:
tenant_id = TenantManager.create_tenant(tenant.name)
return {
"message": "租户创建成功",
"tenant": {
"id": str(tenant_id),
"name": tenant.name,
"description": tenant.description,
"created_at": datetime.now(timezone.utc)
}
}
except Exception as e:
logger.error(f"创建租户失败: {e}")
raise HTTPException(status_code=400, detail="创建租户失败")
@app.get("/api/tenants/", response_model=List[TenantResponse])
async def list_tenants(
search: str = None,
current_user: dict = Depends(get_current_user)
):
"""查询租户列表"""
try:
logger.info(f"Current user accessing tenants: {current_user['username']}")
tenants = TenantManager.search_tenants(search) if search else TenantManager.get_all_tenants()
if not tenants:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="未找到匹配的租户"
)
return tenants
except Exception as e:
logger.error(f"查询租户列表失败: {e}")
raise HTTPException(status_code=400, detail="查询租户列表失败")
@app.get("/api/tenants/{name}")
async def get_tenant(name: str, current_user: dict = Depends(get_current_user)):
"""查询特定租户"""
try:
logger.info(f"Current user accessing tenant {name}: {current_user['username']}")
tenant = TenantManager.get_tenant_by_name(name)
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在"
)
return {
"tenant": {
"id": str(tenant["id"]),
"name": tenant["name"],
"description": tenant.get("description", ""),
"created_at": tenant.get("created_at", datetime.now(timezone.utc))
}
}
except Exception as e:
logger.error(f"查询租户失败: {e}")
if "404" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在"
)
raise HTTPException(status_code=400, detail="查询租户失败")
# 模型管理路由组
@app.post("/api/models/upload")
async def upload_model_config(
file: UploadFile = File(...),
current_user: dict = Depends(get_current_user)
):
"""上传模型配置文件"""
try:
# 检查是否为admin用户
if current_user["username"] != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要admin用户权限"
)
# 读取上传的文件内容
contents = await file.read()
config = json.loads(contents)
# 验证模型配置
required_fields = ["model_name", "provider_name", "model_type", "api_key"]
for field in required_fields:
if field not in config:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"模型配置缺少必要字段: {field}"
)
return {"message": "模型配置上传成功", "config": config}
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的JSON格式"
)
except Exception as e:
logger.error(f"上传模型配置失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="上传模型配置失败"
)
@app.post("/api/models/assign")
async def assign_model_to_tenant(
model_config: ModelConfig,
tenant_id: Optional[str] = None,
current_user: dict = Depends(get_current_user)
):
"""分配模型给租户"""
try:
# 检查是否为admin用户
if current_user["username"] != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要admin用户权限"
)
# 如果未指定租户ID则为所有租户添加模型
if not tenant_id:
total_added = ModelManager.add_models_for_all_tenants()
return {"message": f"模型已成功分配给所有租户,共{total_added}"}
# 为指定租户添加模型
tenant = TenantManager.get_tenant_by_id(tenant_id)
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在"
)
ModelManager.add_model_for_tenant(tenant_id, tenant["encrypt_public_key"], model_config.dict())
return {"message": "模型已成功分配给租户"}
except Exception as e:
logger.error(f"分配模型失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="分配模型失败"
)
@app.get("/api/models/{tenant_id}")
async def get_tenant_models(
tenant_id: str,
current_user: dict = Depends(get_current_user)
):
"""获取租户的模型列表"""
try:
models = ModelManager.get_tenant_models(tenant_id)
return {
"tenant_id": tenant_id,
"models": models
}
except Exception as e:
logger.error(f"获取租户模型失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取租户模型失败"
)
@app.delete("/api/models/{model_id}")
async def delete_model(
model_id: str,
current_user: dict = Depends(get_current_user)
):
"""删除特定模型"""
try:
# 检查是否为admin用户
if current_user["username"] != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要admin用户权限"
)
rows_affected = ModelManager.delete_model(model_id)
if rows_affected == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="模型不存在"
)
return {"message": "模型删除成功"}
except Exception as e:
logger.error(f"删除模型失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="删除模型失败"
)
@app.delete("/api/models/all/{model_name}")
async def delete_model_for_all_tenants(
model_name: str,
current_user: dict = Depends(get_current_user)
):
"""删除所有租户的特定模型"""
try:
# 检查管理员权限
if current_user["username"] != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限"
)
total_deleted = ModelManager.delete_specific_model_for_all_tenants(model_name)
return {"message": f"模型已从所有租户中删除,共{total_deleted}"}
except Exception as e:
logger.error(f"批量删除模型失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="批量删除模型失败"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)