219 lines
5.2 KiB
Go
219 lines
5.2 KiB
Go
package db
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
message "m2pool-payment/internal/msg"
|
|
"strings"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
type MySQLPool struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewMySQLPool 初始化连接池
|
|
func NewMySQLPool(cfg message.MysqlConfig) (*MySQLPool, error) {
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 设置连接池参数
|
|
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
db.SetConnMaxLifetime(cfg.ConnMaxLife)
|
|
|
|
// 测试连接
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &MySQLPool{db: db}, nil
|
|
}
|
|
|
|
// splitSQLStatements 将 SQL 内容按分号分割
|
|
func splitSQLStatements(sqlContent string) []string {
|
|
// 使用 bufio 扫描文件内容
|
|
var statements []string
|
|
scanner := bufio.NewScanner(strings.NewReader(sqlContent))
|
|
var queryBuilder strings.Builder
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
// 处理每一行的 SQL 语句
|
|
if strings.TrimSpace(line) == "" {
|
|
continue
|
|
}
|
|
|
|
// 如果行中包含分号,说明是完整的 SQL 语句
|
|
queryBuilder.WriteString(line)
|
|
if strings.HasSuffix(line, ";") {
|
|
statements = append(statements, queryBuilder.String())
|
|
queryBuilder.Reset() // 清空构建器以便准备下一条 SQL 语句
|
|
} else {
|
|
queryBuilder.WriteString("\n")
|
|
}
|
|
}
|
|
|
|
// 处理可能的扫描错误
|
|
if err := scanner.Err(); err != nil {
|
|
log.Fatalf("error reading .sql file: %v\n", err)
|
|
}
|
|
|
|
return statements
|
|
}
|
|
|
|
func (p *MySQLPool) ExecuteSQLFile(filePath string) error {
|
|
// 读取 SQL 文件内容
|
|
sqlContent, err := ioutil.ReadFile(filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read SQL file: %v", err)
|
|
}
|
|
// 将文件内容按分号 (;) 分割成多条 SQL 语句
|
|
queries := splitSQLStatements(string(sqlContent))
|
|
|
|
// 执行每一条 SQL 语句
|
|
for _, query := range queries {
|
|
// 跳过空行或注释
|
|
if strings.TrimSpace(query) == "" || strings.HasPrefix(strings.TrimSpace(query), "--") {
|
|
continue
|
|
}
|
|
|
|
// 执行 SQL 语句
|
|
_, err := p.db.Exec(query)
|
|
if err != nil {
|
|
log.Printf("error executing query: %v\n", err)
|
|
} else {
|
|
// fmt.Println("Executed query:", query)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Exec 执行 INSERT/UPDATE/DELETE
|
|
func (p *MySQLPool) Exec(query string, args ...any) (sql.Result, error) {
|
|
return p.db.Exec(query, args...)
|
|
}
|
|
|
|
// Query 查询多行
|
|
func (p *MySQLPool) Query(query string, args ...any) (*sql.Rows, error) {
|
|
return p.db.Query(query, args...)
|
|
}
|
|
|
|
// QueryRow 查询单行
|
|
func (p *MySQLPool) QueryRow(query string, args ...any) *sql.Row {
|
|
return p.db.QueryRow(query, args...)
|
|
}
|
|
|
|
// Transaction 执行事务
|
|
func (p *MySQLPool) Transaction(fn func(tx *sql.Tx) error) error {
|
|
tx, err := p.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := fn(tx); err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Insert 执行通用插入操作
|
|
func (p *MySQLPool) Insert(query string, values [][]any) (sql.Result, error) {
|
|
// 预处理查询
|
|
stmt, err := p.db.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to prepare statement: %v", err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
// 执行批量插入
|
|
var result sql.Result
|
|
for _, row := range values {
|
|
result, err = stmt.Exec(row...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute insert: %v", err)
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Delete 执行通用删除操作
|
|
|
|
// Update 执行通用更新操作
|
|
func (p *MySQLPool) Update(query string, values []any) (sql.Result, error) {
|
|
// 预处理查询
|
|
stmt, err := p.db.Prepare(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to prepare statement: %v", err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
// 执行更新操作
|
|
result, err := stmt.Exec(values...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute update: %v", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ExecuteTransactions 执行多条增删改操作,确保事务的原子性
|
|
func (p *MySQLPool) ExecuteTransactions(str_sqls []string, params [][]any) error {
|
|
// 检查 SQL 和参数的数量是否匹配
|
|
if len(str_sqls) != len(params) {
|
|
return fmt.Errorf("sql length != params length")
|
|
}
|
|
|
|
// 开始事务
|
|
tx, err := p.db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to begin transaction: %v", err)
|
|
}
|
|
|
|
// 确保在函数结束时提交或回滚事务
|
|
defer func() {
|
|
if err != nil {
|
|
// 发生错误时回滚事务
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
err = fmt.Errorf("failed to rollback transaction: %v", rollbackErr)
|
|
}
|
|
} else {
|
|
// 如果没有错误,提交事务
|
|
if commitErr := tx.Commit(); commitErr != nil {
|
|
err = fmt.Errorf("failed to commit transaction: %v", commitErr)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// 执行每个 SQL 语句
|
|
for i, sql_str := range str_sqls {
|
|
// 使用事务对象 tx 来执行 SQL
|
|
_, err := tx.Exec(sql_str, params[i]...)
|
|
if err != nil {
|
|
// 如果执行失败,立即返回并且触发回滚
|
|
return fmt.Errorf("failed to execute SQL at index %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
// 如果所有 SQL 执行成功,则返回 nil
|
|
return nil
|
|
}
|
|
|
|
// Close 关闭连接池
|
|
func (p *MySQLPool) Close() error {
|
|
return p.db.Close()
|
|
}
|