def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):

   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d,dist3d

def modify_inlet():

   return u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw

def modify_conv(convw,convs,convl):

   convs[:,0,:]=0
   convs[:,-1,:]=0

   return convw,convs,convl

def modify_u(su3d,sp3d):

# add driving force
   su3d=su3d+vol

# south wall
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed at first cell
   tauw=ustar**2
   su3d[:,0,:]=su3d[:,0,:]-tauw*areas[:,0,:]
# north wall
   ustar=cmu**0.25*k3d[:,-1,:]**0.5
   print('modify_u, ustar[-2,-2],[2,2]',ustar[-2,-2],ustar[2,2])
   tauw=ustar**2
   su3d[:,-1,:]=su3d[:,-1,:]-tauw*areas[:,0,:]

   return su3d,sp3d

def modify_v(su3d,sp3d):
    
   return su3d,sp3d

def modify_w(su3d,sp3d):
    
   return su3d,sp3d

def modify_k(su3d,sp3d,gen):

   comm_term=np.zeros((ni,nj,nk))

   return su3d,sp3d,comm_term

def modify_eps(su3d,sp3d):

   return su3d,sp3d

def modify_om(su3d,sp3d,comm_term):

   return su3d,sp3d

def modify_fk(fk3d):

   global f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_dt_mean,f_b_mean,denom_mean,l_c_mean,l_tilde_mean

   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=np.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=np.repeat(dy[:,:,None],repeats=nk,axis=2)
 
   l_temp=np.maximum(l_dist,l_max)
   l_temp=np.maximum(l_temp,dy)
   l_iddes=np.minimum(l_temp,delta_max)
#  l_iddes=np.minimum(np.maximum(l_dist,l_max,dy),delta_max)

   ueps=(eps3d*viscos)**0.25
   ystar=ueps*dist3d/viscos
   rt=k3d**2/eps3d/viscos
   fdampf2=((1.-np.exp(-ystar/3.1))**2)*(1.-0.3*np.exp(-(rt/6.5)**2))
   fmu=((1.-np.exp(-ystar/14.))**2)*(1.+5./rt**0.75*np.exp(-(rt/200.)**2))
   fmu=np.minimum(fmu,1.)

   psi=np.minimum(10,(fdampf2*fmu)**(-0.75))

   l_c=psi*cdes*l_iddes  #eq. 9

   vist=vis3d-viscos
   denom=kappa**2*dist3d**2*gen**0.5

   r_dt=vist/denom  #eq. 22
   r_dl=viscos/denom  #eq. 23

   f_t=np.tanh((c_t**2*r_dt)**3)
   f_l=np.tanh((c_l**2*r_dl)**10)

   f_e2=1.-np.maximum(f_t,f_l) #eq. 19

   alpha=0.25-dist3d/delta_max

   f_e1=np.where(alpha <= 0,2*np.exp(-9*alpha**2),2*np.exp(-11.09*alpha**2))
#  if alpha <= 0:
#     f_e1=2.*np.exp(-9*alpha**2)
#  else:
#     f_e1=2.*np.exp(-11.09*alpha**2)

   f_b=  np.minimum(2.*np.exp(-9*alpha**2),1.)

   f_dt=1.-np.tanh((8.*r_dt)**3)

   f_e=np.maximum(f_e1-1.,0.)*psi*f_e2

   f_d=np.maximum((1.-f_dt),f_b)

   l_u=k3d**1.5/eps3d

#  l_tilde=f_d*(1+f_e)*l_u+(1-f_d)*l_c
   l_tilde=f_d*l_u+(1-f_d)*l_c

   fk3d=l_u/l_tilde
#  fk3d=np.minimum(fk3d,1.9/1.5)

   if iter == 0 and itstep ==0:
      f_e_mean=np.zeros(nj)
      f_e1_mean=np.zeros(nj)
      f_e2_mean=np.zeros(nj)
      f_d_mean=np.zeros(nj)
      f_dt_mean=np.zeros(nj)
      f_dt_mean=np.zeros(nj)
      f_b_mean=np.zeros(nj)
      l_c_mean=np.zeros(nj)
      l_tilde_mean=np.zeros(nj)
      denom_mean=np.zeros(nj)

   if iter == 0 and itstep%itstep_stats == 0 and itstep >= itstep_start:
      f_e_mean,f_d_mean,f_dt_mean,f_b_mean,denom_mean,f_e1_mean,f_e2_mean,l_c_mean,l_tilde_mean=\
       aver_iddes(denom,f_e,f_e1,f_e2,f_d,f_dt,f_b,l_c,l_tilde,\
       denom_mean,f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_b_mean,l_c_mean,l_tilde_mean)

   if iter == 0 and (itstep == ntstep-1 or itstep%itstep_save == 0):
      print('IDDES functions saved')
      np.save('f_e_mean', f_e_mean)
      np.save('f_d_mean', f_d_mean)
      np.save('l_c_mean', l_c_mean)
      np.save('f_dt_mean',f_dt_mean)
      np.save('f_b_mean',f_b_mean)
      np.save('f_e1_mean',f_e1_mean)
      np.save('f_e2_mean',f_e2_mean)
      np.save('l_tilde_mean',l_tilde_mean)
      np.save('denom_mean',denom_mean)


   return fk3d

def aver_iddes(denom,f_e,f_e1,f_e2,f_d,f_dt,f_b,l_c,l_tilde,\
    denom_mean,f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_b_mean,l_c_mean,l_tilde_mean):

   f_e_mean=f_e_mean+np.mean(f_e,axis=(0,2))
   f_e1_mean=f_e1_mean+np.mean(f_e1,axis=(0,2))
   f_e2_mean=f_e2_mean+np.mean(f_e2,axis=(0,2))
   f_d_mean=f_d_mean+np.mean(f_d,axis=(0,2))
   l_c_mean=l_c_mean+np.mean(l_c,axis=(0,2))
   l_tilde_mean=l_tilde_mean+np.mean(l_tilde,axis=(0,2))
   f_dt_mean=f_dt_mean+np.mean(f_dt,axis=(0,2))
   f_b_mean=f_b_mean+np.mean(f_b,axis=(0,2))
   denom_mean=denom_mean+np.mean(denom,axis=(0,2))

   return f_e_mean,f_d_mean,f_dt_mean,f_b_mean,denom_mean,f_e1_mean,f_e2_mean,l_c_mean,l_tilde_mean



def modify_outlet(convw):

# inlet
   flow_in=np.sum(convw[0,:,:])
   flow_out=np.sum(convw[-1,:,:])
   area_out=np.sum(areaw[-1,:,:])

   uinc=(flow_in-flow_out)/area_out
   ares=areaw[-1,:,:]
   convw[-1,:,:]=convw[-1,:,:]+uinc*ares

   print('area_out',area_out)

   flow_out_new=np.sum(convw[-1,:,:])

   print('flow_in',flow_in,'flow_out',flow_out,'area_out',area_out,'flow_out_new',flow_out_new,'uinc:',uinc)

   return convw,u_bc_east

def modify_vis(vis3d):
    
   return vis3d

def fix_k():

   from sklearn.preprocessing import MinMaxScaler
   from sklearn.preprocessing import StandardScaler
   from sklearn.linear_model import LinearRegression
   from sklearn.svm import SVR
   from sklearn.svm import LinearSVR
   from joblib import dump, load

   global model,scaler_re,re_max,re_min,scaler_yplus,yplus_max,yplus_min,file1,dudy_min,dudy_max,scaler_dudy,ustar_south,ustar_north, ustar2_time_aver,ustar_time_aver,uplus2_time_aver,uplus_time_aver,itstep_ustar
   if itstep == 0 and iter == 0:

     folder='../channel-5200-IDDES-96-86-96-ML-aver-xz-database-3cells-15000-timesteps/'
     folder='/chalmers/users/lada/noback/pycalc-les/channel-5200-IDDES-96-86-96-ML-aver-xz-database-3cells-low-re-15000-timesteps/'


     filename=str(folder)+'model-low-re-svr-C-10-eps-0.001-yplus-inst-uplus-output-first-cell-1-9-local-cells-300-samples.bin'
                           
     model = load(filename)
     print('model',model)
     scaler_yplus = load(str(folder)+'model-low-ustar-svr-C-10-eps-0.001-yplus-inst-uplus-output-first-cell-1-9-local-cell_scaler-yplus-300-samples.bin')

     yplus_max,yplus_min= np.loadtxt(str(folder)+'min-max-model-low-re-svr-C-10-eps-0.001-yplus-inst-uplus-output-first-cell-1-9-local-cells-loca-300-samplesl.txt')

     print('yplus_max,yplus_min',yplus_max,yplus_min)
     ustar_south=np.ones((ni,nk))
     ustar_north=np.ones((ni,nk))

################# south wall
   u2d_wall=abs(u3d[:,0,:])  # first cell

# normalize
   dy_wall=dist3d[0,0,0]  # first cell

   start_time_wallf = time.time()

   print('u2d_wall,dy_wall',u2d_wall[0,0],dy_wall)

   for n in range(0,1):

      yplus=ustar_south*dy_wall/viscos

#flatten
      yplus=yplus.flatten()

      print('south n,yplus_min,max,mean',n,np.min(yplus),np.max(yplus),np.mean(yplus))

# count values larger/smaller than max/min
      yplus_min_number= (yplus < yplus_min).sum()
      yplus_max_number= (yplus > yplus_max).sum()

      print('south: yplus_min_number',yplus_min_number)
      print('south: yplus_max_number',yplus_max_number)

# set limits
      yplus=np.minimum(yplus,yplus_max)
      yplus=np.maximum(yplus,yplus_min)

#size
      N=len(yplus)
# re-scale
      yplus=yplus.reshape(-1, 1)
      X=np.zeros((N,1))
      if gpu:
         yplus_np= np.asnumpy(yplus)
         X_np = np.asnumpy(X)
      else:
         yplus_np= yplus
         X_np = X

      yplus=scaler_yplus.transform(yplus_np)

# predict
      X_np[:,0]=yplus_np[:,0]

# compute ustar
      y_svr = model.predict(X_np) # uplus

      if gpu:
        y_svr_cp = np.asarray(y_svr)
      else:
        y_svr_cp = y_svr

      uplus=np.reshape(y_svr_cp,(ni,nk))
 
      ustar=np.divide(u2d_wall,uplus)
      ustar_south=ustar

      print('n,south ustar_south,max,mean',n,np.min(ustar_south),np.max(ustar_south),np.mean(ustar_south))

   ustar_mean_s=np.mean(ustar)

   kwall=cmu**(-0.5)*ustar**2

# fix k at 1st cell
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=np.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*kwall

   del  yplus

# north wall
   u2d_wall=abs(u3d[:,-1,:])  # first cell

# normalize
   dy_wall=dist3d[0,-1,0]  # first cell


   print('u2d_wall,dy_wall',u2d_wall[0,0],dy_wall)

   for n in range(0,1):

      yplus=ustar_north*dy_wall/viscos

#flatten
      yplus=yplus.flatten()

      print('north n,yplus_min,max,mean',n,np.min(yplus),np.max(yplus),np.mean(yplus))

# count values larger/smaller than max/min
      yplus_min_number= (yplus < yplus_min).sum()
      yplus_max_number= (yplus > yplus_max).sum()

      print('north: yplus_min_number',yplus_min_number)
      print('north: yplus_max_number',yplus_max_number)


# set limits
      yplus=np.minimum(yplus,yplus_max)
      yplus=np.maximum(yplus,yplus_min)

#size
      N=len(yplus)
# re-scale
      yplus=yplus.reshape(-1, 1)

      X=np.zeros((N,1))
      if gpu:
         yplus_np= np.asnumpy(yplus)
         X_np = np.asnumpy(X)
      else:
         yplus_np= yplus
         X_np = X
      yplus=scaler_yplus.transform(yplus_np)

# predict
      X_np[:,0]=yplus_np[:,0]

# compute ustar
      y_svr = model.predict(X_np) # uplus

      if gpu:
        y_svr_cp = np.asarray(y_svr)
      else:
        y_svr_cp = y_svr

      uplus=np.reshape(y_svr_cp,(ni,nk))
 
      ustar=np.divide(u2d_wall,uplus)
      ustar_north=ustar

      print('n,north ustar_north,max,mean',n,np.min(ustar_north),np.max(ustar_north),np.mean(ustar_north))

   ustar_mean_n=np.mean(ustar)

   kwall=cmu**(-0.5)*ustar**2

   print('modify_k, ustar[2,2]',ustar[2,2])

# fix k at 1st cell
   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=np.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*kwall

   print(f"{'time wall function ML: '}{time.time()-start_time_wallf:.2e}")

   if iter == 0:
      if itstep == 0:
         ustar2_time_aver=0
         ustar_time_aver=0
         uplus2_time_aver=0
         uplus_time_aver=0
         itstep_ustar= 0
         print('file1 opened')
         ustar_std=np.std(ustar)
         np.savetxt('ustar-yplus-history.dat',\
            np.c_[itstep,ustar_mean_s,ustar_mean_n,ustar_std])
         ustar_std=np.std(ustar)

      else:
         print('file1 printed')
         itstep_ustar= itstep_ustar+1
         ustar2_time_aver= ustar2_time_aver+np.sum(ustar**2)
         ustar_time_aver= ustar_time_aver+np.sum(ustar)
         uplus2_time_aver= uplus2_time_aver+np.sum(uplus**2)
         uplus_time_aver= uplus_time_aver+np.sum(uplus)
         if itstep%100 == 0:
            np.savetxt('ustar2-ustar.dat', np.c_[ustar2_time_aver,ustar_time_aver,uplus2_time_aver,uplus_time_aver,itstep_ustar])

         print('file1 printed')
         ustar_std=np.std(ustar)
         with open('ustar-yplus-history.dat','ab') as f:
            np.savetxt(f, \
               np.c_[itstep,ustar_mean_s,ustar_mean_n,ustar_std])

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_omega():

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_eps():

# south wall
   dy=dist3d[0,0,0] # first cell
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed in first cell

# fix eps at 1st cell
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=np.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*ustar**3/kappa/dy

# north wall
   dy=dist3d[0,-1,0]
   ustar=cmu**0.25*k3d[:,-1,:]**0.5  # k is fixed in first cell

   print('modify_eps, ustar[-2,-2],[2,2]',ustar[-2,-2],ustar[2,2])

# fix eps at 1st cell
   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=np.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*ustar**3/kappa/dy

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_omega():

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def compute_ustar(wdist,velabs,ustar,elog,n):

      for i in range(0,n):
         arg=max(elog*ustar*wdist/viscos,10.)
         ustar=kappa*velabs/np.log(arg)

      xyplus=ustar*wdist/viscos
      if xyplus < 11.63:
         ustar=(viscos*velabs/wdist)**0.5

      print('yplus,ustar,dist,vel',xyplus,ustar,wdist,velabs)

      return ustar

