diff --git a/pkg/util/command_util.go b/pkg/util/command_util.go index 448a59236..f71459171 100644 --- a/pkg/util/command_util.go +++ b/pkg/util/command_util.go @@ -387,41 +387,48 @@ func getUIDAndGID(userStr string, groupStr string, fallbackToUID bool) (uint32, return 0, 0, err } - gid, err := getGIDFromName(groupStr, fallbackToUID) - if err != nil { - if errors.Is(err, fallbackToUIDError) { - return uid32, uid32, nil + if groupStr != "" { + gid32, err := getGIDFromName(groupStr) + if err != nil { + if errors.Is(err, fallbackToUIDError) { + return uid32, uid32, nil + } + return 0, 0, err } - return 0, 0, err + return uid32, gid32, nil } - return uid32, gid, nil + + if fallbackToUID { + return uid32, uid32, nil + } + + return uid32, 0, nil } -// getGID tries to parse the gid or falls back to getGroupFromName if it's not an id -func getGID(groupStr string, fallbackToUID bool) (uint32, error) { +// getGID tries to parse the gid +func getGID(groupStr string) (uint32, error) { gid, err := strconv.ParseUint(groupStr, 10, 32) if err != nil { - return 0, fallbackToUIDOrError(err, fallbackToUID) + return 0, err } return uint32(gid), nil } // getGIDFromName tries to parse the groupStr into an existing group. -// if the group doesn't exist, fallback to getGID to parse non-existing valid GIDs. -func getGIDFromName(groupStr string, fallbackToUID bool) (uint32, error) { +func getGIDFromName(groupStr string) (uint32, error) { group, err := user.LookupGroup(groupStr) if err != nil { // unknown group error could relate to a non existing group - var groupErr *user.UnknownGroupError - if errors.Is(err, groupErr) { - return getGID(groupStr, fallbackToUID) + var groupErr user.UnknownGroupError + if errors.As(err, &groupErr) { + return getGID(groupStr) } group, err = user.LookupGroupId(groupStr) if err != nil { - return getGID(groupStr, fallbackToUID) + return getGID(groupStr) } } - return getGID(group.Gid, fallbackToUID) + return getGID(group.Gid) } var fallbackToUIDError = new(fallbackToUIDErrorType) @@ -432,13 +439,6 @@ func (e fallbackToUIDErrorType) Error() string { return "fallback to uid" } -func fallbackToUIDOrError(err error, fallbackToUID bool) error { - if fallbackToUID { - return fallbackToUIDError - } - return err -} - // LookupUser will try to lookup the userStr inside the passwd file. // If the user does not exists, the function will fallback to parsing the userStr as an uid. func LookupUser(userStr string) (*user.User, error) { diff --git a/pkg/util/command_util_test.go b/pkg/util/command_util_test.go index 3d8384771..f811f5138 100644 --- a/pkg/util/command_util_test.go +++ b/pkg/util/command_util_test.go @@ -705,7 +705,7 @@ func Test_GetUIDAndGIDFromString(t *testing.T) { }, expected: expected{ userID: 1001, - groupID: uint32(currentUserGID), + groupID: expectedCurrentUser.groupID, }, }, { @@ -714,15 +714,13 @@ func Test_GetUIDAndGIDFromString(t *testing.T) { userGroupStr: fmt.Sprintf("%d:%s", 1001, "hello-world-group"), fallbackToUID: true, }, - expected: expected{ - userID: 1001, - groupID: 1001, - }, + wantErr: true, }, { - testname: "uid and non existing group-name", + testname: "uid and non existing group-name without fallbackToUID", args: args{ - userGroupStr: fmt.Sprintf("%d:%s", 1001, "hello-world-group"), + userGroupStr: fmt.Sprintf("%d:%s", 1001, "hello-world-group"), + fallbackToUID: false, }, wantErr: true, }, @@ -742,7 +740,10 @@ func Test_GetUIDAndGIDFromString(t *testing.T) { userGroupStr: fmt.Sprintf("%d", currentUserUID), fallbackToUID: false, }, - wantErr: true, + expected: expected{ + userID: expectedCurrentUser.userID, + groupID: 0, + }, }, { testname: "only uid and fallback is true", diff --git a/pkg/util/syscall_credentials.go b/pkg/util/syscall_credentials.go index a316ea004..2177abd52 100644 --- a/pkg/util/syscall_credentials.go +++ b/pkg/util/syscall_credentials.go @@ -19,6 +19,7 @@ package util import ( "fmt" "strconv" + "strings" "syscall" "github.com/pkg/errors" @@ -54,6 +55,12 @@ func SyscallCredentials(userStr string) (*syscall.Credential, error) { groups = append(groups, uint32(i)) } + if !(len(strings.Split(userStr, ":")) > 1) { + if u.Gid != "" { + gid, _ = getGID(u.Gid) + } + } + return &syscall.Credential{ Uid: uid, Gid: gid,