"""Task Table for the Database."""
import structlog
from sqlalchemy import (
VARCHAR,
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
from jobmon.core.exceptions import InvalidStateTransition
from jobmon.core.serializers import SerializeDistributorTask, SerializeSwarmTask
from jobmon.server.web.models import Base
from jobmon.server.web.models.task_instance_status import TaskInstanceStatus
from jobmon.server.web.models.task_status import TaskStatus
[docs]
logger = structlog.get_logger(__name__)
[docs]
class Task(Base):
"""Task Database object."""
[docs]
def to_wire_as_distributor_task(self) -> tuple:
"""Serialize executor task object."""
array_id = self.array.id if self.array is not None else None
serialized = SerializeDistributorTask.to_wire(
task_id=self.id,
array_id=array_id,
name=self.name,
command=self.command,
requested_resources=self.task_resources.requested_resources,
)
return serialized
[docs]
def to_wire_as_swarm_task(self) -> tuple:
"""Serialize swarm task."""
serialized = SerializeSwarmTask.to_wire(task_id=self.id, status=self.status)
return serialized
[docs]
id: Mapped[int] = mapped_column(Integer, primary_key=True)
[docs]
workflow_id = Column(Integer, ForeignKey("workflow.id"))
[docs]
node_id = Column(Integer, ForeignKey("node.id"))
[docs]
task_args_hash = Column(VARCHAR(150), index=True, nullable=False)
[docs]
array_id = Column(Integer, ForeignKey("array.id"), default=None)
[docs]
name: Mapped[str] = mapped_column(String(255), index=True, nullable=False)
[docs]
command: Mapped[str] = mapped_column(Text)
[docs]
task_resources_id = Column(Integer, ForeignKey("task_resources.id"), default=None)
[docs]
num_attempts: Mapped[int] = mapped_column(Integer, default=0)
[docs]
max_attempts: Mapped[int] = mapped_column(Integer, default=1)
[docs]
resource_scales = Column(String(1000), default=None)
[docs]
fallback_queues = Column(String(1000), default=None)
[docs]
status: Mapped[str] = mapped_column(
String(1), ForeignKey("task_status.id"), nullable=False
)
[docs]
status_date = mapped_column(DateTime, default=func.now(), index=True)
# ORM relationships
[docs]
task_instances = relationship("TaskInstance", back_populates="task")
[docs]
task_resources = relationship("TaskResources", foreign_keys=[task_resources_id])
[docs]
array = relationship("Array", foreign_keys=[array_id])
[docs]
__table_args__ = (
Index("ix_workflow_id_status_date", "workflow_id", "status_date"),
)
# Finite state machine
[docs]
valid_transitions = [
(TaskStatus.REGISTERING, TaskStatus.QUEUED),
(TaskStatus.ADJUSTING_RESOURCES, TaskStatus.QUEUED),
(TaskStatus.QUEUED, TaskStatus.INSTANTIATING),
(TaskStatus.INSTANTIATING, TaskStatus.LAUNCHED),
(TaskStatus.INSTANTIATING, TaskStatus.ERROR_RECOVERABLE),
(TaskStatus.LAUNCHED, TaskStatus.RUNNING),
(TaskStatus.LAUNCHED, TaskStatus.ERROR_RECOVERABLE),
(TaskStatus.INSTANTIATING, TaskStatus.ERROR_RECOVERABLE),
(TaskStatus.INSTANTIATING, TaskStatus.RUNNING),
(TaskStatus.RUNNING, TaskStatus.DONE),
(TaskStatus.RUNNING, TaskStatus.ERROR_RECOVERABLE),
(TaskStatus.ERROR_RECOVERABLE, TaskStatus.ADJUSTING_RESOURCES),
(TaskStatus.ERROR_RECOVERABLE, TaskStatus.QUEUED),
(TaskStatus.ERROR_RECOVERABLE, TaskStatus.ERROR_FATAL),
(TaskStatus.ERROR_RECOVERABLE, TaskStatus.REGISTERING),
]
[docs]
def reset(
self, name: str, command: str, max_attempts: int, reset_if_running: bool
) -> None:
"""Reset status and number of attempts on a Task."""
# only reset undone tasks
if self.status != TaskStatus.DONE:
# only reset if the task is not currently running or if we are
# resetting running tasks
if self.status != TaskStatus.RUNNING or reset_if_running:
self.status = TaskStatus.REGISTERING
self.num_attempts = 0
self.name = name
self.command = command
self.max_attempts = max_attempts
self.status_date = func.now()
[docs]
def transition(self, new_state: str) -> None:
"""Transition the Task to a new state."""
if self.status == new_state:
# do nothing if the task is already in the new state
return
# bind_to_logger(workflow_id=self.workflow_id, task_id=self.id)
logger.info(f"Transitioning task from {self.status} to {new_state}")
self._validate_transition(new_state)
if new_state == TaskStatus.QUEUED:
self.num_attempts = self.num_attempts + 1
self.status = new_state
self.status_date = func.now()
[docs]
def transition_after_task_instance_error(
self, job_instance_error_state: str
) -> None:
"""Transition the task to an error state."""
# bind_to_logger(workflow_id=self.workflow_id, task_id=self.id)
logger.info("Transitioning task to ERROR_RECOVERABLE")
self.transition(TaskStatus.ERROR_RECOVERABLE)
if self.num_attempts >= self.max_attempts:
logger.info("Giving up task after max attempts.")
self.transition(TaskStatus.ERROR_FATAL)
else:
if job_instance_error_state == TaskInstanceStatus.RESOURCE_ERROR:
logger.debug("Adjust resource for task.")
self.transition(TaskStatus.ADJUSTING_RESOURCES)
else:
logger.debug("Retrying Task.")
self.transition(TaskStatus.REGISTERING)
[docs]
def _validate_transition(self, new_state: str) -> None:
"""Ensure the task state transition is valid."""
if (self.status, new_state) not in self.valid_transitions:
raise InvalidStateTransition("Task", self.id, self.status, new_state)