Source code for bsb.services.pool

"""
Job pooling module.

Jobs derive from the base :class:`.Job` class which can be put on the queue of a
:class:`.JobPool`. In order to submit themselves to the pool Jobs will
:meth:`~.Job.serialize` themselves into a predefined set of variables::

   job.serialize() -> (job_type, f, args, kwargs)

* ``job_type`` should be a string that is a class name defined in this module.
   (e.g. ``"PlacementJob"``)

* ``f`` should be the function object that the job's ``execute`` method should
   execute.

* ``args`` and ``kwargs`` are the args to be passed to that ``f``.

The :meth:`.Job.execute` handler can help interpret ``args`` and ``kwargs``
before running ``f``. The execute handler has access to the scaffold on the MPI
process so one best serializes just the name of some part of the configuration,
rather than trying to pickle the complex objects. For example, the
:class:`.PlacementJob` uses the first ``args`` element to store the
:class:`~bsb.placement.strategy.PlacementStrategy` name and then retrieve it from the
scaffold:

.. code-block:: python

    @staticmethod
    def execute(job_owner, f, args, kwargs):
        placement = job_owner.placement[args[0]]
        indicators = placement.get_indicators()
        return f(placement, *args[1:], indicators, **kwargs)

A job has a couple of display variables that can be set: ``_cname`` for the
class name, ``_name`` for the job name and ``_c`` for the chunk. These are used
to display what the workers are doing during parallel execution. This is an experimental
API and subject to sudden change in the future.

"""

import abc
import concurrent.futures
import logging
import pickle
import tempfile
import threading
import typing
import warnings
from contextlib import ExitStack
from enum import Enum, auto

from exceptiongroup import ExceptionGroup

from .._util import obj_str_insert
from ..exceptions import (
    JobCancelledError,
    JobPoolContextError,
    JobPoolError,
    JobSchedulingError,
)
from . import MPI
from ._util import ErrorModule, MockModule

if typing.TYPE_CHECKING:
    from mpipool import MPIExecutor


[docs] class WorkflowError(ExceptionGroup): pass
[docs] class JobErroredError(Exception): def __init__(self, message, error): super().__init__(message) self.error = error
[docs] class JobStatus(Enum): # Job has not been queued yet, waiting for dependencies to resolve. PENDING = "pending" # Job is on the queue. QUEUED = "queued" # Job is currently running on a worker. RUNNING = "running" # Job ran successfully. SUCCESS = "success" # Job failed (an exception was raised). FAILED = "failed" # Job was cancelled before it started running. CANCELLED = "cancelled" # Job was killed for some reason. ABORTED = "aborted"
[docs] class PoolStatus(Enum): # Pool has been initialized and jobs can be scheduled. SCHEDULING = "scheduling" # Pool started execution. EXECUTING = "executing" # Pool is closing down. CLOSING = "closing"
[docs] class PoolProgressReason(Enum): POOL_STATUS_CHANGE = auto() JOB_ADDED = auto() JOB_STATUS_CHANGE = auto() MAX_TIMEOUT_PING = auto()
[docs] class Workflow: def __init__(self, phases: list[str]): self._phases = phases self._phase = 0 @property def phases(self): return [*self._phases] @property def finished(self): return self._phase >= len(self._phases) @property def phase(self): if self.finished: return "finished" else: return self._phases[self._phase]
[docs] def next_phase(self): self._phase += 1 return self.phase
[docs] class PoolProgress: """ Class used to report pool progression to listeners. """ def __init__(self, pool: "JobPool", reason: PoolProgressReason): self._pool = pool self._reason = reason @property def reason(self): return self._reason @property def workflow(self): return self._pool.workflow @property def jobs(self): return self._pool.jobs @property def status(self): return self._pool.status
[docs] class PoolJobAddedProgress(PoolProgress): def __init__(self, pool: "JobPool", job: "Job"): super().__init__(pool, PoolProgressReason.JOB_ADDED) self._job = job @property def job(self): return self._job
[docs] class PoolJobUpdateProgress(PoolProgress): def __init__(self, pool: "JobPool", job: "Job", old_status: "JobStatus"): super().__init__(pool, PoolProgressReason.JOB_STATUS_CHANGE) self._job = job self._old_status = old_status @property def job(self): return self._job @property def old_status(self): return self._old_status @property def status(self): return self._job.status
[docs] class PoolStatusProgress(PoolProgress): def __init__(self, pool: "JobPool", old_status: PoolStatus): super().__init__(pool, PoolProgressReason.POOL_STATUS_CHANGE) self._old_status = old_status
class _MissingMPIExecutor(ErrorModule): pass class _MPIPoolModule(MockModule): @property def MPIExecutor(self) -> typing.Type["MPIExecutor"]: return _MissingMPIExecutor( "This is not a public interface. Use `.services.JobPool` instead." ) def enable_serde_logging(self): import mpipool mpipool.enable_serde_logging() _MPIPool = _MPIPoolModule("mpipool")
[docs] def dispatcher(pool_id, job_args): job_type, args, kwargs = job_args # Get the static job execution handler from this module handler = globals()[job_type].execute owner = JobPool.get_owner(pool_id) # Execute it. return handler(owner, args, kwargs)
[docs] class SubmissionContext: """ Context information on who submitted a certain job. """ def __init__(self, submitter, chunks=None, **kwargs): self._submitter = submitter self._chunks = chunks self._context = kwargs @property def name(self): if hasattr(self._submitter, "get_node_name"): name = self._submitter.get_node_name() else: name = str(self._submitter) return name @property def submitter(self): return self._submitter @property def chunks(self): from ..storage._chunks import chunklist return chunklist(self._chunks) if self._chunks is not None else None @property def context(self): return {**self._context} def __getattr__(self, key): if key in self._context: return self._context[key] else: return self.__getattribute__(key)
[docs] class Job(abc.ABC): """ Dispatches the execution of a function through a JobPool """ def __init__( self, pool, submission_context: SubmissionContext, args, kwargs, deps=None ): self.pool_id = pool.id self._args = args self._kwargs = kwargs self._deps = set(deps or []) self._submit_ctx = submission_context self._completion_cbs = [] self._status = JobStatus.PENDING self._future: typing.Optional[concurrent.futures.Future] = None self._thread: typing.Optional[threading.Thread] = None self._res_file = None self._error = None for j in self._deps: j.on_completion(self._dep_completed) @obj_str_insert def __str__(self): return self.description @property def name(self): return self._submit_ctx.name @property def description(self): descr = self.name if self.context: descr += " (" + ", ".join(f"{k}={v}" for k, v in self.context.items()) + ")" return descr @property def submitter(self): return self._submit_ctx.submitter @property def context(self): return self._submit_ctx.context @property def status(self): return self._status @property def result(self): try: with open(self._res_file, "rb") as f: return pickle.load(f) except Exception: raise JobPoolError(f"Result of {self} is not available.") from None @property def error(self): return self._error
[docs] def serialize(self): """ Convert the job to a (de)serializable representation """ name = self.__class__.__name__ # First arg is to find the static `execute` method so that we don't have to # serialize any of the job objects themselves but can still use different handlers # for different job types. return (name, self._args, self._kwargs)
[docs] @staticmethod @abc.abstractmethod def execute(job_owner, args, kwargs): """ Job handler """ pass
[docs] def run(self, timeout=None): """ Execute the job on the current process, in a thread, and return whether the job is still running. """ if self._thread is None: def target(): try: # Execute the static handler result = self.execute(self._pool.owner, self._args, self._kwargs) except Exception as e: self._future.set_exception(e) else: self._future.set_result(result) self._thread = threading.Thread(target=target, daemon=True) self._thread.start() self._thread.join(timeout=timeout) if not self._thread.is_alive(): self._completed() return False return True
[docs] def on_completion(self, cb): self._completion_cbs.append(cb)
[docs] def set_result(self, value): dirname = JobPool.get_tmp_folder(self.pool_id) try: with tempfile.NamedTemporaryFile( prefix=dirname + "/", delete=False, mode="wb" ) as fp: pickle.dump(value, fp) self._res_file = fp.name except FileNotFoundError as e: self.set_exception(e) else: self.change_status(JobStatus.SUCCESS)
[docs] def set_exception(self, e: Exception): self._error = e self.change_status(JobStatus.FAILED)
def _completed(self): if self._status != JobStatus.CANCELLED: try: result = self._future.result() except Exception as e: self.set_exception(e) else: self.set_result(result) for cb in self._completion_cbs: cb(self) def _dep_completed(self, dep): # Earlier we registered this callback on the completion of our dependencies. # When a dep completes we end up here and we discard it as a dependency as it has # finished. If the dep returns an error remove the job from the pool, since the dependency have failed. self._deps.discard(dep) if dep._status is not JobStatus.SUCCESS: self.cancel("Job killed for dependency failure") else: # When all our dependencies have been discarded we can queue ourselves. Unless the # pool is serial, then the pool itself just runs all jobs in order. if not self._deps and MPI.get_size() > 1: # self._pool is set when the pool first tried to enqueue us, but we were still # waiting for deps, in the `_enqueue` method below. self._enqueue(self._pool) def _enqueue(self, pool): if not self._deps and self._status is not JobStatus.CANCELLED: # Go ahead and submit ourselves to the pool, no dependencies to wait for # The dispatcher is run on the remote worker and unpacks the data required # to execute the job contents. self.change_status(JobStatus.QUEUED) self._future = pool._submit(dispatcher, self.pool_id, self.serialize()) else: # We have unfinished dependencies and should wait until we can enqueue # ourselves when our dependencies haved all notified us of their completion. # Store the reference to the pool though, so later in `_dep_completed` we can # call `_enqueue` again ourselves! self._pool = pool
[docs] def cancel(self, reason: typing.Optional[str] = None): self.change_status(JobStatus.CANCELLED) self._error = JobCancelledError() if reason is None else JobCancelledError(reason) if self._future: if not self._future.cancel(): warnings.warn(f"Could not cancel {self}, the job is already running.")
[docs] def change_status(self, status: JobStatus): old_status = self._status self._status = status try: # Closed pools may have been removed from this map already. pool = JobPool._pools[self.pool_id] except KeyError: pass else: progress = PoolJobUpdateProgress(pool, self, old_status) pool.add_notification(progress)
[docs] class PlacementJob(Job): """ Dispatches the execution of a chunk of a placement strategy through a JobPool. """ def __init__(self, pool, strategy, chunk, deps=None): args = (strategy.name, chunk) context = SubmissionContext(strategy, [chunk]) super().__init__(pool, context, args, {}, deps=deps)
[docs] @staticmethod def execute(job_owner, args, kwargs): name, chunk = args placement = job_owner.placement[name] indicators = placement.get_indicators() return placement.place(chunk, indicators, **kwargs)
[docs] class ConnectivityJob(Job): """ Dispatches the execution of a chunk of a connectivity strategy through a JobPool. """ def __init__(self, pool, strategy, pre_roi, post_roi, deps=None): from ..storage._chunks import chunklist args = (strategy.name, pre_roi, post_roi) context = SubmissionContext( strategy, chunks=chunklist((*(pre_roi or []), *(post_roi or []))) ) super().__init__(pool, context, args, {}, deps=deps)
[docs] @staticmethod def execute(job_owner, args, kwargs): name = args[0] connectivity = job_owner.connectivity[name] collections = connectivity._get_connect_args_from_job(*args[1:]) return connectivity.connect(*collections, **kwargs)
[docs] class FunctionJob(Job): def __init__(self, pool, f, args, kwargs, deps=None, **context): # Pack the function into the args args = (f, args) # If no submitter was given, set the function as submitter context.setdefault("submitter", f) super().__init__(pool, SubmissionContext(**context), args, kwargs, deps=deps)
[docs] @staticmethod def execute(job_owner, args, kwargs): # Unpack the function from the args f, args = args return f(job_owner, *args, **kwargs)
[docs] class JobPool: _next_pool_id = 0 _pools = {} _pool_owners = {} _tmp_folders = {} def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None): self._schedulers: list[concurrent.futures.Future] = [] self.id: int = None self._scaffold = scaffold self._unhandled_errors = [] self._running_futures: list[concurrent.futures.Future] = [] self._mpipool: typing.Optional["MPIExecutor"] = None self._job_queue: list[Job] = [] self._listeners = [] self._max_wait = 60 self._status: PoolStatus = None self._progress_notifications: list["PoolProgress"] = [] self._workers_raise_unhandled = False self._fail_fast = fail_fast self._workflow = workflow def __enter__(self): self._context = ExitStack() tmp_dirname = self._context.enter_context(tempfile.TemporaryDirectory()) self.id = JobPool._next_pool_id JobPool._next_pool_id += 1 JobPool._pool_owners[self.id] = self._scaffold JobPool._pools[self.id] = self JobPool._tmp_folders[self.id] = tmp_dirname del self._scaffold for listener in self._listeners: try: self._context.enter_context(listener) except (TypeError, AttributeError): # Listener is not a context manager pass self.change_status(PoolStatus.SCHEDULING) return self def __exit__(self, exc_type, exc_val, exc_tb): self._context.__exit__(exc_type, exc_val, exc_tb) # Clean up pool/job references self._job_queue = [] del JobPool._pools[self.id] del JobPool._pool_owners[self.id] del JobPool._tmp_folders[self.id] self.id = None
[docs] def add_listener(self, listener, max_wait=None): self._max_wait = min(self._max_wait, max_wait or float("+inf")) self._listeners.append(listener)
@property def workflow(self): return self._workflow @property def status(self): return self._status @property def jobs(self) -> list[Job]: return [*self._job_queue] @property def parallel(self): return MPI.get_size() > 1
[docs] @classmethod def get_owner(cls, id): return cls._pool_owners[id]
[docs] @classmethod def get_tmp_folder(cls, id): return cls._tmp_folders[id]
@property def owner(self): return self.get_owner(self.id)
[docs] def is_main(self): return MPI.get_rank() == 0
[docs] def get_submissions_of(self, submitter): return [job for job in self._job_queue if job.submitter is submitter]
def _put(self, job): """ Puts a job onto our internal queue. """ if self._mpipool and not self._mpipool.open: raise JobPoolError("No job pool available for job submission.") else: self.add_notification(PoolJobAddedProgress(self, job)) self._job_queue.append(job) if self._mpipool: # This job was scheduled after the MPIPool was opened, so immediately # put it on the MPIPool's queue. job._enqueue(self) def _submit(self, fn, *args, **kwargs): if not self._mpipool or not self._mpipool.open: raise JobPoolError("No job pool available for job submission.") else: future = self._mpipool.submit(fn, *args, **kwargs) self._running_futures.append(future) return future def _schedule(self, future: concurrent.futures.Future, nodes, scheduler): _failed_nodes = [] if not future.set_running_or_notify_cancel(): return try: for node in nodes: failed_deps = [ n for n in getattr(node, "depends_on", []) if n in _failed_nodes ] if failed_deps: _failed_nodes.append(node) ctx = SubmissionContext( node, error=JobSchedulingError( f"Depends on {failed_deps}, whom failed." ), ) self._unhandled_errors.append(ctx) continue try: scheduler(node) except Exception as e: _failed_nodes.append(node) ctx = SubmissionContext(node, error=e) self._unhandled_errors.append(ctx) finally: future.set_result(None)
[docs] def schedule(self, nodes, scheduler=None): if scheduler is None: def scheduler(node): node.queue(self) future = concurrent.futures.Future() self._schedulers.append(future) thread = threading.Thread(target=self._schedule, args=(future, nodes, scheduler)) thread.start()
@property def scheduling(self): return any(not f.done() for f in self._schedulers)
[docs] def queue(self, f, args=None, kwargs=None, deps=None, **context): job = FunctionJob(self, f, args or [], kwargs or {}, deps, **context) self._put(job) return job
[docs] def queue_placement(self, strategy, chunk, deps=None): job = PlacementJob(self, strategy, chunk, deps) self._put(job) return job
[docs] def queue_connectivity(self, strategy, pre_roi, post_roi, deps=None): job = ConnectivityJob(self, strategy, pre_roi, post_roi, deps) self._put(job) return job
[docs] def execute(self, return_results=False): """ Execute the jobs in the queue In serial execution this runs all of the jobs in the queue in First In First Out order. In parallel execution this enqueues all jobs into the MPIPool unless they have dependencies that need to complete first. """ if self.id is None: raise JobPoolContextError("Job pools must use a context manager.") if self.parallel: self._execute_parallel() else: self._execute_serial() if return_results: return { job: job.result for job in self._job_queue if job.status == JobStatus.SUCCESS }
def _execute_parallel(self): import bsb.options # Enable full mpipool debugging if bsb.options.debug_pool: _MPIPool.enable_serde_logging() # Create the MPI pool self._mpipool = _MPIPool.MPIExecutor( loglevel=logging.DEBUG if bsb.options.debug_pool else logging.CRITICAL ) if self._mpipool.is_worker(): # The workers will return out of the pool constructor when they receive # the shutdown signal from the master, they return here skipping the # master logic. # Check if we need to abort our process due to errors etc. abort = MPI.bcast(None) if abort: raise WorkflowError( "Unhandled exceptions during parallel execution.", [JobPoolError("See main node logs for details.")], ) return try: # Tell the listeners execution is running self.change_status(PoolStatus.EXECUTING) # Kickstart the workers with the queued jobs for job in self._job_queue: job._enqueue(self) # Add the scheduling futures to the running futures, to await them. self._running_futures.extend(self._schedulers) # Keep executing as long as any of the schedulers or jobs aren't done yet. while self.scheduling or any( job.status == JobStatus.PENDING or job.status == JobStatus.QUEUED for job in self._job_queue ): try: done, not_done = concurrent.futures.wait( self._running_futures, timeout=self._max_wait, return_when="FIRST_COMPLETED", ) except ValueError: # Sometimes a ValueError is raised here, perhaps because we modify # the list below? continue # Complete any jobs that are done for job in self._job_queue: if job._future in done: job._completed() # Remove running futures that are done for future in done: self._running_futures.remove(future) # If nothing finished, post a timeout notification. if not len(done): self.ping() # Notify all the listeners, and store/raise any unhandled errors self.notify() # Notify listeners that execution is over self.change_status(PoolStatus.CLOSING) # Raise any unhandled errors self.raise_unhandled() except: # If any exception (including SystemExit and KeyboardInterrupt) happen on main, we should # broadcast the abort to all worker nodes. self._workers_raise_unhandled = True raise finally: # Shut down our internal pool self._mpipool.shutdown(wait=False, cancel_futures=True) # Broadcast whether the worker nodes should raise an unhandled error. MPI.bcast(self._workers_raise_unhandled) def _execute_serial(self): # Wait for jobs to finish scheduling while concurrent.futures.wait( self._schedulers, timeout=self._max_wait, return_when="FIRST_COMPLETED" )[1]: self.ping() self.notify() # Prepare jobs for local execution for job in self._job_queue: job._future = concurrent.futures.Future() job._pool = self if job.status != JobStatus.CANCELLED and job.status != JobStatus.ABORTED: job._status = JobStatus.QUEUED else: job._future.cancel() self.change_status(PoolStatus.EXECUTING) # Just run each job serially for job in self._job_queue: if not job._future.set_running_or_notify_cancel(): continue job.change_status(JobStatus.RUNNING) self.notify() while job.run(timeout=self._max_wait): self.ping() self.notify() self.notify() # Raise any unhandled errors self.raise_unhandled() self.change_status(PoolStatus.CLOSING)
[docs] def change_status(self, status: PoolStatus): old_status = self._status self._status = status self.add_notification(PoolStatusProgress(self, old_status)) self.notify()
[docs] def add_notification(self, notification: PoolProgress): self._progress_notifications.append(notification)
[docs] def ping(self): self.add_notification(PoolProgress(self, PoolProgressReason.MAX_TIMEOUT_PING))
[docs] def notify(self): for notification in self._progress_notifications: job = getattr(notification, "job", None) job_error = getattr(job, "error", None) has_error = job_error is not None and type(job_error) is not JobCancelledError handled_error = [bool(listener(notification)) for listener in self._listeners] if has_error and not any(handled_error): self._unhandled_errors.append(job) if self._fail_fast: self.raise_unhandled() self._progress_notifications = []
[docs] def raise_unhandled(self): if not self._unhandled_errors: return errors = [] # Raise and catch for nicer traceback for job in self._unhandled_errors: try: if isinstance(job, SubmissionContext): raise JobSchedulingError( f"{job.name} failed to schedule its jobs." ) from job.context["error"] raise JobErroredError(f"{job} failed", job.error) from job.error except (JobErroredError, JobSchedulingError) as e: errors.append(e) self._unhandled_errors = [] raise WorkflowError( f"Your workflow encountered errors.", errors, )