skip to content
Site header image Abdurahman A. Mohammed

Loss Reduction in Regression

Choosing between sum and mean reduction in PyTorch affects gradient magnitude, learning stability, and batch size dependence.

Last Updated:

Loss Reduction in PyTorch: Sum vs. Mean

💡
PyTorch’s loss functions provide three options for reduction:
  • None (reduction='none') – Returns individual loss values without any reduction.
  • Sum (reduction='sum') – Computes the total sum of all loss values.
  • Mean (reduction='mean') – Averages the loss values.

While sum aggregates all loss values and mean normalizes them, the choice between these two affects how gradients behave during training.

Impact on Training

1. Gradient Magnitude Differences

Using sum results in loss values that scale with batch size, leading to larger gradients. Conversely, mean normalizes the loss, maintaining consistent gradient magnitudes.

Example: Comparing Gradients for Sum vs. Mean

import torch
import torch.nn as nn

# Define a simple model
model = nn.Linear(10, 10)
x = torch.randn(10, 10)
y = torch.randn(10, 10)

# Compute loss with mean reduction
criterion = nn.MSELoss(reduction='mean')
out = model(x)
loss = criterion(out, y)
loss.backward()
print("Gradient sum with mean:", model.weight.grad.abs().sum().item())

# Reset gradients
model.zero_grad()

# Compute loss with sum reduction
criterion = nn.MSELoss(reduction='sum')
out = model(x)
loss = criterion(out, y)
loss.backward()
print("Gradient sum with sum:", model.weight.grad.abs().sum().item())

Expected Output (approximate values):

Gradient sum with mean: 5.61
Gradient sum with sum: 561.42

The significantly larger gradient with sum demonstrates how loss scaling affects weight updates. If sum is used, it may be necessary to adjust the learning rate accordingly.


2. Dependency on Batch Size

With sum, loss values increase proportionally with batch size, meaning changes in batch size affect training dynamics. Mean prevents this issue by normalizing the loss.

Example: Effect of Batch Size on Loss

batch_size_10 = torch.randn(10, 5)
batch_size_20 = torch.randn(20, 5)

# Compute loss with sum reduction
loss_fn = nn.L1Loss(reduction='sum')
loss_10 = loss_fn(batch_size_10, torch.zeros_like(batch_size_10))
loss_20 = loss_fn(batch_size_20, torch.zeros_like(batch_size_20))

print("Loss with batch size 10 (sum reduction):", loss_10.item())
print("Loss with batch size 20 (sum reduction):", loss_20.item())

loss_fn = nn.L1Loss(reduction='mean')
loss_10 = loss_fn(batch_size_10, torch.zeros_like(batch_size_10))
loss_20 = loss_fn(batch_size_20, torch.zeros_like(batch_size_20))

print("Loss with batch size 10 (mean reduction):", loss_10.item())
print("Loss with batch size 20 (mean reduction):", loss_20.item())

Since sum aggregates loss values, doubling the batch size results in approximately twice the loss. However, using mean keeps the loss value consistent across different batch sizes.

Loss with batch size 10 (sum reduction): 42.75994110107422
Loss with batch size 20 (sum reduction): 84.79824829101562
Loss with batch size 10 (mean reduction): 0.8551988005638123
Loss with batch size 20 (mean reduction): 0.8479824662208557

3. Custom Reduction with None

Using reduction='none' allows manual normalization, offering flexibility for custom loss scaling.

Example: Custom Reduction

loss_fn = nn.MSELoss(reduction='none')
loss_values = loss_fn(x, y)  # No reduction applied
custom_loss = loss_values.sum() / x.shape[0]  # Normalize by batch size

print("Shape of loss_values:", loss_values.shape)
print("Custom reduced loss:", custom_loss.item())

This approach is beneficial when implementing specialized loss functions or adjusting loss scaling dynamically. With reduction set to 'none', you get individual loss values per sample, allowing for custom aggregation and weighting strategies.

Shape of loss_values: torch.Size([10, 10])
Custom reduced loss: 15.885004997253418



Choosing the Right Reduction Mode

  • Use mean for stable training and batch size-independent loss values.
  • Use sum when loss should scale with dataset size, such as in segmentation tasks.
  • Use none for full control over loss computation and custom normalization.