import abc as _abc
import itertools
import typing
from graphlib import TopologicalSorter
from . import _util as _gutil
from .reporting import warn
from .storage._chunks import Chunk
if typing.TYPE_CHECKING:
from .services import JobPool
def _queue_placement(self, pool: "JobPool", chunk_size):
# Get the queued jobs of all the strategies we depend on.
deps = set(
itertools.chain(*(pool.get_submissions_of(strat) for strat in self.get_deps()))
)
# todo: perhaps pass the volume or partition boundaries as chunk size
pool.queue_placement(self, Chunk([0, 0, 0], None), deps=deps)
def _all_chunks(iter_):
return _gutil.unique(
_gutil.ichain(ct.get_placement_set().get_all_chunks() for ct in iter_)
)
def _queue_connectivity(self, pool: "JobPool"):
# Get the queued jobs of all the strategies we depend on.
deps = set(_gutil.ichain(pool.get_submissions_of(strat) for strat in self.get_deps()))
# Schedule all chunks in 1 job
pre_chunks = _all_chunks(self.presynaptic.cell_types)
post_chunks = _all_chunks(self.postsynaptic.cell_types)
job = pool.queue_connectivity(self, pre_chunks, post_chunks, deps=deps)
def _raise_na(*args, **kwargs):
raise NotImplementedError("NotParallel connection strategies have no RoI.")
[docs]
class HasDependencies:
"""
Mixin class to mark that this node may depend on other nodes.
"""
[docs]
@_abc.abstractmethod
def get_deps(self):
pass
@_abc.abstractmethod
def __lt__(self, other):
raise NotImplementedError(f"{type(self).__name__} must implement __lt__.")
@_abc.abstractmethod
def __hash__(self):
raise NotImplementedError(f"{type(self).__name__} must implement __hash__.")
[docs]
@classmethod
def sort_deps(cls, objects):
"""
Orders a given dictionary of objects by the class's default mechanism and
then apply the `after` attribute for further restrictions.
"""
objects = set(objects)
ordered = []
sorter = TopologicalSorter(
{o: set(d for d in o.get_deps() if d in objects) for o in objects}
)
sorter.prepare()
while sorter.is_active():
node_group = sorter.get_ready()
ordered.extend(sorted(node_group))
sorter.done(*node_group)
return ordered
[docs]
class NotParallel:
def __init_subclass__(cls, **kwargs):
from .connectivity import ConnectionStrategy
from .placement import PlacementStrategy
super().__init_subclass__(**kwargs)
if PlacementStrategy in cls.__mro__:
cls.queue = _queue_placement
elif ConnectionStrategy in cls.__mro__:
cls.queue = _queue_connectivity
if "get_region_of_interest" not in cls.__dict__:
cls.get_region_of_interest = _raise_na
else:
raise Exception(
"NotParallel can only be applied to placement or "
"connectivity strategies"
)
[docs]
class InvertedRoI:
"""
This mixin inverts the perspective of the ``get_region_of_interest`` interface and
lets you find presynaptic regions of interest for a postsynaptic chunk.
Usage:
..code-block:: python
class MyConnStrat(InvertedRoI, ConnectionStrategy):
def get_region_of_interest(post_chunk):
return [pre_chunk1, pre_chunk2]
"""
[docs]
def queue(self, pool):
# Get the queued jobs of all the strategies we depend on.
deps = set(
_gutil.ichain(pool.get_submissions_of(strat) for strat in self.get_deps())
)
post_types = self.postsynaptic.cell_types
# Iterate over each chunk that is populated by our postsynaptic cell types.
to_chunks = set(
_gutil.ichain(ct.get_placement_set().get_all_chunks() for ct in post_types)
)
rois = {
chunk: roi
for chunk in to_chunks
if (roi := self.get_region_of_interest(chunk)) is None or len(roi)
}
if not rois:
warn(
f"No overlap found between {[post.name for post in post_types]} and "
f"{[pre.name for pre in self.presynaptic.cell_types]} "
f"in '{self.name}'."
)
for chunk, roi in rois.items():
pool.queue_connectivity(self, roi, [chunk], deps=deps)
__all__ = ["HasDependencies", "InvertedRoI", "NotParallel"]