@@ -506,53 +506,54 @@ def benchmark(
506
506
b = torch .cuda .max_memory_allocated (rank )
507
507
max_mem_allocated .append (b // 1024 // 1024 )
508
508
509
- # pyre-ignore[2]
510
- def trace_handler (prof ) -> None :
511
- total_average = prof .profiler .total_average ()
512
- logger .info (f" TOTAL_AVERAGE:\n { name } \n { total_average } " )
513
- dir_path : str = output_dir
514
-
515
- # Don't output trace files if dir_path is empty
516
- # or rank != 0, rank=-1 in no pg case, only 1 rank should output
517
- # in pg case, so rank=0
518
- if dir_path == "" or rank > 0 :
519
- return
520
-
521
- trace_file : str = f"{ dir_path } /trace-{ name } .json"
522
- stacks_cpu_file = f"{ dir_path } /stacks-cpu-{ name } .stacks"
523
- stacks_cuda_file = f"{ dir_path } /stacks-cuda-{ name } .stacks"
524
- logger .info (f" PROFILE[{ name } ].chrome_trace:{ trace_file } " )
525
-
526
- prof .export_chrome_trace (trace_file )
527
- prof .export_stacks (stacks_cpu_file , "self_cpu_time_total" )
528
- prof .export_stacks (stacks_cuda_file , "self_cuda_time_total" )
529
-
530
- # - git clone https://github.com/brendangregg/FlameGraph
531
- # - cd FlameGraph
532
- # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg
533
-
534
- with torch .profiler .profile (
535
- activities = [
536
- torch .profiler .ProfilerActivity .CPU ,
537
- torch .profiler .ProfilerActivity .CUDA ,
538
- ],
539
- record_shapes = True ,
540
- profile_memory = True ,
541
- with_stack = True ,
542
- with_flops = True ,
543
- with_modules = True ,
544
- on_trace_ready = trace_handler ,
545
- ) as p :
546
- for _input in prof_inputs :
547
- with record_function ("## forward ##" ):
548
- model (_input )
549
- p .step ()
550
-
551
- if rank == - 1 :
552
- for di in range (world_size ):
553
- torch .cuda .synchronize (di )
554
- else :
555
- torch .cuda .synchronize (rank )
509
+ if output_dir != "" :
510
+ # Only do profiling if output_dir is set
511
+
512
+ # pyre-ignore[2]
513
+ def trace_handler (prof ) -> None :
514
+ total_average = prof .profiler .total_average ()
515
+ logger .info (f" TOTAL_AVERAGE:\n { name } \n { total_average } " )
516
+ dir_path : str = output_dir
517
+
518
+ # only 1 rank should output in pg case, rank = 0
519
+ if rank > 0 :
520
+ return
521
+
522
+ trace_file : str = f"{ dir_path } /trace-{ name } .json"
523
+ stacks_cpu_file = f"{ dir_path } /stacks-cpu-{ name } .stacks"
524
+ stacks_cuda_file = f"{ dir_path } /stacks-cuda-{ name } .stacks"
525
+ logger .info (f" PROFILE[{ name } ].chrome_trace:{ trace_file } " )
526
+
527
+ prof .export_chrome_trace (trace_file )
528
+ prof .export_stacks (stacks_cpu_file , "self_cpu_time_total" )
529
+ prof .export_stacks (stacks_cuda_file , "self_cuda_time_total" )
530
+
531
+ # - git clone https://github.com/brendangregg/FlameGraph
532
+ # - cd FlameGraph
533
+ # - ./flamegraph.pl --title "CPU time" --countname "us." profiler.stacks > perf_viz.svg
534
+
535
+ with torch .profiler .profile (
536
+ activities = [
537
+ torch .profiler .ProfilerActivity .CPU ,
538
+ torch .profiler .ProfilerActivity .CUDA ,
539
+ ],
540
+ record_shapes = True ,
541
+ profile_memory = True ,
542
+ with_stack = True ,
543
+ with_flops = True ,
544
+ with_modules = True ,
545
+ on_trace_ready = trace_handler ,
546
+ ) as p :
547
+ for _input in prof_inputs :
548
+ with record_function ("## forward ##" ):
549
+ model (_input )
550
+ p .step ()
551
+
552
+ if rank == - 1 :
553
+ for di in range (torch .cuda .device_count ()):
554
+ torch .cuda .synchronize (torch .device (f"cuda:{ di } " ))
555
+ else :
556
+ torch .cuda .synchronize ()
556
557
557
558
return BenchmarkResult (
558
559
short_name = name ,
@@ -754,6 +755,8 @@ def benchmark_module(
754
755
output_dir: Directory to output profiler outputs (traces, stacks)
755
756
pooling_configs: The pooling factor for the tables.
756
757
(Optional; if not set, we'll use 10 as default)
758
+ func_to_benchmark: Custom function to benchmark, check out default_func_to_benchmark for default
759
+ benchmark_func_kwargs: Custom keyword arguments to pass to func_to_benchmark
757
760
758
761
Returns:
759
762
A list of BenchmarkResults
0 commit comments