python pde adi(抛物型差分(二维—ADI格式))

#coding:utf-8
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
import time

def createMatrix( m, n):
	A = np.zeros( (n + 2,m + 2))
	Up = np.ones( (m+2,1)) * 100
	Down = np.ones((m+2, 1)) * 0
	Lf = np.ones((1, n + 2)) * 75
	Rt = np.ones((1, n + 2) )* 50

	A[0,:] = Up.ravel()
	A[n+1,:] = Down.ravel()
	A[:,0] = Lf.ravel()
	A[:, m +1] = Rt.ravel()

	return A

def oneIter(A,  r_lf, r_rt):
	a_size = A.shape
	m = a_size[1] - 2
	n = a_size[0] - 2
	#create init ImpMatrix M and b
	M = np.diag( np.ones((1,m)).ravel() * ( 1 + r_lf))
	M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), 1)
	M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), -1)

	B = A.copy()
	for j in range(1, n + 1 ):
		b = np.zeros((m,1))
		rowA = A[j,:]
		b[0] = b[0] + rowA[0] * r_lf / 2
		b[m-1] = b[m-1] + rowA[m-1] * r_lf /2
		for i in range(1, m+1):
			colA = A[j-1:j+1+1,i]
			b[i-1] = b[i-1] + r_rt / 2 * colA[0] + ( 1 - r_rt) * colA[1] + r_rt / 2 * colA[2]
		B[j,1:m+1] = np.linalg.solve(M, b).ravel()

	return B

def computeA(m, n , rx, ry, iter):
	A = createMatrix(m,n)
	print 'total iter=%s' % (iter)
	for i in range(1, iter):
		print 'iter num=%s' % (i)
		A = oneIter(A, rx,ry)
		B = oneIter(np.transpose(A), ry, rx)
		A = np.transpose(B)

	return A

def computeOneIter(A, m, n , rx, ry):
	A = oneIter(A, rx,ry)
	B = oneIter(np.transpose(A), ry, rx)
	A = np.transpose(B)
	return A

def getStart():
	X_INTERVAL = [0,20]
	Y_INTERVAL = [0,30]
	T = [0,10]
	deltax = 0.5
	deltay = 0.3
	tao = 1.0 / 3 * min(deltax, deltay) * min(deltax, deltay)
	m = (X_INTERVAL[1] - X_INTERVAL[0]) / deltax - 1
	n = (Y_INTERVAL[1] - Y_INTERVAL[0]) / deltay - 1
	m = int(m)
	n = int(n)
	print 'm=%s,n=%s' % (m,n)
	x = np.linspace(X_INTERVAL[0], X_INTERVAL[1], m)
	y = np.linspace(Y_INTERVAL[0], Y_INTERVAL[1], n)
	#A = computeA(m,n,tao/deltax/deltax, tao/deltay/deltay, int((T[1] - T[0])/tao))
	

	#animation
	fig = plt.figure()
	ax = fig.add_subplot(111, projection='3d')
	X = x
	Y = y
	X, Y = np.meshgrid(X, Y)

	wframe = None

	iter = int((T[1] - T[0])/tao)
	A = createMatrix(m-2,n-2)
	for i in range(iter):
		A = computeOneIter(A,m,n,tao/deltax/deltax, tao/deltay/deltay)
		if wframe:
			ax.collections.remove(wframe)
	
		wframe = ax.plot_wireframe(X, Y, A, rstride=2, cstride=2)
		plt.pause(0.01)
		print 'iter=',i


	m = A.shape[0]
	n = A.shape[1]


	return A,x,y



if __name__ == '__main__':
	getStart()


你可能感兴趣的:(interest)