Create new databases and change owners of existing ones during sync. (#153)

* Create new databases and change owners of existing ones during sync.
This commit is contained in:
Oleksii Kliukin 2017-11-02 17:46:33 +01:00 committed by GitHub
parent d3679bfd4a
commit ce960e892a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 116 additions and 51 deletions

View File

@ -255,12 +255,12 @@ func (c *Cluster) Create() error {
if err = c.createRoles(); err != nil {
return fmt.Errorf("could not create users: %v", err)
}
c.logger.Infof("users have been successfully created")
if err = c.createDatabases(); err != nil {
return fmt.Errorf("could not create databases: %v", err)
}
c.logger.Infof("users have been successfully created")
c.logger.Infof("databases have been successfully created")
} else {
if c.masterLess {
c.logger.Warnln("cluster is masterless")

View File

@ -25,8 +25,9 @@ const (
WHERE a.rolname = ANY($1)
ORDER BY 1;`
getDatabasesSQL = `SELECT datname, a.rolname AS owner FROM pg_database d INNER JOIN pg_authid a ON a.oid = d.datdba;`
getDatabasesSQL = `SELECT datname, pg_get_userbyid(datdba) AS owner FROM pg_database;`
createDatabaseSQL = `CREATE DATABASE "%s" OWNER "%s";`
alterDatabaseOwnerSQL = `ALTER DATABASE "%s" OWNER TO "%s";`
)
func (c *Cluster) pgConnectionString() string {
@ -137,6 +138,8 @@ func (c *Cluster) readPgUsersFromDatabase(userNames []string) (users spec.PgUser
return users, nil
}
// getDatabases returns the map of current databases with owners
// The caller is responsible for opening and closing the database connection
func (c *Cluster) getDatabases() (map[string]string, error) {
var (
rows *sql.Rows
@ -144,15 +147,6 @@ func (c *Cluster) getDatabases() (map[string]string, error) {
)
dbs := make(map[string]string)
if err = c.initDbConn(); err != nil {
return nil, fmt.Errorf("could not init db connection")
}
defer func() {
if err = c.closeDbConn(); err != nil {
c.logger.Errorf("could not close db connection: %v", err)
}
}()
if rows, err = c.pgDb.Query(getDatabasesSQL); err != nil {
return nil, fmt.Errorf("could not query database: %v", err)
}
@ -176,49 +170,44 @@ func (c *Cluster) getDatabases() (map[string]string, error) {
return dbs, nil
}
func (c *Cluster) createDatabases() error {
c.setProcessName("creating databases")
newDbs := c.Spec.Databases
curDbs, err := c.getDatabases()
if err != nil {
return fmt.Errorf("could not get current databases: %v", err)
}
for datname := range curDbs {
delete(newDbs, datname)
// executeCreateDatabase creates new database with the given owner.
// The caller is responsible for openinging and closing the database connection.
func (c *Cluster) executeCreateDatabase(datname, owner string) error {
if !c.databaseNameOwnerValid(datname, owner) {
return nil
}
c.logger.Infof("creating database %q with owner %q", datname, owner)
if len(newDbs) == 0 {
if _, err := c.pgDb.Query(fmt.Sprintf(createDatabaseSQL, datname, owner)); err != nil {
return fmt.Errorf("could not execute create database: %v", err)
}
return nil
}
if err = c.initDbConn(); err != nil {
return fmt.Errorf("could not init database connection")
// executeCreateDatabase changes the owner of the given database.
// The caller is responsible for openinging and closing the database connection.
func (c *Cluster) executeAlterDatabaseOwner(datname string, owner string) error {
if !c.databaseNameOwnerValid(datname, owner) {
return nil
}
defer func() {
if err = c.closeDbConn(); err != nil {
c.logger.Errorf("could not close database connection: %v", err)
c.logger.Infof("changing database %q owner to %q", datname, owner)
if _, err := c.pgDb.Query(fmt.Sprintf(alterDatabaseOwnerSQL, datname, owner)); err != nil {
return fmt.Errorf("could not execute alter database owner: %v", err)
}
return nil
}
}()
for datname, owner := range newDbs {
func (c *Cluster) databaseNameOwnerValid(datname, owner string) bool {
if _, ok := c.pgUsers[owner]; !ok {
c.logger.Infof("skipping creation of the %q database, user %q does not exist", datname, owner)
continue
return false
}
if !databaseNameRegexp.MatchString(datname) {
c.logger.Infof("database %q has invalid name", datname)
continue
return false
}
c.logger.Infof("creating database %q with owner %q", datname, owner)
if _, err = c.pgDb.Query(fmt.Sprintf(createDatabaseSQL, datname, owner)); err != nil {
return fmt.Errorf("could not query database: %v", err)
}
}
return nil
return true
}
func makeUserFlags(rolsuper, rolinherit, rolcreaterole, rolcreatedb, rolcanlogin bool) (result []string) {

View File

@ -557,3 +557,27 @@ func (c *Cluster) GetStatefulSet() *v1beta1.StatefulSet {
func (c *Cluster) GetPodDisruptionBudget() *policybeta1.PodDisruptionBudget {
return c.PodDisruptionBudget
}
func (c *Cluster) createDatabases() error {
c.setProcessName("creating databases")
if len(c.Spec.Databases) == 0 {
return nil
}
if err := c.initDbConn(); err != nil {
return fmt.Errorf("could not init database connection")
}
defer func() {
if err := c.closeDbConn(); err != nil {
c.logger.Errorf("could not close database connection: %v", err)
}
}()
for datname, owner := range c.Spec.Databases {
if err := c.executeCreateDatabase(datname, owner); err != nil {
return err
}
}
return nil
}

View File

@ -90,6 +90,11 @@ func (c *Cluster) Sync(newSpec *spec.Postgresql) (err error) {
err = fmt.Errorf("could not sync roles: %v", err)
return
}
c.logger.Debugf("syncing databases")
if err = c.syncDatabases(); err != nil {
err = fmt.Errorf("could not sync databases: %v", err)
return
}
}
c.logger.Debugf("syncing persistent volumes")
@ -292,3 +297,50 @@ func (c *Cluster) samePDBWith(pdb *policybeta1.PodDisruptionBudget) (match bool,
return
}
func (c *Cluster) syncDatabases() error {
c.setProcessName("syncing databases")
createDatabases := make(map[string]string)
alterOwnerDatabases := make(map[string]string)
if err := c.initDbConn(); err != nil {
return fmt.Errorf("could not init database connection")
}
defer func() {
if err := c.closeDbConn(); err != nil {
c.logger.Errorf("could not close database connection: %v", err)
}
}()
currentDatabases, err := c.getDatabases()
if err != nil {
return fmt.Errorf("could not get current databases: %v", err)
}
for datname, newOwner := range c.Spec.Databases {
currentOwner, exists := currentDatabases[datname]
if !exists {
createDatabases[datname] = newOwner
} else if currentOwner != newOwner {
alterOwnerDatabases[datname] = newOwner
}
}
if len(createDatabases)+len(alterOwnerDatabases) == 0 {
return nil
}
for datname, owner := range createDatabases {
if err = c.executeCreateDatabase(datname, owner); err != nil {
return err
}
}
for datname, owner := range alterOwnerDatabases {
if err = c.executeAlterDatabaseOwner(datname, owner); err != nil {
return err
}
}
return nil
}