Forecast Combination

Forecast Combination#

import os
os.chdir("../../")

import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from scripts.python.tsa.mtsmodel import *
from scripts.python.tsa.ts_eval import *

import seaborn as sns
sns.set_style("whitegrid")
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

As Timmermann (2004) summarizes Stock and Watson (1998)’s relative performance weights. Let \(MSE_{t+h,t,i} = (1/v)\sum_{\tau=t-v}^{t} e^{2}_{\tau,\tau−h,i}\) be the \(i\)th forecasting model’s MSE at time \(t\), computed over a window of the previous \(v\) periods. Then

\[ \hat{y}_{t+h,t} = \sum_{i=1}^{N} \hat{\omega}_{t+h,t,i} \hat{y}_{t+h,t,i}, \text{ where } \hat{\omega}_{t+h,t,i} = \frac{(1/MSE_{t+h,t,i})}{\sum_{j=1}^{N} (1/MSE_{t+h,t,j})}\]

Below are the functions to calculate the relative performance weight where \(i \in \{sarimax, lf, var\}\).

Hide code cell source
def calculate_mse(predictions_df: pd.DataFrame, method: str) -> pd.Series:
    total = predictions_df["total"]
    prediction = predictions_df[method]
    mse = np.square(total - prediction).cumsum() / (predictions_df.index + 1)
    return mse


def calculate_rpw(predictions_df: pd.DataFrame, methods: list) -> pd.Series:
    mse_dict = {method: calculate_mse(predictions_df, method)
                for method in methods}
    denominator = sum(1 / mse_dict[method] for method in methods)
    rpw_dict = {}
    for method in methods:
        numerator = 1 / mse_dict[method]
        omega = numerator / denominator
        rpw_dict[method] = omega
    return pd.Series(rpw_dict)


def get_rpw(pred_df: pd.DataFrame,
            methods: list = ["sarimax", "var", "lf"]) -> pd.Series:
    predictions_df = pred_df.copy()
    rpw_series = calculate_rpw(predictions_df, methods)

    combo_cols = []
    for method in methods:
        weight = str(method) + "_weight"
        predictions_df[weight] = predictions_df[method] * rpw_series[method]
        combo_cols.append(weight)

    rpw_pred = predictions_df[combo_cols].sum(axis=1)
    rpw = pd.DataFrame(rpw_series.to_dict())
    rpw.columns = ["rpw_" + str(col) for col in rpw.columns]

    return rpw_pred, rpw


def get_constrained_ls(y: pd.DataFrame,
                       X: pd.DataFrame) -> np.array:

    from scipy.optimize import nnls, minimize

    A, b = np.array(X), np.array(y)
    x0, norm = nnls(A, b)

    def f(x, A, b):
        return np.linalg.norm(A.dot(x) - b)

    mini = minimize(f, x0, args=(A, b), method='SLSQP',
                    bounds=[[0, np.inf]], ## Only set the lb
                    constraints={'type': 'eq', 'fun': lambda x:  np.sum(x)-1})
    estimates = mini.x
    pred = A.dot(estimates)
    
    return estimates, pred
for country in ["palau", "samoa", "tonga", "solomon_islands", "vanuatu"]:

    folderpath = os.getcwd() + "/data/tourism/" + str(country) + "/model/"
    mappings = [("var", "pred_total"),
                ("sarimax", "train_pred"),
                ("lf", "pred_mean")]

    country_pred = pd.DataFrame()
    for mapping in mappings:
        model, column = mapping
        filepath = folderpath + str(model) + "_" + str(country) + ".csv"
        pred_df = pd.read_csv(filepath).drop("Unnamed: 0", axis=1)
        pred_df["date"] = pd.to_datetime(pred_df["date"])

        model_col = (pred_df[["date", "total", column]]
                     .rename({column: model}, axis=1))

        if country_pred.empty:
            country_pred = model_col
        else:
            country_pred = country_pred.merge(
                model_col, how="left", on=["date", "total"]).fillna(0)

    # Mean
    country_pred["mean_ensemble"] = (
        country_pred[["sarimax", "var", "lf"]].mean(axis=1))

    # Relative Performance Weights
    country_pred["rpw"], weights = get_rpw(country_pred)

    # OLS (regularized) without intercept
    ols = smf.ols("total~sarimax+var+lf - 1", data=country_pred)
    ols_res = ols.fit()
    ols_lasso = ols.fit_regularized()
    country_pred["ols"] = ols_res.fittedvalues
    country_pred["ols_lasso"] = ols_lasso.fittedvalues

    methods = ["sarimax", "var", "lf"]
    constrained_ls, cls_pred = get_constrained_ls(y=country_pred["total"],
                                                  X=country_pred[methods])
    country_pred["cls"] = cls_pred

    for method, cls in zip(methods, constrained_ls):
        weights["ols_"+str(method)] = ols_res.params[method]
        weights["ols_lasso_"+str(method)] = ols_lasso.params[method]
        weights["cls_"+str(method)] = cls

    weights["date"] = country_pred["date"]

    # Sort columns
    cols = weights.columns.tolist()
    cols = cols[-1:] + cols[:-1]
    weights = weights[cols]

    # Save Combination weights
    weights.to_csv(folderpath+"forecast_combo_weights.csv",
                   encoding="utf-8")

    # Save Forecast Prediction
    country_pred.to_csv(folderpath+"forecast_combo.csv",
                        encoding="utf-8")

    evals = pd.DataFrame()
    for col in ["sarimax", "var", "lf", "mean_ensemble", "rpw", "ols", "cls", "ols_lasso"]:
        mod_eval = pd.DataFrame(calculate_evaluation(country_pred["total"], country_pred[col]),
                                index=[col])
        evals = pd.concat([evals, mod_eval], axis=0)

    evals.columns.name = str(country)
    evals = evals.style.apply(
        lambda x: ['background-color: yellow' if v == x.min() else '' for v in x])
    display(evals)
palau MSE RMSE MAE SMAPE
sarimax 1586348.581999 1259.503308 701.765452 53.857541
var 1127006.129408 1061.605449 554.935555 38.243616
lf 522059.184389 722.536632 403.159217 41.454946
mean_ensemble 538679.103245 733.947616 415.191645 34.057742
rpw 416042.585625 645.013632 355.022438 33.879718
ols 459520.198387 677.879192 374.512959 34.509414
cls 462504.925790 680.077147 377.897454 34.589740
ols_lasso 485185.924697 696.552887 387.646868 33.639519
samoa MSE RMSE MAE SMAPE
sarimax 8076301.733757 2841.883483 1410.434675 141.993878
var 10290887.267182 3207.941282 1757.557409 141.614714
lf 2107650.783058 1451.775046 763.990597 131.108537
mean_ensemble 3794368.839523 1947.913971 1093.282593 135.822954
rpw 2242890.336299 1497.628237 746.332632 131.773854
ols 1964126.186335 1401.472863 784.266056 131.370686
cls 2009637.625508 1417.616882 763.104201 131.248177
ols_lasso 2003163.146708 1415.331462 819.357748 131.860632
tonga MSE RMSE MAE SMAPE
sarimax 677589.950307 823.158521 473.356102 81.354328
var 547847.120126 740.166954 354.330235 61.940427
lf 203622.901860 451.245944 258.523902 87.345914
mean_ensemble 194568.858596 441.099602 239.402169 78.543566
rpw 155098.313391 393.825232 216.180497 83.251701
ols 110277.230689 332.080157 198.679691 83.294941
cls 167576.289342 409.360830 225.881663 84.660888
ols_lasso 125878.110314 354.793053 197.297385 77.645737
solomon_islands MSE RMSE MAE SMAPE
sarimax 46105.079450 214.720934 147.806140 31.941491
var 65021.669854 254.993470 154.338247 17.694166
lf 48753.426027 220.801780 150.232442 26.614403
mean_ensemble 31432.381706 177.291798 126.564826 22.784182
rpw 26373.763079 162.400009 118.723610 22.831037
ols 31219.515216 176.690450 128.227540 23.614795
cls 31242.949442 176.756752 128.315830 23.459240
ols_lasso 35808.420435 189.231130 135.363359 27.983984
vanuatu MSE RMSE MAE SMAPE
sarimax 669250.974064 818.077609 372.363525 133.369731
var 1600258.286965 1265.013157 554.204122 134.493067
lf 642113.393395 801.319782 497.354972 128.023752
mean_ensemble 496754.494910 704.808126 351.729479 131.547897
rpw 333825.379320 577.776236 289.982721 130.586670
ols 394337.435274 627.962925 342.287171 131.085132
cls 413561.135366 643.087191 354.292010 131.225479
ols_lasso 419935.486543 648.024295 330.751929 131.571070