# Contributing Custom Models for Training

This guide explains how to add custom model implementations to the `optimum/neuron/models/training/` directory. Custom models are needed to support distributed training features like tensor parallelism, pipeline parallelism, and sequence parallelism on AWS Trainium devices.

## Architecture Components

### 1. NeuronModelMixin

The `NeuronModelMixin` class provides core functionality:
- `from_pretrained()`: Loads regular Transformers weights into custom implementations
- `save_pretrained()`: Saves sharded checkpoints with consolidation metadata
- Pipeline parallelism support through `PIPELINE_*` attributes

### 2. Weight Transformation Specs

Transformation specs handle converting weights between:
- Original Transformers format → Custom parallel format (during loading)
- Custom parallel format → Original Transformers format (during checkpoint consolidation)

Key transformation spec types:
- `FusedLinearsSpec`: Handles fused linear layers (e.g., `gate_up_proj`)
- `GQAQKVColumnParallelLinearSpec`: Handles grouped query attention projections when the tensor parallel size is greater than the number of key-value heads

For complete API documentation of all transformation specs and utility functions, see the [Model Weight Transformation Specs API Reference](../training_api/transformations).

### 3. Parallel Layers

Use these parallel layers from `neuronx_distributed`:
- `ColumnParallelLinear`: Splits weight matrix along output dimension
- `RowParallelLinear`: Splits weight matrix along input dimension
- `ParallelEmbedding`: Splits embedding table across ranks
- `GQAQKVColumnParallelLinear`: Specialized for grouped query attention projections when the tensor parallel size is greater than the number of key-value heads

## Implementation Steps

### Step 1: Create Model Structure

Create a new directory: `optimum/neuron/models/training/your_model/`

**`__init__.py`**
```python
from .modeling_your_model import YourModelForCausalLM, YourModel

__all__ = ["YourModelForCausalLM", "YourModel"]
```

### Step 2: Implement the Model Building Blocks

**`modeling_your_model.py`**

#### Imports and Dependencies

```python
import torch
from torch import nn
from neuronx_distributed.parallel_layers.layers import (
    ColumnParallelLinear,
    RowParallelLinear,
    ParallelEmbedding,
)
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from transformers import PreTrainedModel
from transformers.models.your_model import YourModelConfig

from ..config import TrainingNeuronConfig
from ..modeling_utils import NeuronModelMixin
from ..transformations_utils import (
    CustomModule,
    FusedLinearsSpec,
    GQAQKVColumnParallelLinearSpec,
    ModelWeightTransformationSpecs,
)
```

#### Embedding Layer
```python
class YourModelEmbeddings(nn.Module):
    def __init__(self, config, trn_config):
        super().__init__()
        self.embed_tokens = ParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            dtype=config.dtype,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
        )
```

#### MLP Layer with Fused Linears

**Important**: Any module that has transformation specs must inherit from `CustomModule` to ensure proper handling of weight transformations, and the transformation specs must be defined in the `self.specs` attribute.

```python
class YourModelMLP(nn.Module, CustomModule):
    def __init__(self, config, trn_config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        
        # Fused gate and up projections
        self.gate_up_proj = ColumnParallelLinear(
            self.hidden_size,
            2 * self.intermediate_size,
            stride=2,  # Important for proper sharding
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.dtype,
        )
        
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.dtype,
        )
        
        # Define transformation specs
        self.specs = ModelWeightTransformationSpecs()
        self.specs.add_spec(
            FusedLinearsSpec(
                fused_linear_name="gate_up_proj",
                linear_names=["gate_proj", "up_proj"],
                bias=False,
                fuse_axis="column",  # Fuse along output dimension
                original_dims=[self.intermediate_size, self.intermediate_size],
            )
        )
```

#### Attention Layer

The attention layer implementation depends on the model's architecture and tensor parallel configuration. There are three main variants:

**1. Separate Q, K, V Projections (Default)**
```python
class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        self.config = config
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // self.num_heads
        
        # Separate projections for Q, K, V
        self.q_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.dtype,
        )
        self.k_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.dtype,
        )
        self.v_proj = ColumnParallelLinear(
            config.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=False,
            gather_output=False,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.dtype,
        )
        
        self.o_proj = RowParallelLinear(
            self.num_heads * self.head_dim,
            config.hidden_size,
            bias=False,
            input_is_parallel=True,
            sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
            dtype=config.dtype,
        )
        
        # No transformation specs needed - regular parallel layers
        self.specs = ModelWeightTransformationSpecs()
```

**2. Fused QKV Projection (Multi-Head Attention)**
```python
class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Only use fused QKV when num_heads == num_key_value_heads (no GQA)
        if trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
            self.qkv_proj = ColumnParallelLinear(
                config.hidden_size,
                3 * self.num_heads * self.head_dim,  # Q + K + V
                stride=3,  # Important for proper sharding
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                dtype=config.dtype,
            )
            
            # Define transformation specs for fused QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                FusedLinearsSpec(
                    fused_linear_name="qkv_proj",
                    linear_names=["q_proj", "k_proj", "v_proj"],
                    bias=False,
                    fuse_axis="column",
                    original_dims=[self.num_heads * self.head_dim] * 3,
                )
            )
            self.split_size = self.num_heads * self.head_dim // tp_size
```

**3. GQA QKV Projection (Required for Challenging TP Configurations)**
```python
class YourModelAttention(nn.Module, CustomModule):
    def __init__(self, config, trn_config, layer_idx):
        super().__init__()
        # ... (same setup as above)
        
        tp_size = get_tensor_model_parallel_size()
        
        # Use GQA QKV when KV heads can't be evenly distributed across TP ranks
        # This happens when: num_key_value_heads < tp_size or num_key_value_heads % tp_size != 0
        self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
        
        if self.qkv_linear:
            # Calculate KV size multiplier to ensure even distribution
            if trn_config.kv_size_multiplier is None:
                self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
            else:
                self.kv_size_multiplier = trn_config.kv_size_multiplier
                
            self.qkv_proj = GQAQKVColumnParallelLinear(
                config.hidden_size,
                [self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
                bias=False,
                gather_output=False,
                sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
                kv_size_multiplier=self.kv_size_multiplier,
                fuse_qkv=trn_config.fuse_qkv,
                dtype=config.dtype,
            )
            
            # Define transformation specs for GQA QKV
            self.specs = ModelWeightTransformationSpecs()
            self.specs.add_spec(
                GQAQKVColumnParallelLinearSpec(
                    gqa_qkv_projection_name="qkv_proj",
                    query_projection_name="q_proj",
                    key_projection_name="k_proj", 
                    value_projection_name="v_proj",
                    output_projection_name="o_proj",
                    num_attention_heads=self.num_heads,
                    num_key_value_heads=self.num_key_value_heads,
                    kv_size_multiplier=self.kv_size_multiplier,
                    q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
                    kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
                    fuse_qkv=trn_config.fuse_qkv,
                )
            )
```

**When to Use Each Variant:**

- **Separate Q, K, V**: Default approach, works for all configurations but may be less efficient
- **Fused QKV**: Use when `num_heads == num_key_value_heads` (no grouped query attention) and `fuse_qkv=True`
- **GQA QKV**: Required when using grouped query attention with challenging tensor parallel configurations where KV heads cannot be evenly distributed across TP ranks

The choice is typically determined by:
```python
tp_size = get_tensor_model_parallel_size()
use_gqa_qkv = (num_key_value_heads < tp_size) or (num_key_value_heads % tp_size != 0)
use_fused_qkv = trn_config.fuse_qkv and (num_heads == num_key_value_heads) and not use_gqa_qkv
```

### Step 3: Implement Main Model Classes

#### Base Model
```python
class YourPreTrainedModel(PreTrainedModel, NeuronModelMixin):
    config_class = YourModelConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["YourModelDecoderLayer"]
    _skip_keys_device_placement = "past_key_values"
    _supports_flash_attn_2 = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True

class YourModel(NeuronModelMixin, YourPreTrainedModel):
    def __init__(self, config: YourModelConfig, trn_config: TrainingNeuronConfig):
        YourPreTrainedModel.__init__(self, config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.trn_config = trn_config
        
        self.embed_tokens = ParallelEmbedding(...)
        self.layers = nn.ModuleList([
            YourModelDecoderLayer(config, trn_config, layer_idx)
            for layer_idx in range(config.num_hidden_layers)
        ])
        self.norm = YourModelRMSNorm(...)
        
        self.post_init()
```

#### CausalLM Model
```python
class YourModelForCausalLM(NeuronModelMixin, YourPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]
    
    # Pipeline parallelism support
    SUPPORTS_PIPELINE_PARALLELISM = True
    PIPELINE_TRANSFORMER_LAYER_CLS = YourModelDecoderLayer
    PIPELINE_INPUT_NAMES = ["input_ids", "attention_mask"]
    
    def __init__(self, config, trn_config):
        super().__init__(config)
        self.trn_config = trn_config
        self.model = YourModel(config, trn_config)
        self.vocab_size = config.vocab_size
        
        self.lm_head = ColumnParallelLinear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            gather_output=False,
            dtype=config.dtype,
        )
        
        self.post_init()
```

### Step 4: Register Model

Update `optimum/neuron/models/training/__init__.py`:
```python
from .your_model import YourModelForCausalLM, YourModel

__all__ = [..., "YourModelForCausalLM", "YourModel"]
```

Update `optimum/neuron/models/training/auto_models.py`:
```python
from .your_model.modeling_your_model import YourModelForCausalLM, YourModel

# Register the base model (without head)
register_neuron_model_for_training("your_model", "model")(YourModel)

# Register the CausalLM model
register_neuron_model_for_training("your_model", "text-generation")(YourModelForCausalLM)
```

Here `"your_model"` is the corresponding to the `model_type` attribute of your model's configuration class.

## Best Practices

### 1. Parallel Layer Configuration
- Use `gather_output=False` for intermediate layers
- Set `input_is_parallel=True` for layers that receive parallel input
- Configure `sequence_parallel_enabled` consistently across layers
- Use appropriate `stride` values for proper weight sharding

### 2. Weight Transformation Specs
- Always define specs for modules that use fused or parallel layers
- Use `CustomModule` mixin for any module with transformation specs
- Ensure spec parameter names match the actual module structure
- Test both regular and LoRA weight transformations

### 3. Pipeline Parallelism
- Set `SUPPORTS_PIPELINE_PARALLELISM = True` for supported models
- Define `PIPELINE_TRANSFORMER_LAYER_CLS` as your decoder layer class
- List all input names in `PIPELINE_INPUT_NAMES`

### 4. Flash Attention Support
- Set `_supports_flash_attn_2 = True` if your model supports it
- Implement both eager and flash attention paths
- Use appropriate attention function dispatching

## Testing Your Implementation

The training tests in `tests/training/` provide a comprehensive testing framework that validates numerical correctness, distributed training scenarios, and checkpoint compatibility.
Most of the tests are not designed to be run on every custom modeling implementation, but rather to validate the core functionality of the Optimum Neuron training infrastructure.
With that in mind, here's what you need to implement for your custom modeling:

### 1. Custom Modeling Validation

The `test_custom_modeling.py` file validates that your custom implementation produces identical outputs to the original Transformers model:

Update `tests/training/test_custom_modeling.py`:
```python
CUSTOM_MODELINGS_TO_TEST = [
    # ... existing models ...
    ("YourModelForCausalLM", "your-org/your-model-name"),
]
```

**Important**: For custom modeling validation tests, use small/tiny models to ensure CI efficiency. The test models should have:
- Small vocabulary size (e.g., 1000-8000 tokens)
- Few layers (e.g., 2-4 layers)
- Small hidden dimensions (e.g., 128-512)
- Minimal attention heads (e.g., 4-8 heads)

Examples of good test models for custom modeling validation:
- `"michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random"` - 4 layers, 4 KV heads
- `"michaelbenayoun/granite-tiny-4kv-heads-4layers-random"` - Tiny Granite model
- `"michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random"` - Tiny Qwen3 model

Key test your model must pass:
```python
def test_custom_modeling_matches_original()  # Output matching
```

- **Numerical Correctness**: Ensures custom models match Transformers outputs exactly
- **Parallelization Support**: Tests various QKV implementations (regular, fused, GQA)

### 2. End-to-End Training Validation

The `test_overfit.py` file validates training convergence. To include your model in end-to-end training validation, you must add it to the parametrized test cases:

Update `tests/training/test_overfit.py`:
```python
@pytest.mark.parametrize(
    "model_class_name,model_name_or_path,learning_rate,warmup_ratio,training_kwargs,use_flash_attention_2,max_expected_loss,max_length,num_steps",
    [
        # ... existing models ...
        [
            "YourModelForCausalLM",
            "your-org/your-model-name",
            1e-4,
            0.03,
            {},
            True,
            0.5,
            2048,
            50,
        ],
    ],
    ids=[
        # ... existing model IDs ...
        "your-org/your-model-name",
    ],
)
```

This validates:
- **Convergence Validation**: Ensures models can overfit simple datasets

Your model will be tested in:
```python
def test_overfit_custom_modeling_causal_lm()       # Basic training (your model included)
```

### 3. Auto Model Loading

The `test_modeling_auto.py` file validates that your model can be loaded using the `NeuronModel` and `NeuronModelForCausalLM` auto classes. To include your model in these tests, you must add it to the test cases:

Update `tests/training/test_modeling_auto.py`:
```python
@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random", 
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random",
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic

@pytest.mark.parametrize("from_pretrained", [False, True], ids=["from_config", "from_pretrained"])
@distributed_test(world_size=1)
@is_trainium_test
def test_auto_model_for_causal_lm_with_supported_architecture(from_pretrained):
    trn_config = TrainingNeuronConfig()
    kwargs = {"dtype": torch.bfloat16}
    for model_name_or_path in [
        "michaelbenayoun/llama-2-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/granite-tiny-4kv-heads-4layers-random",
        "michaelbenayoun/qwen3-tiny-4kv-heads-4layers-random", 
        "your-org/your-model-name",  # Add your model here
    ]:
        # ... rest of test logic
```

This validates:
- **Auto Model Loading**: Tests that `NeuronModel.from_pretrained()` and `NeuronModel.from_config()` work correctly
- **Auto CausalLM Loading**: Tests that `NeuronModelForCausalLM.from_pretrained()` and `NeuronModelForCausalLM.from_config()` work correctly

### 4. Running Tests

Tests require AWS Trainium instances. Run specific test categories:

```bash
# Run all custom modeling tests
pytest tests/training/test_custom_modeling.py -v

# Run specific model tests
pytest tests/training/test_custom_modeling.py -v -k "your_model"

# Run end-to-end training validation
pytest tests/training/test_overfit.py -v
```

### 5. Test Requirements

Your implementation must:

1. **Pass numerical correctness tests** against original Transformers implementation
2. **Support parallelization strategies** (at minimum DP and TP; PP support recommended)
3. **Handle various QKV implementations** (regular, fused, GQA)
4. **Support checkpoint consolidation** for distributed training
5. **Support LoRA training** if applicable
6. **Demonstrate convergence** through overfitting tests

The testing framework ensures your custom model maintains compatibility with the existing Optimum Neuron training infrastructure while delivering expected performance and correctness guarantees.

## Common Issues

- **Weight Shape Mismatches**: Ensure transformation specs handle tensor shapes correctly
- **Pipeline Parallelism Errors**: Check that all required attributes are set
- **Memory Issues**: Consider gradient checkpointing and activation recomputation
- **Attention Compatibility**: Verify attention implementations work with your model architecture

## Additional Resources

This guide provides the foundation for implementing custom models. For complete examples and advanced patterns, reference these existing implementations:

- **LLaMA**: `optimum/neuron/models/training/llama/modeling_llama.py` - Complete implementation with (regular, fused and GQA attention), fused MLP
- **Qwen3**: `optimum/neuron/models/training/qwen3/modeling_qwen3.py` - Demonstrates how to adapt the Llama implementation for Qwen3 with `q_norm` and `k_norm` layers

Key files to study:
- `optimum/neuron/models/training/modeling_utils.py` - Base `NeuronModelMixin` class
- `optimum/neuron/models/training/transformations_utils.py` - Weight transformation specifications
- `optimum/neuron/models/training/config.py` - `TrainingNeuronConfig` for parallelism settings