import sys
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from collections import Counter
import os
import copy
from matplotlib.patches import Patch
import seaborn as sns
import datetime


# 1) Difference between FIS and FNN
# look for the files
os.chdir(" ")  # set directory
files = [f for f in os.listdir() if f[0:3].isnumeric()]  # list files
individuals = [b[0:3] for b in files]  # extract names individuals

similar_solutions = []
datapoints = []
LEAD = {}
plotting = pd.DataFrame()

# open files fuzzy logic and nearest neighbour and join them
for ind in individuals:
    print(ind)
    df_ff = pd.read_table("%s_inwake_flying_09082019.txt" % ind)
    df_nn = pd.read_csv("%s_nearest_neighbour_09082019.csv" % ind)
    datapoints.append((ind, len(df_ff)))  # length of the dataframe
    df_ff = df_ff.rename(columns={'timestamp': 'Timestamp'})
    df_ff['Timestamp'] = pd.to_datetime(df_ff['Timestamp'])
    df_nn['Timestamp'] = pd.to_datetime(df_nn['Timestamp'])
    df_nn = df_nn.dropna(how='any')
    new_leaders_nn = []
    for idx, row in df_nn.iterrows():
        if row.Dist_nearest_neigh < sys.float_info.max:
            new_leaders_nn.append(row.ID_nearest_neigh)
        else:
            if (row[['Nearest_neigh_n', 'Nearest_neigh_e', 'Nearest_neigh_u']].to_numpy() == np.array([0, 0, 0])).all():
                new_leaders_nn.append(-2)
            else:
                new_leaders_nn.append(-1)
    df_nn['new_leader'] = new_leaders_nn
    # join the two dataframes
    leaders = pd.merge(df_ff[['Timestamp', 'leader']], df_nn[['Timestamp', "new_leader"]], on='Timestamp')
    leaders['similar'] = np.where(leaders['leader'] == leaders['new_leader'], True, False)  # calculate where the leaders are similar
    count_similar = Counter(leaders['similar'])
    percentage_similar = count_similar[1] / len(leaders)  # calculate the percentage similarity
    similar_solutions.append((ind, percentage_similar))
    LEAD[int(ind)] = (Counter(df_ff['leader']))  # put in a dictionary the count of each leader
    df_nn = pd.merge(df_nn, leaders[['Timestamp', 'similar']], on='Timestamp')
    all = pd.merge(df_ff, df_nn[['Timestamp', 'Nearest_neigh_n', 'Nearest_neigh_e', 'Nearest_neigh_u', 'similar']],
                   on=['Timestamp'])
    all['bird'] = ind
    plotting = plotting.append(all)

# put everything in a dataframe and do basic stats
similar_solutions = pd.DataFrame(similar_solutions, columns=['bird', 'similar_sol'])
similar_stats = similar_solutions['similar_sol'].describe()

datapoints = pd.DataFrame(datapoints, columns=['bird', 'number'])
datapoints_stats = datapoints['number'].describe()

# preparing the dataframe for plotting
plotting_f = plotting[plotting['similar'] == False]
plotting_f = plotting_f.reset_index(drop=True)
plotting_f = plotting_f[plotting_f['Nearest_neigh_n'] != 0]

plotting = plotting.reset_index(drop=True)
plotting = plotting[plotting['Nearest_neigh_n'] != 0]

# plotting the all the solution and the difference with histograms
# All solutions
fig2, axs = plt.subplots(2, 3, figsize=(15, 12))
sns.histplot(ax=axs[0][0], x='x', data=plotting, color='blue', stat='frequency', alpha=.7, binrange=(-2, 2), bins=40)
sns.histplot(ax=axs[0][1], x='y', data=plotting, color='blue', stat='frequency', alpha=.7, bins=40, binrange=(-5, 0))
sns.histplot(ax=axs[0][2], x='z', data=plotting, color='blue', stat='frequency', alpha=.7, bins=40, binrange=(-.75, .75))
sns.histplot(ax=axs[0][0], x='Nearest_neigh_e', data=plotting, color='darkorange', stat='frequency', alpha=.7,
             binrange=(-2, 2), bins=40)
sns.histplot(ax=axs[0][1], x='Nearest_neigh_n', data=plotting, color='darkorange', stat='frequency', alpha=.7, bins=40,
             binrange=(-5, 0))
sns.histplot(ax=axs[0][2], x='Nearest_neigh_u', data=plotting, color='darkorange', stat='frequency', alpha=.7, bins=40,
             binrange=(-.75, .75))
axs[0][0].set_axisbelow(True)
axs[0][0].grid(color='lightgray', linestyle='dashed')
axs[0][1].set_axisbelow(True)
axs[0][1].grid(color='lightgray', linestyle='dashed')
axs[0][2].set_axisbelow(True)
axs[0][2].grid(color='lightgray', linestyle='dashed')
axs[0][0].set_xlabel('e|w (m)', fontsize='large', labelpad=4)
axs[0][1].set_xlabel('n|s (m)', fontsize='large', labelpad=4)
axs[0][2].set_xlabel('u|d (m)', fontsize='large', labelpad=4)
axs[0][0].set_ylabel('Frequency', fontsize='large', labelpad=4)
axs[0][1].set(ylabel=None)
axs[0][2].set(ylabel=None)
custom_legend = [Patch(facecolor='blue', edgecolor='blue', label='FIS', alpha=.7),
                 Patch(facecolor='darkorange', edgecolor='darkorange', label='FNN', alpha=.7)]
axs[0][2].legend(handles=custom_legend, loc='upper right')
axs[0][0].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
axs[0][1].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
axs[0][2].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))

# only divergent solutions
sns.histplot(ax=axs[1][0], x='x', data=plotting_f, color='blue', stat='frequency', alpha=.7, binrange=(-2, 2), bins=40)
sns.histplot(ax=axs[1][1], x='y', data=plotting_f, color='blue', stat='frequency', alpha=.7, bins=40, binrange=(-5, 0))
sns.histplot(ax=axs[1][2], x='z', data=plotting_f, color='blue', stat='frequency', alpha=.7, bins=40, binrange=(-.75, .75))
sns.histplot(ax=axs[1][0], x='Nearest_neigh_e', data=plotting_f, color='darkorange', stat='frequency', alpha=.7,
             binrange=(-2, 2), bins=40)
sns.histplot(ax=axs[1][1], x='Nearest_neigh_n', data=plotting_f, color='darkorange', stat='frequency', alpha=.7, bins=40,
             binrange=(-5, 0))
sns.histplot(ax=axs[1][2], x='Nearest_neigh_u', data=plotting_f, color='darkorange', stat='frequency', alpha=.7, bins=40,
             binrange=(-.75, .75))
axs[1][0].set_axisbelow(True)
axs[1][0].grid(color='lightgray', linestyle='dashed')
axs[1][1].set_axisbelow(True)
axs[1][1].grid(color='lightgray', linestyle='dashed')
axs[1][2].set_axisbelow(True)
axs[1][2].grid(color='lightgray', linestyle='dashed')
axs[1][0].set_xlabel('e|w (m)', fontsize='large', labelpad=4)
axs[1][1].set_xlabel('n|s (m)', fontsize='large', labelpad=4)
axs[1][2].set_xlabel('u|d (m)', fontsize='large', labelpad=4)
axs[1][0].set_ylabel('Frequency', fontsize='large', labelpad=4)
axs[1][1].set(ylabel=None)
axs[1][2].set(ylabel=None)
custom_legend = [Patch(facecolor='blue', edgecolor='blue', label='FIS', alpha=.7),
                 Patch(facecolor='darkorange', edgecolor='darkorange', label='FNN', alpha=.7)]
axs[1][2].legend(handles=custom_legend, loc='upper right')
axs[1][0].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
axs[1][1].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
axs[1][2].ticklabel_format(axis='y', style='sci', scilimits=(0, 0))

plt.subplots_adjust(top=0.95, bottom=0.07, left=0.05, right=0.95, hspace=0.2, wspace=0.2)


# calculate percentage leader for each bird
following_values = []
max_following_values = []
for L in LEAD:
    print('bird is: ', L)
    calc_mainlead = copy.copy(LEAD[L])
    calc_mainlead.pop(-1, None)
    calc_mainlead.pop(-2, None)
    calc_mainlead_sum = sum(calc_mainlead.values())
    freq_leader_dict = {k: v / calc_mainlead_sum for k, v in calc_mainlead.items()}
    following_values.extend(list(freq_leader_dict.values()))
    max_freq_leader = max(freq_leader_dict, key=freq_leader_dict.get)
    print('max leader is: ', max_freq_leader, freq_leader_dict[max_freq_leader])
    max_following_values.append(freq_leader_dict[max_freq_leader])

sd_follower, mean_follower = np.std(following_values), np.mean(following_values)
min_f, max_f, mean_f = min(max_following_values), max(max_following_values), np.mean(max_following_values)
plt.hist(max_following_values, bins=20)

alone = []
for L in LEAD:
    print('bird is: ', L)
    calc_alone = LEAD[L]
    calc_alone_sum = sum(calc_alone.values())
    freq_alone_dict = {k: v / calc_alone_sum for k, v in calc_alone.items()}
    if {-1, -2} <= set(freq_alone_dict):
        prop_alone = freq_alone_dict[-1] + freq_alone_dict[-2]
    elif -1 in freq_alone_dict:
        prop_alone = freq_alone_dict[-1]
    else:
        prop_alone = freq_alone_dict[-2]
    alone.append(prop_alone)
    print(prop_alone)

mean_alone = np.mean(alone)
std_alone = np.std(alone)


# resampling of proportion of in-wake flying behind a specific individual
# with bootstrapping.

# these for lines open the file and set the dataframe right
df = pd.read_table('diads_09082019.txt')
df = df.set_index('#timestamp')
df = df.stack().reset_index()
df.columns = ['timestamp', 'follower', 'leader']
df['timestamp'] = pd.to_datetime(df['timestamp'])

# create df for bootstrapping
df = df.set_index('timestamp')
df_grouped = df.groupby('follower')   # group the dataframe by follower
duration_bouts = []
for g in df_grouped:                   # iterate over the groups
    foll = g[1]
    lead = foll.iloc[0, 1]
    start = foll.index[0]
    count = 1
    for i, l in foll.iloc[1:,-1].items():
        if l == lead:
            count += 1
            pass
        else:
            delta = (i - start).total_seconds()
            duration_bouts.append((g[0], lead, delta, count))
            lead = l
            start = i
            count = 1

duration_bouts = pd.DataFrame(duration_bouts, columns=['follower', 'leader', 'duration', 'counts'])

# in this part I count the timestamp that each follower flew behind a leader
original_counts = duration_bouts.groupby(['follower', 'leader'])['counts'].sum()
original_counts = original_counts.reset_index()

n_rep = 1000
bootstrapped_counts = pd.DataFrame()
duration_grouped = duration_bouts.groupby('follower')
for g in duration_grouped:
    print(g[0])
    foll = g[1]
    for_boot = foll[['duration', 'counts']]
    leaders = pd.Series(foll.leader.unique())
    for i in range(n_rep):
        rep = leaders.sample(n=len(for_boot), replace=True, ignore_index=True)
        rep = rep.rename('leader')
        rep = pd.concat([foll[['follower', 'counts', 'duration']].reset_index(drop=True), rep], axis=1)
        rep = rep.groupby(['follower', 'leader'])['counts'].sum().reset_index()
        bootstrapped_counts = bootstrapped_counts.append(rep)

bootstrapped_counts = bootstrapped_counts[bootstrapped_counts.leader != -1]
bootstrapped_counts = bootstrapped_counts[bootstrapped_counts.leader != -2]
bootstrapped_counts = bootstrapped_counts[bootstrapped_counts.leader != -3]

to_plot = []
boot_grouped = bootstrapped_counts.groupby(['follower', 'leader'])
for g in boot_grouped:
    foll = g[1]
    foll = foll.sort_values(by=['counts'], ignore_index=True)
    lim25 = foll.iloc[int(0.025*n_rep)-1, -1]
    lim50 = foll.iloc[int(0.5*n_rep)-1, -1]
    lim75 = foll.iloc[int(0.975*n_rep)-1, -1]
    to_plot.append((g[0][0], g[0][1], lim25, lim50, lim75))

to_plot = pd.DataFrame(to_plot, columns=['follower', 'leader', 'lim25', 'lim50', 'lim75' ])


## PLOTTING RESULTS 5 INDIVIDUALS (main manuscript)
followers = pd.Series(df.follower.unique())
individuals_plot = followers.sample(n=5, random_state=123).tolist()
bootstrapping_plot = to_plot.loc[to_plot.follower.isin(individuals_plot)]
#bootstrapping_plot = bootstrapping_plot[bootstrapping_plot.leader != '-1']
#bootstrapping_plot = bootstrapping_plot[bootstrapping_plot.leader != '-2']
bootstrapping_plot = bootstrapping_plot.groupby('follower')
original_counts_plot = original_counts.loc[original_counts.follower.isin(individuals_plot)]
original_counts_plot = original_counts_plot[original_counts_plot.leader != -1]
original_counts_plot = original_counts_plot[original_counts_plot.leader != -2]
original_counts_plot = original_counts_plot[original_counts_plot.leader != -3]

original_counts_plot = original_counts_plot.groupby('follower')

plt.figure(figsize=(16, 5))
for i, (b, o) in enumerate(zip(bootstrapping_plot, original_counts_plot)):
    print(b[0], o[0])
    boot = b[1]
    ori = o[1].sort_values(by='leader')
    boot['prop25'] = boot.lim25/sum(ori.counts)
    boot['prop50'] = boot.lim50 / sum(ori.counts)
    boot['prop75'] = boot.lim75 / sum(ori.counts)
    ori['prop'] = ori.counts/sum(ori.counts)
    ori= ori.astype({'leader':str})
    boot = boot.astype({'leader': str})
    ax = plt.subplot(1, 5, i + 1)
    ax.set_axisbelow(True)
    ax.grid(color='lightgray', linestyle='dashed', axis='x')
    plt.hlines(y=boot.leader, xmin=boot.prop25, xmax=boot.prop75, color='black', alpha=.7, zorder=5)
    x = plt.barh(ori.leader, ori.prop, color='forestgreen', zorder=1)
    #x = plt.scatter(ori.prop,ori.leader, color='forestgreen', zorder=1)
    plt.scatter(y=boot.leader, x=boot.prop50, c='black', alpha=.7, s=5, zorder=10)
    max_value = max(ori.prop)
    index = list(ori['prop']).index(max_value)
    plt.text(max_value, index - 0.3, '★')
    plt.title(b[0])
plt.tight_layout()

## PLOTTING RESULTS ALL INDIVIDUALS (supplementary)
bootstrapping = to_plot[to_plot.leader != -1]
bootstrapping = bootstrapping[bootstrapping.leader != -2]
bootstrapping = bootstrapping[bootstrapping.leader != -3]
bootstrapping_plot = bootstrapping.groupby('follower')
original_counts = original_counts[original_counts.leader != -1]
original_counts = original_counts[original_counts.leader != -2]
original_counts = original_counts[original_counts.leader != -3]
original_counts_plot = original_counts.groupby('follower')

plt.figure(figsize=(28.5/2.54, 20/2.54))
for i, (b, o) in enumerate(zip(bootstrapping_plot, original_counts_plot)):
    print(b[0], o[0])
    boot = b[1]
    ori = o[1].sort_values(by='leader')
    boot['prop25'] = boot.lim25/sum(ori.counts)
    boot['prop50'] = boot.lim50 / sum(ori.counts)
    boot['prop75'] = boot.lim75 / sum(ori.counts)
    ori['prop'] = ori.counts/sum(ori.counts)
    ori= ori.astype({'leader':str})
    boot = boot.astype({'leader': str})
    ax = plt.subplot(3, 10, i + 1)
    ax.set_axisbelow(True)
    ax.grid(color='lightgray', linestyle='dashed', axis='x')
    plt.hlines(y=boot.leader, xmin=boot.prop25, xmax=boot.prop75, color='black', alpha=.7, zorder=5)
    x = plt.barh(ori.leader, ori.prop, color='forestgreen', alpha=.7, zorder=0)
    plt.scatter(y=boot.leader, x=boot.prop50, c='black', alpha=.7, s=5, zorder=10)
    max_value = max(ori.prop)
    index = list(ori['prop']).index(max_value)
    plt.text(max_value, index - 0.3, '★')
    plt.title(b[0])
    plt.yticks(fontsize=6)
    plt.xticks(fontsize=6)
    plt.xlim(0, max_value + 0.04 )
    ax.xaxis.set_ticks_position('none')
    ax.tick_params(axis='x', which='major', pad=-1)
    plt.text(max_value, index - 0.3, '★')
    plt.title(b[0], fontsize=7, pad=0.5, fontweight="bold" )
plt.subplots_adjust(top=0.985, bottom=0.015, left=0.04, right=0.985, hspace=0.1, wspace=0.3)


# calculate bouts length
files = os.listdir(' ')
files = [f for f in files if 'flying' in f]

length_bouts_solo = []
length_bouts_inwake = []

for f in files:
    print(f)
    df = pd.read_table(r'F:\\Data_collection_2019\\Elisa\\GNSS_data\\09082019\\analysis_IEEE_TFS\\files_analysis\\' + f)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df = df.set_index('timestamp')
    lead = df.leader[0]
    start = df.index[0]
    leader = df.leader[1:]
    for i, l in leader.iteritems():
        if l == lead:
            pass
        else:
            end = i
            duration = (end - start).total_seconds()
            if l in [-1, -2]:
                length_bouts_solo.append(duration)
            else:
                length_bouts_inwake.append(duration)
            start = i
            lead = l
    print(len(length_bouts_inwake))
    print(len(length_bouts_solo))

print('solo')
print(np.mean(length_bouts_solo))
print(np.std(length_bouts_solo))
print(np.median(length_bouts_solo))

print('inwake')
print(np.mean(length_bouts_inwake))
print(np.std(length_bouts_inwake))
print(np.median(length_bouts_inwake))

print('both')
print(np.mean(length_bouts_inwake + length_bouts_solo))
print(np.std(length_bouts_inwake + length_bouts_solo))
print(np.median(length_bouts_inwake + length_bouts_solo))

