import numpy as np
import torch 
import sys 
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
#from sklearn.discriminant_analysis import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from random import randrange
from joblib import dump, load
from matplotlib import ticker

class ThePredictionMachine(nn.Module):

    def __init__(self):

        super(ThePredictionMachine, self).__init__()

        self.input   = nn.Linear(2, 50)
        self.hidden1 = nn.Linear(50, 50)
        self.hidden2 = nn.Linear(50, 2)

    def forward(self, x):

        x = nn.functional.relu(self.input(x))
        x = nn.functional.relu(self.hidden1(x))
        x = self.hidden2(x)


        return x

plt.rcParams.update({'font.size': 22})
plt.interactive(True)
plt.close('all')

viscos = 1/5200


# load DNS channel data
DNS_mean  = np.genfromtxt("LM_Channel_5200_mean_prof.dat",comments="%").transpose()
y_DNS     = DNS_mean[0]
yplus_DNS = DNS_mean[1]
u_DNS     = DNS_mean[2]
dudy_DNS  = np.gradient(u_DNS,yplus_DNS)

DNS_stress = np.genfromtxt("LM_Channel_5200_vel_fluc_prof.dat",comments="%").transpose()

uu_DNS = DNS_stress[2]
vv_DNS = DNS_stress[3]
ww_DNS = DNS_stress[4]
uv_DNS = DNS_stress[5]
uw_DNS = DNS_stress[6]
vw_DNS = DNS_stress[7]
k_DNS  = 0.5*(uu_DNS+vv_DNS+ww_DNS)

# %         y/delta                    y^+                   Production          Turbulent_Transport        Viscous_Transport       Pressure_Strain         Pressure_Transport        Viscous_Dissipation           Balance

DNS_RSTE = np.genfromtxt("LM_Channel_5200_RSTE_k_prof.dat",comments="%")

eps_DNS = DNS_RSTE[:,7]
visc_diff =  DNS_RSTE[:,4]


# fix wall
eps_DNS[0]=eps_DNS[1]
k_DNS[0]=k_DNS[1]

# load pytorch model
folder='./'

filename=str(folder)+'model.pth'
neural_net = torch.load(filename)
print('model',neural_net)
scaler_dudy2 = load(str(folder)+'model_scaler-dudy2.bin')
scaler_dudy = load(str(folder)+'model_scaler-dudy.bin')

dudy2_min, dudy2_max, dudy, dudy, c0_min, c0_max, c2_min, c2_max = np.loadtxt(str(folder)+'./min-max.txt')

print('dudy2_min, dudy2_max, dudy, dudy',dudy2_min, dudy2_max, dudy, dudy)


