diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index eb757d7527..f6912d1543 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -444,7 +444,7 @@ class ShuffleType(Enum): ARRAY_RECHUNK = "ArrayRechunk" -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class ShuffleRunSpec(Generic[_T_partition_id]): run_id: int = field(init=False, default_factory=partial(next, itertools.count(1))) spec: ShuffleSpec @@ -456,7 +456,7 @@ def id(self) -> ShuffleId: return self.spec.id -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class ShuffleSpec(abc.ABC, Generic[_T_partition_id]): id: ShuffleId disk: bool @@ -491,7 +491,7 @@ def create_run_on_worker( """Create the new shuffle run on the worker.""" -@dataclass(eq=False) +@dataclass(eq=False, slots=True) class SchedulerShuffleState(Generic[_T_partition_id]): run_spec: ShuffleRunSpec participating_workers: set[str] diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 5e54388631..a0e91f4933 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -1107,7 +1107,7 @@ def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class ArrayRechunkSpec(ShuffleSpec[NDIndex]): new: ChunkedAxes old: ChunkedAxes diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index aec5d75947..bb5958d26a 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -327,7 +327,7 @@ def validate_data(self, data: pd.DataFrame) -> None: raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.") -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class DataFrameShuffleSpec(ShuffleSpec[int]): npartitions: int column: str diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3ae0d16e9b..886cae37d9 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -281,6 +281,7 @@ class TaskState: _instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet() # Support for weakrefs to a class with __slots__ + # TODO use @dataclass(weakref_slot=True) (requires Python >=3.11) __weakref__: Any = field(init=False) def __post_init__(self) -> None: @@ -339,11 +340,10 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: return {k: v for k, v in out.items() if v and k != "prefix"} -@dataclass +@dataclass(slots=True) class Instruction: """Command from the worker state machine to the Worker, in response to an event""" - __slots__ = ("stimulus_id",) stimulus_id: str @classmethod @@ -399,23 +399,20 @@ def __eq__(self, other: object) -> bool: return all(getattr(other, k) == v for k, v in self.kwargs.items()) -@dataclass +@dataclass(slots=True) class GatherDep(Instruction): - __slots__ = ("worker", "to_gather", "total_nbytes") worker: str to_gather: set[Key] total_nbytes: int -@dataclass +@dataclass(slots=True) class Execute(Instruction): - __slots__ = ("key",) key: Key -@dataclass +@dataclass(slots=True) class RetryBusyWorkerLater(Instruction): - __slots__ = ("worker",) worker: str @@ -432,7 +429,7 @@ def to_dict(self) -> dict[str, Any]: return d -@dataclass +@dataclass(slots=True) class TaskFinishedMsg(SendMessageToScheduler): op = "task-finished" @@ -444,7 +441,6 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] - __slots__ = tuple(__annotations__) def to_dict(self) -> dict[str, Any]: d = super().to_dict() @@ -452,7 +448,7 @@ def to_dict(self) -> dict[str, Any]: return d -@dataclass +@dataclass(slots=True) class TaskErredMsg(SendMessageToScheduler): op = "task-erred" @@ -464,7 +460,6 @@ class TaskErredMsg(SendMessageToScheduler): traceback_text: str thread: int | None startstops: list[StartStop] - __slots__ = tuple(__annotations__) def to_dict(self) -> dict[str, Any]: d = super().to_dict() @@ -489,42 +484,38 @@ def from_task( ) -@dataclass +@dataclass(slots=True) class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" - __slots__ = ("key",) key: Key # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception -@dataclass +@dataclass(slots=True) class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - __slots__ = ("key",) key: Key -@dataclass +@dataclass(slots=True) class LongRunningMsg(SendMessageToScheduler): op = "long-running" - __slots__ = ("key", "run_id", "compute_duration") key: Key run_id: int compute_duration: float | None -@dataclass +@dataclass(slots=True) class AddKeysMsg(SendMessageToScheduler): op = "add-keys" - __slots__ = ("keys",) keys: Collection[Key] -@dataclass +@dataclass(slots=True) class RequestRefreshWhoHasMsg(SendMessageToScheduler): """Worker -> Scheduler asynchronous request for updated who_has information. Not to be confused with the scheduler.who_has synchronous RPC call, which is used @@ -540,11 +531,10 @@ class RequestRefreshWhoHasMsg(SendMessageToScheduler): op = "request-refresh-who-has" - __slots__ = ("keys",) keys: Collection[Key] -@dataclass +@dataclass(slots=True) class StealResponseMsg(SendMessageToScheduler): """Worker->Scheduler response to ``{op: steal-request}`` @@ -555,30 +545,20 @@ class StealResponseMsg(SendMessageToScheduler): op = "steal-response" - __slots__ = ("key", "state") key: Key state: TaskStateState | None -@dataclass +@dataclass(slots=True) class StateMachineEvent: """Base abstract class for all stimuli that can modify the worker state""" - __slots__ = ("stimulus_id", "handled") #: Unique ID of the event stimulus_id: str #: timestamp of when the event was handled by the worker - # TODO Switch to @dataclass(slots=True), uncomment the line below, and remove the - # __new__ method (requires Python >=3.10) - # handled: float | None = field(init=False, default=None) + handled: float | None = field(init=False, default=None) _classes: ClassVar[dict[str, type[StateMachineEvent]]] = {} - def __new__(cls, *args: Any, **kwargs: Any) -> StateMachineEvent: - """Hack to initialize the ``handled`` attribute in Python <3.10""" - self = object.__new__(cls) - self.handled = None - return self - def __init_subclass__(cls) -> None: StateMachineEvent._classes[cls.__name__] = cls @@ -625,39 +605,35 @@ def _after_from_dict(self) -> None: """Optional post-processing after an instance is created by ``from_dict``""" -@dataclass +@dataclass(slots=True) class PauseEvent(StateMachineEvent): - __slots__ = () + pass -@dataclass +@dataclass(slots=True) class UnpauseEvent(StateMachineEvent): - __slots__ = () + pass -@dataclass +@dataclass(slots=True) class RetryBusyWorkerEvent(StateMachineEvent): - __slots__ = ("worker",) worker: str -@dataclass +@dataclass(slots=True) class GatherDepDoneEvent(StateMachineEvent): """:class:`GatherDep` instruction terminated (abstract base class)""" - __slots__ = ("worker", "total_nbytes") worker: str total_nbytes: int # Must be the same as in GatherDep instruction -@dataclass +@dataclass(slots=True) class GatherDepSuccessEvent(GatherDepDoneEvent): """:class:`GatherDep` instruction terminated: remote worker fetched successfully """ - __slots__ = ("data",) - data: dict[Key, object] # There may be fewer keys than in GatherDep def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -670,25 +646,21 @@ def _after_from_dict(self) -> None: self.data = {k: None for k in self.data} -@dataclass +@dataclass(slots=True) class GatherDepBusyEvent(GatherDepDoneEvent): """:class:`GatherDep` instruction terminated: remote worker is busy """ - __slots__ = () - -@dataclass +@dataclass(slots=True) class GatherDepNetworkFailureEvent(GatherDepDoneEvent): """:class:`GatherDep` instruction terminated: network failure while trying to communicate with remote worker """ - __slots__ = () - -@dataclass +@dataclass(slots=True) class GatherDepFailureEvent(GatherDepDoneEvent): """class:`GatherDep` instruction terminated: generic error raised (not a network failure); e.g. data failed to deserialize. @@ -698,7 +670,6 @@ class GatherDepFailureEvent(GatherDepDoneEvent): traceback: Serialize | None exception_text: str traceback_text: str - __slots__ = tuple(__annotations__) def _after_from_dict(self) -> None: self.exception = Serialize(Exception()) @@ -725,13 +696,12 @@ def from_exception( ) -@dataclass +@dataclass(slots=True) class RemoveWorkerEvent(StateMachineEvent): worker: str - __slots__ = ("worker",) -@dataclass +@dataclass(slots=True) class ComputeTaskEvent(StateMachineEvent): key: Key run_id: int @@ -744,8 +714,6 @@ class ComputeTaskEvent(StateMachineEvent): annotations: dict span_id: str | None - __slots__ = tuple(__annotations__) - def __post_init__(self) -> None: # Fixes after msgpack decode if isinstance(self.priority, list): # type: ignore[unreachable] @@ -806,17 +774,16 @@ def dummy( ) -@dataclass +@dataclass(slots=True) class ExecuteDoneEvent(StateMachineEvent): """Abstract base event for all the possible outcomes of a :class:`Compute` instruction """ key: Key - __slots__ = ("key",) -@dataclass +@dataclass(slots=True) class ExecuteSuccessEvent(ExecuteDoneEvent): run_id: int # FIXME: Utilize the run ID in all ExecuteDoneEvents value: object @@ -824,7 +791,6 @@ class ExecuteSuccessEvent(ExecuteDoneEvent): stop: float nbytes: int type: type | None - __slots__ = tuple(__annotations__) def to_loggable(self, *, handled: float) -> StateMachineEvent: out = copy(self) @@ -867,7 +833,7 @@ def dummy( ) -@dataclass +@dataclass(slots=True) class ExecuteFailureEvent(ExecuteDoneEvent): run_id: int # FIXME: Utilize the run ID in all ExecuteDoneEvents start: float | None @@ -876,7 +842,6 @@ class ExecuteFailureEvent(ExecuteDoneEvent): traceback: Serialize | None exception_text: str traceback_text: str - __slots__ = tuple(__annotations__) def _after_from_dict(self) -> None: self.exception = Serialize(Exception()) @@ -934,9 +899,8 @@ def dummy( # Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception -@dataclass +@dataclass(slots=True) class RescheduleEvent(ExecuteDoneEvent): - __slots__ = () @staticmethod def dummy(key: Key, *, stimulus_id: str) -> RescheduleEvent: @@ -946,18 +910,17 @@ def dummy(key: Key, *, stimulus_id: str) -> RescheduleEvent: return RescheduleEvent(key=key, stimulus_id=stimulus_id) -@dataclass +@dataclass(slots=True) class CancelComputeEvent(StateMachineEvent): - __slots__ = ("key",) key: Key -@dataclass +@dataclass(slots=True) class FindMissingEvent(StateMachineEvent): - __slots__ = () + pass -@dataclass +@dataclass(slots=True) class RefreshWhoHasEvent(StateMachineEvent): """Scheduler -> Worker message containing updated who_has information. @@ -966,31 +929,27 @@ class RefreshWhoHasEvent(StateMachineEvent): RequestRefreshWhoHasMsg """ - __slots__ = ("who_has",) # {key: [worker address, ...]} who_has: dict[Key, Collection[str]] -@dataclass +@dataclass(slots=True) class AcquireReplicasEvent(StateMachineEvent): - __slots__ = ("who_has", "nbytes") who_has: dict[Key, Collection[str]] nbytes: dict[Key, int] -@dataclass +@dataclass(slots=True) class RemoveReplicasEvent(StateMachineEvent): - __slots__ = ("keys",) keys: Collection[Key] -@dataclass +@dataclass(slots=True) class FreeKeysEvent(StateMachineEvent): - __slots__ = ("keys",) keys: Collection[Key] -@dataclass +@dataclass(slots=True) class StealRequestEvent(StateMachineEvent): """Event that requests a worker to release a key because it's now being computed somewhere else. @@ -1000,13 +959,11 @@ class StealRequestEvent(StateMachineEvent): StealResponseMsg """ - __slots__ = ("key",) key: Key -@dataclass +@dataclass(slots=True) class UpdateDataEvent(StateMachineEvent): - __slots__ = ("data",) data: dict[Key, object] def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -1016,15 +973,15 @@ def to_loggable(self, *, handled: float) -> StateMachineEvent: return out -@dataclass +@dataclass(slots=True) class SecedeEvent(StateMachineEvent): - __slots__ = ("key", "compute_duration") key: Key compute_duration: float # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} # Not to be confused with distributed.scheduler.Recs +# TODO replace with `type` statement (requires Python >=3.12) Recs: TypeAlias = dict[TaskState, TaskStateState | tuple] Instructions: TypeAlias = list[Instruction] RecsInstrs: TypeAlias = tuple[Recs, Instructions]