countries = ["solomon_islands", "palau", "samoa", "vanuatu", "tonga"]
# Set parameter range
p, d, q = range(0, 3), range(0, 2), range(0, 3)
# list of all parameter combos
pdq = list(itertools.product(p, d, q))
for country in countries:
mod = SARIMAXPipeline(country=country, y_var="total",
data=None,
exog_var=["covid", "stringency_index",
str(country)+"_travel"],
transform_method="scaledlogit",
training_ratio=1,
verbose=False)
mod.read_and_merge()
stl = STL(mod.data["total"], period=12)
stl_res = stl.fit()
mod_val = stl_res.resid + stl_res.trend
auto = pm.auto_arima(mod_val, trend="t")
print(calculate_evaluation(
mod.data["total"], stl_res.seasonal + auto.fittedvalues()))
stlf = STLForecast(mod.data["total"], ARIMA, period=12, model_kwargs=dict(
order=auto.order, trend="t"))
stlf_res = stlf.fit()
temp = pd.DataFrame(stlf_res.forecast(12), columns=["forecast"])
temp = pd.concat([mod.data[["total"]], temp], axis=1)
temp["date"] = pd.date_range(start=mod.data["date"][0], periods=len(temp), freq="MS")
temp_viz = temp.melt(id_vars="date")
fig, ax = plt.subplots(figsize=(12,9))
sns.lineplot(data=temp_viz, x="date", y="value", hue="variable")
plt.ylabel("Number of Tourists")
plt.title(str(country).upper() + "'s STL + ARIMA " + str(auto.order))
plt.show()