package db import ( "bufio" "database/sql" "fmt" "io/ioutil" "log" message "m2pool-payment/internal/msg" "strings" _ "github.com/go-sql-driver/mysql" ) type MySQLPool struct { db *sql.DB } // NewMySQLPool 初始化连接池 func NewMySQLPool(cfg message.MysqlConfig) (*MySQLPool, error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) db, err := sql.Open("mysql", dsn) if err != nil { return nil, err } // 设置连接池参数 db.SetMaxOpenConns(cfg.MaxOpenConns) db.SetMaxIdleConns(cfg.MaxIdleConns) db.SetConnMaxLifetime(cfg.ConnMaxLife) // 测试连接 if err := db.Ping(); err != nil { return nil, err } return &MySQLPool{db: db}, nil } // splitSQLStatements 将 SQL 内容按分号分割 func splitSQLStatements(sqlContent string) []string { // 使用 bufio 扫描文件内容 var statements []string scanner := bufio.NewScanner(strings.NewReader(sqlContent)) var queryBuilder strings.Builder for scanner.Scan() { line := scanner.Text() // 处理每一行的 SQL 语句 if strings.TrimSpace(line) == "" { continue } // 如果行中包含分号,说明是完整的 SQL 语句 queryBuilder.WriteString(line) if strings.HasSuffix(line, ";") { statements = append(statements, queryBuilder.String()) queryBuilder.Reset() // 清空构建器以便准备下一条 SQL 语句 } else { queryBuilder.WriteString("\n") } } // 处理可能的扫描错误 if err := scanner.Err(); err != nil { log.Fatalf("error reading .sql file: %v\n", err) } return statements } func (p *MySQLPool) ExecuteSQLFile(filePath string) error { // 读取 SQL 文件内容 sqlContent, err := ioutil.ReadFile(filePath) if err != nil { return fmt.Errorf("failed to read SQL file: %v", err) } // 将文件内容按分号 (;) 分割成多条 SQL 语句 queries := splitSQLStatements(string(sqlContent)) // 执行每一条 SQL 语句 for _, query := range queries { // 跳过空行或注释 if strings.TrimSpace(query) == "" || strings.HasPrefix(strings.TrimSpace(query), "--") { continue } // 执行 SQL 语句 _, err := p.db.Exec(query) if err != nil { log.Printf("error executing query: %v\n", err) } else { // fmt.Println("Executed query:", query) } } return nil } // Exec 执行 INSERT/UPDATE/DELETE func (p *MySQLPool) Exec(query string, args ...any) (sql.Result, error) { return p.db.Exec(query, args...) } // Query 查询多行 func (p *MySQLPool) Query(query string, args ...any) (*sql.Rows, error) { return p.db.Query(query, args...) } // QueryRow 查询单行 func (p *MySQLPool) QueryRow(query string, args ...any) *sql.Row { return p.db.QueryRow(query, args...) } // Transaction 执行事务 func (p *MySQLPool) Transaction(fn func(tx *sql.Tx) error) error { tx, err := p.db.Begin() if err != nil { return err } if err := fn(tx); err != nil { _ = tx.Rollback() return err } return tx.Commit() } // Insert 执行通用插入操作 func (p *MySQLPool) Insert(query string, values [][]any) (sql.Result, error) { // 预处理查询 stmt, err := p.db.Prepare(query) if err != nil { return nil, fmt.Errorf("failed to prepare statement: %v", err) } defer stmt.Close() // 执行批量插入 var result sql.Result for _, row := range values { result, err = stmt.Exec(row...) if err != nil { return nil, fmt.Errorf("failed to execute insert: %v", err) } } return result, nil } // Delete 执行通用删除操作 // Update 执行通用更新操作 func (p *MySQLPool) Update(query string, values []any) (sql.Result, error) { // 预处理查询 stmt, err := p.db.Prepare(query) if err != nil { return nil, fmt.Errorf("failed to prepare statement: %v", err) } defer stmt.Close() // 执行更新操作 result, err := stmt.Exec(values...) if err != nil { return nil, fmt.Errorf("failed to execute update: %v", err) } return result, nil } // ExecuteTransactions 执行多条增删改操作,确保事务的原子性 func (p *MySQLPool) ExecuteTransactions(str_sqls []string, params [][]any) error { // 检查 SQL 和参数的数量是否匹配 if len(str_sqls) != len(params) { return fmt.Errorf("sql length != params length") } // 开始事务 tx, err := p.db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %v", err) } // 确保在函数结束时提交或回滚事务 defer func() { if err != nil { // 发生错误时回滚事务 if rollbackErr := tx.Rollback(); rollbackErr != nil { err = fmt.Errorf("failed to rollback transaction: %v", rollbackErr) } } else { // 如果没有错误,提交事务 if commitErr := tx.Commit(); commitErr != nil { err = fmt.Errorf("failed to commit transaction: %v", commitErr) } } }() // 执行每个 SQL 语句 for i, sql_str := range str_sqls { // 使用事务对象 tx 来执行 SQL _, err := tx.Exec(sql_str, params[i]...) if err != nil { // 如果执行失败,立即返回并且触发回滚 return fmt.Errorf("failed to execute SQL at index %d: %v", i, err) } } // 如果所有 SQL 执行成功,则返回 nil return nil } // Close 关闭连接池 func (p *MySQLPool) Close() error { return p.db.Close() }