dify_admin/api/account_manager.py

319 lines
11 KiB
Python
Executable File

import uuid
import secrets
import binascii
import hashlib
import base64
import logging
from datetime import datetime, timezone
import psycopg2.extras
from database import get_db_cursor, execute_query, execute_update
# 配置日志
logger = logging.getLogger(__name__)
class AccountManager:
"""账户管理类"""
@staticmethod
def hash_password(password, salt=None):
"""生成密码的哈希值和盐值"""
try:
# 生成密码盐
if salt is None:
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# 使用盐值加密密码
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 10000)
password_hashed = binascii.hexlify(dk)
base64_password_hashed = base64.b64encode(password_hashed).decode()
return base64_password_hashed, base64_salt
except Exception as e:
logger.error(f"密码哈希失败: {e}")
raise
@staticmethod
def create_account(username, email, password):
"""创建新账户"""
try:
# 生成UUID和密码哈希值
user_id = uuid.uuid4()
hashed_password, password_salt = AccountManager.hash_password(password)
# 获取当前时间
current_time = datetime.now(timezone.utc)
# 插入账户记录
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
insert_query = """
INSERT INTO accounts (
id, name, email, password, password_salt, avatar, interface_language,
interface_theme, timezone, last_login_at, last_login_ip, status,
initialized_at, created_at, updated_at, last_active_at
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
) RETURNING id, name, email, created_at;
"""
cursor.execute(insert_query, (
user_id, username, email, hashed_password, password_salt,
None, "en-US", "light", "UTC", None, None, "active",
current_time, current_time, current_time, current_time
))
result = cursor.fetchone()
cursor.connection.commit()
logger.info(f"用户 {username} 邮箱 {email} 注册成功!")
return {
"id": result[0],
"username": result[1],
"email": result[2],
"created_at": result[3]
}
except Exception as e:
logger.error(f"创建账户失败: {e}")
raise
@staticmethod
def get_user_by_username(username):
"""根据用户名获取用户信息"""
try:
query = """
SELECT id, name, email, password, password_salt, created_at
FROM accounts WHERE name = %s;
"""
user = execute_query(query, (username,), fetch_one=True)
if user:
return {
"id": user[0],
"username": user[1],
"email": user[2],
"password": user[3],
"password_salt": user[4],
"created_at": user[5]
}
else:
logger.warning(f"未找到用户名为 {username} 的用户。")
return None
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise
@staticmethod
def search_accounts(search=None, page=1, page_size=10):
"""搜索账户"""
try:
offset = (page - 1) * page_size
query = """
SELECT id, name, email, status, created_at
FROM accounts
WHERE name LIKE %s OR email LIKE %s
ORDER BY created_at DESC
LIMIT %s OFFSET %s;
"""
count_query = """
SELECT COUNT(*) FROM accounts
WHERE name LIKE %s OR email LIKE %s;
"""
search_pattern = f"%{search}%" if search else "%"
with get_db_cursor() as cursor:
cursor.execute(query, (search_pattern, search_pattern, page_size, offset))
accounts = cursor.fetchall()
cursor.execute(count_query, (search_pattern, search_pattern))
total = cursor.fetchone()[0]
return {
"data": [{
"id": a[0],
"username": a[1],
"email": a[2],
"status": a[3],
"created_at": a[4]
} for a in accounts],
"total": total
}
except Exception as e:
logger.error(f"搜索账户失败: {e}")
raise
@staticmethod
def verify_password(plain_password: str, hashed_password: str, salt: str):
"""验证密码"""
try:
# 解码盐值
salt_bytes = base64.b64decode(salt)
# 计算输入密码的哈希值
dk = hashlib.pbkdf2_hmac("sha256", plain_password.encode("utf-8"), salt_bytes, 10000)
input_hashed = base64.b64encode(binascii.hexlify(dk)).decode()
return input_hashed == hashed_password
except Exception as e:
logger.error(f"密码验证失败: {e}")
return False
@staticmethod
def get_user_by_email(email):
"""根据邮箱获取用户信息"""
try:
query = """
SELECT id, name, email FROM accounts WHERE email = %s;
"""
user = execute_query(query, (email,), fetch_one=True)
if user:
return {
"id": user[0],
"username": user[1],
"email": user[2]
}
else:
logger.warning(f"未找到邮箱为 {email} 的用户。")
return None
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise
@staticmethod
def update_password(username, email, new_password):
"""更新用户密码"""
try:
# 生成新的密码哈希值和盐值
hashed_password, password_salt = AccountManager.hash_password(new_password)
# 更新密码
updated_at = datetime.now(timezone.utc)
update_query = """
UPDATE accounts
SET password = %s, password_salt = %s, updated_at = %s
WHERE name = %s AND email = %s;
"""
rows_affected = execute_update(update_query, (hashed_password, password_salt, updated_at, username, email))
if rows_affected > 0:
logger.info(f"用户 {username} 邮箱 {email} 的密码已成功更新!")
return True
else:
logger.warning(f"未找到用户名为 {username} 邮箱为 {email} 的用户。")
return False
except Exception as e:
logger.error(f"更新密码失败: {e}")
raise
@staticmethod
def reset_password(account_id: str):
"""重置账号密码"""
try:
# 验证account_id是否为有效UUID
try:
uuid.UUID(account_id)
except ValueError:
logger.warning(f"无效的account_id格式: {account_id}")
return False
# 使用固定密码
new_password = "Welcome123!"
hashed_password, password_salt = AccountManager.hash_password(new_password)
# 更新密码
updated_at = datetime.now(timezone.utc)
update_query = """
UPDATE accounts
SET password = %s, password_salt = %s, updated_at = %s
WHERE id = %s::uuid;
"""
rows_affected = execute_update(update_query, (hashed_password, password_salt, updated_at, account_id))
if rows_affected > 0:
logger.info(f"账号 {account_id} 的密码已重置!")
return True
else:
logger.warning(f"未找到ID为 {account_id} 的账号。")
return False
except Exception as e:
logger.error(f"重置密码失败: {e}")
raise
@staticmethod
def associate_with_tenant(account_id, tenant_id, role="normal", invited_by=None, current=False):
"""将账户与租户关联"""
try:
with get_db_cursor() as cursor:
psycopg2.extras.register_uuid()
current_time = datetime.now()
insert_query = """
INSERT INTO tenant_account_joins (
id, tenant_id, account_id, role, invited_by, created_at, updated_at, current
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s
);
"""
cursor.execute(insert_query, (
uuid.uuid4(),
tenant_id,
account_id,
role,
invited_by,
current_time,
current_time,
current
))
logger.info(f"账户 {account_id} 已成功关联到租户 {tenant_id},角色为 {role}")
return True
except Exception as e:
logger.error(f"关联账户与租户失败: {e}")
raise
@staticmethod
def get_tenant_accounts(tenant_id):
"""获取租户下的所有账户"""
try:
query = """
SELECT a.id, a.name, a.email, j.role, j.current
FROM accounts a
JOIN tenant_account_joins j ON a.id = j.account_id
WHERE j.tenant_id = %s;
"""
accounts = execute_query(query, (tenant_id,))
return accounts
except Exception as e:
logger.error(f"获取租户账户失败: {e}")
return []
@staticmethod
def get_account_tenants(account_id):
"""获取账户关联的所有租户"""
try:
# 验证account_id是否为有效UUID
try:
uuid.UUID(account_id)
except ValueError:
logger.error(f"无效的account_id格式: {account_id}")
return []
query = """
SELECT t.id, t.name, j.role, j.current
FROM tenants t
JOIN tenant_account_joins j ON t.id = j.tenant_id
WHERE j.account_id = %s::uuid;
"""
tenants = execute_query(query, (account_id,))
if not tenants:
logger.info(f"账号 {account_id} 未关联任何租户")
return []
return [{
"tenant_id": str(t[0]),
"tenant_name": t[1],
"role": str(t[2]).lower(), # 确保角色值统一为小写
"current": t[3]
} for t in tenants]
except Exception as e:
logger.error(f"获取账户租户失败: {e}", exc_info=True)
raise