import itertools
import os
import sys
import typing
import numpy as np
from ._util import obj_str_insert
from .config._config import Configuration
from .connectivity import ConnectionStrategy
from .exceptions import (
InputError,
MissingActiveConfigError,
NodeNotFoundError,
RedoError,
)
from .placement import PlacementStrategy
from .profiling import meter
from .reporting import report
from .services import MPI, JobPool
from .services._pool_listeners import NonTTYTerminalListener, TTYTerminalListener
from .services.pool import Job, Workflow
from .simulation import get_simulation_adapter
from .storage import Storage, open_storage
from .storage._chunks import Chunk
if typing.TYPE_CHECKING:
from .cell_types import CellType
from .config._config import NetworkNode as Network
from .postprocessing import AfterPlacementHook
from .simulation.simulation import Simulation
from .storage.interfaces import (
ConnectivitySet,
FileStore,
MorphologyRepository,
PlacementSet,
)
from .topology import Partition, Region
[docs]
@meter()
def from_storage(root):
"""
Load :class:`.core.Scaffold` from a storage object.
:param root: Root (usually path) pointing to the storage object.
:returns: A network scaffold
:rtype: :class:`Scaffold`
"""
return open_storage(root).load()
_cfg_props = (
"network",
"regions",
"partitions",
"cell_types",
"placement",
"after_placement",
"connectivity",
"after_connectivity",
"simulations",
)
def _config_property(name):
def fget(self):
return getattr(self.configuration, name)
def fset(self, value):
setattr(self.configuration, name, value)
prop = property(fget)
return prop.setter(fset)
def _get_linked_config(storage=None):
import bsb.config
try:
cfg = storage.load_active_config()
except Exception:
import bsb.options
path = bsb.options.config
else:
path = cfg._meta.get("path", None)
if path and os.path.exists(path):
with open(path, "r") as f:
cfg = bsb.config.parse_configuration_file(f)
return cfg
else:
return None
def _bad_flag(flag: bool):
return flag is not None and bool(flag) is not flag
[docs]
class Scaffold:
"""
This is the main object of the bsb package, it represents a network and puts together
all the pieces that make up the model description such as the
:class:`~.config.Configuration` with the technical side like the
:class:`~.storage.Storage`.
"""
network: "Network"
regions: typing.Dict[str, "Region"]
partitions: typing.Dict[str, "Partition"]
cell_types: typing.Dict[str, "CellType"]
placement: typing.Dict[str, "PlacementStrategy"]
after_placement: typing.Dict[str, "AfterPlacementHook"]
connectivity: typing.Dict[str, "ConnectionStrategy"]
after_connectivity: typing.Dict[str, "AfterPlacementHook"]
simulations: typing.Dict[str, "Simulation"]
def __init__(self, config=None, storage=None, clear=False, comm=None):
"""
Bootstraps a network object.
:param config: The configuration to use for this network. If it is omitted the
:ref:`default configuration <default-config>` is used.
:type config: :class:`~.config.Configuration`
:param storage: The storage to use to read and write data for this network. If it
is omitted the configuration's ``Storage`` node is used to construct one.
:type storage: :class:`~.storage.Storage`
:param clear: Start with a new network, clearing any previously stored information
:type clear: bool
:returns: A network object
:rtype: :class:`~.core.Scaffold`
"""
self._pool_listeners: list[tuple[typing.Callable[[list["Job"]], None], float]] = (
[]
)
self._configuration = None
self._storage = None
self._comm = comm or MPI
self._bootstrap(config, storage, clear=clear)
def __contains__(self, component):
return getattr(component, "scaffold", None) is self
@obj_str_insert
def __repr__(self):
file = os.path.abspath(self.storage.root)
cells_placed = len(self.cell_types)
n_types = len(self.connectivity)
return f"'{file}' with {cells_placed} cell types, and {n_types} connection_types"
[docs]
def is_main_process(self) -> bool:
return not MPI.get_rank()
[docs]
def is_worker_process(self) -> bool:
return bool(MPI.get_rank())
def _bootstrap(self, config, storage, clear=False):
if config is None:
# No config given, check for linked configs, or stored configs, otherwise
# make default config.
linked = _get_linked_config(storage)
if linked:
report(f"Pulling configuration from linked {linked}.", level=2)
config = linked
elif storage is not None:
try:
config = storage.load_active_config()
except MissingActiveConfigError:
config = Configuration.default()
else:
config = Configuration.default()
if not storage:
# No storage given, create one.
report("Creating storage from config.", level=4)
storage = Storage(config.storage.engine, config.storage.root)
if clear:
# Storage given, but asked to clear it before use.
storage.remove()
storage.create()
# Synchronize the scaffold, config and storage objects for use together
self._configuration = config
# Make sure the storage config node reflects the storage we are using
config._update_storage_node(storage)
# Give the scaffold access to the unitialized storage object (for use during
# config bootstrapping).
self._storage = storage
# First, the scaffold is passed to each config node, and their boot methods called
self._configuration._bootstrap(self)
# Then, `storage` is initted for the scaffold, and `config` is stored (happens
# inside the `storage` property).
self.storage = storage
storage_cfg = _config_property("storage")
for attr in _cfg_props:
vars()[attr] = _config_property(attr)
@property
def configuration(self) -> Configuration:
return self._configuration
@configuration.setter
def configuration(self, cfg: Configuration):
self._configuration = cfg
cfg._bootstrap(self)
self.storage.store_active_config(cfg)
@property
def storage(self) -> Storage:
return self._storage
@storage.setter
def storage(self, storage: Storage):
self._storage = storage
storage.init(self)
@property
def morphologies(self) -> "MorphologyRepository":
return self.storage.morphologies
@property
def files(self) -> "FileStore":
return self.storage.files
[docs]
def clear(self):
"""
Clears the storage. This deletes any existing network data!
"""
self.storage.renew(self)
[docs]
def clear_placement(self):
"""
Clears the placement storage.
"""
self.storage.clear_placement(self)
[docs]
def clear_connectivity(self):
"""
Clears the connectivity storage.
"""
self.storage.clear_connectivity()
[docs]
def resize(self, x=None, y=None, z=None):
"""
Updates the topology boundary indicators. Use before placement, updates
only the abstract topology tree, does not rescale, prune or otherwise
alter already existing placement data.
"""
from .topology._layout import box_layout
if x is not None:
self.network.x = x
if y is not None:
self.network.y = y
if z is not None:
self.network.z = z
self.topology.do_layout(
box_layout(
self.network.origin,
np.array(self.network.origin)
+ [self.network.x, self.network.y, self.network.z],
)
)
[docs]
@meter()
def run_placement(self, strategies=None, fail_fast=True, pipelines=True):
"""
Run placement strategies.
"""
if pipelines:
self.run_pipelines()
if strategies is None:
strategies = [*self.placement.values()]
strategies = PlacementStrategy.sort_deps(strategies)
with self.create_job_pool(fail_fast=fail_fast) as pool:
if pool.is_main():
def scheduler(strategy):
strategy.queue(pool, self.network.chunk_size)
pool.schedule(strategies, scheduler)
pool.execute()
[docs]
@meter()
def run_connectivity(self, strategies=None, fail_fast=True, pipelines=True):
"""
Run connection strategies.
"""
if pipelines:
self.run_pipelines()
if strategies is None:
strategies = set(self.connectivity.values())
strategies = ConnectionStrategy.sort_deps(strategies)
with self.create_job_pool(fail_fast=fail_fast) as pool:
if pool.is_main():
pool.schedule(strategies)
pool.execute()
[docs]
@meter()
def run_placement_strategy(self, strategy):
"""
Run a single placement strategy.
"""
self.run_placement([strategy])
[docs]
@meter()
def run_after_placement(self, hooks=None, fail_fast=None, pipelines=True):
"""
Run after placement hooks.
"""
if hooks is None:
hooks = self.after_placement
with self.create_job_pool(fail_fast) as pool:
if pool.is_main():
pool.schedule(hooks)
pool.execute()
[docs]
@meter()
def run_after_connectivity(self, hooks=None, fail_fast=None, pipelines=True):
"""
Run after placement hooks.
"""
if hooks is None:
hooks = self.after_placement
with self.create_job_pool(fail_fast) as pool:
if pool.is_main():
pool.schedule(hooks)
pool.execute()
[docs]
@meter()
def compile(
self,
skip_placement=False,
skip_connectivity=False,
skip_after_placement=False,
skip_after_connectivity=False,
only=None,
skip=None,
clear=False,
append=False,
redo=False,
force=False,
fail_fast=True,
):
"""
Run reconstruction steps in the scaffold sequence to obtain a full network.
"""
existed = self.storage.preexisted
if skip_placement:
p_strats = []
else:
p_strats = self.get_placement(skip=skip, only=only)
if skip_connectivity:
c_strats = []
else:
c_strats = self.get_connectivity(skip=skip, only=only)
todo_list_str = ", ".join(s.name for s in itertools.chain(p_strats, c_strats))
report(f"Compiling the following strategies: {todo_list_str}", level=2)
if _bad_flag(clear) or _bad_flag(redo) or _bad_flag(append):
raise InputError(
"`clear`, `redo` and `append` are strictly boolean flags. "
"Pass the strategies to run to the skip/only options instead."
)
if sum((bool(clear), bool(redo), bool(append))) > 1:
raise InputError("`clear`, `redo` and `append` are mutually exclusive.")
if existed:
if not (clear or append or redo):
raise FileExistsError(
f"The `{self.storage.format}` storage"
+ f" at `{self.storage.root}` already exists. Either move/delete it,"
+ " or pass one of the `clear`, `append` or `redo` arguments"
+ " to pick what to do with the existing data."
)
if clear:
report("Clearing data", level=2)
# Clear the placement and connectivity data, but leave any cached files
# and morphologies intact.
self.clear_placement()
self.clear_connectivity()
elif redo:
# In order to properly redo things, we clear some placement and connection
# data, but since multiple placement/connection strategies can contribute
# to the same sets we might be wiping their data too, and they will need
# to be cleared and reran as well.
p_strats, c_strats = self._redo_chain(p_strats, c_strats, skip, force)
# else:
# append mode is luckily simpler, just don't clear anything :)
phases = ["pipelines"]
if not skip_placement:
phases.append("placement")
if not skip_after_placement:
phases.append("after_placement")
if not skip_connectivity:
phases.append("connectivity")
if not skip_after_connectivity:
phases.append("after_connectivity")
self._workflow = Workflow(phases)
try:
self.run_pipelines(fail_fast=fail_fast)
self._workflow.next_phase()
if not skip_placement:
placement_todo = ", ".join(s.name for s in p_strats)
report(f"Starting placement strategies: {placement_todo}", level=2)
self.run_placement(p_strats, fail_fast=fail_fast, pipelines=False)
self._workflow.next_phase()
if not skip_after_placement:
self.run_after_placement(pipelines=False, fail_fast=fail_fast)
self._workflow.next_phase()
if not skip_connectivity:
connectivity_todo = ", ".join(s.name for s in c_strats)
report(f"Starting connectivity strategies: {connectivity_todo}", level=2)
self.run_connectivity(c_strats, fail_fast=fail_fast, pipelines=False)
self._workflow.next_phase()
if not skip_after_connectivity:
self.run_after_connectivity(pipelines=False)
self._workflow.next_phase()
finally:
# After compilation we should flag the storage as having existed before so that
# the `clear`, `redo` and `append` flags take effect on a second `compile` pass.
self.storage._preexisted = True
del self._workflow
[docs]
@meter()
def run_pipelines(self, fail_fast=True, pipelines=None):
if pipelines is None:
pipelines = self.get_dependency_pipelines()
with self.create_job_pool(fail_fast=fail_fast) as pool:
if pool.is_main():
pool.schedule(pipelines)
pool.execute()
[docs]
@meter()
def run_simulation(self, simulation_name: str):
"""
Run a simulation starting from the default single-instance adapter.
:param simulation_name: Name of the simulation in the configuration.
:type simulation_name: str
"""
simulation = self.get_simulation(simulation_name)
adapter = get_simulation_adapter(simulation.simulator)
return adapter.simulate(simulation)[0]
[docs]
def get_simulation(self, sim_name: str) -> "Simulation":
"""
Retrieve the default single-instance adapter for a simulation.
"""
if sim_name not in self.simulations:
simstr = ", ".join(f"'{s}'" for s in self.simulations.keys())
raise NodeNotFoundError(
f"Unknown simulation '{sim_name}', choose from: {simstr}"
)
return self.configuration.simulations[sim_name]
[docs]
def place_cells(
self,
cell_type,
positions,
morphologies=None,
rotations=None,
additional=None,
chunk=None,
):
"""
Place cells inside of the scaffold
.. code-block:: python
# Add one granule cell at position 0, 0, 0
cell_type = scaffold.get_cell_type("granule_cell")
scaffold.place_cells(cell_type, cell_type.layer_instance, [[0., 0., 0.]])
:param cell_type: The type of the cells to place.
:type cell_type: ~bsb.cell_types.CellType
:param positions: A collection of xyz positions to place the cells on.
:type positions: Any `np.concatenate` type of shape (N, 3).
"""
if chunk is None:
chunk = Chunk([0, 0, 0], self.network.chunk_size)
if hasattr(chunk, "dimensions") and np.any(np.isnan(chunk.dimensions)):
chunk.dimensions = self.network.chunk_size
self.get_placement_set(cell_type).append_data(
chunk,
positions=positions,
morphologies=morphologies,
rotations=rotations,
additional=additional,
)
[docs]
def create_entities(self, cell_type, count):
"""
Create entities in the simulation space.
Entities are different from cells because they have no positional data and
don't influence the placement step. They do have a representation in the
connection and simulation step.
:param cell_type: The cell type of the entities
:type cell_type: ~bsb.cell_types.CellType
:param count: Number of entities to place
:type count: int
:todo: Allow `additional` data for entities
"""
if count == 0:
return
ps = self.get_placement_set(cell_type)
# Append entity data to the default chunk 000
chunk = Chunk([0, 0, 0], self.network.chunk_size)
ps.append_entities(chunk, count)
[docs]
def get_placement(
self, cell_types=None, skip=None, only=None
) -> typing.List["PlacementStrategy"]:
if cell_types is not None:
cell_types = [
self.cell_types[ct] if isinstance(ct, str) else ct for ct in cell_types
]
return [
val
for key, val in self.placement.items()
if (cell_types is None or any(ct in cell_types for ct in val.cell_types))
and (only is None or key in only)
and (skip is None or key not in skip)
]
[docs]
def get_placement_of(self, *cell_types):
"""
Find all of the placement strategies that given certain cell types.
:param cell_types: Cell types (or their names) of interest.
:type cell_types: Union[~bsb.cell_types.CellType, str]
"""
return self.get_placement(cell_types=cell_types)
[docs]
def get_placement_set(
self, type, chunks=None, labels=None, morphology_labels=None
) -> "PlacementSet":
"""
Return a cell type's placement set from the output formatter.
:param tag: Unique identifier of the placement set in the storage
:type tag: str
:returns: A placement set
:param labels: Labels to filter the placement set by.
:type labels: list[str]
:param morphology_labels: Subcellular labels to apply to the morphologies.
:type morphology_labels: list[str]
:rtype: :class:`~.storage.interfaces.PlacementSet`
"""
if isinstance(type, str):
type = self.cell_types[type]
return self.storage.get_placement_set(
type, chunks=chunks, labels=labels, morphology_labels=morphology_labels
)
[docs]
def get_placement_sets(self) -> typing.List["PlacementSet"]:
"""
Return all of the placement sets present in the network.
:rtype: List[~bsb.storage.interfaces.PlacementSet]
"""
return [cell_type.get_placement_set() for cell_type in self.cell_types.values()]
[docs]
def get_connectivity(
self, anywhere=None, presynaptic=None, postsynaptic=None, skip=None, only=None
) -> typing.List["ConnectivitySet"]:
conntype_filtered = self._connectivity_query(
any_query=set(self._sanitize_ct(anywhere)),
pre_query=set(self._sanitize_ct(presynaptic)),
post_query=set(self._sanitize_ct(postsynaptic)),
)
return [
ct
for ct in conntype_filtered
if (only is None or ct.name in only) and (skip is None or ct.name not in skip)
]
[docs]
def get_connectivity_sets(self) -> typing.List["ConnectivitySet"]:
"""
Return all connectivity sets from the output formatter.
:param tag: Unique identifier of the connectivity set in the output formatter
:type tag: str
:returns: All connectivity sets
"""
return [self._load_cs_types(cs) for cs in self.storage.get_connectivity_sets()]
[docs]
def require_connectivity_set(self, pre, post, tag=None) -> "ConnectivitySet":
return self._load_cs_types(
self.storage.require_connectivity_set(pre, post, tag), pre, post
)
[docs]
def get_connectivity_set(self, tag=None, pre=None, post=None) -> "ConnectivitySet":
"""
Return a connectivity set from the output formatter.
:param tag: Unique identifier of the connectivity set in the output formatter
:type tag: str
:returns: A connectivity set
:rtype: :class:`~.storage.interfaces.ConnectivitySet`
"""
if tag is None:
try:
tag = f"{pre.name}_to_{post.name}"
except Exception:
raise ValueError("Supply either `tag` or a valid pre and post cell type.")
return self._load_cs_types(self.storage.get_connectivity_set(tag), pre, post)
[docs]
def get_cell_types(self) -> typing.List["CellType"]:
"""
Return a list of all cell types in the network.
"""
return [*self.configuration.cell_types.values()]
[docs]
def merge(self, other, label=None):
raise NotImplementedError("Revisit: merge CT, PS & CS, done?")
def _sanitize_ct(self, seq_str_or_none):
if seq_str_or_none is None:
return []
try:
if isinstance(seq_str_or_none, str):
return [self.cell_types[seq_str_or_none]]
return [
self.cell_types[s] if isinstance(s, str) else s for s in seq_str_or_none
]
except KeyError as e:
raise NodeNotFoundError(f"Cell type `{e.args[0]}` not found.")
def _connectivity_query(self, any_query=set(), pre_query=set(), post_query=set()):
# Filter network connection types for any type that satisfies both
# the presynaptic and postsynaptic query. Empty queries satisfy all
# types. The presynaptic query is satisfied if the conn type contains
# any of the queried cell types presynaptically, and same for post.
# The any query is satisfied if a cell type is found either pre or post.
def partial_query(types, query):
return not query or any(cell_type in query for cell_type in types)
def query(conn_type):
pre_match = partial_query(conn_type.presynaptic.cell_types, pre_query)
post_match = partial_query(conn_type.postsynaptic.cell_types, post_query)
any_match = partial_query(
conn_type.presynaptic.cell_types, any_query
) or partial_query(conn_type.postsynaptic.cell_types, any_query)
return any_match or (pre_match and post_match)
types = self.connectivity.values()
return [*filter(query, types)]
def _redo_chain(self, p_strats, c_strats, skip, force):
p_contrib = set(p_strats)
while True:
# Get all the placement strategies that effect the current set of CT.
full_wipe = set(itertools.chain(*(ps.cell_types for ps in p_contrib)))
contrib = set(self.get_placement(full_wipe))
# Keep repeating until no new contributors are fished up.
if contrib.issubset(p_contrib):
break
# Grow the placement chain
p_contrib.update(contrib)
report(
"Redo-affected placement: " + " ".join(ps.name for ps in p_contrib), level=2
)
c_contrib = set(c_strats)
conn_wipe = full_wipe.copy()
if full_wipe:
while True:
contrib = set(self.get_connectivity(anywhere=conn_wipe))
conn_wipe.update(
itertools.chain(*(ct.get_cell_types() for ct in contrib))
)
if contrib.issubset(c_contrib):
break
c_contrib.update(contrib)
report(
"Redo-affected connectivity: " + " ".join(cs.name for cs in c_contrib),
level=2,
)
# Don't do greedy things without `force`
if not force:
# Error if we need to redo things the user asked to skip
if skip is not None:
unskipped = [p.name for p in p_contrib if p.name in skip]
if unskipped:
chainstr = ", ".join(f"'{s.name}'" for s in (p_strats + c_strats))
skipstr = ", ".join(f"'{s.name}'" for s in unskipped)
raise RedoError(
f"Can't skip {skipstr}. Redoing {chainstr} requires to redo them."
+ f" Omit {skipstr} from `skip` or use `force` (not recommended)."
)
# Error if we need to redo things the user didn't ask for
for label, chain, og in zip(
("placement", "connection"), (p_contrib, c_contrib), (p_strats, c_strats)
):
if len(chain) > len(og):
new = chain.difference(og)
raise RedoError(
f"Need to redo additional {label} strategies: "
+ ", ".join(n.name for n in new)
+ ". Include them or use `force` (not recommended)."
)
for ct in full_wipe:
report(f"Clearing all data of {ct.name}", level=2)
ct.clear()
for ct in conn_wipe:
report(f"Clearing connectivity data of {ct.name}", level=2)
ct.clear_connections()
return p_contrib, c_contrib
[docs]
def get_dependency_pipelines(self):
return [*self.configuration.morphologies]
[docs]
def get_config_diagram(self):
from .config import make_configuration_diagram
return make_configuration_diagram(self.configuration)
[docs]
def get_storage_diagram(self):
dot = f'digraph "{self.configuration.name or "network"}" {{'
for ps in self.get_placement_sets():
dot += f'\n {ps.tag}[label="{ps.tag} ({len(ps)} {ps.cell_type.name})"]'
for conn in self.get_connectivity_sets():
dot += f"\n {conn.pre_type.name} -> {conn.post_type.name}"
dot += f'[label="{conn.tag} ({len(conn)})"];'
dot += "\n}\n"
return dot
def _load_cs_types(
self, cs: "ConnectivitySet", pre=None, post=None
) -> "ConnectivitySet":
if pre and pre.name != cs.pre_type_name:
raise ValueError(
"Given and stored type mismatch:" + f" {pre.name} vs {cs.pre_type_name}"
)
if post and post.name != cs.post_type_name:
raise ValueError(
"Given and stored type mismatch:" + f" {post.name} vs {cs.post_type_name}"
)
try:
cs.pre_type = self.cell_types[cs.pre_type_name]
cs.post_type = self.cell_types[cs.post_type_name]
except KeyError as e:
raise NodeNotFoundError(
f"Couldn't load '{cs.tag}' connections, missing cell type '{e.args[0]}'."
) from None
return cs
[docs]
def create_job_pool(self, fail_fast=None, quiet=False):
pool = JobPool(
self, fail_fast=fail_fast, workflow=getattr(self, "_workflow", None)
)
try:
# Check whether stdout is a TTY, and that it is larger than 0x0
# (e.g. MPI sets it to 0x0 unless an xterm is emulated.
tty = os.isatty(sys.stdout.fileno()) and sum(os.get_terminal_size())
except Exception:
tty = False
if tty:
fps = 25
default_listener = TTYTerminalListener(fps)
default_max_wait = 1 / fps
else:
default_listener = NonTTYTerminalListener()
default_max_wait = None
if self._pool_listeners:
for listener, max_wait in self._pool_listeners:
pool.add_listener(listener, max_wait=max_wait)
elif not quiet:
pool.add_listener(default_listener, max_wait=default_max_wait)
return pool
[docs]
def register_listener(self, listener, max_wait=None):
self._pool_listeners.append((listener, max_wait))
[docs]
def remove_listener(self, listener):
for i, (l, _) in enumerate(self._pool_listeners):
if l is listener:
self._pool_listeners.pop(i)
break
[docs]
class ReportListener:
def __init__(self, scaffold, file):
self.file = file
self.scaffold = scaffold
def __call__(self, progress):
report(
str(progress.progression)
+ "+"
+ str(progress.duration)
+ "+"
+ str(progress.time),
token="simulation_progress",
)
__all__ = ["ReportListener", "Scaffold", "from_storage"]