realtabformer#

Submodules#

Package Contents#

Classes#

REaLTabFormer

class realtabformer.REaLTabFormer(model_type: str, tabular_config: transformers.models.gpt2.GPT2Config | None = None, relational_config: transformers.EncoderDecoderConfig | None = None, parent_realtabformer_path: pathlib.Path | None = None, freeze_parent_model: bool | None = True, checkpoints_dir: str = 'rtf_checkpoints', samples_save_dir: str = 'rtf_samples', epochs: int = 100, batch_size: int = 8, random_state: int = 1029, train_size: float = 1, output_max_length: int = 512, early_stopping_patience: int = 5, early_stopping_threshold: float = 0, mask_rate: float = 0, numeric_nparts: int = 1, numeric_precision: int = 4, numeric_max_len: int = 10, **training_args_kwargs)[source]#
_invalid_model_type(model_type)[source]#
_init_tabular(tabular_config)[source]#
_init_relational(relational_config)[source]#
_extract_column_info(df: pandas.DataFrame) None[source]#
_generate_vocab(df: pandas.DataFrame) dict[source]#
_check_model()[source]#
_split_train_eval_dataset(dataset: datasets.Dataset)[source]#
fit(df: pandas.DataFrame, in_df: pandas.DataFrame | None = None, join_on: str | None = None, resume_from_checkpoint: bool | str = False, device='cuda', num_bootstrap: int = 500, frac: float = 0.165, frac_max_data: int = 10000, qt_max: str | float = 0.05, qt_max_default: float = 0.05, qt_interval: int = 100, qt_interval_unique: int = 100, distance: sklearn.metrics.pairwise.manhattan_distances = manhattan_distances, quantile: float = 0.95, n_critic: int = 5, n_critic_stop: int = 2, gen_rounds: int = 3, sensitivity_max_col_nums: int = 20, use_ks: bool = False, full_sensitivity: bool = False, sensitivity_orig_frac_multiple: int = 4, orig_samples_rounds: int = 5, load_from_best_mean_sensitivity: bool = False, target_col: str = None)[source]#

Train the REaLTabFormer model on the tabular data.

Parameters:
  • df – Pandas DataFrame containing the tabular data that will be generated during sampling. This data goes to the decoder for the relational model.

  • in_df – Pandas DataFrame containing observations related to df, and from which the model will generate data. This data goes to the encoder for the relational model.

  • join_on – Column name that links the df and the in_df tables.

  • resume_from_checkpoint – If True, resumes training from the latest checkpoint in the checkpoints_dir. If path, resumes the training from the given checkpoint.

  • device – Device where the model and the training will be run. Use torch devices, e.g., cpu, cuda, mps (experimental)

  • num_bootstrap – Number of Bootstrap samples

  • frac – The fraction of the data used for training.

  • frac_max_data – The maximum number of rows that the training data will have.

  • qt_max – The maximum quantile for the discriminator.

  • qt_max_default – The default maximum quantile for the discriminator.

  • qt_interval – Interval for the quantile check during the training process.

  • qt_interval_unique – Interval for the quantile check during the training process.

  • distance – Distance metric used for discriminator.

  • quantile – The quantile value that the discriminator will be trained to.

  • n_critic – Interval between epochs to perform a discriminator assessment.

  • n_critic_stop – The number of critic rounds without improvement after which the training will be stopped.

  • gen_rounds – The number of generator rounds.

  • sensitivity_max_col_nums – The maximum number of columns used to compute sensitivity.

  • use_ks – Whether to use KS test or not.

  • full_sensitivity – Whether to use full sensitivity or not.

  • sensitivity_orig_frac_multiple – The size of the training data relative to the chosen frac that will be used in computing the sensitivity. The larger this value is, the more robust the sensitivity threshold will be. However, (sensitivity_orig_frac_multiple + 2) multiplied by frac must be less than 1.

  • orig_samples_rounds – This is the number of train/hold-out samples that will be used to compute the epoch sensitivity value.

  • load_from_best_mean_sensitivity – Whether to load from best mean sensitivity or not.

  • target_col – The target column name.

Returns:

Trainer

_train_with_sensitivity(df: pandas.DataFrame, device: str = 'cuda', num_bootstrap: int = 500, frac: float = 0.165, frac_max_data: int = 10000, qt_max: str | float = 0.05, qt_max_default: float = 0.05, qt_interval: int = 100, qt_interval_unique: int = 100, distance: sklearn.metrics.pairwise.manhattan_distances = manhattan_distances, quantile: float = 0.95, n_critic: int = 5, n_critic_stop: int = 2, gen_rounds: int = 3, sensitivity_max_col_nums: int = 20, use_ks: bool = False, resume_from_checkpoint: bool | str = False, full_sensitivity: bool = False, sensitivity_orig_frac_multiple: int = 4, orig_samples_rounds: int = 5, load_from_best_mean_sensitivity: bool = False)[source]#
_set_up_relational_coder_configs() None[source]#
_fit_relational(out_df: pandas.DataFrame, in_df: pandas.DataFrame, join_on: str, device='cuda')[source]#
_fit_tabular(df: pandas.DataFrame, device='cuda', num_train_epochs: int = None, target_epochs: int = None) transformers.Trainer[source]#
_build_tabular_trainer(device='cuda', num_train_epochs: int = None, target_epochs: int = None) transformers.Trainer[source]#
sample(n_samples: int = None, input_unique_ids: pandas.Series | List | None = None, input_df: pandas.DataFrame | None = None, input_ids: torch.tensor | None = None, gen_batch: int | None = 128, device: str = 'cuda', seed_input: pandas.DataFrame | Dict[str, Any] | None = None, save_samples: bool | None = False, constrain_tokens_gen: bool | None = True, validator: realtabformer.rtf_validators.ObservationValidator | None = None, continuous_empty_limit: int = 10, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, related_num: int | List[int] | None = None, **generate_kwargs) pandas.DataFrame[source]#

Generate synthetic tabular data samples

Parameters:
  • n_samples – Number of synthetic samples to generate for the tabular data.

  • input_unique_ids – The unique identifier that will be used to link the input data to the generated values when sampling for relational data.

  • input_df – Pandas DataFrame containing the tabular input data.

  • input_ids – (NOTE: the input_df argument is the preferred input) The input_ids that conditions the generation of the relational data.

  • gen_batch – Controls the batch size of the data generation process. This parameter should be adjusted based on the compute resources.

  • device – The device used by the generator. Use torch devices, e.g., cpu, cuda, mps (experimental)

  • seed_input – A dictionary of col_name:values for the seed data. Only col_names that are actually in the first sequence of the training input will be used.

  • constrain_tokens_gen – Set whether we impose a constraint at each step of the generation limited only to valid tokens for the column.

  • validator – An instance of ObservationValidator for validating the generated samples. The validators are applied to observations only, and don’t support inter-observation validation. See ObservationValidator docs on how to set up a validator.

  • continuous_invalid_limit – The sampling will raise an exception if continuous_empty_limit empty sample batches have been produced continuously. This will prevent an infinite loop if the quality of the data generated is not good and always produces invalid observations.

  • suppress_tokens – (from docs) A list of tokens that will be supressed at generation. The SupressTokens logit processor will set their log probs to -inf so that they are not sampled. This is a useful feature for imputing missing values.

  • forced_decoder_ids – (from docs) A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, [[1, 123]] means the second generated token will always be a token of index 123. This is a useful feature for constraining the model to generate only specific stratification variables in surveys, e.g., GEO1, URBAN/RURAL variables.

  • related_num – A column name in the input_df containing the number of observations that the child table is expected to have for the parent observation. It can also be an integer if the input_df corresponds to a set of observations having the same number of expected observations. This parameter is only valid for the relational model.

  • generate_kwargs – Additional keywords arguments that will be supplied to .generate method. For a comprehensive list of arguments, see: https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate

Returns:

DataFrame with n_samples rows of generated data

predict(data: pandas.DataFrame, target_col: str, target_pos_val: Any = None, batch: int = 32, obs_sample: int = 30, fillunk: bool = True, device: str = 'cuda', disable_progress_bar: bool = True, **generate_kwargs) pandas.Series[source]#

Use the trained model to make predictions on a given dataframe.

Parameters:
  • data – The data to make predictions on, in the form of a Pandas dataframe.

  • target_col – The name of the target column in the data to predict.

  • target_pos_val – The positive value in the target column to use for binary classification. This is produces a one-to-many prediction relative to target_pos_val for targets that are multi-categorical.

  • batch – The batch size to use when making predictions.

  • obs_sample – The number of observations to sample from the data when making predictions.

  • fillunk – If True, the function will fill any missing values in the data before making predictions. Fill unknown tokens with the mode of the batch in the given step.

  • device – The device to use for prediction. Can be either “cpu” or “cuda”.

  • **generate_kwargs – Additional keyword arguments to pass to the model’s generate method.

Returns:

A Pandas series containing the predicted values for the target column.

save(path: str | pathlib.Path, allow_overwrite: bool | None = False)[source]#

Save REaLTabFormer Model

Saves the model weights and a configuration file in the given directory. :param path: Path where to save the model

classmethod load_from_dir(path: str | pathlib.Path)[source]#

Load a saved REaLTabFormer model

Load trained REaLTabFormer model from directory. :param path: Directory where REaLTabFormer model is saved

Returns:

REaLTabFormer instance