dify_admin/api/tests/conftest.py

165 lines
5.4 KiB
Python
Executable File

import pytest
from fastapi.testclient import TestClient
from app import app
import sys
import os
import uuid
from datetime import datetime, timedelta
from datetime import timezone
import jwt
from account_manager import AccountManager
from tenant_manager import TenantManager
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from ..database import get_db_connection, get_db_cursor
import sqlite3
from contextlib import contextmanager
import os
@pytest.fixture(scope="session")
def test_db():
"""创建内存SQLite测试数据库"""
conn = sqlite3.connect(":memory:", check_same_thread=False)
# 初始化测试表结构
# 创建表结构
conn.execute('''
CREATE TABLE accounts (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
email TEXT,
password TEXT NOT NULL,
password_salt TEXT NOT NULL,
avatar TEXT,
interface_language TEXT,
interface_theme TEXT,
timezone TEXT,
last_login_at TIMESTAMP,
last_login_ip TEXT,
status TEXT,
initialized_at TIMESTAMP,
created_at TIMESTAMP,
updated_at TIMESTAMP,
last_active_at TIMESTAMP
)
''')
# 插入测试数据
conn.execute('''
INSERT INTO accounts (id, name, email, password, password_salt,
interface_language, interface_theme, timezone,
status, initialized_at, created_at, updated_at, last_active_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'), datetime('now'), datetime('now'))
''', (
'550e8400-e29b-41d4-a716-446655440000',
'testuser',
'test@example.com',
'$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW', # testpass
'$2b$12$EixZaYVK1fsbw1ZfbX3OXe', # bcrypt salt
'en-US',
'light',
'UTC',
'active'
))
conn.commit()
yield conn
conn.close()
@pytest.fixture
def override_db(test_db):
"""覆盖原数据库连接和所有数据库相关依赖"""
def _override_db(db_type='sqlite'):
return test_db
# 覆盖数据库连接和游标
app.dependency_overrides[get_db_connection] = _override_db
app.dependency_overrides[get_db_cursor] = lambda: test_db.cursor()
yield
app.dependency_overrides.clear()
@pytest.fixture
def mocker_fixture(mocker):
"""提供统一的mock配置"""
# 模拟AccountManager方法
mock_user = {
"id": "550e8400-e29b-41d4-a716-446655440000",
"username": "testuser",
"password": "mock_hash",
"password_salt": "mock_salt",
"email": "test@example.com",
"status": "active",
"created_at": "2025-04-27T00:00:00Z",
"updated_at": "2025-04-27T00:00:00Z",
"last_active_at": "2025-04-27T00:00:00Z"
}
mocker.patch.object(AccountManager, 'create_account',
return_value=mock_user)
# 使用side_effect来区分第一次和第二次调用
def get_user_side_effect(username):
if username == "nonexistent":
return None
return mock_user
mocker.patch.object(AccountManager, 'get_user_by_username',
side_effect=get_user_side_effect)
mocker.patch.object(AccountManager, 'update_password',
return_value=True)
mocker.patch.object(AccountManager, 'verify_password',
return_value=True)
# 模拟TenantManager方法
mocker.patch.object(TenantManager, 'create_tenant',
return_value=uuid.UUID(int=0))
mocker.patch.object(TenantManager, 'get_tenant_by_name',
return_value={
"id": uuid.UUID(int=0),
"name": "testtenant",
"description": "测试租户",
"created_at": datetime.now(timezone.utc)
})
@pytest.fixture
def client(test_db, override_db):
"""测试客户端"""
# 初始化测试数据库
test_db.executescript('''
CREATE TABLE IF NOT EXISTS api_operations (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
operation_type TEXT NOT NULL,
endpoint TEXT NOT NULL,
parameters TEXT,
status TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
''')
with TestClient(app) as client:
yield client
@pytest.fixture
def auth_headers(mocker):
"""获取认证头"""
# 使用与app.py相同的配置
SECRET_KEY = "your-secret-key-here"
ALGORITHM = "HS256"
# 生成有效的mock token
payload = {
"sub": "testuser",
"user_id": "550e8400-e29b-41d4-a716-446655440000",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30)
}
mock_token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
# 模拟JWT验证
mocker.patch('app.create_access_token', return_value=mock_token)
mocker.patch('jwt.decode', return_value=payload)
# 返回有效的mock token
return {
"Authorization": f"Bearer {mock_token}",
"X-User-Id": "550e8400-e29b-41d4-a716-446655440000",
"X-Username": "testuser"
}