Skip to content

Commit

Permalink
funasr1.0 uniasr
Browse files Browse the repository at this point in the history
  • Loading branch information
LauraGPT committed Jan 24, 2024
1 parent 3ead3cd commit 6cf6524
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions funasr/models/uniasr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,12 @@ def init_beam_search(self,
from funasr.models.uniasr.beam_search import BeamSearchScama
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus


decoding_mode = kwargs.get("decoding_mode", "model1")
if decoding_mode == "model1":
decoder = self.decoder
else:
decoder = self.decoder2
# 1. Build ASR model
scorers = {}

Expand All @@ -813,7 +818,7 @@ def init_beam_search(self,
)
token_list = kwargs.get("token_list")
scorers.update(
decoder=self.decoder,
decoder=decoder,
length_bonus=LengthBonus(len(token_list)),
)

Expand All @@ -830,7 +835,7 @@ def init_beam_search(self,
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearchScama(
beam_size=kwargs.get("beam_size", 2),
beam_size=kwargs.get("beam_size", 5),
weights=weights,
scorers=scorers,
sos=self.sos,
Expand Down Expand Up @@ -866,7 +871,7 @@ def inference(self,

if self.beam_search is None:
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
self.nbest = kwargs.get("nbest", 1)

meta_data = {}
Expand Down

0 comments on commit 6cf6524

Please sign in to comment.