ssm/ssm.py

208 lines
5.9 KiB
Python

from __future__ import annotations
import os
import pathlib
import sys
import typing
import boto3
import botocore.client
import botocore.session
import rich.console
import rich.table
import rich.text
import typer
app = typer.Typer()
stdout = rich.console.Console()
stderr = rich.console.Console(stderr=True)
aws_profiles = boto3.session.Session().available_profiles
class SSMPath(pathlib.PurePosixPath):
def __init__(self, *segments: str | os.PathLike):
super().__init__("/", *segments)
@staticmethod
def parse(value: str) -> SSMPath:
return SSMPath(value)
ProfileArg = typing.Annotated[str, typer.Argument(autocompletion=aws_profiles)]
def get_client(profile: str) -> botocore.client.BaseClient:
try:
session = boto3.Session(profile_name=profile)
return session.client("ssm")
except botocore.session.ProfileNotFound:
stderr.print(f"Invalid profile '{profile}'", style="bold red")
sys.exit(1)
def get_parameters(
client: botocore.client.BaseClient, path: SSMPath, recursive: bool
) -> typing.Iterable[dict]:
try:
parameter = client.get_parameter(Name=str(path))
yield parameter.get("Parameter")
return
except client.exceptions.ParameterNotFound:
pass
results = client.get_paginator("get_parameters_by_path").paginate(
Path=str(path),
Recursive=recursive,
WithDecryption=True,
)
for parameter in (
result for chunk in results for result in chunk.get("Parameters")
):
if parameter.get("Name"):
yield parameter
@app.command()
def profiles() -> None:
"""List available AWS profiles."""
for profile in aws_profiles:
print(profile)
@app.command("ls")
def list_parameters(
profile: ProfileArg,
path: pathlib.Path,
recursive: bool = True,
) -> None:
"""List parameters and their values at a requested PATH."""
root = SSMPath(path)
client = get_client(profile)
console = rich.console.Console()
table = rich.table.Table(
"Name",
"Value",
"Description",
"Type",
title=f"{root} ({profile})",
)
try:
for parameter in get_parameters(client, root, recursive):
table.add_row(
str(pathlib.Path(parameter["Name"]).relative_to(root)),
parameter.get("Value"),
parameter.get("Description"),
parameter.get("Type"),
)
except client.exceptions.ClientError as e:
stderr.print(str(e), style="bold red")
sys.exit(1)
console.print(table)
@app.command()
def set(
profile: ProfileArg,
path: str,
value: str,
secure: bool = False,
overwrite: bool = False,
description: typing.Annotated[str, typer.Option] = "",
) -> None:
"""Set a parameter at PATH to VALUE.
If --secure is used, it will be stored as a SecureString.
"""
client = get_client(profile)
try:
client.put_parameter(
Name=str(SSMPath(path)),
Value=value,
Type="SecureString" if secure else "String",
Overwrite=overwrite,
Description=description,
)
except client.exceptions.ParameterAlreadyExists:
stderr.print(
"Parameter already exists; use --overwrite to replace it.",
style="bold red",
)
sys.exit(1)
@app.command()
def unset(profile: ProfileArg, path: str) -> None:
"""Remove a parameter at PATH."""
client = get_client(profile)
try:
client.delete_parameter(Name=str(SSMPath(path)))
except client.exceptions.ParameterNotFound:
stderr.print("Parameter not found", style="yellow")
sys.exit(1)
@app.command("cp")
def copy_tree(
source_profile: ProfileArg,
source_path: str,
dest_path: str,
dest_profile: ProfileArg | None = None,
replacement_pairs: list[str] = typer.Option([], "--replace", "-r"),
recursive: bool = True,
) -> None:
"""Copy parameters from SRC_PATH to DEST_PATH."""
def parse_replacements(values: list[str]) -> list[tuple[str, str]]:
pairs: list[tuple[str, str]] = []
for value in values:
a, b = value.split("=", 1)
pairs.append((a, b))
return pairs
replacements = parse_replacements(replacement_pairs)
def replace(value: str) -> str:
for old, new in replacements:
value = value.replace(old, new)
return value
source = get_client(source_profile)
destination = get_client(dest_profile or source_profile)
sources = {
str(pathlib.Path(p["Name"]).relative_to(SSMPath(source_path))): p
for p in get_parameters(source, SSMPath(source_path), recursive=recursive)
}
targets = {
str(pathlib.Path(p["Name"]).relative_to(SSMPath(dest_path))): p
for p in get_parameters(destination, SSMPath(dest_path), recursive=recursive)
}
table = rich.table.Table("Path", "Old Value", "New Value")
for name, param in sources.items():
old = targets[name]["Value"] if name in targets else None
new = replace(param["Value"])
table.add_row(
name,
rich.text.Text(old, style="red")
if old
else rich.text.Text("Not defined", style="bright_black italic"),
rich.text.Text(new, style="green"),
)
stdout.print(table)
confirmed = typer.confirm("Are you sure you want to apply the above changes?")
if not confirmed:
stderr.print("No changes applied", style="yellow italic")
sys.exit(1)
for name, param in sources.items():
new_name = str(SSMPath(dest_path) / pathlib.Path(name))
stdout.print(f"Writing {new_name}...")
destination.put_parameter(
Name=new_name,
Value=param["Value"],
Type=param["Type"],
Overwrite=True,
Description=param.get("Description", ""),
)
if __name__ == "__main__":
app()