Source code for bsb.connectivity.strategy

import abc
import typing
from itertools import chain

from .. import config
from .._util import ichain, obj_str_insert
from ..config import refs, types
from ..exceptions import ConnectivityError
from ..mixins import HasDependencies
from ..profiling import node_meter
from ..reporting import warn

if typing.TYPE_CHECKING:
    from ..cell_types import CellType
    from ..core import Scaffold
    from ..morphologies import MorphologySet
    from ..services import JobPool
    from ..storage.interfaces import PlacementSet


[docs] @config.node class Hemitype: """ Class used to represent one (pre- or postsynaptic) side of a connection rule. """ scaffold: "Scaffold" cell_types: list["CellType"] = config.reflist(refs.cell_type_ref, required=True) """List of cell types to use in connection.""" labels: list[str] = config.attr(type=types.list()) """List of labels to filter the placement set by.""" morphology_labels: list[str] = config.attr(type=types.list()) """List of labels to filter the morphologies by.""" morpho_loader: typing.Callable[["PlacementSet"], "MorphologySet"] = config.attr( type=types.function_(), required=False, call_default=False, default=(lambda ps: ps.load_morphologies()), ) """ Function to load the morphologies (MorphologySet) from a PlacementSet. This override can allow temporary dynamic morphology generation during the connectivity phase, from a much smaller, or empty, MorphologySet. It is useful for example when the task would take too much disk space or time otherwise. """
[docs] class HemitypeCollection: def __init__(self, hemitype, roi): self.hemitype = hemitype self.roi = roi def __iter__(self): return iter(self.hemitype.cell_types) @property def placement(self): return [ ct.get_placement_set( chunks=self.roi, labels=self.hemitype.labels, morphology_labels=self.hemitype.morphology_labels, ) for ct in self.hemitype.cell_types ]
[docs] @config.dynamic(attr_name="strategy", required=True) class ConnectionStrategy(abc.ABC, HasDependencies): scaffold: "Scaffold" name: str = config.attr(key=True) """Name used to refer to the connectivity strategy""" presynaptic: Hemitype = config.attr(type=Hemitype, required=True) """Presynaptic (source) neuron population""" postsynaptic: Hemitype = config.attr(type=Hemitype, required=True) """Postsynaptic (target) neuron population""" depends_on: list["ConnectionStrategy"] = config.reflist(refs.connectivity_ref) """The list of strategies that must run before this one""" output_naming: typing.Union[str, None, dict[str, dict[str, str, None, list[str]]]] = ( config.attr( type=types.or_( types.str(), types.dict( type=types.dict( type=types.or_( types.str(), types.list(type=types.str()), types.none() ) ) ), types.list(type=types.str()), ) ) ) """Specifies how to name the output ConnectivitySets in which the connections between cell type pairs are stored.""" def __init_subclass__(cls, **kwargs): super(cls, cls).__init_subclass__(**kwargs) # Decorate subclasses to measure performance node_meter("connect")(cls) def __hash__(self): return id(self) def __lt__(self, other): # This comparison should sort connection strategies by name, via __repr__ below return str(self) < str(other) @obj_str_insert def __repr__(self): if not hasattr(self, "scaffold"): return f"'{self.name}'" pre = [ct.name for ct in self.presynaptic.cell_types] post = [ct.name for ct in self.postsynaptic.cell_types] return f"'{self.name}', connecting {pre} to {post}"
[docs] @abc.abstractmethod def connect(self, presyn_collection, postsyn_collection): pass
[docs] def get_deps(self): return set(self.depends_on)
def _get_connect_args_from_job(self, pre_roi, post_roi): pre = HemitypeCollection(self.presynaptic, pre_roi) post = HemitypeCollection(self.postsynaptic, post_roi) return pre, post
[docs] def connect_cells(self, pre_set, post_set, src_locs, dest_locs, tag=None): names = self.get_output_names(pre_set.cell_type, post_set.cell_type) between_msg = f"between {pre_set.cell_type.name} and {post_set.cell_type.name}" if len(names) == 0: raise ConnectivityError( f"Connections {between_msg} have been disabled by output naming." ) elif len(names) == 1: name = names[0] if tag is not None and tag != name: raise ConnectivityError( f"Tag ('{tag}') and output name ('{name}') mismatch." ) else: names_msg = f"{between_msg} (names: {', '.join(names)})." if tag is None: raise ConnectivityError( f"No tag was given to decide between multiple output names {names_msg}" ) elif tag not in names: raise ConnectivityError( f"Tag '{tag}' is not a valid output name {names_msg}" ) else: name = tag cs = self.scaffold.require_connectivity_set( pre_set.cell_type, post_set.cell_type, name ) cs.connect(pre_set, post_set, src_locs, dest_locs)
[docs] def get_region_of_interest(self, chunk): pass
[docs] def queue(self, pool: "JobPool"): """ Specifies how to queue this connectivity strategy into a job pool. Can be overridden, the default implementation asks each partition to chunk itself and creates 1 placement job per chunk. """ # Get the queued jobs of all the strategies we depend on. dep_jobs = set( chain.from_iterable( pool.get_submissions_of(strat) for strat in self.get_deps() ) ) pre_types = self.presynaptic.cell_types # Iterate over each chunk that is populated by our presynaptic cell types. from_chunks = set( chain.from_iterable( ct.get_placement_set().get_all_chunks() for ct in pre_types ) ) rois = { chunk: roi for chunk in from_chunks if (roi := self.get_region_of_interest(chunk)) is None or len(roi) } if not rois: warn( f"No overlap found between {[pre.name for pre in pre_types]} and " f"{[post.name for post in self.postsynaptic.cell_types]} " f"in '{self.name}'." ) for chunk, roi in rois.items(): job = pool.queue_connectivity(self, [chunk], roi, deps=dep_jobs)
[docs] def get_cell_types(self): return set(self.presynaptic.cell_types) | set(self.postsynaptic.cell_types)
[docs] def get_all_pre_chunks(self): all_ps = (ct.get_placement_set() for ct in self.presynaptic.cell_types) chunks = set(ichain(ps.get_all_chunks() for ps in all_ps)) return list(chunks)
[docs] def get_all_post_chunks(self): all_ps = (ct.get_placement_set() for ct in self.postsynaptic.cell_types) chunks = set(ichain(ps.get_all_chunks() for ps in all_ps)) return list(chunks)
[docs] def get_output_names(self, pre=None, post=None): if (pre is None) != (post is None): raise RuntimeError("pre and post must be specified or omitted together.") if pre is not None and ( pre not in self.presynaptic.cell_types or post not in self.postsynaptic.cell_types ): raise ValueError( f"'{pre.name}' and '{post.name}' are not a valid cell pair type for this connectivity strategy." ) if self.output_naming is None or isinstance(self.output_naming, str): return self._infer_output_name(self.output_naming or self.name, pre, post) elif isinstance(self.output_naming, list): # Call `_infer_output_name` for each given `base` in the list, and chain them together return [ *ichain( self._infer_output_name(base, pre, post) for base in self.output_naming ) ] else: return self._get_output_name(pre, post)
def _infer_output_name(self, base, pre, post): if len(self.presynaptic.cell_types) > 1 or len(self.postsynaptic.cell_types) > 1: if pre is None: # All output names return [ *ichain( self._infer_output_name(base, pre_ct, post_ct) for pre_ct in self.presynaptic.cell_types for post_ct in self.postsynaptic.cell_types ) ] else: # Pair specific output name return [f"{base}_{pre.name}_to_{post.name}"] else: # Single output name return [base] def _get_output_name(self, pre, post): if pre is None: # All output names return [ *ichain( self._get_output_name(pre_ct, post_ct) for pre_ct in self.presynaptic.cell_types for post_ct in self.postsynaptic.cell_types ) ] else: # Pair specific output name MISSING = type("MISSING", (), {"get": lambda *args: MISSING})() spec = self.output_naming.get(pre.name, MISSING).get(post.name, MISSING) if spec is MISSING: return self._infer_output_name(self.name, pre, post) elif spec is None: return [] elif isinstance(spec, str): return [spec] else: return spec
__all__ = ["ConnectionStrategy", "Hemitype", "HemitypeCollection"]