dify_admin/api/provider_manager.py

202 lines
8.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
import psycopg2.extras
from .database import get_db_cursor, execute_query, execute_update
from encryption import Encryption
from .config import CONFIG_PATHS
from tenant_manager import TenantManager
# 配置日志
logger = logging.getLogger(__name__)
class ProviderManager:
"""提供商管理类"""
@staticmethod
def load_config(config_path):
"""加载提供商配置文件"""
try:
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件 {config_path} 不存在!")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
raise
@staticmethod
def check_provider_exists(tenant_id, provider_name):
"""检查指定租户是否已存在指定的提供商"""
try:
query = """
SELECT COUNT(*) FROM providers
WHERE tenant_id = %s AND provider_name = %s;
"""
count = execute_query(query, (tenant_id, provider_name), fetch_one=True)[0]
return count > 0
except Exception as e:
logger.error(f"检查提供商是否存在时发生错误: {e}")
return False
@staticmethod
def add_provider_for_tenant(tenant_id, public_key_pem, provider_config):
"""为指定租户添加提供商记录"""
try:
# 检查必要字段
if "provider_name" not in provider_config:
logger.error("提供商配置缺少必要字段: provider_name")
raise ValueError("提供商配置缺少必要字段: provider_name")
if "config" not in provider_config:
logger.error("提供商配置缺少必要字段: config")
raise ValueError("提供商配置缺少必要字段: config")
# 检查提供商是否已存在
provider_name = provider_config["provider_name"]
if ProviderManager.check_provider_exists(tenant_id, provider_name):
logger.info(f"租户 {tenant_id} 已存在提供商 {provider_name},跳过添加。")
return
# 加密API密钥
api_key = provider_config["config"].get("dashscope_api_key", "")
if api_key:
encrypted_api_key = Encryption.encrypt_api_key(public_key_pem, api_key)
# 构造加密配置
encrypted_config = {
"dashscope_api_key": encrypted_api_key
}
# 插入提供商记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO providers (
id, tenant_id, provider_name, provider_type,
encrypted_config, is_valid, created_at, updated_at
) VALUES (
uuid_generate_v4(), %s, %s, %s, %s, TRUE, NOW(), NOW()
);
"""
cursor.execute(
insert_query, (
tenant_id,
provider_name,
provider_config.get("provider_type", "custom"),
json.dumps(encrypted_config)
)
)
logger.info(f"为租户 {tenant_id} 添加提供商记录成功!")
return True
else:
logger.warning(f"提供商 {provider_name} 配置中缺少API密钥跳过添加。")
return False
except Exception as e:
logger.error(f"为租户 {tenant_id} 添加提供商记录失败: {e}")
raise
@staticmethod
def delete_providers_for_tenant(tenant_id):
"""删除指定租户下的所有提供商记录"""
try:
query = """
DELETE FROM providers
WHERE tenant_id = %s;
"""
rows_affected = execute_update(query, (tenant_id,))
logger.info(f"租户 {tenant_id} 下的所有提供商记录已删除,共 {rows_affected} 条。")
return rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 的提供商记录失败: {e}")
raise
@staticmethod
def delete_provider_for_tenant(tenant_id, provider_name):
"""删除指定租户下的特定提供商记录"""
try:
query = """
DELETE FROM providers
WHERE tenant_id = %s AND provider_name = %s;
"""
rows_affected = execute_update(query, (tenant_id, provider_name))
if rows_affected > 0:
logger.info(f"租户 {tenant_id} 下的提供商 {provider_name} 已删除。")
else:
logger.warning(f"租户 {tenant_id} 下不存在提供商 {provider_name}")
return rows_affected
except Exception as e:
logger.error(f"删除租户 {tenant_id} 下的提供商 {provider_name} 失败: {e}")
raise
@staticmethod
def add_providers_for_all_tenants(config_path=CONFIG_PATHS['provider_config']):
"""为所有租户添加提供商"""
try:
# 加载提供商配置
config = ProviderManager.load_config(config_path)
providers = config.get("providers", [])
# 获取所有租户
tenants = TenantManager.get_all_tenants()
total_added = 0
# 为每个租户添加提供商
for tenant in tenants:
tenant_id = tenant['id']
public_key_pem = tenant['encrypt_public_key']
for provider_config in providers:
if ProviderManager.add_provider_for_tenant(tenant_id, public_key_pem, provider_config):
total_added += 1
logger.info(f"为所有租户添加提供商完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为所有租户添加提供商失败: {e}")
raise
@staticmethod
def add_providers_for_tenant(tenant_name, config_path=CONFIG_PATHS['provider_config']):
"""为指定租户添加提供商"""
try:
# 加载提供商配置
config = ProviderManager.load_config(config_path)
providers = config.get("providers", [])
# 获取租户信息
tenant = TenantManager.get_tenant_by_name(tenant_name)
if not tenant:
logger.error(f"未找到租户: {tenant_name}")
return 0
# 为租户添加提供商
total_added = 0
for provider_config in providers:
if ProviderManager.add_provider_for_tenant(tenant['id'], tenant['encrypt_public_key'], provider_config):
total_added += 1
logger.info(f"为租户 {tenant_name} 添加提供商完成,共添加 {total_added} 条记录。")
return total_added
except Exception as e:
logger.error(f"为租户 {tenant_name} 添加提供商失败: {e}")
raise
@staticmethod
def get_providers_for_tenant(tenant_id):
"""获取指定租户下的所有提供商"""
try:
query = """
SELECT id, provider_name, provider_type, is_valid
FROM providers
WHERE tenant_id = %s;
"""
providers = execute_query(query, (tenant_id,))
return providers
except Exception as e:
logger.error(f"获取租户 {tenant_id} 的提供商失败: {e}")
return []