diff --git a/DocumentUnderstanding/VGT/object_detection/create_grid_input.py b/DocumentUnderstanding/VGT/object_detection/create_grid_input.py index f04d85e..c424a7f 100644 --- a/DocumentUnderstanding/VGT/object_detection/create_grid_input.py +++ b/DocumentUnderstanding/VGT/object_detection/create_grid_input.py @@ -134,7 +134,7 @@ def create_grid_dict(tokenizer, page_data): return grid -def save_pkl_file(grid, output_dir, output_file, model="doclaynet"): +def save_pkl_file(grid, output_dir, filename, model): """Save the grid dictionary as a pickle file. Parameters @@ -153,18 +153,13 @@ def save_pkl_file(grid, output_dir, output_file, model="doclaynet"): ------- None """ - if model == "doclaynet" or model == "publaynet": - extension = "pdf.pkl" + #os.makedirs(output_dir, exist_ok=True) + if model == 'doclaynet': + with open(os.path.join(output_dir, f"{filename}.pkl"), 'wb') as f: + pickle.dump(grid, f) else: - extension = "pkl" - - pkl_save_location = os.path.join( - output_dir, - f'{output_file}.{extension}') - - with open(pkl_save_location, 'wb') as handle: - pickle.dump(grid, handle) - + with open(os.path.join(output_dir, f"{filename}.pdf.pkl"), 'wb') as f: + pickle.dump(grid, f) def select_tokenizer(tokenizer): """Select the tokenizer to be used. @@ -212,4 +207,4 @@ def select_tokenizer(tokenizer): for page in range(len(word_grid)): grid = create_grid_dict(tokenizer, word_grid[page]) - save_pkl_file(grid, args.output, f"page_{page}", page, args.model) + save_pkl_file(grid, args.output, f"page_{page}", args.model)