import os.path, glob
from os.path import join
import json

import pandas as pd
import h5py
import numpy as np
import matplotlib.pyplot as pl

from helper import *
from style import *

path_data_repo = "../Fig_1/absorption_images"


single_realization_dict = {'image_1': {"name":"Exp_absimage_v_34632.03.h5",
                                       "ax": 0},
                           'image_2': {"name":"Exp_absimage_v_131601.73.h5",
                                       "ax": 1},
                           'image_3': {"name":"Exp_absimage_v_176623.38.h5",
                                       "ax": 2}
                           }

set_plot_style()

fontsize = 15
fontsize_ticks = 12

fig, axs = pl.subplots(nrows=3, ncols=1, figsize=(6.5,4), sharex=True)
ax = axs.flatten()

for key, val in single_realization_dict.items():
    # open image
    path = os.path.join(path_data_repo, val["name"])
    # open files
    h5_file = h5py.File(path, 'r')
    img = h5_file['image']

    # get meta_data
    pixel_size = h5_file.attrs['pixel_size']
    magnification = h5_file.attrs['magnification']
    velocity = h5_file.attrs['velocity'] # in Hz/ms
    barrier_position = h5_file.attrs['barrier_position'] # barrier position in µm , when the zero is in th center of the image

    # ste teh extent
    x_max = img.shape[1] * pixel_size / magnification
    y_max = img.shape[0] * pixel_size / magnification

    extent = (-x_max / 2, x_max / 2, -y_max / 2, y_max / 2)

    # plot the image at the specified plot position
    ax[val["ax"]].imshow(img, cmap=cmcrameri.cm.lipari, extent=extent, vmax=50)
    # calculate the current
    I_c = current(1, v_c=1)  # get the critical current from the func keyword parameter
    Cur = current(velocity * scaling, v_c=0.42) / I_c


    ax[val["ax"]].set_ylabel("z [µm]", fontsize=fontsize)

   # draw the barrier
    ax[val["ax"]].axvline(barrier_position, color = RPTU_COLORS["mango"], ls = '--', lw = 2)
    # draw the text box
    ax[val["ax"]].text(x=-35, y=-5.5, s=fr"I={Cur:.1f} $I_\mathrm{{c}}$", ha='center', verticalalignment='center',
                       fontsize=12,
                       bbox={'boxstyle': 'round', 'fc': "white", 'ec': RPTU_COLORS['schiefer'], 'ls': '-',
                             'lw': 1.5}, zorder=5)

    ax[val["ax"]].get_xaxis().set_ticks([-40, -20, 0, 20, 40])
    ax[val["ax"]].get_xaxis().set_ticklabels([-40, -20, 0, 20, 40], fontsize=fontsize_ticks)
    ax[val["ax"]].get_yaxis().set_ticks([-5, 0, 5])
    ax[val["ax"]].get_yaxis().set_ticklabels([-5, 0, 5], fontsize=fontsize_ticks)

ax[2].set_xlabel("x [µm]", fontsize=fontsize)
#fig.subplots_adjust(hspace=0.5, right=0.85)
fig.tight_layout()
pl.show()
# pl.savefig(join(side_image_mwi', "image_stack_single.png"))

