Intro

One of the first barriers a developer is facing during the initial project initialization is setting a unit tests configuration and a basic set of unit tests.

The first step is to set a general testing strategy:

Migrations management

Approaches

When dealing with DB schema, there are 2 most common migrations management approaches one can use within a project:

The (0) Migrations per function approach is almost never justified in mature projects as migration operations are inherently expensive due to disk and metadata changes. Therefore, running a set of unit tests consisting of hundreds test functions will eventually turn into an annoying process.

However, for dealing with project in the initial state I find it great as it’s almost identical to a DB creation from zero. Also, this approach is good for testing downgrades with a DB filled with some data.

Codebase highlights

If we have a session manager class defined classically:

import contextlib
from typing import AsyncGenerator

from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine, AsyncSession


class DBSessionManager:
    def __init__(self, postgres_dsn: str):
        self._engine: AsyncEngine = create_async_engine(url=postgres_dsn)

    @contextlib.asynccontextmanager
    async def asessionmaker(self) -> AsyncGenerator[AsyncSession, None]:
        async with AsyncSession(self._engine) as s:
            yield s

    async def close(self):
        await self._engine.dispose()

, then we’re ready for DB operations with our table.

Then let us define our test DB table, and keep it simple:

class Product(Base):
    __tablename__ = 'product'

    id: Mapped[UUID] = mapped_column(
        type_=types.UUID,
        primary_key=True,
        server_default=text('gen_random_uuid()'),
    )
    name: Mapped[str] = mapped_column(
        type_=types.VARCHAR(100), server_default=text("''")
    )
    created_at: Mapped[timestamp] = mapped_column(
        type_=types.TIMESTAMP,
        server_default=text('NOW()'),
    )


class Review(Base):
    __tablename__ = 'review'

    id: Mapped[UUID] = mapped_column(
        type_=types.UUID,
        primary_key=True,
        server_default=text('gen_random_uuid()'),
    )
    content: Mapped[str] = mapped_column(
        type_=types.VARCHAR(1000), server_default=text("''")
    )
    rating: Mapped[int] = mapped_column(type_=types.DECIMAL(2, 1))
    created_at: Mapped[timestamp] = mapped_column(
        type_=types.TIMESTAMP,
        server_default=text('NOW()'),
    )

Then our conftest.py file is close to being simple:

from typing import AsyncGenerator
from unittest import mock
import os


import pytest
import pytest_asyncio
from asyncpg.exceptions import DuplicateDatabaseError
from alembic import command
from alembic.config import Config
from sqlalchemy import text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.asyncio import create_async_engine


from project.db.session_manager import DBSessionManager
from project.db.models import Product


@pytest.fixture(autouse=True, scope='session')
def os_environ_patch():
    original_connection_string = os.environ['POSTGRES_DSN']
    new_environ = {
        'POSTGRES_DSN': f'{original_connection_string}_test',
        'POSTGRES_DSN_ORIGINAL': original_connection_string,
    }
    with mock.patch.dict(os.environ, new_environ, clear=False):
        yield


@pytest_asyncio.fixture(scope='session')
async def create_test_db(os_environ_patch):
    test_db_name = 'example_db_test'
    engine = create_async_engine(
        os.environ['POSTGRES_DSN_ORIGINAL'],
        isolation_level='AUTOCOMMIT',
    )

    create_db_op = text(f'''CREATE DATABASE "{test_db_name}"''')
    drop_db_op = text(f'DROP DATABASE IF EXISTS "{test_db_name}"')
    try:
        async with engine.begin() as conn:
            await conn.execute(create_db_op)
    except ProgrammingError as err:
        if err.orig and err.orig.pgcode == DuplicateDatabaseError.sqlstate:
            async with engine.begin() as conn:
                await conn.execute(drop_db_op)
                await conn.execute(create_db_op)

    yield
    async with engine.begin() as conn:
        await conn.execute(drop_db_op)


@pytest.fixture
def migrate_db(create_test_db):
    config = Config('alembic.ini')
    test_db_url = os.environ['POSTGRES_DSN']
    config.set_main_option('sqlalchemy.url', test_db_url)
    command.upgrade(config, 'head')
    yield
    command.downgrade(config, 'base')


@pytest_asyncio.fixture
async def db(migrate_db) -> AsyncGenerator[DBSessionManager, None]:
    postgres_dsn = os.environ['POSTGRES_DSN']
    db_manager = DBSessionManager(postgres_dsn)
    yield db_manager
    await db_manager.close()


@pytest_asyncio.fixture
async def product_fixture(db: DBSessionManager):
    async with db.asessionmaker() as session:
        product = Product(name='Test product')
        session.add(product)
        await session.commit()
        await session.refresh(product)
    return product

The trickiest parts here:

Finally, our tests can be identical to:

@pytest.mark.asyncio
async def test_get_record(db: DBSessionManager, product_fixture: Product):
    """Test reading a single existing Record instance"""
    # Prepare

    # Do
    stmt = select(Product)
    async with db.asessionmaker() as s:
        result = await s.execute(stmt)
    product = result.scalar_one_or_none()

    # Check
    assert product is not None
    assert product.id == product_fixture.id
    assert product.name == product_fixture.name
    stmt = select(func.count(Product.id)).select_from(Product)
    async with db.asessionmaker() as s:
        result = await s.execute(stmt)
    assert result.scalar_one() == 1


@pytest.mark.asyncio
async def test_create_record(db: DBSessionManager, product_fixture: Product):
    """Test creating a new Record instance"""
    # Prepare
    stmt = select(func.count(Product.id)).select_from(Product)
    async with db.asessionmaker() as s:
        result = await s.execute(stmt)
    assert result.scalar_one() == 1
    new_product_name = 'New product'

    # Do
    insert_op = insert(Product).values(name=new_product_name)
    async with db.asessionmaker() as s:
        await s.execute(insert_op)
        await s.commit()

    # Check
    stmt = select(func.count(Product.id)).select_from(Product)
    async with db.asessionmaker() as s:
        result = await s.execute(stmt)
    assert result.scalar_one() == 2
    async with db.asessionmaker() as s:
        result = await s.execute(
            select(Product).order_by(Product.created_at.desc()).limit(1)
        )
    new_product = result.scalar_one()
    assert new_product.name == new_product_name