diff --git a/backtesting/_plotting.py b/backtesting/_plotting.py index 844318aa..272bfc54 100644 --- a/backtesting/_plotting.py +++ b/backtesting/_plotting.py @@ -27,6 +27,7 @@ DatetimeTickFormatter, WheelZoomTool, LinearColorMapper, + ColorBar ) try: from bokeh.models import CustomJSTickFormatter @@ -682,10 +683,7 @@ def plot_heatmaps(heatmap: pd.Series, agg: Union[Callable, str], ncols: int, dfs = [heatmap.groupby(list(dims)).agg(agg).to_frame(name='_Value') for dims in param_combinations] plots = [] - cmap = LinearColorMapper(palette='Viridis256', - low=min(df.min().min() for df in dfs), - high=max(df.max().max() for df in dfs), - nan_color='white') + for df in dfs: name1, name2 = df.index.names level1 = df.index.levels[0].astype(str).tolist() @@ -694,6 +692,11 @@ def plot_heatmaps(heatmap: pd.Series, agg: Union[Callable, str], ncols: int, df[name1] = df[name1].astype('str') df[name2] = df[name2].astype('str') + cmap = LinearColorMapper(palette='Viridis256', + low=df['_Value'].min(), + high=df['_Value'].max(), + nan_color='white') + fig = _figure(x_range=level1, y_range=level2, x_axis_label=name1, @@ -717,6 +720,8 @@ def plot_heatmaps(heatmap: pd.Series, agg: Union[Callable, str], ncols: int, line_color=None, fill_color=dict(field='_Value', transform=cmap)) + color_bar = ColorBar(color_mapper=cmap, location=(0, 0)) + fig.add_layout(color_bar, 'right') plots.append(fig) fig = gridplot(