Issue with Plotly Custom Component resizing after first rendering

Hi all,

I’m trying to make a heatmap that could be queried, meaning that when user clicks on a heatmap cell, it applies a filter on the data therefore changing the heatmap itself.

The issue comes when I try to resize the height of the heatmap depending on the number of remaining lines in the plot. I have tried a bunch of differents ways to do this but none seem to work, here is the code I’m running on React side :

import React, { useEffect, useState, ReactNode } from "react";
import { Streamlit, withStreamlitConnection, ComponentProps, StreamlitComponentBase } from "streamlit-component-lib";
import Plot from "react-plotly.js"  


class CustomHeat extends StreamlitComponentBase {
  click_num= this.props.args.click_num; // Create a new variable of type number and initialize it to 0
  

  public render = (): ReactNode => {

  // let totalHeight: number = this.props.args.height;

  const { data, layout, frames, config } = JSON.parse(this.props.args.spec)
  const override_height = this.props.args["override_height"];

  Streamlit.setFrameHeight(override_height);

  return (
    <Plot
      data={data}
      layout={layout}
      frames={frames}
      config={config}
      onClick={this.handleClick}
      style={{width: '100%', height: override_height}}
      />
    )
  }

  private handleClick = (eventData: any) => {
    this.click_num += 1;
    Streamlit.setComponentValue(
      eventData.points.map((p: any) => {
        return { index: p.pointIndex, x: p.x, y: p.y, click_num: this.click_num}
      })
    )
  }
}

On python side, in init.py :

import os
import streamlit.components.v1 as components

# Define location of the packaged frontend build
parent_dir = os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(parent_dir, "frontend/build")


# Create a function _component_func which will call the frontend component when run
_component_func = components.declare_component(
    "custom_heatmap"
    # , url="http://localhost:3001"  # Fetch frontend component from local webserver
    , path=build_dir
)

# Define a public function for the package,
# which wraps the caller to the frontend code
def st_custom_plot(fig, click_num, override_height, key):
    component_value = _component_func(spec=fig.to_json(), click_num=click_num, override_height=override_height, key=key)
    return component_value

On python side when actually using the component :

# Define a function to remove an element from the session state
def remove_element_from_session_state():
    if st.session_state.heatmap_filter is not None:
        st.success(f"'{st.session_state.heatmap_filter}' removed from filters.")
        st.session_state.heatmap_filter = None
        # st.session_state.heatmap[0]["y"] = None
        st.session_state.has_been_clicked = False
    else:
        st.warning(f"No filter applied.")

def get_heatmap_data(df, y_col):
    # Format the table used for the heatmap
    heatmap_data = df.loc[:, ["Country", "Cat", y_col]]
    heatmap_data = heatmap_data.groupby(["Country", "Cat"], as_index=False).mean()
    heatmap_pivot = heatmap_data.pivot(index='Country', columns='Cat', values=y_col)
    
    return heatmap_pivot

def get_fig_heatmap(heatmap_pivot_filtered, num_height):
    
    layout = go.Layout(
    # paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(217,217,214,0.1)',
    title='Heatmap',
    xaxis=dict(title='Criticality'),
    yaxis=dict(title='Country'),
    height=int(num_height*50), width=650
)
    fig = go.Figure(data=go.Heatmap(
                z=heatmap_pivot_filtered.values,
                x=heatmap_pivot_filtered.columns,
                y=heatmap_pivot_filtered.index,
                type='heatmap',
                hoverongaps = False,
                colorscale= [
                            [0.0, "#FDFD96"],  # pastel yellow
                            [1.0, "#FF6961"]   # pastel red
                        ]
,
                # colorscale='gray_r',
                zmin=heatmap_pivot_filtered.values.min(),
                zmax=heatmap_pivot_filtered.values.max()), layout=layout)
    
    # fig.update_layout(height=50*num_height)
    return fig

def session_print(x):
    if x in st.session_state:
        st.write(st.session_state[x])
    else:
        st.write(f"{x} is not yet defined")

def update_heatmap_vars():
    if "heatmap" in st.session_state:               # if the heatmap has been rendered at least once
        if st.session_state.heatmap is not None:        # if the heatmap has been clicked on at least once
            if st.session_state.click_num < st.session_state.heatmap[0]["click_num"]:
                st.session_state.click_num = st.session_state.heatmap[0]["click_num"]
                st.session_state.heatmap_filter = st.session_state.heatmap[0]["y"]
            # else:
            #     st.session_state.heatmap_filter = None
        else:
            st.session_state.heatmap_filter = None
    else:
        st.session_state.heatmap_filter = None

def filter_heatmap_data(heatmap_pivot):
    if "heatmap" in st.session_state:     # if the heatmap has been rendered at least once
        if st.session_state.heatmap is not None:    #if the heatmap has been clicked on at least once
            if st.session_state.heatmap_filter is not None:
                heatmap_pivot_filtered = heatmap_pivot.loc[heatmap_pivot.index == st.session_state.heatmap_filter, :]
            else:
                heatmap_pivot_filtered = heatmap_pivot.copy()
        else:
            heatmap_pivot_filtered = heatmap_pivot.copy()
    else:
        heatmap_pivot_filtered = heatmap_pivot.copy()
    return heatmap_pivot_filtered

def filter_all_data(df):
    if "heatmap" in st.session_state:
        if st.session_state.heatmap is not None:
            if st.session_state.heatmap_filter is not None:              
                country = st.session_state.heatmap[0]["y"]
                df = df.loc[df.Country == country, :]

    return df

def display_heatmap(df, y_col):

    # Transform data to the right format
    heatmap_pivot = get_heatmap_data(df, y_col)
    
    update_heatmap_vars()
        
    # Button to remove the applied filter
    if st.button("Remove Heatmap Filter", key="remove_heatmap"):
        remove_element_from_session_state()
    
    # Filter the data 
    heatmap_pivot_filtered = filter_heatmap_data(heatmap_pivot)
    
    num_height_ticks = heatmap_pivot_filtered.shape[0]
    if num_height_ticks != 0:
        # Define the heatmap fig
        fig = get_fig_heatmap(heatmap_pivot_filtered, num_height=num_height_ticks)

        # Send and receive plot from React side
        if "heatmap" in st.session_state:
            del st.session_state.heatmap
        st_custom_plot(fig, click_num=st.session_state.click_num, 
                            override_height=int(num_height_ticks*50) , key="heatmap")
    else:
        st.warning("Current filter selection leads to no data")
        
    # Apply filters
    df = filter_all_data(df)
            
    return df

This code may actually be useful for anyone trying to have a bi-directional plotly heatmap (it was inspired by many other similar use case).

Here is what the heatmap looks like at the end before filtering (many blanks for privacy purposes):

And after filtering without all the code related to num_height and override_height.:

If I now try to resize the fig in the same manner as I’m doing in the script above, the plot is fine at first rendering, however when I try to filter the lines and therefore change the height of the fig, its starts bugging. The plot keeps switching between new height and the old height at a high rate but it can’t seem to decide which one to choose.

Does anyone have any idea why ? Or just in which direction I should research ?

Sorry if the code isn’t perfect, it’s my first streamlit app and I still have a lot to learn !

Hi @Thomas_LE_ROUX, welcome to the community! :raised_hands:

Streamlit supports Plotly out of the box – May I ask why you would need to use Plotly via a custom component?

Thanks,
Charly

Hi Charly,

Thanks for the quick reply!

I need a bi-directional widget, meaning that when a user clicks on the heat map, the streamlit app receives the click coordinates through the componentValue and applies filters on the data before rendering the heat map. This part works like a charm.

The issue comes from the resizing of the plot depending on the number of ticks on the y-axis.

Thanks,
Thomas

Thanks for your feedback, @Thomas_LE_ROUX!

@blackary, I was wondering whether you knew a workaround for this?

This topic was automatically closed 180 days after the last reply. New replies are no longer allowed.