Skip to content

Commit 4c8099e

Browse files
authored
feat: merkle tree verifier implementation to support all numbers of leaves (#253)
## Overview Closes #249 The implementation is taken from: https://github.com/celestiaorg/celestia-core/blob/0498541b8db00c7fefa918d906877ef2ee0a3710/crypto/merkle/proof.go#L166-L197 ## Checklist <!-- Please complete the checklist to ensure that the PR is ready to be reviewed. IMPORTANT: PRs should be left in Draft until the below checklist is completed. --> - [ ] New and updated code has appropriate documentation - [ ] New and updated code has new and/or updated testing - [ ] Required CI checks are passing - [ ] Visual proof for any user facing features like CLI or documentation updates - [ ] Linked issues closed with keywords <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit The existing bullet-point list is still valid based on the provided information. No changes are required. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5b34595 commit 4c8099e

File tree

5 files changed

+248
-85
lines changed

5 files changed

+248
-85
lines changed

src/lib/tree/Utils.sol

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,34 @@ function pathLengthFromKey(uint256 key, uint256 numLeaves) pure returns (uint256
4242
return 1 + pathLengthFromKey(key - numLeavesLeftSubTree, numLeaves - numLeavesLeftSubTree);
4343
}
4444
}
45+
46+
/// @notice Returns the minimum number of bits required to represent `x`; the
47+
/// result is 0 for `x` == 0.
48+
/// @param x Number.
49+
function _bitsLen(uint256 x) pure returns (uint256) {
50+
uint256 count = 0;
51+
52+
while (x != 0) {
53+
count++;
54+
x >>= 1;
55+
}
56+
57+
return count;
58+
}
59+
60+
/// @notice Returns the largest power of 2 less than `x`.
61+
/// @param x Number.
62+
function _getSplitPoint(uint256 x) pure returns (uint256) {
63+
// Note: since `x` is always an unsigned int * 2, the only way for this
64+
// to be violated is if the input == 0. Since the input is the end
65+
// index exclusive, an input of 0 is guaranteed to be invalid (it would
66+
// be a proof of inclusion of nothing, which is vacuous).
67+
require(x >= 1);
68+
69+
uint256 bitLen = _bitsLen(x);
70+
uint256 k = 1 << (bitLen - 1);
71+
if (k == x) {
72+
k >>= 1;
73+
}
74+
return k;
75+
}

src/lib/tree/binary/BinaryMerkleTree.sol

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -42,65 +42,59 @@ library BinaryMerkleTree {
4242
}
4343
}
4444

45-
uint256 height = 1;
46-
uint256 stableEnd = proof.key;
45+
bytes32 computedHash = computeRootHash(proof.key, proof.numLeaves, digest, proof.sideNodes);
4746

48-
// While the current subtree (of height 'height') is complete, determine
49-
// the position of the next sibling using the complete subtree algorithm.
50-
// 'stableEnd' tells us the ending index of the last full subtree. It gets
51-
// initialized to 'key' because the first full subtree was the
52-
// subtree of height 1, created above (and had an ending index of
53-
// 'key').
54-
55-
while (true) {
56-
// Determine if the subtree is complete. This is accomplished by
57-
// rounding down the key to the nearest 1 << 'height', adding 1
58-
// << 'height', and comparing the result to the number of leaves in the
59-
// Merkle tree.
60-
61-
uint256 subTreeStartIndex = (proof.key / (1 << height)) * (1 << height);
62-
uint256 subTreeEndIndex = subTreeStartIndex + (1 << height) - 1;
63-
64-
// If the Merkle tree does not have a leaf at index
65-
// 'subTreeEndIndex', then the subtree of the current height is not
66-
// a complete subtree.
67-
if (subTreeEndIndex >= proof.numLeaves) {
68-
break;
69-
}
70-
stableEnd = subTreeEndIndex;
71-
72-
// Determine if the key is in the first or the second half of
73-
// the subtree.
74-
if (proof.sideNodes.length <= height - 1) {
75-
return false;
76-
}
77-
if (proof.key - subTreeStartIndex < (1 << (height - 1))) {
78-
digest = nodeDigest(digest, proof.sideNodes[height - 1]);
79-
} else {
80-
digest = nodeDigest(proof.sideNodes[height - 1], digest);
81-
}
47+
return (computedHash == root);
48+
}
8249

83-
height += 1;
50+
/// @notice Use the leafHash and innerHashes to get the root merkle hash.
51+
/// If the length of the innerHashes slice isn't exactly correct, the result is nil.
52+
/// Recursive impl.
53+
function computeRootHash(uint256 key, uint256 numLeaves, bytes32 leafHash, bytes32[] memory sideNodes)
54+
private
55+
pure
56+
returns (bytes32)
57+
{
58+
if (numLeaves == 0) {
59+
revert("cannot call computeRootHash with 0 number of leaves");
8460
}
85-
86-
// Determine if the next hash belongs to an orphan that was elevated. This
87-
// is the case IFF 'stableEnd' (the last index of the largest full subtree)
88-
// is equal to the number of leaves in the Merkle tree.
89-
if (stableEnd != proof.numLeaves - 1) {
90-
if (proof.sideNodes.length <= height - 1) {
91-
return false;
61+
if (numLeaves == 1) {
62+
if (sideNodes.length != 0) {
63+
revert("unexpected inner hashes");
9264
}
93-
digest = nodeDigest(digest, proof.sideNodes[height - 1]);
94-
height += 1;
65+
return leafHash;
9566
}
96-
97-
// All remaining elements in the proof set will belong to a left sibling\
98-
// i.e proof sideNodes are hashed in "from the left"
99-
while (height - 1 < proof.sideNodes.length) {
100-
digest = nodeDigest(proof.sideNodes[height - 1], digest);
101-
height += 1;
67+
if (sideNodes.length == 0) {
68+
revert("expected at least one inner hash");
69+
}
70+
uint256 numLeft = _getSplitPoint(numLeaves);
71+
bytes32[] memory sideNodesLeft = slice(sideNodes, 0, sideNodes.length - 1);
72+
if (key < numLeft) {
73+
bytes32 leftHash = computeRootHash(key, numLeft, leafHash, sideNodesLeft);
74+
return nodeDigest(leftHash, sideNodes[sideNodes.length - 1]);
10275
}
76+
bytes32 rightHash = computeRootHash(key - numLeft, numLeaves - numLeft, leafHash, sideNodesLeft);
77+
return nodeDigest(sideNodes[sideNodes.length - 1], rightHash);
78+
}
10379

104-
return (digest == root);
80+
/// @notice creates a slice of bytes32 from the data slice of bytes32 containing the elements
81+
/// that correspond to the provided range.
82+
/// It selects a half-open range which includes the begin element, but excludes the end one.
83+
/// @param _data The slice that we want to select data from.
84+
/// @param _begin The beginning of the range (inclusive).
85+
/// @param _end The ending of the range (exclusive).
86+
/// @return _ the sliced data.
87+
function slice(bytes32[] memory _data, uint256 _begin, uint256 _end) internal pure returns (bytes32[] memory) {
88+
if (_begin > _end) {
89+
revert("Invalid range: _begin is greater than _end");
90+
}
91+
if (_begin > _data.length || _end > _data.length) {
92+
revert("Invalid range: _begin or _end are out of bounds");
93+
}
94+
bytes32[] memory out = new bytes32[](_end-_begin);
95+
for (uint256 i = _begin; i < _end; i++) {
96+
out[i - _begin] = _data[i];
97+
}
98+
return out;
10599
}
106100
}

src/lib/tree/binary/test/BinaryMerkleTree.t.sol

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
pragma solidity ^0.8.22;
33

44
import "ds-test/test.sol";
5+
import "forge-std/Vm.sol";
56

67
import "../BinaryMerkleProof.sol";
78
import "../BinaryMerkleTree.sol";
@@ -40,6 +41,8 @@ import "../BinaryMerkleTree.sol";
4041
*/
4142

4243
contract BinaryMerkleProofTest is DSTest {
44+
Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));
45+
4346
function setUp() external {}
4447

4548
function testVerifyNone() external {
@@ -101,6 +104,36 @@ contract BinaryMerkleProofTest is DSTest {
101104
assertTrue(isValid);
102105
}
103106

107+
function testVerifyLeafTwoOfEight() external {
108+
bytes32 root = 0xc1ad6548cb4c7663110df219ec8b36ca63b01158956f4be31a38a88d0c7f7071;
109+
bytes32[] memory sideNodes = new bytes32[](3);
110+
sideNodes[0] = 0xb413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2;
111+
sideNodes[1] = 0x78850a5ab36238b076dd99fd258c70d523168704247988a94caa8c9ccd056b8d;
112+
sideNodes[2] = 0x4301a067262bbb18b4919742326f6f6d706099f9c0e8b0f2db7b88f204b2cf09;
113+
114+
uint256 key = 1;
115+
uint256 numLeaves = 8;
116+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
117+
bytes memory data = hex"02";
118+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
119+
assertTrue(isValid);
120+
}
121+
122+
function testVerifyLeafThreeOfEight() external {
123+
bytes32 root = 0xc1ad6548cb4c7663110df219ec8b36ca63b01158956f4be31a38a88d0c7f7071;
124+
bytes32[] memory sideNodes = new bytes32[](3);
125+
sideNodes[0] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
126+
sideNodes[1] = 0x6bcf0e2e93e0a18e22789aee965e6553f4fbe93f0acfc4a705d691c8311c4965;
127+
sideNodes[2] = 0x4301a067262bbb18b4919742326f6f6d706099f9c0e8b0f2db7b88f204b2cf09;
128+
129+
uint256 key = 2;
130+
uint256 numLeaves = 8;
131+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
132+
bytes memory data = hex"03";
133+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
134+
assertTrue(isValid);
135+
}
136+
104137
function testVerifyLeafSevenOfEight() external {
105138
bytes32 root = 0xc1ad6548cb4c7663110df219ec8b36ca63b01158956f4be31a38a88d0c7f7071;
106139
bytes32[] memory sideNodes = new bytes32[](3);
@@ -130,4 +163,140 @@ contract BinaryMerkleProofTest is DSTest {
130163
bool isValid = BinaryMerkleTree.verify(root, proof, data);
131164
assertTrue(isValid);
132165
}
166+
167+
// Test vectors:
168+
// 0x00
169+
// 0x01
170+
// 0x02
171+
// 0x03
172+
// 0x04
173+
function testVerifyProofOfFiveLeaves() external {
174+
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
175+
bytes32[] memory sideNodes = new bytes32[](3);
176+
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
177+
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
178+
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
179+
180+
uint256 key = 1;
181+
uint256 numLeaves = 5;
182+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
183+
bytes memory data = bytes(hex"01");
184+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
185+
assertTrue(isValid);
186+
}
187+
188+
function testVerifyInvalidProofRoot() external {
189+
// correct root: 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
190+
bytes32 root = 0xc855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
191+
bytes32[] memory sideNodes = new bytes32[](3);
192+
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
193+
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
194+
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
195+
196+
uint256 key = 1;
197+
uint256 numLeaves = 5;
198+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
199+
bytes memory data = bytes(hex"01");
200+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
201+
assertTrue(!isValid);
202+
}
203+
204+
function testVerifyInvalidProofKey() external {
205+
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
206+
bytes32[] memory sideNodes = new bytes32[](3);
207+
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
208+
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
209+
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
210+
211+
// correct key: 1
212+
uint256 key = 2;
213+
uint256 numLeaves = 5;
214+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
215+
bytes memory data = bytes(hex"01");
216+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
217+
assertTrue(!isValid);
218+
}
219+
220+
function testVerifyInvalidProofNumberOfLeaves() external {
221+
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
222+
bytes32[] memory sideNodes = new bytes32[](3);
223+
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
224+
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
225+
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
226+
227+
uint256 key = 1;
228+
// correct numLeaves: 5
229+
uint256 numLeaves = 200;
230+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
231+
bytes memory data = bytes(hex"01");
232+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
233+
assertTrue(!isValid);
234+
}
235+
236+
function testVerifyInvalidProofSideNodes() external {
237+
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
238+
bytes32[] memory sideNodes = new bytes32[](3);
239+
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
240+
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
241+
// correct side node: 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
242+
sideNodes[2] = 0x5f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
243+
244+
uint256 key = 1;
245+
uint256 numLeaves = 5;
246+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
247+
bytes memory data = bytes(hex"01");
248+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
249+
assertTrue(!isValid);
250+
}
251+
252+
function testVerifyInvalidProofData() external {
253+
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
254+
bytes32[] memory sideNodes = new bytes32[](3);
255+
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
256+
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
257+
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
258+
259+
uint256 key = 1;
260+
uint256 numLeaves = 5;
261+
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
262+
// correct data: 01
263+
bytes memory data = bytes(hex"012345");
264+
bool isValid = BinaryMerkleTree.verify(root, proof, data);
265+
assertTrue(!isValid);
266+
}
267+
268+
function testValidSlice() public {
269+
bytes32[] memory data = new bytes32[](4);
270+
data[0] = "a";
271+
data[1] = "b";
272+
data[2] = "c";
273+
data[3] = "d";
274+
275+
bytes32[] memory result = BinaryMerkleTree.slice(data, 1, 3);
276+
277+
assertEq(result[0], data[1]);
278+
assertEq(result[1], data[2]);
279+
}
280+
281+
function testInvalidSliceBeginEnd() public {
282+
bytes32[] memory data = new bytes32[](4);
283+
data[0] = "a";
284+
data[1] = "b";
285+
data[2] = "c";
286+
data[3] = "d";
287+
288+
vm.expectRevert("Invalid range: _begin is greater than _end");
289+
BinaryMerkleTree.slice(data, 2, 1);
290+
}
291+
292+
function testOutOfBoundsSlice() public {
293+
bytes32[] memory data = new bytes32[](4);
294+
data[0] = "a";
295+
data[1] = "b";
296+
data[2] = "c";
297+
data[3] = "d";
298+
299+
vm.expectRevert("Invalid range: _begin or _end are out of bounds");
300+
BinaryMerkleTree.slice(data, 2, 5);
301+
}
133302
}

src/lib/tree/namespace/NamespaceMerkleTree.sol

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -218,37 +218,6 @@ library NamespaceMerkleTree {
218218
return count;
219219
}
220220

221-
/// @notice Returns the minimum number of bits required to represent `x`; the
222-
/// result is 0 for `x` == 0.
223-
/// @param x Number.
224-
function _bitsLen(uint256 x) private pure returns (uint256) {
225-
uint256 count = 0;
226-
227-
while (x != 0) {
228-
count++;
229-
x >>= 1;
230-
}
231-
232-
return count;
233-
}
234-
235-
/// @notice Returns the largest power of 2 less than `x`.
236-
/// @param x Number.
237-
function _getSplitPoint(uint256 x) private pure returns (uint256) {
238-
// Note: since `x` is always an unsigned int * 2, the only way for this
239-
// to be violated is if the input == 0. Since the input is the end
240-
// index exclusive, an input of 0 is guaranteed to be invalid (it would
241-
// be a proof of inclusion of nothing, which is vacuous).
242-
require(x >= 1);
243-
244-
uint256 bitLen = _bitsLen(x);
245-
uint256 k = 1 << (bitLen - 1);
246-
if (k == x) {
247-
k >>= 1;
248-
}
249-
return k;
250-
}
251-
252221
/// @notice Computes the NMT root recursively.
253222
/// @param proof Namespace Merkle multiproof for the leaves.
254223
/// @param leafNodes Leaf nodes for which inclusion is proven.

0 commit comments

Comments
 (0)