dify_admin/api/model_manager.py

375 lines
16 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
import psycopg2.extras
from database import get_db_cursor, execute_query, execute_update
from encryption import Encryption
from config import CONFIG_PATHS
from tenant_manager import TenantManager
# 配置日志
logger = logging.getLogger(__name__)
class ModelManager:
"""模型管理类"""
@staticmethod
def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']):
"""为指定租户添加火山模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取租户信息
tenant = TenantManager.get_tenant_by_name(tenant_name)
if not tenant:
logger.error(f"未找到租户: {tenant_name}")
return 0
# 为租户添加模型
total_added = 0
for model_config in models:
if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
total_added += 1
logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}")
raise
@staticmethod
def get_tenant_models(tenant_id: str):
"""获取租户的模型列表"""
try:
query = """
SELECT
id,
provider_name,
model_name,
model_type,
encrypted_config,
is_valid,
created_at,
updated_at
FROM provider_models
WHERE tenant_id = %s
ORDER BY created_at DESC;
"""
with get_db_cursor() as cursor:
cursor.execute(query, (tenant_id,))
models = cursor.fetchall()
return {
"tenant_id": tenant_id,
"models": [{
"id": str(model[0]),
"provider_name": model[1],
"model_name": model[2],
"model_type": model[3],
"encrypted_config": json.loads(model[4]),
"is_valid": model[5],
"created_at": model[6],
"updated_at": model[7]
} for model in models]
}
except Exception as e:
logger.error(f"获取租户模型失败: {e}")
raise
@staticmethod
def add_model_for_tenant(tenant_id, public_key_pem, model_config):
"""为指定租户添加模型记录"""
try:
# 检查必要字段
required_fields = ["model_name", "provider_name", "model_type", "api_key"]
for field in required_fields:
if field not in model_config:
logger.error(f"模型配置缺少必要字段: {field}")
raise ValueError(f"模型配置缺少必要字段: {field}")
# 检查模型是否已存在
model_name = model_config["model_name"]
if ModelManager.check_model_exists(tenant_id, model_name):
logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。")
return
# 加密API密钥
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["api_key"])
# 根据模型类型构造加密配置
if model_config["model_type"] in ["embeddings", "reranking"]:
encrypted_config = {
"display_name": model_config.get("display_name", ""),
"api_key": encrypted_api_key,
"endpoint_url": model_config.get("endpoint_url", ""),
"context_size": model_config.get("context_size", "")
}
elif model_config["model_type"] == "text-generation":
encrypted_config = {
"display_name": model_config.get("display_name", ""),
"api_key": encrypted_api_key,
"endpoint_url": model_config.get("endpoint_url", ""),
"mode": model_config.get("mode", ""),
"context_size": model_config.get("context_size", ""),
"max_tokens_to_sample": model_config.get("max_tokens_to_sample", ""),
"function_calling_type": model_config.get("function_calling_type", ""),
"stream_function_calling": model_config.get("stream_function_calling", ""),
"vision_support": model_config.get("vision_support", ""),
"stream_mode_delimiter": model_config.get("stream_mode_delimiter", "")
}
else:
raise ValueError(f"不支持的模型类型: {model_config['model_type']}")
# 插入模型记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO provider_models (
id, tenant_id, provider_name, model_name, model_type,
encrypted_config, is_valid, created_at, updated_at
) VALUES (
uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW()
);
"""
cursor.execute(
insert_query, (
tenant_id,
model_config["provider_name"],
model_config["model_name"],
model_config["model_type"],
json.dumps(encrypted_config)
)
)
logger.info(f"为租户 {tenant_id} 添加模型记录成功!")
return True
except Exception as e:
logger.error(f"为租户 {tenant_id} 添加模型记录失败: {e}")
raise
@staticmethod
def add_volc_model_for_tenant(tenant_id, public_key_pem, model_config):
"""为指定租户添加火山模型记录"""
try:
# 检查必要字段
required_fields = ["model_name", "provider_name", "model_type", "volc_api_key"]
for field in required_fields:
if field not in model_config:
logger.error(f"火山模型配置缺少必要字段: {field}")
raise ValueError(f"火山模型配置缺少必要字段: {field}")
# 检查模型是否已存在
model_name = model_config["model_name"]
if ModelManager.check_model_exists(tenant_id, model_name):
logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。")
return
# 加密API密钥
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["volc_api_key"])
# 根据模型类型构造加密配置
if model_config["model_type"] == "embeddings":
logger.warning(f"火山模型不支持embeddings类型跳过添加。")
return
elif model_config["model_type"] == "text-generation":
encrypted_config = {
"auth_method": model_config.get("auth_method", "api_key"),
"volc_api_key": encrypted_api_key,
"volc_region": model_config.get("volc_region", "cn-beijing"),
"api_endpoint_host": model_config.get("api_endpoint_host", "https://ark.cn-beijing.volces.com/api/v3"),
"endpoint_id": model_config.get("endpoint_id", ""),
"base_model_name": model_config.get("base_model_name", ""),
}
else:
raise ValueError(f"不支持的模型类型: {model_config['model_type']}")
# 插入模型记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO provider_models (
id, tenant_id, provider_name, model_name, model_type,
encrypted_config, is_valid, created_at, updated_at
) VALUES (
uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW()
);
"""
cursor.execute(
insert_query, (
tenant_id,
model_config["provider_name"],
model_config["model_name"],
model_config["model_type"],
json.dumps(encrypted_config)
)
)
logger.info(f"为租户 {tenant_id} 添加火山模型记录成功!")
return True
except Exception as e:
logger.error(f"为租户 {tenant_id} 添加火山模型记录失败: {e}")
raise
@staticmethod
def delete_models_for_tenant(tenant_id):
"""删除指定租户下的所有模型记录"""
try:
query = """
DELETE FROM provider_models
WHERE tenant_id = %s;
"""
rows_affected = execute_update(query, (tenant_id,))
logger.info(f"租户 {tenant_id} 下的所有模型记录已删除,共 {rows_affected} 条。")
return rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 的模型记录失败: {e}")
raise
@staticmethod
def delete_model_for_tenant(tenant_id, model_name):
"""删除指定租户下的特定模型记录"""
try:
query = """
DELETE FROM provider_models
WHERE tenant_id = %s AND model_name = %s;
"""
rows_affected = execute_update(query, (tenant_id, model_name))
if rows_affected > 0:
logger.info(f"租户 {tenant_id} 下的模型 {model_name} 已删除。")
else:
logger.warning(f"租户 {tenant_id} 下不存在模型 {model_name}")
return rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}")
raise
@staticmethod
def delete_specific_model_for_all_tenants(model_name):
"""删除所有租户下的特定模型记录"""
try:
tenants = TenantManager.get_all_tenants()
total_deleted = 0
for tenant in tenants:
tenant_id = tenant['id']
try:
rows_affected = ModelManager.delete_model_for_tenant(tenant_id, model_name)
total_deleted += rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}")
logger.info(f"所有租户下的模型 {model_name} 已删除,共 {total_deleted} 条。")
return total_deleted
except Exception as e:
logger.error(f"删除所有租户下的模型 {model_name} 失败: {e}")
raise
@staticmethod
def add_models_for_all_tenants(config_path=CONFIG_PATHS['model_config']):
"""为所有租户添加模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取所有租户
tenants = TenantManager.get_all_tenants()
total_added = 0
# 为每个租户添加模型
for tenant in tenants:
tenant_id = tenant['id']
public_key_pem = tenant['encrypt_public_key']
for model_config in models:
if ModelManager.add_model_for_tenant(tenant_id, public_key_pem, model_config):
total_added += 1
logger.info(f"为所有租户添加模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为所有租户添加模型失败: {e}")
raise
@staticmethod
def add_volc_models_for_all_tenants(config_path=CONFIG_PATHS['volc_model_config']):
"""为所有租户添加火山模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取所有租户
tenants = TenantManager.get_all_tenants()
total_added = 0
# 为每个租户添加模型
for tenant in tenants:
tenant_id = tenant['id']
public_key_pem = tenant['encrypt_public_key']
for model_config in models:
if ModelManager.add_volc_model_for_tenant(tenant_id, public_key_pem, model_config):
total_added += 1
logger.info(f"为所有租户添加火山模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为所有租户添加火山模型失败: {e}")
raise
@staticmethod
def add_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['model_config']):
"""为指定租户添加模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取租户信息
tenant = TenantManager.get_tenant_by_name(tenant_name)
if not tenant:
logger.error(f"未找到租户: {tenant_name}")
return 0
# 为租户添加模型
total_added = 0
for model_config in models:
if ModelManager.add_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
total_added += 1
logger.info(f"为租户 {tenant_name} 添加模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为租户 {tenant_name} 添加模型失败: {e}")
raise
@staticmethod
def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']):
"""为指定租户添加火山模型"""
try:
# 加载模型配置
config = ModelManager.load_config(config_path)
models = config.get("models", [])
# 获取租户信息
tenant = TenantManager.get_tenant_by_name(tenant_name)
if not tenant:
logger.error(f"未找到租户: {tenant_name}")
return 0
# 为租户添加模型
total_added = 0
for model_config in models:
if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
total_added += 1
logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}")
raise