-
Notifications
You must be signed in to change notification settings - Fork 260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
should we have an extension point for model transforms out of tree? #790
Comments
cc @awgu , @tianyu-l , @weifengpy |
@vkuzo more than happy to work together on this! Before we explore solutions, may I ask questions to better understand your requests
|
Thanks @vkuzo for opening this issue, I wanted to actually raise the same question! @tianyu-l We are very interested in having a similar feature. In our research team, some of our projects use TorchTitan as a git submodule instead of forking/copying the code (making the upgrade to latest We've implemented a solution internally, and I would be very happy to open a PR for it. In the big lines, we define a simple general class ModelHandler(Protocol):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
...
def convert(self, model: nn.Module):
...
def pre_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
...
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
...
def register_model_handler(...):
pass This interface generalizes the Then in the YAML config, you can define a list of model handlers you want to apply, with every handler having its own (optional) parameters (passed to [model]
name = "llama3"
flavor = "3B"
handlers = ["Float8Handler", "FusedCrossEntropyLossHandler"]
[float8]
...
[fused_cross_entropy_loss]
... The collection of model handlers is applied sequentially to the model, and similarly pre/post optimizer hooks are called one after the other. Our feeling is that this strikes a good balance of keeping TorchTitan codebase simple, but allowing more easily users to incorporate their own training logic. |
@balancap I have a question though: another general question: cc: @vkuzo |
Indeed, I would keep For the kind of projects we have in mind, it is in line with what you're saying: |
@vkuzo In the PRs you listed, I do see that sometimes you'd change |
yes, I think that's a good way to describe it!
yes, that matches 90% of my use cases as well. For the other 10% forking is fine as well, as I have been doing! |
@balancap I'm curious why do |
@fegin I think the
Besides, I'd like to understand your needs more precisely: Could you clarify a bit? |
We are using TorchTitan as a git submodule, so our training script look like that: import torch
import wandb
# First import for MonkeyPatching TorchTitan.
import my_llm_training_module # noqa: F401
from my_llm_training_module import JobConfig # Addiitional Job config options.
from torchtitan import train as tt_train
from torchtitan.logging import logger
def main():
config = JobConfig()
config.parse_args()
# Setup wandb...
# Main TorchTitan training setup & loop
try:
tt_train.main(config)
torch.distributed.destroy_process_group()
except Exception as e:
# Error logging before process ends, to record in W&B.
# Keeping formatting similar to `torchrun` error output.
logger.error("--- Logging error --")
logger.error(f"{type(e).__module__}.{type(e).__qualname__}: {e}", exc_info=True)
raise
finally:
# Note keeping W&B init + finish in `main` for clean exception handling.
wandb.finish()
if __name__ == "__main__":
main() So for some projects, it has the benefit of avoiding copy-pasting the train loop (which means potentially complex testing to make sure we don't break things). Definitely not a fit all approach, but it has some benefits. At the moment, we are changing optimizer, LR scheduler, quantization scheme through fairly hacky monkey patching, which can break at any time when updating TT. It would be much nicer & sustainable to have clean entry points like the |
@balancap Sounds good to me. Please feel free to make a PR on this, I can help review. Please add relevant tests and docs, following https://github.com/pytorch/torchtitan/blob/main/CONTRIBUTING.md Besides, in addition to the |
In torchao, we have various low precision training features which are in prototype: MX, int8, bitnet. While we expect most of these to eventually end up in the main torchao APIs, it often takes ~months for a prototype to graduate.
torchtitan is extremely useful for helping us test low precision prototypes in real-world settings. For now, we've been creating unlanded PRs to test functionality (examples: #614, #778). Would torchtitan consider building an extension point to support this kind of experimentation fully out-of-tree?
An example of how this could look like:
I'm not entirely sure on how this hook would be implemented since the current interface of torchtitan is CLI based, but wanted to share the request and start the discussion.
The text was updated successfully, but these errors were encountered: