From 4d5e3c21554d3169694da7ea350e4198e2784cb1 Mon Sep 17 00:00:00 2001 From: Daniel Mann Date: Thu, 29 Aug 2024 10:21:54 -0400 Subject: [PATCH] local librasr import + renaming --- i6_models/parts/{fsa.py => rasr_fsa.py} | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) rename i6_models/parts/{fsa.py => rasr_fsa.py} (96%) diff --git a/i6_models/parts/fsa.py b/i6_models/parts/rasr_fsa.py similarity index 96% rename from i6_models/parts/fsa.py rename to i6_models/parts/rasr_fsa.py index f4b2df22..bd328c49 100644 --- a/i6_models/parts/fsa.py +++ b/i6_models/parts/rasr_fsa.py @@ -1,13 +1,12 @@ from __future__ import annotations -__all__ = ["TorchFsaBuilder", "WeightedFsa"] +__all__ = ["RasrFsaBuilder", "WeightedFsa"] from functools import reduce from typing import Iterable, NamedTuple, Tuple, Union import numpy as np import torch -import librasr class WeightedFsa(NamedTuple): @@ -43,11 +42,13 @@ def to(self, device: Union[str, torch.device]) -> WeightedFsa: return WeightedFsa._make(tensor.to(device) for tensor in self) -class TorchFsaBuilder: +class RasrFsaBuilder: """ Builder class that wraps around the librasr.AllophoneStateFsaBuilder, bringing the FSAs into the correct format for the `i6_native_ops.fbw.fbw_loss`. Use of this class requires a working installation of the python package `librasr`. + Hence, the package is locally imported in case other classes are accessed from + this module. This class provides an explicit implementation of the `__getstate__` and `__setstate__` functions, necessary for pickling as the C++-class `librasr.AllophoneStateFsaBuilder` is not picklable. @@ -57,6 +58,8 @@ class TorchFsaBuilder: """ def __init__(self, config_path: str, tdp_scale: float = 1.0): + import librasr + self.config_path = config_path config = librasr.Configuration() config.set_from_file(self.config_path) @@ -69,6 +72,8 @@ def __getstate__(self): return state def __setstate__(self, state): + import librasr + self.__dict__.update(state) config = librasr.Configuration() config.set_from_file(self.config_path)