import numpy as np
import torch
import sys
import time
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt

# TDMA solver
def tdma(a,b,c,d,phi,ni):
   import numpy as np

   p=np.zeros(ni+1)
   q=np.zeros(ni+1)
   q[0]=phi[0]
   for i in range(1,ni):

#            calculate coefficients of recurrence formula                      
      term= (a[i]-c[i]*p[i-1])
      if abs(term)<1.e-10:
         term=1.e-10
      p[i]= b[i]/term
      q[i]= (d[i]+c[i]*q[i-1])/term
#         obtain new phi@s                                       
   for ii in range(1,ni):
       i= ni-ii
#      print('i,ii=',i,ii)
       phi[i]= p[i]*phi[i+1]+q[i]

   return phi


plt.interactive(True)

plt.close('all')

#

# Possin eq.
#
# Discretization described in detail in
# http://www.tfd.chalmers.se/~lada/comp_fluid_dynamics/

# max number of iterations
maxit=200

plt.rcParams.update({'font.size': 22})

# half channel width=1
#

# create the grid

nj=30 
nj=11 
#nj=60 
njm1=nj-1
yfac=1.15 # stretching
yfac=1. # stretching
dy=0.1
yc=np.zeros(nj)
delta_y=np.zeros(nj)
yc[0]=0.

for j in range(1,nj):
    yc[j]=yc[j-1]+dy
    dy=yfac*dy

ymax= yc[-1]

yc = yc/ymax

# cell centres
yp=np.zeros(nj+1)
for j in range(1,nj):
   yp[j]=0.5*(yc[j]+yc[j-1])
yp[-1]=yc[-1]


# viscosity
viscos=1/100

# under-relaxation
urf=0.5
urf=1

# initial k
k_old = np.zeros(nj+1)

# plot k for each iteration at node jmon
jmon=8 


small=1.e-10
great=1.e10

# initialaze
k=np.ones(nj+1)*1.e-4
y=np.zeros(nj+1)
vist=np.ones(nj+1)*viscos
dn=np.zeros(nj+1)
ds=np.zeros(nj+1)
dy_s=np.zeros(nj+1)
fy=np.zeros(nj+1)


# do a loop over all nodes (except the boundary nodes)
for j in range(1,nj):

# compute dy_s
   dy_s[j]=yp[j]-yp[j-1]

# compute dy_n
   dy_n=yp[j+1]-yp[j]

# compute deltay
   delta_y[j]=yc[j]-yc[j-1]
 
   dn[j]=1./dy_n
   ds[j]=1./dy_s[j]

# interpolation factor
   del1=yc[j]-yp[j]
   del2=yp[j+1]-yc[j]
   fy[j]=del1/(del1+del2)

# boundary conditions for k
k[0]=0.
k[-1]=0.
nj2 = int(nj/2)

su=np.zeros(nj+1)
sp=np.zeros(nj)
an=np.zeros(nj+1)
as1=np.zeros(nj+1)
ap=np.zeros(nj+1)
k_iter=np.zeros(maxit)
# do max. maxit iterations
for n in range(1,maxit):


# solve k
    res_k = 0
# monitor the development of k in node jmon
    k_iter[n]=k[jmon]

    for j in range(1,nj):

      su[j]=2*delta_y[j]

# compute an & as
      vist_n=fy[j]*vist[j+1]+(1.-fy[j])*vist[j]
      an[j]=vist_n*dn[j]
      vist_s=fy[j-1]*vist[j]+(1.-fy[j-1])*vist[j-1]
      as1[j]=vist_s*ds[j]

    for j in range(1,nj):
# compute ap
      ap[j]=an[j]+as1[j]

# under-relaxate
      ap[j]= ap[j]/urf
      su[j]= su[j]+(1.-urf)*ap[j]*k[j]

      res_k += abs(an[j]*k[j+1]+as1[j]*k[j-1]+su[j]-ap[j]*k[j])
      k_old[j] = np.copy(k[j])
# use Gauss-Seidel
#     k[j]=(an[j]*k[j+1]+as1[j]*k[j-1]+su[j])/ap[j]

# use TDMA
    k=tdma(ap,an,as1,su,k,nj)

# terminate iteration
    k_diff = np.max(k-k_old)
    print(f"{'iter: '}{n}, {'k[nj2]: '}{k[nj2]:.2e}, {'max(k-k_old): '}{k_diff:.2e}")
    if k_diff < 0.0001:
       break


# compute and save dkdy and d2kdy2
dkdy= np.zeros(nj)
dkdy_s= np.zeros(nj)
dkdy_n= np.zeros(nj)
d2kdy2 = np.zeros(nj)

for j in range(1,nj):
   k_n=fy[j]*k[j+1]+(1.-fy[j])*k[j]
   k_s=fy[j-1]*k[j]+(1.-fy[j-1])*k[j-1]
   dkdy[j] = (k_n-k_s)/delta_y[j]
   dkdy_n[j] = (k[j+1]-k[j])*dn[j]
   dkdy_s[j] = (k[j]-k[j-1])*ds[j]

dkdy[-1] = -k[j]/(yc[-1]-yp[-2])
dkdy[1] = k[1]/yp[1]

for j in range(1,nj):
   d2kdy2[j] = (dkdy_n[j]-dkdy_s[j])/delta_y[j]


np.savetxt('y-k-vist-poisson-1D.txt',np.c_[yp[1:-1],k[1:-1],vist[1:-1]])  # save only internal nodes
np.savetxt('dkdy-d2kdy2-1D.txt',np.c_[dkdy[1:],d2kdy2[1:]])  # save only internal nodes
# plot k
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(k,yp,'b-',label="CFD")
# analytical
k_analytical =  (-yp**2 +yp)/vist[2]
plt.xlabel(r'$k$')
plt.ylabel('y')
plt.savefig('k-poisson-1d.png')

