Skip to content

Commit

Permalink
refactor: improve error messages in lossless source codes
Browse files Browse the repository at this point in the history
  • Loading branch information
rwnobrega committed Dec 19, 2024
1 parent 3268ce3 commit fc1baf2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions src/komm/_lossless_coding/FixedToVariableCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def encode(self, source_symbols: npt.ArrayLike) -> npt.NDArray[np.integer]:

def decode(self, target_symbols: npt.ArrayLike) -> npt.NDArray[np.integer]:
r"""
Decodes a sequence of target symbols using the code. Only works if the code is prefix-free.
Decodes a sequence of target symbols using the code. Only implemented for prefix-free codes.
Parameters:
target_symbols: The sequence of symbols to be decoded. Must be a 1D-array with elements in $[0:T)$, where $T$ is the target cardinality of the code.
Expand All @@ -241,9 +241,11 @@ def decode(self, target_symbols: npt.ArrayLike) -> npt.NDArray[np.integer]:
>>> code.decode([1, 0, 0, 1, 0, 0, 1, 1, 0])
Traceback (most recent call last):
...
ValueError: code is not prefix-free
NotImplementedError: decoding is not implemented for non-prefix-free codes
"""
if not self.is_prefix_free():
raise ValueError("code is not prefix-free")
raise NotImplementedError(
"decoding is not implemented for non-prefix-free codes"
)
target_symbols = np.asarray(target_symbols)
return parse_prefix_free(target_symbols, self.inv_enc_mapping)
8 changes: 5 additions & 3 deletions src/komm/_lossless_coding/VariableToFixedCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def rate(self, pmf: npt.ArrayLike) -> float:

def encode(self, source_symbols: npt.ArrayLike) -> npt.NDArray[np.integer]:
r"""
Encodes a sequence of source symbols using the code.
Encodes a sequence of source symbols using the code. Only implemented for prefix-free codes.
Parameters:
source_symbols: The sequence of symbols to be encoded. Must be a 1D-array with elements in $[0:S)$, where $S$ is the source cardinality of the code.
Expand All @@ -198,10 +198,12 @@ def encode(self, source_symbols: npt.ArrayLike) -> npt.NDArray[np.integer]:
>>> code.encode([0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0])
Traceback (most recent call last):
...
ValueError: code is not prefix-free
NotImplementedError: encoding is not implemented for non-prefix-free codes
"""
if not self.is_prefix_free():
raise ValueError("code is not prefix-free")
raise NotImplementedError(
"encoding is not implemented for non-prefix-free codes"
)
source_symbols = np.asarray(source_symbols)
return parse_prefix_free(source_symbols, self.inv_dec_mapping)

Expand Down

0 comments on commit fc1baf2

Please sign in to comment.