diff --git a/go.mod b/go.mod index 4fb2b55f..f76c3738 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( go.uber.org/zap v1.9.1 golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf // indirect golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271 // indirect + golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect google.golang.org/api v0.13.0 // indirect diff --git a/pkg/tmpl/context_funcs.go b/pkg/tmpl/context_funcs.go index 9b563348..79ddd70b 100644 --- a/pkg/tmpl/context_funcs.go +++ b/pkg/tmpl/context_funcs.go @@ -2,6 +2,7 @@ package tmpl import ( "fmt" + "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" "io" "os" @@ -58,60 +59,54 @@ func (c *Context) Exec(command string, args []interface{}, inputs ...string) (st cmd := exec.Command(command, strArgs...) cmd.Dir = c.basePath - writeErrs := make(chan error) - cmdErrs := make(chan error) - cmdOuts := make(chan []byte) + g := errgroup.Group{} if len(input) > 0 { stdin, err := cmd.StdinPipe() if err != nil { return "", err } - go func(input string, stdin io.WriteCloser) { + + g.Go(func() error { defer stdin.Close() - defer close(writeErrs) size := len(input) - var n int - var err error i := 0 + for { - n, err = io.WriteString(stdin, input[i:]) + n, err := io.WriteString(stdin, input[i:]) if err != nil { - writeErrs <- fmt.Errorf("failed while writing %d bytes to stdin of \"%s\": %v", len(input), command, err) - break + return fmt.Errorf("failed while writing %d bytes to stdin of \"%s\": %v", len(input), command, err) } + i += n - if n == size { - break + + if i == size { + return nil } } - }(input, stdin) + }) } - go func() { - defer close(cmdOuts) - defer close(cmdErrs) + var bytes []byte - bytes, err := cmd.Output() + g.Go(func() error { + bs, err := cmd.Output() if err != nil { - cmdErrs <- fmt.Errorf("exec cmd=%s args=[%s] failed: %v", command, strings.Join(strArgs, ", "), err) - } else { - cmdOuts <- bytes + return fmt.Errorf("exec cmd=%s args=[%s] failed: %v", command, strings.Join(strArgs, ", "), err) } - }() - for { - select { - case bytes := <-cmdOuts: - return string(bytes), nil - case err := <-cmdErrs: - return "", err - case err := <-writeErrs: - return "", err - } + bytes = bs + + return nil + }) + + if err := g.Wait(); err != nil { + return "", err } + + return string(bytes), nil } func (c *Context) ReadFile(filename string) (string, error) {