diff --git a/nimbus/evm/blscurve.nim b/nimbus/evm/blscurve.nim index 8e11dc56a..db8a5aa48 100644 --- a/nimbus/evm/blscurve.nim +++ b/nimbus/evm/blscurve.nim @@ -1,5 +1,5 @@ # Nimbus -# Copyright (c) 2020-2024 Status Research & Development GmbH +# Copyright (c) 2020-2025 Status Research & Development GmbH # Licensed under either of # * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or # http://www.apache.org/licenses/LICENSE-2.0) @@ -60,6 +60,13 @@ template toCC(x: auto): auto = elif x is BLS_G2P: toCC(x, cblst_p2_affine) +func isOverModulus(data: openArray[byte]): bool = + const + fieldModulus = StUint[512].fromHex "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" + var z: StUint[512] + z.initFromBytesBE(data) + z >= fieldModulus + func fromBytes*(ret: var BLS_SCALAR, raw: openArray[byte]): bool = const L = 32 if raw.len < L: @@ -73,6 +80,8 @@ func fromBytes(ret: var BLS_FP, raw: openArray[byte]): bool = if raw.len < L: return false let pa = cast[ptr array[L, byte]](raw[0].unsafeAddr) + if isOverModulus(pa[]): + return false blst_fp_from_bendian(toCV(ret), pa[]) true @@ -150,6 +159,12 @@ func pack(g: var BLS_G2P, x0, x1, y0, y1: BLS_FP): bool = g = blst_p2_affine(x: blst_fp2(fp: [x0, x1]), y: blst_fp2(fp: [y0, y1])) blst_p2_affine_on_curve(toCV(g)).int == 1 +func subgroupCheck*(P: BLS_G1): bool {.inline.} = + blst_p1_in_g1(toCC(P)).int == 1 + +func subgroupCheck*(P: BLS_G2): bool {.inline.} = + blst_p2_in_g2(toCC(P)).int == 1 + func subgroupCheck*(P: BLS_G1P): bool {.inline.} = blst_p1_affine_in_g1(toCC(P)).int == 1 diff --git a/nimbus/evm/precompiles.nim b/nimbus/evm/precompiles.nim index 6dde38beb..1646dcb23 100644 --- a/nimbus/evm/precompiles.nim +++ b/nimbus/evm/precompiles.nim @@ -1,5 +1,5 @@ # Nimbus -# Copyright (c) 2018-2024 Status Research & Development GmbH +# Copyright (c) 2018-2025 Status Research & Development GmbH # Licensed under either of # * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or # http://www.apache.org/licenses/LICENSE-2.0) @@ -481,6 +481,9 @@ func blsG1MultiExp(c: Computation): EvmResultVoid = if not p.decodePoint(input.toOpenArray(off, off+127)): return err(prcErr(PrcInvalidPoint)) + if not p.subgroupCheck: + return err(prcErr(PrcInvalidPoint)) + # Decode scalar value if not s.fromBytes(input.toOpenArray(off+128, off+159)): return err(prcErr(PrcInvalidParam)) @@ -546,6 +549,9 @@ func blsG2MultiExp(c: Computation): EvmResultVoid = if not p.decodePoint(input.toOpenArray(off, off+255)): return err(prcErr(PrcInvalidPoint)) + if not p.subgroupCheck: + return err(prcErr(PrcInvalidPoint)) + # Decode scalar value if not s.fromBytes(input.toOpenArray(off+256, off+287)): return err(prcErr(PrcInvalidParam))