Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RMP] Add PyTorch backend in Merlin Models #893

Open
12 of 26 tasks
marcromeyn opened this issue Apr 3, 2023 · 3 comments
Open
12 of 26 tasks

[RMP] Add PyTorch backend in Merlin Models #893

marcromeyn opened this issue Apr 3, 2023 · 3 comments
Assignees
Labels
Milestone

Comments

@marcromeyn
Copy link

marcromeyn commented Apr 3, 2023

Problem:

We are currently in a situation where some customers are using merlin-models & some T4Rec to train models. The APIs of these 2 tools have diverged quite dramatically and some features (like extracting embeddings out of models) are only supported in Merlin Models. Both tools require some work in order to have easy to use APIs.

On the Merlin models side, we are in a in-between state where (because of time pressure) there are a bunch of V1 & V2 classes. We would like to migrate all our users to the V2 classes (while removing V2 from the name) & deprecate the old classes.

On the T4Rec side, we would like to keep using this project for session-based models in PyTorch because of the traction we've got. The idea would be to break out the core model-building parts (block-API) in favor of the pytorch-backend of Merlin Models. This roadmap-level ticket focusses on this new pytorch-backend, integration into T4Rec is left out for later. The first major deliverable of this backend is the creation of retrieval models, this because we typically frame session-based models as retrieval-models

Goal:

Reach feature parity & rough API parity between TF & PyTorch backends in Merlin models. This roadmap ticket will be around PyTorch, a future roadmap ticket will focus on TF.

New Functionality

  • Models
    • PyTorch: New backend, build from the ground up based on the TF implementation. Port the all retrieval examples.

Constraints:

  • We focus on just retrieval-models. Ranking-models will be tackled in a future roadmap ticket.
  • Migrating T4Rec to the new Block-API is future work and will be captured in another roadmap-level ticket.

Starting Point:

In order to properly plan out the work, a dev-branch is created to answer various design-questions around being able to create retrieval-models in PyTorch. This has lead to a rough MVP that contains all the major pieces. This has also given us a better idea how to break things down to turn the MVP into a fully fleshed product.

We are planning to have people work in parallel on 4 different major parts: inputs, outputs, models & masking.

Implement base-classes of block-API in PyTorch

People: @marcromeyn

Currently the block-API is T4Rec is using a similar design to Keras to allow for modules that lazily initialize their variables. We would like to deprecate this in favor of a native way to achieve the same thing that could launched recently.

Masking

People: @sararb, @gabrielspmoreira & @marcromeyn

This work is dependent on answering the design-question how to handle ragged-tensors.

Tasks: TODO

Input-blocks

People: @marcromeyn

PyTorch

Starting point: MVP

  • Implement Continuous & Embeddings
  • Implement TabularInputBlock
  • Implement Encoder
  • Add support for sequential-features in input-blocks
  • Do performance testing of holding multiple features in a single embedding-table

Output-blocks

People: @edknv & @marcromeyn

Models

People: @edknv & @marcromeyn

Starting point: MVP

One of the leading questions in the initial experimentation phase was to figure out if we can leverage PyTorch lightning for a high-level training-API (similar to how we use Keras on the TF-side). We are confident that PyTorch Lightning is the right path forward.

  • Implement Model class (using PyTorch lightning)
  • Create custom Trainer that can handle multi-GPU with data-loader
  • Implement RetrievalModel class
  • Port MatrixFactorizationModel, TwoTowerModel & YoutubeDNNRetrievalModel

Documentation

  • Create a migration guide from Transformers4Rec to Merlin Models session-based PyTorch API
@viswa-nvidia
Copy link

@marcromeyn , please create the tasks for PyT and create the tickets so that we can assign them

@marcromeyn marcromeyn changed the title [RMP] Unify and clean up block-API in TensorFlow & PyTorch [RMP] Add PyTorch backend in Merlin Models May 30, 2023
@EvenOldridge
Copy link
Member

@marcromeyn @gabrielspmoreira can you work to split this up into: Ranking, Retrieval and Session based

@marcromeyn
Copy link
Author

Ranking ticket is here: #1044

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants