Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

バグ修正(format predictionの1d_flagのインデックス, dotsの除去が正常にできていない) #258

Open
wants to merge 10 commits into
base: stable
Choose a base branch
from
9 changes: 8 additions & 1 deletion atcodertools/client/models/problem_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def normalize(content: str) -> str:
return content.strip().replace('\r', '') + "\n"


def normalize_soup(content) -> str:
# for a in content.findAll('var'):
# a.replace_with(' ' + a.text + ' ')
return normalize(content.text)


def is_japanese(ch):
# Thank you!
# http://minus9d.hatenablog.com/entry/2015/07/16/231608
Expand Down Expand Up @@ -92,7 +98,8 @@ def _extract_input_format_and_samples(soup) -> Tuple[str, List[Sample]]:
if input_format_tag is None:
raise InputFormatDetectionError

input_format_text = normalize(input_format_tag.text)
# input_format_text = normalize(input_format_tag.text)
input_format_text = normalize_soup(input_format_tag)
except AttributeError:
raise InputFormatDetectionError

Expand Down
74 changes: 55 additions & 19 deletions atcodertools/fmtprediction/predict_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
search_formats_with_minimum_vars
from atcodertools.fmtprediction.predict_types import predict_types, TypePredictionFailedError
from atcodertools.fmtprediction.models.format_prediction_result import FormatPredictionResult
import re


class NoPredictionResultError(Exception):
Expand All @@ -16,33 +17,68 @@ def __init__(self, cands):
self.cands = cands


def suspect_single_string(input_format: str, samples):
a = input_format.strip().split()
if len(a) != 1:
return None
input_format = a[0].strip()
for sample in samples:
s = sample.get_input().split()
if len(s) != 1:
return None
i = input_format.find('_')
if i == -1:
return None
pattern = input_format[0:i + 1]
if len([m.start() for m in re.finditer(pattern, input_format)]) < 2:
return None
return input_format[0:i]


def predict_format(content: ProblemContent) -> FormatPredictionResult:
input_format = content.get_input_format()
samples = content.get_samples()
input_format = input_format.replace('\'', 'prime')

if len(samples) == 0:
raise NoPredictionResultError

try:
tokenized_possible_formats = search_formats_with_minimum_vars(
input_format)
except NoFormatFoundError:
raise NoPredictionResultError
for ct in [0, 1]:
tokenized_possible_formats = []
if ct == 0:
try:
tokenized_possible_formats += search_formats_with_minimum_vars(
input_format)
except NoFormatFoundError:
continue
elif ct == 1:
input_format2 = suspect_single_string(input_format, samples)
if input_format2 is not None:
try:
tokenized_possible_formats += search_formats_with_minimum_vars(
input_format2)
except NoFormatFoundError:
raise NoPredictionResultError

output_cands = []
for format in tokenized_possible_formats:
for to_1d_flag in [False, True]:
output_cands = []
for format in tokenized_possible_formats:
try:
simple_format = predict_simple_format(
format.var_tokens, to_1d_flag)
output_cands.append(
FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples)))
break
simple_format_array = predict_simple_format(format.var_tokens)
except (TypePredictionFailedError, SimpleFormatPredictionFailedError):
pass
continue
for simple_format in simple_format_array:
try:
output_cands.append(
FormatPredictionResult.create_typed_format(simple_format, predict_types(simple_format, samples)))
break
except (TypePredictionFailedError, SimpleFormatPredictionFailedError):
pass

if len(output_cands) > 1:
raise MultiplePredictionResultsError(output_cands)
if len(output_cands) == 0:
raise NoPredictionResultError
return output_cands[0]
if len(output_cands) == 1:
return output_cands[0]
elif len(output_cands) > 1:
raise MultiplePredictionResultsError(output_cands)
elif ct == 0:
continue
else:
raise NoPredictionResultError
94 changes: 84 additions & 10 deletions atcodertools/fmtprediction/predict_simple_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _predict_period(seq: List[int]):
return 1


def _predict_simple_format_main(var_tokens: List[VariableToken], to_1d_flag=False) -> Format[SimpleVariable]:
def _predict_simple_format_main(var_tokens: List[VariableToken], to_1d_flag) -> Format[SimpleVariable]:
var_to_positions = {}
var_to_simple_var = OrderedDict()

Expand Down Expand Up @@ -62,10 +62,14 @@ def _predict_simple_format_main(var_tokens: List[VariableToken], to_1d_flag=Fals

dim = var_token.dim_num()

if dim == 2 and to_1d_flag:
simple_var.first_index = simple_var.second_index
simple_var.second_index = None
dim = 1
if pos in to_1d_flag:
if dim == 2:
# simple_var.first_index = simple_var.second_index
simple_var.second_index = None
dim = 1
elif dim == 1:
simple_var.first_index = None
dim = 0

if dim == 0:
root.push_back(SingularPattern(simple_var))
Expand All @@ -88,8 +92,78 @@ def _predict_simple_format_main(var_tokens: List[VariableToken], to_1d_flag=Fals
return root


def predict_simple_format(var_tokens: List[VariableToken], to_1d_flag=False) -> Format[SimpleVariable]:
try:
return _predict_simple_format_main(var_tokens, to_1d_flag)
except (WrongGroupingError, UnknownPeriodError):
raise SimpleFormatPredictionFailedError
def _predict_1d_flag_pos(var_tokens: List[VariableToken]) -> List[Format[SimpleVariable]]:
var_to_positions = {}
var_to_simple_var = OrderedDict()

# Pre-computation of the min / max value of each of the first and second
# indices.
for pos, var_token in enumerate(var_tokens):
var_name = var_token.var_name

if var_name not in var_to_simple_var:
var_to_simple_var[var_name] = SimpleVariable.create(
var_name, var_token.dim_num())
var_to_positions[var_name] = []

var_to_positions[var_name].append(pos)

if var_token.dim_num() >= 2:
var_to_simple_var[var_name].second_index.update(
var_token.second_index)
if var_token.dim_num() >= 1:
var_to_simple_var[var_name].first_index.update(
var_token.first_index)

# Building format nodes
already_processed_vars = set()

result = []
for pos, var_token in enumerate(var_tokens):
var_name = var_token.var_name
# simple_var = var_to_simple_var[var_name]

if var_name in already_processed_vars:
continue

dim = var_token.dim_num()

if dim == 0:
pass
elif dim == 1:
try:
period = _predict_period(var_to_positions[var_name])
except UnknownPeriodError:
continue
if period == 1:
result.append(pos)
parallel_vars_group = [var_to_simple_var[token.var_name]
for token in var_tokens[pos:pos + period]]
# try:
# root.push_back(ParallelPattern(parallel_vars_group))
# except WrongGroupingError:
# raise
for var in parallel_vars_group:
already_processed_vars.add(var.name)
elif dim == 2:
result.append(pos)
else:
raise NotImplementedError
already_processed_vars.add(var_name)
return result


def predict_simple_format(var_tokens: List[VariableToken]) -> Format[SimpleVariable]:
flag_pos = _predict_1d_flag_pos(var_tokens)
result = []
for b in range(1 << len(flag_pos)):
to_1d_flag = set()
for i in range(len(flag_pos)):
if b & (1 << i):
to_1d_flag.add(flag_pos[i])
try:
result.append(_predict_simple_format_main(var_tokens, to_1d_flag))
except (WrongGroupingError, UnknownPeriodError):
continue
# raise SimpleFormatPredictionFailedError
return result
6 changes: 5 additions & 1 deletion atcodertools/fmtprediction/tokenize_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ def _is_ascii(s):


DOTS_PATTERNS = ["ldots", "cdots", "vdots", "ddots", "dots"]
SPACE_PATTERNS = ["hspace", "vspace"]


def _is_noise(s):
if any(pattern in s for pattern in DOTS_PATTERNS):
return True
if any(pattern in s for pattern in SPACE_PATTERNS):
return True

return s == ":" or s == "...." or s == "..." or s == ".." or s == "."

Expand Down Expand Up @@ -67,7 +70,8 @@ def _remove_spaces_in_curly_brackets(input_format):

def _sanitized_tokens(input_format: str) -> List[str]:
input_format = input_format.replace("\n", " ").replace("…", " ").replace("...", " ").replace(
"..", " ").replace("\\ ", " ").replace("}", "} ").replace(" ", " ").replace(", ", ",")
"..", " ").replace("‥", " ").replace("\\ ", " ").replace("}", "} ").replace(" ", " ").replace(", ", ",")
input_format = input_format.replace(" _ ", "_") # 空白の添字を削除
input_format = _remove_spaces_in_curly_brackets(input_format)
input_format = _divide_consecutive_vars(input_format)
input_format = _normalize_index(input_format)
Expand Down
10 changes: 9 additions & 1 deletion atcodertools/tools/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,22 @@ def main(prog, args, credential_supplier=None, use_local_session_cache=True, cli
return False
code_path = config.submit_config.submit_filename
logger.info(f"changed to submitfile: {code_path}")


recognized = False
for encoding in ['utf8', 'utf-8_sig', 'cp932']:
try:
with open(os.path.join(args.dir, code_path), 'r', encoding=encoding) as f:
source = f.read()
recognized = True
break
except UnicodeDecodeError:
logger.warning("code wasn't recognized as {}".format(encoding))

if not recognized:
import urllib.parse
with open(os.path.join(args.dir, code_path), 'rb') as f:
source = urllib.parse.quote(f.read())

logger.info(
"Submitting {} as {}".format(code_path, metadata.lang.name))
submission = client.submit_source_code(
Expand Down
7 changes: 6 additions & 1 deletion tests/resources/common/download_htmls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def mkdirs(path):
htmls_dir = "./problem_htmls/"
mkdirs(htmls_dir)
for contest in atcoder.download_all_contests():
for problem in atcoder.download_problem_list(contest):
try:
d = atcoder.download_problem_list(contest)
except Exception as e:
print("Failed to download problem list ")
continue
for problem in d:
html_path = os.path.join(htmls_dir, "{contest}-{problem_id}.html".format(
contest=contest.get_id(), problem_id=problem.get_alphabet()))

Expand Down
11 changes: 10 additions & 1 deletion tests/resources/common/download_testcases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import errno
import os
import time

from atcodertools.client.atcoder import AtCoderClient
from atcodertools.client.models.problem_content import SampleDetectionError, InputFormatDetectionError
Expand All @@ -24,7 +25,15 @@ def mkdirs(path):

if __name__ == "__main__":
for contest in atcoder.download_all_contests():
for problem in atcoder.download_problem_list(contest):
if contest.get_id().startswith("asprocon") or contest.get_id().startswith("future-meets-you-contest"):
continue
try:
d = atcoder.download_problem_list(contest)
except Exception:
print("download problem list error for {}".format(contest.get_id()))
time.sleep(120)
continue
for problem in d:
path = "./test_data/{contest}-{problem_id}".format(contest=contest.get_id(),
problem_id=problem.get_alphabet())
if os.path.exists(path) and len(os.listdir(path)) != 0:
Expand Down
Binary file modified tests/resources/common/problem_htmls.tar.gz
Binary file not shown.
Binary file modified tests/resources/common/test_data.tar.gz
Binary file not shown.
Loading