今天遇到一个问题。老师让用C++
实现三次样条插值函数,输出多项式系数,其中有一个公式长这样 S ( x ) = m i ( x − x i ) ( x − x i + 1 ) 2 h i 2 + m i + 1 ( x − x i ) 2 ( x − x i + 1 ) h i 2 + y i ( x − x i + 1 ) 2 ( h i + 2 x − 2 x i ) h i 3 + y i + 1 ( x − x i ) 2 ( h i − 2 x + 2 x i + 1 ) h i 3 S(x)=\frac{m_{i} \left(x - x_{i}\right) \left(x - x_{i+1}\right)^{2}}{h_{i}^{2}} + \frac{m_{i+1} \left(x - x_{i}\right)^{2} \left(x - x_{i+1}\right)}{h_{i}^{2}} + \frac{y_{i} \left(x - x_{i+1}\right)^{2} \left(h_{i} + 2 x - 2 x_{i}\right)}{h_{i}^{3}} + \frac{y_{i+1} \left(x - x_{i}\right)^{2} \left(h_{i} - 2 x + 2 x_{i+1}\right)}{h_{i}^{3}} S(x)=hi2mi(x−xi)(x−xi+1)2+hi2mi+1(x−xi)2(x−xi+1)+hi3yi(x−xi+1)2(hi+2x−2xi)+hi3yi+1(x−xi)2(hi−2x+2xi+1)对应的 LaTeX \LaTeX LATEX代码为
S(x)=\frac{m_{i} \left(x - x_{i}\right) \left(x - x_{i+1}\right)^{2}}{h_{i}^{2}} + \frac{m_{i+1} \left(x - x_{i}\right)^{2} \left(x - x_{i+1}\right)}{h_{i}^{2}} + \frac{y_{i} \left(x - x_{i+1}\right)^{2} \left(h_{i} + 2 x - 2 x_{i}\right)}{h_{i}^{3}} + \frac{y_{i+1} \left(x - x_{i}\right)^{2} \left(h_{i} - 2 x + 2 x_{i+1}\right)}{h_{i}^{3}}
要求出各项系数,就必须把上面这个贼长的式子化成关于x
的多项式。……那得累死。怎么办呢?
Python
里有一个模块叫sympy
,是做符号计算的。安装方法为:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple sympy
sympy
里面有一个函数collect
,作用是将关于一个/几个变量的同次项合并在一起。函数定义为:
sympy.simplify.radsimp.collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True)
其中,expr
是要化简的表达式,syms
指定化简成关于谁的多项式(可以是一个变量或是变量的列表),evaluate
表示是否对结果赋值。若evaluate = True
,则直接返回化简后的多项式;若evaluate = False
,则返回一个字典,存储每一项的系数。
例如,要把 a x 2 + 3 x + c + 1 b x 2 + 9 + d 3 x ax^2+3x+c+\frac1b x^2+9+\frac d3x ax2+3x+c+b1x2+9+3dx化为关于 x x x的多项式,可以这么写(下面的结果是在IPython
中运行得到的):
In [4]: from sympy import *
In [5]: a, b, c, d, x = symbols('a, b, c, d, x')
In [6]: collect(a * x ** 2 + 3 * x + c + 1 / b * x ** 2 + 9 + d / 3 * x, x)
Out[6]: c + x**2*(a + 1/b) + x*(d/3 + 3) + 9
结果是 ( a + 1 b ) x 2 + ( d 3 + 3 ) x + c + 9 \left(a+\frac1b\right)x^2+\left(\frac d3+3\right)x+c+9 (a+b1)x2+(3d+3)x+c+9。
若设evaluate = False
:
In [8]: p = collect(a * x ** 2 + 3 * x + c + 1 / b * x ** 2 + 9 + d / 3 * x, x, evaluate = False)
In [9]: p
Out[9]: {x: d/3 + 3, x**2: a + 1/b, 1: c + 9}
In [10]: p[x ** 2] # x平方项的系数
Out[10]: a + 1/b
In [11]: p[1] # 常数项
Out[11]: c + 9
也可以用它来化简多变量的表达式。
In [12]: y = Symbol('y')
In [13]: collect(x ** 2 + y * x ** 2 + x * y + y + a * y, [x, y])
Out[13]: x**2*(y + 1) + x*y + y*(a + 1)
In [14]: collect(x ** 2 + y * x ** 2 + x * y + y + a * y, [x, y], evaluate = False)
Out[14]: {y: a + 1, x**2: y + 1, x: y}
注意这个算法是贪心的,它同一时刻只处理一个变量。
还有更高级的用法:
In [15]: collect(a * sin(2 * x) + b * sin(2 * x), sin(2 * x))
Out[15]: (a + b)*sin(2*x)
In [16]: collect(a * x * log(x) + b * (x * log(x)) + x, x * log(x))
Out[16]: x*(a + b)*log(x) + x
现在解决我们一开始提出的问题。注意,要先调用expand
展开括号中的东西才能继续化简。
In [18]: S, h_i, x_i, x_j, y_i, y_j, m_i, m_j = symbols('S, h_i, x_i, x_j, y_i, y_j, m_i, m_j')
In [19]: S = (h_i + 2 * (x - x_i)) * (x - x_j) ** 2 * y_i / h_i ** 3 + \
...: (h_i - 2 * (x - x_j)) * (x - x_i) ** 2 * y_j / h_i ** 3 + \
...: (x - x_i) * (x - x_j) ** 2 * m_i / h_i ** 2 + \
...: (x - x_j) * (x - x_i) ** 2 * m_j / h_i ** 2
In [20]: collect(expand(S), x)
Out[20]: x**3*(m_i/h_i**2 + m_j/h_i**2 + 2*y_i/h_i**3 - 2*y_j/h_i**3) + x**2*(-m_i*x_i/h_i**2 - 2*m_i*x_j/h_i**2 - 2*m_j*x_i/h_i**2 - m_j*x_j/h_i**2 + y_i/h_i**2 + y_j/h_i**2 - 2*x_i*y_i/h_i**3 + 4*x_i*y_j/h_i**3 - 4*x_j*y_i/h_i**3 + 2*x_j*y_j/h_i**3) + x*(2*m_i*x_i*x_j/h_i**2 + m_i*x_j**2/h_i**2 + m_j*x_i**2/h_i**2 + 2*m_j*x_i*x_j/h_i**2 - 2*x_i*y_j/h_i**2 - 2*x_j*y_i/h_i**2 - 2*x_i**2*y_j/h_i**3 + 4*x_i*x_j*y_i/h_i**3 - 4*x_i*x_j*y_j/h_i**3 + 2*x_j**2*y_i/h_i**3) - m_i*x_i*x_j**2/h_i**2 - m_j*x_i**2*x_j/h_i**2 + x_i**2*y_j/h_i**2 + x_j**2*y_i/h_i**2 + 2*x_i**2*x_j*y_j/h_i**3 - 2*x_i*x_j**2*y_i/h_i**3
稍稍化简之后得 S ( x ) = h i ( m i + m i + 1 ) + 2 y i − 2 y i + 1 h i 3 x 3 + h i ( − m i x i − 2 m i x i + 1 − 2 m i + 1 x i − m i + 1 x i + 1 + y i + y i + 1 ) − 2 x i y i + 4 x i y i + 1 − 4 x i + 1 y i + 2 x i + 1 y i + 1 h i 3 x 2 + h i ( 2 m i x i x i + 1 + m i x i + 1 2 + m i + 1 x i 2 + 2 m i + 1 x i x i + 1 − 2 x i y i + 1 − 2 x i + 1 y i ) − 2 x i 2 y i + 1 + 4 x i x i + 1 y i − 4 x i x i + 1 y i + 1 + 2 x i + 1 2 y i h i 3 x + h i ( − m i x i x i + 1 2 − m i + 1 x i 2 x i + 1 + x i 2 y i + 1 + x i + 1 2 y i ) + 2 x i x i + 1 ( x i y i + 1 − x i + 1 y i ) h i 3 S\left(x\right)=\frac{h_i\left(m_i+m_{i+1}\right)+2y_i-2y_{i+1}}{h_i^3}x^3+\frac{h_i\left(-m_ix_i-2m_ix_{i+1}-2m_{i+1}x_i-m_{i+1}x_{i+1}+y_i+y_{i+1}\right)-2x_iy_i+4x_iy_{i+1}-4x_{i+1}y_i+2x_{i+1}y_{i+1}}{h_i^3}x^2+\frac{h_i\left(2m_ix_ix_{i+1}+m_ix_{i+1}^2+m_{i+1}x_i^2+2m_{i+1}x_ix_{i+1}-2x_iy_{i+1}-2x_{i+1}y_i\right)-2x_i^2y_{i+1}+4x_ix_{i+1}y_i-4x_ix_{i+1}y_{i+1}+2x_{i+1}^2y_i}{h_i^3}x+\frac{h_i\left(-m_ix_ix_{i+1}^2-m_{i+1}x_i^2x_{i+1}+x_i^2y_{i+1}+x_{i+1}^2y_i\right)+2x_ix_{i+1}\left(x_iy_{i+1}-x_{i+1}y_i\right)}{h_i^3} S(x)=hi3hi(mi+mi+1)+2yi−2yi+1x3+hi3hi(−mixi−2mixi+1−2mi+1xi−mi+1xi+1+yi+yi+1)−2xiyi+4xiyi+1−4xi+1yi+2xi+1yi+1x2+hi3hi(2mixixi+1+mixi+12+mi+1xi2+2mi+1xixi+1−2xiyi+1−2xi+1yi)−2xi2yi+1+4xixi+1yi−4xixi+1yi+1+2xi+12yix+hi3hi(−mixixi+12−mi+1xi2xi+1+xi2yi+1+xi+12yi)+2xixi+1(xiyi+1−xi+1yi)
对应的 LaTeX \LaTeX LATEX代码为
S\left(x\right)=\frac{h_i\left(m_i+m_{i+1}\right)+2y_i-2y_{i+1}}{h_i^3}x^3+\frac{h_i\left(-m_ix_i-2m_ix_{i+1}-2m_{i+1}x_i-m_{i+1}x_{i+1}+y_i+y_{i+1}\right)-2x_iy_i+4x_iy_{i+1}-4x_{i+1}y_i+2x_{i+1}y_{i+1}}{h_i^3}x^2+\frac{h_i\left(2m_ix_ix_{i+1}+m_ix_{i+1}^2+m_{i+1}x_i^2+2m_{i+1}x_ix_{i+1}-2x_iy_{i+1}-2x_{i+1}y_i\right)-2x_i^2y_{i+1}+4x_ix_{i+1}y_i-4x_ix_{i+1}y_{i+1}+2x_{i+1}^2y_i}{h_i^3}x+\frac{h_i\left(-m_ix_ix_{i+1}^2-m_{i+1}x_i^2x_{i+1}+x_i^2y_{i+1}+x_{i+1}^2y_i\right)+2x_ix_{i+1}\left(x_iy_{i+1}-x_{i+1}y_i\right)}{h_i^3}
至此,问题圆满解决。