Source code for realtabformer.realtabformer
"""The REaLTabFormer implements the model training and data processing
for tabular and relational data.
"""
import json
import logging
import math
import os
import random
import shutil
import time
import warnings
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sklearn.metrics.pairwise import manhattan_distances
# from sklearn.metrics import accuracy_score
from transformers import (
EarlyStoppingCallback,
EncoderDecoderConfig,
EncoderDecoderModel,
PreTrainedModel,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Trainer,
TrainingArguments,
)
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
import realtabformer
from .data_utils import (
ModelFileName,
ModelType,
SpecialTokens,
TabularArtefact,
build_vocab,
make_dataset,
make_relational_dataset,
process_data,
)
from .rtf_analyze import SyntheticDataBench
from .rtf_datacollator import RelationalDataCollator
from .rtf_exceptions import SampleEmptyLimitError
from .rtf_sampler import RelationalSampler, TabularSampler
from .rtf_trainer import ResumableTrainer
from .rtf_validators import ObservationValidator
[docs]def _normalize_gpt2_state_dict(state_dict):
state = []
for key, value in state_dict.items():
if key.startswith("transformer."):
# The saved state prefixes the weight names
# with `transformer.` whereas the
# encoder expects the weight names to not
# have the prefix.
key = key.replace("transformer.", "")
state.append((key, value))
return OrderedDict(state)
[docs]def _validate_get_device(device: str) -> str:
if (device == "cuda") and (torch.cuda.device_count() == 0):
if torch.backends.mps.is_available():
_device = "mps"
else:
_device = "cpu"
warnings.warn(
f"The device={device} is not available, using device={_device} instead."
)
device = _device
return device
[docs]class REaLTabFormer:
def __init__(
self,
model_type: str,
tabular_config: Optional[GPT2Config] = None,
relational_config: Optional[EncoderDecoderConfig] = None,
parent_realtabformer_path: Optional[Path] = None,
freeze_parent_model: Optional[bool] = True,
checkpoints_dir: str = "rtf_checkpoints",
samples_save_dir: str = "rtf_samples",
epochs: int = 1000,
batch_size: int = 8,
random_state: int = 1029,
train_size: float = 1,
output_max_length: int = 512,
early_stopping_patience: int = 5,
early_stopping_threshold: float = 0,
mask_rate: float = 0,
numeric_nparts: int = 1,
numeric_precision: int = 4,
numeric_max_len: int = 10,
**training_args_kwargs,
) -> None:
"""Set up a REaLTabFormer instance.
Args:
model_type: Explicit declaration of which type of model will be used.
Can take `tabular` and `relational` as valid values.
tabular_config: GPT2Config instance to customize the GPT2 model for tabular data
generation.
relational_config: EncoderDecoderConfig instance that defines the encoder and decoder
configs for the encoder-decoder model used for the relational data generation. See
link for example: https://huggingface.co/docs/transformers/model_doc/encoder-decoder
parent_realtabformer_path: Path to a saved tabular REaLTabFormer model trained on the
parent table of a relational tabular data.
freeze_parent_model: Boolean flag indicating whether the parent-based encoder will be
frozen or not.
checkpoints_dir: Directory where the training checkpoints will be saved
samples_save_dir: Save the samples generated by this model in this directory.
epochs: Number of epochs for training the GPT2LM model. Use a large number of epochs to take advantage of the framework's optimal termination feature for the non-relational tabular data model. Defaults to 1000.
batch_size: Batch size used for training. Must be adjusted based on the available
compute resource. TrainingArguments is set to use `gradient_accumulation_steps=4`
which will have an effective batch_size of 32 for the default value.
train_size: Fraction of the data that will be passed to the `.fit` method that will
be used for training. The remaining will be used as validation data.
output_max_length: Truncation length for the number of output token ids in the
relational model. This limit applies to the processed data and not the raw number
of variables. This is not used in the tabular data model.
early_stopping_patience: Number of evaluation rounds without improvement before
stopping the training.
early_stopping_threshold: See link
https://huggingface.co/docs/transformers/main_classes/callback#transformers.EarlyStoppingCallback.early_stopping_threshold(float,
mask_rate: The rate of tokens in the transformed observation that will be replaced
with the [RMASK] token for regularization during training.
training_args_kwargs: Keyword arguments for the `TrainingArguments` used in training
the model. Arguments such as `output_dir`, `num_train_epochs`,
`per_device_train_batch_size`, `per_device_eval_batch_size` if passed will be
replaced by `checkpoints_dir`, `epochs`, `batch_size arguments`. The comprehensive
set of options can be found in
https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
"""
self.model: PreTrainedModel = None
# This will be set during and will also be deleted after training.
self.dataset = None
if model_type not in ModelType.types():
self._invalid_model_type(model_type)
self.model_type = model_type
if self.model_type == ModelType.tabular:
self._init_tabular(tabular_config)
elif self.model_type == ModelType.relational:
self.parent_vocab = None
self.parent_gpt2_config = None
self.parent_gpt2_state_dict = None
self.parent_col_transform_data = None
self.freeze_parent_model = freeze_parent_model
if parent_realtabformer_path is not None:
parent_realtabformer_path = Path(parent_realtabformer_path)
parent_config = json.loads(
(
parent_realtabformer_path / ModelFileName.rtf_config_json
).read_text()
)
self.parent_col_transform_data = parent_config["col_transform_data"]
self.parent_vocab = parent_config["vocab"]
self.parent_gpt2_config = parent_config["tabular_config"]
self.parent_gpt2_state_dict = _normalize_gpt2_state_dict(
torch.load(parent_realtabformer_path / ModelFileName.rtf_model_pt)
)
if output_max_length is None:
warnings.warn(
"The `output_max_length` is None. This could result to extended model training if the output length has large variations."
)
self.output_max_length = output_max_length
self._init_relational(relational_config)
else:
self._invalid_model_type(self.model_type)
self.checkpoints_dir = Path(checkpoints_dir)
self.samples_save_dir = Path(samples_save_dir)
self.epochs = epochs
self.batch_size = batch_size
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
self.training_args_kwargs = dict(
evaluation_strategy="steps",
output_dir=self.checkpoints_dir.as_posix(),
metric_for_best_model="loss", # This will be replaced with "eval_loss" if `train_size` < 1
overwrite_output_dir=True,
num_train_epochs=self.epochs,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size,
gradient_accumulation_steps=4,
remove_unused_columns=True,
logging_steps=100,
save_steps=100,
eval_steps=100,
load_best_model_at_end=True,
save_total_limit=early_stopping_patience + 1,
optim="adamw_torch",
)
# Remove experiment params from `training_args_kwargs`
for p in [
"output_dir",
"num_train_epochs",
"per_device_train_batch_size",
"per_device_eval_batch_size",
]:
if p in training_args_kwargs:
warnings.warn(
f"Argument {p} was passed in training_args_kwargs but will be ignored..."
)
training_args_kwargs.pop(p)
self.training_args_kwargs.update(training_args_kwargs)
self.train_size = train_size
self.mask_rate = mask_rate
self.columns: List[str] = []
self.column_dtypes: Dict[str, type] = {}
self.column_has_missing: Dict[str, bool] = {}
self.drop_na_cols: List[str] = []
self.processed_columns: List[str] = []
self.numeric_columns: List[str] = []
self.datetime_columns: List[str] = []
self.vocab: Dict[str, dict] = {}
# Output length for generator model
# including special tokens.
self.tabular_max_length = None
self.relational_max_length = None
# Number of derived columns for the relational
# and tabular data after performing the data transformation.
# This will be used as record size validator in the
# sampling stage.
self.tabular_col_size = None
self.relational_col_size = None
# This stores the transformation
# parameters for numeric columns.
self.col_transform_data: Optional[Dict] = None
# This is the col_transform_data
# for the relational models's in_df.
self.in_col_transform_data: Optional[Dict] = None
self.col_idx_ids: Dict[int, list] = {}
self.random_state = random_state
self.numeric_nparts = numeric_nparts
self.numeric_precision = numeric_precision
self.numeric_max_len = numeric_max_len
# A unique identifier for the experiment set after the
# model is trained.
self.experiment_id = None
self.trainer_state = None
# Target column, when set, a copy of the column values will be
# implicitly placed at the beginning of the dataframe.
self.target_col = None
self.realtabformer_version = realtabformer.__version__
[docs] def _invalid_model_type(self, model_type):
raise ValueError(
f"Model type: {model_type} is not valid. REaLTabFormer only supports \
`tabular` and `relational` values."
)
[docs] def _init_tabular(self, tabular_config):
if tabular_config is not None:
warnings.warn(
"The `bos_token_id`, `eos_token_id`, and `vocab_size` attributes will \
be replaced when the `.fit` method is run."
)
else:
# Default is 12, use 6 for distill-gpt2 as default
tabular_config = GPT2Config(n_layer=6)
self.tabular_config = tabular_config
self.model = None
[docs] def _init_relational(self, relational_config):
if relational_config is not None:
warnings.warn(
"The `bos_token_id`, `eos_token_id`, and `vocab_size` attributes for the \
encoder and decoder will be replaced when the `.fit` method is run."
)
else:
# Default is 12, use 6 for distill-gpt2 as default
relational_config = EncoderDecoderConfig(
encoder=GPT2Config(n_layer=6).to_dict(),
decoder=GPT2Config(n_layer=6).to_dict(),
)
if self.parent_gpt2_config is not None:
warnings.warn(
"A trained model for the parent table is available. The encoder will use the \
pretrained config and weights."
)
relational_config.encoder = GPT2Config(**self.parent_gpt2_config)
self.relational_config = relational_config
self.model = None
[docs] def _extract_column_info(self, df: pd.DataFrame) -> None:
# Track the column order of the original data
self.columns = df.columns.to_list()
# Store the dtypes of the columns
self.column_dtypes = df.dtypes.astype(str).to_dict()
# Track which variables have missing values
self.column_has_missing = (df.isnull().sum() > 0).to_dict()
# Get the columns where there should be no missing values
self.drop_na_cols = [
col for col, has_na in self.column_has_missing.items() if not has_na
]
# Identify the numeric columns. These will undergo
# special preprocessing.
self.numeric_columns = df.select_dtypes(include=np.number).columns.to_list()
# Identify the datetime columns. These will undergo
# special preprocessing.
self.datetime_columns = df.select_dtypes(include="datetime").columns.to_list()
[docs] def _generate_vocab(self, df: pd.DataFrame) -> dict:
return build_vocab(df, special_tokens=SpecialTokens.tokens(), add_columns=False)
[docs] def _check_model(self):
assert self.model is not None, "Model is None. Train the model first!"
[docs] def _split_train_eval_dataset(self, dataset: Dataset):
test_size = 1 - self.train_size
if test_size > 0:
dataset = dataset.train_test_split(
test_size=test_size, seed=self.random_state
)
dataset["train_dataset"] = dataset.pop("train")
dataset["eval_dataset"] = dataset.pop("test")
# Override `metric_for_best_model` from "loss" to "eval_loss"
self.training_args_kwargs["metric_for_best_model"] = "eval_loss"
# Make this explicit so that no assumption is made on the
# direction of the metric improvement.
self.training_args_kwargs["greater_is_better"] = False
else:
dataset = dict(train_dataset=dataset)
self.training_args_kwargs["evaluation_strategy"] = "no"
self.training_args_kwargs["load_best_model_at_end"] = False
return dataset
[docs] def fit(
self,
df: pd.DataFrame,
in_df: Optional[pd.DataFrame] = None,
join_on: Optional[str] = None,
resume_from_checkpoint: Union[bool, str] = False,
device="cuda",
num_bootstrap: int = 500,
frac: float = 0.165,
frac_max_data: int = 10000,
qt_max: Union[str, float] = 0.05,
qt_max_default: float = 0.05,
qt_interval: int = 100,
qt_interval_unique: int = 100,
distance: manhattan_distances = manhattan_distances,
quantile: float = 0.95,
n_critic: int = 5,
n_critic_stop: int = 2,
gen_rounds: int = 3,
sensitivity_max_col_nums: int = 20,
use_ks: bool = False,
full_sensitivity: bool = False,
sensitivity_orig_frac_multiple: int = 4,
orig_samples_rounds: int = 5,
load_from_best_mean_sensitivity: bool = False,
target_col: str = None,
):
"""Train the REaLTabFormer model on the tabular data.
Args:
df: Pandas DataFrame containing the tabular data that will be generated during sampling.
This data goes to the decoder for the relational model.
in_df: Pandas DataFrame containing observations related to `df`, and from which the
model will generate data. This data goes to the encoder for the relational model.
join_on: Column name that links the `df` and the `in_df` tables.
resume_from_checkpoint: If True, resumes training from the latest checkpoint in the
checkpoints_dir. If path, resumes the training from the given checkpoint.
device: Device where the model and the training will be run.
Use torch devices, e.g., `cpu`, `cuda`, `mps` (experimental)
num_bootstrap: Number of Bootstrap samples
frac: The fraction of the data used for training.
frac_max_data: The maximum number of rows that the training data will have.
qt_max: The maximum quantile for the discriminator.
qt_max_default: The default maximum quantile for the discriminator.
qt_interval: Interval for the quantile check during the training process.
qt_interval_unique: Interval for the quantile check during the training process.
distance: Distance metric used for discriminator.
quantile: The quantile value that the discriminator will be trained to.
n_critic: Interval between epochs to perform a discriminator assessment.
n_critic_stop: The number of critic rounds without improvement after which the training
will be stopped.
gen_rounds: The number of generator rounds.
sensitivity_max_col_nums: The maximum number of columns used to compute sensitivity.
use_ks: Whether to use KS test or not.
full_sensitivity: Whether to use full sensitivity or not.
sensitivity_orig_frac_multiple: The size of the training data relative to the chosen
`frac` that will be used in computing the sensitivity. The larger this value is, the
more robust the sensitivity threshold will be. However,
`(sensitivity_orig_frac_multiple + 2)` multiplied by `frac` must be less than 1.
orig_samples_rounds: This is the number of train/hold-out samples that will be used to
compute the epoch sensitivity value.
load_from_best_mean_sensitivity: Whether to load from best mean sensitivity or not.
target_col: The target column name.
Returns:
Trainer
"""
device = _validate_get_device(device)
# Set target col for teacher forcing
self.target_col = target_col
# Set the seed for, *hopefully*, replicability.
# This may cause an unexpected behavior when using
# the resume_from_checkpoint option.
if self.random_state:
random.seed(self.random_state)
np.random.seed(self.random_state)
torch.manual_seed(self.random_state)
torch.cuda.manual_seed_all(self.random_state)
if self.model_type == ModelType.tabular:
if n_critic <= 0:
trainer = self._fit_tabular(df, device=device)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer = self._train_with_sensitivity(
df,
device,
num_bootstrap=num_bootstrap,
frac=frac,
frac_max_data=frac_max_data,
qt_max=qt_max,
qt_max_default=qt_max_default,
qt_interval=qt_interval,
qt_interval_unique=qt_interval_unique,
distance=distance,
quantile=quantile,
n_critic=n_critic,
n_critic_stop=n_critic_stop,
gen_rounds=gen_rounds,
resume_from_checkpoint=resume_from_checkpoint,
sensitivity_max_col_nums=sensitivity_max_col_nums,
use_ks=use_ks,
full_sensitivity=full_sensitivity,
sensitivity_orig_frac_multiple=sensitivity_orig_frac_multiple,
orig_samples_rounds=orig_samples_rounds,
load_from_best_mean_sensitivity=load_from_best_mean_sensitivity,
)
del self.dataset
elif self.model_type == ModelType.relational:
assert (
in_df is not None
), "The REaLTabFormer for relational data requires two tables for training."
assert join_on is not None, "The column to join the data must not be None."
trainer = self._fit_relational(df, in_df, join_on=join_on, device=device)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
self._invalid_model_type(self.model_type)
try:
self.experiment_id = f"id{int((time.time() * 10 ** 10)):024}"
torch.cuda.empty_cache()
return trainer
except Exception as exception:
if device == torch.device("cuda"):
del self.model
torch.cuda.empty_cache()
self.model = None
raise exception
[docs] def _train_with_sensitivity(
self,
df: pd.DataFrame,
device: str = "cuda",
num_bootstrap: int = 500,
frac: float = 0.165,
frac_max_data: int = 10000,
qt_max: Union[str, float] = 0.05,
qt_max_default: float = 0.05,
qt_interval: int = 100,
qt_interval_unique: int = 100,
distance: manhattan_distances = manhattan_distances,
quantile: float = 0.95,
n_critic: int = 5,
n_critic_stop: int = 2,
gen_rounds: int = 3,
sensitivity_max_col_nums: int = 20,
use_ks: bool = False,
resume_from_checkpoint: Union[bool, str] = False,
full_sensitivity: bool = False,
sensitivity_orig_frac_multiple: int = 4,
orig_samples_rounds: int = 5,
load_from_best_mean_sensitivity: bool = False,
):
assert gen_rounds >= 1
_frac = min(frac, frac_max_data / len(df))
if frac != _frac:
warnings.warn(
f"The frac ({frac}) set results to a sample larger than \
frac_max_data={frac_max_data}. Setting frac to {_frac}."
)
frac = _frac
trainer: Trainer = None
dup_rate = df.duplicated().mean()
if isinstance(qt_max, str):
if qt_max == "compute":
# The idea behind this is if the empirical has
# natural duplicates, we can use that as
# basis for what a typical rate for duplicates a
# random sample should have. Any signidican excess
# from this indicates overfitting.
# The choice of dividing the duplicate rate by 2
# is arbitrary but reasonable to prevent delayed
# stopping when overfitting.
dup_rate = dup_rate / 2
qt_max = dup_rate if dup_rate > 0 else qt_max_default
else:
raise ValueError(f"Unexpected qt_max value: {qt_max}")
elif not isinstance(qt_max, str) and dup_rate >= qt_max:
warnings.warn(
f'The qt_max ({qt_max}) set is lower than the duplicate \rate ({dup_rate}) in \
the data. This will not give a reliable early stopping condition. Consider \
using qt_max="compute" argument.'
)
if dup_rate == 0:
# We do this because for data without unique values, we
# expect that a generated sample should have equal likelihood
# in the minimum distance with the hold out.
warnings.warn(
f"Duplicate rate ({dup_rate}) in the data is zero. The `qt_interval` will be set \
to qt_interval_unique={qt_interval_unique}."
)
qt_interval = qt_interval_unique
# Estimate the sensitivity threshold
print("Computing the sensitivity threshold...")
if not full_sensitivity:
# Dynamically compute the qt_interval to fit the data
# if the resulting sample has lower resolution.
# For example, we can't use qt_interval=1000 if the number
# of samples left at qt_max of the distance matrix is less than
# 1000.
# The formula means:
# - 2 -> accounts for the fact that we concatenate the rows and columns
# of the distance matrix.
# - frac -> the proportion of the training data that is used to compute the
# the distance matrix.
# - qt_max -> the maximum quantile of assessment.
# We divide by 2 to increase the resolution a bit
_qt_interval = min(qt_interval, (2 * frac * len(df) * qt_max) // 2)
_qt_interval = max(_qt_interval, 2)
_qt_interval = int(_qt_interval)
if _qt_interval < qt_interval:
warnings.warn(
f"qt_interval adjusted from {qt_interval} to {_qt_interval}..."
)
qt_interval = _qt_interval
# Computing this here before splitting may have some data
# leakage issue, but it should be almost negligible. Doing
# the computation of the threshold on the full data with the
# train size aligned will give a more reliable estimate of
# the sensitivity threshold.
sensitivity_values = SyntheticDataBench.compute_sensitivity_threshold(
train_data=df,
num_bootstrap=num_bootstrap,
# Divide by two so that the train_data in this computation matches the size
# of the final df used to train the model. This is essential so that the
# sensitivity_threshold value is consistent with the val_sensitivity.
# Concretely, the computation of the distribution of min distances is
# relative to the number of training observations.
# The `frac` in this method corresponds to the size of both the test and the
# synthetic samples.
frac=frac / 2,
qt_max=qt_max,
qt_interval=qt_interval,
distance=distance,
return_values=True,
quantile=quantile,
max_col_nums=sensitivity_max_col_nums,
use_ks=use_ks,
full_sensitivity=full_sensitivity,
sensitivity_orig_frac_multiple=sensitivity_orig_frac_multiple,
)
sensitivity_threshold = np.quantile(sensitivity_values, quantile)
mean_sensitivity_value = np.mean(sensitivity_values)
best_mean_sensitivity_value = np.inf
assert isinstance(sensitivity_threshold, float)
print("Sensitivity threshold:", sensitivity_threshold, "qt_max:", qt_max)
# # Create a hold out sample for the discriminator model
# hold_df = df.sample(frac=frac, random_state=self.random_state)
# df = df.loc[df.index.difference(hold_df.index)]
# Start training
logging.info("Start training...")
# Remove existing checkpoints
for chkp in self.checkpoints_dir.glob("checkpoint-*"):
shutil.rmtree(chkp, ignore_errors=True)
sensitivity_scores = []
bdm_path = self.checkpoints_dir / TabularArtefact.best_disc_model
mean_closest_bdm_path = (
self.checkpoints_dir / TabularArtefact.mean_best_disc_model
)
not_bdm_path = self.checkpoints_dir / TabularArtefact.not_best_disc_model
last_epoch_path = self.checkpoints_dir / TabularArtefact.last_epoch_model
# Remove existing artefacts in the best model dir
shutil.rmtree(bdm_path, ignore_errors=True)
bdm_path.mkdir(parents=True, exist_ok=True)
shutil.rmtree(mean_closest_bdm_path, ignore_errors=True)
mean_closest_bdm_path.mkdir(parents=True, exist_ok=True)
shutil.rmtree(not_bdm_path, ignore_errors=True)
not_bdm_path.mkdir(parents=True, exist_ok=True)
shutil.rmtree(last_epoch_path, ignore_errors=True)
last_epoch_path.mkdir(parents=True, exist_ok=True)
last_epoch = 0
not_best_val_sensitivity = np.inf
if resume_from_checkpoint:
chkp_list = sorted(
self.checkpoints_dir.glob("checkpoint-*"), key=os.path.getmtime
)
if chkp_list:
# Get the most recent checkpoint based on
# creation time.
chkp = chkp_list[-1]
trainer_state = json.loads((chkp / "trainer_state.json").read_text())
last_epoch = math.ceil(trainer_state["epoch"])
trainer = self._fit_tabular(
df,
device=device,
num_train_epochs=last_epoch,
target_epochs=self.epochs,
)
np.random.seed(self.random_state)
random.seed(self.random_state)
for p_epoch in range(last_epoch, self.epochs, n_critic):
gen_total = int(len(df) * frac)
num_train_epochs = min(p_epoch + n_critic, self.epochs)
# Perform the discriminator sampling every `n_critic` epochs
if trainer is None:
trainer = self._fit_tabular(
df,
device=device,
num_train_epochs=num_train_epochs,
target_epochs=self.epochs,
)
trainer.train(resume_from_checkpoint=False)
else:
trainer = self._build_tabular_trainer(
device=device,
num_train_epochs=num_train_epochs,
target_epochs=self.epochs,
)
trainer.train(resume_from_checkpoint=True)
try:
# Generate samples
gen_df = self.sample(n_samples=gen_rounds * gen_total, device=device)
except SampleEmptyLimitError:
# Continue training if the model is still not
# able to generate stable observations.
continue
val_sensitivities = []
if full_sensitivity:
for _ in range(gen_rounds):
hold_df = df.sample(n=gen_total)
for g_idx in range(gen_rounds):
val_sensitivities.append(
SyntheticDataBench.compute_sensitivity_metric(
original=df.loc[df.index.difference(hold_df.index)],
synthetic=gen_df.iloc[
g_idx * gen_total : (g_idx + 1) * gen_total
],
test=hold_df,
qt_max=qt_max,
qt_interval=qt_interval,
distance=distance,
max_col_nums=sensitivity_max_col_nums,
use_ks=use_ks,
)
)
else:
for g_idx in range(gen_rounds):
for _ in range(orig_samples_rounds):
original_df = df.sample(
n=sensitivity_orig_frac_multiple * gen_total, replace=False
)
hold_df = df.loc[df.index.difference(original_df.index)].sample(
n=gen_total, replace=False
)
val_sensitivities.append(
SyntheticDataBench.compute_sensitivity_metric(
original=original_df,
synthetic=gen_df.iloc[
g_idx * gen_total : (g_idx + 1) * gen_total
],
test=hold_df,
qt_max=qt_max,
qt_interval=qt_interval,
distance=distance,
max_col_nums=sensitivity_max_col_nums,
use_ks=use_ks,
)
)
val_sensitivity = np.mean(val_sensitivities)
sensitivity_scores.append(val_sensitivity)
if val_sensitivity < sensitivity_threshold:
# Just save the model while the
# validation sensitivity is still within
# the accepted range.
# This way we can load the acceptable
# model back when the threshold is breached.
trainer.save_model(bdm_path.as_posix())
trainer.state.save_to_json((bdm_path / "trainer_state.json").as_posix())
elif not_best_val_sensitivity > (val_sensitivity - sensitivity_threshold):
print("Saving not-best model...")
trainer.save_model(not_bdm_path.as_posix())
trainer.state.save_to_json(
(not_bdm_path / "trainer_state.json").as_posix()
)
not_best_val_sensitivity = val_sensitivity - sensitivity_threshold
_delta_mean_sensitivity_value = abs(
mean_sensitivity_value - val_sensitivity
)
if _delta_mean_sensitivity_value < best_mean_sensitivity_value:
best_mean_sensitivity_value = _delta_mean_sensitivity_value
trainer.save_model(mean_closest_bdm_path.as_posix())
trainer.state.save_to_json(
(mean_closest_bdm_path / "trainer_state.json").as_posix()
)
print(
f"Critic round: {p_epoch + n_critic}, \
sensitivity_threshold: {sensitivity_threshold}, \
val_sensitivity: {val_sensitivity}, \
val_sensitivities: {val_sensitivities}"
)
if len(sensitivity_scores) > n_critic_stop:
n_no_improve = 0
for sensitivity_score in sensitivity_scores[-n_critic_stop:]:
# We count no improvement if the score is not
# better than the best, and that the score is not
# better than the previous score.
if sensitivity_score > sensitivity_threshold:
n_no_improve += 1
if n_no_improve == n_critic_stop:
print("Stopping training, no improvement in critic...")
break
# Save last epoch artefacts before loading the best model.
trainer.save_model(last_epoch_path.as_posix())
trainer.state.save_to_json((last_epoch_path / "trainer_state.json").as_posix())
loaded_model_path = None
if not load_from_best_mean_sensitivity:
if (bdm_path / "pytorch_model.bin").exists() or (bdm_path / "model.safetensors").exists():
loaded_model_path = bdm_path
else:
if (mean_closest_bdm_path / "pytorch_model.bin").exists() or (mean_closest_bdm_path / "model.safetensors").exists():
loaded_model_path = mean_closest_bdm_path
if loaded_model_path is None:
# There should always be at least one `mean_closest_bdm_path` but
# in case it doesn't exist, try loading from `not_bdm_path`.
warnings.warn(
"No best model was saved. Loading the closest model to the sensitivity_threshold."
)
loaded_model_path = not_bdm_path
self.model = self.model.from_pretrained(loaded_model_path.as_posix())
self.trainer_state = json.loads(
(loaded_model_path / "trainer_state.json").read_text()
)
return trainer
[docs] def _set_up_relational_coder_configs(self) -> None:
def _get_coder(coder_name) -> GPT2Config:
return getattr(self.relational_config, coder_name)
for coder_name in ["encoder", "decoder"]:
coder = _get_coder(coder_name)
coder.bos_token_id = self.vocab[coder_name]["token2id"][SpecialTokens.BOS]
coder.eos_token_id = self.vocab[coder_name]["token2id"][SpecialTokens.EOS]
coder.pad_token_id = self.vocab[coder_name]["token2id"][SpecialTokens.PAD]
coder.vocab_size = len(self.vocab[coder_name]["id2token"])
if coder_name == "decoder":
self.relational_config.bos_token_id = coder.bos_token_id
self.relational_config.eos_token_id = coder.eos_token_id
self.relational_config.pad_token_id = coder.pad_token_id
self.relational_config.decoder_start_token_id = coder.eos_token_id
# Make sure that we have at least the number of
# columns in the transformed data as positions.
# This will prevent runtime error.
# `RuntimeError: CUDA error: device-side assert triggered`
assert self.relational_max_length
if (
coder_name == "decoder"
and coder.n_positions < self.relational_max_length
):
coder.n_positions = 128 + self.relational_max_length
elif coder_name == "encoder" and getattr(
coder, "n_positions", getattr(coder, "max_position_embeddings")
) < len(self.vocab[coder_name]["column_token_ids"]):
positions = 128 + len(self.vocab[coder_name]["column_token_ids"])
try:
coder.n_positions = positions
except:
coder.max_position_embeddings = positions
# This must be set to True for the EncoderDecoderModel to work at least
# with GPT2 as the decoder.
self.relational_config.decoder.add_cross_attention = True
[docs] def _fit_relational(
self, out_df: pd.DataFrame, in_df: pd.DataFrame, join_on: str, device="cuda"
):
# bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
# bert2bert.config.eos_token_id = tokenizer.sep_token_id
# bert2bert.config.pad_token_id = tokenizer.pad_token_id
# bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
# All join values in the out_df must be present in the in_df.
assert len(set(out_df[join_on].unique()).difference(in_df[join_on])) == 0
# Get the list of index of observations that are related based on
# the join_on variable.
common_out_idx = (
out_df.reset_index(drop=True)
.groupby(join_on)
.apply(lambda x: x.index.to_list())
)
# Track the mapping of index from input to the list of output indices.
in_out_idx = pd.Series(
# Reset the index so that we are sure that the index ids are set properly.
dict(in_df[join_on].reset_index(drop=True).items())
).map(lambda x: common_out_idx.get(x, []))
# Remove the unique id column from the in_df and the out_df
in_df = in_df.drop(join_on, axis=1)
out_df = out_df.drop(join_on, axis=1)
self._extract_column_info(out_df)
out_df, self.col_transform_data = process_data(
out_df,
numeric_max_len=self.numeric_max_len,
numeric_precision=self.numeric_precision,
numeric_nparts=self.numeric_nparts,
)
self.processed_columns = out_df.columns.to_list()
self.vocab["decoder"] = self._generate_vocab(out_df)
self.relational_col_size = out_df.shape[1]
# NOTE: the index starts at zero, but should be adjusted
# to account for the special tokens. For relational data,
# the index should start at 3 ([[EOS], [BOS], [BMEM]]).
self.col_idx_ids = {
ix: self.vocab["decoder"]["column_token_ids"][col]
for ix, col in enumerate(self.processed_columns)
}
# Add these special tokens at specific key values
# which are used in `REaLSampler._get_relational_col_idx_ids`
self.col_idx_ids[-1] = [
self.vocab["decoder"]["token2id"][SpecialTokens.BMEM],
self.vocab["decoder"]["token2id"][SpecialTokens.EOS],
]
self.col_idx_ids[-2] = [self.vocab["decoder"]["token2id"][SpecialTokens.EMEM]]
# TODO: handle the col_transform_data from the in_df as well.
in_df, self.in_col_transform_data = process_data(
in_df,
numeric_max_len=self.numeric_max_len,
numeric_precision=self.numeric_precision,
numeric_nparts=self.numeric_nparts,
col_transform_data=self.parent_col_transform_data,
)
if self.parent_vocab is None:
self.vocab["encoder"] = self._generate_vocab(in_df)
else:
self.vocab["encoder"] = self.parent_vocab
# Load the dataframe into a HuggingFace Dataset
dataset = make_relational_dataset(
in_df=in_df,
out_df=out_df,
vocab=self.vocab,
in_out_idx=in_out_idx,
output_max_length=self.output_max_length,
mask_rate=self.mask_rate,
return_token_type_ids=False,
)
# Compute the longest sequence of labels in the dataset and add a buffer of 1.
self.relational_max_length = (
max(
dataset.map(lambda example: dict(length=len(example["labels"])))[
"length"
]
)
+ 1
)
# Create train-eval split if specified
dataset = self._split_train_eval_dataset(dataset)
# Set up the config and the model
self._set_up_relational_coder_configs()
# Build the model.
self.model = EncoderDecoderModel(self.relational_config)
if self.parent_gpt2_state_dict is not None:
pretrain_load = self.model.encoder.load_state_dict(
self.parent_gpt2_state_dict, strict=False
)
assert (
not pretrain_load.missing_keys
), "There should be no missing_keys after loading the pretrained GPT2 state!"
if self.freeze_parent_model:
# We freeze the weights if we use the pretrained
# parent table model.
for param in self.model.encoder.parameters():
param.requires_grad = False
# Tell pytorch to run this model on the GPU.
device = torch.device(device)
if device == torch.device("cuda"):
self.model.cuda()
# Set TrainingArguments and the Seq2SeqTrainer
training_args_kwargs = dict(self.training_args_kwargs)
default_args_kwargs = dict(
# predict_with_generate=True,
# warmup_steps=2000,
fp16=(
device == torch.device("cuda")
), # Use fp16 by default if using cuda device
)
for k, v in default_args_kwargs.items():
if k not in training_args_kwargs:
training_args_kwargs[k] = v
callbacks = None
if training_args_kwargs["load_best_model_at_end"]:
callbacks = [
EarlyStoppingCallback(
self.early_stopping_patience, self.early_stopping_threshold
)
]
# instantiate trainer
trainer = Seq2SeqTrainer(
model=self.model,
args=Seq2SeqTrainingArguments(**training_args_kwargs),
callbacks=callbacks,
data_collator=RelationalDataCollator(),
**dataset,
)
return trainer
[docs] def _fit_tabular(
self,
df: pd.DataFrame,
device="cuda",
num_train_epochs: int = None,
target_epochs: int = None,
) -> Trainer:
self._extract_column_info(df)
df, self.col_transform_data = process_data(
df,
numeric_max_len=self.numeric_max_len,
numeric_precision=self.numeric_precision,
numeric_nparts=self.numeric_nparts,
target_col=self.target_col,
)
self.processed_columns = df.columns.to_list()
self.vocab = self._generate_vocab(df)
self.tabular_col_size = df.shape[0]
# NOTE: the index starts at zero, but should be adjusted
# to account for the special tokens. For tabular data,
# the index should start at 1.
self.col_idx_ids = {
ix: self.vocab["column_token_ids"][col]
for ix, col in enumerate(self.processed_columns)
}
# Load the dataframe into a HuggingFace Dataset
dataset = make_dataset(
df, self.vocab, mask_rate=self.mask_rate, return_token_type_ids=False
)
# Store the sequence length for the processed data
self.tabular_max_length = len(dataset[0]["input_ids"])
# Create train-eval split if specified
dataset = self._split_train_eval_dataset(dataset)
self.dataset = dataset
# Set up the config and the model
self.tabular_config.bos_token_id = self.vocab["token2id"][SpecialTokens.BOS]
self.tabular_config.eos_token_id = self.vocab["token2id"][SpecialTokens.EOS]
self.tabular_config.vocab_size = len(self.vocab["id2token"])
# Make sure that we have at least the number of
# columns in the transformed data as positions.
if self.tabular_config.n_positions < len(self.vocab["column_token_ids"]):
self.tabular_config.n_positions = 128 + len(self.vocab["column_token_ids"])
self.model = GPT2LMHeadModel(self.tabular_config)
# Tell pytorch to run this model on the GPU.
device = torch.device(device)
if device == torch.device("cuda"):
self.model.cuda()
return self._build_tabular_trainer(
device=device,
num_train_epochs=num_train_epochs,
target_epochs=target_epochs,
)
[docs] def _build_tabular_trainer(
self,
device="cuda",
num_train_epochs: int = None,
target_epochs: int = None,
) -> Trainer:
device = torch.device(device)
# Set TrainingArguments and the Trainer
logging.info("Set up the TrainingArguments and the Trainer...")
training_args_kwargs: Dict[str, Any] = dict(self.training_args_kwargs)
default_args_kwargs = dict(
fp16=(
device == torch.device("cuda")
), # Use fp16 by default if using cuda device
)
for k, v in default_args_kwargs.items():
if k not in training_args_kwargs:
training_args_kwargs[k] = v
if num_train_epochs is not None:
training_args_kwargs["num_train_epochs"] = num_train_epochs
# # NOTE: The `ResumableTrainer` will default to its original
# # behavior (Trainer) if `target_epochs`` is None.
# # Set the `target_epochs` to `num_train_epochs` if not specified.
# if target_epochs is None:
# target_epochs = training_args_kwargs.get("num_train_epochs")
callbacks = None
if training_args_kwargs["load_best_model_at_end"]:
callbacks = [
EarlyStoppingCallback(
self.early_stopping_patience, self.early_stopping_threshold
)
]
assert self.dataset
trainer = ResumableTrainer(
target_epochs=target_epochs,
save_epochs=None,
model=self.model,
args=TrainingArguments(**training_args_kwargs),
data_collator=None, # Use the default_data_collator
callbacks=callbacks,
**self.dataset,
)
return trainer
[docs] def sample(
self,
n_samples: int = None,
input_unique_ids: Optional[Union[pd.Series, List]] = None,
input_df: Optional[pd.DataFrame] = None,
input_ids: Optional[torch.tensor] = None,
gen_batch: Optional[int] = 128,
device: str = "cuda",
seed_input: Optional[Union[pd.DataFrame, Dict[str, Any]]] = None,
save_samples: Optional[bool] = False,
constrain_tokens_gen: Optional[bool] = True,
validator: Optional[ObservationValidator] = None,
continuous_empty_limit: int = 10,
suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
related_num: Optional[Union[int, List[int]]] = None,
**generate_kwargs,
) -> pd.DataFrame:
"""Generate synthetic tabular data samples
Args:
n_samples: Number of synthetic samples to generate for the tabular data.
input_unique_ids: The unique identifier that will be used to link the input
data to the generated values when sampling for relational data.
input_df: Pandas DataFrame containing the tabular input data.
input_ids: (NOTE: the `input_df` argument is the preferred input)
The input_ids that conditions the generation of the relational data.
gen_batch: Controls the batch size of the data generation process. This parameter
should be adjusted based on the compute resources.
device: The device used by the generator.
Use torch devices, e.g., `cpu`, `cuda`, `mps` (experimental)
seed_input: A dictionary of `col_name:values` for the seed data. Only `col_names`
that are actually in the first sequence of the training input will be used.
constrain_tokens_gen: Set whether we impose a constraint at each step of the generation
limited only to valid tokens for the column.
validator: An instance of `ObservationValidator` for validating the generated samples.
The validators are applied to observations only, and don't support inter-observation
validation. See `ObservationValidator` docs on how to set up a validator.
continuous_invalid_limit: The sampling will raise an exception if
`continuous_empty_limit` empty sample batches have been produced continuously. This
will prevent an infinite loop if the quality of the data generated is not good and
always produces invalid observations.
suppress_tokens: (from docs) A list of tokens that will be supressed at generation.
The SupressTokens logit processor will set their log probs to -inf so that they are
not sampled. This is a useful feature for imputing missing values.
forced_decoder_ids: (from docs) A list of pairs of integers which indicates a mapping
from generation indices to token indices that will be forced before sampling. For
example, [[1, 123]] means the second generated token will always be a token of
index 123. This is a useful feature for constraining the model to generate only
specific stratification variables in surveys, e.g., GEO1, URBAN/RURAL variables.
related_num: A column name in the input_df containing the number of observations that the child
table is expected to have for the parent observation. It can also be an integer if the input_df
corresponds to a set of observations having the same number of expected observations.
This parameter is only valid for the relational model.
generate_kwargs: Additional keywords arguments that will be supplied to `.generate`
method. For a comprehensive list of arguments, see:
https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate
Returns:
DataFrame with n_samples rows of generated data
"""
self._check_model()
device = _validate_get_device(device)
# Clear the cache
torch.cuda.empty_cache()
if self.model_type == ModelType.tabular:
assert n_samples
assert self.tabular_max_length is not None
assert self.tabular_col_size is not None
assert self.col_transform_data is not None
tabular_sampler = TabularSampler.sampler_from_model(
rtf_model=self, device=device
)
# (
# model_type=self.model_type,
# model=self.model,
# vocab=self.vocab,
# processed_columns=self.processed_columns,
# max_length=self.tabular_max_length,
# col_size=self.tabular_col_size,
# col_idx_ids=self.col_idx_ids,
# columns=self.columns,
# datetime_columns=self.datetime_columns,
# column_dtypes=self.column_dtypes,
# drop_na_cols=self.drop_na_cols,
# col_transform_data=self.col_transform_data,
# random_state=self.random_state,
# device=device,
# )
synth_df = tabular_sampler.sample_tabular(
n_samples=n_samples,
gen_batch=gen_batch,
device=device,
seed_input=seed_input,
constrain_tokens_gen=constrain_tokens_gen,
validator=validator,
continuous_empty_limit=continuous_empty_limit,
suppress_tokens=suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
**generate_kwargs,
)
elif self.model_type == ModelType.relational:
assert (input_ids is not None) or (input_df is not None)
assert self.relational_max_length is not None
assert self.relational_col_size is not None
assert self.col_transform_data is not None
assert self.in_col_transform_data is not None
relational_sampler = RelationalSampler.sampler_from_model(
rtf_model=self, device=device
)
# (
# model_type=self.model_type,
# model=self.model,
# vocab=self.vocab,
# processed_columns=self.processed_columns,
# max_length=self.relational_max_length,
# col_size=self.relational_col_size,
# col_idx_ids=self.col_idx_ids,
# columns=self.columns,
# datetime_columns=self.datetime_columns,
# column_dtypes=self.column_dtypes,
# drop_na_cols=self.drop_na_cols,
# col_transform_data=self.col_transform_data,
# in_col_transform_data=self.in_col_transform_data,
# random_state=self.random_state,
# device=device,
# )
synth_df = relational_sampler.sample_relational(
input_unique_ids=input_unique_ids,
input_df=input_df,
input_ids=input_ids,
device=device,
gen_batch=gen_batch,
constrain_tokens_gen=constrain_tokens_gen,
validator=validator,
continuous_empty_limit=continuous_empty_limit,
suppress_tokens=suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
related_num=related_num,
**generate_kwargs,
)
if save_samples:
samples_fname = (
self.samples_save_dir
/ f"rtf_{self.model_type}-exp_{self.experiment_id}-{int(time.time())}-samples_{synth_df.shape[0]}.pkl"
)
samples_fname.parent.mkdir(parents=True, exist_ok=True)
synth_df.to_pickle(samples_fname)
return synth_df
[docs] def predict(
self,
data: pd.DataFrame,
target_col: str,
target_pos_val: Any = None,
batch: int = 32,
obs_sample: int = 30,
fillunk: bool = True,
device: str = "cuda",
disable_progress_bar: bool = True,
**generate_kwargs,
) -> pd.Series:
"""
Use the trained model to make predictions on a given dataframe.
Args:
data: The data to make predictions on, in the form of a Pandas dataframe.
target_col: The name of the target column in the data to predict.
target_pos_val: The positive value in the target column to use for binary
classification. This is produces a one-to-many prediction relative to
`target_pos_val` for targets that are multi-categorical.
batch: The batch size to use when making predictions.
obs_sample: The number of observations to sample from the data when making predictions.
fillunk: If True, the function will fill any missing values in the data before making
predictions. Fill unknown tokens with the mode of the batch in the given step.
device: The device to use for prediction. Can be either "cpu" or "cuda".
**generate_kwargs: Additional keyword arguments to pass to the model's `generate`
method.
Returns:
A Pandas series containing the predicted values for the target column.
"""
assert (
self.model_type == ModelType.tabular
), "The predict method is only implemented for tabular data..."
self._check_model()
device = _validate_get_device(device)
batch = min(batch, data.shape[0])
# Clear the cache
torch.cuda.empty_cache()
# assert self.tabular_max_length is not None
# assert self.tabular_col_size is not None
# assert self.col_transform_data is not None
tabular_sampler = TabularSampler.sampler_from_model(self, device=device)
# TabularSampler(
# model_type=self.model_type,
# model=self.model,
# vocab=self.vocab,
# processed_columns=self.processed_columns,
# max_length=self.tabular_max_length,
# col_size=self.tabular_col_size,
# col_idx_ids=self.col_idx_ids,
# columns=self.columns,
# datetime_columns=self.datetime_columns,
# column_dtypes=self.column_dtypes,
# drop_na_cols=self.drop_na_cols,
# col_transform_data=self.col_transform_data,
# random_state=self.random_state,
# device=device,
# )
return tabular_sampler.predict(
data=data,
target_col=target_col,
target_pos_val=target_pos_val,
batch=batch,
obs_sample=obs_sample,
fillunk=fillunk,
device=device,
disable_progress_bar=disable_progress_bar,
**generate_kwargs,
)
[docs] def save(self, path: Union[str, Path], allow_overwrite: Optional[bool] = False):
"""Save REaLTabFormer Model
Saves the model weights and a configuration file in the given directory.
Args:
path: Path where to save the model
"""
self._check_model()
assert self.experiment_id is not None
if isinstance(path, str):
path = Path(path)
# Add experiment id to the save path
path = path / self.experiment_id
config_file = path / ModelFileName.rtf_config_json
model_file = path / ModelFileName.rtf_model_pt
if path.is_dir() and not allow_overwrite:
if config_file.exists() or model_file.exists():
raise ValueError(
"This directory is not empty, and contains either a config or a model."
" Consider setting `allow_overwrite=True` if you want to overwrite these."
)
else:
warnings.warn(
f"Directory {path} exists, but `allow_overwrite=False`."
" This will raise an error next time when the model artifacts \
exist on this directory"
)
path.mkdir(parents=True, exist_ok=True)
# Save attributes
rtf_attrs = self.__dict__.copy()
rtf_attrs.pop("model")
# We don't need to store the `parent_config`
# since a saved model should have the weights loaded from
# the trained model already.
for ignore_key in [
"parent_vocab",
"parent_gpt2_config",
"parent_gpt2_state_dict",
"parent_col_transform_data",
]:
if ignore_key in rtf_attrs:
rtf_attrs.pop(ignore_key)
# GPT2Config is not JSON serializable, let us manually
# extract the attributes.
if rtf_attrs.get("tabular_config"):
rtf_attrs["tabular_config"] = rtf_attrs["tabular_config"].to_dict()
if rtf_attrs.get("relational_config"):
rtf_attrs["relational_config"] = rtf_attrs["relational_config"].to_dict()
rtf_attrs["checkpoints_dir"] = rtf_attrs["checkpoints_dir"].as_posix()
rtf_attrs["samples_save_dir"] = rtf_attrs["samples_save_dir"].as_posix()
config_file.write_text(json.dumps(rtf_attrs))
# Save model weights
torch.save(self.model.state_dict(), model_file.as_posix())
if self.model_type == ModelType.tabular:
# Copy the special model checkpoints for
# tabular models.
for artefact in TabularArtefact.artefacts():
print("Copying artefacts from:", artefact)
if (self.checkpoints_dir / artefact).exists():
shutil.copytree(
self.checkpoints_dir / artefact,
path / artefact,
dirs_exist_ok=True,
)
@classmethod
[docs] def load_from_dir(cls, path: Union[str, Path]):
"""Load a saved REaLTabFormer model
Load trained REaLTabFormer model from directory.
Args:
path: Directory where REaLTabFormer model is saved
Returns:
REaLTabFormer instance
"""
if isinstance(path, str):
path = Path(path)
config_file = path / ModelFileName.rtf_config_json
model_file = path / ModelFileName.rtf_model_pt
assert path.is_dir(), f"Directory {path} does not exist."
assert config_file.exists(), f"Config file {config_file} does not exist."
assert model_file.exists(), f"Model file {model_file} does not exist."
# Load the saved attributes
rtf_attrs = json.loads(config_file.read_text())
# Create new REaLTabFormer model instance
try:
realtf = cls(model_type=rtf_attrs["model_type"])
except KeyError:
# Back-compatibility for saved models
# before the support for relational data
# was implemented.
realtf = cls(model_type="tabular")
# Set all attributes and handle the
# special case for the GPT2Config.
for k, v in rtf_attrs.items():
if k == "gpt_config":
# Back-compatibility for saved models
# before the support for relational data
# was implemented.
v = GPT2Config.from_dict(v)
k = "tabular_config"
elif k == "tabular_config":
v = GPT2Config.from_dict(v)
elif k == "relational_config":
v = EncoderDecoderConfig.from_dict(v)
elif k in ["checkpoints_dir", "samples_save_dir"]:
v = Path(v)
elif k == "vocab":
if realtf.model_type == ModelType.tabular:
# Cast id back to int since JSON converts them to string.
v["id2token"] = {int(ii): vv for ii, vv in v["id2token"].items()}
elif realtf.model_type == ModelType.relational:
v["encoder"]["id2token"] = {
int(ii): vv for ii, vv in v["encoder"]["id2token"].items()
}
v["decoder"]["id2token"] = {
int(ii): vv for ii, vv in v["decoder"]["id2token"].items()
}
else:
raise ValueError(f"Invalid model_type: {realtf.model_type}")
elif k == "col_idx_ids":
v = {int(ii): vv for ii, vv in v.items()}
setattr(realtf, k, v)
# Implement back-compatibility for REaLTabFormer version < 0.0.1.8.2
# since the attribute `col_idx_ids` is not implemented before.
if "col_idx_ids" not in rtf_attrs:
if realtf.model_type == ModelType.tabular:
realtf.col_idx_ids = {
ix: realtf.vocab["column_token_ids"][col]
for ix, col in enumerate(realtf.processed_columns)
}
elif realtf.model_type == ModelType.relational:
# NOTE: the index starts at zero, but should be adjusted
# to account for the special tokens. For relational data,
# the index should start at 3 ([[EOS], [BOS], [BMEM]]).
realtf.col_idx_ids = {
ix: realtf.vocab["decoder"]["column_token_ids"][col]
for ix, col in enumerate(realtf.processed_columns)
}
# Add these special tokens at specific key values
# which are used in `REaLSampler._get_relational_col_idx_ids`
realtf.col_idx_ids[-1] = [
realtf.vocab["decoder"]["token2id"][SpecialTokens.BMEM],
realtf.vocab["decoder"]["token2id"][SpecialTokens.EOS],
]
realtf.col_idx_ids[-2] = [
realtf.vocab["decoder"]["token2id"][SpecialTokens.EMEM]
]
# Load model weights
if realtf.model_type == ModelType.tabular:
realtf.model = GPT2LMHeadModel(realtf.tabular_config)
elif realtf.model_type == ModelType.relational:
realtf.model = EncoderDecoderModel(realtf.relational_config)
else:
raise ValueError(f"Invalid model_type: {realtf.model_type}")
realtf.model.load_state_dict(
torch.load(model_file.as_posix(), map_location="cpu")
)
return realtf