Skip to content

Callbacks Guide

This guide explains how to use and create callbacks in Dream Trainer.

Table of Contents

Basic Usage

Callbacks are a way to extend the trainer's functionality without modifying its code. They are called at specific points during training.

Adding Callbacks

Add callbacks to your trainer configuration:

from dream_trainer import DreamTrainerConfig
from dream_trainer.callbacks import (
    LoggerCallback,
    ProgressBar,
    CallbackCollection
)

config = DreamTrainerConfig(
    # ... other settings ...
    callbacks=CallbackCollection([
        LoggerCallback(),  # Logs metrics to console/WandB
        ProgressBar(),    # Shows training progress
    ])
)

Callback Order

Callbacks are executed in the order they are added. You can control the order:

callbacks = CallbackCollection([
    LoggerCallback(),     # First: log metrics
    ProgressBar(),       # Second: show progress
    CheckpointCallback() # Third: save checkpoints
])

Built-in Callbacks

LoggerCallback

Logs metrics to console and/or WandB:

from dream_trainer.callbacks import LoggerCallback

logger = LoggerCallback(
    log_every_n_steps=100,  # Log every 100 steps
    log_every_n_epochs=1,   # Log every epoch
    log_metrics=True,       # Log metrics
    log_gradients=False,    # Don't log gradients
    log_parameters=False    # Don't log parameters
)

ProgressBar

Shows training progress:

from dream_trainer.callbacks import ProgressBar

progress = ProgressBar(
    refresh_rate=10,        # Update every 10 steps
    show_epoch=True,        # Show epoch number
    show_step=True,         # Show step number
    show_metrics=True       # Show metrics
)

CheckpointCallback

Saves model checkpoints:

from dream_trainer.callbacks import CheckpointCallback

checkpoint = CheckpointCallback(
    monitor="val_loss",     # Metric to monitor
    mode="min",            # Minimize metric
    save_top_k=3,          # Keep best 3 checkpoints
    save_last=True,        # Always save latest
    every_n_epochs=1       # Save every epoch
)

EarlyStoppingCallback

Stops training when metric stops improving:

from dream_trainer.callbacks import EarlyStoppingCallback

early_stopping = EarlyStoppingCallback(
    monitor="val_loss",     # Metric to monitor
    mode="min",            # Minimize metric
    patience=5,            # Wait 5 epochs
    min_delta=0.001        # Minimum change
)

LearningRateMonitor

Logs learning rate changes:

from dream_trainer.callbacks import LearningRateMonitor

lr_monitor = LearningRateMonitor(
    logging_interval="step",  # Log every step
    log_momentum=True        # Log momentum too
)

Creating Callbacks

Basic Callback

Create a custom callback by extending Callback:

from dream_trainer.callbacks import Callback

class MyCallback(Callback):
    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        """Called after each training batch"""
        # Access trainer state
        current_epoch = trainer.current_epoch
        current_step = trainer.current_step

        # Access outputs
        loss = outputs["loss"]

        # Do something
        if loss > 10.0:
            print(f"High loss detected: {loss}")

Training Hooks

Available training hooks:

class MyCallback(Callback):
    def on_train_start(self, trainer):
        """Called when training starts"""
        pass

    def on_train_epoch_start(self, trainer):
        """Called at the start of each training epoch"""
        pass

    def on_train_batch_start(self, trainer, batch, batch_idx):
        """Called before each training batch"""
        pass

    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        """Called after each training batch"""
        pass

    def on_train_epoch_end(self, trainer):
        """Called at the end of each training epoch"""
        pass

    def on_train_end(self, trainer):
        """Called when training ends"""
        pass

Validation Hooks

Available validation hooks:

class MyCallback(Callback):
    def on_validation_start(self, trainer):
        """Called when validation starts"""
        pass

    def on_validation_epoch_start(self, trainer):
        """Called at the start of each validation epoch"""
        pass

    def on_validation_batch_start(self, trainer, batch, batch_idx):
        """Called before each validation batch"""
        pass

    def on_validation_batch_end(self, trainer, outputs, batch, batch_idx):
        """Called after each validation batch"""
        pass

    def on_validation_epoch_end(self, trainer):
        """Called at the end of each validation epoch"""
        pass

    def on_validation_end(self, trainer):
        """Called when validation ends"""
        pass

State Management

Callbacks can maintain their own state:

class StatefulCallback(Callback):
    def __init__(self):
        super().__init__()
        self.best_metric = float('inf')
        self.patience_counter = 0

    def on_validation_epoch_end(self, trainer):
        # Get current metric
        current_metric = trainer.get_metric("val_loss")

        # Update state
        if current_metric < self.best_metric:
            self.best_metric = current_metric
            self.patience_counter = 0
        else:
            self.patience_counter += 1

        # Check patience
        if self.patience_counter >= 5:
            trainer.should_stop = True

Accessing Trainer

Callbacks have access to the trainer instance:

class TrainerAwareCallback(Callback):
    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        # Access trainer attributes
        model = trainer.model
        optimizer = trainer.optimizer
        current_epoch = trainer.current_epoch

        # Access trainer methods
        trainer.log("custom_metric", 42)
        trainer.save_checkpoint("path/to/checkpoint.pt")

Callback Collection

Adding Callbacks

Add callbacks to a collection:

from dream_trainer.callbacks import CallbackCollection

callbacks = CallbackCollection([
    LoggerCallback(),
    ProgressBar(),
    MyCustomCallback()
])

Removing Callbacks

Remove callbacks from a collection:

# Remove by type
callbacks.remove(LoggerCallback)

# Remove by instance
callbacks.remove(my_callback)

Reordering Callbacks

Change callback order:

# Move to front
callbacks.move_to_front(my_callback)

# Move to back
callbacks.move_to_back(my_callback)

# Move to specific position
callbacks.move_to_position(my_callback, 2)

Best Practices

1. Keep Callbacks Focused

Each callback should do one thing well:

# Good: Single responsibility
class LossMonitor(Callback):
    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        if outputs["loss"] > 10.0:
            print("High loss detected")

# Bad: Multiple responsibilities
class BadCallback(Callback):
    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        # Monitoring
        if outputs["loss"] > 10.0:
            print("High loss detected")
        # Logging
        trainer.log("custom_metric", 42)
        # Checkpointing
        trainer.save_checkpoint("checkpoint.pt")

2. Use Type Hints

Add type hints for better IDE support:

from typing import Dict, Any
import torch

class TypedCallback(Callback):
    def on_train_batch_end(
        self,
        trainer: "DreamTrainer",
        outputs: Dict[str, torch.Tensor],
        batch: torch.Tensor,
        batch_idx: int
    ) -> None:
        pass

3. Document Callbacks

Add docstrings to explain functionality:

class DocumentedCallback(Callback):
    """Monitors training metrics and logs warnings.

    This callback watches for:
    - High loss values
    - NaN gradients
    - Learning rate spikes

    Args:
        loss_threshold: Threshold for high loss warning
        lr_threshold: Threshold for learning rate warning
    """

    def __init__(self, loss_threshold: float = 10.0, lr_threshold: float = 1e-2):
        super().__init__()
        self.loss_threshold = loss_threshold
        self.lr_threshold = lr_threshold

4. Handle Errors

Add proper error handling:

class ErrorHandlingCallback(Callback):
    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        try:
            # Risky operation
            self.process_outputs(outputs)
        except Exception as e:
            # Log error but don't crash
            trainer.log("callback_error", str(e))

5. Test Callbacks

Write unit tests for your callbacks:

def test_my_callback():
    # Create mock trainer
    trainer = MockTrainer()

    # Create callback
    callback = MyCallback()

    # Test hook
    callback.on_train_batch_end(
        trainer,
        outputs={"loss": torch.tensor(5.0)},
        batch=torch.randn(32, 10),
        batch_idx=0
    )

    # Assert expected behavior
    assert trainer.logged_metrics["custom_metric"] == 42

6. Use Callback Priority

Set callback priority for execution order:

class HighPriorityCallback(Callback):
    priority = 100  # Higher number = earlier execution

class LowPriorityCallback(Callback):
    priority = 0    # Lower number = later execution

7. Avoid Side Effects

Minimize side effects in callbacks:

class CleanCallback(Callback):
    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        # Good: Only logging
        trainer.log("metric", outputs["loss"])

        # Bad: Modifying trainer state
        trainer.model.requires_grad_(False)  # Don't do this

8. Use Callback Groups

Group related callbacks:

class MonitoringGroup(Callback):
    """Group of monitoring callbacks"""

    def __init__(self):
        super().__init__()
        self.callbacks = [
            LossMonitor(),
            GradientMonitor(),
            LearningRateMonitor()
        ]

    def on_train_batch_end(self, trainer, outputs, batch, batch_idx):
        for callback in self.callbacks:
            callback.on_train_batch_end(trainer, outputs, batch, batch_idx)

Next Steps