Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize load distribution between nodes #719

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vishwamartur
Copy link

Related to #701

Implement load distribution between nodes based on flops and memory.

  • Partitioning Strategy:

    • Add flops attribute to Partition class in exo/topology/partitioning_strategy.py.
    • Update map_partitions_to_shards function to consider flops when mapping partitions to shards.
    • Add get_flops method to PartitioningStrategy class to calculate the flops of each partition.
  • Ring Memory Weighted Partitioning Strategy:

    • Update partition method in exo/topology/ring_memory_weighted_partitioning_strategy.py to consider both memory and flops for partitioning.
    • Add calculate_flops_weight helper function to calculate the weight of each node based on its flops.
  • Node Class:

    • Update Node class in exo/orchestration/node.py to implement logic to sort nodes by flops for load distribution.
    • Add sort_nodes_by_flops method to sort nodes by their flops.
  • Inference Engine:

    • Add get_flops method to InferenceEngine class in exo/inference/inference_engine.py to get the flops of the current node.
  • Sharded Inference Engine:

    • Add get_flops method to MLXDynamicShardInferenceEngine class in exo/inference/mlx/sharded_inference_engine.py to get the flops of the current node.

Related to exo-explore#701

Implement load distribution between nodes based on flops and memory.

* **Partitioning Strategy**:
  - Add `flops` attribute to `Partition` class in `exo/topology/partitioning_strategy.py`.
  - Update `map_partitions_to_shards` function to consider `flops` when mapping partitions to shards.
  - Add `get_flops` method to `PartitioningStrategy` class to calculate the flops of each partition.

* **Ring Memory Weighted Partitioning Strategy**:
  - Update `partition` method in `exo/topology/ring_memory_weighted_partitioning_strategy.py` to consider both memory and flops for partitioning.
  - Add `calculate_flops_weight` helper function to calculate the weight of each node based on its flops.

* **Node Class**:
  - Update `Node` class in `exo/orchestration/node.py` to implement logic to sort nodes by flops for load distribution.
  - Add `sort_nodes_by_flops` method to sort nodes by their flops.

* **Inference Engine**:
  - Add `get_flops` method to `InferenceEngine` class in `exo/inference/inference_engine.py` to get the flops of the current node.

* **Sharded Inference Engine**:
  - Add `get_flops` method to `MLXDynamicShardInferenceEngine` class in `exo/inference/mlx/sharded_inference_engine.py` to get the flops of the current node.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant