Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Memory Leak in BraxEnv with requires_grad=True #2837

Open
mondeg0 opened this issue Mar 7, 2025 · 1 comment
Open

[BUG] Memory Leak in BraxEnv with requires_grad=True #2837

mondeg0 opened this issue Mar 7, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@mondeg0
Copy link

mondeg0 commented Mar 7, 2025

Describe the bug

When using BraxEnv with requires_grad=True, there appears to be a memory leak on the CPU side. The memory usage keeps increasing over time, which can be observed using tools like htop. This happens even when explicitly detaching and cloning next_td and backwarding to release the graph.

To Reproduce

Simply run the following code and observe the increasing in memory usage.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from brax import envs
from torchrl.envs.libs.brax import BraxEnv

class Actor(nn.Module):
    def __init__(self, obs_size, action_size, hidden_size, policy_std_init=0.05):
        super(Actor, self).__init__()
        self.mu_net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, action_size),
        )

    def forward(self, obs):
        loc = self.mu_net(obs)
        return torch.tanh(loc)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
env_name = "hopper"
seed = 13
num_envs = 2
epochs = 10000

policy = Actor(11, 3, 64).to(device)
optim = optim.Adam(policy.parameters(), lr=1e-4)

env = BraxEnv(env_name, batch_size=[num_envs], requires_grad=True, device=device)
env.set_seed(seed)

next_td = env.reset()
for i in range(epochs):
    print(i)
    next_td["action"] = policy(next_td["observation"])
    out_td, next_td = env.step_and_maybe_reset(next_td)

    if out_td["next", "done"].any():
        loss = out_td["next", "observation"].sum()  # just for demonstration purpose
        loss.backward()
        optim.step()
        optim.zero_grad()
        next_td = next_td.detach().clone()

Expected behavior

The memory usage should remain stable over time instead of continuously increasing.

System info

Installation was done with pip.

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

outputs :

>>> print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
0.7.1 2.2.3 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] linux

Additional context

This issue seems to be related to the way BraxEnv manages memory when requires_grad=True. Even though next_td is detached and cloned, memory continues to accumulate on the CPU.

Reason and Possible fixes

Some list or buffer is accumulating things under the wood?

Checklist

  • [x ] I have checked that there is no similar issue in the repo (required)
  • [x ] I have read the documentation (required)
  • [ x] I have provided a minimal working example to reproduce the bug (required)
@mondeg0 mondeg0 added the bug Something isn't working label Mar 7, 2025
@vmoens
Copy link
Contributor

vmoens commented Mar 7, 2025

Will look into this, thanks for reporting!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants