120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
|
|
import aiomysql
|
|||
|
|
|
|||
|
|
class DBPool:
|
|||
|
|
def __init__(self, coin, options):
|
|||
|
|
self.coin = coin
|
|||
|
|
self.pool = None
|
|||
|
|
self.options = options
|
|||
|
|
|
|||
|
|
async def _get_pool(self):
|
|||
|
|
if self.pool is None:
|
|||
|
|
self.pool = await aiomysql.create_pool(**self.options)
|
|||
|
|
return self.pool
|
|||
|
|
|
|||
|
|
# -------------------------
|
|||
|
|
# 工具:判断是否 SELECT
|
|||
|
|
# -------------------------
|
|||
|
|
def _is_select(self, sql: str) -> bool:
|
|||
|
|
return sql.lstrip().lower().startswith("select")
|
|||
|
|
|
|||
|
|
# -------------------------
|
|||
|
|
# 非事务 SQL
|
|||
|
|
# -------------------------
|
|||
|
|
async def exec(self, sql, values=None):
|
|||
|
|
if values is None:
|
|||
|
|
values = []
|
|||
|
|
|
|||
|
|
pool = await self._get_pool()
|
|||
|
|
async with pool.acquire() as conn:
|
|||
|
|
async with conn.cursor(aiomysql.DictCursor) as cur:
|
|||
|
|
await cur.execute(sql, values)
|
|||
|
|
|
|||
|
|
if self._is_select(sql):
|
|||
|
|
# JS: SELECT 返回 rows
|
|||
|
|
return await cur.fetchall()
|
|||
|
|
else:
|
|||
|
|
# JS: 非 SELECT 返回 OkPacket
|
|||
|
|
return {
|
|||
|
|
"affectedRows": cur.rowcount,
|
|||
|
|
"insertId": cur.lastrowid,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# -------------------------
|
|||
|
|
# 单 SQL 事务
|
|||
|
|
# -------------------------
|
|||
|
|
async def exec_transaction(self, sql, values=None):
|
|||
|
|
if values is None:
|
|||
|
|
values = []
|
|||
|
|
|
|||
|
|
pool = await self._get_pool()
|
|||
|
|
async with pool.acquire() as conn:
|
|||
|
|
async with conn.cursor(aiomysql.DictCursor) as cur:
|
|||
|
|
await conn.begin()
|
|||
|
|
try:
|
|||
|
|
await cur.execute(sql, values)
|
|||
|
|
|
|||
|
|
if self._is_select(sql):
|
|||
|
|
result = await cur.fetchall()
|
|||
|
|
else:
|
|||
|
|
result = {
|
|||
|
|
"affectedRows": cur.rowcount,
|
|||
|
|
"insertId": cur.lastrowid,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
await conn.commit()
|
|||
|
|
return result
|
|||
|
|
except Exception:
|
|||
|
|
await conn.rollback()
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# -------------------------
|
|||
|
|
# 多 SQL 合并事务
|
|||
|
|
# -------------------------
|
|||
|
|
async def exec_transaction_together(self, params):
|
|||
|
|
pool = await self._get_pool()
|
|||
|
|
async with pool.acquire() as conn:
|
|||
|
|
async with conn.cursor() as cur:
|
|||
|
|
await conn.begin()
|
|||
|
|
try:
|
|||
|
|
for item in params:
|
|||
|
|
sql = item["sql"]
|
|||
|
|
param = item.get("param", [])
|
|||
|
|
await cur.execute(sql, param)
|
|||
|
|
await conn.commit()
|
|||
|
|
except Exception:
|
|||
|
|
await conn.rollback()
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# -------------------------
|
|||
|
|
# 写锁事务
|
|||
|
|
# -------------------------
|
|||
|
|
async def exec_write_lock(self, sql, values=None, table_name=None):
|
|||
|
|
if values is None:
|
|||
|
|
values = []
|
|||
|
|
|
|||
|
|
pool = await self._get_pool()
|
|||
|
|
async with pool.acquire() as conn:
|
|||
|
|
async with conn.cursor(aiomysql.DictCursor) as cur:
|
|||
|
|
await conn.begin()
|
|||
|
|
try:
|
|||
|
|
await cur.execute(f"LOCK TABLES {table_name} WRITE")
|
|||
|
|
await cur.execute(sql, values)
|
|||
|
|
|
|||
|
|
if self._is_select(sql):
|
|||
|
|
result = await cur.fetchall()
|
|||
|
|
else:
|
|||
|
|
result = {
|
|||
|
|
"affectedRows": cur.rowcount,
|
|||
|
|
"insertId": cur.lastrowid,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
await conn.commit()
|
|||
|
|
return result
|
|||
|
|
except Exception:
|
|||
|
|
await conn.rollback()
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
# 对齐 JS:始终解锁
|
|||
|
|
await conn.cursor().execute("UNLOCK TABLES")
|
|||
|
|
|
|||
|
|
__all__ = ["DBPool"]
|