#coding:utf-8
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
import time
def createMatrix( m, n):
A = np.zeros( (n + 2,m + 2))
Up = np.ones( (m+2,1)) * 100
Down = np.ones((m+2, 1)) * 0
Lf = np.ones((1, n + 2)) * 75
Rt = np.ones((1, n + 2) )* 50
A[0,:] = Up.ravel()
A[n+1,:] = Down.ravel()
A[:,0] = Lf.ravel()
A[:, m +1] = Rt.ravel()
return A
def oneIter(A, r_lf, r_rt):
a_size = A.shape
m = a_size[1] - 2
n = a_size[0] - 2
#create init ImpMatrix M and b
M = np.diag( np.ones((1,m)).ravel() * ( 1 + r_lf))
M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), 1)
M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), -1)
B = A.copy()
for j in range(1, n + 1 ):
b = np.zeros((m,1))
rowA = A[j,:]
b[0] = b[0] + rowA[0] * r_lf / 2
b[m-1] = b[m-1] + rowA[m-1] * r_lf /2
for i in range(1, m+1):
colA = A[j-1:j+1+1,i]
b[i-1] = b[i-1] + r_rt / 2 * colA[0] + ( 1 - r_rt) * colA[1] + r_rt / 2 * colA[2]
B[j,1:m+1] = np.linalg.solve(M, b).ravel()
return B
def computeA(m, n , rx, ry, iter):
A = createMatrix(m,n)
print 'total iter=%s' % (iter)
for i in range(1, iter):
print 'iter num=%s' % (i)
A = oneIter(A, rx,ry)
B = oneIter(np.transpose(A), ry, rx)
A = np.transpose(B)
return A
def computeOneIter(A, m, n , rx, ry):
A = oneIter(A, rx,ry)
B = oneIter(np.transpose(A), ry, rx)
A = np.transpose(B)
return A
def getStart():
X_INTERVAL = [0,20]
Y_INTERVAL = [0,30]
T = [0,10]
deltax = 0.5
deltay = 0.3
tao = 1.0 / 3 * min(deltax, deltay) * min(deltax, deltay)
m = (X_INTERVAL[1] - X_INTERVAL[0]) / deltax - 1
n = (Y_INTERVAL[1] - Y_INTERVAL[0]) / deltay - 1
m = int(m)
n = int(n)
print 'm=%s,n=%s' % (m,n)
x = np.linspace(X_INTERVAL[0], X_INTERVAL[1], m)
y = np.linspace(Y_INTERVAL[0], Y_INTERVAL[1], n)
#A = computeA(m,n,tao/deltax/deltax, tao/deltay/deltay, int((T[1] - T[0])/tao))
#animation
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X = x
Y = y
X, Y = np.meshgrid(X, Y)
wframe = None
iter = int((T[1] - T[0])/tao)
A = createMatrix(m-2,n-2)
for i in range(iter):
A = computeOneIter(A,m,n,tao/deltax/deltax, tao/deltay/deltay)
if wframe:
ax.collections.remove(wframe)
wframe = ax.plot_wireframe(X, Y, A, rstride=2, cstride=2)
plt.pause(0.01)
print 'iter=',i
m = A.shape[0]
n = A.shape[1]
return A,x,y
if __name__ == '__main__':
getStart()