dify_admin/api/api_user_manager.py

77 lines
3.0 KiB
Python
Executable File

import logging
from passlib.context import CryptContext
from database import get_db_cursor
import os
import base64
from libs.password import hash_password, compare_password
from libs.exception import APIUserExistsError, APIUserNotFoundError
logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class APIAuthManager:
def __init__(self):
self.salt = os.urandom(16) # 生成16字节的随机salt
def create_user(self, username: str, password: str, email: str = None):
"""创建API用户"""
with get_db_cursor(db_type='sqlite') as cursor:
try:
# 检查用户是否存在
cursor.execute("SELECT 1 FROM api_users WHERE username=?", (username,))
if cursor.fetchone():
raise APIUserExistsError(f"API用户 {username} 已存在")
# 创建用户
password_hash = hash_password(password, self.salt)
cursor.execute(
"INSERT INTO api_users (username, password_hash, email) VALUES (?, ?, ?)",
(username, password_hash, email)
)
return cursor.lastrowid
except Exception as e:
logger.error(f"创建API用户失败: {e}")
raise
def authenticate(self, username: str, password: str):
"""认证API用户"""
with get_db_cursor(db_type='sqlite') as cursor:
try:
cursor.execute(
"SELECT id, username, password_hash FROM api_users WHERE username=? AND is_active=1",
(username,)
)
user = cursor.fetchone()
if not user:
raise APIUserNotFoundError(f"API用户 {username} 不存在")
if not compare_password(password, user[2], base64.b64encode(self.salt).decode()):
return None
return {'id': user[0], 'username': user[1]}
except Exception as e:
logger.error(f"API用户认证失败: {e}")
raise
def update_user(self, user_id: int, **kwargs):
"""更新用户信息"""
updatable_fields = ['email', 'is_active']
updates = {k: v for k, v in kwargs.items() if k in updatable_fields}
if not updates:
return False
with get_db_cursor(db_type='sqlite') as cursor:
try:
set_clause = ", ".join(f"{field}=?" for field in updates.keys())
values = list(updates.values()) + [user_id]
cursor.execute(
f"UPDATE api_users SET {set_clause}, updated_at=CURRENT_TIMESTAMP WHERE id=?",
values
)
return cursor.rowcount > 0
except Exception as e:
logger.error(f"更新API用户失败: {e}")
raise