add log-system, bug fixed
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
message "m2pool-payment/internal/msg"
|
||||
"strings"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
@@ -34,6 +38,67 @@ func NewMySQLPool(cfg message.MysqlConfig) (*MySQLPool, error) {
|
||||
return &MySQLPool{db: db}, nil
|
||||
}
|
||||
|
||||
// splitSQLStatements 将 SQL 内容按分号分割
|
||||
func splitSQLStatements(sqlContent string) []string {
|
||||
// 使用 bufio 扫描文件内容
|
||||
var statements []string
|
||||
scanner := bufio.NewScanner(strings.NewReader(sqlContent))
|
||||
var queryBuilder strings.Builder
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// 处理每一行的 SQL 语句
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果行中包含分号,说明是完整的 SQL 语句
|
||||
queryBuilder.WriteString(line)
|
||||
if strings.HasSuffix(line, ";") {
|
||||
statements = append(statements, queryBuilder.String())
|
||||
queryBuilder.Reset() // 清空构建器以便准备下一条 SQL 语句
|
||||
} else {
|
||||
queryBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// 处理可能的扫描错误
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Fatalf("error reading .sql file: %v\n", err)
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
func (p *MySQLPool) ExecuteSQLFile(filePath string) error {
|
||||
// 读取 SQL 文件内容
|
||||
sqlContent, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read SQL file: %v", err)
|
||||
}
|
||||
// 将文件内容按分号 (;) 分割成多条 SQL 语句
|
||||
queries := splitSQLStatements(string(sqlContent))
|
||||
|
||||
// 执行每一条 SQL 语句
|
||||
for _, query := range queries {
|
||||
// 跳过空行或注释
|
||||
if strings.TrimSpace(query) == "" || strings.HasPrefix(strings.TrimSpace(query), "--") {
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行 SQL 语句
|
||||
_, err := p.db.Exec(query)
|
||||
if err != nil {
|
||||
log.Printf("error executing query: %v\n", err)
|
||||
} else {
|
||||
// fmt.Println("Executed query:", query)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec 执行 INSERT/UPDATE/DELETE
|
||||
func (p *MySQLPool) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return p.db.Exec(query, args...)
|
||||
@@ -64,6 +129,89 @@ func (p *MySQLPool) Transaction(fn func(tx *sql.Tx) error) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// Insert 执行通用插入操作
|
||||
func (p *MySQLPool) Insert(query string, values [][]any) (sql.Result, error) {
|
||||
// 预处理查询
|
||||
stmt, err := p.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare statement: %v", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
// 执行批量插入
|
||||
var result sql.Result
|
||||
for _, row := range values {
|
||||
result, err = stmt.Exec(row...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute insert: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Delete 执行通用删除操作
|
||||
|
||||
// Update 执行通用更新操作
|
||||
func (p *MySQLPool) Update(query string, values []any) (sql.Result, error) {
|
||||
// 预处理查询
|
||||
stmt, err := p.db.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to prepare statement: %v", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
// 执行更新操作
|
||||
result, err := stmt.Exec(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute update: %v", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteTransactions 执行多条增删改操作,确保事务的原子性
|
||||
func (p *MySQLPool) ExecuteTransactions(str_sqls []string, params [][]any) error {
|
||||
// 检查 SQL 和参数的数量是否匹配
|
||||
if len(str_sqls) != len(params) {
|
||||
return fmt.Errorf("sql length != params length")
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
tx, err := p.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
// 确保在函数结束时提交或回滚事务
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// 发生错误时回滚事务
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
err = fmt.Errorf("failed to rollback transaction: %v", rollbackErr)
|
||||
}
|
||||
} else {
|
||||
// 如果没有错误,提交事务
|
||||
if commitErr := tx.Commit(); commitErr != nil {
|
||||
err = fmt.Errorf("failed to commit transaction: %v", commitErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 执行每个 SQL 语句
|
||||
for i, sql_str := range str_sqls {
|
||||
// 使用事务对象 tx 来执行 SQL
|
||||
_, err := tx.Exec(sql_str, params[i]...)
|
||||
if err != nil {
|
||||
// 如果执行失败,立即返回并且触发回滚
|
||||
return fmt.Errorf("failed to execute SQL at index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有 SQL 执行成功,则返回 nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭连接池
|
||||
func (p *MySQLPool) Close() error {
|
||||
return p.db.Close()
|
||||
|
||||
Reference in New Issue
Block a user