From 36983147ace2c61b6361da7cbb7a737296f2bcb3 Mon Sep 17 00:00:00 2001 From: Karoy Lorentey Date: Mon, 28 Nov 2022 23:39:35 -0800 Subject: [PATCH 1/2] [HashTreeCollections] Add `TreeDictionary.combining(_:by:)` --- .../HashNode/_HashNode+Builder.swift | 136 +++- .../HashNode/_HashNode+Lookups.swift | 28 +- .../_HashNode+Structural subtracting.swift | 2 +- .../HashNode/_HashNode+Subtree Removals.swift | 28 + .../HashNode/_Node+Structural combining.swift | 763 ++++++++++++++++++ .../TreeDictionary+Combine.swift | 61 ++ 6 files changed, 1000 insertions(+), 18 deletions(-) create mode 100644 Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift create mode 100644 Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift diff --git a/Sources/HashTreeCollections/HashNode/_HashNode+Builder.swift b/Sources/HashTreeCollections/HashNode/_HashNode+Builder.swift index 13a0a3343..2a7436812 100644 --- a/Sources/HashTreeCollections/HashNode/_HashNode+Builder.swift +++ b/Sources/HashTreeCollections/HashNode/_HashNode+Builder.swift @@ -70,6 +70,16 @@ extension _HashNode.Builder { Self(level, .item(item, at: bucket)) } + @inlinable @inline(__always) + internal static func anyNode( + _ level: _HashLevel, _ node: __owned _HashNode + ) -> Self { + if node.isCollisionNode { + return self.collisionNode(level, node) + } + return self.node(level, node) + } + @inlinable @inline(__always) internal static func node( _ level: _HashLevel, _ node: __owned _HashNode @@ -125,9 +135,33 @@ extension _HashNode.Builder { } } + @inlinable + internal init( + _ level: _HashLevel, + collisions1: __owned Self, + _ hash1: _Hash, + collisions2: __owned Self, + _ hash2: _Hash + ) { + assert(hash1 != hash2) + let b1 = hash1[level] + let b2 = hash2[level] + self = .empty(level) + if b1 == b2 { + let b = Self( + level.descend(), + collisions1: collisions1, hash1, + collisions2: collisions2, hash2) + self.addNewChildBranch(level, b, at: b1) + } else { + self.addNewChildBranch(level, collisions1, at: b1) + self.addNewChildBranch(level, collisions2, at: b2) + } + } + @inlinable internal __consuming func finalize(_ level: _HashLevel) -> _HashNode { - assert(level.isAtRoot && self.level.isAtRoot) + //assert(level.isAtRoot && self.level.isAtRoot) switch kind { case .empty: return ._emptyNode() @@ -193,10 +227,12 @@ extension _HashNode.Builder { _ level: _HashLevel, _ newItem: __owned Element, at newBucket: _Bucket ) { assert(level == self.level) + assert(!newBucket.isInvalid) switch kind { case .empty: kind = .item(newItem, at: newBucket) case .item(let oldItem, let oldBucket): + assert(!oldBucket.isInvalid) assert(oldBucket != newBucket) let node = _HashNode._regularNode(oldItem, oldBucket, newItem, newBucket) kind = .node(node) @@ -215,6 +251,17 @@ extension _HashNode.Builder { } } + @inlinable + internal mutating func addNewItem( + _ level: _HashLevel, + _ key: Key, + _ value: __owned Value?, + at newBucket: _Bucket + ) { + guard let value = value else { return } + addNewItem(level, (key, value), at: newBucket) + } + @inlinable internal mutating func addNewChildNode( _ level: _HashLevel, _ newChild: __owned _HashNode, at newBucket: _Bucket @@ -358,3 +405,90 @@ extension _HashNode.Builder { return mapValues { _ in () } } } + +extension _HashNode.Builder { + @inlinable + internal static func conflictingItems( + _ level: _HashLevel, + _ item1: Element?, + _ item2: Element?, + at bucket: _Bucket + ) -> Self { + switch (item1, item2) { + case (nil, nil): + return .empty(level) + case let (item1?, nil): + return .item(level, item1, at: bucket) + case let (nil, item2?): + return .item(level, item2, at: bucket) + case let (item1?, item2?): + let h1 = _Hash(item1.key) + let h2 = _Hash(item2.key) + guard h1 != h2 else { + return .collisionNode(level, _HashNode._collisionNode(h1, item1, item2)) + } + let n = _HashNode._build( + level: level.descend(), + item1: item1, h1, + item2: { $0.initialize(to: item2) }, h2) + return .node(level, n.top) + } + } + + @inlinable + internal static func mergedUniqueBranch( + _ level: _HashLevel, + _ node: _HashNode, + by merge: (Element) throws -> Value? + ) rethrows -> Self { + try node.read { l in + var result = Self.empty(level) + if l.isCollisionNode { + let hash = l.collisionHash + for lslot: _HashSlot in .zero ..< l.itemsEndSlot { + let lp = l.itemPtr(at: lslot) + if let v = try merge(lp.pointee) { + result.addNewCollision(level, (lp.pointee.key, v), hash) + } + } + return result + } + for (bucket, lslot) in l.itemMap { + let lp = l.itemPtr(at: lslot) + let v = try merge(lp.pointee) + if let v = v { + result.addNewItem(level, (lp.pointee.key, v), at: bucket) + } + } + for (bucket, lslot) in l.childMap { + let b = try Self.mergedUniqueBranch( + level.descend(), l[child: lslot], by: merge) + result.addNewChildBranch(level, b, at: bucket) + } + return result + } + } + + @inlinable + internal mutating func addNewItems( + _ level: _HashLevel, + at bucket: _Bucket, + item1: Element?, + item2: Element? + ) { + switch (item1, item2) { + case (nil, nil): + break + case let (item1?, nil): + self.addNewItem(level, item1, at: bucket) + case let (nil, item2?): + self.addNewItem(level, item2, at: bucket) + case let (item1?, item2?): + let n = _HashNode._build( + level: level, + item1: item1, _Hash(item1.key), + item2: { $0.initialize(to: item2) }, _Hash(item2.key)) + self.addNewChildNode(level, n.top, at: bucket) + } + } +} diff --git a/Sources/HashTreeCollections/HashNode/_HashNode+Lookups.swift b/Sources/HashTreeCollections/HashNode/_HashNode+Lookups.swift index 88c3eca3b..707c54f4c 100644 --- a/Sources/HashTreeCollections/HashNode/_HashNode+Lookups.swift +++ b/Sources/HashTreeCollections/HashNode/_HashNode+Lookups.swift @@ -26,8 +26,9 @@ extension _HashNode.UnsafeHandle { _ level: _HashLevel, _ key: Key, _ hash: _Hash ) -> (descend: Bool, slot: _HashSlot)? { guard !isCollisionNode else { - let r = _findInCollision(level, key, hash) - guard r.code == 0 else { return nil } + guard hash == collisionHash else { return nil } + let r = _findInCollision(key) + guard r.found else { return nil } return (false, r.slot) } let bucket = hash[level] @@ -44,17 +45,12 @@ extension _HashNode.UnsafeHandle { } @inlinable @inline(never) - internal func _findInCollision( - _ level: _HashLevel, _ key: Key, _ hash: _Hash - ) -> (code: Int, slot: _HashSlot) { + internal func _findInCollision(_ key: Key) -> (found: Bool, slot: _HashSlot) { assert(isCollisionNode) - if !level.isAtBottom { - if hash != self.collisionHash { return (2, .zero) } - } // Note: this searches the items in reverse insertion order. guard let slot = reverseItems.firstIndex(where: { $0.key == key }) - else { return (1, self.itemsEndSlot) } - return (0, _HashSlot(itemCount &- 1 &- slot)) + else { return (false, self.itemsEndSlot) } + return (true, _HashSlot(itemCount &- 1 &- slot)) } } @@ -143,15 +139,15 @@ extension _HashNode.UnsafeHandle { _ level: _HashLevel, _ key: Key, _ hash: _Hash ) -> _FindResult { guard !isCollisionNode else { - let r = _findInCollision(level, key, hash) - if r.code == 0 { - return .found(.invalid, r.slot) + if hash != self.collisionHash { + assert(!level.isAtBottom) + return .expansion } - if r.code == 1 { + let r = _findInCollision(key) + guard r.found else { return .appendCollision } - assert(r.code == 2) - return .expansion + return .found(.invalid, r.slot) } let bucket = hash[level] if itemMap.contains(bucket) { diff --git a/Sources/HashTreeCollections/HashNode/_HashNode+Structural subtracting.swift b/Sources/HashTreeCollections/HashNode/_HashNode+Structural subtracting.swift index 43b4e953d..cc4812b5a 100644 --- a/Sources/HashTreeCollections/HashNode/_HashNode+Structural subtracting.swift +++ b/Sources/HashTreeCollections/HashNode/_HashNode+Structural subtracting.swift @@ -124,7 +124,7 @@ extension _HashNode { var removing = false let ritems = r.reverseItems - for lslot: _HashSlot in stride(from: .zero, to: l.itemsEndSlot, by: 1) { + for lslot: _HashSlot in .zero ..< l.itemsEndSlot { let lp = l.itemPtr(at: lslot) let include = !ritems.contains { $0.key == lp.pointee.key } if include, removing { diff --git a/Sources/HashTreeCollections/HashNode/_HashNode+Subtree Removals.swift b/Sources/HashTreeCollections/HashNode/_HashNode+Subtree Removals.swift index 1df9ee610..51703a05e 100644 --- a/Sources/HashTreeCollections/HashNode/_HashNode+Subtree Removals.swift +++ b/Sources/HashTreeCollections/HashNode/_HashNode+Subtree Removals.swift @@ -64,6 +64,34 @@ extension _HashNode { } } +extension _HashNode { + @inlinable + internal func removing( + _ level: _HashLevel, _ bucket: _Bucket + ) -> (removed: Builder, replacement: Builder) { + read { handle in + assert(!handle.isCollisionNode) + if handle.itemMap.contains(bucket) { + let slot = handle.itemMap.slot(of: bucket) + let p = handle.itemPtr(at: slot) + let hash = _Hash(p.pointee.key) + let r = self.removing(level, p.pointee.key, hash)! + return (.item(level, r.removed, at: bucket), r.replacement) + } else if handle.childMap.contains(bucket) { + let slot = handle.childMap.slot(of: bucket) + if hasSingletonChild { + return (.anyNode(level.descend(), handle[child: slot]), .empty(level)) + } + var remainder = self.copy() + let removed = remainder.removeChild(at: bucket, slot) + return (.anyNode(level.descend(), removed), .node(level, remainder)) + } else { + return (.empty(level), .node(level, self)) + } + } + } +} + extension _HashNode { @inlinable internal mutating func remove( diff --git a/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift b/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift new file mode 100644 index 000000000..e14d308e0 --- /dev/null +++ b/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift @@ -0,0 +1,763 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Collections open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +#if swift(>=5.7) +import _CollectionsUtilities + +extension _HashNode { + @inlinable + internal func combining( + _ level: _HashLevel, + _ other: Self, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Self { + assert(level.isAtRoot) + let builder = try Self._combining_node_node( + level, + left: self, + right: other, + by: strategy) + let root = builder.finalize(.top) + root._fullInvariantCheck() + return root + } + + @inlinable + internal static func _combining_node_node( + _ level: _HashLevel, + left: _HashNode, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + if left.raw.storage === right.raw.storage { + switch strategy.commonBehavior { + case .include: + return .node(level, left) + case .discard: + return .empty(level) + case .merge: + return try .mergedUniqueBranch(level, left) { item in + try strategy.merge(item.key, item.value, item.value) + } + } + } + + let lc = left.isCollisionNode + let rc = right.isCollisionNode + if lc { + if rc { + return try Self._combining_collision_collision( + level, left: left, right: right, by: strategy) + } + assert(!level.isAtBottom) // We must be on a compressed path + return try Self._combining_collision_tree( + level, left: left, right: right, by: strategy) + } + if rc { + return try Self._combining_tree_collision( + level, left: left, right: right, by: strategy) + } + + return try left.read { l in + try right.read { r in + var result: Builder = .empty(level) + + let lmap = l.itemMap.union(l.childMap) + let rmap = r.itemMap.union(r.childMap) + + var buckets = lmap.union(rmap) + while let bucket = buckets.popFirst() { + let branch: Builder + if l.itemMap.contains(bucket) { + let lslot = l.itemMap.slot(of: bucket) + if r.itemMap.contains(bucket) { + let rslot = r.itemMap.slot(of: bucket) + branch = try Self._combining_item_item( + level.descend(), + left: l.itemPtr(at: lslot), + right: r.itemPtr(at: rslot), + at: bucket, + by: strategy) + } + else if r.childMap.contains(bucket) { + let rslot = r.childMap.slot(of: bucket) + branch = try Self._combining_item_tree( + level.descend(), + left: l.itemPtr(at: lslot), + right: r[child: rslot], + by: strategy) + } + else { + branch = try Self._combining_item_nil( + level.descend(), + left: l.itemPtr(at: lslot), + at: bucket, + by: strategy) + } + } + else if l.childMap.contains(bucket) { + let lslot = l.childMap.slot(of: bucket) + if r.itemMap.contains(bucket) { + let rslot = r.itemMap.slot(of: bucket) + branch = try Self._combining_tree_item( + level.descend(), + left: l[child: lslot], + right: r.itemPtr(at: rslot), + by: strategy) + } + else if r.childMap.contains(bucket) { + let rslot = r.childMap.slot(of: bucket) + branch = try Self._combining_node_node( + level.descend(), + left: l[child: lslot], + right: r[child: rslot], + by: strategy) + } + else { + branch = try Self._combining_tree_nil( + level.descend(), + left: l[child: lslot], + by: strategy) + } + } + else if r.itemMap.contains(bucket) { + let rslot = r.itemMap.slot(of: bucket) + branch = try Self._combining_nil_item( + level.descend(), + right: r.itemPtr(at: rslot), + at: bucket, + by: strategy) + } + else { + assert(r.childMap.contains(bucket)) + let rslot = r.childMap.slot(of: bucket) + branch = try Self._combining_nil_tree( + level.descend(), + right: r[child: rslot], + by: strategy) + } + result.addNewChildBranch(level, branch, at: bucket) + } + return result + } + } + } + + @inlinable + internal static func _combining_nil_branch( + _ level: _HashLevel, + right: Builder, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + switch right.kind { + case .empty: + return .empty(level) + case .item(let item, at: let bucket): + guard let new = try strategy._processAdd(item) else { + return .empty(level) + } + return .item(level, new, at: bucket) + case .node(let node): + return try _combining_nil_tree(level, right: node, by: strategy) + case .collisionNode(let node): + return try _combining_nil_collision(level, right: node, by: strategy) + } + } + + @inlinable + internal static func _combining_branch_nil( + _ level: _HashLevel, + left: Builder, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + switch left.kind { + case .empty: + return .empty(level) + case .item(let item, at: let bucket): + guard let new = try strategy._processRemove(item) else { + return .empty(level) + } + return .item(level, new, at: bucket) + case .node(let node): + return try _combining_tree_nil(level, left: node, by: strategy) + case .collisionNode(let node): + return try _combining_collision_nil(level, left: node, by: strategy) + } + } + + @inlinable + internal static func _combining_item_item( + _ level: _HashLevel, + left: UnsafePointer, + right: UnsafePointer, + at bucket: _Bucket, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + if left.pointee.key == right.pointee.key { + guard let item = try strategy._processCommon(left, right) else { + return .empty(level) + } + return .item(level, item, at: bucket) + } else { + let item1 = try strategy._processRemove(left.pointee) + let item2 = try strategy._processAdd(right.pointee) + return .conflictingItems(level, item1, item2, at: bucket) + } + } + + @inlinable + internal static func _combining_item_nil( + _ level: _HashLevel, + left: UnsafePointer, + at bucket: _Bucket, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + guard let item = try strategy._processRemove(left.pointee) else { + return .empty(level) + } + return .item(level, item, at: bucket) + } + + @inlinable + internal static func _combining_nil_item( + _ level: _HashLevel, + right: UnsafePointer, + at bucket: _Bucket, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + guard let item = try strategy._processAdd(right.pointee) else { + return .empty(level) + } + return .item(level, item, at: bucket) + } + + @inlinable + internal static func _combining_tree_nil( + _ level: _HashLevel, + left: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + switch strategy.removeBehavior { + case .include: + return .node(level, left) + case .discard: + return .empty(level) + case .merge: + return try Builder.mergedUniqueBranch(level, left) { item in + try strategy.merge(item.key, item.value, nil) + } + } + } + + @inlinable + internal static func _combining_nil_tree( + _ level: _HashLevel, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + switch strategy.addBehavior { + case .include: + return .node(level, right) + case .discard: + return .empty(level) + case .merge: + return try Builder.mergedUniqueBranch(level, right) { item in + try strategy.merge(item.key, nil, item.value) + } + } + } + + @inlinable + internal static func _combining_item_tree( + _ level: _HashLevel, + left: UnsafePointer, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + let hash = _Hash(left.pointee.key) + switch strategy.addBehavior { + case .include: + if var t = right.removing(level, left.pointee.key, hash) { + guard let item = try strategy._processCommon(left, &t.removed) else { + return t.replacement + } + var rnode = t.replacement.finalize(level) + let t2 = rnode.insert(level, item, hash) + assert(t2.inserted) + return .node(level, rnode) + } + guard let item = try strategy._processRemove(left.pointee) else { + return .node(level, right) + } + let t2 = right.inserting(level, item, hash) + assert(t2.inserted) + return .node(level, t2.node) + case .discard: + if let t = right.lookup(level, left.pointee.key, hash) { + let item = try UnsafeHandle.read(t.node) { rn in + try strategy._processCommon(left, rn.itemPtr(at: t.slot)) + } + guard let item else { return .empty(level) } + return .item(level, item, at: .invalid) + } + guard let item = try strategy._processRemove(left.pointee) else { + return .empty(level) + } + return .item(level, item, at: .invalid) + case .merge: + if var t = right.removing(level, left.pointee.key, hash) { + var rnode = t.replacement.finalize(level) + let b = try Builder.mergedUniqueBranch(level, rnode) { + try strategy.merge($0.key, nil, $0.value) + } + guard let item = try strategy._processCommon(left, &t.removed) else { + return b + } + rnode = b.finalize(level) + let t2 = rnode.insert(level, item, hash) + assert(t2.inserted) + return .node(level, rnode) + } + let b = try Builder.mergedUniqueBranch(level, right) { + try strategy.merge($0.key, nil, $0.value) + } + guard let item = try strategy._processRemove(left.pointee) else { + return b + } + var rnode = b.finalize(level) + let t2 = rnode.insert(level, item, hash) + assert(t2.inserted) + return .node(level, rnode) + } + } + + @inlinable + internal static func _combining_tree_item( + _ level: _HashLevel, + left: _HashNode, + right: UnsafePointer, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + let hash = _Hash(right.pointee.key) + switch strategy.removeBehavior { + case .include: + if var t = left.removing(level, right.pointee.key, hash) { + guard let item = try strategy._processCommon(&t.removed, right) else { + return t.replacement + } + var lnode = t.replacement.finalize(level) + let t2 = lnode.insert(level, item, hash) + assert(t2.inserted) + return .node(level, lnode) + } + guard let item = try strategy._processAdd(right.pointee) else { + return .node(level, left) + } + let t2 = left.inserting(level, item, hash) + assert(t2.inserted) + return .node(level, t2.node) + case .discard: + if let t = left.lookup(level, right.pointee.key, hash) { + let item = try UnsafeHandle.read(t.node) { ln in + try strategy._processCommon(ln.itemPtr(at: t.slot), right) + } + guard let item else { return .empty(level) } + return .item(level, item, at: .invalid) + } + guard let item = try strategy._processAdd(right.pointee) else { + return .empty(level) + } + return .item(level, item, at: .invalid) + case .merge: + if var t = left.removing(level, right.pointee.key, hash) { + var lnode = t.replacement.finalize(level) + let b = try Builder.mergedUniqueBranch(level, lnode) { + try strategy.merge($0.key, $0.value, nil) + } + guard let item = try strategy._processCommon(&t.removed, right) else { + return b + } + lnode = b.finalize(level) + let t2 = lnode.insert(level, item, hash) + assert(t2.inserted) + return .node(level, lnode) + } + let b = try Builder.mergedUniqueBranch(level, left) { + try strategy.merge($0.key, $0.value, nil) + } + guard let item = try strategy._processAdd(right.pointee) else { + return b + } + var lnode = b.finalize(level) + let t2 = lnode.insert(level, item, hash) + assert(t2.inserted) + return .node(level, lnode) + } + } + + @inlinable + internal static func _combining_collision_collision( + _ level: _HashLevel, + left: _HashNode, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(left.isCollisionNode && right.isCollisionNode) + let lhash = left.collisionHash + let rhash = right.collisionHash + guard lhash == rhash else { + let ln = try Self._combining_collision_nil( + level, left: left, by: strategy) + let rn = try Self._combining_nil_collision( + level, right: right, by: strategy) + return Builder(level, collisions1: ln, lhash, collisions2: rn, rhash) + } + + return try left.read { l in + try right.read { r in + var result: Builder = .empty(level) + + let ritems = r.reverseItems + try _UnsafeBitSet.withTemporaryBitSet(capacity: ritems.count) { bitset in + bitset.insertAll(upTo: ritems.count) + let hash = l.collisionHash + for lslot: _HashSlot in .zero ..< l.itemsEndSlot { + let lp = l.itemPtr(at: lslot) + let match = r._findInCollision(lp.pointee.key) + if match.found { + bitset.remove(match.slot.value) + let rp = r.itemPtr(at: match.slot) + if let new = try strategy._processCommon(lp, rp) { + result.addNewCollision(level, new, hash) + } + } else { + if let new = try strategy._processRemove(lp.pointee) { + result.addNewCollision(level, new, hash) + } + } + } + for offset in bitset { + let rslot = _HashSlot(offset) + let rp = r.itemPtr(at: rslot) + if let new = try strategy._processAdd(rp.pointee) { + result.addNewCollision(level, new, hash) + } + } + } + return result + } + } + } + + @inlinable + internal static func _combining_collision_branch( + _ level: _HashLevel, + left: _HashNode, + right: Builder, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(left.isCollisionNode) + switch right.kind { + case .empty: + return try _combining_collision_nil( + level, + left: left, + by: strategy) + case .item(let item, at: _): + return try _combining_collision_item( + level, + left: left, + right: item, + by: strategy) + case .node(let node): + return try _combining_collision_tree( + level, + left: left, + right: node, + by: strategy) + case .collisionNode(let node): + return try _combining_collision_collision( + level, + left: left, + right: node, + by: strategy) + } + } + + @inlinable + internal static func _combining_branch_collision( + _ level: _HashLevel, + left: Builder, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(right.isCollisionNode) + switch left.kind { + case .empty: + return try _combining_nil_collision( + level, + right: right, + by: strategy) + case .item(let item, at: _): + return try _combining_item_collision( + level, + left: item, + right: right, + by: strategy) + case .node(let node): + return try _combining_tree_collision( + level, + left: node, + right: right, + by: strategy) + case .collisionNode(let node): + return try _combining_collision_collision( + level, + left: node, + right: right, + by: strategy) + } + } + + @inlinable + internal static func _combining_collision_nil( + _ level: _HashLevel, + left: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(left.isCollisionNode) + switch strategy.removeBehavior { + case .include: + return .node(level, left) + case .discard: + return .empty(level) + case .merge: + return try .mergedUniqueBranch(level, left) { item in + try strategy.merge(item.key, item.value, nil) + } + } + } + + @inlinable + internal static func _combining_nil_collision( + _ level: _HashLevel, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(right.isCollisionNode) + switch strategy.addBehavior { + case .include: + return .node(level, right) + case .discard: + return .empty(level) + case .merge: + var result: Builder = .empty(level) + try right.read { r in + let hash = r.collisionHash + for rslot: _HashSlot in .zero ..< r.itemsEndSlot { + let rp = r.itemPtr(at: rslot) + if let v = try strategy.merge(rp.pointee.key, nil, rp.pointee.value) { + result.addNewCollision(level, (rp.pointee.key, v), hash) + } + } + } + return result + } + } + + @inlinable + internal static func _combining_collision_tree( + _ level: _HashLevel, + left: _HashNode, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(left.isCollisionNode) + assert(!right.isCollisionNode) + return try left.read { l in + let bucket = l.collisionHash[level] + + let (removed, remainder) = right.removing(level, bucket) + + var result = try _combining_nil_branch( + level, right: remainder, by: strategy) + + let branch = try _combining_collision_branch( + level.descend(), + left: left, + right: removed, + by: strategy) + result.addNewChildBranch(level, branch, at: bucket) + return result + } + } + + @inlinable + internal static func _combining_tree_collision( + _ level: _HashLevel, + left: _HashNode, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(!left.isCollisionNode) + assert(right.isCollisionNode) + return try right.read { r in + let bucket = r.collisionHash[level] + + let (removed, remainder) = left.removing(level, bucket) + + var result = try _combining_branch_nil( + level, left: remainder, by: strategy) + + let branch = try _combining_branch_collision( + level.descend(), + left: removed, + right: right, + by: strategy) + result.addNewChildBranch(level, branch, at: bucket) + return result + } + } + + @inlinable + internal static func _combining_collision_item( + _ level: _HashLevel, + left: _HashNode, + right: Element, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(left.isCollisionNode) + let hash = _Hash(right.key) + if left.collisionHash == hash { + if let r = left.removing(level, right.key, hash) { + var result = try _combining_branch_nil( + level, left: r.replacement, by: strategy) + if let removed = try strategy._processCommon(r.removed, right) { + result.addNewCollision(level, removed, hash) + } + return result + } + var result = try _combining_collision_nil( + level, left: left, by: strategy) + if let new = try strategy._processAdd(right) { + result.addNewCollision(level, new, hash) + } + return result + } + + let branch = try _combining_collision_nil(level, left: left, by: strategy) + guard let item = try strategy._processAdd(right) else { + return branch + } + var result: Builder = .empty(level) + result.addNewItem(level, item, at: hash[level]) + result.addNewChildBranch(level, branch, at: left.collisionHash[level]) + return result + } + + @inlinable + internal static func _combining_item_collision( + _ level: _HashLevel, + left: Element, + right: _HashNode, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Builder { + assert(right.isCollisionNode) + let hash = _Hash(left.key) + if right.collisionHash == hash { + if let r = right.removing(level, left.key, hash) { + var result = try _combining_nil_branch( + level, right: r.replacement, by: strategy) + if let removed = try strategy._processCommon(left, r.removed) { + result.addNewCollision(level, removed, hash) + } + return result + } + var result = try _combining_nil_collision( + level, right: right, by: strategy) + if let new = try strategy._processRemove(left) { + result.addNewCollision(level, new, hash) + } + return result + } + + let branch = try _combining_nil_collision(level, right: right, by: strategy) + guard let item = try strategy._processRemove(left) else { + return branch + } + var result: Builder = .empty(level) + result.addNewItem(level, item, at: hash[level]) + result.addNewChildBranch(level, branch, at: right.collisionHash[level]) + return result + } + + +} + +extension TreeDictionaryCombiningStrategy { + @inlinable + internal func _processCommon( + _ p1: UnsafePointer, + _ p2: UnsafePointer + ) throws -> Element? { + try _processCommon(p1.pointee, p2.pointee) + } + + @inlinable + internal func _processCommon( + _ item1: Element, + _ item2: Element + ) throws -> Element? { + assert(item1.key == item2.key) + let b = commonBehavior + if + b == .merge + || !areEquivalentValues(item1.value, item2.value) + { + let v = try merge(item1.key, item1.value, item2.value) + guard let v = v else { return nil } + return (item1.key, v) + } + if b == .include { + return item1 + } + return nil + } + + @inlinable + internal func _processRemove(_ item: Element) throws -> Element? { + switch addBehavior { + case .include: + return item + case .discard: + return nil + case .merge: + let v = try merge(item.key, item.value, nil) + guard let v = v else { return nil } + return (item.key, v) + } + } + + @inlinable + internal func _processAdd(_ item: Element) throws -> Element? { + switch removeBehavior { + case .include: + return item + case .discard: + return nil + case .merge: + let v = try merge(item.key, nil, item.value) + guard let v = v else { return nil } + return (item.key, v) + } + } +} +#endif diff --git a/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift b/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift new file mode 100644 index 000000000..1a29e94d5 --- /dev/null +++ b/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift @@ -0,0 +1,61 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Collections open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +#if swift(>=5.7) +@frozen +public enum CombiningBehavior { + case include + case discard + case merge +} + +public protocol TreeDictionaryCombiningStrategy { + associatedtype Key: Hashable + associatedtype Value + + var commonBehavior: CombiningBehavior { get } + var addBehavior: CombiningBehavior { get } + var removeBehavior: CombiningBehavior { get } + + func areEquivalentValues(_ a: Value, _ b: Value) -> Bool + func merge(_ key: Key, _ value1: Value?, _ value2: Value?) throws -> Value? +} + +extension TreeDictionaryCombiningStrategy { + public typealias Element = (key: Key, value: Value) +} + +extension TreeDictionaryCombiningStrategy where Value: Equatable { + @inlinable @inline(__always) + public func areEquivalentValues(_ a: Value, _ b: Value) -> Bool { + a == b + } +} + +extension TreeDictionary { + @inlinable + public func combining( + _ other: Self, + by strategy: some TreeDictionaryCombiningStrategy + ) throws -> Self { + let root = try _root.combining(.top, other._root, by: strategy) + return Self(_new: root) + } + + @inlinable + mutating func combine( + _ other: Self, + by strategy: some TreeDictionaryCombiningStrategy + ) throws { + self = try combining(other, by: strategy) + } +} +#endif From f8252436e5e4384eaa1a68560db380d83b7c71eb Mon Sep 17 00:00:00 2001 From: Karoy Lorentey Date: Tue, 29 Nov 2022 20:11:57 -0800 Subject: [PATCH 2/2] [TreeDictionary] Adjust combining API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Spin off two separate combining behaviors for values that are present in both dictionaries, based on whether or not the values are equal. Before this change, unequal values were always merged, which isn’t necessarily what we want. --- .../HashNode/_Node+Structural combining.swift | 43 ++++++++------ .../TreeDictionary+Combine.swift | 8 ++- .../TreeDictionary Smoke Tests.swift | 56 +++++++++++++++++++ .../xcschemes/HashTreeCollections.xcscheme | 14 +++++ 4 files changed, 100 insertions(+), 21 deletions(-) diff --git a/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift b/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift index e14d308e0..725905a5c 100644 --- a/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift +++ b/Sources/HashTreeCollections/HashNode/_Node+Structural combining.swift @@ -38,7 +38,7 @@ extension _HashNode { by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { if left.raw.storage === right.raw.storage { - switch strategy.commonBehavior { + switch strategy.equalValuesInBoth { case .include: return .node(level, left) case .discard: @@ -245,7 +245,7 @@ extension _HashNode { left: _HashNode, by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { - switch strategy.removeBehavior { + switch strategy.valuesOnlyInFirst { case .include: return .node(level, left) case .discard: @@ -263,7 +263,7 @@ extension _HashNode { right: _HashNode, by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { - switch strategy.addBehavior { + switch strategy.valuesOnlyInSecond { case .include: return .node(level, right) case .discard: @@ -283,7 +283,7 @@ extension _HashNode { by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { let hash = _Hash(left.pointee.key) - switch strategy.addBehavior { + switch strategy.valuesOnlyInSecond { case .include: if var t = right.removing(level, left.pointee.key, hash) { guard let item = try strategy._processCommon(left, &t.removed) else { @@ -347,7 +347,7 @@ extension _HashNode { by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { let hash = _Hash(right.pointee.key) - switch strategy.removeBehavior { + switch strategy.valuesOnlyInFirst { case .include: if var t = left.removing(level, right.pointee.key, hash) { guard let item = try strategy._processCommon(&t.removed, right) else { @@ -534,7 +534,7 @@ extension _HashNode { by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { assert(left.isCollisionNode) - switch strategy.removeBehavior { + switch strategy.valuesOnlyInFirst { case .include: return .node(level, left) case .discard: @@ -553,7 +553,7 @@ extension _HashNode { by strategy: some TreeDictionaryCombiningStrategy ) throws -> Builder { assert(right.isCollisionNode) - switch strategy.addBehavior { + switch strategy.valuesOnlyInSecond { case .include: return .node(level, right) case .discard: @@ -717,24 +717,31 @@ extension TreeDictionaryCombiningStrategy { _ item2: Element ) throws -> Element? { assert(item1.key == item2.key) - let b = commonBehavior - if - b == .merge - || !areEquivalentValues(item1.value, item2.value) - { + let equals = self.equalValuesInBoth + let unequals = self.unequalValuesInBoth + + let result: CombiningBehavior + + if equals == unequals || areEquivalentValues(item1.value, item2.value) { + result = equals + } else { + result = unequals + } + switch result { + case .include: + return item1 + case .discard: + return nil + case .merge: let v = try merge(item1.key, item1.value, item2.value) guard let v = v else { return nil } return (item1.key, v) } - if b == .include { - return item1 - } - return nil } @inlinable internal func _processRemove(_ item: Element) throws -> Element? { - switch addBehavior { + switch valuesOnlyInFirst { case .include: return item case .discard: @@ -748,7 +755,7 @@ extension TreeDictionaryCombiningStrategy { @inlinable internal func _processAdd(_ item: Element) throws -> Element? { - switch removeBehavior { + switch valuesOnlyInSecond { case .include: return item case .discard: diff --git a/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift b/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift index 1a29e94d5..915d77121 100644 --- a/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift +++ b/Sources/HashTreeCollections/TreeDictionary/TreeDictionary+Combine.swift @@ -21,11 +21,13 @@ public protocol TreeDictionaryCombiningStrategy { associatedtype Key: Hashable associatedtype Value - var commonBehavior: CombiningBehavior { get } - var addBehavior: CombiningBehavior { get } - var removeBehavior: CombiningBehavior { get } + var valuesOnlyInFirst: CombiningBehavior { get } + var valuesOnlyInSecond: CombiningBehavior { get } + var equalValuesInBoth: CombiningBehavior { get } + var unequalValuesInBoth: CombiningBehavior { get } func areEquivalentValues(_ a: Value, _ b: Value) -> Bool + func merge(_ key: Key, _ value1: Value?, _ value2: Value?) throws -> Value? } diff --git a/Tests/HashTreeCollectionsTests/TreeDictionary Smoke Tests.swift b/Tests/HashTreeCollectionsTests/TreeDictionary Smoke Tests.swift index f9981d815..cd3c390c2 100644 --- a/Tests/HashTreeCollectionsTests/TreeDictionary Smoke Tests.swift +++ b/Tests/HashTreeCollectionsTests/TreeDictionary Smoke Tests.swift @@ -625,4 +625,60 @@ final class TreeDictionarySmokeTests: CollectionTestCase { expectTrue(expectedPositions.isEmpty) } + + func test_combine() { + var d1 = TreeDictionary(uniqueKeysWithValues: (0 ..< 10000).map { ($0, "1") }) + d1[1] = "1" + var d2 = d1 +// for i in 10 ..< 20 { +// d1[i] = nil +// } +// for i in 20 ..< 30 { +// d2[i] = nil +// } + for i in 40 ..< 50 { + d2[i] = "2" + } + + class TestStrategy: TreeDictionaryCombiningStrategy { + typealias Key = Int + typealias Value = String + + var _equalCounter = 0 + var _mergeCounter = 0 + + var valuesOnlyInFirst: CombiningBehavior { .merge } + var valuesOnlyInSecond: CombiningBehavior { .merge } + var equalValuesInBoth: CombiningBehavior { .discard } + var unequalValuesInBoth: CombiningBehavior { .merge } + + func areEquivalentValues(_ a: Value, _ b: Value) -> Bool { + _equalCounter += 1 + return a == b + } + + func merge( + _ key: Key, _ value1: Value?, _ value2: Value? + ) throws -> Value? { + _mergeCounter += 1 + + let s1 = value1 ?? "nil" + let s2 = value2 ?? "nil" + print("key: \(key), value1: \(s1), value2: \(s2)") + + switch (value1, value2) { + case (nil, nil): return "00" + case (_?, nil): return "10" + case (nil, _?): return "01" + case (_?, _?): return "11" + } + } + } + + let strategy = TestStrategy() + let d = try! d1.combining(d2, by: strategy) + print(d.map { ($0.key, $0.value) }.sorted(by: { $0.0 < $1.0 })) + print("Merge count: \(strategy._mergeCounter)") + print("isEqual count: \(strategy._equalCounter)") + } } diff --git a/Utils/swift-collections.xcworkspace/xcshareddata/xcschemes/HashTreeCollections.xcscheme b/Utils/swift-collections.xcworkspace/xcshareddata/xcschemes/HashTreeCollections.xcscheme index 35631f44b..461917f4a 100644 --- a/Utils/swift-collections.xcworkspace/xcshareddata/xcschemes/HashTreeCollections.xcscheme +++ b/Utils/swift-collections.xcworkspace/xcshareddata/xcschemes/HashTreeCollections.xcscheme @@ -20,6 +20,20 @@ ReferencedContainer = "container:.."> + + + +