# -*- coding: utf-8 -*-
import astropy as ap
from astropy.time import Time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import math
import tkinter as tk
from tkinter import ttk
from tkinter import filedialog
import json
import platform

def ins_date_extr_index(dates, date):
    index = 0
    if max(dates == date):
        for i in range(len(dates)):
            if (dates[i] == date):
                index = i
                break
    else:
        for i in range(len(dates)-1):
            if (dates[i] < date) and (dates[i+1] > date):
                index = i+1
                break
        dates = np.insert(dates, index, date)
    return dates, index

def find_r_approach_point(r_sep, dist):
    u_r = r_sep/np.linalg.norm(r_sep)  
    if np.linalg.norm(np.cross(u_r, np.array([0, 0, 1]))) == 0:    
        u_ort = np.array([1, 0, 0])
    else:
        u_ort = np.cross(u_r, np.cross(u_r, np.array([0, 0, 1])))
        if (u_ort[2] < 0):
            u_ort = -u_ort
        u_ort = u_ort/np.linalg.norm(u_ort)
        
    theta = math.asin(dist/np.linalg.norm(r_sep))
    phi = math.pi/2 - theta
    d_par = dist*math.cos(phi)
    d_per = dist*math.sin(phi)
    u_ca = d_par*u_r + d_per*u_ort
    u_ca = u_ca/np.linalg.norm(u_ca)
    return u_ca

def find_r_approach_dir(direction):
    direction = direction/np.linalg.norm(direction)
    if np.linalg.norm(np.cross(direction, np.array([0, 0, 1]))) == 0:  
        u_ort = np.array([1, 0, 0])
    else:
        u_ort = np.cross(direction, np.cross(direction, np.array([0, 0, 1])))
        u_ort = u_ort/np.linalg.norm(u_ort)
        if (u_ort[2] < 0):
            u_ort = -u_ort
    return u_ort

def rotate_around_axis(vector, axis, angle):
    if np.linalg.norm(np.cross(vector,axis)) == 0:
        return vector
    else:
        uvector = vector/np.linalg.norm(vector)
        uaxis = axis/np.linalg.norm(axis)
        v_par = np.dot(uvector,uaxis)*vector
        v_perp = vector-v_par
        z = np.cross(axis, v_perp)/np.linalg.norm(axis)
        vector = v_par + (np.cos(angle)*v_perp + np.sin(angle)*z)
    return vector

def search_for_file_path (tkwindow, text):
    currdir = os.getcwd()
    tempdir = filedialog.askdirectory(parent=tkwindow, initialdir=currdir, title=text)
    if len(tempdir) > 0:
        print ("You chose: %s" % tempdir)
    return tempdir

def propagate(r_ta, ta, v, dates):
    states = np.zeros([len(dates),7])
    loc = 0
    for i in dates:
        states[loc,4:7] = v
        states[loc,1:4] = r_ta + (i - ta)*24*3600*v
        loc = loc + 1
    return states

# Trajectory is defined in tuttle-ion-tail frame centered on Tuttle. 
##############################################################################
################################ INPUTS ######################################
##############################################################################
relative_velocity = 10        # Km/s

date_range = ['2035-03-20T00:00:00.0000', '2035-07-05T00:00:00.0000']

B1_separation_date = '2035-03-25T03:32:32.0000'
B2_separation_date = '2035-03-25T04:34:32.0000'

Closest_approach_A_date = '2035-03-27T01:00:00.0000'
Closest_approach_B1_date = '2035-03-27T00:30:00.0000'
Closest_approach_B2_date = '2035-03-27T00:30:00.0000'

Closest_approach_A_v_dir = np.array([1,0, 0])  # Vector, the module is irrelevant as it is going to be normalized
Closest_approach_A_angle = 0*math.pi/6 #rad

Closest_approach_A_dist = 1500   # Km
Closest_approach_B1_dist = 850  # Km
Closest_approach_B2_dist = 550  # Km

Closest_approach_B1_angle = 9*math.pi/6 # rad
Closest_approach_B2_angle = 3*math.pi/6  # rad

COCA_OBS_range = pd.Series(["2035-03-26T23:40:00.000 UTC", "2035-03-26T23:45:00.000 UTC"])
MIRMIS_NIR_OBS_range = pd.Series(["2035-03-26T23:45:00.000 UTC", "2035-03-26T23:50:00.000 UTC"])
MIRMIS_MIR_OBS_range = pd.Series(["2035-03-26T23:50:00.000 UTC", "2035-03-26T23:55:00.000 UTC"])
MIRMIS_TIRI_OBS_range = pd.Series(["2035-03-26T23:55:00.000 UTC", "2035-03-27T00:00:00.000 UTC"])
HI_OBS_range = pd.Series(["2035-03-26T23:40:00.000 UTC", "2035-03-26T23:45:00.000 UTC"])
NAC_OBS_range = pd.Series(["2035-03-26T23:45:00.000 UTC", "2035-03-26T23:50:00.000 UTC"])
WAC_OBS_range = pd.Series(["2035-03-26T23:50:00.000 UTC", "2035-03-26T23:55:00.000 UTC"])
OPIC_OBS_range = pd.Series(["2035-03-26T23:40:00.000 UTC", "2035-03-26T23:45:00.000 UTC"])
ENVISS_OBS_range = pd.Series(["2035-03-26T04:00:00.000 UTC", "2035-03-26T23:39:39.000 UTC"])

##############################################################################
##############################################################################
##############################################################################


# LOAD COSMOGRAPHIA PATH
if os.path.exists("cosmopath.txt"):
    f = open("cosmopath.txt", mode = "r")
    pathcosmo = f.read()
    f.close()
else:
    window = tk.Tk()
    pathcosmo = search_for_file_path(window, 'Please select the installation folder of Cosmographia')
    f = open("cosmopath.txt", mode = "w")
    f.write(pathcosmo)
    f.close()
        
# EDIT OBSERVATION PERIOD OF INSTRUMENTS
misccosmo = "../cosmo/obs_interceptor.json"
obs_dates = open(misccosmo, 'r')
obs_json = json.load(obs_dates)
obs_dates.close()

obs_json["items"][0]["geometry"]["groups"][0]["startTime"] = COCA_OBS_range[0]
obs_json["items"][0]["geometry"]["groups"][0]["endTime"]   = COCA_OBS_range[1]
obs_json["items"][1]["geometry"]["groups"][0]["startTime"] = MIRMIS_NIR_OBS_range[0]
obs_json["items"][1]["geometry"]["groups"][0]["endTime"]   = MIRMIS_NIR_OBS_range[1]
obs_json["items"][2]["geometry"]["groups"][0]["startTime"] = MIRMIS_MIR_OBS_range[0]
obs_json["items"][2]["geometry"]["groups"][0]["endTime"]   = MIRMIS_MIR_OBS_range[1]
obs_json["items"][3]["geometry"]["groups"][0]["startTime"] = MIRMIS_TIRI_OBS_range[0]
obs_json["items"][3]["geometry"]["groups"][0]["endTime"]   = MIRMIS_TIRI_OBS_range[1]
obs_json["items"][4]["geometry"]["groups"][0]["startTime"] = HI_OBS_range[0]
obs_json["items"][4]["geometry"]["groups"][0]["endTime"]   = HI_OBS_range[1]
obs_json["items"][5]["geometry"]["groups"][0]["startTime"] = NAC_OBS_range[0]
obs_json["items"][5]["geometry"]["groups"][0]["endTime"]   = NAC_OBS_range[1]
obs_json["items"][6]["geometry"]["groups"][0]["startTime"] = WAC_OBS_range[0]
obs_json["items"][6]["geometry"]["groups"][0]["endTime"]   = WAC_OBS_range[1]
obs_json["items"][7]["geometry"]["groups"][0]["startTime"] = OPIC_OBS_range[0]
obs_json["items"][7]["geometry"]["groups"][0]["endTime"]   = OPIC_OBS_range[1]
obs_json["items"][8]["geometry"]["groups"][0]["startTime"] = ENVISS_OBS_range[0]
obs_json["items"][8]["geometry"]["groups"][0]["endTime"]   = ENVISS_OBS_range[1]


obs_dates= open(misccosmo, 'w')
json.dump(obs_json, obs_dates, indent = 3)
obs_dates.close()

### GENERATE TRAJECTORY

X = np.array([1, 0, 0])
Y = np.array([0, 1, 0])
Z = np.array([0, 0, 1])

# PUT DATES IN JULIAN DAY FORMAT
jd_range = ap.time.Time(date_range, format='isot', scale= 'utc')
jd_range.format = 'jd'
jd_separation_b1 = ap.time.Time(B1_separation_date, format = 'isot')
jd_separation_b1.format = 'jd'
jd_separation_b2 = ap.time.Time(B2_separation_date, format = 'isot')
jd_separation_b2.format = 'jd'
jd_approach_a = ap.time.Time(Closest_approach_A_date, format = 'isot')
jd_approach_a.format = 'jd'
jd_approach_b1 = ap.time.Time(Closest_approach_B1_date, format = 'isot')
jd_approach_b1.format = 'jd'
jd_approach_b2 = ap.time.Time(Closest_approach_B2_date, format = 'isot')
jd_approach_b2.format = 'jd'

# CREATE DATES VECTOR WITH A TIME STEP OF 1 HOUR
dates = np.arange(jd_range[0].value, jd_range[1].value, 1/(24))

# INSERT THE SEPARATION AND APPROACH DATES IN THE DATE VECTOR
dates, index_b1_sep = ins_date_extr_index(dates, jd_separation_b1.value)
dates, index_b2_sep = ins_date_extr_index(dates, jd_separation_b2.value)
dates, index_a_ap = ins_date_extr_index(dates, jd_approach_a.value)
dates, index_b1_ap = ins_date_extr_index(dates, jd_approach_b1.value)
dates, index_b2_ap = ins_date_extr_index(dates, jd_approach_b2.value)


dates, index_b1_sep = ins_date_extr_index(dates, jd_separation_b1.value)   # RUN TWICE TO UPDATE INDEX IF DATES ADDED ARE NOT IN ORDER
dates, index_b2_sep = ins_date_extr_index(dates, jd_separation_b2.value)
dates, index_a_ap = ins_date_extr_index(dates, jd_approach_a.value)
dates, index_b1_ap = ins_date_extr_index(dates, jd_approach_b1.value)
dates, index_b2_ap = ins_date_extr_index(dates, jd_approach_b2.value)

# CALCULATE CLOSEST APPROACH POSITION OF SPACECRAFT A
u_caA = find_r_approach_dir(Closest_approach_A_v_dir)
u_caA = rotate_around_axis(u_caA, Closest_approach_A_v_dir, Closest_approach_A_angle)
v = Closest_approach_A_v_dir*relative_velocity/np.linalg.norm(Closest_approach_A_v_dir)

# CALCULATE TRAJECTORY OF SPACECRAFT A
states_A = propagate(u_caA*Closest_approach_A_dist, jd_approach_a.value, v, dates)

# CALCULATE TRAJECTORY OF B1 AND B2
states_B1 = np.empty_like(states_A)
states_B2 = np.empty_like(states_A)

states_B1[:] = states_A
states_B2[:] = states_A

B1_separation_r = states_B1[index_b1_sep, 1:4]
B2_separation_r = states_B2[index_b2_sep, 1:4]

u_B1_separation_r = B1_separation_r/np.linalg.norm(B1_separation_r)
u_B2_separation_r = B2_separation_r/np.linalg.norm(B2_separation_r)

u_r_B1_approach = find_r_approach_point(B1_separation_r, Closest_approach_B1_dist)
u_r_B2_approach = find_r_approach_point(B2_separation_r, Closest_approach_B2_dist)
u_r_B1_approach = rotate_around_axis(u_r_B1_approach, u_B1_separation_r, Closest_approach_B1_angle)
u_r_B2_approach = rotate_around_axis(u_r_B2_approach, u_B2_separation_r, Closest_approach_B2_angle)

v_B1 = (u_r_B1_approach*Closest_approach_B1_dist - states_A[index_b1_sep, 1:4])/((dates[index_b1_ap] - dates[index_b1_sep])*24*3600)
v_B2 = (u_r_B2_approach*Closest_approach_B2_dist - states_A[index_b2_sep, 1:4])/((dates[index_b2_ap] - dates[index_b2_sep])*24*3600)

states_B1[index_b1_sep, 4:7] = v_B1
states_B2[index_b2_sep, 4:7] = v_B2

for i in range(index_b1_sep, len(dates)):
    states_B1[i, 4:7] = v_B1
    states_B1[i, 1:4] = states_B1[i-1, 1:4] + states_B1[i-1, 4:7]*(dates[i] - dates[i-1])*24*3600

for i in range(index_b2_sep, len(dates)):
    states_B2[i, 4:7] = v_B2
    states_B2[i, 1:4] = states_B2[i-1, 1:4] + states_B2[i-1, 4:7]*(dates[i] - dates[i-1])*24*3600

states_B1_nosep = states_B1[0:index_b1_sep+1, :].copy()
states_B2_nosep = states_B2[0:index_b2_sep+1, :].copy()
states_B1_nosep[index_b1_sep, 4:7] = states_B1_nosep[index_b1_sep-1, 4:7].copy()
states_B2_nosep[index_b2_sep, 4:7] = states_B2_nosep[index_b2_sep-1, 4:7].copy()

states_B1_sep = states_B1[index_b1_sep:len(dates)+1, :].copy()
states_B2_sep = states_B2[index_b2_sep:len(dates)+1, :].copy()



# PLOT THE TRAJECTORY IN PYTHON
ucomet, vcomet = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
xcomet = 60*np.cos(ucomet)*np.sin(vcomet)
ycomet = 40*np.sin(ucomet)*np.sin(vcomet)
zcomet = 30*np.cos(vcomet)

uplim = max(index_a_ap, index_b1_ap, index_b2_ap)/len(dates)
downlim = min(index_a_ap, index_b1_ap, index_b2_ap)/len(dates)

if uplim < 0.9995:
    uplim = math.ceil((uplim + 0.0005)*len(dates))
else:
    uplim = len(dates)
    
if downlim > 0.0005:
    downlim = math.floor((downlim - 0.0005)*len(dates))
else:
    downlim = 0

wl = np.max(np.array([np.max(abs(states_A)), np.max(abs(states_B1)), np.max(abs(states_B2))]))
widelims = np.array([-wl, wl])
nl = np.max(np.array([np.max(abs(states_A[downlim:uplim,:])) ,np.max(abs(states_B1[downlim:uplim,:])), np.max(abs(states_B2[downlim:uplim,:]))]))
narrowlims = np.array([-nl, nl])

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(states_A[:,1], states_A[:,2], states_A[:,3], 'blue', label = 'SCA')
ax.plot3D(states_B1[:, 1], states_B1[:, 2], states_B1[:, 3], 'red', label = 'SCB1')
ax.plot3D(states_B2[:, 1], states_B2[:, 2], states_B2[:, 3], 'orange', label = 'SCB2')
ax.plot_wireframe(xcomet,ycomet,zcomet, color = "black", label = 'TUTTLE')
ax.scatter(widelims, widelims, widelims, alpha = 0)
ax.set_xlabel('X (Km)')
ax.set_ylabel('Y (Km)')
ax.set_zlabel('Z (Km)')
fig.suptitle("Complete trajectory")
ax.legend()

fig2 = plt.figure()
ax2 = plt.axes(projection='3d')
ax2.plot3D(states_A[:,1], states_A[:,2], states_A[:,3], 'blue', label = 'SCA')
ax2.plot3D(states_B1[:, 1], states_B1[:, 2], states_B1[:, 3], 'red', label = 'SCB1')
ax2.plot3D(states_B2[:, 1], states_B2[:, 2], states_B2[:, 3], 'orange', label = 'SCB2')
ax2.plot_wireframe(xcomet,ycomet,zcomet, color = "black", label = 'TUTTLE')
ax2.set_xlabel('X (Km)')
ax2.set_ylabel('Y (Km)')
ax2.set_zlabel('Z (Km)')
fig2.suptitle("Complete trajectory with distorted axis")
ax2.legend()

fig3 = plt.figure()
ax3 = plt.axes(projection='3d')
ax3.plot3D(states_A[downlim:uplim,1], states_A[downlim:uplim,2], states_A[downlim:uplim,3], 'blue', label = 'SCA')
ax3.plot3D(states_B1[downlim:uplim, 1], states_B1[downlim:uplim, 2], states_B1[downlim:uplim, 3], 'red', label = 'SCB1')
ax3.plot3D(states_B2[downlim:uplim, 1], states_B2[downlim:uplim, 2], states_B2[downlim:uplim, 3], 'orange', label = 'SCB2')
ax3.plot_wireframe(xcomet,ycomet,zcomet, color = "black", label = 'TUTTLE')
ax3.quiver(0, 0, 0, u_r_B1_approach[0]*Closest_approach_B1_dist, u_r_B1_approach[1]*Closest_approach_B1_dist, u_r_B1_approach[2]*Closest_approach_B1_dist)
ax3.quiver(0, 0, 0, u_r_B2_approach[0]*Closest_approach_B2_dist, u_r_B2_approach[1]*Closest_approach_B2_dist, u_r_B2_approach[2]*Closest_approach_B2_dist)
ax3.quiver(0, 0, 0, Closest_approach_A_dist*u_caA[0], Closest_approach_A_dist*u_caA[1], Closest_approach_A_dist*u_caA[2])
ax3.scatter(narrowlims,narrowlims,narrowlims, alpha = 0)
ax3.set_xlabel('X (Km)')
ax3.set_ylabel('Y (Km)')
ax3.set_zlabel('Z (Km)')
fig3.suptitle("Detail of the closest approach")
ax3.legend()
plt.show()

# CREATE INPUT FILES FOR MKSPK
states_A = pd.DataFrame(states_A)
states_A.columns = ['Date', 'X', 'Y', 'Z', 'vx', 'vy', 'vz']

labels = ap.time.Time(dates, format='jd') 
labels.format = 'isot'

states_A['Date'] = labels.value

states_B1_nosep = pd.DataFrame(states_B1_nosep)
states_B1_sep = pd.DataFrame(states_B1_sep)
states_B2_nosep = pd.DataFrame(states_B2_nosep)
states_B2_sep = pd.DataFrame(states_B2_sep)
states_B1_sep.columns = ['Date', 'X', 'Y', 'Z', 'vx', 'vy', 'vz']
states_B1_nosep.columns = ['Date', 'X', 'Y', 'Z', 'vx', 'vy', 'vz']
states_B2_sep.columns = ['Date', 'X', 'Y', 'Z', 'vx', 'vy', 'vz']
states_B2_nosep.columns = ['Date', 'X', 'Y', 'Z', 'vx', 'vy', 'vz']
states_B1_sep['Date'] = labels[index_b1_sep: len(dates)+1].value
states_B1_nosep['Date'] = labels[0: index_b1_sep+1].value
states_B2_sep['Date'] = labels[index_b2_sep: len(dates)+1].value
states_B2_nosep['Date'] = labels[0: index_b2_sep+1].value

#%%
if os.path.exists("input_A.txt"):
    os.remove("input_A.txt")
if os.path.exists("../../kernels/spk/CI_SCA_parametric.bsp"):
    os.remove("../../kernels/spk/CI_SCA_parametric.bsp")
if os.path.exists("input_B1_sep.txt"):
    os.remove("input_B1_sep.txt")
if os.path.exists("input_B1_nosep.txt"):
    os.remove("input_B1_nosep.txt")
if os.path.exists("../../kernels/spk/CI_SCB1_parametric.bsp"):
    os.remove("../../kernels/spk/CI_SCB1_parametric.bsp")
if os.path.exists("input_B2_sep.txt"):
    os.remove("input_B2_sep.txt")
if os.path.exists("input_B2_nosep.txt"):
    os.remove("input_B2_nosep.txt")
if os.path.exists("../../kernels/spk/CI_SCB2_parametric.bsp"):
    os.remove("../../kernels/spk/CI_SCB2_parametric.bsp")
#%%      
print("Creating input_A.txt")
tfile = open('input_A.txt', 'w')
tfile.write(states_A.to_string(header = False, index = False))
tfile.write("\n")
tfile.close()
print("Creating input_B1.txt")
tfile = open('input_B1_nosep.txt', 'w')
tfile.write(states_B1_nosep.to_string(header = False, index = False))
tfile.write("\n")
tfile.close()
print("Creating input_B1.txt")
tfile = open('input_B1_sep.txt', 'w')
tfile.write(states_B1_sep.to_string(header = False, index = False))
tfile.write("\n")
tfile.close()
print("Creating input_B2.txt")
tfile = open('input_B2_nosep.txt', 'w')
tfile.write(states_B2_nosep.to_string(header = False, index = False))
tfile.write("\n")
tfile.close()
print("Creating input_B2.txt")
tfile = open('input_B2_sep.txt', 'w')
tfile.write(states_B2_sep.to_string(header = False, index = False))
tfile.write("\n")
tfile.close()

# CALL MKSPK, CREATE SPK FILES AND LAUNCH COSMOGRAPHIA
print("Creating spks")
if (platform.system() == 'Windows'):
    print("Windows machine")
    mkspkA = 'mkspk -setup setup_A.txt'
    mkspkB1 = 'mkspk -setup setup_B1.txt'
    mkspkB1sep = 'mkspk -setup setup_B1_sep.txt'
    mkspkB2 = 'mkspk -setup setup_B2.txt'
    mkspkB2sep = 'mkspk -setup setup_B2_sep.txt'
    os.system(mkspkA)
    os.system(mkspkB1)
    os.system(mkspkB1sep)
    os.system(mkspkB2)
    os.system(mkspkB2sep)
    os.system(pathcosmo + '/Cosmographia.exe ' +  '../cosmo/load_interceptor_parametric.json ' + '-u "cosmo:COMET-INTERCEPTOR?frame=bfix&jd=' + str(jd_range[0].value + 0.1) + '&x=-0.000705&y=0.004490&z=-0.017913&qw=0.065055&qx=-0.645594&qy=0.753558&qz=0.105481&ts=0&fov=50"')
elif (platform.system() == 'MAC'):
    print("MAC machine")
    mkspkA = 'open mkspk -setup setup_A.txt'
    mkspkB1 = 'open mkspk -setup setup_B1.txt'
    mkspkB2 = 'open mkspk -setup setup_B2.txt'
    os.system(mkspkA)
    os.system(mkspkB1)
    os.system(mkspkB2)
    os.system('open ' + pathcosmo + '/Cosmographia.app ' +  '../cosmo/load_interceptor_parametric.json')
elif (platform.system() == 'Linux'):
    print("Linux machine")
    mkspkA = 'mkspk -setup setup_A.txt'
    mkspkB1 = 'mkspk -setup setup_B1.txt'
    mkspkB2 = 'mkspk -setup setup_B2.txt'
    os.system(mkspkA)
    os.system(mkspkB1)
    os.system(mkspkB2)
    os.system('./' + pathcosmo + '/cosmographia.sh ' +  '../cosmo/load_interceptor_parametric.json ' + '-u "cosmo:COMET-INTERCEPTOR?frame=bfix&jd=' + str(jd_range[0].value + 0.1) + '&x=-0.000705&y=0.004490&z=-0.017913&qw=0.065055&qx=-0.645594&qy=0.753558&qz=0.105481&ts=0&fov=50"')
else:
    print("OS not recognized")