[openzeppelin]:v4.8.3,[forge-std]:v1.5.6
Github: https://github.com/OpenZeppelin/openzeppelin-contracts/blob/v4.8.3/contracts/utils/structs/EnumerableSet.sol
EnumerableSet库提供了Bytes32Set、AddressSet和UintSet三种类型的set,分别用于bytes32、address和uint256类型的元素。 每种set都提供了对应的增添元素、删除元素、检查目标元素是否处于set中、查询当前set中元素个数等操作。几乎所有操作的时间复杂度均为O(1)。
封装EnumerableSet library成为一个可调用合约:
Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/src/utils/structs/MockEnumerableSet.sol
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.0;
import "openzeppelin-contracts/contracts/utils/structs/EnumerableSet.sol";
contract MockBytes32Set {
using EnumerableSet for EnumerableSet.Bytes32Set;
EnumerableSet.Bytes32Set _bytes32Set;
function add(bytes32 value) external returns (bool) {
return _bytes32Set.add(value);
}
function remove(bytes32 value) external returns (bool){
return _bytes32Set.remove(value);
}
function contains(bytes32 value) external view returns (bool) {
return _bytes32Set.contains(value);
}
function length() external view returns (uint) {
return _bytes32Set.length();
}
function at(uint index) external view returns (bytes32){
return _bytes32Set.at(index);
}
function values() external view returns (bytes32[] memory){
return _bytes32Set.values();
}
}
contract MockAddressSet {
using EnumerableSet for EnumerableSet.AddressSet;
EnumerableSet.AddressSet _addressSet;
function add(address value) external returns (bool) {
return _addressSet.add(value);
}
function remove(address value) external returns (bool){
return _addressSet.remove(value);
}
function contains(address value) external view returns (bool) {
return _addressSet.contains(value);
}
function length() external view returns (uint) {
return _addressSet.length();
}
function at(uint index) external view returns (address){
return _addressSet.at(index);
}
function values() external view returns (address[] memory){
return _addressSet.values();
}
}
contract MockUintSet {
using EnumerableSet for EnumerableSet.UintSet;
EnumerableSet.UintSet _uintSet;
function add(uint value) external returns (bool) {
return _uintSet.add(value);
}
function remove(uint value) external returns (bool){
return _uintSet.remove(value);
}
function contains(uint value) external view returns (bool) {
return _uintSet.contains(value);
}
function length() external view returns (uint) {
return _uintSet.length();
}
function at(uint index) external view returns (uint){
return _uintSet.at(index);
}
function values() external view returns (uint[] memory){
return _uintSet.values();
}
}
全部foundry测试合约:
Github: https://github.com/RevelationOfTuring/foundry-openzeppelin-contracts/blob/master/test/utils/structs/EnumerableSet.t.sol
结构体Set是由一个存储set中元素值的bytes32数组和一个用于记录元素值在元素数组中的index的mapping构成:
struct Set {
// 用于存放set内元素值的数组。存储类型为bytes32,根据需求可以对此进行适当修改
bytes32[] _values;
// 用于记录元素值在_values数组中的index的mapping。如果一个元素值的index记录值为0,表示该元素不存在于set中
mapping(bytes32 => uint256) _indexes;
}
结构体Set和它对应的方法都不对外开放。
_contains(Set storage set, bytes32 value)
:查看元素value是否存在于set中。如果存在,返回true。时间复杂度为O(1);_length(Set storage set)
:返回当前set中的元素个数。时间复杂度为O(1);_at(Set storage set, uint256 index)
:返回当前set中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < set中的元素总个数;_values(Set storage set)
:返回当前set中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。 function _contains(Set storage set, bytes32 value) private view returns (bool) {
// 如果记录的value元素对应index为0表示不存在,不为0表示存在
return set._indexes[value] != 0;
}
function _length(Set storage set) private view returns (uint256) {
// 返回Set._values的长度
return set._values.length;
}
function _at(Set storage set, uint256 index) private view returns (bytes32) {
// 直接从Set._values中用index取值
return set._values[index];
}
function _values(Set storage set) private view returns (bytes32[] memory) {
// 直接将整个bytes32[] storage复制到memory中返回
return set._values;
}
向set中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1)
function _add(Set storage set, bytes32 value) private returns (bool) {
if (!_contains(set, value)) {
// 如果元素value不存在于当前set中,向Set._values中添加该元素
set._values.push(value);
// 在Set._indexes中记录该元素value位于Set._values数组中的index——即当前Set._values数组的长度。
// 注:按照传统编程思想该元素位于Set._values数组最后,其index应该为总长度-1。这里对所有的元素的index记录值都+1,其目的是为了将index 0作为非set元素的flag。如果不这么设计,第一个添加的元素的index就是0,这将导致查询该元素是否处于set中的结果不符合预期
set._indexes[value] = set._values.length;
// 增添了新元素返回true
return true;
} else {
// 如果元素value已存在于set中,不进行任何添加操作并返回false
return false;
}
}
从set中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)
function _remove(Set storage set, bytes32 value) private returns (bool) {
// 获取元素value位于set中的index
uint256 valueIndex = set._indexes[value];
if (valueIndex != 0) {
// 如果valueIndex不为0,表示该元素处于当前set中
// 从一个数组删除某给位置的元素的思路:将数组最后一个元素复制到待删除元素的位置上,然后将最后一个元素pop。该操作的时间复杂度为O(1)
// valueIndex-1为待删除元素在Set._values数组中的真实index
uint256 toDeleteIndex = valueIndex - 1;
// lastIndex为当前数组最后一个元素的真实index
uint256 lastIndex = set._values.length - 1;
if (lastIndex != toDeleteIndex) {
// 如果待删除元素非数组内最后一个元素,取出数组最后一个元素的值
bytes32 lastValue = set._values[lastIndex];
// 数组待删除元素位置上的值替换为当前数组最后一个元素。其实此时已经实现了目标元素真正意义上的删除
set._values[toDeleteIndex] = lastValue;
// 由于数组最后一个元素已经换了位置,更新Set._indexes中最后一个元素的index为valueIndex,即待删除元素在Set._indexes中记录的index值
set._indexes[lastValue] = valueIndex;
}
// 直接pop掉Set._values中的最后一个元素
set._values.pop();
// 删除Set._indexes中关于待删除元素的记录
delete set._indexes[value];
// 返回true
return true;
} else {
// 如果待删除元素value非set中的元素,直接返回false
return false;
}
}
如果要存储的元素类型为bytes32,可以采用该体系
struct Bytes32Set {
// 封装了一个Set
Set _inner;
}
add(Bytes32Set storage set, bytes32 value)
:向Bytes32Set中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1);remove(Bytes32Set storage set, bytes32 value)
:从Bytes32Set中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)。 function add(Bytes32Set storage set, bytes32 value) internal returns (bool) {
// 直接调用Set._add()方法
return _add(set._inner, value);
}
function remove(Bytes32Set storage set, bytes32 value) internal returns (bool) {
// 直接调用Set._remove()方法
return _remove(set._inner, value);
}
contains(Bytes32Set storage set, bytes32 value)
:查看元素value是否存在于Bytes32Set中。如果存在,返回true。时间复杂度为O(1);length(Bytes32Set storage set)
:返回当前Bytes32Set中的元素个数。时间复杂度为O(1);at(Bytes32Set storage set, uint256 index)
:返回当前Bytes32Set中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < Bytes32Set中的元素总个数;values(Bytes32Set storage set)
:返回当前Bytes32Set中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。 function contains(Bytes32Set storage set, bytes32 value) internal view returns (bool) {
// 直接调用Set._contains()方法
return _contains(set._inner, value);
}
function length(Bytes32Set storage set) internal view returns (uint256) {
// 直接调用Set._length()方法
return _length(set._inner);
}
function at(Bytes32Set storage set, uint256 index) internal view returns (bytes32) {
// 直接调用Set._at()方法
return _at(set._inner, index);
}
function values(Bytes32Set storage set) internal view returns (bytes32[] memory) {
// 直接调用Set._values()方法得到底层set中存储的元素总集(是一个bytes32[])
bytes32[] memory store = _values(set._inner);
// 将store转换成Bytes32Set的外层封装类型bytes32[]
// 注:个人认为对于Bytes32Set的values()方法可以直接返回store,不需要再做类型转换。由于本库的代码是由js代码生成的,所以此处没与后面的uint[]和address[]做差别处理
bytes32[] memory result;
/// @solidity memory-safe-assembly
assembly {
// 内联汇编中,直接在memory中进行bytes32[]->bytes32[]的类型转换
result := store
}
// 返回类型转换后的bytes32[]
return result;
}
contract EnumerableSetTest is Test {
MockBytes32Set mbs = new MockBytes32Set();
function test_Bytes32Set_Operations() external {
// empty
assertEq(mbs.length(), 0);
assertEq(mbs.values().length, 0);
assertFalse(mbs.contains('a'));
// add
assertTrue(mbs.add('a'));
assertTrue(mbs.contains('a'));
assertEq(mbs.length(), 1);
assertTrue(mbs.add('b'));
assertEq(mbs.length(), 2);
// add 'a' again
assertFalse(mbs.add('a'));
assertEq(mbs.length(), 2);
bytes32[] memory values = mbs.values();
assertEq('a', values[0]);
assertEq('b', values[1]);
assertTrue(mbs.add('c'));
assertTrue(mbs.add('d'));
assertEq(mbs.length(), 4);
// remove
// inner array: ['a','b','c','d']
assertTrue(mbs.contains('b'));
assertTrue(mbs.remove('b'));
assertFalse(mbs.contains('b'));
assertEq(mbs.length(), 3);
// remove 'b' again
assertFalse(mbs.remove('b'));
assertEq(mbs.length(), 3);
// inner array after remove: ['a','d','c']
assertEq(mbs.at(0), 'a');
assertEq(mbs.at(1), 'd');
assertEq(mbs.at(2), 'c');
// check values()
values = mbs.values();
assertEq('a', values[0]);
assertEq('d', values[1]);
assertEq('c', values[2]);
// revert if out of bounds
vm.expectRevert();
mbs.at(1024);
}
}
如果要存储的元素类型为address,可以采用该体系
struct AddressSet {
// 封装了一个Set
Set _inner;
}
add(AddressSet storage set, address value)
:向AddressSet中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1);remove(AddressSet storage set, address value)
: 从AddressSet中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)。 function add(AddressSet storage set, address value) internal returns (bool) {
// 直接调用Set._add()方法,参数value做了address->bytes32的类型转换
return _add(set._inner, bytes32(uint256(uint160(value))));
}
function remove(AddressSet storage set, address value) internal returns (bool) {
// 直接调用Set._remove()方法,参数value做了address->bytes32的类型转换
return _remove(set._inner, bytes32(uint256(uint160(value))));
}
contains(AddressSet storage set, address value)
:查看元素value是否存在于AddressSet中。如果存在,返回true。时间复杂度为O(1);length(AddressSet storage set)
:返回当前AddressSet中的元素个数。时间复杂度为O(1);at(AddressSet storage set, uint256 index)
:返回当前AddressSet中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < AddressSet中的元素总个数;values(AddressSet storage set)
:返回当前AddressSet中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。 function contains(AddressSet storage set, address value) internal view returns (bool) {
// 直接调用Set._contains()方法,参数value做了address->bytes32的类型转换
return _contains(set._inner, bytes32(uint256(uint160(value))));
}
function length(AddressSet storage set) internal view returns (uint256) {
// 直接调用Set._length()方法
return _length(set._inner);
}
function at(AddressSet storage set, uint256 index) internal view returns (address) {
// 直接调用Set._at()方法,并将bytes32类型的返回值转换为address类型返回
return address(uint160(uint256(_at(set._inner, index))));
}
function values(AddressSet storage set) internal view returns (address[] memory) {
// 直接调用Set._values()方法得到底层set中存储的元素总集(是一个bytes32[])
bytes32[] memory store = _values(set._inner);
// 将store转换成AddressSet的外层封装类型address[]
address[] memory result;
/// @solidity memory-safe-assembly
assembly {
// 内联汇编中,直接在memory中进行bytes32[]->address[]的类型转换
result := store
}
// 返回类型转换后的address[]
return result;
}
contract EnumerableSetTest is Test {
MockAddressSet mas = new MockAddressSet();
function test_AddressSet_Operations() external {
// empty
assertEq(mas.length(), 0);
assertEq(mas.values().length, 0);
assertFalse(mas.contains(address(1)));
// add
assertTrue(mas.add(address(1)));
assertTrue(mas.contains(address(1)));
assertEq(mas.length(), 1);
assertTrue(mas.add(address(2)));
assertEq(mas.length(), 2);
// add address(1) again
assertFalse(mas.add(address(1)));
assertEq(mas.length(), 2);
address[] memory values = mas.values();
assertEq(address(1), values[0]);
assertEq(address(2), values[1]);
assertTrue(mas.add(address(4)));
assertTrue(mas.add(address(8)));
assertEq(mas.length(), 4);
// remove
// inner array: [address(1),address(2),address(4),address(8)]
assertTrue(mas.contains(address(2)));
assertTrue(mas.remove(address(2)));
assertFalse(mas.contains(address(2)));
assertEq(mas.length(), 3);
// remove address(2) again
assertFalse(mas.remove(address(2)));
assertEq(mas.length(), 3);
// inner array after remove: [address(1),address(8),address(4)]
assertEq(mas.at(0), address(1));
assertEq(mas.at(1), address(8));
assertEq(mas.at(2), address(4));
// check values()
values = mas.values();
assertEq(address(1), values[0]);
assertEq(address(8), values[1]);
assertEq(address(4), values[2]);
// revert if out of bounds
vm.expectRevert();
mas.at(1024);
}
}
如果要存储的元素类型为uint256,可以采用该体系
struct UintSet {
// 封装了一个Set
Set _inner;
}
add(UintSet storage set, uint256 value)
:向UintSet中增添元素。如果该元素为非set元素返回true,否则返回false。时间复杂度为O(1);remove(UintSet storage set, uint256 value)
:从UintSet中移除元素。如果该元素为当前set元素返回true,否则返回false。时间复杂度为O(1)。 function add(UintSet storage set, uint256 value) internal returns (bool) {
// 直接调用Set._add()方法,参数value做了uint256->bytes32的类型转换
return _add(set._inner, bytes32(value));
}
function remove(UintSet storage set, uint256 value) internal returns (bool) {
// 直接调用Set._remove()方法,参数value做了uint256->bytes32的类型转换
return _remove(set._inner, bytes32(value));
}
contains(UintSet storage set, uint256 value)
:查看元素value是否存在于UintSet中。如果存在,返回true。时间复杂度为O(1);length(UintSet storage set)
:返回当前UintSet中的元素个数。时间复杂度为O(1);at(UintSet storage set, uint256 index)
:返回当前UintSet中对应index位置上的元素值。时间复杂度为O(1)。注意:方法内无索引越界检查,所以使用时需要保证传入的index < UintSet中的元素总个数;values(UintSet storage set)
:返回当前UintSet中全部的元素值(无序)。注意:该方法内部会将storage数组中的全部元素复制到memory中,这将消耗大量gas。所以请不要在非view方法中调用该方法。 function contains(UintSet storage set, uint256 value) internal view returns (bool) {
// 直接调用Set._contains()方法,参数value做了uint256->bytes32的类型转换
return _contains(set._inner, bytes32(value));
}
function length(UintSet storage set) internal view returns (uint256) {
// 直接调用Set._length()方法
return _length(set._inner);
}
function at(UintSet storage set, uint256 index) internal view returns (uint256) {
// 直接调用Set._at()方法,并将bytes32类型的返回值转换为uint256类型返回
return uint256(_at(set._inner, index));
}
function values(UintSet storage set) internal view returns (uint256[] memory) {
// 直接调用Set._values()方法得到底层set中存储的元素总集(是一个bytes32[])
bytes32[] memory store = _values(set._inner);
// 将store转换成UintSet的外层封装类型uint256[]
uint256[] memory result;
/// @solidity memory-safe-assembly
assembly {
// 内联汇编中,直接在memory中进行bytes32[]->uint256[]的类型转换
result := store
}
// 返回类型转换后的uint256[]
return result;
}
contract EnumerableSetTest is Test {
MockUintSet mus = new MockUintSet();
function test_UintSet_Operations() external {
// empty
assertEq(mus.length(), 0);
assertEq(mus.values().length, 0);
assertFalse(mus.contains(1));
// add
assertTrue(mus.add(1));
assertTrue(mus.contains(1));
assertEq(mus.length(), 1);
assertTrue(mus.add(2));
assertEq(mus.length(), 2);
// add 1 again
assertFalse(mus.add(1));
assertEq(mus.length(), 2);
uint[] memory values = mus.values();
assertEq(1, values[0]);
assertEq(2, values[1]);
assertTrue(mus.add(4));
assertTrue(mus.add(8));
assertEq(mus.length(), 4);
// remove
// inner array: [1,2,4,8]
assertTrue(mus.contains(2));
assertTrue(mus.remove(2));
assertFalse(mus.contains(2));
assertEq(mus.length(), 3);
// remove 2 again
assertFalse(mus.remove(2));
assertEq(mus.length(), 3);
// inner array after remove: [1,8,4]
assertEq(mus.at(0), 1);
assertEq(mus.at(1), 8);
assertEq(mus.at(2), 4);
// check values()
values = mus.values();
assertEq(1, values[0]);
assertEq(8, values[1]);
assertEq(4, values[2]);
// revert if out of bounds
vm.expectRevert();
mus.at(1024);
}
}
ps:
本人热爱图灵,热爱中本聪,热爱V神。
以下是我个人的公众号,如果有技术问题可以关注我的公众号来跟我交流。
同时我也会在这个公众号上每周更新我的原创文章,喜欢的小伙伴或者老伙计可以支持一下!
如果需要转发,麻烦注明作者。十分感谢!
公众号名称:后现代泼痞浪漫主义奠基人