Skip to content

Commit

Permalink
define a shared struct to hold the result of a decoding step
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Jul 18, 2024
1 parent a036574 commit a19d318
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
3 changes: 2 additions & 1 deletion backends/trtllm/include/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ namespace huggingface::tgi::backends {
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback);
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
};

/***
Expand Down
21 changes: 8 additions & 13 deletions backends/trtllm/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ use tracing::{instrument, Level, span};

use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{
Chunk, ValidationError, ValidGenerateRequest, ValidParameters,
};
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality;

use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};

// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
Expand Down Expand Up @@ -208,14 +206,11 @@ impl TensorRtLlmBackend {
executor_w.pin_mut().stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext,
token_id: u32,
logprob: f32,
is_final: bool| {
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;

// Insert the latest generated token to the tracker
inner_ctx.tokens.push(token_id);
inner_ctx.tokens.push(step.token_id);

// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
Expand All @@ -224,7 +219,7 @@ impl TensorRtLlmBackend {
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[token_id], true)
.decode(&[step.token_id], true)
.expect("Failed to decode token");

let special = inner_ctx
Expand All @@ -234,13 +229,13 @@ impl TensorRtLlmBackend {

// Create the structure holding the token
let token = Token {
id: token_id,
id: step.token_id,
text,
logprob,
logprob: step.log_prob,
special,
};

let out = if is_final {
let out = if step.is_final {
inner_ctx.done.store(true, Ordering::Relaxed);
let generated_text = inner_ctx
.tokenizer
Expand Down
11 changes: 8 additions & 3 deletions backends/trtllm/src/ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cmath>
#include <exception>
#include <filesystem>
#include <limits>
#include <iterator>
#include <vector>

Expand Down Expand Up @@ -36,10 +37,12 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback) {
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {

size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
Expand All @@ -51,13 +54,15 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
++numTokens;

SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
callback(std::move(ctx), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
callback(std::move(ctx), 0, 0.0, true);
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true};
}

callback(std::move(ctx), std::move(step));
}

return numTokens;
Expand Down
16 changes: 12 additions & 4 deletions backends/trtllm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
pub use backend::TensorRtLlmBackend;

use crate::backend::GenerationContext;
pub use backend::{GenerationContext, TensorRtLlmBackend};

mod backend;
pub mod errors;

#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {

/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[derive(Copy, Clone)]
pub struct GenerationStep {
token_id: u32,
log_prob: f32,
is_final: bool,
}

extern "Rust" {
type GenerationContext;
}
Expand Down Expand Up @@ -60,7 +68,7 @@ mod ffi {
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, u32, f32, bool),
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;

// #[rust_name = "shutdown"]
Expand Down

0 comments on commit a19d318

Please sign in to comment.