import numpy as np
import torch 
import sys 
import time
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
#from sklearn.discriminant_analysis import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from random import randrange
from joblib import dump, load

plt.rcParams.update({'font.size': 22})
plt.interactive(True)
plt.close('all')

viscos_5200 = 1/5200


# load k-omega grid with correct k
name = '../pytorch-solve-k-omega-with-vist_ML-c_k_ML-c_omega_2_ml-from-balance-half-channel-5200-plus-units-prand-max.eq.2-10-omega-bc-works-nj120/'
kom_data = np.loadtxt(str(name)+'y_u_k_om_uv_5200-RANS-half-channel.txt')
y_5200 = kom_data[:,0]
u_5200 = kom_data[:,1]
k_5200 = kom_data[:,2]
om_5200 = kom_data[:,3]
vist_5200 = k_5200/om_5200
yplus_5200 = y_5200/viscos_5200
dudy_plus = np.gradient(u_5200,yplus_5200)
dudy= np.gradient(u_5200,y_5200)
uv_5200 = np.abs(-vist_5200*dudy)
uv_tot_5200 = np.abs(-(vist_5200+viscos_5200)*dudy)

vist_over_y_5200 = vist_5200/y_5200

uv = uv_tot_5200
vist_over_y = vist_over_y_5200

name_nj70 = '../channel-5200-half-channel/'
kom_data_nj70 = np.loadtxt(str(name_nj70)+'y_u_k_om_uv_5200-RANS-half-channel.txt')
y_5200_nj70 = kom_data_nj70[:,0]
c_k_5200_nj70= np.loadtxt('../PINN/c_k_pred_5200-plus-units-from-balance.txt')
c_k_5200 = np.interp(y_5200, y_5200_nj70, c_k_5200_nj70)

c_k = c_k_5200



y_L = y_5200




c = np.array([c_k])

# transpose the target vector to make it a column vector  
y = c.transpose()

# re-shape
uv= uv.reshape(-1,1)
vist_over_y= vist_over_y.reshape(-1,1)
scaler_uv = MinMaxScaler()
scaler_vist_over_y = MinMaxScaler()
X=np.zeros((len(uv),2))
X[:,0] = scaler_vist_over_y.fit_transform(vist_over_y)[:,0]
X[:,1] = scaler_uv.fit_transform(uv)[:,0]

# split the feature matrix and target vector into training and validation sets
# test_size=0.2 means we reserve 20% of the data for validation
# random_state=42 is a fixed seed for the random number generator, ensuring reproducibility

random_state = randrange(100)
random_state = randrange(200)


indices = np.arange(len(X))
X_train, X_test, y_train, y_test, index_train, index_test = train_test_split(X, y, indices,test_size=0.2,shuffle=True,random_state=42)


uv_train = uv[index_train]
vist_over_y_train = vist_over_y[index_train]
y_L_train = y_L[index_train]
c_k_train = c_k[index_train]

uv_test = uv[index_test]
vist_over_y_test = vist_over_y[index_test]
c_k_test = c_k[index_test]
y_L_test = y_L[index_test]

# Set up hyperparameters
learning_rate = 0.04 #   

my_batch_size = 1
N_epochs = 3
#N_epochs = 10000
N_epochs = 1000 

# convert the numpy arrays to PyTorch tensors with float32 data type
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# create PyTorch datasets and dataloaders for the training and validation sets
# a TensorDataset wraps the feature and target tensors into a single dataset
# a DataLoader loads the data in batches and shuffles the batches if shuffle=True
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=my_batch_size)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=my_batch_size)
#test_loader = DataLoader(test_dataset, shuffle=False)

class MyNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.input   = nn.Linear(2, 10)
        self.hidden1 = nn.Linear(10, 10)
        self.hidden2 = nn.Linear(10, 1)

    def forward(self, x):
        x = nn.functional.relu(self.input(x))
        x = nn.functional.relu(self.hidden1(x))
        x = self.hidden2(x)

        return x

class MyNet1(nn.Module):
  def __init__(self):
    super().__init__()
    self.ll1 = nn.Linear(in_features=2,out_features=10)
    self.ac1 = nn.ReLU()
    self.ll2 = nn.Linear(in_features=10,out_features=10)
    self.ac2 = nn.ReLU()
    self.ll3 = nn.Linear(in_features=10,out_features=10)
    self.ac3 = nn.ReLU()
    self.ll4 = nn.Linear(in_features=10,out_features=10)
    self.ac4 = nn.ReLU()
    self.output = nn.Linear(in_features=10,out_features=1)

  def forward(self,X):
    X = self.ll1(X)
    X = self.ac1(X)
    X = self.ll2(X)
    X = self.ac2(X)
    X = self.ll3(X)
    X = self.ac3(X)
    X = self.ll4(X)
    X = self.ac4(X)
    X = self.output(X)
    return(X)


class MyNet2(nn.Module):
  def __init__(self):
    super().__init__()
    self.ll1 = nn.Linear(in_features=1,out_features=10)
    self.tanh = nn.Tanh()
    self.ll2 = nn.Linear(in_features=10,out_features=10)
    self.ll3 = nn.Linear(in_features=10,out_features=10)
    self.output = nn.Linear(in_features=10,out_features=1)

  def forward(self,x):
        out = self.ll1(x)
        out = self.tanh(out)
        out = self.ll2(out)
        out = self.tanh(out)
        out = self.ll3(out)
        out = self.output(out)
        return out


class MyNet3(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(2, 10)
        self.layer_2 = nn.Linear(10, 10)
        self.layer_3 = nn.Linear(10, 1)
        self.tanh = nn.Tanh()         # Tanh activation function
       #
    def forward(self, x):
        #x = torch.nn.functional.relu(self.layer_1(x))
        x = torch.nn.functional.sigmoid(self.layer_1(x)) # Apply sigmoid activation to the first layer
        x = torch.nn.functional.sigmoid(self.layer_2(x))
        x = torch.nn.functional.sigmoid(self.layer_3(x))
#
        return x

class MyNet4(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(2, 50)
        self.layer_2 = nn.Linear(50, 50)
        self.layer_3 = nn.Linear(50, 1)
        self.tanh = nn.Tanh()         # Tanh activation function
       #
    def forward(self, x):
        #x = torch.nn.functional.relu(self.layer_1(x))
        x = torch.nn.functional.sigmoid(self.layer_1(x)) # Apply sigmoid activation to the first layer
        x = torch.nn.functional.sigmoid(self.layer_2(x))
        x = torch.nn.functional.sigmoid(self.layer_3(x))
#
        return x

class MyNet5(nn.Module):

    def __init__(self):
        
        super().__init__()

        self.input   = nn.Linear(2, 50)
        self.hidden1 = nn.Linear(50, 50)
        self.hidden2 = nn.Linear(50, 1)

    def forward(self, x):
        x = nn.functional.sigmoid(self.input(x))
        x = nn.functional.sigmoid(self.hidden1(x))
        x = self.hidden2(x)

        return x

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    print('in train_loop: len(dataloader)',len(dataloader))
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
# https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
#       optimizer.zero_grad(None)
        loss.backward()
        optimizer.step()


def test_loop(dataloader, model, loss_fn, loss_min):
    global pred_numpy,pred1,size1
    size = len(dataloader.dataset)
    size1 = size
    num_batches = len(dataloader)
    test_loss = 0
    print('in test_loop: len(dataloader)',len(dataloader))

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
#transform from tensor to numpy
            pred_numpy = pred.detach().numpy()

    test_loss /= num_batches

# Print the loss every epoch
    loss_min = np.minimum(test_loss,loss_min)
    torch.set_printoptions(precision=4)
    lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch+1}, Learning Rate: {lr}, Loss: {test_loss}, Loss_min: {loss_min}")


    return test_loss, loss_min


# In[7]:

start_time = time.time()

# Instantiate a neural network
neural_net = MyNet()  

# Initialize the loss function
loss_fn = nn.MSELoss()

# Choose loss function, check out https://pytorch.org/docs/stable/optim.html for more info
# In this case we choose Stocastic Gradient Descent
optimizer = torch.optim.SGD(neural_net.parameters(), lr=learning_rate)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=N_epochs)

loss_v = np.zeros(N_epochs)

loss_min = 1e30
for epoch in range(N_epochs):
    train_loop(train_loader, neural_net, loss_fn, optimizer)
    test_loss, loss_min = test_loop(test_loader, neural_net, loss_fn, loss_min)
    loss_v[epoch] = test_loss

print("Done!")

preds = neural_net(X_test_tensor)

#transform from tensor to numpy
c_k_NN = preds.detach().numpy()
 
c_k_NN_old = c_k_NN

c_k_NN=c_k_NN[:,0]


#c_NN = y_test_tensor.detach().numpy()

c_k_std=np.std(c_k_NN-c_k_test)/(np.mean(c_k_test.flatten()**2))**0.5

print('\nc_k_error_std',c_k_std)

error_all=abs(c_k_NN-c_k_test)
error_index= error_all.argsort()

error_sorted = error_all[error_index]
# largest error:
largest_error_percent = error_all[error_index[-1]]/c_k_test[error_index[-1]]
print('largest_error in c_k',largest_error_percent)

np.savetxt('error-neural-k-omega-c_k-vist_over_y-and-uv_tot.txt', [test_loss,c_k_std] )

filename = 'model-neural-k-omega-c_k-vist_over_y-and-uv_tot.pth'
torch.save(neural_net, filename)
#torch.save(neural_net.state_dict(), filename)
dump(scaler_uv,'model-scaler-uv_tot-neural-k-omega-c_k-vist_over_y-and-uv_tot.bin')
dump(scaler_vist_over_y,'model-scaler-vist_over_y-neural-k-omega-c_k-vist_over_y-and-uv_tot.bin')

uv_min = np.min(uv_tot_5200)
uv_max = np.max(uv_tot_5200)
# uv_max should not exceed one
uv_max =0.995
vist_over_y_min = np.min(vist_over_y_5200)
vist_over_y_max = np.max(vist_over_y_5200)
c_k_min = np.min(c_k)
c_k_max = np.max(c_k)
c_k_min = np.max(c_k_min,0)

np.savetxt('min-max-model-k-omega-c_k-vist_over_y-and-uv_tot.txt', \
 [uv_min, uv_max,vist_over_y_min,vist_over_y_max,c_k_min,c_k_max] )

np.savetxt('uv_tot-vist_over_y-c_k_NN',[uv_test[0,0],vist_over_y_test[0,0],c_k_NN[0]])


########################## uv vs y
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(y_5200,uv_tot_5200, 'b-',label='target')
plt.ylabel(r"$\overline{u'v'}$")
plt.xlabel("$y$")
plt.savefig('uv-vist_over_y-and-uv_tot.png')

########################## uv vs y zoom
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(y_5200,uv_tot_5200, 'b-',label='target')
plt.ylabel(r"$\overline{u'v'}$")
plt.xlabel("$y$")
plt.xlim(0,0.05)
plt.savefig('uv-vist_over_y-and-uv_tot-zoom.png')

########################## c_k vs y
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(y_L,c_k, 'bo',label='target')
plt.plot(y_L_test,c_k_NN, 'ro',label='NN')
plt.ylabel(r"$\sigma_{t NN}$")
plt.xlabel("$y$")
plt.legend(loc="best",fontsize=12)
plt.savefig('c_k-omega-c_k-vist_over_y-and-uv_tot.png')

########################## c_k zoom
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(y_L,c_k, 'bo',label='target')
plt.plot(y_L_test,c_k_NN, 'ro',label='NN')
plt.ylabel(r"$\sigma_{t NN}$")
plt.xlabel("$y$")
plt.xlim(0,0.1)
plt.legend(loc="best",fontsize=12)
plt.savefig('c_k-omega-c_k-vist_over_y-and-uv_tot-zoom.png')

########################## vist_over_y
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(y_L,vist_over_y, 'bo',label='target')
plt.ylabel(r"$L/y$")
plt.xlabel("$y$")
plt.legend(loc="best",fontsize=12)
plt.savefig('vist_over_y-k-omega-c_k-vist_over_y-and-uv_tot.png')



########################## loss, error
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(loss_v, 'b-')
plt.axis([100, len(loss_v),min(loss_v),loss_v[100]])
plt.xlabel("$y^+$")
plt.ylabel("loss")
plt.savefig('loss-omega-c_k-vist_over_y-and-uv_tot-zoom.png')


