Source code for meshparty.skeletonize

from scipy import sparse, spatial
import numpy as np
import time
from meshparty import trimesh_vtk, utils, mesh_filters
import pandas as pd
from pykdtree.kdtree import KDTree
from tqdm import trange, tqdm
from meshparty.trimesh_io import Mesh
from meshparty.skeleton import Skeleton
from collections import defaultdict
from pykdtree.kdtree import KDTree as pyKDTree
import trimesh.ray
from trimesh.ray import ray_pyembree
import logging

[docs]def skeletonize_mesh(mesh, soma_pt=None, soma_radius=7500, collapse_soma=True, invalidation_d=12000, smooth_vertices=False, compute_radius=True, compute_original_index=True, verbose=True): ''' Build skeleton object from mesh skeletonization Parameters ---------- mesh: meshparty.trimesh_io.Mesh the mesh to skeletonize, defaults assume vertices in nm soma_pt: np.array a length 3 array specifying to soma location to make the root default=None, in which case a heuristic root will be chosen in units of mesh vertices. soma_radius: float distance in mesh vertex units over which to consider mesh vertices close to soma_pt to belong to soma these vertices will automatically be invalidated and no skeleton branches will attempt to reach them. This distance will also be used to collapse all skeleton points within this distance to the soma_pt root if collpase_soma is true. (default=7500 (nm)) collapse_soma: bool whether to collapse the skeleton around the soma point (default True) invalidation_d: float the distance along the mesh to invalidate when applying TEASAR like algorithm. Controls how detailed a structure the skeleton algorithm reaches. default (12000 (nm)) smooth_vertices: bool whether to smooth the vertices of the skeleton compute_radius: bool whether to calculate the radius of the skeleton at each point on the skeleton (default True) compute_original_index: bool whether to calculate how each of the mesh nodes maps onto the skeleton (default True) verbose: bool whether to print verbose logging Returns ------- :obj:`meshparty.skeleton.Skeleton` a Skeleton object for this mesh ''' skel_verts, skel_edges, smooth_verts, orig_skel_index, skel_map = calculate_skeleton_paths_on_mesh(mesh, soma_pt=soma_pt, soma_thresh=soma_radius, invalidation_d=invalidation_d, return_map=True) if smooth_vertices is True: skel_verts = smooth_verts if collapse_soma is True and soma_pt is not None: soma_verts = mesh_filters.filter_spatial_distance_from_points(mesh, [soma_pt], soma_radius) new_v, new_e, new_skel_map, vert_filter, root_ind = collapse_soma_skeleton(soma_pt, skel_verts, skel_edges, soma_d_thresh=soma_radius, mesh_to_skeleton_map=skel_map, soma_mesh_indices=soma_verts, return_filter=True, return_soma_ind=True) else: new_v, new_e, new_skel_map = skel_verts, skel_edges, skel_map vert_filter = np.arange(len(orig_skel_index)) if soma_pt is None: sk_graph = utils.create_csgraph(new_v, new_e) root_ind = utils.find_far_points_graph(sk_graph)[0] else: _, qry_inds = pyKDTree(new_v).query(soma_pt[np.newaxis,:]) # Still try to root close to the soma root_ind = qry_inds[0] skel_map_full_mesh = np.full(mesh.node_mask.shape, -1, dtype=np.int64) skel_map_full_mesh[mesh.node_mask] = new_skel_map ind_to_fix = mesh.map_boolean_to_unmasked(np.isnan(new_skel_map)) skel_map_full_mesh[ind_to_fix] = -1 props = {} if compute_original_index is True: props['mesh_index'] = np.append(mesh.map_indices_to_unmasked(orig_skel_index[vert_filter]), -1) if compute_radius is True: rs = ray_trace_distance(orig_skel_index[vert_filter], mesh) rs = np.append(rs, soma_radius) props['rs'] = rs sk = Skeleton(new_v, new_e, mesh_to_skel_map=skel_map_full_mesh, vertex_properties=props, root=root_ind) return sk
[docs]def calculate_skeleton_paths_on_mesh(mesh, soma_pt=None, soma_thresh=7500, invalidation_d=10000, smooth_neighborhood=5, large_skel_path_threshold=5000, cc_vertex_thresh=100, return_map=False): """ function to turn a trimesh object of a neuron into a skeleton, without running soma collapse, or recasting result into a Skeleton. Used by :func:`meshparty.skeletonize.skeletonize_mesh` and makes use of :func:`meshparty.skeletonize.skeletonize_components` Parameters ---------- mesh: meshparty.trimesh_io.Mesh the mesh to skeletonize, defaults assume vertices in nm soma_pt: np.array a length 3 array specifying to soma location to make the root default=None, in which case a heuristic root will be chosen in units of mesh vertices soma_thresh: float distance in mesh vertex units over which to consider mesh vertices close to soma_pt to belong to soma these vertices will automatically be invalidated and no skeleton branches will attempt to reach them. This distance will also be used to collapse all skeleton points within this distance to the soma_pt root if collpase_soma is true. (default=7500 (nm)) invalidation_d: float the distance along the mesh to invalidate when applying TEASAR like algorithm. Controls how detailed a structure the skeleton algorithm reaches. default (10000 (nm)) smooth_neighborhood: int the neighborhood in edge hopes over which to smooth skeleton locations. This controls the smoothing of the skeleton (default 5) large_skel_path_threshold: int the threshold in terms of skeleton vertices that skeletons will be nominated for tip merging. Smaller skeleton fragments will not be merged at their tips (default 5000) cc_vertex_thresh: int the threshold in terms of vertex numbers that connected components of the mesh will be considered for skeletonization. mesh connected components with fewer than these number of vertices will be ignored by skeletonization algorithm. (default 100) return_map: bool whether to return a map of how each mesh vertex maps onto each skeleton vertex based upon how it was invalidated. Returns ------- skel_verts: np.array a Nx3 matrix of skeleton vertex positions skel_edges: np.array a Kx2 matrix of skeleton edge indices into skel_verts smooth_verts: np.array a Nx3 matrix of vertex positions after smoothing skel_verts_orig: np.array a N long index of skeleton vertices in the original mesh vertex index (mesh_to_skeleton_map): np.array a Mx2 map of mesh vertex indices to skeleton vertex indices """ skeletonize_output = skeletonize_components(mesh, soma_pt=soma_pt, soma_thresh=soma_thresh, invalidation_d=invalidation_d, cc_vertex_thresh=cc_vertex_thresh, return_map=return_map) if return_map is True: all_paths, roots, tot_path_lengths, mesh_to_skeleton_map = skeletonize_output else: all_paths, roots, tot_path_lengths = skeletonize_output all_edges = [] for comp_paths in all_paths: all_edges.append(utils.paths_to_edges(comp_paths)) tot_edges = np.vstack(all_edges) skel_verts, skel_edges, skel_verts_orig = reduce_verts(mesh.vertices, tot_edges) smooth_verts = smooth_graph(skel_verts, skel_edges, neighborhood=smooth_neighborhood) if return_map: mesh_to_skeleton_map = utils.nanfilter_shapes(np.unique(tot_edges.ravel()), mesh_to_skeleton_map) else: mesh_to_skeleton_map = None output_tuple = (skel_verts, skel_edges, smooth_verts, skel_verts_orig) if return_map: output_tuple = output_tuple + (mesh_to_skeleton_map,) return output_tuple
[docs]def reduce_verts(verts, faces): """removes unused vertices from a graph or mesh Parameters ---------- verts : numpy.array NxD numpy array of vertex locations faces : numpy.array MxK numpy array of connected shapes (i.e. edges or tris) (entries are indices into verts) Returns ------- np.array new_verts, a filtered set of vertices np.array new_face, a reindexed set of faces (or edges) np.array used_verts, the index of the new_verts in the old verts """ used_verts = np.unique(faces.ravel()) new_verts = verts[used_verts, :] new_face = np.zeros(faces.shape, dtype=faces.dtype) for i in range(faces.shape[1]): new_face[:, i] = np.searchsorted(used_verts, faces[:, i]) return new_verts, new_face, used_verts
[docs]def skeletonize_components(mesh, soma_pt=None, soma_thresh=10000, invalidation_d=10000, cc_vertex_thresh=100, return_map=False): """ core skeletonization routine, used by :func:`meshparty.skeletonize.calculate_skeleton_paths_on_mesh` to calculcate skeleton on all components of mesh, with no post processing """ # find all the connected components in the mesh n_components, labels = sparse.csgraph.connected_components(mesh.csgraph, directed=False, return_labels=True) comp_labels, comp_counts = np.unique(labels, return_counts=True) if return_map: mesh_to_skeleton_map = np.full(len(mesh.vertices), np.nan) # variables to collect the paths, roots and path lengths all_paths = [] roots = [] tot_path_lengths = [] if soma_pt is not None: soma_d = mesh.vertices - soma_pt[np.newaxis, :] soma_d = np.linalg.norm(soma_d, axis=1) is_soma_pt = soma_d < soma_thresh else: is_soma_pt = None soma_d = None # loop over the components for k in trange(n_components): if comp_counts[k] > cc_vertex_thresh: # find the root using a soma position if you have it # it will fall back to a heuristic if the soma # is too far away for this component root, root_ds, pred, valid = setup_root(mesh, is_soma_pt, soma_d, labels == k) # run teasar on this component teasar_output = mesh_teasar(mesh, root=root, root_ds=root_ds, root_pred=pred, valid=valid, invalidation_d=invalidation_d, return_map=return_map) if return_map is False: paths, path_lengths = teasar_output else: paths, path_lengths, mesh_to_skeleton_map_single = teasar_output mesh_to_skeleton_map[~np.isnan(mesh_to_skeleton_map_single)] = mesh_to_skeleton_map_single[~np.isnan(mesh_to_skeleton_map_single)] if len(path_lengths) > 0: # collect the results in lists tot_path_lengths.append(path_lengths) all_paths.append(paths) roots.append(root) if return_map: return all_paths, roots, tot_path_lengths, mesh_to_skeleton_map else: return all_paths, roots, tot_path_lengths
[docs]def setup_root(mesh, is_soma_pt=None, soma_d=None, is_valid=None): """ function to find the root index to use for this mesh """ if is_valid is not None: valid = np.copy(is_valid) else: valid = np.ones(len(mesh.vertices), np.bool) assert(len(valid)==mesh.vertices.shape[0]) root = None # soma mode if is_soma_pt is not None: # pick the first soma as root assert(len(soma_d)==mesh.vertices.shape[0]) assert(len(is_soma_pt)==mesh.vertices.shape[0]) is_valid_root = is_soma_pt & valid valid_root_inds = np.where(is_valid_root)[0] if len(valid_root_inds) > 0: min_valid_root = np.nanargmin(soma_d[valid_root_inds]) root = valid_root_inds[min_valid_root] root_ds, pred = sparse.csgraph.dijkstra(mesh.csgraph, directed=False, indices=root, return_predecessors=True) else: start_ind = np.where(valid)[0][0] root, target, pred, dm, root_ds = utils.find_far_points(mesh, start_ind=start_ind) valid[is_soma_pt] = False if root is None: # there is no soma close, so use far point heuristic start_ind = np.where(valid)[0][0] root, target, pred, dm, root_ds = utils.find_far_points(mesh, start_ind=start_ind) valid[root] = False assert(np.all(~np.isinf(root_ds[valid]))) return root, root_ds, pred, valid
[docs]def mesh_teasar(mesh, root=None, valid=None, root_ds=None, root_pred=None, soma_pt=None, soma_thresh=7500, invalidation_d=10000, return_timing=False, return_map=False): """core skeletonization function used to skeletonize a single component of a mesh""" # if no root passed, then calculation one if root is None: root, root_ds, root_pred, valid = setup_root(mesh, soma_pt=soma_pt, soma_thresh=soma_thresh) # if root_ds have not be precalculated do so if root_ds is None: root_ds, root_pred = sparse.csgraph.dijkstra(mesh.csgraph, False, root, return_predecessors=True) # if certain vertices haven't been pre-invalidated start with just # the root vertex invalidated if valid is None: valid = np.ones(len(mesh.vertices), np.bool) valid[root] = False else: if (len(valid) != len(mesh.vertices)): raise Exception("valid must be length of vertices") if return_map == True: mesh_to_skeleton_dist = np.full(len(mesh.vertices), np.inf) mesh_to_skeleton_map = np.full(len(mesh.vertices), np.nan) total_to_visit = np.sum(valid) if np.sum(np.isinf(root_ds) & valid) != 0: print(np.where(np.isinf(root_ds) & valid)) raise Exception("all valid vertices should be reachable from root") # vector to store each branch result paths = [] # vector to store each path's total length path_lengths = [] # keep track of the nodes that have been visited visited_nodes = [root] # counter to track how many branches have been counted k = 0 # arrays to track timing start = time.time() time_arrays = [[], [], [], [], []] with tqdm(total=total_to_visit) as pbar: # keep looping till all vertices have been invalidated while(np.sum(valid) > 0): k += 1 t = time.time() # find the next target, farthest vertex from root # that has not been invalidated target = np.nanargmax(root_ds*valid) if (np.isinf(root_ds[target])): raise Exception('target cannot be reached') time_arrays[0].append(time.time()-t) t = time.time() # figure out the longest this branch could be # by following the route from target to the root # and finding the first already visited node (max_branch) # The dist(root->target) - dist(root->max_branch) # is the maximum distance the shortest route to a branch # point from the target could possibly be, # use this bound to reduce the djisktra search radius for this target max_branch = target while max_branch not in visited_nodes: max_branch = root_pred[max_branch] max_path_length = root_ds[target]-root_ds[max_branch] # calculate the shortest path to that vertex # from all other vertices # up till the distance to the root ds, pred_t = sparse.csgraph.dijkstra( mesh.csgraph, False, target, limit=max_path_length, return_predecessors=True) # pick out the vertex that has already been visited # which has the shortest path to target min_node = np.argmin(ds[visited_nodes]) # reindex to get its absolute index branch = visited_nodes[min_node] # this is in the index of the point on the skeleton # we want this branch to connect to time_arrays[1].append(time.time()-t) t = time.time() # get the path from the target to branch point path = utils.get_path(target, branch, pred_t) visited_nodes += path[0:-1] # record its length assert(~np.isinf(ds[branch])) path_lengths.append(ds[branch]) # record the path paths.append(path) time_arrays[2].append(time.time()-t) t = time.time() # get the distance to all points along the new path # within the invalidation distance dm, _, sources = sparse.csgraph.dijkstra( mesh.csgraph, False, path, limit=invalidation_d, min_only=True, return_predecessors=True) time_arrays[3].append(time.time()-t) t = time.time() # all such non infinite distances are within the invalidation # zone and should be marked invalid nodes_to_update = ~np.isinf(dm) marked = np.sum(valid & ~np.isinf(dm)) if return_map == True: new_sources_closer = dm[nodes_to_update]<mesh_to_skeleton_dist[nodes_to_update] mesh_to_skeleton_map[nodes_to_update] = np.where(new_sources_closer, sources[nodes_to_update], mesh_to_skeleton_map[nodes_to_update]) mesh_to_skeleton_dist[nodes_to_update] = np.where(new_sources_closer, dm[nodes_to_update], mesh_to_skeleton_dist[nodes_to_update]) valid[~np.isinf(dm)] = False # print out how many vertices are still valid pbar.update(marked) time_arrays[4].append(time.time()-t) # record the total time dt = time.time() - start out_tuple = (paths, path_lengths) if return_map: out_tuple = out_tuple + (mesh_to_skeleton_map,) if return_timing: out_tuple = out_tuple + (time_arrays, dt) return out_tuple
[docs]def smooth_graph(values, edges, mask=None, neighborhood=2, iterations=100, r=.1): """ smooths a spatial graph via iterative local averaging calculates the average value of neighboring values and relaxes the values toward that average Parameters ---------- values : numpy.array a NxK numpy array of values, for example xyz positions edges : numpy.array a Mx2 numpy array of indices into values that are edges mask : numpy.array NOT yet implemented optional N boolean vector of values to mask the vert locations. the result will return a result at every vert but the values that are false in this mask will be ignored and not factored into the smoothing. neighborhood : int an integer of how far in the graph to relax over as being local to any vertex (default = 2) iterations : int number of relaxation iterations (default = 100) r : float relaxation factor at each iteration new_vertex = (1-r)*old_vertex*mask + (r+(1-r)*(1-mask))*(local_avg) default = .1 Returns ------- np.array new_verts, a Nx3 list of new smoothed vertex positions """ N = len(values) E = len(edges) # setup a sparse matrix with the edges sm = sparse.csc_matrix( (np.ones(E), (edges[:, 0], edges[:, 1])), shape=(N, N)) # an identity matrix eye = sparse.csc_matrix((np.ones(N, dtype=np.float32), (np.arange(0, N), np.arange(0, N))), shape=(N, N)) # for undirected graphs we want it symettric sm = sm + sm.T # this will store our relaxation matrix C = sparse.csc_matrix(eye) # multiple the matrix and add to itself # to spread connectivity along the graph for i in range(neighborhood): C = C + sm @ C # zero out the diagonal elements C.setdiag(np.zeros(N)) # don't overweight things that are connected in more than one way C = C.sign() # measure total effective neighbors per node neighbors = np.sum(C, axis=1) # normalize the weights of neighbors according to number of neighbors neighbors = 1/neighbors C = C.multiply(neighbors) # convert back to csc C = C.tocsc() # multiply weights by relaxation term C *= r # construct relaxation matrix, adding on identity with complementary weight A = C + (1-r)*eye # make a copy of original vertices to no destroy inpuyt new_values = np.copy(values) # iteratively relax the vertices for i in range(iterations): new_values = A*new_values return new_values
[docs]def collapse_soma_skeleton(soma_pt, verts, edges, soma_d_thresh=12000, mesh_to_skeleton_map=None, soma_mesh_indices=None, return_filter=False, only_soma_component=True, return_soma_ind=False): """function to adjust skeleton result to move root to soma_pt Parameters ---------- soma_pt : numpy.array a 3 long vector of xyz locations of the soma (None to just remove duplicate ) verts : numpy.array a Nx3 array of xyz vertex locations edges : numpy.array a Kx2 array of edges of the skeleton soma_d_thresh : float distance from soma_pt to collapse skeleton nodes mesh_to_skeleton_map : np.array a M long array of how each mesh index maps to a skeleton vertex (default None). The function will update this as it collapses vertices to root. soma_mesh_indices : np.array a K long array of indices in the mesh that should be considered soma Any skeleton vertex on these vertices will all be collapsed to root. return_filter : bool whether to return a list of which skeleton vertices were used in the end for the reduced set of skeleton vertices only_soma_component : bool whether to collapse only the skeleton connected component which is closest to the soma_pt (default True) return_soma_ind : bool whether to return which skeleton index that is the soma_pt Returns ------- np.array verts, Px3 array of xyz skeleton vertices np.array edges, Qx2 array of skeleton edges (np.array) new_mesh_to_skeleton_map, returned if mesh_to_skeleton_map and soma_pt passed (np.array) used_vertices, if return_filter this contains the indices into the passed verts which the return verts is using int an index into the returned verts that is the root of the skeleton node, only returned if return_soma_ind is True """ if soma_pt is not None: if only_soma_component: closest_soma_ind = np.argmin(np.linalg.norm(verts-soma_pt, axis=1)) close_inds = np.linalg.norm(verts-soma_pt, axis=1) < soma_d_thresh orig_graph = utils.create_csgraph(verts, edges, euclidean_weight=False) speye = sparse.diags(close_inds.astype(int)) _, compids = sparse.csgraph.connected_components(orig_graph * speye) soma_verts = np.flatnonzero(compids[closest_soma_ind] == compids) else: dv = np.linalg.norm(verts - soma_pt_m, axis=1) soma_verts = np.where(dv < soma_d_thresh)[0] soma_pt_m = soma_pt[np.newaxis, :] new_verts = np.vstack((verts, soma_pt_m)) soma_i = verts.shape[0] edges_m = edges.copy() edges_m[np.isin(edges, soma_verts)] = soma_i simple_verts, simple_edges = trimesh_vtk.remove_unused_verts(new_verts, edges_m) good_edges = ~(simple_edges[:, 0] == simple_edges[:, 1]) if mesh_to_skeleton_map is not None: new_mesh_to_skeleton_map = mesh_to_skeleton_map.copy() remap_rows = np.isin(mesh_to_skeleton_map, soma_verts) new_mesh_to_skeleton_map[remap_rows] = soma_i new_mesh_to_skeleton_map = utils.nanfilter_shapes(np.unique(edges_m.ravel()), new_mesh_to_skeleton_map) if soma_mesh_indices is not None: new_mesh_to_skeleton_map[soma_mesh_indices] = len(simple_verts)-1 output = [simple_verts, simple_edges[good_edges]] if mesh_to_skeleton_map is not None: output.append(new_mesh_to_skeleton_map) if return_filter: used_vertices = np.unique(edges_m.ravel())[:-1] #Remove the largest value which is soma_i output.append(used_vertices) if return_soma_ind: output.append(len(simple_verts)-1) return output else: simple_verts, simple_edges = trimesh_vtk.remove_unused_verts(verts, edges) return simple_verts, simple_edges
[docs]def ray_trace_distance(vertex_inds, mesh, max_iter=10, rand_jitter=0.001, verbose=False, ray_inter=None): ''' Compute distance to opposite side of the mesh for specified vertex indices on the mesh. Parameters ---------- vertex_inds : np.array a K long set of indices into the mesh.vertices that you want to perform ray tracing on mesh : :obj:`meshparty.trimesh_io.Mesh` mesh to perform ray tracing on max_iter : int maximum retries to attempt in order to get a proper sdf measure (default 10) rand_jitter : float the amplitude of gaussian jitter on the vertex normal to add on each iteration (default .001) verbose : bool whether to print debug statements (default False) ray_inter: ray_pyembree.RayMeshIntersector a ray intercept object pre-initialized with a mesh, in case y ou are doing this many times and want to avoid paying initialization costs. (default None) will initialize it for you Returns ------- np.array rs, a K long array of sdf values. rays with no result after max_iters will contain zeros. ''' if not trimesh.ray.has_embree: logging.warning("calculating rays without pyembree, conda install pyembree for large speedup") if ray_inter is None: ray_inter = ray_pyembree.RayMeshIntersector(mesh) rs = np.zeros(len(vertex_inds)) good_rs = np.full(len(rs), False) it= 0 while not np.all(good_rs): if verbose: print(np.sum(~good_rs)) blank_inds = np.where(~good_rs)[0] starts = (mesh.vertices-mesh.vertex_normals)[vertex_inds,:][~good_rs,:] vs = -mesh.vertex_normals[vertex_inds,:] \ + (1.2**it)*rand_jitter*np.random.rand(*mesh.vertex_normals[vertex_inds,:].shape) vs = vs[~good_rs,:] rtrace = ray_inter.intersects_location(starts, vs, multiple_hits=False) if len(rtrace[0]>0): # radius values rs[blank_inds[rtrace[1]]] = np.linalg.norm(mesh.vertices[vertex_inds,:][rtrace[1]]-rtrace[0], axis=1) good_rs[blank_inds[rtrace[1]]]=True it+=1 if it>max_iter: break return rs