I had a problem with ballooning memory usage with an image datastructure produced by a class based on PyTorch. It’s fixed now, but I wish I understood better why this fix was necessary.
The problem I saw was that the GPU memory usage of my app was going up by roughly 1Gb per run. After about 12Gb, CUDA failed with a memory error.
With help from a colleague, I narrowed it down to these lines:
def img2disp(img):
return img.detach().clamp_(min=-1, max=1).add(1).div_(2).permute(0, 2, 3, 1).to('cpu').numpy()[0]
img_t = sg2(wlist)
img = img2disp(img_t)
st.image(img)
None of the above operations has an @st.cache applied to it. sg2
is a Torch datastructure. The img_t
datastructure is a 1024x1024x3 image (a Torch tensor, IIRC). One solution to the memory problem is to not save the output image and just run:
sg2(wlist)
which isn’t that useful, but it helps isolate the problem.
Instead, adding del img_t
after st.image(img)
fixed the problem. So now my code works fine, but it’s worrying that I had to do that because I don’t know why img_t
didn’t get deleted/reclaimed. It isn’t referenced anywhere in my code. I don’t even know if streamlit is responsible (it’s not important enough to me to write a standalone script to test this.)
Any thoughts?