diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py
index c5cdb2ba..1a25f28d 100644
--- a/src/instructlab/training/main_ds.py
+++ b/src/instructlab/training/main_ds.py
@@ -1,10 +1,8 @@
 # SPDX-License-Identifier: Apache-2.0
 
 # Standard
-from copy import deepcopy
 from pathlib import Path
 import argparse
-import math
 import os
 import re
 import subprocess
@@ -12,13 +10,10 @@
 
 # Third Party
 from accelerate import Accelerator
-from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
 from deepspeed.runtime.zero.utils import ZeRORuntimeException
 
 # pylint: disable=no-name-in-module
-from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
 from tqdm import tqdm
-from transformers import AutoModelForCausalLM, get_scheduler
 import torch
 import torch.distributed
 
@@ -34,18 +29,12 @@
 from instructlab.training.multipack_sampler import (
     find_packing_max_batch_len_and_grad_accum,
 )
-from instructlab.training.setup_accelerator import setup_accelerator
+from instructlab.training.setup_model import setup_model
 from instructlab.training.token_dataset import setup_dataloader, setup_dataset
 from instructlab.training.tokenizer_utils import setup_tokenizer
 from instructlab.training.utils import (
     StreamablePopen,
-    add_noisy_embeddings,
-    apply_gradient_checkpointing,
-    convert_loss_to_reduce_sum,
-    ensure_loadable_granite_checkpoint,
-    get_projection_layer_names,
     load_latest_full_state,
-    prepare_peft_model,
     prepare_universal_checkpoint_from_latest,
     retrieve_chat_template,
     save_checkpoint,
@@ -56,205 +45,6 @@
 import instructlab.training.data_process as dp
 
 
-def setup_optimizer(args, model):
-    if args.distributed_training_framework == DistributedBackend.FSDP.value:
-        optimizer = torch.optim.AdamW(
-            model.parameters(),
-            lr=args.learning_rate,
-            betas=(0.9, 0.95),
-            weight_decay=0.0,
-        )
-    elif args.distributed_training_framework == DistributedBackend.DEEPSPEED.value:
-        # need to use this only when the CPU offload optimizer is enabled
-        if args.cpu_offload_optimizer:
-            print(
-                "\033[33m!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!\033[0m"
-            )
-            optimizer = DeepSpeedCPUAdam(
-                model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
-            )
-        else:
-            optimizer = FusedAdam(
-                model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
-            )
-    else:
-        raise ValueError(
-            f"Sharding framework {args.distributed_training_framework} is not supported."
-        )
-    return optimizer
-
-
-def setup_model(args, tokenizer, train_loader, grad_accum):
-    bnb_config = None
-    if args.lora_r > 0 and args.lora_quant_bits == 4:
-        # Third Party
-        from transformers import BitsAndBytesConfig
-
-        bnb_config = BitsAndBytesConfig(
-            load_in_4bit=True,
-            bnb_4bit_quant_type="nf4",
-            bnb_4bit_use_double_quant=True,
-            bnb_4bit_compute_dtype=torch.float16,  # if not set will throw a warning about slow speeds when training
-        )
-
-    base_model_args = {
-        "pretrained_model_name_or_path": args.model_name_or_path,
-        "torch_dtype": torch.bfloat16,
-        "quantization_config": bnb_config,
-    }
-    if not args.disable_flash_attn:
-        base_model_args["attn_implementation"] = "flash_attention_2"
-    elif args.is_granite:
-        raise RuntimeError(
-            "ERROR: Trying to use padding-free transformer without flash attention is not supported"
-        )
-
-    if args.is_granite:
-        with ensure_loadable_granite_checkpoint(
-            args.model_name_or_path, args.output_dir
-        ) as path:
-            base_model_args["pretrained_model_name_or_path"] = path
-            model = GPTDolomiteForCausalLM.from_pretrained(
-                **base_model_args,
-                use_padding_free_transformer=True,
-            )
-    else:
-        model = AutoModelForCausalLM.from_pretrained(**base_model_args)
-
-    if len(tokenizer) > model.config.vocab_size:
-        print(
-            f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
-        )
-        model.resize_token_embeddings(
-            int(8 * math.ceil(len(tokenizer) / 8.0))
-        )  # make the vocab size multiple of 8 for sharding the embedding layer.
-
-    # Fix any discrepancy between model and tokenizer
-    if (
-        model.config.pad_token_id is not None
-        and tokenizer.pad_token_id is not None
-        and model.config.pad_token_id != tokenizer.pad_token_id
-    ):
-        print(
-            f"WARNING: There is a mismatch between pad token id of model ({model.config.pad_token_id}) and tokenizer({tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
-        )
-        model.config.pad_token_id = tokenizer.pad_token_id
-    if (
-        model.config.bos_token_id is not None
-        and tokenizer.bos_token_id is not None
-        and model.config.bos_token_id != tokenizer.bos_token_id
-    ):
-        print(
-            f"WARNING: There is a mismatch between bos token id of model({model.config.bos_token_id}) and tokenizer({tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
-        )
-        model.config.bos_token_id = tokenizer.bos_token_id
-    if (
-        model.config.eos_token_id is not None
-        and tokenizer.eos_token_id
-        and model.config.eos_token_id != tokenizer.eos_token_id
-    ):
-        print(
-            f"WARNING: There is a mismatch between eos token id of model({model.config.eos_token_id}) and tokenizer({tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
-        )
-        model.config.eos_token_id = tokenizer.eos_token_id
-
-    assert model.__class__.__name__ in [
-        "MistralForCausalLM",
-        "GPTDolomiteForCausalLM",
-        "LlamaForCausalLM",
-        "Starcoder2ForCausalLM",
-        "GemmaForCausalLM",
-        "MixtralForCausalLM",
-    ], f"Model class name: {model.__class__.__name__} is not supported."
-
-    model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
-    model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)
-
-    # handling of gradient checkpointing
-    # it is handled differently for lora and full
-    # - with the exception of granite, which handles it
-    #   in the later stanza
-    if args.lora_r > 0:
-        # if lora
-        # Third Party
-        from peft import LoraConfig
-
-        # ensure we select only the modules that exist in the model
-        proj_layers = get_projection_layer_names(model)
-        if not args.lora_target_modules:
-            print(
-                f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
-            )
-            if not proj_layers:
-                raise RuntimeError("could not find any projection layers in the model")
-            args.__dict__["lora_target_modules"] = proj_layers
-        else:
-            # when the user specifies the module, we should verify that they align with what's in the model
-            lora_target_modules_set = set(args.lora_target_modules)
-            diff = lora_target_modules_set - set(proj_layers)
-            layers_to_target = lora_target_modules_set - diff
-            if len(diff) == len(args.lora_target_modules):
-                raise ValueError(
-                    f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
-                )
-            if diff:
-                print(
-                    f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
-                )
-            args.__dict__["lora_target_modules"] = list(layers_to_target)
-
-        peft_config = LoraConfig(
-            lora_alpha=args.lora_alpha,
-            lora_dropout=args.lora_dropout,
-            r=args.lora_r,
-            bias="none",
-            task_type="CAUSAL_LM",
-            target_modules=args.lora_target_modules,
-        )
-        model = prepare_peft_model(
-            model, peft_config, gradient_checkpointing=not args.is_granite
-        )
-
-    elif not args.is_granite:
-        model.gradient_checkpointing_enable()
-
-    # granite gradient checkpointing is handled uniformly
-    # for both lora and full here
-    if args.is_granite:
-        block_name = model._no_split_modules[0]
-        apply_gradient_checkpointing(
-            model,
-            block_name=block_name,
-            use_reentrant=True,  # this should be the HF default mode
-        )
-
-        if args.lora_r > 0:
-
-            def make_inputs_require_grad(module, input, output):
-                output.requires_grad_(True)
-
-            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
-    accelerator = setup_accelerator(args, model, grad_accum)
-    if args.distributed_training_framework == DistributedBackend.FSDP.value:
-        model = accelerator.prepare(model)
-    optimizer = setup_optimizer(args, model)
-
-    lr_scheduler = get_scheduler(
-        name=args.lr_scheduler,
-        optimizer=optimizer,
-        num_warmup_steps=args.num_warmup_steps,
-        num_training_steps=args.num_epochs * len(train_loader) // grad_accum,
-    )
-    model, optimizer, _, lr_scheduler = accelerator.prepare(
-        model,
-        optimizer,
-        deepcopy(train_loader),
-        lr_scheduler,
-    )
-    return model, lr_scheduler, optimizer, accelerator
-
-
 # this function is to check if the checkpoint provided can be resumed
 def maybe_resume_training(args, model):
     local_rank = int(os.environ["LOCAL_RANK"])
diff --git a/src/instructlab/training/setup_model.py b/src/instructlab/training/setup_model.py
new file mode 100644
index 00000000..b9ef820b
--- /dev/null
+++ b/src/instructlab/training/setup_model.py
@@ -0,0 +1,198 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Standard
+from copy import deepcopy
+from typing import Any, Tuple
+import math
+
+# Third Party
+# pylint: disable=no-name-in-module
+from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
+from transformers import AutoModelForCausalLM, get_scheduler
+import torch
+
+# First Party
+from instructlab.training.config import DistributedBackend
+from instructlab.training.setup_accelerator import setup_accelerator
+from instructlab.training.setup_optimizer import setup_optimizer
+from instructlab.training.utils import (
+    add_noisy_embeddings,
+    apply_gradient_checkpointing,
+    convert_loss_to_reduce_sum,
+    ensure_loadable_granite_checkpoint,
+    get_projection_layer_names,
+    prepare_peft_model,
+)
+
+
+def setup_model(
+    args: Any, tokenizer: Any, train_loader: Any, grad_accum: int
+) -> Tuple[torch.nn.Module, Any, torch.optim.Optimizer, Any]:
+    bnb_config = None
+    if args.lora_r > 0 and args.lora_quant_bits == 4:
+        # Third Party
+        from transformers import BitsAndBytesConfig
+
+        bnb_config = BitsAndBytesConfig(
+            load_in_4bit=True,
+            bnb_4bit_quant_type="nf4",
+            bnb_4bit_use_double_quant=True,
+            bnb_4bit_compute_dtype=torch.float16,  # if not set will throw a warning about slow speeds when training
+        )
+
+    base_model_args = {
+        "pretrained_model_name_or_path": args.model_name_or_path,
+        "torch_dtype": torch.bfloat16,
+        "quantization_config": bnb_config,
+    }
+    if not args.disable_flash_attn:
+        base_model_args["attn_implementation"] = "flash_attention_2"
+    elif args.is_granite:
+        raise RuntimeError(
+            "ERROR: Trying to use padding-free transformer without flash attention is not supported"
+        )
+
+    if args.is_granite:
+        with ensure_loadable_granite_checkpoint(
+            args.model_name_or_path, args.output_dir
+        ) as path:
+            base_model_args["pretrained_model_name_or_path"] = path
+            model = GPTDolomiteForCausalLM.from_pretrained(
+                **base_model_args,
+                use_padding_free_transformer=True,
+            )
+    else:
+        model = AutoModelForCausalLM.from_pretrained(**base_model_args)
+
+    if len(tokenizer) > model.config.vocab_size:
+        print(
+            f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
+        )
+        model.resize_token_embeddings(
+            int(8 * math.ceil(len(tokenizer) / 8.0))
+        )  # make the vocab size multiple of 8 for sharding the embedding layer.
+
+    # Fix any discrepancy between model and tokenizer
+    if (
+        model.config.pad_token_id is not None
+        and tokenizer.pad_token_id is not None
+        and model.config.pad_token_id != tokenizer.pad_token_id
+    ):
+        print(
+            f"WARNING: There is a mismatch between pad token id of model ({model.config.pad_token_id}) and tokenizer({tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
+        )
+        model.config.pad_token_id = tokenizer.pad_token_id
+    if (
+        model.config.bos_token_id is not None
+        and tokenizer.bos_token_id is not None
+        and model.config.bos_token_id != tokenizer.bos_token_id
+    ):
+        print(
+            f"WARNING: There is a mismatch between bos token id of model({model.config.bos_token_id}) and tokenizer({tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
+        )
+        model.config.bos_token_id = tokenizer.bos_token_id
+    if (
+        model.config.eos_token_id is not None
+        and tokenizer.eos_token_id
+        and model.config.eos_token_id != tokenizer.eos_token_id
+    ):
+        print(
+            f"WARNING: There is a mismatch between eos token id of model({model.config.eos_token_id}) and tokenizer({tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
+        )
+        model.config.eos_token_id = tokenizer.eos_token_id
+
+    assert model.__class__.__name__ in [
+        "MistralForCausalLM",
+        "GPTDolomiteForCausalLM",
+        "LlamaForCausalLM",
+        "Starcoder2ForCausalLM",
+        "GemmaForCausalLM",
+        "MixtralForCausalLM",
+    ], f"Model class name: {model.__class__.__name__} is not supported."
+
+    model = convert_loss_to_reduce_sum(model, is_granite=args.is_granite)
+    model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)
+
+    # handling of gradient checkpointing
+    # it is handled differently for lora and full
+    # - with the exception of granite, which handles it
+    #   in the later stanza
+    if args.lora_r > 0:
+        # if lora
+        # Third Party
+        from peft import LoraConfig
+
+        # ensure we select only the modules that exist in the model
+        proj_layers = get_projection_layer_names(model)
+        if not args.lora_target_modules:
+            print(
+                f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
+            )
+            if not proj_layers:
+                raise RuntimeError("could not find any projection layers in the model")
+            args.__dict__["lora_target_modules"] = proj_layers
+        else:
+            # when the user specifies the module, we should verify that they align with what's in the model
+            lora_target_modules_set = set(args.lora_target_modules)
+            diff = lora_target_modules_set - set(proj_layers)
+            layers_to_target = lora_target_modules_set - diff
+            if len(diff) == len(args.lora_target_modules):
+                raise ValueError(
+                    f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
+                )
+            if diff:
+                print(
+                    f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
+                )
+            args.__dict__["lora_target_modules"] = list(layers_to_target)
+
+        peft_config = LoraConfig(
+            lora_alpha=args.lora_alpha,
+            lora_dropout=args.lora_dropout,
+            r=args.lora_r,
+            bias="none",
+            task_type="CAUSAL_LM",
+            target_modules=args.lora_target_modules,
+        )
+        model = prepare_peft_model(
+            model, peft_config, gradient_checkpointing=not args.is_granite
+        )
+
+    elif not args.is_granite:
+        model.gradient_checkpointing_enable()
+
+    # granite gradient checkpointing is handled uniformly
+    # for both lora and full here
+    if args.is_granite:
+        block_name = model._no_split_modules[0]
+        apply_gradient_checkpointing(
+            model,
+            block_name=block_name,
+            use_reentrant=True,  # this should be the HF default mode
+        )
+
+        if args.lora_r > 0:
+
+            def make_inputs_require_grad(module, input, output):
+                output.requires_grad_(True)
+
+            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+    accelerator = setup_accelerator(args, model, grad_accum)
+    if args.distributed_training_framework == DistributedBackend.FSDP.value:
+        model = accelerator.prepare(model)
+    optimizer = setup_optimizer(args, model)
+
+    lr_scheduler = get_scheduler(
+        name=args.lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=args.num_warmup_steps,
+        num_training_steps=args.num_epochs * len(train_loader) // grad_accum,
+    )
+    model, optimizer, _, lr_scheduler = accelerator.prepare(
+        model,
+        optimizer,
+        deepcopy(train_loader),
+        lr_scheduler,
+    )
+    return model, lr_scheduler, optimizer, accelerator
diff --git a/src/instructlab/training/setup_optimizer.py b/src/instructlab/training/setup_optimizer.py
new file mode 100644
index 00000000..337e8f02
--- /dev/null
+++ b/src/instructlab/training/setup_optimizer.py
@@ -0,0 +1,39 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Standard
+from typing import Any
+
+# Third Party
+from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
+import torch
+
+# First Party
+from instructlab.training.config import DistributedBackend
+
+
+def setup_optimizer(args: Any, model: torch.nn.Module) -> torch.optim.Optimizer:
+    if args.distributed_training_framework == DistributedBackend.FSDP.value:
+        optimizer = torch.optim.AdamW(
+            model.parameters(),
+            lr=args.learning_rate,
+            betas=(0.9, 0.95),
+            weight_decay=0.0,
+        )
+    elif args.distributed_training_framework == DistributedBackend.DEEPSPEED.value:
+        # need to use this only when the CPU offload optimizer is enabled
+        if args.cpu_offload_optimizer:
+            print(
+                "\033[33m!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!\033[0m"
+            )
+            optimizer = DeepSpeedCPUAdam(
+                model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
+            )
+        else:
+            optimizer = FusedAdam(
+                model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
+            )
+    else:
+        raise ValueError(
+            f"Sharding framework {args.distributed_training_framework} is not supported."
+        )
+    return optimizer