Skip to content

Parallelism Guide

This guide explains the different types of parallelism supported by Dream Trainer and how to use them effectively.

For a complete example see dream-trainer/examples/llama3/setup.py

Table of Contents

Overview

Dream Trainer is built around PyTorch's DTensor abstractions, providing a unified interface for all parallelism schemes. Each type of parallelism serves a specific purpose:

  • Data Parallelism: Scale training across multiple GPUs by replicating the model
  • FSDP2: Second-generation Fully-Sharded Data Parallel built on DTensor
  • Tensor Parallelism: Split model parameters across GPUs for larger models
  • Context Parallelism: Handle long sequences by splitting across GPUs
  • Pipeline Parallelism: Split model layers across GPUs for efficient memory usage

Data Parallelism

Data Parallelism is the simplest form of parallelism, where the model is replicated across GPUs and each GPU processes a different batch of data.

Using PyTorch's Replicate API

from dream_trainer.trainer.mixins import ModelSetupMixin
from dream_trainer.configs import DeviceParameters
from torch.distributed.device_mesh import DeviceMesh

config = DreamTrainerConfig(
    device_parameters=DeviceParameters.DDP(
        compile_model=True,
        checkpoint_activations=False,
    )
)

class MyTrainer(ModelSetupMixin):
    def apply_replicate(self, dp_replicate_mesh: DeviceMesh):
        # Wraps the model in place with Distributed Data Parallel
        replicate(self.model, device_mesh=dp_replicate_mesh, bucket_cap_mb=100)

Key Features

  • Simple Setup: Just specify the number of GPUs
  • Linear Scaling: Training speed scales linearly with GPU count
  • Memory Efficient: Each GPU holds a complete model copy
  • Gradient Synchronization: Automatic gradient averaging across GPUs

Best Practices

  1. Use when:

  2. Model fits in GPU memory

  3. Batch size can be increased
  4. Training speed is the priority

  5. Avoid when:

  6. Model is too large for single GPU
  7. Memory efficiency is critical
  8. Need more advanced parallelism

FSDP

Fully Sharded Data Parallel (FSDP) reduces memory usage by sharding model parameters across GPUs.

Using FSDP2 API

from dream_trainer import DreamTrainerConfig
from dream_trainer.configs import DeviceParameters

config = DreamTrainerConfig(
    device_parameters=DeviceParameters.FSDP(
        tensor_parallel=1,
        dp_shard="auto",
        compile_model=True,
        cpu_offload=False,
        checkpoint_activations=False,
    )
)

class MyTrainer(ModelSetupMixin):
    def apply_fully_shard(self, config: dict[str, Any]) -> None:
        # NOTE: if using Pipeline Parallelism, make sure to set reshard_after_forward=False on all layers for optimal performance.
        for layer in self.model.layers:
            fully_shard(layer, **config)
        fully_shard(self.model, **config, reshard_after_forward=False)

Alternatively, we can define our sharding strategy directly on the model with fsdp2-utils for simpler usage. apply_fully_shard will recurisvely call Model.fully_shard on all of the model's submodules that conform to the FullyShard protocol.

Note:
All of the model's layers need to be wrapped with fully_shard to ensure inputs & layers are properly casted to the correct dtype/device specified by the MixedPrecisionPolicy. This casting & device movement is handled internally by FSDP.

import torch.nn as nn
from typing import Any
from fsdp2_utils import apply_fully_shard, FullyShard

class TransformerBlock(nn.Module, FullyShard):
    attention: Attention
    feed_forward: FeedForward
    attention_norm: nn.RMSNorm
    ffn_norm: nn.RMSNorm

    def fully_shard(self, config: dict[str, Any]):
        fully_shard(self.attention, **config)
        fully_shard(self.feed_forward, **config)
        fully_shard(self.attention_norm, **config)
        fully_shard(self.ffn_norm, **config)

class Transformer(nn.Module, FullyShard):
    input: nn.Linear
    layers: nn.ModuleList
    output: nn.Linear

    def fully_shard(self, config: dict[str, Any]):
        fully_shard(self.input, **config)
        fully_shard(self.layers, **config)
        fully_shard(self.output, **config, reshard_after_forward=False)

class MyTrainer(ModelSetupMixin):
    def apply_fully_shard(self, config: dict[str, Any]) -> None:
        # apply_fully_shard will override reshard_after_forward to `False` for all blocks when Pipeline Parallelism is enabled.
        apply_fully_shard(self.model, config, pp_enabled=self.world.pp_enabled)

Key Features

  • Memory Efficiency: Parameters are sharded across GPUs
  • Mixed Precision: Native support for FP16/BF16 mixed precision training
  • Gradient Sharding: Reduces memory during backward pass

Best Practices

  1. Use when:

  2. Model is too large for single GPU

  3. Memory efficiency is important
  4. Training speed can be sacrificed for memory

  5. Configuration Tips:

  6. Choose sharding strategy based on memory constraints
  7. Enable mixed precision for better performance
  8. Use activation checkpointing for very large models

Read more about Fully Sharded Data Parallel (FSDP) in the PyTorch documentation.

Tensor Parallelism

Tensor Parallelism splits model parameters across GPUs, allowing for even larger models. Again, we'll use fsdp2-utils to simplify how we apply tensor parallelism.

Configuration

import torch.nn as nn
from typing import Any
from fsdp2_utils import apply_tensor_parallel, ParallelPlan

from dream_trainer import DreamTrainerConfig
from dream_trainer.configs import DeviceParameters

config = DreamTrainerConfig(
    device_parameters=DeviceParameters.FSDP(
        tensor_parallel="auto",
        dp_shard=1,    # no FSDP
        compile_model=True,
        cpu_offload=False,
        checkpoint_activations=False,
    )
)

class TransformerBlockParallel(ParallelPlan):
    attention_norm: nn.RMSNorm
    attention: "Attention"
    feed_forward: "FeedForward"
    ffn_norm: nn.RMSNorm

    def parallel_plan(self, _):
        return {
            "attention_norm": sequence_parallel(self.attention_norm),
            "attention": prepare_module_input(
                self.attention,
                input_layouts=(Shard(1), None),
                desired_input_layouts=(Replicate(), None),
            ),
            "attention.wq": colwise_parallel(self.attention.wq),
            "attention.wk": colwise_parallel(self.attention.wk),
            "attention.wv": colwise_parallel(self.attention.wv),
            "attention.wo": rowwise_parallel(self.attention.wo, output_layouts=Shard(1)),
            "ffn_norm": sequence_parallel(self.ffn_norm),
            "feed_forward": prepare_module_input(
                self.feed_forward,
                input_layouts=(Shard(1),),
                desired_input_layouts=(Replicate(),),
            ),
            "feed_forward.w1": colwise_parallel(self.feed_forward.w1),
            "feed_forward.w2": rowwise_parallel(self.feed_forward.w2, output_layouts=Shard(1)),
            "feed_forward.w3": colwise_parallel(self.feed_forward.w3),
        }


class TransformerParallel(FullyShard, ParallelPlan):
    tok_embeddings: nn.Embedding
    norm: nn.RMSNorm
    output: nn.Linear
    layers: nn.ModuleDict

    def parallel_plan(self, loss_parallel: bool):
        return (
            {
                "tok_embeddings": rowwise_parallel(
                    self.tok_embeddings,
                    input_layouts=Replicate(),
                    output_layouts=Shard(1),
                ),
                "norm": sequence_parallel(self.norm),
                "output": colwise_parallel(
                    self.output,
                    input_layouts=Shard(1),
                    output_layouts=Shard(-1) if loss_parallel else Replicate(),
                    use_local_output=not loss_parallel,
                ),
            },
        )

class MyTrainer(ModelSetupMixin):
    def apply_fully_shard(self, tp_mesh: DeviceMesh) -> None:
        apply_tensor_parallel(self.model, tp_mesh=tp_mesh, loss_parallel=self.world.loss_parallel_enabled)

Note:
Using fsdp2-utils greatly simplifies the construction of a parallel plan.

Without fsdp2-utils, you would need to manually build a parallel plan using PyTorch's classes like ColwiseParallel or RowwiseParallel for each layer. This process can become complex, especially if you want to use features like fp8 quantization, which would require using Fp8ColwiseParallel or similar classes for the affected layers.

With fsdp2-utils, you only need to define a parallel_plan function for your model or block. The utility will automatically generate the correct plan at runtime, choosing the appropriate parallelization strategy (including fp8 support) for each layer.

Key Features

  • Parameter Sharding: Split large tensors across GPUs
  • Communication Efficiency: Minimizes cross-GPU communication
  • Flexible Sharding: Choose which dimensions to split

Best Practices

  1. Use when:

  2. Model has large parameter tensors

  3. Need more memory efficiency than FSDP
  4. Want to combine with other parallelism

  5. Configuration Tips:

  6. Choose parallel dimension carefully
  7. Consider communication overhead
  8. Use with FSDP for maximum memory efficiency

Context Parallelism

Context Parallelism splits sequences across GPUs, useful for long-context models.

Configuration

from dream_trainer import DreamTrainerConfig
from dream_trainer.configs import DeviceParameters

config = DreamTrainerConfig(
    device_parameters=DeviceParameters.FSDP(
        tensor_parallel=2,    # Split across 2 GPUs
        dp_shard="auto",
    )
)

Key Features

  • Sequence Splitting: Distribute long sequences across GPUs
  • Efficient Attention: Optimized attention computation
  • Overlap Support: Optional computation overlap
  • Memory Efficiency: Reduces memory per GPU

Best Practices

  1. Use when:

  2. Working with long sequences

  3. Attention computation is memory-intensive
  4. Need to process longer contexts

  5. Configuration Tips:

  6. Choose appropriate split dimension
  7. Enable overlap for better performance
  8. Consider communication overhead

Pipeline Parallelism

Pipeline Parallelism splits model layers across GPUs, enabling efficient memory usage.

Configuration

from dream_trainer import DreamTrainerConfig
from dream_trainer.configs import DeviceParameters

config = DreamTrainerConfig(
    device_parameters=DeviceParameters(
        pipeline_parallel_size=2,  # Split across 2 GPUs
        pipeline_parallel_config={
            "num_microbatches": 4,  # Number of microbatches
            "schedule": "1F1B"  # Pipeline schedule
        }
    )
)

class MyTrainer(DreamTrainer):
    def configure_models(self):
        # Model is automatically split into pipeline stages
        self.model = self.model.to(self.device)

Key Features

  • Layer Splitting: Distribute model layers across GPUs
  • Microbatch Support: Process multiple batches in pipeline
  • Efficient Scheduling: Various pipeline schedules available
  • Memory Efficiency: Each GPU holds only its layers

Best Practices

  1. Use when:

  2. Model has many layers

  3. Need to maximize GPU utilization
  4. Memory efficiency is critical

  5. Configuration Tips:

  6. Choose appropriate number of microbatches
  7. Select pipeline schedule based on model
  8. Balance pipeline stages

Combining Parallelism

Dream Trainer makes it easy to combine different types of parallelism.

Example: FSDP + Tensor Parallel

config = DreamTrainerConfig(
    device_parameters=DeviceParameters(
        data_parallel_size=2,  # 2-way data parallel
        tensor_parallel_size=2,  # 2-way tensor parallel
        fsdp_config={
            "sharding_strategy": "FULL_SHARD",
            "mixed_precision": True
        }
    )
)

Example: Pipeline + Context Parallel

config = DreamTrainerConfig(
    device_parameters=DeviceParameters(
        pipeline_parallel_size=2,  # 2-way pipeline parallel
        context_parallel_size=2,  # 2-way context parallel
        pipeline_parallel_config={
            "num_microbatches": 4,
            "schedule": "1F1B"
        }
    )
)

Best Practices for Combining

  1. Start Simple: Begin with one type of parallelism
  2. Add Gradually: Add more parallelism as needed
  3. Monitor Performance: Watch for communication overhead
  4. Balance Resources: Ensure even distribution of work
  5. Consider Memory: Account for memory requirements

Common Issues

Memory Issues

  • Out of Memory: Reduce parallelism degree or enable mixed precision
  • Uneven Memory: Balance pipeline stages or tensor sharding
  • Gradient Memory: Use gradient checkpointing or FSDP

Performance Issues

  • Slow Training: Check communication overhead
  • Poor Scaling: Verify batch size and parallelism configuration
  • Bottlenecks: Profile to identify communication bottlenecks

Debugging Tips

  1. Start with small models and data
  2. Enable detailed logging
  3. Use PyTorch profiler
  4. Monitor GPU utilization
  5. Check communication patterns

Next Steps