#!/usr/bin/env python
# coding: utf-8

# # Setup 🏗️
# 

# In[1]:


import numpy as np
import torch 
import sys 
import time
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
#from sklearn.discriminant_analysis import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from random import randrange
from joblib import dump, load
from matplotlib import ticker


plt.rcParams.update({'font.size': 13})
plt.interactive(True)
plt.close('all')

viscos = 1/5200

init_time = time.time()

# 10
data_10 = np.loadtxt('yplus-uplus-ustar-many-j-points-channel-5200.txt')
yplus_10 = data_10[:,0]
uplus_10 = data_10[:,1]
ustar_10 = data_10[:,2]
pplus_10 = data_10[:,3]
dudy= data_10[:,4]


##### output
uplus = uplus_10
ustar = ustar_10
u2d = uplus*ustar


##### input
yplus = yplus_10
pplus = pplus_10
y2d = yplus*viscos/ustar


##### output
y = uplus

##### input
# re-shape
yplus = yplus.reshape(-1,1)
pplus = pplus.reshape(-1,1)
dudy = dudy.reshape(-1,1)
y = y.reshape(-1,1)
# use MinMax scaler
scaler_yplus = MinMaxScaler()
scaler_pplus = MinMaxScaler()
scaler_dudy = MinMaxScaler()
X=np.zeros((len(yplus),1))
print('type(yplus)',type(yplus))
print('yplus.shape',yplus.shape)

X[:,0] = scaler_yplus.fit_transform(yplus)[:,0]
#X[:,1] = scaler_pplus.fit_transform(pplus)[:,0]
#X[:,2] = scaler_dudy.fit_transform(dudy)[:,0]


# split the feature matrix and target vector into training and validation sets
# test_size=0.2 means we reserve 20% of the data for validation
# random_state=42 is a fixed seed for the random number generator, ensuring reproducibility

random= randrange(100)



indices = np.arange(len(X))
# X_train, X_test, y_train, y_test, index_train, index_test = train_test_split(X, y, indices,test_size=0.2,shuffle=True,random_state=random)
X_train, X_test, y_train, y_test, index_train, index_test = train_test_split(X, y, indices,test_size=0.2,shuffle=True,random_state=46)


yplus_train = yplus[index_train]
uplus_train = uplus[index_train]
pplus_train = pplus[index_train]
u2d_train = u2d[index_train]
dudy_train = dudy[index_train]


yplus_test = yplus[index_test]
uplus_test = uplus[index_test]
pplus_test = pplus[index_test]
u2d_test = u2d[index_test]
dudy_test = dudy[index_test]


# Set up hyperparameters
learning_rate = 3e-1
learning_rate = 5e-1
learning_rate = 0.005
learning_rate = 0.002  
learning_rate = 0.01 # 1000: 4.11e-02
learning_rate = 0.1  # gets stuck on 1
learning_rate = 0.001 
#learning_rate = 0.1
my_batch_size = 17
#my_batch_size = 30
my_batch_size = 1
#my_batch_size = 3

epochs = 30
#epochs = 5000
epochs = 10000
epochs = 1000
epochs = 3000
#epochs = 3





# convert the numpy arrays to PyTorch tensors with float32 data type
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# create PyTorch datasets and dataloaders for the training and validation sets
# a TensorDataset wraps the feature and target tensors into a single dataset
# a DataLoader loads the data in batches and shuffles the batches if shuffle=True
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=my_batch_size)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=my_batch_size)
#test_loader = DataLoader(test_dataset, shuffle=False)

class ThePredictionMachine(nn.Module):

    def __init__(self):
        
        super(ThePredictionMachine, self).__init__()

        self.input   = nn.Linear(1, 50) # axis 0: dimension of X
        self.hidden1 = nn.Linear(50, 50)
        self.hidden2 = nn.Linear(50, 1) # axis 1: dimension of y
#
#       self.input   = nn.Linear(1, 10) # axis 0: dimension of X
#       self.hidden1 = nn.Linear(10, 10)
#       self.hidden2 = nn.Linear(10, 10)
#       self.hidden3 = nn.Linear(10, 10)
#       self.hidden4 = nn.Linear(10, 10)
#       self.hidden5 = nn.Linear(10, 10)
#       self.hidden6 = nn.Linear(10, 10)
#       self.hidden7 = nn.Linear(10, 10)
#       self.hidden8 = nn.Linear(10, 1) # axis 1: dimension of y


    def forward(self, x):
        x = nn.functional.relu(self.input(x))
        x = nn.functional.relu(self.hidden1(x))
        x = self.hidden2(x)

#       x = nn.functional.relu(self.input(x))
#       x = nn.functional.relu(self.hidden1(x))
#       x = nn.functional.relu(self.hidden2(x))
#       x = nn.functional.relu(self.hidden3(x))
#       x = nn.functional.relu(self.hidden4(x))
#       x = nn.functional.relu(self.hidden5(x))
#       x = nn.functional.relu(self.hidden6(x))
#       x = nn.functional.relu(self.hidden7(x))
#       x = self.hidden8(x)

        return x

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    print('in train_loop: len(dataloader)',len(dataloader))
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        pred_numpy = pred.detach().numpy()
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
# https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
#       optimizer.zero_grad(None)
        loss.backward()
        optimizer.step()


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    print('in test_loop: len(dataloader)',len(dataloader))

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
#transform from tensor to numpy
            pred_numpy = pred.detach().numpy()

    test_loss /= num_batches

    print(f"Avg loss: {test_loss:>.2e} \n")

    return test_loss

start_time = time.time()

# Instantiate a neural network
neural_net = ThePredictionMachine()


# Initialize the loss function
loss_fn = nn.MSELoss()

# Choose loss function, check out https://pytorch.org/docs/stable/optim.html for more info
# In this case we choose Stocastic Gradient Descent
optimizer = torch.optim.SGD(neural_net.parameters(), lr=learning_rate)

loss_v = np.zeros(epochs)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, neural_net, loss_fn, optimizer)
    test_loss = test_loop(test_loader, neural_net, loss_fn)
    loss_v[t] = test_loss
print("Done!")

preds = neural_net(X_test_tensor)

print(f"{'time ML: '}{time.time()-start_time:.2e}")

#transform from tensor to numpy
uplus_NN = preds.detach().numpy()
 
uplus_std=np.std(uplus_test-uplus_NN[:,0])/(np.mean(uplus_test.flatten()**2))**0.5

error_all_uplus=abs(uplus_test-uplus_NN[:,0])
error_index_uplus= error_all_uplus.argsort()
error_sorted_uplus = error_all_uplus[error_index_uplus]

uplus_max=np.max(abs(uplus_test-uplus_NN[:,0]))
# largest error:
largest_error_percent_uplus = error_all_uplus[error_index_uplus[-1]]/uplus[error_index_uplus[-1]]
print('largest_error_percent in uplus',largest_error_percent_uplus)


# u2d = uplus*ustar
ustar_test = u2d_test/uplus_test
ustar_NN = u2d_test/uplus_NN
ustar_std=np.std(ustar_test-ustar_NN[:,0])/(np.mean(ustar_test.flatten()**2))**0.5
ustar_max=np.max(abs(ustar_test-ustar_NN[:,0]))

print('\nuplus_error_std',uplus_std)
print('\nustar_error_std',ustar_std)


ind = index_test[0]

np.savetxt('loss-v-channel-5200-only-yplus-IDDES.txt',loss_v)
np.savetxt('yplus-pplus-index-1-channel-5200-only-yplus-IDDES.txt', np.c_[yplus[ind,0],pplus[ind,0],uplus_NN[0],ind] )

np.savetxt('error-channel-5200-only-yplus-IDDES.txt', np.c_[test_loss,uplus_std,uplus_max,ustar_std,ustar_max,largest_error_percent_uplus] )


filename = 'model-channel-5200-only-yplus-IDDES.pth'
torch.save(neural_net, filename)
#torch.save(neural_net.state_dict(), filename)
dump(scaler_yplus,'scaler-yplus-channel-5200-only-yplus-IDDES.bin')
dump(scaler_pplus,'scaler-pplus-channel-5200-only-yplus-IDDES.bin')
dump(scaler_dudy,'scaler-dudy-channel-5200-only-yplus-IDDES.bin')

yplus_max = np.max(yplus_test)
yplus_min = np.min(yplus_test)
pplus_min = np.min(pplus_test)
pplus_max = np.max(pplus_test)
dudy_min = np.min(dudy_test)
dudy_max = np.max(dudy_test)
uplus_min = np.min(uplus_test)
uplus_max = np.max(uplus_test)


np.savetxt('min-max-model-channel-5200-only-yplus-IDDES.txt', np.c_[yplus_min, yplus_max, pplus_min, pplus_max,dudy_min, dudy_max,uplus_min,uplus_max])




#################### 3D scatter
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter3D(yplus_10, uplus_10, pplus_10, marker='o', s=3.2,c=pplus_10)
ax.view_init(elev=41, azim=127)

ax.set_xlim3d(0,150)
ax.set_ylim3d(0,12)
ax.set_zlim3d(0,0.04)
M = 3
xticks = ticker.MaxNLocator(M)
ax.xaxis.set_major_locator(xticks)
yticks = ticker.MaxNLocator(M)
ax.yaxis.set_major_locator(yticks)
zticks = ticker.MaxNLocator(M)
ax.zaxis.set_major_locator(zticks)
#plt.box(None)
# make the panes transparent
ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))


# make the grid lines transparent
ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

#label axes
ax.set_xlabel('$y^+$')
ax.set_ylabel(r'$u^+$')
ax.set_zlabel(r'$p^+$')
plt.savefig('3Dscatter-yplus-uplus-pplus-many-points-also-dudy-alpha-10-IDDES.png',bbox_inches='tight')

########################## uplus
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
#plt.plot(yplus[index_train],uplus[index_train], 'kv',label='train')
#plt.plot(yplus[index_test],uplus[index_test], 'b+',label='target')
plt.scatter(yplus_test, uplus_NN, marker='*', s=40.2,c=pplus_test,label='NN')
plt.scatter(yplus_test, uplus_test, marker='o', s=40.2,c=pplus_test,label='target')
#plt.plot(yplus[index_test],uplus[index_test], 'ro',label='target')
plt.xlabel("$y^+$")
plt.ylabel("$U^+$")
plt.legend(loc="best",fontsize=12)
plt.savefig('uplus-channel-5200-only-yplus-IDDES.png')

########################## uplus, error
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
#plt.plot(yplus[index_test],uplus[index_test], 'b+',label='target')
error_all=abs(uplus_test-uplus_NN[:,0])
error_index= error_all.argsort()
cs = plt.scatter(yplus_test, error_all, marker='o', s=40.2,c=pplus_test,label='NN')
plt.colorbar(cs,label = r'$\partial \bar{p}/\partial x$')
#plt.plot(yplus[index_test],uplus[index_test], 'ro',label='target')
plt.xlabel("$y^+$")
plt.ylabel("$U^+$_{error}")
plt.savefig('error-channel-5200-only-yplus-IDDES.png')



########################## loss, error
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(loss_v, 'b-')
plt.axis([1000, len(loss_v),min(loss_v),loss_v[1000]])
plt.xlabel("$y^+$")
plt.ylabel("loss")
plt.savefig('loss-channel-5200-only-yplus-IDDES.png')


