import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import fsolve
class odessolver():
def __init__(self, f, Y_start=np.array([0, 1]), dY_start=np.array([0, 0]), \
X_start=0, X_end=1, h=0.01):
self.f = f
self.h = h
self.X = np.arange(X_start, X_end, self.h)
self.n = Y_start.size
self.Y = np.zeros((self.n, self.X.size))
#第一个参数表示元 第二个参数表示变量
self.Y[:, 0] = Y_start
self.Y[:, 1] = Y_start + self.h * dY_start
self.tol = 1e-6
def __str__(self):
return f"y'(x) = f(x) = ({self.f}) variables"
def RK4(self):
for i in range(1, self.X.size):
k1 = self.f(self.X[i-1] , self.Y[:, i-1])
k2 = self.f(self.X[i-1] +self.h/2 , self.Y[:, i-1]+1/2*self.h*k1)
k3 = self.f(self.X[i-1] +self.h/2 , self.Y[:, i-1]+1/2*self.h*k2)
k4 = self.f(self.X[i-1] +self.h , self.Y[:, i-1]+ self.h*k3)
self.Y[:, i] = self.Y[:, i-1] +self.h/6 * (k1 + 2*k2 + 2*k3 + k4)
return self.Y
def IRK4(self):
for i in range(1, self.X.size):
def f1(k1, k2):
f1_x = self.X[i-1] + self.h*(3-3**0.5)/6
f1_y = self.Y[:, i-1]+k1/4*self.h+(3-2*3**0.5)/12*k2*self.h
f1_res = self.f(f1_x, f1_y)
return np.array([f1_res[i] for i in range(self.n)])
def f2(k1, k2):
f2_x = self.X[i-1] + self.h*(3+3**0.5)/6
f2_y = self.Y[:, i-1]+k2/4*self.h+(3+2*3**0.5)/12*k1*self.h
f2_res = self.f(f2_x, f2_y)
return np.array([f2_res[i] for i in range(self.n)])
def func(k):
k1 = np.array([k[i] for i in range(self.n)])
k2 = np.array([k[i+self.n] for i in range(self.n)])
doc = []
for i in range(self.n):
doc.append((k1 - f1(k1, k2))[i])
for i in range(self.n):
doc.append((k2 - f2(k1, k2))[i])
return doc
sol = fsolve(func, np.zeros(self.n*2))
self.Y[:, i] = self.Y[:, i-1] + 1/2 * self.h * (sol[:self.n] + sol[self.n:])
return self.Y
A = 0
B = 1
Lambda = 1
Q = lambda x:(1+x**2)**2
Y0 = np.array([A, B])
def test_fun(x, Y):
return np.array([Y[1], Lambda**2 * Q(x) * Y[0]])
c = odessolver(test_fun, Y_start=Y0)
x = np.arange(0, 1, 0.01)
y3 = c.RK4()
x = np.arange(0, 1, 0.01)
plt.plot(x, y3[0, :], label="RK4")
##y4 = c.IRK4()
##x = np.arange(0, 1, 0.01)
##plt.plot(x, y4[0, :], label="IRK4")
WKB = lambda x:1/(Lambda*(1+x**2)**0.5)*(np.exp(x+x**3/3)-np.exp(-(x+x**3/3)))/2
plt.plot(x, WKB(x), label="WKB")
plt.legend()
plt.pause(0.01)
[1]数学物理中的渐近方法 李家春 周显初 科学出版社