realtabformer.rtf_trainer#

Module Contents#

Classes#

SaveEpochEndCallback

This callback forces a checkpoint save at each epoch end.

ResumableTrainer

This trainer makes the scheduler consistent over pauses

Attributes#

logger

realtabformer.rtf_trainer.logger[source]#
class realtabformer.rtf_trainer.SaveEpochEndCallback(save_epochs: int = None)[source]#

Bases: transformers.TrainerCallback

This callback forces a checkpoint save at each epoch end.

on_epoch_end(args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs)[source]#
class realtabformer.rtf_trainer.ResumableTrainer(target_epochs: int = None, save_epochs: int = None, model: transformers.PreTrainedModel | torch.nn.Module = None, args: transformers.TrainingArguments = None, data_collator: transformers.DataCollator | None = None, train_dataset: datasets.Dataset | None = None, eval_dataset: datasets.Dataset | None = None, tokenizer: transformers.PreTrainedTokenizerBase | None = None, model_init: Callable[[], transformers.PreTrainedModel] = None, compute_metrics: Callable[[transformers.EvalPrediction], Dict] | None = None, callbacks: List[transformers.TrainerCallback] | None = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None)[source]#

Bases: transformers.Trainer

This trainer makes the scheduler consistent over pauses in the training. The scheduler should return values similar to when a training is done either intermittently or continuously over the target_epochs.

create_scheduler(num_training_steps: int, optimizer: torch.optim.Optimizer = None) torch.optim.lr_scheduler.LambdaLR[source]#

Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. :param num_training_steps: The number of training steps to do. :type num_training_steps: int