On_select for st.dataframe doesn't work for multi-index dataframe

I have a multi-index data frame and i am trying to use the on_select method. When I select a column the st.session_state.get(“styled_expiry_df”) is only populated with the values in the second level of the column index and there seems no way to tell what the first level of column index was. So I wouldn’t be able to differentiate between a user select a1 or b1.

df = pd.DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                  columns=pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)]))
    
def callback():
  st.write(st.session_state.get("styled_expiry_df"))

st.dataframe(
  df,
  on_select=callback,
  height=calculated_height,
  key="styled_expiry_df",
  selection_mode=["multi-row", "multi-column"]
)

{"selection":{"rows":[],"columns":["1","2"]}}