diff --git a/benchmarks/operator_benchmark/pt/qobserver_test.py b/benchmarks/operator_benchmark/pt/qobserver_test.py index 74c982136b4560..fb528eac4ca5bd 100644 --- a/benchmarks/operator_benchmark/pt/qobserver_test.py +++ b/benchmarks/operator_benchmark/pt/qobserver_test.py @@ -70,6 +70,13 @@ ] ) +qobserver_calculate_qparams_list = op_bench.op_list( + attr_names=['op_name', 'op_func'], + attrs=[ + ['HistogramObserverCalculateQparams', obs.HistogramObserver], + ] +) + class QObserverBenchmark(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): @@ -79,6 +86,15 @@ def init(self, C, M, N, dtype, qscheme, op_func, device): def forward(self): return self.op_func(self.f_input) +class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase): + def init(self, C, M, N, dtype, qscheme, op_func, device): + self.f_input = torch.rand(C, M, N, device=device) + self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device) + self.q_observer(self.f_input) + + def forward(self): + return self.q_observer.calculate_qparams() + op_bench.generate_pt_tests_from_op_list( qobserver_per_tensor_list, @@ -90,6 +106,11 @@ def forward(self): qobserver_per_channel_configs_short + qobserver_per_channel_configs_long, QObserverBenchmark) +op_bench.generate_pt_tests_from_op_list( + qobserver_calculate_qparams_list, + qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long, + QObserverBenchmarkCalculateQparams) + if __name__ == "__main__": op_bench.benchmark_runner.main()