Skip to content

Commit

Permalink
Merge branch 'mlc-ai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
JackWeiw authored May 16, 2024
2 parents 9e791ed + 56ea156 commit 2dafa91
Show file tree
Hide file tree
Showing 66 changed files with 3,255 additions and 2,120 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 95 files
+1 −1 gallery/how_to/deploy_models/deploy_prequantized.py
+29 −0 include/tvm/ir/expr.h
+13 −0 include/tvm/relax/analysis.h
+1 −1 include/tvm/relax/dataflow_pattern.h
+30 −0 include/tvm/relax/expr.h
+2 −0 include/tvm/runtime/disco/session.h
+30 −0 include/tvm/tir/var.h
+19 −10 jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
+10 −1 python/tvm/dlight/base/transform.py
+1 −0 python/tvm/relax/analysis/__init__.py
+27 −0 python/tvm/relax/analysis/analysis.py
+1 −1 python/tvm/relay/expr_functor.py
+18 −1 python/tvm/runtime/disco/session.py
+1 −1 python/tvm/testing/utils.py
+14 −17 python/tvm/topi/cuda/sort.py
+1 −1 src/arith/const_int_bound.cc
+5 −5 src/arith/iter_affine_map.cc
+1 −1 src/arith/modular_set.cc
+1 −1 src/arith/rewrite_simplify.h
+1 −1 src/contrib/msc/core/transform/set_expr_layout.cc
+1 −1 src/meta_schedule/feature_extractor/per_store_feature.cc
+3 −3 src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+1 −1 src/relax/analysis/computable_at_compile_time.cc
+2 −2 src/relax/analysis/layout_transformation.cc
+47 −2 src/relax/analysis/struct_info_analysis.cc
+1 −1 src/relax/analysis/udchain.cc
+7 −9 src/relax/analysis/well_formed.cc
+1 −1 src/relax/backend/vm/codegen_vm.cc
+1 −1 src/relax/backend/vm/codegen_vm_tir.cc
+2 −2 src/relax/distributed/transform/lower_global_view_to_local_view.cc
+11 −0 src/relax/ir/expr_functor.cc
+65 −21 src/relax/transform/adjust_matmul_order.cc
+2 −2 src/relax/transform/canonicalize_bindings.cc
+1 −1 src/relax/transform/convert_layout.cc
+10 −14 src/relax/transform/dataflow_inplace.cc
+3 −4 src/relax/transform/dead_code_elimination.cc
+1 −2 src/relax/transform/expand_matmul_of_sum.cc
+29 −2 src/relax/transform/fuse_ops.cc
+4 −5 src/relax/transform/fuse_tir.cc
+1 −1 src/relax/transform/infer_amp_utils.h
+2 −2 src/relax/transform/lambda_lift.cc
+6 −6 src/relax/transform/lazy_transform_params.cc
+89 −23 src/relax/transform/legalize_ops.cc
+8 −13 src/relax/transform/lift_transform_params.cc
+1 −1 src/relax/transform/merge_composite_functions.cc
+1 −1 src/relax/transform/split_call_tir_by_pattern.cc
+1 −1 src/relax/transform/topological_sort.cc
+2 −2 src/relax/transform/update_param_struct_info.cc
+2 −4 src/relay/analysis/call_graph.h
+2 −2 src/runtime/contrib/vllm/attention_kernels.cu
+18 −0 src/runtime/cuda/cuda_device_api.cc
+11 −1 src/runtime/disco/builtin.cc
+16 −6 src/runtime/disco/nccl/nccl.cc
+36 −4 src/runtime/disco/process_session.cc
+4 −0 src/runtime/disco/session.cc
+2 −0 src/runtime/disco/threaded_session.cc
+1 −1 src/runtime/relax_vm/kv_state.h
+13 −12 src/runtime/relax_vm/paged_kv_cache.cc
+24 −5 src/support/pipe.h
+1 −1 src/target/llvm/codegen_llvm.h
+2 −2 src/target/source/codegen_c.h
+1 −1 src/target/source/codegen_webgpu.cc
+1 −1 src/target/spirv/codegen_spirv.h
+1 −1 src/tir/analysis/is_pure_function.cc
+1 −1 src/tir/analysis/verify_ssa.cc
+3 −3 src/tir/analysis/verify_well_formed.cc
+1 −1 src/tir/ir/specialize.cc
+1 −1 src/tir/ir/tir_visitor_with_path.cc
+2 −2 src/tir/schedule/analysis/analysis.cc
+6 −6 src/tir/schedule/primitive/cache_read_write.cc
+2 −2 src/tir/schedule/primitive/reduction.cc
+7 −13 src/tir/transforms/compact_buffer_region.cc
+1 −1 src/tir/transforms/inject_permuted_layout.cc
+1 −1 src/tir/transforms/inject_software_pipeline.cc
+40 −5 src/tir/transforms/ir_utils.cc
+1 −2 src/tir/transforms/ir_utils.h
+1 −1 src/tir/transforms/lower_custom_datatypes.cc
+2 −2 src/tir/transforms/lower_opaque_block.cc
+1 −1 src/tir/transforms/storage_flatten.cc
+1 −1 src/tir/transforms/texture_flatten.cc
+1 −1 src/tir/transforms/thread_storage_sync.cc
+1 −1 src/tir/transforms/transform_mma_buffer_layout.cc
+3 −4 src/tir/transforms/unroll_loop.cc
+7 −10 src/tir/transforms/unsupported_dtype_legalize.cc
+1 −1 src/tir/transforms/vectorize_loop.cc
+1 −1 src/tir/usmp/analysis/extract_buffer_info.cc
+4 −4 src/tir/usmp/transform/create_io_allocates.cc
+17 −5 tests/python/disco/test_ccl.py
+7 −0 tests/python/disco/test_session.py
+78 −0 tests/python/dlight/test_gpu_fallback.py
+43 −0 tests/python/relax/test_analysis_struct_info_analysis.py
+164 −0 tests/python/relax/test_transform_adjust_matmul_order.py
+26 −0 tests/python/relax/test_transform_fuse_ops_by_pattern.py
+81 −0 tests/python/relax/test_transform_legalize_ops.py
+11 −0 tests/python/testing/test_tvm_testing_features.py
299 changes: 190 additions & 109 deletions cpp/json_ffi/conv_template.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ ModelConfig ModelConfig::FromJSON(const picojson::object& json_obj) {

/****************** Conversation template ******************/

std::map<MessagePlaceholders, std::string> PLACEHOLDERS = {
std::unordered_map<MessagePlaceholders, std::string> PLACEHOLDERS = {
{MessagePlaceholders::SYSTEM, "{system_message}"},
{MessagePlaceholders::USER, "{user_message}"},
{MessagePlaceholders::ASSISTANT, "{assistant_message}"},
Expand All @@ -153,120 +153,213 @@ Conversation::Conversation()
{"assistant", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},
{"tool", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}

Result<std::vector<Data>> Conversation::AsPrompt(ModelConfig config, DLDevice device) {
using TResult = Result<std::vector<Data>>;
// Get the system message
std::string system_msg = system_template;
size_t pos = system_msg.find(PLACEHOLDERS[MessagePlaceholders::SYSTEM]);
std::string Conversation::GetSystemText(const std::string& system_msg) const {
std::string system_text = this->system_template;
static std::string system_placeholder = PLACEHOLDERS[MessagePlaceholders::SYSTEM];
size_t pos = system_text.find(system_placeholder);
if (pos != std::string::npos) {
system_msg.replace(pos, PLACEHOLDERS[MessagePlaceholders::SYSTEM].length(),
this->system_message);
system_text.replace(pos, system_placeholder.length(), system_msg);
}
return system_text;
}

// Get the message strings
std::vector<Data> message_list;
std::vector<std::string> separators = seps;
if (separators.size() == 1) {
separators.push_back(separators[0]);
std::string Conversation::GetRoleText(const std::string& role, const std::string& content,
const std::optional<std::string>& fn_call_string) const {
std::string role_text = this->role_templates.at(role);
std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)];
size_t pos = role_text.find(placeholder);
if (pos != std::string::npos) {
role_text.replace(pos, placeholder.length(), content);
}
if (fn_call_string) {
// replace placeholder[FUNCTION] with function_string
// this assumes function calling is used for a single request scenario only
pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]);
if (pos != std::string::npos) {
role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(),
fn_call_string.value());
}
}
return role_text;
}

if (!system_msg.empty()) {
system_msg += separators[0];
message_list.push_back(TextData(system_message));
/// Try to detect if function calling is needed, if so, return the function calling string
Result<std::optional<std::string>> TryGetFunctionCallingString(
const Conversation& conv, const ChatCompletionRequest& request) {
using TResult = Result<std::optional<std::string>>;
if (!request.tools.has_value() ||
(request.tool_choice.has_value() && request.tool_choice.value() == "none")) {
return TResult::Ok(std::nullopt);
}
std::vector<ChatTool> tools_ = request.tools.value();
std::string tool_choice_ = request.tool_choice.value();

// TODO: support with tool choice as dict
for (const auto& tool : tools_) {
if (tool.function.name == tool_choice_) {
picojson::value function_str(tool.function.AsJSON());
return TResult::Ok(function_str.serialize());
}
}

for (int i = 0; i < messages.size(); i++) {
std::string role = messages[i].role;
// Todo(mlc-team): support content to be a single string.
std::optional<std::vector<std::unordered_map<std::string, std::string>>> content =
messages[i].content;
if (roles.find(role) == roles.end()) {
return TResult::Error("Role \"" + role + "\" is not supported");
}
if (tool_choice_ != "auto") {
return TResult::Error("Invalid tool_choice value in the request: " + tool_choice_);
}

picojson::array function_list;
for (const auto& tool : tools_) {
function_list.push_back(picojson::value(tool.function.AsJSON()));
}

std::string separator = separators[role == "assistant"]; // check assistant role
picojson::value function_list_json(function_list);
return TResult::Ok(function_list_json.serialize());
};

// If content is empty, add the role and separator
// assistant's turn to generate text
if (!content.has_value()) {
message_list.push_back(TextData(roles[role] + role_empty_sep));
continue;
}
Result<std::vector<Data>> CreatePrompt(const Conversation& conv,
const ChatCompletionRequest& request,
const ModelConfig& config, DLDevice device) {
using TResult = Result<std::vector<Data>>;

Result<std::optional<std::string>> fn_call_str_tmp = TryGetFunctionCallingString(conv, request);
if (fn_call_str_tmp.IsErr()) {
return TResult::Error(fn_call_str_tmp.UnwrapErr());
}
std::optional<std::string> fn_call_string = fn_call_str_tmp.Unwrap();

// Handle system message
// concz
bool has_custom_system = false;
std::string custom_system_inputs;

std::string message = "";
std::string role_prefix = "";
// Do not append role prefix if this is the first message and there
// is already a system message
if (add_role_after_system_message || system_msg.empty() || i != 0) {
role_prefix = roles[role] + role_content_sep;
auto f_populate_system_message = [&](const std::vector<ChatCompletionMessage>& msg_vec) {
for (ChatCompletionMessage msg : msg_vec) {
if (msg.role == "system") {
ICHECK(msg.content.IsText()) << "System message must be text";
custom_system_inputs += msg.content.Text();
has_custom_system = true;
}
}
};
// go through messages in template and passed in.
f_populate_system_message(conv.messages);
f_populate_system_message(request.messages);

message += role_prefix;
// pending text records the text to be put into data
// we lazily accumulate the pending text
// to reduce amount of segments in the Data vector
std::string pending_text =
conv.GetSystemText(has_custom_system ? custom_system_inputs : conv.system_message);

for (const auto& item : content.value()) {
auto it_type = item.find("type");
if (it_type == item.end()) {
return TResult::Error("The content of a message does not have \"type\" field");
// the seperator after system message.
if (!pending_text.empty()) {
pending_text += conv.seps[0];
}

// Get the message strings
std::vector<Data> message_list;
size_t non_system_msg_count = 0;

// returns error if error happens
auto f_process_messages =
[&](const std::vector<ChatCompletionMessage>& msg_vec) -> std::optional<TResult> {
for (size_t i = 0; i < msg_vec.size(); ++i) {
const ChatCompletionMessage& msg = msg_vec[i];
auto role_it = conv.roles.find(msg.role);
if (role_it == conv.roles.end()) {
return TResult::Error("Role \"" + msg.role + "\" is not supported");
}
if (it_type->second == "text") {
auto it_text = item.find("text");
if (it_text == item.end()) {
return TResult::Error("The text type content of a message does not have \"text\" field");
}
// replace placeholder[ROLE] with input message from role
std::string role_text = role_templates[role];
std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)];
size_t pos = role_text.find(placeholder);
if (pos != std::string::npos) {
role_text.replace(pos, placeholder.length(), it_text->second);
}
if (use_function_calling) {
// replace placeholder[FUNCTION] with function_string
// this assumes function calling is used for a single request scenario only
if (!function_string.has_value()) {
return TResult::Error(
"The function string in conversation template is not defined for function "
"calling.");
const std::string& role_name = role_it->second;
// skip system message as it is already processed
if (msg.role == "system") continue;
// skip when content is empty
if (msg.content.IsNull()) {
pending_text += role_name + conv.role_empty_sep;
continue;
}
++non_system_msg_count;
// assistant uses conv.seps[1] if there are two seps
int sep_offset = msg.role == "assistant" ? 1 : 0;
const std::string& seperator = conv.seps[sep_offset % conv.seps.size()];
// setup role prefix
std::string role_prefix = "";
// Do not append role prefix if this is the first message and there is already a system
// message
if (conv.add_role_after_system_message || pending_text.empty() || non_system_msg_count != 1) {
role_prefix = role_name + conv.role_content_sep;
}
pending_text += role_prefix;

if (msg.content.IsParts()) {
for (const auto& item : msg.content.Parts()) {
auto it_type = item.find("type");
if (it_type == item.end()) {
return TResult::Error("The content of a message does not have \"type\" field");
}
pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]);
if (pos != std::string::npos) {
role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(),
function_string.value());
if (it_type->second == "text") {
auto it_text = item.find("text");
if (it_text == item.end()) {
return TResult::Error(
"The text type content of a message does not have \"text\" field");
}
// replace placeholder[ROLE] with input message from role
pending_text += conv.GetRoleText(msg.role, it_text->second, fn_call_string);
} else if (it_type->second == "image_url") {
if (item.find("image_url") == item.end()) {
return TResult::Error("Content should have an image_url field");
}
std::string image_url =
item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this
// should be a map, with a "url" key containing the URL, but
// we are just assuming this as the URL for now
std::string base64_image = image_url.substr(image_url.find(",") + 1);
Result<NDArray> image_data_res = LoadImageFromBase64(base64_image);
if (image_data_res.IsErr()) {
return TResult::Error(image_data_res.UnwrapErr());
}
if (!config.vision_config.has_value()) {
return TResult::Error("Vision config is required for image input");
}
int image_size = config.vision_config.value().image_size;
int patch_size = config.vision_config.value().patch_size;

int embed_size = (image_size * image_size) / (patch_size * patch_size);

auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
// lazily commit text data
if (pending_text.length() != 0) {
message_list.push_back(TextData(pending_text));
pending_text = "";
}
message_list.push_back(ImageData(image_ndarray, embed_size));
} else {
return TResult::Error("Unsupported content type: " + it_type->second);
}
}
message += role_text;
} else if (it_type->second == "image_url") {
if (item.find("image_url") == item.end()) {
return TResult::Error("Content should have an image_url field");
}
std::string image_url =
item.at("image_url"); // TODO(mlc-team): According to OpenAI API reference this
// should be a map, with a "url" key containing the URL, but
// we are just assuming this as the URL for now
std::string base64_image = image_url.substr(image_url.find(",") + 1);
Result<NDArray> image_data_res = LoadImageFromBase64(base64_image);
if (image_data_res.IsErr()) {
return TResult::Error(image_data_res.UnwrapErr());
}
if (!config.vision_config.has_value()) {
return TResult::Error("Vision config is required for image input");
}
int image_size = config.vision_config.value().image_size;
int patch_size = config.vision_config.value().patch_size;

int embed_size = (image_size * image_size) / (patch_size * patch_size);

auto image_ndarray = ClipPreprocessor(image_data_res.Unwrap(), image_size, device);
message_list.push_back(ImageData(image_ndarray, embed_size));
} else {
return TResult::Error("Unsupported content type: " + it_type->second);
ICHECK(msg.content.IsText());
pending_text += conv.GetRoleText(msg.role, msg.content.Text(), fn_call_string);
}
pending_text += seperator;
}
return std::nullopt;
};

message += separator;
message_list.push_back(TextData(message));
if (auto err = f_process_messages(conv.messages)) {
return err.value();
}
if (auto err = f_process_messages(request.messages)) {
return err.value();
}
// append last assistant begin message
ChatCompletionMessage last_assistant_begin;
last_assistant_begin.role = "assistant";
last_assistant_begin.content = std::nullopt;
if (auto err = f_process_messages({last_assistant_begin})) {
return err.value();
}
if (pending_text.length() != 0) {
message_list.push_back(TextData(pending_text));
}

return TResult::Ok(message_list);
}

Expand Down Expand Up @@ -383,7 +476,10 @@ Result<Conversation> Conversation::FromJSON(const picojson::object& json_obj) {
content.push_back(std::move(item_map));
}
}
conv.messages.push_back({role_res.Unwrap(), content});
ChatCompletionMessage msg;
msg.role = role_res.Unwrap();
msg.content = content;
conv.messages.push_back(msg);
}

Result<picojson::array> seps_arr_res =
Expand Down Expand Up @@ -438,21 +534,6 @@ Result<Conversation> Conversation::FromJSON(const picojson::object& json_obj) {
}
conv.stop_token_ids.push_back(stop.get<int64_t>());
}

Result<std::optional<std::string>> function_string_res =
json::LookupOptionalWithResultReturn<std::string>(json_obj, "function_string");
if (function_string_res.IsErr()) {
return TResult::Error(function_string_res.UnwrapErr());
}
conv.function_string = function_string_res.Unwrap();

Result<bool> use_function_calling_res = json::LookupOrDefaultWithResultReturn<bool>(
json_obj, "use_function_calling", conv.use_function_calling);
if (use_function_calling_res.IsErr()) {
return TResult::Error(use_function_calling_res.UnwrapErr());
}
conv.use_function_calling = use_function_calling_res.Unwrap();

return TResult::Ok(conv);
}

Expand Down
Loading

0 comments on commit 2dafa91

Please sign in to comment.