import torch.nn as nn
from joblib import dump, load
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import scipy.io as sio
import sys
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import ticker
plt.rcParams.update({'font.size': 22})
plt.rcParams.update({'figure.max_open_warning': 0})

plt.interactive(True)

plt.close('all')
viscos=1/10000


# makes sure figures are updated when using ipython

datax= np.loadtxt("x2d.dat")
x=datax[0:-1]
ni=int(datax[-1])
datay= np.loadtxt("y2d.dat")
y=datay[0:-1]
nj=int(datay[-1])

x2d=np.zeros((ni+1,nj+1))
y2d=np.zeros((ni+1,nj+1))

x2d=np.reshape(x,(ni+1,nj+1))
y2d=np.reshape(y,(ni+1,nj+1))

# compute cell centers
xp2d=0.25*(x2d[0:-1,0:-1]+x2d[0:-1,1:]+x2d[1:,0:-1]+x2d[1:,1:])
yp2d=0.25*(y2d[0:-1,0:-1]+y2d[0:-1,1:]+y2d[1:,0:-1]+y2d[1:,1:])

y=yp2d[0,:]
y_s=y2d[0,:]


u2d=np.load('u2d_saved.npy')
p2d=np.load('p2d_saved.npy')
v2d=np.load('v2d_saved.npy')
k2d=np.load('k2d_saved.npy')
om2d=np.load('om2d_saved.npy')
vis2d=np.load('vis2d_saved.npy')
vis2d_earsm=np.load('vis2d_earsm_saved.npy')
uu2d=np.load('uu2d_saved.npy')
vv2d=np.load('vv2d_saved.npy')
ww2d=np.load('ww2d_saved.npy')
uv2d=np.load('uv2d_saved.npy')
beta1=np.load('beta1_saved.npy')
diss_saved=np.load('diss_saved.npy')
beta2=np.load('beta2_saved.npy')
beta4=np.load('beta4_saved.npy')
pk=np.load('pk_saved.npy')
dudy_squared_scaled=np.load('dudy_squared_saved.npy')


# average in x direction
u=np.mean(u2d,axis=0)
v=np.mean(v2d,axis=0)
k=np.mean(k2d,axis=0)
om=np.mean(om2d,axis=0)
vis=np.mean(vis2d,axis=0)
diss_saved=np.mean(diss_saved,axis=0)
beta1=np.mean(beta1,axis=0)
beta2=np.mean(beta2,axis=0)
beta4=np.mean(beta4,axis=0)
vis_earsm=np.mean(vis2d_earsm,axis=0)
vis=np.mean(vis2d,axis=0)
pk=np.mean(pk,axis=0)
eps=0.09*k*om
print('uu2d[0,30],vv2d[0,30]',uu2d[0,30],vv2d[0,30])
uu=np.mean(uu2d,axis=0)+0.666*k
vv=np.mean(vv2d,axis=0)+0.666*k
ww=np.mean(ww2d,axis=0)+0.666*k
uv=np.mean(uv2d,axis=0)

vist_earsm=vis_earsm-viscos

dudx = np.zeros(nj)
dvdy = np.zeros(nj)
dvdx = np.zeros(nj)
dudy = np.gradient(u,y)

uu_tot = uu - vist_earsm*dudx
vv_tot = vv - vist_earsm*dvdy
uv_tot = uv - vist_earsm*(dudy+dvdx)



dudy=np.gradient(u,y)

ustar=(viscos*u[0]/y[0])**0.5
yplus=y*ustar/viscos
yplus_s=y_s*ustar/viscos

############################ 10000
# load DNS data
# %     y/h             y+             U+           u'+           v'+          w'+           uv'+         dU/dy+
DNS_mean=np.genfromtxt("P10k.txt",comments="%")
y_DNS=DNS_mean[:,0];
yplus_DNS=DNS_mean[:,1];
u_DNS=DNS_mean[:,2];
u2_DNS=DNS_mean[:,3]**2;
v2_DNS=DNS_mean[:,4]**2;
w2_DNS=DNS_mean[:,5]**2;
uv_DNS=DNS_mean[:,6];
dudy_DNS= np.gradient(u_DNS,yplus_DNS)
k_DNS=0.5*(u2_DNS+v2_DNS+w2_DNS)
# %      y/h            y+         dissip        prod         p-strain       p-diff        T-diff        V-diff
DNS_uu = np.genfromtxt("P10k.uu.txt",comments="%")
eps_DNS_uu = abs(DNS_uu[:,2])
visc_diff_uu =  DNS_uu[:,7]

DNS_vv = np.genfromtxt("P10k.vv.txt",comments="%")
eps_DNS_vv = abs(DNS_vv[:,2])
visc_diff_vv =  DNS_vv[:,7]

DNS_ww = np.genfromtxt("P10k.uu.txt",comments="%")
eps_DNS_ww = abs(DNS_ww[:,2])
visc_diff_ww =  DNS_ww[:,7]

diss_DNS= (eps_DNS_uu +eps_DNS_vv +eps_DNS_ww)/2
visc_diff = (visc_diff_uu +visc_diff_vv +visc_diff_ww)/2


# fix wall
diss_DNS[0]=diss_DNS[1]
k_DNS[0]=k_DNS[1]

dudy_DNS  = np.gradient(u_DNS,yplus_DNS)

pk_DNS = -uv_DNS*dudy_DNS
dkdy=np.gradient(k_DNS,yplus_DNS)
d2kdy2=np.gradient(dkdy,yplus_DNS)
diss_DNS_org = diss_DNS
diss_DNS = np.maximum(diss_DNS-visc_diff,0)

tau_DNS = k_DNS/diss_DNS

dudy_squared_DNS  = dudy_DNS**2
dudy_squared_scaled_DNS  = dudy_DNS**2*tau_DNS**2



vist_DNS=abs(uv_DNS)/dudy_DNS

omega_DNS_from_vist=k_DNS/vist_DNS
# viscos = m2/s, diss = m2/s3 => omega => diss/viscos)**0.5
# vist = cmu*k2/eps => k = (vist*eps/cmu)**0.5
# omega = eps/k = eps/(vist*eps/cmu)**0.5 = (eps/vist/cmu)**0.5

# find equi.distant DNS cells in log-scale
xx=0.
jDNS=[1]*40
for i in range (0,40):
   i1 = (np.abs(10.**xx-yplus_DNS)).argmin()
   jDNS[i]=int(i1)
   xx=xx+0.2

rt = k/om/viscos
fk = (0.278+(rt/8)**4)*(1+(rt/8)**4)**(-1)
fk_12 = (0.278+(rt/12)**4)*(1+(rt/12)**4)**(-1)
fmu = (0.025+rt/6)*(1+rt/6)**(-1)
fom = (0.1+rt/2.7)*((1+rt/2.7)*fmu)**(-1)


rt = k/om/viscos
fk=1-0.722*np.exp(-(rt/10)**4)
diss=0.09*k*om
fmu=0.025+(1-np.exp(-(rt/10)**0.75))*(0.975+(1.e-3/rt)*np.exp(-(rt/200)**2))

vist = vis-viscos
pk_plot = vist*dudy**2*viscos




########################################## time scale
rt = k/om/viscos
fk=1-0.722*np.exp(-(rt/10)**4)
diss=0.09*k*om
ttau = k/diss
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(yplus,ttau/viscos,'b-',label='RANS, NN')
plt.plot(yplus_DNS,tau_DNS,'r--',label='DNS')
plt.legend(loc="best",prop=dict(size=14))
plt.ylabel(r"$t^+$")
plt.xlabel("$y^+$")
plt.axis([0, 500, 0, 1100])
plt.savefig('t_10000-channel-zoom.png',bbox_inches='tight')


########################################## dudy_squared_scaled
dudy_DNS=np.gradient(u_DNS,yplus_DNS)
dudy_DNS_scaled = dudy_DNS*tau_DNS
dudy_scaled = dudy*ttau
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(yplus,dudy_scaled**2,'b-',label='RANS, NN')
plt.plot(yplus_DNS,dudy_DNS_scaled**2,'r--',label='DNS')
plt.legend(loc="best",prop=dict(size=14))
plt.ylabel(r"$\left(\partial U/\partial y\right)^2$")
plt.xlabel("$y^+$")
plt.axis([0, 400, 0, 100])
plt.savefig('dudy_squared_scaled_10000-channel-zoom.png',bbox_inches='tight')



########################################## diss
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(yplus,diss_saved*viscos,'b--',label='RANS, NN')
plt.plot(yplus_DNS,diss_DNS,'r-',label='DNS')
plt.legend(loc="best",prop=dict(size=14))
plt.xlabel("$y^+$")
plt.ylabel(r"$\varepsilon^+$")
plt.axis([0,200,0,0.5])
plt.savefig('diss.png',bbox_inches='tight')





########################################## pk
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
plt.plot(pk,yplus,'b-',label='NN')
plt.plot(pk_plot,yplus,'b--',label='NN-plot')
plt.plot(pk_DNS,yplus_DNS,'r-',label='DNS')
plt.xlabel("$y^+$")
plt.legend(loc="best",prop=dict(size=14))
plt.xlabel(r"$P^k$")
plt.axis([0,0.3, 0, 140])
plt.savefig('pk.png',bbox_inches='tight')



########################################## U 
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.semilogx(yplus,u,'b-')
ax1.yaxis.set_label_coords(-.1, 0.5)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.semilogx(yplus_DNS[jDNS],u_DNS[jDNS],'b--')
plt.ylabel("$U^+$")
plt.xlabel("$y^+$")
plt.axis([1, 10000, 0, 30])
plt.savefig('u_log_10000-channel.png',bbox_inches='tight')



########################################## U  lin
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
plt.plot(yplus,u,'b-')
ax1.yaxis.set_label_coords(-.1, 0.5)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.semilogx(yplus_DNS[jDNS],u_DNS[jDNS],'b--')
plt.ylabel("$U^+$")
plt.xlabel("$y^+$")
plt.axis([1, 100, 0, 30])
plt.savefig('u_lin_10000-channel-zoom.png',bbox_inches='tight')


########################################## uv 
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(uv_tot,yplus,'b-',label = 'EARSM')
plt.plot(uv_DNS,yplus_DNS,'b--')
ax1.yaxis.set_label_coords(-.1, 0.6)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.xlabel(r"$\overline{u'v'}$")
plt.ylabel("$y^+$")
plt.axis([-1,0,0,10000])
plt.savefig('uv_10000-channel.png',bbox_inches='tight')


########################################## d(uv)/dy
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
# compute shear stress
vist=vis-viscos
dudy=np.gradient(u,y)
uv_bouss=-vist*dudy
duv_dy = np.gradient(uv_tot,y)
duv_dy_bouss = np.gradient(uv_bouss,y)
duv_dy_DNS = np.gradient(uv_DNS,y_DNS)

plt.plot(yplus,duv_dy,'b-',label = 'EARSM')
plt.plot(yplus_DNS[jDNS],duv_dy_DNS[jDNS],'bo')
plt.ylabel(r"$\frac{\partial\overline{u'v'}}{\partial y}$")
plt.xlabel("$y^+$")
plt.axis([1, 1100, -1, 2])
plt.savefig('duv_dy-10000-channel.png',bbox_inches='tight')


########################################## k 
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(yplus,k,'b-',label='RANS, NN, EARSM')
plt.plot(yplus_DNS,k_DNS,'r--',label='DNS')
plt.legend(loc="best",prop=dict(size=14))
plt.ylabel(r"$k^+$")
plt.xlabel("$y^+$")
plt.axis([1, 10000, -1.1, 6.5])
plt.savefig('k_10000-channel.png',bbox_inches='tight')

########################################## k  zoom
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(yplus,k,'b-',label='RANS, NN')
plt.plot(yplus_DNS,k_DNS,'r--',label='DNS')
plt.legend(loc="best",prop=dict(size=14))
plt.ylabel(r"$k^+$")
plt.xlabel("$y^+$")
plt.axis([1, 200, 0, 5.5])
plt.savefig('k_10000-channel-zoom.png',bbox_inches='tight')

########################################## vis 
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
plt.plot(yplus,vist_earsm/viscos,'r-',label='RANS, EARSM')
plt.plot(yplus,vist/viscos,'b-',label=r'RANS, $k-\omega$')
plt.legend(loc="best",prop=dict(size=14))
plt.ylabel(r"$\nu_t/\nu$")
plt.xlabel("$y^+$")
plt.axis([0, 10000, 0, 1500])
plt.savefig('vis_10000-channel.png',bbox_inches='tight')



########################################## uu, vv, ww 
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(uu,yplus,'b-',label=r"$\overline{u'^2}$")
plt.plot(vv,yplus,'r-',label=r"$\overline{v'^2}$")
plt.plot(ww,yplus,'k-',label=r"$\overline{w'^2}$")
plt.plot(u2_DNS,yplus_DNS,'b--',label=r"DNS, $\overline{u'^2}$")
plt.plot(v2_DNS,yplus_DNS,'r--',label=r"DNS, $\overline{v'^2}$")
plt.plot(w2_DNS,yplus_DNS,'k--',label=r"DNS, $\overline{w'^2}$")
plt.ylabel("$y^+$")
plt.legend(loc="best",prop=dict(size=14))
ax1.yaxis.set_label_coords(-.1, 0.6)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.ylabel("$y^+$")
plt.xlabel("normal stresses")
plt.axis([0,9.5,0, 10000])
plt.savefig('uu-vv-ww-10000-channel.png',bbox_inches='tight')

########################################## uu, vv
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(uu,yplus,'b-',label=r"$\overline{u'^2}$")
plt.plot(vv,yplus,'r-',label=r"$\overline{v'^2}$")
plt.plot(u2_DNS,yplus_DNS,'b--',label=r"DNS, $\overline{u'^2}$")
plt.plot(v2_DNS,yplus_DNS,'r--',label=r"DNS, $\overline{v'^2}$")
plt.ylabel("$y^+$")
plt.legend(loc="best",prop=dict(size=14))
ax1.yaxis.set_label_coords(-.1, 0.6)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.ylabel("$y^+$")
plt.xlabel("normal stresses")
plt.axis([0,9.5,0, 10000])
plt.savefig('uu-vv-10000-channel.png',bbox_inches='tight')

########################################## (uu+vv+ww)/2 and k
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(0.5*(uu+vv+ww),yplus,'b-',label=r"$\overline{u'^2}+\overline{v'^2}+\overline{w'^2}$")
plt.plot(k,yplus,'k-',label="$k$")
plt.ylabel("$y^+$")
plt.legend(loc="best",prop=dict(size=14))
ax1.yaxis.set_label_coords(-.1, 0.6)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.ylabel("$y^+$")
plt.xlabel("normal stresses")
plt.axis([0,9.5,0, 10000])
plt.savefig('uu-vv-ww-and-k-10000-channel.png',bbox_inches='tight')


########################################## uu, vv, ww  zoom
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(uu,'b-',yplus,label=r"$\overline{u'^2}$")
plt.plot(vv,'r-',yplus,label=r"$\overline{v'^2}$")
plt.plot(ww,'k-',yplus,label=r"$\overline{w'^2}$")
plt.plot(u2_DNS,yplus_DNS,'b--',label=r"DNS, $\overline{u'^2}$")
plt.plot(v2_DNS,yplus_DNS,'r--',label=r"DNS, $\overline{v'^2}$")
plt.plot(w2_DNS,yplus_DNS,'k--',label=r"DNS, $\overline{w'^2}$")
plt.legend(loc="best",prop=dict(size=14))
plt.xlabel("normal stresses")
plt.ylabel("$y^+$")
plt.axis([1, 100, 0, 9.5])
plt.savefig('uu-vv-ww-10000-channel-zoom.png',bbox_inches='tight')


########################################## beta1
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
plt.plot(beta1,yplus,'b-',label=r'$\beta_1$')
plt.plot(beta1,yplus,'bo')
cmu2 = -0.18*np.ones(len(yplus))
plt.plot(cmu2,yplus,'b--',label=r'$-2 C_\mu$')
ax1.yaxis.set_label_coords(-.1, 0.6)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.legend(loc="best",prop=dict(size=14))
plt.ylabel("$y^+$")
plt.axis([-0.2,-0.15, 0, 10000])
plt.savefig('beta1-10000-channel.png',bbox_inches='tight')



########################################## beta
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
plt.plot(beta2,yplus,'b-')
plt.xlabel("$y^+$")
plt.xlabel(r"$\beta_2$")
plt.axis([0,0.2, 0, 10000])
plt.savefig('beta2-10000-channel.png',bbox_inches='tight')

########################################## beta
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
plt.plot(beta4,yplus,'b-')
plt.xlabel("$y^+$")
plt.xlabel(r"$\beta_4$")
plt.axis([-0.1,0, 0, 10000])
plt.savefig('beta4-10000-channel.png',bbox_inches='tight')


########################################## beta1, 2, 4
fig1,ax1 = plt.subplots()
plt.subplots_adjust(left=0.20,bottom=0.20)
i1 = 0
plt.plot(beta1,yplus,'b-',label =r'$\beta_1$')
plt.plot(beta2,yplus,'r-',label =r'$\beta_2$')
plt.plot(beta4,yplus,'k-',label =r'$\beta_4$')
ax1.yaxis.set_label_coords(-.1, 0.6)
M = 3
yticks = ticker.MaxNLocator(M)
ax1.yaxis.set_major_locator(yticks)
plt.ylabel("$y^+$")
plt.xlabel(r"$\beta_1, \quad \beta_2, \quad \beta_4$")
plt.axis([-0.2,0.7, 0, 10000])
plt.legend(loc="best",prop=dict(size=14))
plt.savefig('beta1-2-4-10000-channel.png',bbox_inches='tight')

