Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge case in PyTorchPredictor.deserialize #2994

Merged
merged 1 commit into from
Sep 5, 2023

Conversation

lostella
Copy link
Contributor

@lostella lostella commented Sep 4, 2023

Description of changes: Since #2965 (not sure before), if a PyTorchPredictor object was serialized from the CPU memory, then deserializeing it with device="cuda" would actually not work. This would happen:

  1. predictor object created, with parameters of its torch.nn.Module model allocated on CPU (since that was the device the predictor was serialized with)
  2. when deserializing with device="cuda", torch.load with map_location="cuda" would put the state_dict values on GPU as expected
  3. torch.nn.Module.load_state_dict would however copy the parameters back to CPU
  4. since the predictor has the .device attribute set to "cpu", prediction would not complain (data & model on the same device) but would be really slow.

This PR makes sure that the predictor object being created is moved .to(device) after step 1, so that step 3 actually keeps parameters on GPU.

The same issue happens inverting CPU and GPU, as in the following example

import pandas as pd
import numpy as np
import logging
import tempfile
from gluonts.model import Predictor
from gluonts.torch.model.wavenet import WaveNetEstimator
from pathlib import Path

logging.basicConfig(level=logging.INFO)

data = [
    {
        "start": pd.Period("2012-02-04", freq="D"),
        "target": np.ones(2000),
    }
]

estimator = WaveNetEstimator(
    freq="H",
    prediction_length=24,
    num_batches_per_epoch=2,
    trainer_kwargs=dict(max_epochs=1),
)

predictor = estimator.train(data)

with tempfile.TemporaryDirectory() as td:
    predictor.serialize(Path(td))
    predictor = Predictor.deserialize(Path(td), device="cpu")
    print(f" predictor device is {predictor.device}")
    print(f"module parameters on {next(predictor.prediction_net.parameters()).device}")

Expected to have the resulting model on CPU. Output before the PR:

 predictor device is cuda
module parameters on cuda:0

Output after the PR:

 predictor device is cpu
module parameters on cpu

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@lostella lostella requested a review from abdulfatir September 4, 2023 11:29
@lostella lostella added torch This concerns the PyTorch side of GluonTS bug fix (one of pr required labels) labels Sep 4, 2023
@lostella lostella marked this pull request as draft September 4, 2023 13:43
@lostella lostella marked this pull request as ready for review September 4, 2023 15:19
@lostella
Copy link
Contributor Author

lostella commented Sep 4, 2023

@abdulfatir now I'm wondering, whether it makes sense to have a device option in PyTorchPredictor.deserialize at all, since doing

predictor = Predictor.deserialize(path, device=device)

is the same as doing

predictor = Predictor.deserialize(path)
predictor.to(device)

@lostella lostella added this to the v0.14 milestone Sep 4, 2023
@lostella lostella merged commit 25c76a2 into awslabs:dev Sep 5, 2023
@lostella lostella deleted the predictor-to-device branch September 5, 2023 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix (one of pr required labels) torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants