用来描述算法渐近运行时间的记号,根据定义域为自然数集 N={0,1,2,⋯} 的函数来定义。这样的记号对描述最坏情况运行时间函数 T(n) 是方便的,因为该函数通常只定义在整数输入规模上。
对一个给定的函数 g(n) ,用 Θ(g(n)) 来表示一下函数的集合:
Θ(g(n))={f(n):存在正常量c1、c2和c0,使得对所有n≥n0,有0≤c1g(n)≤f(n)≤c2g(n)}
若存在正常量 c1 和 c2 使得对足够大的 n ,函数 f(n) 能“夹入” c1g(n) 与 c2g(n) 之间,则 f(n) 属于集合 Θ(g(n)) ,因为 Θ(g(n)) 是一个集合,所以可以记“ f(n)∈Θ(g(n)) ”,以指出 f(n) 是 Θ(g(n)) 的成员。作为替代,我们通常记“ f(n)=Theta(g(n)) ”以表达相同的概念。
我们称 g(n) 是 f(n) 的一个渐进紧确界(asymptotically tight bound)。
一般来说,对任意多项式 p(n)=∑ni=1aini ,其中 ai 为常量且 ad>0 ,我们有 p(n)=Θ(nd) 。
因为任意常量是一个0阶多项式,所以可以把任意常量函数表示成 Θ(n0) 或者 Θ(1) 。
Θ 记号渐近地给出一个函数的上界和下界。当只有一个渐进上界时,使用 O 记号。对于给定的函数 g(n) ,用 O(g(n)) 来表示以下函数的集合:
O(g(n))={f(n):存在常量c和n0,使得对所有n≥n0时,有0≤f(n)≤cg(n)}
正如 O 记号提供了一个函数的渐近上界, Ω 记号提供了渐近下界。对于给定的函数 g(n) ,用 Ω(g(n)) 来表示以下函数的集合:
Ω(g(n))={f(n):存在正常量c和n0,使得对所有n≥n0时,有0≤cg(n)≤f(n)}
定理3.1:对任意两个函数 f(n) 和 g(n) ,我们有 f(n)=Θg(n) ,当且仅当: f(n)=Og(n) 且 f(n)=Ωg(n) 。
由 O 记号提供的渐近上界可能是也可能不是渐近紧确的。界 2n2=Og(n2) 是渐近紧确的,但是界 2n=Og(n2) 却不是。我们使用 o 记号来表示一个非渐进紧确的上界。
形式地定义 o(g(n)) 为以下集合:
o(g(n))={f(n):对任意正常量c>0,存在常量n0>0,使得对所有n>n0,有0≤f(n)<cg(n)}
O 记号与 o 记号的定义类似。主要区别在于是 f(n)=O(g(n)) 中,界 0≤f(n)≤cg(n) 对某个常量 c>0 成立,但在 f(n)=o(g(n)) 中,界 0≤f(n)<cg(n) 对所有常量 c>0 都成立。直观上,在 o 记号中,当 n 趋于无穷时,函数 f(n) 相对于 g(n) 来说变得微不足道了,即:
ω 记号与 Ω 记号的关系类似于 o 与 O 的关系。我们使用 ω 来定义一个非渐近紧确的下界。定义它的一种方式是:
f(n)∈ωg(n) 当且仅当 g(n)∈o(f(n))
然而我们形式化地定义 f(n)=ω(g(n)) 为以下集合:
ω(g(n))={f(n):对任意正常量c>0,存在常量n0>0,使得对所有的n≥n0时,有0≤cg(n)<f(n)}
关系 f(n)=ω(g(n)) 蕴含着:
比较各种函数
实数的许多关键性质也适用于渐近比较。下面假定 f(n) 和 g(n) 渐近为正。
传递性
三分性 对任意两个实数a和b,下列三种情况恰有一种必须成立: a<b, a=b, a>b
虽然任意两个时序都可以进行比较,但是不是所有的函数都可渐近比较。也就是说,对两个函数 f(n) 和 g(n) ,也许 f(n)=O(g(n)) 和 f(n)=Ω(g(n)) 都不成立。
若 m≤n 蕴含 f(m)≤f(n) ,则函数 f(n) 是单调递增的。类似的,若 m≤n 蕴含 f(m)≥f(n) ,则函数 f(n) 是单调递减的。若 m<n 蕴含 f(m)<f(n) ,则函数 f(n) 是严格单调递增的。类似的,若 m<n 蕴含 f(m)>f(n) ,则函数 f(n) 是严格单调递减的。
对于任意实数 x ,我们用 ⌊x⌋ 表示 x 的向下取整,并用 ⌈x⌉ 表示 x 的向上取整。
对所有实数:
对任意整数 a 和任意正整数 n , a mod n 的值就是商 a/n 的余数。
给定一个非负整数 d ,n的d次多项式为具有以下形式的一个函数 p(n) :
其中常量 a1,a2,⋯,an 是多项式的系数且 ad≠0 。一个多项式渐近正的当且仅当 ad>0 。对于一个 d 次渐近正的多项式 p(n) ,有 p(n)=Θ(nd) 。对任意实常数 a≥0 ,函数 na 单调递增,对任意实常量 a≤0 ,函数 na 单调递减。若对某个常量 k ,有 f(n)=O(nk) ,则称函数 f(n) 是多项式有界的。
对所有实数 a>0 、 m 和 n ,我们有以下恒等式:
使用 e 来表示自然对数函数的底 2.71823⋯ ,对所有实数 x ,我们有
我们使用下面的记号:
对于所有实数 a>0, b>0, c>0 和 n ,有
从公式 logba=logcalogcb 中可以看出,对数的底从一个常量到另一个常量的更换仅使对数的值改变一个常量因子,所以当我们不关心这些常量因子时,例如在 O 记号中,我们经常使用“ lg n ”。计算机科学家发现2是对数的最自然的底,因为非常多的算法和数据结构涉及把一个问题分解成两部分。
若对某个常量 k , f(n)=O(lgk n) ,则称函数 f(n) 是多对数有界的。
多项式与多对数的增长相互关联:
记号 n! ,定义为对整数 n≥0 ,有:
我们使用记号 f(i)n 来表示函数 f(n) 重复 i 次作用与处置n上。形式化地,假设 f(n) 为实数集上的一个函数。对非负整数 i ,我们递归地定义
我们使用记号 lg∗ n 来表示多重对数函数,下面给出它的定义。假设 lg(i)n 定义如上,其中 f(n)=lg n 。因为非正数的对数无定义,所以只有在 lg(i−1)>0 时, lg(i)n 才有定义。定义多重对数函数为
使用下面的递归式来定义菲波那切数:
输入: n个数的一个序列 A=⟨a1,a2,⋯,an⟩ 和一个值 v 。
输出:下标 i 使得 v=A[i] 或者当 v 不在A中出现时, i 为特殊值-1。
LINEAR-SEARCH(A, v)
for i = 1 to A.length
if v == A[i]
return i
return -1
java实现:
public static int linearSearch(int[] srcArr, int val) {
for (int i = 0; i < srcArr.length; i++) {
if (srcArr[i] == val)
return i;
}
return -1;
}
如果数组 A 已经排好序,就可以将该序列的中点与 v 进行比较根据比较的结果,原序列中有一半就可以不用再进一步的考虑了。二分查找算法重复这个过程,每次都将序列剩余部分的规模减半。
非递归方式:
BINARY-SEARCH(A, v)
low = 1
high = A.length
while low <= high
middle = (low + high) / 2
if A[middle] < v
low = middle + 1
elseif A[middle] > v
high = middle - 1
else
return middle
return -1
java实现:
/**
* 利用非递归形式的二分查找法在数组中寻找特定的值
* @param srcArray 被搜索的整形数组
* @param val 待查找的值
* @return 若该值在数组中,返回该值对应的数组索引。否则返回-1
*/
public static int binarySearch(int[]srcArray, int val){
//低位“指针”
int low = 0;
//高位“指针”
int high = srcArray.length - 1;
//如果low ≤ high则进行查找,
// 因为无论数组元素为偶数个还是奇数个,当要查找的值不在数组中时最后一步查找情况是low和high重合,此时middle=low=high,
// 如果srcArray[middle]>val,执行low = middle + 1,此时low>high;
//如果srcArray[middle]high;
while(low <= high){
int middle = (low + high) >> 1;
//当数组中间值小于待查找值,该值“可能”在数组右半侧,并且索引middle处的值已经判断过,所以low=middle+1,
//并且如果low=middle,在[srcArray[low], srcArray[high]]会陷入死循环
if(srcArray[middle] < val){
low = middle + 1;
}else if(srcArray[middle] > val){
high = middle - 1;
//找到待查找值,返回该值对应数组索引
}else{
return middle;
}
}
//当待查找值不在数组中时返回-1
return -1;
}
递归方式:
BINARY-SEARCH(A, low, high, val)
while low <= high
middle = (low + high) / 2
if A[middle] == val
return middle
elseif A[middle] < val
return BINARY-SEARCH(A, middle + 1, high, val)
else
return BINARY-SEARCH(A, low, middle - 1, val)
return -1
java实现:
public static int binarySearch(int[]srcArray, int low, int high, int val){
while (low <= high) {
int middle = (low + high) >> 1;
if (srcArray[middle] == val) {
return middle;
} else if (srcArray[middle] < val)
return binarySearch(srcArray, middle + 1, high, val);
else
return binarySearch(srcArray, low, middle - 1, val);
}
return -1;
}
输入: n个数的一个序列 ⟨a1,a2,⋯,an⟩ 。
输出: ⟨a′1,a′1,⋯,a′n⟩ ,满足 a′1≤a′2≤⋯≤a′n⟩ 。
对于少量元素的排序,它是一个有效的算法。
如图所示,一副完整的牌就像一个数组A,手中的牌是A[1..j - 1]已经按从小到大排好,这时你从桌子上的牌堆A[j..A.length]中取出一张牌A[j],你要做的就是将这张牌插入到手中的牌里。手中原来有牌{2, 4, 5, 10},从桌子上取出一张牌是7,和最大的牌10比,7<10,10往后移一个位置,7和5比,7>5,那么就将7插入刚才10空出的位置,以此类推。
INSERTION-SORT(A)
for j = 2 to A.length
key = A[j]
//Insert A[j] into the sorted sequence A[1..j - 1].
i = j - 1
while i > 0 and A[i] > key
A[i + 1] = A[i]
i = i - 1
A[i + 1] = key
java实现
/**
* 对一个整型数组按从小到大的顺序排序
* @param arr 待排序的数组
*/
public static void insertionSort(int[] arr) {
for (int i = 0; i < arr.length; i++) {
int key = arr[i];
int j = i - 1;
//将当前值key与已排好部分arr[0..i-1]中的值挨个比较大小
//如果key大于已排好数组中arr[j],则将arr[j]往后“移一位”,将当前位置腾出来,保存key或前面移过来的值
while (j >= 0 && arr[j] > key) {
arr[j + 1] = arr[j];
j--;
}
//①如果key不小于已排好数组最大值(即已排好部分最后一个值),将key放到arr[j + 1]即arr[i]
//相当于将arr[i]拿出来比较一下,发现arr[i]不小于已排好数组中最大值,再讲arr[i]放回去;
//②如果key小于已排好数组最大值,那么经过while循环当前的arr[j]是第一个小于key的数,所以将
//key放在arr[j + 1]已腾出来的位置,arr[j + 1..arr[i]]里面的元素已经依次向右移动一个位置。
arr[j + 1] = key;
}
}
插入排序的:
事实上,元素A[1, j - 1]就是原来1到j - 1的元素,但是现在已按顺序排列。我们把A[1..j-1]的这些性质形式地表示为一个循环不变式。
循环不变式主要用来帮助我们理解算法的正确性。关于循环不变式,我们必须证明三条性质:
- 初始化:循环的第一次迭代之前,它为真。
- 保持:如果循环的某次迭代之前它为真,那么下次迭代之前塔仍然为真。
- 终止:在循环终止时,不变式为我们提供一个有用的性质,该性质有助于证明算法是正确的。
归并排序算法完全遵循分治模式,直观上其操作如下:
1. 分解:分解带排序的n个元素的序列成各具n/2个元素的两个子序列。
2. 解决:使用归并排序递归地排序两个子序列。
3. 合并:合并两个已经排序的子序列以产生已排序的答案。
MERGE-SORT(A, p, r)
if p < r
q = ⌊(p + r) / 2⌋
MERGE-SORT(A, p, q)
MERGE-SORT(A, q, r)
MERGE(A, p, q, r)
MEARGE(A, p, q, r)
n1 = q - p + 1
n2 = r - q
let L[1..n1 + 1] and R[1..n2 + 1] be new arrays
for i = 1 to n1
L[i] = A[p + i - 1]
for j = 1 to n2
R[j] = A[q + j]
i = 1
j = 1
for k = p to r
if i != n1 and (j == n2 or L[i] ≤ R[j])
A[k] = L[i]
i = i + 1
else
A[k] = R[j]
j = j + 1
java实现
/**
* 给定一个数组,起始位置,终止位置,对数组[起始位置..终止位置]按从小到大排序
* @param arr 待排序数组
* @param p 排序起始位置
* @param r 排序终止位置
*/
public static void mergeSort(int[] arr, int p, int r){
if (p < r){
//取中间位置,将数组分为左右两部分
int q = (p + r) / 2;
//递归对左数组进行排序
mergeSort(arr, p, q);
//递归对右数组进行排序
mergeSort(arr, q + 1, r);
//调用方法merge将数组中arr[p..r]部分进行排序
merge(arr, p, q, r);
}
}
/**
* 给定一个数组arr,给定三个索引参数,p、q、r,满足p≤q≤r,将arr[p..r]分成两个数组
* arr[p..q]和arr[q+1..r],再融合两个数组的过程中对数组进行排序,融合后的数组即排好序的数组
* @param arr 待排序的数组
* @param p 待排序数组首位索引
* @param q 待排序数组中间索引
* @param r 待排序数组摸位索引
*/
public static void merge(int[] arr, int p, int q, int r){
int lLen = q - p + 1;
int rLen = r - q;
//用于接收左半数组
int[] lArr = new int[lLen];
//用于接收右半数组
int[] rArr = new int[rLen];
//将arr[p..q]复制到lArr
System.arraycopy(arr, p, lArr, 0, lLen);
//将arr[q+1..r]复制到rArr
System.arraycopy(arr, q + 1, rArr, 0, rLen);
int i = 0, j = 0;
for (int k = p; k <= r; k++) {
//取lArr中的值首先需要满足lArr数组索引没有越界,这个前提下有两种情况,
//①rLen索引已到rLen.length,即rLen中的值都被取出
//②两个数组中都有值,并且lArr[i] <= rArr[j]
if (i != lLen && (j == rLen || lArr[i] <= rArr[j])){
arr[k] = lArr[i];
i++;
} else {
arr[k] = rArr[j];
j++;
}
}
}
考虑排序存储在数组A中的n个数:首先找出A中的最小元素并将其与A[1]中的元素进行交换。接着,找出A中次最小元素并将其与A[2]中的元素进行交换。对A中的前n-1个元素按该方式继续。
SELECTION-SORT
for i = 1 to A.length - 1
for j = i to A.lenth
if A[j] < A[i]
//change position
temp = A[i]
A[i] = A[j]
A[j] = temp
java实现:
public static void selectionSort(int[] arr) {
for (int i = 0; i < arr.length - 1; i++) {
for (int j = i; j < arr.length; j++) {
if (arr[j] < arr[i]) {
//①a ^ a = 0;②a ^ 0 = a;③a ^ b ^ c = a ^ (b ^ c) = (a ^ b) ^ c;
arr[i] = arr[i] ^ arr[j];
// arr[j] = arr[i] ^ arr[j] ^ arr[j] = arr[i]
arr[j] = arr[i] ^ arr[j];
// arr[i] = arr[i] ^ arr[j] ^ arr[j] = arr[j] ^ arr[i] ^ arr[i] = arr[j]
arr[i] = arr[i] ^ arr[j];
}
}
}
}
分治法:将原问题分解为几个规模较小但类似于原问题的子问题,递归地求解这些子问题,然后再合并这些子问题的解来建立原问题的解。
分治模式在每层递归时都有三个步骤:
1. 分解原问题为若干子问题,这些子问题是原问题的规模较小的实例。
2. 解决这些子问题,递归地求解各个子问题。然而,若子问题的规模足够小,则直接求解。
3. 合并这些子问题的解成原问题的解。
递归式:递归式与分治方法是紧密相关的,因为使用递归式可以很自然地刻画分治算法的运行时间。一个递归式(recurrence)就是一个等式或一个不等式。
本章介绍三种求解递归式的方法,即得出算法“ Θ ”和“ O ”渐近界的方法。
有一整形数组 A ,找出 A 中和为最大的非空连续子数组。我们称这样的连续子数组为最大连续子数组。
遍历所有可能的数组组合,找出其中和最大的。
FIND-MAXIMUM-SUBARRAY(A, low, high)
sum = -∞
for i = 1 to A.length
tempSum = 0
for j = i to A.length
tempSum = tempSum + A[j]
if tempSum > sum
sum = tempSum
max-left = i
max-right = j
return (max-left, max-right, left-sum + right-sum)
java实现
public static int[] findMaximumSubarray(int[] arr, int low, int high) {
//假设arr[0]就是最大连续子数组
int sum = arr[0];
int maxLeft = 0;
int maxRight = 0;
for (int i = 0; i < arr.length; i++) {
int tempSum = 0;
for (int j = i; j < arr.length; j++) {
tempSum += arr[j];
if (tempSum > sum) {
sum = tempSum;
maxLeft = i;
maxRight = j;
}
}
}
return new int[]{maxLeft, maxRight, sum};
}
过程FIND-MAX-CORSSING-SUBARRAY接受数组 A 和下标 low 、 mid 、 high 作为输入,返回一个下标元祖规定跨越中点的最大子数组的边界,并返回最大子数组中值的和。
FIND-MAXIMUM-SUBARRAY(A, low, high)
if high == low
return (low, high, A[low])
else mid = ⌊(low + high) / 2⌋
(left-low, left-high, left-sum) =
FIND-MAXIMUM-SUBARRAY(A, low, mid)
(right-low, right-high, right-sum) =
FIND-MAXIMUM-SUBARRAY(A, mid + 1, high)
(cross-low, cross-high, cross-sum) =
FIND-MAX-CORSSING-SUBARRAY(A, low, mid, high)
if left-sum ≥ right-sum and left-sum ≥ cross-sum
return (left-low, left-high, left-sum)
elseif right-sum ≥ left-sum and right-sum ≥ cross-sum
return (right-low, right-high, right-sum)
else return (cross-low, cross-high, cross-sum)
FIND-MAX-CORSSING-SUBARRAY(A, low, mid, high)
left-sum = -∞
sum = 0
for i = mid downto low
sum = sum + A[i]
if sum > left-sum
left-sum = sum
max-left = i
right-sum = -∞
sum = 0
for j = mid + 1 to high
sum = sum + A[j]
if sum > right-sum
right-sum = sum
max-right = j
return (max-left, max-right, left-sum + right-sum)
java实现:
/**
* 该方法接收一个数组和low、high下标,找出其范围内的最大子数组
* @param arr 被查找的数组
* @param low 低位下标
* @param high 高位下标
* @return 最大子数组的起始位置,终止位置、和
*/
public static int[] findMaximumSubarray(int[] arr, int low, int high) {
//递归触底反弹,子数组只有一个元素,所以arr[low]本身就是最大子数组
if (low == high)
return new int[]{low, high, arr[low]};
else {
int mid = (low + high) / 2;
int[] leftArr = findMaximumSubarray(arr, low, mid);
int[] rightArr = findMaximumSubarray(arr, mid + 1, high);
int[] crossingArr = findMaxCrossingSubarray(arr, low, mid, high);
if (leftArr[2] >= rightArr[2] && leftArr[2] >= crossingArr[2])
return leftArr;
else if (rightArr[2] >= leftArr[2] && rightArr[2] >= crossingArr[2])
return rightArr;
else
return crossingArr;
}
}
/**
* 该方法接收一个数组arr和下标low,mid,high为输入,返回一个下标元组划定跨越中点的最大子数组的边界,
* 并返回最大子数组中值的和。
* @param arr 被查询的数组
* @param low 低位下标
* @param mid 中间位置下标,最大子数组跨越该点
* @param high 高位下标
* @return 最大子数组的起始位置,终止位置,和
*/
public static int[] findMaxCrossingSubarray(int[] arr, int low, int mid, int high) {
int maxLeft = mid;
int maxRight = mid + 1;
int leftSum = arr[mid];
int sum = 0;
for (int i = mid; i >= low; i--) {
sum += arr[i];
if (sum > leftSum) {
leftSum = sum;
maxLeft = i;
}
}
int rightSum = arr[mid + 1];
sum = 0;
for (int i = mid + 1; i <= high; i++) {
sum += arr[i];
if (sum > rightSum) {
rightSum = sum;
maxRight = i;
}
}
return new int[]{maxLeft, maxRight, leftSum + rightSum};
}
如果子数组A[low..high]包含n个元素,则调用FIND-MAX-CROSSING-SUBARRAY(A, low, mid, high)花费 Θ(n) 时间。
初始调用FIND-MAXIMUM-SUBARRAY(A, 1, A.length)会求出A[1..n]的最大子数组。
从数组的左边界开始,由左至右处理,记录到目前为止已经处理过的最大子数组。若已知A[1..j]的最大子数组,基于如下性质将扩展为A[1..j+1]的最大子数组:A[1..j+1]的最大子数组要么是A[1..j]的最大子数组,要么是某个子数组A[i..j+1](1 ≤ i ≤ j+1)。在已知A[1..j]的最大子数组的情况下,可以在线性时间内找出形如A[i..j+1]的最大子数组。
有问题,但还是先记录下, Θ(n2)
/**
* 该方法接收一个数组和low、high下标,找出其范围内的最大子数组
* @param arr 被查找的数组
* @param low 低位下标
* @param high 高位下标
* @return 最大子数组的起始位置,终止位置、和
*/
public static int[] findMaximumSubarray(int[] arr, int low, int high) {
int maxLeft = low;
int maxRight = low;
int sum = arr[low];
int tempSum;
for (int i = 0; i < arr.length - 1; i++) {
tempSum = 0;
for (int j = i + 1; j >= 0 ; j--) {
tempSum += arr[j];
if (tempSum > sum) {
maxLeft = j;
maxRight = i + 1;
sum = tempSum;
}
}
}
return new int[]{maxLeft, maxRight, sum};
}
若 A=(aij) 和 B=(bij) 是 n×n 的方阵,则对 i,j=1,2,⋯,n ,定义乘积 C=A⋅B 中的元素 cij 为:
SQUARE-MATRIX-MULTIPLAY(A, B)
n = A.rows
for i = 1 to n
cij = 0
for j = 1 to n
for k = 1 to n
cij = cij + aik * bkj
return C
由于三重for循环的每一重都恰好执行n步,而第三重的加法需要常量时间,因此过程SQUARE-MATRIX-MAULPLAY花费 Θ(n3) 时间。
java模拟
模拟矩阵的类,用二维数组存储值,重写了toString()方法:
public class Matrix {
private int rows; //矩阵的行数
private int cols; //矩阵的列数
private double[][] matrixArray; //代表矩阵的二维数组
public Matrix(int rows, int cols) {
this.rows = rows;
this.cols = cols;
matrixArray = new double[rows][cols];
}
public Matrix(double[][] matrixArray) {
rows = matrixArray.length;
cols = matrixArray[0].length;
this.matrixArray = matrixArray;
}
public int getRows() {
return rows;
}
public void setRows(int rows) {
this.rows = rows;
}
public int getCols() {
return cols;
}
public void setCols(int cols) {
this.cols = cols;
}
public double[][] getMatrixArray() {
return matrixArray;
}
public void setMatrixArray(double[][] matrixArray) {
this.matrixArray = matrixArray;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < matrixArray.length; i++) {
for (int j = 0; j < matrixArray[i].length; j++) {
sb.append(matrixArray[i][j] + "\t");
if (j == matrixArray[i].length - 1)
sb.append("\n");
}
}
return sb.toString();
}
}
算法类,包含一个求两个矩阵积的静态方法:
public class Algorithms {
/**
* 接收两个矩阵A, B返回两者的乘积
* @param matrixA 乘数A矩阵
* @param matrixB 乘数B矩阵
* @return 两个矩阵的乘积
*/
public static Matrix squareMatrixMultiply(Matrix matrixA, Matrix matrixB) {
double[][] matrixAArray = matrixA.getMatrixArray();
double[][] matrixBArray = matrixB.getMatrixArray();
int rows = matrixA.getRows();
int cols = matrixB.getCols();
double sum = 0;
double[][] matrixCArray = new double[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
sum = 0;
for (int k = 0; k < cols; k++) {
sum = sum + matrixAArray[i][k] * matrixBArray[k][j];
}
matrixCArray[i][j] = sum;
}
}
return new Matrix(matrixCArray);
}
}
测试类:
public class TestAlgorithms {
public static void main(String[] args) {
double[][] matrixAArray = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
double[][] matrixBArray = {{12.3, 12, 56}, {34, 456, 234.2}, {93, 3434, 1314}};
Matrix A = new Matrix(matrixAArray);
Matrix B = new Matrix(matrixBArray);
Matrix C = Algorithms.squareMatrixMultiply(A, B);
System.out.println(C);
}
}
输出:
359.3 11226.0 4466.4
777.2 22932.0 9279.0
1195.1 34638.0 14091.6
假定将 A、B 和 C 均分解为4个 n/2×n/2 的子矩阵:
SQUARE-MATRIX-MULPLAY-RECURSIVE(A, B)
n = A.rows
let C be a new n×n matrix
if n == 1
c11 = a11 · b11
else
C11 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A11, B11)
+ SQUARE-MATRIX-MULPLAY-RECURSIVE(A12, B21)
C12 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A11, B12)
+ SQUARE-MATRIX-MULPLAY-RECURSIVE(A12, B22)
C21 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A21, B11)
+ SQUARE-MATRIX-MULPLAY-RECURSIVE(A22, B21)
C22 = SQUARE-MATRIX-MULPLAY-RECURSIVE(A21, B12)
+ SQUARE-MATRIX-MULPLAY-RECURSIVE(A22, B22)
return C
SQUARE-MATRIX-MULPLAY-RECURSIVE运行时间递归式:
Matrix
。
public class Algorithms {
/**
* 接收两个矩阵A, B返回两者的乘积
* @param matrixA 乘数A矩阵
* @param matrixB 乘数B矩阵
* @return 两个矩阵的乘积
*/
public static Matrix squareMatrixMultiplyRecursive(Matrix matrixA, Matrix matrixB) {
double[][] A = matrixA.getMatrixArray();
double[][] B = matrixB.getMatrixArray();
int rows = matrixA.getRows();
double[][] C = new double[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
return new Matrix(C);
}
else {
int count = rows >> 1;
double[][] A11Arr = arrayCopy(A, 0, 0, count);
double[][] A12Arr = arrayCopy(A, 0, count, count);
double[][] A21Arr = arrayCopy(A, count, 0, count);
double[][] A22Arr = arrayCopy(A, count, count, count);
double[][] B11Arr = arrayCopy(B, 0, 0, count);
double[][] B12Arr = arrayCopy(B,0, count, count);
double[][] B21Arr = arrayCopy(B, count, 0, count);
double[][] B22Arr = arrayCopy(B, count, count, count);
Matrix C11 = add(squareMatrixMultiplyRecursive(new Matrix(A11Arr), new Matrix(B11Arr)), squareMatrixMultiplyRecursive(new Matrix(A12Arr), new Matrix(B21Arr)));
Matrix C12 = add(squareMatrixMultiplyRecursive(new Matrix(A11Arr), new Matrix(B12Arr)), squareMatrixMultiplyRecursive(new Matrix(A12Arr), new Matrix(B22Arr)));
Matrix C21 = add(squareMatrixMultiplyRecursive(new Matrix(A21Arr), new Matrix(B11Arr)), squareMatrixMultiplyRecursive(new Matrix(A22Arr), new Matrix(B21Arr)));
Matrix C22 = add(squareMatrixMultiplyRecursive(new Matrix(A21Arr), new Matrix(B12Arr)), squareMatrixMultiplyRecursive(new Matrix(A22Arr), new Matrix(B22Arr)));
return combine(C11, C12, C21, C22);
}
}
/**
* 复制二维数组
* @param srcArr 源数组
* @param x 从二维数组中第几个一维数组开始复制
* @param y 从那个一维数组的第几个元素开始复制
* @param count 连续的作用到几个一维数组,每个一维数组复制几个值
* @return 复制好的二维数组
*/
public static double[][] arrayCopy(double[][] srcArr, int x, int y, int count) {
double[][] destArr = new double[count][count];
for (int i = 0; i < count; i++) {
for (int j = 0; j < count; j++) {
destArr[i][j] = srcArr[x + i][y + j];
}
}
return destArr;
}
/**
* 求两个矩阵的和
* @param A 加数矩阵A
* @param B 加数矩阵B
* @return 两个矩阵的和
*/
public static Matrix add(Matrix A, Matrix B) {
double[][] aArr = A.getMatrixArray();
double[][] bArr = B.getMatrixArray();
double[][] cArr = new double[aArr.length][bArr[0].length];
for (int i = 0; i < aArr.length; i++) {
for (int j = 0; j < aArr[i].length; j++) {
cArr[i][j] = aArr[i][j] + bArr[i][j];
}
}
return new Matrix(cArr);
}
/**
* 将四个矩阵合并为一个矩阵
* @param A11 子矩阵
* @param A12 子矩阵
* @param A21 子矩阵
* @param A22 子矩阵
* @return 合并后的矩阵
*/
public static Matrix combine(Matrix A11, Matrix A12, Matrix A21, Matrix A22) {
double[][] a11Arr = A11.getMatrixArray();
double[][] a12Arr = A12.getMatrixArray();
double[][] a21Arr = A21.getMatrixArray();
double[][] a22Arr = A22.getMatrixArray();
int rowsA = a11Arr.length;
int colsA = a11Arr[0].length;
int rowsB = a12Arr.length;
int colsB = a12Arr[0].length;
int rowsC = a21Arr.length;
int colsC = a21Arr[0].length;
int rowsD = a22Arr.length;
int colsD = a22Arr[0].length;
double[][] resultArr = new double[rowsA + rowsC][colsA + colsB];
for (int i = 0; i < rowsA; i++) {
for (int j = 0; j < colsA; j++) {
resultArr[i][j] = a11Arr[i][j];
}
}
for (int i = 0; i < rowsB; i++) {
for (int j = 0; j < colsB; j++) {
resultArr[i][colsA + j] = a12Arr[i][j];
}
}
for (int i = 0; i < rowsC; i++) {
for (int j = 0; j < colsC; j++) {
resultArr[rowsA + i][j] = a21Arr[i][j];
}
}
for (int i = 0; i < rowsD; i++) {
for (int j = 0; j < colsD; j++) {
resultArr[rowsA + i][colsA + j] = a22Arr[i][j];
}
}
return new Matrix(resultArr);
}
}
测试类:
public class TestAlgorithms {
public static void main(String[] args) {
double[][] matrixAArray = {{1, 2, 3, 3}, {4, 5, 6, 6}, {7, 8, 9, 9}, {1, 1, 1, 1}};
double[][] matrixBArray = {{12.3, 12, 56, 1}, {34, 456, 234.2, 1}, {93, 3434, 1314, 1}, {1, 1, 1, 1}};
Matrix A = new Matrix(matrixAArray);
Matrix B = new Matrix(matrixBArray);
Matrix C = Algorithms.squareMatrixMultiplyRecursive(A, B);
System.out.println(C);
}
}
输出:
362.3 11229.0 4469.4 9.0
783.2 22938.0 9285.0 21.0
1204.1 34647.0 14100.6 33.0
140.3 3903.0 1605.2 4.0