diff --git a/pkg/util/fs_util_test.go b/pkg/util/fs_util_test.go index 968cacf99..00b480ece 100644 --- a/pkg/util/fs_util_test.go +++ b/pkg/util/fs_util_test.go @@ -535,3 +535,63 @@ func TestExtractFile(t *testing.T) { }) } } + +func TestCopySymlink(t *testing.T) { + type tc struct { + name string + linkTarget string + dest string + beforeLink func(r string) error + } + + tcs := []tc{{ + name: "absolute symlink", + linkTarget: "/abs/dest", + }, { + name: "relative symlink", + linkTarget: "rel", + }, { + name: "symlink copy overwrites existing file", + linkTarget: "/abs/dest", + dest: "overwrite_me", + beforeLink: func(r string) error { + return ioutil.WriteFile(filepath.Join(r, "overwrite_me"), nil, 0644) + }, + }} + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() + r, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(r) + + if tc.beforeLink != nil { + if err := tc.beforeLink(r); err != nil { + t.Fatal(err) + } + } + link := filepath.Join(r, "link") + dest := filepath.Join(r, "copy") + if tc.dest != "" { + dest = filepath.Join(r, tc.dest) + } + if err := os.Symlink(tc.linkTarget, link); err != nil { + t.Fatal(err) + } + if err := CopySymlink(link, dest); err != nil { + t.Fatal(err) + } + got, err := os.Readlink(dest) + if err != nil { + t.Fatalf("error reading link %s: %s", link, err) + } + if got != tc.linkTarget { + t.Errorf("link target does not match: %s != %s", got, tc.linkTarget) + } + }) + } +}