Interaction between joblib and streamlit.cache

Hi,

I’d like to use joblib.Parallel in a heavily streamlit-cached streamlit application, and I’m bumping into something I don’t understand. I’d really be grateful for some hints about how best to debug such seemingly-unreasonable cache misses.

In my actual code I use a “Show 10 more thumbnails” button at the end of a thumbnail list. Moving a given thumbnail to a URL where the client can access it is a slow operation. Hence 1. I use st.cache around the s3cp(source, target) networking call, and 2. I wrap the 10 calls in a joblib.Parallel. Each of these two optimizations work in isolation, but they don’t work when both are applied. Below is the smallest possible stylized reproduction:

import streamlit as st
from joblib import Parallel, delayed
import time


# In my real application this is fetching the x-th thumbnail file from a given L list of thumbnails.
# When a "Show more" button is pressed, this is re-run for all
# elements of L[:required_number_of_thumbnails].
@st.cache
def work(x):
    print("processing input", x)
    time.sleep(1)
    return x * x

arguments = range(1, 4)

parallel = True

if parallel:
    st.header("Parallel")
    last_time = time.time()
    for i in range(1, 4):
        print("starting parallel batch", i)
        results = Parallel(n_jobs=-1)(delayed(work)(x) for x in arguments)
        st.write(results)
        current_time = time.time()
        print(current_time - last_time, "seconds")
        last_time = current_time
    # that's weird, the second and third batches hit the cache
    # an unpredictable number of times,
    # typically once of twice per batch.
else:
    st.header("Sequential")
    last_time = time.time()
    for i in range(1, 4):
        print("starting sequential batch", i)
        results = [work(x) for x in arguments]
        st.write(results)
        current_time = time.time()
        print(current_time - last_time, "seconds")
        last_time = current_time

The task is sometimes executed, sometimes cached, in a seemingly unpredictable way:

processing input 1
processing input 2
processing input 3
1.4828269481658936 seconds
starting parallel batch 2
processing input 3
processing input 1
1.1675269603729248 seconds
starting parallel batch 3
processing input 3
processing input 2
1.1554582118988037 seconds

Hi @danielvarga,
thanks a lot for your question. I am not an expert of joblib, but here is a theory.
st.cache is a global object and, if the backend is similar to multiprocessing (I know the default should be loky, but I am assuming is not very different for simplicity), I am expecting that there is a different cache for each process and each cache is initially empty. This means that work(i) is scheduled on one of these parallel processes. The first time it will miss, but then it should hit. What I am expecting is that work(i) is not bound to a specific process, so depending on where it gets scheduled the second (and third) time it may hit or miss.

Let me know if it make sense to me. I am going to circulate this answer in Streamlit and see if people have different ideas.

Matteo