Skip to content

Commit

Permalink
Sketched set api
Browse files Browse the repository at this point in the history
  • Loading branch information
JaumeAmoresDS committed Sep 29, 2024
1 parent 9ac2bbc commit 9a9b0c0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
22 changes: 22 additions & 0 deletions nbmodular/cell2func.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def __init__(
self.is_class = False
self.is_pipeline = False
self.pipeline_name_or_default = None
self.api = True
self.keep_original_in_documentation = False
super().__init__(
**kwargs,
)
Expand Down Expand Up @@ -1808,6 +1810,18 @@ def __init__(
default=None,
help="Restrict inputs to those specified.",
)
self.parser.add_argument(
"--api",
action="store_true",
default=None,
help="Export the function as a library (API) function.",
)
self.parser.add_argument(
"--keep-original-in-documentation",
action="store_true",
default=None,
help="Keep the original cell text in the documentation.",
)
self.add_opposite_actions()

def add_opposite_actions(self):
Expand Down Expand Up @@ -1926,6 +1940,10 @@ def set_value(self, attr, value, convert=True):
f"Cell processor has no attribute {attr}. Existing attributes are:\n{list_of_attributes}"
)

def set_api (self, value):
self.api = value
self.default_api = value

def debug_function(
self, call_history=None, idx=None, name=None, test=False, data=False, **kwargs
):
Expand Down Expand Up @@ -3474,6 +3492,10 @@ def set(self, line):
attr, value = values
self.processor.set_value(attr, value)

@line_magic
def keep_original(self, line):
self.processor.set_api (False)


# %% ../nbs/cell2func.ipynb 74
def load_ipython_extension(ipython):
Expand Down
59 changes: 47 additions & 12 deletions nbmodular/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def set_paths_nb_processor(
nb_processor: "NbMagicExporter | Bunch",
path: str | Path,
code_cells_path: str | Path = ".nbmodular",
changed_path: str | Path | None = None,
) -> None:
"""
Set the paths for the notebook processor.
Expand Down Expand Up @@ -278,14 +279,16 @@ def __init__(
log_level="INFO",
from_notebook=False,
restrict_inputs=False,
api=True,
):
nb = read_nb(path) if nb is None else nb
super().__init__(nb)
self.logger = create_or_get_logger(logger_name, log_level)
self.logger.info(f"Analyzing code from notebook {path}")
self.cell_processor = CellProcessor(
path=path, run=from_notebook, restrict_inputs=restrict_inputs
path=path, run=from_notebook, restrict_inputs=restrict_inputs, api=api
)
self.cell_processor.change_file_name = False
self.cell_processor.set_run_tests(False)
self.from_notebook = from_notebook

Expand All @@ -311,6 +314,16 @@ def cell(self, cell):
add_call=True,
is_class=command == "class",
)
elif len(source_lines) > 0 and source_lines[0].strip().startswith("%"):
line = source_lines[0]
source = "\n".join(source_lines[1:])
command, remaining_line = line.split()[0]
# function_name, kwargs = self.cell_processor.parse_signature(remaining_line)
if command == "keep_original":
self.cell_processor.api = False
elif command == "file_path":
self.cell_processor.set_file_path(remaining_line)
self.cell_processor.change_file_name = True


# %% ../nbs/export.ipynb 33
Expand Down Expand Up @@ -370,26 +383,38 @@ def __init__(
tab_size=4,
from_notebook=False,
restrict_inputs=False,
api=True,
):
nb = read_nb(path) if nb is None else nb
super().__init__(nb)
self.logger = create_or_get_logger(logger_name, log_level)
set_paths_nb_processor(self, path, code_cells_path=code_cells_path)
code_cells_file_name = (
self.file_name_without_extension
if code_cells_file_name is None
else code_cells_file_name
)

self.logger.info(f"Analyzing code from notebook {self.path}")
self.logger.info(f"Analyzing code from notebook {path}")
self.nb_magic_processor = NbMagicProcessor(
path,
nb=nb,
logger_name=logger_name,
log_level=log_level,
from_notebook=from_notebook,
restrict_inputs=restrict_inputs,
api=api,
)
set_paths_nb_processor(
self,
path,
code_cells_path=code_cells_path,
changed_path=(
self.nb_magic_processor.file_path
if self.nb_magic_processor.change_file_name
else None
),
)
code_cells_file_name = (
self.file_name_without_extension
if code_cells_file_name is None
else code_cells_file_name
)

NBProcessor(path, self.nb_magic_processor, rm_directives=False, nb=nb).process()

self.function_names = {}
Expand Down Expand Up @@ -434,8 +459,12 @@ def cell(self, cell):
source_lines = cell.source.splitlines() if cell.cell_type == "code" else []
is_test = False
cell_type = "original"

if len(source_lines) > 0 and source_lines[0].strip().startswith("%%"):
keep_original_in_documentation = False
if (
self.nb_magic_processor.api
and len(source_lines) > 0
and source_lines[0].strip().startswith("%%")
):
line = source_lines[0]
source = "\n".join(source_lines[1:])
to_export = False
Expand Down Expand Up @@ -470,7 +499,10 @@ def cell(self, cell):
code_cell = code_cells[idx]
self.logger.debug("code:")
self.logger.debug(f"{code_cell.code}valid: {code_cell.valid}")
if code_cell.valid:
keep_original_in_documentation = (
code_cell.keep_original_in_documentation
)
if code_cell.valid and code_cell.api:
source = code_cell.code
to_export = True
elif line.startswith("%%include") or line.startswith("%%class"):
Expand All @@ -491,7 +523,10 @@ def cell(self, cell):
self.cells.append(new_cell)
cell_type = "code"
else:
doc_source = source # doc_source does not include first line with %% (? to think about)
if keep_original_in_documentation:
doc_source = cell.source
else:
doc_source = source # doc_source does not include first line with %% (? to think about)
if is_test:
doc_source = transform_test_source_for_docs(
code_cell.code, idx, self.tab_size
Expand Down
4 changes: 3 additions & 1 deletion nbmodular/jupynbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from logging import warn
import sys
from pathlib import Path
from typing import Tuple, Optional
from typing import List, Tuple, Optional
import os
import glob
import warnings
Expand Down Expand Up @@ -229,6 +229,8 @@ def parse_argv_and_run_jupynbm(argv: List[str]):


def jupynbm_export_cli():
# import pdb
# pdb.set_trace()
parse_argv_and_run_jupynbm(sys.argv[1:])


Expand Down

0 comments on commit 9a9b0c0

Please sign in to comment.