Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 58 additions & 108 deletions cmd/update/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@ import (
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"

"github.com/anyproto/anytype-cli/core"
"github.com/anyproto/anytype-cli/core/output"
"github.com/minio/selfupdate"
"github.com/spf13/cobra"
)

const (
githubOwner = "anyproto"
githubRepo = "anytype-cli"
"github.com/anyproto/anytype-cli/core"
"github.com/anyproto/anytype-cli/core/config"
"github.com/anyproto/anytype-cli/core/output"
)

func NewUpdateCmd() *cobra.Command {
Expand Down Expand Up @@ -98,22 +95,19 @@ func downloadAndInstall(version string) error {
return err
}

if err := extractArchive(archivePath, tempDir); err != nil {
return output.Error("Failed to extract: %w", err)
}

binaryName := "anytype"
if runtime.GOOS == "windows" {
binaryName = "anytype.exe"
}

newBinary := filepath.Join(tempDir, binaryName)
if _, err := os.Stat(newBinary); err != nil {
return output.Error("binary not found in archive (expected %s)", binaryName)
binaryReader, err := extractBinary(archivePath, binaryName)
if err != nil {
return output.Error("Failed to extract: %w", err)
}
defer binaryReader.Close()

if err := replaceBinary(newBinary); err != nil {
return output.Error("Failed to install: %w", err)
if err := selfupdate.Apply(binaryReader, selfupdate.Options{}); err != nil {
return output.Error("failed to apply update: %w", err)
}

return nil
Expand All @@ -135,8 +129,8 @@ func downloadRelease(version, destination string) error {
return downloadViaAPI(version, archiveName, destination)
}

url := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
githubOwner, githubRepo, version, archiveName)
url := fmt.Sprintf("%s/releases/download/%s/%s",
config.GitHubBaseURL, version, archiveName)

return downloadFile(url, destination, "")
}
Expand Down Expand Up @@ -207,25 +201,24 @@ func downloadFile(url, destination, token string) error {
return err
}

func extractArchive(archivePath, destDir string) error {
func extractBinary(archivePath, binaryName string) (io.ReadCloser, error) {
if strings.HasSuffix(archivePath, ".zip") {
return extractZip(archivePath, destDir)
return extractBinaryFromZip(archivePath, binaryName)
}
return extractTarGz(archivePath, destDir)
return extractBinaryFromTarGz(archivePath, binaryName)
}

func extractTarGz(archivePath, destDir string) error {
func extractBinaryFromTarGz(archivePath, binaryName string) (io.ReadCloser, error) {
file, err := os.Open(archivePath)
if err != nil {
return err
return nil, err
}
defer file.Close()

gz, err := gzip.NewReader(file)
if err != nil {
return err
file.Close()
return nil, err
}
defer gz.Close()

tr := tar.NewReader(gz)
for {
Expand All @@ -234,108 +227,65 @@ func extractTarGz(archivePath, destDir string) error {
break
}
if err != nil {
return err
file.Close()
gz.Close()
return nil, err
}

if strings.Contains(header.Name, "..") {
return fmt.Errorf("illegal file path: %s", header.Name)
}
target := filepath.Join(destDir, header.Name)

switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, 0755); err != nil {
return err
}
case tar.TypeReg:
if err := writeFile(target, tr, header.FileInfo().Mode()); err != nil {
return err
}
if header.Typeflag == tar.TypeReg && header.Name == binaryName {
return &tarGzReader{Reader: tr, gz: gz, file: file}, nil
}
}
return nil
}

func extractZip(archivePath, destDir string) error {
r, err := zip.OpenReader(archivePath)
if err != nil {
return err
}
defer r.Close()

for _, f := range r.File {
if strings.Contains(f.Name, "..") {
return fmt.Errorf("illegal file path: %s", f.Name)
}
target := filepath.Join(destDir, f.Name)

if f.FileInfo().IsDir() {
_ = os.MkdirAll(target, f.Mode())
continue
}

rc, err := f.Open()
if err != nil {
return err
}

if err := writeFile(target, rc, f.Mode()); err != nil {
rc.Close()
return err
}
rc.Close()
}
return nil
file.Close()
gz.Close()
return nil, fmt.Errorf("binary %s not found in archive", binaryName)
}

func writeFile(path string, r io.Reader, mode os.FileMode) error {
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return err
}

out, err := os.Create(path)
if err != nil {
return err
}
defer out.Close()

if _, err := io.Copy(out, r); err != nil {
return err
}

return os.Chmod(path, mode)
type tarGzReader struct {
io.Reader
gz *gzip.Reader
file *os.File
}

func replaceBinary(newBinary string) error {
if err := os.Chmod(newBinary, 0755); err != nil {
return err
}
func (r *tarGzReader) Close() error {
r.gz.Close()
return r.file.Close()
}

currentBinary, err := os.Executable()
if err != nil {
return err
}
currentBinary, err = filepath.EvalSymlinks(currentBinary)
func extractBinaryFromZip(archivePath, binaryName string) (io.ReadCloser, error) {
r, err := zip.OpenReader(archivePath)
if err != nil {
return err
return nil, err
}

if err := os.Rename(newBinary, currentBinary); err != nil {
if runtime.GOOS != "windows" {
cmd := exec.Command("mv", newBinary, currentBinary)
if err := cmd.Run(); err != nil {
return output.Error("failed to replace binary at %s (permission denied). To get the latest version, reinstall from repository: https://github.com/%s/%s", currentBinary, githubOwner, githubRepo)
for _, f := range r.File {
if f.Name == binaryName {
rc, err := f.Open()
if err != nil {
r.Close()
return nil, err
}
} else {
return output.Error("failed to replace binary at %s. To get the latest version, reinstall from repository: https://github.com/%s/%s", currentBinary, githubOwner, githubRepo)
return &zipReader{ReadCloser: rc, zipReader: r}, nil
}
}

return nil
r.Close()
return nil, fmt.Errorf("binary %s not found in archive", binaryName)
}

type zipReader struct {
io.ReadCloser
zipReader *zip.ReadCloser
}

func (r *zipReader) Close() error {
r.ReadCloser.Close()
return r.zipReader.Close()
}

func githubAPI(method, endpoint string, body io.Reader) (*http.Response, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s%s", githubOwner, githubRepo, endpoint)
url := config.GitHubAPIURL + endpoint

req, err := http.NewRequest(method, url, body)
if err != nil {
Expand Down
Loading