import torch.nn as nn

class ThePredictionMachine(nn.Module):

      def forward(self, x):
        x = nn.functional.relu(self.input(x))
        x = nn.functional.relu(self.hidden1(x))
        x = self.hidden2(x)

        return x

def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):
   
   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d,dist3d


def modify_inlet():


   return u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw

def modify_conv(convw,convs,convl):

   convs[:,0,:]=0
   convs[:,-1,:]=0

   return convw,convs,convl


def modify_u(su3d,sp3d):

    su3d=su3d + vol

# south wall
    ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed at first cell

    tauw=ustar**2
    su3d[:,0,:]=su3d[:,0,:]-tauw*areas[:,0,:]*xp.sign(u3d[:,0,:])
    ustar_old = ustar

# north wall
    tauw=ustar**2
    su3d[:,-1,:]=su3d[:,-1,:]-tauw*areas[:,-1,:]*xp.sign(u3d[:,-1,:])

    return su3d,sp3d

def modify_v(su3d,sp3d):

    return su3d,sp3d


def modify_w(su3d,sp3d):

   return su3d,sp3d

def modify_k(su3d,sp3d,gen):

   comm_term=xp.zeros((nj,nk))

   return su3d,sp3d,comm_term

def modify_eps(su3d,sp3d):

   return su3d,sp3d

def modify_om(su3d,sp3d,comm_term):

   return su3d,sp3d

def fix_omega():

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_outlet(convw):

# inlet
   flow_in=xp.sum(convw[0,:,:])
   flow_out=xp.sum(convw[-1,:,:])
   area_out=xp.sum(areaw[-1,:,:])

   uinc=(flow_in-flow_out)/area_out
   ares=areaw[-1,:,:]
   convw[-1,:,:]=convw[-1,:,:]+uinc*ares

   print('area_out',area_out)

   flow_out_new=xp.sum(convw[-1,:,:])

   print('flow_in',flow_in,'flow_out',flow_out,'area_out',area_out,'flow_out_new',flow_out_new,'uinc:',uinc)

   return convw,u_bc_east

def modify_fk(fk3d):

   global f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_dt_mean,f_b_mean,denom_mean,l_c_mean,l_tilde_mean

   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=xp.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=xp.repeat(dy[:,:,None],repeats=nk,axis=2)
 
   l_temp=xp.maximum(l_dist,l_max)
   l_temp=xp.maximum(l_temp,dy)
   l_iddes=xp.minimum(l_temp,delta_max)
#  l_iddes=xp.minimum(xp.maximum(l_dist,l_max,dy),delta_max)

   ueps=(eps3d*viscos)**0.25
   ystar=ueps*dist3d/viscos
   rt=k3d**2/eps3d/viscos
   fdampf2=((1.-xp.exp(-ystar/3.1))**2)*(1.-0.3*xp.exp(-(rt/6.5)**2))
   fmu=((1.-xp.exp(-ystar/14.))**2)*(1.+5./rt**0.75*xp.exp(-(rt/200.)**2))
   fmu=xp.minimum(fmu,1.)

   psi=xp.minimum(10,(fdampf2*fmu)**(-0.75))

   l_c=psi*cdes*l_iddes  #eq. 9

   vist=vis3d-viscos
   denom=kappa**2*dist3d**2*gen**0.5

   r_dt=vist/denom  #eq. 22
   r_dl=viscos/denom  #eq. 23

   f_t=xp.tanh((c_t**2*r_dt)**3)
   f_l=xp.tanh((c_l**2*r_dl)**10)

   f_e2=1.-xp.maximum(f_t,f_l) #eq. 19

   alpha=0.25-dist3d/delta_max

   f_e1=xp.where(alpha <= 0,2*xp.exp(-9*alpha**2),2*xp.exp(-11.09*alpha**2))
#  if alpha <= 0:
#     f_e1=2.*xp.exp(-9*alpha**2)
#  else:
#     f_e1=2.*xp.exp(-11.09*alpha**2)

   f_b=  xp.minimum(2.*xp.exp(-9*alpha**2),1.)

   f_dt=1.-xp.tanh((8.*r_dt)**3)

   f_e=xp.maximum(f_e1-1.,0.)*psi*f_e2

   f_d=xp.maximum((1.-f_dt),f_b)

   l_u=k3d**1.5/eps3d

#  l_tilde=f_d*(1+f_e)*l_u+(1-f_d)*l_c
   l_tilde=f_d*l_u+(1-f_d)*l_c

   fk3d=l_u/l_tilde
#  fk3d=xp.minimum(fk3d,1.9/1.5)

   if iter == 0 and itstep ==0:
      f_e_mean=xp.zeros(nj)
      f_e1_mean=xp.zeros(nj)
      f_e2_mean=xp.zeros(nj)
      f_d_mean=xp.zeros(nj)
      f_dt_mean=xp.zeros(nj)
      f_dt_mean=xp.zeros(nj)
      f_b_mean=xp.zeros(nj)
      l_c_mean=xp.zeros(nj)
      l_tilde_mean=xp.zeros(nj)
      denom_mean=xp.zeros(nj)

   if iter == 0 and itstep%itstep_stats == 0 and itstep >= itstep_start:
      f_e_mean,f_d_mean,f_dt_mean,f_b_mean,denom_mean,f_e1_mean,f_e2_mean,l_c_mean,l_tilde_mean=\
       aver_iddes(denom,f_e,f_e1,f_e2,f_d,f_dt,f_b,l_c,l_tilde,\
       denom_mean,f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_b_mean,l_c_mean,l_tilde_mean)

   if iter == 0 and (itstep == ntstep-1 or itstep%itstep_save == 0):
      print('IDDES functions saved')
      xp.save('f_e_mean', f_e_mean)
      xp.save('f_d_mean', f_d_mean)
      xp.save('l_c_mean', l_c_mean)
      xp.save('f_dt_mean',f_dt_mean)
      xp.save('f_b_mean',f_b_mean)
      xp.save('f_e1_mean',f_e1_mean)
      xp.save('f_e2_mean',f_e2_mean)
      xp.save('l_tilde_mean',l_tilde_mean)
      xp.save('denom_mean',denom_mean)

   return fk3d

def aver_iddes(denom,f_e,f_e1,f_e2,f_d,f_dt,f_b,l_c,l_tilde,\
    denom_mean,f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_b_mean,l_c_mean,l_tilde_mean):

   f_e_mean=f_e_mean+xp.mean(f_e,axis=(0,2))
   f_e1_mean=f_e1_mean+xp.mean(f_e1,axis=(0,2))
   f_e2_mean=f_e2_mean+xp.mean(f_e2,axis=(0,2))
   f_d_mean=f_d_mean+xp.mean(f_d,axis=(0,2))
   l_c_mean=l_c_mean+xp.mean(l_c,axis=(0,2))
   l_tilde_mean=l_tilde_mean+xp.mean(l_tilde,axis=(0,2))
   f_dt_mean=f_dt_mean+xp.mean(f_dt,axis=(0,2))
   f_b_mean=f_b_mean+xp.mean(f_b,axis=(0,2))
   denom_mean=denom_mean+xp.mean(denom,axis=(0,2))

   return f_e_mean,f_d_mean,f_dt_mean,f_b_mean,denom_mean,f_e1_mean,f_e2_mean,l_c_mean,l_tilde_mean



def fix_k():

    from joblib import dump, load
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import TensorDataset, DataLoader
  
    global yplus_min,yplus_max,pplus_min,pplus_max,scaler_yplus,scaler_pplus,neural_net,ubulk,u_wall_mean,uplus_min,uplus_max,u_wall_mean_2nd
    global ustar_mean_s,dpdx_ustar_corr,counter_ustar,dpdx_mean,ustar_mean_p,ustar_old,count,u_mean,scaler_dudy,dudy_min,dudy_max
 


# NN

   
# load data
# inlet
    if iter == 0 and itstep == 0: 
       folder = '../pytorch-diffusor/'
       filename=str(folder)+'model-channel-5200-only-yplus-IDDES.pth'
       neural_net = torch.load(filename)
       print('model',neural_net)
       scaler_yplus = load(str(folder)+'scaler-yplus-channel-5200-only-yplus-IDDES.bin')
       scaler_pplus = load(str(folder)+'scaler-pplus-channel-5200-only-yplus-IDDES.bin') # dummy
       scaler_dudy = load(str(folder)+'scaler-dudy-channel-5200-only-yplus-IDDES.bin') # dummy

       [yplus_min, yplus_max, pplus_min, pplus_max, dudy_min, dudy_max, uplus_min, uplus_max] = \
       xp.loadtxt(str(folder)+'min-max-model-channel-5200-only-yplus-IDDES.txt')

       print('yplus_min,yplus_max,pplus_min,pplus_max,re_min,re_max, uplus_min, uplus_max',\
              yplus_min,yplus_max,pplus_min,pplus_max,dudy_min,dudy_max, uplus_min, uplus_max)

       ustar_mean_s = xp.zeros(ni)
       ustar_mean_p = xp.zeros(ni)
       u2d_wall=abs(u3d[:,0,:])  # first cell
       u2d_wall_2nd=abs(u3d[:,1,:])  # 2nd cell
       u_mean_k=xp.mean(u2d_wall,axis=1)
       u_mean_k_2nd=xp.mean(u2d_wall_2nd,axis=1)
       u_wall_mean = u_mean_k
       u_wall_mean_2nd = u_mean_k_2nd
       dpdx = dphidx(p3d_face_w,p3d_face_s)
       dpdx_mean_k =xp.mean(dpdx[:,0,:],axis=1)
       dpdx_mean = dpdx_mean_k
       ustar_old = xp.ones((ni,nk))
       u1d = xp.mean(u3d[0,:,:],axis=1)
#      ubulk = xp.trapz(u1d,yp2d[0,:])/max(y2d[0,:])
       ubulk= xp.sum(u1d*xp.diff(y2d[0,:]))/max(y2d[0,:])
       print('ubulk',ubulk)


# take old ustar
    ustar=ustar_old
# average ustar to be used in p+
    ustar_mean_k=xp.mean(ustar,axis=1)
    ustar_mean_p=ustar_mean_p + ustar_mean_k
    ustar_p = ustar_mean_p/(itstep+1)
# make it 2D
    ustar_p=xp.repeat(ustar_p[:,None], repeats=nk, axis=1)

    dpdx = dphidx(p3d_face_w,p3d_face_s)
    dpdx_mean_k =xp.mean(dpdx[:,0,:],axis=1)
    if iter == 0:
       dpdx_mean=dpdx_mean + dpdx_mean_k
    dpdx_p=dpdx_mean/(itstep+1)
# make it 2D
    dpdx_p=xp.repeat(dpdx_p[:,None], repeats=nk, axis=1)

    u2d_wall=abs(u3d[:,0,:])  # first cell
    u2d_wall_2nd=abs(u3d[:,1,:])  # 2nd cell
    u_mean_k=xp.mean(u2d_wall,axis=1)
    u_mean_k_2nd=xp.mean(u2d_wall_2nd,axis=1)
    if iter == 0:
       u_wall_mean=u_wall_mean + u_mean_k
       u_wall_mean_2nd=u_wall_mean_2nd + u_mean_k_2nd
    u_mean=u_wall_mean/(itstep+1)
    u_mean_2nd=u_wall_mean_2nd/(itstep+1)

    dy=dist3d[:,0,:] # first cell
    yplus_south = ustar*dy/viscos
    dudy_south=(u_mean_2nd - u_mean)/(yp2d[:,1] - yp2d[:,0])*viscos/ubulk**2

# make it 2D
    dudy_south=xp.repeat(dudy_south[:,None], repeats=nk, axis=1)

    pplus_south = viscos*dpdx_p/ubulk**3 
#   if iter == 0 and itstep % 1 == 0:
    if iter == 0 and itstep % 100 == 0:
       print('dpdx_p[0:20,0],yplus_south[0,0]',dpdx_p[0:20,0],yplus_south[0,0])
       print('yplus_south[0:20,0]',yplus_south[0:20,0])
       print('pplus_south[0:20,0]',pplus_south[0:20,0])
       print('dudy_south[0:20,0]',dudy_south[0:20,0])
#   pplus_south = -xp.ones((ni,nk))
    uplus_south = u2d_wall/ustar
    ustar_south = ustar

# count values larger/smaller than max/min
    yplus_min_number= (yplus_south < yplus_min).sum()
    yplus_max_number= (yplus_south > yplus_max).sum()
    print('south: yplus_min_number',yplus_min_number)
    print('south: yplus_max_number',yplus_max_number)

    pplus_min_number= (pplus_south < pplus_min).sum()
    pplus_max_number= (pplus_south > pplus_max).sum()
    print('south: pplus_min_number',pplus_min_number)
    print('south: pplus_max_number',pplus_max_number)

    dudy_min_number= (dudy_south < dudy_min).sum()
    dudy_max_number= (dudy_south > dudy_max).sum()
    print('south: dudy_min_number',dudy_min_number)
    print('south: dudy_max_number',dudy_max_number)

    ustar_min =xp.min(ustar)
    ustar_max =xp.max(ustar)
    print('dy,ustar_min,max,mean',dy[0,0],ustar_min,ustar_max,xp.mean(ustar))
    print('pplus, min,max,mean',xp.min(pplus_south),\
          xp.max(pplus_south),xp.mean(pplus_south))
    print('yplus, min,max,mean',xp.min(yplus_south),\
          xp.max(yplus_south),xp.mean(yplus_south))

# set limits
    yplus_south=xp.minimum(yplus_south,yplus_max)
    yplus_south=xp.maximum(yplus_south,yplus_min)

    pplus_south=xp.minimum(pplus_south,pplus_max)
    pplus_south=xp.maximum(pplus_south,pplus_min)

    dudy_south=xp.minimum(dudy_south,dudy_max)
    dudy_south=xp.maximum(dudy_south,dudy_min)

    print('u2d_wall, min, max, mean',xp.min(u2d_wall),xp.max(u2d_wall),xp.mean(u2d_wall))

    N_predict=ni*nk
    y = uplus_south
    y = y.reshape(-1,1)
    yplus = yplus_south.reshape(-1,1)
    pplus = pplus_south.reshape(-1,1)
    dudy = dudy_south.reshape(-1,1)

    X=xp.zeros((len(yplus),1))

    if gpu:
      yplus_np = xp.asnumpy(yplus)
      pplus_np = xp.asnumpy(pplus)
      dudy_np = xp.asnumpy(dudy)
      X_np = xp.asnumpy(X)
    else:
      yplus_np = yplus
      pplus_np = pplus
      dudy_np = dudy
      X_np = X

    X_np[:,0] = scaler_yplus.transform(yplus_np)[:,0]
#   X_np[:,1] = scaler_pplus.transform(pplus_np)[:,0]
#   X_np[:,2] = scaler_dudy.transform(dudy_np)[:,0]
    
    X_tensor = torch.tensor(X_np, dtype=torch.float32)

    preds = neural_net(X_tensor)

    uplus_NN = preds.detach().numpy()
    uplus_NN = uplus_NN[:,0]
    print('type(uplus_NN)',type(uplus_NN))

    if gpu:
       uplus_NN = xp.asarray(uplus_NN)

    uplus_min_number= (uplus_south < uplus_min).sum()
    uplus_max_number= (uplus_south > uplus_max).sum()
    print('south: uplus_min_number',uplus_min_number)
    print('south: uplus_max_number',uplus_max_number)


    uplus_NN=xp.minimum(uplus_NN,uplus_max)
    uplus_NN=xp.maximum(uplus_NN,uplus_min)

    uplus_predict=xp.reshape(uplus_NN,(ni,nk))

    if iter == 0 and itstep % 100 == 0:
       print('uplus_predict[0:20,0]',uplus_predict[0:20,0])

    ustar=xp.divide(u2d_wall,uplus_predict)  # ustar = u2d_wall/uplus


    kwall=cmu**(-0.5)*ustar**2

# fix k at 1st cell
    aw3d[:,0,:]=0
    ae3d[:,0,:]=0
    as3d[:,0,:]=0
    an3d[:,0,:]=0
    al3d[:,0,:]=0
    ah3d[:,0,:]=0
    ap_max=xp.max(ap3d)
    ap3d[:,0,:]=ap_max
    su3d[:,0,:]=ap_max*kwall

# north wall
    aw3d[:,-1,:]=0
    ae3d[:,-1,:]=0
    as3d[:,-1,:]=0
    an3d[:,-1,:]=0
    al3d[:,-1,:]=0
    ah3d[:,-1,:]=0
    ap_max=xp.max(ap3d)
    ap3d[:,-1,:]=ap_max
    su3d[:,-1,:]=ap_max*kwall



    if iter == 0:
      if itstep == 0:
         ustar_mean_s = xp.zeros(ni)
         count = 0
      if itstep > itstep_start:
         count = count +1
         ustar_mean_k=xp.mean(ustar,axis=1)
         ustar_mean_old = ustar_mean_s
         ustar_mean_s=ustar_mean_s + ustar_mean_k
         print('ustar_mean_old,ustar_mean_s[0],ustar_mean_s[0]/step',ustar_mean_old[0],ustar_mean_s[0],ustar_mean_s[0]/count)
         dpdx_mean[-1] = ntstep
         xp2d[-1,1] = count
#        if itstep%100 == 0:
         if itstep%1 == 0:
            xp.savetxt('ustar-vs-x.dat', xp.c_[xp2d[:,1],ustar_mean_s,dpdx_mean,u_mean,u_mean_2nd,dudy_south[:,0]])

    return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def fix_eps():

# south wall
   dy=dist3d[0,0,0] # first cell
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed in first cell

# fix eps at 1st cell
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*ustar**3/kappa/dy

# north wall
   dy=dist3d[0,-1,0]
   ustar=cmu**0.25*k3d[:,-1,:]**0.5  # k is fixed in first cell
# fix eps at 1st cell
   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*ustar**3/kappa/dy

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def modify_vis(vis3d):

   return vis3d

def solve_ustar_reich(ustar,velabs,wdist):
    import numpy as xp

    return  ustar-velabs/(1/0.4*xp.log(1+0.4*ustar*wdist)+7.8*(1-xp.exp(-ustar*wdist/11))-ustar*wdist*xp.exp(-ustar*wdist/3))

def compute_ustar_reich_solve(wdist,velabs,ustar):
   from scipy.optimize import fsolve,root,newton
   yplus=ustar*wdist/viscos
   ustar=ustar.flatten()
   velabs=velabs.flatten()
   wdist=wdist.flatten()
   if gpu:
      velabs_np = xp.asnumpy(velabs)
      ustar_np = xp.asnumpy(ustar)
      wdist_np = xp.asnumpy(wdist)
   else:
      velabs_np = velabs
      ustar_np = ustar
      wdist_np = wdist
   ustar = newton(solve_ustar_reich,x0=ustar_np,args=(velabs_np,wdist_np))
   ustar = xp.asarray(ustar)
   ustar=xp.reshape(ustar,(ni,nk))

   return ustar


