diff --git a/pkg/cluster/cluster.go b/pkg/cluster/cluster.go index f5c8a3b01..5f688e44b 100644 --- a/pkg/cluster/cluster.go +++ b/pkg/cluster/cluster.go @@ -71,10 +71,11 @@ type Cluster struct { deleteOptions *metav1.DeleteOptions podEventsQueue *cache.FIFO - teamsAPIClient *teams.API - KubeClient k8sutil.KubernetesClient //TODO: move clients to the better place? - currentProcess spec.Process - processMu sync.RWMutex + teamsAPIClient teams.Interface + oauthTokenGetter OAuthTokenGetter + KubeClient k8sutil.KubernetesClient //TODO: move clients to the better place? + currentProcess spec.Process + processMu sync.RWMutex } type compareStatefulsetResult struct { @@ -112,9 +113,10 @@ func New(cfg Config, kubeClient k8sutil.KubernetesClient, pgSpec spec.Postgresql deleteOptions: &metav1.DeleteOptions{OrphanDependents: &orphanDependents}, podEventsQueue: podEventsQueue, KubeClient: kubeClient, - teamsAPIClient: teams.NewTeamsAPI(cfg.OpConfig.TeamsAPIUrl, logger), } cluster.logger = logger.WithField("pkg", "cluster").WithField("cluster-name", cluster.clusterName()) + cluster.teamsAPIClient = teams.NewTeamsAPI(cfg.OpConfig.TeamsAPIUrl, logger) + cluster.oauthTokenGetter = NewSecretOauthTokenGetter(&kubeClient, cfg.OpConfig.OAuthTokenSecretName) cluster.patroni = patroni.New(cluster.logger) return cluster diff --git a/pkg/cluster/cluster_test.go b/pkg/cluster/cluster_test.go new file mode 100644 index 000000000..71a45a582 --- /dev/null +++ b/pkg/cluster/cluster_test.go @@ -0,0 +1,125 @@ +package cluster + +import ( + "fmt" + "github.com/Sirupsen/logrus" + "github.com/zalando-incubator/postgres-operator/pkg/spec" + "github.com/zalando-incubator/postgres-operator/pkg/util/k8sutil" + "github.com/zalando-incubator/postgres-operator/pkg/util/teams" + "reflect" + "testing" +) + +var logger = logrus.New().WithField("test", "cluster") +var cl = New(Config{}, k8sutil.KubernetesClient{}, spec.Postgresql{}, logger) + +func TestInitRobotUsers(t *testing.T) { + testName := "TestInitRobotUsers" + tests := []struct { + manifestUsers map[string]spec.UserFlags + infraRoles map[string]spec.PgUser + result map[string]spec.PgUser + err error + }{ + { + manifestUsers: map[string]spec.UserFlags{"foo": {"superuser", "createdb"}}, + infraRoles: map[string]spec.PgUser{"foo": {Name: "foo", Password: "bar"}}, + result: map[string]spec.PgUser{"foo": {Name: "foo", Password: "bar", + Flags: []string{"CREATEDB", "LOGIN", "SUPERUSER"}}}, + err: nil, + }, + { + manifestUsers: map[string]spec.UserFlags{"!fooBar": {"superuser", "createdb"}}, + err: fmt.Errorf(`invalid username: "!fooBar"`), + }, + { + manifestUsers: map[string]spec.UserFlags{"foobar": {"!superuser", "createdb"}}, + err: fmt.Errorf(`invalid flags for user "foobar": ` + + `user flag "!superuser" is not alphanumeric`), + }, + { + manifestUsers: map[string]spec.UserFlags{"foobar": {"superuser1", "createdb"}}, + err: fmt.Errorf(`invalid flags for user "foobar": ` + + `user flag "SUPERUSER1" is not valid`), + }, + { + manifestUsers: map[string]spec.UserFlags{"foobar": {"inherit", "noinherit"}}, + err: fmt.Errorf(`invalid flags for user "foobar": ` + + `conflicting user flags: "NOINHERIT" and "INHERIT"`), + }, + } + for _, tt := range tests { + cl.Spec.Users = tt.manifestUsers + cl.pgUsers = tt.infraRoles + if err := cl.initRobotUsers(); err != nil { + if tt.err == nil { + t.Errorf("%s got an unexpected error: %v", testName, err) + } + if err.Error() != tt.err.Error() { + t.Errorf("%s expected error %v, got %v", testName, tt.err, err) + } + } else { + if !reflect.DeepEqual(cl.pgUsers, tt.result) { + t.Errorf("%s expected: %#v, got %#v", testName, tt.result, cl.pgUsers) + } + } + } +} + +type mockOAuthTokenGetter struct { +} + +func (m *mockOAuthTokenGetter) getOAuthToken() (string, error) { + return "", nil +} + +type mockTeamsAPIClient struct { + members []string +} + +func (m *mockTeamsAPIClient) TeamInfo(teamID, token string) (tm *teams.Team, err error) { + return &teams.Team{Members: m.members}, nil +} + +func (m *mockTeamsAPIClient) setMembers(members []string) { + m.members = members +} + +func TestInitHumanUsers(t *testing.T) { + + var mockTeamsAPI mockTeamsAPIClient + cl.oauthTokenGetter = &mockOAuthTokenGetter{} + cl.teamsAPIClient = &mockTeamsAPI + testName := "TestInitHumanUsers" + + cl.OpConfig.EnableTeamSuperuser = true + cl.OpConfig.EnableTeamsAPI = true + cl.OpConfig.PamRoleName = "zalandos" + cl.Spec.TeamID = "test" + + tests := []struct { + existingRoles map[string]spec.PgUser + teamRoles []string + result map[string]spec.PgUser + }{ + { + existingRoles: map[string]spec.PgUser{"foo": {Name: "foo", Flags: []string{"NOLOGIN"}}, + "bar": {Name: "bar", Flags: []string{"NOLOGIN"}}}, + teamRoles: []string{"foo"}, + result: map[string]spec.PgUser{"foo": {Name: "foo", MemberOf: []string{cl.OpConfig.PamRoleName}, Flags: []string{"LOGIN", "SUPERUSER"}}, + "bar": {Name: "bar", Flags: []string{"NOLOGIN"}}}, + }, + } + + for _, tt := range tests { + cl.pgUsers = tt.existingRoles + mockTeamsAPI.setMembers(tt.teamRoles) + if err := cl.initHumanUsers(); err != nil { + t.Errorf("%s got an unexpected error %v", testName, err) + } + + if !reflect.DeepEqual(cl.pgUsers, tt.result) { + t.Errorf("%s expects %#v, got %#v", testName, tt.result, cl.pgUsers) + } + } +} diff --git a/pkg/cluster/util.go b/pkg/cluster/util.go index 257601ed7..3c85017dc 100644 --- a/pkg/cluster/util.go +++ b/pkg/cluster/util.go @@ -16,10 +16,46 @@ import ( "github.com/zalando-incubator/postgres-operator/pkg/spec" "github.com/zalando-incubator/postgres-operator/pkg/util" "github.com/zalando-incubator/postgres-operator/pkg/util/constants" + "github.com/zalando-incubator/postgres-operator/pkg/util/k8sutil" "github.com/zalando-incubator/postgres-operator/pkg/util/retryutil" "sort" ) +// OAuthTokenGetter provides the method for fetching OAuth tokens +type OAuthTokenGetter interface { + getOAuthToken() (string, error) +} + +// OAuthTokenGetter enables fetching OAuth tokens by reading Kubernetes secrets +type SecretOauthTokenGetter struct { + kubeClient *k8sutil.KubernetesClient + OAuthTokenSecretName spec.NamespacedName +} + +func NewSecretOauthTokenGetter(kubeClient *k8sutil.KubernetesClient, + OAuthTokenSecretName spec.NamespacedName) *SecretOauthTokenGetter { + return &SecretOauthTokenGetter{kubeClient, OAuthTokenSecretName} +} + +func (g *SecretOauthTokenGetter) getOAuthToken() (string, error) { + //TODO: we can move this function to the Controller in case it will be needed there. As for now we use it only in the Cluster + // Temporary getting postgresql-operator secret from the NamespaceDefault + credentialsSecret, err := g.kubeClient. + Secrets(g.OAuthTokenSecretName.Namespace). + Get(g.OAuthTokenSecretName.Name, metav1.GetOptions{}) + + if err != nil { + return "", fmt.Errorf("could not get credentials secret: %v", err) + } + data := credentialsSecret.Data + + if string(data["read-only-token-type"]) != "Bearer" { + return "", fmt.Errorf("wrong token type: %v", data["read-only-token-type"]) + } + + return string(data["read-only-token-secret"]), nil +} + func isValidUsername(username string) bool { return userRegexp.MatchString(username) } @@ -150,26 +186,6 @@ func (c *Cluster) logVolumeChanges(old, new spec.Volume) { c.logger.Debugf("diff\n%s\n", util.PrettyDiff(old, new)) } -func (c *Cluster) getOAuthToken() (string, error) { - //TODO: we can move this function to the Controller in case it will be needed there. As for now we use it only in the Cluster - // Temporary getting postgresql-operator secret from the NamespaceDefault - credentialsSecret, err := c.KubeClient. - Secrets(c.OpConfig.OAuthTokenSecretName.Namespace). - Get(c.OpConfig.OAuthTokenSecretName.Name, metav1.GetOptions{}) - - if err != nil { - c.logger.Debugf("oauth token secret name: %q", c.OpConfig.OAuthTokenSecretName) - return "", fmt.Errorf("could not get credentials secret: %v", err) - } - data := credentialsSecret.Data - - if string(data["read-only-token-type"]) != "Bearer" { - return "", fmt.Errorf("wrong token type: %v", data["read-only-token-type"]) - } - - return string(data["read-only-token-secret"]), nil -} - func (c *Cluster) getTeamMembers() ([]string, error) { if c.Spec.TeamID == "" { return nil, fmt.Errorf("no teamId specified") @@ -179,7 +195,7 @@ func (c *Cluster) getTeamMembers() ([]string, error) { return []string{}, nil } - token, err := c.getOAuthToken() + token, err := c.oauthTokenGetter.getOAuthToken() if err != nil { return []string{}, fmt.Errorf("could not get oauth token: %v", err) } diff --git a/pkg/util/teams/teams.go b/pkg/util/teams/teams.go index c371ec85d..8645871ba 100644 --- a/pkg/util/teams/teams.go +++ b/pkg/util/teams/teams.go @@ -43,6 +43,10 @@ type httpClient interface { Do(req *http.Request) (*http.Response, error) } +type Interface interface { + TeamInfo(teamID, token string) (tm *Team, err error) +} + // API describes teams API type API struct { httpClient