Following my previous post about setting a function-level database setup, which is a junior-level solution, we’ll be looking at a session-level database migration setup.

When initiating a DB-coupled application, one of the initial goals is to set a DB connection function or class for spawning a reliable asynchronous connection with our DB.

Let us cover the most important parts of our setup.

The DB session manager class:


from sqlalchemy.ext.asyncio import (
    async_sessionmaker,
    create_async_engine,
    AsyncEngine,
    AsyncSession,
)


class DBSessionManager:
    def __init__(self, postgres_dsn: str):
        self._engine: AsyncEngine = create_async_engine(url=postgres_dsn)
        self._async_sesionmaker = async_sessionmaker(
            bind=self._engine, expire_on_commit=False
        )

    @property
    def asessionmaker(self) -> async_sessionmaker[AsyncSession]:
        return self._async_sesionmaker

    async def close(self):
        await self._engine.dispose()

With the same set of models:


class Product(Base):
    __tablename__ = 'product'

    id: Mapped[UUID] = mapped_column(
        type_=types.UUID,
        primary_key=True,
        server_default=text('gen_random_uuid()'),
    )
    name: Mapped[str] = mapped_column(
        type_=types.VARCHAR(100), server_default=text("''")
    )
    created_at: Mapped[timestamp] = mapped_column(
        type_=types.TIMESTAMP,
        server_default=text('NOW()'),
    )


class Review(Base):
    __tablename__ = 'review'

    id: Mapped[UUID] = mapped_column(
        type_=types.UUID,
        primary_key=True,
        server_default=text('gen_random_uuid()'),
    )
    content: Mapped[str] = mapped_column(
        type_=types.VARCHAR(1000), server_default=text("''")
    )
    rating: Mapped[int] = mapped_column(type_=types.DECIMAL(2, 1))
    created_at: Mapped[timestamp] = mapped_column(
        type_=types.TIMESTAMP,
        server_default=text('NOW()'),
    )

Note: the test setup file is still the bottleneck of our test environment setup.

The Essence of the Fixture Setup

The key fixtures to implement in a Python application with a database connection include:


@pytest_asyncio.fixture(scope='session')
async def create_test_db(os_environ_patch):
    test_db_name = 'example_db_test'
    engine = create_async_engine(
        os.environ['POSTGRES_DSN_ORIGINAL'],
        isolation_level='AUTOCOMMIT',
    )

    create_db_op = text(f'CREATE DATABASE {test_db_name}')
    drop_db_op = text(f'DROP DATABASE IF EXISTS {test_db_name} WITH (FORCE)')
    async with engine.begin() as conn:
        await conn.execute(create_db_op)

    yield
    async with engine.connect() as conn:
        await conn.execute(drop_db_op)


@pytest.fixture(scope='session')
def migrate_db(create_test_db):
    config = Config('alembic.ini')
    test_db_url = os.environ['POSTGRES_DSN']
    config.set_main_option('sqlalchemy.url', test_db_url)
    command.upgrade(config, 'head')
    yield
    command.downgrade(config, 'base')


@pytest_asyncio.fixture
async def db(migrate_db) -> AsyncGenerator[DBSessionManager, None]:
    postgres_dsn = os.environ['POSTGRES_DSN']
    db_manager = DBSessionManager(postgres_dsn)
    yield db_manager
    target_metadata = Base.metadata
    tables = target_metadata.tables.keys()
    all_tables_str = ', '.join(f'"{t}"' for t in tables)
    async with db_manager.asessionmaker() as s:
        await s.execute(text(f'TRUNCATE TABLE {all_tables_str} CASCADE'))
        await s.commit()
    await db_manager.close()

Now, let’s zoom in on the most important parts.

Migrations

@pytest.fixture(scope='session')
def migrate_db(create_test_db):

The above lets us run through the migration step only once per session.

Tables truncation

Here, the DB fixture is relying on the session manager to execute custom SQL transactions.

    target_metadata = Base.metadata
    tables = target_metadata.tables.keys()   # dict_keys(['product', 'review'])
    all_tables_str = ', '.join(f'"{t}"' for t in tables)   # '"product", "review"'

The code above extracts the registered tables to the comma-separated and quotation marks-wrapped representation.

After that, TRUNCATE TABLE {all_tables_str} CASCADE will delete all the records in the tables using cascade mode by deleting records in the constraints-dependent tables.

The final step is to dispose of the DB manager instance

    await db_manager.close()

This way, we are ensured the migration process is set up correctly within our Python application.