1. 添加README说明项目结构 2. 配置Python和Node.js的.gitignore 3. 包含认证模块和账号管理的前后端基础代码 4. 开发计划文档记录当前阶段任务
202 lines
8.0 KiB
Python
202 lines
8.0 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
import psycopg2.extras
|
||
from database import get_db_cursor, execute_query, execute_update
|
||
from encryption import Encryption
|
||
from config import CONFIG_PATHS
|
||
from tenant_manager import TenantManager
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ProviderManager:
|
||
"""提供商管理类"""
|
||
|
||
@staticmethod
|
||
def load_config(config_path):
|
||
"""加载提供商配置文件"""
|
||
try:
|
||
if not os.path.exists(config_path):
|
||
raise FileNotFoundError(f"配置文件 {config_path} 不存在!")
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
config = json.load(f)
|
||
return config
|
||
except Exception as e:
|
||
logger.error(f"加载配置文件失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def check_provider_exists(tenant_id, provider_name):
|
||
"""检查指定租户是否已存在指定的提供商"""
|
||
try:
|
||
query = """
|
||
SELECT COUNT(*) FROM providers
|
||
WHERE tenant_id = %s AND provider_name = %s;
|
||
"""
|
||
count = execute_query(query, (tenant_id, provider_name), fetch_one=True)[0]
|
||
return count > 0
|
||
except Exception as e:
|
||
logger.error(f"检查提供商是否存在时发生错误: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def add_provider_for_tenant(tenant_id, public_key_pem, provider_config):
|
||
"""为指定租户添加提供商记录"""
|
||
try:
|
||
# 检查必要字段
|
||
if "provider_name" not in provider_config:
|
||
logger.error("提供商配置缺少必要字段: provider_name")
|
||
raise ValueError("提供商配置缺少必要字段: provider_name")
|
||
|
||
if "config" not in provider_config:
|
||
logger.error("提供商配置缺少必要字段: config")
|
||
raise ValueError("提供商配置缺少必要字段: config")
|
||
|
||
# 检查提供商是否已存在
|
||
provider_name = provider_config["provider_name"]
|
||
if ProviderManager.check_provider_exists(tenant_id, provider_name):
|
||
logger.info(f"租户 {tenant_id} 已存在提供商 {provider_name},跳过添加。")
|
||
return
|
||
|
||
# 加密API密钥
|
||
api_key = provider_config["config"].get("dashscope_api_key", "")
|
||
if api_key:
|
||
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, api_key)
|
||
|
||
# 构造加密配置
|
||
encrypted_config = {
|
||
"dashscope_api_key": encrypted_api_key
|
||
}
|
||
|
||
# 插入提供商记录
|
||
with get_db_cursor() as cursor:
|
||
psycopg2.extras.register_uuid()
|
||
insert_query = """
|
||
INSERT INTO providers (
|
||
id, tenant_id, provider_name, provider_type,
|
||
encrypted_config, is_valid, created_at, updated_at
|
||
) VALUES (
|
||
uuid_generate_v4(), %s, %s, %s, %s, TRUE, NOW(), NOW()
|
||
);
|
||
"""
|
||
cursor.execute(
|
||
insert_query, (
|
||
tenant_id,
|
||
provider_name,
|
||
provider_config.get("provider_type", "custom"),
|
||
json.dumps(encrypted_config)
|
||
)
|
||
)
|
||
|
||
logger.info(f"为租户 {tenant_id} 添加提供商记录成功!")
|
||
return True
|
||
else:
|
||
logger.warning(f"提供商 {provider_name} 配置中缺少API密钥,跳过添加。")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_id} 添加提供商记录失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def delete_providers_for_tenant(tenant_id):
|
||
"""删除指定租户下的所有提供商记录"""
|
||
try:
|
||
query = """
|
||
DELETE FROM providers
|
||
WHERE tenant_id = %s;
|
||
"""
|
||
rows_affected = execute_update(query, (tenant_id,))
|
||
logger.info(f"租户 {tenant_id} 下的所有提供商记录已删除,共 {rows_affected} 条。")
|
||
return rows_affected
|
||
except Exception as e:
|
||
logger.error(f"删除租户 {tenant_id} 的提供商记录失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def delete_provider_for_tenant(tenant_id, provider_name):
|
||
"""删除指定租户下的特定提供商记录"""
|
||
try:
|
||
query = """
|
||
DELETE FROM providers
|
||
WHERE tenant_id = %s AND provider_name = %s;
|
||
"""
|
||
rows_affected = execute_update(query, (tenant_id, provider_name))
|
||
|
||
if rows_affected > 0:
|
||
logger.info(f"租户 {tenant_id} 下的提供商 {provider_name} 已删除。")
|
||
else:
|
||
logger.warning(f"租户 {tenant_id} 下不存在提供商 {provider_name}。")
|
||
|
||
return rows_affected
|
||
except Exception as e:
|
||
logger.error(f"删除租户 {tenant_id} 下的提供商 {provider_name} 失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_providers_for_all_tenants(config_path=CONFIG_PATHS['provider_config']):
|
||
"""为所有租户添加提供商"""
|
||
try:
|
||
# 加载提供商配置
|
||
config = ProviderManager.load_config(config_path)
|
||
providers = config.get("providers", [])
|
||
|
||
# 获取所有租户
|
||
tenants = TenantManager.get_all_tenants()
|
||
total_added = 0
|
||
|
||
# 为每个租户添加提供商
|
||
for tenant in tenants:
|
||
tenant_id = tenant['id']
|
||
public_key_pem = tenant['encrypt_public_key']
|
||
for provider_config in providers:
|
||
if ProviderManager.add_provider_for_tenant(tenant_id, public_key_pem, provider_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为所有租户添加提供商完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为所有租户添加提供商失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def add_providers_for_tenant(tenant_name, config_path=CONFIG_PATHS['provider_config']):
|
||
"""为指定租户添加提供商"""
|
||
try:
|
||
# 加载提供商配置
|
||
config = ProviderManager.load_config(config_path)
|
||
providers = config.get("providers", [])
|
||
|
||
# 获取租户信息
|
||
tenant = TenantManager.get_tenant_by_name(tenant_name)
|
||
if not tenant:
|
||
logger.error(f"未找到租户: {tenant_name}")
|
||
return 0
|
||
|
||
# 为租户添加提供商
|
||
total_added = 0
|
||
for provider_config in providers:
|
||
if ProviderManager.add_provider_for_tenant(tenant['id'], tenant['encrypt_public_key'], provider_config):
|
||
total_added += 1
|
||
|
||
logger.info(f"为租户 {tenant_name} 添加提供商完成,共添加 {total_added} 条记录。")
|
||
return total_added
|
||
except Exception as e:
|
||
logger.error(f"为租户 {tenant_name} 添加提供商失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def get_providers_for_tenant(tenant_id):
|
||
"""获取指定租户下的所有提供商"""
|
||
try:
|
||
query = """
|
||
SELECT id, provider_name, provider_type, is_valid
|
||
FROM providers
|
||
WHERE tenant_id = %s;
|
||
"""
|
||
providers = execute_query(query, (tenant_id,))
|
||
return providers
|
||
except Exception as e:
|
||
logger.error(f"获取租户 {tenant_id} 的提供商失败: {e}")
|
||
return []
|