A Template for Custom and Distributed Training

Custom training loops offer great flexibility. You can quickly add new functionality and gain deep insight into how your algorithm works under the hood. However, setting up custom algorithms over and over is tedious. The general layout often is the same; it’s only tiny parts that change.

A Template for Custom and Distributed Training
The general layout often is the same; it’s only tiny parts that change. Photo by Chris Ried on Unsplash

This is where the following template comes into play: It outlines a custom and distributed training loop. All places that you have to modify to fit your task are highlighted with TODO notes.

The general layout of custom distributed loops

A custom training loop  is a mechanism that manually iterates over the dataset, updates the model’s weights, and computes any metrics.

Before iterating over any dataset, be it the train, validation, or test split, the dataset must be made distribution ready. This is done with the help of TensorFlow’s distribution strategy objects.

We first create our strategy object, which is responsible for all distributed computation. By choosing different distribution strategies, we can employ our algorithm in all sorts of computation environments. This fact makes the custom loops highly flexible.

In line 2, after creating our strategy, we make our dataset distribution ready; TensorFlow handles all internal details.

With our distributed dataset, we can then iterate over it using a “for i in x” approach:

This loop is the same for all subsets (train, validation, test). The main difference is the step that is called. Above, I have exemplarily called the distributed_train_step, which handles the distribution of our data to all accelerators. But don’t worry, TensorFlow handles most of the internal device-to-device communication. We only have it know what we want to do. The How is done for us.

As part of the train, test, or validation steps, we also update any metrics that we want to track during training. We have to do this manually, as shown below:

Here, I have written a training step that takes a single data batch, unpacks it into features and labels, calculates the gradients, updates the model, and computes any training metrics. This approach is similar for the validation and test steps; we only skip the model updates.

To summarize the general layout:

  • Make required objects (models, datasets, optimizers, metrics) distribution ready
  • Iterate over the datasets
  • Call the distributed train/test/validation steps to update the model and calculate the metrics

That is it so far for the general layout. In the following, we’ll go over the template and see what you have to modify for your task at hand.

Template for distributed custom training

We begin with the necessary imports and some global definitions:

For this template, I’ve decided to make global all objects and variables used within the training, test, and validation steps. This way, we won’t have to pass them around all the time, which makes the code cleaner. However, this is only one way of doing it, and there might be better ones.

All models and their optimizers are made available globally, as are any metrics and losses. I have not included test and validation losses in the template, but I highly recommend you do so. Beyond that, we also globally register the global batch size and the distribution strategy (as introduced before).

Our training script is called with a few sample commands, so we also imported the argparse library. The code to kick of the complete script is

and leads to the primary method:

Main method

The main method starts with using most of the global objects we have introduced before. So far, they have only been placeholders, which is why we instantiate them now. We begin with the distribution strategy (line 13) and follow with the losses (line 16), metrics (line 19), and models (line 22). We will cover all the called methods in detail soon.

What I have skipped in this template is the dataset creating part. There are just too many ways to do that, and you know what suits your problems best. From line 26 on, I have thus only used None as the initial value; replace this part with creating and distributing your datasets (using the strategy as shown in a previous code snippet)

From line 31 on, we first train the model using the distributed train and validation datasets and then evaluate it on the hold-out test dataset.


The first method, create_loss_objects, is responsible for creating any losses that we use. It is shown below:

Within the scope of the distribution strategy, we create any losses we need. In the sample code, this is only a dummy loss — modify it to suit your needs. Whichever loss you choose, don’t forget to set the reduction parameter as shown: we manually reduce the loss in a later method. Whether you use a single loss only or multiple ones, create them all here and return the objects. You should globally register all returned loss objects. I have done so with the sample train_loss_object1. Repeat this for all losses.


The code for creating the metrics used to track our model’s progress follows the same scheme:

Under the scope of the distribution strategy, we instantiate all metrics. In the template, this is only the training accuracy. What is missing are any metrics for validation (during training) and testing (after training). I recommend giving meaningful names to identify the metrics when printing. As before, register all returned metrics globally. In the mainmethod, I have done this with train_metric1 in line 9.

Models and optimizers

After instantiating the metrics, we now focus on the model(s) and optimizer(s). The template method for this follows the previous two:

As before, we create anything under the scope of the chosen strategy. This step is mandatory to make the internal variables of the model and the optimizers distribution ready. In the sample code, I have not selected any specific model or optimizer; modify anything from line 8 to suit your task. Also, don’t forget to register the returned objects globally, as I have done with model and optimizer in line 9 of the main function.


After all required objects are instantiated and made distribution-ready, it’s time to create the datasets. I have left this blank intentionally, as there are multiple ways to create a dataset: from tf.data.Dataset objects to custom generators or hybrid approaches. You can find an overview here, but since this template is geared towards more experienced coders, I doubt that it’ll be necessary for you. Anyway, after you have created the dataset, you have to distribute it using

Once the datasets are prepared, we can go on to the training loop (line 31 in main). We need the train and validation datasets for the template we use, but you are free to ditch the latter.

Train and validation loop

The training loop is at the heart of our distributed algorithm:

In line 12, we begin with iterating over the datasets for epoch times. Then, from line 15 to line 20, we loop over the training data and feed every batch to the distributed_train_step method. After we have finished an iteration, we calculate the training loss and repeat the process for the validation data.

After this is done, we query all (global) metrics and loss objects for their current value, which we print and reset for the next epoch (lines 34 to 39). As the comments point out, modify this code to account for any metrics or loss objects. Further, if you do not reset the metrics after each epoch (line 37), they will track the progress over all epochs, leading to wrong per-epoch values.

Distributed train step

The training loop (line 17 onwards) calls the distributed train step internally. The code for this is

We don’t have much to do here: We get a data batch (dataset_inputs), tell the strategy to run a single train step, and reduce the returned loss. Now, why don’t we call train_step directly? Because we work with distributed data. The strategy.runcall in line 5 accounts for this; it is called for each computing device. So, for example, if we have 5 GPUs connected, TensorFlow would automatically and simultaneously call the train step five times.

We then have to aggregate (think “combine” or “merge”) the train loss from each replica onto the current device (where we run the script) and reduce it to take the number of accelerators into account. Have a look at the documentation for more information.

Train step

If we were not distributing our workload, we could directly call the actual train step. However, since we are using a distributed setup, we have to let TensorFlow know. We did this with the previous method, which uses strategy.run to handle the distribution. Inside this call, we said that the actual training method is the train_step method, which is shown below:

Since TensorFlow handles splitting the batch, we receive a single batch. In line 9, we unpack it into features and labels — this is only an example, adapt the code to your needs. The coming lines follow a standard custom loop: We calculate the loss (line 15), calculate and apply the gradients (lines 18 and 19), and update any training metric (line 22). As I have marked with the TODO comments, you must adapt this to account for all models, losses, and metrics you utilize.

Compute the (training) loss

To calculate any loss, we utilize the following short template:

We query our loss object in line 10, which gives us the loss per replica. Let me explain this: Our custom algorithm is run on multiple computing devices or replicas. On each device, the train step is called independently, leading to n losses in total. Each loss is now used to calculate the gradients, which are synced across the replicas by summing. If we did not scale the loss, the result would be exaggerated.

Still not convinced? On a single machine, the loss is divided by the number of samples in a mini-batch. In a multi-GPU setup, we do not have to divide the loss by the local batch size but by the global batch size. For example, the local batch size might be 8, and the global batch size might be 32 (=4*8 for four GPUs). Would we divide by 8, we would assume that the total number of samples we saw in a forward pass is 8 — which is not correct. Therefore, in line 14, we average the loss to take the global batch size into account.

That is is for the training procedure. To summarize: We first call distributed_train_step, which handles distributing the calculation to each worker. Then, on each worker, the train_step is called, which takes a single batch and updates the model’s weights.

Validation step

The validation procedure is highly similar to the training procedure:

There are only slight differences: We don’t update the model’s weights (line 11), and we update our validation metrics (line 13). Modify these parts to suit your needs. The distributed_validation_step follows the same layout as the distributed_train_step: Get the scaled per-replica losses and aggregate them (lines 25 to 28).

Test loop

The test loop is similar to the previous two:

We iterate over the distributed test dataset (line 11), sum the losses (line 12), and average it (line 14). After that, we collect the results of all metrics. I have not included any separate test metrics; modify this part to suit your needs (lines 16 and 17).

Once we have a look at the distributed_test_step (line 12), we’ll notice that it is once again highly similar:

In line 32, we collect the scaled loss and average them in line 36. The actual test_step (line 2) takes a single batch, unpacks it (line 9), and calculates the loss (line 16) and any test metrics (line 18). In this snippet, I have not included any separate test metrics. You have to adapt it to fit your problem. Generally, you can create all metrics in the create_metricsmethod. Remember to register them globally to be able to access them easily.


We have covered the essential parts of a custom and distributed algorithm. We can use it for all kinds of computing environments by choosing an appropriate distribution strategy — also for single-GPU setups. In general, all places that have to be modified are marked with TODO comments. While the template should cover most use cases, it serves as a starting point only. Practitioners like you and me should adapt the crucial parts to meet their requirements. The complete code is available at GitHub.