Skip to content

Commit 60a3d53

Browse files
authored
Ensure temporary directory symlinks do not cause errors (#560)
* Add test for validating git subdirectory symlink failure Fix the test * Replace go-safetemp as it is superfluous to our needs Remove go-safetemp dependency
1 parent d2e581e commit 60a3d53

File tree

10 files changed

+209
-11
lines changed

10 files changed

+209
-11
lines changed

client.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"strings"
1414

1515
urlhelper "github.com/hashicorp/go-getter/helper/url"
16-
safetemp "github.com/hashicorp/go-safetemp"
1716
)
1817

1918
// ErrSymlinkCopy means that a copy of a symlink was encountered on a request with DisableSymlinks enabled.
@@ -143,7 +142,7 @@ func (c *Client) Get() error {
143142
subDir = subDir[1:]
144143
}
145144

146-
td, tdcloser, err := safetemp.Dir("", "getter")
145+
td, tdcloser, err := mkdirTemp("", "getter")
147146
if err != nil {
148147
return err
149148
}

get_git.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"time"
2020

2121
urlhelper "github.com/hashicorp/go-getter/helper/url"
22-
safetemp "github.com/hashicorp/go-safetemp"
2322
version "github.com/hashicorp/go-version"
2423
)
2524

@@ -148,7 +147,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error {
148147
// GetFile for Git doesn't support updating at this time. It will download
149148
// the file every time.
150149
func (g *GitGetter) GetFile(dst string, u *url.URL) error {
151-
td, tdcloser, err := safetemp.Dir("", "getter")
150+
td, tdcloser, err := mkdirTemp("", "getter")
152151
if err != nil {
153152
return err
154153
}

get_git_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,53 @@ func TestGitGetter_subdirectory(t *testing.T) {
894894
}
895895
}
896896

897+
func TestGitGetter_subdirectory_ok(t *testing.T) {
898+
if !testHasGit {
899+
t.Skip("git not found, skipping")
900+
}
901+
902+
g := new(GitGetter)
903+
dst := filepath.Join(t.TempDir(), "target")
904+
905+
repo := testGitRepo(t, "basic")
906+
err := os.Mkdir(filepath.Join(repo.dir, "nested"), os.ModePerm)
907+
if err != nil {
908+
t.Fatal(err)
909+
}
910+
repo.commitFile("nested/foo.txt", "hello")
911+
912+
u, err := url.Parse(fmt.Sprintf("git::%s//nested", repo.url.String()))
913+
if err != nil {
914+
t.Fatal(err)
915+
}
916+
917+
client := &Client{
918+
Src: u.String(),
919+
Dst: dst,
920+
Pwd: ".",
921+
922+
Mode: ClientModeDir,
923+
924+
Detectors: []Detector{
925+
new(GitDetector),
926+
},
927+
Getters: map[string]Getter{
928+
"git": g,
929+
},
930+
}
931+
932+
err = client.Get()
933+
if err != nil {
934+
t.Fatalf("err: %s", err)
935+
}
936+
937+
// Verify the main file exists
938+
mainPath := filepath.Join(dst, "foo.txt")
939+
if _, err := os.Stat(mainPath); err != nil {
940+
t.Fatalf("err: %s", err)
941+
}
942+
}
943+
897944
func TestGitGetter_BadRemoteUrl(t *testing.T) {
898945

899946
if !testHasGit {

get_hg.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"time"
1515

1616
urlhelper "github.com/hashicorp/go-getter/helper/url"
17-
safetemp "github.com/hashicorp/go-safetemp"
1817
)
1918

2019
// HgGetter is a Getter implementation that will download a module from
@@ -85,7 +84,7 @@ func (g *HgGetter) Get(dst string, u *url.URL) error {
8584
func (g *HgGetter) GetFile(dst string, u *url.URL) error {
8685
// Create a temporary directory to store the full source. This has to be
8786
// a non-existent directory.
88-
td, tdcloser, err := safetemp.Dir("", "getter")
87+
td, tdcloser, err := mkdirTemp("", "getter")
8988
if err != nil {
9089
return err
9190
}

get_http.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"time"
1919

2020
"github.com/hashicorp/go-cleanhttp"
21-
safetemp "github.com/hashicorp/go-safetemp"
2221
)
2322

2423
// HttpGetter is a Getter implementation that will download from an HTTP
@@ -514,7 +513,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error {
514513
func (g *HttpGetter) getSubdir(ctx context.Context, dst, source, subDir string, opts ...ClientOption) error {
515514
// Create a temporary directory to store the full source. This has to be
516515
// a non-existent directory.
517-
td, tdcloser, err := safetemp.Dir("", "getter")
516+
td, tdcloser, err := mkdirTemp("", "getter")
518517
if err != nil {
519518
return err
520519
}

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ require (
1111
github.com/cheggaaa/pb v1.0.27
1212
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.65
1313
github.com/hashicorp/go-cleanhttp v0.5.2
14-
github.com/hashicorp/go-safetemp v1.0.0
1514
github.com/hashicorp/go-version v1.6.0
1615
github.com/klauspost/compress v1.15.11
1716
github.com/mitchellh/go-homedir v1.1.0

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.65 h1:81+kWbE1yErFBMjME0I5k3
101101
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.65/go.mod h1:WtMzv9T++tfWVea+qB2MXoaqxw33S8bpJslzUike2mQ=
102102
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
103103
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
104-
github.com/hashicorp/go-safetemp v1.0.0 h1:2HR189eFNrjHQyENnQMMpCiBAsRxzbTMIgBhEyExpmo=
105-
github.com/hashicorp/go-safetemp v1.0.0/go.mod h1:oaerMy3BhqiTbVye6QuFhFtIceqFoDHxNAB65b+Rj1I=
106104
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
107105
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
108106
github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c=

tempdir.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package getter
5+
6+
import (
7+
"io"
8+
"os"
9+
"path/filepath"
10+
)
11+
12+
// mkdirTemp creates a new temporary directory that isn't yet created. This
13+
// can be used with calls that expect a non-existent directory.
14+
//
15+
// The temporary directory is also evaluated for symlinks upon creation
16+
// as some operating systems provide symlinks by default when created.
17+
//
18+
// The directory is created as a child of a temporary directory created
19+
// within the directory dir starting with prefix. The temporary directory
20+
// returned is always named "temp". The parent directory has the specified
21+
// prefix.
22+
//
23+
// The returned io.Closer should be used to clean up the returned directory.
24+
// This will properly remove the returned directory and any other temporary
25+
// files created.
26+
//
27+
// If an error is returned, the Closer does not need to be called (and will
28+
// be nil).
29+
func mkdirTemp(dir, prefix string) (string, io.Closer, error) {
30+
// Create the temporary directory
31+
td, err := os.MkdirTemp(dir, prefix)
32+
if err != nil {
33+
return "", nil, err
34+
}
35+
36+
// we evaluate symlinks as some operating systems (eg: MacOS), that
37+
// actually has any temporary directory created as a symlink.
38+
// As we have only just created the temporary directory, this is a safe
39+
// evaluation to make at this time.
40+
td, err = filepath.EvalSymlinks(td)
41+
if err != nil {
42+
return "", nil, err
43+
}
44+
45+
return filepath.Join(td, "temp"), pathCloser(td), nil
46+
}
47+
48+
// pathCloser implements io.Closer to remove the given path on Close.
49+
type pathCloser string
50+
51+
// Close deletes this path.
52+
func (p pathCloser) Close() error {
53+
return os.RemoveAll(string(p))
54+
}

tempdir_unix_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//go:build !windows
2+
// +build !windows
3+
4+
// Copyright (c) HashiCorp, Inc.
5+
// SPDX-License-Identifier: MPL-2.0
6+
7+
package getter
8+
9+
import (
10+
"os"
11+
"path/filepath"
12+
"testing"
13+
)
14+
15+
func Test_mkdirTemp(t *testing.T) {
16+
d, c, err := mkdirTemp("", "test")
17+
if err != nil {
18+
t.Fatalf("err: %s", err)
19+
}
20+
21+
if _, err := os.Stat(d); err == nil || !os.IsNotExist(err) {
22+
t.Fatalf("directory %q should not exist", d)
23+
}
24+
25+
parent := filepath.Dir(d)
26+
fi, err := os.Stat(parent)
27+
if err != nil {
28+
t.Fatalf("parent directory error: %s", err)
29+
}
30+
if v := fi.Mode().Perm(); v != 0700 {
31+
t.Fatalf("parent directory should be 0700: %s", v)
32+
}
33+
34+
// Create the directory
35+
if err := os.MkdirAll(d, 0755); err != nil {
36+
t.Fatalf("err: %s", err)
37+
}
38+
if _, err := os.Stat(d); err != nil {
39+
t.Fatalf("directory %q should exist", d)
40+
}
41+
42+
// Close should remove it
43+
if err := c.Close(); err != nil {
44+
t.Fatalf("err: %s", err)
45+
}
46+
if _, err := os.Stat(d); err == nil || !os.IsNotExist(err) {
47+
t.Fatalf("directory %q should not exist", d)
48+
}
49+
if _, err := os.Stat(parent); err == nil || !os.IsNotExist(err) {
50+
t.Fatalf("directory %q should not exist", parent)
51+
}
52+
}

tempdir_windows_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//go:build windows
2+
// +build windows
3+
4+
// Copyright (c) HashiCorp, Inc.
5+
// SPDX-License-Identifier: MPL-2.0
6+
7+
package getter
8+
9+
import (
10+
"os"
11+
"path/filepath"
12+
"testing"
13+
)
14+
15+
func Test_mkdirTemp(t *testing.T) {
16+
d, c, err := mkdirTemp("", "test")
17+
if err != nil {
18+
t.Fatalf("err: %s", err)
19+
}
20+
21+
if _, err := os.Stat(d); err == nil || !os.IsNotExist(err) {
22+
t.Fatalf("directory %q should not exist", d)
23+
}
24+
25+
parent := filepath.Dir(d)
26+
fi, err := os.Stat(parent)
27+
if err != nil {
28+
t.Fatalf("parent directory error: %s", err)
29+
}
30+
if v := fi.Mode().Perm(); v != 0777 {
31+
t.Fatalf("parent directory should be 0777: %s", v)
32+
}
33+
34+
// Create the directory
35+
if err := os.MkdirAll(d, 0755); err != nil {
36+
t.Fatalf("err: %s", err)
37+
}
38+
if _, err := os.Stat(d); err != nil {
39+
t.Fatalf("directory %q should exist", d)
40+
}
41+
42+
// Close should remove it
43+
if err := c.Close(); err != nil {
44+
t.Fatalf("err: %s", err)
45+
}
46+
if _, err := os.Stat(d); err == nil || !os.IsNotExist(err) {
47+
t.Fatalf("directory %q should not exist", d)
48+
}
49+
if _, err := os.Stat(parent); err == nil || !os.IsNotExist(err) {
50+
t.Fatalf("directory %q should not exist", parent)
51+
}
52+
}

0 commit comments

Comments
 (0)