You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We created a performance tracing utility, as shown below, that prints the time it takes for each function in our MLX application and assists us in identifying the component that needs improvement.
Do you think it's worth checking the code into the mlx_example repository? If so, I'd be pleased to submit a pull request. Thanks.
The file time_mlx.py:
importtimefromtypingimportCallable, Dict, List, Optionalimportmlx.coreasmximportmlx.nnfromtabulateimporttabulateclass_Record:
def__init__(self, msg: str, indentation: int) ->None:
self.msg=msgself.indentation=indentationself.timing: List[float] = []
self.parent: Optional[_Record] =NoneclassLedger:
def__init__(self) ->None:
self.records: List[_Record] = []
self.records_dict: Dict[str, _Record] = {}
self.indentation=-1self.key=""defreset(self):
self.records= []
self.records_dict: Dict[str, _Record] = {}
self.indentation=-1self.key=""defprint_table(self):
table= [
[
"-"*r.indentation+"> "+r.msg,
sum(r.timing) /len(r.timing),
sum(r.timing),
sum(r.timing) /sum(r.parent.timing) *100ifr.parentelse100,
]
forrinself.records
]
print(
tabulate(
table,
headers=[
"function",
"latency per run (ms)",
"latency in total (ms)",
"Latency Ratio (%)",
],
tablefmt="psql",
)
)
defprint_summary(self):
forrinself.records:
print(f"{r.msg}{sum(r.timing):.3f} (ms)")
ledger=Ledger()
deffunction(msg: str):
"""This decorator times the exeuction time of a function that calls MLX"""defdecorator(g: Callable):
defg_wrapped(*args, **kwargs):
# Evaluate each of the input parameters to make sure they are ready before starting# ticking, and evaluate the return value(s) of g to make sure they are ready before# ending ticking.defeval_arg(arg):
if (
isinstance(arg, mx.array)
orisinstance(arg, list)
orisinstance(arg, tuple)
orisinstance(arg, dict)
):
mx.eval(arg)
elifisinstance(arg, mlx.nn.Module):
mx.eval(arg.parameters())
returnargforarginargs:
eval_arg(arg)
fork, vinkwargs.items():
eval_arg(v)
ledger.indentation+=1prev_key=ledger.keyledger.key+=msgifledger.keynotinledger.records_dict:
r=_Record(msg, ledger.indentation)
ledger.records.append(r)
ledger.records_dict[ledger.key] =rr.parent=ledger.records_dict[prev_key] iflen(prev_key) >0elseNonetic=time.perf_counter()
result=g(*args, **kwargs)
eval_arg(result)
timing=1e3* (time.perf_counter() -tic)
ledger.records_dict[ledger.key].timing.append(timing)
ledger.indentation-=1ledger.key=prev_keyreturnresultreturng_wrappedreturndecorator
The unit test, which serves as an example as well.
This discussion was converted from issue #941 on August 16, 2024 17:23.
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
Menu
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello, team.
We created a performance tracing utility, as shown below, that prints the time it takes for each function in our MLX application and assists us in identifying the component that needs improvement.
Do you think it's worth checking the code into the mlx_example repository? If so, I'd be pleased to submit a pull request. Thanks.
The file
time_mlx.py
:The unit test, which serves as an example as well.
Beta Was this translation helpful? Give feedback.
All reactions