Source code for meshparty.skeletonize

from scipy import sparse, spatial, optimize, signal
import numpy as np
import time
from meshparty import utils

try:
    from pykdtree.kdtree import KDTree
except:
    KDTree = spatial.cKDTree
from tqdm import tqdm
from meshparty.skeleton import Skeleton

import fastremap
import logging
from . import skeleton_utils


[docs]def skeletonize_mesh( mesh, soma_pt=None, soma_radius=7500, collapse_soma=True, collapse_function="sphere", invalidation_d=12000, smooth_vertices=False, compute_radius=True, shape_function="single", compute_original_index=True, verbose=True, smooth_iterations=12, smooth_neighborhood=2, smooth_r=0.1, cc_vertex_thresh=100, root_index=None, remove_zero_length_edges=True, collapse_params={}, meta={}, ): """ Build skeleton object from mesh skeletonization Parameters ---------- mesh: meshparty.trimesh_io.Mesh the mesh to skeletonize, defaults assume vertices in nm soma_pt: np.array a length 3 array specifying to soma location to make the root default=None, in which case a heuristic root will be chosen in units of mesh vertices. soma_radius: float distance in mesh vertex units over which to consider mesh vertices close to soma_pt to belong to soma these vertices will automatically be invalidated and no skeleton branches will attempt to reach them. This distance will also be used to collapse all skeleton points within this distance to the soma_pt root if collpase_soma is true. (default=7500 (nm)) collapse_soma: bool whether to collapse the skeleton around the soma point (default True) collapse_function: 'sphere' or 'branch' Determines which soma collapse function to use. Sphere uses the soma_radius and collapses all vertices within that radius to the soma. Branch is an experimental approach that tries to compute the right boundary for each branch into soma. invalidation_d: float the distance along the mesh to invalidate when applying TEASAR like algorithm. Controls how detailed a structure the skeleton algorithm reaches. default (12000 (nm)) smooth_vertices: bool whether to smooth the vertices of the skeleton compute_radius: bool whether to calculate the radius of the skeleton at each point on the skeleton (default True) shape_function: 'single' or 'cone' Selects how to compute the radius, either with a single ray or a cone of rays. Default is 'single'. compute_original_index: bool whether to calculate how each of the mesh nodes maps onto the skeleton (default True) smooth_iterations: int, optional Number of iterations to smooth (default is 12) smooth_neighborhood: int, optional Size of neighborhood to look at for smoothing smooth_r: float, optional Weight of update step in smoothing algorithm, default is 0.2 root_index: int or None, optional A precise mesh vertex to use as the skeleton root. If provided, the vertex location overrides soma_pt. By default, None. remove_zero_length_edges: bool If True, removes vertices involved in zero length edges, which can disrupt graph computations. Default True. collapse_params: dict Extra keyword arguments for the collapse function. See soma_via_sphere and soma_via_branch_starts for specifics. cc_vertex_thresh : int, optional Smallest number of vertices in a connected component to skeletonize. verbose: bool whether to print verbose logging meta: dict Skeletonization metadata to add to the skeleton. See skeleton.SkeletonMetadata for keys. Returns ------- :obj:`meshparty.skeleton.Skeleton` a Skeleton object for this mesh """ ( skel_verts, skel_edges, orig_skel_index, skel_map, ) = calculate_skeleton_paths_on_mesh( mesh, invalidation_d=invalidation_d, cc_vertex_thresh=cc_vertex_thresh, root_index=root_index, return_map=True, ) if smooth_vertices is True: smooth_verts = smooth_graph( skel_verts, skel_edges, neighborhood=smooth_neighborhood, iterations=smooth_iterations, r=smooth_r, ) skel_verts = smooth_verts if root_index is not None and soma_pt is None: soma_pt = mesh.vertices[root_index] if soma_pt is not None: soma_pt = np.array(soma_pt).reshape(1, 3) rs = None if collapse_soma is True and soma_pt is not None: temp_sk = Skeleton( skel_verts, skel_edges, mesh_index=mesh.map_indices_to_unmasked(orig_skel_index), mesh_to_skel_map=skel_map, ) _, close_ind = temp_sk.kdtree.query(soma_pt) temp_sk.reroot(close_ind[0]) if collapse_function == "sphere": soma_verts, soma_r = soma_via_sphere( soma_pt, temp_sk.vertices, temp_sk.edges, soma_radius ) elif collapse_function == "branch": try: from .ray_tracing import ray_trace_distance, shape_diameter_function except: raise ImportError('Could not import pyembree for ray tracing') if shape_function == "single": rs = ray_trace_distance( mesh.filter_unmasked_indices_padded(temp_sk.mesh_index), mesh ) elif shape_function == "cone": rs = shape_diameter_function( mesh.filter_unmasked_indices_padded(temp_sk.mesh_index), mesh, num_points=30, cone_angle=np.pi / 3, ) soma_verts, soma_r = soma_via_branch_starts( temp_sk, mesh, soma_pt, rs, search_radius=collapse_params.get("search_radius", 25000), fallback_radius=collapse_params.get("fallback_radius", soma_radius), cutoff_threshold=collapse_params.get("cutoff_threshold", 0.4), min_cutoff=collapse_params.get("min_cutoff", 0.1), dynamic_range=collapse_params.get("dynamic_range", 1), dynamic_threshold=collapse_params.get("dynamic_threshold", False), ) if root_index is not None: collapse_index = np.flatnonzero(orig_skel_index == root_index)[0] else: collapse_index = None new_v, new_e, new_skel_map, vert_filter, root_ind = collapse_soma_skeleton( soma_verts, soma_pt, temp_sk.vertices, temp_sk.edges, mesh_to_skeleton_map=temp_sk.mesh_to_skel_map, collapse_index=collapse_index, return_filter=True, return_soma_ind=True, ) else: new_v, new_e, new_skel_map = skel_verts, skel_edges, skel_map vert_filter = np.arange(len(orig_skel_index)) if root_index is not None: root_ind = np.flatnonzero(orig_skel_index == root_index)[0] elif soma_pt is None: sk_graph = utils.create_csgraph(new_v, new_e) root_ind = utils.find_far_points_graph(sk_graph)[0] else: # Still try to root close to the soma _, qry_inds = spatial.cKDTree(new_v, balanced_tree=False).query( soma_pt[np.newaxis, :] ) root_ind = qry_inds[0] skel_map_full_mesh = np.full(mesh.node_mask.shape, -1, dtype=np.int64) skel_map_full_mesh[mesh.node_mask] = new_skel_map ind_to_fix = mesh.map_boolean_to_unmasked(np.isnan(new_skel_map)) skel_map_full_mesh[ind_to_fix] = -1 props = {} if compute_original_index is True: if collapse_soma is True and soma_pt is not None: mesh_index = temp_sk.mesh_index[vert_filter] if root_index is None: mesh_index = np.append(mesh_index, -1) else: mesh_index = orig_skel_index[vert_filter] props["mesh_index"] = mesh_index if compute_radius is True: if rs is None: try: from .ray_tracing import ray_trace_distance, shape_diameter_function except: raise ImportError('Could not import pyembree for ray tracing') if shape_function == "single": rs = ray_trace_distance(orig_skel_index[vert_filter], mesh) elif shape_function == "cone": rs = shape_diameter_function(orig_skel_index[vert_filter], mesh) else: rs = rs[vert_filter] if collapse_soma is True and soma_pt is not None: if root_index is None: rs = np.append(rs, soma_r) else: rs[root_ind] = soma_r props["rs"] = rs sk_params = { "soma_pt_x": soma_pt[0, 0] if soma_pt is not None else None, "soma_pt_y": soma_pt[0, 1] if soma_pt is not None else None, "soma_pt_z": soma_pt[0, 2] if soma_pt is not None else None, "soma_radius": soma_radius, "collapse_soma": collapse_soma, "collapse_function": collapse_function, "invalidation_d": invalidation_d, "smooth_vertices": smooth_vertices, "compute_radius": compute_radius, "shape_function": shape_function, "smooth_iterations": smooth_iterations, "smooth_neighborhood": smooth_neighborhood, "smooth_r": smooth_r, "cc_vertex_thresh": cc_vertex_thresh, "remove_zero_length_edges": remove_zero_length_edges, "collapse_params": collapse_params, "timestamp": time.time(), } sk_params.update(meta) sk = Skeleton( new_v, new_e, mesh_to_skel_map=skel_map_full_mesh, mesh_index=props.get("mesh_index", None), radius=props.get("rs", None), root=root_ind, remove_zero_length_edges=remove_zero_length_edges, meta=sk_params, ) if compute_radius is True: _remove_nan_radius(sk) return sk
def _remove_nan_radius(sk, set_unfixed_to_lowest=True): last_numnans = np.inf nanlocs = np.flatnonzero(np.isnan(sk.radius)) numnans = len(nanlocs) while numnans > 0 and last_numnans > numnans: for nanloc in nanlocs: sparse_row = sk.csgraph_binary_undirected[nanloc].toarray().ravel() prod = sparse_row * sk.radius with np.errstate(divide="ignore", invalid="ignore"): new_rad = np.nansum(prod) / np.nansum(prod > 0) sk._rooted.radius[nanloc] = new_rad last_numnans = numnans nanlocs = np.flatnonzero(np.isnan(sk.radius)) numnans = len(nanlocs) if numnans > 0 and set_unfixed_to_lowest: sk._rooted.radius[nanlocs] = np.nanmin(sk.radius)
[docs]def calculate_skeleton_paths_on_mesh( mesh, soma_pt=None, soma_thresh=7500, invalidation_d=10000, smooth_neighborhood=2, smooth_iterations=12, cc_vertex_thresh=100, large_skel_path_threshold=5000, return_map=False, root_index=None, ): """function to turn a trimesh object of a neuron into a skeleton, without running soma collapse, or recasting result into a Skeleton. Used by :func:`meshparty.skeletonize.skeletonize_mesh` and makes use of :func:`meshparty.skeletonize.skeletonize_components` Parameters ---------- mesh: meshparty.trimesh_io.Mesh the mesh to skeletonize, defaults assume vertices in nm soma_pt: np.array a length 3 array specifying to soma location to make the root default=None, in which case a heuristic root will be chosen in units of mesh vertices soma_thresh: float distance in mesh vertex units over which to consider mesh vertices close to soma_pt to belong to soma these vertices will automatically be invalidated and no skeleton branches will attempt to reach them. This distance will also be used to collapse all skeleton points within this distance to the soma_pt root if collpase_soma is true. (default=7500 (nm)) invalidation_d: float the distance along the mesh to invalidate when applying TEASAR like algorithm. Controls how detailed a structure the skeleton algorithm reaches. default (10000 (nm)) smooth_neighborhood: int the neighborhood in edge hopes over which to smooth skeleton locations. This controls the smoothing of the skeleton (default 5) large_skel_path_threshold: int the threshold in terms of skeleton vertices that skeletons will be nominated for tip merging. Smaller skeleton fragments will not be merged at their tips (default 5000) cc_vertex_thresh: int the threshold in terms of vertex numbers that connected components of the mesh will be considered for skeletonization. mesh connected components with fewer than these number of vertices will be ignored by skeletonization algorithm. (default 100) return_map: bool whether to return a map of how each mesh vertex maps onto each skeleton vertex based upon how it was invalidated. root_index: int or None Mesh vertex to set as initial root node. Overides soma_pt if provided. Default is None. Returns ------- skel_verts: np.array a Nx3 matrix of skeleton vertex positions skel_edges: np.array a Kx2 matrix of skeleton edge indices into skel_verts smooth_verts: np.array a Nx3 matrix of vertex positions after smoothing skel_verts_orig: np.array a N long index of skeleton vertices in the original mesh vertex index (mesh_to_skeleton_map): np.array a Mx2 map of mesh vertex indices to skeleton vertex indices """ skeletonize_output = skeletonize_components( mesh, soma_pt=soma_pt, soma_thresh=soma_thresh, invalidation_d=invalidation_d, cc_vertex_thresh=cc_vertex_thresh, return_map=return_map, root_index=root_index, ) if return_map is True: all_paths, roots, tot_path_lengths, mesh_to_skeleton_map = skeletonize_output else: all_paths, roots, tot_path_lengths = skeletonize_output all_edges = [] for comp_paths in all_paths: all_edges.append(utils.paths_to_edges(comp_paths)) if len(all_edges) > 0: tot_edges = np.vstack(all_edges) else: tot_edges = np.zeros((3, 0)) skel_verts, skel_edges, skel_verts_orig = reduce_verts(mesh.vertices, tot_edges) if return_map: mesh_to_skeleton_map = utils.nanfilter_shapes( np.unique(tot_edges.ravel()), mesh_to_skeleton_map ) mesh_to_skeleton_map[np.isnan(mesh_to_skeleton_map)] = -1 else: mesh_to_skeleton_map = None output_tuple = (skel_verts, skel_edges, skel_verts_orig) if return_map: output_tuple = output_tuple + (mesh_to_skeleton_map.astype(int),) return output_tuple
[docs]def reduce_verts(verts, faces): """removes unused vertices from a graph or mesh Parameters ---------- verts : numpy.array NxD numpy array of vertex locations faces : numpy.array MxK numpy array of connected shapes (i.e. edges or tris) (entries are indices into verts) Returns ------- np.array new_verts, a filtered set of vertices np.array new_face, a reindexed set of faces (or edges) np.array used_verts, the index of the new_verts in the old verts """ used_verts = np.unique(faces.ravel()) new_verts = verts[used_verts, :] new_face = np.zeros(faces.shape, dtype=faces.dtype) for i in range(faces.shape[1]): new_face[:, i] = np.searchsorted(used_verts, faces[:, i]) return new_verts, new_face, used_verts
[docs]def skeletonize_components( mesh, soma_pt=None, soma_thresh=10000, invalidation_d=10000, cc_vertex_thresh=100, return_map=False, root_index=None, ): """core skeletonization routine, used by :func:`meshparty.skeletonize.calculate_skeleton_paths_on_mesh` to calculate skeleton on all components of mesh, with no post processing""" # find all the connected components in the mesh n_components, labels = sparse.csgraph.connected_components( mesh.csgraph, directed=False, return_labels=True ) comp_labels, comp_counts = np.unique(labels, return_counts=True) if return_map: mesh_to_skeleton_map = np.full(len(mesh.vertices), np.nan) # variables to collect the paths, roots and path lengths all_paths = [] roots = [] tot_path_lengths = [] if root_index is not None: soma_d = np.linalg.norm(mesh.vertices - mesh.vertices[root_index], axis=1) is_soma_pt = np.arange(len(mesh.vertices)) == root_index elif soma_pt is not None: soma_d = mesh.vertices - soma_pt.reshape(1, 3) soma_d = np.linalg.norm(soma_d, axis=1) is_soma_pt = soma_d < soma_thresh else: is_soma_pt = None soma_d = None # is_soma_pt = None # soma_d = None # loop over the components for k in range(n_components): if comp_counts[k] > cc_vertex_thresh: # find the root using a soma position if you have it # it will fall back to a heuristic if the soma # is too far away for this component root, root_ds, pred, valid = setup_root( mesh, is_soma_pt, soma_d, labels == k ) # run teasar on this component teasar_output = mesh_teasar( mesh, root=root, root_ds=root_ds, root_pred=pred, valid=valid, invalidation_d=invalidation_d, return_map=return_map, ) if return_map is False: paths, path_lengths = teasar_output else: paths, path_lengths, mesh_to_skeleton_map_single = teasar_output mesh_to_skeleton_map[ ~np.isnan(mesh_to_skeleton_map_single) ] = mesh_to_skeleton_map_single[~np.isnan(mesh_to_skeleton_map_single)] if len(path_lengths) > 0: # collect the results in lists tot_path_lengths.append(path_lengths) all_paths.append(paths) roots.append(root) if return_map: return all_paths, roots, tot_path_lengths, mesh_to_skeleton_map else: return all_paths, roots, tot_path_lengths
[docs]def setup_root(mesh, is_soma_pt=None, soma_d=None, is_valid=None): """ function to find the root index to use for this mesh """ if is_valid is not None: valid = np.copy(is_valid) else: valid = np.ones(len(mesh.vertices), bool) assert len(valid) == mesh.vertices.shape[0] root = None # soma mode if is_soma_pt is not None: # pick the first soma as root assert len(soma_d) == mesh.vertices.shape[0] assert len(is_soma_pt) == mesh.vertices.shape[0] is_valid_root = is_soma_pt & valid valid_root_inds = np.where(is_valid_root)[0] if len(valid_root_inds) > 0: min_valid_root = np.nanargmin(soma_d[valid_root_inds]) root = valid_root_inds[min_valid_root] root_ds, pred = sparse.csgraph.dijkstra( mesh.csgraph, directed=False, indices=root, return_predecessors=True ) else: start_ind = np.where(valid)[0][0] root, target, pred, dm, root_ds = utils.find_far_points( mesh, start_ind=start_ind ) valid[is_soma_pt] = False if root is None: # there is no soma close, so use far point heuristic start_ind = np.where(valid)[0][0] root, target, pred, dm, root_ds = utils.find_far_points( mesh, start_ind=start_ind ) valid[root] = False assert np.all(~np.isinf(root_ds[valid])) return root, root_ds, pred, valid
[docs]def mesh_teasar( mesh, root=None, valid=None, root_ds=None, root_pred=None, soma_pt=None, soma_thresh=7500, invalidation_d=10000, return_timing=False, return_map=False, exclude_edges_sigma=None, ): """core skeletonization function used to skeletonize a single component of a mesh""" # if no root passed, then calculation one if root is None: root, root_ds, root_pred, valid = setup_root( mesh, soma_pt=soma_pt, soma_thresh=soma_thresh ) # if root_ds have not be precalculated do so if root_ds is None: root_ds, root_pred = sparse.csgraph.dijkstra( mesh.csgraph, False, root, return_predecessors=True ) # if certain vertices haven't been pre-invalidated start with just # the root vertex invalidated if valid is None: valid = np.ones(len(mesh.vertices), bool) valid[root] = False else: if len(valid) != len(mesh.vertices): raise Exception("valid must be length of vertices") if return_map == True: mesh_to_skeleton_dist = np.full(len(mesh.vertices), np.inf) mesh_to_skeleton_map = np.full(len(mesh.vertices), np.nan) total_to_visit = np.sum(valid) if np.sum(np.isinf(root_ds) & valid) != 0: print(np.where(np.isinf(root_ds) & valid)) raise Exception("all valid vertices should be reachable from root") # vector to store each branch result paths = [] # vector to store each path's total length path_lengths = [] # keep track of the nodes that have been visited visited_nodes = [root] # counter to track how many branches have been counted k = 0 # arrays to track timing start = time.time() time_arrays = [[], [], [], [], []] with tqdm(total=total_to_visit, disable=True) as pbar: # keep looping till all vertices have been invalidated while np.sum(valid) > 0: k += 1 t = time.time() # find the next target, farthest vertex from root # that has not been invalidated target = np.nanargmax(root_ds * valid) if np.isinf(root_ds[target]): raise Exception("target cannot be reached") time_arrays[0].append(time.time() - t) t = time.time() # figure out the longest this branch could be # by following the route from target to the root # and finding the first already visited node (max_branch) # The dist(root->target) - dist(root->max_branch) # is the maximum distance the shortest route to a branch # point from the target could possibly be, # use this bound to reduce the djisktra search radius for this target max_branch = target while max_branch not in visited_nodes: max_branch = root_pred[max_branch] max_path_length = root_ds[target] - root_ds[max_branch] # calculate the shortest path to that vertex # from all other vertices # up till the distance to the root ds, pred_t = sparse.csgraph.dijkstra( mesh.csgraph, False, target, limit=max_path_length, return_predecessors=True, ) # pick out the vertex that has already been visited # which has the shortest path to target min_node = np.argmin(ds[visited_nodes]) # reindex to get its absolute index branch = visited_nodes[min_node] # this is in the index of the point on the skeleton # we want this branch to connect to time_arrays[1].append(time.time() - t) t = time.time() # get the path from the target to branch point path = utils.get_path(target, branch, pred_t) visited_nodes += path[0:-1] # record its length assert ~np.isinf(ds[branch]) path_lengths.append(ds[branch]) # record the path paths.append(path) time_arrays[2].append(time.time() - t) t = time.time() # get the distance to all points along the new path # within the invalidation distance dm, _, sources = sparse.csgraph.dijkstra( mesh.csgraph, False, path, limit=invalidation_d, min_only=True, return_predecessors=True, ) time_arrays[3].append(time.time() - t) t = time.time() # all such non infinite distances are within the invalidation # zone and should be marked invalid nodes_to_update = ~np.isinf(dm) marked = np.sum(valid & ~np.isinf(dm)) if return_map == True: new_sources_closer = ( dm[nodes_to_update] < mesh_to_skeleton_dist[nodes_to_update] ) mesh_to_skeleton_map[nodes_to_update] = np.where( new_sources_closer, sources[nodes_to_update], mesh_to_skeleton_map[nodes_to_update], ) mesh_to_skeleton_dist[nodes_to_update] = np.where( new_sources_closer, dm[nodes_to_update], mesh_to_skeleton_dist[nodes_to_update], ) valid[~np.isinf(dm)] = False # print out how many vertices are still valid pbar.update(marked) time_arrays[4].append(time.time() - t) # record the total time dt = time.time() - start out_tuple = (paths, path_lengths) if return_map: out_tuple = out_tuple + (mesh_to_skeleton_map,) if return_timing: out_tuple = out_tuple + (time_arrays, dt) return out_tuple
[docs]def smooth_graph(values, edges, mask=None, neighborhood=2, iterations=100, r=0.1): """smooths a spatial graph via iterative local averaging calculates the average value of neighboring values and relaxes the values toward that average Parameters ---------- values : numpy.array a NxK numpy array of values, for example xyz positions edges : numpy.array a Mx2 numpy array of indices into values that are edges mask : numpy.array NOT yet implemented optional N boolean vector of values to mask the vert locations. the result will return a result at every vert but the values that are false in this mask will be ignored and not factored into the smoothing. neighborhood : int an integer of how far in the graph to relax over as being local to any vertex (default = 2) iterations : int number of relaxation iterations (default = 100) r : float relaxation factor at each iteration new_vertex = (1-r)*old_vertex*mask + (r+(1-r)*(1-mask))*(local_avg) default = .1 Returns ------- np.array new_verts, a Nx3 list of new smoothed vertex positions """ N = len(values) E = len(edges) # setup a sparse matrix with the edges sm = sparse.csc_matrix((np.ones(E), (edges[:, 0], edges[:, 1])), shape=(N, N)) # an identity matrix eye = sparse.csc_matrix( (np.ones(N, dtype=np.float32), (np.arange(0, N), np.arange(0, N))), shape=(N, N) ) # for undirected graphs we want it symettric sm = sm + sm.T # this will store our relaxation matrix C = sparse.csc_matrix(eye) # multiple the matrix and add to itself # to spread connectivity along the graph for i in range(neighborhood): C = C + sm @ C # zero out the diagonal elements C.setdiag(np.zeros(N)) # don't overweight things that are connected in more than one way C = C.sign() # measure total effective neighbors per node neighbors = np.sum(C, axis=1) # normalize the weights of neighbors according to number of neighbors neighbors = 1 / neighbors C = C.multiply(neighbors) # convert back to csc C = C.tocsc() # multiply weights by relaxation term C *= r # construct relaxation matrix, adding on identity with complementary weight A = C + (1 - r) * eye # make a copy of original vertices to no destroy inpuyt new_values = np.copy(values) # iteratively relax the vertices for i in range(iterations): new_values = A * new_values return new_values
[docs]def soma_via_sphere(soma_pt, verts, edges, soma_d_thresh): """Get indices within soma_d_thresh of a soma_pt. Exclude vertices that left and come back.""" closest_soma_ind = np.argmin(np.linalg.norm(verts - soma_pt, axis=1)) close_inds = np.linalg.norm(verts - soma_pt, axis=1) < soma_d_thresh orig_graph = utils.create_csgraph(verts, edges, euclidean_weight=False) speye = sparse.diags(close_inds.astype(int)) _, compids = sparse.csgraph.connected_components(orig_graph * speye) return np.flatnonzero(compids[closest_soma_ind] == compids), soma_d_thresh
[docs]def soma_via_branch_starts( sk, mesh, soma_pt, rs, search_radius=20000, fallback_radius=15000, cutoff_threshold=0.2, min_cutoff=0.1, dynamic_range=1, dynamic_threshold=False, ): """Runs down paths into the soma region and finds onset of each branch.""" is_close = np.linalg.norm(sk.vertices - soma_pt, axis=1) < search_radius is_close_fallback = np.linalg.norm(sk.vertices - soma_pt, axis=1) < fallback_radius # Find segments that emerge from the soma region close_segs = [] for seg in sk.segments: seg = seg[np.argsort(sk.distance_to_root[seg])] if is_close[seg[0]]: if np.all(is_close[sk.path_to_root(seg[0])]): close_segs.append(seg) close_seg_inds = np.concatenate(close_segs) close_inds = close_seg_inds[is_close[close_seg_inds]] # From those segments, find the tips that come out of the search_radius is_close_specific = np.full(sk.n_vertices, False) is_close_specific[close_inds] = True close_parent_edges = sk.edges[is_close_specific[sk.edges[:, 1]]] tip_inds = close_parent_edges[~is_close_specific[close_parent_edges[:, 0]], 0] rs = rs[close_inds] rs_long = np.full(sk.n_vertices, np.nan) rs_long[close_inds] = rs sk.reroot(close_inds[np.argmin(np.abs(rs - np.percentile(rs, 98)))]) # Fit a logistic curve def log_func(x, h, k, xh, a): return h / (1 + np.exp(-k * (xh - x))) + a all_params = [] soma_votes = [] for tip_ind in tip_inds: ptr = sk.path_to_root(tip_ind) path_inds = ptr[1:] xdata = sk.distance_to_root[path_inds] / 1000 ydata = rs_long[path_inds] / 1000 good_rows = np.invert(np.logical_or(np.isnan(ydata), np.isinf(ydata))) ydata = ydata[good_rows] xdata = xdata[good_rows] ydata_filt = np.maximum.accumulate(signal.medfilt(ydata, 21)) try: # sig = ydata_filt * log_func( (np.max(xdata)-xdata), 1, 2, 3, 1) sig = np.where(ydata_filt < 2, 2, ydata_filt) * log_func( (np.max(xdata) - xdata), 2, 1, 5, 1 ) params, _ = optimize.curve_fit( log_func, xdata, ydata_filt, sigma=sig, bounds=([0, 0.5, 0, 0], [np.inf, 5, np.inf, np.inf]), p0=(10, 1, 10, 1), method="trf", ) all_params.append(params) if dynamic_threshold: cutoff_threshold_eff = np.min( [ cutoff_threshold, min_cutoff + (cutoff_threshold - min_cutoff) * np.max([0, (params[1] - 0.5) / dynamic_range]), ] ) else: cutoff_threshold_eff = cutoff_threshold def f(x): return log_func(x, *params) - ( params[3] + cutoff_threshold_eff * (params[0] - params[3]) ) opt_sol = optimize.root_scalar(f, bracket=[0, xdata.max()]) if opt_sol.converged: root = opt_sol.root use_fallback = False else: use_fallback = True except: use_fallback = True all_params.append(None) if not use_fallback: base_path_ind = np.argmin(np.abs(xdata - root)) else: d_to_soma_pt = np.linalg.norm(sk.vertices[path_inds] - soma_pt, axis=1) base_path_ind = np.argmin(np.abs(d_to_soma_pt - fallback_radius)) soma_vote = np.full(sk.n_vertices, np.nan) soma_vote[path_inds[base_path_ind:]] = 1 soma_vote[path_inds[:base_path_ind]] = 0 soma_votes.append(soma_vote) # Any tips whose path to root is entirely in the close zone also vote as as 'somatic' for ep in sk.end_points[is_close[sk.end_points]]: ptr = sk.path_to_root(ep) if np.all(is_close[ptr]): soma_vote = np.full(sk.n_vertices, np.nan) soma_vote[ptr] = 1 soma_vote[ptr[is_close_fallback[ptr]]] = np.inf soma_votes.append(soma_vote) # Get soma region soma_votes = np.vstack(soma_votes) num_votes = len(soma_votes) - np.sum(np.isnan(soma_votes), axis=0) num_yes = np.nansum(soma_votes, axis=0) with np.errstate(all="ignore"): is_soma = (num_yes / num_votes) > 0.5 last_nonsoma = [] for tip_ind in tip_inds: ptr = sk.path_to_root(tip_ind) last_nonsoma.append(ptr[np.flatnonzero(np.diff(is_soma[ptr]) == 1)[0]]) last_nonsoma = np.unique(last_nonsoma) keep_binds = [] for bind in last_nonsoma: bind_ptr = sk.path_to_root(bind) if np.any(np.isin(last_nonsoma, bind_ptr[1:])): keep_binds.append(False) else: keep_binds.append(True) keep_binds = np.array(keep_binds) g = sk.cut_graph(last_nonsoma[keep_binds]) _, comps = sparse.csgraph.connected_components(g) root_comp = comps[sk.root] return np.flatnonzero(comps == root_comp), np.nanmedian(rs_long[comps == root_comp])
[docs]def collapse_soma_skeleton( soma_verts, soma_pt, verts, edges, mesh_to_skeleton_map=None, collapse_index=None, return_filter=False, return_soma_ind=False, ): """function to adjust skeleton result to move root to soma_pt Parameters ---------- soma_pt : numpy.array a 3 long vector of xyz locations of the soma (None to just remove duplicate ) verts : numpy.array a Nx3 array of xyz vertex locations edges : numpy.array a Kx2 array of edges of the skeleton soma_d_thresh : float distance from soma_pt to collapse skeleton nodes mesh_to_skeleton_map : np.array a M long array of how each mesh index maps to a skeleton vertex (default None). The function will update this as it collapses vertices to root. soma_mesh_indices : np.array a K long array of indices in the mesh that should be considered soma Any skeleton vertex on these vertices will all be collapsed to root. return_filter : bool whether to return a list of which skeleton vertices were used in the end for the reduced set of skeleton vertices only_soma_component : bool whether to collapse only the skeleton connected component which is closest to the soma_pt (default True) return_soma_ind : bool whether to return which skeleton index that is the soma_pt Returns ------- np.array verts, Px3 array of xyz skeleton vertices np.array edges, Qx2 array of skeleton edges (np.array) new_mesh_to_skeleton_map, returned if mesh_to_skeleton_map and soma_pt passed (np.array) used_vertices, if return_filter this contains the indices into the passed verts which the return verts is using int an index into the returned verts that is the root of the skeleton node, only returned if return_soma_ind is True """ if soma_verts is not None: soma_pt_m = soma_pt.reshape(1, 3) if collapse_index is None: new_verts = np.vstack((verts, soma_pt_m)) soma_i = verts.shape[0] else: new_verts = verts soma_i = collapse_index soma_verts = soma_verts[soma_verts != soma_i] edges_m = edges.copy() edges_m[np.isin(edges, soma_verts)] = soma_i simple_verts, simple_edges = utils.remove_unused_verts(new_verts, edges_m) good_edges = ~(simple_edges[:, 0] == simple_edges[:, 1]) if mesh_to_skeleton_map is not None: consolidate_dict = {v: soma_i for v in soma_verts} new_index_dict, _ = utils.remap_dict(len(new_verts), consolidate_dict) new_index_dict[-1] = -1 mesh_to_skeleton_map[np.isnan(mesh_to_skeleton_map)] = -1 new_mesh_to_skeleton_map = fastremap.remap( mesh_to_skeleton_map, new_index_dict ) output = [simple_verts, simple_edges[good_edges]] if mesh_to_skeleton_map is not None: output.append(new_mesh_to_skeleton_map) if return_filter: used_vertices = np.unique(edges_m.ravel()) if collapse_index is None: # Remove the largest value which is soma_i used_vertices = used_vertices[:-1] output.append(used_vertices) if return_soma_ind: output.append(new_index_dict[soma_i]) return output else: simple_verts, simple_edges = utils.remove_unused_verts(verts, edges) return simple_verts, simple_edges
[docs]def resample(sk, spacing, kind="linear", tip_length_ratio=0.25): """Resample a skeleton's vertices Parameters ---------- sk : Skeleton Input skeleton file with a skeleton spacing : numeric Desired spacing in nanometers kind : str, optional Type of interpolation to use when resampling. Options follow scipy.interpolate.interp1d. By default "linear" tip_length_ratio : float, optional The ratio of spacing to branch tip length that a branch tip must have in order to be included in the final skeleton for example: spacing is 10 and branch length is 8. do you want to include that final 8 length tip? then perhaps consider a tip_length_ratio of .75, by default 0.25 Returns ------- Skeleton New skeleton with resampled vertices. resample_map Array where the ith index corresponds to the ith vertex of the resampled skeleton and the value is the associated index in the original skeleton. To assign vertices, we assign a "domain" to each vertex in the original skeleton that is halfway between the vertex and its neighbors. Resampled vertices that fall within that domain (based on topology and distance-to-root) are then associated with the original vertex. """ path_counter = 0 branch_d = {} vert_list = [] edge_list = [] output_map_list = [] for path in sk.cover_paths: new_verts, new_edges, output_map_path, branch_d = skeleton_utils.resample_path( path, sk, path_counter, spacing, kind, tip_length_ratio, branch_d, ) vert_list.append(new_verts) edge_list.append(new_edges) output_map_list.append(output_map_path) path_counter += len(new_verts) new_verts = np.vstack(vert_list) new_edges = np.vstack(edge_list) resample_map = np.concatenate(output_map_list) return ( Skeleton( new_verts, new_edges, root=branch_d[int(sk.root)], remove_zero_length_edges=False, ), resample_map, )