Source code for meshparty.skeleton

import numpy as np
from meshparty import utils
from scipy import spatial, sparse
from pykdtree.kdtree import KDTree as pyKDTree
from copy import copy
import json
from meshparty import skeleton_io
from collections.abc import Iterable

[docs]class Skeleton: """Class to manage skeleton data Parameters ---------- vertices : np.array a Nx3 list of xyz locations of skeleton nodes edges : np.array a Kx2 list of edges in the skeleton where each row is a connected pair of vertex indices. mesh_to_skel_map : np.array For a skeleton derived from a mesh with N_mesh vertices, a length N_mesh array where that gives the closest on-graph skeleton vertex for each mesh vertex (or -1 if None). (default is None) vertex_properties: dict a dictionary of keys with strings where each value is a numpy.array of len(N) of properties of skeleton vertices root : None what vertex index should be root (default None will find a vertex far from others) """ def __init__(self, vertices, edges, mesh_to_skel_map=None, vertex_properties={}, root=None): self._vertices = np.array(vertices) self._edges = np.vstack(edges).astype(int) self.vertex_properties = vertex_properties self._mesh_to_skel_map = mesh_to_skel_map self._root = None self._cover_paths = None self._segments = None self._segment_map = None self._parent_node_array = None self._kdtree = None self._pykdtree = None self._csgraph = None self._csgraph_binary = None self._branch_points = None self._end_points = None self._reset_derived_objects() if root is None: self._create_default_root() else: self.reroot(root) @property def vertices(self): """ numpy.array : Nx3 set of xyz coordinates of skeletons""" return self._vertices @property def edges(self): """ numpy.array : Mx2 set of edges as indices into vertices """ return self._edges @property def mesh_to_skel_map(self): """ numpy.array : N_mesh length array giving the associated skeleton vertex for each mesh vertex""" return self._mesh_to_skel_map @property def segments(self): """ list : A list of numpy.array indicies of segments, paths from each branch or end point (inclusive) to the next rootward branch/root point (exclusive), that cover the skeleton""" if self._segments is None: self._segments, self._segment_map = self._compute_segments() return self._segments @property def segment_map(self): """ np.array : N set of of indices between 0 and len(self.segments)-1, denoting which segment a given skeleton vertex is in. """ if self._segment_map is None: self._segments, self._segment_map = self._compute_segments() return self._segment_map @property def csgraph(self): """ scipy.sparse.csr.csr_matrix : Directed sparse graph representation of the skeleton vertex connectivity. The i,jth element is the distance from a vertex i to its parent j, 0 otherwise. """ if self._csgraph is None: self._csgraph = self._create_csgraph() return self._csgraph.copy() @property def csgraph_binary(self): """ scipy.sparse.csr.csr_matrix : Directed binary sparse graph representation of skeleton vertex connectivity. The i,jth element is 1 if vertex i has parent j, 0 otherwise. """ if self._csgraph_binary is None: self._csgraph_binary = self._create_csgraph(euclidean_weight=False) return self._csgraph_binary @property def csgraph_undirected(self): """ scipy.sparse.csr_matrix : Undirected sparse graph representation of skeleton vertex connectivity. The i,jth element is the distance from vertex i to any connected vertex j, 0 otherwise. """ return self.csgraph + self.csgraph.T @property def csgraph_binary_undirected(self): """ scipy.sparse.csr_matrix : Undirected sparse graph representation of skeleton vertex connectivity. The i,jth element is 1 if vertex i is connected to vertex j, 0 otherwise. """ return self.csgraph_binary + self.csgraph_binary.T @property def n_vertices(self): """ int : Number of vertices in the skeleton """ return len(self.vertices) @property def root(self): """ int : Index of the skeleton root """ if self._root is None: self._create_default_root() return copy(self._root) @property def pykdtree(self): """ pykdtree.pyKDTree object : k-D tree from pykdtree (a bit faster but fewer functions) """ if self._pykdtree is None: self._pykdtree = pyKDTree(self.vertices) return self._pykdtree @property def kdtree(self): """ scipy.spatial.kdtree : k-D tree from scipy.spatial. """ if self._kdtree is None: self._kdtree = spatial.cKDTree(self.vertices) return self._kdtree @property def branch_points(self): """ numpy.array : Indices of branch points on the skeleton (pottentially including root)""" if self._branch_points is None: self._create_branch_and_end_points() return self._branch_points.copy() @property def n_branch_points(self): """ int : Number of branch points on the skeleton """ if self._branch_points is None: self._create_branch_and_end_points() return len(self._branch_points) @property def end_points(self): """ numpy.array : Indices of end points on the skeleton (pottentially including root)""" if self._end_points is None: self._create_branch_and_end_points() return self._end_points.copy() @property def n_end_points(self): """ int : Number of end points on the skeleton """ if self._end_points is None: self._create_branch_and_end_points() return len(self._end_points) @property def cover_paths(self): """ list : List of numpy.array objects with self.n_end_points elements, each a rootward path (ordered set of indices) starting from an endpoint and continuing until it reaches a point on a path earlier on the list. Paths are ordered by end point distance from root, starting with the most distal. When traversed from begining to end, gives the longest rootward path down each major branch first. When traversed from end to begining, guarentees every branch point is visited at the end of all its downstream paths before being traversed along a path. """ if self._cover_paths is None: self._cover_paths = self._compute_cover_paths() return self._cover_paths @property def distance_to_root(self): """ np.array : N length array with the distance to the root node along the skeleton. """ ds = sparse.csgraph.dijkstra(self.csgraph, directed=False, indices=self.root) return ds
[docs] def path_to_root(self, v_ind): ''' Gives the path to root from a specified vertex. Parameters ---------- v_ind : int Vertex index Returns ------- numpy.array : Ordered set of indices from v_ind to root, inclusive of both. ''' path = [v_ind] ind = v_ind ind = self.parent_node(ind) while ind is not None: path.append(ind) ind = self.parent_node(ind) return np.array(path)
[docs] def path_length(self, paths): """Returns the length of a path (described as an ordered collection of connected indices) Parameters ---------- path : Path or collection of paths, a path being an ordered list of linked vertex indices. Returns ------- Float or list of floats : The length of each path. """ if len(paths) == 0: return 0 if isinstance(paths[0], Iterable): Ls = [] for path in paths: Ls.append(self._single_path_length(path)) else: Ls = self._single_path_length(paths) return Ls
[docs] def reroot(self, new_root): """Change the skeleton root index. Parameters ---------- new_root : Int Skeleton vertex index to be the new root. """ if new_root > self.n_vertices: raise ValueError('New root must correspond to a skeleton vertex index') self._root = new_root self._parent_node_array = np.full(self.n_vertices, None) # The edge list has to be treated like an undirected graph d = sparse.csgraph.dijkstra(self.csgraph_binary, directed=False, indices=new_root) # Make edges in edge list orient as [child, parent] # Where each child only has one parent # And the root has no parent. (Thus parent is closer than child) edges = self.edges is_ordered = d[edges[:, 0]] > d[edges[:, 1]] e1 = np.where(is_ordered, edges[:, 0], edges[:, 1]) e2 = np.where(is_ordered, edges[:, 1], edges[:, 0]) self._edges = np.stack((e1, e2)).T self._parent_node_array[e1] = e2 self._reset_derived_objects()
[docs] def cut_graph(self, vinds, directed=True, euclidean_weight=True): """Return a csgraph for the skeleton with specified vertices cut off from their parent vertex. Parameters ---------- vinds : Collection of indices to cut off from their parent. directed : bool, optional Return the graph as directed, by default True euclidean_weight : bool, optional Return the graph with euclidean weights, by default True. If false, the unweighted. Returns ------- scipy.sparse.csr.csr_matrix Graph with vertices in vinds cutt off from parents. """ e_keep = ~np.isin(self.edges[:,0], vinds) es_new = self.edges[e_keep] return utils.create_csgraph(self.vertices, es_new, euclidean_weight=euclidean_weight, directed=directed)
[docs] def downstream_nodes(self, vinds): """Get a list of all nodes downstream of a collection of indices Parameters ---------- vinds : Collection of ints Collection of vertex indices Returns ------- List of arrays List whose ith element is an array of all vertices downstream of the ith element of vinds. """ if np.isscalar(vinds): vinds = [vinds] return_single = True else: return_single = False dns = [] for vind in vinds: g = self.cut_graph(vind) d = sparse.csgraph.dijkstra(g.T, indices=[vind]) dns.append(np.flatnonzero(~np.isinf(d[0]))) if return_single: dns=dns[0] return dns
[docs] def child_nodes(self, vinds): """Get a list of all immediate children of list of indices Parameters ---------- vinds : Collection of ints Collection of vertex indices Returns ------- List of arrays A list whose ith element is the children of the ith element of vinds. """ if np.isscalar(vinds): vinds = [vinds] return_single = True else: return_single = False cinds = [] for vind in vinds: cinds.append(self.edges[self.edges[:,1]==vind, 0]) if return_single: cinds=cinds[0] return cinds
def _create_default_root(self): r = utils.find_far_points_graph(self._create_csgraph(directed=False)) self.reroot(r[0])
[docs] def parent_node(self, vinds): """Get a list of parent nodes for specified vertices Parameters ---------- vinds : Collection of ints Collection of vertex indices Returns ------- numpy.array The parent node of each vertex index in vinds. """ if isinstance(vinds, list): vinds = np.array(vinds) return self._parent_node_array[vinds]
def _reset_derived_objects(self): """Reset all properties that could change when skeletons are rerooted. """ self._csgraph = None self._csgraph_binary = None self._cover_paths = None self._segments = None self._segment_map = None def _create_csgraph(self, directed=True, euclidean_weight=True): """Create the csgraph for the skeleton. """ return utils.create_csgraph(self.vertices, self.edges, euclidean_weight=euclidean_weight, directed=directed) def _create_branch_and_end_points(self): """Pre-compute branch and end points from the graph """ n_children = np.sum(self.csgraph_binary > 0, axis=0).squeeze() self._branch_points = np.flatnonzero(n_children > 1) self._end_points = np.flatnonzero(n_children == 0) def _single_path_length(self, path): """Compute the length of a single path """ xs = self.vertices[path[:-1]] ys = self.vertices[path[1:]] return np.sum(np.linalg.norm(ys-xs, axis=1)) def _compute_cover_paths(self): '''Compute the list of cover paths along the skeleton ''' cover_paths = [] seen = np.full(self.n_vertices, False) ep_order = np.argsort(self.distance_to_root[self.end_points])[::-1] for ep in self.end_points[ep_order]: ptr = np.array(self.path_to_root(ep)) cover_paths.append(ptr[~seen[ptr]]) seen[ptr] = True return cover_paths def _compute_segments(self): """Precompute segments between branches and end points""" segments = [] segment_map = np.zeros(len(self.vertices))-1 path_queue = self.end_points.tolist() bp_all = self.branch_points bp_seen = [] seg_ind = 0 while len(path_queue)>0: ind = path_queue.pop() segment = [ind] ptr = self.path_to_root(ind) if len(ptr)>1: for pind in ptr[1:]: if pind in bp_all: segments.append(np.array(segment)) segment_map[segment] = seg_ind seg_ind += 1 if pind not in bp_seen: path_queue.append(pind) bp_seen.append(pind) break else: segment.append(pind) else: segments.append(np.array(segment)) segment_map[segment] = seg_ind seg_ind += 1 else: segments.append(np.array(segment)) segment_map[segment] = seg_ind seg_ind += 1 return segments, segment_map.astype(int)
[docs] def export_to_swc(self, filename, node_labels=None, radius=None, header=None, xyz_scaling=1000): ''' Export a skeleton file to an swc file (see http://research.mssm.edu/cnic/swc.html for swc definition) Parameters ---------- filename : str Name of the file to save the swc to node_labels : np.array None (default) or an interable of ints co-indexed with vertices. Corresponds to the swc node categories. Defaults to setting all odes to label 3, dendrite. radius : iterable None (default) or an iterable of floats. This should be co-indexed with vertices. Radius values are assumed to be in the same units as the node vertices. header : dict default None. Each key value pair in the dict becomes a parameter line in the swc header. xyz_scaling : Number default 1000. Down-scales spatial units from the skeleton's units to whatever is desired by the swc. E.g. nm to microns has scaling=1000. ''' skeleton_io.export_to_swc(self, filename, node_labels=node_labels, radius=radius, header=header, xyz_scaling=xyz_scaling)