256 lines
7.0 KiB
Go
256 lines
7.0 KiB
Go
package updater
|
||
|
||
import (
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
"strings"
|
||
)
|
||
|
||
// CheckAndUpdate 检查远程版本并更新
|
||
// remoteBaseURL: 远程服务器基础URL,例如 "http://example.com/update"
|
||
// currentVersion: 当前版本号
|
||
// 返回是否需要重启(如果更新了文件)
|
||
func CheckAndUpdate(remoteBaseURL string, currentVersion string) (bool, error) {
|
||
// 先检查是否有待应用的更新(Windows 下上次下载但未应用的新版本)
|
||
if runtime.GOOS == "windows" {
|
||
exePath, err := os.Executable()
|
||
if err == nil {
|
||
exeDir := filepath.Dir(exePath)
|
||
updateFlag := filepath.Join(exeDir, ".update_pending")
|
||
if data, err := os.ReadFile(updateFlag); err == nil {
|
||
newPath := strings.TrimSpace(string(data))
|
||
if err := applyPendingUpdate(newPath, exePath); err == nil {
|
||
os.Remove(updateFlag)
|
||
log.Println("已应用待处理的更新")
|
||
return true, nil
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取远程版本号
|
||
remoteVersion, err := fetchRemoteVersion(remoteBaseURL)
|
||
if err != nil {
|
||
log.Printf("获取远程版本失败:%v", err)
|
||
return false, err
|
||
}
|
||
|
||
log.Printf("当前版本:%s,远程版本:%s", currentVersion, remoteVersion)
|
||
|
||
// 比较版本
|
||
if strings.TrimSpace(remoteVersion) == strings.TrimSpace(currentVersion) {
|
||
log.Println("版本一致,无需更新")
|
||
return false, nil
|
||
}
|
||
|
||
log.Printf("检测到新版本:%s,开始下载更新...", remoteVersion)
|
||
|
||
// 下载新的可执行文件
|
||
remoteExeName := getRemoteExecutableName() // 远程文件名
|
||
localExeName := getLocalExecutableName() // 本地可执行文件名
|
||
remoteExeURL := fmt.Sprintf("%s/%s", remoteBaseURL, remoteExeName)
|
||
|
||
err = downloadAndReplace(remoteExeURL, localExeName)
|
||
if err != nil {
|
||
log.Printf("下载更新失败:%v", err)
|
||
return false, err
|
||
}
|
||
|
||
log.Printf("更新成功!新版本:%s", remoteVersion)
|
||
return true, nil
|
||
}
|
||
|
||
// fetchRemoteVersion 从远程获取版本号
|
||
func fetchRemoteVersion(remoteBaseURL string) (string, error) {
|
||
versionURL := fmt.Sprintf("%s/current_version", remoteBaseURL)
|
||
|
||
resp, err := http.Get(versionURL)
|
||
if err != nil {
|
||
return "", fmt.Errorf("请求远程版本文件失败:%v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return "", fmt.Errorf("获取远程版本文件失败,状态码:%d", resp.StatusCode)
|
||
}
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取远程版本文件失败:%v", err)
|
||
}
|
||
|
||
return strings.TrimSpace(string(body)), nil
|
||
}
|
||
|
||
// downloadAndReplace 下载并替换可执行文件
|
||
func downloadAndReplace(remoteURL string, exeName string) error {
|
||
// 获取当前可执行文件的路径
|
||
exePath, err := os.Executable()
|
||
if err != nil {
|
||
return fmt.Errorf("获取可执行文件路径失败:%v", err)
|
||
}
|
||
|
||
exeDir := filepath.Dir(exePath)
|
||
newExePath := filepath.Join(exeDir, exeName)
|
||
backupPath := exePath + ".backup"
|
||
|
||
// 1. 备份当前可执行文件
|
||
log.Printf("备份当前可执行文件到:%s", backupPath)
|
||
err = copyFile(exePath, backupPath)
|
||
if err != nil {
|
||
return fmt.Errorf("备份文件失败:%v", err)
|
||
}
|
||
|
||
// 2. 下载新的可执行文件
|
||
log.Printf("从 %s 下载新版本...", remoteURL)
|
||
resp, err := http.Get(remoteURL)
|
||
if err != nil {
|
||
// 如果下载失败,恢复备份
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("下载新版本失败:%v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("下载新版本失败,状态码:%d", resp.StatusCode)
|
||
}
|
||
|
||
// 3. 创建临时文件
|
||
tempPath := newExePath + ".tmp"
|
||
out, err := os.Create(tempPath)
|
||
if err != nil {
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("创建临时文件失败:%v", err)
|
||
}
|
||
defer out.Close()
|
||
|
||
// 4. 写入新文件
|
||
_, err = io.Copy(out, resp.Body)
|
||
if err != nil {
|
||
os.Remove(tempPath)
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("写入新文件失败:%v", err)
|
||
}
|
||
out.Close()
|
||
|
||
// 5. 设置可执行权限(Linux)
|
||
if runtime.GOOS != "windows" {
|
||
err = os.Chmod(tempPath, 0755)
|
||
if err != nil {
|
||
os.Remove(tempPath)
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("设置可执行权限失败:%v", err)
|
||
}
|
||
}
|
||
|
||
// 6. 尝试替换原文件
|
||
// Windows 下如果文件正在运行,无法直接替换,需要特殊处理
|
||
if runtime.GOOS == "windows" {
|
||
// Windows: 先尝试直接替换,如果失败则保存为 .new 文件,下次启动时替换
|
||
err = os.Rename(tempPath, exePath)
|
||
if err != nil {
|
||
// 如果替换失败(文件正在使用),保存为 .new 文件
|
||
newPath := exePath + ".new"
|
||
err = os.Rename(tempPath, newPath)
|
||
if err != nil {
|
||
os.Remove(tempPath)
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("保存新版本文件失败:%v", err)
|
||
}
|
||
log.Printf("当前程序正在运行,新版本已保存为:%s,程序重启后将自动应用更新", newPath)
|
||
// 创建更新标记文件,下次启动时检测并应用
|
||
updateFlag := filepath.Join(exeDir, ".update_pending")
|
||
os.WriteFile(updateFlag, []byte(newPath), 0644)
|
||
return nil
|
||
}
|
||
} else {
|
||
// Linux: 直接替换
|
||
err = os.Rename(tempPath, exePath)
|
||
if err != nil {
|
||
os.Remove(tempPath)
|
||
restoreBackup(backupPath, exePath)
|
||
return fmt.Errorf("替换文件失败:%v", err)
|
||
}
|
||
}
|
||
|
||
// 7. 清理备份文件(可选,更新成功后删除)
|
||
// os.Remove(backupPath)
|
||
|
||
log.Println("文件更新完成")
|
||
return nil
|
||
}
|
||
|
||
// copyFile 复制文件
|
||
func copyFile(src, dst string) error {
|
||
sourceFile, err := os.Open(src)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer sourceFile.Close()
|
||
|
||
destFile, err := os.Create(dst)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer destFile.Close()
|
||
|
||
_, err = io.Copy(destFile, sourceFile)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 复制文件权限
|
||
sourceInfo, err := os.Stat(src)
|
||
if err == nil {
|
||
os.Chmod(dst, sourceInfo.Mode())
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// restoreBackup 恢复备份文件
|
||
func restoreBackup(backupPath, exePath string) {
|
||
if _, err := os.Stat(backupPath); err == nil {
|
||
log.Printf("恢复备份文件:%s -> %s", backupPath, exePath)
|
||
os.Rename(backupPath, exePath)
|
||
}
|
||
}
|
||
|
||
// applyPendingUpdate 应用待处理的更新(Windows 下使用)
|
||
func applyPendingUpdate(newPath, exePath string) error {
|
||
if _, err := os.Stat(newPath); os.IsNotExist(err) {
|
||
return fmt.Errorf("新版本文件不存在:%s", newPath)
|
||
}
|
||
|
||
// 尝试替换
|
||
err := os.Rename(newPath, exePath)
|
||
if err != nil {
|
||
return fmt.Errorf("应用更新失败:%v(可能程序仍在运行)", err)
|
||
}
|
||
|
||
log.Printf("成功应用更新:%s -> %s", newPath, exePath)
|
||
return nil
|
||
}
|
||
|
||
// getRemoteExecutableName 根据操作系统获取远程可执行文件名(用于下载)
|
||
func getRemoteExecutableName() string {
|
||
if runtime.GOOS == "windows" {
|
||
return "client_windows.exe"
|
||
}
|
||
return "client_linux"
|
||
}
|
||
|
||
// getLocalExecutableName 根据操作系统获取本地可执行文件名(用于替换)
|
||
func getLocalExecutableName() string {
|
||
if runtime.GOOS == "windows" {
|
||
return "client.exe"
|
||
}
|
||
return "client"
|
||
}
|