Skip to content

Commit

Permalink
feat: optimize if stmt lazy eval
Browse files Browse the repository at this point in the history
Signed-off-by: peefy <[email protected]>
  • Loading branch information
Peefy committed Jun 20, 2024
1 parent edecd59 commit 3fc5f3e
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 117 deletions.
19 changes: 19 additions & 0 deletions kclvm/compiler/src/codegen/llvm/backtrack.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright The KCL Authors. All rights reserved.

use super::context::LLVMCodeGenContext;
use crate::codegen::llvm::context::BacktrackKind;
use crate::codegen::traits::BuilderMethods;
use inkwell::values::BasicValueEnum;

Expand All @@ -22,4 +23,22 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
}
false
}

#[inline]
pub(crate) fn is_backtrack_only_if(&self) -> bool {
if let Some(backtrack_meta) = self.backtrack_meta.borrow_mut().as_ref() {
matches!(backtrack_meta.kind, BacktrackKind::If)
} else {
false
}
}

#[inline]
pub(crate) fn is_backtrack_only_or_else(&self) -> bool {
if let Some(backtrack_meta) = self.backtrack_meta.borrow_mut().as_ref() {
matches!(backtrack_meta.kind, BacktrackKind::OrElse)
} else {
false
}
}
}
13 changes: 13 additions & 0 deletions kclvm/compiler/src/codegen/llvm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,25 @@ pub struct Scope<'ctx> {
pub arguments: RefCell<IndexSet<String>>,
}

/// Backtrack kind.
/// - If it is a normal kind, traverse all statements in the setter.
/// - If it is an if type, only traverse the if statement in the if stmt, skipping the else stmt.
/// - If it is an orelse type, only traverse the else statement, and make conditional judgments based on the inverse of the if stmt's cond.
#[derive(Default, Debug, Clone, PartialEq)]
pub enum BacktrackKind {
#[default]
Normal,
If,
OrElse,
}

/// Schema or Global internal order independent computation backtracking meta information.
pub struct BacktrackMeta {
pub target: String,
pub level: usize,
pub count: usize,
pub stop: bool,
pub kind: BacktrackKind,
}

/// The LLVM code generator
Expand Down
136 changes: 86 additions & 50 deletions kclvm/compiler/src/codegen/llvm/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use kclvm_runtime::ApiFunc;
use kclvm_sema::pkgpath_without_prefix;

use super::context::{BacktrackMeta, LLVMCodeGenContext};
use crate::codegen::llvm::context::BacktrackKind;
use crate::codegen::traits::{BuilderMethods, ProgramCodeGen, ValueMethods};
use crate::codegen::{error as kcl_error, ENTRY_NAME};
use crate::value;
Expand Down Expand Up @@ -64,7 +65,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
self.br(global_setter_block);
self.builder.position_at_end(global_setter_block);
let mut place_holder_map: IndexMap<String, Vec<FunctionValue<'ctx>>> = IndexMap::new();
let mut body_map: IndexMap<String, Vec<&ast::Node<ast::Stmt>>> = IndexMap::new();
let mut body_map: IndexMap<String, Vec<(&ast::Node<ast::Stmt>, BacktrackKind)>> =
IndexMap::new();
let pkgpath = &self.current_pkgpath();
// Setter function name format: "$set.<pkg_path>.$<var_name>"
self.emit_global_setters(
Expand All @@ -83,7 +85,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
}
let stmt_list = body_map.get(k).expect(kcl_error::INTERNAL_ERROR_MSG);
let mut if_level = 0;
for (attr_func, stmt) in functions.iter().zip(stmt_list) {
for (attr_func, (stmt, kind)) in functions.iter().zip(stmt_list) {
let function = *attr_func;
let name = function
.get_name()
Expand All @@ -102,6 +104,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
level: if_level,
count: 0,
stop: false,
kind: kind.clone(),
});
} else {
if_level = 0;
Expand Down Expand Up @@ -207,56 +210,65 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
pkgpath: &str,
is_in_if: bool,
place_holder_map: &mut IndexMap<String, Vec<FunctionValue<'ctx>>>,
body_map: &mut IndexMap<String, Vec<&'ctx ast::Node<ast::Stmt>>>,
body_map: &mut IndexMap<String, Vec<(&'ctx ast::Node<ast::Stmt>, BacktrackKind)>>,
in_if_names: &mut Vec<String>,
) {
let add_stmt =
|name: &str,
stmt: &'ctx ast::Node<ast::Stmt>,
place_holder_map: &mut IndexMap<String, Vec<FunctionValue<'ctx>>>,
body_map: &mut IndexMap<String, Vec<&'ctx ast::Node<ast::Stmt>>>| {
// The function form e.g., $set.__main__.a(&Context, &LazyScope, &ValueRef, &ValueRef)
let var_key = format!("{}.{name}", pkgpath_without_prefix!(pkgpath));
let function =
self.add_setter_function(&format!("{}.{}", value::GLOBAL_SETTER, var_key));
let lambda_fn_ptr = self.builder.build_bitcast(
function.as_global_value().as_pointer_value(),
self.context.i64_type().ptr_type(AddressSpace::default()),
"",
);
if !place_holder_map.contains_key(name) {
place_holder_map.insert(name.to_string(), vec![]);
}
let name_vec = place_holder_map
.get_mut(name)
.expect(kcl_error::INTERNAL_ERROR_MSG);
name_vec.push(function);
self.build_void_call(
&ApiFunc::kclvm_scope_add_setter.name(),
&[
self.current_runtime_ctx_ptr(),
self.current_scope_ptr(),
self.native_global_string(pkgpath, "").into(),
self.native_global_string(name, "").into(),
lambda_fn_ptr,
],
);
let key = format!("{}.{name}", pkgpath_without_prefix!(pkgpath));
self.setter_keys.borrow_mut().insert(key);
if !body_map.contains_key(name) {
body_map.insert(name.to_string(), vec![]);
}
let body_vec = body_map.get_mut(name).expect(kcl_error::INTERNAL_ERROR_MSG);
body_vec.push(stmt);
};
let add_stmt = |name: &str,
stmt: &'ctx ast::Node<ast::Stmt>,
kind: BacktrackKind,
place_holder_map: &mut IndexMap<String, Vec<FunctionValue<'ctx>>>,
body_map: &mut IndexMap<
String,
Vec<(&'ctx ast::Node<ast::Stmt>, BacktrackKind)>,
>| {
// The function form e.g., $set.__main__.a(&Context, &LazyScope, &ValueRef, &ValueRef)
let var_key = format!("{}.{name}", pkgpath_without_prefix!(pkgpath));
let function =
self.add_setter_function(&format!("{}.{}", value::GLOBAL_SETTER, var_key));
let lambda_fn_ptr = self.builder.build_bitcast(
function.as_global_value().as_pointer_value(),
self.context.i64_type().ptr_type(AddressSpace::default()),
"",
);
if !place_holder_map.contains_key(name) {
place_holder_map.insert(name.to_string(), vec![]);
}
let name_vec = place_holder_map
.get_mut(name)
.expect(kcl_error::INTERNAL_ERROR_MSG);
name_vec.push(function);
self.build_void_call(
&ApiFunc::kclvm_scope_add_setter.name(),
&[
self.current_runtime_ctx_ptr(),
self.current_scope_ptr(),
self.native_global_string(pkgpath, "").into(),
self.native_global_string(name, "").into(),
lambda_fn_ptr,
],
);
let key = format!("{}.{name}", pkgpath_without_prefix!(pkgpath));
self.setter_keys.borrow_mut().insert(key);
if !body_map.contains_key(name) {
body_map.insert(name.to_string(), vec![]);
}
let body_vec = body_map.get_mut(name).expect(kcl_error::INTERNAL_ERROR_MSG);
body_vec.push((stmt, kind));
};
for stmt in body {
match &stmt.node {
ast::Stmt::Unification(unification_stmt) => {
let name = &unification_stmt.target.node.names[0].node;
if is_in_if {
in_if_names.push(name.to_string());
} else {
add_stmt(name, stmt, place_holder_map, body_map);
add_stmt(
name,
stmt,
BacktrackKind::Normal,
place_holder_map,
body_map,
);
}
}
ast::Stmt::Assign(assign_stmt) => {
Expand All @@ -265,7 +277,13 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
if is_in_if {
in_if_names.push(name.to_string());
} else {
add_stmt(name, stmt, place_holder_map, body_map);
add_stmt(
name,
stmt,
BacktrackKind::Normal,
place_holder_map,
body_map,
);
}
}
}
Expand All @@ -275,7 +293,13 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
if is_in_if {
in_if_names.push(name.to_string());
} else {
add_stmt(name, stmt, place_holder_map, body_map);
add_stmt(
name,
stmt,
BacktrackKind::Normal,
place_holder_map,
body_map,
);
}
}
ast::Stmt::If(if_stmt) => {
Expand All @@ -294,10 +318,10 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
}
} else {
for name in &names {
add_stmt(name, stmt, place_holder_map, body_map);
add_stmt(name, stmt, BacktrackKind::If, place_holder_map, body_map);
}
names.clear();
}
names.clear();
self.emit_global_setters(
&if_stmt.orelse,
pkgpath,
Expand All @@ -312,17 +336,29 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
}
} else {
for name in &names {
add_stmt(name, stmt, place_holder_map, body_map);
add_stmt(
name,
stmt,
BacktrackKind::OrElse,
place_holder_map,
body_map,
);
}
names.clear();
}
names.clear();
}
ast::Stmt::SchemaAttr(schema_attr) => {
let name = schema_attr.name.node.as_str();
if is_in_if {
in_if_names.push(name.to_string());
} else {
add_stmt(name, stmt, place_holder_map, body_map);
add_stmt(
name,
stmt,
BacktrackKind::Normal,
place_holder_map,
body_map,
);
}
}
_ => {}
Expand Down
38 changes: 29 additions & 9 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use kclvm_sema::ty::{ANY_TYPE_STR, STR_TYPE_STR};

use crate::check_backtrack_stop;
use crate::codegen::error as kcl_error;
use crate::codegen::llvm::context::BacktrackKind;
use crate::codegen::llvm::context::BacktrackMeta;
use crate::codegen::llvm::utils;
use crate::codegen::traits::*;
Expand Down Expand Up @@ -262,21 +263,38 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {

fn walk_if_stmt(&self, if_stmt: &'ctx ast::IfStmt) -> Self::Result {
check_backtrack_stop!(self);
let cond = self
.walk_expr(&if_stmt.cond)
.expect(kcl_error::COMPILE_ERROR_MSG);
let cond = self.walk_expr(&if_stmt.cond)?;
let then_block = self.append_block("");
let else_block = self.append_block("");
let end_block = self.append_block("");
let is_truth = self.value_is_truthy(cond);
self.cond_br(is_truth, then_block, else_block);
self.builder.position_at_end(then_block);
self.walk_stmts(&if_stmt.body)
.expect(kcl_error::COMPILE_ERROR_MSG);
// Is backtrack only orelse stmt?
if self.is_backtrack_only_or_else() {
self.ok_result()?;
self.br(end_block);
self.builder.position_at_end(else_block);
self.walk_stmts(&if_stmt.orelse)?;
self.br(end_block);
self.builder.position_at_end(end_block);
return Ok(self.none_value());
}
// Is backtrack only if stmt?
if self.is_backtrack_only_if() {
self.walk_stmts(&if_stmt.body)?;
self.br(end_block);
self.builder.position_at_end(else_block);
self.ok_result()?;
self.br(end_block);
self.builder.position_at_end(end_block);
return Ok(self.none_value());
}
// Normal full if stmt.
self.walk_stmts(&if_stmt.body)?;
self.br(end_block);
self.builder.position_at_end(else_block);
self.walk_stmts(&if_stmt.orelse)
.expect(kcl_error::COMPILE_ERROR_MSG);
self.walk_stmts(&if_stmt.orelse)?;
self.br(end_block);
self.builder.position_at_end(end_block);
Ok(self.none_value())
Expand Down Expand Up @@ -432,7 +450,8 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
pkgpath_without_prefix!(runtime_type),
));
let mut place_holder_map: HashMap<String, Vec<FunctionValue<'ctx>>> = HashMap::new();
let mut body_map: HashMap<String, Vec<&ast::Node<ast::Stmt>>> = HashMap::new();
let mut body_map: HashMap<String, (Vec<&ast::Node<ast::Stmt>>, BacktrackKind)> =
HashMap::new();
// Enter the function
self.push_function(function);
// Lambda function body
Expand Down Expand Up @@ -885,7 +904,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
if k == kclvm_runtime::CAL_MAP_INDEX_SIGNATURE {
continue;
}
let stmt_list = body_map.get(k).expect(kcl_error::INTERNAL_ERROR_MSG);
let (stmt_list, kind) = body_map.get(k).expect(kcl_error::INTERNAL_ERROR_MSG);
let mut if_level = 0;
for (attr_func, stmt) in functions.iter().zip(stmt_list) {
let function = *attr_func;
Expand Down Expand Up @@ -944,6 +963,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
level: if_level,
count: 0,
stop: false,
kind: kind.clone(),
});
} else {
if_level = 0;
Expand Down
Loading

0 comments on commit 3fc5f3e

Please sign in to comment.