Efficient workflow and reproducibility are crucially important components in every machine learning project, which enables to:

PyTorch Lightning and Hydra serve as the foundation of this template. Such reasonable technology stack for deep learning prototyping provides a comprehensive and seamless solution, allowing you to effortlessly explore different tasks across a variety of hardware accelerators such as CPUs, multi-GPUs, and TPUs. Furthermore, it includes a curated collection of best practices and extensive documentation for greater clarity and comprehension.

This template can be used as is for some basic tasks like Classification, Segmentation, or Metric Learning, or be easily extended for any other tasks due to high-level modularity and scalable structure.

As a baseline, I have used the gorgeous Lightning Hydra Template, reshaped and polished it, and implemented more features that can improve the overall efficiency of workflow and reproducibility.

Table of content

Main technologies

PyTorch Lightning - a lightweight deep learning framework / PyTorch wrapper for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale.

Hydra - a framework that simplifies configuring complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line.

Project structure

The machine learning project structure may differ depending on the specific requirements and goals of the project, as well as the tools and frameworks being used. However, this is a typical directory structure of machine learning project:

In this particular case, the directory structure looks like this:

├── configs                     <- Hydra configuration files
│   ├── callbacks               <- Callbacks configs
│   ├── datamodule              <- Datamodule configs
│   ├── debug                   <- Debugging configs
│   ├── experiment              <- Experiment configs
│   ├── extras                  <- Extra utilities configs
│   ├── hparams_search          <- Hyperparameter search configs
│   ├── hydra                   <- Hydra settings configs
│   ├── local                   <- Local configs
│   ├── logger                  <- Logger configs
│   ├── module                  <- Module configs
│   ├── paths                   <- Project paths configs
│   ├── trainer                 <- Trainer configs
│   │
│   ├── eval.yaml               <- Main config for evaluation
│   └── train.yaml              <- Main config for training
│
├── data                        <- Project data
├── logs                        <- Generated logs
├── notebooks                   <- Jupyter notebooks
├── scripts                     <- Shell scripts
│
├── src                         <- Source code
│   ├── callbacks               <- Additional callbacks
│   ├── datamodules             <- Lightning datamodules
│   ├── modules                 <- Lightning modules
│   ├── utils                   <- Utility scripts
│   │
│   ├── eval.py                 <- Run evaluation
│   └── train.py                <- Run training
│
├── tests                       <- Tests of any kind
│
├── .dockerignore               <- List of files ignored by docker
├── .gitattributes              <- List of attributes to pathnames
├── .gitignore                  <- List of files ignored by git
├── .pre-commit-config.yaml     <- Configuration of pre-commit hooks
├── Dockerfile                  <- Dockerfile
├── Makefile                    <- Makefile
├── pyproject.toml              <- Config for testing and linting
├── requirements.txt            <- Python dependencies
├── setup.py                    <- Setup file
└── README.md

Workflow - how it works

Before starting a project, you should consider the following aspects to ensure the reproducibility of results:

Basic workflow

This template could be used as is for some basic tasks like Classification, Segmentation, or Metric Learning approach, but if you need to do something more complex, here is a general workflow:

  1. Write your PyTorch Lightning Module (see examples in src/modules/single_module.py)

  2. Write your PyTorch Lightning DataModule (see examples in src/datamodules/datamodules.py)

  3. Fill up your configs, particularly create experiment configs

  4. Run experiments:

    • Run training with chosen experiment config:
    python src/train.py experiment=experiment_name.yaml
    
    • Use hyperparameter search, for example by Optuna Sweeper via Hydra:
    # using Hydra multirun mode
    python src/train.py -m hparams_search=mnist_optuna
    
    • Execute the runs with some config parameter manually:
    python src/train.py -m logger=csv module.optimizer.weight_decay=0.0,0.00001,0.0001
    

  5. Run evaluation with different checkpoints or run prediction on a custom dataset for additional analysis

The template contains an example with MNIST classification, which uses for tests by the way. If you run python src/train.py, you will get something like this: Show terminal screen when running pipeline in the template documentation.

LightningDataModule

At the start, you need to create PyTorch Dataset for your task. It has to include __getitem__ and __len__ methods. Maybe you can use as is or easily modify already implemented datasets in the template. See more details in PyTorch documentation.

Also, it could be useful to see a data section about how it is possible to save data for training and evaluation.

Then, you need to create DataModule using PyTorch Lightning DataModule API. By default, API has the following methods:

See examples of datamodule configs in configs/datamodule folder.

Show LightningDataModule API in the template documentation.

By default, the template contains the following DataModules:

In the template, DataModules has _get_dataset_ method to simplify datasets instantiation.

LightningModule

LightningModule API

Next, your need to create LightningModule using PyTorch Lightning LightningModule API. Minimum API has the following methods:

Also, you can override optional methods for each step to perform additional logic:

Show LightningModule API methods and appropriate order in the template documentation.

In the template, LightningModule has model_step method to adjust repeated operations, like forward or loss calculation, which are required in training_stepvalidation_step and test_step.

Metrics

The template offers the following Metrics API:

Each metric config should contain _target_ key with the metric class name and other parameters, which are required by the metric. The template allows to use any metrics, for example from torchmetrics or implemented by yourself. See more details about  torchmetrics API, implemented Metrics API and metrics config as a part of network configs in configs/module/network folder.

Metric config example:

metrics:
  main:
    _target_: "torchmetrics.Accuracy"
    task: "binary"
  valid_best:
    _target_: "torchmetrics.MaxMetric"
  additional:
    AUROC:
      _target_: "torchmetrics.AUROC"
      task: "binary"

Loss

The template suggests the following Losses API:

The template allows you to use any losses, for example from PyTorch or implemented by yourself. See more details about implemented Losses API and loss config as a part of network configs in configs/module/network folder.

Loss config examples:

loss:
  _target_: "torch.nn.CrossEntropyLoss"
loss:
  _target_: "torch.nn.BCEWithLogitsLoss"
  pos_weight: [0.25]
loss:
  _target_: "src.modules.losses.VicRegLoss"
  sim_loss_weight: 25.0
  var_loss_weight: 25.0
  cov_loss_weight: 1.0

Also, the template includes few manually implemented losses:

Model

The template offers the following Model API, model config should contain:

By default, a model can be loaded from:

See more details about implemented Model API and model config as a part of network configs in configs/module/network folder.

Model config example:

model:
  _target_: "src.modules.models.classification.Classifier"
  model_name: "torchvision.models/mobilenet_v3_large"
  model_repo: null
  weights: "IMAGENET1K_V2"
  num_classes: 1

Implemented LightningModules

By default, the template comes with the following LightningModules:

See examples of module configs in configs/module folder. Some LightningModule config example:

_target_: src.modules.single_module.MNISTLitModule

defaults:
  - _self_
  - network: mnist.yaml

optimizer:
  _target_: torch.optim.Adam
  lr: 0.001
  weight_decay: 0.0

scheduler:
  scheduler:
    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
    mode: "max"
    factor: 0.1
    min_lr: 1.0e-9
    patience: 10
    verbose: True
  extras:
    monitor: ${replace:"__metric__/valid"}
    interval: "epoch"
    frequency: 1

logging:
  on_step: False
  on_epoch: True
  sync_dist: False
  prog_bar: True

Training loop

Training loop in the template consists of the following stages:

See more details in training loop and configs/train.yaml.

Evaluation and prediction loops

Evaluation loop in the template consists of the following stages:

See more details in evaluation loop and configs/eval.yaml.

The template contains the following Prediction API:

datasets:
  predict:
    dataset1:
      _target_: src.datamodules.datasets.ClassificationDataset
      json_path: ${paths.data_dir}/predict/data1.json
    dataset2:
      _target_: src.datamodules.datasets.ClassificationDataset
      json_path: ${paths.data_dir}/predict/data2.json

See more details about Prediction API and predict_step in LightningModule.

Callbacks

PyTorch Lightning has a lot of built-in callbacks, which can be used just by adding them to the callbacks config, thanks to Hydra. See examples in callbacks config folder.

By default, the template contains a few of them:

However, there is an additional LightProgressBar callback, which might be more elegant and useful, instead of using RichProgressbar:

Logs

Hydra creates new output directory in logs/ for every executed run.

Furthermore, template offers to save additional metadata for better reproducibility and debugging, including:

Default logging structure:

├── logs
│   ├── task_name
│   │   ├── runs                        <- Logs generated by runs
│   │   │   ├── YYYY-MM-DD_HH-MM-SS     <- Datetime of the run
│   │   │   │   ├── .hydra              <- Hydra logs
│   │   │   │   ├── csv                 <- Csv logs
│   │   │   │   ├── wandb               <- Weights & Biases logs
│   │   │   │   ├── checkpoints         <- Training checkpoints
│   │   │   │   ├── metadata            <- Metadata
│   │   │   │   │   ├── pip.log         <- Pip logs
│   │   │   │   │   ├── git.log         <- Git logs
│   │   │   │   │   ├── env.log         <- Environment logs
│   │   │   │   │   ├── src             <- Full copy of `src/`
│   │   │   │   │   └── configs         <- Full copy of `configs/`
│   │   │   │   └── ...                 <- Any other saved files
│   │   │   └── ...
│   │   │
│   │   └── multiruns                   <- Logs generated by multiruns
│   │       ├── YYYY-MM-DD_HH-MM-SS     <- Datetime of the multirun
│   │       │   ├──1                    <- Multirun job number
│   │       │   ├──2
│   │       │   └── ...
│   │       └── ...
│   │
│   └── debugs                          <- Logs generated during debug
│       └── ...

Data

Usually, images or any other data files just stored on disk in folders. It is a simple and convenient way.

However, there are other methods and one of them calls as Hierarchical Data Format HDF5 or h5py, which has a few reasons why it might be more beneficial to store images in HDF5 files instead of just folders:

This template contains a tool which might be used to easily create and read HDF5 files.

To create HDF5 file:

from src.datamodules.components.h5_file import H5PyFile

H5PyFile().create(
    filename="/path/to/dataset_train_set_v1.h5",
    content=["/path/to/image_0.png", "/path/to/image_1.png", ...],
    # each content item loads as np.fromfile(filepath, dtype=np.uint8)
)

To read HDF5 file in the wild:

import matplotlib.pyplot as plt
from src.datamodules.components.h5_file import H5PyFile

h5py_file = H5PyFile(filename="/path/to/dataset_train_set_v1.h5")
image = h5py_file[0]

plt.imshow(image)

To read HDF5 file in Dataset.__getitem__:

def __getitem__(self, index: int) -> Any:
    key = self.keys[index]  # get the image key, e.g. path
    data_file = self.data_file
    source = data_file[key]  # get the image
    image = io.BytesIO(source)  # read the image
    ...

Hyperparameters search

Hydra provides out-of-the-box hyperparameters sweepers: Optuna, Nevergrad or Ax.

You may define hyperparameters search by adding new config file to configs/hparams_search.

See example of hyperparameters search config. With this method, there is no need to add extra code, everything is specified in a single configuration file. The only requirement is to return the optimized metric value from the launch file.

Execute it with:

python src/train.py -m hparams_search=mnist_optuna

The optimization_results.yaml will be available under logs/task_name/multirun folder.

Docker

Docker is an essential part of environment reproducibility that makes it possible to easily package a machine learning pipeline and its dependencies into a single container that can be easily deployed and run on any environment. This is particularly useful due to it helps to ensure that the code will run consistently, regardless of the environment in which it is deployed.

Docker image could require some additional packages depends on which device is used for running. For example, for running on cluster with NVIDIA GPUs it requires the CUDA Toolkit from NVIDIA. The CUDA Toolkit provides everything you need to develop GPU-accelerated applications, including GPU-accelerated libraries, a compiler, development tools and the CUDA runtime.

In general, there are many way how to set up it, but to simplify this process you can use:

Moreover, it can be advantageous to use:

Here it is some example of container running based on proposed Dockerfile and .dockerignore:

set -o errexit
export DOCKER_BUILDKIT=1
export PROGRESS_NO_TRUNC=1

docker build --tag <project-name> \
    --build-arg OS_VERSION="22.04" \
    --build-arg CUDA_VERSION="11.7.0" \
    --build-arg PYTHON_VERSION="3.10" \
    --build-arg USER_ID=$(id -u) \
    --build-arg GROUP_ID=$(id -g) \
    --build-arg NAME="<your-name>" \
    --build-arg WORKDIR_PATH=$(pwd) .

docker run \
    --name <task-name> \
    --rm \
    -u $(id -u):$(id -g) \
    -v $(pwd):$(pwd):rw \
    --gpus '"device=0,1,3,4"' \
    --cpuset-cpus "0-47" \
    -it \
    --entrypoint /bin/bash \
    <project-name>:latest

Tests

Tests are an important aspect of software development in general, and especially in Machine Learning, because here it can be much more difficult to understand if code are working correctly without testing. Consequently, template contains some generic tests implemented with pytest.

For this purpose MNIST is used. It is a small dataset, so it is possible to run all tests on CPU. However, it is easy to implement tests for your own dataset if it requires.

As a baseline the tests cover:

All this implemented tests created for verifying that the main pipeline modules and utils are executable and working as expected However, sometimes it couldn’t be enough to ensure that the code is working correctly, especially in case of more complex pipelines and models.

For running:

# run all tests
pytest

# run tests from specific file
pytest tests/test_train.py

# run tests from specific test
pytest tests/test_train.py::test_train_ddp_sim

# run all tests except the ones marked as slow
pytest -k "not slow"

Continuous integration

The template contains a few initial CI workflows via the GitHub Actions platform. It makes it easy to automate and streamline development workflows, which can help to save time and effort, increase efficiency, and improve overall quality of the code. In particularly, it includes:

Note: You need to enable the GitHub Actions from the settings in your repository.

See more about GitHub Actions for CI.

In the case of using GitLab, it is easy to set up GitLab CI based on GitHub Actions workflows. Here it manages by .gitlab-ci.yml file. See more here.

Also published here.