Skip to content

Commit

Permalink
Added prompt for sample app
Browse files Browse the repository at this point in the history
  • Loading branch information
rk-helper committed Oct 8, 2024
1 parent bfb1316 commit 08ff013
Showing 1 changed file with 41 additions and 3 deletions.
44 changes: 41 additions & 3 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct ContentView: View {
@State private var availableModels: [String] = []
@State private var availableLanguages: [String] = []
@State private var disabledModels: [String] = WhisperKit.recommendedModels().disabled

@AppStorage("promptText") private var promptText: String?
@AppStorage("selectedAudioInput") private var selectedAudioInput: String = "No Audio Input"
@AppStorage("selectedModel") private var selectedModel: String = WhisperKit.recommendedModels().default
@AppStorage("selectedTab") private var selectedTab: String = "Transcribe"
Expand Down Expand Up @@ -765,6 +765,11 @@ struct ContentView: View {

var settingsForm: some View {
List {





HStack {
Text("Show Timestamps")
InfoButton("Toggling this will include/exclude timestamps in both the UI and the prefill tokens.\nEither <|notimestamps|> or <|0.00|> will be forced based on this setting unless \"Prompt Prefill\" is de-selected.")
Expand Down Expand Up @@ -817,6 +822,14 @@ struct ContentView: View {
}
.padding(.horizontal)
.padding(.bottom)

TextField("Enter prompt text", text: Binding(
get: { self.promptText ?? "" },
set: { self.promptText = $0.isEmpty ? nil : $0 }
))
.textFieldStyle(.roundedBorder)
.padding(.horizontal)
.padding(.bottom)

VStack {
Text("Starting Temperature:")
Expand Down Expand Up @@ -1303,7 +1316,7 @@ struct ContentView: View {
let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate
let seekClip: [Float] = [lastConfirmedSegmentEndSeconds]

let options = DecodingOptions(
var options = DecodingOptions(
verbose: true,
task: task,
language: languageCode,
Expand All @@ -1318,6 +1331,19 @@ struct ContentView: View {
clipTimestamps: seekClip,
chunkingStrategy: chunkingStrategy
)

// Prompt
if let promptText = promptText {
guard whisperKit.tokenizer != nil else {
throw WhisperError.tokenizerUnavailable()
}

if promptText.count > 0, let tokenizer = whisperKit.tokenizer {
options.promptTokens = tokenizer.encode(text: " " + promptText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin }
options.usePrefillPrompt = true
}
}


// Early stopping checks
let decodingCallback: ((TranscriptionProgress) -> Bool?) = { (progress: TranscriptionProgress) in
Expand Down Expand Up @@ -1542,7 +1568,7 @@ struct ContentView: View {
print(selectedLanguage)
print(languageCode)

let options = DecodingOptions(
var options = DecodingOptions(
verbose: true,
task: task,
language: languageCode,
Expand All @@ -1556,6 +1582,18 @@ struct ContentView: View {
wordTimestamps: true, // required for eager mode
firstTokenLogProbThreshold: -1.5 // higher threshold to prevent fallbacks from running to often
)

// Prompt
if let promptText = promptText {
guard whisperKit.tokenizer != nil else {
throw WhisperError.tokenizerUnavailable()
}

if promptText.count > 0, let tokenizer = whisperKit.tokenizer {
options.promptTokens = tokenizer.encode(text: " " + promptText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin }
options.usePrefillPrompt = true
}
}

// Early stopping checks
let decodingCallback: ((TranscriptionProgress) -> Bool?) = { progress in
Expand Down

0 comments on commit 08ff013

Please sign in to comment.