Combination vs. Individual Models

Combination vs. Individual Models#

import os
os.chdir("../../")
import numpy as np
import pandas as pd


from bokeh.palettes import Spectral10
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, HoverTool, CustomJS, Label
from bokeh.io import output_notebook
from bokeh.models.widgets import Select
from bokeh.layouts import column

output_notebook()
Loading BokehJS ...
Hide code cell source
countries = ["palau", "samoa", "tonga", "solomon_islands", "vanuatu"]
data = {}
for country in countries:
    filepath = os.getcwd() + \
        f"/data/tourism/{country}/model/forecast_combo.csv"
    df = pd.read_csv(filepath).drop("Unnamed: 0", axis=1)
    df["date"] = pd.to_datetime(df["date"])
    data[country] = ColumnDataSource(df)

# Create the dropdown menu
country_select = Select(title="Select Country", width=200,
                        options=countries, value=countries[0])

# Create the initial plot
current_country = country_select.value
current_data = data[current_country]
p = figure(width=800, height=500, x_axis_type="datetime")

# Select all models
models = df.columns[~df.columns.isin(["date"])].tolist()

for name, color in zip(models, Spectral10):
    r = p.line(x="date", y=name, line_width=1.5, source=current_data,
               color=color, alpha=0.5, name=name,
               muted_color=color, muted_alpha=0.2, legend_label=name)

hover = HoverTool(tooltips=[("date", "@date{%F}"),
                            ("method", "$name"),
                            ('value', '$y{f}'),
                            ('total', "@total")],
                  formatters={"@date": "datetime"})

p.add_tools(hover)

p.legend.location = "top_left"
p.legend.title_text_font_style = "bold"
p.legend.title_text_font_size = "20px"
p.legend.click_policy = "mute"

js_callback = CustomJS(args=dict(data=data, country_select=country_select, p=p, models=models), code="""
    const selectedCountry = country_select.value;
    const currentData = data[selectedCountry];
    
    // Update the data source of the plot
    for (let i = 0; i < p.renderers.length; i++) {
        const renderer = p.renderers[i];
        const name = models[i];
        renderer.data_source.data = currentData.data;
        renderer.name = name;
    }
""")

# Add the JavaScript callback to the dropdown menu
country_select.js_on_change('value', js_callback)

# Add a label to display the length of the select menu
layout = column(country_select, p)

# Show the layout with the dropdown menu and the plot
show(layout)
countries = ["palau", "samoa", "tonga", "solomon_islands", "vanuatu"]

weights = pd.DataFrame()
for country in countries:
    filepath = os.getcwd() + \
        f"/data/tourism/{country}/model/forecast_combo_weights.csv"
    country_weight = pd.read_csv(filepath).drop("Unnamed: 0", axis=1)
    country_weight = country_weight[[col for col in country_weight.columns
                                     if "rpw" in col]]
    country_weight.columns = [str(country) + '_' + col
                              if col != "date" else col
                              for col in country_weight.columns]
    if weights.empty:
        weights = country_weight
    else:
        weights = pd.concat([weights, country_weight], axis=1)
        
weights["date"] = pd.date_range(start="2019-01-01", periods=len(weights), freq="MS")
Hide code cell source
weights_src = ColumnDataSource(weights)

def quick_generate(country):

    s = figure(width=800, height=300, x_axis_type="datetime")
    con_mod_vars = [col for col in weights.columns
                    if col.startswith(str(country))]

    for var, color in zip(con_mod_vars, Spectral10):
        s.line(x="date", y=var, source=weights_src, alpha=0.7, color=color,
               name=var, legend_label=var)

    hover = HoverTool(tooltips=[("date", "@date{%F}"),
                                ("method", "$name"),
                                ("vale", "$y")],
                      formatters={"@date": "datetime"})

    s.add_tools(hover)
    s.legend.location = "top_right"
    s.title = str(country).upper() + "'s Relative Performance Weights"
    
    return s

s1 = quick_generate("palau")
s2 = quick_generate("tonga")
s3 = quick_generate("samoa")
s4 = quick_generate("solomon_islands")
s5 = quick_generate("vanuatu")


show(column(s1, s2, s3, s4, s5))