diff --git a/sparv/core/Snakefile b/sparv/core/Snakefile index 3f44ba55..3e74c377 100644 --- a/sparv/core/Snakefile +++ b/sparv/core/Snakefile @@ -31,7 +31,11 @@ selected_targets = config.get("targets", []) # Explicitly selected rule names # ============================================================================== def make_rules(config_missing: bool) -> None: - """Load all Sparv modules and create Snakemake rules.""" + """Load all Sparv modules and create Snakemake rules. + + Args: + config_missing: Whether the config file is missing. + """ # Get preloader info if config.get("socket") and not config.get("preloader"): from sparv.core import preload @@ -68,7 +72,15 @@ def make_rules(config_missing: bool) -> None: def make_rule(module_name: str, f_name: str, annotator_info: dict, config_missing: bool = False, custom_rule_obj: dict = None) -> None: - """Create single Snakemake rule.""" + """Create single Snakemake rule. + + Args: + module_name: Name of the module. + f_name: Name of the function. + annotator_info: Information about the annotator. + config_missing: Whether the config file is missing. + custom_rule_obj: Custom rule object. + """ # Init rule storage rule_storage = snake_utils.RuleStorage(module_name, f_name, annotator_info) @@ -113,7 +125,11 @@ def make_rule(module_name: str, f_name: str, annotator_info: dict, config_missin def make_all_files_rule(rule_storage: snake_utils.RuleStorage) -> None: - """Create named rule to run an annotation on all input files.""" + """Create named rule to run an annotation on all input files. + + Args: + rule_storage: Rule storage object. + """ # Only create rule when explicitly called if config.get("run_by_sparv") and rule_storage.target_name not in selected_targets: return diff --git a/sparv/core/config.py b/sparv/core/config.py index e0a68d27..159894d0 100644 --- a/sparv/core/config.py +++ b/sparv/core/config.py @@ -53,7 +53,17 @@ class Unset: def read_yaml(yaml_file: str | Path) -> dict: - """Read YAML file and handle errors.""" + """Read YAML file and handle errors. + + Args: + yaml_file: Path to YAML file. + + Returns: + Dictionary with parsed YAML data. + + Raises: + SparvErrorMessage: If the config can't be parsed or read. + """ # Handle dates as strings yaml.constructor.SafeConstructor.yaml_constructors["tag:yaml.org,2002:timestamp"] = ( yaml.constructor.SafeConstructor.yaml_constructors["tag:yaml.org,2002:str"] @@ -77,6 +87,9 @@ def load_config(config_file: str | None, config_dict: dict | None = None) -> Non Args: config_file: Path to corpus config file. If None, only the default config is read. config_dict: Get corpus config from dictionary instead of config file. + + Raises: + SparvErrorMessage: If the config can't be parsed. """ assert not (config_file and config_dict), "config_file and config_dict can not be used together" # Read default config @@ -93,7 +106,15 @@ def load_config(config_file: str | None, config_dict: dict | None = None) -> Non _config_user = read_yaml(config_file) or {} def handle_parents(cfg: dict, current_dir: Path = Path()) -> dict: - """Combine parent configs recursively.""" + """Combine parent configs recursively. + + Args: + cfg: Config dictionary. + current_dir: Current directory. + + Returns: + Combined config. + """ combined_parents = {} if cfg.get(PARENT): parents = cfg[PARENT] @@ -130,14 +151,29 @@ def handle_parents(cfg: dict, current_dir: Path = Path()) -> dict: def _get(name: str, config_dict: dict | None = None) -> Any: - """Try to get value from config, raising an exception if key doesn't exist.""" + """Try to get value from config, raising an exception if key doesn't exist. + + Args: + name: Config key to look up. + config_dict: Dictionary to look up key in. If None, the global config is used. + + Returns: + The value of the config key. If the key is not found, a KeyError is raised. + """ config_dict = config_dict if config_dict is not None else config # Handle dot notation return reduce(lambda c, k: c[k], name.split("."), config_dict) def set_value(name: str, value: Any, overwrite: bool = True, config_dict: dict | None = None) -> None: - """Set value in config, possibly using dot notation.""" + """Set value in config, possibly using dot notation. + + Args: + name: Config key to set. + value: Value to set. + overwrite: If False, only set value if key doesn't exist. + config_dict: Dictionary to set key in. If None, the global config is used. + """ keys = name.split(".") prev = config_dict if config_dict is not None else config for key in keys[:-1]: @@ -150,7 +186,15 @@ def set_value(name: str, value: Any, overwrite: bool = True, config_dict: dict | def get(name: str, default: Any = None) -> Any: - """Get value from config, or return the supplied 'default' if key doesn't exist.""" + """Get value from config, or return the supplied 'default' if key doesn't exist. + + Args: + name: Config key to look up. + default: Value to return if key doesn't exist. + + Returns: + The value of the config key, or the default value if the key is not found. + """ try: return _get(name) except KeyError: @@ -158,8 +202,12 @@ def get(name: str, default: Any = None) -> Any: def set_default(name: str, default: Any = None) -> None: - """Set default value for config variable.""" - # If config variable is already set to None but we get a better default value, replace the existing + """Set config value to default if key is not already set, or if it is set to None. + + Args: + name: Config key. + default: Value to set if key is not already set. + """ if default is not None: try: if _get(name) is None: @@ -171,12 +219,20 @@ def set_default(name: str, default: Any = None) -> None: def extend_config(new_config: dict) -> None: - """Extend existing config with new values for missing keys.""" + """Extend existing config with new values for missing keys. + + Args: + new_config: Dictionary with new config values. + """ _merge_dicts(config, new_config) def update_config(new_config: dict) -> None: - """Update existing config with new values, replacing existing values.""" + """Update existing config with new values, replacing existing values. + + Args: + new_config: Dictionary with new config values. + """ _merge_dicts_replace(config, new_config) @@ -218,7 +274,12 @@ def _merge_dicts_replace(d: dict, new_dict: dict) -> None: def add_to_structure(cfg: Config, annotator: str | None = None) -> None: - """Add config variable to config structure.""" + """Add config variable to config structure. + + Args: + cfg: Config object to add. + annotator: Name of annotator using the config. + """ set_value( cfg.name, { @@ -233,23 +294,46 @@ def add_to_structure(cfg: Config, annotator: str | None = None) -> None: def get_config_description(name: str) -> str | None: - """Get description for config key.""" + """Get description for config key. + + Args: + name: Config key. + + Returns: + Description of the config key. + """ cfg = _get(name, config_structure).get("_cfg") return cfg.description if cfg else None def get_config_object(name: str) -> Config | None: - """Get original Config object for config key.""" + """Get original Config object for config key. + + Args: + name: Config key. + + Returns: + Config object for the config key. + """ return _get(name, config_structure).get("_cfg") def add_config_usage(config_key: str, annotator: str) -> None: - """Add an annotator to the list of annotators that are using a given config key.""" + """Add an annotator to the list of annotators that are using a given config key. + + Args: + config_key: Config key. + annotator: Name of annotator using the config key. + """ config_usage[config_key].add(annotator) def validate_module_config() -> None: - """Make sure that modules don't try to access undeclared config keys.""" + """Make sure that modules don't try to access undeclared config keys. + + Raises: + SparvErrorMessage: If an annotator tries to access a config key that isn't declared anywhere. + """ for config_key in config_usage: try: _get(config_key, config_structure) @@ -262,7 +346,15 @@ def validate_module_config() -> None: def load_presets(lang: str, lang_variety: str | None) -> dict: - """Read presets files and return dictionaries with all available presets annotations and preset classes.""" + """Read presets files and return dictionaries with all available preset annotations and preset classes. + + Args: + lang: Language code. + lang_variety: Language variety. + + Returns: + Dictionary with all available preset annotations and preset classes. + """ class_dict = {} full_lang = lang if lang_variety: @@ -320,7 +412,7 @@ def resolve_presets(annotations: list[str], class_dict: dict) -> tuple[list[str] def apply_presets() -> None: - """Resolve annotations from presets and set preset classes.""" + """Resolve annotations from presets in all annotation lists, and set preset classes.""" # Load annotation presets and classes class_dict = load_presets(get("metadata.language"), get("metadata.variety")) all_preset_classes = {} @@ -345,7 +437,11 @@ def apply_presets() -> None: def handle_text_annotation() -> None: - """Copy text annotation to text class.""" + """Copy text annotation to text class. + + Raises: + SparvErrorMessage: If classes.text and import.text_annotation have different values. + """ text_ann = get("import.text_annotation") # Make sure that if both classes.text and import.text_annotation are set, that they have the same value @@ -360,7 +456,7 @@ def handle_text_annotation() -> None: def inherit_config(source: str, target: str) -> None: - """Let 'target' inherit config values from 'source' for evey key that is supported and not already populated. + """Let 'target' inherit config values from 'source' for every key that is supported and not already populated. Only keys which are either missing or with a value of None in the target will inherit the source's value, meaning that falsy values like empty strings or lists will not be overwritten. diff --git a/sparv/core/io.py b/sparv/core/io.py index 7b4d8ec4..121e3257 100644 --- a/sparv/core/io.py +++ b/sparv/core/io.py @@ -39,7 +39,15 @@ def annotation_exists(annotation: BaseAnnotation, source_file: str | None = None) -> bool: - """Check if an annotation file exists.""" + """Check if an annotation file exists. + + Args: + annotation: Annotation object to check. + source_file: Related source file. + + Returns: + True if the annotation file exists, False otherwise. + """ annotation_path = get_annotation_path(source_file or annotation.source_file, annotation, data=annotation.data) return annotation_path.exists() @@ -57,7 +65,10 @@ def write_annotation( ) -> None: """Write an annotation to one or more files. The file is overwritten if it exists. - The annotation should be a list of values. + Args: + source_file: Source filename. + annotation: Annotation object. + values: List of values to write. """ annotations = annotation.name.split() @@ -82,13 +93,18 @@ def write_annotation( annotation_values[annotation_name], annotation.root) -def _write_single_annotation( - source_file: str, - annotation: str, - values: list, - root: Path -) -> None: - """Write an annotation to a file.""" +def _write_single_annotation(source_file: str, annotation: str, values: list, root: Path) -> None: + """Write an annotation to a file. + + Args: + source_file: Source filename. + annotation: Annotation name. + values: List of values to write. + root: Root directory for the annotation. + + Raises: + SparvErrorMessage: If annotation spans are not sorted. + """ is_span = not split_annotation(annotation)[1] if is_span: @@ -117,7 +133,15 @@ def _write_single_annotation( def get_annotation_size(source_file: str, annotation: BaseAnnotation) -> int: - """Return number of lines in an annotation.""" + """Return number of lines in an annotation. + + Args: + source_file: Source filename. + annotation: Annotation object. + + Returns: + Number of values in the annotation. + """ count = 0 for ann in annotation.name.split(): @@ -133,7 +157,17 @@ def read_annotation_spans( decimals: bool = False, with_annotation_name: bool = False ) -> Iterator[tuple]: - """Iterate over the spans of an annotation.""" + """Iterate over the spans of an annotation. + + Args: + source_file: Source filename. + annotation: Annotation object. + decimals: Whether to return spans as decimals or not. Defaults to False. + with_annotation_name: Whether to yield the annotation name along with the value. + + Yields: + The annotation spans. If with_annotation_name is True, yields a tuple with the value and the annotation name. + """ # Strip any annotation attributes for span in read_annotation(source_file, annotation, with_annotation_name, spans=True): if not decimals: @@ -148,7 +182,17 @@ def read_annotation( with_annotation_name: bool = False, spans: bool = False ) -> Iterator: - """Yield each line from an annotation file.""" + """Yield each line from an annotation file. + + Args: + source_file: Source filename. + annotation: Annotation object. + with_annotation_name: Whether to yield the annotation name along with the value. + spans: Whether to read annotation spans or regular values. + + Yields: + The annotation values. If with_annotation_name is True, yields a tuple with the value and the annotation name. + """ annotations = [split_annotation(ann)[0] for ann in annotation.name.split()] if spans else annotation.name.split() root = annotation.root if len(annotations) == 1: @@ -177,7 +221,16 @@ def read_annotation( def read_annotation_attributes(source_file: str, annotations: list[BaseAnnotation] | tuple[BaseAnnotation, ...], with_annotation_name: bool = False) -> Iterator[tuple]: - """Yield tuples of multiple attributes on the same annotation.""" + """Yield tuples of multiple attributes on the same annotation. + + Args: + source_file: Source filename. + annotations: List of annotation objects. + with_annotation_name: Whether to yield the annotation name along with the value. + + Returns: + An iterator of tuples with the values of the attributes. + """ assert isinstance(annotations, (tuple, list)), "'annotations' argument must be tuple or list" assert len({split_annotation(annotation)[0] for annotation in annotations}) == 1, "All attributes need to be for the same annotation" @@ -191,7 +244,17 @@ def _read_single_annotation( with_annotation_name: bool, root: Path | None = None ) -> Iterator[Any]: - """Read a single annotation file.""" + """Read a single annotation file and yield each value. + + Args: + source_file: Source filename. + annotation: Annotation name. + with_annotation_name: Whether to yield the annotation name along with the value. + root: Root path. + + Yields: + The annotation values. If with_annotation_name is True, yields a tuple with the value and the annotation name. + """ ann_file = get_annotation_path(source_file, annotation, root) ctr = 0 @@ -202,7 +265,13 @@ def _read_single_annotation( def write_data(source_file: str | None, name: BaseAnnotation | str, value: Any) -> None: - """Write arbitrary data to file in workdir directory.""" + """Write arbitrary data to file in workdir directory. + + Args: + source_file: Source filename. + name: Annotation object or name. + value: Data to write. + """ file_path = get_annotation_path(source_file, name, data=True) file_path.parent.mkdir(parents=True, exist_ok=True) @@ -221,7 +290,15 @@ def write_data(source_file: str | None, name: BaseAnnotation | str, value: Any) def read_data(source_file: str | None, name: BaseAnnotation | str) -> Any: - """Read arbitrary data from file in workdir directory.""" + """Read arbitrary data from file in workdir directory. + + Args: + source_file: Source filename. + name: Annotation object or name. + + Returns: + The data read from the annotation. + """ file_path = get_annotation_path(source_file, name, data=True) data = next(read_annotation_file(file_path, is_data=True)) @@ -237,7 +314,14 @@ def read_data(source_file: str | None, name: BaseAnnotation | str) -> Any: def split_annotation(annotation: BaseAnnotation | str) -> tuple[str, str]: - """Split annotation into annotation name and attribute.""" + """Split annotation into annotation name and attribute. + + Args: + annotation: Annotation object or name. + + Returns: + Tuple with annotation name and attribute. + """ if isinstance(annotation, BaseAnnotation): annotation = annotation.name elem, _, attr = annotation.partition(ELEM_ATTR_DELIM) @@ -245,13 +329,31 @@ def split_annotation(annotation: BaseAnnotation | str) -> tuple[str, str]: def join_annotation(name: str, attribute: str | None) -> str: - """Join annotation name and attribute.""" + """Join annotation name and attribute. + + Args: + name: Annotation name. + attribute: Annotation attribute. + + Returns: + Annotation name joined with with attribute. + """ return ELEM_ATTR_DELIM.join((name, attribute)) if attribute else name def get_annotation_path(source_file: str | None, annotation: BaseAnnotation | str, root: Path | None = None, data: bool = False) -> Path: - """Construct a path to an annotation file given a source filename and annotation.""" + """Construct a path to an annotation file given a source filename and annotation. + + Args: + source_file: Source filename. + annotation: Annotation object or name. + root: Root path. + data: Whether the annotation is of the type data or not. + + Returns: + The path to the annotation file. + """ chunk = "" if source_file: source_file, _, chunk = source_file.partition(DOC_CHUNK_DELIM) @@ -273,7 +375,13 @@ def get_annotation_path(source_file: str | None, annotation: BaseAnnotation | st def write_annotation_file(file_path: Path, value: Any, is_data: bool = False) -> None: - """Write annotation data to a file.""" + """Write annotation data to a file. + + Args: + file_path: Path to the file to write. + value: Data to write. + is_data: Whether the value is of the type data. + """ chunk_size = 1000 opener = _compressed_open.get(compression, open) with opener(file_path, mode="wb") as f: @@ -285,7 +393,18 @@ def write_annotation_file(file_path: Path, value: Any, is_data: bool = False) -> def read_annotation_file(file_path: Path, is_data: bool = False) -> Iterator: - """Return an iterator for reading an annotation file.""" + """Return an iterator for reading an annotation file. + + Args: + file_path: Path to the file to read. + is_data: Whether the value is of the type data. + + Yields: + The annotation values. + + Raises: + SparvErrorMessage: If the file is not in the correct format. + """ opener = _compressed_open.get(compression, open) with opener(file_path, mode="rb") as f: try: diff --git a/sparv/core/log_handler.py b/sparv/core/log_handler.py index 24d041e4..5b62b4d1 100644 --- a/sparv/core/log_handler.py +++ b/sparv/core/log_handler.py @@ -121,6 +121,13 @@ class LogLevelCounterHandler(logging.Handler): """Handler that counts the number of log messages per log level.""" def __init__(self, count_dict: dict[str, int], *args: Any, **kwargs: Any) -> None: + """Initialize handler. + + Args: + count_dict: Dictionary to store the count of log messages per log level. + args: Additional arguments. + kwargs: Additional keyword arguments. + """ super().__init__(*args, **kwargs) self.levelcount = count_dict @@ -145,7 +152,14 @@ class InternalFilter(logging.Filter): @staticmethod def filter(record: logging.LogRecord) -> bool: - """Filter out internal records.""" + """Filter out internal records. + + Args: + record: Log record. + + Returns: + True if record is not internal, False otherwise. + """ return record.levelno < INTERNAL @@ -154,7 +168,14 @@ class ProgressInternalFilter(logging.Filter): @staticmethod def filter(record: logging.LogRecord) -> bool: - """Filter out progress and internal records.""" + """Filter out progress and internal records. + + Args: + record: Log record. + + Returns: + True if record is not progress or internal, False otherwise. + """ return record.levelno < PROGRESS @@ -162,6 +183,14 @@ class InternalLogHandler(logging.Handler): """Handler for internal log messages.""" def __init__(self, export_dirs_list: set, progress_: progress.Progress, jobs: OrderedDict, job_ids: dict) -> None: + """Initialize handler. + + Args: + export_dirs_list: Set to be updated with export directories. + progress_: Progress bar object. + jobs: Dictionary of jobs. + job_ids: Translation from (Sparv task name, source file) to Snakemake job ID. + """ self.export_dirs_list = export_dirs_list self.progress: progress.Progress = progress_ self.jobs = jobs @@ -209,13 +238,26 @@ class ProgressWithTable(progress.Progress): """Progress bar with additional table.""" def __init__(self, all_tasks: dict, current_tasks: OrderedDict, max_len: int, *args: Any, **kwargs: Any) -> None: + """Initialize progress bar with table. + + Args: + all_tasks: Dictionary of all tasks. + current_tasks: Currently running tasks. + max_len: Maximum length of task names. + args: Additional arguments. + kwargs: Additional keyword arguments. + """ self.all_tasks = all_tasks self.current_tasks = current_tasks self.task_max_len = max_len super().__init__(*args, **kwargs) def get_renderables(self) -> Iterable[progress.RenderableType]: - """Get a number of renderables for the progress display.""" + """Get a number of renderables for the progress display. + + Yields: + Renderables for the progress display. + """ # Progress bar yield self.make_tasks_table(self.tasks[0:1]) @@ -414,34 +456,57 @@ def setup_bar(self) -> None: self.setup_loggers() def start_bar(self, total: int) -> None: - """Start progress bar.""" + """Start progress bar. + + Args: + total: Total number of tasks. + """ self.progress.update(self.bar, total=total) self.progress.start_task(self.bar) self.bar_started = True def info(self, msg: str) -> None: - """Print info message.""" + """Print info message. + + Args: + msg: Message to print. + """ if self.json: self.logger.log(FINAL, msg) else: console.print(Text(msg, style="green")) def warning(self, msg: str) -> None: - """Print warning message.""" + """Print warning message. + + Args: + msg: Message to print. + """ if self.json: self.logger.log(FINAL, msg) else: console.print(Text(msg, style="yellow")) def error(self, msg: str) -> None: - """Print error message.""" + """Print error message. + + Args: + msg: Message to print. + """ if self.json: self.logger.log(FINAL, msg) else: console.print(Text(msg, style="red")) def log_handler(self, msg: dict) -> None: - """Log handler for Snakemake displaying a progress bar.""" + """Log handler for Snakemake displaying a progress bar. + + Args: + msg: Log message dictionary. + + Raises: + BrokenPipeError: If a missing config variable is detected. This stops Snakemake. + """ def missing_config_message(source: str) -> None: """Create error message when config variables are missing.""" _variables = messages["missing_configs"][source] @@ -830,7 +895,15 @@ def setup_logging( file: str | None = None, job: str | None = None ) -> None: - """Set up logging with socket handler.""" + """Set up logging with socket handler. + + Args: + log_server: Tuple with host and port for logging server. + log_level: Log level for logging to stdout. + log_file_level: Log level for logging to file. + file: Source file name for current job. + job: Current task name. + """ # Set logger to use the lowest selected log level, but never higher than warning (we still want to count warnings) log_level = min(logging.WARNING, getattr(logging, log_level.upper()), getattr(logging, log_file_level.upper())) socket_logger = logging.getLogger("sparv") diff --git a/sparv/core/misc.py b/sparv/core/misc.py index c532ec6b..fda6ed9b 100644 --- a/sparv/core/misc.py +++ b/sparv/core/misc.py @@ -28,7 +28,14 @@ def __init__(self, message: str, module: str = "", function: str = "") -> None: def get_logger(name: str) -> logging.Logger: - """Get a logger that is a child of 'sparv.modules'.""" + """Get a logger that is a child of 'sparv.modules'. + + Args: + name: Name of the logger. + + Returns: + Logger object. + """ if not name.startswith("sparv.modules"): name = "sparv.modules." + name return logging.getLogger(name) @@ -51,6 +58,14 @@ def parse_annotation_list(annotation_names: Iterable[str] | None, all_annotation Plain annotations (without attributes) will be added if needed, unless add_plain_annotations is set to False. Make sure to disable add_plain_annotations if the annotation names may include classes or config variables. + + Args: + annotation_names: List of annotation names. + all_annotations: List of all available annotations. + add_plain_annotations: If True, add plain annotations to the list if they are not already included. + + Returns: + List of tuples with annotation names and export names. """ from sparv.api import Annotation diff --git a/sparv/core/paths.py b/sparv/core/paths.py index 1524722d..5a138318 100644 --- a/sparv/core/paths.py +++ b/sparv/core/paths.py @@ -15,7 +15,11 @@ def read_sparv_config() -> dict: - """Get Sparv data path from config file.""" + """Get Sparv data path from config file. + + Returns: + dict: Sparv config data. + """ data = {} if sparv_config_file.is_file(): try: @@ -27,7 +31,14 @@ def read_sparv_config() -> dict: def get_data_path(subpath: str | Path = "") -> Path | None: - """Get location of directory containing Sparv models, binaries and other files.""" + """Get location of directory containing Sparv models, binaries and other files. + + Args: + subpath: Optional subpath to append to data dir. + + Returns: + Path to data dir or data dir subpath. + """ global data_dir if not data_dir: diff --git a/sparv/core/preload.py b/sparv/core/preload.py index a40ba24a..69adbb45 100644 --- a/sparv/core/preload.py +++ b/sparv/core/preload.py @@ -4,6 +4,7 @@ import logging import multiprocessing +import multiprocessing.synchronize import os import pickle import socket @@ -52,6 +53,7 @@ def __init__( cleanup: Callable, shared: bool ) -> None: + """Initialize a preloader.""" self.function = function self.target = target self.preloader = preloader @@ -62,7 +64,15 @@ def __init__( def connect_to_socket(socket_path: str, timeout: bool = False) -> socket.socket: - """Connect to a socket and return it.""" + """Connect to a socket and return it. + + Args: + socket_path: Path to the socket file. + timeout: Whether to use a 1 second timeout when connecting or not. + + Returns: + A connected socket. + """ s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) if timeout: s.settimeout(1) @@ -73,7 +83,14 @@ def connect_to_socket(socket_path: str, timeout: bool = False) -> socket.socket: @contextmanager def socketcontext(socket_path: str) -> Iterator[socket.socket]: - """Context manager for socket.""" + """Context manager for socket. + + Args: + socket_path: Path to the socket file. + + Yields: + A connected socket. + """ s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s.connect(socket_path) try: @@ -83,7 +100,14 @@ def socketcontext(socket_path: str) -> Iterator[socket.socket]: def receive_data(sock: socket.socket) -> Any: - """Receive pickled data from socket and unpickle.""" + """Receive pickled data from socket and unpickle. + + Args: + sock: Socket object. + + Returns: + Unpickled data. + """ # Get data length buf_length = recvall(sock, 4) if not buf_length or len(buf_length) < 4: @@ -98,28 +122,54 @@ def receive_data(sock: socket.socket) -> Any: def send_data(sock: socket.socket, data: Any) -> None: - """Send pickled data over socket.""" + """Send pickled data over socket. + + Args: + sock: Socket object. + data: Data to send. + """ datap = pickle.dumps(data) sock.sendall(struct.pack(">I", len(datap))) sock.sendall(datap) def get_preloader_info(socket_path: str) -> dict: - """Get information about preloaded modules.""" + """Get information about preloaded modules. + + Args: + socket_path: Path to the socket file. + + Returns: + Information about preloaded modules. + """ with socketcontext(socket_path) as sock: send_data(sock, INFO) return receive_data(sock) def get_preloader_status(socket_path: str) -> Any: - """Get preloader status.""" + """Get preloader status. + + Args: + socket_path: Path to the socket file. + + Returns: + Preloader status. + """ with socketcontext(socket_path) as sock: send_data(sock, STATUS) return receive_data(sock) def stop(socket_path: str) -> bool: - """Send stop signal to Sparv preloader.""" + """Send stop signal to Sparv preloader. + + Args: + socket_path: Path to the socket file. + + Returns: + True if the preloader was succesfully stopped, False if the connection was refused. + """ try: with socketcontext(socket_path) as sock: send_data(sock, STOP) @@ -132,6 +182,13 @@ def recvall(sock: socket.socket, size: int) -> bytes | None: """Receive data of a specific size from socket. If 'size' number of bytes are not received, None is returned. + + Args: + sock: Socket object. + size: Number of bytes to receive. + + Returns: + Received data. """ buf = b"" while size: @@ -144,7 +201,15 @@ def recvall(sock: socket.socket, size: int) -> bytes | None: def handle(client_sock: socket.socket, annotators: dict[str, Preloader]) -> bool | None: - """Handle request and execute preloaded function.""" + """Handle request and execute preloaded function. + + Args: + client_sock: Client socket. + annotators: Dictionary of preloaded annotators. + + Returns: + False if stop signal received, otherwise None. + """ # Get data data = receive_data(client_sock) if data is None: @@ -208,7 +273,14 @@ def worker( annotators: dict[str, Preloader], stop_event: multiprocessing.synchronize.Event ) -> None: - """Listen to the socket server and handle incoming requests.""" + """Listen to the socket server and handle incoming requests. + + Args: + worker_no: Worker number. + server_socket: Server socket. + annotators: Dictionary of preloaded annotators. + stop_event: Event to signal when stopping. + """ log.info("Worker %d started", worker_no) # Load any non-shared preloaders @@ -234,8 +306,21 @@ def worker( client_sock.close() -def serve(socket_path: str, processes: int, storage: SnakeStorage, stop_signal: multiprocessing.Event) -> None: - """Start the Sparv preloader socket server.""" +def serve( + socket_path: str, processes: int, storage: SnakeStorage, stop_signal: multiprocessing.synchronize.Event +) -> None: + """Start the Sparv preloader socket server. + + Args: + socket_path: Path to the socket file. + processes: Number of processes to start. + storage: SnakeStorage object. + stop_signal: Event to signal when stopping. + + Raises: + SparvErrorMessage: If the socket already exists, or if an annotator in the preloader config is unknown, or if + the annotator doesn't support preloading. + """ socket_file = Path(socket_path) if socket_file.exists(): raise SparvErrorMessage(f"Socket {socket_path} already exists.") diff --git a/sparv/core/registry.py b/sparv/core/registry.py index b466e51d..e9c7c960 100644 --- a/sparv/core/registry.py +++ b/sparv/core/registry.py @@ -51,6 +51,7 @@ class Module: """Class holding data about Sparv modules.""" def __init__(self, name: str) -> None: + """Initialize module.""" self.name = name self.functions: dict[str, dict] = {} self.description = None @@ -61,6 +62,14 @@ class LanguageRegistry(dict): """Registry for supported languages.""" def add_language(self, lang: str) -> str: + """Add language to registry. + + Args: + lang: Language code plus optional suffix. + + Returns: + The full language name. + """ from sparv.api import util if lang not in self: langcode, _, suffix = lang.partition("-") @@ -121,6 +130,9 @@ def find_modules(no_import: bool = False, find_custom: bool = False) -> list: Returns: A list of available module names. + + Raises: + SparvErrorMessage: If a module cannot be imported due to an error. """ from importlib_metadata import entry_points from packaging.requirements import Requirement @@ -196,7 +208,13 @@ def find_modules(no_import: bool = False, find_custom: bool = False) -> list: def add_module_to_registry(module: ModuleType, module_name: str, skip_language_check: bool = False) -> None: - """Add module and its annotators to registry.""" + """Add module and its annotators to registry. + + Args: + module: The Python module to add. + module_name: The name of the Sparv module. + skip_language_check: Set to True to skip checking of language compatibility. + """ if not skip_language_check and hasattr(module, "__language__"): # Add to set of supported languages... for lang in module.__language__: @@ -233,16 +251,39 @@ def add_module_to_registry(module: ModuleType, module_name: str, skip_language_c def wizard(config_keys: list[str], source_structure: bool = False) -> Callable: - """Return a wizard decorator.""" + """Return a wizard decorator. + + Args: + config_keys: A list of config keys to be set or changed by the decorated function. + source_structure: Set to `True` if the decorated function needs access to a SourceStructureParser instance + (holding information on the structure of the source files). + + Returns: + A decorator that adds the wrapped function to the wizard registry. + """ def decorator(f: Callable) -> Callable: - """Add wrapped function to wizard registry.""" + """Add wrapped function to wizard registry. + + Args: + f: The function to add to the wizard registry. + + Returns: + The function. + """ wizards.append((f, tuple(config_keys), source_structure)) return f return decorator def _get_module_name(module_string: str) -> str: - """Extract module name from dotted path, i.e. 'modulename.submodule' -> 'modulename'.""" + """Extract module name from dotted path, i.e. 'modulename.submodule' -> 'modulename'. + + Args: + module_string: Dotted path to module. + + Returns: + The module name. + """ if module_string.startswith(modules_path): # Built-in Sparv module module_name = module_string[len(modules_path) + 1:].split(".")[0] @@ -279,10 +320,50 @@ def _annotator( preloader_shared: bool = True, uninstaller: str | None = None, ) -> Callable: - """Return a decorator for annotator functions, adding them to annotator registry.""" + """Return a decorator for annotator functions, adding them to annotator registry. + + Args: + description: Description of annotator. + a_type: Type of annotator. + name: Optional name to use instead of the function name. + file_extension: (importer) The file extension of the type of source this importer handles, e.g. "xml" or + "txt". + outputs: (importer) A list of annotations and attributes that the importer is guaranteed to generate. + May also be a Config instance referring to such a list. + It may generate more outputs than listed, but only the annotations listed here will be available + to use as input for annotator functions. + text_annotation: (importer) An annotation from 'outputs' that should be used as the value for the + import.text_annotation config variable, unless it or classes.text has been set manually. + structure: (importer) A class used to parse and return the structure of source files. + language: List of supported languages. + config: List of Config instances defining config options for the annotator. + priority: Functions with higher priority (higher number) will be preferred when scheduling which functions to + run. The default priority is 0. + order: If several annotators have the same output, this integer value will help decide which to try to use + first. A lower number indicates higher priority. + abstract: (exporter) Set to True if this exporter has no output. + wildcards: List of wildcards used in the annotator function's arguments. + preloader: Reference to a preloader function, used to preload models or processes. + preloader_params: List of names of parameters for the annotator, which will be used as arguments for the + preloader. + preloader_target: The name of the annotator parameter which should receive the return value of the preloader. + preloader_cleanup: Reference to an optional cleanup function, which will be executed after each annotator use. + preloader_shared: Set to False if the preloader result should not be shared among preloader processes. + uninstaller: (installer) Name of related uninstaller. + + Returns: + A decorator adding the wrapped function to the annotator registry. + """ def decorator(f: Callable) -> Callable: - """Add wrapped function to registry.""" + """Add wrapped function to registry. + + Args: + f: The function to add to the registry. + + Returns: + The function. + """ module_name = _get_module_name(f.__module__) _potential_annotators[module_name].append( { @@ -328,7 +409,28 @@ def annotator( preloader_cleanup: Callable | None = None, preloader_shared: bool = True, ) -> Callable: - """Return a decorator for annotator functions, adding them to the annotator registry.""" + """Return a decorator for annotator functions, adding them to the annotator registry. + + Args: + description: Description of annotator. + name: Optional name to use instead of the function name. + language: List of supported languages. + config: List of Config instances defining config options for the annotator. + priority: Functions with higher priority (higher number) will be preferred when scheduling which functions to + run. The default priority is 0. + order: If several annotators have the same output, this integer value will help decide which to try to use + first. A lower number indicates higher priority. + wildcards: List of wildcards used in the annotator function's arguments. + preloader: Reference to a preloader function, used to preload models or processes. + preloader_params: List of names of parameters for the annotator, which will be used as arguments for the + preloader. + preloader_target: The name of the annotator parameter which should receive the return value of the preloader. + preloader_cleanup: Reference to an optional cleanup function, which will be executed after each annotator use. + preloader_shared: Set to False if the preloader result should not be shared among preloader processes. + + Returns: + A decorator. + """ return _annotator( description=description, a_type=Annotator.annotator, @@ -365,7 +467,7 @@ def importer(description: str, file_extension: str, name: str | None = None, out config: List of Config instances defining config options for the importer. Returns: - A decorator + A decorator. """ return _annotator(description=description, a_type=Annotator.importer, name=name, file_extension=file_extension, outputs=outputs, text_annotation=text_annotation, structure=structure, config=config) @@ -387,7 +489,10 @@ def exporter( name: Optional name to use instead of the function name. config: List of Config instances defining config options for the exporter. language: List of supported languages. + priority: Functions with higher priority (higher number) will be preferred when scheduling which functions to + run. The default priority is 0. order: If several exporters have the same output, this integer value will help decide which to try to use first. + A lower number indicates higher priority. abstract: Set to True if this exporter has no output. Returns: @@ -413,7 +518,20 @@ def installer( priority: int | None = None, uninstaller: str | None = None, ) -> Callable: - """Return a decorator for installer functions.""" + """Return a decorator for installer functions. + + Args: + description: Description of installer. + name: Optional name to use instead of the function name. + config: List of Config instances defining config options for the installer. + language: List of supported languages. + priority: Functions with higher priority (higher number) will be preferred when scheduling which functions to + run. The default priority is 0. + uninstaller: Name of related uninstaller. + + Returns: + A decorator. + """ return _annotator( description=description, a_type=Annotator.installer, @@ -432,7 +550,19 @@ def uninstaller( language: list[str] | None = None, priority: int | None = None, ) -> Callable: - """Return a decorator for uninstaller functions.""" + """Return a decorator for uninstaller functions. + + Args: + description: Description of uninstaller. + name: Optional name to use instead of the function name. + config: List of Config instances defining config options for the uninstaller. + language: List of supported languages. + priority: Functions with higher priority (higher number) will be preferred when scheduling which functions to + run. The default priority is 0. + + Returns: + A decorator. + """ return _annotator( description=description, a_type=Annotator.uninstaller, @@ -451,7 +581,21 @@ def modelbuilder( priority: int | None = None, order: int | None = None, ) -> Callable: - """Return a decorator for modelbuilder functions.""" + """Return a decorator for modelbuilder functions. + + Args: + description: Description of modelbuilder. + name: Optional name to use instead of the function name. + config: List of Config instances defining config options for the modelbuilder. + language: List of supported languages. + priority: Functions with higher priority (higher number) will be preferred when scheduling which functions to + run. The default priority is 0. + order: If several modelbuilders have the same output, this integer value will help decide which to try to use + first. A lower number indicates higher priority. + + Returns: + A decorator. + """ return _annotator( description=description, a_type=Annotator.modelbuilder, @@ -464,7 +608,15 @@ def modelbuilder( def _add_to_registry(annotator: dict, skip_language_check: bool = False) -> None: - """Add function to annotator registry. Used by annotator.""" + """Add function to annotator registry. Used by annotator. + + Args: + annotator: Annotator data. + skip_language_check: Set to True to skip checking of language compatibility. + + Raises: + SparvErrorMessage: On any expected errors. + """ module_name = annotator["module_name"] f_name = annotator["function"].__name__ if not annotator["name"] else annotator["name"] rule_name = f"{module_name}:{f_name}" @@ -597,7 +749,18 @@ def handle_config( rule_name: str | None = None, language: list[str] | None = None ) -> None: - """Handle Config instances.""" + """Handle Config instances. + + Args: + cfg: The Config instance. + module_name: The name of the module. + rule_name: The name of the rule using the config variable. + language: List of supported languages. + + Raises: + SparvErrorMessage: If the config variable doesn't include the module name as prefix, or if the config variable + has already been declared, or if the config variable is missing a description. + """ if not cfg.name.startswith(module_name + "."): raise SparvErrorMessage(f"Config option '{cfg.name}' in module '{module_name}' doesn't include module " "name as prefix.") @@ -631,9 +794,15 @@ def handle_config( def _expand_class(cls: str) -> str | None: - """Convert class name to annotation. + """Convert class name to annotation name. Classes from config takes precedence over classes automatically collected from modules. + + Args: + cls: The class name. + + Returns: + The annotation name, or None if the class is not found. """ annotation = None if cls in annotation_classes["config_classes"]: @@ -646,7 +815,15 @@ def _expand_class(cls: str) -> str | None: def find_config_variables(string: str, match_objects: bool = False) -> list[str] | list[re.Match]: - """Find all config variables in a string and return a list of strings or match objects.""" + """Find all config variables in a string and return a list of strings or match objects. + + Args: + string: The string to process. + match_objects: Set to True to return match objects instead of strings. + + Returns: + A list of strings or match objects. + """ if match_objects: result = list(re.finditer(r"\[([^\]=[]+)(?:=([^\][]+))?\]", string)) else: @@ -655,7 +832,15 @@ def find_config_variables(string: str, match_objects: bool = False) -> list[str] def find_classes(string: str, match_objects: bool = False) -> list[str] | list[re.Match]: - """Find all class references in a string and return a list of strings or match objects.""" + """Find all class references in a string and return a list of strings or match objects. + + Args: + string: The string to process. + match_objects: Set to True to return match objects instead of strings. + + Returns: + A list of strings or match objects. + """ if match_objects: result = list(re.finditer(r"<([^>]+)>", string)) else: @@ -745,7 +930,14 @@ def expand_classes(s: str, parents: set[str]) -> tuple[str | None, str | None]: def get_type_hint_type(type_hint: Any) -> tuple[type, bool, bool]: - """Given a type hint, return the type, whether it's contained in a List and whether it's Optional.""" + """Given a type hint, return the type, whether it's contained in a List and whether it's Optional. + + Args: + type_hint: The type hint. + + Returns: + A tuple with the type, a boolean indicating whether it's a list and a boolean indicating whether it's optional. + """ optional = typing_inspect.is_optional_type(type_hint) if optional: type_hint = typing_inspect.get_args(type_hint)[0] @@ -770,6 +962,14 @@ def check_language(corpus_lang: str, langs: list[str], corpus_lang_suffix: str | If langs is empty, always return True. If corpus_lang is "__all__", always return True. + + Args: + corpus_lang: The language of the corpus. + langs: A list of languages to check against. + corpus_lang_suffix: Optional suffix for the corpus language. + + Returns: + True if the corpus language is among the languages, otherwise False. """ if not langs or corpus_lang == "__all__": return True diff --git a/sparv/core/run.py b/sparv/core/run.py index 77486f9f..a4426cbf 100644 --- a/sparv/core/run.py +++ b/sparv/core/run.py @@ -13,7 +13,12 @@ def main(argv: list[str] | None = None, log_level: str = "info") -> None: - """Parse command line arguments and execute the requested Sparv module.""" + """Parse command line arguments and execute the requested Sparv module. + + Args: + argv: List of command line arguments. + log_level: Log level. + """ # Set up logging logging.basicConfig(format=log_handler.LOG_FORMAT, datefmt=log_handler.DATE_FORMAT, level=log_level.upper(), stream=sys.stdout) diff --git a/sparv/core/run_snake.py b/sparv/core/run_snake.py index 5293a3ca..de804053 100644 --- a/sparv/core/run_snake.py +++ b/sparv/core/run_snake.py @@ -30,19 +30,35 @@ class StreamToLogger: """File-like stream object that redirects writes to a logger instance.""" def __init__(self, logger: logging.Logger, log_level: int = logging.INFO) -> None: + """Initialize file-like stream object with a logger and a log level. + + Args: + logger: Logger instance. + log_level: Log level. + """ self.logger = logger self.log_level = log_level def write(self, buf: str) -> None: + """Write to logger. + + Args: + buf: String to write. + """ self.logger.log(self.log_level, buf.rstrip()) @staticmethod def isatty() -> bool: + """Return False to indicate that this is not a terminal. + + Returns: + Always returns False. + """ return False @staticmethod def flush() -> None: - pass + """Do nothing; needed for compatibility with sys.stdout.""" # Set compression diff --git a/sparv/core/schema.py b/sparv/core/schema.py index 86d2e24d..848c97b3 100644 --- a/sparv/core/schema.py +++ b/sparv/core/schema.py @@ -20,6 +20,13 @@ class BaseProperty: """Base class for other types of properties.""" def __init__(self, prop_type: str | None, allow_null: bool | None = False, **kwargs: AnyType) -> None: + """Initialize the class. + + Args: + prop_type: The type of the property. + allow_null: If null values are allowed. + **kwargs: Additional keyword arguments. + """ self.schema = { "type": prop_type if not allow_null else [prop_type, "null"], **kwargs @@ -29,6 +36,11 @@ def __init__(self, prop_type: str | None, allow_null: bool | None = False, **kwa class Any(BaseProperty): """Class representing any type.""" def __init__(self, **kwargs: AnyType) -> None: + """Initialize the class. + + Args: + **kwargs: Additional keyword arguments. + """ super().__init__(None, **kwargs) @@ -43,6 +55,16 @@ def __init__( allow_null: bool = False, **kwargs: AnyType ) -> None: + """Initialize the class. + + Args: + pattern: A regex pattern. + choices: A list of possible choices. + min_len: The minimum length of the string. + max_len: The maximum length of the string. + allow_null: If null values are allowed. + **kwargs: Additional keyword arguments. + """ if pattern: kwargs["pattern"] = pattern if choices: @@ -64,6 +86,13 @@ def __init__( max_value: int | None = None, **kwargs: AnyType ) -> None: + """Initialize the class. + + Args: + min_value: The minimum value. + max_value: The maximum value. + **kwargs: Additional keyword arguments. + """ if min_value is not None: kwargs["minimum"] = min_value if max_value is not None: @@ -79,6 +108,13 @@ def __init__( max_value: int | float | None, **kwargs: AnyType ) -> None: + """Initialize the class. + + Args: + min_value: The minimum value. + max_value: The maximum value. + **kwargs: Additional keyword arguments. + """ if min_value is not None: kwargs["minimum"] = min_value if max_value is not None: @@ -89,12 +125,22 @@ def __init__( class Boolean(BaseProperty): """Class representing a boolean.""" def __init__(self, **kwargs: AnyType) -> None: + """Initialize the class. + + Args: + **kwargs: Additional keyword arguments. + """ super().__init__("boolean", **kwargs) class Null(BaseProperty): """Class representing a null value.""" def __init__(self, **kwargs: AnyType) -> None: + """Initialize the class. + + Args: + **kwargs: Additional keyword arguments. + """ super().__init__("null", **kwargs) @@ -105,6 +151,12 @@ def __init__( items: type[String | Integer | Number | Boolean | Null | Any | Array | Object] | None = None, **kwargs: AnyType ) -> None: + """Initialize the class. + + Args: + items: The type of items in the array. + **kwargs: Additional keyword arguments. + """ if items: if isinstance(items, list): kwargs["items"] = {"type": []} @@ -123,6 +175,13 @@ def __init__( self, additional_properties: dict | bool = True, description: str | None = None, **kwargs: AnyType ) -> None: + """Initialize the class. + + Args: + additional_properties: If additional properties are allowed. + description: A description of the object. + **kwargs: Additional keyword arguments. + """ if additional_properties is False or isinstance(additional_properties, dict): kwargs["additionalProperties"] = additional_properties if description: @@ -133,14 +192,35 @@ def __init__( self.allof: defaultdict[tuple[tuple[Object, ...], tuple[Object, ...]], list] = defaultdict(list) def __hash__(self) -> int: + """Return a hash of the schema. + + Returns: + A hash of the schema. + """ return hash(json.dumps(self.schema, sort_keys=True)) def __eq__(self, other: Object) -> bool: + """Compare two objects based on their schema. + + Args: + other: The object to compare with. + + Returns: + True if the schema of the current object is equal to the schema of the other object. + """ if other is None: return False return json.dumps(self.schema, sort_keys=True) == json.dumps(other.schema, sort_keys=True) def __lt__(self, other: Object) -> bool: + """Compare two objects based on their schema. + + Args: + other: The object to compare with. + + Returns: + True if the schema of the current object is less than the schema of the other object. + """ if other is None: return False return json.dumps(self.schema, sort_keys=True) < json.dumps(other.schema, sort_keys=True) @@ -152,7 +232,17 @@ def add_property( required: bool = False, condition: tuple[tuple[Object, ...], tuple[Object, ...]] | None = None ) -> Object: - """Add a property to the object.""" + """Add a property to the object. + + Args: + name: The name of the property. + prop_obj: The property object. + required: If the property is required. + condition: A tuple with two tuples of conditions (positive and negative). + + Returns: + The object itself. + """ if condition and condition != NO_COND: self.allof[condition].append((name, prop_obj)) else: @@ -227,7 +317,14 @@ def to_json(self) -> str: def get_class_from_type(t: type) -> type: - """Get JSON schema class from Python type.""" + """Get JSON schema class from Python type. + + Args: + t: A Python type. + + Returns: + A JSON schema class. + """ types = { str: String, int: Integer, @@ -242,7 +339,14 @@ def get_class_from_type(t: type) -> type: def build_json_schema(config_structure: dict) -> dict: - """Build a JSON schema based on Sparv's config structure.""" + """Build a JSON schema based on Sparv's config structure. + + Args: + config_structure: A dictionary with info about the structure of the config file. + + Returns: + A dictionary with the JSON schema. + """ schema = JsonSchema() def handle_object( @@ -253,7 +357,18 @@ def handle_object( ) -> defaultdict[tuple[tuple[Object | None, ...], tuple[Object, ...]], list]: """Handle dictionary which will become an object in the JSON schema. - Return a dictionary with conditionals as keys and lists of children to each conditional as values. + Args: + structure: The dictionary to handle. + parent_obj: The parent object. + parent_name: The name of the parent object. + is_condition: If this object is a condition. + + Returns: + A dictionary with conditionals as keys and lists of children to each conditional as values. + + Raises: + ValueError: If the datatype is not supported. + SparvErrorMessage: If an unknown error occurs. """ conditionals: defaultdict[tuple[tuple[Object | None, ...], tuple[Object, ...]], list] = defaultdict(list) @@ -336,6 +451,9 @@ def handle_property( Returns: A tuple with two values. The first is either a datatype object or a list of datatype objects, and the second is a tuple of conditions (possible empty). + + Raises: + ValueError: If the datatype is not supported. """ kwargs = {} if cfg.description: @@ -425,7 +543,15 @@ def handle_property( def validate(cfg: dict, schema: dict) -> None: - """Validate a Sparv config using JSON schema.""" + """Validate a Sparv config using JSON schema. + + Args: + cfg: The config to validate. + schema: The JSON schema to validate against. + + Raises: + SparvErrorMessage: If the config is invalid. + """ import jsonschema # noqa: PLC0415 def build_path_string(path: Sequence) -> str: diff --git a/sparv/core/setup.py b/sparv/core/setup.py index 6605dabd..33cf3ca4 100644 --- a/sparv/core/setup.py +++ b/sparv/core/setup.py @@ -35,7 +35,11 @@ def check_sparv_version() -> bool | None: def copy_resource_files(data_dir: pathlib.Path) -> None: - """Copy resource files to data dir.""" + """Copy resource files to data dir. + + Args: + data_dir: Path to the data directory. + """ resources_dir = importlib.resources.files("sparv") / "resources" with importlib.resources.as_file(resources_dir) as path: for f in path.rglob("*"): @@ -73,7 +77,11 @@ def reset() -> None: def run(sparv_datadir: str | None = None) -> None: - """Query user about data dir path unless provided by argument, and populate path with files.""" + """Query user about data dir path unless provided by argument, and populate path with files. + + Args: + sparv_datadir: Path to the data directory. + """ default_dir = pathlib.Path(appdirs.user_data_dir("sparv")) current_dir = paths.get_data_path() path: pathlib.Path diff --git a/sparv/core/snake_prints.py b/sparv/core/snake_prints.py index b8f81426..833a58b5 100644 --- a/sparv/core/snake_prints.py +++ b/sparv/core/snake_prints.py @@ -16,7 +16,11 @@ def prettyprint_yaml(in_dict: dict) -> None: - """Pretty-print YAML.""" + """Pretty-print YAML. + + Args: + in_dict: Dictionary to print. + """ from rich.syntax import Syntax from sparv.api.util.misc import dump_yaml @@ -26,7 +30,12 @@ def prettyprint_yaml(in_dict: dict) -> None: def print_modules_summary(snake_storage: snake_utils.SnakeStorage, json_output: bool = False) -> None: - """Print a summary of all annotation modules.""" + """Print a summary of all annotation modules. + + Args: + snake_storage: SnakeStorage object. + json_output: Print output as JSON. + """ all_module_types = { "annotators": snake_storage.all_annotators, "importers": snake_storage.all_importers, @@ -80,7 +89,16 @@ def print_modules_info( json_output: bool = False, include_params: bool = False ) -> None: - """Print full info for chosen module_types and module_names.""" + """Print full info for chosen module_types and module_names. + + Args: + module_types: List of module types to print. + module_names: List of module names to print. + snake_storage: SnakeStorage object. + reverse_config_usage: Dictionary with config usage. + json_output: Print output as JSON. + include_params: Include parameters in output. + """ all_module_types = { "annotators": snake_storage.all_annotators, "importers": snake_storage.all_importers, @@ -265,7 +283,11 @@ def tuple_representer(dumper: yaml.Dumper, data: tuple) -> yaml.SequenceNode: def _print_modules(modules_data: dict) -> None: - """Pretty print module information.""" + """Pretty print module information. + + Args: + modules_data: Dictionary with module information. + """ # Box styles left_line = box.Box(" \n┃ \n┃ \n┃ \n┃ \n┃ \n┃ \n ") minimal = box.Box(" \n │ \n╶─┼╴\n │ \n╶─┼╴\n╶─┼╴\n │ \n \n") @@ -493,12 +515,21 @@ def print_languages() -> None: def get_custom_module_description(name: str) -> str: - """Return string with description for custom modules.""" + """Return string with description for custom modules. + + Args: + name: Name of the custom module. + """ return "Custom module from corpus directory ({}.py).".format(name.split(".")[1]) def print_installers(snake_storage: snake_utils.SnakeStorage, uninstall: bool = False) -> None: - """Print list of installers or uninstallers.""" + """Print list of installers or uninstallers. + + Args: + snake_storage: SnakeStorage object. + uninstall: Print uninstallers instead of installers. + """ if uninstall: targets = snake_storage.uninstall_targets prefix = "un" diff --git a/sparv/core/snake_utils.py b/sparv/core/snake_utils.py index 904f7551..9db6d1ef 100644 --- a/sparv/core/snake_utils.py +++ b/sparv/core/snake_utils.py @@ -166,6 +166,12 @@ def rule_helper(rule: RuleStorage, config: dict, storage: SnakeStorage, config_m storage: Object for saving information for all rules. config_missing: True if there is no corpus config file. custom_rule_obj: Custom annotation dictionary from corpus config. + + Returns: + True if a Snakemake rule should be created, otherwise False. + + Raises: + SparvErrorMessage: On assorted errors. """ # Only create certain rules when config is missing if config_missing and not rule.modelbuilder: @@ -616,9 +622,24 @@ def rule_helper(rule: RuleStorage, config: dict, storage: SnakeStorage, config_m def name_custom_rule(rule: RuleStorage, storage: SnakeStorage) -> None: - """Create unique name for custom rule.""" + """Create unique name for custom rule. + + If the rule name already exists, a numerical suffix is added to the name. + + Args: + rule: RuleStorage object. + storage: SnakeStorage object. + """ def get_new_suffix(name: str, existing_names: list[str]) -> str: - """Find a numerical suffix that leads to a unique rule name.""" + """Find a numerical suffix that leads to a unique rule name. + + Args: + name: Base name for the rule. + existing_names: List of existing rule names. + + Returns: + A numerical suffix that leads to a unique rule name. + """ i = 2 new_name = name + str(i) while new_name in existing_names: @@ -636,7 +657,14 @@ def get_new_suffix(name: str, existing_names: list[str]) -> str: def check_ruleorder(storage: SnakeStorage) -> set[tuple[RuleStorage, RuleStorage]]: - """Order rules where necessary and print warning if rule order is missing.""" + """Order rules where necessary and print warning if rule order is missing. + + Args: + storage: SnakeStorage object. + + Returns: + A set of tuples with ordered rules. + """ ruleorder_pairs = set() ordered_rules = set() # Find rules that have common outputs and therefore need to be ordered @@ -664,14 +692,28 @@ def check_ruleorder(storage: SnakeStorage) -> set[tuple[RuleStorage, RuleStorage def file_value(rule_params: RuleStorage) -> Callable: - """Get source filename for use as parameter to rule.""" + """Get source filename for use as parameter to rule. + + Args: + rule_params: RuleStorage object. + + Returns: + Function that returns the source filename. + """ def _file_value(wildcards: snakemake.io.Wildcards) -> str | None: return get_file_value(wildcards, rule_params.annotator) return _file_value def get_parameters(rule_params: RuleStorage) -> Callable: - """Extend function parameters with source filenames and replace wildcards.""" + """Extend function parameters with source filenames and replace wildcards. + + Args: + rule_params: RuleStorage object. + + Returns: + Function that returns the parameters for the rule. + """ def get_params(wildcards: snakemake.io.Wildcards) -> dict: file = get_file_value(wildcards, rule_params.annotator) # We need to make a copy of the parameters, since the rule might be used for multiple source files @@ -712,7 +754,12 @@ def get_params(wildcards: snakemake.io.Wildcards) -> dict: def update_storage(storage: SnakeStorage, rule: RuleStorage) -> None: - """Update info to snake storage with different targets.""" + """Update info to snake storage with different targets. + + Args: + storage: SnakeStorage object. + rule: RuleStorage object. + """ if rule.exporter: storage.export_targets.append((rule.target_name, rule.description, rule.annotator_info["language"])) @@ -732,12 +779,25 @@ def update_storage(storage: SnakeStorage, rule: RuleStorage) -> None: def get_source_path() -> str: - """Get path to source files.""" + """Get path to source files. + + Returns: + Path to source files. + """ return sparv_config.get("import.source_dir") def get_annotation_path(annotation: str | BaseAnnotation, data: bool = False, common: bool = False) -> Path: - """Construct a path to an annotation file given an annotation name.""" + """Construct a path to an annotation file given an annotation name. + + Args: + annotation: Annotation name or BaseAnnotation object. + data: Set to True if the annotation is of the data type. + common: Set to True if the annotation is a common annotation for the whole corpus. + + Returns: + Path to the annotation file. + """ if not isinstance(annotation, BaseAnnotation): annotation = BaseAnnotation(annotation) elem, attr = annotation.split() @@ -754,22 +814,52 @@ def get_annotation_path(annotation: str | BaseAnnotation, data: bool = False, co def get_file_values(config: dict, snake_storage: SnakeStorage) -> list[str]: - """Get a list of files represented by the {file} wildcard.""" + """Get a list of files represented by the {file} wildcard. + + Args: + config: Dictionary containing the corpus configuration. + snake_storage: SnakeStorage object. + + Returns: + List of files represented by the {file} wildcard. + """ return config.get("file") or snake_storage.source_files def get_wildcard_values(config: dict) -> dict: - """Get user-supplied wildcard values.""" + """Get user-supplied wildcard values. + + Args: + config: Dictionary containing the corpus configuration. + + Returns: + Dictionary with wildcard values. + """ return dict(wc.split("=") for wc in config.get("wildcards", [])) def escape_wildcards(s: Path | str) -> str: - """Escape all wildcards other than {file}.""" + """Escape all wildcards other than {file}. + + Args: + s: Path or string to escape. + + Returns: + Escaped string. + """ return re.sub(r"(?!{file})({[^}]+})", r"{\1}", str(s)) def get_file_value(wildcards: snakemake.io.Wildcards, annotator: bool) -> str | None: - """Extract the {file} part from full annotation path.""" + """Extract the {file} part from full annotation path. + + Args: + wildcards: Wildcards object. + annotator: True if the rule is an annotator. + + Returns: + The value of {file}. + """ file = None if hasattr(wildcards, "file"): if annotator: @@ -780,7 +870,14 @@ def get_file_value(wildcards: snakemake.io.Wildcards, annotator: bool) -> str | def load_config(snakemake_config: dict) -> bool: - """Load corpus config and override the corpus language (if needed).""" + """Load corpus config and override the corpus language (if needed). + + Args: + snakemake_config: Snakemake config dictionary. + + Returns: + True if the corpus config is missing. + """ # Find corpus config corpus_config_file = Path.cwd() / paths.config_file if corpus_config_file.is_file(): @@ -806,7 +903,19 @@ def get_install_outputs( install_types: list | None = None, uninstall: bool = False ) -> list[Path]: - """Collect files to be created for all (un)installations given as argument or listed in config.(un)install.""" + """Collect files to be created for all (un)installations given as argument or listed in config.(un)install. + + Args: + snake_storage: SnakeStorage object. + install_types: List of (un)installation types. + uninstall: True if uninstallation files should be collected instead of installation files. + + Returns: + List of files to be created by the selected (un)installations. + + Raises: + SparvErrorMessage: If unknown (un)installation types are given. + """ unknown = [] install_outputs = [] @@ -844,7 +953,20 @@ def get_export_targets( file: list[str], wildcards: dict ) -> list: - """Get export targets from sparv_config.""" + """Get export targets from sparv_config. + + Args: + snake_storage: SnakeStorage object. + workflow: Snakemake workflow object. + file: List of files represented by the {file} wildcard. + wildcards: Dictionary with wildcard values. + + Returns: + List of export targets. + + Raises: + SparvErrorMessage: If unknown output formats are specified in export.default. + """ all_outputs = [] config_exports = set(sparv_config.get("export.default", [])) @@ -867,7 +989,15 @@ def get_export_targets( def make_param_dict(params: OrderedDict[str, inspect.Parameter]) -> dict: - """Make dictionary storing info about a function's parameters.""" + """Make dictionary storing info about a function's parameters. + + Args: + params: OrderedDict of function parameters. + + Returns: + Dictionary with parameter names as keys and tuples with default value, type, whether it is a list and whether + it is optional as values. + """ param_dict = {} for p, v in params.items(): default = v.default if v.default != inspect.Parameter.empty else None @@ -877,7 +1007,11 @@ def make_param_dict(params: OrderedDict[str, inspect.Parameter]) -> dict: def get_reverse_config_usage() -> defaultdict[str, list]: - """Get a dictionary with annotators as keys, and lists of the config variables they use as values.""" + """Get a dictionary with annotators as keys, and lists of the config variables they use as values. + + Returns: + Dictionary with annotators as keys, and lists of the config variables they use as values. + """ reverse_config_usage = defaultdict(list) for config_key in sparv_config.config_usage: for annotator in sparv_config.config_usage[config_key]: @@ -886,10 +1020,18 @@ def get_reverse_config_usage() -> defaultdict[str, list]: def print_sparv_warning(msg: str) -> None: - """Format msg into a Sparv warning message.""" - console.print(f"[yellow]WARNING: {msg}[/yellow]", highlight=False) + """Format msg into a Sparv warning message. + + Args: + msg: Warning message. + """ + console.print(f"[red]WARNING:[/] {msg}") def print_sparv_info(msg: str) -> None: - """Format msg into a Sparv info message.""" + """Format msg into a Sparv info message. + + Args: + msg: Info message. + """ console.print(f"[green]{msg}[/green]", highlight=False)