Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't replicate O(1) memory with Adjoint method #259

Open
petercmh01 opened this issue Dec 23, 2024 · 0 comments
Open

Can't replicate O(1) memory with Adjoint method #259

petercmh01 opened this issue Dec 23, 2024 · 0 comments

Comments

@petercmh01
Copy link

petercmh01 commented Dec 23, 2024

Hi, I tried to use the following code to test the effects of using adjoint method

import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
import time

class LargeODEFunc(nn.Module):
    def __init__(self, state_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, t, y):
        return self.net(y)

state_dim = 1024
hidden_dim = 512
output_dim = 1024

ode_func = LargeODEFunc(state_dim, hidden_dim, output_dim).cuda()

y0 = torch.randn(1, state_dim).cuda()
t_dense = torch.linspace(0, 1, 1000).cuda()
t_sparse = torch.linspace(0, 1, 10).cuda()

optimizer = torch.optim.Adam(ode_func.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

def print_cuda_memory_usage():
    print(f"Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Memory Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def train_step(all_time_steps, label):
    optimizer.zero_grad()
    torch.cuda.reset_peak_memory_stats()

    print("\nAfter optimizer.zero_grad() and before odeint()")
    print_cuda_memory_usage()

    y_ode = odeint(ode_func, y0, all_time_steps)[-1]
    
    target = torch.randn_like(y_ode).cuda()
    loss = loss_fn(y_ode, target)
    loss.backward()
    print("\nAfter loss_backward()")
    print_cuda_memory_usage()
    optimizer.step()
    torch.cuda.synchronize()
    optimizer.zero_grad()
    #print(f"Peak Memory = {peak_mem:.2f}MB, Decoded Steps = {len(all_time_steps)}")

print("Training with dense time steps:")
for epoch in range(5):
    train_step(t_dense, "dense")

As comparison, I seen to have lower memory consumption with the normal odeint instead of odeint_adjoint. Can anyone see if there is a problem in my design?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant