import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from random import randrange
from joblib import dump, load

class MyNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.input   = nn.Linear(2, 10)
        self.hidden1 = nn.Linear(10, 10)
        self.hidden2 = nn.Linear(10, 1)

    def forward(self, x):
        x = nn.functional.relu(self.input(x))
        x = nn.functional.relu(self.hidden1(x))
        x = self.hidden2(x)

        return x

def modify_init(u2d,v2d,k2d,om2d,eps2d,vis2d):

# set inlet field in entre domain
   u2d=np.repeat(u_bc_west[None,:], repeats=ni, axis=0)
   k2d=np.ones((ni,nj))
   om2d=np.ones((ni,nj))

   vis2d=k2d/om2d+viscos

   return u2d,v2d,k2d,om2d,eps2d,vis2d,dist

def modify_inlet():

   return u_bc_west,v_bc_west,k_bc_west,om_bc_west,om_bc_west,u2d_face_w,convw

def modify_conv(convw,convs):

# since we are solving for fully-developed channel flow, we know that the convection terms are zero

   return convw,convs

def modify_u(su2d,sp2d):

   global file1

# add a driving pressure gradient term
   su2d= su2d+vol

# we know that the convection and diffusion term in the x direction are zero
   aw2d=np.zeros((ni,nj))
   ae2d=np.zeros((ni,nj))

# we know that for this flow the wall shear stress mustt be equal to one (since the driving pressure
# gradient is equal to one). We print it every iteration to see if it is one. When it reaches one it is
# a good indicator that the flow has converged

   tauw_south=viscos*np.sum(as_bound*u2d[:,0])/x2d[-1,0]
   tauw_north=viscos*np.sum(an_bound*u2d[:,-1])/x2d[-1,0]

   print(f"{'tau wall, south: '} {tauw_south:.3f},{'  tau wall, north: '} {tauw_north:.3f}")


   return su2d,sp2d

def modify_v(su2d,sp2d):

   return su2d,sp2d

def modify_p(su2d,sp2d):

   return su2d,sp2d

def modify_k(su2d,sp2d):

   if iter == 0:
      np.savetxt('k-iteration.dat', np.c_[iter,k2d[-1,5],k2d[-1,10],k2d[-1,20],k2d[-1,30],k2d[-1,40],\
           k2d[-1,50],k2d[-1,58]])
   else:
      with open('k-iteration.dat','ab') as f:
         np.savetxt(f,np.c_[iter,k2d[-1,5],k2d[-1,10],k2d[-1,20],k2d[-1,30],k2d[-1,40],\
           k2d[-1,50],k2d[-1,58]])

   return su2d,sp2d

def modify_om(su2d,sp2d):

   return su2d,sp2d

def modify_outlet(convw):

# since we are solving for fully-developed channel flow, we know that the convection terms are zero
   convw=np.zeros((ni+1,nj))

   return convw

def fix_omega():

   aw2d[:,0]=0
   ae2d[:,0]=0
   as2d[:,0]=0
   an2d[:,0]=0
   ap2d[:,0]=1
   su2d[:,0]=om_bc_south

   return aw2d,ae2d,as2d,an2d,ap2d,su2d,sp2d

def modify_vis(vis2d):

   return vis2d


def fix_k():

   return aw2d,ae2d,as2d,an2d,ap2d,su2d,sp2d


def modify_PINN(prand_k_ML, c_k_ML, c_omega_2_ML):
    from joblib import dump, load
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import TensorDataset, DataLoader
    from scipy.interpolate import interp1d

    global prand_k_ML_mean, c_k_ML_mean, c_omega_2_ML_mean, a_int, b_int
    global scaler_vist_over_y,scaler_uv
    global prand_k_ML_min,prand_k_ML_max,c_k_ML_min,c_k_ML_max,c_omega_2_ML_min,c_omega_2_ML_max
    global uv_min, uv_max,vist_over_y_min,vist_over_y_max
    global NN_c_k, NN_c_omega_2, NN_prand_k
    global prand_k_ML_old, c_k_ML_old, c_omega_2_ML_old, uv1, vist_over_y1, c_k_ML1, prand_k_ML1, c_omega_2_ML1, \
     ustar1, yplus2d1, jv11


    if iter == 0:
      print('NN modify_NN called')

# load data c_k
      name = '/chalmers/users/lada/noback/pycalc-les/pytorch-k-eq/'
      filename='model-neural-k-omega-c_k-vist_over_y-and-uv_tot.pth'
      NN_c_k = torch.load(str(name)+filename,weights_only=False)
      scaler_vist_over_y = load(str(name)+'model-scaler-vist_over_y-neural-k-omega-c_k-vist_over_y-and-uv_tot.bin')
      scaler_uv = load(str(name)+'model-scaler-uv_tot-neural-k-omega-c_k-vist_over_y-and-uv_tot.bin')

      uv_min, uv_max,vist_over_y_min,vist_over_y_max,c_k_ML_min,c_k_ML_max = \
        np.loadtxt(str(name)+'min-max-model-k-omega-c_k-vist_over_y-and-uv_tot.txt')

      print('NN_c_k',NN_c_k)
      print('vist_over_y_min, vist_over_y_max',vist_over_y_min,vist_over_y_max)
      print('uv_min, uv_max',uv_min,uv_max)
      print('c_k_ML_min, c_k_ML_max',c_k_ML_min,c_k_ML_max)

# load data prand_k
      filename='model-neural-k-omega-prand_k-vist_over_y-and-uv_tot.pth'
      NN_prand_k = torch.load(str(name)+filename,weights_only=False)
# scalers already loader
      d1, d2, d3, d4, prand_k_ML_min,prand_k_ML_max = \
           np.loadtxt(str(name)+'min-max-model-k-omega-prand_k-vist_over_y-and-uv_tot.txt')

      print('NN_prand_k',NN_prand_k)
      print('prand_k_ML_min,prand_k_ML_max',prand_k_ML_min,prand_k_ML_max)

# load data c_omega_2
      filename='model-neural-k-omega-c_omega_2-vist_over_y-and-uv_tot.pth'
      NN_c_omega_2 = torch.load(str(name)+filename,weights_only=False)
# scalers already loader
      d1, d2, d3, d4, c_omega_2_ML_min, c_omega_2_ML_max = \
        np.loadtxt(str(name)+'min-max-model-k-omega-c_omega_2-vist_over_y-and-uv_tot.txt')

      print('NN_c_omega_2',NN_c_omega_2)
      print('c_omega_2_ML_min, c_omega_2_ML_max',c_omega_2_ML_min,c_omega_2_ML_max)

# load prand_k_ML ... from earlier simulation
      data_ML = np.loadtxt('y-prand_k-c_k-c_omega_2-CFD.txt')
      y_0 = data_ML[:,1]
      prand_k_ML_0 = data_ML[:,1]
      c_k_ML_0 = data_ML[:,2]
      c_omega_2_ML_0 = data_ML[:,3]

      prand_k_ML = np.interp(yp2d[0,:], y_0, prand_k_ML_0)
      c_k_ML = np.interp(yp2d[0,:], y_0, c_k_ML_0)
      c_omega_2_ML = np.interp(yp2d[0,:], y_0, c_omega_2_ML_0)

# make it 2D
      prand_k_ML = np.repeat(prand_k_ML[None,:], repeats=ni, axis=0)
      c_k_ML = np.repeat(c_k_ML[None,:], repeats=ni, axis=0)
      c_omega_2_ML = np.repeat(c_omega_2_ML[None,:], repeats=ni, axis=0)


      c_k_ML_old = c_k_ML
      prand_k_ML_old = prand_k_ML
      c_omega_2_ML_old = c_omega_2_ML

#     t_int=100./20.  !100 h/u_b according to laage de meux, audebert, manceau & perrin phys. fluids
#      t_int=5.*0.13  !100 h/u_b   h=0.13, U=1 => 5*h/U=5*0.13
#      t_int=100.*0.13  !100 h/u_b   h=0.13, U=1 => 100*h/U=100*0.13
#      a_int=exp(-dt(itstep)/t_int)
#      b_int=1.-a_int
#             u_rans_mean(i,j)=a_int*u_rans_mean(i,j)+b_int*u_rans(i,j)

      prand_k_ML_mean = prand_k_ML
      c_k_ML_mean = c_k_ML
      c_omega_2_ML_mean = c_omega_2_ML
      t_int = 300

      a_int = np.exp(-1/t_int)
      b_int=1-a_int


# compute y+
    yp=yp2d[0,0]
    ustar=(u2d[:,0]*viscos/yp)**0.5
#make it 2D
    ustar=np.repeat(ustar[:,None], repeats=nj, axis=1)
    yplus2d=ustar*yp2d/viscos
    dudy=dphidy(u2d_face_w,u2d_face_s)
    uv = vis2d*dudy/ustar**2
    vist_over_y = (vis2d - viscos)/yp2d/ustar

# from uv_tot-vist_over_y-c_k_NN
# 9.804025330740754329e-01
# 3.580336111816977973e-01
# 3.978034257888793945e-01 c_k_ML


# from uv_tot-vist_over_y-c_omega_2_NN
# 9.804025330740754329e-01
# 3.580336111816977973e-01
# 4.332774877548217773e-02 c_omega_2_ML
 
# from uv_tot-vist_over_y-prand_k_NN
#9.804025330740754329e-01
#3.580336111816977973e-01
#1.997603535652160645e+00 prand_k_ML


#   uv[0,0] = 9.804025330740754329e-01
#   vist_over_y[0,0] = 3.580336111816977973e-01



# limit min/max
# count values larger/smaller than max/min
    uv_min_number= (uv < uv_min).sum()
    uv_max_number= (uv > uv_max).sum()
    print('uv_min_number',uv_min_number)
    print('uv_max_number',uv_max_number)
    print('uv_min-max',np.min(uv),np.max(uv))

# set limits
    uv=np.minimum(uv,uv_max)
    uv=np.maximum(uv,uv_min)

    vist_over_y_min_number= (vist_over_y < vist_over_y_min).sum()
    vist_over_y_max_number= (vist_over_y > vist_over_y_max).sum()
    print('vist_over_y_min_number',vist_over_y_min_number)
    print('vist_over_y_max_number',vist_over_y_max_number)
    print('vist_over_y_min-max',np.min(vist_over_y),np.max(vist_over_y))
# set limits
    vist_over_y=np.minimum(vist_over_y,vist_over_y_max)
    vist_over_y=np.maximum(vist_over_y,vist_over_y_min)

    vist_over_y1 = vist_over_y
    uv1 = uv
#
    vist_over_y = vist_over_y.reshape(-1,1)
    uv= uv.reshape(-1,1)
    X=np.zeros((len(uv),2))
    X[:,0] = scaler_vist_over_y.transform(vist_over_y)[:,0]
    X[:,1] = scaler_uv.transform(uv)[:,0]
    X_tensor = torch.tensor(X, dtype=torch.float32)

# predict prand_k
    prand_k_pred = NN_prand_k(X_tensor)
# transform from tensor to numpy
    prand_k_ML = prand_k_pred.detach().numpy()[:,0]
    prand_k_ML = np.reshape(prand_k_ML,(ni,nj))

    print('prand_k_ML[0,0]',prand_k_ML[0,0])

# set limits
    print('prand_k_ML_min, max',np.min(prand_k_ML),np.max(prand_k_ML))
    prand_k_min_number= (prand_k_ML < prand_k_ML_min).sum()
    prand_k_max_number= (prand_k_ML > prand_k_ML_max).sum()
    print('prand_k_ML_number',prand_k_min_number)
    print('prand_k_ML_max_number',prand_k_max_number)
    prand_k_ML=np.minimum(prand_k_ML,prand_k_ML_max)
    prand_k_ML=np.maximum(prand_k_ML,prand_k_ML_min)

# predict c_k
    c_k_pred = NN_c_k(X_tensor)
# transform from tensor to numpy
    c_k_ML = c_k_pred.detach().numpy()[:,0]
    c_k_ML = np.reshape(c_k_ML,(ni,nj))

    print('c_k_ML[0,0]',c_k_ML[0,0])

# set limits
    print('c_k_ML_min, max',np.min(c_k_ML),np.max(c_k_ML))
    c_k_min_number= (c_k_ML < c_k_ML_min).sum()
    c_k_max_number= (c_k_ML > c_k_ML_max).sum()
    print('c_k_ML_number',c_k_min_number)
    print('c_k_ML_max_number',c_k_max_number)
    c_k_ML=np.minimum(c_k_ML,c_k_ML_max)
    c_k_ML=np.maximum(c_k_ML,c_k_ML_min)

# predict c_omega_2
    c_omega_2_pred = NN_c_omega_2(X_tensor)
# transform from tensor to numpy
    c_omega_2_ML = c_omega_2_pred.detach().numpy()[:,0]
    c_omega_2_ML = np.reshape(c_omega_2_ML,(ni,nj))
    print('c_omega_2_ML[0,0]',c_omega_2_ML[0,0])

# set limits
    print('c_omega_2_ML_min, max',np.min(c_omega_2_ML),np.max(c_omega_2_ML))
    c_omega_2_min_number= (c_omega_2_ML < c_omega_2_ML_min).sum()
    c_omega_2_max_number= (c_omega_2_ML > c_omega_2_ML_max).sum()
    c_omega_2_ML=np.minimum(c_omega_2_ML,c_omega_2_ML_max)
    c_omega_2_ML=np.maximum(c_omega_2_ML,c_omega_2_ML_min)

#             u_rans_mean(i,j)=a_int*u_rans_mean(i,j)+b_int*u_rans(i,j)
    prand_k_ML_mean = a_int*prand_k_ML_mean+b_int*prand_k_ML
    c_k_ML_mean = a_int*c_k_ML_mean+b_int*c_k_ML
    c_omega_2_ML_mean = a_int*c_omega_2_ML_mean+b_int*c_omega_2_ML

    j = 30

    if iter % 100== 0:
      np.savetxt('prand_k-c_k-c_omega_2.dat', np.c_[prand_k_ML_mean[0,j],c_k_ML_mean[0,j],c_omega_2_ML_mean[0,j]])
    elif iter > 0:
      with open('prand_k-c_k-c_omega_2.dat','ab') as f:
       np.savetxt(f, np.c_[prand_k_ML_mean[0,j],c_k_ML_mean[0,j],c_omega_2_ML_mean[0,j]])


#   jv1 = ni*[0]
#   yy = 500
#   index_choose=np.nonzero(yp2d[0,:]/viscos > 500 )
#   j1 = index_choose[0][0]

#   c_k_ML[:,j1:] = 1
#   c_omega_2_ML[:,j1:] = 3/40   # don't use c_omega_2[0,0]
#   prand_k_ML[:,j1:] = 2  # don't use prand_k[0,0]


    if iter % 500 == 0:
# save data for restart
       np.savetxt('y-prand_k-c_k-c_omega_2-CFD.txt', np.c_[yp2d[0,:],prand_k_ML[0,:],c_k_ML[0,:],c_omega_2_ML[0,:]])

    return prand_k_ML_mean, c_k_ML_mean, c_omega_2_ML_mean
