from scipy import sparse
import numpy as np
import sys
import time
import pyamg
from scipy.sparse import spdiags,linalg,eye

def init():

# distance to nearest wall
   ywall_s=0.5*(y2d[0:-1,0]+y2d[1:,0])
   dist_s=yp2d-ywall_s[:,None]
   ywall_n=0.5*(y2d[0:-1,-1]+y2d[1:,-1])
   dist_n=ywall_n[:,None] -yp2d
   dist=np.minimum(dist_s,dist_n)
   dist3d=np.repeat(dist[:,:,None], repeats=nk, axis=2)

#  west face coordinate
   xw=0.5*(x2d[0:-1,0:-1]+x2d[0:-1,1:])
   yw=0.5*(y2d[0:-1,0:-1]+y2d[0:-1,1:])

   del1x=((xw-xp2d)**2+(yw-yp2d)**2)**0.5
   del2x=((xw-np.roll(xp2d,1,axis=0))**2+(yw-np.roll(yp2d,1,axis=0))**2)**0.5
   fx=del2x/(del1x+del2x)
   fx = np.dstack([fx]*nk)

   if cyclic_x:
     fx[0,:,:]=0.5

#  south face coordinate
   xs=0.5*(x2d[0:-1,0:-1]+x2d[1:,0:-1])
   ys=0.5*(y2d[0:-1,0:-1]+y2d[1:,0:-1])

   del1y=((xs-xp2d)**2+(ys-yp2d)**2)**0.5
   del2y=((xs-np.roll(xp2d,1,axis=1))**2+(ys-np.roll(yp2d,1,axis=1))**2)**0.5
   fy=del2y/(del1y+del2y)
   fy = np.dstack([fy]*nk)

   areawy=np.diff(x2d,axis=1)*dz
   areawx=-np.diff(y2d,axis=1)*dz

# make them 3d
   areawx= np.dstack([areawx]*nk)
   areawy= np.dstack([areawy]*nk)

   areasy=-np.diff(x2d,axis=0)*dz
   areasx=np.diff(y2d,axis=0)*dz
# make them 3d
   areasx= np.dstack([areasx]*nk)
   areasy= np.dstack([areasy]*nk)

#  areaz=np.zeros((ni,nj,nk+1))

   areaw=(areawx**2+areawy**2)**0.5
   areas=(areasx**2+areasy**2)**0.5

# volume approaximated as the vector product of two triangles for cells
   ax=np.diff(x2d,axis=1)
   ay=np.diff(y2d,axis=1)
   bx=np.diff(x2d,axis=0)
   by=np.diff(y2d,axis=0)

   areaz_1=0.5*np.absolute(ax[0:-1,:]*by[:,0:-1]-ay[0:-1,:]*bx[:,0:-1])

   ax=np.diff(x2d,axis=1)
   ay=np.diff(y2d,axis=1)
   areaz_2=0.5*np.absolute(ax[1:,:]*by[:,0:-1]-ay[1:,:]*bx[:,0:-1])

   areaz=areaz_1+areaz_2
# make it 3d
   vol=areaz*dz
   vol= np.dstack([vol]*nk)
#  vol= np.loadtxt("vol.dat")
#  vol=np.reshape(vol,(ni,nj))
#  vol= np.dstack([vol]*nk)

# make it 3d
   areaz= np.dstack([areaz]*(nk+1))

# coeff at south wall (without viscosity)
   as_bound=areas[:,0,:]**2/(0.5*vol[:,0,:])

# coeff at north wall (without viscosity)
   an_bound=areas[:,-1,:]**2/(0.5*vol[:,-1,:])

# coeff at west wall (without viscosity)
   aw_bound=areaw[0,:,:]**2/(0.5*vol[0,:,:])
   if cyclic_x:
      aw_bound=areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))

# coeff at east wall (without viscosity) N.B: which cyclic_x
# this is never used 
   ae_bound=areaw[-1,:,:]**2/(0.5*vol[-1,:,:])

# make it 2d
   az_bound=areaz[:,:,0]/(0.5*dz)  # wall node located AT fhe boudary 

   return areaw,areawx,areawy,areas,areasx,areasy,areaz,vol,fx,fy,aw_bound,ae_bound,as_bound,an_bound,az_bound,dist3d

   return 

def print_indata():

   print('Start of input data\n\n\n')

   print('\n\n########### section 1 choice of differencing scheme ###########')
   print(f"{'scheme: ':<29}   {scheme}")
   print(f"{'scheme_turb: ':<29}   {scheme_turb}")
   print(f"{'acrank: ':<29}   {acrank}")
   print(f"{'acrank_conv: ':<29}   {acrank_conv}")
   print(f"{'acrank_conv_keps: ':<29}   {acrank_conv_keps}")
   print(f"{'acrank_conv_kom: ':<29}   {acrank_conv_kom}")

   print('\n\n########### section 2 turbulence models ###########')

   print(f"{'cmu: ':<29} {cmu}")
   print(f"{'pans: ':<29} {pans}")
   print(f"{'keps: ':<29} {keps}")
   print(f"{'kom_des: ':<29} {kom_des}")
   if jl0 < 0:
      print(f"{'jl0: ':<29} {jl0}")
   print(f"{'kom: ':<29} {kom}")
   print(f"{'wale: ':<29} {wale}")
   print(f"{'smag: ':<29} {smag}")
   if keps or pans:
      print(f"{'c_eps_1: ':<29} {c_eps_1}")
      print(f"{'c_eps_2: ':<29} {c_eps_2}")
      print(f"{'prand_k: ':<29} {prand_k}")
   print(f"{'prand_eps: ':<29} {prand_eps}")
   if kom or kom_des:
      print(f"{'c_omega_1: ':<29} {c_omega_1:.3f}")
      print(f"{'c_omega_2: ':<29} {c_omega_2}")
      print(f"{'prand_k: ':<29} {prand_k}")
      print(f"{'prand_omega: ':<29} {prand_omega}")


   print('\n\n########### section 3 restart/save ###########')
   print(f"{'restart: ':<29} {restart}")
   print(f"{'save: ':<29} {save}")


   print('\n\n########### section 4 fluid properties ###########')
   print(f"{'viscos: ':<29} {viscos}")

   print('\n\n########### section 5 relaxation factors ###########')
   print(f"{'urfvis: ':<29} {urfvis}")


   print('\n\n########### section 6 number of iteration and convergence criterira ###########')
   print(f"{'sormax: ':<29} {sormax}")
   print(f"{'maxit: ':<29} {maxit}")
   print(f"{'min_iter: ':<29} {min_iter}")
   print(f"{'amg_relax: ':<29} {amg_relax}")
   print(f"{'amg_cycle: ':<29} {amg_cycle}")
   print(f"{'solver_vel: ':<29} {solver_vel}")
   print(f"{'solver_turb: ':<29} {solver_turb}")
   print(f"{'nsweep_vel: ':<29} {nsweep_vel}")
   print(f"{'nsweep_keps: ':<29} {nsweep_keps}")
   print(f"{'nsweep_kom: ':<29} {nsweep_kom}")
   print(f"{'convergence_limit_u: ':<29} {convergence_limit_u}")
   print(f"{'convergence_limit_v: ':<29} {convergence_limit_v}")
   print(f"{'convergence_limit_w: ':<29} {convergence_limit_w}")
   print(f"{'convergence_limit_p: ':<29} {convergence_limit_p}")
   if keps or pans or kom or kom_des:
      print(f"{'convergence_limit_k: ':<29} {convergence_limit_k}")
   if keps or pans:
      print(f"{'convergence_limit_eps: ':<29} {convergence_limit_eps}")
   if kom or kom_des:
      print(f"{'convergence_limit_om: ':<29} {convergence_limit_om}")


   print('\n\n########### section 7 all variables are printed during the iteration at node ###########')
   print(f"{'imon: ':<29} {imon}")
   print(f"{'jmon: ':<29} {jmon}")
   print(f"{'kmon: ':<29} {kmon}")


   print('\n\n########### section 8 time-averaging ###########')
   print(f"{'ntstep: ':<29} {ntstep}")
   print(f"{'dt[0]: ':<29} {dt[0]:.2e}")
   print(f"{'itstep_start: ':<29} {itstep_start}")
   print(f"{'itstep_save: ':<29} {itstep_save}")
   print(f"{'itstep_stats: ':<29} {itstep_stats}")


   print('\n\n########### section 9 residual scaling parameters ###########')
   print(f"{'resnorm_p: ':<29} {resnorm_p:.1f}")
   print(f"{'resnorm_vel: ':<29} {resnorm_vel:.1f}")


   print('\n\n########### Section 10 grid and boundary conditions ###########')
   print(f"{'ni: ':<29} {ni}")
   print(f"{'nj: ':<29} {nj}")
   print(f"{'nk: ':<29} {nk}")
   print('\n')
   print(f"{'cyclic_x: ':<29} {cyclic_x}")
   print(f"{'cyclic_z: ':<29} {cyclic_z}")
   print('\n')
   print(f"{'L_t_synt: ':<29} {L_t_synt}")
   print(f"{'nmodes_synt: ':<29} {nmodes_synt}")
   print(f"{'jmirror_synt: ':<29} {jmirror_synt}")
   print('\n')

   print('------boundary conditions for u')
   if not cyclic_x:
      print(f"{' ':<5}{'u_bc_west_type: ':<29} {u_bc_west_type}")
      print(f"{' ':<5}{'u_bc_east_type: ':<29} {u_bc_east_type}")
      if u_bc_west_type == 'd':
         print(f"{' ':<5}{'u_bc_west[0,0]: ':<29} {u_bc_west[0,0]}")
      if u_bc_east_type == 'd':
         print(f"{' ':<5}{'u_bc_east[0,0]: ':<29} {u_bc_east[0,0]}")


   print(f"{' ':<5}{'u_bc_south_type: ':<29} {u_bc_south_type}")
   print(f"{' ':<5}{'u_bc_north_type: ':<29} {u_bc_north_type}")

   if u_bc_south_type == 'd':
      print(f"{' ':<5}{'u_bc_south[0,0]: ':<29} {u_bc_south[0,0]}")
   if u_bc_north_type == 'd':
      print(f"{' ':<5}{'u_bc_north[0,0]: ':<29} {u_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'u_bc_z_type: ':<29} {u_bc_z_type}")
      if u_bc_z_type == 'd':
         print(f"{' ':<5}{'u_bc_z: ':<29} {u_bc_z}")

   print('------boundary conditions for v')
   if not cyclic_x:
      print(f"{' ':<5}{'v_bc_west_type: ':<29} {v_bc_west_type}")
      print(f"{' ':<5}{'v_bc_east_type: ':<29} {v_bc_east_type}")
      if v_bc_west_type == 'd':
         print(f"{' ':<5}{'v_bc_west[0,0]: ':<29} {v_bc_west[0,0]}")
      if v_bc_east_type == 'd':
         print(f"{' ':<5}{'v_bc_east[0,0]: ':<29} {v_bc_east[0,0]}")


   print(f"{' ':<5}{'v_bc_south_type: ':<29} {v_bc_south_type}")
   print(f"{' ':<5}{'v_bc_north_type: ':<29} {v_bc_north_type}")

   if v_bc_south_type == 'd':
      print(f"{' ':<5}{'v_bc_south[0,0]: ':<29} {v_bc_south[0,0]}")
   if v_bc_north_type == 'd':
      print(f"{' ':<5}{'v_bc_north[0,0]: ':<29} {v_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'v_bc_z_type: ':<29} {v_bc_z_type}")
      if v_bc_z_type == 'd':
         print(f"{' ':<5}{'v_bc_z: ':<29} {v_bc_z}")

   print('------boundary conditions for w')
   if not cyclic_x:
      print(f"{' ':<5}{'w_bc_west_type: ':<29} {w_bc_west_type}")
      print(f"{' ':<5}{'w_bc_east_type: ':<29} {w_bc_east_type}")
      if w_bc_west_type == 'd':
         print(f"{' ':<5}{'w_bc_west[0,0]: ':<29} {w_bc_west[0,0]}")
      if w_bc_east_type == 'd':
         print(f"{' ':<5}{'w_bc_east[0,0]: ':<29} {w_bc_east[0,0]}")


   print(f"{' ':<5}{'w_bc_south_type: ':<29} {w_bc_south_type}")
   print(f"{' ':<5}{'w_bc_north_type: ':<29} {w_bc_north_type}")

   if w_bc_south_type == 'd':
      print(f"{' ':<5}{'w_bc_south[0,0]: ':<29} {w_bc_south[0,0]}")
   if w_bc_north_type == 'd':
      print(f"{' ':<5}{'w_bc_north[0,0]: ':<29} {w_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'w_bc_z_type: ':<29} {w_bc_z_type}")
      if w_bc_z_type == 'd':
         print(f"{' ':<5}{'w_bc_z: ':<29} {w_bc_z}")

   print('------boundary conditions for p')
   if not cyclic_x:
      print(f"{' ':<5}{'p_bc_west_type: ':<29} {p_bc_west_type}")
      print(f"{' ':<5}{'p_bc_east_type: ':<29} {p_bc_east_type}")
      if p_bc_west_type == 'd':
         print(f"{' ':<5}{'p_bc_west[0,0]: ':<29} {p_bc_west[0,0]}")
      if p_bc_east_type == 'd':
         print(f"{' ':<5}{'p_bc_east[0,0]: ':<29} {p_bc_east[0,0]}")


   print(f"{' ':<5}{'p_bc_south_type: ':<29} {p_bc_south_type}")
   print(f"{' ':<5}{'p_bc_north_type: ':<29} {p_bc_north_type}")

   if p_bc_south_type == 'd':
      print(f"{' ':<5}{'p_bc_south[0,0]: ':<29} {p_bc_south[0,0]}")
   if p_bc_north_type == 'd':
      print(f"{' ':<5}{'p_bc_north[0,0]: ':<29} {p_bc_north[0,0]}")

   if not cyclic_z:
      print(f"{' ':<5}{'p_bc_z_type: ':<29} {p_bc_z_type}")
      if p_bc_z_type == 'd':
         print(f"{' ':<5}{'p_bc_z: ':<29} {p_bc_z}")

   print('------boundary conditions for k')
   if kom or kom_des or keps or pans:
      if not cyclic_x:
         print(f"{' ':<5}{'k_bc_west_type: ':<29} {k_bc_west_type}")
         print(f"{' ':<5}{'k_bc_east_type: ':<29} {k_bc_east_type}")
         if k_bc_west_type == 'd':
            print(f"{' ':<5}{'k_bc_west[0,0]: ':<29} {k_bc_west[0,0]}")
         if k_bc_east_type == 'd':
            print(f"{' ':<5}{'k_bc_east[0,0]: ':<29} {k_bc_east[0,0]}")
   
   
      print(f"{' ':<5}{'k_bc_south_type: ':<29} {k_bc_south_type}")
      print(f"{' ':<5}{'k_bc_north_type: ':<29} {k_bc_north_type}")
   
      if k_bc_south_type == 'd':
         print(f"{' ':<5}{'k_bc_south[0,0]: ':<29} {k_bc_south[0,0]}")
      if k_bc_north_type == 'd':
         print(f"{' ':<5}{'k_bc_north[0,0]: ':<29} {k_bc_north[0,0]}")
   
      if not cyclic_z:
         print(f"{' ':<5}{'k_bc_z_type: ':<29} {k_bc_z_type}")
         if k_bc_z_type == 'd':
            print(f"{' ':<5}{'k_bc_z: ':<29} {k_bc_z}")


   if keps or pans:
      print('------boundary conditions for eps')
      if not cyclic_x:
         print(f"{' ':<5}{'eps_bc_west_type: ':<29} {eps_bc_west_type}")
         print(f"{' ':<5}{'eps_bc_east_type: ':<29} {eps_bc_east_type}")
         if eps_bc_west_type == 'd':
            print(f"{' ':<5}{'eps_bc_west[0,0]: ':<29} {eps_bc_west[0,0]}")
         if eps_bc_east_type == 'd':
            print(f"{' ':<5}{'eps_bc_east[0,0]: ':<29} {eps_bc_east[0,0]}")
   
   
      print(f"{' ':<5}{'eps_bc_south_type: ':<29} {eps_bc_south_type}")
      print(f"{' ':<5}{'eps_bc_north_type: ':<29} {eps_bc_north_type}")
   
      if eps_bc_south_type == 'd':
         print(f"{' ':<5}{'eps_bc_south[0,0]: ':<29} {eps_bc_south[0,0]}")
      if eps_bc_north_type == 'd':
         print(f"{' ':<5}{'eps_bc_north[0,0]: ':<29} {eps_bc_north[0,0]}")
   
      if not cyclic_z:
         print(f"{' ':<5}{'eps_bc_z_type: ':<29} {eps_bc_z_type}")
         if eps_bc_z_type == 'd':
            print(f"{' ':<5}{'eps_bc_z: ':<29} {eps_bc_z:.1f}")

   if kom or kom_des:
      print('------boundary conditions for omega')
      if not cyclic_x:
         print(f"{' ':<5}{'om_bc_west_type: ':<29} {om_bc_west_type}")
         print(f"{' ':<5}{'om_bc_east_type: ':<29} {om_bc_east_type}")
         if om_bc_west_type == 'd':
            print(f"{' ':<5}{'om_bc_west[0,0]: ':<29} {om_bc_west[0,0]:.1f}")
         if om_bc_east_type == 'd':
            print(f"{' ':<5}{'om_bc_east[0,0]: ':<29} {om_bc_east[0,0]:.1f}")
   
   
      print(f"{' ':<5}{'om_bc_south_type: ':<29} {om_bc_south_type}")
      print(f"{' ':<5}{'om_bc_north_type: ':<29} {om_bc_north_type}")
   
      if om_bc_south_type == 'd':
         print(f"{' ':<5}{'om_bc_south[0,0]: ':<29} {om_bc_south[0,0]:.1f}")
      if om_bc_north_type == 'd':
         print(f"{' ':<5}{'om_bc_north[0,0]: ':<29} {om_bc_north[0,0]:.1f}")
   
      if not cyclic_z:
         print(f"{' ':<5}{'om_bc_z_type: ':<29} {om_bc_z_type}")
         if om_bc_z_type == 'd':
            print(f"{' ':<5}{'om_bc_z: ':<29} {om_bc_z:.1f}")


   print('\n\n\nEnd of input data')



   return 

def compute_face_phi(phi3d,phi_bc_west,phi_bc_east,phi_bc_south,phi_bc_north,phi_bc_z,\
    phi_bc_west_type,phi_bc_east_type,phi_bc_south_type,phi_bc_north_type,phi_bc_z_type):
   import numpy as np

   phi3d_face_w=np.empty((ni+1,nj,nk))
   phi3d_face_s=np.empty((ni,nj+1,nk))
   phi3d_face_l=np.empty((ni,nj,nk+1))
   phi3d_face_w[0:-1,:,:]=fx*phi3d+(1-fx)*np.roll(phi3d,1,axis=0)
   phi3d_face_s[:,0:-1,:]=fy*phi3d+(1-fy)*np.roll(phi3d,1,axis=1)
   phi3d_face_l[:,:,0:-1]=0.5*np.roll(phi3d,1,axis=2)+0.5*phi3d


# west boundary 
   phi3d_face_w[0,:,:]=phi_bc_west
   if phi_bc_west_type == 'n': 
# neumann
      phi3d_face_w[0,:,:]=phi3d[0,:,:]
   if cyclic_x:
      phi3d_face_w[0,:,:]=0.5*(phi3d[0,:,:]+phi3d[-1,:,:])

# east boundary 
   phi3d_face_w[-1,:,:]=phi_bc_east
   if phi_bc_east_type == 'n': 
# neumann
      phi3d_face_w[-1,:,:]=phi3d[-1,:,:]
   if cyclic_x:
      phi3d_face_w[-1,:,:]=0.5*(phi3d[0,:,:]+phi3d[-1,:,:])

# south boundary 
   phi3d_face_s[:,0,:]=phi_bc_south
   if phi_bc_south_type == 'n': 
# neumann
      phi3d_face_s[:,0,:]=phi3d[:,0,:]
# d2phidy2=0
   if phi_bc_south_type == '2': 
      phi3d_face_s[:,0,:]=1.5*phi3d[:,0,:]-0.5*phi3d[:,1,:]

# north boundary 
   phi3d_face_s[:,-1,:]=phi_bc_north
   if phi_bc_north_type == 'n': 
# neumann
      phi3d_face_s[:,-1,:]=1.5*phi3d[:,-1,:]-0.5*phi3d[:,-2,:]
   if phi_bc_north_type == '2': 
# d2phidy2=0
      phi3d_face_s[:,-1,:]=1.5*phi3d[:,-1,:]-0.5*phi3d[:,-2,:]

# low boundary 
   phi3d_face_l[:,:,0]=phi_bc_z
# high boundary 
   phi3d_face_l[:,:,-1]=phi_bc_z
   if phi_bc_z_type == 'n': 
# neumann
# low boundary 
      phi3d_face_l[:,:,0]= phi3d[:,:,0]
# high boundary 
      phi3d_face_l[:,:,-1]= phi3d[:,:,-1]
   if cyclic_z:
# low boundary 
      phi3d_face_l[:,:,0]= 0.5*(phi3d[:,:,-1]+phi3d[:,:,0])
# high boundary 
      phi3d_face_l[:,:,-1]= 0.5*(phi3d[:,:,-1]+phi3d[:,:,0])
   
   return phi3d_face_w,phi3d_face_s,phi3d_face_l

def dphidx(phi_face_w,phi_face_s):

   phi_w=phi_face_w[0:-1,:,:]*areawx[0:-1,:,:]
   phi_e=-phi_face_w[1:,:,:]*areawx[1:,:,:]
   phi_s=phi_face_s[:,0:-1,:]*areasx[:,0:-1,:]
   phi_n=-phi_face_s[:,1:,:]*areasx[:,1:,:]
   return (phi_w+phi_e+phi_s+phi_n)/vol

def dphidy(phi_face_w,phi_face_s):

   phi_w=phi_face_w[0:-1,:,:]*areawy[0:-1,:,:]
   phi_e=-phi_face_w[1:,:,:]*areawy[1:,:,:]
   phi_s=phi_face_s[:,0:-1,:]*areasy[:,0:-1,:]
   phi_n=-phi_face_s[:,1:,:]*areasy[:,1:,:]
   return (phi_w+phi_e+phi_s+phi_n)/vol

def dphidz(phi_face_l):

   phi_l=phi_face_l[:,:,0:-1]
   phi_h=phi_face_l[:,:,1:]
   return (phi_h-phi_l)/dz

def coeff(convw,convs,convl,vis3d,prand,scheme_local):

   visw=np.zeros((ni+1,nj,nk))
   viss=np.zeros((ni,nj+1,nk))
   visl=np.zeros((ni,nj,nk+1))
   if prand > 0:
      vis_turb=(vis3d-viscos)/prand
   elif pans: # k and eps in PANS
      vis_turb=(vis3d-viscos)/np.abs(prand)/fk3d**2
      

   visw[0:-1,:,:]=fx*vis_turb+(1-fx)*np.roll(vis_turb,1,axis=0)+viscos
   viss[:,0:-1,:]=fy*vis_turb+(1-fy)*np.roll(vis_turb,1,axis=1)+viscos
   visl[:,:,0:-1]=0.5*vis_turb+0.5*np.roll(vis_turb,1,axis=2)+viscos


   if cyclic_z:
      visl[:,:,0]=0.5*(vis_turb[:,:,0]+vis_turb[:,:,-1])+viscos

   volw=np.ones((ni+1,nj,nk))*1e-10
   vols=np.ones((ni,nj+1,nk))*1e-10
   volw[1:,:,:]=0.5*np.roll(vol,-1,axis=0)+0.5*vol
   diffw=visw[0:-1,:,:]*areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*np.roll(vol,-1,axis=1)+0.5*vol
   diffs=viss[:,0:-1,:]*areas[:,0:-1,:]**2/vols[:,0:-1,:]
   diffl=visl[:,:,0:-1]*areaz[:,:,0:-1]/dz

   if cyclic_x:
      visw[0,:,:]=0.5*(vis_turb[0,:,:]+vis_turb[-1,:,:])+viscos
      diffw[0,:,:]=visw[0,:,:]*areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))


   if scheme_local == 'h':
      if itstep == 0 and iter == 0:
         print('hybrid scheme, prand=',prand)

      aw3d=np.maximum(convw[0:-1,:,:],diffw+(1-fx)*convw[0:-1,:,:])
      aw3d=np.maximum(aw3d,0.)

      ae3d=np.maximum(-convw[1:,:,:],np.roll(diffw,-1,axis=0)-np.roll(fx,-1,axis=0)*convw[1:,:,:])
      ae3d=np.maximum(ae3d,0.)

      as3d=np.maximum(convs[:,0:-1,:],diffs+(1-fy)*convs[:,0:-1,:])
      as3d=np.maximum(as3d,0.)

      an3d=np.maximum(-convs[:,1:,:],np.roll(diffs,-1,axis=1)-np.roll(fy,-1,axis=1)*convs[:,1:,:])
      an3d=np.maximum(an3d,0.)

      al3d=np.maximum(convl[:,:,0:-1],diffl+0.5*convl[:,:,0:-1])
      al3d=np.maximum(al3d,0.)

      ah3d=np.maximum(-convl[:,:,1:],np.roll(diffl,-1,axis=2)-0.5*convl[:,:,1:])
      ah3d=np.maximum(ah3d,0.)

   if scheme_local == 'u':
      if itstep == 0 and iter == 0:
         print('upwind scheme, prand=',prand)

      aw3d=np.maximum(convw[0:-1,:,:],0)+diffw
      ae3d=np.maximum(-convw[1:,:,:],-0)+np.roll(diffw,-1,axis=0)
      as3d=np.maximum(convs[:,0:-1,:],0)+diffs
      an3d=np.maximum(-convs[:,1:,:],0)+np.roll(diffs,-1,axis=1)
      al3d=np.maximum(convl[:,:,0:-1],0)+diffl
      ah3d=np.maximum(-convl[:,:,1:],0)+np.roll(diffl,-1,axis=2)

   if scheme_local == 'c':
      if itstep == 0 and iter == 0:
         print('CDS scheme, prand=',prand)
      aw3d=diffw+(1-fx)*convw[0:-1,:,:]
      ae3d=np.roll(diffw,-1,axis=0)-np.roll(fx,-1,axis=0)*convw[1:,:,:]

      as3d=diffs+(1-fy)*convs[:,0:-1,:]
      an3d=np.roll(diffs,-1,axis=1)-np.roll(fy,-1,axis=1)*convs[:,1:,:]

      al3d=diffl+0.5*convl[:,:,0:-1]
      ah3d=np.roll(diffl,-1,axis=2)-0.5*convl[:,:,1:]

   apo3d=vol/dt[itstep]


   if not cyclic_x:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   as3d[:,0,:]=0
   an3d[:,-1,:]=0
   if not cyclic_z:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d

def bc(su3d,sp3d,phi_bc_west,phi_bc_east,phi_bc_south,phi_bc_north,phi_bc_z\
     ,phi_bc_west_type,phi_bc_east_type,phi_bc_south_type,phi_bc_north_type,phi_bc_z_type):
   su3d=np.zeros((ni,nj,nk))
   sp3d=np.zeros((ni,nj,nk))

#south
   if phi_bc_south_type == 'd':
      sp3d[:,0,:]=sp3d[:,0,:]-viscos*as_bound
      su3d[:,0,:]=su3d[:,0,:]+viscos*as_bound*phi_bc_south

#north
   if phi_bc_north_type == 'd':
      sp3d[:,-1,:]=sp3d[:,-1,:]-viscos*an_bound
      su3d[:,-1,:]=su3d[:,-1,:]+viscos*an_bound*phi_bc_north

#west
   if phi_bc_west_type == 'd' and not cyclic_x:
      sp3d[0,:,:]=sp3d[0,:,:]-viscos*aw_bound
      su3d[0,:,:]=su3d[0,:,:]+viscos*aw_bound*phi_bc_west
#east
   if phi_bc_east_type == 'd' and not cyclic_x:
      sp3d[-1,:,:]=sp3d[-1,:,:]-viscos*ae_bound
      su3d[-1,:,:]=su3d[-1,:,:]+viscos*ae_bound*phi_bc_east

#low & high
   if phi_bc_z_type == 'd' and not cyclic_z:
      sp3d[:,:,0]=sp3d[:,:,0]-viscos*az_bound
      sp3d[:,:,-1]=sp3d[:,:,-1]-viscos*az_bound
      su3d[:,:,0]=su3d[:,:,0]+viscos*az_bound*phi_bc_z
      su3d[:,:,-1]=su3d[:,:,-1]+viscos*az_bound*phi_bc_z

#  cyclic x
#  if cyclic_x:

   return su3d,sp3d

def conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l):
#compute convection
   
   dtt=dt[itstep]*acrank
   u3d_star=u3d+dphidx(p3d_face_w,p3d_face_s)*dtt
   v3d_star=v3d+dphidy(p3d_face_w,p3d_face_s)*dtt
   w3d_star=w3d+dphidz(p3d_face_l)*dtt

   u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d_star,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
    u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
   v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d_star,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z,\
    v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
   w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d_star,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z,\
    w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)

   convw=-u3d_face_w*areawx-v3d_face_w*areawy
   convs=-u3d_face_s*areasx-v3d_face_s*areasy
   convl=w3d_face_l*areaz

   convw,convs,convl=modify_conv(convw,convs,convl)

   return convw,convs,convl
   
def solve_3d(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv,nmax,acrank_conv_local,solver_local):
   if itstep == 0 and iter == 0:
      print('solve_3d called')
      print('nmax,acrank_conv_local',nmax,acrank_conv_local)

   aw=np.matrix.flatten(aw3d)*acrank_conv_local
   ae=np.matrix.flatten(ae3d)*acrank_conv_local
   as1=np.matrix.flatten(as3d)*acrank_conv_local
   an=np.matrix.flatten(an3d)*acrank_conv_local
   al=np.matrix.flatten(al3d)*acrank_conv_local
   ah=np.matrix.flatten(ah3d)*acrank_conv_local
   ap=np.matrix.flatten(ap3d)
  
   m=ni*nj*nk



   if cyclic_x and cyclic_z:
      al_cyc=np.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=np.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1:-1:nk]=0
      ah[-1]=0
      A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr') 
   elif not cyclic_z and cyclic_x:
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr') 
   elif cyclic_z and not cyclic_x:
      al_cyc=np.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=np.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1::nk]=0
      A= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
   elif not cyclic_z and not cyclic_x:
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr') 

   su=np.matrix.flatten(su3d)
   phi=np.matrix.flatten(phi3d)

   res_su=np.linalg.norm(su)
   resid_orig=np.linalg.norm(A*phi - su)

   phi_org=phi

   if solver_local == 'gmres':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: gmres')
      phi,info=linalg.gmres(A,su,x0=phi, atol=tol_conv,  maxiter=nmax)  # good
   if solver_local == 'lgmres':
      if itstep == 0 and iter == 0:
         print('solver in solve_3d: lgmres')
      phi,info=linalg.lgmres(A,su,x0=phi, atol=tol_conv,  maxiter=nmax)  # good
   if info > 0:
      print('warning in module solve_3d: convergence in sparse matrix solver not reached')
# compute residual without normalizing with |b|=|su3d|
   resid=np.linalg.norm(A*phi - su)

   delta_phi=np.max(np.abs(phi-phi_org))

   phi3d=np.reshape(phi,(ni,nj,nk))
   phi3d_org=np.reshape(phi_org,(ni,nj,nk))

   print(f"{'residual history in solve_3d: initial residual: '} {resid_orig:.2e}{'final residual: ':>30}{resid:.2e}\
      {'delta_phi: ':>25}{delta_phi:.2e}")


   return phi3d,resid

def solve_pyamg(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv,acrank_conv_local):

   if itstep == 0 and iter == 0:
      print('solve_pyamg called,tol_conv=',tol_conv,'acrank_conv_local=',acrank_conv_local)
      print('relation method=',amg_relax)

   aw=np.matrix.flatten(aw3d)*acrank_conv_local
   ae=np.matrix.flatten(ae3d)*acrank_conv_local
   as1=np.matrix.flatten(as3d)*acrank_conv_local
   an=np.matrix.flatten(an3d)*acrank_conv_local
   al=np.matrix.flatten(al3d)*acrank_conv_local
   ah=np.matrix.flatten(ah3d)*acrank_conv_local
   ap=np.matrix.flatten(ap3d)

   m=ni*nj*nk



   if cyclic_x and cyclic_z:
      print('cyclic_x cyclic_z')
      al_cyc=np.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=np.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
      ah[nk-1:-1:nk]=0
      ah[-1]=0
      A = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
   elif not cyclic_z and cyclic_x:
      print('cyclic_x and not cyclic_z')
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
   elif cyclic_z and not cyclic_x:
      al_cyc=np.zeros(m)
      al_cyc[0:-1:nk]= al[0:-1:nk]
      al[0:-1:nk]=0
      ah_cyc=np.zeros(m)
      ah_cyc[nk-1::nk]=ah[nk-1::nk]
#       ah[nk-1:-1:nk]=0
      ah[nk-1::nk]=0
      A= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
         [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
   elif not cyclic_z and not cyclic_x:
      print('not cyclic_z and not cyclic_x')
      A = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr')
  


   App = pyamg.ruge_stuben_solver(A)                    # construct the multigrid hierarchy
#     Ap = pyamg.classical.ruge_stuben_solver(Ap) 

   print('in solve_pyamg')

   phi=np.matrix.flatten(phi3d)
   su=np.matrix.flatten(su3d)
   phi_org=phi
   res_amg = []
   if amg_relax == 'default':
      phi = App.solve(su, tol=tol_conv, x0=phi, residuals=res_amg)
   else:
      phi = App.solve(su, tol=tol_conv, x0=phi,accel=amg_relax,cycle=amg_cycle, residuals=res_amg)

   print('Residual history in pyAMG', ["%0.4e" % i for i in res_amg])

   delta_phi=np.max(np.abs(phi-phi_org))

   print(f"{'Residual history in solve_pyAMG: delta_phi: ':>25}{delta_phi:.2e}")

   phi3d=np.reshape(phi,(ni,nj,nk))

   return phi3d,res_amg[-1]

def solve_p(phi3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,tol_conv):
   global Ap,Mp
   if itstep == 0 and iter == 0:
      print('solve_p called')
      print('relation method=',amg_relax)

   if iter == 0 and itstep == 0:
      print('A and M computed,tol_conv=',tol_conv)
      aw=np.matrix.flatten(aw3d)
      ae=np.matrix.flatten(ae3d)
      as1=np.matrix.flatten(as3d)
      an=np.matrix.flatten(an3d)
      al=np.matrix.flatten(al3d)
      ah=np.matrix.flatten(ah3d)
      ap=np.matrix.flatten(ap3d)

      m=ni*nj*nk



      if cyclic_x and cyclic_z:
         print('cyclic_x cyclic_z')
         al_cyc=np.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=np.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
         ah[nk-1:-1:nk]=0
         ah[-1]=0
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif not cyclic_z and cyclic_x:
         print('cyclic_x and not cyclic_z')
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:],-aw,-ae[nj*nk*(ni-1):]], \
               [0, 1, -1, nk,-nk, nk*nj, -nk*nj, nj*nk*(ni-1), -nj*nk*(ni-1)], format='csr')
      elif cyclic_z and not cyclic_x:
         al_cyc=np.zeros(m)
         al_cyc[0:-1:nk]= al[0:-1:nk]
         al[0:-1:nk]=0
         ah_cyc=np.zeros(m)
         ah_cyc[nk-1::nk]=ah[nk-1::nk]
#        ah[nk-1:-1:nk]=0
         ah[nk-1::nk]=0
         Ap= sparse.diags([ap,-ah[:-1],-al[1:],-al_cyc, -ah_cyc[nk-1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], \
            [0, 1, -1, nk-1, -(nk-1), nk, -nk, nk*nj, -nk*nj], format='csr')
      elif not cyclic_z and not cyclic_x:
         print('not cyclic_z and not cyclic_x')
         Ap = sparse.diags([ap, -ah[:-1], -al[1:], -an[0:-nk], -as1[nk:], -ae, -aw[nj*nk:]], [0, 1, -1, nk, -nk, nk*nj, -nk*nj], format='csr')
   


      Ap = pyamg.ruge_stuben_solver(Ap)                    # construct the multigrid hierarchy
#     Ap = pyamg.classical.ruge_stuben_solver(Ap) 

   print('in solve_p')

   phi=np.matrix.flatten(phi3d)
   su=np.matrix.flatten(su3d)
   if amg_relax == 'default':
      phi = Ap.solve(su, tol=tol_conv, x0=phi)
   else:
      phi = Ap.solve(su, tol=tol_conv, x0=phi,accel=amg_relax,cycle=amg_cycle)

   phi3d=np.reshape(phi,(ni,nj,nk))

   return phi3d


def calcu(su3d,sp3d,dpdx_old,p3d_face_w,p3d_face_s):
   if itstep == 0 and iter == 0:
      print('calcu called')
# b.c., sources, coefficients

# presssure gradient
   dpdx=acrank*dphidx(p3d_face_w,p3d_face_s)+(1-acrank)*dpdx_old
   su3d=su3d-dpdx*vol

# modify su & sp
   su3d,sp3d=modify_u(su3d,sp3d)
# unsteady term added in crank_nicol

   return su3d,sp3d

def calcv(su3d,sp3d,dpdy_old,p3d_face_w,p3d_face_s):
   if itstep == 0 and iter == 0:
      print('calcv called')
# b.c., sources, coefficients 

# presssure gradient
   dpdy=acrank*dphidy(p3d_face_w,p3d_face_s)+(1-acrank)*dpdy_old
   su3d=su3d-dpdy*vol

# modify su & sp
   su3d,sp3d=modify_v(su3d,sp3d)
# unsteady term added in crank_nicol

   return su3d,sp3d

def compute_fk(k3d,eps3d):

   if itstep == 0 and iter == 0:
      print('compute_fk called')

   L_t=k3d**1.5/eps3d
   cdes=0.67
   psi=np.maximum(1,L_t/(cdes*delta_max))

   fkmin=0.2
   fk3d=np.maximum(1.-(psi-1.)/(c_eps_2-c_eps_1),fkmin)

   return fk3d


def calck_kom(su3d,sp3d,k3d,om3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calck_kom called')

# production term
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)
   vist=np.maximum(vis3d-viscos,1e-10)
   su3d=su3d+vist*gen*vol

   rl=k3d**0.5/(cmu*om3d)

   if kom_des:
      fk3d=np.maximum(1.,rl/(0.67*delta_max))
      if jl0 < 0:
         jl=np.abs(jl0)
         fk3d[:,0:jl,:]=1
   else:
      fk3d=1

# dissipation term
   sp3d=sp3d-fk3d*cmu*om3d*vol

# modify su & sp
   su3d,sp3d,comm_term=modify_k(su3d,sp3d,gen)

# unsteady term added in crank_nicol

   return su3d,sp3d,gen,fk3d,comm_term

def calcom(su3d,sp3d,om3d,gen,comm_term):
   if itstep == 0 and iter == 0:
      print('calcom called')


#--------production term
   su3d=su3d+c_omega_1*gen*vol

#--------dissipation term
   sp3d=sp3d-c_omega_2*om3d*vol

# modify su & sp
   su3d,sp3d=modify_om(su3d,sp3d,comm_term)

   return su3d,sp3d

def calck_ls(su3d,sp3d,k3d,eps3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calck_ls called')

# production term
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)
   vist=np.maximum(vis3d-viscos,1e-10)
   su3d=su3d+vist*gen*vol

# dissipation term
   sp3d=sp3d-eps3d/k3d*vol

# D term
# compute gradient of k**0.5
   k05=k3d**0.5
   k05_face_w,k05_face_s,k05_face_l=compute_face_phi(k05,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_z,\
     k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_z_type)
   dk05dy=dphidy(k05_face_w,k05_face_s)
   dterm=2.*viscos*dk05dy**2
   sp3d=sp3d-dterm/k3d*vol

# modify su & sp
   su3d,sp3d,comm_term=modify_k(su3d,sp3d,gen)

# unsteady term added in crank_nicol

   return su3d,sp3d,gen,dudx,dudy

def calceps_ls(su3d,sp3d,k3d,eps3d,vis3d,gen,dudx,dudy):
   if itstep == 0 and iter == 0:
      print('calceps_ls called')

# b.c., sources, coefficients 
   rt=k3d**2/eps3d/viscos
   fdampf2=1.-0.3*np.exp(-rt**2)
   fmu3d=np.exp(-3.4/(1.+rt/50.)**2)
   fmu3d=np.minimum(fmu3d,1.)

#--------production term
   su3d=su3d+c_eps_1*cmu*fmu3d*gen*k3d*vol
   c2u=c_eps_1+fk3d*(fdampf2*c_eps_2-c_eps_1)

#--------dissipation term
   sp3d=sp3d-c2u*eps3d*vol/k3d

#--- E term (note that u_bc_west,u_bc_east ... are not used since Neumann bc are prescribed) 
   dudy_face_w,dudy_face_s,dudy_face_l=compute_face_phi(dudy,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
     'n','n','n','n','n')
   dudx_face_w,dudx_face_s,dudx_face_l=compute_face_phi(dudx,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
     'n','n','n','n','n')
   d2udy2=dphidy(dudy_face_w,dudy_face_s)
   d2udx2=dphidx(dudx_face_w,dudx_face_s)
   vist=vis3d-viscos
   vist_min=np.min(vist)

   eterm=2.*viscos*vist*(d2udx2**2+d2udy2**2)
#  eterm=2.*viscos*vist*d2udy2**2
   su3d=su3d+eterm*vol

# modify su & sp
   su3d,sp3d=modify_eps(su3d,sp3d)

   return su3d,sp3d,fmu3d


def calcw(su3d,sp3d,dpdz_old,p3d_face_l):
# b.c., sources, coefficients 
   if itstep == 0 and iter == 0:
      print('calcw called')

# presssure gradient
   dpdz=acrank*dphidz(p3d_face_l)+(1-acrank)*dpdz_old
   su3d=su3d-dpdz*vol

# modify su & sp
   su3d,sp3d=modify_w(su3d,sp3d)

# unsteady term added in crank_nicol
   return su3d,sp3d

def calcp(convw,convs,convl):
   if itstep == 0 and iter == 0:
      print('calcp called')
# b.c., sources, coefficients
   volw=np.ones((ni+1,nj,nk))*1e-10
   vols=np.ones((ni,nj+1,nk))*1e-10
   volw[1:,:,:]=0.5*np.roll(vol,-1,axis=0)+0.5*vol
   aw3d=areaw[0:-1,:,:]**2/volw[0:-1,:,:]
   vols[:,1:,:]=0.5*np.roll(vol,-1,axis=1)+0.5*vol
   as3d=areas[:,0:-1,:]**2/vols[:,0:-1,:]
   al3d=areaz[:,:,0:-1]/dz

   ae3d=np.roll(aw3d,-1,axis=0)
   an3d=np.roll(as3d,-1,axis=1)
   ah3d=np.roll(al3d,-1,axis=2)


   if cyclic_x:
      aw3d[0,:,:]=areaw[0,:,:]**2/(0.5*(vol[0,:,:]+vol[-1,:,:]))
      ae3d[-1,:,:]=aw3d[0,:,:]
   else:
      aw3d[0,:,:]=0
      ae3d[-1,:,:]=0
   
   if not cyclic_z:
      al3d[:,:,0]=0
      ah3d[:,:,-1]=0

   as3d[:,0,:]=0
   an3d[:,-1,:]=0



   ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d

# set p3d=0 in [0,0,0] to make it non-singular
   ap3d[0,0,0]=1e10

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d


def correct_conv(u3d,v3d,w3d,p3d,aw3d_p,as3d_p,al3d_p):
# correct convections
# create ghost cells at east & west boundaries with Neumann b.c.
   p3d_w=p3d
   p3d_s=p3d
   p3d_l=p3d
   dtt=dt[itstep]*acrank
#\\\\\\\\\\\\\ west face
# set zeros and put if before row 0
   p3d_w=np.insert(p3d_w,0,np.zeros((nj,nk)),axis=0)
   if cyclic_x:
      convw[1:-1,:,:]=convw[1:-1,:,:]+aw3d_p[1:,:,:]*(p3d_w[1:-1,:,:]-p3d[1:,:,:])*dtt
      convw[0,:,:]=convw[0,:,:]+aw3d_p[0,:,:]*(p3d[-1,:,:]-p3d[0,:,:])*dtt
      convw[-1,:,:]=convw[0,:,:]
   else:
      convw[0:-1,:,:]=convw[0:-1,:,:]+aw3d_p*(p3d_w[0:-1,:,:]-p3d)*dtt


#\\\\\\\\\\\\\ south face
# set zeros and put it before column 0
   p3d_s=np.insert(p3d_s,0,np.zeros((ni,nk)),axis=1)
   convs[:,0:-1,:]=convs[:,0:-1,:]+as3d_p*(p3d_s[:,0:-1,:]-p3d)*dtt

#\\\\\\\\\\\\\ low face
# set zeros and put it before column 0
   p3d_l=np.insert(p3d_l,0,np.zeros((ni,nj)),axis=2)
   if cyclic_z:
      convl[:,:,1:-1]=convl[:,:,1:-1]+al3d_p[:,:,1:]*(p3d_l[:,:,1:-1]-p3d[:,:,1:])*dtt
      convl[:,:,-1]=convl[:,:,-1]+al3d_p[:,:,-1]*(p3d[:,:,-1]-p3d[:,:,0])*dtt
      convl[:,:,0]=convl[:,:,-1]
   else:
      convl[:,:,0:-1]=convl[:,:,0:-1]+al3d_p*(p3d_l[:,:,0:-1]-p3d)*dtt
# boundary
      convl[:,:,0]=w3d_face_l[:,:,0]*areaz[:,:,0]

# continuity error
   su3d=convw[0:-1,:,:]-convw[1:,:,:]\
       +convs[:,0:-1,:]-convs[:,1:,:]\
       +convl[:,:,0:-1]-convl[:,:,1:]

   return convw,convs,convl,p3d,u3d,v3d,w3d,su3d


def update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l):
    u3d_old=u3d
    v3d_old=v3d
    w3d_old=w3d
    k3d_old=k3d
    eps3d_old=eps3d
    om3d_old=om3d
    dpdx_old=dphidx(p3d_face_w,p3d_face_s)
    dpdy_old=dphidy(p3d_face_w,p3d_face_s)
    dpdz_old=dphidz(p3d_face_l)

    return u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old

def time_stats(u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,\
                fk3d_mean,vis3d_mean):

    global itstep_stats_counter

    itstep_stats_counter=itstep_stats_counter+1
    u3d_mean=u3d_mean+u3d
    v3d_mean=v3d_mean+v3d
    w3d_mean=w3d_mean+w3d
    p3d_mean=p3d_mean+p3d
    k3d_mean=k3d_mean+k3d
    fk3d_mean=fk3d_mean+fk3d
    om3d_mean=om3d_mean+om3d
    eps3d_mean=eps3d_mean+eps3d
    vis3d_mean=vis3d_mean+vis3d
    uu3d_stress=uu3d_stress+u3d**2
    vv3d_stress=vv3d_stress+v3d**2
    ww3d_stress=ww3d_stress+w3d**2
    uv3d_stress=uv3d_stress+u3d*v3d


 
    return u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,\
           fk3d_mean,vis3d_mean

def crank_nicol(phi3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_local):
    ap3d=aw3d+ae3d+as3d+an3d+al3d+ah3d
    su3d=su3d+(apo3d-(1-acrank_conv_local)*ap3d)*phi3d_old
    ap3d=apo3d+acrank_conv_local*ap3d-sp3d
    su3d=su3d+(1-acrank_conv_local)*\
      (ae3d*np.roll(phi3d_old,-1,axis=0)+aw3d*np.roll(phi3d_old,1,axis=0) \
      +an3d*np.roll(phi3d_old,-1,axis=1)+as3d*np.roll(phi3d_old,1,axis=1) \
      +ah3d*np.roll(phi3d_old,-1,axis=2)+al3d*np.roll(phi3d_old,1,axis=2))
    return ap3d,su3d

def vist_kom(vis3d,k3d,om3d):
   if itstep == 0 and iter == 0:
      print('vist_kom called')

   visold= vis3d
   vis3d= k3d/om3d+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold

   return vis3d

def vist_pans(vis3d,k3d,eps3d,fmu3d):
   if itstep == 0 and iter == 0:
      print('vist_pans called')

   visold= vis3d
   vis3d= cmu*fmu3d*k3d**2/eps3d+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold

   return vis3d

def vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d):
   if itstep == 0 and iter == 0:
      print('vist_smag called')
   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   gen= (2.*(dudx**2+dvdy**2+dwdz**2)+(dudz+dwdx)**2+(dvdz+dwdy)**2+(dudy+dvdx)**2)

# RANS lengthscale
   rl_rans=0.41*np.minimum(yp2d,yp2d[1,-1]-yp2d)
# make it 3d
   rl_rans_3d= np.dstack([rl_rans]*nk)
   rl_les=cmu*vol**0.3333333
   rl=np.minimum(rl_rans_3d,rl_les)
   visold= vis3d
   vis3d= rl**2*gen**0.5+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
   return vis3d

def save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress, \
                        vv3d_stress,ww3d_stress,uv3d_stress):

# save time-averaged data to disk
   np.save('u_averaged', np.mean(u3d_mean,axis=2))
   np.save('v_averaged', np.mean(v3d_mean,axis=2))
   np.save('w_averaged', np.mean(w3d_mean,axis=2))
   np.save('p_averaged', np.mean(p3d_mean,axis=2))
   np.save('k_averaged', np.mean(k3d_mean,axis=2))
   np.save('fk_averaged', np.mean(fk3d_mean,axis=2))
   np.save('k_averaged', np.mean(k3d_mean,axis=2))
   np.save('om_averaged', np.mean(om3d_mean,axis=2))
   np.save('vis_averaged', np.mean(vis3d_mean,axis=2))
   np.save('eps_averaged', np.mean(eps3d_mean,axis=2))
   np.save('k3d_averaged', np.mean(k3d_mean,axis=2))
   np.save('uu_stress', np.mean(uu3d_stress,axis=2))
   np.save('vv_stress', np.mean(vv3d_stress,axis=2))
   np.save('ww_stress', np.mean(ww3d_stress,axis=2))
   np.save('uv_stress', np.mean(uv3d_stress,axis=2))
   np.save('itstep',[itstep_stats_counter,nk,dz])
   print('itstep_stats_counter,nk,dz',itstep_stats_counter,nk,dz)
 
   return

def read_restart_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d):

   u3d=np.load('u3d_saved.npy')
   v3d=np.load('v3d_saved.npy')
   w3d=np.load('w3d_saved.npy')
   p3d=np.load('p3d_saved.npy')
   if keps or pans:
      k3d=np.load('k3d_saved.npy')
      eps3d=np.load('eps3d_saved.npy')
   if kom or kom_des:
      k3d=np.load('k3d_saved.npy')
      om3d=np.load('om3d_saved.npy')

   return u3d,v3d,w3d,p3d,k3d,eps3d,om3d

def save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d):

   np.save('u3d_saved', u3d)
   np.save('v3d_saved', v3d)
   np.save('w3d_saved', w3d)
   np.save('p3d_saved', p3d)
   if keps or pans:
      np.save('k3d_saved', k3d)
      np.save('eps3d_saved', eps3d)
   if kom or kom_des:
      np.save('k3d_saved', k3d)
      np.save('om3d_saved', om3d)

   return 

def vist_wale(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d):
   if itstep == 0 and iter == 0:
      print('vist_wale called')

   dudx=dphidx(u3d_face_w,u3d_face_s)
   dvdx=dphidx(v3d_face_w,v3d_face_s)
   dwdx=dphidx(w3d_face_w,w3d_face_s)

   dudy=dphidy(u3d_face_w,u3d_face_s)
   dvdy=dphidy(v3d_face_w,v3d_face_s)
   dwdy=dphidy(w3d_face_w,w3d_face_s)

   dudz=dphidz(u3d_face_l)
   dvdz=dphidz(v3d_face_l)
   dwdz=dphidz(w3d_face_l)

   s11=dudx
   s12=0.5*(dudy+dvdx)
   s13=0.5*(dudz+dwdx)

   s21=s12
   s22=dvdy
   s23=0.5*(dvdz+dwdy)

   s31=s13
   s32=s23
   s33=dwdz

   g11=dudx
   g12=dudy
   g13=dudz

   g21=dvdx
   g22=dvdy
   g23=dvdz
      
   g31=dwdx
   g32=dwdy
   g33=dwdz

#square of g_ij = g_ik g_kj
   g11_2=g11*g11+g12*g21+g13*g31
   g12_2=g11*g12+g12*g22+g13*g32
   g13_2=g11*g13+g12*g23+g13*g33

   g21_2=g21*g11+g22*g21+g23*g31
   g22_2=g21*g12+g22*g22+g23*g32
   g23_2=g21*g13+g22*g23+g23*g33

   g31_2=g31*g11+g32*g21+g33*g31
   g32_2=g31*g12+g32*g22+g33*g32
   g33_2=g31*g13+g32*g23+g33*g33

   gkk_2=(g11_2+g22_2+g33_2)/3.

   sd11=g11_2-gkk_2
   sd12=0.5*(g12_2+g21_2)
   sd13=0.5*(g13_2+g31_2)
   sd21=sd12
   sd22=g22_2-gkk_2
   sd23=0.5*(g23_2+g32_2)

   sd31=sd13
   sd32=sd23
   sd33=g33_2-gkk_2

   sijsij=s11*s11+s12*s12+s13*s13+\
          s21*s21+s22*s22+s23*s23+\
          s31*s31+s32*s32+s33*s33

   sdijsdij=sd11*sd11+sd12*sd12+sd13*sd13+\
          sd21*sd21+sd22*sd22+sd23*sd23+\
          sd31*sd31+sd32*sd32+sd33*sd33

   cm=10.6*0.1**2

   term1=sdijsdij**1.5/(sijsij**2.5+sdijsdij**1.25)

   visold= vis3d
   delta=vol**0.333333
   vis3d= (cm*delta)**2*term1+viscos
#            under-relax viscosity
   vis3d= urfvis*vis3d+(1.-urfvis)*visold
   return vis3d

######################### the execution of the code starts here #############################

########### grid specification ###########
datax= np.loadtxt("x2d.dat")
x=datax[0:-1]
ni=int(datax[-1])
datay= np.loadtxt("y2d.dat")
y=datay[0:-1]
nj=int(datay[-1])

x2d=np.zeros((ni+1,nj+1))
y2d=np.zeros((ni+1,nj+1))

x2d=np.reshape(x,(ni+1,nj+1))
y2d=np.reshape(y,(ni+1,nj+1))

# compute cell centers
xp2d=0.25*(x2d[0:-1,0:-1]+x2d[0:-1,1:]+x2d[1:,0:-1]+x2d[1:,1:])
yp2d=0.25*(y2d[0:-1,0:-1]+y2d[0:-1,1:]+y2d[1:,0:-1]+y2d[1:,1:])

zmax, nk=np.loadtxt('z.dat')
nk=np.int(nk)
dz=zmax/nk

# initialize geometric arrays

vol=np.zeros((ni,nj,nk))
areas=np.zeros((ni,nj+1,nk))
areasx=np.zeros((ni,nj+1,nk))
areasy=np.zeros((ni,nj+1,nk))
areaw=np.zeros((ni+1,nj,nk))
areawx=np.zeros((ni+1,nj,nk))
areawy=np.zeros((ni+1,nj,nk))
areaz=np.zeros((ni,nj,nk+1))
as_bound=np.zeros((ni,nk))
an_bound=np.zeros((ni,nk))
aw_bound=np.zeros((nj,nk))
ae_bound=np.zeros((nj,nk))
az_bound=np.zeros((ni,nj))
fx=np.zeros((ni,nj,nk))
fy=np.zeros((ni,nj,nk))

setup_case()

print_indata()

areaw,areawx,areawy,areas,areasx,areasy,areaz,vol,fx,fy,aw_bound,ae_bound,as_bound,an_bound,az_bound,dist3d=init()


# initialization
itstep_stats_counter=0 # counter for timeaveraging
u3d=np.ones((ni,nj,nk))*1e-20
v3d=np.ones((ni,nj,nk))*1e-20
w3d=np.ones((ni,nj,nk))*1e-20
p3d=np.ones((ni,nj,nk))*1e-20
k3d=np.ones((ni,nj,nk))*1
eps3d=np.ones((ni,nj,nk))*1
om3d=np.ones((ni,nj,nk))*1
vis3d=np.ones((ni,nj,nk))*viscos

fk3d=np.ones((ni,nj,nk))

dpdx_old=np.ones((ni,nj,nk))*1e-20
dpdy_old=np.ones((ni,nj,nk))*1e-20
dpdz_old=np.ones((ni,nj,nk))*1e-20

convw=np.ones((ni+1,nj,nk))*1e-20
convs=np.ones((ni,nj+1,nk))*1e-20
convl=np.ones((ni,nj,nk+1))*1e-20

u3d_mean=np.ones((ni,nj,nk))*1e-20
v3d_mean=np.ones((ni,nj,nk))*1e-20
w3d_mean=np.ones((ni,nj,nk))*1e-20
p3d_mean=np.ones((ni,nj,nk))*1e-20
k3d_mean=np.ones((ni,nj,nk))*1e-20
om3d_mean=np.ones((ni,nj,nk))*1e-20
eps3d_mean=np.ones((ni,nj,nk))*1e-20
uu3d_stress=np.ones((ni,nj,nk))*1e-20
vv3d_stress=np.ones((ni,nj,nk))*1e-20
ww3d_stress=np.ones((ni,nj,nk))*1e-20
uv3d_stress=np.ones((ni,nj,nk))*1e-20
fk3d_mean=np.ones((ni,nj,nk))*1e-20
vis3d_mean=np.ones((ni,nj,nk))*1e-20

aw3d=np.ones((ni,nj,nk))*1e-20
ae3d=np.ones((ni,nj,nk))*1e-20
as3d=np.ones((ni,nj,nk))*1e-20
an3d=np.ones((ni,nj,nk))*1e-20
al3d=np.ones((ni,nj,nk))*1e-20
ah3d=np.ones((ni,nj,nk))*1e-20
ap3d=np.ones((ni,nj,nk))*1e-20
apo3d=np.ones((ni,nj,nk))*1e-20
su3d=np.ones((ni,nj,nk))*1e-20
sp3d=np.ones((ni,nj,nk))*1e-20
dudx=np.ones((ni,nj,nk))*1e-20
dudy=np.ones((ni,nj,nk))*1e-20
usynt_inlet=np.ones((nj,nk))*1e-20
vsynt_inlet=np.ones((nj,nk))*1e-20
wsynt_inlet=np.ones((nj,nk))*1e-20

# comute Delta_max for LES/DES/PANS models
delta_max=np.maximum(vol/areas[:,1:],vol/areaw[1:,:])
delta_max=np.maximum(delta_max,dz)


itstep=0
iter=0



# initialize
u3d,v3d,w3d,k3d,om3d,eps3d,vis3d=modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d)


# read data for restart
if restart: 
   u3d,v3d,w3d,p3d,k3d,eps3d,om3d= read_restart_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

eps3d=np.maximum(eps3d,1e-6)
k3d=np.maximum(k3d,1e-6)

u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
    u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z,\
    v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z,\
    w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)
p3d_face_w,p3d_face_s,p3d_face_l=compute_face_phi(p3d,p_bc_west,p_bc_east,p_bc_south,p_bc_north,p_bc_z,\
    p_bc_west_type,p_bc_east_type,p_bc_south_type,p_bc_north_type,p_bc_z_type)


if not cyclic_x:
   u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw = modify_inlet()


convw,convs,convl=conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l)

u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old=update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l)

itstep=0
iter=0

if kom or kom_des:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_kom(vis3d,k3d,om3d)
   urfvis=urf_temp

if pans or keps:
   if pans:
     fk3d=compute_fk(k3d,eps3d)
# compute fmu3d
   gen=np.zeros((ni,nj,nk))
   itstep=1
   itstep=0
   su3d,sp3d,fmu3d= calceps_ls(su3d,sp3d,k3d,eps3d,vis3d,gen,dudx,dudy)
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_pans(vis3d,k3d,eps3d,fmu3d)
   urfvis=urf_temp

if smag:
   urf_temp=urfvis # no under-relaxation
   urfvis=1
   vis3d=vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
   urfvis=urf_temp


iter=0
itstep=0


# find max index
#sumax=np.max(su3d.flatten())
#print('[i,j,k]', np.where(su3d == np.amax(su3d)) 

residual_u=0
residual_v=0
residual_w=0
residual_p=0
residual_k=0
residual_eps=0
residual_om=0


######################### start of time stepping  #############################

for itstep in range(0,ntstep):

######################### start of global iteration process #############################

   for iter in range(0,maxit):

      start_time_iter = time.time()
# coefficients for velocities
      start_time = time.time()
# conpute inlet fluc
      if iter == 0 and not cyclic_x:
         u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw = modify_inlet()
      aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,1,scheme)
# u3d
# boundary conditions for u3d
      su3d,sp3d=bc(su3d,sp3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z, \
                   u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
      su3d,sp3d=calcu(su3d,sp3d,dpdx_old,p3d_face_w,p3d_face_s)
      ap3d,su3d=crank_nicol(u3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)

      if solver_vel == 'pyamg':
         u3d,residual_u=solve_pyamg(u3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_u,acrank_conv)
      else:
         u3d,residual_u=solve_3d(u3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_u,nsweep_vel,acrank_conv,solver_vel)
      print('time u',time.time()-start_time)


      start_time = time.time()
# v3d
# boundary conditions for v3d
      su3d,sp3d=bc(su3d,sp3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z, \
                   v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
      su3d,sp3d=calcv(su3d,sp3d,dpdy_old,p3d_face_w,p3d_face_s)
      ap3d,su3d=crank_nicol(v3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)
      if solver_vel == 'pyamg':
         v3d,residual_v=solve_pyamg(v3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_v,acrank_conv)
      else:
         v3d,residual_v=solve_3d(v3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_v,nsweep_vel,acrank_conv,solver_vel)
      print('time v',time.time()-start_time)


      start_time = time.time()
# w3d
# boundary conditions for w3d
      su3d,sp3d=bc(su3d,sp3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z, \
                   w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)
      su3d,sp3d=calcw(su3d,sp3d,dpdz_old,p3d_face_l)

      
      ap3d,su3d=crank_nicol(w3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv)
      if solver_vel == 'pyamg':
         w3d,residual_w=solve_pyamg(w3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_w,acrank_conv)
      else:
         w3d,residual_w=solve_3d(w3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_w,nsweep_vel,acrank_conv,solver_vel)
      print('time w',time.time()-start_time)

# p3d
      convw,convs,convl=conv(u3d,v3d,w3d,p3d_face_w,p3d_face_s,p3d_face_l)

      if not cyclic_x:
         convw,u_bc_east =modify_outlet(convw)


# RHS
# continuity error
      su3d=(convw[0:-1,:,:]-convw[1:,:,:]\
           +convs[:,0:-1,:]-convs[:,1:,:]\
           +convl[:,:,0:-1]-convl[:,:,1:])/acrank/dt[itstep]
#
      if iter == 0 and itstep == 0:
         aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d=calcp(convw,convs,convl)
         aw3d_p=aw3d
         as3d_p=as3d
         al3d_p=al3d

      p3d=solve_p(p3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_p)
      print('time p',time.time()-start_time)

# correct u, v, w, p
      convw,convs,convl,p3d,u3d,v3d,w3d,su3d= correct_conv(u3d,v3d,w3d,p3d,aw3d_p,as3d_p,al3d_p)
      res_1d=np.matrix.flatten(su3d)
      residual_p=np.linalg.norm(res_1d,ord=1)

      u3d_face_w,u3d_face_s,u3d_face_l=compute_face_phi(u3d,u_bc_west,u_bc_east,u_bc_south,u_bc_north,u_bc_z,\
        u_bc_west_type,u_bc_east_type,u_bc_south_type,u_bc_north_type,u_bc_z_type)
      v3d_face_w,v3d_face_s,v3d_face_l=compute_face_phi(v3d,v_bc_west,v_bc_east,v_bc_south,v_bc_north,v_bc_z,\
        v_bc_west_type,v_bc_east_type,v_bc_south_type,v_bc_north_type,v_bc_z_type)
      w3d_face_w,w3d_face_s,w3d_face_l=compute_face_phi(w3d,w_bc_west,w_bc_east,w_bc_south,w_bc_north,w_bc_z,\
        w_bc_west_type,w_bc_east_type,w_bc_south_type,w_bc_north_type,w_bc_z_type)
      p3d_face_w,p3d_face_s,p3d_face_l=compute_face_phi(p3d,p_bc_west,p_bc_east,p_bc_south,p_bc_north,p_bc_z,\
        p_bc_west_type,p_bc_east_type,p_bc_south_type,p_bc_north_type,p_bc_z_type)

      if kom or kom_des:
         start_time = time.time()
         vis3d=vist_kom(vis3d,k3d,om3d)
# coefficients
         start_time = time.time()
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,scheme_turb)
# k
# boundary conditions for k3d
         su3d,sp3d=bc(su3d,sp3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_z, \
                   k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_z_type)
         su3d,sp3d,gen,fk3d,comm_term=calck_kom(su3d,sp3d,k3d,om3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l)

         ap3d,su3d=crank_nicol(k3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_kom)

         if solver_turb == 'pyamg':
            k3d,residual_k=solve_pyamg(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_kom)
         else:
            k3d,residual_k=solve_3d(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,nsweep_kom,acrank_conv_kom,solver_turb)
         k3d=np.maximum(k3d,1e-10)
         print('time k',time.time()-start_time)


         start_time = time.time()
# omega
# boundary conditions for om3d
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,scheme_turb)
         su3d,sp3d=bc(su3d,sp3d,om_bc_west,om_bc_east,om_bc_south,om_bc_north,om_bc_z, \
                   om_bc_west_type,om_bc_east_type,om_bc_south_type,om_bc_north_type,om_bc_z_type)
         su3d,sp3d= calcom(su3d,sp3d,om3d,gen,comm_term)
         ap3d,su3d=crank_nicol(om3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_kom)

         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=fix_omega()

         if solver_turb == 'pyamg':
            om3d,residual_om=solve_pyamg(om3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_om,acrank_conv_kom)
         else:
            om3d,residual_om=solve_3d(om3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_om,nsweep_kom,acrank_conv_kom,solver_turb)
         om3d=np.maximum(om3d,1e-10)

         print('time omega',time.time()-start_time)

      if smag:
         vis3d=vist_smag(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
      if wale:
         vis3d=vist_wale(u3d_face_w,u3d_face_s,u3d_face_l,v3d_face_w,v3d_face_s,v3d_face_l,w3d_face_w,w3d_face_s,w3d_face_l,vis3d)
      if pans or keps:
         start_time = time.time()
         vis3d=vist_pans(vis3d,k3d,eps3d,fmu3d)
# coefficients
         start_time = time.time()
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_k,scheme_turb)
# k
# boundary conditions for u3d
         su3d,sp3d=bc(su3d,sp3d,k_bc_west,k_bc_east,k_bc_south,k_bc_north,k_bc_z, \
                   k_bc_west_type,k_bc_east_type,k_bc_south_type,k_bc_north_type,k_bc_z_type)
         su3d,sp3d,gen,dudx,dudy=calck_ls(su3d,sp3d,k3d,eps3d,vis3d,u3d_face_w,u3d_face_s,v3d_face_w,v3d_face_s,w3d_face_l)

         ap3d,su3d=crank_nicol(k3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_keps)

         if solver_turb == 'pyamg':
            k3d,residual_k=solve_pyamg(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_kom)
         else:
            k3d,residual_k=solve_3d(k3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,nsweep_keps,acrank_conv_kom,solver_turb)
         k3d=np.maximum(k3d,1e-10)
         k3d=np.maximum(k3d,1e-6)
         print('time k',time.time()-start_time)

         start_time  = time.time()


# eps
         aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d=coeff(convw,convs,convl,vis3d,prand_eps,scheme_turb)
# boundary conditions for u3d
         su3d,sp3d=bc(su3d,sp3d,eps_bc_west,eps_bc_east,eps_bc_south,eps_bc_north,eps_bc_z, \
                   eps_bc_west_type,eps_bc_east_type,eps_bc_south_type,eps_bc_north_type,eps_bc_z_type)
         su3d,sp3d,fmu3d= calceps_ls(su3d,sp3d,k3d,eps3d,vis3d,gen,dudx,dudy)

         ap3d,su3d=crank_nicol(eps3d_old,aw3d,ae3d,as3d,an3d,al3d,ah3d,apo3d,su3d,sp3d,acrank_conv_keps)

         if solver_turb == 'pyamg':
            eps3d,residual_eps=solve_pyamg(eps3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_k,acrank_conv_keps)
         else:
            eps3d,residual_eps=solve_3d(eps3d,aw3d,ae3d,as3d,an3d,al3d,ah3d,su3d,ap3d,convergence_limit_eps,nsweep_keps,acrank_conv_keps,solver_turb)

         print('time eps',time.time()-start_time)

         if np.min(eps3d) < 0:
            print('eps < 0')
            print('[i,j,k]', np.where(eps3d < 0))
            sys.exit()


         if pans:
            fk3d=compute_fk(k3d,eps3d)
       

# scale residuals
      residual_u=residual_u/resnorm_vel
      residual_v=residual_v/resnorm_vel
      residual_w=residual_w/resnorm_vel
      residual_p=residual_p/resnorm_p
      residual_k=residual_k/resnorm_vel**2
      residual_eps=residual_eps/resnorm_vel**3
      residual_om=residual_om/resnorm_vel

#     resmax=np.max([residual_u ,residual_v,residual_p,residual_k,residual_eps,residual_om])
      resmax=np.max([residual_u ,residual_v,residual_p])

      print('-time step: %d, iter: %d, max residul=%10.2E, u=%10.2E, v=%10.2E,\
w=%10.2E, cont=%10.2E, k=%10.2E, eps=%10.2E, om =%10.2E\n\n'\
      % (itstep,iter, resmax,residual_u, residual_v, residual_w, residual_p, residual_k, residual_eps, residual_om))

      print('monitor --- -time step: %d, iter: %d, u=%10.2E, v=%10.2E, w=%10.2E, p=%10.2E, \
k=%10.2E, eps=%10.2E, om=%10.2E, vis=%10.2E,\n\n'\
      % (itstep,iter,u3d[imon,jmon,kmon],v3d[imon,jmon,kmon],w3d[imon,jmon,kmon],p3d[imon,jmon,kmon],\
          k3d[imon,jmon,kmon],eps3d[imon,jmon,kmon],om3d[imon,jmon,kmon],vis3d[imon,jmon,kmon]))


      vismax=np.max(vis3d.flatten())/viscos
      umax=np.max(u3d.flatten())
      epsmin=np.min(eps3d.flatten())
      ommin=np.min(om3d.flatten())

      kmin=np.min(k3d.flatten())

      if itstep%10 == 0:
         cfl_x=np.abs(u3d)*dt[itstep]*areaw[1:,:,:]/vol
         cfl_y=np.abs(v3d)*dt[itstep]*areas[:,1:,:]/vol
         cfl_x_max=np.max(cfl_x)
         cfl_y_max=np.max(cfl_y)
         print('-time step: %d, cfl_x_max: %8.2E, cfl_y_max: %8.2E\n\n' % (itstep,cfl_x_max,cfl_y_max))


      print('-time step: %d, dt: %8.2E, iter: %d, umax=%8.2E, vismax = %8.2E, kmin = %8.2E, epsmin  %8.2E, ommin  %8.2E\n\n'\
      % (itstep,dt[itstep],iter, umax,vismax,kmin,epsmin,ommin))


      if iter >= min_iter-1 and resmax < sormax:  
#     if resmax < sormax:  

         break

######################### end of global iteration process #############################

   u3d_old,v3d_old,w3d_old,k3d_old,eps3d_old,om3d_old,dpdx_old,dpdy_old,dpdz_old=\
     update(u3d,v3d,w3d,k3d,eps3d,om3d,p3d_face_w,p3d_face_s,p3d_face_l)
# save data every itstep_save timsstep
   if itstep%itstep_save == 0 and itstep >= itstep_start:
      save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                          vv3d_stress,ww3d_stress,uv3d_stress)
   if save and itstep%itstep_save == 0 and itstep > 0:
      save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

   if itstep >= itstep_start and itstep % itstep_stats == 0:
      u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,vv3d_stress,ww3d_stress,uv3d_stress,\
          fk3d_mean,vis3d_mean= time_stats(u3d_mean,v3d_mean,w3d_mean,p3d_mean,k3d_mean,eps3d_mean,om3d_mean,uu3d_stress,\
          vv3d_stress,ww3d_stress,uv3d_stress,fk3d_mean,vis3d_mean)
   print('time one iteration',time.time()-start_time_iter)

######################### end of time stepping  #############################
      
# save data for restart
if save:
   save_data(u3d,v3d,w3d,p3d,k3d,eps3d,om3d)

# save time-averaged data
save_time_aver_data(u3d_mean,v3d_mean,w3d_mean,p3d_mean,eps3d_mean,om3d_mean,fk3d_mean,vis3d_mean,k3d_mean,uu3d_stress,\
                    vv3d_stress,ww3d_stress,uv3d_stress)

print('program reached normal stop')

