" MicromOne: How to Write a Training Loop in PyTorch: A Practical Guide

Pagine

How to Write a Training Loop in PyTorch: A Practical Guide

Getting Started: Model, Data, and Training Function

To keep our code clean and reusable, we wrap the training loop inside a function called train_model.
This function takes the following parameters:

  • the model

  • the number of epochs (with a default value of 8)

  • the learning rate (default set to 0.05)

Inside this function, we can either define the optimizer and loss function directly or pass them in as parameters. In this example, we use:

  • Stochastic Gradient Descent (SGD) with momentum = 0.9

  • Cross Entropy Loss, implemented with nn.CrossEntropyLoss

This setup is commonly used for classification tasks.

The Core of the Training Loop

Now let’s dive into the most important part: the training loop.

For each epoch:

  1. Set the model to training mode using model.train()
    This ensures that gradients are computed correctly and that layers such as Dropout and Batch Normalization behave as expected during training.

  2. Initialize tracking variables

    • epoch loss

    • number of correct predictions

  3. Iterate over the training data
    We use enumerate(train_loader) to keep track of the number of mini-batches processed. This is especially useful when working with large datasets or long training times, as it allows us to monitor running loss during each epoch.

  4. GPU handling
    We check whether a GPU is available with torch.cuda.is_available().
    If so, we move both inputs and labels to the GPU using the .cuda() method.

  5. Forward pass

    • reset gradients with optimizer.zero_grad()

    • pass the inputs through the model to obtain outputs

  6. Loss computation
    We compute the loss by passing the model outputs and the labels to the loss function (nn.CrossEntropyLoss).

  7. Backward pass and parameter update

    • compute gradients using loss.backward()

    • update model parameters with optimizer.step()

  8. Accuracy calculation
    We obtain predictions using torch.max on the output tensor and compare them with the ground-truth labels to count the number of correct predictions.

Model Validation

After completing the training phase for an epoch, we move on to validation:

  • initialize validation loss and correct prediction counters

  • set the model to evaluation mode with model.eval()
    This disables gradient computation and ensures consistent behavior during inference

  • perform a forward pass on the validation dataset

  • compute validation loss and accuracy in the same way as during training

Finally, we print the training and validation metrics to the notebook.

Running the Training Process

At this point, all that’s left to do is call the train_model function and pass in the instantiated model (for example, net).
Depending on the dataset size and model complexity, training may take some time.