How to specify hash_funcs of @st.cache for all subtypes of a specific class

I want to override hashing functions of @st.cache for all sub classes of a specific base class.

Though we can now change hash_funcs of st.cache, it seems to affect only the specified types/classes in the keys, but not their subtypes.

Are there any ways to do that?


For example, in the sample code below, some_func decorated with @st.cache properly caches an instance of Base class, but not Sub class.

import streamlit as st
import random


class Base:
    id = 42
    foo = "foo"


class Sub(Base):
    pass


@st.cache(hash_funcs={Base: lambda o: o.id})
def some_func(obj):
    return random.randint(0, 1000)


base = Base()
first = some_func(base)
base.foo = "foo"  # Only .foo attr is changed. .id attr, which is used as for hashing, is not changed.
second = some_func(base)
cached = first == second
print("Is an instance of Base class cached: ", cached)

# Do the same thing on Sub class
sub = Sub()
first = some_func(sub)
sub.foo = "bar"
second = some_func(sub)
cached = first == second
print("Is an instance of Sub class cached: ", cached)

The result of this sample would be

Is an instance of Base class cached:  True
Is an instance of Sub class cached:  False

Currently, I avoid this problem with introducing a wrapper class, ObjectHashWrapper the example below.

from typing import Generic, TypeVar


HashedObjT = TypeVar("HashedObjT")


class ObjectHashWrapper(Generic[HashedObjT]):
    def __init__(self, obj: HashedObjT, hash) -> None:
        self.obj = obj
        self.hash = hash


@st.cache(hash_funcs={ObjectHashWrapper: lambda o: o.hash})
def _inner_func(obj: ObjectHashWrapper[Base]):
    return random.randint(0, 1000)


def workaround_func(obj):
    hash_wrapper = ObjectHashWrapper(obj, obj.id)
    return _inner_func(hash_wrapper)


base = Base()
first = workaround_func(base)
base.foo = "foo"  # Only .foo attr is changed. .id attr, which is used as for hashing, is not changed.
second = workaround_func(base)
cached = first == second
print("Is an instance of Base class cached: ", cached)

# Do the same thing on Sub class
sub = Sub()
first = workaround_func(sub)
sub.foo = "bar"
second = workaround_func(sub)
cached = first == second
print("Is an instance of Sub class cached: ", cached)

The result as below shows this hack works, but I think it does not look like an authentic way.

Is an instance of Base class cached:  True
Is an instance of Sub class cached:  True