Summary
I want to download an image that is generated with matplotlib pyplot. I have tested different solutions but they are not working in my case. In this case, I have a function which is assembling a plot and then I want to save this figure. I am using streamlit in the cloud (with github)
Steps to reproduce
Code snippet:
#code inspired from:
#https://stackoverflow.com/questions/71713951/how-to-download-matplotlib-graphs-generated-in-a-streamlit-app
#this function takes a list of images and captions (optionally)
#it should plot the figure
def download_images(images, captions, cols = 2, rows = 2):
n_img = len(images)
figure, axis = plt.subplots(rows, cols)
ax = axis.flatten()
for i in range(n_img):
ax[i].imshow(images[i])
ax[i].axis('off')
if captions is not None:
ax[i].set_title(str(captions[i]))
plt.tight_layout()
plt.show()
#here I create a button to decide if you want to save the images
#it should then execute the function above and save in png
save_all_imgs = st.button('save all images',
key = '1')
if save_all_imgs:
download_images(imgs,
captions= None,
cols = 3, rows = 3 )
plt.savefig(img, format='png')
fn = 'scatter.png'
img = io.BytesIO()
btn = st.download_button(
label="Download image",
data=img,
file_name=fn,
mime="image/png"
)
Expected behavior:
it should save/download the image
Actual behavior:
there is no error, the function is executed but then no image is downloaded.
Debug info
I am using streamlit in the cloud
Requirements file
torch
matplotlib
torchvision
scikit-learn
Thank you very much for your help