
import cupy as xp
import time,random,sys
#from scipy.signal import welch, hann
import math


def synt_fluct(nmodes,it,sli,yp,zp,uv_rans,visc,jmirror,dmin_synt):
#=========================== chapter 1 ============================================

#!!!  number of modes                        = nmodes
#!!!  smallest wavenumber                    = dxmin
#!!!  ratio  of ke and kmin (in wavenumber)  = wew1fct
#!!!  turb. velocity scale                   = up
#!!!  diss. rate.                            = epsm
#!!!  kinetic viscosity                      = visc
#!!!  length scale                           = sli
#!!!  mirror vfluct at j > jmirror

   global dxmin,amp,epsm,epsm,wnr1,wew1fct,xp,yp2d_synt,zp2d_synt,xp2d_synt,nj,up
   global e,kxio,kyio,kyio,sxio,syio,syio,utn,tfunk,wnre,dkn,arg1,arg2,arg3,arg,fi,psi,teta,alfa,wnr,kx,ky,kz,wnrn,\
          r11,r12,r13,r21,r22,r23,r31,r32,r33,a11,a22,a33,wnreta,uv_synt_mean,uvmean_check,a11i,a22i,a33i,\
          xp2d_wave,yp2d_wave,zp2d_wave,uv_rans_non,uv_rans_max
   global utn,tfunk,sx,e,rk,kxi,usynt_wave,usynt1,usynt,sy,vsynt,yp2d_org,usynt_aniso,scale_2

   uv_rans=xp.abs(uv_rans)
   if it == 0:

      uvmean_check=0
      uv_synt_mean=0
# anisotropix fluctuations
#     R=xp.loadtxt('R.dat')
      R=xp.genfromtxt("R.dat", dtype=None,comments="%")
      r11=R[0,0]
      r12=R[0,1]
      r13=R[0,2]
      r21=R[1,0]
      r22=R[1,1]
      r23=R[1,2]
      r31=R[2,0]
      r32=R[2,1]
      r33=R[2,2]

#     A=xp.loadtxt('a.dat')
      A=xp.genfromtxt("a.dat", dtype=None,comments="%")
      a11=A[0]
      a22=A[1]
      a33=A[2]

      amp=1.452762113
      wew1fct=2

# in log region:  k/uvmax=3.3  => up=(3.3*uvmax)**0.5
      if xp.all(uv_rans==1):
         up=1
      else:
         up=(3.3*xp.max(uv_rans))**0.5

      epsm=up**3/sli
#
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
#     number of  grid points in y, z
      nj=len(yp)
      nk=len(zp)

# make it 2D
      zp2d_synt=xp.repeat(zp[None,:], repeats=nj, axis=0)

# make it 2D
      yp2d_synt=xp.repeat(yp[:,None], repeats=nk, axis=1)


      xpp=0.5
# make it 2D
      xp2d_synt=xp.ones((nj,nk))*xpp

      yp2d_org=yp2d_synt
# transform to principal coord. directions
      xp2d_synt=r11*xpp+r21*yp2d_org+r31*zp2d_synt
      yp2d_synt=r12*xpp+r22*yp2d_synt+r32*zp2d_synt
      zp2d_synt=r13*xpp+r23*yp2d_synt+r33*zp2d_synt

# search min grid step
      dminy=xp.min(xp.diff(yp))
      dminz=xp.min(xp.diff(zp))
      dxmin=min(dminy,dminz)

# don't let is be smaller than dmin_synt
      dxmin=max(dxmin,dmin_synt)



# create a seed from time 
      xp.random.seed()
      xp.random.seed(2)

# zero all arrays to zero
      wnr=xp.zeros(nmodes+2)
      fi=xp.zeros(nmodes+2)
      teta=xp.zeros(nmodes+2)
      psi=xp.zeros(nmodes+2)
      wnr=xp.zeros(nmodes+2)
      kxio=xp.zeros((nj,nk,nmodes+2))
      kyio=xp.zeros((nj,nk,nmodes+2))
      kzio=xp.zeros((nj,nk,nmodes+2))
      sxio=xp.zeros((nj,nk,nmodes+2))
      syio=xp.zeros((nj,nk,nmodes+2))
      szio=xp.zeros((nj,nk,nmodes+2))
#  yp2d_wave=xp.zeros((nj,nk,nmodes+2))
      zp2d_wave=xp.zeros((nj,nk,nmodes+2))
      u=xp.zeros((nj,nk))
      v=xp.zeros((nj,nk))
      w=xp.zeros((nj,nk))
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#     highest wave number
      wnrn=2.*math.pi/dxmin
#
#     k_e (related to peak energy wave number)
      wnre=9.*math.pi*amp/(55.*sli)
#
# wavenumber used in the viscous expression (high wavenumbers) in the von Karman spectrum
      wnreta=(epsm/visc**3)**0.25

#     smallest wavenumber 
      wnr1=wnre/wew1fct

# wavenumber step
      dkn=(wnrn-wnr1)/nmodes

# wavenumers
      wnr=xp.linspace(wnr1,wnrn,nmodes)

# invert the eigenvalue matrix (anisotropic)
      a11i=1/a11
      a22i=1/a22
      a33i=1/a33

# make a non-dimensional uv_rans profile
      uv_rans_max=xp.max(uv_rans)
      uv_rans_non=(uv_rans/(uv_rans_max+1e-10))**0.5
# make it 2D
      uv_rans_non=xp.repeat(uv_rans_non[:,None], repeats=nk, axis=1)


      print(f"\n{'sli: '} {sli}, {'visc: '}{visc:.2e}, {'nmodes: '}{nmodes}, {'dxmin: '}{dxmin:.3e}, {'dkn: '}{dkn:.3e}, {'dmin_synt: '}{dmin_synt:.3e}")

      print(f"\n{'wnre: '} {wnre:.2e}, {'wnr1: '}{wnr1:.2e}, {'epsm: '}{epsm:.3e}, {'wnrn: '}{wnrn:.3e}")

      print(f"\n{'eigenvalue 1, 2 and 3: '}{a11:.3e}, {a22:.3e}, {a33:.3e}")
      print(f"\n{'eigenvector R11, R12 and R13: '}{r11:.3e}, {r12:.3e}, {r13:.3e}")
      print(f"\n{'eigenvector R21, R22 and R23: '}{r21:.3e}, {r22:.3e}, {r23:.3e}")
      print(f"\n{'eigenvector R31, R32 and R33: '}{r31:.3e}, {r32:.3e}, {r33:.3e}\n")

      

#
#=========================== chapter 2 ============================================
#

# compute random angles
   fi = xp.random.uniform(0.,2.*math.pi,nmodes)
   psi = xp.random.uniform(0.,2.*math.pi,nmodes)
   alfa = xp.random.uniform(0.,2.*math.pi,nmodes)
   ang = xp.random.uniform(0.,1,nmodes)
   teta=xp.arccos(1.-ang/0.5) 

   print('time step no,',it)


#   wavenumber vector from random angles
   kxio=xp.sin(teta)*xp.cos(fi)
   kyio=xp.sin(teta)*xp.sin(fi)
   kzio=xp.cos(teta)
#
# sigma (s=sigma) from random angles. sigma is the unit direction which gives the direction
# of the synthetic velocity vector (u, v, w)
   sxio=xp.cos(fi)*xp.cos(teta)*xp.cos(alfa)-xp.sin(fi)*xp.sin(alfa)
   syio=xp.sin(fi)*xp.cos(teta)*xp.cos(alfa)+xp.cos(fi)*xp.sin(alfa)
   szio=-xp.sin(teta)*xp.cos(alfa)
   
#
#=========================== chapter 3 ============================================
#
# loop over all wavenumbers
   kxi=r11*kxio+r21*kyio+r31*kzio
   kyi=r12*kxio+r22*kyio+r32*kzio
   kzi=r13*kxio+r23*kyio+r33*kzio

   sxi=r11*sxio+r21*syio+r31*szio
   syi=r12*sxio+r22*syio+r32*szio
   szi=r13*sxio+r23*syio+r33*szio

   sx=a11**0.5*sxi
   sy=a22**0.5*syi
   sz=a33**0.5*szi

   kx=kxi*wnr*a11i**0.5
   ky=kyi*wnr*a22i**0.5
   kz=kzi*wnr*a33i**0.5
   rk=xp.sqrt(kx**2+ky**2+kz**2)



   xp2d_wave=xp.repeat(xp2d_synt[:,:,None], repeats=nmodes, axis=2)
   arg1=xp2d_wave*kx


   yp2d_wave=xp.repeat(yp2d_synt[:,:,None], repeats=nmodes, axis=2)
   arg2=yp2d_wave*ky

   zp2d_wave=xp.repeat(zp2d_synt[:,:,None], repeats=nmodes, axis=2)
   arg3=zp2d_wave*kz

   arg=arg1+arg2+arg3+psi

   tfunk=xp.cos(arg)

# von Karman spectrum
   e=amp/wnre*(wnr/wnre)**4/((1.+(wnr/wnre)**2)**(17./6.))*xp.exp(-2*(wnr/wnreta)**2)

# include only wavenumber for which rk < wnrn
   e=xp.where(rk < wnrn,e,0)

   utn=xp.sqrt(e*up**2*dkn)

# sum over all wavenumbers => synthetic velocity field 
   usynt=xp.sum(2.*utn*tfunk*sx,axis=2)
   vsynt=xp.sum(2.*utn*tfunk*sy,axis=2)
   wsynt=xp.sum(2.*utn*tfunk*sz,axis=2)
   
# transform back to x-y-z  => anjsotropic fluct
   usynt_aniso=r11*usynt+r12*vsynt+r13*wsynt
   vsynt_aniso=r21*usynt+r22*vsynt+r23*wsynt
   wsynt_aniso=r31*usynt+r32*vsynt+r33*wsynt

# mean shear stress (must be computed before mirroring)
   uv=xp.mean(usynt_aniso*vsynt_aniso)

# mirror vfluct
   if jmirror > 0:
      vsynt_aniso[jmirror:,:]=-vsynt_aniso[jmirror:,:]

# sum over timesteps
   uv_synt_mean=uv_synt_mean+uv

# compute average
   uvmean_time=xp.abs(uv_synt_mean)/(it+1)
   print('uvmean_time',uvmean_time)

   scale_2=(uv_rans_max/uvmean_time)**0.5*uv_rans_non

# if uv_rans=1: don't scale
#  scale_2=xp.where(uv_rans==1,1,(uv_rans_max/uvmean_time)**0.5*uv_rans_non)

# scale all fluctuations with uv_rans
   usynt_aniso=usynt_aniso*scale_2
   vsynt_aniso=vsynt_aniso*scale_2
   wsynt_aniso=wsynt_aniso*scale_2


# compute mean of synt fluct
   uvmean_check=uvmean_check+xp.mean(usynt_aniso*vsynt_aniso,axis=1)

# peak of uv_rans
   j=xp.where(uv_rans == xp.amax(uv_rans))

# check peak
   print('synt: uvmean',xp.abs(uvmean_check[j])/(it+1),'uv_rans_max=',xp.max(xp.abs(uv_rans)),'at j=',j)
 
   return usynt_aniso,vsynt_aniso,wsynt_aniso
