120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
import aiomysql
|
||
|
||
class DBPool:
|
||
def __init__(self, coin, options):
|
||
self.coin = coin
|
||
self.pool = None
|
||
self.options = options
|
||
|
||
async def _get_pool(self):
|
||
if self.pool is None:
|
||
self.pool = await aiomysql.create_pool(**self.options)
|
||
return self.pool
|
||
|
||
# -------------------------
|
||
# 工具:判断是否 SELECT
|
||
# -------------------------
|
||
def _is_select(self, sql: str) -> bool:
|
||
return sql.lstrip().lower().startswith("select")
|
||
|
||
# -------------------------
|
||
# 非事务 SQL
|
||
# -------------------------
|
||
async def exec(self, sql, values=None):
|
||
if values is None:
|
||
values = []
|
||
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cur:
|
||
await cur.execute(sql, values)
|
||
|
||
if self._is_select(sql):
|
||
# JS: SELECT 返回 rows
|
||
return await cur.fetchall()
|
||
else:
|
||
# JS: 非 SELECT 返回 OkPacket
|
||
return {
|
||
"affectedRows": cur.rowcount,
|
||
"insertId": cur.lastrowid,
|
||
}
|
||
|
||
# -------------------------
|
||
# 单 SQL 事务
|
||
# -------------------------
|
||
async def exec_transaction(self, sql, values=None):
|
||
if values is None:
|
||
values = []
|
||
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cur:
|
||
await conn.begin()
|
||
try:
|
||
await cur.execute(sql, values)
|
||
|
||
if self._is_select(sql):
|
||
result = await cur.fetchall()
|
||
else:
|
||
result = {
|
||
"affectedRows": cur.rowcount,
|
||
"insertId": cur.lastrowid,
|
||
}
|
||
|
||
await conn.commit()
|
||
return result
|
||
except Exception:
|
||
await conn.rollback()
|
||
raise
|
||
|
||
# -------------------------
|
||
# 多 SQL 合并事务
|
||
# -------------------------
|
||
async def exec_transaction_together(self, params):
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cur:
|
||
await conn.begin()
|
||
try:
|
||
for item in params:
|
||
sql = item["sql"]
|
||
param = item.get("param", [])
|
||
await cur.execute(sql, param)
|
||
await conn.commit()
|
||
except Exception:
|
||
await conn.rollback()
|
||
raise
|
||
|
||
# -------------------------
|
||
# 写锁事务
|
||
# -------------------------
|
||
async def exec_write_lock(self, sql, values=None, table_name=None):
|
||
if values is None:
|
||
values = []
|
||
|
||
pool = await self._get_pool()
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cur:
|
||
await conn.begin()
|
||
try:
|
||
await cur.execute(f"LOCK TABLES {table_name} WRITE")
|
||
await cur.execute(sql, values)
|
||
|
||
if self._is_select(sql):
|
||
result = await cur.fetchall()
|
||
else:
|
||
result = {
|
||
"affectedRows": cur.rowcount,
|
||
"insertId": cur.lastrowid,
|
||
}
|
||
|
||
await conn.commit()
|
||
return result
|
||
except Exception:
|
||
await conn.rollback()
|
||
raise
|
||
finally:
|
||
# 对齐 JS:始终解锁
|
||
await conn.cursor().execute("UNLOCK TABLES")
|
||
|
||
__all__ = ["DBPool"] |