
def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):
   
   global dist3d
# re-define dist3d = distance frok south wall
   ywall_s=0.5*(y2d[0:-1,0]+y2d[1:,0])
   dist_s=yp2d-ywall_s[:,None]
   dist=dist_s
   dist3d=np.repeat(dist[:,:,None], repeats=nk, axis=2)

# start from RANS
   name ='../hump-2D-RANS-AKN-go4hybrid-mesh-inlet-at-x-m2.1/'
   itstep,dummy1,dummy2=np.load(str(name)+'itstep.npy')
   u2d=np.load(str(name)+'u_averaged.npy')/itstep
   v2d=np.load(str(name)+'v_averaged.npy')/itstep
   p2d=np.load(str(name)+'p_averaged.npy')/itstep
   k2d=np.load(str(name)+'k_averaged.npy')/itstep
   eps2d=np.load(str(name)+'eps_averaged.npy')/itstep

   vismax=np.max(0.09*k2d**2/eps2d)
   print('vismax',vismax/viscos)

# change from RANS to LES
   k2d=0.4*k2d

   vismax_les=np.max(0.09*k2d**2/eps2d)
   print('vismax_les',vismax_les/viscos)

# make 2d to 3d
   u3d=np.repeat(u2d[:,:,None], repeats=nk, axis=2)
   v3d=np.repeat(v2d[:,:,None], repeats=nk, axis=2)
   p3d=np.repeat(p2d[:,:,None], repeats=nk, axis=2)
   k3d=np.repeat(k2d[:,:,None], repeats=nk, axis=2)
   eps3d=np.repeat(eps2d[:,:,None], repeats=nk, axis=2)
   w3d=np.zeros((ni,nj,nk))

   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d,dist3d

def stress_EARSM(dudy,k_rans,om_rans):
    nj = np.size(dudy,0)
    uu=np.zeros(nj)
    vv=np.zeros(nj)
    ww=np.zeros(nj)
    uv=np.zeros(nj)

    for j in range (0,nj):

        diss=.09*k_rans[j]*om_rans[j]
        rk=k_rans[j]
        ttau=rk/diss

        om12=ttau*0.5*dudy[j]
        om21=-om12
        om22=0.
        om11=0.
        s11=0.
        s12=ttau*0.5*dudy[j]
        s21=s12
        s22=0.
        s33=0.
        vor=(-2.*om12**2)
        str1=(s11**2+s12**2+s21**2+s22**2)

      
        cpr1=9./4.*(1.8-1.) #OLD
        
        Neq = 81./20.
        b1eq = -6./5.*(Neq/(Neq**2-2.*str1))
        cdiff = 2.2
        #cpr1 = 9./5. + 9./4.*cdiff*np.maximum(1.+b1eq*str1,0.)
        
        
        p1=(1./27.*cpr1**2+9./20.*str1-2./3.*vor)*cpr1
        p2=p1**2-(1./9.*cpr1**2+9./10.*str1+2./3.*vor)**3
      
        if p2 >  0:
           if p1-p2**0.5 >= 0:
              sigg=1.
           else:
              sigg=-1.
           un=cpr1/3.+(p1+p2**0.5)**(1./3.)+sigg*(abs(p1-p2**0.5))**(1./3.)
        else:
           un=cpr1/3.+2.*(p1**2-p2)**(1./6.)*np.cos(1./3.*np.arccos(p1/(np.sqrt(p1**2-p2))))
      
        const=6./5.
        beta1=-const*un/(un**2-2.*vor)
        beta4=beta1/un

        uu[j]=2./3.*rk+rk*beta1*s11+rk*beta4*(s12*om21-om12*s21)
        vv[j]=2./3.*rk+rk*beta1*s22+rk*beta4*(s21*om12-om21*s12)
        ww[j]=2./3.*rk
        uv[j]=rk*(beta1*s12+beta4*(s11*om12-om12*s22))

    return uu,vv,ww,uv

def stress_Normal(dudy,k_rans,om_rans):
    eddy_visc = k_rans/om_rans
    uu = (2./3.)*k_rans
    vv = (2./3.)*k_rans
    ww = (2./3.)*k_rans
    uv =-eddy_visc*dudy  
    return uu,vv,ww,uv


def modify_inlet():

   global y_rans,u_rans,v_rans,k_rans,om_rans,uv_rans,zp,usynt_inlet,vsynt_inlet,wsynt_inlet,\
          uu_synt,vv_synt,ww_synt,uv_synt,two_corr,u_time,uv_aver,w_synt,k_bc_west,eps_bc_west,\
          r11,r22,r33,r12,r13,r23,xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,dz2d,uin,k_bc_west,\
          eps_bc_west,om_bc_west
  
   global usynt,vsynt,wsynt

   if itstep == 0:
      two_corr=np.zeros(2*nk-1)
      uv_aver=0
      u_time=np.zeros(ntstep)
      data=np.loadtxt('y_u_v_k_om_uv_hump.dat')

      y_rans_in=data[:,0]
      u_rans_in=data[:,1]
      k_rans_in=data[:,3]
      om_rans_in=data[:,4]
      uv_rans_in=np.abs(data[:,5])

      y_rans=yp2d[0,:]

      u_rans=np.interp(y_rans, y_rans_in, u_rans_in)
      k_rans=np.interp(y_rans, y_rans_in, k_rans_in)
      om_rans=np.interp(y_rans, y_rans_in, om_rans_in)
      uv_rans=np.interp(y_rans, y_rans_in, uv_rans_in)

      eps_rans=0.09*om_rans*k_rans
      eddy_visc = k_rans/om_rans
      dudy=np.gradient(u_rans,y_rans)
      #uv=-eddy_visc*dudy  
      
      uu_in,vv_in,ww_in,uv_in = stress_EARSM(dudy,k_rans,om_rans)
      
# make it 2D
      u_rans=np.repeat(u_rans[:,None], repeats=nk, axis=1)
      k_rans=np.repeat(k_rans[:,None], repeats=nk, axis=1)
      eddy_visc=np.repeat(eddy_visc[:,None], repeats=nk, axis=1)
      eps_rans=np.repeat(eps_rans[:,None], repeats=nk, axis=1)
      om_rans=np.repeat(om_rans[:,None], repeats=nk, axis=1)
      uu_in=np.repeat(uu_in[:,None], repeats=nk, axis=1)
      vv_in=np.repeat(vv_in[:,None], repeats=nk, axis=1)
      ww_in=np.repeat(ww_in[:,None], repeats=nk, axis=1)
      uv_in=np.repeat(uv_in[:,None], repeats=nk, axis=1)
 

      dy = np.diff(y2d,axis=1)
      tmp = dy[0,:]
      dy2d = np.repeat(tmp[:,None], repeats=nk, axis=1) 
      dx = np.diff(x2d,axis=0)
      tmp = dx[0,0:-1]
      dx2d = np.repeat(tmp[:,None], repeats=nk, axis=1)
      dz2d = np.ones((nj,nk))*dz

        # Cell Centers
      zp = np.linspace(0, zmax, nk)
      tmp = xp2d[0,:]
      xp2d_synt = np.repeat(tmp[:,None], repeats=nk, axis=1) 
      tmp = yp2d[0,:]
      yp2d_synt = np.repeat(tmp[:,None], repeats=nk, axis=1) 
      zp2d_synt=np.repeat(zp[None,:], repeats=nj, axis=0)


      # compute wall distance
      dw =dist3d[0,:,0]
      dw2d = np.repeat(dw[:,None], repeats=nk, axis=1) 
 
        # # Get/Compute Reynolds Stress Tensor  
      r11 = uu_in
      r22 = vv_in
      r33 = ww_in
      r12 = uv_in
      r13 = np.zeros((nj,nk))
      r23 = np.zeros((nj,nk))
      uin=np.sum(u_rans*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      #usynt,vsynt,wsynt=synt_fluct(nmodes_synt,itstep,L_t_synt,y_rans,zp,uv_rans,viscos,jmirror_synt)
      usynt,vsynt,wsynt=STG(uin,viscos,itstep,dt[itstep],xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,\
                      dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans)
# correct usynt so that it is = 0 (easier to converge the p solver)
      uinc=np.sum(usynt*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      usynt=usynt-uinc
      usynt_inlet=usynt
      vsynt_inlet=vsynt
      wsynt_inlet=wsynt
      

      uu_synt=np.zeros(nj)
      vv_synt=np.zeros(nj)
      ww_synt=np.zeros(nj)
      uv_synt=np.zeros(nj)
      
      k_bc_west=k_rans
      eps_bc_west=eps_rans
      om_rans=eps_rans/cmu/k_rans
      om_bc_west=om_rans

   usynt,vsynt,wsynt=STG(uin,viscos,itstep,dt[itstep],xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,\
                      dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans)
# correct usynt so that it is = 0 (easier to converge the p solver)
   uinc=np.sum(usynt*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
   usynt=usynt-uinc
   u_bc_west=u_rans+usynt
   v_bc_west=vsynt
   w_bc_west=wsynt

   uu_synt=uu_synt+np.mean(usynt**2,axis=1)
   vv_synt=vv_synt+np.mean(vsynt**2,axis=1)
   ww_synt=ww_synt+np.mean(wsynt**2,axis=1)
   uv_synt=uv_synt+np.mean(usynt*vsynt,axis=1)

# update face velocity and convw at inlet
#  u3d_face_w[0,:,:]=u_bc_west
#  convw[0,:,:]=-u_bc_west*areawx[0,:,:]-v_bc_west*areawy[0,:,:]

# compute two-point corr in node 10
   two_corr=two_corr+np.correlate(w_bc_west[10,:],w_bc_west[10,:],'full')

# sum over timesteps
   uv_aver=uv_aver+np.mean(u_bc_west[58,:]*v_bc_west[58,:])

# compute average
   if itstep % 100 == 0:
      np.save('two_corr_inlet',two_corr)
      np.save('u_time',u_time)
      np.save('uu_synt',uu_synt/(itstep+1))
      np.save('uv_synt',uv_synt/(itstep+1))
      print('uvmean_aver at max',uv_aver/(itstep+1))

   return u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw

def modify_conv(convw,convs,convl):

   convs[:,0,:]=0
   convs[:,-1,:]=0

   return convw,convs,convl

def modify_u(su3d,sp3d):

   global file_065_j20,file_080_j20,file_110_j20,file_130_j20,file_065_j40,file_080_j40,file_110_j40,file_130_j40,file_065_j60,file_080_j60,file_110_j60,file_130_j60

   su3d[0,:,:]= su3d[0,:,:]+np.maximum(convw[0,:,:],0)*u_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-np.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*u_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d


def modify_v(su3d,sp3d):
   su3d[0,:,:]= su3d[0,:,:]+np.maximum(convw[0,:,:],0)*v_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-np.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*v_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound



   return su3d,sp3d


def modify_w(su3d,sp3d):
   su3d[0,:,:]= su3d[0,:,:]+np.maximum(convw[0,:,:],0)*w_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-np.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*w_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound



   return su3d,sp3d

def modify_k(su3d,sp3d,gen):

   global u2prim_i0,v2prim_i0,w2prim_i0,umean_i0

   if iter == 0:
# set running-averaged inlet values to zero
      if itstep == 0:
         u2prim_i0=np.zeros(nj)
         v2prim_i0=np.zeros(nj)
         w2prim_i0=np.zeros(nj)
         umean_i0=np.zeros(nj)

# time average
      u2prim_i0=u2prim_i0+np.mean(u3d[0,:,:]**2,axis=1)
      v2prim_i0=v2prim_i0+np.mean(v3d[0,:,:]**2,axis=1)
      w2prim_i0=w2prim_i0+np.mean(w3d[0,:,:]**2,axis=1)
      umean_i0=umean_i0+np.mean(u3d[0,:,:],axis=1)

   comm_term=np.zeros((nj,nk))
   umean=umean_i0/(itstep+1)
   u2prim=u2prim_i0/(itstep+1)-umean**2
   v2prim=v2prim_i0/(itstep+1)
   w2prim=w2prim_i0/(itstep+1)
   k_tot=0.5*(u2prim+v2prim+w2prim)+np.mean(k3d[0,:,:],axis=1)
# make it 2D
   k_tot=np.repeat(k_tot[:,None], repeats=nk, axis=1)

   psi_small=fk3d[0,:,:]
   term1=np.maximum((c_eps_2-c_eps_1*psi_small)/(c_eps_2-c_eps_1),1e-10)
   fk2d_from_psi=term1**0.333

#  dfk_dx=u3d[0,:,:]*(0.4-1)/(x2d[1,1]-x2d[0,1])
   dfk_dx=u3d[0,:,:]*(fk2d_from_psi-1)/(x2d[1,1]-x2d[0,1])

# commutation term 
   comm_term_pans=k_tot*dfk_dx
   comm_min_pans=np.min(comm_term_pans)
   pk_max=np.max((vis3d-viscos)*gen)

   u2prim_max=np.max(u2prim)
   v2prim_max=np.max(v2prim)
   w2prim_max=np.max(w2prim)

   print(f"\n{'comm_min_pans: '} {comm_min_pans:.2e}, {'pk_max: '}{pk_max:.2e}, {'u2prim_max: '}{u2prim_max:.2e}, {'v2prim_max: '}{v2prim_max:.2e}, {'w2prim_max: '}{w2prim_max:.2e}")
   sp3d[0,:,:]=sp3d[0,:,:]+np.minimum(comm_term_pans,0)*vol[0,:,:]/k3d[0,:,:]

   su3d[0,:,:]= su3d[0,:,:]+np.maximum(convw[0,:,:],0)*k_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-np.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*k_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d,comm_term

def modify_eps(su3d,sp3d):

   su3d[0,:,:]= su3d[0,:,:]+np.maximum(convw[0,:,:],0)*eps_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-convw[0,:,:]
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*eps_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def modify_om(su3d,sp3d,comm_term):

   t_kolmog=6.*(viscos/0.09/k3d[0,:,:]/om3d[0,:,:])**0.5
   t_scale=np.maximum(1./0.09/om3d[0,:,:],t_kolmog)
   omeg=1./0.09/t_scale

   prod_extra=-omeg/k3d[0,:,:]*comm_term
   su3d[0,:,:]=su3d[0,:,:]+np.maximum(prod_extra,0.)*vol[0,:,:]

   su3d[0,:,:]= su3d[0,:,:]+np.maximum(convw[0,:,:],0)*om_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-np.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*om_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def fix_omega():

#  aw3d[:,0,:]=0
#  ae3d[:,0,:]=0
#  as3d[:,0,:]=0
#  an3d[:,0,:]=0
#  al3d[:,0,:]=0
#  ah3d[:,0,:]=0
#  ap3d[:,0,:]=1
#  su3d[:,0,:]=om_bc_south

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_outlet(convw):

# inlet
   flow_in=np.sum(convw[0,:,:])
   flow_out=np.sum(convw[-1,:,:])
   area_out=np.sum(areaw[-1,:,:])

   uinc=(flow_in-flow_out)/area_out
   ares=areaw[-1,:,:]
   convw[-1,:,:]=convw[-1,:,:]+uinc*ares

   print('area_out',area_out)

   flow_out_new=np.sum(convw[-1,:,:])

   print('flow_in',flow_in,'flow_out',flow_out,'area_out',area_out,'flow_out_new',flow_out_new,'uinc:',uinc)

   return convw,u_bc_east

def modify_fk(fk3d):

   global f_e_mean, f_b_mean, f_dt_mean

   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=np.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=np.repeat(dy[:,:,None],repeats=nk,axis=2)

   l_temp=np.maximum(l_dist,l_max)
   l_temp=np.maximum(l_temp,dy)
   l_iddes=np.minimum(l_temp,delta_max)
#  l_iddes=np.minimum(np.maximum(l_dist,l_max,dy),delta_max)

   ueps=(eps3d*viscos)**0.25
   ystar=ueps*dist3d/viscos
   rt=k3d**2/eps3d/viscos
   fdampf2=((1.-np.exp(-ystar/3.1))**2)*(1.-0.3*np.exp(-(rt/6.5)**2))
   fmu=((1.-np.exp(-ystar/14.))**2)*(1.+5./rt**0.75*np.exp(-(rt/200.)**2))
   fmu=np.minimum(fmu,1.)

   psi=np.minimum(10,(fdampf2*fmu)**(-0.75))

   l_c=psi*cdes*l_iddes  #eq. 9

   vist=vis3d-viscos
   denom=kappa**2*dist3d**2*gen**0.5

   r_dt=vist/denom  #eq. 22
   r_dl=viscos/denom  #eq. 23

   f_t=np.tanh((c_t**2*r_dt)**3)
   f_l=np.tanh((c_l**2*r_dl)**10)

   f_e2=1.-np.maximum(f_t,f_l) #eq. 19

   alpha=0.25-dist3d/delta_max

   f_e1=np.where(alpha <= 0,2*np.exp(-9*alpha**2),2*np.exp(-11.09*alpha**2))

   f_b=  np.minimum(2.*np.exp(-9*alpha**2),1.)

   f_dt=1.-np.tanh((8.*r_dt)**3)

   f_e=np.maximum(f_e1-1.,0.)*psi*f_e2

   f_d=np.maximum((1.-f_dt),f_b)

   l_u=k3d**1.5/eps3d

   l_tilde=f_d*(1+f_e)*l_u+(1-f_d)*l_c

   fk3d=l_u/l_tilde

#  l_tilde=f_d*l_u+(1-f_d)*l_c

#  fk3d=np.minimum(fk3d,1.9/1.5)


   if iter == 0 and itstep ==0:
      f_e_mean=np.zeros((ni,nj))
      f_b_mean=np.zeros((ni,nj))
      f_dt_mean=np.zeros((ni,nj))



   if iter == 0 and itstep%itstep_stats == 0 and itstep >= itstep_start:
      f_e_mean,f_b_mean,f_dt_mean= aver_iddes(f_e,f_b,f_dt,f_e_mean,f_b_mean,f_dt_mean)

   if iter == 0 and (itstep == ntstep-1 or itstep%itstep_save == 0):
      print('IDDES functions saved')
      np.save('f_e_mean', f_e_mean)
      np.save('f_b_mean',f_b_mean)
      np.save('f_dt_mean',f_dt_mean)


   return fk3d


def aver_iddes(f_e,f_b,f_dt,f_e_mean,f_b_mean,f_dt_mean):

   f_e_mean=f_e_mean+np.mean(f_e,axis=2)
   f_b_mean=f_b_mean+np.mean(f_b,axis=2)
   f_dt_mean=f_dt_mean+np.mean(f_dt,axis=2)

   return f_e_mean,f_b_mean,f_dt_mean

def fix_k():

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_eps():

# south wall
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=np.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*2*viscos*k3d[:,0,:]/dist3d[:,0,:]**2

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_vis(vis3d):

   return vis3d


