Skip to content

Commit

Permalink
add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
john committed Mar 6, 2024
1 parent c45daeb commit 27aef92
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 89 deletions.
30 changes: 16 additions & 14 deletions apps/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

class Classifier(object):

def __init__(self, params, audio_rate):
def __init__(self, params: dict[str, bool], audio_rate: int):

self.params = params
self.audio_rate = audio_rate

if all(value == False for value in self.params.values()):
if all(value is False for value in self.params.values()):
raise Exception('user does not want to classify audio')

path = Path('model/model_1.tflite')
Expand All @@ -40,31 +40,31 @@ def __init__(self, params, audio_rate):
except:
raise

def load_model(self, path):
def load_model(self, path: Path) -> None:
self.model = tf.lite.Interpreter(model_path=path.absolute().as_posix())
self.model.allocate_tensors()
self.input_details = self.model.get_input_details()
self.output_details = self.model.get_output_details()

def is_wanted(self, file):
def is_wanted(self, file: str) -> str | None:

spectrogram = self.get_spectrogram(file)
detected_as = self.predict(spectrogram)

wanted = detected_as if self.params[detected_as] else False
wanted = detected_as if detected_as and self.params[detected_as] else None
return wanted

# convert the waveform into a spectrogram
def get_spectrogram(self, file):
def get_spectrogram(self, file: str) -> tf.Tensor:

audio_binary = tf.io.read_file(file)
audio_binary: tf.Tensor = tf.io.read_file(file)
try:
waveform = self.decode_audio(audio_binary)
waveform: tf.Tensor = self.decode_audio(audio_binary)
except Exception as error:
logging.error(f'could not decode audio (try "-b 16" option or disable classification): {error}')
raise

total_samples = tf.size(waveform).numpy()
total_samples: int = tf.size(waveform).numpy()
sample_rate = self.audio_rate
#length = round(total_samples/sample_rate, 3)

Expand All @@ -84,28 +84,30 @@ def get_spectrogram(self, file):
#logging.debug('length: %d [ start: %d middle: %d end: %d]', tf.size(waveform).numpy(), start, middle, end)

# Padding for files with less than 16000 samples (2 seconds of 8 Khz sample rate)
zero_padding = tf.zeros([target_samples] - tf.shape(waveform), dtype=tf.float32)
zero_padding: tf.Tensor = tf.zeros([target_samples] - tf.shape(waveform), dtype=tf.float32)

# Concatenate audio with padding so that all audio clips will be of the
# same length
waveform = tf.cast(waveform, tf.float32)
equal_length = tf.concat([waveform, zero_padding], 0)
equal_length: tf.Tensor = tf.concat([waveform, zero_padding], 0)
spectrogram = tf.signal.stft(
equal_length, frame_length=255, frame_step=128)

spectrogram = tf.abs(spectrogram)

spectro = []
spectro: tf.Tensor = []
spectro.append(spectrogram.numpy())
# logging.debug(f'{spectro =}')
spectro = np.expand_dims(spectro, axis=-1) # TODO: what is this dimension for?
# logging.debug(f' after: {spectro =}')

return spectro

def decode_audio(self, audio_binary):
def decode_audio(self, audio_binary: tf.Tensor) -> tf.Tensor:
audio, _ = tf.audio.decode_wav(audio_binary)
return tf.squeeze(audio, axis=-1)

def predict(self, spectrogram):
def predict(self, spectrogram: tf.Tensor) -> str | None:
if spectrogram is None:
return None

Expand Down
3 changes: 2 additions & 1 deletion apps/cursesgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import numpy as np
import logging
from h2h_types import Channel
from h2m_types import Channel

locale.setlocale(locale.LC_ALL, '')
class SpectrumWindow(object):
Expand Down Expand Up @@ -378,6 +378,7 @@ def __init__(self, screen):
self.priority_file_name = ""
self.channel_log_file_name = ""
self.channel_log_timeout = 15
self.gains = None
# nothing other than file logging defined
if (self.channel_log_file_name != ""):
self.log_mode = "file"
Expand Down
8 changes: 4 additions & 4 deletions apps/h2m_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def __init__(self):
parser.add_argument("-x", "--mix_gain", type=float, dest="mix_gain_db",
default=5, help="Hardware MIX gain index")

parser.add_argument("-s", "--squelch", type=float,
parser.add_argument("-s", "--squelch", type=int,
dest="squelch_db", default=-60,
help="Squelch in dB")

parser.add_argument("-v", "--volume", type=float,
parser.add_argument("-v", "--volume", type=int,
dest="volume_db", default=0,
help="Volume in dB")

Expand Down Expand Up @@ -196,8 +196,8 @@ def __init__(self):
{ "name": "PGA", "value": float(options.pga_gain_db) },
{ "name": "LB", "value": float(options.lb_gain_db) }
]
self.squelch_db = float(options.squelch_db)
self.volume_db = float(options.volume_db)
self.squelch_db = int(options.squelch_db)
self.volume_db = int(options.volume_db)
self.threshold_db = float(options.threshold_db)
self.record = bool(options.record)
self.play = bool(options.play)
Expand Down
File renamed without changes.
38 changes: 20 additions & 18 deletions apps/ham2mon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from curses import ERR, KEY_RESIZE, curs_set, wrapper, echo, nocbreak, endwin
import cursesgui
import h2m_parser as h2m_parser
import time
import asyncio
import errors as err
import logging
Expand All @@ -21,7 +20,7 @@

class MyDisplay():

def __init__(self, stdscr: "_curses._CursesWindow"):
def __init__(self, stdscr: "_curses._CursesWindow") -> None:
self.stdscr = stdscr

async def run(self) -> None:
Expand Down Expand Up @@ -85,10 +84,13 @@ def make_display(self) -> None:

self.specwin.max_db = PARSER.max_db
self.specwin.min_db = PARSER.min_db
self.rxwin.classifier_params = self.classifier_params
self.rxwin.classifier_params = { 'V': PARSER.voice,
'D': PARSER.data,
'S': PARSER.skip
}
self.specwin.threshold_db = self.scanner.threshold_db

async def cycle(self):
async def cycle(self) -> None:
# Initiate a scan cycle

# No need to go faster than 10 Hz rate of GNU Radio probe
Expand All @@ -105,7 +107,7 @@ async def cycle(self):
# Update physical screen
self.stdscr.refresh()

def init_scanner(self) -> object:
def init_scanner(self) -> scnr.Scanner:
# Create scanner object
ask_samp_rate = PARSER.ask_samp_rate
num_demod = PARSER.num_demod
Expand All @@ -123,17 +125,17 @@ def init_scanner(self) -> object:
center_freq = PARSER.center_freq
min_recording = PARSER.min_recording
max_recording = PARSER.max_recording
self.classifier_params = { 'V': PARSER.voice,
'D': PARSER.data,
'S': PARSER.skip
}
classifier_params = { 'V': PARSER.voice,
'D': PARSER.data,
'S': PARSER.skip
}

scanner = scnr.Scanner(ask_samp_rate, num_demod, type_demod, hw_args,
freq_correction, record, lockout_file_name,
priority_file_name, channel_log_file_name, channel_log_timeout,
play, audio_bps, channel_spacing,
center_freq,
min_recording, max_recording, self.classifier_params)
min_recording, max_recording, classifier_params)

# Set the paramaters
scanner.set_center_freq(PARSER.center_freq)
Expand Down Expand Up @@ -173,7 +175,7 @@ def handle_char(self, keyb: int) -> None:
if self.lockoutwin.proc_keyb_clear_lockout(keyb):
self.scanner.clear_lockout()

async def display_main(stdscr):
async def display_main(stdscr) -> None:
display = MyDisplay(stdscr)
await display.run()

Expand All @@ -194,24 +196,24 @@ def main(stdscr) -> None:
wrapper(main)
except KeyboardInterrupt:
pass
except RuntimeError as err:
except RuntimeError as error:
print("")
print("RuntimeError: SDR hardware not detected or insufficient USB permissions. Try running as root.")
print("RuntimeError: SDR hardware not detected or insufficient USB permissions. Try running as root or with --debug option.")
print("")
print("RuntimeError: {err=}, {type(err)=}")
print(f'RuntimeError: {error=}, {type(error)=}')
logging.debug(traceback.format_exc())
print("")
except err.LogError:
print("")
print("LogError: database logging not active, to be expanded.")
print("")
except OSError as err:
except OSError as error:
print("")
print("OS error: {0}".format(err))
print(f'OS error: {error=}, {type(error)=}')
print("")
except BaseException as err:
except BaseException as error:
print("")
print("Unexpected: {err=}, {type(err)=}", err, type(err))
print(f'Unexpected: {error=}, {type(error)=}')
logging.debug(traceback.format_exc())
print("")

Expand Down
Loading

0 comments on commit 27aef92

Please sign in to comment.