dify_admin/api/model_manager.py
xh.xin 96480a27a9 初始化项目仓库,包含基础结构和开发计划
1. 添加README说明项目结构
2. 配置Python和Node.js的.gitignore
3. 包含认证模块和账号管理的前后端基础代码
4. 开发计划文档记录当前阶段任务
2025-05-02 18:33:06 +08:00

336 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
import psycopg2.extras
from database import get_db_cursor, execute_query, execute_update
from encryption import Encryption
from config import CONFIG_PATHS
from tenant_manager import TenantManager
# 配置日志
logger = logging.getLogger(__name__)
class ModelManager:
"""模型管理类"""
@staticmethod
def load_config(config_path):
"""加载模型配置文件"""
try:
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件 {config_path} 不存在!")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
raise
@staticmethod
def check_model_exists(tenant_id, model_name):
"""检查指定租户是否已存在指定的模型"""
try:
query = """
SELECT COUNT(*) FROM provider_models
WHERE tenant_id = %s AND model_name = %s;
"""
count = execute_query(query, (tenant_id, model_name), fetch_one=True)[0]
return count > 0
except Exception as e:
logger.error(f"检查模型是否存在时发生错误: {e}")
return False
@staticmethod
def add_model_for_tenant(tenant_id, public_key_pem, model_config):
"""为指定租户添加模型记录"""
try:
# 检查必要字段
required_fields = ["model_name", "provider_name", "model_type", "api_key"]
for field in required_fields:
if field not in model_config:
logger.error(f"模型配置缺少必要字段: {field}")
raise ValueError(f"模型配置缺少必要字段: {field}")
# 检查模型是否已存在
model_name = model_config["model_name"]
if ModelManager.check_model_exists(tenant_id, model_name):
logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。")
return
# 加密API密钥
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["api_key"])
# 根据模型类型构造加密配置
if model_config["model_type"] in ["embeddings", "reranking"]:
encrypted_config = {
"display_name": model_config.get("display_name", ""),
"api_key": encrypted_api_key,
"endpoint_url": model_config.get("endpoint_url", ""),
"context_size": model_config.get("context_size", "")
}
elif model_config["model_type"] == "text-generation":
encrypted_config = {
"display_name": model_config.get("display_name", ""),
"api_key": encrypted_api_key,
"endpoint_url": model_config.get("endpoint_url", ""),
"mode": model_config.get("mode", ""),
"context_size": model_config.get("context_size", ""),
"max_tokens_to_sample": model_config.get("max_tokens_to_sample", ""),
"function_calling_type": model_config.get("function_calling_type", ""),
"stream_function_calling": model_config.get("stream_function_calling", ""),
"vision_support": model_config.get("vision_support", ""),
"stream_mode_delimiter": model_config.get("stream_mode_delimiter", "")
}
else:
raise ValueError(f"不支持的模型类型: {model_config['model_type']}")
# 插入模型记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO provider_models (
id, tenant_id, provider_name, model_name, model_type,
encrypted_config, is_valid, created_at, updated_at
) VALUES (
uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW()
);
"""
cursor.execute(
insert_query, (
tenant_id,
model_config["provider_name"],
model_config["model_name"],
model_config["model_type"],
json.dumps(encrypted_config)
)
)
logger.info(f"为租户 {tenant_id} 添加模型记录成功!")
return True
except Exception as e:
logger.error(f"为租户 {tenant_id} 添加模型记录失败: {e}")
raise
@staticmethod
def add_volc_model_for_tenant(tenant_id, public_key_pem, model_config):
"""为指定租户添加火山模型记录"""
try:
# 检查必要字段
required_fields = ["model_name", "provider_name", "model_type", "volc_api_key"]
for field in required_fields:
if field not in model_config:
logger.error(f"火山模型配置缺少必要字段: {field}")
raise ValueError(f"火山模型配置缺少必要字段: {field}")
# 检查模型是否已存在
model_name = model_config["model_name"]
if ModelManager.check_model_exists(tenant_id, model_name):
logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。")
return
# 加密API密钥
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["volc_api_key"])
# 根据模型类型构造加密配置
if model_config["model_type"] == "embeddings":
logger.warning(f"火山模型不支持embeddings类型跳过添加。")
return
elif model_config["model_type"] == "text-generation":
encrypted_config = {
"auth_method": model_config.get("auth_method", "api_key"),
"volc_api_key": encrypted_api_key,
"volc_region": model_config.get("volc_region", "cn-beijing"),
"api_endpoint_host": model_config.get("api_endpoint_host", "https://ark.cn-beijing.volces.com/api/v3"),
"endpoint_id": model_config.get("endpoint_id", ""),
"base_model_name": model_config.get("base_model_name", ""),
}
else:
raise ValueError(f"不支持的模型类型: {model_config['model_type']}")
# 插入模型记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO provider_models (
id, tenant_id, provider_name, model_name, model_type,
encrypted_config, is_valid, created_at, updated_at
) VALUES (
uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW()
);
"""
cursor.execute(
insert_query, (
tenant_id,
model_config["provider_name"],
model_config["model_name"],
model_config["model_type"],
json.dumps(encrypted_config)
)
)
logger.info(f"为租户 {tenant_id} 添加火山模型记录成功!")
return True
except Exception as e:
logger.error(f"为租户 {tenant_id} 添加火山模型记录失败: {e}")
raise
@staticmethod
def delete_models_for_tenant(tenant_id):
"""删除指定租户下的所有模型记录"""
try:
query = """
DELETE FROM provider_models
WHERE tenant_id = %s;
"""
rows_affected = execute_update(query, (tenant_id,))
logger.info(f"租户 {tenant_id} 下的所有模型记录已删除,共 {rows_affected} 条。")
return rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 的模型记录失败: {e}")
raise
@staticmethod
def delete_model_for_tenant(tenant_id, model_name):
"""删除指定租户下的特定模型记录"""
try:
query = """
DELETE FROM provider_models
WHERE tenant_id = %s AND model_name = %s;
"""
rows_affected = execute_update(query, (tenant_id, model_name))
if rows_affected > 0:
logger.info(f"租户 {tenant_id} 下的模型 {model_name} 已删除。")
else:
logger.warning(f"租户 {tenant_id} 下不存在模型 {model_name}")
return rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}")
raise
@staticmethod
def delete_specific_model_for_all_tenants(model_name):
"""删除所有租户下的特定模型记录"""
try:
tenants = TenantManager.get_all_tenants()
total_deleted = 0
for tenant in tenants:
tenant_id = tenant['id']
try:
rows_affected = ModelManager.delete_model_for_tenant(tenant_id, model_name)
total_deleted += rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}")
logger.info(f"所有租户下的模型 {model_name} 已删除,共 {total_deleted} 条。")
return total_deleted
except Exception as e:
logger.error(f"删除所有租户下的模型 {model_name} 失败: {e}")
raise
@staticmethod
def add_models_for_all_tenants(config_path=CONFIG_PATHS['model_config']):
"""为所有租户添加模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取所有租户
tenants = TenantManager.get_all_tenants()
total_added = 0
# 为每个租户添加模型
for tenant in tenants:
tenant_id = tenant['id']
public_key_pem = tenant['encrypt_public_key']
for model_config in models:
if ModelManager.add_model_for_tenant(tenant_id, public_key_pem, model_config):
total_added += 1
logger.info(f"为所有租户添加模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为所有租户添加模型失败: {e}")
raise
@staticmethod
def add_volc_models_for_all_tenants(config_path=CONFIG_PATHS['volc_model_config']):
"""为所有租户添加火山模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取所有租户
tenants = TenantManager.get_all_tenants()
total_added = 0
# 为每个租户添加模型
for tenant in tenants:
tenant_id = tenant['id']
public_key_pem = tenant['encrypt_public_key']
for model_config in models:
if ModelManager.add_volc_model_for_tenant(tenant_id, public_key_pem, model_config):
total_added += 1
logger.info(f"为所有租户添加火山模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为所有租户添加火山模型失败: {e}")
raise
@staticmethod
def add_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['model_config']):
"""为指定租户添加模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取租户信息
tenant = TenantManager.get_tenant_by_name(tenant_name)
if not tenant:
logger.error(f"未找到租户: {tenant_name}")
return 0
# 为租户添加模型
total_added = 0
for model_config in models:
if ModelManager.add_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
total_added += 1
logger.info(f"为租户 {tenant_name} 添加模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为租户 {tenant_name} 添加模型失败: {e}")
raise
@staticmethod
def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']):
"""为指定租户添加火山模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取租户信息
tenant = TenantManager.get_tenant_by_name(tenant_name)
if not tenant:
logger.error(f"未找到租户: {tenant_name}")
return 0
# 为租户添加模型
total_added = 0
for model_config in models:
if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
total_added += 1
logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}")
raise