Source code for mpire.async_result

import collections
import itertools
import queue
import threading
from typing import Any, Callable, Dict, List, Optional, Union

from mpire.comms import EXIT_FUNC, INIT_FUNC

job_counter = itertools.count()


[docs]class AsyncResult: """ Adapted from ``multiprocessing.pool.ApplyResult``. """
[docs] def __init__(self, cache: Dict, callback: Optional[Callable], error_callback: Optional[Callable], job_id: Optional[int] = None, delete_from_cache: bool = True, timeout: Optional[float] = None) -> None: """ :param cache: Cache for storing intermediate results :param callback: Callback function to call when the task is finished. The callback function receives the output of the function as its argument :param error_callback: Callback function to call when the task has failed. The callback function receives the exception as its argument :param job_id: Job ID of the task. If None, a new job ID is generated :param delete_from_cache: If True, the result is deleted from the cache when the task is finished :param timeout: Timeout in seconds for a single task. When the timeout is exceeded, MPIRE will raise a ``TimeoutError``. Use ``None`` to disable (default) """ self._cache = cache self._callback = callback self._error_callback = error_callback self._delete_from_cache = delete_from_cache self._timeout = timeout self.job_id = next(job_counter) if job_id is None else job_id self._ready_event = threading.Event() self._success = None self._value = None if self.job_id in self._cache: raise ValueError(f"Job ID {job_id} already exists in cache") self._cache[self.job_id] = self
[docs] def ready(self) -> bool: """ :return: Returns True if the task is finished """ return self._ready_event.is_set()
[docs] def successful(self) -> bool: """ :return: Returns True if the task has finished successfully :raises: ValueError if the task is not finished yet """ if not self.ready(): raise ValueError(f"{self.job_id} is not ready") return self._success
[docs] def wait(self, timeout: Optional[float] = None) -> None: """ Wait until the task is finished :param timeout: Timeout in seconds. If None, wait indefinitely """ self._ready_event.wait(timeout)
[docs] def get(self, timeout: Optional[float] = None) -> Any: """ Wait until the task is finished and return the output of the function :param timeout: Timeout in seconds. If None, wait indefinitely :return: Output of the function :raises: TimeoutError if the task is not finished within the timeout. When the task has failed, the exception raised by the function is re-raised """ self.wait(timeout) if not self.ready(): raise TimeoutError if self._success: return self._value else: raise self._value
def _set(self, success: bool, result: Any) -> None: """ Set the result of the task and call any callbacks, when provided. This also removes the task from the cache, as it's no longer needed there. The user should store a reference to the result object :param success: True if the task has finished successfully :param result: Output of the function or the exception raised by the function """ self._success = success self._value = result if self._callback and self._success: self._callback(self._value) if self._error_callback and not self._success: self._error_callback(self._value) self._ready_event.set() if self._delete_from_cache: del self._cache[self.job_id]
class UnorderedAsyncResultIterator: """ Stores results of a task and provides an iterator to obtain the results in an unordered fashion """ def __init__(self, cache: Dict, n_tasks: Optional[int], job_id: Optional[int] = None, timeout: Optional[float] = None) -> None: """ :param cache: Cache for storing intermediate results :param n_tasks: Number of tasks that will be executed. If None, we don't know the lenght yet :param job_id: Job ID of the task. If None, a new job ID is generated :param timeout: Timeout in seconds for a single task. When the timeout is exceeded, MPIRE will raise a ``TimeoutError``. Use ``None`` to disable (default) """ self._cache = cache self._n_tasks = None self._timeout = timeout self.job_id = next(job_counter) if job_id is None else job_id self._items = collections.deque() self._condition = threading.Condition(lock=threading.Lock()) self._n_received = 0 self._n_returned = 0 self._exception = None self._got_exception = threading.Event() if self.job_id in self._cache: raise ValueError(f"Job ID {job_id} already exists in cache") self._cache[self.job_id] = self if n_tasks is not None: self.set_length(n_tasks) def __iter__(self) -> "UnorderedAsyncResultIterator": return self def next(self, block: bool = True, timeout: Optional[float] = None) -> Any: """ Obtain the next unordered result for the task :param block: If True, wait until the next result is available. If False, raise queue.Empty if no result is available :param timeout: Timeout in seconds. If None, wait indefinitely :return: The next result """ if self._items: self._n_returned += 1 return self._items.popleft() if self._n_tasks is not None and self._n_returned == self._n_tasks: raise StopIteration if not block: raise queue.Empty # We still expect results. Wait until the next result is available with self._condition: while not self._items: timed_out = not self._condition.wait(timeout=timeout) if timed_out: raise queue.Empty if self._n_tasks is not None and self._n_returned == self._n_tasks: raise StopIteration self._n_returned += 1 return self._items.popleft() __next__ = next def wait(self) -> None: """ Wait until all results are available """ with self._condition: while self._n_tasks is None or self._n_received < self._n_tasks: self._condition.wait() def _set(self, success: bool, result: Any) -> None: """ Set the result of the task :param success: True if the task has finished successfully :param result: Output of the function or the exception raised by the function """ if success: # Add the result to the queue and notify the iterator self._n_received += 1 self._items.append(result) with self._condition: self._condition.notify() else: self._exception = result self._got_exception.set() def set_length(self, length: int) -> None: """ Set the length of the iterator :param length: Length of the iterator """ if self._n_tasks is not None: if self._n_tasks != length: raise ValueError(f"Length of iterator has already been set to {self._n_tasks}, " f"but is now set to {length}") # Length has already been set. No need to do anything return with self._condition: self._n_tasks = length self._condition.notify() def get_exception(self) -> Exception: """ :return: The exception raised by the function """ self._got_exception.wait() return self._exception def remove_from_cache(self) -> None: """ Remove the iterator from the cache """ del self._cache[self.job_id] class AsyncResultWithExceptionGetter(AsyncResult): def __init__(self, cache: Dict, job_id: int) -> None: super().__init__(cache, callback=None, error_callback=None, job_id=job_id, delete_from_cache=False, timeout=None) def get_exception(self) -> Exception: """ :return: The exception raised by the function """ self.wait() return self._value def reset(self) -> None: """ Reset the result object """ self._success = None self._value = None self._ready_event.clear() class UnorderedAsyncExitResultIterator(UnorderedAsyncResultIterator): def __init__(self, cache: Dict) -> None: super().__init__(cache, n_tasks=None, job_id=EXIT_FUNC, timeout=None) def get_results(self) -> List[Any]: """ :return: List of exit results """ return list(self._items) def reset(self) -> None: """ Reset the result object """ self._n_tasks = None self._items.clear() self._n_received = 0 self._n_returned = 0 self._exception = None self._got_exception.clear() AsyncResultType = Union[AsyncResult, AsyncResultWithExceptionGetter, UnorderedAsyncResultIterator, UnorderedAsyncExitResultIterator]