Skip to content

Commit

Permalink
Merge pull request #471 from vacantron/t2c/jalr
Browse files Browse the repository at this point in the history
Improve `JALR` execution with JIT-cache
  • Loading branch information
jserv authored Aug 7, 2024
2 parents 9759ad2 + f5d04fb commit 34c3db0
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .ci/riscv-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -e -u -o pipefail

# Install RISCOF
python3 -m pip install git+https://github.com/riscv/riscof
pip3 install git+https://github.com/riscv/riscof.git@d38859f85fe407bcacddd2efcd355ada4683aee4

set -x

Expand Down
Binary file added build/fibonacci.elf
Binary file not shown.
1 change: 1 addition & 0 deletions src/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,7 @@ static void code_cache_flush(struct jit_state *state, riscv_t *rv)
state->offset = state->org_size;
state->n_blocks = 0;
set_reset(&state->set);
jit_cache_clear(rv->jit_cache);
clear_cache_hot(rv->block_cache, (clear_func_t) clear_hot);
return;
}
Expand Down
22 changes: 21 additions & 1 deletion src/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ void jit_translate(riscv_t *rv, block_t *block);
typedef void (*exec_block_func_t)(riscv_t *rv, uintptr_t);

#if RV32_HAS(T2C)
void t2c_compile(block_t *block, uint64_t mem_base);
void t2c_compile(riscv_t *, block_t *);
typedef void (*exec_t2c_func_t)(riscv_t *);

/* The jit-cache records the program counters and the entries of executable
* instructions generated by T2C. Like hardware cache, the old jit-cache will be
* replaced by the new one which uses the same slot.
*/

/* The size of jit-cache table should be the power of 2, thus, we can easily
* access the element by masking the program counter.
*/
#define N_JIT_CACHE_ENTRIES (1 << 12)

struct jit_cache {
uint64_t pc; /* program counter, easy to build LLVM IR with 64-bit width */
void *entry; /* entry of JIT-ed code */
};

struct jit_cache *jit_cache_init();
void jit_cache_exit(struct jit_cache *cache);
void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry);
void jit_cache_clear(struct jit_cache *cache);
#endif
5 changes: 3 additions & 2 deletions src/riscv.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ static void *t2c_runloop(void *arg)
pthread_mutex_lock(&rv->wait_queue_lock);
list_del_init(&entry->list);
pthread_mutex_unlock(&rv->wait_queue_lock);
t2c_compile(entry->block,
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base);
t2c_compile(rv, entry->block);
free(entry);
}
}
Expand Down Expand Up @@ -291,6 +290,7 @@ riscv_t *rv_create(riscv_user_t rv_attr)
mpool_create(sizeof(chain_entry_t) << BLOCK_IR_MAP_CAPACITY_BITS,
sizeof(chain_entry_t));
rv->jit_state = jit_state_init(CODE_CACHE_SIZE);
rv->jit_cache = jit_cache_init();
rv->block_cache = cache_create(BLOCK_MAP_CAPACITY_BITS);
assert(rv->block_cache);
#if RV32_HAS(T2C)
Expand Down Expand Up @@ -392,6 +392,7 @@ void rv_delete(riscv_t *rv)
#endif
mpool_destroy(rv->chain_entry_mp);
jit_state_exit(rv->jit_state);
jit_cache_exit(rv->jit_cache);
cache_free(rv->block_cache);
#endif
free(rv);
Expand Down
1 change: 1 addition & 0 deletions src/riscv_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct riscv_internal {
struct mpool *block_mp, *block_ir_mp;

void *jit_state;
void *jit_cache;
#if RV32_HAS(GDBSTUB)
/* gdbstub instance */
gdbstub_t gdbstub;
Expand Down
72 changes: 55 additions & 17 deletions src/t2c.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ FORCE_INLINE LLVMBasicBlockRef t2c_block_map_search(struct LLVM_block_map *map,
return NULL;
}

#define T2C_OP(inst, code) \
static void t2c_##inst( \
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
LLVMBuilderRef *taken_builder UNUSED, \
LLVMBuilderRef *untaken_builder UNUSED, uint64_t mem_base UNUSED, \
rv_insn_t *ir UNUSED) \
{ \
code; \
#define T2C_OP(inst, code) \
static void t2c_##inst( \
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
LLVMBuilderRef *taken_builder UNUSED, \
LLVMBuilderRef *untaken_builder UNUSED, riscv_t *rv UNUSED, \
uint64_t mem_base UNUSED, rv_insn_t *ir UNUSED) \
{ \
code; \
}

#define T2C_LLVM_GEN_ADDR(reg, rv_member, ir_member) \
Expand Down Expand Up @@ -135,6 +135,9 @@ FORCE_INLINE void t2c_gen_call_io_func(LLVMValueRef start,
&io_param, 1, "");
}

static LLVMTypeRef t2c_jit_cache_func_type;
static LLVMTypeRef t2c_jit_cache_struct_type;

#include "t2c_template.c"
#undef T2C_OP

Expand Down Expand Up @@ -174,14 +177,15 @@ typedef void (*t2c_codegen_block_func_t)(LLVMBuilderRef *builder UNUSED,
LLVMBasicBlockRef *entry UNUSED,
LLVMBuilderRef *taken_builder UNUSED,
LLVMBuilderRef *untaken_builder UNUSED,
riscv_t *rv UNUSED,
uint64_t mem_base UNUSED,
rv_insn_t *ir UNUSED);

static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMTypeRef *param_types UNUSED,
LLVMValueRef start,
LLVMBasicBlockRef *entry,
uint64_t mem_base,
riscv_t *rv,
rv_insn_t *ir,
set_t *set,
struct LLVM_block_map *map)
Expand All @@ -194,7 +198,8 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,

while (1) {
((t2c_codegen_block_func_t) dispatch_table[ir->opcode])(
builder, param_types, start, entry, &tk, &utk, mem_base, ir);
builder, param_types, start, entry, &tk, &utk, rv,
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base, ir);
if (!ir->next)
break;
ir = ir->next;
Expand All @@ -214,8 +219,7 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMPositionBuilderAtEnd(untaken_builder, untaken_entry);
LLVMBuildBr(utk, untaken_entry);
t2c_trace_ebb(&untaken_builder, param_types, start,
&untaken_entry, mem_base, ir->branch_untaken, set,
map);
&untaken_entry, rv, ir->branch_untaken, set, map);
}
}
if (ir->branch_taken) {
Expand All @@ -230,13 +234,13 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMPositionBuilderAtEnd(taken_builder, taken_entry);
LLVMBuildBr(tk, taken_entry);
t2c_trace_ebb(&taken_builder, param_types, start, &taken_entry,
mem_base, ir->branch_taken, set, map);
rv, ir->branch_taken, set, map);
}
}
}
}

void t2c_compile(block_t *block, uint64_t mem_base)
void t2c_compile(riscv_t *rv, block_t *block)
{
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMTypeRef io_members[] = {
Expand All @@ -254,6 +258,16 @@ void t2c_compile(block_t *block, uint64_t mem_base)
LLVMTypeRef param_types[] = {LLVMPointerType(struct_rv, 0)};
LLVMValueRef start = LLVMAddFunction(
module, "start", LLVMFunctionType(LLVMVoidType(), param_types, 1, 0));

LLVMTypeRef t2c_args[1] = {LLVMInt64Type()};
t2c_jit_cache_func_type =
LLVMFunctionType(LLVMVoidType(), t2c_args, 1, false);

/* Notice to the alignment */
LLVMTypeRef jit_cache_memb[2] = {LLVMInt64Type(),
LLVMPointerType(LLVMVoidType(), 0)};
t2c_jit_cache_struct_type = LLVMStructType(jit_cache_memb, 2, false);

LLVMBasicBlockRef first_block = LLVMAppendBasicBlock(start, "first_block");
LLVMBuilderRef first_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(first_builder, first_block);
Expand All @@ -266,8 +280,8 @@ void t2c_compile(block_t *block, uint64_t mem_base)
struct LLVM_block_map map;
map.count = 0;
/* Translate custon IR into LLVM IR */
t2c_trace_ebb(&builder, param_types, start, &entry, mem_base,
block->ir_head, &set, &map);
t2c_trace_ebb(&builder, param_types, start, &entry, rv, block->ir_head,
&set, &map);
/* Offload LLVM IR to LLVM backend */
char *error = NULL, *triple = LLVMGetDefaultTargetTriple();
LLVMExecutionEngineRef engine;
Expand Down Expand Up @@ -298,5 +312,29 @@ void t2c_compile(block_t *block, uint64_t mem_base)

/* Return the function pointer of T2C generated machine code */
block->func = (exec_t2c_func_t) LLVMGetPointerToGlobal(engine, start);
jit_cache_update(rv->jit_cache, block->pc_start, block->func);
block->hot2 = true;
}

struct jit_cache *jit_cache_init()
{
return calloc(N_JIT_CACHE_ENTRIES, sizeof(struct jit_cache));
}

void jit_cache_exit(struct jit_cache *cache)
{
free(cache);
}

void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry)
{
uint32_t pos = pc & (N_JIT_CACHE_ENTRIES - 1);

cache[pos].pc = pc;
cache[pos].entry = entry;
}

void jit_cache_clear(struct jit_cache *cache)
{
memset(cache, 0, N_JIT_CACHE_ENTRIES * sizeof(struct jit_cache));
}
72 changes: 63 additions & 9 deletions src/t2c_template.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,63 @@ T2C_OP(jal, {
}
})

FORCE_INLINE void t2c_jit_cache_helper(LLVMBuilderRef *builder,
LLVMValueRef start,
LLVMValueRef addr,
riscv_t *rv,
rv_insn_t *ir)
{
LLVMBasicBlockRef true_path = LLVMAppendBasicBlock(start, "");
LLVMBuilderRef true_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(true_builder, true_path);

LLVMBasicBlockRef false_path = LLVMAppendBasicBlock(start, "");
LLVMBuilderRef false_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(false_builder, false_path);

/* get jit-cache base address */
LLVMValueRef base = LLVMConstIntToPtr(
LLVMConstInt(LLVMInt64Type(), (long) rv->jit_cache, false),
LLVMPointerType(t2c_jit_cache_struct_type, 0));

/* get index */
LLVMValueRef hash = LLVMBuildAnd(
*builder, addr,
LLVMConstInt(LLVMInt32Type(), N_JIT_CACHE_ENTRIES - 1, false), "");

/* get jit_cache_t::pc */
LLVMValueRef cast =
LLVMBuildIntCast2(*builder, hash, LLVMInt64Type(), false, "");
LLVMValueRef element_ptr = LLVMBuildInBoundsGEP2(
*builder, t2c_jit_cache_struct_type, base, &cast, 1, "");
LLVMValueRef pc_ptr = LLVMBuildStructGEP2(
*builder, t2c_jit_cache_struct_type, element_ptr, 0, "");
LLVMValueRef pc = LLVMBuildLoad2(*builder, LLVMInt32Type(), pc_ptr, "");

/* compare with calculated destination */
LLVMValueRef cmp = LLVMBuildICmp(*builder, LLVMIntEQ, pc, addr, "");

LLVMBuildCondBr(*builder, cmp, true_path, false_path);

/* get jit_cache_t::entry */
LLVMValueRef entry_ptr = LLVMBuildStructGEP2(
true_builder, t2c_jit_cache_struct_type, element_ptr, 1, "");

/* invoke T2C JIT-ed code */
LLVMValueRef t2c_args[1] = {
LLVMConstInt(LLVMInt64Type(), (long) rv, false)};

LLVMBuildCall2(true_builder, t2c_jit_cache_func_type,
LLVMBuildLoad2(true_builder, LLVMInt64Type(), entry_ptr, ""),
t2c_args, 1, "");
LLVMBuildRetVoid(true_builder);

/* return to interpreter if cache-miss */
LLVMBuildStore(false_builder, addr,
t2c_gen_PC_addr(start, &false_builder, ir));
LLVMBuildRetVoid(false_builder);
}

T2C_OP(jalr, {
if (ir->rd)
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 4,
Expand All @@ -40,8 +97,7 @@ T2C_OP(jalr, {
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(Add, val_rs1, ir->imm);
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(And, val_rs1, ~1U);
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

#define BRANCH_FUNC(type, cond) \
Expand Down Expand Up @@ -672,8 +728,7 @@ T2C_OP(clwsp, {

T2C_OP(cjr, {
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

T2C_OP(cmv, {
Expand All @@ -692,8 +747,7 @@ T2C_OP(cjalr, {
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 2,
t2c_gen_ra_addr(start, builder, ir));
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

T2C_OP(cadd, {
Expand Down Expand Up @@ -785,15 +839,15 @@ T2C_OP(fuse5, {
switch (fuse[i].opcode) {
case rv_insn_slli:
t2c_slli(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
case rv_insn_srli:
t2c_srli(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
case rv_insn_srai:
t2c_srai(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
default:
__UNREACHABLE;
Expand Down
43 changes: 43 additions & 0 deletions tests/fibonacci.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fib:
li a5, 1
bleu a0, a5, .L3
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp)
sw s1, 4(sp)
mv s0, a0
addi a0, a0, -1
la t0, fib
jalr ra, 0(t0)
mv s1, a0
addi a0, s0, -2
la t0, fib
jalr ra, 0(t0)
add a0, s1, a0
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
addi sp, sp, 16
jr ra
.L3:
li a0, 1
ret
.LC0:
.string "%d\n"
.text
.align 1
.globl main
.type main, @function
main:
addi sp, sp, -16
sw ra, 12(sp)
li a0, 42
call fib
mv a1, a0
lui a0, %hi(.LC0)
addi a0, a0, %lo(.LC0)
call printf
li a0, 0
lw ra, 12(sp)
addi sp, sp, 16
jr ra

0 comments on commit 34c3db0

Please sign in to comment.