diff --git a/problog/tasks/probability.py b/problog/tasks/probability.py index 69c6918b..a13b8453 100644 --- a/problog/tasks/probability.py +++ b/problog/tasks/probability.py @@ -361,6 +361,11 @@ class OutputFile(str): return parser +def main_result(argv): + """ Equivalent to the main method, but it returns the result rather than success/failure. """ + return main(argv, result_handler=lambda result, output: result) + + def main(argv, result_handler=None): parser = argparser() args = parser.parse_args(argv) @@ -433,5 +438,6 @@ def main(argv, result_handler=None): stop_timer() return retcode + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/problog/test/test_tasks.py b/problog/test/test_tasks.py index 3ce7e27c..78e07e35 100644 --- a/problog/test/test_tasks.py +++ b/problog/test/test_tasks.py @@ -94,7 +94,7 @@ def check_probability(self, expected, result): def test_probability_some_heads(self): file_name = test_folder / "tasks" / "some_heads.pl" - result = probability.main([str(file_name)]) + result = probability.main_result([str(file_name)]) self.check_probability({Term("someHeads"): 0.8}, result) def check_probability_probabilistic_graph(self, result): @@ -112,35 +112,35 @@ def check_probability_probabilistic_graph(self, result): def test_probability_pgraph(self): file_name = test_folder / "tasks" / "map_probabilistic_graph.pl" - self.check_probability_probabilistic_graph(probability.main([str(file_name)])) + self.check_probability_probabilistic_graph(probability.main_result([str(file_name)])) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--combine"]) + probability.main_result([str(file_name), "--combine"]) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--nologspace"]) + probability.main_result([str(file_name), "--nologspace"]) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--propagate-evidence"]) + probability.main_result([str(file_name), "--propagate-evidence"]) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--propagate-weights"]) + probability.main_result([str(file_name), "--propagate-weights"]) ) self.check_probability_probabilistic_graph( - probability.main( + probability.main_result( [str(file_name), "--propagate-evidence", "--propagate-weights"] ) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--unbuffered"]) + probability.main_result([str(file_name), "--unbuffered"]) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--convergence", str(0.00000001)]) + probability.main_result([str(file_name), "--convergence", str(0.00000001)]) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--format", "prolog"]) + probability.main_result([str(file_name), "--format", "prolog"]) ) self.check_probability_probabilistic_graph( - probability.main([str(file_name), "--web"]) + probability.main_result([str(file_name), "--web"]) ) def check_ground_result(self, expected, result):