# taken from /chalmers/users/lada/noback/pycalc-les/channel-16000-WF-IDDES-NN-only-yplus/
import numpy as np
import torch
import sys
import time
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
#from sklearn.discriminant_analysis import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from random import randrange
from joblib import dump, load

# log law wall function
def compute_ustar(wdist,velabs,ustar,elog,n):

   for i in range(0,n):
      arg=np.maximum(elog*ustar*wdist/viscos,10.)
      ustar=kappa*velabs/np.log(arg)

   xyplus=ustar*wdist/viscos
   ustar=np.where(xyplus <= 11.69,(viscos*velabs/wdist)**0.5,ustar)

   return ustar


## Reichard's wall function
def solve_ustar_reich(ustar,velabs,wdist):
    import numpy as xp

    return  ustar-velabs/(1/0.4*np.log(1+0.4*ustar*wdist)+7.8*(1-np.exp(-ustar*wdist/11))-ustar*wdist*np.exp(-ustar*wdist/3))

def compute_ustar_reich_solve(wdist,velabs,ustar):
   from scipy.optimize import fsolve,root,newton
   yplus=ustar*wdist/viscos
   ustar=ustar.flatten()
   velabs=velabs.flatten()
   wdist=wdist.flatten()
   ustar = newton(solve_ustar_reich,x0=ustar,args=(velabs,wdist))

   return ustar


#########  The neural network modules: start ################################
class ThePredictionMachine(nn.Module):

    def __init__(self):
        
        super(ThePredictionMachine, self).__init__()

        self.input   = nn.Linear(1, 50) # axis 0: dimension of X
        self.hidden1 = nn.Linear(50, 50)
        self.hidden2 = nn.Linear(50, 1) # axis 1: dimension of y

    def forward(self, x):
        x = nn.functional.relu(self.input(x))
        x = nn.functional.relu(self.hidden1(x))
        x = self.hidden2(x)

        return x

# TDMA solver
def tdma(a,b,c,d,phi,ni):
   import numpy as np

   nim1=ni-1
   p=np.zeros(ni+1)
   q=np.zeros(ni+1)
   q[0]=phi[0]
   for i in range(1,ni):

#            calculate coefficients of recurrence formula                      
      term= (a[i]-c[i]*p[i-1])
      if abs(term)<1.e-10:
         term=1.e-10
      p[i]= b[i]/term
      q[i]= (d[i]+c[i]*q[i-1])/term
#         obtain new phi@s                                       
   for ii in range(1,ni):
       i= ni-ii
#      print('i,ii=',i,ii)
       phi[i]= p[i]*phi[i+1]+q[i]

   return phi


plt.interactive(True)

plt.close('all')

#

# exemple of 1d Channel flow with a k-eps model. Re=u_tau*h/nu=16000 (h=half 
# channel height).
#
# Discretization described in detail in
# http://www.tfd.chalmers.se/~lada/comp_fluid_dynamics/

# max number of iterations
maxit=25000
#maxit=1
#maxit=100

plt.rcParams.update({'font.size': 22})

# friction velocity u_*=1
# half channel width=1
#

# create the grid

nj=33
njm1=nj-1
yfac=1.0 
dy=0.1
yc=np.zeros(nj)
delta_y=np.zeros(nj)
yc[0]=0.
for j in range(1,nj):
    yc[j]=yc[j-1]+dy
    dy=yfac*dy

ymax= yc[-1]

# cell faces
for j in range(0,nj):
   yc[j]=yc[j]/ymax
yc[-1]=1.

# cell centres
yp=np.zeros(nj+1)
for j in range(1,nj):
   yp[j]=0.5*(yc[j]+yc[j-1])
yp[-1]=yc[-1]

# viscosity
viscos=1/16000

# under-relaxation
urf=0.5

# plot k for each iteration at node jmon
jmon=8 

# turbulent constants 
c_eps_1=1.5
c_eps_2=1.9

prand_eps=1.4
prand_k=1.4

kappa = 0.41

cmu=0.09

small=1.e-10
great=1.e10

# initialaze
u=np.ones(nj+1)
k=np.ones(nj+1)
y=np.zeros(nj+1)
eps=np.ones(nj+1)
dn=np.zeros(nj+1)
ds=np.zeros(nj+1)
dy_s=np.zeros(nj+1)
fy=np.zeros(nj+1)
tau_w=np.zeros(maxit)
k_iter=np.zeros(maxit)
eps_iter=np.zeros(maxit)
dudy=np.gradient(u,yp)
vist=cmu*k**2/eps

# load NN model
filename='model-channel-5200-only-yplus-IDDES.pth'
neural_net = torch.load(filename,weights_only=False)
print('model',neural_net)
scaler_yplus = load('scaler-yplus-channel-5200-only-yplus-IDDES.bin')
scaler_pplus = load('scaler-pplus-channel-5200-only-yplus-IDDES.bin') # dummy
scaler_dudy = load('scaler-dudy-channel-5200-only-yplus-IDDES.bin') # dummy

[yplus_min, yplus_max, pplus_min, pplus_max, dudy_min, dudy_max, uplus_min, uplus_max] = \
np.loadtxt('min-max-model-channel-5200-only-yplus-IDDES.txt')

print('yplus_min,yplus_max,pplus_min,pplus_max,re_min,re_max, uplus_min, uplus_max',\
   yplus_min,yplus_max,pplus_min,pplus_max,dudy_min,dudy_max, uplus_min, uplus_max)

# do a loop over all nodes (except the boundary nodes)
for j in range(1,nj):

# compute dy_s
   dy_s[j]=yp[j]-yp[j-1]

# compute dy_n
   dy_n=yp[j+1]-yp[j]

# compute deltay
   delta_y[j]=yc[j]-yc[j-1]
 
   dn[j]=1./dy_n
   ds[j]=1./dy_s[j]

# interpolation factor
   del1=yc[j]-yp[j]
   del2=yp[j+1]-yc[j]
   fy[j]=del1/(del1+del2)

vist[0]=0.
vist[-1]=0.
k[0]=0.
k[-1]=0.


su=np.zeros(nj+1)
sp=np.zeros(nj+1)
an=np.zeros(nj+1)
as1=np.zeros(nj+1)
ap=np.zeros(nj+1)
# do max. maxit iterations
for n in range(0,maxit):

# solve u

    for j in range(1,nj):

# driving pressure gradient
      su[j]=delta_y[j]

      sp[j]=0.

# interpolate turbulent viscosity to faces
      vist_n=fy[j]*vist[j+1]+(1-fy[j])*vist[j]
      vist_s=fy[j-1]*vist[j]+(1-fy[j-1])*vist[j-1]

# compute an & as
      an[j]=(vist_n+viscos)*dn[j]
      as1[j]=(vist_s+viscos)*ds[j]

# wall function
    ustar=cmu**0.25*k[1]**0.5

# boundary conditions for u
    as1[1]=0 # wall function
    su[1] = su[1] - ustar**2 # wall shear stress
    an[-2]=0 # symmetry

    res_u = 0
    for j in range(1,nj):
# compute ap
      ap[j]=an[j]+as1[j]-sp[j]

# under-relaxate
      ap[j]= ap[j]/urf
      su[j]= su[j]+(1.0-urf)*ap[j]*u[j]

      res_u += abs(an[j]*u[j+1]+as1[j]*u[j-1]+su[j]-ap[j]*u[j])
# TDMA
    u=tdma(ap,an,as1,su,u,nj)

    u[-1] =  u[-2] # needed when dudy is computed in the production terms

# monitor the development of u_tau in node jmon
    tau_w[n]=ustar**2

# print iteration info
    tau_target=1
    print(f"\n{'---iter: '}{n:2d}, {'wall shear stress: '}{tau_w[n]:.2e},{'  tau_w_target='}{tau_target:.2e}\n")


# check for convergence (when converged, the wall shear stress must be one)
    ntot=n
    if abs(tau_w[n]-1) < 0.001:
# do at least 1000 iter 
        if n > 1000:
           print('Converged!')
           break

# solve k
    res_k = 0
# monitor the development of k in node jmon
    k_iter[n]=k[jmon]

    dudy=np.gradient(u,yp,edge_order=2)
# fix boundaries
    dudy[0]=dudy[1]
    dudy[-1]=dudy[-2]

    dudy2=dudy**2
    for j in range(1,nj):

# compute viscosity
      vist_new = cmu*k[j]**2/eps[j]
      vist[j] = vist_new*urf + (1-urf)*vist[j]

# production term
      su[j]=vist[j]*dudy2[j]*delta_y[j]

# dissipation term
      sp[j]=-eps[j]/k[j]*delta_y[j]

# compute an & as
      vist_n=fy[j]*vist[j+1]+(1.-fy[j])*vist[j]
      an[j]=(vist_n/prand_k+viscos)*dn[j]
      vist_s=fy[j-1]*vist[j]+(1.-fy[j-1])*vist[j-1]
      as1[j]=(vist_s/prand_k+viscos)*ds[j]

    dy=yp[1]
    velabs=abs(u[1])
#   ustar=compute_ustar_reich_solve(dy/viscos,velabs,ustar)  # reichard wall function
#   ustar = compute_ustar(dy,velabs,ustar,9,4)  # log wall function

# set number of cells in x direction to 1
    ni = 1
    ustar_mean_s = np.zeros(ni)
    ustar_mean_p = np.zeros(ni)
    u2d_wall=abs(u[1])  # first cell
    u2d_wall_2nd=abs(u[2])  # 2nd cell
    ubulk = np.trapz(u,yp)/max(yc)

    kwall=cmu**(-0.5)*ustar**2
    dy=yp[1] # first cell
#   dy=yp[1] # 2nd cell
    yplus_south = ustar*dy/viscos
    uplus_south = u2d_wall/ustar
    dpdx = np.zeros(ni)
    pplus_south = viscos*dpdx/ubulk**3
    dudy_south = np.zeros(ni)

# count values larger/smaller than max/min
    yplus_min_number= (yplus_south < yplus_min).sum()
    yplus_max_number= (yplus_south > yplus_max).sum()
    print('south: yplus_min_number',yplus_min_number)
    print('south: yplus_max_number',yplus_max_number)

# set limits
    yplus_south=np.minimum(yplus_south,yplus_max)
    yplus_south=np.maximum(yplus_south,yplus_min)

    N_predict=ni
    y = uplus_south
    y = y.reshape(-1,1)
    yplus = yplus_south.reshape(-1,1)
    #pplus = pplus_south.reshape(-1,1)
    #dudy = dudy_south.reshape(-1,1)

    X=np.zeros((len(yplus),1))

    X[:,0] = scaler_yplus.transform(yplus)[:,0]
#   X[:,1] = scaler_pplus.transform(pplus)[:,0]
#   X[:,2] = scaler_dudy.transform(dudy)[:,0]

    X_tensor = torch.tensor(X, dtype=torch.float32)

    preds = neural_net(X_tensor)

    uplus_NN = preds.detach().numpy()
    uplus_NN = uplus_NN[:,0]

    uplus_min_number= (uplus_south < uplus_min).sum()
    uplus_max_number= (uplus_south > uplus_max).sum()
    print('south: uplus_min_number',uplus_min_number)
    print('south: uplus_max_number',uplus_max_number)

    uplus_NN=np.minimum(uplus_NN,uplus_max)
    uplus_NN=np.maximum(uplus_NN,uplus_min)

    uplus_predict=np.reshape(uplus_NN,(ni))


    ustar=u2d_wall/uplus_predict  # ustar = u2d_wall/uplus

    ustar = ustar.item() #convert from array to scalar

    kwall=cmu**(-0.5)*ustar**2

# boundary conditions for k
    an[-2]=0 # symmetry
    as1[1] = 0

    for j in range(1,nj):
# compute ap
      ap[j]=an[j]+as1[j]-sp[j]

# under-relaxate
      ap[j]= ap[j]/urf
      su[j]= su[j]+(1.-urf)*ap[j]*k[j]

      res_k += abs(an[j]*k[j+1]+as1[j]*k[j-1]+su[j]-ap[j]*k[j])

# wall function
    ap[1]=1
    an[1]=0
    as1[1]=0
    su[1]=kwall

# use TDMA
    k=tdma(ap,an,as1,su,k,nj)

#****** solve eps-eq.
    res_eps = 0
    eps_iter[n]=eps[jmon]

    for j in range(1,nj):
# compute an & as
      vist_n=fy[j]*vist[j+1]+(1.-fy[j])*vist[j]
      an[j]=(vist_n/prand_eps+viscos)*dn[j]
      vist_s=fy[j-1]*vist[j]+(1.-fy[j-1])*vist[j-1]
      as1[j]=(vist_s/prand_eps+viscos)*ds[j]

# production term
      su[j]=c_eps_1*cmu*dudy2[j]*k[j]*delta_y[j]

# dissipation term
      sp[j]=-c_eps_2*eps[j]*delta_y[j]/k[j]

# b.c. north wall
    an[-2]=0 # symmetry

    for j in range(1,nj):
# compute ap
      ap[j]=an[j]+as1[j]-sp[j]

# under-relaxate
      ap[j]= ap[j]/urf
      su[j]= su[j]+(1.-urf)*ap[j]*eps[j]

      if j != 1 and j != nj-1: # omit the wall-adjacent cells where eps is set
         res_eps += abs(an[j]*eps[j+1]+as1[j]*eps[j-1]+su[j]-ap[j]*eps[j])

# wall function
    dy=yp[1]
    eps_wall=ustar**3/0.41/yp[1]  # cell 0 is outside the domain
    ap[1]=1
    an[1]=0
    as1[1]=0
    su[1]=eps_wall

# TDMA
    eps=tdma(ap,an,as1,su,eps,nj)

# print residuals
    print(f"\n{'---iter: '}{n:2d}, {'res u: '}{res_u:.2e},{'  res k='}{res_k:.2e},{'  res eps='}{res_eps:.2e}\n")

DNS_mean=np.genfromtxt("LM_Channel_5200_mean_prof.dat",comments="%")
y_DNS=DNS_mean[:,0];
yplus_DNS=DNS_mean[:,1];
u_DNS=DNS_mean[:,2];

DNS_stress=np.genfromtxt("LM_Channel_5200_vel_fluc_prof.dat",comments="%")
u2DNS=DNS_stress[:,2];
v2DNS=DNS_stress[:,3];
w2DNS=DNS_stress[:,4];
uvDNS=DNS_stress[:,5];

k_DNS = 0.5*(u2DNS + v2DNS + w2DNS) 


# plot u
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(u,yp,'b-',label="CFD")
plt.plot(u_DNS,y_DNS,'r-',label="DNS")
plt.xlabel("$U^+$")
plt.ylabel("$y$")
plt.legend(loc="best",prop=dict(size=18))
plt.savefig('u_5200-NN-kom-no-batch.png')

# plot u log-scale
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
ustar=tau_w[ntot]**0.5
uplus=u/ustar
plt.semilogx(yp[1:]*ustar/viscos,uplus[1:],'b-',label="CFD")
plt.semilogx(yplus_DNS,u_DNS,'r-',label="DNS")
plt.ylabel("$U^+$")
plt.xlabel("$y^+$")
plt.axis([1, 5200, 0, 28])
plt.legend(loc="best",prop=dict(size=18))
plt.savefig('u_log-5200-NN-kom-no-batch.png')

# plot visc
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(vist/viscos,yp,'b-',label=r"$k-\varepsilon$")
plt.legend(loc="best",prop=dict(size=18))
plt.xlabel(r'$\nu_t/\nu$')
plt.ylabel('y')
plt.savefig('vis_5200-NN-kom-no-batch.png')

# plot eps
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(eps,yp,'b-',label=r"$k-\varepsilon$")
plt.legend(loc="best",prop=dict(size=18))
plt.xlabel(r'$\varepsilon$')
plt.ylabel('y')
plt.savefig('eps_5200-NN-kom-no-batch.png')

# plot k
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(k,yp,'b-',label="CFD")
plt.plot(k_DNS,y_DNS,'r-',label="DNS")
plt.legend(loc="best",prop=dict(size=18))
plt.xlabel('k')
plt.ylabel('y')
plt.savefig('k_5200-NN-kom-no-batch.png')


# plot uv
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.25,bottom=0.20)
uv = -vist*dudy
plt.plot(uv,yp*ustar/viscos,'b-',label="CFD")
plt.plot(uvDNS,yplus_DNS,'r-',label="DNS")
plt.legend(loc="best",prop=dict(size=18))
plt.xlabel(r"$\overline{u'v'}$")
plt.ylabel('y')
plt.savefig('uv_5200-NN-kom-no-batch.png')

