77 lines
3.0 KiB
Python
Executable File
77 lines
3.0 KiB
Python
Executable File
import logging
|
|
from passlib.context import CryptContext
|
|
from database import get_db_cursor
|
|
import os
|
|
import base64
|
|
from libs.password import hash_password, compare_password
|
|
from libs.exception import APIUserExistsError, APIUserNotFoundError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
class APIAuthManager:
|
|
def __init__(self):
|
|
self.salt = os.urandom(16) # 生成16字节的随机salt
|
|
|
|
def create_user(self, username: str, password: str, email: str = None):
|
|
"""创建API用户"""
|
|
with get_db_cursor(db_type='sqlite') as cursor:
|
|
try:
|
|
# 检查用户是否存在
|
|
cursor.execute("SELECT 1 FROM api_users WHERE username=?", (username,))
|
|
if cursor.fetchone():
|
|
raise APIUserExistsError(f"API用户 {username} 已存在")
|
|
|
|
# 创建用户
|
|
password_hash = hash_password(password, self.salt)
|
|
cursor.execute(
|
|
"INSERT INTO api_users (username, password_hash, email) VALUES (?, ?, ?)",
|
|
(username, password_hash, email)
|
|
)
|
|
return cursor.lastrowid
|
|
except Exception as e:
|
|
logger.error(f"创建API用户失败: {e}")
|
|
raise
|
|
|
|
def authenticate(self, username: str, password: str):
|
|
"""认证API用户"""
|
|
with get_db_cursor(db_type='sqlite') as cursor:
|
|
try:
|
|
cursor.execute(
|
|
"SELECT id, username, password_hash FROM api_users WHERE username=? AND is_active=1",
|
|
(username,)
|
|
)
|
|
user = cursor.fetchone()
|
|
if not user:
|
|
raise APIUserNotFoundError(f"API用户 {username} 不存在")
|
|
|
|
if not compare_password(password, user[2], base64.b64encode(self.salt).decode()):
|
|
return None
|
|
|
|
return {'id': user[0], 'username': user[1]}
|
|
except Exception as e:
|
|
logger.error(f"API用户认证失败: {e}")
|
|
raise
|
|
|
|
def update_user(self, user_id: int, **kwargs):
|
|
"""更新用户信息"""
|
|
updatable_fields = ['email', 'is_active']
|
|
updates = {k: v for k, v in kwargs.items() if k in updatable_fields}
|
|
|
|
if not updates:
|
|
return False
|
|
|
|
with get_db_cursor(db_type='sqlite') as cursor:
|
|
try:
|
|
set_clause = ", ".join(f"{field}=?" for field in updates.keys())
|
|
values = list(updates.values()) + [user_id]
|
|
|
|
cursor.execute(
|
|
f"UPDATE api_users SET {set_clause}, updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
|
values
|
|
)
|
|
return cursor.rowcount > 0
|
|
except Exception as e:
|
|
logger.error(f"更新API用户失败: {e}")
|
|
raise
|