From 36e27d5883928b288f91aefef3c22cd13ed04e39 Mon Sep 17 00:00:00 2001 From: Nikita Orlov Date: Mon, 1 Jul 2024 16:46:31 +0200 Subject: [PATCH] #519 improve set integration bench (#527) * #519 improve set integration bench closes #519 issue * fix + improved perf * fix error name --------- Co-authored-by: lanaivina <31368580+lana-shanghai@users.noreply.github.com> --- build.zig.zon | 4 +- src/hint_processor/set.zig | 17 +++++- src/vm/core_test.zig | 4 +- src/vm/memory/memory.zig | 101 +++++++++++++++++++--------------- src/vm/memory/relocatable.zig | 11 +++- 5 files changed, 87 insertions(+), 50 deletions(-) diff --git a/build.zig.zon b/build.zig.zon index 6c1a1c83..db74dd41 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -14,8 +14,8 @@ .hash = "1220ab73fb7cc11b2308edc3364988e05efcddbcac31b707f55e6216d1b9c0da13f1", }, .starknet = .{ - .url = "https://github.com/StringNick/starknet-zig/archive/8cfb4286ffda4ad2781647c3d96b2aec8ccfeb32.zip", - .hash = "122026eaa24834fd2e2cc7e8b6c4eefb03dda08158a2844615f189758fa24d32fc44", + .url = "https://github.com/StringNick/starknet-zig/archive/57810b7a64364f1bf12725ba823385c2a213bfa5.zip", + .hash = "1220d848be799ff21a80c6751c088ea619891ec450f20017cc7aa5cbbeb5904ae8b8", }, }, } diff --git a/src/hint_processor/set.zig b/src/hint_processor/set.zig index 23bfd51e..cfb5b732 100644 --- a/src/hint_processor/set.zig +++ b/src/hint_processor/set.zig @@ -12,6 +12,7 @@ const HintProcessor = @import("hint_processor_def.zig").CairoVMHintProcessor; const HintData = @import("hint_processor_def.zig").HintData; const Relocatable = @import("../vm/memory/relocatable.zig").Relocatable; const MaybeRelocatable = @import("../vm/memory/relocatable.zig").MaybeRelocatable; +const MemoryCell = @import("../vm/memory/memory.zig").MemoryCell; const Felt252 = @import("../math/fields/starknet.zig").Felt252; const hint_codes = @import("builtin_hint_codes.zig"); const MathError = @import("../vm/error.zig").MathError; @@ -60,10 +61,22 @@ pub fn setAdd( // Calculate the range limit. const range_limit = (try set_end_ptr.sub(set_ptr)).offset; + // load all list, and then we compare elements + var elm_segment = vm.segments.memory.getSegmentAtIndex(elm_ptr.segment_index) orelse return HintError.InvalidSetRange; + + if (elm_segment.len < elm_ptr.offset + elm_size) return HintError.InvalidSetRange; + + var set_segment = vm.segments.memory.getSegmentAtIndex(set_ptr.segment_index) orelse return HintError.InvalidSetRange; + + if (set_ptr.offset + range_limit > set_segment.len) return HintError.InvalidSetRange; + + elm_segment = elm_segment[elm_ptr.offset .. elm_ptr.offset + elm_size]; + set_segment = set_segment[set_ptr.offset .. set_ptr.offset + range_limit]; + // Iterate over the set elements. - for (0..range_limit) |i| { + for (0..range_limit / elm_size) |i| { // Check if the element is in the set. - if (try vm.memEq(elm_ptr, try set_ptr.addUint(elm_size * i), elm_size)) { + if (MemoryCell.eqlSlice(elm_segment, set_segment[i * elm_size .. (i + 1) * elm_size])) { // Insert index of the element into the virtual machine. try hint_utils.insertValueFromVarName( allocator, diff --git a/src/vm/core_test.zig b/src/vm/core_test.zig index ff677121..1af05223 100644 --- a/src/vm/core_test.zig +++ b/src/vm/core_test.zig @@ -3666,7 +3666,7 @@ test "CairoVM: runInstruction without any insertion in the memory" { // Compare each cell in VM's memory with the corresponding cell in the expected memory. for (vm.segments.memory.data.items, 0..) |d, i| { for (d.items, 0..) |cell, j| { - try expect(cell.eql(expected_memory.data.items[i].items[j])); + try expect(cell.eql(&expected_memory.data.items[i].items[j])); } } } @@ -3839,7 +3839,7 @@ test "CairoVM: runInstruction with Op0 being deduced" { // Compare each cell in VM's memory with the corresponding cell in the expected memory. for (vm.segments.memory.data.items, 0..) |d, i| { for (d.items, 0..) |cell, j| { - try expect(cell.eql(expected_memory.data.items[i].items[j])); + try expect(cell.eql(&expected_memory.data.items[i].items[j])); } } } diff --git a/src/vm/memory/memory.zig b/src/vm/memory/memory.zig index a022ce46..03d38d2f 100644 --- a/src/vm/memory/memory.zig +++ b/src/vm/memory/memory.zig @@ -23,7 +23,7 @@ const RangeCheckBuiltinRunner = @import("../builtins/builtin_runner/range_check. // Function that validates a memory address and returns a list of validated adresses pub const validation_rule = *const fn (Allocator, *Memory, Relocatable) anyerror!std.ArrayList(Relocatable); -pub const MemoryCell = struct { +pub const MemoryCell = extern struct { /// Represents a memory cell that holds relocation information and access status. const Self = @This(); const ACCESS_MASK: u64 = 1 << 62; @@ -103,8 +103,12 @@ pub const MemoryCell = struct { /// # Returns /// /// Returns `true` if both MemoryCell instances are equal, otherwise `false`. - pub fn eql(self: Self, other: Self) bool { - return std.mem.eql(u64, self.data[0..], other.data[0..]); + pub fn eql(self: *const Self, other: *const Self) bool { + inline for (0..4) |i| { + if (self.data[i] != other.data[i]) return false; + } + + return true; } /// Checks equality between slices of MemoryCell instances. @@ -124,7 +128,7 @@ pub const MemoryCell = struct { if (a.len != b.len) return false; if (a.ptr == b.ptr) return true; - for (a, b) |a_elem, b_elem| { + for (a, b) |*a_elem, *b_elem| { if (!a_elem.eql(b_elem)) return false; } @@ -609,20 +613,11 @@ pub const Memory = struct { /// # Returns /// /// Returns the segment of MemoryCell items if it exists, or `null` if not found. - fn getSegmentAtIndex(self: *Self, idx: i64) ?[]MemoryCell { - return switch (idx < 0) { - true => blk: { - const i: usize = @intCast(-(idx + 1)); - break :blk if (i < self.temp_data.items.len) - self.temp_data.items[i].items - else - null; - }, - false => if (idx < self.data.items.len) - self.data.items[@intCast(idx)].items - else - null, - }; + pub inline fn getSegmentAtIndex(self: *const Self, idx: i64) ?[]MemoryCell { + return if (idx < 0) { + const i: usize = @bitCast(-(idx + 1)); + return if (i >= self.temp_data.items.len) null else self.temp_data.items[i].items; + } else if (idx >= self.data.items.len) null else self.data.items[@intCast(idx)].items; } /// Compares two memory segments within the VM's memory starting from specified addresses @@ -663,12 +658,6 @@ pub const Memory = struct { const l_idx = lhs.offset + i; const r_idx = rhs.offset + i; - // std.log.err("lhs: {any}, rhs: {any}, i: {any}, {any}", .{ - // if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE, if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE, i, MemoryCell.cmp( - // if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE, - // if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE, - // ), - // }); return switch (MemoryCell.cmp( if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE, if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE, @@ -700,7 +689,7 @@ pub const Memory = struct { /// # Returns /// /// Returns `true` if segments are equal up to the specified length, otherwise `false`. - pub fn memEq(self: *Self, lhs: Relocatable, rhs: Relocatable, len: usize) !bool { + pub fn memEq(self: *const Self, lhs: Relocatable, rhs: Relocatable, len: usize) !bool { // Check if the left and right addresses are the same, in which case the segments are equal. if (lhs.eq(rhs)) return true; @@ -714,29 +703,25 @@ pub const Memory = struct { // Get the segment starting from the right-hand address. const r: ?[]MemoryCell = if (self.getSegmentAtIndex(rhs.segment_index)) |s| // Check if the offset is within the bounds of the segment. - if (rhs.offset < s.len) s[rhs.offset..] else if (l == null) return true else return false - else if (l == null) return true else return false; + if (rhs.offset < s.len) s[rhs.offset..] else return l == null + else + return l == null; // If the left segment exists, perform further checks. if (l) |ls| { // If the right segment also exists, compare the segments up to the specified length. - if (r) |rs| { - // Determine the actual lengths to compare. - const lhs_len = @min(ls.len, len); - const rhs_len = @min(rs.len, len); + // Determine the actual lengths to compare. + const lhs_len = @min(ls.len, len); + const rhs_len = @min(r.?.len, len); - // Compare slices of MemoryCell items up to the specified length. - if (lhs_len != rhs_len) return false; - - return MemoryCell.eqlSlice(ls[0..lhs_len], rs[0..rhs_len]); - } + // Compare slices of MemoryCell items up to the specified length. + if (lhs_len != rhs_len) return false; - // If only the left segment exists, return false. - return false; + return MemoryCell.eqlSlice(ls[0..lhs_len], r.?[0..rhs_len]); } - // If the left segment does not exist, return true only if the right segment is also null. - return r == null; + // If only the left segment exists, return false. + return false; } /// Retrieves a range of memory values starting from a specified address. @@ -769,6 +754,36 @@ pub const Memory = struct { return values; } + /// Retrieves a range of memory values starting from a specified address. + /// + /// # Arguments + /// + /// * `allocator`: The allocator used for the memory allocation of the returned list. + /// * `address`: The starting address in the memory from which the range is retrieved. + /// * `size`: The size of the range to be retrieved. + /// + /// # Returns + /// + /// Returns a list containing memory values retrieved from the specified range starting at the given address. + /// The list may contain `MemoryCell.NONE` elements for inaccessible memory positions. + /// + /// # Errors + /// + /// Returns an error if there are any issues encountered during the retrieval of the memory range. + pub fn getRangeRaw( + self: *Self, + allocator: Allocator, + address: Relocatable, + size: usize, + ) !std.ArrayList(?MaybeRelocatable) { + var values = std.ArrayList(?MaybeRelocatable).init(allocator); + errdefer values.deinit(); + for (0..size) |i| { + try values.append(self.get(try address.addUint(i))); + } + return values; + } + /// Counts the number of accessed addresses within a specified segment in the VM memory. /// /// # Arguments @@ -2426,9 +2441,9 @@ test "MemoryCell: eql function" { memoryCell4.markAccessed(); // Test checks - try expect(memoryCell1.eql(memoryCell2)); - try expect(!memoryCell1.eql(memoryCell3)); - try expect(!memoryCell1.eql(memoryCell4)); + try expect(memoryCell1.eql(&memoryCell2)); + try expect(!memoryCell1.eql(&memoryCell3)); + try expect(!memoryCell1.eql(&memoryCell4)); } test "MemoryCell: eqlSlice should return false if slice len are not the same" { diff --git a/src/vm/memory/relocatable.zig b/src/vm/memory/relocatable.zig index 96bd4672..a9e538df 100644 --- a/src/vm/memory/relocatable.zig +++ b/src/vm/memory/relocatable.zig @@ -312,7 +312,16 @@ pub const MaybeRelocatable = union(enum) { /// * `true` if the two instances are equal. /// * `false` otherwise. pub fn eq(self: Self, other: Self) bool { - return std.meta.eql(self, other); + return switch (self) { + inline .felt => |f| switch (other) { + inline .felt => |f1| f.eql(f1), + else => false, + }, + inline .relocatable => |r| switch (other) { + inline .relocatable => |r1| r.eq(r1), + else => false, + }, + }; } /// Determines if self is less than other.