622 lines
22 KiB
Python
622 lines
22 KiB
Python
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
from fastapi import FastAPI, Depends, HTTPException, status, Request, Body, Response, UploadFile, File
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
from models import AccountCreate, AccountResponse, PasswordChange, TenantCreate, TenantResponse, ModelConfig
|
||
from model_manager import ModelManager
|
||
from account_manager import AccountManager as DifyAccountManager
|
||
from backend_account_manager import BackendAccountManager
|
||
|
||
backend_account_manager = BackendAccountManager()
|
||
from tenant_manager import TenantManager
|
||
from api_user_manager import APIAuthManager
|
||
from operation_logger import OperationLogger
|
||
from jose import JWTError, jwt
|
||
from passlib.context import CryptContext
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional
|
||
import logging
|
||
from account_manager import AccountManager
|
||
from database import get_db_cursor
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# JWT配置
|
||
SECRET_KEY = "your-secret-key-here" # 生产环境应该从环境变量获取
|
||
ALGORITHM = "HS256"
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 延长至24小时
|
||
|
||
# 密码哈希
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
# OAuth2方案
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||
|
||
app = FastAPI()
|
||
|
||
# 初始化数据库表(使用IF NOT EXISTS语法,不会覆盖已有表)
|
||
try:
|
||
from database import init_sqlite_db
|
||
init_sqlite_db()
|
||
logger.info("数据库表检查完成,缺失的表已创建")
|
||
except Exception as e:
|
||
logger.error(f"数据库表初始化失败: {e}")
|
||
raise
|
||
|
||
# 添加CORS中间件
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:3001", "http://127.0.0.1:3001"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
expose_headers=["*"],
|
||
max_age=600
|
||
)
|
||
api_auth = APIAuthManager()
|
||
op_logger = OperationLogger()
|
||
|
||
def verify_password(plain_password: str, hashed_password: str):
|
||
"""验证密码"""
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
def get_password_hash(password: str):
|
||
"""生成密码哈希"""
|
||
return pwd_context.hash(password)
|
||
|
||
def authenticate_user(username: str, password: str):
|
||
"""认证用户"""
|
||
try:
|
||
user = AccountManager.get_user_by_username(username)
|
||
if not user:
|
||
return False
|
||
|
||
# 兼容测试用户
|
||
if isinstance(user, tuple):
|
||
user_dict = {
|
||
"id": user[0],
|
||
"username": user[1],
|
||
"email": user[2],
|
||
"password": user[3],
|
||
"password_salt": user[4]
|
||
}
|
||
if not AccountManager.verify_password(password, user_dict["password"], user_dict["password_salt"]):
|
||
return False
|
||
return user_dict
|
||
else:
|
||
if not AccountManager.verify_password(password, user["password"], user["password_salt"]):
|
||
return False
|
||
return user
|
||
except Exception as e:
|
||
logger.error(f"认证失败: {e}")
|
||
return False
|
||
|
||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||
"""创建访问令牌"""
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.now(timezone.utc) + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||
"""获取当前用户"""
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无法验证凭据",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
try:
|
||
# 打印接收到的token用于调试
|
||
logger.info(f"Received token: {token[:10]}...{token[-10:]}")
|
||
|
||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||
username: str = payload.get("sub")
|
||
if username is None:
|
||
raise credentials_exception
|
||
|
||
# 验证用户(兼容前后端用户)
|
||
user = backend_account_manager.get_user_by_username(username) or \
|
||
AccountManager.get_user_by_username(username)
|
||
if not user:
|
||
raise credentials_exception
|
||
|
||
logger.info(f"Authenticated user: {username}")
|
||
return {
|
||
"id": str(user.get("id")),
|
||
"username": user.get("username"),
|
||
"email": user.get("email", "")
|
||
}
|
||
except jwt.ExpiredSignatureError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Token已过期",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
except JWTError:
|
||
raise credentials_exception
|
||
|
||
# 认证路由组
|
||
@app.options("/api/auth/register", include_in_schema=False)
|
||
async def auth_register_options():
|
||
"""处理OPTIONS预检请求"""
|
||
response = Response(status_code=204)
|
||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||
response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS"
|
||
response.headers["Access-Control-Allow-Headers"] = "Content-Type"
|
||
return response
|
||
|
||
@app.post("/api/auth/register")
|
||
async def auth_register(request: Request):
|
||
"""注册后台管理账号"""
|
||
form_data = await request.form()
|
||
username = form_data.get("username", "").strip()
|
||
password = form_data.get("password", "").strip()
|
||
email = form_data.get("email", "").strip()
|
||
|
||
if not all([username, password, email]):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||
detail="缺少必要参数"
|
||
)
|
||
try:
|
||
# 检查用户名是否已存在
|
||
if backend_account_manager.get_user_by_username(username):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="用户名已存在"
|
||
)
|
||
|
||
# 创建后台管理账号
|
||
user = backend_account_manager.create_account(username, email, password)
|
||
return {
|
||
"user_id": str(user["id"]),
|
||
"username": user["username"],
|
||
"email": user["email"]
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"注册后台账号失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="注册失败"
|
||
)
|
||
|
||
@app.post("/api/auth/login")
|
||
async def auth_login(request: Request):
|
||
"""用户登录(auth)"""
|
||
try:
|
||
data = await request.json()
|
||
username = data.get("username")
|
||
password = data.get("password")
|
||
|
||
if not username or not password:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||
detail="缺少用户名或密码"
|
||
)
|
||
|
||
client_ip = request.client.host if request.client else "unknown"
|
||
user = backend_account_manager.get_user_by_username(username)
|
||
if not user or not backend_account_manager.verify_password(password, user["password"], user["password_salt"]):
|
||
op_logger.log_operation(
|
||
user_id=0,
|
||
operation_type="LOGIN_ATTEMPT",
|
||
endpoint="/api/auth/login",
|
||
parameters=f"username={username}, ip={client_ip}",
|
||
status="FAILED"
|
||
)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户名或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
access_token = create_access_token(
|
||
data={"sub": user["username"]},
|
||
expires_delta=access_token_expires
|
||
)
|
||
op_logger.log_operation(
|
||
user_id=user["id"],
|
||
operation_type="LOGIN",
|
||
endpoint="/api/auth/login",
|
||
parameters=f"ip={client_ip}",
|
||
status="SUCCESS"
|
||
)
|
||
return {"access_token": access_token, "token_type": "bearer"}
|
||
except Exception as e:
|
||
logger.error(f"登录失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="登录失败"
|
||
)
|
||
|
||
@app.post("/api/user/login")
|
||
async def user_login(request: Request):
|
||
"""用户登录(user)"""
|
||
return await auth_login(request)
|
||
|
||
@app.post("/api/auth/refresh")
|
||
async def refresh_token(current_user: dict = Depends(get_current_user)):
|
||
"""刷新Token"""
|
||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
access_token = create_access_token(
|
||
data={"sub": current_user["username"]},
|
||
expires_delta=access_token_expires
|
||
)
|
||
return {"access_token": access_token, "token_type": "bearer"}
|
||
|
||
# 账户管理路由组
|
||
@app.post("/api/dify_accounts/")
|
||
async def create_dify_account(account: AccountCreate):
|
||
"""创建Dify账户"""
|
||
try:
|
||
user = AccountManager.create_account(account.username, account.email, account.password)
|
||
return {
|
||
"user_id": str(user["id"]),
|
||
"username": user["username"],
|
||
"email": user["email"],
|
||
"created_at": user["created_at"]
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"创建Dify账户失败: {e}")
|
||
raise HTTPException(status_code=400, detail="创建Dify账户失败")
|
||
|
||
@app.get("/api/accounts/search")
|
||
async def search_accounts(
|
||
search: str = None,
|
||
page: int = 1,
|
||
page_size: int = 10,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""搜索账户"""
|
||
try:
|
||
accounts = AccountManager.search_accounts(search, page, page_size)
|
||
return {
|
||
"accounts": [{
|
||
"id": str(a["id"]),
|
||
"username": a["username"],
|
||
"email": a["email"],
|
||
"status": a.get("status", "active"),
|
||
"created_at": a.get("created_at", datetime.now(timezone.utc))
|
||
} for a in accounts["data"]],
|
||
"total": accounts["total"]
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"搜索账户失败: {e}")
|
||
raise HTTPException(status_code=400, detail="搜索账户失败")
|
||
|
||
@app.get("/api/dify_accounts/{username}")
|
||
async def get_dify_account(username: str, current_user: dict = Depends(get_current_user)):
|
||
"""查询Dify账户信息及关联租户"""
|
||
try:
|
||
account = AccountManager.get_user_by_username(username)
|
||
if not account:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Dify账户不存在")
|
||
|
||
# 获取关联租户信息
|
||
tenant_info = AccountManager.get_account_tenants(account["id"])
|
||
|
||
return {
|
||
"user_id": str(account["id"]),
|
||
"username": account["username"],
|
||
"email": account["email"],
|
||
"created_at": account["created_at"],
|
||
"tenants": tenant_info # 直接使用已格式化的租户信息
|
||
}
|
||
except ValueError as e:
|
||
logger.error(f"参数格式错误: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="参数格式错误"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"查询Dify账户失败: {e}", exc_info=True)
|
||
if "404" in str(e):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="Dify账户不存在"
|
||
)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="服务器内部错误"
|
||
)
|
||
|
||
@app.put("/api/accounts/password")
|
||
async def change_password(
|
||
password_change: PasswordChange,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""修改当前用户密码"""
|
||
try:
|
||
user = AccountManager.get_user_by_username(current_user["username"])
|
||
if not AccountManager.verify_password(
|
||
password_change.current_password,
|
||
user["password"],
|
||
user["password_salt"]
|
||
):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="当前密码不正确"
|
||
)
|
||
|
||
AccountManager.update_password(
|
||
current_user["username"],
|
||
current_user["email"],
|
||
password_change.new_password
|
||
)
|
||
return {"message": "密码修改成功"}
|
||
except Exception as e:
|
||
logger.error(f"修改密码失败: {e}")
|
||
raise HTTPException(status_code=400, detail="修改密码失败")
|
||
|
||
@app.post("/api/accounts/{account_id}/reset-password")
|
||
async def reset_password(
|
||
account_id: str,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""管理员重置用户密码"""
|
||
try:
|
||
# 检查当前用户是否是admin
|
||
admin_user = backend_account_manager.get_user_by_username(current_user["username"])
|
||
if not admin_user or admin_user.get("username") != "admin":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="需要管理员权限"
|
||
)
|
||
|
||
# 增强account_id格式验证
|
||
logger.info(f"完整请求参数: account_id={account_id}")
|
||
|
||
# 去除可能的空格和引号
|
||
clean_account_id = account_id.strip().strip('"').strip("'")
|
||
logger.info(f"清理后的account_id: {clean_account_id}")
|
||
|
||
try:
|
||
# 严格验证UUID格式
|
||
if len(clean_account_id) != 36 or clean_account_id.count("-") != 4:
|
||
raise ValueError("UUID格式不正确")
|
||
|
||
parsed_uuid = uuid.UUID(clean_account_id)
|
||
logger.info(f"成功解析为UUID: {parsed_uuid}")
|
||
except ValueError as e:
|
||
logger.error(f"无效的account_id格式: {clean_account_id}, 原始值: {account_id}, 错误: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"账号ID格式无效: {clean_account_id} (必须是标准的UUID格式)"
|
||
)
|
||
|
||
# 重置密码
|
||
success = AccountManager.reset_password(account_id)
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="账号不存在"
|
||
)
|
||
|
||
return {"message": "密码重置成功"}
|
||
except Exception as e:
|
||
logger.error(f"重置密码失败: {e}")
|
||
raise HTTPException(status_code=400, detail="重置密码失败")
|
||
|
||
# 租户管理路由组
|
||
@app.post("/api/tenants/")
|
||
async def create_tenant(tenant: TenantCreate, current_user: dict = Depends(get_current_user)):
|
||
"""创建租户"""
|
||
try:
|
||
tenant_id = TenantManager.create_tenant(tenant.name)
|
||
return {
|
||
"message": "租户创建成功",
|
||
"tenant": {
|
||
"id": str(tenant_id),
|
||
"name": tenant.name,
|
||
"description": tenant.description,
|
||
"created_at": datetime.now(timezone.utc)
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"创建租户失败: {e}")
|
||
raise HTTPException(status_code=400, detail="创建租户失败")
|
||
|
||
@app.get("/api/tenants/")
|
||
async def list_tenants(current_user: dict = Depends(get_current_user)):
|
||
"""查询租户列表"""
|
||
try:
|
||
tenants = TenantManager.get_all_tenants()
|
||
return {
|
||
"tenants": [{
|
||
"id": str(t["id"]),
|
||
"name": t["name"],
|
||
"description": t.get("description", ""),
|
||
"created_at": t.get("created_at", datetime.now(timezone.utc))
|
||
} for t in tenants]
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"查询租户列表失败: {e}")
|
||
raise HTTPException(status_code=400, detail="查询租户列表失败")
|
||
|
||
@app.get("/api/tenants/{name}")
|
||
async def get_tenant(name: str, current_user: dict = Depends(get_current_user)):
|
||
"""查询特定租户"""
|
||
try:
|
||
tenant = TenantManager.get_tenant_by_name(name)
|
||
if not tenant:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="租户不存在"
|
||
)
|
||
return {
|
||
"tenant": {
|
||
"id": str(tenant["id"]),
|
||
"name": tenant["name"],
|
||
"description": tenant.get("description", ""),
|
||
"created_at": tenant.get("created_at", datetime.now(timezone.utc))
|
||
}
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"查询租户失败: {e}")
|
||
if "404" in str(e):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="租户不存在"
|
||
)
|
||
raise HTTPException(status_code=400, detail="查询租户失败")
|
||
|
||
# 模型管理路由组
|
||
@app.post("/api/models/upload")
|
||
async def upload_model_config(
|
||
file: UploadFile = File(...),
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""上传模型配置文件"""
|
||
try:
|
||
# 检查是否为admin用户
|
||
if current_user["username"] != "admin":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="需要admin用户权限"
|
||
)
|
||
|
||
# 读取上传的文件内容
|
||
contents = await file.read()
|
||
config = json.loads(contents)
|
||
|
||
# 验证模型配置
|
||
required_fields = ["model_name", "provider_name", "model_type", "api_key"]
|
||
for field in required_fields:
|
||
if field not in config:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||
detail=f"模型配置缺少必要字段: {field}"
|
||
)
|
||
|
||
return {"message": "模型配置上传成功", "config": config}
|
||
except json.JSONDecodeError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="无效的JSON格式"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"上传模型配置失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="上传模型配置失败"
|
||
)
|
||
|
||
@app.post("/api/models/assign")
|
||
async def assign_model_to_tenant(
|
||
model_config: ModelConfig,
|
||
tenant_id: Optional[str] = None,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""分配模型给租户"""
|
||
try:
|
||
# 检查是否为admin用户
|
||
if current_user["username"] != "admin":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="需要admin用户权限"
|
||
)
|
||
|
||
# 如果未指定租户ID,则为所有租户添加模型
|
||
if not tenant_id:
|
||
total_added = ModelManager.add_models_for_all_tenants()
|
||
return {"message": f"模型已成功分配给所有租户,共{total_added}个"}
|
||
|
||
# 为指定租户添加模型
|
||
tenant = TenantManager.get_tenant_by_id(tenant_id)
|
||
if not tenant:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="租户不存在"
|
||
)
|
||
|
||
ModelManager.add_model_for_tenant(tenant_id, tenant["encrypt_public_key"], model_config.dict())
|
||
return {"message": "模型已成功分配给租户"}
|
||
except Exception as e:
|
||
logger.error(f"分配模型失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="分配模型失败"
|
||
)
|
||
|
||
@app.get("/api/models/{tenant_id}")
|
||
async def get_tenant_models(
|
||
tenant_id: str,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""获取租户的模型列表"""
|
||
try:
|
||
models = ModelManager.get_tenant_models(tenant_id)
|
||
return {
|
||
"tenant_id": tenant_id,
|
||
"models": models
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取租户模型失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="获取租户模型失败"
|
||
)
|
||
|
||
@app.delete("/api/models/{model_id}")
|
||
async def delete_model(
|
||
model_id: str,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""删除特定模型"""
|
||
try:
|
||
# 检查是否为admin用户
|
||
if current_user["username"] != "admin":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="需要admin用户权限"
|
||
)
|
||
|
||
rows_affected = ModelManager.delete_model(model_id)
|
||
if rows_affected == 0:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="模型不存在"
|
||
)
|
||
return {"message": "模型删除成功"}
|
||
except Exception as e:
|
||
logger.error(f"删除模型失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="删除模型失败"
|
||
)
|
||
|
||
@app.delete("/api/models/all/{model_name}")
|
||
async def delete_model_for_all_tenants(
|
||
model_name: str,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
"""删除所有租户的特定模型"""
|
||
try:
|
||
# 检查管理员权限
|
||
if current_user["username"] != "admin":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="需要管理员权限"
|
||
)
|
||
|
||
total_deleted = ModelManager.delete_specific_model_for_all_tenants(model_name)
|
||
return {"message": f"模型已从所有租户中删除,共{total_deleted}个"}
|
||
except Exception as e:
|
||
logger.error(f"批量删除模型失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="批量删除模型失败"
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8001)
|