diff --git a/atcodertools/client/models/problem_content.py b/atcodertools/client/models/problem_content.py index 6b4248a9..a32430f5 100644 --- a/atcodertools/client/models/problem_content.py +++ b/atcodertools/client/models/problem_content.py @@ -14,6 +14,12 @@ def normalize(content: str) -> str: return content.strip().replace('\r', '') + "\n" +def normalize_soup(content) -> str: + for a in content.findAll('var'): + a.replace_with(' ' + a.text + ' ') + return normalize(content.text) + + def is_japanese(ch): # Thank you! # http://minus9d.hatenablog.com/entry/2015/07/16/231608 @@ -34,6 +40,21 @@ class InputFormatDetectionError(Exception): pass +class InputFormat: + def __init__(self, input_format: list[str]): + self.type = None + self.loop_length_var = None + self.input_format = input_format + + +def _strip_case_vars(s): + result = [] + for line in s.split("\n"): + if line.find("case") == -1 and line.find("Case") == -1: + result.append(line) + return "\n".join(result) + + class ProblemContent: def __init__(self, input_format_text: Optional[str] = None, @@ -41,19 +62,23 @@ def __init__(self, input_format_text: Optional[str] = None, original_html: Optional[str] = None, ): self.samples = samples - self.input_format_text = input_format_text + if type(input_format_text) is str: + self.input_format_text = [input_format_text] + else: + self.input_format_text = input_format_text self.original_html = original_html + self.input_format_data = InputFormat(self.input_format_text) @classmethod def from_html(cls, html: str): res = ProblemContent(original_html=html) soup = BeautifulSoup(html, "html.parser") - res.input_format_text, res.samples = res._extract_input_format_and_samples( + res.input_format_text, res.input_format_data, res.samples = res._extract_input_format_and_samples( soup) return res - def get_input_format(self) -> str: - return self.input_format_text + def get_input_format(self) -> list[str]: + return self.input_format_data def get_samples(self) -> List[Sample]: return self.samples @@ -91,12 +116,17 @@ def _extract_input_format_and_samples(soup) -> Tuple[str, List[Sample]]: if input_format_tag is None: raise InputFormatDetectionError - - input_format_text = normalize(input_format_tag.text) + input_format_text = list( + map(lambda x: normalize_soup(x), input_format_tag)) except AttributeError: raise InputFormatDetectionError - return input_format_text, res + if len(input_format_text) == 2: + input_format_text[0] = _strip_case_vars(input_format_text[0]) + + input_format_data = InputFormat(input_format_text) + + return input_format_text, input_format_data, res @staticmethod def _primary_strategy(soup): @@ -114,17 +144,18 @@ def _primary_strategy(soup): if section_title.startswith("入力例"): input_tags.append(tag.find('pre')) elif section_title.startswith("入力"): - input_format_tag = tag.find('pre') + input_format_tag = tag.findAll('pre') if section_title.startswith("出力例"): output_tags.append(tag.find('pre')) return input_format_tag, input_tags, output_tags + # TODO: こっちのタイプはmulti caseに未対応!!! @staticmethod def _secondary_strategy(soup): # TODO: more descriptive name pre_tags = soup.select('pre') sample_tags = pre_tags[1:] input_tags = sample_tags[0::2] output_tags = sample_tags[1::2] - input_format_tag = pre_tags[0] + input_format_tag = [pre_tags[0]] return input_format_tag, input_tags, output_tags diff --git a/atcodertools/codegen/code_generators/universal_code_generator.py b/atcodertools/codegen/code_generators/universal_code_generator.py index 3bbf6d5f..1d273b3c 100644 --- a/atcodertools/codegen/code_generators/universal_code_generator.py +++ b/atcodertools/codegen/code_generators/universal_code_generator.py @@ -52,24 +52,31 @@ def _insert_space_around_operators(self, code: str): def _global_declaration(self) -> str: lines = [] - for pattern in self._format.sequence: + for pattern in self._format[0].sequence: for var in pattern.all_vars(): self._append( lines, self.info["global_prefix"] + self._generate_declaration(var)) return "\n".join(lines) def generate_parameters(self) -> Dict[str, Any]: - if self._format is None: + if self._format[0] is None: return dict(prediction_success=False) + self._input_part_data = self._input_part(global_mode=False) + return dict(formal_arguments=self._formal_arguments(), actual_arguments=self._actual_arguments(), - input_part=self._input_part(global_mode=False), + input_part=self._input_part_data, global_declaration=self._global_declaration(), global_input_part=self._input_part(global_mode=True), + input_part_with_solve_function=self._input_part_with_solve_function(), prediction_success=True) def _input_part(self, global_mode): + t = len(self._format) - 1 + return self._get_input_part(global_mode, self._format[t]) + + def _get_input_part(self, global_mode, format): lines = [] newline_after_input = False if "newline_after_input" in self.info and self.info["newline_after_input"]: @@ -80,7 +87,7 @@ def _input_part(self, global_mode): lines.append(line) if newline_after_input: lines.append("") - for pattern in self._format.sequence: + for pattern in format.sequence: lines += self._render_pattern(pattern, global_mode) if newline_after_input: lines.append("") @@ -98,6 +105,31 @@ def _input_part(self, global_mode): start = False return result + def _input_part_with_solve_function(self): + prefix = "{indent}".format( + indent=self._indent(self.info["base_indent"])) + result = self._get_input_part(False, self._format[0]) + "\n" + solve_function = self.info["solve_function"].format( + actual_arguments=self._actual_arguments()) + if len(self._format) == 1: + result += prefix + solve_function + elif len(self._format) == 2: + result += prefix + result += self.info["loop"]["header"].format( + loop_var="case_index", + length="T" # TODO + ) + result += "\n" + t = (prefix + self._get_input_part(False, + self._format[1])).split("\n") + t = list(map(lambda x: self._indent(1) + x, t)) + t.append(prefix + self._indent(1) + solve_function) + result += "\n".join(t) + footer = self.info["loop"]["footer"].format() + if footer != '': + result += "\n" + prefix + footer + return result + def _convert_type(self, type_: Type) -> str: return self.info["type"][type_.value] @@ -143,8 +175,9 @@ def _actual_arguments(self) -> str: """ :return the string form of actual arguments e.g. "N, K, a" """ + t = len(self._format) - 1 ret = [] - for v in self._format.all_vars(): + for v in self._format[t].all_vars(): if v.dim_num() == 0: ret.append(v.name) else: @@ -160,7 +193,8 @@ def _formal_arguments(self): """ :return the string form of formal arguments e.g. "int N, int K, std::vector a" """ - return ", ".join([self._get_argument(v) for v in self._format.all_vars()]) + t = len(self._format) - 1 + return ", ".join([self._get_argument(v) for v in self._format[t].all_vars()]) def _generate_declaration(self, var: Variable): """ diff --git a/atcodertools/codegen/code_generators/universal_generator/cpp.toml b/atcodertools/codegen/code_generators/universal_generator/cpp.toml index 9dcd0555..bc29cfdb 100644 --- a/atcodertools/codegen/code_generators/universal_generator/cpp.toml +++ b/atcodertools/codegen/code_generators/universal_generator/cpp.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments});" # ループ [loop] @@ -65,4 +66,3 @@ int = "std::scanf(\"%lld\", &{name});" float = "std::scanf(\"%Lf\", &{name});" str = "std::cin >> {name};" - diff --git a/atcodertools/codegen/code_generators/universal_generator/cs.toml b/atcodertools/codegen/code_generators/universal_generator/cs.toml index 2e0bcf13..af558b0a 100644 --- a/atcodertools/codegen/code_generators/universal_generator/cs.toml +++ b/atcodertools/codegen/code_generators/universal_generator/cs.toml @@ -4,6 +4,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "new Program().Solve({actual_arguments});" # ループ [loop] diff --git a/atcodertools/codegen/code_generators/universal_generator/d.toml b/atcodertools/codegen/code_generators/universal_generator/d.toml index 41d15c05..23de4d2d 100644 --- a/atcodertools/codegen/code_generators/universal_generator/d.toml +++ b/atcodertools/codegen/code_generators/universal_generator/d.toml @@ -5,6 +5,7 @@ insert_space_around_operators = false global_prefix = "" input_part_prefix = "auto input = stdin.byLine.map!split.joiner;" newline_after_input = true +solve_function = "solve({actual_arguments});" # ループ [loop] diff --git a/atcodertools/codegen/code_generators/universal_generator/go.toml b/atcodertools/codegen/code_generators/universal_generator/go.toml index fb584b52..53974697 100644 --- a/atcodertools/codegen/code_generators/universal_generator/go.toml +++ b/atcodertools/codegen/code_generators/universal_generator/go.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments})" # ループ [loop] diff --git a/atcodertools/codegen/code_generators/universal_generator/java.toml b/atcodertools/codegen/code_generators/universal_generator/java.toml index 0f5e683b..9e117a75 100644 --- a/atcodertools/codegen/code_generators/universal_generator/java.toml +++ b/atcodertools/codegen/code_generators/universal_generator/java.toml @@ -4,6 +4,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "static " input_part_prefix = "final Scanner sc = new Scanner(System.in);" +solve_function = "solve({actual_arguments});" # ループ [loop] diff --git a/atcodertools/codegen/code_generators/universal_generator/julia.toml b/atcodertools/codegen/code_generators/universal_generator/julia.toml index 50a04328..3c993608 100644 --- a/atcodertools/codegen/code_generators/universal_generator/julia.toml +++ b/atcodertools/codegen/code_generators/universal_generator/julia.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments})" # インデックス [index] diff --git a/atcodertools/codegen/code_generators/universal_generator/nim.toml b/atcodertools/codegen/code_generators/universal_generator/nim.toml index e28e59e9..6c1a37ff 100644 --- a/atcodertools/codegen/code_generators/universal_generator/nim.toml +++ b/atcodertools/codegen/code_generators/universal_generator/nim.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments})" # ループ [loop] diff --git a/atcodertools/codegen/code_generators/universal_generator/python.toml b/atcodertools/codegen/code_generators/universal_generator/python.toml index a39fe77a..c7809326 100644 --- a/atcodertools/codegen/code_generators/universal_generator/python.toml +++ b/atcodertools/codegen/code_generators/universal_generator/python.toml @@ -3,6 +3,7 @@ insert_space_around_operators = true # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments})" # インデックス [index] diff --git a/atcodertools/codegen/code_generators/universal_generator/rust.toml b/atcodertools/codegen/code_generators/universal_generator/rust.toml index fdd3c83e..1f37d9af 100644 --- a/atcodertools/codegen/code_generators/universal_generator/rust.toml +++ b/atcodertools/codegen/code_generators/universal_generator/rust.toml @@ -4,6 +4,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" input_part_prefix = "let con = read_string();\nlet mut scanner = Scanner::new(&con);" +solve_function = "solve({actual_arguments});" # ループ [loop] diff --git a/atcodertools/codegen/code_generators/universal_generator/swift.toml b/atcodertools/codegen/code_generators/universal_generator/swift.toml index 07c27ad7..0e27383e 100644 --- a/atcodertools/codegen/code_generators/universal_generator/swift.toml +++ b/atcodertools/codegen/code_generators/universal_generator/swift.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "_ = solve({actual_arguments})" # ループ [loop] diff --git a/atcodertools/fmtprediction/models/format_prediction_result.py b/atcodertools/fmtprediction/models/format_prediction_result.py index 60f016cc..110400a3 100644 --- a/atcodertools/fmtprediction/models/format_prediction_result.py +++ b/atcodertools/fmtprediction/models/format_prediction_result.py @@ -10,8 +10,7 @@ class FormatPredictionResult: def __init__(self, format_: Optional[Format[Variable]] = None): self.format = format_ - @classmethod - def create_typed_format(cls, simple_format: Format[SimpleVariable], var_to_type: Dict[str, Type]): + def _create_typed_format(self, simple_format: Format[SimpleVariable], var_to_type: Dict[str, Type]): var_to_info = {} for var in simple_format.all_vars(): assert var.name not in var_to_info @@ -29,6 +28,14 @@ def create_typed_format(cls, simple_format: Format[SimpleVariable], var_to_type: return FormatPredictionResult(fmt) + @classmethod + def create_typed_format(cls, simple_formats: list[Format[SimpleVariable]], var_to_type: Dict[str, Type]): + result = [] + for simple_format in simple_formats: + result.append(cls._create_typed_format( + cls, simple_format, var_to_type)) + return result + @classmethod def empty_result(cls): return FormatPredictionResult() diff --git a/atcodertools/fmtprediction/predict_format.py b/atcodertools/fmtprediction/predict_format.py index d53cdaef..133d3229 100644 --- a/atcodertools/fmtprediction/predict_format.py +++ b/atcodertools/fmtprediction/predict_format.py @@ -16,7 +16,7 @@ def __init__(self, cands): self.cands = cands -def predict_format(content: ProblemContent) -> FormatPredictionResult: +def predict_format(content: ProblemContent) -> list[FormatPredictionResult]: input_format = content.get_input_format() samples = content.get_samples() @@ -29,20 +29,45 @@ def predict_format(content: ProblemContent) -> FormatPredictionResult: except NoFormatFoundError: raise NoPredictionResultError + # tokenized_possible_formats = [input1, input2, ...] output_cands = [] - for format in tokenized_possible_formats: + + simple_formats = [] + for tokenized_possible_format in tokenized_possible_formats: + simple_format = [] for to_1d_flag in [False, True]: + for format in tokenized_possible_format: + try: + simple_format.append(predict_simple_format( + format.var_tokens, to_1d_flag)) + except (TypePredictionFailedError, SimpleFormatPredictionFailedError): + pass + simple_formats.append(simple_format) + + if len(simple_formats) == 1: + for a in simple_formats[0]: + simple_format = [a] try: - simple_format = predict_simple_format( - format.var_tokens, to_1d_flag) output_cands.append( - FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples))) + FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples, input_format.loop_length_var))) break except (TypePredictionFailedError, SimpleFormatPredictionFailedError): pass + elif len(simple_formats) == 2: + for a in simple_formats[0]: + for b in simple_formats[1]: + simple_format = [a, b] + try: + output_cands.append( + FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples, input_format.loop_length_var))) + break + except (TypePredictionFailedError, SimpleFormatPredictionFailedError): + pass - if len(output_cands) > 1: - raise MultiplePredictionResultsError(output_cands) + # TODO: ここをコメントアウトしたが大丈夫か? + # if len(output_cands) > 1: + # raise MultiplePredictionResultsError(output_cands) if len(output_cands) == 0: raise NoPredictionResultError + return output_cands[0] diff --git a/atcodertools/fmtprediction/predict_types.py b/atcodertools/fmtprediction/predict_types.py index d7bfdfb9..8ed41698 100644 --- a/atcodertools/fmtprediction/predict_types.py +++ b/atcodertools/fmtprediction/predict_types.py @@ -141,21 +141,58 @@ def merge_type_dicts(to_dict: Dict[str, Type], src_dict: Dict[str, Type]): return to_dict -def predict_types(simple_format: Format[SimpleVariable], samples: List[Sample]) -> Dict[str, Type]: +def predict_types(simple_format: list[Format[SimpleVariable]], samples: List[Sample], loop_length_var: str) -> Dict[str, Type]: res_type_dict = {} - for sample in samples: - token_manager = TokenManager(sample.get_input().split()) - predictor = TypePredictor(simple_format) - try: - while not token_manager.is_terminal(): - predictor.feed(token_manager.next()) - predictor.ensure_terminal() + if len(simple_format) == 1: + for sample in samples: + token_manager = TokenManager(sample.get_input().split()) + predictor = TypePredictor(simple_format[0]) + try: + while not token_manager.is_terminal(): + predictor.feed(token_manager.next()) + predictor.ensure_terminal() + res_type_dict = merge_type_dicts( + res_type_dict, + predictor.get_typing_result()) + except ( + TooLessFetchesError, TooManyFetchesError, KeyError, InvalidLoopSizeError, + InvalidLoopIndexError, EvaluateError): + raise TypePredictionFailedError + else: + for sample in samples: + token_manager = TokenManager(sample.get_input().split()) + predictor = TypePredictor(simple_format[0]) + loop_length_var = "T" + loop_length = -1 + while True: + try: + var = predictor._fetch() + s = token_manager.next() + if var.name == loop_length_var: + loop_length = int(s) + predictor._refresh(var, _convert_to_proper_type(s)) + except TooManyFetchesError: + break res_type_dict = merge_type_dicts( res_type_dict, predictor.get_typing_result()) - except ( - TooLessFetchesError, TooManyFetchesError, KeyError, InvalidLoopSizeError, - InvalidLoopIndexError, EvaluateError): - raise TypePredictionFailedError + assert loop_length >= 0 + try: + for ct in range(loop_length): + predictor = TypePredictor(simple_format[1]) + while True: + try: + var = predictor._fetch() + s = token_manager.next() + predictor._refresh(var, _convert_to_proper_type(s)) + except TooManyFetchesError: + break + res_type_dict = merge_type_dicts( + res_type_dict, + predictor.get_typing_result()) + except ( + TooLessFetchesError, TooManyFetchesError, KeyError, InvalidLoopSizeError, + InvalidLoopIndexError, EvaluateError): + raise TypePredictionFailedError return res_type_dict diff --git a/atcodertools/fmtprediction/tokenize_format.py b/atcodertools/fmtprediction/tokenize_format.py index 1d7e9ce8..01417294 100644 --- a/atcodertools/fmtprediction/tokenize_format.py +++ b/atcodertools/fmtprediction/tokenize_format.py @@ -5,6 +5,7 @@ from atcodertools.fmtprediction.models.variable_token import VariableToken, TokenizedFormat from atcodertools.fmtprediction.token_manager import TokenManager +from atcodertools.client.models.problem_content import InputFormat def _is_ascii(s): @@ -155,18 +156,23 @@ def check_if_possible(var_token: VariableToken): return [var_token for var_token in var_token_candidates if check_if_possible(var_token)] -def search_formats_with_minimum_vars(input_format: str) -> List[TokenizedFormat]: +def search_formats_with_minimum_vars(input_format: InputFormat) -> List[TokenizedFormat]: """ Fast enough for realistic instances. This method returns possible formats with the smallest number of variables. """ - tokens = _sanitized_tokens(input_format) - searcher = FormatSearcher(tokens) - for max_variable_length in range(1, 20): - result = searcher.search(max_variable_length) - if result: - return result - raise NoFormatFoundError + tokens = list(map(_sanitized_tokens, input_format.input_format)) + a = [] + for token in tokens: + searcher = FormatSearcher(token) + for max_variable_length in range(1, 20): + result = searcher.search(max_variable_length) + if result: + a.append(result) + break + else: + raise NoFormatFoundError + return a class NoFormatFoundError(Exception): diff --git a/atcodertools/tools/codegen.py b/atcodertools/tools/codegen.py index 098687e6..2f0b7784 100755 --- a/atcodertools/tools/codegen.py +++ b/atcodertools/tools/codegen.py @@ -111,7 +111,7 @@ def emit_info(text): output_file.write(code_generator( CodeGenArgs( template, - prediction_result.format, + list(map(lambda x: x.format, prediction_result)), constants, config.code_style_config ))) diff --git a/atcodertools/tools/envgen.py b/atcodertools/tools/envgen.py index 69fe76b6..8a2091c4 100755 --- a/atcodertools/tools/envgen.py +++ b/atcodertools/tools/envgen.py @@ -122,7 +122,7 @@ def emit_info(text): emit_info( with_color("Format prediction succeeded", Fore.LIGHTGREEN_EX)) except (NoPredictionResultError, MultiplePredictionResultsError) as e: - prediction_result = FormatPredictionResult.empty_result() + prediction_result = [FormatPredictionResult.empty_result()] if isinstance(e, NoPredictionResultError): msg = "No prediction -- Failed to understand the input format" else: @@ -134,10 +134,12 @@ def emit_info(text): with open(template_code_path, "r") as f: template = f.read() + prediction_result_format = list(map(lambda x: x.format, prediction_result)) + create_code(code_generator( CodeGenArgs( template, - prediction_result.format, + prediction_result_format, constants, config.code_style_config )), diff --git a/tests/resources/common/problem_htmls.tar.gz b/tests/resources/common/problem_htmls.tar.gz index 85b82abb..2bf41f08 100644 Binary files a/tests/resources/common/problem_htmls.tar.gz and b/tests/resources/common/problem_htmls.tar.gz differ diff --git a/tests/resources/common/test_data.tar.gz b/tests/resources/common/test_data.tar.gz index c45294ad..05401130 100644 Binary files a/tests/resources/common/test_data.tar.gz and b/tests/resources/common/test_data.tar.gz differ diff --git a/tests/resources/test_codegen/template_jinja_with_solve.cpp b/tests/resources/test_codegen/template_jinja_with_solve.cpp new file mode 100644 index 00000000..85643af2 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std; + +{% if mod is not none %} +const int mod = {{ mod }}; +{% endif %} +{% if yes_str is not none %} +const string YES = "{{ yes_str }}"; +{% endif %} +{% if no_str is not none %} +const string NO = "{{ no_str }}"; +{% endif %} +void solve({{ formal_arguments }}){ + +} +int main(){ + {{input_part_with_solve_function}} + return 0; +} diff --git a/tests/resources/test_codegen/template_jinja_with_solve.cs b/tests/resources/test_codegen/template_jinja_with_solve.cs new file mode 100644 index 00000000..e17cafff --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.cs @@ -0,0 +1,55 @@ +using System; +using System.Text; +using System.Linq; +using System.Collections; +using System.Collections.Generic; +using static System.Console; +using static System.Math; + +public class Program{ + {% if mod %} + const long MOD = {{ mod }}; + {% endif %} + {% if yes_str %} + const string YES = "{{ yes_str }}"; + {% endif %} + {% if no_str %} + const string NO = "{{ no_str }}"; + {% endif %} + + public static void Main(string[] args){ + ConsoleInput cin = new ConsoleInput(Console.In, ' '); + {{ input_part_with_solve_function }} + } + + public void Solve({{ formal_arguments }}){ + + } +} + +public class ConsoleInput{ + private readonly System.IO.TextReader _stream; + private char _separator = ' '; + private Queue inputStream; + public ConsoleInput(System.IO.TextReader stream, char separator = ' '){ + this._separator = separator; + this._stream = stream; + inputStream = new Queue(); + } + public string Read{ + get{ + if (inputStream.Count != 0) return inputStream.Dequeue(); + string[] tmp = _stream.ReadLine().Split(_separator); + for (int i = 0; i < tmp.Length; ++i) + inputStream.Enqueue(tmp[i]); + return inputStream.Dequeue(); + } + } + public string ReadLine { get { return _stream.ReadLine(); } } + public int ReadInt { get { return int.Parse(Read); } } + public long ReadLong { get { return long.Parse(Read); } } + public double ReadDouble { get { return double.Parse(Read); } } + public string[] ReadStrArray(long N) { var ret = new string[N]; for (long i = 0; i < N; ++i) ret[i] = Read; return ret;} + public int[] ReadIntArray(long N) { var ret = new int[N]; for (long i = 0; i < N; ++i) ret[i] = ReadInt; return ret;} + public long[] ReadLongArray(long N) { var ret = new long[N]; for (long i = 0; i < N; ++i) ret[i] = ReadLong; return ret;} +} diff --git a/tests/resources/test_codegen/template_jinja_with_solve.d b/tests/resources/test_codegen/template_jinja_with_solve.d new file mode 100644 index 00000000..a050a702 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.d @@ -0,0 +1,32 @@ +{% if prediction_success %} +import std.algorithm; +import std.conv; +import std.stdio; +import std.string; +{% endif %} +{% if mod or yes_str or no_str %} + +{% endif %} +{% if mod %} +immutable long MOD = {{ mod }}; +{% endif %} +{% if yes_str %} +immutable string YES = "{{ yes_str }}"; +{% endif %} +{% if no_str %} +immutable string NO = "{{ no_str }}"; +{% endif %} +{% if prediction_success %} + +void solve({{ formal_arguments }}){ + +} + +{% endif %} +int main(){ + {% if prediction_success %} + {{ input_part_with_solve_function }} + {% else %} + {% endif %} + return 0; +} diff --git a/tests/resources/test_codegen/template_jinja_with_solve.go b/tests/resources/test_codegen/template_jinja_with_solve.go new file mode 100644 index 00000000..eba989b0 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.go @@ -0,0 +1,40 @@ +package main +{% if prediction_success %} + +import ( + "bufio" + "os" + "strconv" +) +{% endif %} +{% if mod or yes_str or no_str %} + +{% endif %} +{% if mod %} +const MOD = {{mod}} +{% endif %} +{% if yes_str %} +const YES = "{{ yes_str }}" +{% endif %} +{% if no_str %} +const NO = "{{ no_str }}" +{% endif %} +{% if prediction_success %} + +func solve({{ formal_arguments }}) { + +} +{% endif %} + +func main() { + {% if prediction_success %} + scanner := bufio.NewScanner(os.Stdin) + const initialBufSize = 4096 + const maxBufSize = 1000000 + scanner.Buffer(make([]byte, initialBufSize), maxBufSize) + scanner.Split(bufio.ScanWords) + {{ input_part_with_solve_function }} + {% else %} + // Failed to predict input format + {% endif %} +} diff --git a/tests/resources/test_codegen/template_jinja_with_solve.java b/tests/resources/test_codegen/template_jinja_with_solve.java new file mode 100644 index 00000000..1e237b62 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.java @@ -0,0 +1,21 @@ +import java.io.*; +import java.util.*; + +class Main { + {% if mod is not none %} + static final int mod = {{ mod }}; + {% endif %} + {% if yes_str is not none %} + static final String YES = "{{ yes_str }}"; + {% endif %} + {% if no_str is not none %} + static final String NO = "{{ no_str }}"; + {% endif %} + public static void main(String[] args) throws Exception { + {{ input_part_with_solve_function }} + } + + static void solve({{ formal_arguments }}){ + + } +} diff --git a/tests/resources/test_codegen/template_jinja_with_solve.jl b/tests/resources/test_codegen/template_jinja_with_solve.jl new file mode 100644 index 00000000..9deb6b42 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.jl @@ -0,0 +1,39 @@ +#!/usr/bin/env julia +{% if prediction_success %} +{% endif %} +{% if mod or yes_str or no_str %} +{% endif %} +{% if mod %} +const MOD = {{ mod }} +{% endif %} +{% if yes_str %} +const YES = "{{ yes_str }}" +{% endif %} +{% if no_str %} +const NO = "{{ no_str }}" +{% endif %} +{% if prediction_success %} + +function solve({{ formal_arguments }}) + +end +{% endif %} + +function main() + {% if prediction_success %} + tokens = Channel{String}(32) + Task() do + for line in eachline(@static VERSION < v"0.6" ? STDIN : stdin) + for token in split(chomp(line)) + put!(tokens, token) + end + end + close(tokens) + end |> schedule + {{ input_part_with_solve_function }} + {% else %} + # Failed to predict input format + {% endif %} +end + +isempty(ARGS) && main() diff --git a/tests/resources/test_codegen/template_jinja_with_solve.nim b/tests/resources/test_codegen/template_jinja_with_solve.nim new file mode 100644 index 00000000..7b26909d --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.nim @@ -0,0 +1,34 @@ +import sequtils +proc scanf(formatstr: cstring){.header: "", varargs.} +proc getchar(): char {.header: "", varargs.} +proc nextInt(): int = scanf("%lld",addr result) +proc nextFloat(): float = scanf("%lf",addr result) +proc nextString(): string = + var get = false + result = "" + while true: + var c = getchar() + if int(c) > int(' '): + get = true + result.add(c) + else: + if get: break + +{% if mod %} +let MOD = {{ mod }} +{% endif %} +{% if yes_str %} +let YES = "{{ yes_str }}" +{% endif %} +{% if no_str %} +let NO = "{{ no_str }}" +{% endif %} + +proc solve({{ formal_arguments }}):void = + return + +proc main():void = + {{input_part_with_solve_function}} + return + +main() diff --git a/tests/resources/test_codegen/template_jinja_with_solve.py b/tests/resources/test_codegen/template_jinja_with_solve.py new file mode 100644 index 00000000..1cc83ba3 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +{% if prediction_success %} +import sys +{% endif %} +{% if mod or yes_str or no_str %} + +{% endif %} +{% if mod %} +MOD = {{ mod }} # type: int +{% endif %} +{% if yes_str %} +YES = "{{ yes_str }}" # type: str +{% endif %} +{% if no_str %} +NO = "{{ no_str }}" # type: str +{% endif %} +{% if prediction_success %} + + +def solve({{ formal_arguments }}): + return +{% endif %} + + +def main(): + {% if prediction_success %} + def iterate_tokens(): + for line in sys.stdin: + for word in line.split(): + yield word + tokens = iterate_tokens() + {{ input_part_with_solve_function }} + {% else %} + # Failed to predict input format + pass + {% endif %} + +if __name__ == '__main__': + main() diff --git a/tests/resources/test_codegen/template_jinja_with_solve.rust b/tests/resources/test_codegen/template_jinja_with_solve.rust new file mode 100644 index 00000000..3d9c4ec6 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.rust @@ -0,0 +1,73 @@ +use std::*; + +{% if mod %} +const MOD: i64 = {{ mod }}; +{% endif %} +{% if yes_str %} +const YES: String = "{{ yes_str }}"; +{% endif %} +{% if no_str %} +const NO: String = "{{ no_str }}"; +{% endif %} +{% if prediction_success %} +fn solve({{ formal_arguments }}) { + +} +{% endif %} + +fn main() { + {% if prediction_success %} + {{input_part_with_solve_function}} + {% else %} + // Failed to predict input format + {% endif %} +} + +pub mod io { + use std; + use std::str::FromStr; + + pub struct Scanner<'a> { + iter: std::str::SplitWhitespace<'a>, + } + + impl<'a> Scanner<'a> { + pub fn new(s: &'a str) -> Scanner<'a> { + Scanner { + iter: s.split_whitespace(), + } + } + + pub fn next(&mut self) -> T { + let s = self.iter.next().unwrap(); + if let Ok(v) = s.parse::() { + v + } else { + panic!("Parse error") + } + } + + pub fn next_vec_len(&mut self) -> Vec { + let n: usize = self.next(); + self.next_vec(n) + } + + pub fn next_vec(&mut self, n: usize) -> Vec { + (0..n).map(|_| self.next()).collect() + } + } + + pub fn read_string() -> String { + use std::io::Read; + + let mut s = String::new(); + std::io::stdin().read_to_string(&mut s).unwrap(); + s + } + + pub fn read_line() -> String { + let mut s = String::new(); + std::io::stdin().read_line(&mut s).unwrap(); + s.trim_right().to_owned() + } +} diff --git a/tests/resources/test_codegen/template_jinja_with_solve.swift b/tests/resources/test_codegen/template_jinja_with_solve.swift new file mode 100644 index 00000000..e9a6c914 --- /dev/null +++ b/tests/resources/test_codegen/template_jinja_with_solve.swift @@ -0,0 +1,51 @@ +import Foundation + +{% if mod %} +let MOD = {{ mod }} +{% endif %} +{% if yes_str %} +let YES = "{{ yes_str }}" +{% endif %} +{% if no_str %} +let NO = "{{ no_str }}" +{% endif %} +{% if prediction_success %} + +func solve({{ formal_arguments }}) { + {% if yes_str %} + var ans = false + + print(ans ? YES : NO) + {% else %} + var ans = 0 + + print(ans) + {% endif %} +} +{% endif %} + +func main() { + {% if prediction_success %} + var tokenIndex = 0, tokenBuffer = [String]() + func readString() -> String { + if tokenIndex >= tokenBuffer.count { + tokenIndex = 0 + tokenBuffer = readLine()!.split(separator: " ").map { String($0) } + } + defer { tokenIndex += 1 } + return tokenBuffer[tokenIndex] + } + func readInt() -> Int { Int(readString())! } + func readDouble() -> Double { Double(readString())! } + {{input_part_with_solve_function}} + {% else %} + // Failed to predict input format + {% endif %} +} + +#if DEBUG +let caseNumber = 1 +_ = freopen("in_\(caseNumber).txt", "r", stdin) +#endif + +main() diff --git a/tests/resources/test_codegen/test_float_case/intermediate_format.txt b/tests/resources/test_codegen/test_float_case/intermediate_format.txt index 4720d4fd..d03b265d 100644 --- a/tests/resources/test_codegen/test_float_case/intermediate_format.txt +++ b/tests/resources/test_codegen/test_float_case/intermediate_format.txt @@ -1 +1 @@ -[(Singular: L),(Singular: N),(Singular: M),(Parallel: K | 1 to L),(Parallel: A,S | 1 to N)] \ No newline at end of file +[[(Singular: L),(Singular: N),(Singular: M),(Parallel: K | 1 to L),(Parallel: A,S | 1 to N)]] diff --git a/tests/resources/test_codegen/test_long_case/intermediate_format.txt b/tests/resources/test_codegen/test_long_case/intermediate_format.txt index 63ceb874..de8d6f9c 100644 --- a/tests/resources/test_codegen/test_long_case/intermediate_format.txt +++ b/tests/resources/test_codegen/test_long_case/intermediate_format.txt @@ -1 +1 @@ -[(Singular: H),(Singular: W),(Singular: K),(Singular: sr),(Singular: sc),(Parallel: s | 1 to H),(Singular: N),(Parallel: fr,fc,F,D | 1 to N)] \ No newline at end of file +[[(Singular: H),(Singular: W),(Singular: K),(Singular: sr),(Singular: sc),(Parallel: s | 1 to H),(Singular: N),(Parallel: fr,fc,F,D | 1 to N)]] diff --git a/tests/resources/test_codegen/test_mod_case/intermediate_format.txt b/tests/resources/test_codegen/test_mod_case/intermediate_format.txt index ede137d4..0f43b293 100644 --- a/tests/resources/test_codegen/test_mod_case/intermediate_format.txt +++ b/tests/resources/test_codegen/test_mod_case/intermediate_format.txt @@ -1 +1 @@ -[(Singular: A),(Singular: B)] \ No newline at end of file +[[(Singular: A),(Singular: B)]] diff --git a/tests/resources/test_codegen/test_two_dimensional_case/intermediate_format.txt b/tests/resources/test_codegen/test_two_dimensional_case/intermediate_format.txt index b9f3d04d..e88214f6 100644 --- a/tests/resources/test_codegen/test_two_dimensional_case/intermediate_format.txt +++ b/tests/resources/test_codegen/test_two_dimensional_case/intermediate_format.txt @@ -1 +1 @@ -[(Singular: H),(Singular: W),(TwoDimensional: c),(TwoDimensional: A)] \ No newline at end of file +[[(Singular: H),(Singular: W),(TwoDimensional: c),(TwoDimensional: A)]] diff --git a/tests/resources/test_codegen/test_yes_no_case/intermediate_format.txt b/tests/resources/test_codegen/test_yes_no_case/intermediate_format.txt index cf1b0f42..9d32fdc5 100644 --- a/tests/resources/test_codegen/test_yes_no_case/intermediate_format.txt +++ b/tests/resources/test_codegen/test_yes_no_case/intermediate_format.txt @@ -1 +1 @@ -[(Singular: N),(Singular: M),(Singular: A),(Singular: B)] \ No newline at end of file +[[(Singular: N),(Singular: M),(Singular: A),(Singular: B)]] diff --git a/tests/resources/test_config/test_custom_codegen_toml/nim.toml b/tests/resources/test_config/test_custom_codegen_toml/nim.toml index e28e59e9..6c1a37ff 100644 --- a/tests/resources/test_config/test_custom_codegen_toml/nim.toml +++ b/tests/resources/test_config/test_custom_codegen_toml/nim.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments})" # ループ [loop] diff --git a/tests/resources/test_config/test_custom_codegen_toml/nim_custom.toml b/tests/resources/test_config/test_custom_codegen_toml/nim_custom.toml index 337ff18b..36aa1242 100644 --- a/tests/resources/test_config/test_custom_codegen_toml/nim_custom.toml +++ b/tests/resources/test_config/test_custom_codegen_toml/nim_custom.toml @@ -3,6 +3,7 @@ insert_space_around_operators = false # global変数宣言時の接頭辞 global_prefix = "" +solve_function = "solve({actual_arguments})" # ループ [loop] diff --git a/tests/resources/test_fmtprediction/answer.txt b/tests/resources/test_fmtprediction/answer.txt index d78b5939..605da8dc 100644 --- a/tests/resources/test_fmtprediction/answer.txt +++ b/tests/resources/test_fmtprediction/answer.txt @@ -472,6 +472,7 @@ abc115-A OK [(Singular: D)] [( abc115-B OK [(Singular: N),(Parallel: p | 1 to N)] [('N', ), ('p', )] abc115-C OK [(Singular: N),(Singular: K),(Parallel: h | 1 to N)] [('N', ), ('K', ), ('h', )] abc115-D OK [(Singular: N),(Singular: X)] [('N', ), ('X', )] +abc214-E OK [(Singular: T)] [('T', ), ('N', ), ('L', ), ('R', )] agc001-A OK [(Singular: N),(Parallel: L | 1 to 2*N)] [('N', ), ('L', )] agc001-B OK [(Singular: N),(Singular: X)] [('N', ), ('X', )] agc001-C OK [(Singular: N),(Singular: K),(Parallel: A,B | 1 to N-1)] [('N', ), ('K', ), ('A', ), ('B', )] diff --git a/tests/resources/test_setter/ans/main.java b/tests/resources/test_setter/ans/main.java index 658bc694..a243c118 100644 --- a/tests/resources/test_setter/ans/main.java +++ b/tests/resources/test_setter/ans/main.java @@ -3,7 +3,7 @@ class Main { - // Generated by 2.9.0 https://github.com/kyuridenamida/atcoder-tools (tips: You use the default template now. You can remove this line by using your custom template) + // Generated by 2.11.0 https://github.com/kyuridenamida/atcoder-tools (tips: You use the default template now. You can remove this line by using your custom template) public static void main(String[] args) throws Exception { final Scanner sc = new Scanner(System.in); long T; diff --git a/tests/resources/test_tester/config_common.toml b/tests/resources/test_tester/config_common.toml new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_atcoder_client_real.py b/tests/test_atcoder_client_real.py index b3903df5..5bb3c983 100644 --- a/tests/test_atcoder_client_real.py +++ b/tests/test_atcoder_client_real.py @@ -27,7 +27,8 @@ def test_submit_source_code(self): def test_download_problem_content(self): content = self.client.download_problem_content( Problem(Contest("arc002"), "C", "arc002_3")) - self.assertEqual("N\nc_{1}c_{2}...c_{N}\n", content.input_format_text) + self.assertEqual(["N \n c_{1}c_{2}...c_{N}\n"], + content.input_format_text) self.assertEqual(3, len(content.samples)) def test_login_failed(self): diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 8eb186f9..51ddc8d2 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -35,12 +35,12 @@ def load_generated_code(py_test_name, lang): def load_intermediate_types(py_test_name): with open(os.path.join(RESOURCE_DIR, py_test_name, "intermediate_types.txt"), 'r') as f: - return f.read() + return f.read().strip() def load_intermediate_format(py_test_name): with open(os.path.join(RESOURCE_DIR, py_test_name, "intermediate_format.txt"), 'r') as f: - return f.read() + return f.read().strip() class TestCodeGenerator(unittest.TestCase): @@ -55,42 +55,52 @@ def setUp(self): CPP: { "old": "template.cpp", "jinja": "template_jinja.cpp", + "with_solve": "template_jinja_with_solve.cpp", }, JAVA: { "old": "template.java", "jinja": "template_jinja.java", + "with_solve": "template_jinja_with_solve.java", }, RUST: { "old": "template.rust", "jinja": "template_jinja.rust", + "with_solve": "template_jinja_with_solve.rust", }, PYTHON: { "old": "template.py", "jinja": "template_jinja.py", + "with_solve": "template_jinja_with_solve.py", }, NIM: { "old": "template.nim", "jinja": "template_jinja.nim", + "with_solve": "template_jinja_with_solve.nim", }, DLANG: { "old": "template.d", "jinja": "template_jinja.d", + "with_solve": "template_jinja_with_solve.d", }, CSHARP: { "old": "template.cs", "jinja": "template_jinja.cs", + "with_solve": "template_jinja_with_solve.cs", }, SWIFT: { "old": "template.swift", "jinja": "template_jinja.swift", + "with_solve": "template_jinja_with_solve.swift", }, GO: { "old": "template.go", "jinja": "template_jinja.go", + "with_solve": "template_jinja_with_solve.go", }, JULIA: { "old": "template.jl", "jinja": "template_jinja.jl", + "with_solve": "template_jinja_with_solve.jl", }, } self.lang_to_code_generator_func = { @@ -186,7 +196,7 @@ def _full_path(filename): self._compile_and_run( lang, - pred_result.format, + list(map(lambda x: x.format, pred_result)), lang.default_template_path, expected_default_generated_code_file, input_file @@ -196,7 +206,7 @@ def _full_path(filename): exec_result = self._compile_and_run( lang, - pred_result.format, + list(map(lambda x: x.format, pred_result)), _full_path(os.path.join( lang.name, lang.source_code_name("echo_template"))), _full_path(os.path.join(lang.name, lang.source_code_name( @@ -206,6 +216,17 @@ def _full_path(filename): self.assertEqual(load_text_file( expected_output_file), exec_result.output) + def test_with_solve_template(self): + # test_mod_caseがちゃんと展開されるか + response = self.runner.run('agc019-E') + for lang in ALL_LANGUAGES: + self.verify(response, "test_mod_case", + lang, "with_solve", ProblemConstantSet(mod=998244353)) + response = self.runner.run('abc214-E') + for lang in ALL_LANGUAGES: + self.verify(response, "test_multiple_testcases", + lang, "with_solve", ProblemConstantSet()) + def _compile_command(self, lang: Language, code_file: str): if lang == CPP: return "g++ {} -o a.out -std=c++14".format(code_file) @@ -299,8 +320,9 @@ def _compile_and_run(self, lang, format, template_file, expected_generated_code_ print(run_command(compile_cmd, self.temp_dir)) print("Run program:", [exec_file] + exec_args) + # TODO: なぜかjuliaだけ時間がかかる。。。 exec_result = run_program( - exec_file, input_file, 2, exec_args, self.temp_dir) + exec_file, input_file, 100, exec_args, self.temp_dir) finally: self._clean_up(lang) @@ -328,18 +350,21 @@ def verify(self, lang: Language, template_type: str = "old", constants: ProblemConstantSet = ProblemConstantSet()): + response_simple_format = "[" + \ + ",".join(list(map(str, response.simple_format))) + "]" self.assertEqual( load_intermediate_format(py_test_name), - str(response.simple_format)) + response_simple_format) self.assertEqual( load_intermediate_types(py_test_name), str(response.types)) + self.assertEqual( load_generated_code(py_test_name, lang), self.lang_to_code_generator_func[lang]( CodeGenArgs( self.get_template(lang, template_type), - response.original_result.format, + list(map(lambda x: x.format, response.original_result)), constants, CodeStyleConfig(lang=lang.name)) )) diff --git a/tests/test_config.py b/tests/test_config.py index ac06e07c..7b978497 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -113,7 +113,7 @@ def test_custom_codegen_toml(self): code = config.code_generator( CodeGenArgs( template, - response.original_result.format, + list(map(lambda x: x.format, response.original_result)), ProblemConstantSet(), config) ) diff --git a/tests/test_fmtprediction.py b/tests/test_fmtprediction.py index 900d45cd..bcd3d0ef 100644 --- a/tests/test_fmtprediction.py +++ b/tests/test_fmtprediction.py @@ -38,7 +38,9 @@ def test_overall(self): response = runner.run(case) if response.status == "OK": - output_text += "{:40} {:20} {} {}\n".format(case, response.status, response.simple_format, + # TODO: インデックス0しか見ていない。本当はansの方を配列に変える方が適切 + first_format = response.simple_format[0] + output_text += "{:40} {:20} {} {}\n".format(case, response.status, first_format, response.types) else: output_text += "{:40} {}\n".format(case, response.status) diff --git a/tests/test_tester.py b/tests/test_tester.py index 3356d731..dcc44783 100755 --- a/tests/test_tester.py +++ b/tests/test_tester.py @@ -23,48 +23,65 @@ class TestTester(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.mkdtemp() + self.common_config_path = os.path.join( + RESOURCE_DIR, "config_common.toml") def test_multiple_exec_files(self): all_ok = tester.main( - '', ['-d', os.path.join(RESOURCE_DIR, "test_multiple_exec_files")]) + '', ['-d', os.path.join(RESOURCE_DIR, "test_multiple_exec_files"), + "--config", self.common_config_path]) self.assertTrue(all_ok) def test_run_single_test(self): test_dir = os.path.join(RESOURCE_DIR, "test_run_single_test") - self.assertTrue(tester.main('', ['-d', test_dir, "-n", "1"])) - self.assertFalse(tester.main('', ['-d', test_dir, "-n", "2"])) + self.assertTrue(tester.main( + '', ['-d', test_dir, "-n", "1", "--config", self.common_config_path])) + self.assertFalse(tester.main( + '', ['-d', test_dir, "-n", "2", "--config", self.common_config_path])) def test_run_single_test_decimal_addition(self): test_dir = os.path.join( RESOURCE_DIR, "test_run_single_test_decimal_addition") self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute_or_relative"])) + '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute_or_relative", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "2", "-v", "0.01", "--judge-type", "absolute_or_relative"])) + '', ['-d', test_dir, "-n", "2", "-v", "0.01", "--judge-type", "absolute_or_relative", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute"])) + '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "absolute"])) + '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "absolute", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "relative"])) + '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "relative", + "--config", self.common_config_path])) self.assertFalse(tester.main( - '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "relative"])) + '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "relative", + "--config", self.common_config_path])) def test_run_single_test_decimal_multiplication(self): test_dir = os.path.join( RESOURCE_DIR, "test_run_single_test_decimal_multiplication") self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute_or_relative"])) + '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute_or_relative", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "2", "--error-value", "0.01", "-j", "absolute_or_relative"])) + '', ['-d', test_dir, "-n", "2", "--error-value", "0.01", "-j", "absolute_or_relative", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute"])) + '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "absolute", + "--config", self.common_config_path])) self.assertFalse(tester.main( - '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "absolute"])) + '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "absolute", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "relative"])) + '', ['-d', test_dir, "-n", "1", "-v", "0.01", "-j", "relative", + "--config", self.common_config_path])) self.assertTrue(tester.main( - '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "relative"])) + '', ['-d', test_dir, "-n", "2", "-v", "0.01", "-j", "relative", + "--config", self.common_config_path])) @patch('os.access', return_value=True) @patch('pathlib.Path.is_file', return_value=True) @@ -211,7 +228,9 @@ def test_compiler_and_tester(self): self.assertTrue(tester.main( '', ['-d', test_dir, "-n", "{:d}".format(i), "--compile-before-testing", "-j", "normal", - "--compile-command", "g++ main.cpp -o main && touch compile{}".format(i)])) + "--compile-command", "g++ main.cpp -o main && touch compile{}".format( + i), + "--config", self.common_config_path])) lst = os.listdir(test_dir) self.assertTrue("compile1" in lst) self.assertTrue("compile2" in lst) diff --git a/tests/utils/fmtprediction_test_runner.py b/tests/utils/fmtprediction_test_runner.py index 5df6e1b9..eec5d00b 100644 --- a/tests/utils/fmtprediction_test_runner.py +++ b/tests/utils/fmtprediction_test_runner.py @@ -14,9 +14,12 @@ def __init__(self, result: Optional[FormatPredictionResult], status): self.status = status if result: self.original_result = result - self.simple_format = result.format - var_info = [(var.name, var.type) - for var in result.format.all_vars()] + self.simple_format = list(map(lambda x: x.format, result)) + var_info = [] + for r in result: + var_info += [(var.name, var.type) + for var in r.format.all_vars()] + self.types = [(name, type.to_py_type()) for name, type in var_info] @@ -35,10 +38,15 @@ def load_problem_content(self, case_name: str) -> ProblemContent: case_dir = self._get_test_case_dir(case_name) format_file = os.path.join(case_dir, FORMAT_FILE_NAME) example_files = [os.path.join(case_dir, file) - for file in os.listdir(case_dir) if file != FORMAT_FILE_NAME] + for file in os.listdir(case_dir) if file.startswith("ex")] with open(format_file, 'r', encoding="utf-8") as f: - input_format = f.read() + input_format = [f.read()] + + second_format_file = os.path.join(case_dir, "format_2.txt") + if os.path.exists(second_format_file): + with open(second_format_file, 'r', encoding="utf-8") as f: + input_format.append(f.read()) examples = [] for ex_file in example_files: