Source code for meshparty.skeleton_io

import os
from . import trimesh_io
import h5py
import json
import numpy as np
from meshparty import skeleton


[docs]def write_skeleton_h5(sk, filename, overwrite=False): ''' Write a skeleton and its properties to an hdf5 file. Parameters ---------- sk : :obj:`meshparty.skeleton.Skeleton` new_mesh_filt filename : str Filename of skeleton file. overwrite : bool Allows overwriting.(default False). ''' write_skeleton_h5_by_part(filename, vertices=sk.vertices, edges=sk.edges, mesh_to_skel_map=sk.mesh_to_skel_map, vertex_properties=sk.vertex_properties, root=sk.root, overwrite=overwrite)
[docs]def write_skeleton_h5_by_part(filename, vertices, edges, mesh_to_skel_map=None, vertex_properties={}, root=None, overwrite=False): ''' Helper function for writing all parts of a skeleton file to an h5. Parameters ---------- filename : str path to write vertices : np.array Nx3 numpy array of vertex locations edges : np.array Kx2 numpy array of vertex indices for edges mesh_to_skel_map : np.array M long numpy array. M is the number of vertices in a mesh that this is associated with. The entries are indices into the N skeleton vertices vertex_properties : dict a dictionary of np.arrays, were keys are descriptive labels and the values are arrays that quantify that label at each vertex examples.. mesh_index) the mesh index of this vertex rs) the sdf (local caliber/thickness) of the mesh at each index root : int which vertex index is root overwrite : bool whether to overwrite file ''' if os.path.isfile(filename): if overwrite: os.remove(filename) else: return with h5py.File(filename, 'w') as f: f.create_dataset('vertices', data=vertices, compression='gzip') f.create_dataset('edges', data=edges, compression='gzip') if mesh_to_skel_map is not None: f.create_dataset('mesh_to_skel_map', data=mesh_to_skel_map, compression='gzip') if len(vertex_properties) > 0: _write_dict_to_group(f, 'vertex_properties', vertex_properties) if root is not None: f.create_dataset('root', data=root)
def _write_dict_to_group(f, group_name, data_dict): d_grp = f.create_group(group_name) for d_name, d_data in data_dict.items(): is_np = type(d_data) is np.ndarray d_grp.create_dataset(d_name, data=json.dumps(d_data, cls=_NumpyEncoder))
[docs]def read_skeleton_h5_by_part(filename): ''' Helper function for extracting all parts of a skeleton file from an h5. Parameters ---------- filename : str path to a h5 file with skeletons Returns ------- str filename, path to write np.array vertices, Nx3 numpy array of vertex locations np.array edges , Kx2 numpy array of vertex indices for edges np.array mesh_to_skel_map , M long numpy array. M is the number of vertices in a mesh that this is associated with. The entries are indices into the N skeleton vertices dict vertex_properties, a dictionary of np.arrays, were keys are descriptive labels and the values are arrays that quantify that label at each vertex examples.. mesh_index) the mesh index of this vertex rs) the sdf (local caliber/thickness) of the mesh at each index int root, which vertex index is root bool overwrite, whether to overwrite file ''' assert os.path.isfile(filename) with h5py.File(filename, 'r') as f: vertices = f['vertices'][()] edges = f['edges'][()] if 'mesh_to_skel_map' in f.keys(): mesh_to_skel_map = f['mesh_to_skel_map'][()] else: mesh_to_skel_map = None vertex_properties = {} if 'vertex_properties' in f.keys(): for vp_key in f['vertex_properties'].keys(): vertex_properties[vp_key] = json.loads(f['vertex_properties'][vp_key][()], object_hook=_convert_keys_to_int) if 'root' in f.keys(): root = f['root'][()] else: root = None return vertices, edges, mesh_to_skel_map, vertex_properties, root
[docs]def read_skeleton_h5(filename): ''' Reads a skeleton and its properties from an hdf5 file. Parameters ---------- filename: str path to skeleton file Returns ------- :obj:`meshparty.skeleton.Skeleton` skeleton object loaded from the h5 file ''' vertices, edges, mesh_to_skel_map, vertex_properties, root = read_skeleton_h5_by_part(filename) return skeleton.Skeleton(vertices=vertices, edges=edges, mesh_to_skel_map=mesh_to_skel_map, vertex_properties=vertex_properties, root=root)
[docs]def export_to_swc(skel, filename, node_labels=None, radius=None, header=None, xyz_scaling=1000): ''' Export a skeleton file to an swc file (see http://research.mssm.edu/cnic/swc.html for swc definition) Parameters ---------- filename : str path to the file to save the swc to node_labels : iterable None (default) or an interable of ints co-indexed with vertices. Corresponds to the swc node categories. Defaults to setting all nodes to label 3, dendrite. radius : iterable None (default) or an iterable of floats. This should be co-indexed with vertices. Radius values are assumed to be in the same units as the node vertices. header : dict, default None. Each key value pair in the dict becomes a parameter line in the swc header. xyz_scaling: Number, default 1000. Down-scales spatial units from the skeleton's units to whatever is desired by the swc. E.g. nm to microns has scaling=1000. ''' if header is None: header_string = '' else: header_string = '\n'.join(['{}: {}'.format(k, v) for k, v in header.items()]) if radius is None: radius = np.full(len(skel.vertices), 1000) elif np.issubdtype(type(radius), int): radius = np.full(len(skel.vertices), radius) if node_labels is None: node_labels = np.full(len(skel.vertices), 3) swc_dat = _build_swc_array(skel, node_labels, radius, xyz_scaling) with open(filename, 'w') as f: np.savetxt(f, swc_dat, delimiter=' ', header=header_string, comments='#', fmt=['%i', '%i', '%.3f', '%.3f', '%.3f', '%.3f', '%i'])
def _build_swc_array(skel, node_labels, radius, xyz_scaling): ''' Helper function for producing the numpy table for an swc. ''' ds = skel.distance_to_root order_old = np.argsort(ds) new_ids = np.arange(len(ds)) order_map = dict(zip(order_old, new_ids)) node_labels = np.array(node_labels)[order_old] xyz = skel.vertices[order_old] radius = radius[order_old] par_ids = np.array([order_map.get(nid, -1) for nid in skel._parent_node_array[order_old]]) swc_dat = np.hstack((new_ids[:, np.newaxis], node_labels[:, np.newaxis], xyz / xyz_scaling, radius[:, np.newaxis] / xyz_scaling, par_ids[:, np.newaxis])) return swc_dat class _NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): return int(obj) elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): return float(obj) elif isinstance(obj,(np.ndarray,)): return obj.tolist() return json.JSONEncoder.default(self, obj) def _convert_keys_to_int(x): if type(x) is dict: return {int(k):v for k,v in x.items()} else: return x