Files
m2pool_payment/internal/db/mysql.go
2025-11-18 11:10:16 +08:00

219 lines
5.2 KiB
Go

package db
import (
"bufio"
"database/sql"
"fmt"
"io/ioutil"
"log"
message "m2pool-payment/internal/msg"
"strings"
_ "github.com/go-sql-driver/mysql"
)
type MySQLPool struct {
db *sql.DB
}
// NewMySQLPool 初始化连接池
func NewMySQLPool(cfg message.MysqlConfig) (*MySQLPool, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
// 设置连接池参数
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
db.SetConnMaxLifetime(cfg.ConnMaxLife)
// 测试连接
if err := db.Ping(); err != nil {
return nil, err
}
return &MySQLPool{db: db}, nil
}
// splitSQLStatements 将 SQL 内容按分号分割
func splitSQLStatements(sqlContent string) []string {
// 使用 bufio 扫描文件内容
var statements []string
scanner := bufio.NewScanner(strings.NewReader(sqlContent))
var queryBuilder strings.Builder
for scanner.Scan() {
line := scanner.Text()
// 处理每一行的 SQL 语句
if strings.TrimSpace(line) == "" {
continue
}
// 如果行中包含分号,说明是完整的 SQL 语句
queryBuilder.WriteString(line)
if strings.HasSuffix(line, ";") {
statements = append(statements, queryBuilder.String())
queryBuilder.Reset() // 清空构建器以便准备下一条 SQL 语句
} else {
queryBuilder.WriteString("\n")
}
}
// 处理可能的扫描错误
if err := scanner.Err(); err != nil {
log.Fatalf("error reading .sql file: %v\n", err)
}
return statements
}
func (p *MySQLPool) ExecuteSQLFile(filePath string) error {
// 读取 SQL 文件内容
sqlContent, err := ioutil.ReadFile(filePath)
if err != nil {
return fmt.Errorf("failed to read SQL file: %v", err)
}
// 将文件内容按分号 (;) 分割成多条 SQL 语句
queries := splitSQLStatements(string(sqlContent))
// 执行每一条 SQL 语句
for _, query := range queries {
// 跳过空行或注释
if strings.TrimSpace(query) == "" || strings.HasPrefix(strings.TrimSpace(query), "--") {
continue
}
// 执行 SQL 语句
_, err := p.db.Exec(query)
if err != nil {
log.Printf("error executing query: %v\n", err)
} else {
// fmt.Println("Executed query:", query)
}
}
return nil
}
// Exec 执行 INSERT/UPDATE/DELETE
func (p *MySQLPool) Exec(query string, args ...any) (sql.Result, error) {
return p.db.Exec(query, args...)
}
// Query 查询多行
func (p *MySQLPool) Query(query string, args ...any) (*sql.Rows, error) {
return p.db.Query(query, args...)
}
// QueryRow 查询单行
func (p *MySQLPool) QueryRow(query string, args ...any) *sql.Row {
return p.db.QueryRow(query, args...)
}
// Transaction 执行事务
func (p *MySQLPool) Transaction(fn func(tx *sql.Tx) error) error {
tx, err := p.db.Begin()
if err != nil {
return err
}
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
// Insert 执行通用插入操作
func (p *MySQLPool) Insert(query string, values [][]any) (sql.Result, error) {
// 预处理查询
stmt, err := p.db.Prepare(query)
if err != nil {
return nil, fmt.Errorf("failed to prepare statement: %v", err)
}
defer stmt.Close()
// 执行批量插入
var result sql.Result
for _, row := range values {
result, err = stmt.Exec(row...)
if err != nil {
return nil, fmt.Errorf("failed to execute insert: %v", err)
}
}
return result, nil
}
// Delete 执行通用删除操作
// Update 执行通用更新操作
func (p *MySQLPool) Update(query string, values []any) (sql.Result, error) {
// 预处理查询
stmt, err := p.db.Prepare(query)
if err != nil {
return nil, fmt.Errorf("failed to prepare statement: %v", err)
}
defer stmt.Close()
// 执行更新操作
result, err := stmt.Exec(values...)
if err != nil {
return nil, fmt.Errorf("failed to execute update: %v", err)
}
return result, nil
}
// ExecuteTransactions 执行多条增删改操作,确保事务的原子性
func (p *MySQLPool) ExecuteTransactions(str_sqls []string, params [][]any) error {
// 检查 SQL 和参数的数量是否匹配
if len(str_sqls) != len(params) {
return fmt.Errorf("sql length != params length")
}
// 开始事务
tx, err := p.db.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
// 确保在函数结束时提交或回滚事务
defer func() {
if err != nil {
// 发生错误时回滚事务
if rollbackErr := tx.Rollback(); rollbackErr != nil {
err = fmt.Errorf("failed to rollback transaction: %v", rollbackErr)
}
} else {
// 如果没有错误,提交事务
if commitErr := tx.Commit(); commitErr != nil {
err = fmt.Errorf("failed to commit transaction: %v", commitErr)
}
}
}()
// 执行每个 SQL 语句
for i, sql_str := range str_sqls {
// 使用事务对象 tx 来执行 SQL
_, err := tx.Exec(sql_str, params[i]...)
if err != nil {
// 如果执行失败,立即返回并且触发回滚
return fmt.Errorf("failed to execute SQL at index %d: %v", i, err)
}
}
// 如果所有 SQL 执行成功,则返回 nil
return nil
}
// Close 关闭连接池
func (p *MySQLPool) Close() error {
return p.db.Close()
}