Caching PyTorch variables?

Is there a recommended way to cache variables from PyTorch? I’m currently getting the error “UnhashableType: Cannot hash object of type torch._C._TensorBase”.

I’m trying to cache initialization of a StyleGAN2 class datastructure (which includes loading model weights from a file or online). I don’t want the model weights to be reloaded (and other data structures re-initialized) every time a new image is created.

I tried refactoring my code so that @st.cache is in front of a function that only does one thing: torch.load(filename) and the same error occurs.

Hi @hertzmann and welcome to the community! :balloon:

Have you tried adding allow_output_mutation=True to the st.cache() function call that loads the weights?

2 Likes

That works, thanks!

Hey @hertzmann :wave:,

We’d like to natively support this in Streamlit. Anyway you could reproduce the issue for us after upgrading to 0.57.0 and share the updated error message? Also, for reference, here is the GitHub issue we made to track this. Thanks for the help!

Hey @hertzmann,

  • Type torch._C._TensorBase is now natively supported in Streamlit via nightly release. It will be in a general release soon.
  • Type Torch.Tensor is also now natively supported in Streamlit via 0.59.0.

Will update the thread one more time when TensorBase is in a general release. Thanks again for providing all of the info for these!

2 Likes

TensorBase is now native in 0.60.0 :partying_face:

1 Like

Amazing how easy this one-liner solves the issue :slight_smile:

1 Like