importabcas_abcimportitertoolsimporttypingfromgraphlibimportTopologicalSorterfrom.import_utilas_gutilfrom.reportingimportwarnfrom.storage._chunksimportChunkiftyping.TYPE_CHECKING:from.servicesimportJobPooldef_queue_placement(self,pool:"JobPool",chunk_size):# Get the queued jobs of all the strategies we depend on.deps=set(itertools.chain(*(pool.get_submissions_of(strat)forstratinself.get_deps())))# todo: perhaps pass the volume or partition boundaries as chunk sizepool.queue_placement(self,Chunk([0,0,0],None),deps=deps)def_all_chunks(iter_):return_gutil.unique(_gutil.ichain(ct.get_placement_set().get_all_chunks()forctiniter_))def_queue_connectivity(self,pool:"JobPool"):""" Get the queued jobs of all the strategies we depend on. param pool: pool where the jobs will be queued type pool: bsb.services.pool.JobPool """deps=set(_gutil.ichain(pool.get_submissions_of(strat)forstratinself.get_deps()))# Schedule all chunks in 1 jobpre_chunks=_all_chunks(self.presynaptic.cell_types)post_chunks=_all_chunks(self.postsynaptic.cell_types)job=pool.queue_connectivity(self,pre_chunks,post_chunks,deps=deps)def_raise_na(*args,**kwargs):raiseNotImplementedError("NotParallel connection strategies have no RoI.")
[docs]classHasDependencies:""" Mixin class to mark that this node may depend on other nodes. """
@_abc.abstractmethoddef__lt__(self,other):raiseNotImplementedError(f"{type(self).__name__} must implement __lt__.")@_abc.abstractmethoddef__hash__(self):raiseNotImplementedError(f"{type(self).__name__} must implement __hash__.")
[docs]@classmethoddefsort_deps(cls,objects):""" Orders a given dictionary of objects by the class's default mechanism and then apply the `after` attribute for further restrictions. """objects=set(objects)ordered=[]sorter=TopologicalSorter({o:set(dfordino.get_deps()ifdinobjects)foroinobjects})sorter.prepare()whilesorter.is_active():node_group=sorter.get_ready()ordered.extend(sorted(node_group))sorter.done(*node_group)returnordered
[docs]classNotParallel:def__init_subclass__(cls,**kwargs):from.connectivityimportConnectionStrategyfrom.placementimportPlacementStrategysuper().__init_subclass__(**kwargs)ifPlacementStrategyincls.__mro__:cls.queue=_queue_placementelifConnectionStrategyincls.__mro__:cls.queue=_queue_connectivityif"get_region_of_interest"notincls.__dict__:cls.get_region_of_interest=_raise_naelse:raiseException("NotParallel can only be applied to placement or ""connectivity strategies")
[docs]classInvertedRoI:""" This mixin inverts the perspective of the ``get_region_of_interest`` interface and lets you find presynaptic regions of interest for a postsynaptic chunk. Usage: ..code-block:: python class MyConnStrat(InvertedRoI, ConnectionStrategy): def get_region_of_interest(post_chunk): return [pre_chunk1, pre_chunk2] """
[docs]defqueue(self,pool):# Get the queued jobs of all the strategies we depend on.deps=set(_gutil.ichain(pool.get_submissions_of(strat)forstratinself.get_deps()))post_types=self.postsynaptic.cell_types# Iterate over each chunk that is populated by our postsynaptic cell types.to_chunks=set(_gutil.ichain(ct.get_placement_set().get_all_chunks()forctinpost_types))rois={chunk:roiforchunkinto_chunksif(roi:=self.get_region_of_interest(chunk))isNoneorlen(roi)}ifnotrois:warn(f"No overlap found between {[post.nameforpostinpost_types]} and "f"{[pre.nameforpreinself.presynaptic.cell_types]} "f"in '{self.name}'.")forchunk,roiinrois.items():pool.queue_connectivity(self,roi,[chunk],deps=deps)