Forecast Combination#
import os
os.chdir("../../")
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from scripts.python.tsa.mtsmodel import *
from scripts.python.tsa.ts_eval import *
import seaborn as sns
sns.set_style("whitegrid")
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
As Timmermann (2004) summarizes Stock and Watson (1998)’s relative performance weights. Let
Below are the functions to calculate the relative performance weight where
Show code cell source
def calculate_mse(predictions_df: pd.DataFrame, method: str) -> pd.Series:
total = predictions_df["total"]
prediction = predictions_df[method]
mse = np.square(total - prediction).cumsum() / (predictions_df.index + 1)
return mse
def calculate_rpw(predictions_df: pd.DataFrame, methods: list) -> pd.Series:
mse_dict = {method: calculate_mse(predictions_df, method)
for method in methods}
denominator = sum(1 / mse_dict[method] for method in methods)
rpw_dict = {}
for method in methods:
numerator = 1 / mse_dict[method]
omega = numerator / denominator
rpw_dict[method] = omega
return pd.Series(rpw_dict)
def get_rpw(pred_df: pd.DataFrame,
methods: list = ["sarimax", "var", "lf"]) -> pd.Series:
predictions_df = pred_df.copy()
rpw_series = calculate_rpw(predictions_df, methods)
combo_cols = []
for method in methods:
weight = str(method) + "_weight"
predictions_df[weight] = predictions_df[method] * rpw_series[method]
combo_cols.append(weight)
rpw_pred = predictions_df[combo_cols].sum(axis=1)
rpw = pd.DataFrame(rpw_series.to_dict())
rpw.columns = ["rpw_" + str(col) for col in rpw.columns]
return rpw_pred, rpw
def get_constrained_ls(y: pd.DataFrame,
X: pd.DataFrame) -> np.array:
from scipy.optimize import nnls, minimize
A, b = np.array(X), np.array(y)
x0, norm = nnls(A, b)
def f(x, A, b):
return np.linalg.norm(A.dot(x) - b)
mini = minimize(f, x0, args=(A, b), method='SLSQP',
bounds=[[0, np.inf]], ## Only set the lb
constraints={'type': 'eq', 'fun': lambda x: np.sum(x)-1})
estimates = mini.x
pred = A.dot(estimates)
return estimates, pred
for country in ["palau", "samoa", "tonga", "solomon_islands", "vanuatu"]:
folderpath = os.getcwd() + "/data/tourism/" + str(country) + "/model/"
mappings = [("var", "pred_total"),
("sarimax", "train_pred"),
("lf", "pred_mean")]
country_pred = pd.DataFrame()
for mapping in mappings:
model, column = mapping
filepath = folderpath + str(model) + "_" + str(country) + ".csv"
pred_df = pd.read_csv(filepath).drop("Unnamed: 0", axis=1)
pred_df["date"] = pd.to_datetime(pred_df["date"])
model_col = (pred_df[["date", "total", column]]
.rename({column: model}, axis=1))
if country_pred.empty:
country_pred = model_col
else:
country_pred = country_pred.merge(
model_col, how="left", on=["date", "total"]).fillna(0)
# Mean
country_pred["mean_ensemble"] = (
country_pred[["sarimax", "var", "lf"]].mean(axis=1))
# Relative Performance Weights
country_pred["rpw"], weights = get_rpw(country_pred)
# OLS (regularized) without intercept
ols = smf.ols("total~sarimax+var+lf - 1", data=country_pred)
ols_res = ols.fit()
ols_lasso = ols.fit_regularized()
country_pred["ols"] = ols_res.fittedvalues
country_pred["ols_lasso"] = ols_lasso.fittedvalues
methods = ["sarimax", "var", "lf"]
constrained_ls, cls_pred = get_constrained_ls(y=country_pred["total"],
X=country_pred[methods])
country_pred["cls"] = cls_pred
for method, cls in zip(methods, constrained_ls):
weights["ols_"+str(method)] = ols_res.params[method]
weights["ols_lasso_"+str(method)] = ols_lasso.params[method]
weights["cls_"+str(method)] = cls
weights["date"] = country_pred["date"]
# Sort columns
cols = weights.columns.tolist()
cols = cols[-1:] + cols[:-1]
weights = weights[cols]
# Save Combination weights
weights.to_csv(folderpath+"forecast_combo_weights.csv",
encoding="utf-8")
# Save Forecast Prediction
country_pred.to_csv(folderpath+"forecast_combo.csv",
encoding="utf-8")
evals = pd.DataFrame()
for col in ["sarimax", "var", "lf", "mean_ensemble", "rpw", "ols", "cls", "ols_lasso"]:
mod_eval = pd.DataFrame(calculate_evaluation(country_pred["total"], country_pred[col]),
index=[col])
evals = pd.concat([evals, mod_eval], axis=0)
evals.columns.name = str(country)
evals = evals.style.apply(
lambda x: ['background-color: yellow' if v == x.min() else '' for v in x])
display(evals)
palau | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 1586348.581999 | 1259.503308 | 701.765452 | 53.857541 |
var | 1127006.129408 | 1061.605449 | 554.935555 | 38.243616 |
lf | 522059.184389 | 722.536632 | 403.159217 | 41.454946 |
mean_ensemble | 538679.103245 | 733.947616 | 415.191645 | 34.057742 |
rpw | 416042.585625 | 645.013632 | 355.022438 | 33.879718 |
ols | 459520.198387 | 677.879192 | 374.512959 | 34.509414 |
cls | 462504.925790 | 680.077147 | 377.897454 | 34.589740 |
ols_lasso | 485185.924697 | 696.552887 | 387.646868 | 33.639519 |
samoa | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 8076301.733757 | 2841.883483 | 1410.434675 | 141.993878 |
var | 10290887.267182 | 3207.941282 | 1757.557409 | 141.614714 |
lf | 2107650.783058 | 1451.775046 | 763.990597 | 131.108537 |
mean_ensemble | 3794368.839523 | 1947.913971 | 1093.282593 | 135.822954 |
rpw | 2242890.336299 | 1497.628237 | 746.332632 | 131.773854 |
ols | 1964126.186335 | 1401.472863 | 784.266056 | 131.370686 |
cls | 2009637.625508 | 1417.616882 | 763.104201 | 131.248177 |
ols_lasso | 2003163.146708 | 1415.331462 | 819.357748 | 131.860632 |
tonga | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 677589.950307 | 823.158521 | 473.356102 | 81.354328 |
var | 547847.120126 | 740.166954 | 354.330235 | 61.940427 |
lf | 203622.901860 | 451.245944 | 258.523902 | 87.345914 |
mean_ensemble | 194568.858596 | 441.099602 | 239.402169 | 78.543566 |
rpw | 155098.313391 | 393.825232 | 216.180497 | 83.251701 |
ols | 110277.230689 | 332.080157 | 198.679691 | 83.294941 |
cls | 167576.289342 | 409.360830 | 225.881663 | 84.660888 |
ols_lasso | 125878.110314 | 354.793053 | 197.297385 | 77.645737 |
solomon_islands | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 46105.079450 | 214.720934 | 147.806140 | 31.941491 |
var | 65021.669854 | 254.993470 | 154.338247 | 17.694166 |
lf | 48753.426027 | 220.801780 | 150.232442 | 26.614403 |
mean_ensemble | 31432.381706 | 177.291798 | 126.564826 | 22.784182 |
rpw | 26373.763079 | 162.400009 | 118.723610 | 22.831037 |
ols | 31219.515216 | 176.690450 | 128.227540 | 23.614795 |
cls | 31242.949442 | 176.756752 | 128.315830 | 23.459240 |
ols_lasso | 35808.420435 | 189.231130 | 135.363359 | 27.983984 |
vanuatu | MSE | RMSE | MAE | SMAPE |
---|---|---|---|---|
sarimax | 669250.974064 | 818.077609 | 372.363525 | 133.369731 |
var | 1600258.286965 | 1265.013157 | 554.204122 | 134.493067 |
lf | 642113.393395 | 801.319782 | 497.354972 | 128.023752 |
mean_ensemble | 496754.494910 | 704.808126 | 351.729479 | 131.547897 |
rpw | 333825.379320 | 577.776236 | 289.982721 | 130.586670 |
ols | 394337.435274 | 627.962925 | 342.287171 | 131.085132 |
cls | 413561.135366 | 643.087191 | 354.292010 | 131.225479 |
ols_lasso | 419935.486543 | 648.024295 | 330.751929 | 131.571070 |