用Go代码模拟数组,足够硬核

前奏

书接上文:https://blog.csdn.net/dawnto/article/details/135289279
我们对数组下了一个定义:数组是对线性的内存区域的抽象;高维数组和一维数组有着同样的内存布局。

如果我们想要将一片连续的内存区域映射成高维度数组。或者说的更直白一些,想将高维数组映射到一维数组。
完成这个想法的前提是寻找高维数组和一维数组的下标关系;从而计算地址偏移量即可。

关系映射

约定一下,将左侧数组称为arr0,右侧一维数组称为arr1。

一维数组

假设将一个一维数组,映射到另一个一维数组。那么关系十分简单。如下图所示。
用Go代码模拟数组,足够硬核_第1张图片

arr0和 arr1长度相等,下标从0开始滚动直到arr等长度 - 1,假设arr0和arr1的长度为N,则:
i的取值范围是:[0,N)
k的取值范围是:[0,N)
即一维数组到一维数组的下标映射,使下标i和下标index同步即可。

arr1[0] = arr0[1]
arr1[1] = arr0[1]
...
arr1[index] = arr0[i]

公式:
index = i
arr1[index] = arr0[i]
arr1[i] = arr0[i]

二维数组

二维数组的下标向一维数组下标映射。
用Go代码模拟数组,足够硬核_第2张图片

假设左侧arr0的长度是M,arr0[0]中的一维数组长度是N,右侧arr1的长度是W
则下标关系可表示为:
i:[0,M)
j:[0,N)
index:[0,W)
那么,左右两个数组的长度关系可表示为:
W = M * N
若左侧数组arr0从前向后遍历,arr1同步遍历,则i,j与右侧数组的下标index关系可以表示为
index = i * N + j

arr1[0] = arr1[0][0]
arr1[1] = arr1[0][1]
arr1[2] = arr1[1][0]
arr1[3] = arr0[1][1]

用Go代码模拟数组,足够硬核_第3张图片上图数组下标关系映射为

arr1[0] = arr0[0][0]
arr1[1] = arr0[0][1]
arr1[2] = arr0[0][2]
arr1[3] = arr0[1][0]
arr1[4] = arr0[1][1]
...
arr1[8] = arr0[2][2]

公式:
index = i * N + j
arr1[index] = arr0[i*N][j]
arr1[i * N + j] = arr0[i][j]

三维数组

3维数组到1维数组映射关系。

var arr [3][2][3]int
arr[0][0][0]  ==> arr1[0*2*3 + 0*3 + 0*1] =>  arr1[0]
arr[0][0][1]  ==> arr1[0*2*3 + 0*3 + 1*1] =>  arr1[1]
arr[0][0][2]  ==> arr1[0*2*3 + 0*3 + 2*1] =>  arr1[2]
arr[0][1][0]  ==> arr1[0*2*3 + 1*3 + 0*1] =>  arr1[3]
arr[0][1][1]  ==> arr1[0*2*3 + 1*3 + 1*1] =>  arr1[4]
arr[0][1][2]  ==> arr1[0*2*3 + 1*3 + 2*1] =>  arr1[5]
arr[1][0][0]  ==> arr1[1*2*3 + 0*3 + 0*1] =>  arr1[6]
arr[1][0][1]  ==> arr1[1*2*3 + 0*3 + 1*1] =>  arr1[7]
arr[1][0][2]  ==> arr1[1*2*3 + 0*3 + 2*1] =>  arr1[8]
arr[1][1][0]  ==> arr1[1*2*3 + 1*3 + 0*1] =>  arr1[9]
arr[1][1][1]  ==> arr1[1*2*3 + 1*3 + 1*1] =>  arr1[10]
arr[1][1][2]  ==> arr1[1*2*3 + 1*3 + 2*1] =>  arr1[11]
arr[2][0][0]  ==> arr1[2*2*3 + 0*3 + 0*1] =>  arr1[12]
arr[2][0][1]  ==> arr1[2*2*3 + 0*3 + 1*1] =>  arr1[13]
arr[2][0][2]  ==> arr1[2*2*3 + 0*3 + 2*1] =>  arr1[14]
arr[2][1][0]  ==> arr1[2*2*3 + 1*3 + 0*1] =>  arr1[15]
arr[2][1][1]  ==> arr1[2*2*3 + 1*3 + 1*1] =>  arr1[16]
arr[2][1][2]  ==> arr1[2*2*3 + 1*3 + 2*1] =>  arr1[17]

假设有两个数组

var [M][N][P]int arr0
var [M*N*P]int   arr1

假设3维数组下标分别是,i,j,k, 一维数组下标为index则:
index = i * N * P + j * P + k
arr1[index] = arr0[i][j][k]
arr1[i * N * P + j * P + k] = arr0[i][j][k]

总结一下公式

高维数组向一维数组映射,下标关系为:

var [A][B][C]...[Z]int  array0
var [A * B * C *...* Z] array1

array1[a * B*C...*Z + b * C*D*...*Z + z * 1] = array0[a][b]...[z]

假设我们有一个容器:dimensions := []int{2, 3, 5} // var array [2][3][5]int,存储了某数组各个维度的宽度。则可以通过逆序乘积的方式,求各个维度的系数,用下标i,j,k去乘以每个维度的系数就可以得到线性的一维数组下标,fmt.Println(i*tmpArray[0] + j*tmpArray[1] + k*tmpArray[2])

func TestDimension(t *testing.T) {
	dimensions := []int{2, 3, 5} //var array [2][3][5]int
	tmpArray := make([]int, len(dimensions))

	mul := 1 // 总乘积
	for i := len(dimensions) - 1; i >= 0; i-- {
		tmpArray[i] = mul
		mul *= dimensions[i]
	}
	//fmt.Println(dimensions)
	//fmt.Println(tmpArray)
	for i := 0; i < 2; i++ {
		for j := 0; j < 3; j++ {
			for k := 0; k < 5; k++ {
				fmt.Println(i*tmpArray[0] + j*tmpArray[1] + k*tmpArray[2])
			}
		}
	}
}

模拟数组代码

package Go

import (
	"fmt"
	"testing"
	"unsafe"
)

type array[T any] struct {
	data      unsafe.Pointer // 内存区域指针
	n         int            // 数组有n个维度
	dimension map[int]uint   // 每个维度的数据宽度,用于判定数组索引是否越界
	factor    map[int]uint   // 映射到一维数组的偏移量
	elemSize  uintptr        // 类型大小
}

// NewArray create an Array.
// array := NewArray[int](10,5,7,8) equals var array [10][5][7][8]int
func NewArray[T any](dimensions ...uint) *array[T] {
	n := len(dimensions) // n个维度的数组
	if n == 0 {
		panic("invalid dimensions")
	}
	dimensionMap := make(map[int]uint)
	factorMap := make(map[int]uint)

	f := uint(1)
	for i := range dimensions {
		dimensionMap[n-i-1] = dimensions[i]
		factorMap[n-i-1] = f // 反向索引映射
		f *= dimensions[n-i-1]
	}

	return &array[T]{
		data:      malloc[T](f),
		n:         n,
		dimension: dimensionMap,
		factor:    factorMap,
		elemSize:  unsafe.Sizeof(*(*T)(nil)),
	}
}

func malloc[T any](size uint) unsafe.Pointer {
	origin := unsafe.Sizeof(*(*T)(nil))
	memPointer := make([]uint8, uint(origin)*size)
	// defer runtime.KeepAlive(memPointer)
	return *(*unsafe.Pointer)(unsafe.Pointer(&memPointer))
}

func (a *array[T]) Set(value T, index ...uint) {
	if len(index) != a.n {
		panic("out of dimensions")
	}
	location := uint(0)
	for i := range index {
		if index[i] >= a.dimension[i] || index[i] < 0 {
			panic("index out of bounds")
		}
		location += index[i] * a.factor[i]
	}
	*(*T)(unsafe.Add(a.data, uintptr(location)*a.elemSize)) = value
}

func (a *array[T]) Get(index ...uint) T {
	if len(index) != a.n {
		panic("out of dimensions")
	}
	location := uint(0)
	for i := range index {
		if index[i] >= a.dimension[i] || index[i] < 0 {
			panic("index out of bounds")
		}
		location += index[i] * a.factor[i]
	}
	return *(*T)(unsafe.Add(a.data, uintptr(location)*a.elemSize))
}

type student struct {
	Name string
	Age  int
}

func TestArray(t *testing.T) {
	studentArray := NewArray[student](3) // [10]student
	studentArray.Set(student{            // student[0] = student{...}
		Name: "zhang san",
		Age:  10,
	}, 0)

	studentArray.Set(student{ // student[1] = student{...}
		Name: "li si",
		Age:  20,
	}, 1)

	fmt.Println(studentArray.Get(0))
	fmt.Println(studentArray.Get(1))
	fmt.Println(studentArray.Get(2))

	array1 := NewArray[int](2, 2) // [2][2]int
	array1.Set(1, 0, 0)           // arr[0][0] = 1
	array1.Set(2, 0, 1)           // arr[0][1] = 2
	array1.Set(3, 1, 0)           // arr[1][0] = 3
	array1.Set(4, 1, 1)           // arr[1][1] = 4

	fmt.Println(array1.Get(0, 0)) // arr[0][0]
	fmt.Println(array1.Get(0, 1)) // arr[0][1]
	fmt.Println(array1.Get(1, 0)) // arr[1][0]
	fmt.Println(array1.Get(1, 1)) // arr[1][1]

	array2 := NewArray[int](3, 3, 3) // [3][3][3]int
	w := 1
	for i := 0; i < 3; i++ {
		for j := 0; j < 3; j++ {
			for k := 0; k < 3; k++ {
				array2.Set(w, uint(i), uint(j), uint(k))
				w++
			}
		}
	}

	w = 1
	for i := 0; i < 3; i++ {
		for j := 0; j < 3; j++ {
			for k := 0; k < 3; k++ {
				num := array2.Get(uint(i), uint(j), uint(k))
				if num != w {
					fmt.Println("---", num, w)
					panic("oi! 出错了")
				}
				w++
			}
		}
	}

	arr := *(*[27]int)(array2.data)
	fmt.Println(arr) // [1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27

}

Reference
https://blog.csdn.net/dawnto/article/details/135289279

你可能感兴趣的:(#,具象的Go,golang,数组)