diff --git a/pkg/util/fs_util.go b/pkg/util/fs_util.go index ba36461b2..8b4829844 100644 --- a/pkg/util/fs_util.go +++ b/pkg/util/fs_util.go @@ -509,8 +509,12 @@ func CopySymlink(src, dest string) error { if err != nil { return err } - linkDst := filepath.Join(dest, link) - return os.Symlink(linkDst, dest) + if FilepathExists(dest) { + if err := os.RemoveAll(dest); err != nil { + return err + } + } + return os.Symlink(link, dest) } // CopyFile copies the file at src to dest diff --git a/pkg/util/fs_util_test.go b/pkg/util/fs_util_test.go index 3ca29aeb8..b259552e1 100644 --- a/pkg/util/fs_util_test.go +++ b/pkg/util/fs_util_test.go @@ -536,3 +536,63 @@ func TestExtractFile(t *testing.T) { }) } } + +func TestCopySymlink(t *testing.T) { + type tc struct { + name string + linkTarget string + dest string + beforeLink func(r string) error + } + + tcs := []tc{{ + name: "absolute symlink", + linkTarget: "/abs/dest", + }, { + name: "relative symlink", + linkTarget: "rel", + }, { + name: "symlink copy overwrites existing file", + linkTarget: "/abs/dest", + dest: "overwrite_me", + beforeLink: func(r string) error { + return ioutil.WriteFile(filepath.Join(r, "overwrite_me"), nil, 0644) + }, + }} + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() + r, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(r) + + if tc.beforeLink != nil { + if err := tc.beforeLink(r); err != nil { + t.Fatal(err) + } + } + link := filepath.Join(r, "link") + dest := filepath.Join(r, "copy") + if tc.dest != "" { + dest = filepath.Join(r, tc.dest) + } + if err := os.Symlink(tc.linkTarget, link); err != nil { + t.Fatal(err) + } + if err := CopySymlink(link, dest); err != nil { + t.Fatal(err) + } + got, err := os.Readlink(dest) + if err != nil { + t.Fatalf("error reading link %s: %s", link, err) + } + if got != tc.linkTarget { + t.Errorf("link target does not match: %s != %s", got, tc.linkTarget) + } + }) + } +}