375 lines
16 KiB
Python
Executable File
375 lines
16 KiB
Python
Executable File
import os
|
||
import json
|
||
import logging
|
||
import psycopg2.extras
|
||
from database import get_db_cursor, execute_query, execute_update
|
||
from encryption import Encryption
|
||
from config import CONFIG_PATHS
|
||
from tenant_manager import TenantManager
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ModelManager:
|
||
"""模型管理类"""
|
||
|
||
@staticmethod
|
||
def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']):
|
||
"""为指定租户添加火山模型"""
|
||
try:
|
||
# 加载模型配置
|
||
config = ModelManager.load_config(config_path)
|
||
models = config.get("models", [])
|
||
|
||
# 获取租户信息
|
||
tenant = TenantManager.get_tenant_by_name(tenant_name)
|
||
if not tenant:
|
||
logger.error(f"未找到租户: {tenant_name}")
|
||
return 0
|
||
|
||
# 为租户添加模型
|
||
total_added = 0
|
||
for model_config in models:
|
||
if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def get_tenant_models(tenant_id: str):
|
||
"""获取租户的模型列表"""
|
||
try:
|
||
query = """
|
||
SELECT
|
||
id,
|
||
provider_name,
|
||
model_name,
|
||
model_type,
|
||
encrypted_config,
|
||
is_valid,
|
||
created_at,
|
||
updated_at
|
||
FROM provider_models
|
||
WHERE tenant_id = %s
|
||
ORDER BY created_at DESC;
|
||
"""
|
||
with get_db_cursor() as cursor:
|
||
cursor.execute(query, (tenant_id,))
|
||
models = cursor.fetchall()
|
||
|
||
return {
|
||
"tenant_id": tenant_id,
|
||
"models": [{
|
||
"id": str(model[0]),
|
||
"provider_name": model[1],
|
||
"model_name": model[2],
|
||
"model_type": model[3],
|
||
"encrypted_config": json.loads(model[4]),
|
||
"is_valid": model[5],
|
||
"created_at": model[6],
|
||
"updated_at": model[7]
|
||
} for model in models]
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取租户模型失败: {e}")
|
||
raise
|
||
|
||
|
||
@staticmethod
|
||
def add_model_for_tenant(tenant_id, public_key_pem, model_config):
|
||
"""为指定租户添加模型记录"""
|
||
try:
|
||
# 检查必要字段
|
||
required_fields = ["model_name", "provider_name", "model_type", "api_key"]
|
||
for field in required_fields:
|
||
if field not in model_config:
|
||
logger.error(f"模型配置缺少必要字段: {field}")
|
||
raise ValueError(f"模型配置缺少必要字段: {field}")
|
||
|
||
# 检查模型是否已存在
|
||
model_name = model_config["model_name"]
|
||
if ModelManager.check_model_exists(tenant_id, model_name):
|
||
logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。")
|
||
return
|
||
|
||
# 加密API密钥
|
||
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["api_key"])
|
||
|
||
# 根据模型类型构造加密配置
|
||
if model_config["model_type"] in ["embeddings", "reranking"]:
|
||
encrypted_config = {
|
||
"display_name": model_config.get("display_name", ""),
|
||
"api_key": encrypted_api_key,
|
||
"endpoint_url": model_config.get("endpoint_url", ""),
|
||
"context_size": model_config.get("context_size", "")
|
||
}
|
||
elif model_config["model_type"] == "text-generation":
|
||
encrypted_config = {
|
||
"display_name": model_config.get("display_name", ""),
|
||
"api_key": encrypted_api_key,
|
||
"endpoint_url": model_config.get("endpoint_url", ""),
|
||
"mode": model_config.get("mode", ""),
|
||
"context_size": model_config.get("context_size", ""),
|
||
"max_tokens_to_sample": model_config.get("max_tokens_to_sample", ""),
|
||
"function_calling_type": model_config.get("function_calling_type", ""),
|
||
"stream_function_calling": model_config.get("stream_function_calling", ""),
|
||
"vision_support": model_config.get("vision_support", ""),
|
||
"stream_mode_delimiter": model_config.get("stream_mode_delimiter", "")
|
||
}
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {model_config['model_type']}")
|
||
|
||
# 插入模型记录
|
||
with get_db_cursor() as cursor:
|
||
psycopg2.extras.register_uuid()
|
||
insert_query = """
|
||
INSERT INTO provider_models (
|
||
id, tenant_id, provider_name, model_name, model_type,
|
||
encrypted_config, is_valid, created_at, updated_at
|
||
) VALUES (
|
||
uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW()
|
||
);
|
||
"""
|
||
cursor.execute(
|
||
insert_query, (
|
||
tenant_id,
|
||
model_config["provider_name"],
|
||
model_config["model_name"],
|
||
model_config["model_type"],
|
||
json.dumps(encrypted_config)
|
||
)
|
||
)
|
||
|
||
logger.info(f"为租户 {tenant_id} 添加模型记录成功!")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_id} 添加模型记录失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_volc_model_for_tenant(tenant_id, public_key_pem, model_config):
|
||
"""为指定租户添加火山模型记录"""
|
||
try:
|
||
# 检查必要字段
|
||
required_fields = ["model_name", "provider_name", "model_type", "volc_api_key"]
|
||
for field in required_fields:
|
||
if field not in model_config:
|
||
logger.error(f"火山模型配置缺少必要字段: {field}")
|
||
raise ValueError(f"火山模型配置缺少必要字段: {field}")
|
||
|
||
# 检查模型是否已存在
|
||
model_name = model_config["model_name"]
|
||
if ModelManager.check_model_exists(tenant_id, model_name):
|
||
logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。")
|
||
return
|
||
|
||
# 加密API密钥
|
||
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["volc_api_key"])
|
||
|
||
# 根据模型类型构造加密配置
|
||
if model_config["model_type"] == "embeddings":
|
||
logger.warning(f"火山模型不支持embeddings类型,跳过添加。")
|
||
return
|
||
elif model_config["model_type"] == "text-generation":
|
||
encrypted_config = {
|
||
"auth_method": model_config.get("auth_method", "api_key"),
|
||
"volc_api_key": encrypted_api_key,
|
||
"volc_region": model_config.get("volc_region", "cn-beijing"),
|
||
"api_endpoint_host": model_config.get("api_endpoint_host", "https://ark.cn-beijing.volces.com/api/v3"),
|
||
"endpoint_id": model_config.get("endpoint_id", ""),
|
||
"base_model_name": model_config.get("base_model_name", ""),
|
||
}
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {model_config['model_type']}")
|
||
|
||
# 插入模型记录
|
||
with get_db_cursor() as cursor:
|
||
psycopg2.extras.register_uuid()
|
||
insert_query = """
|
||
INSERT INTO provider_models (
|
||
id, tenant_id, provider_name, model_name, model_type,
|
||
encrypted_config, is_valid, created_at, updated_at
|
||
) VALUES (
|
||
uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW()
|
||
);
|
||
"""
|
||
cursor.execute(
|
||
insert_query, (
|
||
tenant_id,
|
||
model_config["provider_name"],
|
||
model_config["model_name"],
|
||
model_config["model_type"],
|
||
json.dumps(encrypted_config)
|
||
)
|
||
)
|
||
|
||
logger.info(f"为租户 {tenant_id} 添加火山模型记录成功!")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_id} 添加火山模型记录失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def delete_models_for_tenant(tenant_id):
|
||
"""删除指定租户下的所有模型记录"""
|
||
try:
|
||
query = """
|
||
DELETE FROM provider_models
|
||
WHERE tenant_id = %s;
|
||
"""
|
||
rows_affected = execute_update(query, (tenant_id,))
|
||
logger.info(f"租户 {tenant_id} 下的所有模型记录已删除,共 {rows_affected} 条。")
|
||
return rows_affected
|
||
except Exception as e:
|
||
logger.error(f"删除租户 {tenant_id} 的模型记录失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def delete_model_for_tenant(tenant_id, model_name):
|
||
"""删除指定租户下的特定模型记录"""
|
||
try:
|
||
query = """
|
||
DELETE FROM provider_models
|
||
WHERE tenant_id = %s AND model_name = %s;
|
||
"""
|
||
rows_affected = execute_update(query, (tenant_id, model_name))
|
||
|
||
if rows_affected > 0:
|
||
logger.info(f"租户 {tenant_id} 下的模型 {model_name} 已删除。")
|
||
else:
|
||
logger.warning(f"租户 {tenant_id} 下不存在模型 {model_name}。")
|
||
|
||
return rows_affected
|
||
except Exception as e:
|
||
logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def delete_specific_model_for_all_tenants(model_name):
|
||
"""删除所有租户下的特定模型记录"""
|
||
try:
|
||
tenants = TenantManager.get_all_tenants()
|
||
total_deleted = 0
|
||
|
||
for tenant in tenants:
|
||
tenant_id = tenant['id']
|
||
try:
|
||
rows_affected = ModelManager.delete_model_for_tenant(tenant_id, model_name)
|
||
total_deleted += rows_affected
|
||
except Exception as e:
|
||
logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}")
|
||
|
||
logger.info(f"所有租户下的模型 {model_name} 已删除,共 {total_deleted} 条。")
|
||
return total_deleted
|
||
except Exception as e:
|
||
logger.error(f"删除所有租户下的模型 {model_name} 失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_models_for_all_tenants(config_path=CONFIG_PATHS['model_config']):
|
||
"""为所有租户添加模型"""
|
||
try:
|
||
# 加载模型配置
|
||
config = ModelManager.load_config(config_path)
|
||
models = config.get("models", [])
|
||
|
||
# 获取所有租户
|
||
tenants = TenantManager.get_all_tenants()
|
||
total_added = 0
|
||
|
||
# 为每个租户添加模型
|
||
for tenant in tenants:
|
||
tenant_id = tenant['id']
|
||
public_key_pem = tenant['encrypt_public_key']
|
||
for model_config in models:
|
||
if ModelManager.add_model_for_tenant(tenant_id, public_key_pem, model_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为所有租户添加模型完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为所有租户添加模型失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_volc_models_for_all_tenants(config_path=CONFIG_PATHS['volc_model_config']):
|
||
"""为所有租户添加火山模型"""
|
||
try:
|
||
# 加载模型配置
|
||
config = ModelManager.load_config(config_path)
|
||
models = config.get("models", [])
|
||
|
||
# 获取所有租户
|
||
tenants = TenantManager.get_all_tenants()
|
||
total_added = 0
|
||
|
||
# 为每个租户添加模型
|
||
for tenant in tenants:
|
||
tenant_id = tenant['id']
|
||
public_key_pem = tenant['encrypt_public_key']
|
||
for model_config in models:
|
||
if ModelManager.add_volc_model_for_tenant(tenant_id, public_key_pem, model_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为所有租户添加火山模型完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为所有租户添加火山模型失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['model_config']):
|
||
"""为指定租户添加模型"""
|
||
try:
|
||
# 加载模型配置
|
||
config = ModelManager.load_config(config_path)
|
||
models = config.get("models", [])
|
||
|
||
# 获取租户信息
|
||
tenant = TenantManager.get_tenant_by_name(tenant_name)
|
||
if not tenant:
|
||
logger.error(f"未找到租户: {tenant_name}")
|
||
return 0
|
||
|
||
# 为租户添加模型
|
||
total_added = 0
|
||
for model_config in models:
|
||
if ModelManager.add_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为租户 {tenant_name} 添加模型完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_name} 添加模型失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']):
|
||
"""为指定租户添加火山模型"""
|
||
try:
|
||
# 加载模型配置
|
||
config = ModelManager.load_config(config_path)
|
||
models = config.get("models", [])
|
||
|
||
# 获取租户信息
|
||
tenant = TenantManager.get_tenant_by_name(tenant_name)
|
||
if not tenant:
|
||
logger.error(f"未找到租户: {tenant_name}")
|
||
return 0
|
||
|
||
# 为租户添加模型
|
||
total_added = 0
|
||
for model_config in models:
|
||
if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}")
|
||
raise
|