realtabformer.rtf_sampler
#
This module contains the implementation for the sampling algorithms used for tabular and relational data generation.
Module Contents#
Classes#
Sampler class for tabular data generation. |
|
Sampler class for relational data generation. |
Attributes#
- class realtabformer.rtf_sampler.REaLSampler(model_type: str, model: transformers.PreTrainedModel, vocab: Dict, processed_columns: List, max_length: int, col_size: int, col_idx_ids: Dict, columns: List, datetime_columns: List, column_dtypes: Dict, column_has_missing: Dict, drop_na_cols: List, col_transform_data: Dict, random_state: int | None = 1029, device='cuda')[source]#
-
- _generate(device: torch.device, as_numpy: bool | None = True, constrain_tokens_gen: bool | None = True, **generate_kwargs) torch.tensor | numpy.ndarray [source]#
- _processes_sample(sample_outputs: numpy.ndarray, vocab: Dict, relate_ids: List[Any] | None = None, validator: realtabformer.rtf_validators.ObservationValidator | None = None) pandas.DataFrame [source]#
- _validate_data(synth_df: pandas.DataFrame, validator: realtabformer.rtf_validators.ObservationValidator | None = None) pandas.DataFrame [source]#
- class realtabformer.rtf_sampler.TabularSampler(model_type: str, model: transformers.PreTrainedModel, vocab: Dict, processed_columns: List, max_length: int, col_size: int, col_idx_ids: Dict, columns: List, datetime_columns: List, column_dtypes: Dict, column_has_missing: Dict, drop_na_cols: List, col_transform_data: Dict, random_state: int | None = 1029, device='cuda')[source]#
Bases:
REaLSampler
Sampler class for tabular data generation.
- sample_tabular(n_samples: int, gen_batch: int | None = 128, device: str | None = 'cuda', seed_input: pandas.DataFrame | Dict[str, Any] | None = None, constrain_tokens_gen: bool | None = True, validator: realtabformer.rtf_validators.ObservationValidator | None = None, continuous_empty_limit: int = 10, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, **generate_kwargs) pandas.DataFrame [source]#
- predict(data: pandas.DataFrame, target_col: str, target_pos_val: Any = None, batch: int = 32, obs_sample: int = 30, fillunk: bool = True, device: str = 'cuda', disable_progress_bar: bool = True, **generate_kwargs) pandas.Series [source]#
fillunk: Fill unknown tokens with the mode of the batch. target_pos_val: Categorical value for the positive target. This is produces a
one-to-many prediction relative to target_pos_val for targets that are multi-categorical.
- class realtabformer.rtf_sampler.RelationalSampler(model_type: str, model: transformers.PreTrainedModel, vocab: Dict, processed_columns: List, max_length: int, col_size: int, col_idx_ids: Dict, columns: List, datetime_columns: List, column_dtypes: Dict, column_has_missing: Dict, drop_na_cols: List, col_transform_data: Dict, in_col_transform_data: Dict, random_state: int | None = 1029, device='cuda')[source]#
Bases:
REaLSampler
Sampler class for relational data generation.
- sample_relational(input_unique_ids: pandas.Series | List, input_df: pandas.DataFrame | None = None, input_ids: torch.tensor | None = None, gen_batch: int | None = 128, device: str | None = 'cuda', constrain_tokens_gen: bool | None = True, validator: realtabformer.rtf_validators.ObservationValidator | None = None, continuous_empty_limit: int | None = 10, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, related_num: int | List[int] | None = None, **generate_kwargs) pandas.DataFrame [source]#
- _sample_input_batch(input_df: pandas.DataFrame | None = None, gen_batch: int | None = 128, device: str | None = 'cuda', constrain_tokens_gen: bool | None = True, suppress_tokens: List[int] | None = None, forced_decoder_ids: List[List[int]] | None = None, **generate_kwargs)[source]#
- _get_relational_col_idx_ids(len_ids: int) List [source]#
This method returns the true index given the generation step i.
- col_size: The expected number of variables for a single observation.
This is equal to the number of columns.
### Generating constrained tokens per step ```
1 -> BOS 2 -> BMEM or EOS 3 -> col 0 … 3 + col_size -> col col_size - 1 3 + col_size + 1 -> EMEM 3 + col_size + 2 -> BMEM or EOS 3 + col_size + 3 -> col 0