import numpy as np
from ..strategy import ConnectionStrategy
from .shared import MorphologyStrategy
from ...helpers import (
DistributionConfiguration,
assert_attr_in,
)
from ...reporting import report, warn
from random import sample as sample_elements
class TouchInformation:
def __init__(
self, from_cell_type, from_cell_compartments, to_cell_type, to_cell_compartments
):
self.from_cell_type = from_cell_type
self.from_cell_compartments = from_cell_compartments
self.to_cell_type = to_cell_type
self.to_cell_compartments = to_cell_compartments
[docs]class TouchDetector(ConnectionStrategy, MorphologyStrategy):
"""
Connectivity based on intersection of detailed morphologies
"""
casts = {
"compartment_intersection_radius": float,
"cell_intersection_radius": float,
"synapses": DistributionConfiguration.cast,
"allow_zero_synapses": bool,
}
defaults = {
"cell_intersection_plane": "xyz",
"compartment_intersection_plane": "xyz",
"compartment_intersection_radius": 5.0,
"synapses": DistributionConfiguration.cast(1),
"allow_zero_synapses": False,
}
required = [
"cell_intersection_plane",
"compartment_intersection_plane",
"compartment_intersection_radius",
]
[docs] def validate(self):
planes = ["xyz", "xy", "xz", "yz", "x", "y", "z"]
assert_attr_in(
self.__dict__,
"cell_intersection_plane",
planes,
"connection_types.{}".format(self.name),
)
assert_attr_in(
self.__dict__,
"compartment_intersection_plane",
planes,
"connection_types.{}".format(self.name),
)
def connect(self):
labels_pre = None if self.label_pre is None else [self.label_pre]
labels_post = None if self.label_post is None else [self.label_post]
self.morphology_cache = {}
for from_cell_type_index in range(len(self.from_cell_types)):
from_cell_type = self.from_cell_types[from_cell_type_index]
from_cell_compartments = self.from_cell_compartments[from_cell_type_index]
for to_cell_type_index in range(len(self.to_cell_types)):
to_cell_type = self.to_cell_types[to_cell_type_index]
to_cell_compartments = self.to_cell_compartments[to_cell_type_index]
touch_info = TouchInformation(
from_cell_type,
from_cell_compartments,
to_cell_type,
to_cell_compartments,
)
touch_info.from_placement = self.scaffold.get_placement_set(
from_cell_type, labels=labels_pre
)
touch_info.from_positions = list(touch_info.from_placement.positions)
touch_info.from_identifiers = list(touch_info.from_placement.identifiers)
touch_info.to_placement = self.scaffold.get_placement_set(
to_cell_type, labels=labels_post
)
touch_info.to_identifiers = list(touch_info.to_placement.identifiers)
touch_info.to_positions = list(touch_info.to_placement.positions)
# Intersect cells on the widest possible search radius.
candidates = self.intersect_cells(touch_info)
# Intersect cell compartments between matched cells.
connections, morphology_names, compartments = self.intersect_compartments(
touch_info, candidates
)
# Connect the cells and store the morphologies and selected compartments that connect them.
self.scaffold.connect_cells(
self,
connections,
morphologies=morphology_names,
compartments=compartments,
)
# Remove the morphology cache
self.morphology_cache = None
def intersect_cells(self, touch_info):
from_cell_type = touch_info.from_cell_type
to_cell_type = touch_info.to_cell_type
cell_plane = self.cell_intersection_plane
from_cell_tree = self.scaffold.trees.cells.get_planar_tree(
from_cell_type.name, plane=cell_plane
)
to_cell_tree = self.scaffold.trees.cells.get_planar_tree(
to_cell_type.name, plane=cell_plane
)
if from_cell_tree is None or to_cell_tree is None:
return []
from_count = self.scaffold.get_placed_count(from_cell_type.name)
to_count = self.scaffold.get_placed_count(to_cell_type.name)
if hasattr(self, "cell_intersection_radius"):
radius = self.cell_intersection_radius
else:
radius = self.get_search_radius(from_cell_type) + self.get_search_radius(
to_cell_type
)
# TODO: Profile whether the reverse lookup with the smaller tree and then reversing the matches array
# gains us any speed.
if from_count < to_count:
return to_cell_tree.query_radius(from_cell_tree.get_arrays()[0], radius)
else:
reversed_matches = from_cell_tree.query_radius(
to_cell_tree.get_arrays()[0], radius
)
matches = [[] for _ in range(len(from_cell_tree.get_arrays()[0]))]
for i in range(len(reversed_matches)):
for match in reversed_matches[i]:
matches[match].append(i)
return matches
def intersect_compartments(self, touch_info, candidate_map):
connected_cells = []
morphology_names = []
connected_compartments = []
c_check = 0
touching_cells = 0
for i in range(len(candidate_map)):
if i % 100 == 0:
percentage = 100 * float(i) / float(len(candidate_map))
report(
f"Connection progress: {percentage:.2f}%...",
level=2,
ongoing=True,
)
from_id = touch_info.from_identifiers[i]
touch_info.from_morphology = self.get_random_morphology(
touch_info.from_cell_type
)
for j in candidate_map[i]:
c_check += 1
to_id = touch_info.to_identifiers[j]
touch_info.to_morphology = self.get_random_morphology(
touch_info.to_cell_type
)
intersections = self.get_compartment_intersections(
touch_info, touch_info.from_positions[i], touch_info.to_positions[j]
)
if len(intersections) > 0:
touching_cells += 1
number_of_synapses = max(
min(int(self.synapses.sample()), len(intersections)),
int(not self.allow_zero_synapses),
)
cell_connections = [
[from_id, to_id] for _ in range(number_of_synapses)
]
compartment_connections = sample_elements(
intersections, k=number_of_synapses
)
connected_cells.extend(cell_connections)
connected_compartments.extend(compartment_connections)
# Pad the morphology names with the right names for the amount of compartment connections made
morphology_names.extend(
[
[
touch_info.from_morphology.morphology_name,
touch_info.to_morphology.morphology_name,
]
for _ in range(len(compartment_connections))
]
)
report(
"Checked {} candidate cell pairs from {} to {}".format(
c_check, touch_info.from_cell_type.name, touch_info.to_cell_type.name
),
level=2,
)
report(
"Touch connection results: \n* Touching pairs: {} \n* Synapses: {}".format(
touching_cells, len(connected_compartments)
),
level=2,
)
return (
np.array(connected_cells, dtype=int),
np.array(morphology_names, dtype=np.string_),
np.array(connected_compartments, dtype=int),
)
def get_compartment_intersections(self, touch_info, from_pos, to_pos):
from_morpho = touch_info.from_morphology
to_morpho = touch_info.to_morphology
to_comps = to_morpho.get_compartment_positions(touch_info.to_cell_compartments)
from_tree = from_morpho.get_compartment_tree(touch_info.from_cell_compartments)
if from_tree is None or not len(to_comps):
return []
query_points = to_comps + to_pos - from_pos
compartment_hits = from_tree.query_radius(
query_points, self.compartment_intersection_radius
)
from_map = from_morpho.get_compartment_submask(touch_info.from_cell_compartments)
to_map = to_morpho.get_compartment_submask(touch_info.to_cell_compartments)
intersections = []
for i in range(len(compartment_hits)):
hits = compartment_hits[i]
if len(hits) > 0:
for j in range(len(hits)):
intersections.append([from_map[hits[j]], to_map[i]])
return intersections
def get_search_radius(self, cell_type):
morphologies = self.get_all_morphologies(cell_type)
max_radius = 0.0
for morphology in morphologies:
max_radius = max(
max_radius,
np.max(
np.sqrt(
np.sum(
np.power(morphology.compartment_tree.get_arrays()[0], 2),
axis=1,
)
)
),
)
return max_radius