From b5fa1b0d637dac8e2975b59fd067fbc7b61bf6be Mon Sep 17 00:00:00 2001 From: Xin Date: Wed, 7 May 2025 19:08:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dtoken=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E9=97=AE=E9=A2=98=EF=BC=8C=E7=A7=9F=E6=88=B7=E7=AE=A1?= =?UTF-8?q?=E7=90=86API=E5=AD=98=E5=9C=A8401=E6=9C=AA=E6=8E=88=E6=9D=83?= =?UTF-8?q?=E9=97=AE=E9=A2=98=E9=9C=80=E8=A6=81=E8=BF=9B=E4=B8=80=E6=AD=A5?= =?UTF-8?q?=E6=8E=92=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app.py | 238 +++++++++++++++++++++++++++++----- api/model_manager.py | 72 +++++++--- api/models.py | 27 +++- web/src/api/auth/index.ts | 6 +- web/src/api/model/index.ts | 50 +++++++ web/src/api/model/types.ts | 30 +++++ web/src/api/tenant/types.ts | 7 + web/src/axios/service.ts | 8 +- web/src/router/index.ts | 3 +- web/src/views/Auth/Login.vue | 22 ++-- web/src/views/Model/index.vue | 237 ++++++++++++++++++++++++++++++++- 11 files changed, 631 insertions(+), 69 deletions(-) create mode 100644 web/src/api/model/index.ts create mode 100644 web/src/api/model/types.ts diff --git a/api/app.py b/api/app.py index cc07e1b..c983fae 100644 --- a/api/app.py +++ b/api/app.py @@ -1,9 +1,10 @@ import uuid from datetime import datetime, timedelta, timezone -from fastapi import FastAPI, Depends, HTTPException, status, Request, Body, Response +from fastapi import FastAPI, Depends, HTTPException, status, Request, Body, Response, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from models import AccountCreate, AccountResponse, PasswordChange, TenantCreate, TenantResponse +from models import AccountCreate, AccountResponse, PasswordChange, TenantCreate, TenantResponse, ModelConfig +from model_manager import ModelManager from account_manager import AccountManager as DifyAccountManager from backend_account_manager import BackendAccountManager @@ -48,11 +49,12 @@ except Exception as e: # 添加CORS中间件 app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "http://localhost:3001", "http://127.0.0.1:3001"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], - expose_headers=["*"] + expose_headers=["*"], + max_age=600 ) api_auth = APIAuthManager() op_logger = OperationLogger() @@ -111,6 +113,9 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): headers={"WWW-Authenticate": "Bearer"}, ) try: + # 打印接收到的token用于调试 + logger.info(f"Received token: {token[:10]}...{token[-10:]}") + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: @@ -122,11 +127,18 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): if not user: raise credentials_exception + logger.info(f"Authenticated user: {username}") return { "id": str(user.get("id")), "username": user.get("username"), "email": user.get("email", "") } + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token已过期", + headers={"WWW-Authenticate": "Bearer"}, + ) except JWTError: raise credentials_exception @@ -176,41 +188,59 @@ async def auth_register(request: Request): ) @app.post("/api/auth/login") -async def auth_login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()): +async def auth_login(request: Request): """用户登录(auth)""" - client_ip = request.client.host if request.client else "unknown" - user = backend_account_manager.get_user_by_username(form_data.username) - if not user or not backend_account_manager.verify_password(form_data.password, user["password"], user["password_salt"]): + try: + data = await request.json() + username = data.get("username") + password = data.get("password") + + if not username or not password: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="缺少用户名或密码" + ) + + client_ip = request.client.host if request.client else "unknown" + user = backend_account_manager.get_user_by_username(username) + if not user or not backend_account_manager.verify_password(password, user["password"], user["password_salt"]): + op_logger.log_operation( + user_id=0, + operation_type="LOGIN_ATTEMPT", + endpoint="/api/auth/login", + parameters=f"username={username}, ip={client_ip}", + status="FAILED" + ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="用户名或密码错误", + headers={"WWW-Authenticate": "Bearer"}, + ) + + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": user["username"]}, + expires_delta=access_token_expires + ) op_logger.log_operation( - user_id=0, - operation_type="LOGIN_ATTEMPT", + user_id=user["id"], + operation_type="LOGIN", endpoint="/api/auth/login", - parameters=f"username={form_data.username}, ip={client_ip}", - status="FAILED" + parameters=f"ip={client_ip}", + status="SUCCESS" ) + return {"access_token": access_token, "token_type": "bearer"} + except Exception as e: + logger.error(f"登录失败: {e}") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="用户名或密码错误", - headers={"WWW-Authenticate": "Bearer"}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="登录失败" ) - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": user["username"]}, - expires_delta=access_token_expires - ) - op_logger.log_operation( - user_id=user["id"], - operation_type="LOGIN", - endpoint="/api/auth/login", - parameters=f"ip={client_ip}", - status="SUCCESS" - ) - return {"access_token": access_token, "token_type": "bearer"} @app.post("/api/user/login") -async def user_login(request: Request, form_data: OAuth2PasswordRequestForm = Depends()): +async def user_login(request: Request): """用户登录(user)""" - return await auth_login(request, form_data) + return await auth_login(request) @app.post("/api/auth/refresh") async def refresh_token(current_user: dict = Depends(get_current_user)): @@ -438,6 +468,154 @@ async def get_tenant(name: str, current_user: dict = Depends(get_current_user)): ) raise HTTPException(status_code=400, detail="查询租户失败") +# 模型管理路由组 +@app.post("/api/models/upload") +async def upload_model_config( + file: UploadFile = File(...), + current_user: dict = Depends(get_current_user) +): + """上传模型配置文件""" + try: + # 检查是否为admin用户 + if current_user["username"] != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要admin用户权限" + ) + + # 读取上传的文件内容 + contents = await file.read() + config = json.loads(contents) + + # 验证模型配置 + required_fields = ["model_name", "provider_name", "model_type", "api_key"] + for field in required_fields: + if field not in config: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"模型配置缺少必要字段: {field}" + ) + + return {"message": "模型配置上传成功", "config": config} + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的JSON格式" + ) + except Exception as e: + logger.error(f"上传模型配置失败: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="上传模型配置失败" + ) + +@app.post("/api/models/assign") +async def assign_model_to_tenant( + model_config: ModelConfig, + tenant_id: Optional[str] = None, + current_user: dict = Depends(get_current_user) +): + """分配模型给租户""" + try: + # 检查是否为admin用户 + if current_user["username"] != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要admin用户权限" + ) + + # 如果未指定租户ID,则为所有租户添加模型 + if not tenant_id: + total_added = ModelManager.add_models_for_all_tenants() + return {"message": f"模型已成功分配给所有租户,共{total_added}个"} + + # 为指定租户添加模型 + tenant = TenantManager.get_tenant_by_id(tenant_id) + if not tenant: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="租户不存在" + ) + + ModelManager.add_model_for_tenant(tenant_id, tenant["encrypt_public_key"], model_config.dict()) + return {"message": "模型已成功分配给租户"} + except Exception as e: + logger.error(f"分配模型失败: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="分配模型失败" + ) + +@app.get("/api/models/{tenant_id}") +async def get_tenant_models( + tenant_id: str, + current_user: dict = Depends(get_current_user) +): + """获取租户的模型列表""" + try: + models = ModelManager.get_tenant_models(tenant_id) + return { + "tenant_id": tenant_id, + "models": models + } + except Exception as e: + logger.error(f"获取租户模型失败: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="获取租户模型失败" + ) + +@app.delete("/api/models/{model_id}") +async def delete_model( + model_id: str, + current_user: dict = Depends(get_current_user) +): + """删除特定模型""" + try: + # 检查是否为admin用户 + if current_user["username"] != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要admin用户权限" + ) + + rows_affected = ModelManager.delete_model(model_id) + if rows_affected == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="模型不存在" + ) + return {"message": "模型删除成功"} + except Exception as e: + logger.error(f"删除模型失败: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="删除模型失败" + ) + +@app.delete("/api/models/all/{model_name}") +async def delete_model_for_all_tenants( + model_name: str, + current_user: dict = Depends(get_current_user) +): + """删除所有租户的特定模型""" + try: + # 检查管理员权限 + if current_user["username"] != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="需要管理员权限" + ) + + total_deleted = ModelManager.delete_specific_model_for_all_tenants(model_name) + return {"message": f"模型已从所有租户中删除,共{total_deleted}个"} + except Exception as e: + logger.error(f"批量删除模型失败: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="批量删除模型失败" + ) + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/api/model_manager.py b/api/model_manager.py index 6de3f62..183dd6b 100644 --- a/api/model_manager.py +++ b/api/model_manager.py @@ -14,32 +14,68 @@ class ModelManager: """模型管理类""" @staticmethod - def load_config(config_path): - """加载模型配置文件""" + def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']): + """为指定租户添加火山模型""" try: - if not os.path.exists(config_path): - raise FileNotFoundError(f"配置文件 {config_path} 不存在!") - with open(config_path, "r", encoding="utf-8") as f: - config = json.load(f) - return config + # 加载模型配置 + config = ModelManager.load_config(config_path) + models = config.get("models", []) + + # 获取租户信息 + tenant = TenantManager.get_tenant_by_name(tenant_name) + if not tenant: + logger.error(f"未找到租户: {tenant_name}") + return 0 + + # 为租户添加模型 + total_added = 0 + for model_config in models: + if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config): + total_added += 1 + + logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。") + return total_added except Exception as e: - logger.error(f"加载配置文件失败: {e}") + logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}") raise - + @staticmethod - def check_model_exists(tenant_id, model_name): - """检查指定租户是否已存在指定的模型""" + def get_tenant_models(tenant_id: str): + """获取租户的模型列表""" try: query = """ - SELECT COUNT(*) FROM provider_models - WHERE tenant_id = %s AND model_name = %s; + SELECT + id, + provider_name, + model_name, + model_type, + encrypted_config, + is_valid, + created_at, + updated_at + FROM provider_models + WHERE tenant_id = %s + ORDER BY created_at DESC; """ - count = execute_query(query, (tenant_id, model_name), fetch_one=True)[0] - return count > 0 + with get_db_cursor() as cursor: + cursor.execute(query, (tenant_id,)) + models = cursor.fetchall() + + return [{ + "id": str(model[0]), + "provider_name": model[1], + "model_name": model[2], + "model_type": model[3], + "encrypted_config": json.loads(model[4]), + "is_valid": model[5], + "created_at": model[6], + "updated_at": model[7] + } for model in models] except Exception as e: - logger.error(f"检查模型是否存在时发生错误: {e}") - return False - + logger.error(f"获取租户模型失败: {e}") + raise + + @staticmethod def add_model_for_tenant(tenant_id, public_key_pem, model_config): """为指定租户添加模型记录""" diff --git a/api/models.py b/api/models.py index 3d6fff6..c9b81db 100644 --- a/api/models.py +++ b/api/models.py @@ -1,7 +1,32 @@ from pydantic import BaseModel, EmailStr -from typing import Optional +from typing import Optional, List from datetime import datetime +class ModelConfig(BaseModel): + """模型配置上传模型""" + model_name: str + provider_name: str + model_type: str + api_key: str + endpoint_url: Optional[str] = None + display_name: Optional[str] = None + context_size: Optional[int] = None + max_tokens_to_sample: Optional[int] = None + +class ModelResponse(BaseModel): + """模型响应模型""" + id: str + model_name: str + provider_name: str + model_type: str + created_at: datetime + +class TenantModelResponse(BaseModel): + """租户模型响应模型""" + tenant_id: str + tenant_name: str + models: List[ModelResponse] + class AccountCreate(BaseModel): """创建账户请求模型""" username: str diff --git a/web/src/api/auth/index.ts b/web/src/api/auth/index.ts index 72de31b..efa7502 100644 --- a/web/src/api/auth/index.ts +++ b/web/src/api/auth/index.ts @@ -1,13 +1,13 @@ import { request } from '../../axios/service' import type { LoginParams, RegisterParams, LoginForm } from '@/api/auth/types' -export const login = (formData: FormData | LoginForm) => +export const login = (data: { username: string; password: string }) => request<{ access_token: string }>({ method: 'POST', url: '/api/auth/login', - data: formData, + data: data, headers: { - 'Content-Type': 'multipart/form-data' + 'Content-Type': 'application/json' } }) diff --git a/web/src/api/model/index.ts b/web/src/api/model/index.ts new file mode 100644 index 0000000..fb89532 --- /dev/null +++ b/web/src/api/model/index.ts @@ -0,0 +1,50 @@ +import { request } from '@/axios/service' +import type { ApiResponse, ModelConfig, ModelResponse, TenantModelResponse } from './types.ts' + +export function uploadModelConfig(file: File): Promise> { + const formData = new FormData() + formData.append('file', file) + return request({ + url: '/api/models/upload', + method: 'post', + data: formData, + headers: { + 'Content-Type': 'multipart/form-data' + } + }) +} + +export function assignModelToTenant( + modelConfig: ModelConfig, + tenantId?: string +): Promise> { + return request({ + url: '/api/models/assign', + method: 'post', + data: { + model_config: modelConfig, + tenant_id: tenantId + } + }) +} + +export function getTenantModels(tenantId: string): Promise> { + return request({ + url: `/api/models/${tenantId}`, + method: 'get' + }) +} + +export function deleteModel(modelId: string): Promise> { + return request({ + url: `/api/models/${modelId}`, + method: 'delete' + }) +} + +export function deleteModelForAllTenants(modelName: string): Promise> { + return request({ + url: `/api/models/all/${modelName}`, + method: 'delete' + }) +} diff --git a/web/src/api/model/types.ts b/web/src/api/model/types.ts new file mode 100644 index 0000000..8d6ee6d --- /dev/null +++ b/web/src/api/model/types.ts @@ -0,0 +1,30 @@ +export interface ApiResponse { + code: number + message: string + data: T +} + +export interface ModelConfig { + model_name: string + provider_name: string + model_type: string + api_key: string + endpoint_url?: string + display_name?: string + context_size?: number + max_tokens_to_sample?: number +} + +export interface ModelResponse { + id: string + model_name: string + provider_name: string + model_type: string + created_at: string +} + +export interface TenantModelResponse { + tenant_id: string + tenant_name: string + models: ModelResponse[] +} diff --git a/web/src/api/tenant/types.ts b/web/src/api/tenant/types.ts index c474da8..e20871a 100644 --- a/web/src/api/tenant/types.ts +++ b/web/src/api/tenant/types.ts @@ -10,8 +10,15 @@ export interface TenantForm { description: string } +export interface TenantListParams { + search?: string + page?: number + pageSize?: number +} + export interface TenantListResponse { tenants: TenantItem[] + total?: number } export interface TenantDetailResponse { diff --git a/web/src/axios/service.ts b/web/src/axios/service.ts index 910daa7..1177ec4 100644 --- a/web/src/axios/service.ts +++ b/web/src/axios/service.ts @@ -7,10 +7,12 @@ const service = createAxios() // 请求拦截器 service.interceptors.request.use( (config) => { - const userStore = useUserStore() - if (userStore.token) { - config.headers.Authorization = `Bearer ${userStore.token}` + const token = localStorage.getItem('access_token') + if (token) { + config.headers.Authorization = `Bearer ${token}` } + config.headers['Content-Type'] = 'application/json' + config.withCredentials = true return config }, (error) => { diff --git a/web/src/router/index.ts b/web/src/router/index.ts index babb015..a84e5ac 100644 --- a/web/src/router/index.ts +++ b/web/src/router/index.ts @@ -32,7 +32,8 @@ const router = createRouter({ }, { path: 'model', - component: () => import('../views/Model/index.vue') + component: () => import('../views/Model/index.vue'), + meta: { requiresAuth: true } } ] } diff --git a/web/src/views/Auth/Login.vue b/web/src/views/Auth/Login.vue index aab476b..59f96dc 100644 --- a/web/src/views/Auth/Login.vue +++ b/web/src/views/Auth/Login.vue @@ -72,20 +72,20 @@ const loginRules = { const loading = ref(false) const loginFormRef = ref() -const handleLogin = async () => { - try { - loading.value = true - await loginFormRef.value.validate() - const encryptedPassword = await encryptPassword(loginForm.value.password) - const formData = new FormData() - formData.append('username', loginForm.value.username) - formData.append('password', encryptedPassword) - - const res = await login(formData as LoginFormData) + const handleLogin = async () => { + try { + loading.value = true + await loginFormRef.value.validate() + const encryptedPassword = await encryptPassword(loginForm.value.password) + + const res = await login({ + username: loginForm.value.username, + password: encryptedPassword + }) console.log('登录成功:', res) // 确保token存储完成后再跳转 - const token = res.access_token || res.token + const token = res.access_token setToken(token) localStorage.setItem('token', token) setupTokenRefresh() diff --git a/web/src/views/Model/index.vue b/web/src/views/Model/index.vue index bdf0352..f19c56c 100644 --- a/web/src/views/Model/index.vue +++ b/web/src/views/Model/index.vue @@ -1,16 +1,249 @@