Hi,
I’d like to use joblib.Parallel in a heavily streamlit-cached streamlit application, and I’m bumping into something I don’t understand. I’d really be grateful for some hints about how best to debug such seemingly-unreasonable cache misses.
In my actual code I use a “Show 10 more thumbnails” button at the end of a thumbnail list. Moving a given thumbnail to a URL where the client can access it is a slow operation. Hence 1. I use st.cache around the s3cp(source, target)
networking call, and 2. I wrap the 10 calls in a joblib.Parallel. Each of these two optimizations work in isolation, but they don’t work when both are applied. Below is the smallest possible stylized reproduction:
import streamlit as st
from joblib import Parallel, delayed
import time
# In my real application this is fetching the x-th thumbnail file from a given L list of thumbnails.
# When a "Show more" button is pressed, this is re-run for all
# elements of L[:required_number_of_thumbnails].
@st.cache
def work(x):
print("processing input", x)
time.sleep(1)
return x * x
arguments = range(1, 4)
parallel = True
if parallel:
st.header("Parallel")
last_time = time.time()
for i in range(1, 4):
print("starting parallel batch", i)
results = Parallel(n_jobs=-1)(delayed(work)(x) for x in arguments)
st.write(results)
current_time = time.time()
print(current_time - last_time, "seconds")
last_time = current_time
# that's weird, the second and third batches hit the cache
# an unpredictable number of times,
# typically once of twice per batch.
else:
st.header("Sequential")
last_time = time.time()
for i in range(1, 4):
print("starting sequential batch", i)
results = [work(x) for x in arguments]
st.write(results)
current_time = time.time()
print(current_time - last_time, "seconds")
last_time = current_time
The task is sometimes executed, sometimes cached, in a seemingly unpredictable way:
processing input 1
processing input 2
processing input 3
1.4828269481658936 seconds
starting parallel batch 2
processing input 3
processing input 1
1.1675269603729248 seconds
starting parallel batch 3
processing input 3
processing input 2
1.1554582118988037 seconds