def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):

   data =xp.loadtxt('y_u_k_om_uv_16000-RANS-code.txt')
   y_rans_in=data[:,0]
   y_rans=yp2d[0,:]
   u_rans_in=data[:,1]
   u_rans=xp.interp(y_rans, y_rans_in, u_rans_in)
# make it 2D
   u_rans=xp.repeat(u_rans[:,None], repeats=nk, axis=1)

   k_rans_in=data[:,2]
   k_rans=xp.interp(y_rans, y_rans_in, k_rans_in)
# make it 2D
   k_rans=xp.repeat(k_rans[:,None], repeats=nk, axis=1)

   om_rans_in=data[:,3]
   om_rans=xp.interp(y_rans, y_rans_in, om_rans_in)
# make it 2D
   om_rans=xp.repeat(om_rans[:,None], repeats=nk, axis=1)

   eps_rans=cmu*k_rans*om_rans
# set inlet field in entre domain
   u3d=xp.repeat(u_rans[None,:,:], repeats=ni, axis=0)
   k3d=0.2*xp.repeat(k_rans[None,:,:], repeats=ni, axis=0)
   eps3d=xp.repeat(eps_rans[None,:,:], repeats=ni, axis=0)

   vis3d=cmu*k3d**2/eps3d+viscos
   
   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d,dist3d

def stress_EARSM(dudy,k_rans,om_rans):
    nj = xp.size(dudy,0)
    uu=xp.zeros(nj)
    vv=xp.zeros(nj)
    ww=xp.zeros(nj)
    uv=xp.zeros(nj)

    for j in range (0,nj):

        diss=.09*k_rans[j]*om_rans[j]
        rk=k_rans[j]
        ttau=rk/diss

        om12=ttau*0.5*dudy[j]
        om21=-om12
        om22=0.
        om11=0.
        s11=0.
        s12=ttau*0.5*dudy[j]
        s21=s12
        s22=0.
        s33=0.
        vor=(-2.*om12**2)
        str1=(s11**2+s12**2+s21**2+s22**2)

      
        cpr1=9./4.*(1.8-1.) #OLD
        
        Neq = 81./20.
        b1eq = -6./5.*(Neq/(Neq**2-2.*str1))
        cdiff = 2.2
        #cpr1 = 9./5. + 9./4.*cdiff*xp.maximum(1.+b1eq*str1,0.)
        
        
        p1=(1./27.*cpr1**2+9./20.*str1-2./3.*vor)*cpr1
        p2=p1**2-(1./9.*cpr1**2+9./10.*str1+2./3.*vor)**3
      
        if p2 >  0:
           if p1-p2**0.5 >= 0:
              sigg=1.
           else:
              sigg=-1.
           un=cpr1/3.+(p1+p2**0.5)**(1./3.)+sigg*(abs(p1-p2**0.5))**(1./3.)
        else:
           un=cpr1/3.+2.*(p1**2-p2)**(1./6.)*xp.cos(1./3.*xp.arccos(p1/(xp.sqrt(p1**2-p2))))
      
        const=6./5.
        beta1=-const*un/(un**2-2.*vor)
        beta4=beta1/un

        uu[j]=2./3.*rk+rk*beta1*s11+rk*beta4*(s12*om21-om12*s21)
        vv[j]=2./3.*rk+rk*beta1*s22+rk*beta4*(s21*om12-om21*s12)
        ww[j]=2./3.*rk
        uv[j]=rk*(beta1*s12+beta4*(s11*om12-om12*s22))

    return uu,vv,ww,uv

def stress_Normal(dudy,k_rans,om_rans):
    eddy_visc = k_rans/om_rans
    uu = (2./3.)*k_rans
    vv = (2./3.)*k_rans
    ww = (2./3.)*k_rans
    uv =-eddy_visc*dudy  
    return uu,vv,ww,uv


def modify_inlet():

   global y_rans,u_rans,v_rans,k_rans,om_rans,uv_rans,zp,usynt_inlet,vsynt_inlet,wsynt_inlet,\
          uu_synt,vv_synt,ww_synt,uv_synt,two_corr,u_time,uv_aver,w_synt,k_bc_west,eps_bc_west,\
          r11,r22,r33,r12,r13,r23,xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,dz2d,uin,k_bc_west,\
          eps_bc_west,om_bc_west
  
   global usynt,vsynt,wsynt

   if itstep == 0:
      two_corr=xp.zeros(2*nk-1)
      uv_aver=0
      u_time=xp.zeros(ntstep)
      y_u_k_om  =xp.loadtxt('/chalmers/users/lada/pythons-rans-code-RANS/channel-16000-cyclicx/y_u_k_om_uv_16000-RANS-code.txt')
      y_rans_in=y_u_k_om[:,0]
      y_rans=yp2d[0,:]
      u_rans_in=y_u_k_om[:,1]
      u_rans=xp.interp(y_rans, y_rans_in, u_rans_in)

      k_rans_in=y_u_k_om[:,2]
      k_rans=xp.interp(y_rans, y_rans_in, k_rans_in)

      om_rans_in=y_u_k_om[:,3]
      om_rans=xp.interp(y_rans, y_rans_in, om_rans_in)

      eps_rans=cmu*k_rans*om_rans
      eddy_visc = k_rans/om_rans
      dudy=xp.gradient(u_rans,y_rans)
      #uv=-eddy_visc*dudy  


      
      uu_in,vv_in,ww_in,uv_in = stress_EARSM(dudy,k_rans,om_rans)
      
# make it 2D
      u_rans=xp.repeat(u_rans[:,None], repeats=nk, axis=1)
      k_rans=xp.repeat(k_rans[:,None], repeats=nk, axis=1)
      eddy_visc=xp.repeat(eddy_visc[:,None], repeats=nk, axis=1)
      eps_rans=xp.repeat(eps_rans[:,None], repeats=nk, axis=1)
      om_rans=xp.repeat(om_rans[:,None], repeats=nk, axis=1)
      uu_in=xp.repeat(uu_in[:,None], repeats=nk, axis=1)
      vv_in=xp.repeat(vv_in[:,None], repeats=nk, axis=1)
      ww_in=xp.repeat(ww_in[:,None], repeats=nk, axis=1)
      uv_in=xp.repeat(uv_in[:,None], repeats=nk, axis=1)
 

      dy = xp.diff(y2d,axis=1)
      tmp = dy[0,:]
      dy2d = xp.repeat(tmp[:,None], repeats=nk, axis=1) 
      dx = xp.diff(x2d,axis=0)
      tmp = dx[0,0:-1]
      dx2d = xp.repeat(tmp[:,None], repeats=nk, axis=1)
      dz2d = xp.ones((nj,nk))*dz

        # Cell Centers
      zp = xp.linspace(0, zmax, nk)
      tmp = xp2d[0,:]
      xp2d_synt = xp.repeat(tmp[:,None], repeats=nk, axis=1) 
      tmp = yp2d[0,:]
      yp2d_synt = xp.repeat(tmp[:,None], repeats=nk, axis=1) 
      zp2d_synt=xp.repeat(zp[None,:], repeats=nj, axis=0)

      # compute wall distance
      dw =dist3d[0,:,0]
      dw2d = xp.repeat(dw[:,None], repeats=nk, axis=1) 
 
        # # Get/Compute Reynolds Stress Tensor  
      r11 = uu_in
      r22 = vv_in
      r33 = ww_in
      r12 = uv_in
      r13 = xp.zeros((nj,nk))
      r23 = xp.zeros((nj,nk))
      uin=xp.sum(u_rans*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      #usynt,vsynt,wsynt=synt_fluct(nmodes_synt,itstep,L_t_synt,y_rans,zp,uv_rans,viscos,jmirror_synt)
      usynt,vsynt,wsynt=STG(uin,viscos,itstep,dt[itstep],xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,\
                      dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans)
# correct usynt so that it is = 0 (easier to converge the p solver)
      uinc=xp.sum(usynt*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      usynt=usynt-uinc
      usynt_inlet=usynt
      vsynt_inlet=vsynt
      wsynt_inlet=wsynt
      

      uu_synt=xp.zeros(nj)
      vv_synt=xp.zeros(nj)
      ww_synt=xp.zeros(nj)
      uv_synt=xp.zeros(nj)
      
      k_bc_west=k_rans
      eps_bc_west=eps_rans
      om_rans=eps_rans/cmu/k_rans
      om_bc_west=om_rans

   usynt,vsynt,wsynt=STG(uin,viscos,itstep,dt[itstep],xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,\
                      dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans)
# correct usynt so that it is = 0 (easier to converge the p solver)
   uinc=xp.sum(usynt*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
   usynt=usynt-uinc
   u_bc_west=u_rans+usynt
   v_bc_west=vsynt
   w_bc_west=wsynt

   uu_synt=uu_synt+xp.mean(usynt**2,axis=1)
   vv_synt=vv_synt+xp.mean(vsynt**2,axis=1)
   ww_synt=ww_synt+xp.mean(wsynt**2,axis=1)
   uv_synt=uv_synt+xp.mean(usynt*vsynt,axis=1)

# update face velocity and convw at inlet
#  u3d_face_w[0,:,:]=u_bc_west
#  convw[0,:,:]=-u_bc_west*areawx[0,:,:]-v_bc_west*areawy[0,:,:]

# compute two-point corr in node 10
   two_corr=two_corr+xp.correlate(w_bc_west[10,:],w_bc_west[10,:],'full')

# sum over timesteps
   uv_aver=uv_aver+xp.mean(u_bc_west[58,:]*v_bc_west[58,:])

# compute average
   if itstep % 100 == 0:
      xp.save('two_corr_inlet',two_corr)
      xp.save('u_time',u_time)
      xp.save('uu_synt',uu_synt/(itstep+1))
      xp.save('uv_synt',uv_synt/(itstep+1))
      print('uvmean_aver at max',uv_aver/(itstep+1))

   return u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw

def modify_conv(convw,convs,convl):

   convs[:,0,:]=0
   convs[:,-1,:]=0

   return convw,convs,convl

def modify_u(su3d,sp3d):

# south wall
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed at first cell
   tauw=ustar**2
   su3d[:,0,:]=su3d[:,0,:]-tauw*areas[:,0,:]
# north wall
   ustar=cmu**0.25*k3d[:,-1,:]**0.5
   print('modify_u, ustar[-2,-2],[2,2]',ustar[-2,-2],ustar[2,2])
   tauw=ustar**2
   su3d[:,-1,:]=su3d[:,-1,:]-tauw*areas[:,0,:]

# inlet
   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*u_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*u_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def modify_v(su3d,sp3d):

# south wall
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed at first cell
   tauw=ustar**2

   xt=xp.diff(x2d[:,0])
   yt=xp.diff(y2d[:,0])
   rl=(xt**2+yt**2)**0.5
   yt=abs(yt)/rl
   if itstep == 0 and iter == 0:
      print('yt',yt)
# make it 2D
   yt=xp.repeat(yt[:,None], repeats=nk, axis=1)
#  tauw_v=tauw*yt
   tauw_v=tauw*abs(yt)
# downstream part of hump: assume V < 0 => su should be > 0.
#  tauw_v > 0, su = -tauw_s*(-1) = tauw_v  > 0 as it should
   su3d[:,0,:]=su3d[:,0,:]-tauw_v*areas[:,0,:]*xp.sign(v3d[:,0,:])


# inlet
   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*v_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*v_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d


def modify_w(su3d,sp3d):
   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*w_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*w_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound



   return su3d,sp3d

def modify_k(su3d,sp3d,gen):

   global u2prim_i0,v2prim_i0,w2prim_i0,umean_i0

   if iter == 0:
# set running-averaged inlet values to zero
      if itstep == 0:
         u2prim_i0=xp.zeros(nj)
         v2prim_i0=xp.zeros(nj)
         w2prim_i0=xp.zeros(nj)
         umean_i0=xp.zeros(nj)

# time average
      u2prim_i0=u2prim_i0+xp.mean(u3d[0,:,:]**2,axis=1)
      v2prim_i0=v2prim_i0+xp.mean(v3d[0,:,:]**2,axis=1)
      w2prim_i0=w2prim_i0+xp.mean(w3d[0,:,:]**2,axis=1)
      umean_i0=umean_i0+xp.mean(u3d[0,:,:],axis=1)

   comm_term=xp.zeros((nj,nk))
   umean=umean_i0/(itstep+1)
   u2prim=u2prim_i0/(itstep+1)-umean**2
   v2prim=v2prim_i0/(itstep+1)
   w2prim=w2prim_i0/(itstep+1)
   k_tot=0.5*(u2prim+v2prim+w2prim)+xp.mean(k3d[0,:,:],axis=1)
# make it 2D
   k_tot=xp.repeat(k_tot[:,None], repeats=nk, axis=1)

   psi_small=fk3d[0,:,:]
   term1=xp.maximum((c_eps_2-c_eps_1*psi_small)/(c_eps_2-c_eps_1),1e-10)
   fk2d_from_psi=term1**0.333

#  dfk_dx=u3d[0,:,:]*(0.4-1)/(x2d[1,1]-x2d[0,1])
   dfk_dx=u3d[0,:,:]*(fk2d_from_psi-1)/(x2d[1,1]-x2d[0,1])

# commutation term 
   comm_term_pans=k_tot*dfk_dx
   comm_min_pans=xp.min(comm_term_pans)
   pk_max=xp.max((vis3d-viscos)*gen)

   u2prim_max=xp.max(u2prim)
   v2prim_max=xp.max(v2prim)
   w2prim_max=xp.max(w2prim)

   print(f"\n{'comm_min_pans: '} {comm_min_pans:.2e}, {'pk_max: '}{pk_max:.2e}, {'u2prim_max: '}{u2prim_max:.2e}, {'v2prim_max: '}{v2prim_max:.2e}, {'w2prim_max: '}{w2prim_max:.2e}")
   sp3d[0,:,:]=sp3d[0,:,:]+xp.minimum(comm_term_pans,0)*vol[0,:,:]/k3d[0,:,:]

   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*k_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*k_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d,comm_term

def modify_eps(su3d,sp3d):

   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*eps_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-convw[0,:,:]
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*eps_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def modify_om(su3d,sp3d,comm_term):

   t_kolmog=6.*(viscos/0.09/k3d[0,:,:]/om3d[0,:,:])**0.5
   t_scale=xp.maximum(1./0.09/om3d[0,:,:],t_kolmog)
   omeg=1./0.09/t_scale

   prod_extra=-omeg/k3d[0,:,:]*comm_term
   su3d[0,:,:]=su3d[0,:,:]+xp.maximum(prod_extra,0.)*vol[0,:,:]

   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*om_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*om_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def fix_omega():

#  aw3d[:,0,:]=0
#  ae3d[:,0,:]=0
#  as3d[:,0,:]=0
#  an3d[:,0,:]=0
#  al3d[:,0,:]=0
#  ah3d[:,0,:]=0
#  ap3d[:,0,:]=1
#  su3d[:,0,:]=om_bc_south

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_outlet(convw):

# inlet
   flow_in=xp.sum(convw[0,:,:])
#  flow_out=xp.sum(convw[-1,:,:])
   flow_out=xp.sum(convw[-2,:,:])
   area_out=xp.sum(areaw[-1,:,:])

   uinc=(flow_in-flow_out)/area_out
   ares=areaw[-1,:,:]
#  convw[-1,:,:]=convw[-1,:,:]+uinc*ares
   convw[-1,:,:]=convw[-2,:,:]+uinc*ares

   print('area_out',area_out)

   flow_out_new=xp.sum(convw[-1,:,:])

   print('flow_in',flow_in,'flow_out',flow_out,'area_out',area_out,'flow_out_new',flow_out_new,'uinc:',uinc)

   return convw,u_bc_east

def modify_fk(fk3d):

   global f_e_mean, f_b_mean, f_dt_mean

   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=xp.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=xp.repeat(dy[:,:,None],repeats=nk,axis=2)

   l_temp=xp.maximum(l_dist,l_max)
   l_temp=xp.maximum(l_temp,dy)
   l_iddes=xp.minimum(l_temp,delta_max)
#  l_iddes=xp.minimum(xp.maximum(l_dist,l_max,dy),delta_max)

   ueps=(eps3d*viscos)**0.25
   ystar=ueps*dist3d/viscos
   rt=k3d**2/eps3d/viscos
   fdampf2=((1.-xp.exp(-ystar/3.1))**2)*(1.-0.3*xp.exp(-(rt/6.5)**2))
   fmu=((1.-xp.exp(-ystar/14.))**2)*(1.+5./rt**0.75*xp.exp(-(rt/200.)**2))
   fmu=xp.minimum(fmu,1.)

   psi=xp.minimum(10,(fdampf2*fmu)**(-0.75))

   l_c=psi*cdes*l_iddes  #eq. 9

   vist=vis3d-viscos
   denom=kappa**2*dist3d**2*gen**0.5

   r_dt=vist/denom  #eq. 22
   r_dl=viscos/denom  #eq. 23

   f_t=xp.tanh((c_t**2*r_dt)**3)
   f_l=xp.tanh((c_l**2*r_dl)**10)

   f_e2=1.-xp.maximum(f_t,f_l) #eq. 19

   alpha=0.25-dist3d/delta_max

   f_e1=xp.where(alpha <= 0,2*xp.exp(-9*alpha**2),2*xp.exp(-11.09*alpha**2))

   f_b=  xp.minimum(2.*xp.exp(-9*alpha**2),1.)

   f_dt=1.-xp.tanh((8.*r_dt)**3)

   f_e=xp.maximum(f_e1-1.,0.)*psi*f_e2

   f_d=xp.maximum((1.-f_dt),f_b)

   l_u=k3d**1.5/eps3d

   l_tilde=f_d*(1+f_e)*l_u+(1-f_d)*l_c

   fk3d=l_u/l_tilde

   if iter == 0 and itstep ==0:
      f_e_mean=xp.zeros((ni,nj))
      f_b_mean=xp.zeros((ni,nj))
      f_dt_mean=xp.zeros((ni,nj))



   if iter == 0 and itstep%itstep_stats == 0 and itstep >= itstep_start:
      f_e_mean,f_b_mean,f_dt_mean= aver_iddes(f_e,f_b,f_dt,f_e_mean,f_b_mean,f_dt_mean)

   if iter == 0 and (itstep == ntstep-1 or itstep%itstep_save == 0):
      print('IDDES functions saved')
      xp.save('f_e_mean', f_e_mean)
      xp.save('f_b_mean',f_b_mean)
      xp.save('f_dt_mean',f_dt_mean)


   return fk3d


def aver_iddes(f_e,f_b,f_dt,f_e_mean,f_b_mean,f_dt_mean):

   f_e_mean=f_e_mean+xp.mean(f_e,axis=2)
   f_b_mean=f_b_mean+xp.mean(f_b,axis=2)
   f_dt_mean=f_dt_mean+xp.mean(f_dt,axis=2)

   return f_e_mean,f_b_mean,f_dt_mean

def fix_k():


    from joblib import dump, load
    from sklearn.preprocessing import MinMaxScaler
    from scipy.spatial import KDTree
 
  
    global yplus_min,yplus_max,scaler_yplus,scaler_pplus,neural_net,ubulk
    global ustar_mean_s,counter_ustar,count
    global uplus_min,uplus_max,scaler_uplus, tree, X
    global yplus_np_target,uplus_np_target,ustar2_mean_s,x_target,yplus_target
    global inds, ds_np,ds, nn_np, nn, ds,x_location,y_location,x2_location
# NN
# load data
# inlet
    if iter == 0 and itstep == 0: 
#      folder = '../hump-IDDES-ni-583-nk-64-go4hybrid-mesh-STG-dt-0.002-full-GPU/'
       folder = './'
       data_10 = xp.loadtxt(str(folder)+'x-yplus-uplus-ustar-ALL-i-j-hump-incl-backflow.txt')

       x_target = data_10[:,0]
       yplus_target = abs(data_10[:,1])
       uplus_target = data_10[:,2]

       yplus_max = xp.max(yplus_target)
       yplus_min = xp.min(yplus_target)
       uplus_min = xp.min(uplus_target)
       uplus_max = xp.max(uplus_target)

       print('yplus_max, yplus_min, uplus_min, uplus_max',\
              yplus_max, yplus_min, uplus_min, uplus_max)

       uplus_target = uplus_target.reshape(-1,1)
       yplus_target = yplus_target.reshape(-1,1)
# use MinMax scaler
       scaler_yplus = MinMaxScaler()
       scaler_uplus = MinMaxScaler()

       X=xp.zeros((len(yplus_target),2))

       if gpu:
         yplus_np_target = xp.asnumpy(yplus_target)
         uplus_np_target = xp.asnumpy(uplus_target)
         X_np = xp.asnumpy(X)
       else:
         yplus_np_target = yplus_target
         uplus_np_target = uplus_target
         X_np = X

       X_np[:,0] = scaler_uplus.fit_transform(uplus_np_target)[:,0]
       X_np[:,1] = scaler_yplus.fit_transform(yplus_np_target)[:,0]

       ustar_mean_s = xp.zeros(ni)
       ustar2_mean_s = xp.zeros(ni)
       u2d_wall=u3d[:,0,:] # first cell
       u_mean_k=xp.mean(u2d_wall,axis=1)
       u_wall_mean = u_mean_k
       ustar=cmu**0.25*k3d[:,0,:]**0.5
       u1d = xp.mean(u3d[0,:,:],axis=1)
#      ubulk = xp.trapz(u1d,yp2d[0,:])/max(y2d[0,:])
       ubulk= xp.sum(u1d*xp.diff(y2d[0,:]))/max(y2d[0,:])
       print('ubulk',ubulk)

       tree = KDTree(X_np) # data we want to match

       x_location=xp.zeros((ni*nk))
       x2_location=xp.zeros((ni*nk))
       y_location=xp.zeros((ni*nk))

# take old ustar from k=cmu**(-0.5)*ustar**2
    ustar=cmu**0.25*k3d[:,0,:]**0.5

    j_wall = 0 # 1st cell
    u2d_wall=u3d[:,j_wall,:]  # first cell

# average ustar to be used in p+
    ustar_mean_k=xp.mean(ustar,axis=1)

    dy=dist3d[:,j_wall,:] # first cell
    yplus_south = ustar*dy/viscos

    uplus_south = u2d_wall/ustar
    ustar_south = ustar

# count values larger/smaller than max/min
    yplus_min_number= (yplus_south < yplus_min).sum()
    yplus_max_number= (yplus_south > yplus_max).sum()
    print('south: yplus_min_number',yplus_min_number)
    print('south: yplus_max_number',yplus_max_number)

    uplus_min_number= (uplus_south < uplus_min).sum()
    uplus_max_number= (uplus_south > uplus_max).sum()
    print('south: uplus_min_number',uplus_min_number)
    print('south: uplus_max_number',uplus_max_number)

    ustar_min =xp.min(ustar)
    ustar_max =xp.max(ustar)
    print('dy,ustar_min,max,mean',dy[0,0],ustar_min,ustar_max,xp.mean(ustar))
    print('yplus, min,max,mean',xp.min(yplus_south),\
          xp.max(yplus_south),xp.mean(yplus_south))
    print('uplus, min,max,mean',xp.min(uplus_south),\
          xp.max(uplus_south),xp.mean(uplus_south))

# set limits
    uplus_south=xp.minimum(uplus_south,uplus_max)
    uplus_south=xp.maximum(uplus_south,uplus_min)

    yplus_south=xp.minimum(yplus_south,yplus_max)
    yplus_south=xp.maximum(yplus_south,yplus_min)

    N_predict=ni*nk
    y = uplus_south
    y = y.reshape(-1,1)
    yplus = yplus_south.reshape(-1,1)
    uplus = uplus_south.reshape(-1,1)
    n_test = len(uplus)
    x=xp.zeros((n_test,2))
    if gpu:
       yplus_np = xp.asnumpy(yplus)
       uplus_np = xp.asnumpy(uplus)
       x_np = xp.asnumpy(x)
    else:
       yplus_np = yplus
       uplus_np = uplus
       x_np = x

    x_np[:,0] = scaler_uplus.transform(uplus_np)[:,0]
    x_np[:,1] = scaler_yplus.transform(yplus_np)[:,0]
#ds, inds =  tree.query(x, 1) # finds one nearest neighbors at n_test samples  at distance d
    ds, inds =  tree.query(x_np, 5) # finds two nearest neighbors at n_test samples  at distance d
    if iter == 0 and itstep >= itstep_start:
       x_location = x_location + x_target[inds[:,0]]
       x2_location = x2_location + x_target[inds[:,0]]**2
       y_location = y_location + yplus_target[inds[:,0],0]
    if gpu:
       ds_np = xp.asarray(ds)
       inds_np = xp.asarray(inds)
    else:
       ds_np = ds
       inds_np = inds
    n_len0 = len(inds_np[:,0])
    yplus_kdtree =xp.zeros(n_len0)
    uplus_kdtree =xp.zeros(n_len0)
    n_len = len(inds_np[0,:])
    print('type(inds_np),type(ds_np)',type(inds_np),type(ds_np))
    for nn in range(n_len):
       temp1 =   yplus_np_target[inds[:,nn],0]/ds[:,nn]
       tempu =   uplus_np_target[inds[:,nn],0]/ds[:,nn]
       if gpu:
          temp1_np = xp.asarray(temp1)
          tempu_np = xp.asarray(tempu)
       else:
          temp1_np = temp1
          tempu_np = tempu
       yplus_kdtree = yplus_kdtree +  temp1_np
       uplus_kdtree = uplus_kdtree +  tempu_np
    yplus_kdtree = yplus_kdtree/xp.sum(1/ds_np,axis=1)
    uplus_kdtree = uplus_kdtree/xp.sum(1/ds_np,axis=1)

    if gpu:
       yplus_kdtree_cp = xp.asarray(yplus_kdtree)
       uplus_kdtree_cp = xp.asarray(uplus_kdtree)
    else:
       yplus_kdtree_cp = yplus_kdtree
       uplus_kdtree_cp = uplus_kdtree
#yplus_kdtree = (yplus_np[inds[:,0]]/ds[:,0]+yplus_np[inds[:,1]]/ds[:,1] \
#              +yplus_np[inds[:,2]]/ds[:,2]+yplus_np[inds[:,3]]/ds[:,3])/xp.sum(1/ds,axis=1)
    print('yplus_kdtree.shape,inds.shape',yplus_kdtree.shape,inds.shape)
    yplus_predict = xp.reshape(yplus_kdtree_cp,(ni,nk))
    uplus_predict = xp.reshape(uplus_kdtree_cp,(ni,nk))
    ustar=yplus_predict*viscos/dy
    yplus = yplus_predict
    ustar_5 = (viscos*abs(u2d_wall)/dy)**0.5
    ustar=xp.where(yplus< 5,ustar_5,ustar)

    number_yplus_linear= (yplus< 5).sum()

    print('number y+ < 5',number_yplus_linear)

#   if itstep%100 == 0:
#      xp.savetxt('umean.dat', u_mean)

# use Reichardd the first 5000 time steps
#   if itstep < 1000:
#      if itstep ==  997:
#      if itstep ==  0:
#        xp.savetxt('xp-Cf-yplus-re.dat', xp.c_[xp2d[:,1],Cf_south[:,0],yplus[:,0],re_south[:,0],\
#            ustar[:,0],u2d_wall[:,0],yplus_predict[:,0]])
#      ustar=xp.ones((ni,nk))
#      velabs=abs(u3d[:,0,:])
#      ustar=compute_ustar_reich_solve(dy/viscos,velabs,ustar)
    if iter == 0:
      if itstep == 0:
         ustar_mean_s = xp.zeros(ni)
         ustar2_mean_s = xp.zeros(ni)
         count = 0
         inds_2d = xp.reshape(inds[:,0],(ni,nk))
         i = 58
         cf0=ustar[i,:]**2*xp.sign(u2d_wall[i,:])
         xp.savetxt('min-cf-hist-58.txt', \
        xp.c_[yplus_south[i,:],yplus_predict[i,:],uplus_south[i,:],uplus_predict[i,:],x_target[inds_2d[i,:]],inds_2d[i,:],cf0])


    if iter == 0:
      if itstep == 0:
         ustar_mean_s = xp.zeros(ni)
         ustar2_mean_s = xp.zeros(ni)
         count = 0
      if itstep > itstep_start:
         count = count +1
         ustar_mean_k=xp.mean(ustar*xp.sign(u2d_wall),axis=1)
         ustar2_mean_k=xp.mean(ustar**2*xp.sign(u2d_wall),axis=1)
         ustar_mean_old = ustar_mean_s
         ustar_mean_s=ustar_mean_s + ustar_mean_k
         ustar2_mean_s=ustar2_mean_s + ustar2_mean_k
         print('ustar_mean_old,ustar_mean_s[0],ustar_mean_s[0]/step',ustar_mean_old[0],ustar_mean_s[0],ustar_mean_s[0]/count)
         xp2d[-1,1] = count
         if itstep%100 == 0:
#        if itstep%1 == 0:
            xp.savetxt('ustar-vs-x.dat', xp.c_[xp2d[:,1],ustar_mean_s,ustar2_mean_s])
            xp.savetxt('x-y-target-location.dat', xp.c_[x_location,y_location,x2_location])

            inds_2d = xp.reshape(inds[:,0],(ni,nk))

            i = 58
            cf0=ustar[i,:]**2*xp.sign(u2d_wall[i,:])
            f = open('min-cf-hist-58.txt','ab')
            xp.savetxt(f,\
        xp.c_[yplus_south[i,:],yplus_predict[i,:],uplus_south[i,:],uplus_predict[i,:],x_target[inds_2d[i,:]],inds_2d[i,:],cf0])
            f.close()




    kwall=cmu**(-0.5)*ustar**2

    aw3d[:,0,:]=0
    ae3d[:,0,:]=0
    as3d[:,0,:]=0
    an3d[:,0,:]=0
    al3d[:,0,:]=0
    ah3d[:,0,:]=0
    ap_max=xp.max(ap3d)
    ap3d[:,0,:]=ap_max
    su3d[:,0,:]=ap_max*kwall

    aw3d[:,-1,:]=0
    ae3d[:,-1,:]=0
    as3d[:,-1,:]=0
    an3d[:,-1,:]=0
    al3d[:,-1,:]=0
    ah3d[:,-1,:]=0
    ap3d[:,-1,:]=ap_max
    su3d[:,-1,:]=ap_max*kwall

    return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_eps():

# south wall
   dy=dist3d[0,0,0] # first cell
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed in first cell

# fix eps at 1st cell
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*ustar**3/kappa/dy

# north wall
   dy=dist3d[0,-1,0]
   ustar=cmu**0.25*k3d[:,-1,:]**0.5  # k is fixed in first cell

   print('modify_eps, ustar[-2,-2],[2,2]',ustar[-2,-2],ustar[2,2])

# fix eps at 1st cell
   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*ustar**3/kappa/dy



   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_vis(vis3d):

   return vis3d

def solve_ustar_reich(ustar,velabs,wdist):
    import numpy as xp

    return  ustar-velabs/(1/0.4*xp.log(1+0.4*ustar*wdist)+7.8*(1-xp.exp(-ustar*wdist/11))-ustar*wdist*xp.exp(-ustar*wdist/3))

def compute_ustar_reich_solve(wdist,velabs,ustar):
   from scipy.optimize import fsolve,root,newton
   yplus=ustar*wdist/viscos
   ustar=ustar.flatten()
   velabs=velabs.flatten()
   wdist=wdist.flatten()
   if gpu:
      velabs_np = xp.asnumpy(velabs)
      ustar_np = xp.asnumpy(ustar)
      wdist_np = xp.asnumpy(wdist)
   else:
      velabs_np = velabs
      ustar_np = ustar
      wdist_np = wdist
   ustar = newton(solve_ustar_reich,x0=ustar_np,args=(velabs_np,wdist_np))
   ustar = xp.asarray(ustar)
   ustar=xp.reshape(ustar,(ni,nk))

   return ustar

