dify_admin/api/tenant_manager.py

149 lines
4.9 KiB
Python
Executable File

import os
import uuid
import logging
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from psycopg2.extras import RealDictCursor
import psycopg2.extras
from database import get_db_cursor, execute_query
from config import CONFIG_PATHS
# 配置日志
logger = logging.getLogger(__name__)
class TenantManager:
"""租户管理类"""
@staticmethod
def generate_rsa_key_pair():
"""生成RSA密钥对"""
try:
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
public_key = private_key.public_key()
# 将公钥序列化为PEM格式
public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode("utf-8")
return public_key_pem, private_key
except Exception as e:
logger.error(f"生成RSA密钥对失败: {e}")
raise
@staticmethod
def save_private_key(tenant_id, private_key):
"""保存私钥到文件"""
try:
privkey_dir = os.path.join(CONFIG_PATHS['privkeys_dir'], str(tenant_id))
os.makedirs(privkey_dir, exist_ok=True)
privkey_path = os.path.join(privkey_dir, "private.pem")
with open(privkey_path, "wb") as key_file:
key_file.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
))
logger.info(f"私钥已保存到 {privkey_path}")
return privkey_path
except Exception as e:
logger.error(f"保存私钥失败: {e}")
raise
@staticmethod
def create_tenant(workspace_name):
"""创建新租户"""
try:
# 生成UUID和RSA密钥对
tenant_id = uuid.uuid4()
public_key_pem, private_key = TenantManager.generate_rsa_key_pair()
# 保存私钥
TenantManager.save_private_key(tenant_id, private_key)
# 插入租户记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO tenants (id, name, encrypt_public_key)
VALUES (%s, %s, %s);
"""
cursor.execute(insert_query, (tenant_id, workspace_name, public_key_pem))
logger.info(f"租户 '{workspace_name}' 创建成功!租户 ID: {tenant_id}")
return tenant_id
except Exception as e:
logger.error(f"创建租户失败: {e}")
raise
@staticmethod
def get_tenant_by_name(workspace_name):
"""根据租户名称获取租户信息"""
try:
query = """
SELECT id, encrypt_public_key FROM tenants WHERE name = %s;
"""
tenant = execute_query(query, (workspace_name,), cursor_factory=RealDictCursor, fetch_one=True)
if tenant:
return tenant
else:
logger.warning(f"未找到名称为 '{workspace_name}' 的租户。")
return None
except Exception as e:
logger.error(f"获取租户信息失败: {e}")
raise
@staticmethod
def get_all_tenants():
"""获取所有租户信息"""
try:
query = """
SELECT
id::text,
name,
encrypt_public_key,
created_at
FROM tenants;
"""
tenants = execute_query(query, cursor_factory=RealDictCursor)
if not tenants:
return []
# 确保返回字段名正确
return tenants
except Exception as e:
logger.error(f"获取所有租户信息失败: {e}")
return []
@staticmethod
def search_tenants(search_term: str):
"""根据名称搜索租户"""
try:
query = """
SELECT
id::text,
name,
encrypt_public_key,
created_at
FROM tenants
WHERE name ILIKE %s;
"""
tenants = execute_query(query, (f"%{search_term}%",), cursor_factory=RealDictCursor)
if not tenants:
return []
return tenants
except Exception as e:
logger.error(f"搜索租户失败: {e}")
return []