"""
Job pooling module.
Jobs derive from the base :class:`.Job` class which can be put on the queue of a
:class:`.JobPool`. In order to submit themselves to the pool Jobs will
:meth:`~.Job.serialize` themselves into a predefined set of variables::
job.serialize() -> (job_type, f, args, kwargs)
* ``job_type`` should be a string that is a class name defined in this module.
(e.g. ``"PlacementJob"``)
* ``f`` should be the function object that the job's ``execute`` method should
execute.
* ``args`` and ``kwargs`` are the args to be passed to that ``f``.
The :meth:`.Job.execute` handler can help interpret ``args`` and ``kwargs``
before running ``f``. The execute handler has access to the scaffold on the MPI
process so one best serializes just the name of some part of the configuration,
rather than trying to pickle the complex objects. For example, the
:class:`.PlacementJob` uses the first ``args`` element to store the
:class:`~bsb.placement.strategy.PlacementStrategy` name and then retrieve it from the
scaffold:
.. code-block:: python
@staticmethod
def execute(job_owner, f, args, kwargs):
placement = job_owner.placement[args[0]]
indicators = placement.get_indicators()
return f(placement, *args[1:], indicators, **kwargs)
A job has a couple of display variables that can be set: ``_cname`` for the
class name, ``_name`` for the job name and ``_c`` for the chunk. These are used
to display what the workers are doing during parallel execution. This is an experimental
API and subject to sudden change in the future.
"""
import abc
import concurrent.futures
import logging
import pickle
import tempfile
import threading
import typing
import warnings
from contextlib import ExitStack
from enum import Enum, auto
from exceptiongroup import ExceptionGroup
from .._util import obj_str_insert
from ..exceptions import (
JobCancelledError,
JobPoolContextError,
JobPoolError,
JobSchedulingError,
)
from . import MPI
from ._util import ErrorModule, MockModule
if typing.TYPE_CHECKING:
from mpipool import MPIExecutor
[docs]
class WorkflowError(ExceptionGroup):
pass
[docs]
class JobErroredError(Exception):
def __init__(self, message, error):
super().__init__(message)
self.error = error
[docs]
class JobStatus(Enum):
# Job has not been queued yet, waiting for dependencies to resolve.
PENDING = "pending"
# Job is on the queue.
QUEUED = "queued"
# Job is currently running on a worker.
RUNNING = "running"
# Job ran successfully.
SUCCESS = "success"
# Job failed (an exception was raised).
FAILED = "failed"
# Job was cancelled before it started running.
CANCELLED = "cancelled"
# Job was killed for some reason.
ABORTED = "aborted"
[docs]
class PoolStatus(Enum):
# Pool has been initialized and jobs can be scheduled.
SCHEDULING = "scheduling"
# Pool started execution.
EXECUTING = "executing"
# Pool is closing down.
CLOSING = "closing"
[docs]
class PoolProgressReason(Enum):
POOL_STATUS_CHANGE = auto()
JOB_ADDED = auto()
JOB_STATUS_CHANGE = auto()
MAX_TIMEOUT_PING = auto()
[docs]
class Workflow:
def __init__(self, phases: list[str]):
self._phases = phases
self._phase = 0
@property
def phases(self):
return [*self._phases]
@property
def finished(self):
return self._phase >= len(self._phases)
@property
def phase(self):
if self.finished:
return "finished"
else:
return self._phases[self._phase]
[docs]
def next_phase(self):
self._phase += 1
return self.phase
[docs]
class PoolProgress:
"""
Class used to report pool progression to listeners.
"""
def __init__(self, pool: "JobPool", reason: PoolProgressReason):
self._pool = pool
self._reason = reason
@property
def reason(self):
return self._reason
@property
def workflow(self):
return self._pool.workflow
@property
def jobs(self):
return self._pool.jobs
@property
def status(self):
return self._pool.status
[docs]
class PoolJobAddedProgress(PoolProgress):
def __init__(self, pool: "JobPool", job: "Job"):
super().__init__(pool, PoolProgressReason.JOB_ADDED)
self._job = job
@property
def job(self):
return self._job
[docs]
class PoolJobUpdateProgress(PoolProgress):
def __init__(self, pool: "JobPool", job: "Job", old_status: "JobStatus"):
super().__init__(pool, PoolProgressReason.JOB_STATUS_CHANGE)
self._job = job
self._old_status = old_status
@property
def job(self):
return self._job
@property
def old_status(self):
return self._old_status
@property
def status(self):
return self._job.status
[docs]
class PoolStatusProgress(PoolProgress):
def __init__(self, pool: "JobPool", old_status: PoolStatus):
super().__init__(pool, PoolProgressReason.POOL_STATUS_CHANGE)
self._old_status = old_status
class _MissingMPIExecutor(ErrorModule):
pass
class _MPIPoolModule(MockModule):
@property
def MPIExecutor(self) -> typing.Type["MPIExecutor"]:
return _MissingMPIExecutor(
"This is not a public interface. Use `.services.JobPool` instead."
)
def enable_serde_logging(self):
import mpipool
mpipool.enable_serde_logging()
_MPIPool = _MPIPoolModule("mpipool")
[docs]
def dispatcher(pool_id, job_args):
job_type, args, kwargs = job_args
# Get the static job execution handler from this module
handler = globals()[job_type].execute
owner = JobPool.get_owner(pool_id)
# Execute it.
return handler(owner, args, kwargs)
[docs]
class SubmissionContext:
"""
Context information on who submitted a certain job.
"""
def __init__(self, submitter, chunks=None, **kwargs):
self._submitter = submitter
self._chunks = chunks
self._context = kwargs
@property
def name(self):
if hasattr(self._submitter, "get_node_name"):
name = self._submitter.get_node_name()
else:
name = str(self._submitter)
return name
@property
def submitter(self):
return self._submitter
@property
def chunks(self):
from ..storage._chunks import chunklist
return chunklist(self._chunks) if self._chunks is not None else None
@property
def context(self):
return {**self._context}
def __getattr__(self, key):
if key in self._context:
return self._context[key]
else:
return self.__getattribute__(key)
[docs]
class Job(abc.ABC):
"""
Dispatches the execution of a function through a JobPool
"""
def __init__(
self, pool, submission_context: SubmissionContext, args, kwargs, deps=None
):
self.pool_id = pool.id
self._args = args
self._kwargs = kwargs
self._deps = set(deps or [])
self._submit_ctx = submission_context
self._completion_cbs = []
self._status = JobStatus.PENDING
self._future: typing.Optional[concurrent.futures.Future] = None
self._thread: typing.Optional[threading.Thread] = None
self._res_file = None
self._error = None
for j in self._deps:
j.on_completion(self._dep_completed)
@obj_str_insert
def __str__(self):
return self.description
@property
def name(self):
return self._submit_ctx.name
@property
def description(self):
descr = self.name
if self.context:
descr += " (" + ", ".join(f"{k}={v}" for k, v in self.context.items()) + ")"
return descr
@property
def submitter(self):
return self._submit_ctx.submitter
@property
def context(self):
return self._submit_ctx.context
@property
def status(self):
return self._status
@property
def result(self):
try:
with open(self._res_file, "rb") as f:
return pickle.load(f)
except Exception:
raise JobPoolError(f"Result of {self} is not available.") from None
@property
def error(self):
return self._error
[docs]
def serialize(self):
"""
Convert the job to a (de)serializable representation
"""
name = self.__class__.__name__
# First arg is to find the static `execute` method so that we don't have to
# serialize any of the job objects themselves but can still use different handlers
# for different job types.
return (name, self._args, self._kwargs)
[docs]
@staticmethod
@abc.abstractmethod
def execute(job_owner, args, kwargs):
"""
Job handler
"""
pass
[docs]
def run(self, timeout=None):
"""
Execute the job on the current process, in a thread, and return whether the job is still running.
"""
if self._thread is None:
def target():
try:
# Execute the static handler
result = self.execute(self._pool.owner, self._args, self._kwargs)
except Exception as e:
self._future.set_exception(e)
else:
self._future.set_result(result)
self._thread = threading.Thread(target=target, daemon=True)
self._thread.start()
self._thread.join(timeout=timeout)
if not self._thread.is_alive():
self._completed()
return False
return True
[docs]
def on_completion(self, cb):
self._completion_cbs.append(cb)
[docs]
def set_result(self, value):
dirname = JobPool.get_tmp_folder(self.pool_id)
try:
with tempfile.NamedTemporaryFile(
prefix=dirname + "/", delete=False, mode="wb"
) as fp:
pickle.dump(value, fp)
self._res_file = fp.name
except FileNotFoundError as e:
self.set_exception(e)
else:
self.change_status(JobStatus.SUCCESS)
[docs]
def set_exception(self, e: Exception):
self._error = e
self.change_status(JobStatus.FAILED)
def _completed(self):
if self._status != JobStatus.CANCELLED:
try:
result = self._future.result()
except Exception as e:
self.set_exception(e)
else:
self.set_result(result)
for cb in self._completion_cbs:
cb(self)
def _dep_completed(self, dep):
# Earlier we registered this callback on the completion of our dependencies.
# When a dep completes we end up here and we discard it as a dependency as it has
# finished. If the dep returns an error remove the job from the pool, since the dependency have failed.
self._deps.discard(dep)
if dep._status is not JobStatus.SUCCESS:
self.cancel("Job killed for dependency failure")
else:
# When all our dependencies have been discarded we can queue ourselves. Unless the
# pool is serial, then the pool itself just runs all jobs in order.
if not self._deps and MPI.get_size() > 1:
# self._pool is set when the pool first tried to enqueue us, but we were still
# waiting for deps, in the `_enqueue` method below.
self._enqueue(self._pool)
def _enqueue(self, pool):
if not self._deps and self._status is not JobStatus.CANCELLED:
# Go ahead and submit ourselves to the pool, no dependencies to wait for
# The dispatcher is run on the remote worker and unpacks the data required
# to execute the job contents.
self.change_status(JobStatus.QUEUED)
self._future = pool._submit(dispatcher, self.pool_id, self.serialize())
else:
# We have unfinished dependencies and should wait until we can enqueue
# ourselves when our dependencies haved all notified us of their completion.
# Store the reference to the pool though, so later in `_dep_completed` we can
# call `_enqueue` again ourselves!
self._pool = pool
[docs]
def cancel(self, reason: typing.Optional[str] = None):
self.change_status(JobStatus.CANCELLED)
self._error = JobCancelledError() if reason is None else JobCancelledError(reason)
if self._future:
if not self._future.cancel():
warnings.warn(f"Could not cancel {self}, the job is already running.")
[docs]
def change_status(self, status: JobStatus):
old_status = self._status
self._status = status
try:
# Closed pools may have been removed from this map already.
pool = JobPool._pools[self.pool_id]
except KeyError:
pass
else:
progress = PoolJobUpdateProgress(pool, self, old_status)
pool.add_notification(progress)
[docs]
class PlacementJob(Job):
"""
Dispatches the execution of a chunk of a placement strategy through a JobPool.
"""
def __init__(self, pool, strategy, chunk, deps=None):
args = (strategy.name, chunk)
context = SubmissionContext(strategy, [chunk])
super().__init__(pool, context, args, {}, deps=deps)
[docs]
@staticmethod
def execute(job_owner, args, kwargs):
name, chunk = args
placement = job_owner.placement[name]
indicators = placement.get_indicators()
return placement.place(chunk, indicators, **kwargs)
[docs]
class ConnectivityJob(Job):
"""
Dispatches the execution of a chunk of a connectivity strategy through a JobPool.
"""
def __init__(self, pool, strategy, pre_roi, post_roi, deps=None):
from ..storage._chunks import chunklist
args = (strategy.name, pre_roi, post_roi)
context = SubmissionContext(
strategy, chunks=chunklist((*(pre_roi or []), *(post_roi or [])))
)
super().__init__(pool, context, args, {}, deps=deps)
[docs]
@staticmethod
def execute(job_owner, args, kwargs):
name = args[0]
connectivity = job_owner.connectivity[name]
collections = connectivity._get_connect_args_from_job(*args[1:])
return connectivity.connect(*collections, **kwargs)
[docs]
class FunctionJob(Job):
def __init__(self, pool, f, args, kwargs, deps=None, **context):
# Pack the function into the args
args = (f, args)
# If no submitter was given, set the function as submitter
context.setdefault("submitter", f)
super().__init__(pool, SubmissionContext(**context), args, kwargs, deps=deps)
[docs]
@staticmethod
def execute(job_owner, args, kwargs):
# Unpack the function from the args
f, args = args
return f(job_owner, *args, **kwargs)
[docs]
class JobPool:
_next_pool_id = 0
_pools = {}
_pool_owners = {}
_tmp_folders = {}
def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None):
self._schedulers: list[concurrent.futures.Future] = []
self.id: int = None
self._scaffold = scaffold
self._unhandled_errors = []
self._running_futures: list[concurrent.futures.Future] = []
self._mpipool: typing.Optional["MPIExecutor"] = None
self._job_queue: list[Job] = []
self._listeners = []
self._max_wait = 60
self._status: PoolStatus = None
self._progress_notifications: list["PoolProgress"] = []
self._workers_raise_unhandled = False
self._fail_fast = fail_fast
self._workflow = workflow
def __enter__(self):
self._context = ExitStack()
tmp_dirname = self._context.enter_context(tempfile.TemporaryDirectory())
self.id = JobPool._next_pool_id
JobPool._next_pool_id += 1
JobPool._pool_owners[self.id] = self._scaffold
JobPool._pools[self.id] = self
JobPool._tmp_folders[self.id] = tmp_dirname
del self._scaffold
for listener in self._listeners:
try:
self._context.enter_context(listener)
except (TypeError, AttributeError):
# Listener is not a context manager
pass
self.change_status(PoolStatus.SCHEDULING)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._context.__exit__(exc_type, exc_val, exc_tb)
# Clean up pool/job references
self._job_queue = []
del JobPool._pools[self.id]
del JobPool._pool_owners[self.id]
del JobPool._tmp_folders[self.id]
self.id = None
[docs]
def add_listener(self, listener, max_wait=None):
self._max_wait = min(self._max_wait, max_wait or float("+inf"))
self._listeners.append(listener)
@property
def workflow(self):
return self._workflow
@property
def status(self):
return self._status
@property
def jobs(self) -> list[Job]:
return [*self._job_queue]
@property
def parallel(self):
return MPI.get_size() > 1
[docs]
@classmethod
def get_owner(cls, id):
return cls._pool_owners[id]
[docs]
@classmethod
def get_tmp_folder(cls, id):
return cls._tmp_folders[id]
@property
def owner(self):
return self.get_owner(self.id)
[docs]
def is_main(self):
return MPI.get_rank() == 0
[docs]
def get_submissions_of(self, submitter):
return [job for job in self._job_queue if job.submitter is submitter]
def _put(self, job):
"""
Puts a job onto our internal queue.
"""
if self._mpipool and not self._mpipool.open:
raise JobPoolError("No job pool available for job submission.")
else:
self.add_notification(PoolJobAddedProgress(self, job))
self._job_queue.append(job)
if self._mpipool:
# This job was scheduled after the MPIPool was opened, so immediately
# put it on the MPIPool's queue.
job._enqueue(self)
def _submit(self, fn, *args, **kwargs):
if not self._mpipool or not self._mpipool.open:
raise JobPoolError("No job pool available for job submission.")
else:
future = self._mpipool.submit(fn, *args, **kwargs)
self._running_futures.append(future)
return future
def _schedule(self, future: concurrent.futures.Future, nodes, scheduler):
_failed_nodes = []
if not future.set_running_or_notify_cancel():
return
try:
for node in nodes:
failed_deps = [
n for n in getattr(node, "depends_on", []) if n in _failed_nodes
]
if failed_deps:
_failed_nodes.append(node)
ctx = SubmissionContext(
node,
error=JobSchedulingError(
f"Depends on {failed_deps}, whom failed."
),
)
self._unhandled_errors.append(ctx)
continue
try:
scheduler(node)
except Exception as e:
_failed_nodes.append(node)
ctx = SubmissionContext(node, error=e)
self._unhandled_errors.append(ctx)
finally:
future.set_result(None)
[docs]
def schedule(self, nodes, scheduler=None):
if scheduler is None:
def scheduler(node):
node.queue(self)
future = concurrent.futures.Future()
self._schedulers.append(future)
thread = threading.Thread(target=self._schedule, args=(future, nodes, scheduler))
thread.start()
@property
def scheduling(self):
return any(not f.done() for f in self._schedulers)
[docs]
def queue(self, f, args=None, kwargs=None, deps=None, **context):
job = FunctionJob(self, f, args or [], kwargs or {}, deps, **context)
self._put(job)
return job
[docs]
def queue_placement(self, strategy, chunk, deps=None):
job = PlacementJob(self, strategy, chunk, deps)
self._put(job)
return job
[docs]
def queue_connectivity(self, strategy, pre_roi, post_roi, deps=None):
job = ConnectivityJob(self, strategy, pre_roi, post_roi, deps)
self._put(job)
return job
[docs]
def execute(self, return_results=False):
"""
Execute the jobs in the queue
In serial execution this runs all of the jobs in the queue in First In First Out
order. In parallel execution this enqueues all jobs into the MPIPool unless they
have dependencies that need to complete first.
"""
if self.id is None:
raise JobPoolContextError("Job pools must use a context manager.")
if self.parallel:
self._execute_parallel()
else:
self._execute_serial()
if return_results:
return {
job: job.result
for job in self._job_queue
if job.status == JobStatus.SUCCESS
}
def _execute_parallel(self):
import bsb.options
# Enable full mpipool debugging
if bsb.options.debug_pool:
_MPIPool.enable_serde_logging()
# Create the MPI pool
self._mpipool = _MPIPool.MPIExecutor(
loglevel=logging.DEBUG if bsb.options.debug_pool else logging.CRITICAL
)
if self._mpipool.is_worker():
# The workers will return out of the pool constructor when they receive
# the shutdown signal from the master, they return here skipping the
# master logic.
# Check if we need to abort our process due to errors etc.
abort = MPI.bcast(None)
if abort:
raise WorkflowError(
"Unhandled exceptions during parallel execution.",
[JobPoolError("See main node logs for details.")],
)
return
try:
# Tell the listeners execution is running
self.change_status(PoolStatus.EXECUTING)
# Kickstart the workers with the queued jobs
for job in self._job_queue:
job._enqueue(self)
# Add the scheduling futures to the running futures, to await them.
self._running_futures.extend(self._schedulers)
# Keep executing as long as any of the schedulers or jobs aren't done yet.
while self.scheduling or any(
job.status == JobStatus.PENDING or job.status == JobStatus.QUEUED
for job in self._job_queue
):
try:
done, not_done = concurrent.futures.wait(
self._running_futures,
timeout=self._max_wait,
return_when="FIRST_COMPLETED",
)
except ValueError:
# Sometimes a ValueError is raised here, perhaps because we modify
# the list below?
continue
# Complete any jobs that are done
for job in self._job_queue:
if job._future in done:
job._completed()
# Remove running futures that are done
for future in done:
self._running_futures.remove(future)
# If nothing finished, post a timeout notification.
if not len(done):
self.ping()
# Notify all the listeners, and store/raise any unhandled errors
self.notify()
# Notify listeners that execution is over
self.change_status(PoolStatus.CLOSING)
# Raise any unhandled errors
self.raise_unhandled()
except:
# If any exception (including SystemExit and KeyboardInterrupt) happen on main, we should
# broadcast the abort to all worker nodes.
self._workers_raise_unhandled = True
raise
finally:
# Shut down our internal pool
self._mpipool.shutdown(wait=False, cancel_futures=True)
# Broadcast whether the worker nodes should raise an unhandled error.
MPI.bcast(self._workers_raise_unhandled)
def _execute_serial(self):
# Wait for jobs to finish scheduling
while concurrent.futures.wait(
self._schedulers, timeout=self._max_wait, return_when="FIRST_COMPLETED"
)[1]:
self.ping()
self.notify()
# Prepare jobs for local execution
for job in self._job_queue:
job._future = concurrent.futures.Future()
job._pool = self
if job.status != JobStatus.CANCELLED and job.status != JobStatus.ABORTED:
job._status = JobStatus.QUEUED
else:
job._future.cancel()
self.change_status(PoolStatus.EXECUTING)
# Just run each job serially
for job in self._job_queue:
if not job._future.set_running_or_notify_cancel():
continue
job.change_status(JobStatus.RUNNING)
self.notify()
while job.run(timeout=self._max_wait):
self.ping()
self.notify()
self.notify()
# Raise any unhandled errors
self.raise_unhandled()
self.change_status(PoolStatus.CLOSING)
[docs]
def change_status(self, status: PoolStatus):
old_status = self._status
self._status = status
self.add_notification(PoolStatusProgress(self, old_status))
self.notify()
[docs]
def add_notification(self, notification: PoolProgress):
self._progress_notifications.append(notification)
[docs]
def ping(self):
self.add_notification(PoolProgress(self, PoolProgressReason.MAX_TIMEOUT_PING))
[docs]
def notify(self):
for notification in self._progress_notifications:
job = getattr(notification, "job", None)
job_error = getattr(job, "error", None)
has_error = job_error is not None and type(job_error) is not JobCancelledError
handled_error = [bool(listener(notification)) for listener in self._listeners]
if has_error and not any(handled_error):
self._unhandled_errors.append(job)
if self._fail_fast:
self.raise_unhandled()
self._progress_notifications = []
[docs]
def raise_unhandled(self):
if not self._unhandled_errors:
return
errors = []
# Raise and catch for nicer traceback
for job in self._unhandled_errors:
try:
if isinstance(job, SubmissionContext):
raise JobSchedulingError(
f"{job.name} failed to schedule its jobs."
) from job.context["error"]
raise JobErroredError(f"{job} failed", job.error) from job.error
except (JobErroredError, JobSchedulingError) as e:
errors.append(e)
self._unhandled_errors = []
raise WorkflowError(
f"Your workflow encountered errors.",
errors,
)