Source code for bsb.services.pool

"""
Job pooling module.

Jobs derive from the base :class:`.Job` class which can be put on the queue of a
:class:`.JobPool`. In order to submit themselves to the pool Jobs will
:meth:`~.Job.serialize` themselves into a predefined set of variables::

   job.serialize() -> (job_type, f, args, kwargs)

* ``job_type`` should be a string that is a class name defined in this module.
   (e.g. ``"PlacementJob"``)

* ``f`` should be the function object that the job's ``execute`` method should
   execute.

* ``args`` and ``kwargs`` are the args to be passed to that ``f``.

The :meth:`.Job.execute` handler can help interpret ``args`` and ``kwargs``
before running ``f``. The execute handler has access to the scaffold on the MPI
process so one best serializes just the name of some part of the configuration,
rather than trying to pickle the complex objects. For example, the
:class:`.PlacementJob` uses the first ``args`` element to store the
:class:`~bsb.placement.strategy.PlacementStrategy` name and then retrieve it from the
scaffold:

.. code-block:: python

    @staticmethod
    def execute(job_owner, f, args, kwargs):
        placement = job_owner.placement[args[0]]
        indicators = placement.get_indicators()
        return f(placement, *args[1:], indicators, **kwargs)

A job has a couple of display variables that can be set: ``_cname`` for the
class name, ``_name`` for the job name and ``_c`` for the chunk. These are used
to display what the workers are doing during parallel execution. This is an experimental
API and subject to sudden change in the future.
"""

import abc
import concurrent.futures
import contextlib
import functools
import logging
import pickle
import tempfile
import threading
import typing
import warnings
import zlib
from contextlib import ExitStack
from enum import Enum, auto

import numpy as np
from exceptiongroup import ExceptionGroup

from .._util import obj_str_insert
from ..exceptions import (
    JobCancelledError,
    JobPoolContextError,
    JobPoolError,
    JobSchedulingError,
)
from ._util import ErrorModule, MockModule

if typing.TYPE_CHECKING:
    from mpipool import MPIExecutor


[docs] class WorkflowError(ExceptionGroup): pass
[docs] class JobErroredError(Exception): def __init__(self, message, error): super().__init__(message) self.error = error
[docs] class JobStatus(Enum): # Job has not been queued yet, waiting for dependencies to resolve. PENDING = "pending" # Job is on the queue. QUEUED = "queued" # Job is currently running on a worker. RUNNING = "running" # Job ran successfully. SUCCESS = "success" # Job failed (an exception was raised). FAILED = "failed" # Job was cancelled before it started running. CANCELLED = "cancelled" # Job was killed for some reason. ABORTED = "aborted"
[docs] class PoolStatus(Enum): # Pool has been initialized and jobs can be scheduled. SCHEDULING = "scheduling" # Pool started execution. EXECUTING = "executing" # Pool is closing down. CLOSING = "closing"
[docs] class PoolProgressReason(Enum): POOL_STATUS_CHANGE = auto() JOB_ADDED = auto() JOB_STATUS_CHANGE = auto() MAX_TIMEOUT_PING = auto()
[docs] class Workflow: def __init__(self, phases: list[str]): self._phases = phases self._phase = 0 @property def phases(self): return [*self._phases] @property def finished(self): return self._phase >= len(self._phases) @property def phase(self): if self.finished: return "finished" else: return self._phases[self._phase]
[docs] def next_phase(self): self._phase += 1 return self.phase
[docs] class PoolProgress: """ Class used to report pool progression to listeners. """ def __init__(self, pool: "JobPool", reason: PoolProgressReason): self._pool = pool self._reason = reason @property def reason(self): return self._reason @property def workflow(self): return self._pool.workflow @property def jobs(self): return self._pool.jobs @property def status(self): return self._pool.status
[docs] class PoolJobAddedProgress(PoolProgress): def __init__(self, pool: "JobPool", job: "Job"): super().__init__(pool, PoolProgressReason.JOB_ADDED) self._job = job @property def job(self): return self._job
[docs] class PoolJobUpdateProgress(PoolProgress): def __init__(self, pool: "JobPool", job: "Job", old_status: "JobStatus"): super().__init__(pool, PoolProgressReason.JOB_STATUS_CHANGE) self._job = job self._old_status = old_status @property def job(self): return self._job @property def old_status(self): return self._old_status @property def status(self): return self._job.status
[docs] class PoolStatusProgress(PoolProgress): def __init__(self, pool: "JobPool", old_status: PoolStatus): super().__init__(pool, PoolProgressReason.POOL_STATUS_CHANGE) self._old_status = old_status
class _MissingMPIExecutor(ErrorModule): pass class _MPIPoolModule(MockModule): @property def MPIExecutor(self) -> type["MPIExecutor"]: return _MissingMPIExecutor( "This is not a public interface. Use `.services.JobPool` instead." ) def enable_serde_logging(self): import mpipool mpipool.enable_serde_logging() _MPIPool = _MPIPoolModule("mpipool")
[docs] def dispatcher(pool_id, job_args): """ The dispatcher is the function that gets pickled on main, and unpacked "here" on the worker. Through class variables on `JobPool` and the given `pool_id` we can find the pool and scaffold object, and the job function to run. Before running a job, the cache is checked for eventual cached items to free up. """ job_type, args, kwargs = job_args # Get the static job execution handler from this module handler = globals()[job_type].execute # Get the owning scaffold from the JobPool class variables, which act as a registry. owner = JobPool.get_owner(pool_id) # Check the pool's cache pool = JobPool._pools[pool_id] required_cache_items = pool._read_required_cache_items() # and free any stale cached items free_stale_pool_cache(owner, required_cache_items) # Execute the job handler. return handler(owner, args, kwargs)
[docs] class SubmissionContext: """ Context information on who submitted a certain job. """ def __init__(self, submitter, chunks=None, **kwargs): self._submitter = submitter self._chunks = chunks self._context = kwargs @property def name(self): if hasattr(self._submitter, "get_node_name"): name = self._submitter.get_node_name() else: name = str(self._submitter) return name @property def submitter(self): return self._submitter @property def chunks(self): from ..storage._chunks import chunklist return chunklist(self._chunks) if self._chunks is not None else None @property def context(self): return {**self._context} def __getattr__(self, key): if key in self._context: return self._context[key] else: return self.__getattribute__(key)
[docs] class Job(abc.ABC): """ Dispatches the execution of a function through a JobPool. """ def __init__( self, pool, submission_context: SubmissionContext, args, kwargs, deps=None, cache_items=None, ): self.pool_id = pool.id self._comm = pool._comm self._args = args self._kwargs = kwargs self._deps = set(deps or []) self._submit_ctx = submission_context self._completion_cbs = [] self._status = JobStatus.PENDING self._future: concurrent.futures.Future | None = None self._thread: threading.Thread | None = None self._res_file = None self._error = None self._cache_items: list[int] = [] if cache_items is None else cache_items for j in self._deps: j.on_completion(self._dep_completed) @obj_str_insert def __str__(self): return self.description @property def name(self): return self._submit_ctx.name @property def description(self): descr = self.name if self.context: descr += " (" + ", ".join(f"{k}={v}" for k, v in self.context.items()) + ")" return descr @property def submitter(self): return self._submit_ctx.submitter @property def context(self): return self._submit_ctx.context @property def status(self): return self._status @property def result(self): try: with open(self._res_file, "rb") as f: return pickle.load(f) except Exception: raise JobPoolError(f"Result of {self} is not available.") from None @property def error(self): return self._error
[docs] def serialize(self): """ Convert the job to a (de)serializable representation. """ name = self.__class__.__name__ # First arg is to find the static `execute` method so that we don't have to # serialize any of the job objects themselves but can still use different handlers # for different job types. return (name, self._args, self._kwargs)
[docs] @staticmethod @abc.abstractmethod def execute(job_owner, args, kwargs): """ Job handler. """ pass
[docs] def run(self, timeout=None): """ Execute the job on the current process, in a thread, and return whether the job is still running. """ if self._thread is None: def target(): try: # Execute the static handler result = self.execute(self._pool.owner, self._args, self._kwargs) except Exception as e: self._future.set_exception(e) else: self._future.set_result(result) self._thread = threading.Thread(target=target, daemon=True) self._thread.start() self._thread.join(timeout=timeout) if not self._thread.is_alive(): self._completed() return False return True
[docs] def on_completion(self, cb): self._completion_cbs.append(cb)
[docs] def set_result(self, value): dirname = JobPool.get_tmp_folder(self.pool_id) try: with tempfile.NamedTemporaryFile( prefix=dirname + "/", delete=False, mode="wb" ) as fp: pickle.dump(value, fp) self._res_file = fp.name except FileNotFoundError as e: self.set_exception(e) else: self.change_status(JobStatus.SUCCESS)
[docs] def set_exception(self, e: Exception): self._error = e self.change_status(JobStatus.FAILED)
def _completed(self): if self._status != JobStatus.CANCELLED: try: result = self._future.result() except Exception as e: self.set_exception(e) else: self.set_result(result) for cb in self._completion_cbs: cb(self) def _dep_completed(self, dep): # Earlier we registered this callback on the completion of our dependencies. # When a dep completes we end up here, and we discard it as a dependency as it has # finished. If the dep returns an error remove the job from the pool, # since the dependency have failed. self._deps.discard(dep) if dep._status is not JobStatus.SUCCESS: self.cancel("Job killed for dependency failure") else: # When all our dependencies have been discarded we can queue ourselves. # Unless the pool is serial, then the pool itself just runs all jobs in order. if not self._deps and self._comm.get_size() > 1: # self._pool is set when the pool first tried to enqueue us, but we were # still waiting for deps, in the `_enqueue` method below. self._enqueue(self._pool) def _enqueue(self, pool): if not self._deps and self._status is JobStatus.PENDING: # Go ahead and submit ourselves to the pool, no dependencies to wait for # The dispatcher is run on the remote worker and unpacks the data required # to execute the job contents. self.change_status(JobStatus.QUEUED) self._future = pool._submit(dispatcher, self.pool_id, self.serialize()) else: # We have unfinished dependencies and should wait until we can enqueue # ourselves when our dependencies haved all notified us of their completion. # Store the reference to the pool though, so later in `_dep_completed` we can # call `_enqueue` again ourselves! self._pool = pool
[docs] def cancel(self, reason: str | None = None): self.change_status(JobStatus.CANCELLED) self._error = JobCancelledError() if reason is None else JobCancelledError(reason) if self._future and not self._future.cancel(): warnings.warn( f"Could not cancel {self}, the job is already running.", stacklevel=2 )
[docs] def change_status(self, status: JobStatus): old_status = self._status self._status = status try: # Closed pools may have been removed from this map already. pool = JobPool._pools[self.pool_id] except KeyError: pass else: progress = PoolJobUpdateProgress(pool, self, old_status) pool.add_notification(progress)
[docs] class PlacementJob(Job): """ Dispatches the execution of a chunk of a placement strategy through a JobPool. """ def __init__(self, pool, strategy, chunk, deps=None): args = (strategy.name, chunk) context = SubmissionContext(strategy, [chunk]) cache_items = get_node_cache_items(strategy) super().__init__(pool, context, args, {}, deps=deps, cache_items=cache_items)
[docs] @staticmethod def execute(job_owner, args, kwargs): name, chunk = args placement = job_owner.placement[name] indicators = placement.get_indicators() return placement.place(chunk, indicators, **kwargs)
[docs] class ConnectivityJob(Job): """ Dispatches the execution of a chunk of a connectivity strategy through a JobPool. """ def __init__(self, pool, strategy, pre_roi, post_roi, deps=None): from ..storage._chunks import chunklist args = (strategy.name, pre_roi, post_roi) context = SubmissionContext( strategy, chunks=chunklist((*(pre_roi or []), *(post_roi or []))) ) cache_items = get_node_cache_items(strategy) super().__init__(pool, context, args, {}, deps=deps, cache_items=cache_items)
[docs] @staticmethod def execute(job_owner, args, kwargs): name = args[0] connectivity = job_owner.connectivity[name] collections = connectivity._get_connect_args_from_job(*args[1:]) return connectivity.connect(*collections, **kwargs)
[docs] class FunctionJob(Job): def __init__(self, pool, f, args, kwargs, deps=None, cache_items=None, **context): # Pack the function into the args args = (f, args) # If no submitter was given, set the function as submitter context.setdefault("submitter", f) super().__init__( pool, SubmissionContext(**context), args, kwargs, deps=deps, cache_items=cache_items, )
[docs] @staticmethod def execute(job_owner, args, kwargs): # Unpack the function from the args f, args = args return f(job_owner, *args, **kwargs)
[docs] class JobPool: _next_pool_id = 0 _pools = {} _pool_owners = {} _tmp_folders = {} def __init__(self, id, scaffold, fail_fast=False, workflow: "Workflow" = None): self._schedulers: list[concurrent.futures.Future] = [] self.id: int = id self._scaffold = scaffold self._comm = scaffold._comm self._unhandled_errors = [] self._running_futures: list[concurrent.futures.Future] = [] self._mpipool: MPIExecutor | None = None self._job_queue: list[Job] = [] self._listeners = [] self._max_wait = 60 self._status: PoolStatus = None self._progress_notifications: list[PoolProgress] = [] self._workers_raise_unhandled = False self._fail_fast = fail_fast self._workflow = workflow self._cache_buffer = np.zeros(1000, dtype=np.uint64) self._cache_window = self._comm.window(self._cache_buffer) def __enter__(self): self._context = ExitStack() tmp_dirname = self._context.enter_context(tempfile.TemporaryDirectory()) JobPool._pool_owners[self.id] = self._scaffold JobPool._pools[self.id] = self JobPool._tmp_folders[self.id] = tmp_dirname del self._scaffold for listener in self._listeners: # Pass if listener is not a context manager with contextlib.suppress(TypeError, AttributeError): self._context.enter_context(listener) self.change_status(PoolStatus.SCHEDULING) return self def __exit__(self, exc_type, exc_val, exc_tb): self._context.__exit__(exc_type, exc_val, exc_tb) # Clean up pool/job references self._job_queue = [] del JobPool._pools[self.id] del JobPool._pool_owners[self.id] del JobPool._tmp_folders[self.id] self.id = None
[docs] def add_listener(self, listener, max_wait=None): self._max_wait = min(self._max_wait, max_wait or float("+inf")) self._listeners.append(listener)
@property def workflow(self): return self._workflow @property def status(self): return self._status @property def jobs(self) -> list[Job]: return [*self._job_queue] @property def parallel(self): return self._comm.get_size() > 1
[docs] @classmethod def get_owner(cls, id): return cls._pool_owners[id]
[docs] @classmethod def get_tmp_folder(cls, id): return cls._tmp_folders[id]
@property def owner(self): return self.get_owner(self.id)
[docs] def is_main(self): return self._comm.get_rank() == 0
[docs] def get_submissions_of(self, submitter): return [job for job in self._job_queue if job.submitter is submitter]
def _put(self, job): """ Puts a job onto our internal queue. """ if self._mpipool and not self._mpipool.open: raise JobPoolError("No job pool available for job submission.") else: self.add_notification(PoolJobAddedProgress(self, job)) self._job_queue.append(job) if self._mpipool: # This job was scheduled after the MPIPool was opened, so immediately # put it on the MPIPool's queue. job._enqueue(self) def _submit(self, fn, *args, **kwargs): if not self._mpipool or not self._mpipool.open: raise JobPoolError("No job pool available for job submission.") else: future = self._mpipool.submit(fn, *args, **kwargs) self._running_futures.append(future) return future def _schedule(self, future: concurrent.futures.Future, nodes, scheduler): _failed_nodes = [] if not future.set_running_or_notify_cancel(): return try: for node in nodes: failed_deps = [ n for n in getattr(node, "depends_on", []) if n in _failed_nodes ] if failed_deps: _failed_nodes.append(node) ctx = SubmissionContext( node, error=JobSchedulingError( f"Depends on {failed_deps}, whom failed." ), ) self._unhandled_errors.append(ctx) continue try: scheduler(node) except Exception as e: _failed_nodes.append(node) ctx = SubmissionContext(node, error=e) self._unhandled_errors.append(ctx) finally: future.set_result(None)
[docs] def schedule(self, nodes, scheduler=None): if scheduler is None: def scheduler(node): node.queue(self) future = concurrent.futures.Future() self._schedulers.append(future) thread = threading.Thread(target=self._schedule, args=(future, nodes, scheduler)) thread.start()
@property def scheduling(self): return any(not f.done() for f in self._schedulers)
[docs] def queue(self, f, args=None, kwargs=None, deps=None, **context): job = FunctionJob(self, f, args or [], kwargs or {}, deps, [], **context) self._put(job) return job
[docs] def queue_placement(self, strategy, chunk, deps=None): job = PlacementJob(self, strategy, chunk, deps) self._put(job) return job
[docs] def queue_connectivity(self, strategy, pre_roi, post_roi, deps=None): job = ConnectivityJob(self, strategy, pre_roi, post_roi, deps) self._put(job) return job
[docs] def execute(self, return_results=False): """ Execute the jobs in the queue. In serial execution this runs all the jobs in the queue in First In First Out order. In parallel execution this enqueues all jobs into the MPIPool unless they have dependencies that need to complete first. """ if self.id is None: raise JobPoolContextError("Job pools must use a context manager.") if self.parallel: self._execute_parallel() else: self._execute_serial() if return_results: return { job: job.result for job in self._job_queue if job.status == JobStatus.SUCCESS }
def _execute_parallel(self): import bsb.options # Enable full mpipool debugging if bsb.options.debug_pool: _MPIPool.enable_serde_logging() # Create the MPI pool self._mpipool = _MPIPool.MPIExecutor( loglevel=logging.DEBUG if bsb.options.debug_pool else logging.CRITICAL ) if self._mpipool.is_worker(): # The workers will return out of the pool constructor when they receive # the shutdown signal from the master, they return here skipping the # master logic. # Check if we need to abort our process due to errors etc. abort = self._comm.bcast(None) if abort: raise WorkflowError( "Unhandled exceptions during parallel execution.", [JobPoolError("See main node logs for details.")], ) # Free all cached items free_stale_pool_cache(self.owner, set()) return try: # Tell the listeners execution is running self.change_status(PoolStatus.EXECUTING) # Kickstart the workers with the queued jobs for job in self._job_queue: job._enqueue(self) # Add the scheduling futures to the running futures, to await them. self._running_futures.extend(self._schedulers) # Start tracking cached items self._update_cache_window() # Keep executing as long as any of the schedulers or jobs aren't done yet. while self.scheduling or any( job.status == JobStatus.PENDING or job.status == JobStatus.QUEUED for job in self._job_queue ): try: done, not_done = concurrent.futures.wait( self._running_futures, timeout=self._max_wait, return_when="FIRST_COMPLETED", ) except ValueError: # Sometimes a ValueError is raised here, perhaps because we modify # the list below? continue # Complete any jobs that are done for job in self._job_queue: if job._future in done: job._completed() # If a job finished, update the required cache items if len(done): self._update_cache_window() # Remove running futures that are done for future in done: self._running_futures.remove(future) # If nothing finished, post a timeout notification. if not len(done): self.ping() # Notify all the listeners, and store/raise any unhandled errors self.notify() # Notify listeners that execution is over self.change_status(PoolStatus.CLOSING) # Raise any unhandled errors self.raise_unhandled() except: # If any exception (including SystemExit and KeyboardInterrupt) happen on main # we should broadcast the abort to all worker nodes. self._workers_raise_unhandled = True raise finally: # Shut down our internal pool self._mpipool.shutdown(wait=False, cancel_futures=True) # Broadcast whether the worker nodes should raise an unhandled error. self._comm.bcast(self._workers_raise_unhandled) def _execute_serial(self): # Wait for jobs to finish scheduling while concurrent.futures.wait( self._schedulers, timeout=self._max_wait, return_when="FIRST_COMPLETED" )[1]: self.ping() self.notify() # Prepare jobs for local execution for job in self._job_queue: job._future = concurrent.futures.Future() job._pool = self if job.status != JobStatus.CANCELLED and job.status != JobStatus.ABORTED: job._status = JobStatus.QUEUED else: job._future.cancel() self.change_status(PoolStatus.EXECUTING) # Just run each job serially for job in self._job_queue: if not job._future.set_running_or_notify_cancel(): continue job.change_status(JobStatus.RUNNING) self.notify() while job.run(timeout=self._max_wait): self.ping() self.notify() # After each job, check if any cache items can be freed. free_stale_pool_cache(self.owner, self.get_required_cache_items()) self.notify() # Raise any unhandled errors self.raise_unhandled() self.change_status(PoolStatus.CLOSING)
[docs] def change_status(self, status: PoolStatus): old_status = self._status self._status = status self.add_notification(PoolStatusProgress(self, old_status)) self.notify()
[docs] def add_notification(self, notification: PoolProgress): self._progress_notifications.append(notification)
[docs] def ping(self): self.add_notification(PoolProgress(self, PoolProgressReason.MAX_TIMEOUT_PING))
[docs] def notify(self): for notification in self._progress_notifications: job = getattr(notification, "job", None) job_error = getattr(job, "error", None) has_error = job_error is not None and type(job_error) is not JobCancelledError handled_error = [bool(listener(notification)) for listener in self._listeners] if has_error and not any(handled_error): self._unhandled_errors.append(job) if self._fail_fast: self.raise_unhandled() self._progress_notifications = []
[docs] def raise_unhandled(self): if not self._unhandled_errors: return errors = [] # Raise and catch for nicer traceback for job in self._unhandled_errors: try: if isinstance(job, SubmissionContext): raise JobSchedulingError( f"{job.name} failed to schedule its jobs." ) from job.context["error"] raise JobErroredError(f"{job} failed", job.error) from job.error except (JobErroredError, JobSchedulingError) as e: errors.append(e) self._unhandled_errors = [] raise WorkflowError( "Your workflow encountered errors.", errors, )
[docs] def get_required_cache_items(self): """ Returns the list of cache functions for all the jobs in the queue. :return: set of cache function name :rtype: set[int] """ items = set() for job in self._job_queue: if ( job.status == JobStatus.QUEUED or job.status == JobStatus.PENDING or job.status == JobStatus.RUNNING ): items.update(job._cache_items) return items
def _update_cache_window(self): """ Checks and updates if the cache buffer should be updated by looking at the job statuses in the job queue. Only call on main. """ # Create a new cache window buffer new_buffer = np.zeros(1000, dtype=int) for i, item in enumerate(self.get_required_cache_items()): new_buffer[i] = item # If there are actual cache requirement differences, lock the window # and transfer the buffer if np.any(new_buffer != self._cache_buffer): self._cache_window.Lock(0) self._cache_buffer[:] = new_buffer self._cache_window.Unlock(0) def _read_required_cache_items(self): """ Locks the cache window and read the still required cache items from rank 0. Only call on workers. """ from mpi4py.MPI import UINT64_T self._cache_window.Lock(0) self._cache_window.Get([self._cache_buffer, UINT64_T], 0) self._cache_window.Unlock(0) return set(self._cache_buffer)
[docs] def get_node_cache_items(node): return [ attr.get_pool_cache_id(node) for key in dir(node) if hasattr(attr := getattr(node, key), "get_pool_cache_id") ]
[docs] def free_stale_pool_cache(scaffold, required_cache_items: set[int]): for stale_key in set(scaffold._pool_cache.keys()) - required_cache_items: # If so, pop them and execute the registered cleanup function. scaffold._pool_cache.pop(stale_key)()
[docs] def pool_cache(caching_function): @functools.cache def decorated(self, *args, **kwargs): self.scaffold.register_pool_cached_item( decorated.get_pool_cache_id(self), cleanup ) return caching_function(self, *args, **kwargs) def get_pool_cache_id(node): if not hasattr(node, "get_node_name"): raise RuntimeError( "Pool caching can only be used on methods of @node decorated classes." ) return _cache_hash(f"{node.get_node_name()}.{caching_function.__name__}") def cleanup(): decorated.cache_clear() decorated.get_pool_cache_id = get_pool_cache_id return decorated
def _cache_hash(string): return zlib.crc32(string.encode())