dify_admin/api/tenant_manager.py
xh.xin 96480a27a9 初始化项目仓库,包含基础结构和开发计划
1. 添加README说明项目结构
2. 配置Python和Node.js的.gitignore
3. 包含认证模块和账号管理的前后端基础代码
4. 开发计划文档记录当前阶段任务
2025-05-02 18:33:06 +08:00

116 lines
4.1 KiB
Python

import os
import uuid
import logging
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from psycopg2.extras import RealDictCursor
import psycopg2.extras
from database import get_db_cursor, execute_query
from config import CONFIG_PATHS
# 配置日志
logger = logging.getLogger(__name__)
class TenantManager:
"""租户管理类"""
@staticmethod
def generate_rsa_key_pair():
"""生成RSA密钥对"""
try:
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
public_key = private_key.public_key()
# 将公钥序列化为PEM格式
public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode("utf-8")
return public_key_pem, private_key
except Exception as e:
logger.error(f"生成RSA密钥对失败: {e}")
raise
@staticmethod
def save_private_key(tenant_id, private_key):
"""保存私钥到文件"""
try:
privkey_dir = os.path.join(CONFIG_PATHS['privkeys_dir'], str(tenant_id))
os.makedirs(privkey_dir, exist_ok=True)
privkey_path = os.path.join(privkey_dir, "private.pem")
with open(privkey_path, "wb") as key_file:
key_file.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
))
logger.info(f"私钥已保存到 {privkey_path}")
return privkey_path
except Exception as e:
logger.error(f"保存私钥失败: {e}")
raise
@staticmethod
def create_tenant(workspace_name):
"""创建新租户"""
try:
# 生成UUID和RSA密钥对
tenant_id = uuid.uuid4()
public_key_pem, private_key = TenantManager.generate_rsa_key_pair()
# 保存私钥
TenantManager.save_private_key(tenant_id, private_key)
# 插入租户记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO tenants (id, name, encrypt_public_key)
VALUES (%s, %s, %s);
"""
cursor.execute(insert_query, (tenant_id, workspace_name, public_key_pem))
logger.info(f"租户 '{workspace_name}' 创建成功!租户 ID: {tenant_id}")
return tenant_id
except Exception as e:
logger.error(f"创建租户失败: {e}")
raise
@staticmethod
def get_tenant_by_name(workspace_name):
"""根据租户名称获取租户信息"""
try:
query = """
SELECT id, encrypt_public_key FROM tenants WHERE name = %s;
"""
tenant = execute_query(query, (workspace_name,), cursor_factory=RealDictCursor, fetch_one=True)
if tenant:
return tenant
else:
logger.warning(f"未找到名称为 '{workspace_name}' 的租户。")
return None
except Exception as e:
logger.error(f"获取租户信息失败: {e}")
raise
@staticmethod
def get_all_tenants():
"""获取所有租户信息"""
try:
query = "SELECT id, encrypt_public_key FROM tenants;"
tenants = execute_query(query, cursor_factory=RealDictCursor)
return tenants
except Exception as e:
logger.error(f"获取所有租户信息失败: {e}")
return []