import sys

import simpful as sf
from numpy import zeros, array, argmin, argmax, savetxt, hstack, newaxis, linspace, meshgrid, insert, stack
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import seaborn as sns
import networkx as nx
import os
from datetime import datetime
import numpy as np

sns.set(font_scale=1.2)
sns.set_style("white")


class SnapshotError(Exception):
    pass


class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

        return np.min(zs)


class InwakeDetector:

    def __init__(self):

        # model implementing FIS
        self._reasoner = self._create_model()

        # working data structures
        self._all_ids = []  # list of names
        self._corrs = None  # dict <name:integer> (inverse mapping for self._all_ids)
        self._header = None

        # (SxB) matrix storing the information about which bird
        # is providing upwash to each ibis during in each snapshot
        # where S = snapshots, B = birds
        self._relationship_data = []

        self._strength_data = []

        # save timestamps for output files
        self._timestamp_data = []

        # (SxB) matrix storing the information about the relative position of the bird
        # that is providing upwash to each ibis during in each snapshot
        # where S = snapshots, B = birds
        self._strongest_positions = []

        # base plot
        self._figure = plt.figure(figsize=(16, 14), dpi=300)

        # in case something's wrong
        self._backup = {}

        # in case we have gaps in the snapshots
        self._missing_ibises_list = []

        # information about "grouping of inwake birds"
        self._info_grouping = []

    def _clear_figure(self):
        """	Helper method, required to clear the figure before the generation of each snapshot result. """
        self._figure.clf()
        self._ax = self._figure.add_subplot(224, projection='3d')
        self._top = self._figure.add_subplot(223, projection='3d')
        self._back = self._figure.add_subplot(221, projection='3d')
        self._side = self._figure.add_subplot(222, projection='3d')

    def plot_surfaces(self, detail=50, colormap="hot", filename="surface.pdf"):
        """
			This method generates the surfaces of the firing strengths according to the variables values.
			The figure is composed of three panels: e/n; e/u; n/u.
			Arguments:  'detail' controls the amount of subdivisions used for the interpolation of the lattices
						(the higher, the better). The running time is O(detail^2).
						'filename' is the path where the figure should be saved.
		"""

        fig = plt.figure(figsize=(15, 5))
        ax = fig.add_subplot(131)
        bx = fig.add_subplot(132)
        cx = fig.add_subplot(133)

        ax.set_title("(a)")
        bx.set_title("(b)")
        cx.set_title("(c)")

        xpoints = linspace(-3, 3, detail)  # lateral
        ypoints = linspace(-7, 0, detail)  # front
        zpoints = linspace(-3, 3, detail)  # up/down

        # lateral : front
        ax.set_xlabel("e|w")
        ax.set_ylabel("n|s")
        pointsx = []
        pointsy = []
        pointsz = []
        for x in xpoints:
            for y in ypoints:
                res = self._assess_inwake(y, x, 0)
                pointsx.append(x)
                pointsy.append(y)
                pointsz.append(res)
        ax.scatter(pointsx, pointsy, c=pointsz, cmap=colormap, s=2, marker="s")

        # lateral : up/down
        bx.set_xlabel("e|w")
        bx.set_ylabel("u|d")
        pointsx = []
        pointsy = []
        pointsz = []
        for x in xpoints:
            for z in zpoints:
                res = self._assess_inwake(-1, x, z)
                pointsx.append(x)
                pointsz.append(z)
                pointsy.append(res)
        bx.scatter(pointsx, pointsz, c=pointsy, cmap=colormap, s=2, marker="s")

        # front : up/down
        cx.set_xlabel("n|s")
        cx.set_ylabel("u|d")
        pointsx = []
        pointsy = []
        pointsz = []
        for y in ypoints:
            for z in zpoints:
                res = self._assess_inwake(y, 1, z)
                pointsy.append(y)
                pointsz.append(z)
                pointsx.append(res)
        cx.scatter(pointsy, pointsz, c=pointsx, s=2, cmap=colormap, marker="s")

        fig.tight_layout()
        fig.savefig(filename)

    def _create_model(self):

        """ This methods creates the FIS for the evaluation of membership to in-wake flight. """

        S = sf.FuzzySystem()

        # east-west
        S_1 = sf.FuzzySet(points=[[-1.8, 0], [-1.3, 1.], [-.8, 0], [.8, 0], [+1.3, 1.], [+1.8, 0]],
                          term="wing_tip_aligned")
        S_2 = sf.FuzzySet(points=[[-1.8, 1], [-1.3, 0.], [-.8, 1.], [.8, 1], [+1.3, 0], [+1.8, 1.]],
                          term="wing_tip_misaligned")
        S.add_linguistic_variable("bird_ew", sf.LinguisticVariable([S_1, S_2]))

        # north-south
        S_4 = sf.FuzzySet(points=[[-.1, 0], [0., 1]], term="too_close")
        S_3 = sf.FuzzySet(points=[[-5., 0], [-.1, 1.], [0., 0]], term="close")
        S_5 = sf.FuzzySet(points=[[-5., 1.], [-.1, 0]], term="distant")
        S.add_linguistic_variable("bird_ns", sf.LinguisticVariable([S_3, S_4, S_5]))

        # up-down
        S_5 = sf.FuzzySet(points=[[-0.75, 0], [0, 1.], [0.75, 0]], term="same_plane")
        S_6 = sf.FuzzySet(points=[[-0.75, 1.], [0, 0], [0.75, 1.]], term="different_plane")
        S.add_linguistic_variable("bird_plane", sf.LinguisticVariable([S_5, S_6]))

        rules = [

            "IF ((bird_ew IS wing_tip_aligned) AND (bird_ns IS close) AND (bird_plane IS same_plane)) THEN flying IS in_wake",
            "IF bird_ew IS wing_tip_misaligned THEN flying IS not_in_wake",
            "IF bird_ns IS too_close THEN flying IS not_in_wake",
            "IF bird_ns IS distant THEN flying IS not_in_wake",
            "IF bird_plane IS different_plane THEN flying IS not_in_wake"
        ]

        S.set_crisp_output_value('in_wake', 1)
        S.set_crisp_output_value('not_in_wake', 0)

        S.add_rules(rules)

        S.produce_figure("fuzzysets.pdf")

        return S

    def _get_all_IDs(self, listids):
        """ Returns a list of all IDs used in this snapshot. """
        all_ids = set()
        for element in listids:
            first_chunk = element[:element.find('_')]
            second_chunk = element[element.find('_') + 1:-1]
            all_ids.add(first_chunk)
            all_ids.add(second_chunk)
        return list(all_ids)

    def _process_shapshot(self, line, verbose=False, fix_errors=False, forbidden_intervals=None):
        """
			This method processes a snapshot (i.e., the 'line' string coming from dataset).
			Returns:
				- a total_ibides*total_ibides numpy matrix of pair-wise firing strengths
				of the "in-wake flight" fuzzy rule;
				- the list of ordered bird pairs that are present in the dataset file
				(calculated according to the 'header' string).
				- the split line from the dataset file
				- all relative positions of Ibides
		"""

        total_ibides = len(self._all_ids)

        # create empty total_ibides*total_ibides matrix for pair-wise firing strengths
        matrix = zeros((total_ibides * total_ibides)).reshape((total_ibides, total_ibides))

        # NEW: let's detect which ibis has wrong sensor data in this snapshot
        #      in order to do so, let's check which ibis are missing and store
        #      the information in a matrix
        matrix_missing = zeros((total_ibides * total_ibides)).reshape((total_ibides, total_ibides))

        # create empty total_ibides*total_ibides matrix for relative positions
        relposition = [[[] for _ in range(total_ibides)] for _ in range(total_ibides)]

        # split line into timestamp and rest of data
        split_line = (line.strip().split(","))
        timestamp = split_line[0]
        print("Timestamp:", timestamp);  # exit()

        if forbidden_intervals is not None:
            if self.is_circling(timestamp, forbidden_intervals.T[0], forbidden_intervals.T[1]):
                print("Snasphot %s in forbidden interval, skipping" % timestamp)
                raise SnapshotError("Forbidden snapshot (flying in circle)")

        split_line = split_line[1:]

        pairs = []

        # needed later for plotting
        try:
            splitline = array(list(map(float, split_line[:])))
        except:
            splitline = None

        for i, _ in enumerate(self._header[::3]):

            # extract and store ibis names
            ibis1, ibis2 = self._header[i * 3][:-1].split("_")
            pairs.append([ibis1, ibis2])

            # extract relative position
            # try/except in case of missing values
            if fix_errors:
                try:
                    east, north, vertical = map(float, split_line[i * 3:i * 3 + 3])
                except:

                    # if fix_errors is set to True, use the previous data to fill the gap
                    print("ERROR: cannot split (%s, %s)" % (ibis1, ibis2))
                    east, north, vertical = self._backup[(ibis1, ibis2)]

                finally:
                    self._backup[(ibis1, ibis2)] = (east, north, vertical)
            # splitline.extend([east, north, vertical])
            else:
                try:
                    east, north, vertical = map(float, split_line[i * 3:i * 3 + 3])
                except Exception as e:
                    # print(e)
                    # if fix_errors is set to False, ignore this value
                    # raise SnapshotError("Snaphot has incomplete data")
                    if verbose: print("WARNING: broken tuple: (%s, %s)" % (ibis1, ibis2))
                    matrix_missing[self._corrs[ibis1]][
                        self._corrs[ibis2]] = True  # do not raise error: store info instead
                    matrix[self._corrs[ibis2]][self._corrs[ibis1]] = -1
                    relposition[self._corrs[ibis2]][self._corrs[ibis1]] = [0, 0, 0]
                    continue

            # assess and store in-wakeness
            inwake = self._assess_inwake(north, east, vertical)

            if verbose:
                print(ibis1, ibis2, east, north, vertical)
                if inwake > 0:
                    print("! Ibis %s is inwake of %s (%.3f)" % (ibis2, ibis1, inwake))
                    print(" ", east, north, vertical, "\n")

            # ibis2, if n<0, could be flying inwake wrt ibis1 (who is leading)
            # in any case:
            # - the ibis2-th row of 'matrix', at the ibis1-th column, will contain the inwake strength
            # - the ibis2-th row of relpositions, at the ibis1-th column, will contain the position of ibis2
            matrix[self._corrs[ibis2]][self._corrs[ibis1]] = inwake
            relposition[self._corrs[ibis2]][self._corrs[ibis1]] = [east, north, vertical]

        # NEW: process the matrix of missing ibises
        # print(matrix_missing)
        missing_ibises = []
        for n, row in enumerate(matrix_missing):
            if sum(row) == len(self._all_ids) - 1:
                if verbose: print("WARNING: all data for ibis", self._all_ids[n], "is missing")
                missing_ibises.append(self._all_ids[n])
        print("WARNING: missing ibises in this snapshot:", missing_ibises)

        self._timestamp_data.append(timestamp)

        self._missing_ibises_list.append(missing_ibises)  # list for all snapshots

        return matrix, pairs, splitline, relposition, missing_ibises

    def is_circling(self, timestamp, list1, list2):
        begin = [datetime.strptime(elem, '%Y-%m-%d %H:%M:%S.%f') for elem in list1]
        end = [datetime.strptime(elem, '%Y-%m-%d %H:%M:%S.%f') for elem in list2]
        circling = [[elem1, elem2] for elem1, elem2 in zip(begin, end)]
        timepoint = datetime.strptime(timestamp, '%Y-%m-%d %H:%M:%S.%f')
        for c in circling:
            if c[0] <= timepoint <= c[1]:
                return True

    def _build_graph(self, matrix):
        G = nx.DiGraph()
        G.add_nodes_from(map(int, self._all_ids))
        listone = self._get_edges(matrix, self._all_ids)
        G.add_edges_from(listone)  # from follower to leader
        return G

    def _get_edges(self, matrix, all_ids):
        listone = []
        total_ibides = len(matrix)
        for leader in range(total_ibides):
            to_node = all_ids[argmax(matrix[leader])]
            from_node = all_ids[leader]
            weight = max(matrix[leader])
            if weight > 0:
                listone.append((int(to_node), int(from_node), {'weight': weight}))
        return listone

    def import_data(self, path, use_snapshots=None, outputdir='output', plot=True,
                    verbose=False, forbidden_intervals=None, highlight=None, img_format="jpg", print_names=[],
                    plot_arrow=True):
        from time import time
        from os import sep, mkdir

        if forbidden_intervals is not None: forbidden_intervals = array(forbidden_intervals)

        try:
            mkdir(outputdir)
        except:
            pass

        with open(path) as fi:
            self._header = fi.readline().strip().split(",")[1:]
            self._all_ids = self._get_all_IDs(self._header)
            total_ibides = len(self._all_ids)
            print(" * Detected %d ibides" % total_ibides)
            self._corrs = dict(zip(self._all_ids, range(total_ibides)))
            print(" * Decoding dict:", self._corrs)

            self._relationship_data = []
            self._missing_ibises_list = []

            # process all snapshots
            for n, line in enumerate(fi):

                if use_snapshots is not None:
                    if n not in use_snapshots:
                        continue
                print(" * Processing snapshot %d..." % n)

                # this array stores the name of the ibis providing upwash
                providing_inwash = zeros(total_ibides)

                start = time()

                timestamp = line.split(",")[0]

                try:
                    matrix, pairs, splitline, relpositions, missing_ibises = self._process_shapshot(line,
                                                                                                    verbose=verbose,
                                                                                                    forbidden_intervals=forbidden_intervals)

                    # create matrix of distances
                    # filtered_distances = [[ [] for _ in range(total_ibides) ] for _ in range(total_ibides)]
                    filtered_distances = zeros(total_ibides * total_ibides).reshape((total_ibides, total_ibides))

                    # print(relpositions); exit()

                    # cicliamo sulle righe / inseguitori di relpositions
                    for indice_a, a in enumerate(relpositions):

                        # cicliamo sulle colonne / leaders di relpositions
                        for indice_i, i in enumerate(a):

                            # print (self._all_ids[indice_i], missing_ibises)
                            if str(self._all_ids[indice_i]) in missing_ibises:
                                # print (" * Ignoring ibis", (self._all_ids[indice_i]))
                                filtered_distances[indice_a][indice_i] = sys.float_info.max
                                continue

                            # print(indice_a, indice_i, i); #exit()
                            try:
                                if i[1] < 0:
                                    distance = np.linalg.norm(i, 2)
                                else:
                                    distance = sys.float_info.max
                            except IndexError:
                                distance = sys.float_info.max

                            filtered_distances[indice_a][indice_i] = distance

                except SnapshotError:
                    print("WARNING: snapshot was dropped")
                    continue

                # create matrices for diads and strengths
                self._relationship_data.append(zeros(total_ibides, dtype=int))
                self._strength_data.append(zeros(total_ibides, dtype=float))
                self._strongest_positions.append([[] for _ in range(total_ibides)])

                for m, row in enumerate(matrix):

                    # name of the following Ibis
                    back = self._all_ids[m]
                    if verbose: print("   Processing Ibis %d (name: %s)" % (m, back))

                    if str(back) in missing_ibises:
                        print("   Ibis %s in forbidden list, skipping..." % back)
                        self._relationship_data[-1][m] = -3  # code for skipped ibis
                        self._strength_data[-1][m] = 0
                        self._strongest_positions[-1][m] = [0, 0, 0]
                        continue

                    # m = current Ibis
                    # print ("   Inwake rules strengths:\n", row)

                    max_row = max(row)

                    # if the strengths are all zero, then the Ibis is not flying inwake
                    if max_row == 0:

                        # am I in front of anybody?
                        if len(list(filter(lambda x: x != sys.float_info.max, filtered_distances[m]))) == 0:
                            if verbose: print(
                                "   Ibis %s seems to be flying in front of anyone else" % self._all_ids[m])
                            self._relationship_data[-1][m] = -2
                            self._strength_data[-1][m] = 0
                            self._strongest_positions[-1][m] = [0, 0, 0]

                        else:

                            if verbose: print("   Ibis %s seems to be flying in-wake of none" % self._all_ids[m])
                            self._relationship_data[-1][m] = -1
                            self._strength_data[-1][m] = 0
                            min_dist = np.argmin(filtered_distances[m])
                            self._strongest_positions[-1][m] = relpositions[m][min_dist]

                    else:

                        front_id = argmax(row)  # index of the leading Ibis
                        front = self._all_ids[front_id]  # name of the leading Ibis
                        diad = (front, back)  # pair (leader, follower)

                        # relationship_data has
                        # - one row per snapshot (the last row has index -1)
                        # - it has as many columns as the number of Ibides
                        # - the value in column m corresponds to the NAME of the Ibis that is leading the m-th Ibis
                        self._relationship_data[-1][m] = front

                        # strengths_data has
                        # - one row per snapshot (the last row has index -1)
                        # - it has as many columns as the number of Ibides
                        # - the value in column m corresponds to the highest firing strength of inwake flying for the m-th Ibis
                        # NOTE: this cannot be <=0
                        self._strength_data[-1][m] = max_row
                        assert (max_row > 0)

                        # stronghest_positions has
                        # - one row per snapshot (the last row has index -1)
                        # - it has as many columns as the number of Ibides
                        # - the value in column m corresponds to the relative position with respect to the leading Ibis
                        # NOTE: it must be a triple (x,y,z) and y must be <0
                        self._strongest_positions[-1][m] = relpositions[m][front_id]

                        # print (self._strongest_positions[-1][m])
                        assert (self._strongest_positions[-1][m][1] < 0)

                # build inwake graph
                G = self._build_graph(matrix)

                self._info_grouping.append(list(nx.weakly_connected_components(G)))

                # plot figure
                if plot and splitline is not None:
                    self._plot_birds_with_vectors(splitline, pairs, matrix, self._header,
                                                  timestamp=timestamp, inwake_graph=G,
                                                  # filename=outputdir+sep+"ibis%d.png" % n)
                                                  filename=outputdir + sep + "ibis%d.%s" % (n, img_format),
                                                  highlight=highlight,
                                                  print_names=print_names, plot_arrow=plot_arrow)

                end = time()

                print("[%.3f s]" % (end - start))

        self._relationship_data = array(self._relationship_data)
        self._strongest_positions = array(self._strongest_positions)

    """
	def _remove_multiple_inwake(self, G):
		print(" * Detecting multiple inwake")
		for indegree in G.in_degree(G.nodes):
			if indegree[1]>1:
				print(" * Ibis %d is in-wake of %d birds, reducing to 1" % (indegree[0], indegree[1]))
				index_best = 0; value_best = 0
				for n, edge in enumerate(G.in_edges(indegree[0], data=True)):					
					print (edge)
					if edge[2]['weight']>value_best:
						value_best = edge[2]['weight']
						index_best = n
				to_be_removed = []
				for n, edge in enumerate(G.in_edges(indegree[0])):
					if n!=index_best:
						to_be_removed.append(edge)
				G.remove_edges_from(to_be_removed)
		return G
	"""

    def _plot_birds_with_vectors(self, splitline, pairs, matrix, header, timestamp, inwake_graph, filename, highlight,
                                 print_names, plot_arrow=True):
        # def _plot_birds_with_vectors(self,line,pairs,matrix,header,inwake_graph,filename):

        # split_line = line.strip().split(",")
        # timestamp = split_line[0]
        # line = array(list(map(float,split_line[1:])))
        line = array(splitline)
        line = line.reshape((len(line) // 3, 3))

        # build graph and remove multiple in_edges (multiple inwakeness)
        # G = self._build_graph(matrix)
        # G = self._remove_multiple_inwake(G)
        G = inwake_graph

        # generate list of names and birds coordinates
        pivot = pairs[0][0]
        coords = [None] * (len(self._all_ids))
        names = [None] * (len(self._all_ids))
        for i, p in enumerate(pairs):
            if p[0] != pivot: break
            indx = self._corrs[p[1]]
            coords[indx] = line[i]
            names[indx] = p[1]
        coords[self._corrs[pivot]] = [0, 0, 0]
        names[self._corrs[pivot]] = pivot
        coords = array(coords).T

        self._clear_figure()
        fig = self._figure
        fig.suptitle("Detected in-wake flying at timestamp: %s" % (timestamp))

        ax = self._ax
        top = self._top
        back = self._back
        side = self._side

        top.view_init(elev=90., azim=-90)
        back.view_init(elev=0., azim=270)
        side.view_init(elev=0., azim=0)

        """
		top.set_title("Top view")
		back.set_title("Back view")
		side.set_title("Side view")
		ax.set_title("Ortho view")
		"""

        maxx = max(coords[0])
        maxy = max(coords[1])
        maxz = max(coords[2])

        mass = (max(coords.flatten()))
        mini = (min(coords.flatten()))

        deltax = (max(coords[1]) - min(coords[1])) * .02
        deltay = (max(coords[1]) - min(coords[1])) * .02
        deltaz = (max(coords[1]) - min(coords[1])) * .02

        # orientation hint
        a = Arrow3D([maxx, maxx], [maxy - deltay, maxy + deltay], [0, 0], mutation_scale=20, lw=3, arrowstyle="-|>",
                    color="g")
        b = Arrow3D([maxx, maxx], [maxy - deltay, maxy + deltay], [maxz, maxz], mutation_scale=20, lw=3,
                    arrowstyle="-|>", color="g")
        c = Arrow3D([maxx, maxx], [maxy - deltay, maxy + deltay], [maxz+ .5, maxz + .5], mutation_scale=20, lw=3,
                    arrowstyle="-|>", color="g")
        d = Arrow3D([maxx, maxx], [maxy - deltay, maxy + deltay], [maxz, maxz], mutation_scale=30, lw=3,
                    arrowstyle="-|>", color="g")

        if plot_arrow:
            top.add_artist(a)
            # back.add_artist(b)
            side.add_artist(c)
            ax.add_artist(d)

        # draw ibides
        bla = ax.scatter(*coords, s=50, color='slategrey')
        top.scatter(*coords, s=50, color='slategrey')
        side.scatter(*coords, s=50, color='slategrey')
        back.scatter(*coords, s=50, color='slategrey')

        #### HIGHLIGHT
        for j in highlight:
            j = str(j)
            ax.scatter(*coords.T[self._corrs[j]], s=200, lw=3, alpha=0.5, edgecolor="red", color="none")

        top.set_xlim(min(coords[0]) - 5, max(coords[0]) + 5)
        top.set_ylim(min(coords[1]) - 2, max(coords[1]) + 2)
        top.set_zlim(min(coords[2]) - 2, max(coords[2]) + 2)
        side.set_xlim(min(coords[0]) - 5, max(coords[0]) + 5)
        side.set_ylim(min(coords[1]) - 2, max(coords[1]) + 2)
        side.set_zlim(min(coords[2]) - 1, max(coords[2]) + 1)
        back.set_xlim(min(coords[0]) - 3, max(coords[0]) + 5)
        back.set_ylim(min(coords[1]) - 2, max(coords[1]) + 2)
        back.set_zlim(min(coords[2]) - 0.5, max(coords[2]) + 0.5)
        

        back.set_yticks([])
        back.tick_params(axis='z', which='major', pad=10)
        side.set_xticks([])
        top.set_zticks([])
        top.tick_params(axis='both', which='major', pad=10)

        for edge in G.edges(data=True):
            id_generating_inwake = self._corrs[str(edge[1])]
            id_receiving_inwake = self._corrs[str(edge[0])]
            weight = edge[2]['weight']
            from_n = coords.T[id_generating_inwake]
            to_n = coords.T[id_receiving_inwake]
            z = Arrow3D(
                [from_n[0], to_n[0]],
                [from_n[1], to_n[1]],
                [from_n[2], to_n[2]],
                mutation_scale=10, lw=weight * 3, arrowstyle="-|>",
                color=plt.cm.coolwarm(weight))
            ax.add_artist(z)
            y = Arrow3D(
                [from_n[0], to_n[0]],
                [from_n[1], to_n[1]],
                [from_n[2], to_n[2]],
                mutation_scale=10, lw=weight * 3, arrowstyle="-|>",
                color=plt.cm.coolwarm(weight))
            top.add_artist(y)
            w = Arrow3D(
                [from_n[0], to_n[0]],
                [from_n[1], to_n[1]],
                [from_n[2], to_n[2]],
                mutation_scale=10, lw=weight * 3, arrowstyle="-|>",
                color=plt.cm.coolwarm(weight))
            side.add_artist(w)
            q = Arrow3D(
                [from_n[0], to_n[0]],
                [from_n[1], to_n[1]],
                [from_n[2], to_n[2]],
                mutation_scale=10, lw=weight * 3, arrowstyle="-|>",
                color=plt.cm.coolwarm(weight))
            back.add_artist(q)

        fontsize = 11
        for name, c in zip(names, coords.T):
            if int(name) not in print_names: continue
            ax.text(*c, name, alpha=0.9, fontsize=fontsize)
            top.text(*c, name, alpha=0.9, fontsize=fontsize)
            side.text(*c, name, alpha=0.9, fontsize=fontsize)
            back.text(*c, name, alpha=0.9, fontsize=fontsize)

        ax.set_xlabel("e/w", labelpad=10)
        ax.set_ylabel("n/s", labelpad=10)
        ax.set_zlabel("u/d", labelpad=10)

        top.set_xlabel("e/w", labelpad=10)
        top.set_ylabel("n/s", labelpad=15)
        top.set_zlabel("u/d", labelpad=1)

        back.set_xlabel("e/w", labelpad=10)
        back.set_ylabel("n/s", labelpad=1)
        back.set_zlabel("u/d", labelpad=25)

        side.set_xlabel("e/w", labelpad=1)
        side.set_ylabel("n/s", labelpad=10)
        side.set_zlabel("u/d", labelpad=10)

        # ax.legend()
        fig.tight_layout()
        fig.savefig(filename)
        try:
            fig.savefig(filename)
            print(" * Figure %s saved" % filename)
        except:
            print("WARNING: * Figure %s could not be saved" % filename)

    def _assess_inwake(self, n, e, v):
        """
			Calculates the firing of the inwake flying rule.
			Only "following" birds are considered (i.e., having n(orth)>0).
			Only "in-plane" birds are considered (i.e., |z|<0.75)
		"""

        # in any case, flying in front of a bird reduces firing to 0
        if n > 0: return 0

        # in any other case, calculate fuzzy rules
        self._reasoner.set_variable("bird_ew", e)
        self._reasoner.set_variable("bird_ns", n)
        self._reasoner.set_variable("bird_plane", v)

        # return inference
        return self._reasoner.inference()['flying']

    def export_diads(self, file):
        with open(file, "w") as fo:
            fo.write("#timestamp\t" + "\t".join(map(str, self._all_ids)) + "\n")
            for ts, reldata in zip(self._timestamp_data, self._relationship_data):
                fo.write(ts + "\t" + "\t".join(map(str, reldata)) + "\n")

    def export_strengths(self, file):
        with open(file, "w") as fo:
            fo.write("#timestamp\t" + "\t".join(map(str, self._all_ids)) + "\n")
            for ts, rest in zip(self._timestamp_data, self._strength_data):
                fo.write("%s\t" % ts)
                fo.write("\t".join(map(str, rest)) + "\n")

    def export_4Elisa(self, file, bird=None):
        if bird is None:
            raise Exception("Specify a bird code")

        absopath = os.path.dirname(file)
        absofile = file[len(absopath) + 1:]
        absofile = absopath + '\\' + str(bird) + '_' + absofile
        index = self._corrs[bird]

        output = []
        with open(absofile, "w") as fo:

            fo.write("timestamp\tleader\tx\ty\tz\n")

            snap = 0
            for ts, indici, triple in zip(self._timestamp_data, self._relationship_data, self._strongest_positions):
                if bird in self._missing_ibises_list[snap]:
                    print(
                        "WARNING: skipping snapshot at time %s because ibis %s's sensors were not working properly." % (
                        ts, bird))
                    snap += 1;
                    continue
                fo.write("%s\t" % ts)
                fo.write("%s\t" % indici[index])
                fo.write("\t".join(map(str, triple[index])) + "\n")
                snap += 1

    def export_groups(self, file_grouping):
        with open(file_grouping, "w") as fo:
            for ts, info in zip(self._timestamp_data, self._info_grouping):
                fo.write("%s\t" % ts)
                for ss in info:
                    # print(ss)
                    for ibis in ss:
                        fo.write("%d " % ibis)
                    fo.write("\t")
                fo.write("\n")


if __name__ == "__main__":
    pass