# coding: utf-8

# In[1]:
import os
import time
import itertools
import mako.template
import numpy as np
import matplotlib
matplotlib.use('Qt4Agg')

import matplotlib.pyplot as plt
#%matplotlib inline


# In[2]:

dxmin =5
dxmax=20
dy=20
ny=50
hmin=.1
damheight=2
Hrms=1.5
Tm01=10
dir0=250
zs0=0


# In[3]:

xz = np.array([
        [-500, -10],
        [-200, -5],
        [0, 0],
        [100, 5]
    ])
#np.savetxt('profile.txt', xz)


# In[4]:

xdata=xz[:,0]
zdata=xz[:,1]

xgrid = []
zgrid = []

xgrid.append(xdata[0])
zgrid.append(zdata[0])
xgrid


# In[5]:

h0 = -zdata[0]

i=0
while xgrid[i] < xdata[-1]:
    i=i+1
    h=max(-zgrid[i-1],hmin)
    dx=dxmax*np.sqrt(h/h0)
    dx=max(dx, dxmin)
    xgrid.append(xgrid[i-1]+dx)
    zgrid.append(np.interp(xgrid[i], xdata,zdata))


# In[7]:

fig, ax = plt.subplots()
ax.plot(xdata,zdata,'+')
ax.plot(xgrid,zgrid,'.')
ax.set_xlabel('X (m)')
ax.set_ylabel('Zb (m)')
ax.set_title ('Cross-shore profile and grid')
plt.show()


# In[11]:




# In[84]:

x = np.zeros((len(xgrid),ny+1))
y = np.zeros_like(x)
zb = np.zeros_like(x)
x, y, zb


# compute grid

for j in range(ny+1):
    x[:,j] = xgrid
    y[:,j] = j*dy
    zb[:,j] = zgrid

ne_layer = np.zeros_like(zb) + 10.0

# save grid
dirname = '/Users/baart_f/models/xbeach/interactive2'
np.savetxt(os.path.join(dirname, 'x.dep'), x.T)
np.savetxt(os.path.join(dirname, 'y.dep'), y.T)
np.savetxt(os.path.join(dirname, 'zb.dep'), zb.T)
np.savetxt(os.path.join(dirname, 'ne_layer.dep'), ne_layer.T)


template = mako.template.Template(filename=os.path.join(dirname, 'params_template.txt'))

with open('/Users/baart_f/models/xbeach/interactive2/params.txt', 'w') as f:
    f.write(template.render(nx=x.shape[0]-1, ny=x.shape[1]-1))

from bmi.wrapper import BMIWrapper
wrapper = BMIWrapper(engine="xbeach", configfile='/Users/baart_f/models/xbeach/interactive2/params.txt')
wrapper.initialize()
# interaction

zs0 = wrapper.get_var('zs').copy()
zb0 = wrapper.get_var('zb').copy()
uu0 = wrapper.get_var('uu').copy()
vv0 = wrapper.get_var('vv').copy()

for i in range(500):
    wrapper.update(-1)

zb = wrapper.get_var('zb')
plt.ion()
fig, ax = plt.subplots()
contour = ax.contour(x.T,y.T, zb);

u = wrapper.get_var('u')
v = wrapper.get_var('v')
H = wrapper.get_var('H')
pc_H = ax.pcolor(x.T,y.T, H, cmap='Blues');

print(u.shape, v.shape, x.shape, y.shape)
qv = ax.quiver(x.T, y.T, u, v, units='xy')
ax.axis('equal')
plt.colorbar(pc_H)
plt.show()


import shapely.geometry

plt.draw()

# for i in range(500):
#     wrapper.update(-1)
while True:
    for i in range(10):
        wrapper.update(-1)
    ax.set_title('time: %.2f' % (wrapper.get_current_time() / (24*3600.0)))
    contour.set_array(zb[:-1,:-1].flatten())
    pc_H.set_array(H[:-1,:-1].flatten())
    qv.set_UVC(u, v)
    fig.canvas.draw()
    selected_point = plt.ginput(timeout=0.01)
    points = []
    while selected_point:
        selected_point = plt.ginput(1, mouse_stop=3, mouse_pop=2)
        if selected_point:
            points.append(selected_point[0])
        else:
            poly = shapely.geometry.Polygon(points)
            inpoly = np.zeros_like(x, dtype='bool')
            for (i,j), x_ij in np.ndenumerate(x):
                y_ij = y[i,j]
                pt = shapely.geometry.Point(x_ij, y_ij)
                inpoly[i,j] = poly.contains(pt)
            zb = wrapper.get_var('zb')

            zb[inpoly.T] = damheight

        if points:
            plt.plot([pt[0] for pt in points], [pt[1] for pt in points], '+-')
