"""Array object used by swarm to create task instance batches."""
from __future__ import annotations
from typing import Optional, Set
import structlog
from jobmon.client.swarm.swarm_task import SwarmTask
[docs]
logger = structlog.get_logger(__name__)
[docs]
class SwarmArray:
def __init__(
self,
array_id: int,
max_concurrently_running: int,
array_name: Optional[str] = None,
) -> None:
"""Initialization of the SwarmArray."""
[docs]
self.array_id = array_id
[docs]
self.array_name = array_name or f"array_{array_id}"
[docs]
self.tasks: Set[SwarmTask] = set()
[docs]
self.max_concurrently_running = max_concurrently_running
[docs]
def add_task(self, task: SwarmTask) -> None:
if task.array_id != self.array_id:
raise ValueError(
f"array_id mismatch. SwarmTask={task.array_id}. Array={self.array_id}."
)
self.tasks.add(task)
[docs]
def __hash__(self) -> int:
"""Returns the array ID."""
return self.array_id
[docs]
def __eq__(self, other: object) -> bool:
"""Check if the hashes of two arrays are equivalent."""
if not isinstance(other, SwarmArray):
return False
else:
return hash(self) == hash(other)
[docs]
def __lt__(self, other: SwarmArray) -> bool:
"""Check if one hash is less than the has of another DistributorArray."""
return hash(self) < hash(other)