To add multi-tenancy to an existing application I settled on the following:
- Create a
tenant
table which holds a row for every tenant. Tenants are externally identified by a unique key. - Add a
tenant_id
column to every existing table which referencestenant(id)
. - Create a new role (
application
) without superuser privileges and grant it ownership of all tables. - Create a row level security policy on each table to check that the
tenant_id
column is equal tocurrent_setting('application.tenant.current_id')
. - Make sure the role is set to
application
andapplication.tenant.current_id
is set before performing any queries.
To set up the database:
-- Set up application role
do
$$
declare
current_db text;
begin
if not exists(select oid from pg_roles where rolname = 'application') then
create role application with nologin;
select current_database() into current_db;
execute format('grant all privileges on database %I to application', current_db);
end if;
execute 'alter default privileges grant usage on schemas to application';
execute 'alter default privileges grant create on schemas to application';
execute (
select string_agg(
format('grant usage on schema %1$I to %2$I;' ||
'grant create on schema %1$I to %2$I;' ||
'grant all privileges on all tables in schema %1$I to %2$I;' ||
'grant all privileges on all sequences in schema %1$I to %2$I;' ||
'alter default privileges in schema %1$I grant all privileges on tables to %2$I;' ||
'alter default privileges in schema %1$I grant all privileges on sequences to %2$I;',
nspname, 'application'
),
'; ')
from pg_namespace
where nspname <> 'information_schema' -- exclude information schema and ...
and nspname not like 'pg\_%' -- ... system schemas
);
end;
$$ language plpgsql;
-- Utility function to set the role from application code
create or replace function set_role(role text, local boolean)
returns void as
$$
begin
if local then
if role is not null then
execute format('set local role %I', role);
else
execute format('set local role none');
end if;
else
if role is not null then
execute format('set session role %I', role);
else
execute format('set session role none');
end if;
end if;
end;
$$ language plpgsql;
create table tenant
(
id bigint primary key,
key text not null unique,
);
-- Create the default tenant
insert into tenant(id, key)
values (1, 'default');
-- Helper functions to get and set tenant from application code
create or replace function get_current_tenant()
returns bigint as
$$
begin
return nullif(current_setting('application.tenant.current_id', true), '')::bigint;
end
$$ language plpgsql stable;
create or replace function set_current_tenant(tenant_id bigint)
returns void as
$$
begin
perform
set_config('application.tenant.current_id', tenant_id::text, true);
end
$$ language plpgsql volatile;
-- Add tenant id to (almost) all tables and enable row level security
do
$$
declare
item record;
begin
for item in (
select table_name
from information_schema.tables
where table_schema = 'public'
and table_name != all (array ['tenant', 'schema_version'])
and table_name not like 'pg\_%' -- postgres specific tables
)
loop
execute format('alter table %1$I add column tenant_id bigint default 1 references tenant(id);', item.table_name);
execute format('alter table %1$I alter column tenant_id set default get_current_tenant();', item.table_name);
execute format('alter table %1$I alter column tenant_id set not null;', item.table_name);
execute format('create index on %1$I(tenant_id);', item.table_name);
execute format('alter table %1$I enable row level security;', item.table_name);
execute format('create policy tenant on %1$I for all to public using (tenant_id = get_current_tenant());', item.table_name);
end loop;
end;
$$ language plpgsql;
-- Recreate all indices (in particular unique indices) to include tenant_id as the second column
--- ...omitted...
To make sure the role and tenant id are set for every transaction I added a custom PlatformTransactionManager that delegates to the original PlatformTransactionManager and runs my transaction initialization code (Kotlin/Spring Boot/jOOQ):
fun interface TransactionInitializer {
operator fun invoke()
}
open class TransactionInitializingTransactionManager(
private val baseTransactionManager: PlatformTransactionManager,
private val initializer: TransactionInitializer,
) : PlatformTransactionManager {
override fun getTransaction(transactionDefinition: TransactionDefinition): TransactionStatus {
val status = baseTransactionManager.getTransaction(transactionDefinition)
if (status.isNewTransaction) {
try {
initializer()
} catch (t: Throwable) {
rollback(status)
throw t
}
}
return status
}
override fun commit(status: TransactionStatus) {
baseTransactionManager.commit(status)
}
override fun rollback(status: TransactionStatus) {
baseTransactionManager.rollback(status)
}
}
typealias RoleSupplier = () -> String?
@Transaction(propagation = Propagation.MANDATORY)
open class RoleSettingTransactionInitializer(
private val ctx: DSLContext,
private val roleSupplier: RoleSupplier
) : TransactionInitializer {
override fun invoke() {
val role = roleSupplier()
Routines.setRole(ctx.configuration(), role, true)
}
}
typealias TenantIdSupplier = () -> Long?
@Tx(propagation = Propagation.MANDATORY)
open class TenantByIdSettingTransactionInitializer(
private val ctx: DSLContext,
private val tenantIdSupplier: TenantIdSupplier,
) : TransactionInitializer {
override fun invoke() {
val tenantId = tenantIdSupplier()
Routines.setCurrentTenant(ctx.configuration(), tenantId)
}
}
@Tx(propagation = Propagation.MANDATORY)
open class CompositeTransactionInitializer(
private val initializers: List<TransactionInitializer>
) : TransactionInitializer {
override fun invoke() {
for (initializer in initializers) {
initializer.invoke()
}
}
}
fun makeTransactionInitializer(
ctx: DSLContext,
roleSupplier: RoleSupplier?,
tenantIdSupplier: TenantIdSupplier?,
): TransactionInitializer {
val initializers = mutableListOf<TransactionInitializer>()
if (roleSupplier != null) {
initializers.add(
RoleSettingTransactionInitializer(
ctx = ctx,
roleSupplier = roleSupplier,
)
)
}
if (tenantIdSupplier != null) {
initializers.add(
TenantByIdSettingTransactionInitializer(
ctx = ctx,
tenantIdSupplier = tenantIdSupplier,
)
)
}
return CompositeTransactionInitializer(
initializers = initializers,
)
}
To make sure nobody forgets to add a RLS policy to a newly created table, I added a test that introspects the schema and ensures every table (except for some that have been explicitly excluded) has a RLS policy applied.