Flask and vault


When using dynamic database credentials with Flask, we need to make sure that the flask instance picks up the right credentials, renews them when necessary, and uses the right roles.

My flask code is pretty embedded with the database changes here, so pardon the dust, but I think it's relatively easy to follow.

Configuration parameters are either from the config file or they are taken from environment variables.

Parameter Required Purpose Default
VAULT_ROLE dynamic database role to use None
DB_ROLE role to assume in connection None
SQLALCHEMY_DATABASE_URI URI for datbase no default

The application below is named telemetry_ingest and uses TELEMETRY_INGEST as the prefix for any environment variables that are used for configuration. This is mostly interesting if you are going to adapt this code elsewhere, since you need to remember to pull those out.

Vault use is triggered by the presence of the VAULT_ROLE parameter, since the vault credentials may or may not be necessary depending on the environment. If they are present in the config, this code will push them to the libraries, otherwise they'll come as None and hvac will use its defaults from the environment or statically.

Authentication data is stored in the auth global in this module and is initialized when the application starts. The logic to get and renew the authentication data is in get_vault_credentials().

Of particular interest is the event handling at the bottom in the with app.app_context() stanza. This adds event handlers for do_connect (called before the connection, so we can load the credentials), checkout (called when a connection is "checked out" to do something, where we verify the connection), and connect (where we set the database role if requested). Finally, the standard configuration is done, registering the blueprint for the actions.

import datetime
import os
from typing import Optional

import hvac
from flask import Flask
from sqlalchemy import event
from sqlalchemy.exc import DisconnectionError

auth = {}


def create_app(test_config=None):
    # create database
    # create and configure the app
    app = Flask(__name__)
    app.config.from_object("telemetry_ingest.default_settings")
    if test_config is None:
        # If we want to read from py file instead of prefixed variables
        # if os.environ['TELEMETRY_INGEST_SETTINGS']:
        #     app.config.from_envvar('TELEMETRY_INGEST_SETTINGS')
        app.config.from_prefixed_env("TELEMETRY_INGEST")
    else:
        # load the test config if passed in
        app.config.from_mapping(test_config)

    from telemetry_ingest.models import db

    db.init_app(app)

    def use_vault() -> bool:
        if requested_credential() is None:
            return False
        return True

    def requested_role() -> Optional[str]:
        if "DB_ROLE" not in app.config:
            return None
        return app.config["DB_ROLE"]

    def requested_credential() -> Optional[str]:
        if "VAULT_ROLE" not in app.config:
            return None
        return app.config["VAULT_ROLE"]

    def get_vault_credentials(existing=None):
        if not use_vault():
            return None
        client = hvac.Client(
            url=os.environ["VAULT_ADDR"], token=os.environ["VAULT_TOKEN"]
        )
        assert client.is_authenticated()
        if existing is not None:
            if (
                existing["response"]["renewable"]
                and datetime.datetime.now() < existing["vault_expire"]
            ):
                try:
                    renew_response = client.sys.renew_lease(existing["vault_lease_id"])
                    new_auth = existing
                    new_auth[
                        "vault_expire"
                    ] = datetime.datetime.now() + datetime.timedelta(
                        seconds=renew_response["lease_duration"]
                    )
                    new_auth[
                        "vault_renew"
                    ] = datetime.datetime.now() + datetime.timedelta(
                        seconds=renew_response["lease_duration"] / 2
                    )
                    return new_auth
                except hvac.v1.exceptions.VaultError:
                    app.logger.debug("lease renewal failed")
                    pass
        read_response = client.secrets.database.generate_credentials(
            requested_credential()
        )
        app.logger.debug("new lease")
        new_auth = {
            "user": read_response["data"]["username"],
            "password": read_response["data"]["password"],
            "vault_lease_id": read_response["lease_id"],
            "vault_expire": datetime.datetime.now()
            + datetime.timedelta(seconds=read_response["lease_duration"]),
            "vault_renew": datetime.datetime.now()
            + datetime.timedelta(seconds=read_response["lease_duration"] / 2),
            "response": read_response,
        }
        return new_auth

    global auth
    auth = get_vault_credentials()

    with app.app_context():
        # https://docs.sqlalchemy.org/en/20/core/engines.html#custom-dbapi-args
        @event.listens_for(db.engine, "do_connect")
        def provide_credentials(dialect, conn_rec, cargs, cparams):
            if use_vault():
                global auth
                cparams["user"] = auth["user"]
                cparams["password"] = auth["password"]

        @event.listens_for(db.engine, "checkout")
        def validate_checkout(dbapi_connection, connection_record, connection_proxy):
            if not use_vault():
                return
            global auth
            if datetime.datetime.now() > auth["vault_renew"]:
                app.logger.debug("credentials expired")
                auth = get_vault_credentials(auth)
                raise DisconnectionError()

        @event.listens_for(db.engine, "connect")
        def set_role_on_connect(dbapi_connection, connection_record):
            if requested_role() is None:
                return
            with dbapi_connection.cursor() as cursor:
                cursor.execute("SET ROLE '" + requested_role() + "'")

        from telemetry_ingest.routes import telemetry, redirection

        app.register_blueprint(telemetry)
        app.register_blueprint(redirection)
        db.create_all()
        return app