import os import json import logging import psycopg2.extras from database import get_db_cursor, execute_query, execute_update from encryption import Encryption from config import CONFIG_PATHS from tenant_manager import TenantManager # 配置日志 logger = logging.getLogger(__name__) class ModelManager: """模型管理类""" @staticmethod def load_config(config_path): """加载模型配置文件""" try: if not os.path.exists(config_path): raise FileNotFoundError(f"配置文件 {config_path} 不存在!") with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) return config except Exception as e: logger.error(f"加载配置文件失败: {e}") raise @staticmethod def check_model_exists(tenant_id, model_name): """检查指定租户是否已存在指定的模型""" try: query = """ SELECT COUNT(*) FROM provider_models WHERE tenant_id = %s AND model_name = %s; """ count = execute_query(query, (tenant_id, model_name), fetch_one=True)[0] return count > 0 except Exception as e: logger.error(f"检查模型是否存在时发生错误: {e}") return False @staticmethod def add_model_for_tenant(tenant_id, public_key_pem, model_config): """为指定租户添加模型记录""" try: # 检查必要字段 required_fields = ["model_name", "provider_name", "model_type", "api_key"] for field in required_fields: if field not in model_config: logger.error(f"模型配置缺少必要字段: {field}") raise ValueError(f"模型配置缺少必要字段: {field}") # 检查模型是否已存在 model_name = model_config["model_name"] if ModelManager.check_model_exists(tenant_id, model_name): logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。") return # 加密API密钥 encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["api_key"]) # 根据模型类型构造加密配置 if model_config["model_type"] in ["embeddings", "reranking"]: encrypted_config = { "display_name": model_config.get("display_name", ""), "api_key": encrypted_api_key, "endpoint_url": model_config.get("endpoint_url", ""), "context_size": model_config.get("context_size", "") } elif model_config["model_type"] == "text-generation": encrypted_config = { "display_name": model_config.get("display_name", ""), "api_key": encrypted_api_key, "endpoint_url": model_config.get("endpoint_url", ""), "mode": model_config.get("mode", ""), "context_size": model_config.get("context_size", ""), "max_tokens_to_sample": model_config.get("max_tokens_to_sample", ""), "function_calling_type": model_config.get("function_calling_type", ""), "stream_function_calling": model_config.get("stream_function_calling", ""), "vision_support": model_config.get("vision_support", ""), "stream_mode_delimiter": model_config.get("stream_mode_delimiter", "") } else: raise ValueError(f"不支持的模型类型: {model_config['model_type']}") # 插入模型记录 with get_db_cursor() as cursor: psycopg2.extras.register_uuid() insert_query = """ INSERT INTO provider_models ( id, tenant_id, provider_name, model_name, model_type, encrypted_config, is_valid, created_at, updated_at ) VALUES ( uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW() ); """ cursor.execute( insert_query, ( tenant_id, model_config["provider_name"], model_config["model_name"], model_config["model_type"], json.dumps(encrypted_config) ) ) logger.info(f"为租户 {tenant_id} 添加模型记录成功!") return True except Exception as e: logger.error(f"为租户 {tenant_id} 添加模型记录失败: {e}") raise @staticmethod def add_volc_model_for_tenant(tenant_id, public_key_pem, model_config): """为指定租户添加火山模型记录""" try: # 检查必要字段 required_fields = ["model_name", "provider_name", "model_type", "volc_api_key"] for field in required_fields: if field not in model_config: logger.error(f"火山模型配置缺少必要字段: {field}") raise ValueError(f"火山模型配置缺少必要字段: {field}") # 检查模型是否已存在 model_name = model_config["model_name"] if ModelManager.check_model_exists(tenant_id, model_name): logger.info(f"租户 {tenant_id} 已存在模型 {model_name},跳过添加。") return # 加密API密钥 encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, model_config["volc_api_key"]) # 根据模型类型构造加密配置 if model_config["model_type"] == "embeddings": logger.warning(f"火山模型不支持embeddings类型,跳过添加。") return elif model_config["model_type"] == "text-generation": encrypted_config = { "auth_method": model_config.get("auth_method", "api_key"), "volc_api_key": encrypted_api_key, "volc_region": model_config.get("volc_region", "cn-beijing"), "api_endpoint_host": model_config.get("api_endpoint_host", "https://ark.cn-beijing.volces.com/api/v3"), "endpoint_id": model_config.get("endpoint_id", ""), "base_model_name": model_config.get("base_model_name", ""), } else: raise ValueError(f"不支持的模型类型: {model_config['model_type']}") # 插入模型记录 with get_db_cursor() as cursor: psycopg2.extras.register_uuid() insert_query = """ INSERT INTO provider_models ( id, tenant_id, provider_name, model_name, model_type, encrypted_config, is_valid, created_at, updated_at ) VALUES ( uuid_generate_v4(), %s, %s, %s, %s, %s, TRUE, NOW(), NOW() ); """ cursor.execute( insert_query, ( tenant_id, model_config["provider_name"], model_config["model_name"], model_config["model_type"], json.dumps(encrypted_config) ) ) logger.info(f"为租户 {tenant_id} 添加火山模型记录成功!") return True except Exception as e: logger.error(f"为租户 {tenant_id} 添加火山模型记录失败: {e}") raise @staticmethod def delete_models_for_tenant(tenant_id): """删除指定租户下的所有模型记录""" try: query = """ DELETE FROM provider_models WHERE tenant_id = %s; """ rows_affected = execute_update(query, (tenant_id,)) logger.info(f"租户 {tenant_id} 下的所有模型记录已删除,共 {rows_affected} 条。") return rows_affected except Exception as e: logger.error(f"删除租户 {tenant_id} 的模型记录失败: {e}") raise @staticmethod def delete_model_for_tenant(tenant_id, model_name): """删除指定租户下的特定模型记录""" try: query = """ DELETE FROM provider_models WHERE tenant_id = %s AND model_name = %s; """ rows_affected = execute_update(query, (tenant_id, model_name)) if rows_affected > 0: logger.info(f"租户 {tenant_id} 下的模型 {model_name} 已删除。") else: logger.warning(f"租户 {tenant_id} 下不存在模型 {model_name}。") return rows_affected except Exception as e: logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}") raise @staticmethod def delete_specific_model_for_all_tenants(model_name): """删除所有租户下的特定模型记录""" try: tenants = TenantManager.get_all_tenants() total_deleted = 0 for tenant in tenants: tenant_id = tenant['id'] try: rows_affected = ModelManager.delete_model_for_tenant(tenant_id, model_name) total_deleted += rows_affected except Exception as e: logger.error(f"删除租户 {tenant_id} 下的模型 {model_name} 失败: {e}") logger.info(f"所有租户下的模型 {model_name} 已删除,共 {total_deleted} 条。") return total_deleted except Exception as e: logger.error(f"删除所有租户下的模型 {model_name} 失败: {e}") raise @staticmethod def add_models_for_all_tenants(config_path=CONFIG_PATHS['model_config']): """为所有租户添加模型""" try: # 加载模型配置 config = ModelManager.load_config(config_path) models = config.get("models", []) # 获取所有租户 tenants = TenantManager.get_all_tenants() total_added = 0 # 为每个租户添加模型 for tenant in tenants: tenant_id = tenant['id'] public_key_pem = tenant['encrypt_public_key'] for model_config in models: if ModelManager.add_model_for_tenant(tenant_id, public_key_pem, model_config): total_added += 1 logger.info(f"为所有租户添加模型完成,共添加 {total_added} 条记录。") return total_added except Exception as e: logger.error(f"为所有租户添加模型失败: {e}") raise @staticmethod def add_volc_models_for_all_tenants(config_path=CONFIG_PATHS['volc_model_config']): """为所有租户添加火山模型""" try: # 加载模型配置 config = ModelManager.load_config(config_path) models = config.get("models", []) # 获取所有租户 tenants = TenantManager.get_all_tenants() total_added = 0 # 为每个租户添加模型 for tenant in tenants: tenant_id = tenant['id'] public_key_pem = tenant['encrypt_public_key'] for model_config in models: if ModelManager.add_volc_model_for_tenant(tenant_id, public_key_pem, model_config): total_added += 1 logger.info(f"为所有租户添加火山模型完成,共添加 {total_added} 条记录。") return total_added except Exception as e: logger.error(f"为所有租户添加火山模型失败: {e}") raise @staticmethod def add_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['model_config']): """为指定租户添加模型""" try: # 加载模型配置 config = ModelManager.load_config(config_path) models = config.get("models", []) # 获取租户信息 tenant = TenantManager.get_tenant_by_name(tenant_name) if not tenant: logger.error(f"未找到租户: {tenant_name}") return 0 # 为租户添加模型 total_added = 0 for model_config in models: if ModelManager.add_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config): total_added += 1 logger.info(f"为租户 {tenant_name} 添加模型完成,共添加 {total_added} 条记录。") return total_added except Exception as e: logger.error(f"为租户 {tenant_name} 添加模型失败: {e}") raise @staticmethod def add_volc_models_for_tenant(tenant_name, config_path=CONFIG_PATHS['volc_model_config']): """为指定租户添加火山模型""" try: # 加载模型配置 config = ModelManager.load_config(config_path) models = config.get("models", []) # 获取租户信息 tenant = TenantManager.get_tenant_by_name(tenant_name) if not tenant: logger.error(f"未找到租户: {tenant_name}") return 0 # 为租户添加模型 total_added = 0 for model_config in models: if ModelManager.add_volc_model_for_tenant(tenant['id'], tenant['encrypt_public_key'], model_config): total_added += 1 logger.info(f"为租户 {tenant_name} 添加火山模型完成,共添加 {total_added} 条记录。") return total_added except Exception as e: logger.error(f"为租户 {tenant_name} 添加火山模型失败: {e}") raise