# file to evaluate a measurement
import os.path, glob
from os.path import join
import json
import pandas as pd
from scipy.ndimage import gaussian_filter
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as pl

import h5py
from helper import *
from style import *


path_data_repo = "../data_supplementary"


meas_list_ac = [
    {
        "path": "Excitation_vel_exp_AC_density_time_evolution_I_0.88.h5",
        "mod_amp": 0.6,
        "mod_freq": 90,
        "velocity": 0.37,
        "color": "purple",
        'v_max': 0.8,
        'x_min': 20,
        'x_max': 375,
        'label': r'$\mathrm{II}$',
        'dips': [
            {'slice':(65, 105),
             'y_max': 32,
             "y_min": 19,#16
             "p0": [-1.5, 25],
             "sigma": 3}
        ],
    },
    {
        "path": "Excitation_vel_exp_AC_density_time_evolution_I_1.31.h5",
        "mod_amp": 0.6,
        "mod_freq": 90,
        "velocity": 0.55,
        "color": "purple",
        'v_max': 0.8,
        'x_min': 10,
        'x_max': 365,
        'label': r'$\mathrm{III}$',
        'dips': [{'slice': (0, 80),
                  'y_max': 23.8,
                  "y_min": 9,
                  "p0": [-0.3, -25],
                  "sigma": 3},
                 {'slice': (65, 105),
                  'y_max': 32,
                  "y_min": 23,
                  "p0": [-1.5, 25],
                  "sigma": 3}]
    },
]

set_plot_style()

fig, ax = pl.subplots(nrows=1, ncols=2, sharey= True, figsize = (3.5 * 1.61,3),dpi = 300)
axs = ax.flatten()


c_s_R = 1.69 #  speed of sound in mm/s
c_s_L = -1.38 #  speed of sound in mm/s


cmap=cmcrameri.cm.vik
v_max = 1000
#for meas in meas_list_ac:
for a, meas in enumerate([meas_list_ac[-2],meas_list_ac[-1]]):
    path = join(path_data_repo, meas["path"])
    h5_file = h5py.File(path, 'r')
    print(h5_file.attrs.keys())
    image = h5_file['image']


    pixel_size = h5_file.attrs['pixel_size']
    magnification = h5_file.attrs['magnification']
    c = h5_file.attrs['conversion_to_um']
    y_max = h5_file.attrs['t_max']

    # transform the atomic density back to atom space
    image = np.asarray(image) / pixel_size * magnification # [N/µm] ; Pixel size = 6.45 µm/Pixel

    x_min = meas['x_min']
    x_max = meas['x_max']
    image = image[:, x_min:x_max]
    print(np.shape(image))


    x_max = c * image.shape[1]
    extent = (0, x_max, 0, y_max)


    # fig = pl.figure()
    # im = pl.imshow(image, cmap=cmcrameri.cm.vik, aspect='auto', origin='lower', extent=extent, vmin=-v_max, vmax=v_max)

    fit = True
    if fit:
        min_val = 0
        max_val = y_max
        smoothed = gaussian_filter(image, sigma=1)
        #
        velocities = []
        velocities_err = []

        axs[a].imshow(smoothed, cmap=cmap, vmin=-v_max, vmax=v_max, aspect='auto', origin='lower', extent=extent)
        axs[a].set_xlabel('x [µm]',fontsize = 12)

        axs[a].text(x=38, y=36, s=meas['label'], ha='center', verticalalignment='center', fontsize=10,
                    bbox={'boxstyle': 'round', 'fc': "white", 'ec': 'orange', 'ls': '-', 'lw': 1.5}, zorder=5)

        for dip in meas["dips"]:

            smoothed = gaussian_filter(image, sigma=dip["sigma"])
            mask = dip["slice"]

            #fig = pl.figure()
            #pl.imshow(smoothed[:, slice(*mask)], cmap='bwr', vmin=-v_max, vmax=v_max, aspect='auto', origin='lower')

            offset = mask[0]
            sign = np.sign(mask[1])
            mask = slice(*mask)
            #print(mask)
            x_vals = []
            y_vals = []
            #print(smoothed.shape[0] - 20)
            for i in range(smoothed.shape[0]):
                x_vals.append(c * (np.argmin(smoothed[i, mask]) + offset))
                y_vals.append(i / smoothed.shape[0] * (max_val - min_val))
            x_vals = np.asarray(x_vals)
            y_vals = np.asarray(y_vals)


            def linear(x, m, b):
                return x / m + b


            mask = ((y_vals < dip["y_max"]) & (y_vals > dip["y_min"]))
            p0 = dip["p0"]
            p_opt, p_cov = curve_fit(linear, x_vals[mask], y_vals[mask], p0=p0)
            p_err = np.sqrt(np.diag(p_cov))
            axs[a].errorbar(x_vals[mask], y_vals[mask], **get_style(color='orange'), markersize=5)
            axs[a].errorbar(x_vals[mask], linear(x_vals[mask], *p_opt), color='forestgreen', lw=2.5)
            velocities.append(round(p_opt[0], 2))
            velocities_err.append(round(abs(p_err[0]), 2))

            # plot speed of sound
            def linear_c_s(x,m,x0,b):
                return (x-x0) /m +b

            x_speed_of_sound_R = np.linspace(20,40,1000)
#            axs[0].plot(x_speed_of_sound_R,linear_c_s(x_speed_of_sound_R,c_s_R,b=-11.5),color=RPTU_COLORS['himbeere'],lw = 2.5)
            axs[0].plot(x_speed_of_sound_R, linear_c_s(x_speed_of_sound_R, c_s_R, x0=20,b=0),
                        color=RPTU_COLORS['himbeere'], lw=2.5)

            x_speed_of_sound_L = np.linspace(3, 20, 1000)
            axs[0].plot(x_speed_of_sound_L, linear_c_s(x_speed_of_sound_L, c_s_L,x0=20, b=0), color=RPTU_COLORS['himbeere'],lw=2.5)

        print(rf'Excitation velocities: {velocities} $\pm$ {velocities_err} mm/s')


axs[0].set_ylabel("Time [ms]",fontsize = 12)

fig.tight_layout()
# fig.savefig(f"soliton_velocity.png")


pl.show()