realtabformer.rtf_sampler#

This module contains the implementation for the sampling algorithms used for tabular and relational data generation.

Module Contents#

Classes#

REaLSampler

TabularSampler

Sampler class for tabular data generation.

RelationalSampler

Sampler class for relational data generation.

Attributes#

NQ_COL

realtabformer.rtf_sampler.NQ_COL = '_nq_ds_'[source]#
class realtabformer.rtf_sampler.REaLSampler(model_type: str, model: transformers.PreTrainedModel, vocab: Dict, processed_columns: List, max_length: int, col_size: int, col_idx_ids: Dict, columns: List, datetime_columns: List, column_dtypes: Dict, column_has_missing: Dict, drop_na_cols: List, col_transform_data: Dict, random_state: int | None = 1029, device='cuda')[source]#
abstract _prefix_allowed_tokens_fn(batch_id, input_ids) List[source]#
_convert_to_table(synth_df: pandas.DataFrame) pandas.DataFrame[source]#
_generate(device: torch.device, as_numpy: bool | None = True, constrain_tokens_gen: bool | None = True, **generate_kwargs) torch.tensor | numpy.ndarray[source]#
_validate_synth_sample(synth_sample: pandas.DataFrame) pandas.DataFrame[source]#
_recover_data_values(synth_sample: pandas.DataFrame) pandas.DataFrame[source]#
_processes_sample(sample_outputs: numpy.ndarray, vocab: Dict, relate_ids: List[Any] | None = None, validator: realtabformer.rtf_validators.ObservationValidator | None = None) pandas.DataFrame[source]#
_validate_data(synth_df: pandas.DataFrame, validator: realtabformer.rtf_validators.ObservationValidator | None = None) pandas.DataFrame[source]#
_validate_missing(synth_df: pandas.DataFrame) pandas.DataFrame[source]#
class realtabformer.rtf_sampler.TabularSampler(model_type: str, model: transformers.PreTrainedModel, vocab: Dict, processed_columns: List, max_length: int, col_size: int, col_idx_ids: Dict, columns: List, datetime_columns: List, column_dtypes: Dict, column_has_missing: Dict, drop_na_cols: List, col_transform_data: Dict, random_state: int | None = 1029, device='cuda')[source]#

Bases: REaLSampler

Sampler class for tabular data generation.

static sampler_from_model(rtf_model, device: str = 'cuda')[source]#
_prefix_allowed_tokens_fn(batch_id, input_ids) List[source]#
_process_seed_input(seed_input: pandas.DataFrame | Dict[str, Any]) torch.Tensor[source]#
sample_tabular(n_samples: int, gen_batch: int | None = 128, device: str | None = 'cuda', seed_input: pandas.DataFrame | Dict[str, Any] | None = None, constrain_tokens_gen: bool | None = True, validator: realtabformer.rtf_validators.ObservationValidator | None = None, continuous_empty_limit: int = 10, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, **generate_kwargs) pandas.DataFrame[source]#
predict(data: pandas.DataFrame, target_col: str, target_pos_val: Any = None, batch: int = 32, obs_sample: int = 30, fillunk: bool = True, device: str = 'cuda', disable_progress_bar: bool = True, **generate_kwargs) pandas.Series[source]#

fillunk: Fill unknown tokens with the mode of the batch. target_pos_val: Categorical value for the positive target. This is produces a

one-to-many prediction relative to target_pos_val for targets that are multi-categorical.

class realtabformer.rtf_sampler.RelationalSampler(model_type: str, model: transformers.PreTrainedModel, vocab: Dict, processed_columns: List, max_length: int, col_size: int, col_idx_ids: Dict, columns: List, datetime_columns: List, column_dtypes: Dict, column_has_missing: Dict, drop_na_cols: List, col_transform_data: Dict, in_col_transform_data: Dict, random_state: int | None = 1029, device='cuda')[source]#

Bases: REaLSampler

Sampler class for relational data generation.

static sampler_from_model(rtf_model, device: str = 'cuda')[source]#
sample_relational(input_unique_ids: pandas.Series | List, input_df: pandas.DataFrame | None = None, input_ids: torch.tensor | None = None, gen_batch: int | None = 128, device: str | None = 'cuda', constrain_tokens_gen: bool | None = True, validator: realtabformer.rtf_validators.ObservationValidator | None = None, continuous_empty_limit: int | None = 10, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, related_num: int | List[int] | None = None, **generate_kwargs) pandas.DataFrame[source]#
_get_min_max_length(related_num)[source]#
_sample_input_batch(input_df: pandas.DataFrame | None = None, gen_batch: int | None = 128, device: str | None = 'cuda', constrain_tokens_gen: bool | None = True, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, **generate_kwargs)[source]#
_get_relational_col_idx_ids(len_ids: int) List[source]#

This method returns the true index given the generation step i.

col_size: The expected number of variables for a single observation.

This is equal to the number of columns.

### Generating constrained tokens per step ```

1 -> BOS 2 -> BMEM or EOS 3 -> col 0 … 3 + col_size -> col col_size - 1 3 + col_size + 1 -> EMEM 3 + col_size + 2 -> BMEM or EOS 3 + col_size + 3 -> col 0

```

_prefix_allowed_tokens_fn(batch_id, input_ids) List[source]#