Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

テストケースが1ファイルにつき複数ある場合の対応 #255

Open
wants to merge 4 commits into
base: stable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions atcodertools/client/models/problem_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def normalize(content: str) -> str:
return content.strip().replace('\r', '') + "\n"


def normalize_soup(content) -> str:
for a in content.findAll('var'):
a.replace_with(' ' + a.text + ' ')
return normalize(content.text)


def is_japanese(ch):
# Thank you!
# http://minus9d.hatenablog.com/entry/2015/07/16/231608
Expand All @@ -34,26 +40,45 @@ class InputFormatDetectionError(Exception):
pass


class InputFormat:
def __init__(self, input_format: list[str]):
self.type = None
self.loop_length_var = None
self.input_format = input_format


def _strip_case_vars(s):
result = []
for line in s.split("\n"):
if line.find("case") == -1 and line.find("Case") == -1:
result.append(line)
return "\n".join(result)


class ProblemContent:

def __init__(self, input_format_text: Optional[str] = None,
samples: Optional[List[Sample]] = None,
original_html: Optional[str] = None,
):
self.samples = samples
self.input_format_text = input_format_text
if type(input_format_text) is str:
self.input_format_text = [input_format_text]
else:
self.input_format_text = input_format_text
self.original_html = original_html
self.input_format_data = InputFormat(self.input_format_text)

@classmethod
def from_html(cls, html: str):
res = ProblemContent(original_html=html)
soup = BeautifulSoup(html, "html.parser")
res.input_format_text, res.samples = res._extract_input_format_and_samples(
res.input_format_text, res.input_format_data, res.samples = res._extract_input_format_and_samples(
soup)
return res

def get_input_format(self) -> str:
return self.input_format_text
def get_input_format(self) -> list[str]:
return self.input_format_data

def get_samples(self) -> List[Sample]:
return self.samples
Expand Down Expand Up @@ -91,12 +116,17 @@ def _extract_input_format_and_samples(soup) -> Tuple[str, List[Sample]]:

if input_format_tag is None:
raise InputFormatDetectionError

input_format_text = normalize(input_format_tag.text)
input_format_text = list(
map(lambda x: normalize_soup(x), input_format_tag))
except AttributeError:
raise InputFormatDetectionError

return input_format_text, res
if len(input_format_text) == 2:
input_format_text[0] = _strip_case_vars(input_format_text[0])

input_format_data = InputFormat(input_format_text)

return input_format_text, input_format_data, res

@staticmethod
def _primary_strategy(soup):
Expand All @@ -114,17 +144,18 @@ def _primary_strategy(soup):
if section_title.startswith("入力例"):
input_tags.append(tag.find('pre'))
elif section_title.startswith("入力"):
input_format_tag = tag.find('pre')
input_format_tag = tag.findAll('pre')

if section_title.startswith("出力例"):
output_tags.append(tag.find('pre'))
return input_format_tag, input_tags, output_tags

# TODO: こっちのタイプはmulti caseに未対応!!!
@staticmethod
def _secondary_strategy(soup): # TODO: more descriptive name
pre_tags = soup.select('pre')
sample_tags = pre_tags[1:]
input_tags = sample_tags[0::2]
output_tags = sample_tags[1::2]
input_format_tag = pre_tags[0]
input_format_tag = [pre_tags[0]]
return input_format_tag, input_tags, output_tags
46 changes: 40 additions & 6 deletions atcodertools/codegen/code_generators/universal_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,31 @@ def _insert_space_around_operators(self, code: str):

def _global_declaration(self) -> str:
lines = []
for pattern in self._format.sequence:
for pattern in self._format[0].sequence:
for var in pattern.all_vars():
self._append(
lines, self.info["global_prefix"] + self._generate_declaration(var))
return "\n".join(lines)

def generate_parameters(self) -> Dict[str, Any]:
if self._format is None:
if self._format[0] is None:
return dict(prediction_success=False)

self._input_part_data = self._input_part(global_mode=False)

return dict(formal_arguments=self._formal_arguments(),
actual_arguments=self._actual_arguments(),
input_part=self._input_part(global_mode=False),
input_part=self._input_part_data,
global_declaration=self._global_declaration(),
global_input_part=self._input_part(global_mode=True),
input_part_with_solve_function=self._input_part_with_solve_function(),
prediction_success=True)

def _input_part(self, global_mode):
t = len(self._format) - 1
return self._get_input_part(global_mode, self._format[t])

def _get_input_part(self, global_mode, format):
lines = []
newline_after_input = False
if "newline_after_input" in self.info and self.info["newline_after_input"]:
Expand All @@ -80,7 +87,7 @@ def _input_part(self, global_mode):
lines.append(line)
if newline_after_input:
lines.append("")
for pattern in self._format.sequence:
for pattern in format.sequence:
lines += self._render_pattern(pattern, global_mode)
if newline_after_input:
lines.append("")
Expand All @@ -98,6 +105,31 @@ def _input_part(self, global_mode):
start = False
return result

def _input_part_with_solve_function(self):
prefix = "{indent}".format(
indent=self._indent(self.info["base_indent"]))
result = self._get_input_part(False, self._format[0]) + "\n"
solve_function = self.info["solve_function"].format(
actual_arguments=self._actual_arguments())
if len(self._format) == 1:
result += prefix + solve_function
elif len(self._format) == 2:
result += prefix
result += self.info["loop"]["header"].format(
loop_var="case_index",
length="T" # TODO
)
result += "\n"
t = (prefix + self._get_input_part(False,
self._format[1])).split("\n")
t = list(map(lambda x: self._indent(1) + x, t))
t.append(prefix + self._indent(1) + solve_function)
result += "\n".join(t)
footer = self.info["loop"]["footer"].format()
if footer != '':
result += "\n" + prefix + footer
return result

def _convert_type(self, type_: Type) -> str:
return self.info["type"][type_.value]

Expand Down Expand Up @@ -143,8 +175,9 @@ def _actual_arguments(self) -> str:
"""
:return the string form of actual arguments e.g. "N, K, a"
"""
t = len(self._format) - 1
ret = []
for v in self._format.all_vars():
for v in self._format[t].all_vars():
if v.dim_num() == 0:
ret.append(v.name)
else:
Expand All @@ -160,7 +193,8 @@ def _formal_arguments(self):
"""
:return the string form of formal arguments e.g. "int N, int K, std::vector<int> a"
"""
return ", ".join([self._get_argument(v) for v in self._format.all_vars()])
t = len(self._format) - 1
return ", ".join([self._get_argument(v) for v in self._format[t].all_vars()])

def _generate_declaration(self, var: Variable):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ insert_space_around_operators = false

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "solve({actual_arguments});"

# ループ
[loop]
Expand Down Expand Up @@ -65,4 +66,3 @@ int = "std::scanf(\"%lld\", &{name});"
float = "std::scanf(\"%Lf\", &{name});"
str = "std::cin >> {name};"


Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ insert_space_around_operators = false

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "new Program().Solve({actual_arguments});"

# ループ
[loop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ insert_space_around_operators = false
global_prefix = ""
input_part_prefix = "auto input = stdin.byLine.map!split.joiner;"
newline_after_input = true
solve_function = "solve({actual_arguments});"

# ループ
[loop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ insert_space_around_operators = false

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "solve({actual_arguments})"

# ループ
[loop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ insert_space_around_operators = false
# global変数宣言時の接頭辞
global_prefix = "static "
input_part_prefix = "final Scanner sc = new Scanner(System.in);"
solve_function = "solve({actual_arguments});"

# ループ
[loop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ insert_space_around_operators = false

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "solve({actual_arguments})"

# インデックス
[index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ insert_space_around_operators = false

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "solve({actual_arguments})"

# ループ
[loop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ insert_space_around_operators = true

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "solve({actual_arguments})"

# インデックス
[index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ insert_space_around_operators = false
# global変数宣言時の接頭辞
global_prefix = ""
input_part_prefix = "let con = read_string();\nlet mut scanner = Scanner::new(&con);"
solve_function = "solve({actual_arguments});"

# ループ
[loop]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ insert_space_around_operators = false

# global変数宣言時の接頭辞
global_prefix = ""
solve_function = "_ = solve({actual_arguments})"

# ループ
[loop]
Expand Down
11 changes: 9 additions & 2 deletions atcodertools/fmtprediction/models/format_prediction_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class FormatPredictionResult:
def __init__(self, format_: Optional[Format[Variable]] = None):
self.format = format_

@classmethod
def create_typed_format(cls, simple_format: Format[SimpleVariable], var_to_type: Dict[str, Type]):
def _create_typed_format(self, simple_format: Format[SimpleVariable], var_to_type: Dict[str, Type]):
var_to_info = {}
for var in simple_format.all_vars():
assert var.name not in var_to_info
Expand All @@ -29,6 +28,14 @@ def create_typed_format(cls, simple_format: Format[SimpleVariable], var_to_type:

return FormatPredictionResult(fmt)

@classmethod
def create_typed_format(cls, simple_formats: list[Format[SimpleVariable]], var_to_type: Dict[str, Type]):
result = []
for simple_format in simple_formats:
result.append(cls._create_typed_format(
cls, simple_format, var_to_type))
return result

@classmethod
def empty_result(cls):
return FormatPredictionResult()
39 changes: 32 additions & 7 deletions atcodertools/fmtprediction/predict_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, cands):
self.cands = cands


def predict_format(content: ProblemContent) -> FormatPredictionResult:
def predict_format(content: ProblemContent) -> list[FormatPredictionResult]:
input_format = content.get_input_format()
samples = content.get_samples()

Expand All @@ -29,20 +29,45 @@ def predict_format(content: ProblemContent) -> FormatPredictionResult:
except NoFormatFoundError:
raise NoPredictionResultError

# tokenized_possible_formats = [input1, input2, ...]
output_cands = []
for format in tokenized_possible_formats:

simple_formats = []
for tokenized_possible_format in tokenized_possible_formats:
simple_format = []
for to_1d_flag in [False, True]:
for format in tokenized_possible_format:
try:
simple_format.append(predict_simple_format(
format.var_tokens, to_1d_flag))
except (TypePredictionFailedError, SimpleFormatPredictionFailedError):
pass
simple_formats.append(simple_format)

if len(simple_formats) == 1:
for a in simple_formats[0]:
simple_format = [a]
try:
simple_format = predict_simple_format(
format.var_tokens, to_1d_flag)
output_cands.append(
FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples)))
FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples, input_format.loop_length_var)))
break
except (TypePredictionFailedError, SimpleFormatPredictionFailedError):
pass
elif len(simple_formats) == 2:
for a in simple_formats[0]:
for b in simple_formats[1]:
simple_format = [a, b]
try:
output_cands.append(
FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples, input_format.loop_length_var)))
break
except (TypePredictionFailedError, SimpleFormatPredictionFailedError):
pass

if len(output_cands) > 1:
raise MultiplePredictionResultsError(output_cands)
# TODO: ここをコメントアウトしたが大丈夫か?
# if len(output_cands) > 1:
# raise MultiplePredictionResultsError(output_cands)
if len(output_cands) == 0:
raise NoPredictionResultError

return output_cands[0]
Loading