TRL documentation
SDFT
SDFT
Self-Distilled Fine-Tuning (SDFT) is described in Self-Training with On-Policy Self-Distillation for Language Model Alignment.
The TRL implementation adapts SDFT to the experimental trainer API while reusing the shared self-distillation infrastructure also used by SDPO.
In the current TRL implementation:
- the teacher is the model itself (base weights with adapter disabled for PEFT, or the same model under
no_gradfor non-PEFT); usesync_ref_model=Truefor an EMA teacher - the dataset must provide both
promptandprivileged_context privileged_contextcontains only the extra teacher-only information; the trainer combines it withpromptto build the teacher promptteacher_prompt_templatecontrols howpromptandprivileged_contextare combined into the teacher prompt- on-policy generation can use either the student prompt or the teacher-conditioned prompt via
generate_from_teacher num_loss_tokens_to_skipcan exclude initial completion tokens from the distillation loss- SDFT currently supports text-only training and does not support
use_vllm=True - the shared dataset contract is
promptplusprivileged_context
Usage
from datasets import Dataset
from trl.experimental.sdft import SDFTConfig, SDFTTrainer
dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Solve 2+2."}]],
"privileged_context": ["Example answer: 4."],
}
)
training_args = SDFTConfig(
output_dir="sdft-model",
distillation_alpha=0.5,
distillation_topk=5,
max_completion_length=64,
)
trainer = SDFTTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
args=training_args,
train_dataset=dataset,
)
trainer.train()To generate from the teacher-conditioned prompt instead of the student prompt, set generate_from_teacher=True.
To customize how the teacher prompt is built, set teacher_prompt_template on SDFTConfig.
Expected dataset columns
Each example must provide:
prompt: the student-facing promptprivileged_context: only the extra teacher-only information, such as a demonstration, hint, or privileged feedback
Both standard text prompts and conversational prompts are supported by the trainer prompt handling.
Callbacks
The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows.
Shared self-distillation hooks:
on_self_distillation_batch_prepared: fired when a self-distillation batch is ready. The payload includesprompt_ids,completion_ids, andold_per_token_logpswhen importance-sampling clipping inputs are available.on_generation_batch_built: fired when a new buffered generation batch is created. The payload includesgenerate_everyandsteps_per_generation.
SDFT-specific hook:
on_generation_prompts_selected: fired when SDFT chooses the prompt source for on-policy generation. The payload includes the selectedgeneration_promptsand the correspondinggeneration_prompt_text.
SDFTConfig
class trl.experimental.sdft.SDFTConfig
< source >( output_dir: str | None = None per_device_train_batch_size: int = 8 num_train_epochs: float = 3.0 max_steps: int = -1 learning_rate: float = 5e-05 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_steps: float = 0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 optim_target_modules: None | str | list[str] = None gradient_accumulation_steps: int = 1 average_tokens_across_devices: bool = True max_grad_norm: float = 1.0 label_smoothing_factor: float = 0.0 bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None use_cache: bool = False neftune_noise_alpha: float | None = None torch_empty_cache_steps: int | None = None auto_find_batch_size: bool = False logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_steps: float = 10 logging_first_step: bool = False log_on_each_node: bool = True logging_nan_inf_filter: bool = True include_num_input_tokens_seen: str | bool = 'no' log_level: str = 'passive' log_level_replica: str = 'warning' disable_tqdm: bool | None = None report_to: None | str | list[str] = 'none' run_name: str | None = None project: str = 'huggingface' trackio_space_id: str | None = 'trackio' eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' eval_steps: float | None = None eval_delay: float = 0 per_device_eval_batch_size: int = 8 prediction_loss_only: bool = False eval_on_start: bool = False eval_do_concat_batches: bool = True eval_use_gather_object: bool = False eval_accumulation_steps: int | None = None include_for_metrics: list = <factory> batch_eval_metrics: bool = False save_only_model: bool = False save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_on_each_node: bool = False save_total_limit: int | None = None enable_jit_checkpoint: bool = False push_to_hub: bool = False hub_token: str | None = None hub_private_repo: bool | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_always_push: bool = False hub_revision: str | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False restore_callback_states_from_checkpoint: bool = False full_determinism: bool = False seed: int = 42 data_seed: int | None = None use_cpu: bool = False accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None dataloader_drop_last: bool = False dataloader_num_workers: int = 0 dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False dataloader_prefetch_factor: int | None = None remove_unused_columns: bool = False label_names: list[str] | None = None train_sampling_strategy: str = 'random' length_column_name: str = 'length' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None ddp_backend: str | None = None ddp_timeout: int = 1800 fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None deepspeed: dict | str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' skip_memory_metrics: bool = True do_train: bool = False do_eval: bool = False do_predict: bool = False resume_from_checkpoint: str | None = None warmup_ratio: float | None = None logging_dir: str | None = None local_rank: int = -1 model_init_kwargs: dict[str, typing.Any] | None = None disable_dropout: bool = True max_prompt_length: int | None = 512 num_generations: int = 8 num_generations_eval: int | None = None max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int = 0 min_p: float | None = None generation_kwargs: dict[str, typing.Any] | None = None chat_template_kwargs: dict[str, typing.Any] | None = None repetition_penalty: float = 1.0 use_transformers_paged: bool = False cache_implementation: str | None = None use_vllm: bool = False beta: float = 0.0 num_iterations: int = 1 epsilon: float = 0.2 epsilon_high: float | None = None importance_sampling_level: str = 'token' reward_weights: list[float] | None = None scale_rewards: str | bool = 'group' loss_type: str = 'dapo' mask_truncated_completions: bool = False sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.6 ref_model_sync_steps: int = 512 top_entropy_quantile: float = 1.0 distillation_alpha: float = 0.5 distillation_topk: int | None = 100 full_logit_distillation: bool = False distillation_is_clip: float | None = 2.0 distillation_add_tail: bool = False distillation_weight: float = 1.0 diagnostics_warning_interval: int = 10 diagnostics_flat_tolerance: float = 1e-08 generate_from_teacher: bool = False teacher_prompt_template: str = '{prompt}\n\n{privileged_context}' num_loss_tokens_to_skip: int = 0 )
Parameters
- disable_dropout (
bool, optional, defaults toTrue) — Whether to disable dropout in the student and teacher models. - generate_from_teacher (
bool, optional, defaults toFalse) — Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. - teacher_prompt_template (
str, optional, defaults to"{prompt}\n\n{privileged_context}") — Template used to combine the student prompt and privileged context into the teacher prompt. - num_loss_tokens_to_skip (
int, optional, defaults to0) — Number of initial completion tokens to exclude from the distillation loss.
Configuration class for SDFTTrainer.
This adapts the official SDFT implementation to the TRL trainer API while reusing the common self-distillation configuration shared with SDPO.
SDFTTrainer
class trl.experimental.sdft.SDFTTrainer
< source >( model: str | PreTrainedModel | nn.Module args: SDFTConfig | None = None train_dataset: Dataset | IterableDataset | None = None eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None callbacks: list[TrainerCallback] | None = None optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None) peft_config: PeftConfig | None = None )
Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.
train
< source >( resume_from_checkpoint: str | bool | None = None trial: optuna.Trial | dict[str, Any] | None = None ignore_keys_for_eval: list[str] | None = None ) → ~trainer_utils.TrainOutput
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.