Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Measure compute needed in FLOPs #75

Open
jonas-becker opened this issue Jun 19, 2024 · 2 comments
Open

Measure compute needed in FLOPs #75

jonas-becker opened this issue Jun 19, 2024 · 2 comments
Assignees
Labels
feature request A requested feature

Comments

@jonas-becker
Copy link
Collaborator

Is your feature request related to a problem? Please describe.
Currently, we only log the compute times of MALLM.

Describe the solution you'd like
The proper way to compare MALLM and a single-LLM baseline in terms of computing would be by FLOPs.
Thus, we want to track the FLOPs and output them to the results similar to the compute times.

@jonas-becker jonas-becker added the feature request A requested feature label Jun 19, 2024
@lkaesberg
Copy link
Collaborator

lkaesberg commented Jun 26, 2024

import math

def estimate_flops(model_params, input_length, output_length, flops_per_param=3):
    """
    Estimate the number of FLOPs for a TGI model.
    
    :param model_params: Number of parameters in the model
    :param input_length: Length of the input sequence
    :param output_length: Length of the generated output sequence
    :param flops_per_param: Estimated FLOPs per parameter per token (default: 3)
    :return: Estimated total FLOPs
    """
    total_length = input_length + output_length
    flops_per_token = model_params * flops_per_param
    total_flops = flops_per_token * total_length
    
    return total_flops

def format_flops(flops):
    """
    Format FLOPs into a human-readable string.
    """
    if flops < 1e3:
        return f"{flops:.2f} FLOPs"
    elif flops < 1e6:
        return f"{flops/1e3:.2f} KFLOPs"
    elif flops < 1e9:
        return f"{flops/1e6:.2f} MFLOPs"
    elif flops < 1e12:
        return f"{flops/1e9:.2f} GFLOPs"
    else:
        return f"{flops/1e12:.2f} TFLOPs"

# Example usage
model_params = 7e9  # 7B parameter model (e.g., LLaMA-7B)
input_length = 100
output_length = 50

estimated_flops = estimate_flops(model_params, input_length, output_length)
print(f"Estimated compute cost: {format_flops(estimated_flops)}")

Something like this is possible to estimate it. Difficult to get the actual data for tgi and we need a way to get the model param count

@jpwahle
Copy link
Collaborator

jpwahle commented Jul 3, 2024

I think this is great, but @flowun will start a project mid August to look specifically at that, then I will assign him to the issue :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request A requested feature
Projects
None yet
Development

No branches or pull requests

4 participants