AWS Trainium & Inferentia documentation
Neuron TRL Trainers
Neuron TRL Trainers
TRL-compatible trainers for AWS Trainium accelerators.
NeuronSFTTrainer
NeuronSFTConfig
class optimum.neuron.NeuronSFTConfig
< source >( output_dir: str | None = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' per_device_train_batch_size: int = 1 per_device_eval_batch_size: int = 1 gradient_accumulation_steps: int = 1 learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict[str, typing.Any] | str | None = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'info' log_level_replica: str = 'silent' logging_dir: str | None = None logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_first_step: bool = False logging_steps: float = 500 save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_total_limit: int | None = None save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False seed: int = 42 bf16: bool = False dataloader_drop_last: bool = False eval_steps: float | None = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: int | None = None run_name: str | None = None disable_tqdm: bool | None = None remove_unused_columns: bool | None = True label_names: list[str] | None = None accelerator_config: dict | str | None = None label_smoothing_factor: float = 0.0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch' optim_args: str | None = None report_to: None | str | list[str] = None resume_from_checkpoint: str | None = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None use_liger_kernel: bool | None = False average_tokens_across_devices: bool | None = False dataloader_prefetch_size: int = None skip_cache_push: bool = False use_autocast: bool = False zero_1: bool = True stochastic_rounding_enabled: bool = True optimizer_use_master_weights: bool = True optimizer_use_fp32_grad_acc: bool = True optimizer_save_master_weights_in_ckpt: bool = False tensor_parallel_size: int = 1 disable_sequence_parallel: bool = False pipeline_parallel_size: int = 1 pipeline_parallel_num_microbatches: int = -1 kv_size_multiplier: int | None = None num_local_ranks_per_step: int = 8 use_xser: bool = True async_save: bool = False fuse_qkv: bool = False recompute_causal_mask: bool = True )
NeuronSFTTrainer
class optimum.neuron.NeuronSFTTrainer
< source >( model: transformers.modeling_utils.PreTrainedModel | torch.nn.modules.module.Module | str args: optimum.neuron.trainers.sft_trainer.SFTConfig | None = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: Dataset | IterableDataset | datasets.Dataset | None = None eval_dataset: Dataset | dict[str, Dataset] | datasets.Dataset | None = None processsing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | transformers.processing_utils.ProcessorMixin | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None) optimizer_cls_and_kwargs: tuple[type[torch.optim.optimizer.Optimizer], dict[str, typing.Any]] | None = None tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase | None = None peft_config: peft.config.PeftConfig | None = None formatting_func: typing.Optional[typing.Callable] = None )
SFTTrainer
adapted for Neuron.
It differs from the original SFTTrainer
by:
- Using
_TrainerForNeuron.__init__()
instead ofTrainer.__init__()
- Using the
_TrainerForNeuron.train()
instead ofTrainer.train()
- Adapts the
_prepare_non_packed_dataloader
to pad to max length. In the originalSFTTrainer
examples are not padded, which is an issue here because it triggers compilation every time.