"""
An example ilustrating graph manipulation and display with Mayavi.

Starting from random points positions around a plane, we first use VTK
to create the Delaunay graph, and also to plot it. We then create a
minimum spanning tree of the data with Prim's algorithm. We display it
adding connections to Mayavi points.

The visualization clearly shows that the minimum spanning tree of the
points, considering all possible connections, is included in the Delaunay
graph.
"""
# Author: Gary Ruben
#         Gael Varoquaux <gael dot varoquaux at normalesup dot org>
# Copyright (c) 2009, Enthought, Inc.
# License: BSD style.

from enthought.mayavi import mlab
import numpy as np

def compute_delaunay_edges(xyz, visualize=False):
    """ Given 3-D points, returns the edges of their 
        Delaunay triangulation.

        Parameters
        -----------
        xyz: ndarray
            3 rows x N columns of x, y, z coordinates of the points

        Returns
        ---------  
        edges: 2D ndarray. 
            The indices of the edges of the Delaunay triangulation as a 
            (2, N) array [[pair1_index1, pair2_index1, ...],
                          [pair1_index2, pair2_index2, .. ]]
    """
    x, y, z = xyz
    if visualize:
        vtk_source =  mlab.pipeline.scalar_scatter(x, y, z, z,
                            name='Delaunay graph')
    else:
        vtk_source =  mlab.pipeline.scalar_scatter(x, y, z, figure=False)
    delaunay =  mlab.pipeline.delaunay3d(vtk_source)
    edges = mlab.pipeline.extract_edges(delaunay)
    if visualize:
        mlab.pipeline.surface(mlab.pipeline.tube(edges,
                                        tube_radius=0.001), 
                colormap='gist_yarg',
                vmin=-0.5*z.ptp(),
                vmax=2*z.ptp(),
                opacity=0.3)
    lines = edges.outputs[0].lines.to_array()
    return np.array([lines[1::3], lines[2::3]])


def minimum_spanning_tree(adjacency):
    """ Form a Minimum Spanning Tree using Prim's algorithm.
    """
    edge_list = list()
    adjacency = adjacency.copy()
    adjacency[adjacency == 0] = np.inf
    nb_vertices = adjacency.shape[0]
    indices = np.arange(nb_vertices)
    unselected = np.ones(nb_vertices, np.bool)
    unselected[0] = False
    parents = np.zeros(nb_vertices, np.int)
    while np.any(unselected):
        this_adjacency = adjacency[np.logical_not(unselected)].T[unselected].T
        u, v = np.lib.index_tricks.unravel_index(np.argmin(this_adjacency), 
                                          this_adjacency.shape)
        u = indices[np.logical_not(unselected)][u]
        v = indices[unselected][v]
        parents[v] = u
        unselected[v] = False
        edge_list.append((u, v))

    return np.array(edge_list), parents


def get_generation(parents):
    """ Given parent indices, return ranks in generation. 
    """
    ranks = np.zeros_like(parents)
    rank = 0
    descendants = set(range(1, len(parents)))
    # The great ancestor is the first one
    this_generation = [0, ]
    while descendants:
        next_generation = list()
        for vertex in descendants.copy():
            if parents[vertex] in this_generation:
                ranks[vertex] = rank
                next_generation.append(vertex)
                descendants.discard(vertex)
        rank += 1
        this_generation = next_generation
    return ranks

################################################################################
if __name__ == '__main__':
    mlab.figure(1, size=(720, 450), bgcolor=(1, 1, 1))
    mlab.clf()

    # Create a flat set of random points
    N = 800
    np.random.seed(10)
    xyz = np.random.random((3, N))
    xyz[-1, :] *= 0.1

    # Compute and display their Delaunay tessalation
    delaunay_edges = compute_delaunay_edges(xyz, visualize=True)

    # Compute a minimum spanning tree
    x, y, z = xyz
    distances =   (x - x[:, np.newaxis])**2 \
                + (y - y[:, np.newaxis])**2 \
                + (z - z[:, np.newaxis])**2
    edges, parents = minimum_spanning_tree(distances)
    # Compute the rank (generation) of each point in the tree structure.
    # We use it to give color to the tree.
    ranks = get_generation(parents)
    
    # Display the points
    pts = mlab.points3d(x, y, z, ranks, scale_factor=0.01, scale_mode='none',
                                          colormap='jet',
                                          name='Minimum Spanning Tree')
    # Add the connections and display them with tubes.
    pts.mlab_source.dataset.lines = edges
    mlab.pipeline.surface(mlab.pipeline.tube(pts, tube_radius=0.004),
                                                  colormap='jet')

    mlab.view(180, -30, 0.7) 
    mlab.show()
