import os.path, glob
from os.path import join
import json

import pandas as pd
from scipy.ndimage import gaussian_filter1d
from scipy.optimize import minimize_scalar, curve_fit
from scipy.signal import find_peaks
from scipy.special import jv
import scipy.constants as c

import numpy as np
import matplotlib.pyplot as pl

from helper import *
from style import *


path_data_repo = "../data_supplementary"

density = pd.read_csv(join(path_data_repo,f"Fig_S3_Josephson_oscillation_density_data.csv"))
phase = pd.read_csv(join(path_data_repo,f"Fig_S3_Josephson_oscillation_phase_data.csv"))




sigma = 3
####################################################################################
####################################################################################
# Density

density.set_index('Unnamed: 0', inplace=True)
# Shift density around zero.
density['delta_N'] -= density['delta_N'].mean()
density['delta_N_s'] = gaussian_filter1d(density['delta_N'], sigma=sigma)
density[['dN_dt', 'dN_dt_err']] = 0.
density.iloc[1:, 3] = np.diff(density['delta_N_s']) / np.diff(density.index)
density.iloc[1:, 4] = density.iloc[1:, 1] / np.mean(np.diff(density.index))
density['dN_dt_err'] = gaussian_filter1d(density['dN_dt_err'], sigma=sigma) / sigma

####################################################################################
####################################################################################
# Phase

phase.set_index('Unnamed: 0', inplace=True)
phase['phase_s'] = gaussian_filter1d(phase['phase'], sigma=sigma)
# TODO remove this random reduction in phase error
phase['phase_err'] /= sigma
phase[['dphi_dt', 'dphi_dt_err']] = 0.
phase.iloc[1:, 3] = np.diff(phase['phase_s']) / np.diff(phase.index)
phase.iloc[1:, 4] = phase.iloc[1:, 1] / np.mean(np.diff(phase.index))
phase['dphi_dt_err'] = gaussian_filter1d(phase['dphi_dt_err'], sigma=sigma) / sigma

phases = phase['phase_s'].to_numpy()[:-1]

## Minimize the distance between the curves

def error(i_c):
    return np.sum((density['dN_dt'] - i_c * np.sin(phases)) ** 2)

def func(p,ic):
    return ic*p

popt, pcov = curve_fit(func,np.sin(phases),density['dN_dt'],p0=[190e3])
print('Curve fit IC:',popt)
print('Curve fit IC error:',np.diag(np.sqrt(pcov)))

# result = minimize_scalar(error, bounds=(100e3, 1e6), method='bounded')
# I_c = result.x
I_c = popt[0]

def error(c):
    return np.sum((phase['dphi_dt'] + density['delta_N_s'] / c) ** 2)


# result = minimize_scalar(error, bounds=(1, 200), method='bounded')
# C = result.x


def func_c(N,c):
    return -N/c

popt, pcov = curve_fit(func_c,density['delta_N_s'],phase['dphi_dt'].to_numpy()[:-1],p0=[59])
print('Curve fit C:',popt)
print('Curve fit C error:',np.diag(np.sqrt(pcov)))
C = popt[0]


print(f'I_c = {I_c} and C = {C}')



fig, ax = pl.subplot_mosaic([["a", "b"]],layout="constrained",figsize=(5*1.61,3.5))

#### Current

I = density['dN_dt'].to_numpy() * 1e-3
d_I = density['dN_dt_err'].to_numpy() * 1e-3

color = RPTU_COLORS['pflaume']
ax["a"].plot(density.index.to_numpy(), I, label='d$N$/d$t$', **get_style(color=color,errorbar=False))
ax["a"].fill_between(density.index.to_numpy(), (I - d_I), (I + d_I), color=color, alpha=0.4)

I = I_c * np.sin(phase['phase_s'].to_numpy()) * 1e-3
d_I = I_c * np.abs(np.cos(phase['phase_s'].to_numpy())) * phase['phase_err'].to_numpy() * 1e-3

color = RPTU_COLORS['himbeere']
ax["a"].plot(phase.index.to_numpy(), I, label=r'$I_\mathrm{c} \sin(\varphi)$', **get_style(color=color,errorbar=False))
ax["a"].fill_between(phase.index.to_numpy(), (I - d_I), (I + d_I), color=color, alpha=0.5,zorder=2)


## Phase
# 2 pi converts it from angular frequencies to actual frequencies.
U = phase['dphi_dt'].to_numpy() / (2*np.pi)
d_U = phase['dphi_dt_err'].to_numpy() / (2*np.pi)

color = RPTU_COLORS['nacht']

ax["b"].plot(phase.index.to_numpy(), U, label=r'd$\varphi$/d$t$', **get_style(color=color,errorbar=False))
ax["b"].fill_between(phase.index.to_numpy(), (U - d_U), (U + d_U), color=color, alpha=0.4)

color = RPTU_COLORS['mango']
U = -density['delta_N_s'].to_numpy() / C / (2*np.pi)
d_U = density['delta_N_err'].to_numpy() / C / (2*np.pi)
ax["b"].plot(density.index.to_numpy(), U, label=r'$-N / C$', **get_style(color=color,errorbar=False))
ax["b"].fill_between(density.index.to_numpy(), (U - d_U), (U + d_U), color=color, alpha=0.6,zorder=2)



ax["a"].legend()
ax["a"].set_xlabel(r'Time [s]')
ax["a"].set_ylabel(r'$I$ [10³/s]')

ax["b"].legend()
ax["b"].set_xlabel(r'Time [s]')
ax["b"].set_ylabel(r'$\Delta \mu$ [Hz]')

pl.tight_layout()

pl.show()