
def modify_init(u3d,v3d,w3d,k3d,om3d,eps3d,vis3d):

   k3d=xp.ones((ni,nj,nk))*0.01
   om3d=xp.ones((ni,nj,nk))*0.01
   
   return u3d,v3d,w3d,k3d,om3d,eps3d,vis3d, dist3d

def modify_inlet():

   return u_bc_west,v_bc_west,w_bc_west,k_bc_west,eps_bc_west,om_bc_west,u3d_face_w,convw

def modify_conv(convw,convs,convl):

   return convw,convs,convl

def modify_u(su3d,sp3d):

   global beta_old, uin_old, uin, beta

   if itstep == 0 and iter == 0:
      xp.savetxt('u-iteration.dat', xp.c_[iter,u3d[-1,5,0],u3d[-1,10,0],u3d[-1,20,0],u3d[-1,30,0],u3d[-1,40,0],\
           u3d[-1,50,0],u3d[-1,100,0]])
   elif itstep > 0 and iter == 0:
      with open('u-iteration.dat','ab') as f:
         xp.savetxt(f,xp.c_[iter,u3d[-1,5,0],u3d[-1,10,0],u3d[-1,20,0],u3d[-1,30,0],u3d[-1,40,0],\
           u3d[-1,50,0],u3d[-1,100,0]])

#  taus=xp.sum(viscos*as_bound*u3d[:,0,:])
#  taun=xp.sum(viscos*an_bound*u3d[:,-1,:])
#  sumvol=xp.sum(vol)
#  sumps=xp.sum(p3d[:,0,:]*areasx[:,0,:])
#  uin=xp.sum(convw[0,:,:].flatten())/(y2d[0,-1]-y2d[0,0])/zmax
#  uin1 = uin
#  total_forces=taus+taun+sumps
#  beta=total_forces/sumvol

# north wall
   ustar=cmu**0.25*k3d[:,-1,:]**0.5
   tauw=ustar**2
   su3d[:,-1,:]=su3d[:,-1,:]-tauw*areas[:,0,:]

# compute beta
   if itstep  == 0 and iter == 0:
      beta = xp.loadtxt('beta.dat')
      beta = xp.ones(2)*beta
      uin = xp.loadtxt('uin.dat')
      uin = xp.ones(2)*1
      

   uin=xp.sum(convw[0,:,:].flatten())/(y2d[0,-1]-y2d[0,0])/zmax
   uin_target = 1

   if iter ==0:
      uin_old = uin
      beta_old = beta

   beta = beta_old + 0.001*(uin_target - 2*uin_old + uin)
   

   if itstep % 5 ==0 and iter == 0:
      if gpu:
# cuda can only save an array
         beta1 = xp.ones(2)*beta
         xp.savetxt('beta.dat', xp.c_[beta1])
         uin1 = xp.ones(2)*uin_old
         xp.savetxt('uin.dat', xp.c_[uin1])
      else:
         xp.savetxt('beta.dat', xp.c_[beta])

   su3d=su3d+beta*vol

   print('uin',uin)
   print('uin_old',uin_old)
   print('beta',beta)
#   print('ubulk=%7.3E,drag-south=%7.3E,drag-north=%7.3E,beta=%7.3E' %(uin,drag_south,drag_north,beta[0]))

   return su3d,sp3d

def modify_v(su3d,sp3d):

   return su3d,sp3d


def modify_w(su3d,sp3d):

   return su3d,sp3d

def modify_k(su3d,sp3d,gen):

   comm_term=xp.zeros((ni,nj,nk))

   return su3d,sp3d,comm_term

def modify_eps(su3d,sp3d):

   return su3d,sp3d

def modify_om(su3d,sp3d,comm_term):

   return su3d,sp3d

def fix_omega():
 
   if itstep == 0 and iter == 0:
      print('fix_omega called')

   j = 0
   aw3d[:,j,:]=0
   ae3d[:,j,:]=0
   as3d[:,j,:]=0
   an3d[:,j,:]=0
   al3d[:,j,:]=0
   ah3d[:,j,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,j,:]=ap_max
   su3d[:,j,:]=ap_max*om_bc_south

# north wall
   j = -1
   ustar=cmu**0.25*k3d[:,j,:]**0.5
   dy=dist3d[0,j,0]
   aw3d[:,j,:]=0
   ae3d[:,j,:]=0
   as3d[:,j,:]=0
   an3d[:,j,:]=0
   al3d[:,j,:]=0
   ah3d[:,j,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,j,:]=ap_max
   su3d[:,j,:]=ap_max*ustar/kappa/dy/cmu**0.5

# eps = ustar**3/kappa/dy
# k = cmu**(-0.5)*ustar**2
#
# om = eps/k/cmu = ustar**3/kappa/dy/(cmu**(-0.5)*ustar**2)/cmu =  ustar/dy/cmu**0.5

# vist = k/om = cmu**(-0.5)*ustar**2/ustar*kappa*dy*cmu**0.5 = ustar*kappa*dy

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d


def modify_outlet(convw):

   return convw,u_bc_east

def modify_fk(fk3d):

   return fk3d

def aver_iddes(denom,f_e,f_e1,f_e2,f_d,f_dt,f_b,l_c,l_tilde,\
    denom_mean,f_e_mean,f_e1_mean,f_e2_mean,f_d_mean,f_dt_mean,f_b_mean,l_c_mean,l_tilde_mean,\
        f_l_mean,f_t_mean):

   f_e_mean=f_e_mean+xp.mean(f_e,axis=(0,2))
   f_e1_mean=f_e1_mean+xp.mean(f_e1,axis=(0,2))
   f_e2_mean=f_e2_mean+xp.mean(f_e2,axis=(0,2))
   f_d_mean=f_d_mean+xp.mean(f_d,axis=(0,2))
   l_c_mean=l_c_mean+xp.mean(l_c,axis=(0,2))
   l_tilde_mean=l_tilde_mean+xp.mean(l_tilde,axis=(0,2))
   f_dt_mean=f_dt_mean+xp.mean(f_dt,axis=(0,2))
   f_b_mean=f_b_mean+xp.mean(f_b,axis=(0,2))
   denom_mean=denom_mean+xp.mean(denom,axis=(0,2))
   f_l_mean=f_l_mean+xp.mean(f_l,axis=(0,2))
   f_t_mean=f_t_mean+xp.mean(f_t,axis=(0,2))

   return f_e_mean,f_d_mean,f_dt_mean,f_b_mean,denom_mean,f_e1_mean,f_e2_mean,\
       l_c_mean,l_tilde_mean,f_l_mean,f_t_mean

def fix_eps():

# south wall
   aw3d[:,0,:]=0
   ae3d[:,0,:]=0
   as3d[:,0,:]=0
   an3d[:,0,:]=0
   al3d[:,0,:]=0
   ah3d[:,0,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,0,:]=ap_max
   su3d[:,0,:]=ap_max*2*viscos*k3d[:,0,:]/dist3d[:,0,:]**2

# north wall
   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*2*viscos*k3d[:,-1,:]/dist3d[:,-1,:]**2

   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d

def fix_k():
 
   if itstep == 0 and iter == 0:
      print('fix_k called')

#  north wall
   ustar=cmu**0.25*k3d[:,-1,:]**0.5
   dy=dist3d[0,-1,0]
   velabs=abs(u3d[:,-1,:])
#  ustar=compute_ustar_reich_solve(dy/viscos,velabs,ustar)
   ustar=compute_ustar_reich_solve(dy/viscos,velabs,ustar)
   kwall=cmu**(-0.5)*ustar**2

   ustar_min =xp.min(ustar)
   ustar_max =xp.max(ustar)
   print('north wall: dy,ustar_min,max',dy,ustar_min,ustar_max)

   kwall_min =xp.min(kwall)
   kwall_max =xp.max(kwall)
   print('north wall: kwall_min,max',kwall_min,kwall_max)



   aw3d[:,-1,:]=0
   ae3d[:,-1,:]=0
   as3d[:,-1,:]=0
   an3d[:,-1,:]=0
   al3d[:,-1,:]=0
   ah3d[:,-1,:]=0
   ap_max=xp.max(ap3d)
   ap3d[:,-1,:]=ap_max
   su3d[:,-1,:]=ap_max*kwall



   return aw3d,ae3d,as3d,an3d,al3d,ah3d,ap3d,su3d,sp3d



def modify_vis(vis3d):

   return vis3d


def compute_ustar_reich_solve(wdist,velabs,ustar):

   from scipy.optimize import fsolve,root,newton
   
   yplus=ustar*wdist/viscos

   ustar=ustar.flatten()
   velabs=velabs.flatten()
   #ustar = fsolve(solve_ustar_reich,x0=ustar,args=(velabs,wdist),xtol=1)
   ustar = newton(solve_ustar_reich,x0=ustar,args=(velabs,wdist))

   ustar=xp.reshape(ustar,(ni,nk))

   return ustar

def compute_ustar(wdist,velabs,ustar,elog,n):

   for i in range(0,n):
      arg=xp.maximum(elog*ustar*wdist/viscos,10.)
      ustar=kappa*velabs/xp.log(arg)

   xyplus=ustar*wdist/viscos
   ustar=xp.where(xyplus <= 11.69,(viscos*velabs/wdist)**0.5,ustar)

   return ustar

