Source code for meshparty.skeleton_io

import os
import h5py
import orjson
import json

from .utils_io import NumpyEncoder

import numpy as np
from meshparty import skeleton
from dataclasses import asdict

FILE_VERSION = 2


[docs]def write_skeleton_h5(sk, filename, overwrite=False): """ Write a skeleton and its properties to an hdf5 file. Parameters ---------- sk : :obj:`meshparty.skeleton.Skeleton` new_mesh_filt filename : str Filename of skeleton file. overwrite : bool Allows overwriting.(default False). """ if not hasattr(sk, "meta"): sk.meta = {} write_skeleton_h5_by_part( filename, vertices=sk.vertices, edges=sk.edges, meta=sk.meta, mesh_to_skel_map=sk.mesh_to_skel_map, vertex_properties=sk.vertex_properties, root=sk.root, overwrite=overwrite, )
[docs]def write_skeleton_h5_by_part( filename, vertices, edges, meta, mesh_to_skel_map=None, vertex_properties={}, root=None, overwrite=False, ): """ Helper function for writing all parts of a skeleton file to an h5. Parameters ---------- filename : str path to write vertices : np.array Nx3 numpy array of vertex locations edges : np.array Kx2 numpy array of vertex indices for edges mesh_to_skel_map : np.array M long numpy array. M is the number of vertices in a mesh that this is associated with. The entries are indices into the N skeleton vertices vertex_properties : dict a dictionary of np.arrays, were keys are descriptive labels and the values are arrays that quantify that label at each vertex examples.. mesh_index) the mesh index of this vertex rs) the sdf (local caliber/thickness) of the mesh at each index root : int which vertex index is root overwrite : bool whether to overwrite file """ if os.path.isfile(filename): if overwrite: os.remove(filename) else: return with h5py.File(filename, "w") as f: f.attrs["file_version"] = FILE_VERSION f.create_dataset("vertices", data=vertices, compression="gzip") f.create_dataset("edges", data=edges, compression="gzip") f.create_dataset( "meta", data=np.string_( orjson.dumps(asdict(meta), option=orjson.OPT_SERIALIZE_NUMPY) ), ) if mesh_to_skel_map is not None: f.create_dataset( "mesh_to_skel_map", data=mesh_to_skel_map, compression="gzip" ) if len(vertex_properties) > 0: _write_dict_to_group(f, "vertex_properties", vertex_properties) if root is not None: f.create_dataset("root", data=root)
def _write_dict_to_group(f, group_name, data_dict): d_grp = f.create_group(group_name) for d_name, d_data in data_dict.items(): d_grp.create_dataset(d_name, data=json.dumps(d_data, cls=NumpyEncoder))
[docs]def read_skeleton_h5_by_part(filename): """ Helper function for extracting all parts of a skeleton file from an h5. Parameters ---------- filename : str path to a h5 file with skeletons Returns ------- str filename, path to write np.array vertices, Nx3 numpy array of vertex locations np.array edges , Kx2 numpy array of vertex indices for edges np.array mesh_to_skel_map , M long numpy array. M is the number of vertices in a mesh that this is associated with. The entries are indices into the N skeleton vertices dict vertex_properties, a dictionary of np.arrays, were keys are descriptive labels and the values are arrays that quantify that label at each vertex examples.. mesh_index) the mesh index of this vertex rs) the sdf (local caliber/thickness) of the mesh at each index int root, which vertex index is root bool overwrite, whether to overwrite file """ assert os.path.isfile(filename) with h5py.File(filename, "r") as f: vertices = f["vertices"][()] edges = f["edges"][()] if "mesh_to_skel_map" in f.keys(): mesh_to_skel_map = f["mesh_to_skel_map"][()] else: mesh_to_skel_map = None vertex_properties = {} if "vertex_properties" in f.keys(): for vp_key in f["vertex_properties"].keys(): vertex_properties[vp_key] = json.loads( f["vertex_properties"][vp_key][()], object_hook=_convert_keys_to_int ) if "meta" in f.keys(): dat = f["meta"][()].tobytes() meta = orjson.loads(dat) else: meta = {} if "root" in f.keys(): root = f["root"][()] else: root = None return vertices, edges, meta, mesh_to_skel_map, vertex_properties, root
[docs]def read_skeleton_h5(filename, remove_zero_length_edges=False): """ Reads a skeleton and its properties from an hdf5 file. Parameters ---------- filename: str path to skeleton file remove_zero_length_edges: bool, optional If True, post-processes the skeleton data to removes any zero length edges. Default is False. Returns ------- :obj:`meshparty.skeleton.Skeleton` skeleton object loaded from the h5 file """ ( vertices, edges, meta, mesh_to_skel_map, vertex_properties, root, ) = read_skeleton_h5_by_part(filename) return skeleton.Skeleton( vertices=vertices, edges=edges, mesh_to_skel_map=mesh_to_skel_map, vertex_properties=vertex_properties, root=root, remove_zero_length_edges=remove_zero_length_edges, meta=meta, )
[docs]def swc_node_labels( sk, dendrite_indices=None, apical_indices=None, soma_indices=None, axon_indices=None, dendrite_default=True, ): """Assemble swc node labels based on compartment labels. By default, unlabeled indices are considered dendrite. Parameters ---------- sk : Skeleton Neuron skeleton object with N vertices dendrite_indices : array, optional Array of indices associated with the dendrites (or basal dendrites if apicals are distinct), by default None. apical_indices : array, optional Array of indices (or boolean mask) for the apical dendrite, by default None. soma_indices : array, optional Array of indices (or boolean mask) for the soma, by default None. axon_indices : axon, optional Array of indices (or boolean mask) for the axon, by default None. dendrite_default : bool, optional, If True, assumed unlabeled vertices are dendrite. Otherwise, give a label of 0. Returns ------- nodelabels N-length vector with the appropriate SWC label for each compartment. Unlabeled vertices are given a default label. Default label is 3 (basal dendrite) if dendrite default is True, else 0. """ SOMA_LABEL = 1 AXON_LABEL = 2 DENDRITE_LABEL = 3 APICAL_LABEL = 4 inds = [dendrite_indices, apical_indices, soma_indices, axon_indices] labels = [DENDRITE_LABEL, APICAL_LABEL, SOMA_LABEL, AXON_LABEL] if dendrite_default: val = DENDRITE_LABEL else: val = 0 node_labels = np.full(len(sk.vertices), val) for ii, label in zip(inds, labels): if ii is not None: node_labels[np.array(ii)] = label return node_labels
[docs]def export_to_swc( skel, filename, node_labels=None, radius=None, header=None, xyz_scaling=1000, resample_spacing=None, interp_kind="linear", tip_length_ratio=0.5, avoid_root=True, ): """ Export a skeleton file to an swc file (see http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html for swc definition) Parameters ---------- filename : str path to the file to save the swc to node_labels : iterable None (default) or an interable of ints co-indexed with vertices. Corresponds to the swc node categories. Defaults to setting all nodes to label 3, dendrite. radius : iterable None (default) or an iterable of floats. This should be co-indexed with vertices. Radius values are assumed to be in the same units as the node vertices. header : str or list, default None. An optional header string for the file. Each element of the list xyz_scaling: Number, default 1000. Down-scales spatial units from the skeleton's units to whatever is desired by the swc. E.g. nm to microns has scaling=1000. """ if header is None: header_string = "" else: if isinstance(header, str): header = [header] header[0] = " ".join(["#", header[0]]) header_string = "\n# ".join(header) if radius is None: radius = np.full(len(skel.vertices), xyz_scaling) elif np.issubdtype(type(radius), int): radius = np.full(len(skel.vertices), radius) if node_labels is None: node_labels = np.full(len(skel.vertices), 0) if resample_spacing is not None: skel, output_map = skeleton.resample( skel, spacing=resample_spacing, tip_length_ratio=tip_length_ratio, kind=interp_kind, avoid_root=avoid_root, ) node_labels = node_labels[output_map] radius = radius[output_map] swc_dat = _build_swc_array(skel, node_labels, radius, xyz_scaling) np.savetxt( filename, swc_dat, delimiter=" ", header=header_string, comments="#", fmt=["%i", "%i", "%.3f", "%.3f", "%.3f", "%.3f", "%i"], )
def _build_swc_array(skel, node_labels, radius, xyz_scaling): """ Helper function for producing the numpy table for an swc. """ order_old = np.concatenate([p[::-1] for p in skel.cover_paths]) new_ids = np.arange(skel.n_vertices) order_map = dict(zip(order_old, new_ids)) node_labels = np.array(node_labels)[order_old] xyz = skel.vertices[order_old] radius = radius[order_old] par_ids = np.array([order_map.get(nid, -1) for nid in skel.parent_nodes(order_old)]) swc_dat = np.hstack( ( new_ids[:, np.newaxis], node_labels[:, np.newaxis], xyz / xyz_scaling, radius[:, np.newaxis] / xyz_scaling, par_ids[:, np.newaxis], ) ) return swc_dat def _convert_keys_to_int(x): if type(x) is dict: return {int(k): v for k, v in x.items()} else: return x