postgres-operator/ui/operator_ui/spiloutils.py

347 lines
8.8 KiB
Python

from boto3 import client
from datetime import datetime, timezone
from furl import furl
from json import dumps, loads
from os import environ, getenv
from requests import Session
from urllib.parse import urljoin
from uuid import UUID
from wal_e.cmd import configure_backup_cxt
from .utils import Attrs, defaulting, these
from operator_ui.adapters.logger import logger
session = Session()
AWS_ENDPOINT = getenv('AWS_ENDPOINT')
OPERATOR_CLUSTER_NAME_LABEL = getenv('OPERATOR_CLUSTER_NAME_LABEL', 'cluster-name')
COMMON_CLUSTER_LABEL = getenv('COMMON_CLUSTER_LABEL', '{"application":"spilo"}')
COMMON_POOLER_LABEL = getenv('COMMON_POOLER_LABEL', '{"application":"db-connection-pooler"}')
logger.info("Common Cluster Label: {}".format(COMMON_CLUSTER_LABEL))
logger.info("Common Pooler Label: {}".format(COMMON_POOLER_LABEL))
COMMON_CLUSTER_LABEL = loads(COMMON_CLUSTER_LABEL)
COMMON_POOLER_LABEL = loads(COMMON_POOLER_LABEL)
def request(cluster, path, **kwargs):
if 'timeout' not in kwargs:
# sane default timeout
kwargs['timeout'] = (5, 15)
if cluster.cert_file and cluster.key_file:
kwargs['cert'] = (cluster.cert_file, cluster.key_file)
return session.get(
urljoin(cluster.api_server_url, path),
auth=cluster.auth,
verify=cluster.ssl_ca_cert,
**kwargs
)
def request_post(cluster, path, data, **kwargs):
if 'timeout' not in kwargs:
# sane default timeout
kwargs['timeout'] = 5
if cluster.cert_file and cluster.key_file:
kwargs['cert'] = (cluster.cert_file, cluster.key_file)
return session.post(
urljoin(cluster.api_server_url, path),
data=data,
auth=cluster.auth,
verify=cluster.ssl_ca_cert,
**kwargs
)
def request_put(cluster, path, data, **kwargs):
if 'timeout' not in kwargs:
# sane default timeout
kwargs['timeout'] = 5
if cluster.cert_file and cluster.key_file:
kwargs['cert'] = (cluster.cert_file, cluster.key_file)
return session.put(
urljoin(cluster.api_server_url, path),
data=data,
auth=cluster.auth,
verify=cluster.ssl_ca_cert,
**kwargs
)
def request_delete(cluster, path, **kwargs):
if 'timeout' not in kwargs:
# sane default timeout
kwargs['timeout'] = 5
if cluster.cert_file and cluster.key_file:
kwargs['cert'] = (cluster.cert_file, cluster.key_file)
return session.delete(
urljoin(cluster.api_server_url, path),
auth=cluster.auth,
verify=cluster.ssl_ca_cert,
**kwargs
)
def resource_api_version(resource_type):
return {
'postgresqls': 'apis/acid.zalan.do/v1',
'statefulsets': 'apis/apps/v1',
'deployments': 'apis/apps/v1',
}.get(resource_type, 'api/v1')
def encode_labels(label_selector):
return ','.join([
f'{label}={value}'
for label, value in label_selector.items()
])
def cluster_labels(spilo_cluster):
labels = COMMON_CLUSTER_LABEL
labels[OPERATOR_CLUSTER_NAME_LABEL] = spilo_cluster
return labels
def kubernetes_url(
resource_type,
namespace='default',
resource_name=None,
label_selector=None,
):
return furl('/').add(
path=(
resource_api_version(resource_type).split('/')
+ (
['namespaces', namespace]
if namespace
else []
)
+ [resource_type]
+ (
[resource_name]
if resource_name
else []
)
),
args=(
{'labelSelector': encode_labels(label_selector)}
if label_selector
else {}
),
).url
def kubernetes_get(cluster, **kwargs):
response = request(cluster, kubernetes_url(**kwargs))
if response.status_code == 404:
return None
if response.status_code >= 400:
response.raise_for_status()
return response.json()
def read_pods(cluster, namespace, spilo_cluster):
return kubernetes_get(
cluster=cluster,
resource_type='pods',
namespace=namespace,
label_selector=cluster_labels(spilo_cluster),
)
def read_pod(cluster, namespace, resource_name):
return kubernetes_get(
cluster=cluster,
resource_type='pods',
namespace=namespace,
resource_name=resource_name,
label_selector=COMMON_CLUSTER_LABEL,
)
def read_service(cluster, namespace, resource_name):
return kubernetes_get(
cluster=cluster,
resource_type='services',
namespace=namespace,
resource_name=resource_name,
label_selector=COMMON_CLUSTER_LABEL,
)
def read_pooler(cluster, namespace, resource_name):
return kubernetes_get(
cluster=cluster,
resource_type='deployments',
namespace=namespace,
resource_name=resource_name,
label_selector=COMMON_POOLER_LABEL,
)
def read_statefulset(cluster, namespace, resource_name):
return kubernetes_get(
cluster=cluster,
resource_type='statefulsets',
namespace=namespace,
resource_name=resource_name,
label_selector=COMMON_CLUSTER_LABEL,
)
def read_postgresql(cluster, namespace, resource_name):
return kubernetes_get(
cluster=cluster,
resource_type='postgresqls',
namespace=namespace,
resource_name=resource_name,
)
def read_postgresqls(cluster, namespace):
return kubernetes_get(
cluster=cluster,
resource_type='postgresqls',
namespace=namespace,
)
def read_namespaces(cluster):
return kubernetes_get(
cluster=cluster,
resource_type='namespaces',
namespace=None,
)
def create_postgresql(cluster, namespace, definition):
path = kubernetes_url(
resource_type='postgresqls',
namespace=namespace,
)
try:
r = request_post(cluster, path, dumps(definition))
r.raise_for_status()
return True
except Exception as ex:
logger.exception("K8s create request failed")
return False
def apply_postgresql(cluster, namespace, resource_name, definition):
path = kubernetes_url(
resource_type='postgresqls',
namespace=namespace,
resource_name=resource_name,
)
try:
r = request_put(cluster, path, dumps(definition))
r.raise_for_status()
return True
except Exception as ex:
logger.exception("K8s create request failed")
return False
def remove_postgresql(cluster, namespace, resource_name):
path = kubernetes_url(
resource_type='postgresqls',
namespace=namespace,
resource_name=resource_name,
)
try:
r = request_delete(cluster, path)
r.raise_for_status()
return True
except Exception as ex:
logger.exception("K8s delete request failed")
return False
def read_stored_clusters(bucket, prefix, delimiter='/'):
return [
prefix['Prefix'].split('/')[-2]
for prefix in these(
client('s3', endpoint_url=AWS_ENDPOINT).list_objects(
Bucket=bucket,
Delimiter=delimiter,
Prefix=prefix,
),
'CommonPrefixes',
)
]
def read_versions(
pg_cluster,
bucket,
s3_endpoint,
prefix,
delimiter='/',
use_aws_instance_profile=False,
):
return [
'base' if uid == 'wal' else uid
for prefix in these(
client('s3', endpoint_url=AWS_ENDPOINT).list_objects(
Bucket=bucket,
Delimiter=delimiter,
Prefix=prefix + pg_cluster + delimiter,
),
'CommonPrefixes',
)
for uid in [prefix['Prefix'].split('/')[-2]]
if uid == 'wal' or defaulting(lambda: UUID(uid))
]
BACKUP_VERSION_PREFIXES = ['', '10/', '11/', '12/', '13/', '14/', '15/', '16/']
def read_basebackups(
pg_cluster,
uid,
bucket,
s3_endpoint,
prefix,
delimiter='/',
use_aws_instance_profile=False,
):
environ['WALE_S3_ENDPOINT'] = s3_endpoint
suffix = '' if uid == 'base' else '/' + uid
backups = []
for vp in BACKUP_VERSION_PREFIXES:
backups = backups + [
{
key: value
for key, value in basebackup.__dict__.items()
if isinstance(value, str) or isinstance(value, int)
}
for basebackup in Attrs.call(
f=configure_backup_cxt,
aws_instance_profile=use_aws_instance_profile,
s3_prefix=f's3://{bucket}/{prefix}{pg_cluster}{suffix}/wal/{vp}',
)._backup_list(detail=True)
]
return backups
def parse_time(s: str):
return (
datetime.strptime(s, '%Y-%m-%dT%H:%M:%SZ')
.replace(tzinfo=timezone.utc)
.timestamp()
)