gpu = True
import numpy
from cupyx.scipy import sparse
import sys
import time
import pyamg
from cupyx.scipy.sparse import spdiags,linalg,eye

def setup_case():
   global  acrank,acrank_conv, acrank_conv_keps, acrank_conv_kom,amg_cycle, amg_cycle_phi, amg_relax, amg_relax_phi, \
   blend,cdes,c_eps,c_eps_1,c_eps_2, c_l,c_omega_1, c_omega_2, cmu,c_t, coeff_v, coeff_w,\
   c_omega_1_sst_1, c_omega_2_sst_1, c_omega_1_sst_2,c_omega_2_sst_2, \
   convergence_limit_eps, convergence_limit_k, convergence_limit_om, convergence_limit_p, convergence_limit_t, convergence_limit_u, \
   convergence_limit_v, convergence_limit_w,cr_sst, convergence_limit_gpu, \
   cyclic_x, cyclic_z, dt, dist3d,dz, dmin_synt,embedded,eps_bc_east,eps_bc_east_type,eps_bc_north,eps_bc_north_type, eps_bc_south, \
   eps_bc_south_type, eps_bc_west, eps_bc_west_type, eps_bc_low, eps_bc_high, eps_bc_low_type,eps_bc_high_type,eps_min,fkmin_limit,\
   fx, fy,fz,imon,itstep_stats, gpu, save_average_z, i_s_fair, k_s_fair, i_fair,i_fair_embedd, \
   i_block_start,i_block_end,j_block_start,j_block_end,\
   itstep_save,itstep_start,jl0,jmirror_synt,jmon,kappa,k_bc_east,k_bc_east_type,k_bc_north,k_bc_north_type,k_bc_south,\
   k_bc_south_type,k_bc_west,k_bc_west_type,k_bc_low,k_bc_high,k_bc_low_type,k_bc_high_type,keps,keps_des,k_eq_les,kmon,kom, k_min,\
   kom_peng,kom_des,launder,L_t_synt,\
   maxit,min_iter, norm_order, ni,nj,nk,nmodes_synt,nsweep_keps,nsweep_kom, nsweep_t, \
   nsweep_vel, ntstep, om_bc_east, om_bc_east_type, om_bc_north, om_bc_north_type, \
   om_bc_south, om_bc_south_type, om_bc_west, om_bc_west_type, om_bc_low,om_bc_high, om_bc_low_type, om_bc_high_type, om_min,\
   p_bc_east, p_bc_east_type, p_bc_north, p_bc_north_type, p_bc_south, p_bc_south_type, p_bc_west, p_bc_west_type, p_bc_low,\
   p_bc_high, p_bc_low_type, p_bc_high_type,pans, prand_lam,prand_eps,\
   prand_k_sst_1, prand_k_sst_2, prand_omega_sst_1, prand_omega_sst_2, prand_k,prand_omega,prand_t, \
   resnorm_p,resnorm_vel,resnorm_t,restart,save,save_vtk_movie,scheme,scheme_turb,scheme_t, smag,solver_vel, solver_p, \
   solver_t,t_bc_east, t_bc_east_type, t_bc_north, t_bc_north_type, t_bc_sotth, t_bc_south_type, t_bc_west, t_bc_west_type, \
   t_bc_low,t_bc_high, t_bc_low_type, t_bc_high_type, temp, \
   solver_turb,solverx,sormax, s2,sst,sst_sst_uv_limit,u_bc_east, u_bc_east_type, u_bc_north, u_bc_north_type, \
   u_bc_south, u_bc_south_type, u_bc_west, u_bc_west_type, \
   u_bc_low,u_bc_high,u_bc_low_type,u_bc_high_type,urfvis, v_bc_east, v_bc_east_type, v_bc_north, v_bc_north_type, v_bc_south, \
   v_bc_south_type,v_bc_west,v_bc_west_type,v_bc_low,v_bc_high,v_bc_low_type,v_bc_high_type,viscos, vol,vtk,vtk_save,\
   vtk_file_name,w_bc_east,w_bc_east_type,w_bc_north,w_bc_north_type,\
   w_bc_south, w_bc_south_type, w_bc_west, w_bc_west_type, w_bc_low, w_bc_high, w_bc_low_type, w_bc_high_type,\
   wale, x_embed, x2d, xp2d, y2d, yp2d, z,zp,zmax

   import numpy as np
   import sys


# N.B. All variables that are set in this module must be included in the 'return' statement at the last line

########### section 0 choice of CPU or GPU ###################
   gpu = True

########### section 1 choice of differencing scheme ###########
   scheme='c'  #hybrid
   scheme_turb='u'  #hybrid upwind-central 
   acrank=1.0  # for pressure gradient
   acrank_conv=0.5  # for convection-diffusion
   acrank_conv_kom=1  # for convection-diffusion
   acrank_conv_keps=1  # for convection-diffusion
   jl0=0
#  scheme_turb='h'  #hybrid upwind-central 


########### section 2 turbulence models ###########
   cmu=0.09
   pans = False
   fkmin_limit=0
   keps = False
   kom_des = False
   keps_des = True
   kom = False
   wale = False
   smag = False
   sst = False
   c_eps_1=1.5
   c_eps_2=1.9
   cmu=0.09
   c_omega_1= 5./9.
   c_omega_2=3./40.
   prand_omega=2.0
   prand_eps=1.4
   prand_k=1.4
   jl0=0
   cdes=0.67
   kappa=0.4
   c_t=1.87
   c_l=5

   if keps:
      prand_k=1.4

   if pans: #pitm
      prand_k=1.4 
      prand_eps=1.4 

   if kom or kom_des:
      prand_k=2.0
   if smag:
      cmu=0.1

########### section 3 restart/save ###########
   restart = False
   save = True

########### section 4 fluid properties ###########
   viscos=1/16000

########### section 5 relaxation factors ###########
   urfvis=0.5

########### section 6 number of iteration and convergence criterira ###########
   maxit=5
   min_iter=2
   sormax=1e-3

########### section 7 all variables are printed during the iteration at node ###########
   imon=0
   jmon=0
   kmon=0

########### section 8 time-averaging ###########
   ntstep=20000
   uin=20
   dt=0.25*(x2d[1,0]-x2d[0,0])*xp.ones(ntstep)/uin
   itstep_start=ntstep-10000
   itstep_save=2000  # save every itstep_save timestep
   itstep_stats=1 # time average every itstep_stats timestep
   vtk=False

########### section 9 residual scaling parameters ###########
   resnorm_p=uin*zmax*y2d[1,-1]
   resnorm_vel=uin**2*zmax*y2d[1,-1]

########### Section 10 boundary conditions ###########
   cyclic_x = False
   cyclic_z = True

# synthetic inlet fluct
   L_t_synt=0.2
   nmodes_synt=150
   jmirror_synt=int(nj/2) # mirror vsynt at node jmirror; jmirror=0 means no mirroring
   dmin_synt=dz/4

# boundary conditions for u
   u_bc_west=xp.zeros((nj,nk))
   u_bc_east=xp.zeros((nj,nk))
   u_bc_south=xp.zeros((ni,nk))
   u_bc_north=xp.zeros((ni,nk))
   u_bc_z=0

   u_bc_west_type='d' 
   u_bc_east_type='n' 
   u_bc_south_type='n' # wall functions
   u_bc_north_type='n' # wall functions
   u_bc_z_type='n'

# boundary conditions for v
   v_bc_west=xp.zeros((nj,nk))
   v_bc_east=xp.zeros((nj,nk))
   v_bc_south=xp.zeros((ni,nk))
   v_bc_north=xp.zeros((ni,nk))
   v_bc_z=0

   v_bc_west_type='d' 
   v_bc_east_type='n' 
   v_bc_south_type='d'
   v_bc_north_type='d'
   v_bc_z_type='n'

# boundary conditions for w
   w_bc_west=xp.zeros((nj,nk))
   w_bc_east=xp.zeros((nj,nk))
   w_bc_south=xp.zeros((ni,nk))
   w_bc_north=xp.zeros((ni,nk))
   w_bc_z=0

   w_bc_west_type='d' 
   w_bc_east_type='n' 
   w_bc_south_type='d'
   w_bc_north_type='d'
   w_bc_z_type='d'

# boundary conditions for p
   p_bc_west=xp.zeros((nj,nk))
   p_bc_east=xp.zeros((nj,nk))
   p_bc_south=xp.zeros((ni,nk))
   p_bc_north=xp.zeros((ni,nk))
   p_bc_z=0

   p_bc_west_type='n'
   p_bc_east_type='n'
   p_bc_south_type='n'
   p_bc_north_type='n'
   p_bc_z_type='n'

# boundary conditions for k
   k_bc_west=xp.zeros((nj,nk))
   k_bc_east=xp.zeros((nj,nk))
   k_bc_south=xp.zeros((ni,nk))
   k_bc_north=xp.zeros((ni,nk))
   k_bc_z=0

   k_bc_west_type='d'
   k_bc_east_type='n'
   k_bc_south_type='d'
   k_bc_north_type='d'
   k_bc_z_type='n'

# boundary conditions for eps
   eps_bc_west=xp.zeros((nj,nk))
   eps_bc_east=xp.zeros((nj,nk))
   eps_bc_south=xp.zeros((ni,nk))
   eps_bc_north=xp.zeros((ni,nk))
   eps_bc_z=0

   eps_bc_west_type='d'
   eps_bc_east_type='n'
   eps_bc_south_type='d' 
   eps_bc_north_type='d' 
   eps_bc_z_type='n'

# boundary conditions for omega
   om_bc_west=xp.zeros((nj,nk))
   om_bc_east=xp.zeros((nj,nk))
   om_bc_south=xp.zeros((ni,nk))
   om_bc_north=xp.zeros((ni,nk))

   xwall_s=0.5*(x2d[0:-1,0]+x2d[1:,0])
   ywall_s=0.5*(y2d[0:-1,0]+y2d[1:,0])
   dist2_s=(yp2d[:,0]-ywall_s)**2+(xp2d[:,0]-xwall_s)**2
   om_bc_south=10*6*viscos/0.075/dist2_s

# make it 2D
   om_bc_south=xp.repeat(om_bc_south[:,None], repeats=nk, axis=1)

   xwall_n=0.5*(x2d[0:-1,-1]+x2d[1:,-1])
   ywall_n=0.5*(y2d[0:-1,-1]+y2d[1:,-1])
   dist2_n=(yp2d[:,-1]-ywall_n)**2+(xp2d[:,-1]-xwall_n)**2
   om_bc_north=10*6*viscos/0.075/dist2_n

# make it 2D
   om_bc_north=xp.repeat(om_bc_north[:,None], repeats=nk, axis=1)
   om_bc_z=0

   om_bc_west_type='d'
   om_bc_east_type='n'
   om_bc_south_type='d'
   om_bc_north_type='d'
   om_bc_z_type='n'

   return 


def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):

   data =xp.loadtxt('y_u_k_om_uv_16000-RANS-code.txt')
   y_rans_in=data[:,0]
   y_rans=yp2d[0,:]
   u_rans_in=data[:,1]
   u_rans=xp.interp(y_rans, y_rans_in, u_rans_in)
# make it 2D
   u_rans=xp.repeat(u_rans[:,None], repeats=nk, axis=1)

   k_rans_in=data[:,2]
   k_rans=xp.interp(y_rans, y_rans_in, k_rans_in)
# make it 2D
   k_rans=xp.repeat(k_rans[:,None], repeats=nk, axis=1)

   om_rans_in=data[:,3]
   om_rans=xp.interp(y_rans, y_rans_in, om_rans_in)
# make it 2D
   om_rans=xp.repeat(om_rans[:,None], repeats=nk, axis=1)

   eps_rans=cmu*k_rans*om_rans
# set inlet field in entre domain
   u3d=xp.repeat(u_rans[None,:,:], repeats=ni, axis=0)
   k3d=0.2*xp.repeat(k_rans[None,:,:], repeats=ni, axis=0)
   eps3d=xp.repeat(eps_rans[None,:,:], repeats=ni, axis=0)

   vis3d=cmu*k3d**2/eps3d+viscos
   
   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d,dist3d

def stress_EARSM(dudy,k_rans,om_rans):
    nj = xp.size(dudy,0)
    uu=xp.zeros(nj)
    vv=xp.zeros(nj)
    ww=xp.zeros(nj)
    uv=xp.zeros(nj)

    for j in range (0,nj):

        diss=.09*k_rans[j]*om_rans[j]
        rk=k_rans[j]
        ttau=rk/diss

        om12=ttau*0.5*dudy[j]
        om21=-om12
        om22=0.
        om11=0.
        s11=0.
        s12=ttau*0.5*dudy[j]
        s21=s12
        s22=0.
        s33=0.
        vor=(-2.*om12**2)
        str1=(s11**2+s12**2+s21**2+s22**2)

      
        cpr1=9./4.*(1.8-1.) #OLD
        
        Neq = 81./20.
        b1eq = -6./5.*(Neq/(Neq**2-2.*str1))
        cdiff = 2.2
        #cpr1 = 9./5. + 9./4.*cdiff*xp.maximum(1.+b1eq*str1,0.)
        
        
        p1=(1./27.*cpr1**2+9./20.*str1-2./3.*vor)*cpr1
        p2=p1**2-(1./9.*cpr1**2+9./10.*str1+2./3.*vor)**3
      
        if p2 >  0:
           if p1-p2**0.5 >= 0:
              sigg=1.
           else:
              sigg=-1.
           un=cpr1/3.+(p1+p2**0.5)**(1./3.)+sigg*(abs(p1-p2**0.5))**(1./3.)
        else:
           un=cpr1/3.+2.*(p1**2-p2)**(1./6.)*xp.cos(1./3.*xp.arccos(p1/(xp.sqrt(p1**2-p2))))
      
        const=6./5.
        beta1=-const*un/(un**2-2.*vor)
        beta4=beta1/un

        uu[j]=2./3.*rk+rk*beta1*s11+rk*beta4*(s12*om21-om12*s21)
        vv[j]=2./3.*rk+rk*beta1*s22+rk*beta4*(s21*om12-om21*s12)
        ww[j]=2./3.*rk
        uv[j]=rk*(beta1*s12+beta4*(s11*om12-om12*s22))

    return uu,vv,ww,uv

def stress_Normal(dudy,k_rans,om_rans):
    eddy_visc = k_rans/om_rans
    uu = (2./3.)*k_rans
    vv = (2./3.)*k_rans
    ww = (2./3.)*k_rans
    uv =-eddy_visc*dudy  
    return uu,vv,ww,uv


def modify_inlet():

   global y_rans,u_rans,v_rans,k_rans,om_rans,uv_rans,zp,usynt_inlet,vsynt_inlet,wsynt_inlet,\
          uu_synt,vv_synt,ww_synt,uv_synt,two_corr,u_time,uv_aver,w_synt,k_bc_west,eps_bc_west,\
          r11,r22,r33,r12,r13,r23,xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,dz2d,uin,k_bc_west,\
          eps_bc_west,om_bc_west
  
   global usynt,vsynt,wsynt

   if itstep == 0:
      two_corr=xp.zeros(2*nk-1)
      uv_aver=0
      u_time=xp.zeros(ntstep)
      y_u_k_om  =xp.loadtxt('/chalmers/users/lada/pythons-rans-code-RANS/channel-16000-cyclicx/y_u_k_om_uv_16000-RANS-code.txt')
      y_rans_in=y_u_k_om[:,0]
      y_rans=yp2d[0,:]
      u_rans_in=y_u_k_om[:,1]
      u_rans=xp.interp(y_rans, y_rans_in, u_rans_in)

      k_rans_in=y_u_k_om[:,2]
      k_rans=xp.interp(y_rans, y_rans_in, k_rans_in)

      om_rans_in=y_u_k_om[:,3]
      om_rans=xp.interp(y_rans, y_rans_in, om_rans_in)

      eps_rans=cmu*k_rans*om_rans
      eddy_visc = k_rans/om_rans
      dudy=xp.gradient(u_rans,y_rans)
      #uv=-eddy_visc*dudy  


      
      uu_in,vv_in,ww_in,uv_in = stress_EARSM(dudy,k_rans,om_rans)
      
# make it 2D
      u_rans=xp.repeat(u_rans[:,None], repeats=nk, axis=1)
      k_rans=xp.repeat(k_rans[:,None], repeats=nk, axis=1)
      eddy_visc=xp.repeat(eddy_visc[:,None], repeats=nk, axis=1)
      eps_rans=xp.repeat(eps_rans[:,None], repeats=nk, axis=1)
      om_rans=xp.repeat(om_rans[:,None], repeats=nk, axis=1)
      uu_in=xp.repeat(uu_in[:,None], repeats=nk, axis=1)
      vv_in=xp.repeat(vv_in[:,None], repeats=nk, axis=1)
      ww_in=xp.repeat(ww_in[:,None], repeats=nk, axis=1)
      uv_in=xp.repeat(uv_in[:,None], repeats=nk, axis=1)
 

      dy = xp.diff(y2d,axis=1)
      tmp = dy[0,:]
      dy2d = xp.repeat(tmp[:,None], repeats=nk, axis=1) 
      dx = xp.diff(x2d,axis=0)
      tmp = dx[0,0:-1]
      dx2d = xp.repeat(tmp[:,None], repeats=nk, axis=1)
      dz2d = xp.ones((nj,nk))*dz

        # Cell Centers
      zp = xp.linspace(0, zmax, nk)
      tmp = xp2d[0,:]
      xp2d_synt = xp.repeat(tmp[:,None], repeats=nk, axis=1) 
      tmp = yp2d[0,:]
      yp2d_synt = xp.repeat(tmp[:,None], repeats=nk, axis=1) 
      zp2d_synt=xp.repeat(zp[None,:], repeats=nj, axis=0)

      # compute wall distance
      dw =dist3d[0,:,0]
      dw2d = xp.repeat(dw[:,None], repeats=nk, axis=1) 
 
        # # Get/Compute Reynolds Stress Tensor  
      r11 = uu_in
      r22 = vv_in
      r33 = ww_in
      r12 = uv_in
      r13 = xp.zeros((nj,nk))
      r23 = xp.zeros((nj,nk))
      uin=xp.sum(u_rans*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      #usynt,vsynt,wsynt=synt_fluct(nmodes_synt,itstep,L_t_synt,y_rans,zp,uv_rans,viscos,jmirror_synt)
      usynt,vsynt,wsynt=STG(uin,viscos,itstep,dt[itstep],xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,\
                      dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans)
# correct usynt so that it is = 0 (easier to converge the p solver)
      uinc=xp.sum(usynt*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
      usynt=usynt-uinc
      usynt_inlet=usynt
      vsynt_inlet=vsynt
      wsynt_inlet=wsynt
      

      uu_synt=xp.zeros(nj)
      vv_synt=xp.zeros(nj)
      ww_synt=xp.zeros(nj)
      uv_synt=xp.zeros(nj)
      
      k_bc_west=k_rans
      eps_bc_west=eps_rans
      om_rans=eps_rans/cmu/k_rans
      om_bc_west=om_rans

   usynt,vsynt,wsynt=STG(uin,viscos,itstep,dt[itstep],xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,\
                      dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans)
# correct usynt so that it is = 0 (easier to converge the p solver)
   uinc=xp.sum(usynt*areaw[0,:,:])/(y2d[0,-1]-y2d[0,0])/zmax
   usynt=usynt-uinc
   u_bc_west=u_rans+usynt
   v_bc_west=vsynt
   w_bc_west=wsynt

   uu_synt=uu_synt+xp.mean(usynt**2,axis=1)
   vv_synt=vv_synt+xp.mean(vsynt**2,axis=1)
   ww_synt=ww_synt+xp.mean(wsynt**2,axis=1)
   uv_synt=uv_synt+xp.mean(usynt*vsynt,axis=1)

# update face velocity and convw at inlet
#  u3d_face_w[0,:,:]=u_bc_west
#  convw[0,:,:]=-u_bc_west*areawx[0,:,:]-v_bc_west*areawy[0,:,:]

# compute two-point corr in node 10
   two_corr=two_corr+xp.correlate(w_bc_west[10,:],w_bc_west[10,:],'full')

# sum over timesteps
   uv_aver=uv_aver+xp.mean(u_bc_west[58,:]*v_bc_west[58,:])

# compute average
   if itstep % 100 == 0:
      xp.save('two_corr_inlet',two_corr)
      xp.save('u_time',u_time)
      xp.save('uu_synt',uu_synt/(itstep+1))
      xp.save('uv_synt',uv_synt/(itstep+1))
      print('uvmean_aver at max',uv_aver/(itstep+1))

   return u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw

def modify_conv(convw,convs,convl):

   convs[:,0,:]=0
   convs[:,-1,:]=0

   return convw,convs,convl

def modify_u(su3d,sp3d):

# south wall
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed at first cell
   tauw=ustar**2
   su3d[:,0,:]=su3d[:,0,:]-tauw*areas[:,0,:]
# north wall
   ustar=cmu**0.25*k3d[:,-1,:]**0.5
   print('modify_u, ustar[-2,-2],[2,2]',ustar[-2,-2],ustar[2,2])
   tauw=ustar**2
   su3d[:,-1,:]=su3d[:,-1,:]-tauw*areas[:,0,:]

# inlet
   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*u_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*u_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def modify_v(su3d,sp3d):

# south wall
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed at first cell
   tauw=ustar**2

   xt=xp.diff(x2d[:,0])
   yt=xp.diff(y2d[:,0])
   rl=(xt**2+yt**2)**0.5
   yt=abs(yt)/rl
   if itstep == 0 and iter == 0:
      print('yt',yt)
# make it 2D
   yt=xp.repeat(yt[:,None], repeats=nk, axis=1)
#  tauw_v=tauw*yt
   tauw_v=tauw*abs(yt)
# downstream part of hump: assume V < 0 => su should be > 0.
#  tauw_v > 0, su = -tauw_s*(-1) = tauw_v  > 0 as it should
   su3d[:,0,:]=su3d[:,0,:]-tauw_v*areas[:,0,:]*xp.sign(v3d[:,0,:])


# inlet
   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*v_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*v_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d


def modify_w(su3d,sp3d):
   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*w_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*w_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound



   return su3d,sp3d

def modify_k(su3d,sp3d,gen):

   global u2prim_i0,v2prim_i0,w2prim_i0,umean_i0

   if iter == 0:
# set running-averaged inlet values to zero
      if itstep == 0:
         u2prim_i0=xp.zeros(nj)
         v2prim_i0=xp.zeros(nj)
         w2prim_i0=xp.zeros(nj)
         umean_i0=xp.zeros(nj)

# time average
      u2prim_i0=u2prim_i0+xp.mean(u3d[0,:,:]**2,axis=1)
      v2prim_i0=v2prim_i0+xp.mean(v3d[0,:,:]**2,axis=1)
      w2prim_i0=w2prim_i0+xp.mean(w3d[0,:,:]**2,axis=1)
      umean_i0=umean_i0+xp.mean(u3d[0,:,:],axis=1)

   comm_term=xp.zeros((nj,nk))
   umean=umean_i0/(itstep+1)
   u2prim=u2prim_i0/(itstep+1)-umean**2
   v2prim=v2prim_i0/(itstep+1)
   w2prim=w2prim_i0/(itstep+1)
   k_tot=0.5*(u2prim+v2prim+w2prim)+xp.mean(k3d[0,:,:],axis=1)
# make it 2D
   k_tot=xp.repeat(k_tot[:,None], repeats=nk, axis=1)

   psi_small=fk3d[0,:,:]
   term1=xp.maximum((c_eps_2-c_eps_1*psi_small)/(c_eps_2-c_eps_1),1e-10)
   fk2d_from_psi=term1**0.333

#  dfk_dx=u3d[0,:,:]*(0.4-1)/(x2d[1,1]-x2d[0,1])
   dfk_dx=u3d[0,:,:]*(fk2d_from_psi-1)/(x2d[1,1]-x2d[0,1])

# commutation term 
   comm_term_pans=k_tot*dfk_dx
   comm_min_pans=xp.min(comm_term_pans)
   pk_max=xp.max((vis3d-viscos)*gen)

   u2prim_max=xp.max(u2prim)
   v2prim_max=xp.max(v2prim)
   w2prim_max=xp.max(w2prim)

   print(f"\n{'comm_min_pans: '} {comm_min_pans:.2e}, {'pk_max: '}{pk_max:.2e}, {'u2prim_max: '}{u2prim_max:.2e}, {'v2prim_max: '}{v2prim_max:.2e}, {'w2prim_max: '}{w2prim_max:.2e}")
   sp3d[0,:,:]=sp3d[0,:,:]+xp.minimum(comm_term_pans,0)*vol[0,:,:]/k3d[0,:,:]

   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*k_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*k_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d,comm_term

def modify_eps(su3d,sp3d):

   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*eps_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-convw[0,:,:]
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*eps_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def modify_om(su3d,sp3d,comm_term):

   t_kolmog=6.*(viscos/0.09/k3d[0,:,:]/om3d[0,:,:])**0.5
   t_scale=xp.maximum(1./0.09/om3d[0,:,:],t_kolmog)
   omeg=1./0.09/t_scale

   prod_extra=-omeg/k3d[0,:,:]*comm_term
   su3d[0,:,:]=su3d[0,:,:]+xp.maximum(prod_extra,0.)*vol[0,:,:]

   su3d[0,:,:]= su3d[0,:,:]+xp.maximum(convw[0,:,:],0)*om_bc_west
   sp3d[0,:,:]= sp3d[0,:,:]-xp.maximum(convw[0,:,:],0)
   vist=vis3d[0,:,:]-viscos
   su3d[0,:,:]=su3d[0,:,:]+vist*aw_bound*om_bc_west
   sp3d[0,:,:]=sp3d[0,:,:]-vist*aw_bound

   return su3d,sp3d

def fix_omega():

#  aw3d[:,0,:]=0
#  ae3d[:,0,:]=0
#  as3d[:,0,:]=0
#  an3d[:,0,:]=0
#  al3d[:,0,:]=0
#  ah3d[:,0,:]=0
#  ap3d[:,0,:]=1
#  su3d[:,0,:]=om_bc_south

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_outlet(convw):

# inlet
   flow_in=xp.sum(convw[0,:,:])
#  flow_out=xp.sum(convw[-1,:,:])
   flow_out=xp.sum(convw[-2,:,:])
   area_out=xp.sum(areaw[-1,:,:])

   uinc=(flow_in-flow_out)/area_out
   ares=areaw[-1,:,:]
#  convw[-1,:,:]=convw[-1,:,:]+uinc*ares
   convw[-1,:,:]=convw[-2,:,:]+uinc*ares

   print('area_out',area_out)

   flow_out_new=xp.sum(convw[-1,:,:])

   print('flow_in',flow_in,'flow_out',flow_out,'area_out',area_out,'flow_out_new',flow_out_new,'uinc:',uinc)

   return convw,u_bc_east

def modify_fk(fk3d):

   global f_e_mean, f_b_mean, f_dt_mean

   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=xp.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=xp.repeat(dy[:,:,None],repeats=nk,axis=2)

   l_temp=xp.maximum(l_dist,l_max)
   l_temp=xp.maximum(l_temp,dy)
   l_iddes=xp.minimum(l_temp,delta_max)
#  l_iddes=xp.minimum(xp.maximum(l_dist,l_max,dy),delta_max)

   ueps=(eps3d*viscos)**0.25
   ystar=ueps*dist3d/viscos
   rt=k3d**2/eps3d/viscos
   fdampf2=((1.-xp.exp(-ystar/3.1))**2)*(1.-0.3*xp.exp(-(rt/6.5)**2))
   fmu=((1.-xp.exp(-ystar/14.))**2)*(1.+5./rt**0.75*xp.exp(-(rt/200.)**2))
   fmu=xp.minimum(fmu,1.)

   psi=xp.minimum(10,(fdampf2*fmu)**(-0.75))

   l_c=psi*cdes*l_iddes  #eq. 9

   vist=vis3d-viscos
   denom=kappa**2*dist3d**2*gen**0.5

   r_dt=vist/denom  #eq. 22
   r_dl=viscos/denom  #eq. 23

   f_t=xp.tanh((c_t**2*r_dt)**3)
   f_l=xp.tanh((c_l**2*r_dl)**10)

   f_e2=1.-xp.maximum(f_t,f_l) #eq. 19

   alpha=0.25-dist3d/delta_max

   f_e1=xp.where(alpha <= 0,2*xp.exp(-9*alpha**2),2*xp.exp(-11.09*alpha**2))

   f_b=  xp.minimum(2.*xp.exp(-9*alpha**2),1.)

   f_dt=1.-xp.tanh((8.*r_dt)**3)

   f_e=xp.maximum(f_e1-1.,0.)*psi*f_e2

   f_d=xp.maximum((1.-f_dt),f_b)

   l_u=k3d**1.5/eps3d

   l_tilde=f_d*(1+f_e)*l_u+(1-f_d)*l_c

   fk3d=l_u/l_tilde

   if iter == 0 and itstep ==0:
      f_e_mean=xp.zeros((ni,nj))
      f_b_mean=xp.zeros((ni,nj))
      f_dt_mean=xp.zeros((ni,nj))



   if iter == 0 and itstep%itstep_stats == 0 and itstep >= itstep_start:
      f_e_mean,f_b_mean,f_dt_mean= aver_iddes(f_e,f_b,f_dt,f_e_mean,f_b_mean,f_dt_mean)

   if iter == 0 and (itstep == ntstep-1 or itstep%itstep_save == 0):
      print('IDDES functions saved')
      xp.save('f_e_mean', f_e_mean)
      xp.save('f_b_mean',f_b_mean)
      xp.save('f_dt_mean',f_dt_mean)


   return fk3d


def aver_iddes(f_e,f_b,f_dt,f_e_mean,f_b_mean,f_dt_mean):

   f_e_mean=f_e_mean+xp.mean(f_e,axis=2)
   f_b_mean=f_b_mean+xp.mean(f_b,axis=2)
   f_dt_mean=f_dt_mean+xp.mean(f_dt,axis=2)

   return f_e_mean,f_b_mean,f_dt_mean

def fix_k():


    from joblib import dump, load
    from sklearn.preprocessing import MinMaxScaler
    from scipy.spatial import KDTree
 
  
    global yplus_min,yplus_max,scaler_yplus,scaler_pplus,neural_net,ubulk
    global ustar_mean_s,counter_ustar,count
    global uplus_min,uplus_max,scaler_uplus, tree, X
    global yplus_np_target,uplus_np_target,ustar2_mean_s,x_target,yplus_target
    global inds, ds_np,ds, nn_np, nn, ds,x_location,y_location,x2_location
# NN
# load data
# inlet
    if iter == 0 and itstep == 0: 
#      folder = '../hump-IDDES-ni-583-nk-64-go4hybrid-mesh-STG-dt-0.002-full-GPU/'
       folder = './'
       data_10 = xp.loadtxt(str(folder)+'x-yplus-uplus-ustar-ALL-i-j-hump-incl-backflow.txt')

       x_target = data_10[:,0]
       yplus_target = abs(data_10[:,1])
       uplus_target = data_10[:,2]

       yplus_max = xp.max(yplus_target)
       yplus_min = xp.min(yplus_target)
       uplus_min = xp.min(uplus_target)
       uplus_max = xp.max(uplus_target)

       print('yplus_max, yplus_min, uplus_min, uplus_max',\
              yplus_max, yplus_min, uplus_min, uplus_max)

       uplus_target = uplus_target.reshape(-1,1)
       yplus_target = yplus_target.reshape(-1,1)
# use MinMax scaler
       scaler_yplus = MinMaxScaler()
       scaler_uplus = MinMaxScaler()

       X=xp.zeros((len(yplus_target),2))

       if gpu:
         yplus_np_target = xp.asnumpy(yplus_target)
         uplus_np_target = xp.asnumpy(uplus_target)
         X_np = xp.asnumpy(X)
       else:
         yplus_np_target = yplus_target
         uplus_np_target = uplus_target
         X_np = X

       X_np[:,0] = scaler_uplus.fit_transform(uplus_np_target)[:,0]
       X_np[:,1] = scaler_yplus.fit_transform(yplus_np_target)[:,0]

       ustar_mean_s = xp.zeros(ni)
       ustar2_mean_s = xp.zeros(ni)
       u2d_wall=u3d[:,0,:] # first cell
       u_mean_k=xp.mean(u2d_wall,axis=1)
       u_wall_mean = u_mean_k
       ustar=cmu**0.25*k3d[:,0,:]**0.5
       u1d = xp.mean(u3d[0,:,:],axis=1)
#      ubulk = xp.trapz(u1d,yp2d[0,:])/max(y2d[0,:])
       ubulk= xp.sum(u1d*xp.diff(y2d[0,:]))/max(y2d[0,:])
       print('ubulk',ubulk)

       tree = KDTree(X_np) # data we want to match

       x_location=xp.zeros((ni*nk))
       x2_location=xp.zeros((ni*nk))
       y_location=xp.zeros((ni*nk))

# take old ustar from k=cmu**(-0.5)*ustar**2
    ustar=cmu**0.25*k3d[:,0,:]**0.5

    j_wall = 0 # 1st cell
    u2d_wall=u3d[:,j_wall,:]  # first cell

# average ustar to be used in p+
    ustar_mean_k=xp.mean(ustar,axis=1)

    dy=dist3d[:,j_wall,:] # first cell
    yplus_south = ustar*dy/viscos

    uplus_south = u2d_wall/ustar
    ustar_south = ustar

# count values larger/smaller than max/min
    yplus_min_number= (yplus_south < yplus_min).sum()
    yplus_max_number= (yplus_south > yplus_max).sum()
    print('south: yplus_min_number',yplus_min_number)
    print('south: yplus_max_number',yplus_max_number)

    uplus_min_number= (uplus_south < uplus_min).sum()
    uplus_max_number= (uplus_south > uplus_max).sum()
    print('south: uplus_min_number',uplus_min_number)
    print('south: uplus_max_number',uplus_max_number)

    ustar_min =xp.min(ustar)
    ustar_max =xp.max(ustar)
    print('dy,ustar_min,max,mean',dy[0,0],ustar_min,ustar_max,xp.mean(ustar))
    print('yplus, min,max,mean',xp.min(yplus_south),\
          xp.max(yplus_south),xp.mean(yplus_south))
    print('uplus, min,max,mean',xp.min(uplus_south),\
          xp.max(uplus_south),xp.mean(uplus_south))

# set limits
    uplus_south=xp.minimum(uplus_south,uplus_max)
    uplus_south=xp.maximum(uplus_south,uplus_min)

    yplus_south=xp.minimum(yplus_south,yplus_max)
    yplus_south=xp.maximum(yplus_south,yplus_min)

    N_predict=ni*nk
    y = uplus_south
    y = y.reshape(-1,1)
    yplus = yplus_south.reshape(-1,1)
    uplus = uplus_south.reshape(-1,1)
    n_test = len(uplus)
    x=xp.zeros((n_test,2))
    if gpu:
       yplus_np = xp.asnumpy(yplus)
       uplus_np = xp.asnumpy(uplus)
       x_np = xp.asnumpy(x)
    else:
       yplus_np = yplus
       uplus_np = uplus
       x_np = x

    x_np[:,0] = scaler_uplus.transform(uplus_np)[:,0]
    x_np[:,1] = scaler_yplus.transform(yplus_np)[:,0]
#ds, inds =  tree.query(x, 1) # finds one nearest neighbors at n_test samples  at distance d
    ds, inds =  tree.query(x_np, 5) # finds two nearest neighbors at n_test samples  at distance d
    if iter == 0 and itstep >= itstep_start:
       x_location = x_location + x_target[inds[:,0]]
       x2_location = x2_location + x_target[inds[:,0]]**2
       y_location = y_location + yplus_target[inds[:,0],0]
    if gpu:
       ds_np = xp.asarray(ds)
       inds_np = xp.asarray(inds)
    else:
       ds_np = ds
       inds_np = inds
    n_len0 = len(inds_np[:,0])
    yplus_kdtree =xp.zeros(n_len0)
    uplus_kdtree =xp.zeros(n_len0)
    n_len = len(inds_np[0,:])
    print('type(inds_np),type(ds_np)',type(inds_np),type(ds_np))
    for nn in range(n_len):
       temp1 =   yplus_np_target[inds[:,nn],0]/ds[:,nn]
       tempu =   uplus_np_target[inds[:,nn],0]/ds[:,nn]
       if gpu:
          temp1_np = xp.asarray(temp1)
          tempu_np = xp.asarray(tempu)
       else:
          temp1_np = temp1
          tempu_np = tempu
       yplus_kdtree = yplus_kdtree +  temp1_np
       uplus_kdtree = uplus_kdtree +  tempu_np
    yplus_kdtree = yplus_kdtree/xp.sum(1/ds_np,axis=1)
    uplus_kdtree = uplus_kdtree/xp.sum(1/ds_np,axis=1)

    if gpu:
       yplus_kdtree_cp = xp.asarray(yplus_kdtree)
       uplus_kdtree_cp = xp.asarray(uplus_kdtree)
    else:
       yplus_kdtree_cp = yplus_kdtree
       uplus_kdtree_cp = uplus_kdtree
#yplus_kdtree = (yplus_np[inds[:,0]]/ds[:,0]+yplus_np[inds[:,1]]/ds[:,1] \
#              +yplus_np[inds[:,2]]/ds[:,2]+yplus_np[inds[:,3]]/ds[:,3])/xp.sum(1/ds,axis=1)
    print('yplus_kdtree.shape,inds.shape',yplus_kdtree.shape,inds.shape)
    yplus_predict = xp.reshape(yplus_kdtree_cp,(ni,nk))
    uplus_predict = xp.reshape(uplus_kdtree_cp,(ni,nk))
    ustar=yplus_predict*viscos/dy
    yplus = yplus_predict
    ustar_5 = (viscos*abs(u2d_wall)/dy)**0.5
    ustar=xp.where(yplus< 5,ustar_5,ustar)

    number_yplus_linear= (yplus< 5).sum()

    print('number y+ < 5',number_yplus_linear)

#   if itstep%100 == 0:
#      xp.savetxt('umean.dat', u_mean)

# use Reichardd the first 5000 time steps
#   if itstep < 1000:
#      if itstep ==  997:
#      if itstep ==  0:
#        xp.savetxt('xp-Cf-yplus-re.dat', xp.c_[xp2d[:,1],Cf_south[:,0],yplus[:,0],re_south[:,0],\
#            ustar[:,0],u2d_wall[:,0],yplus_predict[:,0]])
#      ustar=xp.ones((ni,nk))
#      velabs=abs(u3d[:,0,:])
#      ustar=compute_ustar_reich_solve(dy/viscos,velabs,ustar)
    if iter == 0:
      if itstep == 0:
         ustar_mean_s = xp.zeros(ni)
         ustar2_mean_s = xp.zeros(ni)
         count = 0
         inds_2d = xp.reshape(inds[:,0],(ni,nk))
         i = 58
         cf0=ustar[i,:]**2*xp.sign(u2d_wall[i,:])
         xp.savetxt('min-cf-hist-58.txt', \
        xp.c_[yplus_south[i,:],yplus_predict[i,:],uplus_south[i,:],uplus_predict[i,:],x_target[inds_2d[i,:]],inds_2d[i,:],cf0])


    if iter == 0:
      if itstep == 0:
         ustar_mean_s = xp.zeros(ni)
         ustar2_mean_s = xp.zeros(ni)
         count = 0
      if itstep > itstep_start:
         count = count +1
         ustar_mean_k=xp.mean(ustar*xp.sign(u2d_wall),axis=1)
         ustar2_mean_k=xp.mean(ustar**2*xp.sign(u2d_wall),axis=1)
         ustar_mean_old = ustar_mean_s
         ustar_mean_s=ustar_mean_s + ustar_mean_k
         ustar2_mean_s=ustar2_mean_s + ustar2_mean_k
         print('ustar_mean_old,ustar_mean_s[0],ustar_mean_s[0]/step',ustar_mean_old[0],ustar_mean_s[0],ustar_mean_s[0]/count)
         xp2d[-1,1] = count
         if itstep%100 == 0:
#        if itstep%1 == 0:
            xp.savetxt('ustar-vs-x.dat', xp.c_[xp2d[:,1],ustar_mean_s,ustar2_mean_s])
            xp.savetxt('x-y-target-location.dat', xp.c_[x_location,y_location,x2_location])

            inds_2d = xp.reshape(inds[:,0],(ni,nk))

            i = 58
            cf0=ustar[i,:]**2*xp.sign(u2d_wall[i,:])
            f = open('min-cf-hist-58.txt','ab')
            xp.savetxt(f,\
        xp.c_[yplus_south[i,:],yplus_predict[i,:],uplus_south[i,:],uplus_predict[i,:],x_target[inds_2d[i,:]],inds_2d[i,:],cf0])
            f.close()




    kwall=cmu**(-0.5)*ustar**2

    aw3d[:,0,:]=0
    ae3d[:,0,:]=0
    as3d[:,0,:]=0
    an3d[:,0,:]=0
    al3d[:,0,:]=0
    ah3d[:,0,:]=0
    ap_max=xp.max(ap3d)
    ap3d[:,0,:]=ap_max
    su3d[:,0,:]=ap_max*kwall

    aw3d[:,-1,:]=0
    ae3d[:,-1,:]=0
    as3d[:,-1,:]=0
    an3d[:,-1,:]=0
    al3d[:,-1,:]=0
    ah3d[:,-1,:]=0
    ap3d[:,-1,:]=ap_max
    su3d[:,-1,:]=ap_max*kwall

    return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_eps():

# south wall
   dy=dist3d[0,0,0] # first cell
   ustar=cmu**0.25*k3d[:,0,:]**0.5 # k is fixed in first cell

# fix eps at 1st cell
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*ustar**3/kappa/dy

# north wall
   dy=dist3d[0,-1,0]
   ustar=cmu**0.25*k3d[:,-1,:]**0.5  # k is fixed in first cell

   print('modify_eps, ustar[-2,-2],[2,2]',ustar[-2,-2],ustar[2,2])

# fix eps at 1st cell
   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*ustar**3/kappa/dy



   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_vis(vis3d):

   return vis3d

def solve_ustar_reich(ustar,velabs,wdist):
    import numpy as xp

    return  ustar-velabs/(1/0.4*xp.log(1+0.4*ustar*wdist)+7.8*(1-xp.exp(-ustar*wdist/11))-ustar*wdist*xp.exp(-ustar*wdist/3))

def compute_ustar_reich_solve(wdist,velabs,ustar):
   from scipy.optimize import fsolve,root,newton
   yplus=ustar*wdist/viscos
   ustar=ustar.flatten()
   velabs=velabs.flatten()
   wdist=wdist.flatten()
   if gpu:
      velabs_np = xp.asnumpy(velabs)
      ustar_np = xp.asnumpy(ustar)
      wdist_np = xp.asnumpy(wdist)
   else:
      velabs_np = velabs
      ustar_np = ustar
      wdist_np = wdist
   ustar = newton(solve_ustar_reich,x0=ustar_np,args=(velabs_np,wdist_np))
   ustar = xp.asarray(ustar)
   ustar=xp.reshape(ustar,(ni,nk))

   return ustar

def STG(U0,viscos,itstep,deltat,xp2d_synt,yp2d_synt,zp2d_synt,dw2d,dx2d,dy2d,dz2d,r11,r12,r13,r22,r23,r33,k_rans,om_rans):
#!!!  number of modes                        = nmodes
#!!!  kinetic viscosity                      = visc

   global N,kn_wave,le_max,kx,ky,kz,psi,qnNorm,sxio,syio,szio,a11,a21,a22,a31,a32,a33
   
   nj = int(xp.size(yp2d_synt,0))
   nk = int(xp.size(yp2d_synt,1))

   if itstep == 0:

      uvmean_check=0
      uv_synt_mean=0
# compute length scales le, lcut, lnu, lt
      eps = 0.09*om_rans*k_rans

      lt  = k_rans**0.5/(0.09*om_rans)
      le  = xp.minimum(2.0*dw2d,3.0*lt)
      lnu = viscos**0.75/eps**0.25    
      le_max = xp.amax(le)
      
      tmp = xp.maximum(dy2d,dz2d)
      hmax = xp.maximum(tmp,dx2d)
      tmpNew = xp.maximum(tmp,0.3*hmax) + 0.1*dw2d
      lcut = 2.0*xp.minimum(tmpNew,hmax)
# compute ke, kcut, knu
      ke = 2.0*xp.pi/le
      kcut = 2.0*xp.pi/lcut 
      knu = 2*xp.pi/lnu
      
      ke_min = 2*xp.pi/le_max
      k_min_STG = 0.5*ke_min
      k_max = 1.5*xp.amax(kcut)
      alpha = 0.01
      
      N = int(xp.ceil(xp.log(k_max/k_min_STG)/xp.log(1+alpha) + 1))
      n = xp.linspace(1,N,N)
      kn = k_min_STG*(1+alpha)**(n-1)
      dkn = xp.zeros(N)
      dkn[0] = 0.5*(kn[1]-kn[0])
      dkn[1:-1] = 0.5*(kn[2:]-kn[0:-2])
      dkn[-1] = 0.5*(kn[-1]-kn[-2])
      
#      fcut = 
      
# create a seed from time 
      xp.random.seed()
      xp.random.seed(2)

# zero all arrays to zero
#      wnr=xp.zeros(nmodes+2)
#      fi=xp.zeros(nmodes+2)
      teta=xp.zeros(N)
      psi=xp.zeros(N)
#      wnr=xp.zeros(nmodes+2)
      kxio=xp.zeros((nj,nk,N))
      kyio=xp.zeros((nj,nk,N))
      kzio=xp.zeros((nj,nk,N))
      sxio=xp.zeros((nj,nk,N))
      syio=xp.zeros((nj,nk,N))
      szio=xp.zeros((nj,nk,N))
#  yp2d_wave=xp.zeros((nj,nk,nmodes+2))
      zp2d_wave=xp.zeros((nj,nk,N))
      u=xp.zeros((nj,nk))
      v=xp.zeros((nj,nk))
      w=xp.zeros((nj,nk))
 # # compute random angles
     
      fi = xp.random.uniform(0.,2.*math.pi,N)
      psi = xp.random.uniform(0.,2.*math.pi,N)
      alfa = xp.random.uniform(0.,2.*math.pi,N)
      ang = xp.random.uniform(0.,1,N)
      teta=xp.arccos(1.-ang/0.5) 
      
#   wavenumber vector from random angles
      kxio=xp.sin(teta)*xp.cos(fi)
      kyio=xp.sin(teta)*xp.sin(fi)
      kzio=xp.cos(teta)
#
# sigma (s=sigma) from random angles. sigma is the unit direction which gives the direction
# of the synthetic velocity vector (u, v, w)
      sxio=xp.cos(fi)*xp.cos(teta)*xp.cos(alfa)-xp.sin(fi)*xp.sin(alfa)
      syio=xp.sin(fi)*xp.cos(teta)*xp.cos(alfa)+xp.cos(fi)*xp.sin(alfa)
      szio=-xp.sin(teta)*xp.cos(alfa)
      
      kcut_wave =  xp.repeat(kcut[:,:,None], repeats=N, axis=2)
      knu_wave  =  xp.repeat(knu[:,:,None], repeats=N, axis=2)
      ke_wave  =  xp.repeat(ke[:,:,None], repeats=N, axis=2)
      kn_wave = xp.repeat(kn[:,None],repeats=nj,axis=1)
      kn_wave = xp.repeat(kn_wave[:,:,None],repeats=nk,axis=2)
      kn_wave = xp.transpose(kn_wave,(1,2,0))
      
      fnu = xp.exp(-(12.0*kn_wave/knu_wave)**2)
      fcut = xp.exp(-(4.0*xp.maximum(kn_wave-0.9*kcut_wave,xp.zeros((nj,nk,N)))/kcut_wave)**3)
      
      E = (kn_wave/ke_wave)**4/(1.0+2.4*(kn_wave/ke_wave)**2)**(17/6)*fnu*fcut 
      
      # print('ASD')
      # jidx = 35
      # kidx = 15

      # Eplot = E[jidx,kidx,:]

      # fig1,ax1 = plt.subplots()
      # plt.subplots_adjust(left=0.20,bottom=0.20)
      # plt.loglog(kn,Eplot,'b--',label="$x=0$")
      # plt.loglog(kn,kn**(-5/3),'r-',label="$x=0$")
      # plt.axis([1, xp.amax(kn), 0.000001, 1])
      
      qn = E*dkn
      qnSum = xp.sum(qn,axis=2)
      
      qnNorm = qn/xp.repeat(qnSum[:,:,None], repeats=N, axis=2)
      
      kx=kxio*kn
      ky=kyio*kn
      kz=kzio*kn
      
      # Cholesky Decomposition
      a11 = xp.sqrt(r11)
      a21 = r12/(a11)
      a22 = xp.sqrt(r22 - a21*a21)
      a31 = r13/(a11)
      a32 = (r23 - a21*a31)/(a22)
      a33 = xp.sqrt(r33 - a31*a31 - a32*a32)

      print('le_min :  ' + str(xp.amin(le)))
      print('lcut_min :  ' + str(xp.amin(lcut)))
      print('lnu_min :  ' + str(xp.amin(lnu)))
      print('k_min_STG :  ' + str(k_min_STG))
      print('k_max :  ' + str(k_max))
      print('NModes STG:  ' + str(N))
          
# #
# #=========================================================================
# #
     
   xp2d_wave=xp.repeat(xp2d_synt[:,:,None], repeats=N, axis=2)
   arg1=(2*xp.pi/(kn_wave*le_max))*(xp2d_wave-U0*deltat*itstep)*kx
      
   yp2d_wave=xp.repeat(yp2d_synt[:,:,None], repeats=N, axis=2)
   arg2=yp2d_wave*ky

   zp2d_wave=xp.repeat(zp2d_synt[:,:,None], repeats=N, axis=2)
   arg3=zp2d_wave*kz

   arg=arg1+arg2+arg3+psi

   tfunk=xp.cos(arg)

# sum over all wavenumbers => synthetic velocity field 
   usynt= xp.sqrt(6.0)*xp.sum(xp.sqrt(qnNorm)*tfunk*sxio,axis=2)
   vsynt= xp.sqrt(6.0)*xp.sum(xp.sqrt(qnNorm)*tfunk*syio,axis=2)
   wsynt= xp.sqrt(6.0)*xp.sum(xp.sqrt(qnNorm)*tfunk*szio,axis=2)
   
#  
   usynt_aniso = a11*usynt 
   vsynt_aniso = a21*usynt+a22*vsynt
   wsynt_aniso = a31*usynt+a32*vsynt+a33*wsynt
 
   return usynt_aniso,vsynt_aniso,wsynt_aniso

import cupy as xp
import time,random,sys
#from scipy.signal import welch, hann
import math


def synt_fluct(nmodes,it,sli,yp,zp,uv_rans,visc,jmirror,dmin_synt):
#=========================== chapter 1 ============================================

#!!!  number of modes                        = nmodes
#!!!  smallest wavenumber                    = dxmin
#!!!  ratio  of ke and kmin (in wavenumber)  = wew1fct
#!!!  turb. velocity scale                   = up
#!!!  diss. rate.                            = epsm
#!!!  kinetic viscosity                      = visc
#!!!  length scale                           = sli
#!!!  mirror vfluct at j > jmirror

   global dxmin,amp,epsm,epsm,wnr1,wew1fct,xp,yp2d_synt,zp2d_synt,xp2d_synt,nj,up
   global e,kxio,kyio,kyio,sxio,syio,syio,utn,tfunk,wnre,dkn,arg1,arg2,arg3,arg,fi,psi,teta,alfa,wnr,kx,ky,kz,wnrn,\
          r11,r12,r13,r21,r22,r23,r31,r32,r33,a11,a22,a33,wnreta,uv_synt_mean,uvmean_check,a11i,a22i,a33i,\
          xp2d_wave,yp2d_wave,zp2d_wave,uv_rans_non,uv_rans_max
   global utn,tfunk,sx,e,rk,kxi,usynt_wave,usynt1,usynt,sy,vsynt,yp2d_org,usynt_aniso,scale_2

   uv_rans=xp.abs(uv_rans)
   if it == 0:

      uvmean_check=0
      uv_synt_mean=0
# anisotropix fluctuations
#     R=xp.loadtxt('R.dat')
      R=xp.genfromtxt("R.dat", dtype=None,comments="%")
      r11=R[0,0]
      r12=R[0,1]
      r13=R[0,2]
      r21=R[1,0]
      r22=R[1,1]
      r23=R[1,2]
      r31=R[2,0]
      r32=R[2,1]
      r33=R[2,2]

#     A=xp.loadtxt('a.dat')
      A=xp.genfromtxt("a.dat", dtype=None,comments="%")
      a11=A[0]
      a22=A[1]
      a33=A[2]

      amp=1.452762113
      wew1fct=2

# in log region:  k/uvmax=3.3  => up=(3.3*uvmax)**0.5
      if xp.all(uv_rans==1):
         up=1
      else:
         up=(3.3*xp.max(uv_rans))**0.5

      epsm=up**3/sli
#
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#     number of  grid points in y, z
      nj=len(yp)
      nk=len(zp)

# make it 2D
      zp2d_synt=xp.repeat(zp[None,:], repeats=nj, axis=0)

# make it 2D
      yp2d_synt=xp.repeat(yp[:,None], repeats=nk, axis=1)


      xpp=0.5
# make it 2D
      xp2d_synt=xp.ones((nj,nk))*xpp

      yp2d_org=yp2d_synt
# transform to principal coord. directions
      xp2d_synt=r11*xpp+r21*yp2d_org+r31*zp2d_synt
      yp2d_synt=r12*xpp+r22*yp2d_synt+r32*zp2d_synt
      zp2d_synt=r13*xpp+r23*yp2d_synt+r33*zp2d_synt

# search min grid step
      dminy=xp.min(xp.diff(yp))
      dminz=xp.min(xp.diff(zp))
      dxmin=min(dminy,dminz)

# don't let is be smaller than dmin_synt
      dxmin=max(dxmin,dmin_synt)



# create a seed from time 
      xp.random.seed()
      xp.random.seed(2)

# zero all arrays to zero
      wnr=xp.zeros(nmodes+2)
      fi=xp.zeros(nmodes+2)
      teta=xp.zeros(nmodes+2)
      psi=xp.zeros(nmodes+2)
      wnr=xp.zeros(nmodes+2)
      kxio=xp.zeros((nj,nk,nmodes+2))
      kyio=xp.zeros((nj,nk,nmodes+2))
      kzio=xp.zeros((nj,nk,nmodes+2))
      sxio=xp.zeros((nj,nk,nmodes+2))
      syio=xp.zeros((nj,nk,nmodes+2))
      szio=xp.zeros((nj,nk,nmodes+2))
#  yp2d_wave=xp.zeros((nj,nk,nmodes+2))
      zp2d_wave=xp.zeros((nj,nk,nmodes+2))
      u=xp.zeros((nj,nk))
      v=xp.zeros((nj,nk))
      w=xp.zeros((nj,nk))
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#     highest wave number
      wnrn=2.*math.pi/dxmin
#
#     k_e (related to peak energy wave number)
      wnre=9.*math.pi*amp/(55.*sli)
#
# wavenumber used in the viscous expression (high wavenumbers) in the von Karman spectrum
      wnreta=(epsm/visc**3)**0.25

#     smallest wavenumber 
      wnr1=wnre/wew1fct

# wavenumber step
      dkn=(wnrn-wnr1)/nmodes

# wavenumers
      wnr=xp.linspace(wnr1,wnrn,nmodes)

# invert the eigenvalue matrix (anisotropic)
      a11i=1/a11
      a22i=1/a22
      a33i=1/a33

# make a non-dimensional uv_rans profile
      uv_rans_max=xp.max(uv_rans)
      uv_rans_non=(uv_rans/(uv_rans_max+1e-10))**0.5
# make it 2D
      uv_rans_non=xp.repeat(uv_rans_non[:,None], repeats=nk, axis=1)


      print(f"\n{'sli: '} {sli}, {'visc: '}{visc:.2e}, {'nmodes: '}{nmodes}, {'dxmin: '}{dxmin:.3e}, {'dkn: '}{dkn:.3e}, {'dmin_synt: '}{dmin_synt:.3e}")

      print(f"\n{'wnre: '} {wnre:.2e}, {'wnr1: '}{wnr1:.2e}, {'epsm: '}{epsm:.3e}, {'wnrn: '}{wnrn:.3e}")

      print(f"\n{'eigenvalue 1, 2 and 3: '}{a11:.3e}, {a22:.3e}, {a33:.3e}")
      print(f"\n{'eigenvector R11, R12 and R13: '}{r11:.3e}, {r12:.3e}, {r13:.3e}")
      print(f"\n{'eigenvector R21, R22 and R23: '}{r21:.3e}, {r22:.3e}, {r23:.3e}")
      print(f"\n{'eigenvector R31, R32 and R33: '}{r31:.3e}, {r32:.3e}, {r33:.3e}\n")

      

#
#=========================== chapter 2 ============================================
#

# compute random angles
   fi = xp.random.uniform(0.,2.*math.pi,nmodes)
   psi = xp.random.uniform(0.,2.*math.pi,nmodes)
   alfa = xp.random.uniform(0.,2.*math.pi,nmodes)
   ang = xp.random.uniform(0.,1,nmodes)
   teta=xp.arccos(1.-ang/0.5) 

   print('time step no,',it)


#   wavenumber vector from random angles
   kxio=xp.sin(teta)*xp.cos(fi)
   kyio=xp.sin(teta)*xp.sin(fi)
   kzio=xp.cos(teta)
#
# sigma (s=sigma) from random angles. sigma is the unit direction which gives the direction
# of the synthetic velocity vector (u, v, w)
   sxio=xp.cos(fi)*xp.cos(teta)*xp.cos(alfa)-xp.sin(fi)*xp.sin(alfa)
   syio=xp.sin(fi)*xp.cos(teta)*xp.cos(alfa)+xp.cos(fi)*xp.sin(alfa)
   szio=-xp.sin(teta)*xp.cos(alfa)
   
#
#=========================== chapter 3 ============================================
#
# loop over all wavenumbers
   kxi=r11*kxio+r21*kyio+r31*kzio
   kyi=r12*kxio+r22*kyio+r32*kzio
   kzi=r13*kxio+r23*kyio+r33*kzio

   sxi=r11*sxio+r21*syio+r31*szio
   syi=r12*sxio+r22*syio+r32*szio
   szi=r13*sxio+r23*syio+r33*szio

   sx=a11**0.5*sxi
   sy=a22**0.5*syi
   sz=a33**0.5*szi

   kx=kxi*wnr*a11i**0.5
   ky=kyi*wnr*a22i**0.5
   kz=kzi*wnr*a33i**0.5
   rk=xp.sqrt(kx**2+ky**2+kz**2)



   xp2d_wave=xp.repeat(xp2d_synt[:,:,None], repeats=nmodes, axis=2)
   arg1=xp2d_wave*kx


   yp2d_wave=xp.repeat(yp2d_synt[:,:,None], repeats=nmodes, axis=2)
   arg2=yp2d_wave*ky

   zp2d_wave=xp.repeat(zp2d_synt[:,:,None], repeats=nmodes, axis=2)
   arg3=zp2d_wave*kz

   arg=arg1+arg2+arg3+psi

   tfunk=xp.cos(arg)

# von Karman spectrum
   e=amp/wnre*(wnr/wnre)**4/((1.+(wnr/wnre)**2)**(17./6.))*xp.exp(-2*(wnr/wnreta)**2)

# include only wavenumber for which rk < wnrn
   e=xp.where(rk < wnrn,e,0)

   utn=xp.sqrt(e*up**2*dkn)

# sum over all wavenumbers => synthetic velocity field 
   usynt=xp.sum(2.*utn*tfunk*sx,axis=2)
   vsynt=xp.sum(2.*utn*tfunk*sy,axis=2)
   wsynt=xp.sum(2.*utn*tfunk*sz,axis=2)
   
# transform back to x-y-z  => anjsotropic fluct
   usynt_aniso=r11*usynt+r12*vsynt+r13*wsynt
   vsynt_aniso=r21*usynt+r22*vsynt+r23*wsynt
   wsynt_aniso=r31*usynt+r32*vsynt+r33*wsynt

# mean shear stress (must be computed before mirroring)
   uv=xp.mean(usynt_aniso*vsynt_aniso)

# mirror vfluct
   if jmirror > 0:
      vsynt_aniso[jmirror:,:]=-vsynt_aniso[jmirror:,:]

# sum over timesteps
   uv_synt_mean=uv_synt_mean+uv

# compute average
   uvmean_time=xp.abs(uv_synt_mean)/(it+1)
   print('uvmean_time',uvmean_time)

   scale_2=(uv_rans_max/uvmean_time)**0.5*uv_rans_non

# if uv_rans=1: don't scale
#  scale_2=xp.where(uv_rans==1,1,(uv_rans_max/uvmean_time)**0.5*uv_rans_non)

# scale all fluctuations with uv_rans
   usynt_aniso=usynt_aniso*scale_2
   vsynt_aniso=vsynt_aniso*scale_2
   wsynt_aniso=wsynt_aniso*scale_2


# compute mean of synt fluct
   uvmean_check=uvmean_check+xp.mean(usynt_aniso*vsynt_aniso,axis=1)

# peak of uv_rans
   j=xp.where(uv_rans == xp.amax(uv_rans))

# check peak
   print('synt: uvmean',xp.abs(uvmean_check[j])/(it+1),'uv_rans_max=',xp.max(xp.abs(uv_rans)),'at j=',j)
 
   return usynt_aniso,vsynt_aniso,wsynt_aniso
if gpu:
   from cupyx.scipy import sparse
   import cupy as xp
   from cupyx.scipy.sparse import spdiags,linalg,eye
else:
   from scipy import sparse
   import numpy as xp
   from scipy.sparse import spdiags,linalg,eye
import sys
import time
import pyamg
import pyamgx

import socket

def init():
   print('hostname: ',socket.gethostname())

# distance to nearest wall
   ywall_s=0.5*(y2d[0:-1,0]+y2d[1:,0])
   dist_s=yp2d-ywall_s[:,None]
   ywall_n=0.5*(y2d[0:-1,-1]+y2d[1:,-1])
   dist_n=ywall_n[:,None] -yp2d
   dist=xp.minimum(dist_s,dist_n)
   dist3d=xp.repeat(dist[:,:,None], repeats=nk, axis=2)

#  west face coordinate
   xw=0.5*(x2d[0:-1,0:-1]+x2d[0:-1,1:])
   yw=0.5*(y2d[0:-1,0:-1]+y2d[0:-1,1:])

   del1x=((xw-xp2d)**2+(yw-yp2d)**2)**0.5
   del2x=((xw-xp.roll(xp2d,1,axis=0))**2+(yw-xp.roll(yp2d,1,axis=0))**2)**0.5
   fx=del2x/(del1x+del2x)
   fx = xp.dstack([fx]*nk)

   if cyclic_x:
     fx[0,:,:]=0.5

#  south face coordinate
   xs=0.5*(x2d[0:-1,0:-1]+x2d[1:,0:-1])
   ys=0.5*(y2d[0:-1,0:-1]+y2d[1:,0:-1])

   del1y=((xs-xp2d)**2+(ys-yp2d)**2)**0.5
   del2y=((xs-xp.roll(xp2d,1,axis=1))**2+(ys-xp.roll(yp2d,1,axis=1))**2)**0.5
   fy=del2y/(del1y+del2y)
   fy = xp.dstack([fy]*nk)

#  low face coordinate
   zl=z[0:-1]

   del1z=zp-zl
   del2z=xp.abs(zl-xp.roll(zp,1))
   fz=del2z/(del1z+del2z)
   fz=xp.repeat(fz[None,:], repeats=nj, axis=0)
   fz=xp.repeat(fz[None,:,:], repeats=ni, axis=0)

   if cyclic_z:
     fz[:,:,0]=0.5

   areawy=xp.diff(x2d,axis=1)
   areawx=-xp.diff(y2d,axis=1)

# make them 3d
   areawx= xp.dstack([areawx]*nk)
   areawy= xp.dstack([areawy]*nk)
   areawy=areawy*dz3d[:,1:,:]
   areawx=areawx*dz3d[:,1:,:]

   areasy=-xp.diff(x2d,axis=0)
   areasx=xp.diff(y2d,axis=0)
# make them 3d
   areasx= xp.dstack([areasx]*nk)
   areasy= xp.dstack([areasy]*nk)
   areasy=areasy*dz3d[1:,:,:]
   areasx=areasx*dz3d[1:,:,:]

   areaw=(areawx**2+areawy**2)**0.5
   areas=(areasx**2+areasy**2)**0.5

# volume approaximated as the vector product of two triangles for cells
   ax=xp.diff(x2d,axis=1)
   ay=xp.diff(y2d,axis=1)
   bx=xp.diff(x2d,axis=0)
   by=xp.diff(y2d,axis=0)

   areaz_1=0.5*xp.absolute(ax[0:-1,:]*by[:,0:-1]-ay[0:-1,:]*bx[:,0:-1])

   ax=xp.diff(x2d,axis=1)
   ay=xp.diff(y2d,axis=1)
   areaz_2=0.5*xp.absolute(ax[1:,:]*by[:,0:-1]-ay[1:,:]*bx[:,0:-1])

   areaz=areaz_1+areaz_2
# make it 3d
   areal= xp.dstack([areaz]*(nk+1))
   vol=areal[:,:,1:]*dz3d[0:-1,0:-1,:]

# coeff at south wall (without viscosity)
   as_bound=areas[:,0,:]**2/(0.5*vol[:,0,:])

# coeff at north wall (without viscosity)
   an_bound=areas[:,-1,:]**2/(0.5*vol[:,-1,:])

# coeff at west wall (without viscosity)
   aw_bound=areaw[0,:,:]**2/(0.5*vol[0,:,:])
   if cyclic_x:
      aw_bound=areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))

# coeff at east wall (without viscosity) N.B: which cyclic_x
# this is never used 
   ae_bound=areaw[-1,:,:]**2/(0.5*vol[-1,:,:])

# make it 2d
   al_bound=areal[:,:,0]**2/(0.5*vol[:,:,0]) 
   if cyclic_z:
      al_bound=areal[:,:,0]**2/(0.5*(vol[:,:,0]+vol[:,:,-1]))

   ah_bound=areal[:,:,-1]**2/(0.5*vol[:,:,-1]) 

   return areaw,areawx,areawy,areas,areasx,areasy,areal,vol,fx,fy,fz,aw_bound,ae_bound,as_bound,an_bound,al_bound,ah_bound,dist3d

def print_indata():

   print('////////////////// Start of input data ////////////////// \n\n\n')

   print('\n\n########### section 0 section 0 choice of CPU or GPU ##########')
   print(f"{'GPU: ':<29} {gpu}")

   print('\n\n########### section 1 choice of differencing scheme ###########')
   print(f"{'scheme: ':<29}   {scheme}")
   print(f"{'scheme_turb: ':<29}   {scheme_turb}")
   print(f"{'acrank: ':<29}   {acrank}")
   print(f"{'acrank_conv: ':<29}   {acrank_conv}")
   print(f"{'acrank_conv_keps: ':<29}   {acrank_conv_keps}")
   print(f"{'acrank_conv_kom: ':<29}   {acrank_conv_kom}")
   if scheme == 'm':
      print(f"{'blend: ':<29}   {blend}")


   print('\n\n########### section 2 turbulence models ###########')

   print(f"{'cmu: ':<29} {cmu}")
   print(f"{'pans: ':<29} {pans}")
   print(f"{'keps: ':<29} {keps}")
   print(f"{'kom_des: ':<29} {kom_des}")
   print(f"{'keps_des: ':<29} {keps_des}")
   print(f"{'k_eq_les: ':<29} {k_eq_les}")
   if jl0 < 0:
      print(f"{'jl0: ':<29} {jl0}")
   print(f"{'kom: ':<29} {kom}")
   print(f"{'sst: ':<29} {sst}")
   print(f"{'smag: ':<29} {smag}")
   print(f"{'wale: ':<29} {wale}")
   if sst:
      print(f"{'prand_k_sst_1: ':<29} {prand_k_sst_1}")
      print(f"{'prand_k_sst_2: ':<29} {prand_k_sst_2}")
      print(f"{'prand_omega_sst_1: ':<29} {prand_omega_sst_1}")
      print(f"{'prand_omega_sst_2: ':<29} {prand_omega_sst_2}")
      print(f"{'cdes: ':<29} {cdes}")
   if k_eq_les:
      print(f"{'c_eps: ':<29} {c_eps}")
   if pans:
      print(f"{'fkmin_limit: ':<29} {fkmin_limit}")
   if sst :
      print(f"{'c_omega_1_sst_1: ':<29} {c_omega_1_sst_1:.3f}")
      print(f"{'c_omega_1_sst_2: ':<29} {c_omega_1_sst_2:.3f}")
      print(f"{'c_omega_2_sst_1: ':<29} {c_omega_2_sst_1}")
      print(f"{'c_omega_2_sst_2: ':<29} {c_omega_2_sst_2}")
   if keps or pans or keps_des:
      print(f"{'c_eps_1: ':<29} {c_eps_1}")
      print(f"{'c_eps_2: ':<29} {c_eps_2}")
      print(f"{'prand_k: ':<29} {prand_k}")
      print(f"{'prand_eps: ':<29} {prand_eps}")
   if kom or kom_des:
      print(f"{'c_omega_1: ':<29} {c_omega_1:.3f}")
      print(f"{'c_omega_2: ':<29} {c_omega_2}")
      print(f"{'prand_k: ':<29} {prand_k}")
      print(f"{'prand_omega: ':<29} {prand_omega}")

   if keps_des or kom_des:
      print(f"{'cdes: ':<29} {cdes}")


   print('\n\n########### section 3 restart/save ###########')
   print(f"{'restart: ':<29} {restart}")
   print(f"{'save: ':<29} {save}")


   print('\n\n########### section 4 fluid properties ###########')
   print(f"{'viscos: ':<29} {viscos:.2e}")

   print('\n\n########### section 5 relaxation factors ###########')
   print(f"{'urfvis: ':<29} {urfvis}")


   print('\n\n########### section 6 number of iteration and convergence criterira ###########')
   print(f"{'sormax: ':<29} {sormax}")
   print(f"{'norm_order: ':<29} {norm_order}")
   print(f"{'maxit: ':<29} {maxit}")
   print(f"{'min_iter: ':<29} {min_iter}")
   print(f"{'solver_vel: ':<29} {solver_vel}")
   print(f"{'solver_p: ':<29} {solver_p}")
   print(f"{'coeff_v: ':<29} {coeff_v}")
   print(f"{'coeff_w: ':<29} {coeff_w}")
   print(f"{'embedded: ':<29} {embedded}")
   if embedded:
      print(f"{'x_embed: ':<29} {x_embed}")
   print(f"{'solver_turb: ':<29} {solver_turb}")
   print(f"{'amg_relax: ':<29} {amg_relax}")
   print(f"{'amg_cycle: ':<29} {amg_cycle}")
   if solver_vel == 'pyamg' or solver_turb == 'pyamg':
      if amg_relax_phi != 'default':
         print(f"{'amg_relax_phi: ':<29} {amg_relax_phi}")
         print(f"{'amg_cycle_phi: ':<29} {amg_cycle_phi}")
   print(f"{'nsweep_vel: ':<29} {nsweep_vel}")
   print(f"{'nsweep_keps: ':<29} {nsweep_keps}")
   print(f"{'nsweep_kom: ':<29} {nsweep_kom}")
   if gpu:
      print(f"{'convergence_limit_gpu: ':<29} {convergence_limit_gpu}")
   else:
      print(f"{'convergence_limit_u: ':<29} {convergence_limit_u}")
      print(f"{'convergence_limit_v: ':<29} {convergence_limit_v}")
      print(f"{'convergence_limit_w: ':<29} {convergence_limit_w}")
      print(f"{'convergence_limit_p: ':<29} {convergence_limit_p}")
      if keps or pans or sst or kom or kom_des or keps_des:
         print(f"{'convergence_limit_k: ':<29} {convergence_limit_k}")
      if keps or pans or keps_des:
         print(f"{'convergence_limit_eps: ':<29} {convergence_limit_eps}")
      if sst or kom or kom_des:
         print(f"{'convergence_limit_om: ':<29} {convergence_limit_om}")

   print('\n\n########### section 7 all variables are printed during the iteration at node ###########')
   print(f"{'imon: ':<29} {imon}")
   print(f"{'jmon: ':<29} {jmon}")
   print(f"{'kmon: ':<29} {kmon}")


   print('\n\n########### section 8 time-averaging ###########')
   print(f"{'ntstep: ':<29} {ntstep}")
   print(f"{'dt[0]: ':<29} {dt[0]:.2e}")
   print(f"{'itstep_start: ':<29} {itstep_start}")
   print(f"{'itstep_save: ':<29} {itstep_save}")
   print(f"{'itstep_stats: ':<29} {itstep_stats}")


   print('\n\n########### section 9 residual scaling parameters ###########')
   print(f"{'resnorm_p: ':<29} {resnorm_p:.1f}")
   print(f"{'resnorm_vel: ':<29} {resnorm_vel:.1f}")


   print('\n\n########### Section 10 grid and boundary conditions ###########')
   print(f"{'ni: ':<29} {ni}")
   print(f"{'nj: ':<29} {nj}")
   print(f"{'nk: ':<29} {nk}")
   if xp.isscalar(dz):
      print(f"{'dz: ':<29} {dz:.2e}")
   print('\n')
   print(f"{'cyclic_x: ':<29} {cyclic_x}")
   print(f"{'cyclic_z: ':<29} {cyclic_z}")
   print('\n')
   if not cyclic_x:
      print(f"{'L_t_synt: ':<29} {L_t_synt}")
      print(f"{'nmodes_synt: ':<29} {nmodes_synt}")
      print(f"{'dmin_synt: ':<29} {dmin_synt}")
      print(f"{'jmirror_synt: ':<29} {jmirror_synt}")
      print('\n')

   if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
      print('------blockage')
      print(f"{' ':<5}{'i_block_start: ':<29} {i_block_start}")
      print(f"{' ':<5}{'i_block_end: ':<29} {i_block_end}")
      print(f"{' ':<5}{'j_block_start: ':<29} {j_block_start}")
      print(f"{' ':<5}{'j_block_end: ':<29} {j_block_end}")
      print('\n')

   print('------boundary conditions for u')
   if not cyclic_x:
      print(f"{' ':<5}{'u_bc_west_type: ':<29} {u_bc_west_type}")
      print(f"{' ':<5}{'u_bc_east_type: ':<29} {u_bc_east_type}")
      if u_bc_west_type == 'd':
         print(f"{' ':<5}{'u_bc_west[0,0]: ':<29} {u_bc_west[0,0]}")
      if u_bc_east_type == 'd':
         print(f"{' ':<5}{'u_bc_east[0,0]: ':<29} {u_bc_east[0,0]}")


   print(f"{' ':<5}{'u_bc_south_type: ':<29} {u_bc_south_type}")
   print(f"{' ':<5}{'u_bc_north_type: ':<29} {u_bc_north_type}")

   if u_bc_south_type == 'd':
      print(f"{' ':<5}{'u_bc_south[0,0]: ':<29} {u_bc_south[0,0]}")
   if u_bc_north_type == 'd':
      print(f"{' ':<5}{'u_bc_north[0,0]: ':<29} {u_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'u_bc_low_type: ':<29} {u_bc_low_type}")
      print(f"{' ':<5}{'u_bc_high_type: ':<29} {u_bc_high_type}")
      if u_bc_low_type == 'd':
         print(f"{' ':<5}{'u_bc_lo[0,0]: ':<29} {u_bc_low[0,0]}")
      if u_bc_high_type == 'd':
         print(f"{' ':<5}{'u_bc_hig[0,0]: ':<29} {u_bc_high[0,0]}")

   print('------boundary conditions for v')
   if not cyclic_x:
      print(f"{' ':<5}{'v_bc_west_type: ':<29} {v_bc_west_type}")
      print(f"{' ':<5}{'v_bc_east_type: ':<29} {v_bc_east_type}")
      if v_bc_west_type == 'd':
         print(f"{' ':<5}{'v_bc_west[0,0]: ':<29} {v_bc_west[0,0]}")
      if v_bc_east_type == 'd':
         print(f"{' ':<5}{'v_bc_east[0,0]: ':<29} {v_bc_east[0,0]}")


   print(f"{' ':<5}{'v_bc_south_type: ':<29} {v_bc_south_type}")
   print(f"{' ':<5}{'v_bc_north_type: ':<29} {v_bc_north_type}")

   if v_bc_south_type == 'd':
      print(f"{' ':<5}{'v_bc_south[0,0]: ':<29} {v_bc_south[0,0]}")
   if v_bc_north_type == 'd':
      print(f"{' ':<5}{'v_bc_north[0,0]: ':<29} {v_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'v_bc_low_type: ':<29} {v_bc_low_type}")
      print(f"{' ':<5}{'v_bc_high_type: ':<29} {v_bc_high_type}")
      if v_bc_low_type == 'd':
         print(f"{' ':<5}{'v_bc_lo[0,0]: ':<29} {v_bc_low[0,0]}")
      if v_bc_high_type == 'd':
         print(f"{' ':<5}{'v_bc_hig[0,0]: ':<29} {v_bc_high[0,0]}")

   print('------boundary conditions for w')
   if not cyclic_x:
      print(f"{' ':<5}{'w_bc_west_type: ':<29} {w_bc_west_type}")
      print(f"{' ':<5}{'w_bc_east_type: ':<29} {w_bc_east_type}")
      if w_bc_west_type == 'd':
         print(f"{' ':<5}{'w_bc_west[0,0]: ':<29} {w_bc_west[0,0]}")
      if w_bc_east_type == 'd':
         print(f"{' ':<5}{'w_bc_east[0,0]: ':<29} {w_bc_east[0,0]}")


   print(f"{' ':<5}{'w_bc_south_type: ':<29} {w_bc_south_type}")
   print(f"{' ':<5}{'w_bc_north_type: ':<29} {w_bc_north_type}")

   if w_bc_south_type == 'd':
      print(f"{' ':<5}{'w_bc_south[0,0]: ':<29} {w_bc_south[0,0]}")
   if w_bc_north_type == 'd':
      print(f"{' ':<5}{'w_bc_north[0,0]: ':<29} {w_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'w_bc_low_type: ':<29} {w_bc_low_type}")
      print(f"{' ':<5}{'w_bc_high_type: ':<29} {w_bc_high_type}")
      if w_bc_low_type == 'd':
         print(f"{' ':<5}{'w_bc_low[0,0]: ':<29} {w_bc_low[0,0]}")
      if w_bc_high_type == 'd':
         print(f"{' ':<5}{'w_bc_hig[0,0]: ':<29} {w_bc_high[0,0]}")

   print('------boundary conditions for p')
   if not cyclic_x:
      print(f"{' ':<5}{'p_bc_west_type: ':<29} {p_bc_west_type}")
      print(f"{' ':<5}{'p_bc_east_type: ':<29} {p_bc_east_type}")
      if p_bc_west_type == 'd':
         print(f"{' ':<5}{'p_bc_west[0,0]: ':<29} {p_bc_west[0,0]}")
      if p_bc_east_type == 'd':
         print(f"{' ':<5}{'p_bc_east[0,0]: ':<29} {p_bc_east[0,0]}")


   print(f"{' ':<5}{'p_bc_south_type: ':<29} {p_bc_south_type}")
   print(f"{' ':<5}{'p_bc_north_type: ':<29} {p_bc_north_type}")

   if p_bc_south_type == 'd':
      print(f"{' ':<5}{'p_bc_south[0,0]: ':<29} {p_bc_south[0,0]}")
   if p_bc_north_type == 'd':
      print(f"{' ':<5}{'p_bc_north[0,0]: ':<29} {p_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'p_bc_low_type: ':<29} {p_bc_low_type}")
      print(f"{' ':<5}{'p_bc_high_type: ':<29} {p_bc_high_type}")
      if p_bc_low_type == 'd':
         print(f"{' ':<5}{'p_bc_lo[0,0]: ':<29} {p_bc_low[0,0]}")
      if p_bc_high_type == 'd':
         print(f"{' ':<5}{'p_bc_hig[0,0]: ':<29} {p_bc_high[0,0]}")

   if sst or kom or kom_des or keps or pans or keps_des or k_eq_les:
      print('------boundary conditions for k')
      if not cyclic_x:
         print(f"{' ':<5}{'k_bc_west_type: ':<29} {k_bc_west_type}")
         print(f"{' ':<5}{'k_bc_east_type: ':<29} {k_bc_east_type}")
         if k_bc_west_type == 'd':
            print(f"{' ':<5}{'k_bc_west[0,0]: ':<29} {k_bc_west[0,0]}")
         if k_bc_east_type == 'd':
            print(f"{' ':<5}{'k_bc_east[0,0]: ':<29} {k_bc_east[0,0]}")
   
   
      print(f"{' ':<5}{'k_bc_south_type: ':<29} {k_bc_south_type}")
      print(f"{' ':<5}{'k_bc_north_type: ':<29} {k_bc_north_type}")
   
      if k_bc_south_type == 'd':
         print(f"{' ':<5}{'k_bc_south[0,0]: ':<29} {k_bc_south[0,0]}")
      if k_bc_north_type == 'd':
         print(f"{' ':<5}{'k_bc_north[0,0]: ':<29} {k_bc_north[0,0]}")
   
      if not cyclic_z:
         print(f"{' ':<5}{'k_bc_low_type: ':<29} {k_bc_low_type}")
         print(f"{' ':<5}{'k_bc_high_type: ':<29} {k_bc_high_type}")
         if k_bc_low_type == 'd':
            print(f"{' ':<5}{'k_bc_lo[0,0]: ':<29} {k_bc_low[0,0]}")
         if k_bc_high_type == 'd':
            print(f"{' ':<5}{'k_bc_hig[0,0]: ':<29} {k_bc_high[0,0]}")


   if keps or pans or keps_des:
      print('------boundary conditions for eps')
      if not cyclic_x:
         print(f"{' ':<5}{'eps_bc_west_type: ':<29} {eps_bc_west_type}")
         print(f"{' ':<5}{'eps_bc_east_type: ':<29} {eps_bc_east_type}")
         if eps_bc_west_type == 'd':
            print(f"{' ':<5}{'eps_bc_west[0,0]: ':<29} {eps_bc_west[0,0]}")
         if eps_bc_east_type == 'd':
            print(f"{' ':<5}{'eps_bc_east[0,0]: ':<29} {eps_bc_east[0,0]}")
   
   
      print(f"{' ':<5}{'eps_bc_south_type: ':<29} {eps_bc_south_type}")
      print(f"{' ':<5}{'eps_bc_north_type: ':<29} {eps_bc_north_type}")
   
      if eps_bc_south_type == 'd':
         print(f"{' ':<5}{'eps_bc_south[0,0]: ':<29} {eps_bc_south[0,0]}")
      if eps_bc_north_type == 'd':
         print(f"{' ':<5}{'eps_bc_north[0,0]: ':<29} {eps_bc_north[0,0]}")
   
      if not cyclic_z:
         print(f"{' ':<5}{'eps_bc_low_type: ':<29} {eps_bc_low_type}")
         print(f"{' ':<5}{'eps_bc_high_type: ':<29} {eps_bc_high_type}")
         if eps_bc_low_type == 'd':
            print(f"{' ':<5}{'eps_bc_lo[0,0]: ':<29} {eps_bc_low[0,0]}")
         if eps_bc_high_type == 'd':
            print(f"{' ':<5}{'eps_bc_hig[0,0]: ':<29} {eps_bc_high[0,0]}")

   if sst or kom or kom_des:
      print('------boundary conditions for omega')
      if not cyclic_x:
         print(f"{' ':<5}{'om_bc_west_type: ':<29} {om_bc_west_type}")
         print(f"{' ':<5}{'om_bc_east_type: ':<29} {om_bc_east_type}")
         if om_bc_west_type == 'd':
            print(f"{' ':<5}{'om_bc_west[0,0]: ':<29} {om_bc_west[0,0]:.1f}")
         if om_bc_east_type == 'd':
            print(f"{' ':<5}{'om_bc_east[0,0]: ':<29} {om_bc_east[0,0]:.1f}")
   
   
      print(f"{' ':<5}{'om_bc_south_type: ':<29} {om_bc_south_type}")
      print(f"{' ':<5}{'om_bc_north_type: ':<29} {om_bc_north_type}")
   
      if om_bc_south_type == 'd':
         print(f"{' ':<5}{'om_bc_south[0,0]: ':<29} {om_bc_south[0,0]:.1f}")
      if om_bc_north_type == 'd':
         print(f"{' ':<5}{'om_bc_north[0,0]: ':<29} {om_bc_north[0,0]:.1f}")
   
      if not cyclic_z:
         print(f"{' ':<5}{'om_bc_low_type: ':<29} {om_bc_low_type}")
         print(f"{' ':<5}{'om_bc_high_type: ':<29} {om_bc_high_type}")
         if om_bc_low_type == 'd':
            print(f"{' ':<5}{'om_bc_low[0,0]: ':<29} {om_bc_low[0,0]}")
         if om_bc_high_type == 'd':
            print(f"{' ':<5}{'om_bc_high[0,0]: ':<29} {om_bc_high[0,0]}")


   print('\n\n\n ////////////////// End of input data //////////////////\n\n\n')



   return 

def compute_face_phi(phi3d,phi_bc_west,phi_bc_east,phi_bc_south,phi_bc_north,phi_bc_low,phi_bc_high,\
    phi_bc_west_type,phi_bc_east_type,phi_bc_south_type,phi_bc_north_type,phi_bc_low_type,phi_bc_high_type,variable):
   if gpu:
      import cupy as xp
   else:
      import numpy as xp

   phi3d_face_w=xp.empty((ni+1,nj,nk))
   phi3d_face_s=xp.empty((ni,nj+1,nk))
   phi3d_face_l=xp.empty((ni,nj,nk+1))
   phi3d_face_w[0:-1,:,:]=fx*phi3d+(1-fx)*xp.roll(phi3d,1,axis=0)
   phi3d_face_s[:,0:-1,:]=fy*phi3d+(1-fy)*xp.roll(phi3d,1,axis=1)
   phi3d_face_l[:,:,0:-1]=fz*phi3d+(1-fz)*xp.roll(phi3d,1,axis=2)


# west boundary 
   phi3d_face_w[0,:,:]=phi_bc_west
   if phi_bc_west_type == 'n': 
# neumann
      phi3d_face_w[0,:,:]=phi3d[0,:,:]
   if cyclic_x:
      phi3d_face_w[0,:,:]=0.5*(phi3d[0,:,:]+phi3d[-1,:,:])

# east boundary 
   phi3d_face_w[-1,:,:]=phi_bc_east
   if phi_bc_east_type == 'n': 
# neumann
      phi3d_face_w[-1,:,:]=phi3d[-1,:,:]
   if cyclic_x:
      phi3d_face_w[-1,:,:]=0.5*(phi3d[0,:,:]+phi3d[-1,:,:])

# south boundary 
   phi3d_face_s[:,0,:]=phi_bc_south
   if phi_bc_south_type == 'n': 
# neumann
      phi3d_face_s[:,0,:]=phi3d[:,0,:]
# d2phidy2=0
   if phi_bc_south_type == '2': 
      phi3d_face_s[:,0,:]=1.5*phi3d[:,0,:]-0.5*phi3d[:,1,:]

# north boundary 
   phi3d_face_s[:,-1,:]=phi_bc_north
   if phi_bc_north_type == 'n': 
# neumann
      phi3d_face_s[:,-1,:]=phi3d[:,-1,:]
   if phi_bc_north_type == '2': 
# d2phidy2=0
      phi3d_face_s[:,-1,:]=1.5*phi3d[:,-1,:]-0.5*phi3d[:,-2,:]

# low boundary 
   phi3d_face_l[:,:,0]=phi_bc_low
# high boundary 
   phi3d_face_l[:,:,-1]=phi_bc_high
   if phi_bc_low_type == 'n': 
# neumann
# low boundary 
      phi3d_face_l[:,:,0]= phi3d[:,:,0]
   if phi_bc_high_type == 'n': 
# high boundary 
      phi3d_face_l[:,:,-1]= phi3d[:,:,-1]
   if cyclic_z:
# low boundary 
      phi3d_face_l[:,:,0]=  0.5*(phi3d[:,:,-1]+phi3d[:,:,0])
# high boundary 
      phi3d_face_l[:,:,-1]= 0.5*(phi3d[:,:,-1]+phi3d[:,:,0])

# this is needed only when blockage ios used
   if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
      phi3d_face_w,phi3d_face_s,phi3d_face_l = modify_face(phi3d_face_w,phi3d_face_s,phi3d_face_l,variable)
   
   return phi3d_face_w,phi3d_face_s,phi3d_face_l

def dphidx(phi_face_w,phi_face_s):

   phi_w=phi_face_w[0:-1,:,:]*areawx[0:-1,:,:]
   phi_e=-phi_face_w[1:,:,:]*areawx[1:,:,:]
   phi_s=phi_face_s[:,0:-1,:]*areasx[:,0:-1,:]
   phi_n=-phi_face_s[:,1:,:]*areasx[:,1:,:]
   return (phi_w+phi_e+phi_s+phi_n)/vol

def dphidy(phi_face_w,phi_face_s):

   phi_w=phi_face_w[0:-1,:,:]*areawy[0:-1,:,:]
   phi_e=-phi_face_w[1:,:,:]*areawy[1:,:,:]
   phi_s=phi_face_s[:,0:-1,:]*areasy[:,0:-1,:]
   phi_n=-phi_face_s[:,1:,:]*areasy[:,1:,:]
   return (phi_w+phi_e+phi_s+phi_n)/vol

def dphidz(phi_face_l):

   phi_l=phi_face_l[:,:,0:-1]*areal[:,:,0:-1]
   phi_h=phi_face_l[:,:,1:]*areal[:,:,1:]
   return (phi_h-phi_l)/vol

def coeff_m(convw,convs,convl,vis3d,u3d,v3d,w3d):

   if itstep == 0 and iter == 0:
      print('muscle scheme')

   visw=xp.zeros((ni+1,nj,nk))
   viss=xp.zeros((ni,nj+1,nk))
   visl=xp.zeros((ni,nj,nk+1))
   vis_turb=(vis3d-viscos)

   visw[0:-1,:,:]=fx*vis_turb+(1-fx)*xp.roll(vis_turb,1,axis=0)+viscos
   viss[:,0:-1,:]=fy*vis_turb+(1-fy)*xp.roll(vis_turb,1,axis=1)+viscos
   visl[:,:,0:-1]=fz*vis_turb+(1-fz)*xp.roll(vis_turb,1,axis=2)+viscos

   volw=xp.ones((ni+1,nj,nk))*1e-10
   vols=xp.ones((ni,nj+1,nk))*1e-10
   voll=xp.ones((ni,nj,nk+1))*1e-10
   volw[1:,:,:]=0.5*xp.roll(vol,-1,axis=0)+0.5*vol
   diffw=visw[0:-1,:,:]*areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*xp.roll(vol,-1,axis=1)+0.5*vol
   diffs=viss[:,0:-1,:]*areas[:,0:-1,:]**2/vols[:,0:-1,:]
   voll[:,:,1:]=0.5*xp.roll(vol,-1,axis=2)+0.5*vol
   diffl=visl[:,:,0:-1]*areal[:,:,0:-1]**2/voll[:,:,0:-1]

   if cyclic_x:
      visw[0,:,:]=0.5*(vis_turb[0,:,:]+vis_turb[-1,:,:])+viscos
      diffw[0,:,:]=visw[0,:,:]*areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))
   if cyclic_z:
      visl[:,:,0]=0.5*(vis_turb[:,:,0]+vis_turb[:,:,-1])+viscos

#        cep=0.5+sign(0.5,conve(i,j,k))
#        cwp=0.5+sign(0.5,conve(i-1,j,k))
#        cem=sign(0.5,conve(i,j,k))-0.5

   cwp=0.5+0.5*xp.sign(convw[0:-1,:,:])
   cep=0.5+0.5*xp.sign(convw[1:,:,:])
   cwm=0.5*xp.sign(convw[0:-1,:,:])-0.5
   cem=0.5*xp.sign(convw[1:,:,:])-0.5

   csp=0.5+0.5*xp.sign(convs[:,0:-1,:])
   cnp=0.5+0.5*xp.sign(convs[:,1:,:])
   csm=0.5*xp.sign(convs[:,0:-1,:])-0.5
   cnm=0.5*xp.sign(convs[:,1:,:])-0.5

   clp=0.5+0.5*xp.sign(convl[:,:,0:-1])
   chp=0.5+0.5*xp.sign(convl[:,:,1:])
   clm=0.5*xp.sign(convl[:,:,0:-1])-0.5
   chm=0.5*xp.sign(convl[:,:,1:])-0.5

#  fix boundaries: no contribution 
   if not cyclic_x:
      cem[-1,:,:]=0
      cwp[0,:,:]=0
   cnm[:,-1,:]=0
   csp[:,0,:]=0
   if not cyclic_z:
      chm[:,:,-1]=0
      clp[:,:,0]=0

# first-order upwind in left-hand side
   aw3d=xp.maximum(convw[0:-1,:,:],0)+diffw
   ae3d=xp.maximum(-convw[1:,:,:],-0)+xp.roll(diffw,-1,axis=0)
   as3d=xp.maximum(convs[:,0:-1,:],0)+diffs
   an3d=xp.maximum(-convs[:,1:,:],0)+xp.roll(diffs,-1,axis=1)
   al3d=xp.maximum(convl[:,:,0:-1],0)+diffl
   ah3d=xp.maximum(-convl[:,:,1:],0)+xp.roll(diffl,-1,axis=2)

   su3d  =muscle_source(u3d,cep,cem,cwp,cwm,cnp,cnm,csp,csm,chp,chm,clp,clm)
   su3d_v=muscle_source(v3d,cep,cem,cwp,cwm,cnp,cnm,csp,csm,chp,chm,clp,clm)
   su3d_w=muscle_source(w3d,cep,cem,cwp,cwm,cnp,cnm,csp,csm,chp,chm,clp,clm)

   su3d=(1-blend)*su3d
   su3d_v=(1-blend)*su3d_v
   su3d_w=(1-blend)*su3d_w

   if blend > 0:

# central differencing
      aw3d_c=diffw+(1-fx)*convw[0:-1,:,:]
      ae3d_c=xp.roll(diffw,-1,axis=0)-xp.roll(fx,-1,axis=0)*convw[1:,:,:]

      as3d_c=diffs+(1-fy)*convs[:,0:-1,:]
      an3d_c=xp.roll(diffs,-1,axis=1)-xp.roll(fy,-1,axis=1)*convs[:,1:,:]

      al3d_c=diffl+(1-fz)*convl[:,:,0:-1]
      ah3d_c=xp.roll(diffl,-1,axis=2)-xp.roll(fz,-1,axis=2)*convl[:,:,1:]

      if not cyclic_x:
         aw3d_c[0,:,:]=0
         ae3d_c[-1,:,:]=0
         aw3d[0,:,:]=0
         ae3d[-1,:,:]=0

      as3d_c[:,0,:]=0
      an3d_c[:,-1,:]=0
      as3d[:,0,:]=0
      an3d[:,-1,:]=0

      if not cyclic_z:
         al3d_c[:,:,0]=0
         ah3d_c[:,:,-1]=0
         al3d[:,:,0]=0
         ah3d[:,:,-1]=0

# blend of CDS and muscle. blend=1 means fully CDS, deferred
#        su_u=c*(
#    . (ae_c-ae(i,j,k))*(acr*(phi(i+1,j,k,n)-phi(i,j,k,n))
#    .            +(1.-acr)*(phio(i+1,j,k,n)-phio(i,j,k,n)))
#    .+(aw_c-aw(i,j,k))*(acr*(phi(i-1,j,k,n)-phi(i,j,k,n))
#    .            +(1.-acr)*(phio(i-1,j,k,n)-phio(i,j,k,n)))

      a=acrank_conv

# u equation
      su3d=su3d+blend*\
                 ((ae3d_c-ae3d)*(a*(xp.roll(u3d,-1,axis=0)-u3d)+(1-a)*(xp.roll(u3d_old,-1,axis=0)-u3d_old))
                 +(aw3d_c-aw3d)*(a*(xp.roll(u3d, 1,axis=0)-u3d)+(1-a)*(xp.roll(u3d_old, 1,axis=0)-u3d_old))
                 +(an3d_c-an3d)*(a*(xp.roll(u3d,-1,axis=1)-u3d)+(1-a)*(xp.roll(u3d_old,-1,axis=1)-u3d_old))
                 +(as3d_c-as3d)*(a*(xp.roll(u3d, 1,axis=1)-u3d)+(1-a)*(xp.roll(u3d_old, 1,axis=1)-u3d_old))
                 +(ah3d_c-ah3d)*(a*(xp.roll(u3d,-1,axis=2)-u3d)+(1-a)*(xp.roll(u3d_old,-1,axis=2)-u3d_old))
                 +(al3d_c-al3d)*(a*(xp.roll(u3d, 1,axis=2)-u3d)+(1-a)*(xp.roll(u3d_old, 1,axis=2)-u3d_old)))

# v equation
      su3d_v=su3d_v+blend* \
                 ((ae3d_c-ae3d)*(a*(xp.roll(v3d,-1,axis=0)-v3d)+(1-a)*(xp.roll(v3d_old,-1,axis=0)-v3d_old))
                 +(aw3d_c-aw3d)*(a*(xp.roll(v3d, 1,axis=0)-v3d)+(1-a)*(xp.roll(v3d_old, 1,axis=0)-v3d_old))
                 +(an3d_c-an3d)*(a*(xp.roll(v3d,-1,axis=1)-v3d)+(1-a)*(xp.roll(v3d_old,-1,axis=1)-v3d_old))
                 +(as3d_c-as3d)*(a*(xp.roll(v3d, 1,axis=1)-v3d)+(1-a)*(xp.roll(v3d_old, 1,axis=1)-v3d_old))
                 +(ah3d_c-ah3d)*(a*(xp.roll(v3d,-1,axis=2)-v3d)+(1-a)*(xp.roll(v3d_old,-1,axis=2)-v3d_old))
                 +(al3d_c-al3d)*(a*(xp.roll(v3d, 1,axis=2)-v3d)+(1-a)*(xp.roll(v3d_old, 1,axis=2)-v3d_old)))

# w equation
      su3d_w=su3d_w+blend*\
                 ((ae3d_c-ae3d)*(a*(xp.roll(w3d,-1,axis=0)-w3d)+(1-a)*(xp.roll(w3d_old,-1,axis=0)-w3d_old))
                 +(aw3d_c-aw3d)*(a*(xp.roll(w3d, 1,axis=0)-w3d)+(1-a)*(xp.roll(w3d_old, 1,axis=0)-w3d_old))
                 +(an3d_c-an3d)*(a*(xp.roll(w3d,-1,axis=1)-w3d)+(1-a)*(xp.roll(w3d_old,-1,axis=1)-w3d_old))
                 +(as3d_c-as3d)*(a*(xp.roll(w3d, 1,axis=1)-w3d)+(1-a)*(xp.roll(w3d_old, 1,axis=1)-w3d_old))
                 +(ah3d_c-ah3d)*(a*(xp.roll(w3d,-1,axis=2)-w3d)+(1-a)*(xp.roll(w3d_old,-1,axis=2)-w3d_old))
                 +(al3d_c-al3d)*(a*(xp.roll(w3d, 1,axis=2)-w3d)+(1-a)*(xp.roll(w3d_old, 1,axis=2)-w3d_old)))

# use first-order upstream upstream of x=0.2
      if embedded:
         i1 = (xp.abs(x_embed-xp2d[:,0])).argmin()  # find index which closest
         if itstep == 0 and iter == 0:
            print('embedded: i1',i1)
#     aw3d=diffw+(1-fx)*convw[0:-1,:,:]
#     ae3d=xp.roll(diffw,-1,axis=0)-xp.roll(fx,-1,axis=0)*convw[1:,:,:]

#     as3d=diffs+(1-fy)*convs[:,0:-1,:]
#     an3d=xp.roll(diffs,-1,axis=1)-xp.roll(fy,-1,axis=1)*convs[:,1:,:]

#     al3d=diffl+(1-fz)*convl[:,:,0:-1]
#     ah3d=xp.roll(diffl,-1,axis=2)-xp.roll(fz,-1,axis=2)*convl[:,:,1:]

         aw3d[i1:,:,:]=diffw[i1:,:,:]+(1-fx[i1:,:,:])*convw[i1:-1,:,:]
         ae3d[i1:,:,:]=xp.roll(diffw[i1:,:,:],-1,axis=0)-xp.roll(fx[i1:,:,:],-1,axis=0)*convw[i1+1:,:,:]

         as3d[i1:,:,:]=diffs[i1:,:,:]+(1-fy[i1:,:,:])*convs[i1:,0:-1,:]
         an3d[i1:,:,:]=xp.roll(diffs[i1:,:,:],-1,axis=1)-xp.roll(fy[i1:,:,:],-1,axis=1)*convs[i1:,1:,:]

         al3d[i1:,:,:]=diffl[i1:,:,:]+(1-fz[i1:,:,:])*convl[i1:,:,0:-1]
         ah3d[i1:,:,:]=xp.roll(diffl[i1:,:,:],-1,axis=2)-xp.roll(fz[i1:,:,:],-1,axis=2)*convl[i1:,:,1:]

         su3d[i1:,:,:]=0
         su3d_v[i1:,:,:]=0
         su3d_w[i1:,:,:]=0

#        aw3d[0:i1,:,:]=xp.maximum(convw[0:i1,:,:],0)+diffw[0:i1,:,:]
#        ae3d[1:i1,:,:]=xp.maximum(-convw[1:i1,:,:],-0)+xp.roll(diffw[1:i1,:,:],-1,axis=0)
#        as3d[:i1,:,:]=xp.maximum(convs[:i1,0:-1,:],0)+diffs[:i1,:,:]
#        an3d[:i1,:,:]=xp.maximum(-convs[:i1,1:,:],0)+xp.roll(diffs[:i1,:,:],-1,axis=1)
#        al3d[:i1,:,:]=xp.maximum(convl[:i1,:,0:-1],0)+diffl[:i1,:,:]
#        ah3d[:i1,:,:]=xp.maximum(-convl[:i1,:,1:],0)+xp.roll(diffl[:i1,:,:],-1,axis=2)

         if itstep == 0 and iter == 0:
            print('emnedded scheme; i_embed=',i1)


   apo3d=vol/dt[itstep]

   if not cyclic_x:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   as3d[:,0,:]=0
   an3d[:,-1,:]=0
   if not cyclic_z:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,su3d_v,su3d_w

def minmo(a,b):

#    asign=sign(1.,a)
#     rminmo=asign*max(0.,min(abs(a),b*asign))
 
   asign=xp.sign(a)

   return asign*xp.maximum(0,xp.minimum(abs(a),b*asign))

def muscle_source(phi3d,cep,cem,cwp,cwm,cnp,cnm,csp,csm,chp,chm,clp,clm):

    phip=xp.roll(phi3d,-1,axis=0)
    phim=xp.roll(phi3d,1,axis=0)
    phipp=xp.roll(phi3d,-2,axis=0)
    phimm=xp.roll(phi3d,2,axis=0)

#       su(i,j,k)=su(i,j,k)-0.5*
#    &            (conve(i,j,k)*cep*rminmo(phie-phip,phip-phiw)
#    &            -conve(i,j,k)*cem*rminmo(phie-phip,phiee-phie)
#    &            -conve(i-1,j,k)*cwp*rminmo(phip-phiw,phiw-phiww)
#    &            +conve(i-1,j,k)*cwm*rminmo(phip-phiw,phie-phip)



    ss=-0.5*(convw[1:,:,:]*(cep*minmo(phip-phi3d,phi3d-phim)-cem*minmo(phip-phi3d,phipp-phip)) \
          -convw[0:-1,:,:]*(cwp*minmo(phi3d-phim,phim-phimm)-cwm*minmo(phi3d-phim,phip-phi3d)))

    phip=xp.roll(phi3d,-1,axis=1)
    phim=xp.roll(phi3d,1,axis=1)
    phipp=xp.roll(phi3d,-2,axis=1)
    phimm=xp.roll(phi3d,2,axis=1)

    ss=ss\
     -0.5*(convs[:,1:,:]*(cnp*minmo(phip-phi3d,phi3d-phim)-cnm*minmo(phip-phi3d,phipp-phip)) \
          -convs[:,0:-1,:]*(csp*minmo(phi3d-phim,phim-phimm)-csm*minmo(phi3d-phim,phip-phi3d)))

    phip=xp.roll(phi3d,-1,axis=2)
    phim=xp.roll(phi3d,1,axis=2)
    phipp=xp.roll(phi3d,-2,axis=2)
    phimm=xp.roll(phi3d,2,axis=2)

    ss=ss\
     -0.5*(convl[:,:,1:]*(chp*minmo(phip-phi3d,phi3d-phim)-chm*minmo(phip-phi3d,phipp-phip)) \
          -convl[:,:,0:-1]*(clp*minmo(phi3d-phim,phim-phimm)-clm*minmo(phi3d-phim,phip-phi3d)))

    return ss

def minmo(a,b):

#    asign=sign(1.,a)
#     rminmo=asign*max(0.,min(abs(a),b*asign))
 
   asign=xp.sign(a)

   return asign*xp.maximum(0,xp.minimum(abs(a),b*asign))

def coeff(convw,convs,convl,vis3d,prand_1,prand_2,f1_sst,scheme_local):

   if prand_1 == prand_2:
      prand = xp.ones((ni,nj,nk))*prand_1
   else:
      prand = f1_sst*prand_1+(1-f1_sst)*prand_2

   visw=xp.zeros((ni+1,nj,nk))
   viss=xp.zeros((ni,nj+1,nk))
   visl=xp.zeros((ni,nj,nk+1))
   if prand_1 > 0:
      vis_turb=(vis3d-viscos)/prand
   else:
      fk3d_local=xp.maximum(fk3d,fkmin_limit)  #this limit is used only in the diffusion
      vis_turb=(vis3d-viscos)/xp.abs(prand)/fk3d_local**2

   visw[0:-1,:,:]=fx*vis_turb+(1-fx)*xp.roll(vis_turb,1,axis=0)+viscos
   viss[:,0:-1,:]=fy*vis_turb+(1-fy)*xp.roll(vis_turb,1,axis=1)+viscos
   visl[:,:,0:-1]=fz*vis_turb+(1-fz)*xp.roll(vis_turb,1,axis=2)+viscos

   volw=xp.ones((ni+1,nj,nk))*1e-10
   vols=xp.ones((ni,nj+1,nk))*1e-10
   voll=xp.ones((ni,nj,nk+1))*1e-10
   volw[1:,:,:]=0.5*xp.roll(vol,-1,axis=0)+0.5*vol
   diffw=visw[0:-1,:,:]*areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*xp.roll(vol,-1,axis=1)+0.5*vol
   diffs=viss[:,0:-1,:]*areas[:,0:-1,:]**2/vols[:,0:-1,:]
   voll[:,:,1:]=0.5*xp.roll(vol,-1,axis=2)+0.5*vol
   diffl=visl[:,:,0:-1]*areal[:,:,0:-1]**2/voll[:,:,0:-1]

   if cyclic_x:
      visw[0,:,:]=0.5*(vis_turb[0,:,:]+vis_turb[-1,:,:])+viscos
      diffw[0,:,:]=visw[0,:,:]*areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))
   if cyclic_z:
      visl[:,:,0]=0.5*(vis_turb[:,:,0]+vis_turb[:,:,-1])+viscos
      diffl[:,:,0]=visl[:,:,0]*areal[:,:,0]**2/(0.5*(vol[:,:,0]+vol[:,:,-1]))

   if scheme_local == 'h':
      if itstep == 0 and iter == 0:
         print('hybrid scheme, prand=',prand_1)

      aw3d=xp.maximum(convw[0:-1,:,:],diffw+(1-fx)*convw[0:-1,:,:])
      aw3d=xp.maximum(aw3d,0.)

      ae3d=xp.maximum(-convw[1:,:,:],xp.roll(diffw,-1,axis=0)-xp.roll(fx,-1,axis=0)*convw[1:,:,:])
      ae3d=xp.maximum(ae3d,0.)

      as3d=xp.maximum(convs[:,0:-1,:],diffs+(1-fy)*convs[:,0:-1,:])
      as3d=xp.maximum(as3d,0.)

      an3d=xp.maximum(-convs[:,1:,:],xp.roll(diffs,-1,axis=1)-xp.roll(fy,-1,axis=1)*convs[:,1:,:])
      an3d=xp.maximum(an3d,0.)

      al3d=xp.maximum(convl[:,:,0:-1],diffl+(1-fz)*convl[:,:,0:-1])
      al3d=xp.maximum(al3d,0.)

      ah3d=xp.maximum(-convl[:,:,1:],xp.roll(diffl,-1,axis=2)-xp.roll(fz,-1,axis=2)*convl[:,:,1:])
      ah3d=xp.maximum(ah3d,0.)


   if scheme_local == 'u':
      if itstep == 0 and iter == 0:
         print('upwind scheme, prand=',prand_1)

      aw3d=xp.maximum(convw[0:-1,:,:],0)+diffw
      ae3d=xp.maximum(-convw[1:,:,:],-0)+xp.roll(diffw,-1,axis=0)
      as3d=xp.maximum(convs[:,0:-1,:],0)+diffs
      an3d=xp.maximum(-convs[:,1:,:],0)+xp.roll(diffs,-1,axis=1)
      al3d=xp.maximum(convl[:,:,0:-1],0)+diffl
      ah3d=xp.maximum(-convl[:,:,1:],0)+xp.roll(diffl,-1,axis=2)

   if scheme_local == 'c':
      if itstep == 0 and iter == 0:
         print('CDS scheme, prand=',prand_1)
      aw3d=diffw+(1-fx)*convw[0:-1,:,:]
      ae3d=xp.roll(diffw,-1,axis=0)-xp.roll(fx,-1,axis=0)*convw[1:,:,:]

      as3d=diffs+(1-fy)*convs[:,0:-1,:]
      an3d=xp.roll(diffs,-1,axis=1)-xp.roll(fy,-1,axis=1)*convs[:,1:,:]

      al3d=diffl+(1-fz)*convl[:,:,0:-1]
      ah3d=xp.roll(diffl,-1,axis=2)-xp.roll(fz,-1,axis=2)*convl[:,:,1:]


   apo3d=vol/dt[itstep]

   if not cyclic_x:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   as3d[:,0,:]=0
   an3d[:,-1,:]=0
   if not cyclic_z:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d

def bc(su3d,sp3d,phi_bc_west,phi_bc_east,phi_bc_south,phi_bc_north,phi_bc_low,phi_bc_high,\
       phi_bc_west_type,phi_bc_east_type,phi_bc_south_type,phi_bc_north_type,phi_bc_low_type,phi_bc_high_type):

   su3d=xp.zeros((ni,nj,nk))
   sp3d=xp.zeros((ni,nj,nk))

#south
   if phi_bc_south_type == 'd':
      sp3d[:,0,:]=sp3d[:,0,:]-viscos*as_bound
      su3d[:,0,:]=su3d[:,0,:]+viscos*as_bound*phi_bc_south

#north
   if phi_bc_north_type == 'd':
      sp3d[:,-1,:]=sp3d[:,-1,:]-viscos*an_bound
      su3d[:,-1,:]=su3d[:,-1,:]+viscos*an_bound*phi_bc_north

#west
   if phi_bc_west_type == 'd' and not cyclic_x:
      sp3d[0,:,:]=sp3d[0,:,:]-viscos*aw_bound
      su3d[0,:,:]=su3d[0,:,:]+viscos*aw_bound*phi_bc_west
#east
   if phi_bc_east_type == 'd' and not cyclic_x:
      sp3d[-1,:,:]=sp3d[-1,:,:]-viscos*ae_bound
      su3d[-1,:,:]=su3d[-1,:,:]+viscos*ae_bound*phi_bc_east

#low 
   if phi_bc_low_type == 'd' and not cyclic_z:
      sp3d[:,:,0]=sp3d[:,:,0]-viscos*al_bound
      su3d[:,:,0]=su3d[:,:,0]+viscos*al_bound*phi_bc_low

#high 
   if phi_bc_high_type == 'd' and not cyclic_z:
      sp3d[:,:,-1]=sp3d[:,:,-1]-viscos*ah_bound
      su3d[:,:,-1]=su3d[:,:,-1]+viscos*ah_bound*phi_bc_high

   return su3d,sp3d

def conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l):

   if itstep == 0 and iter == 0:
      print('conv called')
#compute convection
   
   dtt=dt[itstep]*acrank
   u3d_star=u3d+dphidx(p3d_face_w,p3d_face_s)*dtt
   v3d_star=v3d+dphidy(p3d_face_w,p3d_face_s)*dtt
   w3d_star=w3d+dphidz(p3d_face_l)*dtt

   u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d_star,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_low,u_bc_high,\
    u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_low_type,u_bc_high_type,'u')
   v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d_star,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_low,v_bc_high,\
    v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_low_type,v_bc_high_type,'v')
   w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d_star,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_low,w_bc_high,\
    w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_low_type,w_bc_high_type,'w')

   convw=-u3d_face_w*areawx-v3d_face_w*areawy
   convs=-u3d_face_s*areasx-v3d_face_s*areasy
   convl=w3d_face_l*areal

   convw,convs,convl=modify_conv(convw,convs,convl)

   return convw,convs,convl

def solve_3d(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv,nmax,acrank_conv_local,solver_local,variable):
   if itstep == 0 and iter == 0:
      print('solve_3d called')
      print('nmax,acrank_conv_local',nmax,acrank_conv_local)

   start_time_solver = time.time()

   aw=aw3d.flatten()*acrank_conv_local
   ae=ae3d.flatten()*acrank_conv_local
   as1=as3d.flatten()*acrank_conv_local
   an=an3d.flatten()*acrank_conv_local
   al=al3d.flatten()*acrank_conv_local
   ah=ah3d.flatten()*acrank_conv_local
   ap=ap3d.flatten()
  
   m=ni*nj*nk



   if cyclic_x and cyclic_z:
      al_cyc=xp.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=xp.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1:-1:nk]=0
      ah[-1]=0
      A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr') 
   elif not cyclic_z and cyclic_x:
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr') 
   elif cyclic_z and not cyclic_x:
      al_cyc=xp.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=xp.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1::nk]=0
      A= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
   elif not cyclic_z and not cyclic_x:
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr') 

   su=su3d.flatten()
   phi=phi3d.flatten()

   res_su=xp.linalg.norm(su)
   resid_orig=xp.linalg.norm(A*phi - su)

   phi_org=phi

# bicg (BIConjugate Gradient)
# bicgstab (BIConjugate Gradient STABilized)
# cg (Conjugate Gradient) - symmetric positive definite matrices only
# cgs (Conjugate Gradient Squared)
# gmres (Generalized Minimal RESidual)
# minres (MINimum RESidual)
# qmr (Quasi
   resid=xp.linalg.norm(A*phi - su,ord=norm_order)
   tol=tol_conv
   abs_tol=1e-10
   if tol_conv < 0:
# use absolute convergence criterium
      abs_tol =abs(tol_conv)*resid
      tol_conv=0

   if solver_local == 'cgs':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: cgs')
      phi,info=linalg.cgs(A,su,x0=phi, atol=tol_conv, rtol=tol,  maxiter=nmax)  # good
   if solver_local == 'cg':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: cg')
      phi,info=linalg.cg(A,su,x0=phi, atol=abs_tol, rtol=tol,  maxiter=nmax)  # good
   if solver_local == 'gmres':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: gmres')
      phi,info=linalg.gmres(A,su,x0=phi, atol=abs_tol, rtol=tol,  maxiter=nmax)  # good
   if solver_local == 'qmr':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: qmr')
      phi,info=linalg.qmr(A,su,x0=phi, atol=abs_tol, rtol=tol,  maxiter=nmax)  # good
   if solver_local == 'lgmres':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: lgmres,tol,atol',tol,tol_conv)
      phi,info=linalg.lgmres(A,su,x0=phi, atol=abs_tol, rtol=tol,  maxiter=nmax)  # good
   if info > 0:
      print('warning in module solve_3d: convergence in sparse matrix solver not reached')
   print('solve_3d, into',info)
# compute residual without normalizing with |b|=|su3d|
   resid=xp.linalg.norm(A*phi - su,ord=norm_order)

   delta_phi=xp.max(xp.abs(phi-phi_org))

   phi3d=xp.reshape(phi,(ni,nj,nk))
   phi3d_org=xp.reshape(phi_org,(ni,nj,nk))


   print('variable=',variable)
   print(f"{'residual history in solve_3d. Variable '}{variable}{': initial residual: '} {resid_orig:.2e}{'final residual: ':>20}{resid:.2e}\
      {'delta_phi: ':>15}{delta_phi:.2e}")

   print(f"{'time solve_3d: '}{time.time()-start_time_solver:.2e}")

   return phi3d,resid

def solve_pyamg(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv,acrank_conv_local):

   if itstep == 0 and iter == 0:
      print('solve_pyamg called,tol_conv=',tol_conv,'acrank_conv_local=',acrank_conv_local)
      print('relation method_phi, amg_cycle_phi=',amg_relax_phi,amg_cycle_phi)

   aw=aw3d.flatten()*acrank_conv_local
   ae=ae3d.flatten()*acrank_conv_local
   as1=as3d.flatten()*acrank_conv_local
   an=an3d.flatten()*acrank_conv_local
   al=al3d.flatten()*acrank_conv_local
   ah=ah3d.flatten()*acrank_conv_local
   ap=ap3d.flatten()
  

   m=ni*nj*nk



   if cyclic_x and cyclic_z:
      if itstep == 0 and iter == 0:
         print('cyclic_x cyclic_z')
      al_cyc=xp.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=xp.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1:-1:nk]=0
      ah[-1]=0
      A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
   elif not cyclic_z and cyclic_x:
      if itstep == 0 and iter == 0:
         print('cyclic_x and not cyclic_z')
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
   elif cyclic_z and not cyclic_x:
      al_cyc=xp.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=xp.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
#       ah[nk-1:-1:nk]=0
      ah[nk-1::nk]=0
      A= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
   elif not cyclic_z and not cyclic_x:
      if itstep == 0 and iter == 0:
         print('not cyclic_z and not cyclic_x')
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr')
  


   App = pyamg.ruge_stuben_solver(A)                    # construct the multigrid hierarchy
#     Ap = pyamg.classical.ruge_stuben_solver(Ap) 

   print('in solve_pyamg')

   phi=phi3d.flatten()
   su=su3d.flatten()
   phi_org=phi
   res_amg = []
   if amg_relax_phi == 'default':
      phi = App.solve(su, tol=tol_conv, x0=phi, residuals=res_amg)
   elif amg_relax_phi == 'direct':
      phi = linalg.spsolve(A,su)
      res_amg=xp.zeros(1)
   else:
      phi = App.solve(su, tol=tol_conv, x0=phi,accel=amg_relax_phi,cycle=amg_cycle_phi, residuals=res_amg)

#  if amg_relax_phi != 'direct':
   print('Residual history in pyAMG', ["%0.4e" % i for i in res_amg])

#  index_phi=xp.argmax(xp.abs(phi-phi_org))
   delta_phi=xp.max(xp.abs(phi-phi_org))

   print(f"{'Residual history in solve_pyAMG: delta_phi: ':>25}{delta_phi:.2e}")

   resid=xp.linalg.norm(A*phi - su,ord=norm_order)

   phi3d=xp.reshape(phi,(ni,nj,nk))

   return phi3d,resid

def solve_pyamgx(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv,acrank_conv_local,variable):

   global A_x,b_x,x_x,rsc_x,cfg_x,A

   start_time_solver = time.time()

   if itstep == 0 and iter == 0:
      print('solve_pyamgx called,tol_conv=',tol_conv,'acrank_conv_local=',acrank_conv_local,'variable=',variable)

   su=su3d.flatten()
   phi=phi3d.flatten()

   if variable == 'p' or variable == 'u' or variable == 'k' or variable == 'eps' or variable == 'om' or \
     (variable == 'v' and coeff_v)  or (variable == 'w' and coeff_w):
# the coefficient matrix is the same for u, v, w. It is assumed that sp3d=0  (e.g. there must be no symmetry b.c.) 

      aw=aw3d.flatten()*acrank_conv_local
      ae=ae3d.flatten()*acrank_conv_local
      as1=as3d.flatten()*acrank_conv_local
      an=an3d.flatten()*acrank_conv_local
      al=al3d.flatten()*acrank_conv_local
      ah=ah3d.flatten()*acrank_conv_local
      ap=ap3d.flatten()
   
      m=ni*nj*nk
   
      if cyclic_x and cyclic_z:
         print('cyclic_x cyclic_z')
         al_cyc=xp.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=xp.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
         ah[nk-1:-1:nk]=0
         ah[-1]=0
         A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif not cyclic_z and cyclic_x:
         print('cyclic_x and not cyclic_z')
         A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
               [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif cyclic_z and not cyclic_x:
         al_cyc=xp.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=xp.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
#       ah[nk-1:-1:nk]=0
         ah[nk-1::nk]=0
         A= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
      elif not cyclic_z and not cyclic_x:
         print('not cyclic_z and not cyclic_x')
         A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr')
  
#  Create matrices and vectors:
      A_x = pyamgx.Matrix().create(rsc)
      b_x = pyamgx.Vector().create(rsc)
      x_x = pyamgx.Vector().create(rsc)

      A_x.upload_CSR(A) 
      b_x.upload(su)     #upload pyamg and rhs
      x_x.upload(phi)
      solverx.setup(A_x)
      print('solvers setup')


   b_x.upload(su)     #upload pyamg and rhs
   x_x.upload(phi)

# solve system:
   solverx.solve(b_x, x_x)

# Download solution
   if xp is numpy:
      x_x.download(phi)
   else:
      x_x.download_raw(phi.data)

   if   (variable == 'p' or variable == 'w' or variable == 'k' or variable == 'eps' or variable == 'om')\
     or (variable == 'u' and coeff_v):
      A_x.destroy()
      b_x.destroy()
      x_x.destroy()

   phi3d=xp.reshape(phi,(ni,nj,nk))

   resid=xp.linalg.norm(A*phi - su,ord=norm_order)

   phi3d=xp.reshape(phi,(ni,nj,nk))

   print(f"{'resid in solve_pyamgx: '}{resid:.2e}")

   print(f"{'time solve_pyamgx: '}{time.time()-start_time_solver:.2e}")

   return phi3d,resid

def solve_p(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv):

   global Ap_solve_p

   start_time_solver = time.time()

   if itstep == 0 and iter == 0:
      print('solve_p called')
      print('relation method=',amg_relax)

   if iter == 0 and itstep == 0:
      print('A and M computed,tol_conv=',tol_conv)
#     aw=xp.matrix.flatten(aw3d)
#     ae=xp.matrix.flatten(ae3d)
#     as1=xp.matrix.flatten(as3d)
#     an=xp.matrix.flatten(an3d)
#     al=xp.matrix.flatten(al3d)
#     ah=xp.matrix.flatten(ah3d)
#     ap=xp.matrix.flatten(ap3d)

      aw=aw3d.flatten()
      ae=ae3d.flatten()
      as1=as3d.flatten()
      an=an3d.flatten()
      al=al3d.flatten()
      ah=ah3d.flatten()
      ap=ap3d.flatten()

      m=ni*nj*nk

      if cyclic_x and cyclic_z:
         print('cyclic_x cyclic_z')
         al_cyc=xp.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=xp.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
         ah[nk-1:-1:nk]=0
         ah[-1]=0
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif not cyclic_z and cyclic_x:
         print('cyclic_x and not cyclic_z')
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
               [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif cyclic_z and not cyclic_x:
         print('cyclic_z and not cyclic_x')
         al_cyc=xp.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=xp.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
#        ah[nk-1:-1:nk]=0
         ah[nk-1::nk]=0
         Ap= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
      elif not cyclic_z and not cyclic_x:
         print('not cyclic_z and not cyclic_x')
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr')
   


      Ap_solve_p = pyamg.ruge_stuben_solver(Ap)                    # construct the multigrid hierarchy

   print('in solve_p')

#  phi=xp.matrix.flatten(phi3d)
#  su=xp.matrix.flatten(su3d)
   phi=phi3d.flatten()
   su=su3d.flatten()
   res_amg = []
   if amg_relax == 'default':
      phi = Ap_solve_p.solve(su, tol=tol_conv, x0=phi, residuals=res_amg)
   else:
      phi = Ap_solve_p.solve(su, tol=tol_conv, x0=phi, residuals=res_amg, accel=amg_relax,cycle=amg_cycle)

   print('Residual history in solve_p', ["%0.4e" % i for i in res_amg])

   phi3d=xp.reshape(phi,(ni,nj,nk))

   print(f"{'time solve_p: '}{time.time()-start_time_solver:.2e}")

   return phi3d


def solve_px(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv):

   global Ap,A_xp,b_xp,x_xp,rsc,cfg,solverx_p

   if itstep == 0 and iter == 0:
      print('solve_px called')

   phi=phi3d.flatten()
   su=su3d.flatten()

   if iter == 0 and itstep == 0:
      print('A and M computed,tol_conv=',tol_conv)

      aw=aw3d.flatten()
      ae=ae3d.flatten()
      as1=as3d.flatten()
      an=an3d.flatten()
      al=al3d.flatten()
      ah=ah3d.flatten()
      ap=ap3d.flatten()

      m=ni*nj*nk

      if cyclic_x and cyclic_z:
         print('cyclic_x cyclic_z')
         al_cyc=xp.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=xp.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
         ah[nk-1:-1:nk]=0
         ah[-1]=0
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif not cyclic_z and cyclic_x:
         print('cyclic_x and not cyclic_z')
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
               [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif cyclic_z and not cyclic_x:
         al_cyc=xp.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=xp.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
#        ah[nk-1:-1:nk]=0
         ah[nk-1::nk]=0
         Ap= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
      elif not cyclic_z and not cyclic_x:
         print('not cyclic_z and not cyclic_x')
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr')
   
# Create matrices and vectors:
      A_xp = pyamgx.Matrix().create(rsc)
      b_xp = pyamgx.Vector().create(rsc)
      x_xp = pyamgx.Vector().create(rsc)

      solverx_p = pyamgx.Solver().create(rsc, cfg_p)

      A_xp.upload_CSR(Ap) #upload pyamg poisson problem
      b_xp.upload(su)     #upload pyamg and rhs
      x_xp.upload(phi)
      solverx_p.setup(A_xp)

   b_xp.upload(su)     #upload pyamg and rhs
   x_xp.upload(phi)

# solve system:
   solverx_p.solve(b_xp, x_xp)

# Download solution
   if xp is numpy:
      x_xp.download(phi)
   else:
      x_xp.download_raw(phi.data)

   phi3d=xp.reshape(phi,(ni,nj,nk))


   resid=xp.linalg.norm(Ap*phi - su,ord=norm_order)

   print(f"{'resid in solve_px: '}{resid:.2e}")

   return phi3d

def calcu(su3d,sp3d,dpdx_old,p3d_face_w,p3d_face_s):
   if itstep == 0 and iter == 0:
      print('calcu called')
# b.c., sources, coefficients

# presssure gradient
   dpdx=acrank*dphidx(p3d_face_w,p3d_face_s)+(1-acrank)*dpdx_old

   su3d=su3d-dpdx*vol

# modify su & sp
   su3d,sp3d=modify_u(su3d,sp3d)
# unsteady term added in crank_nicol

   return su3d,sp3d

# find max index
# ind = xp.unravel_index(xp.argmax(yplus_south, axis=None), yplus_south.shape)



def calcv(su3d,sp3d,dpdy_old,p3d_face_w,p3d_face_s):
   if itstep == 0 and iter == 0:
      print('calcv called')
# b.c., sources, coefficients 

# presssure gradient
   dpdy=acrank*dphidy(p3d_face_w,p3d_face_s)+(1-acrank)*dpdy_old
   su3d=su3d-dpdy*vol

# modify su & sp
   su3d,sp3d=modify_v(su3d,sp3d)

   return su3d,sp3d

def compute_fk(k3d,eps3d,fk3d):

   if itstep == 0 and iter == 0:
      print('compute_fk called')

   if pans:
      L_t=k3d**1.5/eps3d
      psi=xp.maximum(1,L_t/(cdes*delta_max))

      fk3d=xp.maximum(1.-(psi-1.)/(c_eps_2-c_eps_1),0)
#     fk3d=xp.maximum(1.-(psi-1.)/(c_eps_2-c_eps_1),fkmin_limit)

   if keps_des:
      rl=k3d**1.5/eps3d
      fk3d=xp.maximum(1.,rl/(0.67*delta_max))
   if jl0 < 0:
      jl=xp.abs(jl0)
      fk3d[:,0:jl,:]=1

   fk3d=modify_fk(fk3d)

   return fk3d

def calck_kom(su3d,sp3d,k3d,om3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calck_kom called')

# production term
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)
   vist=xp.maximum(vis3d-viscos,1e-10)
   su3d=su3d+vist*gen*vol


# dissipation term
   if sst or kom_des:
      sp3d=sp3d-fk3d*cmu*om3d*vol
   else:
      sp3d=sp3d-cmu*om3d*vol

# modify su & sp
   su3d,sp3d,comm_term=modify_k(su3d,sp3d,gen)

# unsteady term added in crank_nicol

   return su3d,sp3d,gen,comm_term

def calcom(su3d,sp3d,om3d,gen,comm_term):
   if itstep == 0 and iter == 0:
      print('calcom called')


#--------production term
   su3d=su3d+c_omega_1*gen*vol

#--------dissipation term
   sp3d=sp3d-c_omega_2*om3d*vol

# modify su & sp
   su3d,sp3d=modify_om(su3d,sp3d,comm_term)

   return su3d,sp3d

def calcom_sst(su3d,sp3d,om3d,gen,comm_term):
   if itstep == 0 and iter == 0:
      print('calcom_sst called')

   k3d_face_w,k3d_face_s,k3d_face_l=compute_face_phi(k3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_low,k_bc_high,\
     k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_low_type,k_bc_high_type,'k')
   om3d_face_w,om3d_face_s,om3d_face_l=compute_face_phi(om3d,om_bc_west,om_bc_east,om_bc_south,om_bc_north,om_bc_low,om_bc_high,\
     om_bc_west_type,om_bc_east_type,om_bc_south_type,om_bc_north_type,om_bc_low_type,om_bc_highlow_type,'om')

#-------- cross term
   dkdx=dphidx(k3d_face_w,k3d_face_s)
   dkdy=dphidy(k3d_face_w,k3d_face_s)
   dkdz=dphidz(k3d_face_l)

   domdx=dphidx(om3d_face_w,om3d_face_s)
   domdy=dphidy(om3d_face_w,om3d_face_s)
   domdz=dphidz(om3d_face_l)

   if om_bc_south_type == 'd':
      domdy[:,0,:]=-2*om_bc_south/dist3d[:,0,:]

#--- north wall
   if om_bc_north_type == 'd':
      domdy[:,-1,:]=2*om_bc_north/dist3d[:,-1,:]

   crosv=dkdx*domdx+dkdy*domdy+dkdz*domdz
   cross_term=2*cr_sst*crosv/om3d

# f1_sst
   term1=xp.maximum(cross_term,1.e-10)
   term1b=500*viscos/(om3d*dist3d**2)
   term2=xp.maximum(k3d**0.5/(cmu*om3d*dist3d),term1b)
   term3=4*cr_sst*k3d/(term1*dist3d**2)
   zeta=xp.minimum(term2,term3)
   f1_sst=xp.tanh(zeta**4)

# f2_sst
   zeta=xp.maximum(2*k3d**0.5/(cmu*om3d*dist3d),term1b)
   f2_sst=xp.tanh(zeta**2)

   cross_term=cross_term*(1-f1_sst)
    
   sp3d=sp3d+xp.minimum(cross_term,0)/om3d*vol
   su3d=su3d+xp.maximum(cross_term,0)*vol

#------- interpolate constants
   c_sst_1 = f1_sst*c_omega_1_sst_1+(1-f1_sst)*c_omega_1_sst_2
   c_sst_2 = f1_sst*c_omega_2_sst_1+(1-f1_sst)*c_omega_2_sst_2

   su3d=su3d+c_sst_1*gen*vol

#--------dissipation term
   sp3d=sp3d-c_sst_2*om3d*vol

# modify su & sp
   su3d,sp3d=modify_om(su3d,sp3d,comm_term)

   return su3d,sp3d,f1_sst,f2_sst

def calck(su3d,sp3d,k3d,eps3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calck called')

# production term
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)
   vist=xp.maximum(vis3d-viscos,1e-10)
   su3d=su3d+vist*gen*vol

# dissipation term
   if keps_des:
      sp3d=sp3d-fk3d*eps3d/k3d*vol
   elif k_eq_les:
      delta_les=vol**0.3333
      sp3d=sp3d-c_eps*k3d**0.5/delta_les*vol
   else:
      sp3d=sp3d-eps3d/k3d*vol

# modify su & sp
   su3d,sp3d,comm_term=modify_k(su3d,sp3d,gen)

# unsteady term added in crank_nicol

   return su3d,sp3d,gen,dudx,dudy

def calceps(su3d,sp3d,k3d,eps3d,vis3d,gen):
   if itstep == 0 and iter == 0:
      print('calceps called')

# b.c., sources, coefficients 
   ueps=(eps3d*viscos)**0.25
   ystar=ueps*dist3d/viscos
   rt=k3d**2/eps3d/viscos
   fdampf2=((1.-xp.exp(-ystar/3.1))**2)*(1.-0.3*xp.exp(-(rt/6.5)**2))
   fmu3d=((1.-xp.exp(-ystar/14.))**2)*(1.+5./rt**0.75*xp.exp(-(rt/200.)**2))
   fmu3d=xp.minimum(fmu3d,1.)

#--------production term
   su3d=su3d+c_eps_1*cmu*fmu3d*gen*k3d*vol
   if pans:
      c2u=c_eps_1+fk3d*(fdampf2*c_eps_2-c_eps_1)
   else:
      c2u=fdampf2*c_eps_2

#--------dissipation term
   sp3d=sp3d-c2u*eps3d*vol/k3d

# modify su & sp
   su3d,sp3d=modify_eps(su3d,sp3d)

   return su3d,sp3d,fmu3d

def calcw(su3d,sp3d,dpdz_old,p3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calcw called')

# presssure gradient
   dpdz=acrank*dphidz(p3d_face_l)+(1-acrank)*dpdz_old
   su3d=su3d-dpdz*vol

# modify su & sp
   su3d,sp3d=modify_w(su3d,sp3d)

# unsteady term added in crank_nicol
   return su3d,sp3d

def calcp(convw,convs,convl):
   if itstep == 0 and iter == 0:
      print('calcp called')
# b.c., sources, coefficients
   volw=xp.ones((ni+1,nj,nk))*1e-10
   vols=xp.ones((ni,nj+1,nk))*1e-10
   voll=xp.ones((ni,nj,nk+1))*1e-10
   volw[1:,:,:]=0.5*xp.roll(vol,-1,axis=0)+0.5*vol
   aw3d=areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*xp.roll(vol,-1,axis=1)+0.5*vol
   as3d=areas[:,0:-1,:]**2/vols[:,0:-1,:]
   voll[:,:,1:]=0.5*xp.roll(vol,-1,axis=2)+0.5*vol
   al3d=areal[:,:,0:-1]**2/voll[:,:,0:-1]

   ae3d=xp.roll(aw3d,-1,axis=0)
   an3d=xp.roll(as3d,-1,axis=1)
   ah3d=xp.roll(al3d,-1,axis=2)


   if cyclic_x:
      aw3d[0,:,:]=areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))
      ae3d[-1,:,:]=aw3d[0,:,:]
   else:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   
   if cyclic_z:
      al3d[:,:,0]=areal[:,:,0]**2/(0.5*(vol[:,:,0]+vol[:,:,-1]))
      ah3d[:,:,-1]=al3d[:,:,0]
   else:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   as3d[:,0,:]=0
   an3d[:,-1,:]=0

   ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d

# set p3d=0 in [0,0,0] to make it non-singular
   ap3d[0,0,0]=1e10

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d


def correct_conv(u3d,v3d,w3d,p3d,aw3d_p,as3d_p,al3d_p):
# correct convections
# create ghost cells at east & west boundaries with Neumann b.c.
   p3d_w=p3d
   p3d_s=p3d
   p3d_l=p3d
   dtt=dt[itstep]*acrank
#\\\\\\\\\\\\\ west face
# set zeros and put if before row 0
   p3d_w=xp.concatenate((xp.zeros((1,nj,nk)),p3d_w),axis=0)
   if cyclic_x:
      convw[1:-1,:,:]=convw[1:-1,:,:]+aw3d_p[1:,:,:]*(p3d_w[1:-1,:,:]-p3d[1:,:,:])*dtt
      convw[0,:,:]=convw[0,:,:]+aw3d_p[0,:,:]*(p3d[-1,:,:]-p3d[0,:,:])*dtt
      convw[-1,:,:]=convw[0,:,:]
   else:
      convw[0:-1,:,:]=convw[0:-1,:,:]+aw3d_p*(p3d_w[0:-1,:,:]-p3d)*dtt


#\\\\\\\\\\\\\ south face
# set zeros and put it before column 0
   p3d_s=xp.concatenate((xp.zeros((ni,1,nk)),p3d_s),axis=1)
   convs[:,0:-1,:]=convs[:,0:-1,:]+as3d_p*(p3d_s[:,0:-1,:]-p3d)*dtt

#\\\\\\\\\\\\\ low face
# set zeros and put it before column 0
   p3d_l=xp.concatenate((xp.zeros((ni,nj,1)),p3d_l),axis=2)
   if cyclic_z:
      convl[:,:,1:-1]=convl[:,:,1:-1]+al3d_p[:,:,1:]*(p3d_l[:,:,1:-1]-p3d[:,:,1:])*dtt
      convl[:,:,-1]=convl[:,:,-1]+al3d_p[:,:,-1]*(p3d[:,:,-1]-p3d[:,:,0])*dtt
      convl[:,:,0]=convl[:,:,-1]
   else:
      convl[:,:,0:-1]=convl[:,:,0:-1]+al3d_p*(p3d_l[:,:,0:-1]-p3d)*dtt
# boundary
      convl[:,:,0]=w3d_face_l[:,:,0]*areal[:,:,0]

# continuity error
   su3d=convw[0:-1,:,:]-convw[1:,:,:]\
       +convs[:,0:-1,:]-convs[:,1:,:]\
       +convl[:,:,0:-1]-convl[:,:,1:]

   return convw,convs,convl,p3d,u3d,v3d,w3d,su3d


def update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l):
    u3d_old=u3d
    v3d_old=v3d
    w3d_old=w3d
    k3d_old=k3d
    eps3d_old=eps3d
    om3d_old=om3d
    dpdx_old=dphidx(p3d_face_w,p3d_face_s)
    dpdy_old=dphidy(p3d_face_w,p3d_face_s)
    dpdz_old=dphidz(p3d_face_l)

    return u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old

def time_stats(u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,\
                fk3d_mean,vis3d_mean,gen_mean):

    global itstep_stats_counter


    itstep_stats_counter=itstep_stats_counter+1

    print('time_stats called: itstep_stats_counter=',itstep_stats_counter)
    u3d_mean=u3d_mean+u3d
    v3d_mean=v3d_mean+v3d
    w3d_mean=w3d_mean+w3d
    p3d_mean=p3d_mean+p3d
    k3d_mean=k3d_mean+k3d
    fk3d_mean=fk3d_mean+fk3d
    om3d_mean=om3d_mean+om3d
    eps3d_mean=eps3d_mean+eps3d
    vis3d_mean=vis3d_mean+vis3d
    uu3d_stress=uu3d_stress+u3d**2
    vv3d_stress=vv3d_stress+v3d**2
    ww3d_stress=ww3d_stress+w3d**2
    uv3d_stress=uv3d_stress+u3d*v3d
    uw3d_stress=uw3d_stress+u3d*w3d
    vw3d_stress=vw3d_stress+v3d*w3d
    gen_mean=gen_mean+gen

    return u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,\
           fk3d_mean,vis3d_mean,gen_mean

def crank_nicol(phi3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_local):
    ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d
    su3d=su3d+(apo3d-(1-acrank_conv_local)*ap3d)*phi3d_old
    ap3d=apo3d+acrank_conv_local*ap3d-sp3d
    su3d=su3d+(1-acrank_conv_local)*\
      (ae3d*xp.roll(phi3d_old,-1,axis=0)+aw3d*xp.roll(phi3d_old,1,axis=0) \
      +an3d*xp.roll(phi3d_old,-1,axis=1)+as3d*xp.roll(phi3d_old,1,axis=1) \
      +ah3d*xp.roll(phi3d_old,-1,axis=2)+al3d*xp.roll(phi3d_old,1,axis=2))
    return ap3d,su3d

def vist_kom(vis3d,k3d,om3d):
   if itstep == 0 and iter == 0:
      print('vist_kom called')

   visold= vis3d
   vis3d= k3d/om3d+viscos

# modify viscosity
   vis3d=modify_vis(vis3d)

#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold

   return vis3d

def vist_k_eq(vis3d,k3d):
   if itstep == 0 and iter == 0:
      print('vist_k_eq called')

   visold= vis3d
   delta_les=vol**0.3333
   vis3d= cmu*k3d**0.5*delta_les+viscos

# modify viscosity
   vis3d=modify_vis(vis3d)

#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold

   return vis3d

def vist_sst(vis3d,k3d,om3d,f2_sst,gen):
   if itstep == 0 and iter == 0:
      print('vist_sst called')

   visold= vis3d
   a1=cmu**0.5
   denom=xp.maximum(a1*om3d,gen**0.5*f2_sst)
   vis3d= a1*k3d/denom+viscos

# modify viscosity
   vis3d=modify_vis(vis3d)


#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold


   return vis3d

def vist_pans(vis3d,k3d,eps3d,fmu3d):
   if itstep == 0 and iter == 0:
      print('vist_pans called')

   vis3d= cmu*fmu3d*k3d**2/eps3d+viscos
#            under-relax viscosity

# modify viscosity
   vis3d=modify_vis(vis3d)

   visold= vis3d
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
 
   

   return vis3d

def vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d):
   if itstep == 0 and iter == 0:
      print('vist_smag called')
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)

# RANS lengthscale
   rl_rans=0.41*dist3d
   rl_les=cmu*vol**0.3333333
   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=xp.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=xp.repeat(dy[:,:,None],repeats=nk,axis=2)

   l_temp=xp.maximum(l_dist,l_max)
   l_temp=xp.maximum(l_temp,dy)
   l_les=xp.minimum(l_temp,delta_max)

#  rl=xp.minimum(rl_rans,cmu*l_les)

# IDDES
   alpha=0.25-dist3d/delta_max
   f_b=  xp.minimum(2.*xp.exp(-9*alpha**2),1.)
   rl=f_b*rl_rans+(1-f_b)*cmu*l_les

   visold= vis3d
   vis3d= rl**2*gen**0.5+viscos

# modify viscosity
   vis3d=modify_vis(vis3d)

#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
   return vis3d

def save_vtk():
   scalar_names = ['pressure']
   scalar_variables = [p3d]
   if keps or pans or keps_des:
      scalar_names.append('turb_kin')
      scalar_names.append('epsilon')
      scalar_variables.append(k3d)
      scalar_variables.append(eps3d)
   if k_eq_les:
      scalar_names.append('turb_kin')
      scalar_variables.append(k3d)
   if sst or kom or kom_des:
      scalar_names.append('turb_kin')
      scalar_names.append('omega')
      scalar_variables.append(k3d)
      scalar_variables.append(om3d)

   if save_vtk_movie:
      file_name = '%s.%d.vtk' % (vtk_file_name, itstep)
   else:
      file_name = '%s.vtk' % (vtk_file_name)

   f = open(file_name,'w')
   f.write('# vtk DataFile Version 3.0\npyCALC-LES Data\nASCII\nDATASET STRUCTURED_GRID\n')
   f.write('DIMENSIONS %d %d %d\nPOINTS %d double\n' % (nk+1,nj+1,ni+1,(ni+1)*(nj+1)*(nk+1)))
   for i in range(ni+1):
      for j in range(nj+1):
         for k in range(nk+1):
            f.write('%.5f %.5f %.5f\n' % (x2d[i,j],y2d[i,j],dz*k))
   f.write('\nCELL_DATA %d\n' % (ni*nj*nk))

   f.write('\nVECTORS velocity double\n')
   for i in range(ni):
      for j in range(nj):
         for k in range(nk):
            f.write('%.12e %.12e %.12e\n' % (u3d[i,j,k],v3d[i,j,k],w3d[i,j,k]))

   for v in range(len(scalar_names)):
      var_name = scalar_names[v]
      var = scalar_variables[v]
      f.write('\nSCALARS %s double 1\nLOOKUP_TABLE default\n' % (var_name))
      for i in range(ni):
         for j in range(nj):
            for k in range(nk):
               f.write('%.10e\n' % (var[i,j,k]))
   f.close()

   print('Flow state save into VTK format to file %s\n' % (file_name))

def save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress, \
                        vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,gen_mean):

   print('save_time_aver_data called')
# save time-averaged data to disk
# if equi-distant mesh in z direction
   if save_average_z:
      umm=xp.mean(u3d_mean,axis=2)
      xp.save('u_averaged', xp.mean(u3d_mean,axis=2))
      xp.save('v_averaged', xp.mean(v3d_mean,axis=2))
      xp.save('w_averaged', xp.mean(w3d_mean,axis=2))
      xp.save('p_averaged', xp.mean(p3d_mean,axis=2))
      xp.save('k_averaged', xp.mean(k3d_mean,axis=2))
      xp.save('fk_averaged', xp.mean(fk3d_mean,axis=2))
      xp.save('k_averaged', xp.mean(k3d_mean,axis=2))
      xp.save('om_averaged', xp.mean(om3d_mean,axis=2))
      xp.save('vis_averaged', xp.mean(vis3d_mean,axis=2))
      xp.save('eps_averaged', xp.mean(eps3d_mean,axis=2))
      xp.save('gen_averaged', xp.mean(gen_mean,axis=2))
      xp.save('uu_stress', xp.mean(uu3d_stress,axis=2))
      xp.save('vv_stress', xp.mean(vv3d_stress,axis=2))
      xp.save('ww_stress', xp.mean(ww3d_stress,axis=2))
      xp.save('uv_stress', xp.mean(uv3d_stress,axis=2))
      xp.save('uw_stress', xp.mean(uw3d_stress,axis=2))
      xp.save('vw_stress', xp.mean(vw3d_stress,axis=2))
      xp.save('itstep',xp.array([xp.array(itstep_stats_counter),xp.array(nk),dz3d[0,0,0]]))
      print('itstep_stats_counter,nk,dz',itstep_stats_counter,nk,dz3d[0,0,0])
      print('data averaged in z')
   else:
      xp.save('u_averaged_3d', u3d_mean)
      xp.save('v_averaged_3d', v3d_mean)
      xp.save('w_averaged_3d', w3d_mean)
      xp.save('p_averaged_3d', p3d_mean)
      xp.save('vis_averaged_3d', vis3d_mean)
      if keps or kom_des  or keps_des  or kom  or sst or k_eq_les:
         xp.save('k_averaged_3d', k3d_mean)
      if kom_des  or kom:
         xp.save('om_averaged_3d', om3d_mean)
      if keps or keps_des or sst:
         xp.save('eps_averaged_3d', eps3d_mean)
      if kom_des  or keps_des  or sst:
         xp.save('fk_averaged_3d', fk3d_mean)
      xp.save('uu_stress_3d', uu3d_stress)
      xp.save('vv_stress_3d', vv3d_stress)
      xp.save('ww_stress_3d', ww3d_stress)
      xp.save('uv_stress_3d', uv3d_stress)
      xp.save('uw_stress_3d', uw3d_stress)
      xp.save('vw_stress_3d', vw3d_stress)
      xp.save('itstep',[itstep_stats_counter,nk,dz3d[0,0,0]])
      print('itstep_stats_counter,nk,dz',itstep_stats_counter,nk,dz3d[0,0,0])
      print('data not averaged in z')
   return
 
   return

def read_restart_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d):

   print('read_restart_data called')

   u3d=xp.load('u3d_saved.npy')
   v3d=xp.load('v3d_saved.npy')
   w3d=xp.load('w3d_saved.npy')
   p3d=xp.load('p3d_saved.npy')
   if keps or pans or keps_des or k_eq_les:
      k3d=xp.load('k3d_saved.npy')
   if keps or pans or keps_des:
      eps3d=xp.load('eps3d_saved.npy')
   if sst or kom or kom_des:
      k3d=xp.load('k3d_saved.npy')
      om3d=xp.load('om3d_saved.npy')


   return u3d,v3d,w3d,p3d,k3d,eps3d,om3d

def save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d):

   print('save_data called')

   xp.save('u3d_saved', u3d)
   xp.save('v3d_saved', v3d)
   xp.save('w3d_saved', w3d)
   xp.save('p3d_saved', p3d)
   if keps or pans or keps_des or k_eq_les:
      xp.save('k3d_saved', k3d)
   if keps or pans or keps_des:
      xp.save('eps3d_saved', eps3d)
   if sst or kom or kom_des:
      xp.save('k3d_saved', k3d)
      xp.save('om3d_saved', om3d)


   return 

def vist_wale(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d):
   if itstep == 0 and iter == 0:
      print('vist_wale called')

   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   s11=dudx
   s12=0.5*(dudy+dvdx)
   s13=0.5*(dudz+dwdx)

   s21=s12
   s22=dvdy
   s23=0.5*(dvdz+dwdy)

   s31=s13
   s32=s23
   s33=dwdz

   g11=dudx
   g12=dudy
   g13=dudz

   g21=dvdx
   g22=dvdy
   g23=dvdz
      
   g31=dwdx
   g32=dwdy
   g33=dwdz

#square of g_ij = g_ik g_kj
   g11_2=g11*g11+g12*g21+g13*g31
   g12_2=g11*g12+g12*g22+g13*g32
   g13_2=g11*g13+g12*g23+g13*g33

   g21_2=g21*g11+g22*g21+g23*g31
   g22_2=g21*g12+g22*g22+g23*g32
   g23_2=g21*g13+g22*g23+g23*g33

   g31_2=g31*g11+g32*g21+g33*g31
   g32_2=g31*g12+g32*g22+g33*g32
   g33_2=g31*g13+g32*g23+g33*g33

   gkk_2=(g11_2+g22_2+g33_2)/3.

   sd11=g11_2-gkk_2
   sd12=0.5*(g12_2+g21_2)
   sd13=0.5*(g13_2+g31_2)
   sd21=sd12
   sd22=g22_2-gkk_2
   sd23=0.5*(g23_2+g32_2)

   sd31=sd13
   sd32=sd23
   sd33=g33_2-gkk_2

   sijsij=s11*s11+s12*s12+s13*s13+\
          s21*s21+s22*s22+s23*s23+\
          s31*s31+s32*s32+s33*s33

   sdijsdij=sd11*sd11+sd12*sd12+sd13*sd13+\
          sd21*sd21+sd22*sd22+sd23*sd23+\
          sd31*sd31+sd32*sd32+sd33*sd33

   cm=(10.6*0.1**2)**0.5

   term1=sdijsdij**1.5/xp.maximum((sijsij**2.5+sdijsdij**1.25),1e-10)

# RANS lengthscale
   rl_rans=0.41*dist3d
   rl_les=cmu*vol**0.3333333
   l_dist=0.15*dist3d
   l_max=0.15*delta_max
   dy=xp.diff(y2d[1:,:],axis=1)
# make it 3d
   dy=xp.repeat(dy[:,:,None],repeats=nk,axis=2)

   l_temp=xp.maximum(l_dist,l_max)
   l_temp=xp.maximum(l_temp,dy)
   l_les=xp.minimum(l_temp,delta_max)

   rl=xp.minimum(rl_rans,cm*l_les)


   visold= vis3d
   delta=vol**0.333333
#  vis3d= (cm*delta)**2*term1+viscos
   vis3d= rl**2*term1+viscos

# modify viscosity
   vis3d=modify_vis(vis3d)

#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
   return vis3d

######################### the execution of the code starts here #############################

########### grid specification ###########
datax= xp.loadtxt("x2d.dat")
x=datax[0:-1]
ni=int(datax[-1])
datay= xp.loadtxt("y2d.dat")
y=datay[0:-1]
nj=int(datay[-1])

x2d=xp.zeros((ni+1,nj+1))
y2d=xp.zeros((ni+1,nj+1))

x2d=xp.reshape(x,(ni+1,nj+1))
y2d=xp.reshape(y,(ni+1,nj+1))

dataz=xp.loadtxt('z.dat')
if dataz[1] < 0:
# non-equi-distans grid in z
   nk=xp.abs(dataz[1])
   nk=int(nk)
   z=dataz[2:]
   zmax=dataz[-1]
   dz=xp.diff(z)
# dz3d
   dz3d=xp.ones((ni+1,nj+1,nk))
# make it 2D
   dz3d=xp.repeat(dz[None,:], repeats=nj+1, axis=0)
# make it 3D
   dz3d=xp.repeat(dz3d[None,:,:], repeats=ni+1, axis=0)
else:
# equi-distans grid in z
   zmax=dataz[0]
   nk=dataz[1]
   nk=int(nk)
   dz=zmax/nk
   dz3d=xp.ones((ni+1,nj+1,nk))*dz
   z = xp.linspace(0, zmax, nk+1)

# compute cell centers
xp2d=0.25*(x2d[0:-1,0:-1]+x2d[0:-1,1:]+x2d[1:,0:-1]+x2d[1:,1:])
yp2d=0.25*(y2d[0:-1,0:-1]+y2d[0:-1,1:]+y2d[1:,0:-1]+y2d[1:,1:])
zp=0.5*(z[0:-1]+z[1:])


# initialize geometric arrays

vol=xp.zeros((ni,nj,nk))
areas=xp.zeros((ni,nj+1,nk))
areasx=xp.zeros((ni,nj+1,nk))
areasy=xp.zeros((ni,nj+1,nk))
areaw=xp.zeros((ni+1,nj,nk))
areawx=xp.zeros((ni+1,nj,nk))
areawy=xp.zeros((ni+1,nj,nk))
areal=xp.zeros((ni,nj,nk+1))
as_bound=xp.zeros((ni,nk))
an_bound=xp.zeros((ni,nk))
aw_bound=xp.zeros((nj,nk))
ae_bound=xp.zeros((nj,nk))
al_bound=xp.zeros((ni,nj))
ah_bound=xp.zeros((ni,nj))
fx=xp.zeros((ni,nj,nk))
fy=xp.zeros((ni,nj,nk))
fz=xp.zeros((ni,nj,nk))

# default 
embedded=False
save_vtk_movie = False
vtk_file_name = 'my-movie'
k_min=1e-6
eps_min=1e-6
om_min=1e-10
solver_turb = 'gmres'
solver_vel = 'gmres'
solver_p='pyamg'
amg_relax_phi='default'
amg_cycle_phi='V'
amg_cycle='V'
amg_relax='default'
sst=False
k_eq_les=False
pans = False
fkmin_limit=0.1
keps = False
kom_des = False
keps_des = False
kom = False
wale = False
smag = False
cdes=0.67
i_block_start = 0
i_block_end = 0
j_block_start = 0
j_block_end =  0


norm_order=2
coeff_v=False
coeff_w=False
# default re-conmpute coeff, matrix
# boundary conditions for u. Default
u_bc_west=xp.zeros((nj,nk))
u_bc_east=xp.zeros((nj,nk))
u_bc_south=xp.zeros((ni,nk))
u_bc_north=xp.zeros((ni,nk))
u_bc_low=xp.zeros((ni,nj))
u_bc_high=xp.zeros((ni,nj))

u_bc_west_type='d'
u_bc_east_type='d'
u_bc_south_type='d'
u_bc_north_type='d'
u_bc_low_type='d'
u_bc_high_type='d'


# boundary conditions for v
v_bc_west=xp.zeros((nj,nk))
v_bc_east=xp.zeros((nj,nk))
v_bc_south=xp.zeros((ni,nk))
v_bc_north=xp.zeros((ni,nk))
v_bc_low=xp.zeros((ni,nj))
v_bc_high=xp.zeros((ni,nj))

v_bc_west_type='d'
v_bc_east_type='d'
v_bc_south_type='d'
v_bc_north_type='d'
v_bc_low_type='d'
v_bc_high_type='d'

# boundary conditions for w
w_bc_west=xp.zeros((nj,nk))
w_bc_east=xp.zeros((nj,nk))
w_bc_south=xp.zeros((ni,nk))
w_bc_north=xp.zeros((ni,nk))
w_bc_low=xp.zeros((ni,nj))
w_bc_high=xp.zeros((ni,nj))

w_bc_west_type='d'
w_bc_east_type='d'
w_bc_south_type='d'
w_bc_north_type='d'
w_bc_low_type='d'
w_bc_high_type='d'

# boundary conditions for p
p_bc_west=xp.zeros((nj,nk))
p_bc_east=xp.zeros((nj,nk))
p_bc_south=xp.zeros((ni,nk))
p_bc_north=xp.zeros((ni,nk))
p_bc_low=xp.zeros((ni,nj))
p_bc_high=xp.zeros((ni,nj))

p_bc_west_type='n'
p_bc_east_type='n'
p_bc_south_type='n'
p_bc_north_type='n'
p_bc_low_type='n'
p_bc_high_type='n'

# boundary conditions for k
k_bc_west=xp.zeros((nj,nk))
k_bc_east=xp.zeros((nj,nk))
k_bc_south=xp.zeros((ni,nk))
k_bc_north=xp.zeros((ni,nk))
k_bc_low=xp.zeros((ni,nj))
k_bc_high=xp.zeros((ni,nj))

k_bc_west_type='d'
k_bc_east_type='d'
k_bc_south_type='d'
k_bc_north_type='d'
k_bc_low_type='d'
k_bc_high_type='d'

# boundary conditions for eps
eps_bc_west=xp.zeros((nj,nk))
eps_bc_east=xp.zeros((nj,nk))
eps_bc_south=xp.zeros((ni,nk))
eps_bc_north=xp.zeros((ni,nk))
eps_bc_low=xp.zeros((ni,nj))
eps_bc_high=xp.zeros((ni,nj))

eps_bc_west_type='d'
eps_bc_east_type='d'
eps_bc_south_type='d'
eps_bc_north_type='d'
eps_bc_low_type='d'
eps_bc_high_type='d'

# boundary conditions for om
om_bc_west=xp.zeros((nj,nk))
om_bc_east=xp.zeros((nj,nk))
om_bc_south=xp.zeros((ni,nk))
om_bc_north=xp.zeros((ni,nk))
om_bc_low=xp.zeros((ni,nj))
om_bc_high=xp.zeros((ni,nj))

om_bc_west_type='d'
om_bc_east_type='d'
om_bc_south_type='d'
om_bc_north_type='d'
om_bc_low_type='d'
om_bc_high_type='d'

convergence_limit_gpu = 5e-5
nsweep_vel=50
nsweep_keps=50
nsweep_kom=50
convergence_limit_u=1e-5
convergence_limit_v=1e-5
convergence_limit_w=1e-5
convergence_limit_eps=-1e-4
convergence_limit_k=-1e-4
convergence_limit_p=5e-5

if gpu:
   solver_p='pyamgx_p'
   solver_vel='pyamgx'
   solver_turb='pyamgx'

save_average_z=True

setup_case()

print_indata()

areaw,areawx,areawy,areas,areasx,areasy,areal,vol,fx,fy,fz,aw_bound,ae_bound,as_bound,an_bound,al_bound,ah_bound,dist3d=init()


# initialization
itstep_stats_counter=0 # counter for timeaveraging
u3d=xp.ones((ni,nj,nk))*1e-20
v3d=xp.ones((ni,nj,nk))*1e-20
w3d=xp.ones((ni,nj,nk))*1e-20
p3d=xp.ones((ni,nj,nk))*1e-20
k3d=xp.ones((ni,nj,nk))*1
eps3d=xp.ones((ni,nj,nk))*1
om3d=xp.ones((ni,nj,nk))*1
vis3d=xp.ones((ni,nj,nk))*viscos

fk3d=xp.ones((ni,nj,nk))
fmu3d=xp.ones((ni,nj,nk))
f1_sst=xp.ones((ni,nj,nk))
f2_sst=xp.ones((ni,nj,nk))
gen=xp.ones((ni,nj,nk))

dpdx_old=xp.ones((ni,nj,nk))*1e-20
dpdy_old=xp.ones((ni,nj,nk))*1e-20
dpdz_old=xp.ones((ni,nj,nk))*1e-20

convw=xp.ones((ni+1,nj,nk))*1e-20
convs=xp.ones((ni,nj+1,nk))*1e-20
convl=xp.ones((ni,nj,nk+1))*1e-20

u3d_mean=xp.ones((ni,nj,nk))*1e-20
v3d_mean=xp.ones((ni,nj,nk))*1e-20
w3d_mean=xp.ones((ni,nj,nk))*1e-20
p3d_mean=xp.ones((ni,nj,nk))*1e-20
k3d_mean=xp.ones((ni,nj,nk))*1e-20
om3d_mean=xp.ones((ni,nj,nk))*1e-20
eps3d_mean=xp.ones((ni,nj,nk))*1e-20
uu3d_stress=xp.ones((ni,nj,nk))*1e-20
vv3d_stress=xp.ones((ni,nj,nk))*1e-20
ww3d_stress=xp.ones((ni,nj,nk))*1e-20
uv3d_stress=xp.ones((ni,nj,nk))*1e-20
uw3d_stress=xp.ones((ni,nj,nk))*1e-20
vw3d_stress=xp.ones((ni,nj,nk))*1e-20
fk3d_mean=xp.ones((ni,nj,nk))*1e-20
vis3d_mean=xp.ones((ni,nj,nk))*1e-20
gen_mean=xp.ones((ni,nj,nk))*1e-20

aw3d=xp.ones((ni,nj,nk))*1e-20
ae3d=xp.ones((ni,nj,nk))*1e-20
as3d=xp.ones((ni,nj,nk))*1e-20
an3d=xp.ones((ni,nj,nk))*1e-20
al3d=xp.ones((ni,nj,nk))*1e-20
ah3d=xp.ones((ni,nj,nk))*1e-20
ap3d=xp.ones((ni,nj,nk))*1e-20
apo3d=xp.ones((ni,nj,nk))*1e-20
su3d=xp.ones((ni,nj,nk))*1e-20
sp3d=xp.ones((ni,nj,nk))*1e-20
dudx=xp.ones((ni,nj,nk))*1e-20
dudy=xp.ones((ni,nj,nk))*1e-20
usynt_inlet=xp.ones((nj,nk))*1e-20
vsynt_inlet=xp.ones((nj,nk))*1e-20
wsynt_inlet=xp.ones((nj,nk))*1e-20




# comute Delta_max for LES/DES/PANS models
delta_max=xp.maximum(vol/areas[:,1:],vol/areaw[1:,:])
delta_max=xp.maximum(delta_max,dz)

if solver_p == 'pyamgx' or solver_vel == 'pyamgx' or solver_p == 'pyamgx_p':
   print('pymgx initialized')
   pyamgx.initialize()
   max_amg_iters = 2 #max iters for the amg preconditioners

   cfg_p= pyamgx.Config().create_from_dict({
     "config_version": 2, 
      "solver": {
          "print_grid_stats": 1, 
          "store_res_history": 1, 
          "solver": "FGMRES", 
          "print_solve_stats": 1, 
          "obtain_timings": 1, 
          "preconditioner": {
              "interpolator": "D2", 
              "print_grid_stats": 1, 
              "aggressive_levels": 1, 
              "solver": "AMG", 
              "smoother": {
                  "relaxation_factor": 1, 
                  "scope": "jacobi", 
                  "solver": "JACOBI_L1"
              }, 
              "presweeps": 1, 
              "selector": "PMIS", 
              "coarsest_sweeps": 1, 
              "coarse_solver": "NOSOLVER", 
              "max_iters": 1, 
              "max_row_sum": 0.9, 
              "strength_threshold": 0.25, 
              "min_coarse_rows": 2, 
              "scope": "amg_solver", 
              "max_levels": 24, 
              "cycle": "V", 
              "postsweeps": 1
          }, 
          "max_iters": 100, 
          "monitor_residual": 1, 
          "gmres_n_restart": 10, 
          "convergence": "RELATIVE_INI_CORE", 
          "tolerance": convergence_limit_gpu,
          "norm": "L2"
               }
   })

   rsc = pyamgx.Resources().create_simple(cfg_p)
   solverx = pyamgx.Solver().create(rsc, cfg_p)

itstep=0
iter=0



# initialize
u3d,v3d,w3d,k3d,om3d,eps3d,vis3d,dist3d=modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d)


# read data for restart
if restart: 
   u3d,v3d,w3d,p3d,k3d,eps3d,om3d= read_restart_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

eps3d=xp.maximum(eps3d,eps_min)
k3d=xp.maximum(k3d,k_min)

# set inlet b.c. so the dudx is correctly computed  when vist_smag is called
u3d_face_w = xp.zeros((ni+1,nj,nk))
if not cyclic_x:
   u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw = modify_inlet()

u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_low,u_bc_high,\
    u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_low_type,u_bc_high_type,'u')
v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_low,v_bc_high,\
    v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_low_type,v_bc_high_type,'v')
w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_low,w_bc_high,\
    w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_low_type,w_bc_high_type,'w')
p3d_face_w,p3d_face_s,p3d_face_l=compute_face_phi(p3d,p_bc_west,p_bc_east,p_bc_south,p_bc_north,p_bc_low,p_bc_high,\
    p_bc_west_type,p_bc_east_type,p_bc_south_type,p_bc_north_type,p_bc_low_type,p_bc_high_type,'p')

if not cyclic_x:
   u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw = modify_inlet()


convw,convs,convl=conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l)

u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old=update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l)

itstep=0
iter=0

if sst:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_sst(vis3d,k3d,om3d,f2_sst,gen)
   urfvis=urf_temp

if kom or kom_des:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_kom(vis3d,k3d,om3d)
   urfvis=urf_temp

if k_eq_les:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_k_eq(vis3d,k3d)
   urfvis=urf_temp

if pans or keps_des:
   fk3d=compute_fk(k3d,eps3d,fk3d)
# compute fmu3d
   gen=xp.zeros((ni,nj,nk))
   itstep=1
   itstep=0
   su3d,sp3d,fmu3d= calceps(su3d,sp3d,k3d,eps3d,vis3d,gen)
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_pans(vis3d,k3d,eps3d,fmu3d)
   urfvis=urf_temp

if smag:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
   urfvis=urf_temp


iter=0
itstep=0


# find max index
#sumax=xp.max(su3d.flatten())
#print('[i,j,k]', xp.where(su3d == xp.amax(su3d)) 

residual_u=0
residual_v=0
residual_w=0
residual_p=0
residual_k=0
residual_eps=0
residual_om=0


######################### start of time stepping  #############################

# time every 10th time steps
time_10min = time.time()
for itstep in range(0,ntstep):

######################### start of global iteration process #############################

   for iter in range(0,maxit):

      start_time_iter = time.time()
# coefficients for velocities
      start_time = time.time()
# conpute inlet fluc
      if iter == 0 and not cyclic_x:
         u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw = modify_inlet()
# boundary conditions for u3d
      su3d,sp3d=bc(su3d,sp3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_low,u_bc_high, \
                   u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_low_type,u_bc_high_type)
      if scheme == 'm':
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d_local,su3d_v,su3d_w=coeff_m(convw,convs,convl,vis3d,u3d,v3d,w3d)
         su3d=su3d+su3d_local 
      else:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,1,1,f1_sst,scheme)

# u3d
      su3d,sp3d=calcu(su3d,sp3d,dpdx_old,p3d_face_w,p3d_face_s)
      ap3d,su3d=crank_nicol(u3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)

# this is needed only when blockage ios used
      if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('u')

      if solver_vel == 'pyamg':
         u3d,residual_u=solve_pyamg(u3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_u,acrank_conv)
      elif solver_vel == 'pyamgx':
         u3d,residual_u=solve_pyamgx(u3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_u,acrank_conv,'u')
      else:
         u3d,residual_u=solve_3d(u3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_u,nsweep_vel,acrank_conv,solver_vel,'u')
      print(f"{'time u: '}{time.time()-start_time:.2e}")
      start_time = time.time()
# v3d
# boundary conditions for v3d
      su3d,sp3d=bc(su3d,sp3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_low,v_bc_high, \
                   v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_low_type,v_bc_high_type)
      if scheme == 'm':
         su3d=su3d+su3d_v
      su3d,sp3d=calcv(su3d,sp3d,dpdy_old,p3d_face_w,p3d_face_s)
      ap3d,su3d=crank_nicol(v3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)

# this is needed only when blockage ios used
      if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('v')
   
      if solver_vel == 'pyamg':
         v3d,residual_v=solve_pyamg(v3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_v,acrank_conv)
      elif solver_vel == 'pyamgx':
         v3d,residual_v=solve_pyamgx(v3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_v,acrank_conv,'v')
      else:
         v3d,residual_v=solve_3d(v3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_v,nsweep_vel,acrank_conv,solver_vel,'v')
      print(f"{'time v: '}{time.time()-start_time:.2e}")

      start_time = time.time()
# w3d
# boundary conditions for w3d
      su3d,sp3d=bc(su3d,sp3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_low,w_bc_high, \
                   w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_low_type,w_bc_high_type)
      if scheme == 'm':
         su3d=su3d+su3d_w
      su3d,sp3d=calcw(su3d,sp3d,dpdz_old,p3d_face_l)
      ap3d,su3d=crank_nicol(w3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)

# this is needed only when blockage ios used
      if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('w')

      if solver_vel == 'pyamg':
         w3d,residual_w=solve_pyamg(w3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_w,acrank_conv)
      elif solver_vel == 'pyamgx':
         w3d,residual_w=solve_pyamgx(w3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_w,acrank_conv,'w')
      else:
         w3d,residual_w=solve_3d(w3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_w,nsweep_vel,acrank_conv,solver_vel,'w')
      print(f"{'time w: '}{time.time()-start_time:.2e}")

      start_time = time.time()
# p3d
      convw,convs,convl=conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l)

      if not cyclic_x:
         convw,u_bc_east = modify_outlet(convw)


# RHS
# continuity error
      su3d=(convw[0:-1,:,:]-convw[1:,:,:]\
           +convs[:,0:-1,:]-convs[:,1:,:]\
           +convl[:,:,0:-1]-convl[:,:,1:])/acrank/dt[itstep]

#
      if iter == 0 and itstep == 0:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d=calcp(convw,convs,convl)
# this is needed only when blockage ios used
         if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('p')
# coefficients are (maybe) modified in 'fix_block': re-compute ap3d
         ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d

         aw3d_p=aw3d
         as3d_p=as3d
         al3d_p=al3d

      print('xp.sum(su3d)',xp.sum(su3d))

      if solver_p == 'pyamg':
         p3d=solve_p(p3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_p)
      elif solver_p == 'pyamgx':
# the A matrix is re-computed in solve_pyamgx every time
         aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d=calcp(convw,convs,convl)
         p3d,dummy=solve_pyamgx(p3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_p,1,'p')

      elif solver_p == 'pyamgx_p':
# the A matrix is computed only once (requires more GPUY memory)
         p3d=solve_px(p3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_p)
      else:
         print('solver_p = ',solver_p)
         print('no such solver')
         sys.exit()

# correct u, v, w, p
      convw,convs,convl,p3d,u3d,v3d,w3d,su3d= correct_conv(u3d,v3d,w3d,p3d,aw3d_p,as3d_p,al3d_p)

      res_1d=su3d.flatten()
      residual_p=xp.linalg.norm(res_1d,ord=1)

      u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_low,u_bc_high,\
          u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_low_type,u_bc_high_type,'u')
      v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_low,v_bc_high,\
          v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_low_type,v_bc_high_type,'v')
      w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_low,w_bc_high,\
          w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_low_type,w_bc_high_type,'w')
      p3d_face_w,p3d_face_s,p3d_face_l=compute_face_phi(p3d,p_bc_west,p_bc_east,p_bc_south,p_bc_north,p_bc_low,p_bc_high,\
          p_bc_west_type,p_bc_east_type,p_bc_south_type,p_bc_north_type,p_bc_low_type,p_bc_high_type,'p')


      print(f"{'time p: '}{time.time()-start_time:.2e}")

      start_time = time.time()

      if sst or kom or kom_des:
         if sst:
            vis3d=vist_sst(vis3d,k3d,om3d,f2_sst,gen)
         else:
            vis3d=vist_kom(vis3d,k3d,om3d)
# boundary conditions for k3d
         su3d,sp3d=bc(su3d,sp3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_low,k_bc_high, \
                   k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_low_type,k_bc_high_type)
# coefficients
         if sst:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k_sst_1,prand_k_sst_2,f1_sst,scheme_turb)
         else:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,prand_k,f1_sst,scheme_turb)
# k
         su3d,sp3d,gen,comm_term=calck_kom(su3d,sp3d,k3d,om3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l)


         ap3d,su3d=crank_nicol(k3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_kom)

# this is needed only when blockage ios used
         if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('k')

         if solver_turb == 'pyamg':
            k3d,residual_k=solve_pyamg(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_kom)
         elif solver_turb == 'pyamgx':
            k3d,residual_k=solve_pyamgx(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_kom,'k')
         else:
            k3d,residual_k=solve_3d(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,nsweep_kom,acrank_conv_kom,solver_turb,'k')
         print(f"{'time k: '}{time.time()-start_time:.2e}")

         k3d=xp.maximum(k3d,k_min)

         start_time = time.time()
# omega
# boundary conditions for om3d
         su3d,sp3d=bc(su3d,sp3d,om_bc_west,om_bc_east,om_bc_south,om_bc_north,om_bc_low,om_bc_high, \
                   om_bc_west_type,om_bc_east_type,om_bc_south_type,om_bc_north_type,om_bc_low_type,om_bc_high_type)
         if sst:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_omega_sst_1,prand_omega_sst_2,f1_sst,scheme_turb)
         else:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_omega,prand_omega,f1_sst,scheme_turb)
         if sst:
            su3d,sp3d,f1_sst,f2_sst= calcom_sst(su3d,sp3d,om3d,gen,comm_term)
         else:
            su3d,sp3d= calcom(su3d,sp3d,om3d,gen,comm_term)
         ap3d,su3d=crank_nicol(om3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_kom)

# this is needed only when blockage ios used
         if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('om')

         aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_omega()

         if solver_turb == 'pyamg':
            om3d,residual_om=solve_pyamg(om3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_om,acrank_conv_kom)
         elif solver_turb == 'pyamgx':
            om3d,residual_om=solve_pyamgx(om3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_om,acrank_conv_kom,'om')
         else:
            om3d,residual_om=solve_3d(om3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_om,nsweep_kom,acrank_conv_kom,solver_turb,'om')
         om3d=xp.maximum(om3d,om_min)

         print(f"{'time omega: '}{time.time()-start_time:.2e}")

         start_time = time.time()

      if smag:
         vis3d=vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
      if wale:
         vis3d=vist_wale(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)

      print(f"{'time Smag or Wale: '}{time.time()-start_time:.2e}")
      start_time = time.time()
      if pans or keps or keps_des or k_eq_les:
         if k_eq_les:
            vis3d=vist_k_eq(vis3d,k3d)
         else:
            vis3d=vist_pans(vis3d,k3d,eps3d,fmu3d)
# coefficients
         start_time = time.time()
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,prand_k,f1_sst,scheme_turb)
# k
# boundary conditions for k3d
         su3d,sp3d=bc(su3d,sp3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_low,k_bc_high, \
                   k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_low_type,k_bc_high_type)
         su3d,sp3d,gen,dudx,dudy=calck(su3d,sp3d,k3d,eps3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l)

         ap3d,su3d=crank_nicol(k3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_keps)

# this is needed only when blockage ios used
         if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
            aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('k')

         aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_k()

         if solver_turb == 'pyamg':
            k3d,residual_k=solve_pyamg(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_keps)
         elif solver_turb == 'pyamgx':
            k3d,residual_k=solve_pyamgx(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_keps,'k')
         else:
            k3d,residual_k=solve_3d(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,nsweep_keps,acrank_conv_keps,solver_turb,'k')


         k3d=xp.maximum(k3d,k_min)
         print(f"{'time k: '}{time.time()-start_time:.2e}")
         start_time = time.time()

         if not k_eq_les:
           start_time  = time.time()
# boundary conditions for eps3d
           su3d,sp3d=bc(su3d,sp3d,eps_bc_west,eps_bc_east,eps_bc_south,eps_bc_north,eps_bc_low,eps_bc_high, \
                   eps_bc_west_type,eps_bc_east_type,eps_bc_south_type,eps_bc_north_type,eps_bc_low_type,eps_bc_high_type)
# eps
           aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_eps,prand_eps,f1_sst,scheme_turb)
           su3d,sp3d,fmu3d= calceps(su3d,sp3d,k3d,eps3d,vis3d,gen)

           ap3d,su3d=crank_nicol(eps3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_keps)

# this is needed only when blockage ios used
           if i_block_start !=0 or i_block_end !=0 or j_block_start !=0 or i_block_end !=0:
              aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_block('eps')

           aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d=fix_eps()

           if solver_turb == 'pyamg':
            eps3d,residual_eps=solve_pyamg(eps3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_eps,acrank_conv_keps)
           elif solver_turb == 'pyamgx':
            eps3d,residual_eps=solve_pyamgx(eps3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_eps,acrank_conv_keps,'eps')
           else:
            eps3d,residual_eps=solve_3d(eps3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_eps,nsweep_keps,acrank_conv_keps,solver_turb,'eps')

           eps3d=xp.maximum(eps3d,eps_min)

           print(f"{'time eps: '}{time.time()-start_time:.2e}")

      if pans or sst or kom_des or keps_des:
         fk3d=compute_fk(k3d,eps3d,fk3d)
	       

# scale residuals
      residual_u=residual_u/resnorm_vel
      residual_v=residual_v/resnorm_vel
      residual_w=residual_w/resnorm_vel
      residual_p=residual_p/resnorm_p
      residual_k=residual_k/resnorm_vel**2
      residual_eps=residual_eps/resnorm_vel**3
      residual_om=residual_om/resnorm_vel

#     resmax=xp.max([residual_u ,residual_v,residual_p,residual_k,residual_eps,residual_om])
      resmax=xp.max(xp.array([residual_u ,residual_v,residual_w,residual_p]))

      print(f"\n{'--time step:'}{itstep:4d}, {'iter:'}{iter:d}, {'max residual:'}{resmax:.2e}, {'u:'}{residual_u:.2e}\
, {'v:'}{residual_v:.2e}, {'w:'}{residual_w:.2e}, {'p:'}{residual_p:.2e}, {'k:'}{residual_k:.2e}\
, {'eps:'}{residual_eps:.2e}, {'om:'}{residual_om:.2e}\n")

      print(f"\n{'monitor time step:'}{itstep:4d}, {'iter:'}{iter:1d}, {'u:'}{u3d[imon,jmon,kmon]: .2e}\
, {'v:'}{v3d[imon,jmon,kmon]: .2e}, {'w:'}{w3d[imon,jmon,kmon]: .2e}, {'p:'}{p3d[imon,jmon,kmon]: .2e}\
, {'k:'}{k3d[imon,jmon,kmon]: .2e}, {'eps:'}{eps3d[imon,jmon,kmon]: .2e}, {'om:'}{om3d[imon,jmon,kmon]: .2e}\n")



      vismax=xp.max(vis3d.flatten())/viscos
      umax=xp.max(u3d.flatten())
      fkmax=xp.max(fk3d.flatten())
      epsmin=xp.min(eps3d.flatten())
      ommin=xp.min(om3d.flatten())


      kmin=xp.min(k3d.flatten())

      if itstep%10 == 0:
         cfl_x=xp.abs(u3d)*dt[itstep]*areaw[1:,:,:]/vol
         cfl_y=xp.abs(v3d)*dt[itstep]*areas[:,1:,:]/vol
         cfl_x_max=xp.max(cfl_x)
         cfl_y_max=xp.max(cfl_y)

# number of points larger than one
         cfl_y_no=cfl_y[xp.where( cfl_y > 1 )]
         cfl_x_no=cfl_x[xp.where( cfl_x > 1 )]

         [i,j,k]= xp.where(cfl_y== xp.amax(cfl_y))
# check if i is an array
         if xp.size(i): 
            ii=i[0]
            jj=j[0]
         else:
            ii=i
            jj=j
         print(f"\n{'-- cfl_max_x: '}{cfl_x_max:.2e}, {'cfl_max_y: '}{cfl_y_max:.2e}, {'at i= '}{ii:2d}, {'j= '}{jj:2d}\n")

         print(f"\n{'No of points cfl_x > 1: '}{len(cfl_x_no):2d}, {'No of points cfl_y > 1: '}{len(cfl_y_no):2d}\n")
 

      print(f"\n{'--time step: '} {dt[itstep]:.2e}, {'iter: '}{iter:2d}, {'umax: '}{umax:.2e}, {'vismax: '}{vismax:.2e}, {'kmin: '}{kmin:.2e}, {'epsmin: '}{epsmin:.2e}, {'ommin: '}{ommin:.2e}, {'fkmax: '}{fkmax:.2e}\n")

      print(f"{'time one iteration: '}{time.time()-start_time_iter:.2e}")

# every 10th time step
      if itstep%10 == 0 and iter == 0:
         print(f"{'time every 10th time steps: '}{time.time()-time_10min:.2e}")
         time_10min = time.time()

      if iter >= min_iter-1 and resmax < sormax:  

         break

######################### end of global iteration process #############################

   u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old=\
     update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l)
# save data every itstep_save timsstep
   if itstep%itstep_save == 0 and itstep >= itstep_start:
      save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                          vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,gen_mean)
   if save and itstep%itstep_save == 0 and itstep > 0:
      save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)
      if vtk and save_vtk_movie:
         save_vtk()
# save data if file save.data  = 1
   if save and itstep%2 == 0:
      isave=0
      isave = xp.loadtxt('save.file')
      if isave == 1:
         save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)
         save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                          vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,gen_mean)

   if itstep >= itstep_start and itstep % itstep_stats == 0:
      u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,\
          fk3d_mean,vis3d_mean,gen_mean= time_stats(u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,\
          vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,fk3d_mean,vis3d_mean,gen_mean)

   vismean_mean=xp.max(vis3d_mean.flatten())/viscos/(itstep+1)

   print('vismean_mean',vismean_mean)

######################### end of time stepping  #############################
      
# save data for restart
if save:
   save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

if vtk:
   itstep=ntstep
   save_vtk()

# save time-averaged data
save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                    vv3d_stress,ww3d_stress,uv3d_stress,uw3d_stress,vw3d_stress,gen_mean)
print('program reached normal stop')

