diff --git a/pkg/cluster/database.go b/pkg/cluster/database.go index ff08c6c24..bb640ade0 100644 --- a/pkg/cluster/database.go +++ b/pkg/cluster/database.go @@ -39,7 +39,12 @@ const ( createExtensionSQL = `CREATE EXTENSION IF NOT EXISTS "%s" SCHEMA "%s"` alterExtensionSQL = `ALTER EXTENSION "%s" SET SCHEMA "%s"` + getPublicationsSQL = `SELECT p.pubname, string_agg(pt.schemaname || . || pt.tablename, ', ' ORDER BY pt.schemaname, pt.tablename) + FROM pg_publication p + JOIN pg_publication_tables pt ON pt.pubname = p.pubname + GROUP BY p.pubname;` createPublicationSQL = `CREATE PUBLICATION "%s" FOR TABLE %s WITH (publish = 'insert, update');` + alterPublicationSQL = `ALTER PUBLICATION "%s" SET TABLE %s;` globalDefaultPrivilegesSQL = `SET ROLE TO "%s"; ALTER DEFAULT PRIVILEGES GRANT USAGE ON SCHEMAS TO "%s","%s"; @@ -508,6 +513,67 @@ func (c *Cluster) execCreateOrAlterExtension(extName, schemaName, statement, doi return nil } +// getPublications returns the list of current database publications with tables +// The caller is responsible for opening and closing the database connection +func (c *Cluster) getPublications() (publications map[string]string, err error) { + var ( + rows *sql.Rows + dbPublications map[string]string + ) + + if rows, err = c.pgDb.Query(getPublicationsSQL); err != nil { + return nil, fmt.Errorf("could not query database publications: %v", err) + } + + defer func() { + if err2 := rows.Close(); err2 != nil { + if err != nil { + err = fmt.Errorf("error when closing query cursor: %v, previous error: %v", err2, err) + } else { + err = fmt.Errorf("error when closing query cursor: %v", err2) + } + } + }() + + for rows.Next() { + var ( + dbPublication string + dbPublicationTables string + ) + + if err = rows.Scan(&dbPublication, &dbPublicationTables); err != nil { + return nil, fmt.Errorf("error when processing row: %v", err) + } + dbPublications[dbPublication] = dbPublicationTables + } + + return dbPublications, err +} + +// executeCreatePublication creates new publication for given tables +// The caller is responsible for opening and closing the database connection. +func (c *Cluster) executeCreatePublication(pubName, tableList string) error { + return c.execCreateOrAlterPublication(pubName, tableList, createPublicationSQL, + "creating publication", "create publication") +} + +// executeAlterExtension changes the table list of the given publication. +// The caller is responsible for opening and closing the database connection. +func (c *Cluster) executeAlterPublication(pubName, tableList string) error { + return c.execCreateOrAlterPublication(pubName, tableList, alterPublicationSQL, + "changing table list of publication", "alter publication tables") +} + +func (c *Cluster) execCreateOrAlterPublication(pubName, tableList, statement, doing, operation string) error { + + c.logger.Infof("%s %q table list %q", doing, pubName, tableList) + if _, err := c.pgDb.Exec(fmt.Sprintf(statement, pubName, tableList)); err != nil { + return fmt.Errorf("could not execute %s: %v", operation, err) + } + + return nil +} + // Creates a connection pool credentials lookup function in every database to // perform remote authentication. func (c *Cluster) installLookupFunction(poolerSchema, poolerUser string) error { @@ -612,23 +678,3 @@ func (c *Cluster) installLookupFunction(poolerSchema, poolerUser string) error { return nil } - -// getExtension returns the list of current database extensions -// The caller is responsible for opening and closing the database connection -func (c *Cluster) createPublication(dbName, publication, tables string) (err error) { - - if err := c.initDbConnWithName(dbName); err != nil { - return fmt.Errorf("could not init connection to database %q", dbName) - } - defer func() { - if err = c.closeDbConn(); err != nil { - err = fmt.Errorf("could not close connection to database %q: %v", dbName, err) - } - }() - - if _, err := c.pgDb.Exec(fmt.Sprintf(createPublicationSQL, publication, tables)); err != nil { - return fmt.Errorf("could not create publication %s: %v", publication, err) - } - - return err -} diff --git a/pkg/cluster/streams.go b/pkg/cluster/streams.go index ea8381eee..ebd110c32 100644 --- a/pkg/cluster/streams.go +++ b/pkg/cluster/streams.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" acidv1 "github.com/zalando/postgres-operator/pkg/apis/acid.zalan.do/v1" @@ -77,6 +78,15 @@ func (c *Cluster) syncPostgresConfig() error { slots := make(map[string]map[string]string) publications := make(map[string]map[string]acidv1.StreamTable) + createPublications := make(map[string]string) + alterPublications := make(map[string]string) + + defer func() { + if err := c.closeDbConn(); err != nil { + c.logger.Errorf("could not close database connection: %v", err) + } + }() + desiredPatroniConfig := c.Spec.Patroni if len(desiredPatroniConfig.Slots) > 0 { slots = desiredPatroniConfig.Slots @@ -135,9 +145,19 @@ func (c *Cluster) syncPostgresConfig() error { } } - // next create publications to each created slot + // next, create publications to each created slot for publication, tables := range publications { + // but first check for existing publications dbName := slots[publication]["database"] + if err := c.initDbConnWithName(dbName); err != nil { + return fmt.Errorf("could not init database connection") + } + + currentPublications, err := c.getPublications() + if err != nil { + return fmt.Errorf("could not get current publications: %v", err) + } + tableNames := make([]string, len(tables)) i := 0 for t := range tables { @@ -145,10 +165,29 @@ func (c *Cluster) syncPostgresConfig() error { tableNames[i] = fmt.Sprintf("%q.%q", schemaName, tableName) i++ } + sort.Strings(tableNames) tableList := strings.Join(tableNames, ", ") - c.logger.Debugf("creating publication %q in database %q for tables %s", publication, dbName, tableList) - if err := c.createPublication(dbName, publication, tableList); err != nil { - c.logger.Warningf("%v", err) + + currentTables, exists := currentPublications[publication] + if !exists { + createPublications[publication] = tableList + } else if currentTables != tableList { + alterPublications[publication] = tableList + } + + if len(createPublications)+len(alterPublications) == 0 { + return nil + } + + for publicationName, tables := range createPublications { + if err = c.executeCreatePublication(publicationName, tables); err != nil { + c.logger.Warningf("%v", err) + } + } + for publicationName, tables := range alterPublications { + if err = c.executeAlterPublication(publicationName, tables); err != nil { + c.logger.Warningf("%v", err) + } } }