Skip to content

Commit

Permalink
[cases] add more config for eval.mmm_mem*
Browse files Browse the repository at this point in the history
  • Loading branch information
SharzyL committed Nov 22, 2024
1 parent ab6e4fa commit 2401b73
Show file tree
Hide file tree
Showing 7 changed files with 1,414 additions and 567 deletions.
25 changes: 16 additions & 9 deletions tests/eval/_mmm_mem/default.nix
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
{ linkerScript
{ lib
, linkerScript
, makeBuilder
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : len: kernel_src:
build_mmm = caseName /* must be consistent with attr name */: bn: vl:
builder {
caseName = caseName;

Expand All @@ -16,8 +17,8 @@ let
buildPhase = ''
runHook preBuild
$CC -T${linkerScript} -DLEN=${toString len} \
${./mmm_main.c} ${kernel_src} \
$CC -T${linkerScript} -DLEN=${toString bn} \
${./mmm_main.c} ./mmm_${toString bn}_vl${toString vl}.S \
${t1main} \
-o $pname.elf
Expand All @@ -26,8 +27,14 @@ let

meta.description = "test case 'ntt'";
};

in {
mmm_mem_512_vl4096 = build_ntt "mmm_mem_512_vl4096" 4096 ./mmm_512_vl4096.S;
mmm_mem_256_vl4096 = build_ntt "mmm_mem_256_vl4096" 4096 ./mmm_256_vl4096.S;
}
configs = lib.cartesianProduct { bn = [ 256 512 ]; vl = [ 256 512 4096 ]; };
in
builtins.listToAttrs (builtins.map
(
{ bn, vl }:
let name = "mmm_mem_${toString bn}_vl${toString vl}"; in
lib.nameValuePair
name
(build_mmm name bn vl)
)
configs)
298 changes: 298 additions & 0 deletions tests/eval/_mmm_mem/mmm_256_vl128.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
.text
.balign 16
.globl mmm
.type mmm,@function
# assume VLEN >= 128, BN = 256, SEW = 16 * 2 = 32
# we only support LMUL = 1 for now
# P, A, B, AB should have 20 elements
mmm:
# quite SIMD
li t0, 4 # in case way > 31
vsetvli zero, t0, e32, m1, ta, ma
# stride
li t1, 20
# start loop of niter + 1 times
li t4,0
1:
# AB = B_i*A + AB
# !!!!!! important: lw here assumes SEW = 32
# T0 is used in vmacc, do not use for temp now!
lw t0, 0(a2)
addi a2, a2, 4 # advance B by a SEW

# carry for ABV_0
vmv.v.i v30,0
# loop variable
li t5,0

# ---
# macc (V=a1, VV=v10, VVN=10, ngroupreg=5)
# ---

# load one group of values from arg
# offset of one group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
slli t2,t5,5
add t3,t2,a1
vlsseg5e32.v v10, (t3), t1
add t3,t2,a0
vlsseg5e32.v v20, (t3), t1
vmacc.vx v20, t0, v10
vmacc.vx v21, t0, v11
vmacc.vx v22, t0, v12
vmacc.vx v23, t0, v13
vmacc.vx v24, t0, v14
# store one group of AB
vssseg5e32.v v20, (t3), t1

# ---
# propagate_niter
# ---

# start loop of niter + 1 times
# use T2 as outer loop index
li t2,0
9:
# mask
# set TV2 for every propagate()
# set TV2 every time (see slide1up below)
li t0,65535
vmv.v.x v31,t0

# carry for ABV_0
vmv.v.i v30,0

# loop variable
li t5,0

# load last group of values from arg
# offset of last group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
# LOOP2 is now ngroup - 1
slli t3,t5,5
add t3,t3,a0
vlsseg5e32.v v20, (t3), t1

# ---
# propagate (j=0, ngroupreg=5)
# ---

vadd.vv v20, v20, v30
# save carry in TV
vsrl.vi v30, v20, 16
# mod 2 ** 16
vand.vv v20, v20, v31
vadd.vv v21, v21, v30

# ---
# propagate (j=1, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v21, 16
# mod 2 ** 16
vand.vv v21, v21, v31
vadd.vv v22, v22, v30

# ---
# propagate (j=2, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v22, 16
# mod 2 ** 16
vand.vv v22, v22, v31
vadd.vv v23, v23, v30

# ---
# propagate (j=3, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v23, 16
# mod 2 ** 16
vand.vv v23, v23, v31
vadd.vv v24, v24, v30

# ---
# propagate (j=4, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v24, 16
# mod 2 ** 16
vand.vv v24, v24, v31
# store last group of AB
vssseg5e32.v v20, (t3), t1

# update carry of AB_{ntotalreg - 1} to AB_0
vlse32.v v20, (a0), t1
vslide1up.vx v31, v30, zero
vadd.vv v20, v20, v31
vsse32.v v20, (a0), t1
addi t2,t2,1
li t0,4
bne t2,t0,9b
# !!!!!! important: lw here assumes SEW = 32
# T0 is used in vmacc, do not use for temp now!
lw t0, 0(a0)
mul t0, t0, a4
# mod 2 ** 16
# !!!! important: here we assume SEW = 32 and XLEN = 64
sll t0, t0, 16
srl t0, t0, 16

# loop variable
li t5,0

# ---
# macc (V=a3, VV=v0, VVN=0, ngroupreg=5)
# ---

# load one group of values from arg
# offset of one group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
slli t2,t5,5
add t3,t2,a3
vlsseg5e32.v v0, (t3), t1
add t3,t2,a0
vlsseg5e32.v v20, (t3), t1
vmacc.vx v20, t0, v0
vmacc.vx v21, t0, v1
vmacc.vx v22, t0, v2
vmacc.vx v23, t0, v3
vmacc.vx v24, t0, v4
# store one group of AB
vssseg5e32.v v20, (t3), t1

# ---
# propagate_niter
# ---

# start loop of niter + 1 times
# use T2 as outer loop index
li t2,0
9:
# mask
# set TV2 for every propagate()
# set TV2 every time (see slide1up below)
li t0,65535
vmv.v.x v31,t0

# carry for ABV_0
vmv.v.i v30,0

# loop variable
li t5,0

# load last group of values from arg
# offset of last group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
# LOOP2 is now ngroup - 1
slli t3,t5,5
add t3,t3,a0
vlsseg5e32.v v20, (t3), t1

# ---
# propagate (j=0, ngroupreg=5)
# ---

vadd.vv v20, v20, v30
# save carry in TV
vsrl.vi v30, v20, 16
# mod 2 ** 16
vand.vv v20, v20, v31
vadd.vv v21, v21, v30

# ---
# propagate (j=1, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v21, 16
# mod 2 ** 16
vand.vv v21, v21, v31
vadd.vv v22, v22, v30

# ---
# propagate (j=2, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v22, 16
# mod 2 ** 16
vand.vv v22, v22, v31
vadd.vv v23, v23, v30

# ---
# propagate (j=3, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v23, 16
# mod 2 ** 16
vand.vv v23, v23, v31
vadd.vv v24, v24, v30

# ---
# propagate (j=4, ngroupreg=5)
# ---

# save carry in TV
vsrl.vi v30, v24, 16
# mod 2 ** 16
vand.vv v24, v24, v31
# store last group of AB
vssseg5e32.v v20, (t3), t1

# update carry of AB_{ntotalreg - 1} to AB_0
vlse32.v v20, (a0), t1
vslide1up.vx v31, v30, zero
vadd.vv v20, v20, v31
vsse32.v v20, (a0), t1
addi t2,t2,1
li t0,4
bne t2,t0,9b

# update carry of AB_4 to AB_0
# since we need to substract AB_0
vlse32.v v20, (a0), t1
# AB / word
vslide1down.vx v30, v20, zero
# do not need vsse now
# just store it in TV for move

# -----
# move
# -----

# move AB_1 to AB_0, AB_2 to AB_1, ... , AB_0 (in TV now) to AB_4
# loop variable
li t5,0
# load last group of values from arg
# offset of last group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
# LOOP2 is now ngroup - 1
slli t2,t5,5
# then offset by 1 element
addi t2,t2,4
add t3,t2,a0
vlsseg4e32.v v20, (t3), t1
# move AB_0 to AB_4
vmv.v.v v24, v30

# back to original offset
addi t3,t3,-4
vssseg5e32.v v20, (t3), t1

addi t4,t4,1
li t0,17

bne t4,t0,1b

ret
Loading

0 comments on commit 2401b73

Please sign in to comment.