Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should we have an extension point for model transforms out of tree? #790

Open
vkuzo opened this issue Jan 15, 2025 · 11 comments
Open

should we have an extension point for model transforms out of tree? #790

vkuzo opened this issue Jan 15, 2025 · 11 comments
Assignees
Labels
enhancement New feature or request

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Jan 15, 2025

In torchao, we have various low precision training features which are in prototype: MX, int8, bitnet. While we expect most of these to eventually end up in the main torchao APIs, it often takes ~months for a prototype to graduate.

torchtitan is extremely useful for helping us test low precision prototypes in real-world settings. For now, we've been creating unlanded PRs to test functionality (examples: #614, #778). Would torchtitan consider building an extension point to support this kind of experimentation fully out-of-tree?

An example of how this could look like:

  1. torchtitan provides a "model transformation" hook that it calls at a specified point in the initialization stage (for quantization, that should be after model init and before parallelization / torch.compile)
  2. user can provide a custom pass to transform the model (such as a prototype low precision training conversion pass)

I'm not entirely sure on how this hook would be implemented since the current interface of torchtitan is CLI based, but wanted to share the request and start the discussion.

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 15, 2025

cc @awgu , @tianyu-l , @weifengpy

@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 16, 2025

@vkuzo more than happy to work together on this!

Before we explore solutions, may I ask questions to better understand your requests

  • what would you like to achieve ideally, e.g. do you want to use torchtitan as a library without cloning it? or more specifically demonstrate your work in torchao by importing torchtitan as a library?
  • (maybe repeating the question above) what are the "pain points" you are experiencing in the current ways (PR / branch / fork)

@tianyu-l tianyu-l added the enhancement New feature or request label Jan 16, 2025
@tianyu-l tianyu-l self-assigned this Jan 16, 2025
@balancap
Copy link

balancap commented Jan 17, 2025

Thanks @vkuzo for opening this issue, I wanted to actually raise the same question!

@tianyu-l We are very interested in having a similar feature. In our research team, some of our projects use TorchTitan as a git submodule instead of forking/copying the code (making the upgrade to latest main much cleaner and easier). But for that usecase, it is very useful to have clean/simple entry point for modifying models, optimizers, ...

We've implemented a solution internally, and I would be very happy to open a PR for it. In the big lines, we define a simple general ModelHandler protocol with a registry:

class ModelHandler(Protocol):
    def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
        ...

    def convert(self, model: nn.Module):
        ...

    def pre_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
        ...

    def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
        ...

def register_model_handler(...):
     pass

This interface generalizes the Float8Handler, can be used for any quantization handler, but also model modification (e.g. FlexAttention, custom fused kernels https://github.com/linkedin/Liger-Kernel/, ...)

Then in the YAML config, you can define a list of model handlers you want to apply, with every handler having its own (optional) parameters (passed to __init__ in the JobConfig):

[model]
name = "llama3"
flavor = "3B"
handlers = ["Float8Handler", "FusedCrossEntropyLossHandler"]

[float8]
...

[fused_cross_entropy_loss]
...

The collection of model handlers is applied sequentially to the model, and similarly pre/post optimizer hooks are called one after the other.

Our feeling is that this strikes a good balance of keeping TorchTitan codebase simple, but allowing more easily users to incorporate their own training logic.

@tianyu-l
Copy link
Contributor

@balancap
Thank you very much for the suggestions! Generalizing the Float8Handler sounds a reasonable thing to do.

I have a question though:
I'm assuming convert would be applied before the parallelization & compilation code. It sounds viable if the underlying change in convert stay compatible with those code. E.g. one might want to modify a (llama) model in a "parallelization-breaking" way, or register a completely new model (see #282). In such cases, do you think it's better to only support the forking way, not the library/submodule way?

another general question:
Could you share that in your case, what parts of torchtitan you'd keep and what parts you'd modify?
E.g. keeping the model, data loading, parallelization; modifying the dataset, optimizer / lr scheduler, FlexAttention / kernels

cc: @vkuzo

@balancap
Copy link

Indeed, I would keep convert before the parallelization & compilation code (like Float8Handler https://github.com/pytorch/torchtitan/blob/main/train.py#L116), and the user should be aware to that they need to stay compatible with those passes. Most quantization methods I can think of should satisfy that, as wel as swapping simple layers. If some people want to do deeper modifications, I agree forking is the way (and I tend to encourage it in our team if I feel a project needs to deeply modify TT). It should probably be very clear and explicit in ModelHandler definition what is the intended usecase, and what is beyond that scope.

For the kind of projects we have in mind, it is in line with what you're saying:
Keep: Parallelization, training loop, checkpointing, Llama model; => robust, but simple enough base infra with good MFU .
Change: Some layers in model (quantized linear, flex attention, liger kernels, ...), optimizer, LR scheduler and dataset. => Things which are fairly safe to modify without breaking parallelism or collapsing MFU.

@tianyu-l
Copy link
Contributor

@vkuzo In the PRs you listed, I do see that sometimes you'd change model.py and parallelize_llama.py in a more intrusive way. I wonder what's your take on the ModelHandler protocol?

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 24, 2025

what would you like to achieve ideally, e.g. do you want to use torchtitan as a library without cloning it? or more specifically demonstrate your work in torchao by importing torchtitan as a library?

yes, I think that's a good way to describe it!

Keep: Parallelization, training loop, checkpointing, Llama model; => robust, but simple enough base infra with good MFU .
Change: Some layers in model (quantized linear, flex attention, liger kernels, ...), optimizer, LR scheduler and dataset. => Things which are fairly safe to modify without breaking parallelism or collapsing MFU.

yes, that matches 90% of my use cases as well. For the other 10% forking is fine as well, as I have been doing!

@fegin
Copy link
Contributor

fegin commented Jan 24, 2025

@balancap I'm curious why dopre_optimizer_hook and post_opotimizer_hook mean? Do the the pre_hook and post_hook of optimizer.step()? If so, shouldn't these hooks belong to Optimizer class? Or ModelHandler is more close to TrainLoop class?

@tianyu-l
Copy link
Contributor

@fegin I think the float8_handler in train.py is a good example. For other use cases, @balancap please add.
In particular, are there motivations other than quantization to have the pre/post_optimizer_hooks?

Keep: Parallelization, training loop, checkpointing, Llama model; => robust, but simple enough base infra with good MFU .
Change: Some layers in model (quantized linear, flex attention, liger kernels, ...), optimizer, LR scheduler and dataset. => Things which are fairly safe to modify without breaking parallelism or collapsing MFU.

Besides, I'd like to understand your needs more precisely:
In order change optimizer / lr scheduler, I believe today one has to slightly modify the train loop. (Do you want that part also be configurable from config?) Also if installed as a library, one still needs to copy the train loop anyway?
And if train loop needs to be modified, you can still write your own handler out-of-TT, but just can't easily config things in toml.

Could you clarify a bit?

@balancap
Copy link

We are using TorchTitan as a git submodule, so our training script look like that:

import torch
import wandb

# First import for MonkeyPatching TorchTitan.
import my_llm_training_module  # noqa: F401
from my_llm_training_module import JobConfig # Addiitional Job config options.

from torchtitan import train as tt_train
from torchtitan.logging import logger


def main():
    config = JobConfig()
    config.parse_args()
    # Setup wandb...
    # Main TorchTitan training setup & loop
    try:
        tt_train.main(config)
        torch.distributed.destroy_process_group()
    except Exception as e:
        # Error logging before process ends, to record in W&B.
        # Keeping formatting similar to `torchrun` error output.
        logger.error("--- Logging error --")
        logger.error(f"{type(e).__module__}.{type(e).__qualname__}: {e}", exc_info=True)
        raise
    finally:
        # Note keeping W&B init + finish in `main` for clean exception handling.
        wandb.finish()

if __name__ == "__main__":
    main()

So for some projects, it has the benefit of avoiding copy-pasting the train loop (which means potentially complex testing to make sure we don't break things). Definitely not a fit all approach, but it has some benefits.

At the moment, we are changing optimizer, LR scheduler, quantization scheme through fairly hacky monkey patching, which can break at any time when updating TT. It would be much nicer & sustainable to have clean entry points like the ModelHandler interface I suggest above.

@tianyu-l
Copy link
Contributor

@balancap Sounds good to me. Please feel free to make a PR on this, I can help review. Please add relevant tests and docs, following https://github.com/pytorch/torchtitan/blob/main/CONTRIBUTING.md

Besides, in addition to the ModelHandler, it'd be good to allow customized optimizer builder and lr scheduler builder, similar to what we are using parallelize_llama. I can make a PR on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants