From 382c1eea4ab158e9583bb1500e580405ba82510d Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Wed, 17 Jul 2024 11:15:14 +0200 Subject: [PATCH] Add Or() and AndNot() --- README.md | 2 +- and_amd64.go | 18 +++++++ and_amd64.s | 112 ++++++++++++++++++++++++++++++++++++++++---- and_arm64.go | 17 +++++++ and_arm64.s | 49 ++++++++++++++++++- and_stubs_amd64.go | 10 ++++ and_test.go | 84 ++++++++++++++++++++++++++++++--- fallback.go | 8 ++++ internal/asm/src.go | 15 ++++-- lib.go | 60 ++++++++++++++++++++++++ 10 files changed, 353 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 3ed3802..891ab9d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ go-and ====== -Fast bitwise and for `[]byte` slices. +Fast bitwise and, or and andn for `[]byte` slices. ```go import "github.com/bwesterb/go-and" diff --git a/and_amd64.go b/and_amd64.go index 353025a..2b48d63 100644 --- a/and_amd64.go +++ b/and_amd64.go @@ -8,3 +8,21 @@ func and(dst, a, b []byte) { l <<= 8 andGeneric(dst[l:], a[l:], b[l:]) } + +func or(dst, a, b []byte) { + l := uint64(len(a)) >> 8 + if l != 0 { + orAVX2(&dst[0], &a[0], &b[0], l) + } + l <<= 8 + orGeneric(dst[l:], a[l:], b[l:]) +} + +func andNot(dst, a, b []byte) { + l := uint64(len(a)) >> 8 + if l != 0 { + andNotAVX2(&dst[0], &a[0], &b[0], l) + } + l <<= 8 + andNotGeneric(dst[l:], a[l:], b[l:]) +} diff --git a/and_amd64.s b/and_amd64.s index e74c47f..1cd0438 100644 --- a/and_amd64.s +++ b/and_amd64.s @@ -27,14 +27,110 @@ loop: VMOVDQU 192(CX), Y14 VMOVDQU 224(AX), Y7 VMOVDQU 224(CX), Y15 - VPAND Y0, Y8, Y8 - VPAND Y1, Y9, Y9 - VPAND Y2, Y10, Y10 - VPAND Y3, Y11, Y11 - VPAND Y4, Y12, Y12 - VPAND Y5, Y13, Y13 - VPAND Y6, Y14, Y14 - VPAND Y7, Y15, Y15 + VPAND Y8, Y0, Y8 + VPAND Y9, Y1, Y9 + VPAND Y10, Y2, Y10 + VPAND Y11, Y3, Y11 + VPAND Y12, Y4, Y12 + VPAND Y13, Y5, Y13 + VPAND Y14, Y6, Y14 + VPAND Y15, Y7, Y15 + VMOVDQU Y8, (DX) + VMOVDQU Y9, 32(DX) + VMOVDQU Y10, 64(DX) + VMOVDQU Y11, 96(DX) + VMOVDQU Y12, 128(DX) + VMOVDQU Y13, 160(DX) + VMOVDQU Y14, 192(DX) + VMOVDQU Y15, 224(DX) + ADDQ $0x00000100, AX + ADDQ $0x00000100, CX + ADDQ $0x00000100, DX + SUBQ $0x00000001, BX + JNZ loop + RET + +// func orAVX2(dst *byte, a *byte, b *byte, l uint64) +// Requires: AVX, AVX2 +TEXT ·orAVX2(SB), NOSPLIT, $0-32 + MOVQ a+8(FP), AX + MOVQ b+16(FP), CX + MOVQ dst+0(FP), DX + MOVQ l+24(FP), BX + +loop: + VMOVDQU (AX), Y0 + VMOVDQU (CX), Y8 + VMOVDQU 32(AX), Y1 + VMOVDQU 32(CX), Y9 + VMOVDQU 64(AX), Y2 + VMOVDQU 64(CX), Y10 + VMOVDQU 96(AX), Y3 + VMOVDQU 96(CX), Y11 + VMOVDQU 128(AX), Y4 + VMOVDQU 128(CX), Y12 + VMOVDQU 160(AX), Y5 + VMOVDQU 160(CX), Y13 + VMOVDQU 192(AX), Y6 + VMOVDQU 192(CX), Y14 + VMOVDQU 224(AX), Y7 + VMOVDQU 224(CX), Y15 + VPOR Y8, Y0, Y8 + VPOR Y9, Y1, Y9 + VPOR Y10, Y2, Y10 + VPOR Y11, Y3, Y11 + VPOR Y12, Y4, Y12 + VPOR Y13, Y5, Y13 + VPOR Y14, Y6, Y14 + VPOR Y15, Y7, Y15 + VMOVDQU Y8, (DX) + VMOVDQU Y9, 32(DX) + VMOVDQU Y10, 64(DX) + VMOVDQU Y11, 96(DX) + VMOVDQU Y12, 128(DX) + VMOVDQU Y13, 160(DX) + VMOVDQU Y14, 192(DX) + VMOVDQU Y15, 224(DX) + ADDQ $0x00000100, AX + ADDQ $0x00000100, CX + ADDQ $0x00000100, DX + SUBQ $0x00000001, BX + JNZ loop + RET + +// func andNotAVX2(dst *byte, a *byte, b *byte, l uint64) +// Requires: AVX, AVX2 +TEXT ·andNotAVX2(SB), NOSPLIT, $0-32 + MOVQ a+8(FP), AX + MOVQ b+16(FP), CX + MOVQ dst+0(FP), DX + MOVQ l+24(FP), BX + +loop: + VMOVDQU (AX), Y0 + VMOVDQU (CX), Y8 + VMOVDQU 32(AX), Y1 + VMOVDQU 32(CX), Y9 + VMOVDQU 64(AX), Y2 + VMOVDQU 64(CX), Y10 + VMOVDQU 96(AX), Y3 + VMOVDQU 96(CX), Y11 + VMOVDQU 128(AX), Y4 + VMOVDQU 128(CX), Y12 + VMOVDQU 160(AX), Y5 + VMOVDQU 160(CX), Y13 + VMOVDQU 192(AX), Y6 + VMOVDQU 192(CX), Y14 + VMOVDQU 224(AX), Y7 + VMOVDQU 224(CX), Y15 + VPANDN Y8, Y0, Y8 + VPANDN Y9, Y1, Y9 + VPANDN Y10, Y2, Y10 + VPANDN Y11, Y3, Y11 + VPANDN Y12, Y4, Y12 + VPANDN Y13, Y5, Y13 + VPANDN Y14, Y6, Y14 + VPANDN Y15, Y7, Y15 VMOVDQU Y8, (DX) VMOVDQU Y9, 32(DX) VMOVDQU Y10, 64(DX) diff --git a/and_arm64.go b/and_arm64.go index f436470..ffecf2d 100644 --- a/and_arm64.go +++ b/and_arm64.go @@ -3,6 +3,9 @@ package and //go:noescape func andNEON(dst, a, b *byte, len uint64) +//go:noescape +func orNEON(dst, a, b *byte, len uint64) + func and(dst, a, b []byte) { l := uint64(len(a)) >> 8 if l != 0 { @@ -11,3 +14,17 @@ func and(dst, a, b []byte) { l <<= 8 andGeneric(dst[l:], a[l:], b[l:]) } + +func or(dst, a, b []byte) { + l := uint64(len(a)) >> 8 + if l != 0 { + orNEON(&dst[0], &a[0], &b[0], l) + } + l <<= 8 + orGeneric(dst[l:], a[l:], b[l:]) +} + +func andNot(dst, a, b []byte) { + // TODO: Write a NEON version for this + andNotGeneric(dst, a, b) +} diff --git a/and_arm64.s b/and_arm64.s index d3a4a80..9658ab6 100644 --- a/and_arm64.s +++ b/and_arm64.s @@ -46,5 +46,52 @@ loop: SUBS $1, R3, R3 CBNZ R3, loop - + + RET + +// func orNEON(dst *byte, a *byte, b *byte, l uint64) +TEXT ·orNEON(SB), NOSPLIT, $0-32 + MOVD dst+0(FP), R0 + MOVD a+8(FP), R1 + MOVD b+16(FP), R2 + MOVD l+24(FP), R3 + +loop: + VLD1.P 64(R1), [ V0.B16, V1.B16, V2.B16, V3.B16] + VLD1.P 64(R2), [ V4.B16, V5.B16, V6.B16, V7.B16] + VLD1.P 64(R1), [ V8.B16, V9.B16, V10.B16, V11.B16] + VLD1.P 64(R2), [V12.B16, V13.B16, V14.B16, V15.B16] + VLD1.P 64(R1), [V16.B16, V17.B16, V18.B16, V19.B16] + VLD1.P 64(R2), [V20.B16, V21.B16, V22.B16, V23.B16] + VLD1.P 64(R1), [V24.B16, V25.B16, V26.B16, V27.B16] + VLD1.P 64(R2), [V28.B16, V29.B16, V30.B16, V31.B16] + + VORR V0.B16, V4.B16, V0.B16 + VORR V1.B16, V5.B16, V1.B16 + VORR V2.B16, V6.B16, V2.B16 + VORR V3.B16, V7.B16, V3.B16 + + VORR V8.B16, V12.B16, V8.B16 + VORR V9.B16, V13.B16, V9.B16 + VORR V10.B16, V14.B16, V10.B16 + VORR V11.B16, V15.B16, V11.B16 + + VORR V16.B16, V20.B16, V16.B16 + VORR V17.B16, V21.B16, V17.B16 + VORR V18.B16, V22.B16, V18.B16 + VORR V19.B16, V23.B16, V19.B16 + + VORR V24.B16, V28.B16, V24.B16 + VORR V25.B16, V29.B16, V25.B16 + VORR V26.B16, V30.B16, V26.B16 + VORR V27.B16, V31.B16, V27.B16 + + VST1.P [ V0.B16, V1.B16, V2.B16, V3.B16], 64(R0) + VST1.P [ V8.B16, V9.B16, V10.B16, V11.B16], 64(R0) + VST1.P [V16.B16, V17.B16, V18.B16, V19.B16], 64(R0) + VST1.P [V24.B16, V25.B16, V26.B16, V27.B16], 64(R0) + + SUBS $1, R3, R3 + CBNZ R3, loop + RET diff --git a/and_stubs_amd64.go b/and_stubs_amd64.go index 395b2fa..32a2071 100644 --- a/and_stubs_amd64.go +++ b/and_stubs_amd64.go @@ -6,3 +6,13 @@ package and // //go:noescape func andAVX2(dst *byte, a *byte, b *byte, l uint64) + +// Sets dst to the bitwise or of a and b assuming all are 256*l bytes +// +//go:noescape +func orAVX2(dst *byte, a *byte, b *byte, l uint64) + +// Sets dst to the bitwise and of not(a) and b assuming all are 256*l bytes +// +//go:noescape +func andNotAVX2(dst *byte, a *byte, b *byte, l uint64) diff --git a/and_test.go b/and_test.go index 55c3fe1..9103da4 100644 --- a/and_test.go +++ b/and_test.go @@ -3,10 +3,12 @@ package and import ( "bytes" "math/rand/v2" + "reflect" + "runtime" "testing" ) -func testAgainstGeneric(t *testing.T, size int) { +func testAgainstGeneric(t *testing.T, fancy, generic func(dst, a, b []byte), size int) { a := make([]byte, size) b := make([]byte, size) c1 := make([]byte, size) @@ -16,19 +18,39 @@ func testAgainstGeneric(t *testing.T, size int) { a[i] = uint8(rng.UintN(256)) b[i] = uint8(rng.UintN(256)) } - And(c1, a, b) - andGeneric(c2, a, b) + fancy(c1, a, b) + generic(c2, a, b) if !bytes.Equal(c1, c2) { - t.Fatalf("And produced a different result from andGeneric at length %d:\n%x\n%x", size, c1, c2) + t.Fatalf("%s produced a different result from %s at length %d:\n%x\n%x", runtime.FuncForPC(reflect.ValueOf(fancy).Pointer()).Name(), runtime.FuncForPC(reflect.ValueOf(generic).Pointer()).Name(), size, c1, c2) } } -func TestAgainstGeneric(t *testing.T) { +func TestAndAgainstGeneric(t *testing.T) { for i := 0; i < 20; i++ { size := 1 << i - testAgainstGeneric(t, size) + testAgainstGeneric(t, And, andGeneric, size) for j := 0; j < 10; j++ { - testAgainstGeneric(t, size+rand.IntN(100)) + testAgainstGeneric(t, And, andGeneric, size+rand.IntN(100)) + } + } +} + +func TestOrAgainstGeneric(t *testing.T) { + for i := 0; i < 20; i++ { + size := 1 << i + testAgainstGeneric(t, Or, orGeneric, size) + for j := 0; j < 10; j++ { + testAgainstGeneric(t, Or, orGeneric, size+rand.IntN(100)) + } + } +} + +func TestAndNotAgainstGeneric(t *testing.T) { + for i := 0; i < 20; i++ { + size := 1 << i + testAgainstGeneric(t, AndNot, andNotGeneric, size) + for j := 0; j < 10; j++ { + testAgainstGeneric(t, AndNot, andNotGeneric, size+rand.IntN(100)) } } } @@ -56,3 +78,51 @@ func BenchmarkAndGeneric(b *testing.B) { andGeneric(a, a, bb) } } + +func BenchmarkOr(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + Or(a, a, bb) + } +} + +func BenchmarkOrGeneric(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + orGeneric(a, a, bb) + } +} + +func BenchmarkAndNot(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + AndNot(a, a, bb) + } +} + +func BenchmarkAndNotGeneric(b *testing.B) { + b.StopTimer() + size := 1000000 + a := make([]byte, size) + bb := make([]byte, size) + b.SetBytes(int64(size)) + b.StartTimer() + for i := 0; i < b.N; i++ { + andNotGeneric(a, a, bb) + } +} diff --git a/fallback.go b/fallback.go index a9077a8..f4a7a37 100644 --- a/fallback.go +++ b/fallback.go @@ -5,3 +5,11 @@ package and func and(dst, a, b []byte) { andGeneric(dst, a, b) } + +func or(dst, a, b []byte) { + orGeneric(dst, a, b) +} + +func andNot(dst, a, b []byte) { + andNotGeneric(dst, a, b) +} diff --git a/internal/asm/src.go b/internal/asm/src.go index d4f3e80..79e4402 100644 --- a/internal/asm/src.go +++ b/internal/asm/src.go @@ -7,12 +7,18 @@ import ( ) func main() { - // Must be called on 32 byte aligned a, b, dst. - TEXT("andAVX2", NOSPLIT, "func(dst, a, b *byte, l uint64)") + gen("and", VPAND, "Sets dst to the bitwise and of a and b") + gen("or", VPOR, "Sets dst to the bitwise or of a and b") + gen("andNot", VPANDN, "Sets dst to the bitwise and of not(a) and b") + Generate() +} + +func gen(name string, op func(Op, Op, Op), doc string) { + TEXT(name+"AVX2", NOSPLIT, "func(dst, a, b *byte, l uint64)") Pragma("noescape") - Doc("Sets dst to the bitwise and of a and b assuming all are 256*l bytes") + Doc(doc + " assuming all are 256*l bytes") a := Load(Param("a"), GP64()) b := Load(Param("b"), GP64()) dst := Load(Param("dst"), GP64()) @@ -28,7 +34,7 @@ func main() { VMOVDQU(Mem{Base: b, Disp: 32 * i}, bs[i]) } for i := 0; i < len(as); i++ { - VPAND(as[i], bs[i], bs[i]) + op(bs[i], as[i], bs[i]) } for i := 0; i < len(as); i++ { VMOVDQU(bs[i], Mem{Base: dst, Disp: 32 * i}) @@ -41,5 +47,4 @@ func main() { JNZ(LabelRef("loop")) RET() - Generate() } diff --git a/lib.go b/lib.go index 9e71a3f..4a67f0d 100644 --- a/lib.go +++ b/lib.go @@ -34,3 +34,63 @@ func andGeneric(dst, a, b []byte) { dst[i] = a[i] & b[i] } } + +// Writes bitwise or of a and b to dst. +// +// Panics if len(a) ≠ len(b), or len(dst) ≠ len(a). +func Or(dst, a, b []byte) { + if len(a) != len(b) || len(b) != len(dst) { + panic("lengths of a, b and dst must be equal") + } + + if hasAVX2() { + or(dst, a, b) + return + } + orGeneric(dst, a, b) +} + +func orGeneric(dst, a, b []byte) { + i := 0 + + for ; i <= len(a)-8; i += 8 { + binary.LittleEndian.PutUint64( + dst[i:], + binary.LittleEndian.Uint64(a[i:])|binary.LittleEndian.Uint64(b[i:]), + ) + } + + for ; i < len(a); i++ { + dst[i] = a[i] | b[i] + } +} + +// Writes bitwise and of not(a) and b to dst. +// +// Panics if len(a) ≠ len(b), or len(dst) ≠ len(a). +func AndNot(dst, a, b []byte) { + if len(a) != len(b) || len(b) != len(dst) { + panic("lengths of a, b and dst must be equal") + } + + if hasAVX2() { + andNot(dst, a, b) + return + } + andNotGeneric(dst, a, b) +} + +func andNotGeneric(dst, a, b []byte) { + i := 0 + + for ; i <= len(a)-8; i += 8 { + binary.LittleEndian.PutUint64( + dst[i:], + (^binary.LittleEndian.Uint64(a[i:]))&binary.LittleEndian.Uint64(b[i:]), + ) + } + + for ; i < len(a); i++ { + dst[i] = (^a[i]) & b[i] + } +}