dify_admin/api/database.py

155 lines
5.6 KiB
Python
Executable File

import psycopg2
import psycopg2.extras
import sqlite3
from psycopg2.pool import SimpleConnectionPool
from contextlib import contextmanager
from config import DB_CONFIG, SQLITE_CONFIG
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 创建PostgreSQL连接池
pg_pool = None
try:
pg_pool = SimpleConnectionPool(1, 10,
host=DB_CONFIG.get('host', 'localhost'),
port=DB_CONFIG.get('port', '5432'),
database=DB_CONFIG.get('database', 'postgres'),
user=DB_CONFIG.get('user', 'postgres'),
password=DB_CONFIG.get('password', ''))
logger.info("PostgreSQL连接池创建成功")
except Exception as e:
logger.error(f"PostgreSQL连接池创建失败: {e}")
pg_pool = None
# SQLite数据库连接
def get_sqlite_conn():
"""获取SQLite数据库连接"""
try:
conn = sqlite3.connect(SQLITE_CONFIG['database'],
timeout=SQLITE_CONFIG['timeout'])
logger.info("SQLite数据库连接成功")
return conn
except Exception as e:
logger.error(f"SQLite数据库连接失败: {e}")
raise
@contextmanager
def get_db_connection(db_type='postgres'):
"""获取数据库连接的上下文管理器"""
conn = None
try:
if db_type == 'postgres':
if pg_pool is None:
raise Exception("PostgreSQL连接池未初始化")
conn = pg_pool.getconn()
logger.info("PostgreSQL数据库连接成功")
else:
conn = get_sqlite_conn()
logger.info("SQLite数据库连接成功")
yield conn
except Exception as e:
logger.error(f"数据库连接失败: {e}")
raise
finally:
if conn:
if db_type == 'postgres' and pg_pool:
pg_pool.putconn(conn)
elif db_type == 'sqlite':
conn.close()
@contextmanager
def get_db_cursor(cursor_factory=None, db_type='postgres'):
"""获取数据库游标的上下文管理器"""
with get_db_connection(db_type) as conn:
cursor = None
try:
if db_type == 'sqlite':
cursor = conn.cursor()
else:
cursor = conn.cursor(cursor_factory=cursor_factory)
yield cursor
conn.commit()
except Exception as e:
conn.rollback()
logger.error(f"数据库操作失败: {e}")
raise
finally:
if cursor:
cursor.close()
def execute_query(query, params=None, cursor_factory=None, fetch_one=False, db_type='postgres'):
"""执行SQL查询并返回结果"""
with get_db_cursor(cursor_factory=cursor_factory, db_type=db_type) as cursor:
cursor.execute(query, params or ())
if fetch_one:
return cursor.fetchone()
return cursor.fetchall()
def execute_update(query, params=None, db_type='postgres'):
"""执行SQL更新操作并返回影响的行数"""
with get_db_cursor(db_type=db_type) as cursor:
cursor.execute(query, params or ())
return cursor.rowcount
def init_sqlite_db():
"""初始化SQLite数据库表结构"""
try:
with get_db_connection('sqlite') as conn:
cursor = conn.cursor()
# 创建API接口表
cursor.execute('''
CREATE TABLE IF NOT EXISTS api_endpoints (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
method TEXT NOT NULL,
description TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# 创建API请求日志表
cursor.execute('''
CREATE TABLE IF NOT EXISTS api_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
endpoint_id INTEGER,
request_data TEXT,
response_data TEXT,
status_code INTEGER,
duration REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(endpoint_id) REFERENCES api_endpoints(id)
)
''')
# 创建API用户表
cursor.execute('''
CREATE TABLE IF NOT EXISTS api_users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
email TEXT,
is_active BOOLEAN DEFAULT 1,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# 创建API操作记录表
cursor.execute('''
CREATE TABLE IF NOT EXISTS api_operations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
operation_type TEXT NOT NULL,
endpoint TEXT NOT NULL,
parameters TEXT,
status TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(user_id) REFERENCES api_users(id)
)
''')
conn.commit()
logger.info("SQLite数据库表初始化成功")
except Exception as e:
logger.error(f"SQLite数据库表初始化失败: {e}")
raise