Source code for meshparty.iterator
__doc__ = """
Mesh Iterator Classes
"""
import random
import time
import numpy as np
ORDERS = ["random", "sequential"]
[docs]class LocalViewIterator(object):
"""
Iterator class which samples local views that cover an entire mesh.
Each mesh vertex is counted as "covered" if it's included in at
least one local view across the iterator.
"""
def __init__(self, mesh, n_points, batch_size=1, order="random",
pc_align=False, pc_norm=False, adaptnorm=False,
fisheye=False, sample_n_points=None, verbose=False):
assert order in ORDERS, f"invalid order {order} not in {ORDERS}"
self._active_inds = list(range(mesh.vertices.shape[0]))
self._order = order
self._mesh = mesh
self._batch_size = batch_size
# arguments for local view method calls
self._kwargs = dict(n_points=n_points, pc_align=pc_align,
fisheye=fisheye,
verbose=verbose, sample_n_points=sample_n_points,
return_node_ids=True, pc_norm=pc_norm)
self._deact_kwargs = dict(n_points=n_points, pc_align=pc_align,
adapt_unit_sphere_norm=adaptnorm,
verbose=verbose, sample_n_points=None,
return_node_ids=True, pc_norm=pc_norm)
def __iter__(self):
return self
def __next__(self):
time_start = time.time()
# stopping condition: no more indices to sample
if len(self._active_inds) == 0:
raise StopIteration
n_samples = min(self._batch_size, len(self._active_inds))
if self._order == "random":
random.seed(time.time())
centers = np.random.choice(self._active_inds, n_samples,
replace=False)
views, _, node_ids = self._mesh.get_local_views(
center_node_ids=centers, **self._kwargs)
if self._kwargs["sample_n_points"] is not None:
_, _, node_ids = self._mesh.get_local_views(
center_node_ids=centers, **self._deact_kwargs)
self._deactivate_nodes(node_ids.flatten())
elif self._order == "sequential":
centers = []
views = []
while len(self._active_inds) > 0:
centers.append(self._active_inds[0])
view, _, node_ids = self._mesh.get_local_views(
center_node_ids=[centers[-1]], **self._kwargs)
views.extend(view)
if self._kwargs["sample_n_points"] is not None:
_, _, node_ids = self._mesh.get_local_views(
center_node_ids=centers, **self._deact_kwargs)
self._deactivate_nodes(node_ids)
else:
raise Exception()
print("Views took %.3fs" % (time.time() - time_start))
return np.array(views, dtype=np.float32), \
np.array(centers, dtype=np.uint32)
def _deactivate_nodes(self, node_ids):
"""
Removes nodes from consideration which have been sampled
in the last patch
"""
if isinstance(node_ids, np.ndarray):
node_ids = node_ids.tolist()
to_deactivate = set(node_ids)
self._active_inds = list(filter(lambda x: x not in to_deactivate,
self._active_inds))