Skip to content

Commit

Permalink
Support --dtype in mistralrs bench (#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Nov 16, 2024
1 parent 9d647a9 commit 6c70800
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use candle_core::Device;
use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use mistralrs_core::{
initialize_logging, paged_attn_supported, parse_isq_value, Constraint, DefaultSchedulerMethod,
DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, IsqType, Loader, LoaderBuilder,
MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelDType, ModelSelected, NormalRequest,
PagedAttentionConfig, Request, RequestMessage, Response, SamplingParams, SchedulerConfig,
TokenSource, Usage,
get_model_dtype, initialize_logging, paged_attn_supported, parse_isq_value, Constraint,
DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata, DrySamplingParams, IsqType,
Loader, LoaderBuilder, MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected,
NormalRequest, PagedAttentionConfig, Request, RequestMessage, Response, SamplingParams,
SchedulerConfig, TokenSource, Usage,
};
use std::sync::Arc;
use std::{fmt::Display, num::NonZeroUsize};
Expand Down Expand Up @@ -348,6 +348,8 @@ fn main() -> anyhow::Result<()> {
None => None,
};

let dtype = get_model_dtype(&args.model)?;

let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
.with_use_flash_attn(use_flash_attn)
.with_prompt_batchsize(prompt_batchsize)
Expand Down Expand Up @@ -477,7 +479,7 @@ fn main() -> anyhow::Result<()> {
let pipeline = loader.load_model_from_hf(
None,
token_source,
&ModelDType::Auto,
&dtype,
&device,
false,
mapper,
Expand Down

0 comments on commit 6c70800

Please sign in to comment.