import os import json import logging import psycopg2.extras from .database import get_db_cursor, execute_query, execute_update from encryption import Encryption from .config import CONFIG_PATHS from tenant_manager import TenantManager # 配置日志 logger = logging.getLogger(__name__) class ProviderManager: """提供商管理类""" @staticmethod def load_config(config_path): """加载提供商配置文件""" try: if not os.path.exists(config_path): raise FileNotFoundError(f"配置文件 {config_path} 不存在!") with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) return config except Exception as e: logger.error(f"加载配置文件失败: {e}") raise @staticmethod def check_provider_exists(tenant_id, provider_name): """检查指定租户是否已存在指定的提供商""" try: query = """ SELECT COUNT(*) FROM providers WHERE tenant_id = %s AND provider_name = %s; """ count = execute_query(query, (tenant_id, provider_name), fetch_one=True)[0] return count > 0 except Exception as e: logger.error(f"检查提供商是否存在时发生错误: {e}") return False @staticmethod def add_provider_for_tenant(tenant_id, public_key_pem, provider_config): """为指定租户添加提供商记录""" try: # 检查必要字段 if "provider_name" not in provider_config: logger.error("提供商配置缺少必要字段: provider_name") raise ValueError("提供商配置缺少必要字段: provider_name") if "config" not in provider_config: logger.error("提供商配置缺少必要字段: config") raise ValueError("提供商配置缺少必要字段: config") # 检查提供商是否已存在 provider_name = provider_config["provider_name"] if ProviderManager.check_provider_exists(tenant_id, provider_name): logger.info(f"租户 {tenant_id} 已存在提供商 {provider_name},跳过添加。") return # 加密API密钥 api_key = provider_config["config"].get("dashscope_api_key", "") if api_key: encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, api_key) # 构造加密配置 encrypted_config = { "dashscope_api_key": encrypted_api_key } # 插入提供商记录 with get_db_cursor() as cursor: psycopg2.extras.register_uuid() insert_query = """ INSERT INTO providers ( id, tenant_id, provider_name, provider_type, encrypted_config, is_valid, created_at, updated_at ) VALUES ( uuid_generate_v4(), %s, %s, %s, %s, TRUE, NOW(), NOW() ); """ cursor.execute( insert_query, ( tenant_id, provider_name, provider_config.get("provider_type", "custom"), json.dumps(encrypted_config) ) ) logger.info(f"为租户 {tenant_id} 添加提供商记录成功!") return True else: logger.warning(f"提供商 {provider_name} 配置中缺少API密钥,跳过添加。") return False except Exception as e: logger.error(f"为租户 {tenant_id} 添加提供商记录失败: {e}") raise @staticmethod def delete_providers_for_tenant(tenant_id): """删除指定租户下的所有提供商记录""" try: query = """ DELETE FROM providers WHERE tenant_id = %s; """ rows_affected = execute_update(query, (tenant_id,)) logger.info(f"租户 {tenant_id} 下的所有提供商记录已删除,共 {rows_affected} 条。") return rows_affected except Exception as e: logger.error(f"删除租户 {tenant_id} 的提供商记录失败: {e}") raise @staticmethod def delete_provider_for_tenant(tenant_id, provider_name): """删除指定租户下的特定提供商记录""" try: query = """ DELETE FROM providers WHERE tenant_id = %s AND provider_name = %s; """ rows_affected = execute_update(query, (tenant_id, provider_name)) if rows_affected > 0: logger.info(f"租户 {tenant_id} 下的提供商 {provider_name} 已删除。") else: logger.warning(f"租户 {tenant_id} 下不存在提供商 {provider_name}。") return rows_affected except Exception as e: logger.error(f"删除租户 {tenant_id} 下的提供商 {provider_name} 失败: {e}") raise @staticmethod def add_providers_for_all_tenants(config_path=CONFIG_PATHS['provider_config']): """为所有租户添加提供商""" try: # 加载提供商配置 config = ProviderManager.load_config(config_path) providers = config.get("providers", []) # 获取所有租户 tenants = TenantManager.get_all_tenants() total_added = 0 # 为每个租户添加提供商 for tenant in tenants: tenant_id = tenant['id'] public_key_pem = tenant['encrypt_public_key'] for provider_config in providers: if ProviderManager.add_provider_for_tenant(tenant_id, public_key_pem, provider_config): total_added += 1 logger.info(f"为所有租户添加提供商完成,共添加 {total_added} 条记录。") return total_added except Exception as e: logger.error(f"为所有租户添加提供商失败: {e}") raise @staticmethod def add_providers_for_tenant(tenant_name, config_path=CONFIG_PATHS['provider_config']): """为指定租户添加提供商""" try: # 加载提供商配置 config = ProviderManager.load_config(config_path) providers = config.get("providers", []) # 获取租户信息 tenant = TenantManager.get_tenant_by_name(tenant_name) if not tenant: logger.error(f"未找到租户: {tenant_name}") return 0 # 为租户添加提供商 total_added = 0 for provider_config in providers: if ProviderManager.add_provider_for_tenant(tenant['id'], tenant['encrypt_public_key'], provider_config): total_added += 1 logger.info(f"为租户 {tenant_name} 添加提供商完成,共添加 {total_added} 条记录。") return total_added except Exception as e: logger.error(f"为租户 {tenant_name} 添加提供商失败: {e}") raise @staticmethod def get_providers_for_tenant(tenant_id): """获取指定租户下的所有提供商""" try: query = """ SELECT id, provider_name, provider_type, is_valid FROM providers WHERE tenant_id = %s; """ providers = execute_query(query, (tenant_id,)) return providers except Exception as e: logger.error(f"获取租户 {tenant_id} 的提供商失败: {e}") return []