Weird memory usage behavior

I had a problem with ballooning memory usage with an image datastructure produced by a class based on PyTorch. It’s fixed now, but I wish I understood better why this fix was necessary.

The problem I saw was that the GPU memory usage of my app was going up by roughly 1Gb per run. After about 12Gb, CUDA failed with a memory error.

With help from a colleague, I narrowed it down to these lines:

def img2disp(img):
    return img.detach().clamp_(min=-1, max=1).add(1).div_(2).permute(0, 2, 3, 1).to('cpu').numpy()[0]

img_t = sg2(wlist)
img = img2disp(img_t)
st.image(img)

None of the above operations has an @st.cache applied to it. sg2 is a Torch datastructure. The img_t datastructure is a 1024x1024x3 image (a Torch tensor, IIRC). One solution to the memory problem is to not save the output image and just run:

sg2(wlist)

which isn’t that useful, but it helps isolate the problem.

Instead, adding del img_t after st.image(img) fixed the problem. So now my code works fine, but it’s worrying that I had to do that because I don’t know why img_t didn’t get deleted/reclaimed. It isn’t referenced anywhere in my code. I don’t even know if streamlit is responsible (it’s not important enough to me to write a standalone script to test this.)

Any thoughts?

2 Likes

Same here: GPU usage is going up every time I run something via inference.

Hey @hertzmann -

Without seeing the entire script, it’s hard to know what’s going on here (and possibly hard to know even with the full context!).

My hunch is that this is either a bug in PyTorch (that GPU memory is almost surely being allocated by non-Python FFI code within PyTorch), or an error in PyTorch usage. Since Streamlit apps are “long-lived” in the sense that the Python interpreter does not shut down between app runs, they can surface memory leaks that wouldn’t otherwise be noticeable in a comparable script that runs to completion and then exits completely.

st.image - and most other st.foo functions - do create some data that’s persisted across runs, but it’s specifically serialized protobuf objects that by definition can’t hold pointers to PyTorch objects or anything else, so I’m fairly sure there’s nothing inside st.image that’s causing this - but never say never when it comes to memory leaks :slight_smile:

It would be useful to know: Do you still see the behavior if you remove the call to st.image? If not, I don’t think this is anything to do with Streamlit. (But if so, we’d love to know!)