dify_admin/api/account_manager.py

343 lines
13 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import uuid
import secrets
import binascii
import hashlib
import base64
import logging
from datetime import datetime, timezone
import psycopg2.extras
from database import get_db_cursor, execute_query, execute_update
from tenant_manager import TenantManager # Import TenantManager for tenant operations
# 配置日志
logger = logging.getLogger(__name__)
class AccountManager:
"""账户管理类"""
@staticmethod
def hash_password(password, salt=None):
"""生成密码的哈希值和盐值"""
try:
# 生成密码盐
if salt is None:
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# 使用盐值加密密码
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 10000)
password_hashed = binascii.hexlify(dk)
base64_password_hashed = base64.b64encode(password_hashed).decode()
return base64_password_hashed, base64_salt
except Exception as e:
logger.error(f"密码哈希失败: {e}")
raise
@staticmethod
def create_account(username, email, password):
"""创建新账户"""
try:
# 生成UUID和密码哈希值
user_id = uuid.uuid4()
hashed_password, password_salt = AccountManager.hash_password(password)
# 获取当前时间
current_time = datetime.now(timezone.utc)
# 插入账户记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO accounts (
id, name, email, password, password_salt, avatar, interface_language,
interface_theme, timezone, last_login_at, last_login_ip, status,
initialized_at, created_at, updated_at, last_active_at
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
) RETURNING id, name, email, created_at;
"""
cursor.execute(insert_query, (
user_id, username, email, hashed_password, password_salt,
None, "en-US", "light", "UTC", None, None, "active",
current_time, current_time, current_time, current_time
))
result = cursor.fetchone()
cursor.connection.commit()
logger.info(f"用户 {username} 邮箱 {email} 注册成功!")
# 创建租户并关联角色
try:
# 创建租户(租户名为"账号名's Workspace"
tenant_name = f"{username}'s Workspace"
tenant_id = TenantManager.create_tenant(tenant_name)
# 关联新账号为自建租户的owner
AccountManager.associate_with_tenant(result[0], tenant_id, "owner")
# 关联新账号为ucas's Workspace的normal角色
ucas_tenant = TenantManager.get_tenant_by_name("ucas's Workspace")
if ucas_tenant:
AccountManager.associate_with_tenant(result[0], ucas_tenant['id'], "normal")
else:
logger.warning("ucas's Workspace not found - skipping normal role association")
except Exception as e:
logger.error(f"租户创建或关联失败: {e}")
# Continue even if tenant operations fail since account creation succeeded
return {
"id": result[0],
"username": result[1],
"email": result[2],
"created_at": result[3]
}
except Exception as e:
logger.error(f"创建账户失败: {e}")
raise
@staticmethod
def get_user_by_username(username):
"""根据用户名获取用户信息"""
try:
query = """
SELECT id, name, email, password, password_salt, created_at
FROM accounts WHERE name = %s;
"""
user = execute_query(query, (username,), fetch_one=True)
if user:
return {
"id": user[0],
"username": user[1],
"email": user[2],
"password": user[3],
"password_salt": user[4],
"created_at": user[5]
}
else:
logger.warning(f"未找到用户名为 {username} 的用户。")
return None
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise
@staticmethod
def search_accounts(search=None, page=1, page_size=10):
"""搜索账户"""
try:
offset = (page - 1) * page_size
query = """
SELECT id, name, email, status, created_at
FROM accounts
WHERE name LIKE %s OR email LIKE %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s;
"""
count_query = """
SELECT COUNT(*) FROM accounts
WHERE name LIKE %s OR email LIKE %s;
"""
search_pattern = f"%{search}%" if search else "%"
with get_db_cursor() as cursor:
cursor.execute(query, (search_pattern, search_pattern, page_size, offset))
accounts = cursor.fetchall()
cursor.execute(count_query, (search_pattern, search_pattern))
total = cursor.fetchone()[0]
return {
"data": [{
"id": a[0],
"username": a[1],
"email": a[2],
"status": a[3],
"created_at": a[4]
} for a in accounts],
"total": total
}
except Exception as e:
logger.error(f"搜索账户失败: {e}")
raise
@staticmethod
def verify_password(plain_password: str, hashed_password: str, salt: str):
"""验证密码"""
try:
# 解码盐值
salt_bytes = base64.b64decode(salt)
# 计算输入密码的哈希值
dk = hashlib.pbkdf2_hmac("sha256", plain_password.encode("utf-8"), salt_bytes, 10000)
input_hashed = base64.b64encode(binascii.hexlify(dk)).decode()
return input_hashed == hashed_password
except Exception as e:
logger.error(f"密码验证失败: {e}")
return False
@staticmethod
def get_user_by_email(email):
"""根据邮箱获取用户信息"""
try:
query = """
SELECT id, name, email FROM accounts WHERE email = %s;
"""
user = execute_query(query, (email,), fetch_one=True)
if user:
return {
"id": user[0],
"username": user[1],
"email": user[2]
}
else:
logger.warning(f"未找到邮箱为 {email} 的用户。")
return None
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise
@staticmethod
def update_password(username, email, new_password):
"""更新用户密码"""
try:
# 生成新的密码哈希值和盐值
hashed_password, password_salt = AccountManager.hash_password(new_password)
# 更新密码
updated_at = datetime.now(timezone.utc)
update_query = """
UPDATE accounts
SET password = %s, password_salt = %s, updated_at = %s
WHERE name = %s AND email = %s;
"""
rows_affected = execute_update(update_query, (hashed_password, password_salt, updated_at, username, email))
if rows_affected > 0:
logger.info(f"用户 {username} 邮箱 {email} 的密码已成功更新!")
return True
else:
logger.warning(f"未找到用户名为 {username} 邮箱为 {email} 的用户。")
return False
except Exception as e:
logger.error(f"更新密码失败: {e}")
raise
@staticmethod
def reset_password(account_id: str):
"""重置账号密码"""
try:
# 验证account_id是否为有效UUID
try:
uuid.UUID(account_id)
except ValueError:
logger.warning(f"无效的account_id格式: {account_id}")
return False
# 使用固定密码
new_password = "Welcome123!"
hashed_password, password_salt = AccountManager.hash_password(new_password)
# 更新密码
updated_at = datetime.now(timezone.utc)
update_query = """
UPDATE accounts
SET password = %s, password_salt = %s, updated_at = %s
WHERE id = %s::uuid;
"""
rows_affected = execute_update(update_query, (hashed_password, password_salt, updated_at, account_id))
if rows_affected > 0:
logger.info(f"账号 {account_id} 的密码已重置!")
return True
else:
logger.warning(f"未找到ID为 {account_id} 的账号。")
return False
except Exception as e:
logger.error(f"重置密码失败: {e}")
raise
@staticmethod
def associate_with_tenant(account_id, tenant_id, role="normal", invited_by=None, current=False):
"""将账户与租户关联"""
try:
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
current_time = datetime.now()
insert_query = """
INSERT INTO tenant_account_joins (
id, tenant_id, account_id, role, invited_by, created_at, updated_at, current
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s
);
"""
cursor.execute(insert_query, (
uuid.uuid4(),
tenant_id,
account_id,
role,
invited_by,
current_time,
current_time,
current
))
logger.info(f"账户 {account_id} 已成功关联到租户 {tenant_id},角色为 {role}")
return True
except Exception as e:
logger.error(f"关联账户与租户失败: {e}")
raise
@staticmethod
def get_tenant_accounts(tenant_id):
"""获取租户下的所有账户"""
try:
query = """
SELECT a.id, a.name, a.email, j.role, j.current
FROM accounts a
JOIN tenant_account_joins j ON a.id = j.account_id
WHERE j.tenant_id = %s;
"""
accounts = execute_query(query, (tenant_id,))
return accounts
except Exception as e:
logger.error(f"获取租户账户失败: {e}")
return []
@staticmethod
def get_account_tenants(account_id):
"""获取账户关联的所有租户"""
try:
# 处理account_id可能是UUID对象或字符串的情况
account_id_str = str(account_id) if hasattr(account_id, 'hex') else account_id
# 验证account_id是否为有效UUID
try:
uuid.UUID(account_id_str)
except ValueError:
logger.error(f"无效的account_id格式: {account_id_str}")
return []
query = """
SELECT t.id, t.name, j.role, j.current
FROM tenants t
JOIN tenant_account_joins j ON t.id = j.tenant_id
WHERE j.account_id = %s::uuid;
"""
tenants = execute_query(query, (account_id_str,))
if not tenants:
logger.info(f"账号 {account_id} 未关联任何租户")
return []
return [{
"tenant_id": str(t[0]),
"tenant_name": t[1],
"role": str(t[2]).lower(), # 确保角色值统一为小写
"current": t[3]
} for t in tenants]
except Exception as e:
logger.error(f"获取账户租户失败: {e}", exc_info=True)
raise