线性方程组问题可以利用矩阵变换求解。利用高斯消元法,将矩阵转换成一个行阶梯矩阵,最后得到一个简化行阶梯矩阵,就是方程的解。参考资料(高斯消元法)
Java代码
public class FunctionResolver {
public static class LinearEquationGroup {
/**
代表线性方程组的矩阵。方程组已经经过归一化处理,带未知变量的部分全部位于“=”左边,常数合并后位于“=”右边。
比如:
2a + b - c = 8
-3a - b + 2c = -11
-2a + b + 2c = -3
对应的矩阵为:
2 1 -1 8
-3 -1 2 -11
-2 1 2 -3
*/
private BigDecimal[][] matrix;
//未知变量的名称,排列顺序和矩阵一致,比如上面的例子中,对应的变量名列表就是a, b, c
private List variantes;
public LinearEquationGroup() {
}
public LinearEquationGroup(BigDecimal[][] matrix) {
this.matrix = matrix;
check();
}
private void check() {
for (int i = 0; i < matrix.length; i++) {
if (matrix.length != (matrix[i].length - 1)) {
throw new IllegalArgumentException("输入矩阵有误! 必须为n*(n+1)矩阵");
}
}
}
public Map solve() {
check();
sortRows();
eliminateVarianteDownwards();
normalize();
eliminateVarianteUpwards();
Map ret = new HashMap<>();
int i = 0;
int lastCol = variantes.size();
for (String var : variantes) {
ret.put(var, matrix[i++][lastCol]);
}
return ret;
}
/**
* 重排序行, 以便做高斯消元. 保证第i行的第i列元素不为0
*/
void sortRows() {
int row = 0;
int below = 0;
int col = 0;
for (; row < matrix.length - 1; ++row) {
col = row;
if (matrix[row][col].compareTo(BigDecimal.ZERO) == 0) {
for (below = row + 1; below < matrix.length; below++) {
if (matrix[below][col].compareTo(BigDecimal.ZERO) != 0) {
BigDecimal[] temp = matrix[row];
matrix[row] = matrix[below];
matrix[below] = temp;
break;
}
}
if (below >= matrix.length) {
throw new IllegalArgumentException("方程组无解或者无唯一解!");
}
}
}
}
/**
* 从上往下消元。消元结果使得第i行以下的第i列都为0
*/
void eliminateVarianteDownwards() {
int baseRow = 0;
int targetRow;
final BigDecimal minusOne = new BigDecimal(-1);
int colCnt = matrix[0].length;
int rowCnt = matrix.length;
for (; baseRow < matrix.length - 1; ++baseRow) {
int col = baseRow;
for (targetRow = baseRow + 1; targetRow < matrix.length; ++targetRow) {
if (matrix[targetRow][col].compareTo(BigDecimal.ZERO) != 0) {
BigDecimal fraction = matrix[targetRow][col].divide(matrix[baseRow][col]).multiply(minusOne);
for (int j = 0; j < colCnt; j++) {
matrix[targetRow][j] = matrix[targetRow][j].add(fraction.multiply(matrix[baseRow][j]));
}
}
}
}
}
//归一化。使得第i行第i列的元素都为1,这样最后得到的简化行阶梯矩阵的最后一列即为解
void normalize() {
for (int i = 0; i < matrix.length; i++) {
if (matrix[i][i].compareTo(BigDecimal.ZERO) == 0) {
throw new IllegalArgumentException("方程组无解或者无唯一解!");
}
if (matrix[i][i].compareTo(BigDecimal.ONE) == 0)
continue;
for (int j = matrix[i].length - 1; j >= i; j--) {
matrix[i][j] = matrix[i][j].divide(matrix[i][i]);
}
}
}
//从下往上消元。使得第i行以上第i列的元素都为0
void eliminateVarianteUpwards() {
int baseRow = matrix.length - 1;
int colCnt = matrix[0].length;
final BigDecimal minusOne = new BigDecimal(-1);
for (; baseRow > 0; baseRow--) {
int col = baseRow;
for (int targetRow = baseRow - 1; targetRow >= 0; targetRow--) {
if (matrix[targetRow][col].compareTo(BigDecimal.ZERO) != 0) {
BigDecimal fraction = minusOne.multiply(matrix[targetRow][col].divide(matrix[baseRow][col]));
for (int j = 0; j < colCnt; j++) {
matrix[targetRow][j] = matrix[targetRow][j].add(fraction.multiply(matrix[baseRow][j]));
}
}
}
}
}
public BigDecimal[][] getMatrix() {
return matrix;
}
public List getVariantes() {
return variantes;
}
public void setVariantes(List variantes) {
this.variantes = variantes;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LinearEquationGroup that = (LinearEquationGroup) o;
return Arrays.deepEquals(matrix, that.matrix);
}
@Override
public int hashCode() {
return Arrays.deepHashCode(matrix);
}
}
//解析Map形式输入的方程组。输入类型Map,每个entry代表一个方程,entry的key是方程“=”左边部分,entry的value是“=”右边部分。未知变量用字母串表示。
public static class LinearEquationGroupBuilder {
private Map> keyTokens = new HashMap<>();
private Map> valueTokens = new HashMap<>();
public LinearEquationGroup createFrom(Map funcMap) {
final LinearEquationGroup linearEquationGroup = new LinearEquationGroup();
linearEquationGroup.matrix = new BigDecimal[funcMap.size()][];
funcMap.forEach((k, v) -> {
keyTokens.put(k, tokenize(k));
valueTokens.put(v, tokenize(v));
});
Map vars = findVariantes(funcMap);
final ArrayList variants = new ArrayList<>(vars.size());
vars.forEach((k, v) -> variants.add(v, k));
linearEquationGroup.setVariantes(variants);
BigDecimal[][] matrix = linearEquationGroup.matrix;
for (int i = 0; i < matrix.length; i++) {
matrix[i] = new BigDecimal[vars.size() + 1];
for (int j = 0; j < matrix[i].length; j++) {
matrix[i][j] = new BigDecimal(0);
}
}
buildMatrix(matrix, funcMap, vars);
return linearEquationGroup;
}
private Map findVariantes(Map fucMap) {
final Map map = new HashMap<>();
final int[] keyIdx = new int[1];
keyIdx[0] = 0;
fucMap.forEach((k,v) -> {
parseVariantes(map, keyIdx, keyTokens.get(k));
parseVariantes(map, keyIdx, valueTokens.get(v));
});
return map;
}
private void parseVariantes(Map map, int[] keyIdx, List tokens) {
for (String part : tokens) {
final Object[] objects = splitFactorAndVar(part);
if (objects[1] != null) {
final String var = (String)objects[1];
if (!map.containsKey(var)) {
map.put(var, keyIdx[0]++);
}
}
}
}
private LinkedList tokenize(String k) {
final LinkedList parts = new LinkedList<>();
int end = k.length();
for (int i = k.length() - 1; i >= 0; i--) {
if (k.charAt(i) == '-') {
final String trimed = k.substring(i, end).trim();
if (trimed.length() > 0)
parts.addFirst(trimed);
end = i;
} else if (k.charAt(i) == '+') {
final String trimed = k.substring(i + 1, end).trim();
if (trimed.length() > 0)
parts.addFirst(trimed);
end = i;
} else {
//
}
}
if (end > 0) {
final String trimed = k.substring(0, end).trim();
if (trimed.length() > 0)
parts.addFirst(trimed);
}
return parts;
}
/**
* 解析数字和变量
* @param part 数字和变量, 或者纯数字, 或者纯变量, 比如2a, 1.5b, 7, c, -3d, -e
* @return 返回一个数组, 0号元素是解析出来的系数, 1号元素是变量名. 系数或变量名可能不存在,这个时候对应位置的元素为null
*/
Object[] splitFactorAndVar(String part) {
Object[] ret = new Object[]{null, null};
part = part.trim();
boolean isNegative = false;
if (part.charAt(0) == '-') {
isNegative = true;
part = part.substring(1);
}
int idx = part.length() - 1;
while (idx >= 0 && !Character.isDigit(part.charAt(idx))) --idx;
if (idx >= 0) {
double factor = Double.parseDouble(part.substring(0, idx + 1));
ret[0] = factor;
String var = null;
if (idx + 1 < part.length()) {
var = part.substring(idx + 1);
ret[1] = var.trim();
}
} else {
ret[1] = part.trim();
ret[0] = 1.0;
}
if (isNegative) {
ret[0] = -1 * (double)ret[0];
}
return ret;
}
private void buildMatrix(BigDecimal[][] matric, Map funcMap, Map varIds) {
final int[] i = new int[1];
i[0] = 0;
funcMap.forEach((k, v) -> {
//matric[i[0]] = new BigDecimal[varIds.size() + 1];
processOneFunction(matric[i[0]++], k, v, varIds);
});
}
void processOneFunction(BigDecimal[] matricRow, String key, String value, Map varIds) {
processOneside(matricRow, key, varIds, true);
processOneside(matricRow, value, varIds, false);
}
private void processOneside(BigDecimal[] matricRow, String key, Map varIds, boolean isKey) {
int commonFactor;
List parts;
if (isKey) {
commonFactor = 1;
parts = keyTokens.get(key);
} else {
commonFactor = -1;
parts = valueTokens.get(key);
}
for (String part : parts) {
final Object[] objects = splitFactorAndVar(part);
// 含有变量
if (objects[1] != null) {
if (objects[0] != null) {
final Double factor = (Double) objects[0];
final Integer idx = varIds.get((String) objects[1]);
if (idx == null) {
throw new IllegalArgumentException("can't found " + objects[1] + " in varIds");
}
if (matricRow[idx] == null) {
matricRow[idx] = new BigDecimal(factor * commonFactor);
} else {
matricRow[idx] = matricRow[idx].add(new BigDecimal(factor * commonFactor));
}
} else {
final Integer idx = varIds.get((String) objects[1]);
if (idx == null) {
throw new IllegalArgumentException("can't found " + objects[1] + " in varIds");
}
if (matricRow[idx] == null) {
matricRow[idx] = new BigDecimal(1 * commonFactor);
} else {
matricRow[idx] = matricRow[idx].add(new BigDecimal(1 * commonFactor));
}
}
} else {
//只是数字
if (objects[0] != null) {
matricRow[matricRow.length - 1] = matricRow[matricRow.length - 1].add(new BigDecimal(-1 * commonFactor * ((Double)objects[0])));
} else {
throw new IllegalArgumentException("系数和数字不能同时为null: " + key);
}
}
}
}
}
public static void main(String[] args) {
final FunctionResolver.LinearEquationGroupBuilder linearEquationGroupBuilder = new FunctionResolver.LinearEquationGroupBuilder();
Map fMap = new HashMap<>();
fMap.put("2a+b-c", "8");
fMap.put("-3a-b+2c", "-11");
fMap.put("-2a+b+2c", "-3");
FunctionResolver.LinearEquationGroup equationGroup = linearEquationGroupBuilder.createFrom(fMap);
final Map solve = equationGroup.solve();
solve.forEach((k, v) -> System.out.println(k + "=" + v));
}
}
复杂度分析
该算法的时间复杂度为O(n^3),空间复杂度为O(n^2)。对于维度不高的线性方程还是可以接受。