from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
import pandas as pd
import structlog
from sqlalchemy import Select, func, select, text, update
from sqlalchemy.orm import Session
from jobmon.core.constants import WorkflowStatus as Statuses
from jobmon.server.web.models.node import Node
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.task_template import TaskTemplate
from jobmon.server.web.models.task_template_version import TaskTemplateVersion
from jobmon.server.web.models.tool import Tool
from jobmon.server.web.models.tool_version import ToolVersion
from jobmon.server.web.models.workflow import Workflow
from jobmon.server.web.models.workflow_run import WorkflowRun
from jobmon.server.web.models.workflow_status import WorkflowStatus
from jobmon.server.web.schemas.workflow import (
TaskTableItem,
TaskTableResponse,
WorkflowDetailsItem,
WorkflowOverviewItem,
WorkflowOverviewResponse,
WorkflowRunForResetResponse,
WorkflowStatusResponse,
WorkflowTasksResponse,
WorkflowUserValidationResponse,
WorkflowValidationResponse,
)
[docs]
logger = structlog.get_logger(__name__)
[docs]
_cli_label_mapping = {
"A": "PENDING",
"G": "PENDING",
"Q": "PENDING",
"I": "PENDING",
"E": "PENDING",
"O": "SCHEDULED",
"R": "RUNNING",
"F": "FATAL",
"D": "DONE",
}
[docs]
_reversed_cli_label_mapping = {
"SCHEDULED": ["O"],
"PENDING": ["A", "G", "Q", "E", "I"],
"RUNNING": ["R"],
"FATAL": ["F"],
"DONE": ["D"],
}
[docs]
_cli_order = ["PENDING", "SCHEDULED", "RUNNING", "DONE", "FATAL"]
[docs]
class WorkflowRepository:
def __init__(self, session: Session) -> None:
"""Initialize the workflow repository."""
[docs]
def get_workflow_validation_status(
self, task_ids: List[int]
) -> WorkflowValidationResponse:
"""Check if workflow is valid."""
# if the given list is empty, return True
if len(task_ids) == 0:
return WorkflowValidationResponse(validation=True)
# execute query
query_filter = [Task.workflow_id == Workflow.id, Task.id.in_(task_ids)]
sql = (
select(Task.workflow_id, Workflow.status).where(*query_filter)
).distinct()
rows = self.session.execute(sql).all()
res = [ti[1] for ti in rows]
# Validate if all tasks are in the same workflow and the workflow status is dead
if len(res) == 1 and res[0] in (
Statuses.FAILED,
Statuses.DONE,
Statuses.ABORTED,
Statuses.HALTED,
):
validation = True
else:
validation = False
return WorkflowValidationResponse(
validation=validation, workflow_status=res[0] if res else None
)
[docs]
def get_workflow_tasks(
self, workflow_id: int, limit: int, status: Optional[List[str]] = None
) -> WorkflowTasksResponse:
"""Get the tasks for a given workflow."""
logger.debug(f"Get tasks for workflow in status {status}")
if status:
query_filter = [
Workflow.id == Task.workflow_id,
Task.status.in_(
[i for arg in status for i in _reversed_cli_label_mapping[arg]]
),
Workflow.id == int(workflow_id),
]
else:
query_filter = [
Workflow.id == Task.workflow_id,
Workflow.id == int(workflow_id),
]
sql = (
select(Task.id, Task.name, Task.status, Task.num_attempts).where(
*query_filter
)
).order_by(Task.id.desc())
rows = self.session.execute(sql).all()
column_names = ("TASK_ID", "TASK_NAME", "STATUS", "RETRIES")
res = [dict(zip(column_names, ti)) for ti in rows]
for r in res:
r["RETRIES"] = 0 if r["RETRIES"] <= 1 else r["RETRIES"] - 1
if limit:
res = res[: int(limit)]
logger.debug(f"The following tasks of workflow are in status {status}:\n{res}")
if res:
# assign to dataframe for serialization
df = pd.DataFrame(res, columns=list(res[0].keys()))
# remap to jobmon_cli statuses
df.STATUS.replace(to_replace=_cli_label_mapping, inplace=True)
df_json = df.to_json()
else:
df = pd.DataFrame({}, columns=["TASK_ID", "TASK_NAME", "STATUS", "RETRIES"])
df_json = df.to_json()
return WorkflowTasksResponse(workflow_tasks=df_json)
[docs]
def get_workflow_user_validation(
self, workflow_id: int, username: str
) -> WorkflowUserValidationResponse:
"""Return all usernames associated with a given workflow_id's workflow runs."""
logger.debug(f"Validate user name {username} for workflow")
query_filter = [WorkflowRun.workflow_id == workflow_id]
sql = (select(WorkflowRun.user).where(*query_filter)).distinct()
rows = self.session.execute(sql).all()
usernames = [row[0] for row in rows]
return WorkflowUserValidationResponse(validation=username in usernames)
[docs]
def get_workflow_run_for_reset(
self, workflow_id: int, username: str
) -> WorkflowRunForResetResponse:
"""Get last workflow_run_id for workflow reset validation."""
query_filter = [
WorkflowRun.workflow_id == workflow_id,
WorkflowRun.status == "E",
]
sql = (select(WorkflowRun.id, WorkflowRun.user).where(*query_filter)).order_by(
WorkflowRun.created_date.desc()
)
rows = self.session.execute(sql).all()
result = None if len(rows) <= 0 else rows[0]
if result is not None and result[1] == username:
workflow_run_id = result[0]
else:
workflow_run_id = None
return WorkflowRunForResetResponse(workflow_run_id=workflow_run_id)
[docs]
def reset_workflow(self, workflow_id: int, partial_reset: bool = False) -> None:
"""Update the workflow's status, all its tasks' statuses to 'G'."""
current_time = self.session.query(func.now()).scalar()
workflow_query = select(Workflow).where(Workflow.id == workflow_id)
workflow = self.session.execute(workflow_query).scalars().one_or_none()
if workflow:
workflow.reset(current_time=current_time)
self.session.flush()
# Update task statuses associated with the workflow
# Default behavior is a full workflow reset, all tasks to registered state
# User can optionally request only a partial reset if they want to resume
invalid_statuses = ["G"]
if partial_reset:
invalid_statuses.append("D")
update_filter = [
Task.workflow_id == workflow_id,
Task.status.notin_(invalid_statuses),
]
update_stmt = (
update(Task)
.where(*update_filter)
.values(status="G", status_date=func.now(), num_attempts=0)
)
self.session.execute(update_stmt)
self.session.commit()
[docs]
def get_workflow_status(
self,
workflow_id: Optional[Union[int, str, List[Union[int, str]]]] = None,
limit: Optional[int] = None,
user: Optional[List[str]] = None,
) -> WorkflowStatusResponse:
"""Get the status of the workflow."""
user_request = user
if user_request == "all": # specifying all is equivalent to None
user_request = []
if isinstance(workflow_id, int):
workflow_request = [workflow_id]
elif isinstance(workflow_id, str) and workflow_id == "all":
workflow_request = []
else:
workflow_request = workflow_id # type: ignore
logger.debug(f"Query for wf {workflow_request} status.")
# set default to 5 to match status_commands
limit = int(limit) if limit else 5
# convert workflow request into sql filter
if workflow_request:
workflow_request = [int(w) for w in workflow_request]
else: # if we don't specify workflow then we use the users
# convert user request into sql filter
# directly producing workflow_ids, and thus where_clause
if user_request:
query_filter = [WorkflowRun.user.in_(user_request)]
sql = (
(select(WorkflowRun.workflow_id).where(*query_filter))
.distinct()
.order_by(WorkflowRun.workflow_id.desc())
.limit(limit)
)
rows = self.session.execute(sql).all()
workflow_request = [int(row[0]) for row in rows]
# performance improvement one: only query the limited number of workflows
workflow_request = workflow_request[:limit]
# performance improvement two: split query
query_filter = [
Workflow.id.in_(workflow_request), # type: ignore
WorkflowStatus.id == Workflow.status, # type: ignore
]
sql1: Select[
Tuple[Optional[int], Optional[str], Optional[str], Optional[datetime]]
] = (
select(
Workflow.id,
Workflow.name,
WorkflowStatus.label,
Workflow.created_date,
)
).where(
*query_filter
)
rows1 = self.session.execute(sql1).all()
row_map = dict()
for r in rows1:
row_map[r[0]] = r
query_filter = [
Task.workflow_id.in_(workflow_request),
]
sql2: Select[Tuple[int, int, str]] = (
select(
Task.workflow_id,
func.count(Task.status),
Task.status,
).where(*query_filter)
).group_by(Task.workflow_id, Task.status)
rows2 = self.session.execute(sql2).all()
res = []
for r in rows2: # type: ignore
d = dict()
d["WF_ID"] = r[0]
d["WF_NAME"] = row_map[r[0]][1]
d["WF_STATUS"] = row_map[r[0]][2]
d["TASKS"] = r[1]
d["STATUS"] = r[2]
d["CREATED_DATE"] = row_map[r[0]][3]
q_filter = [Task.workflow_id == d["WF_ID"], Task.status == d["STATUS"]]
q = select(Task.num_attempts).where(*q_filter)
query_result = self.session.execute(q).all()
retries = 0
for rr in query_result:
retries += 0 if int(rr[0]) <= 1 else int(rr[0]) - 1
d["RETRIES"] = retries
res.append(d)
if res is not None and len(res) > 0:
# assign to dataframe for aggregation
df = pd.DataFrame(res, columns=res[0].keys())
# remap to jobmon_cli statuses
df.STATUS.replace(to_replace=_cli_label_mapping, inplace=True)
# aggregate totals by workflow and status
df = df.groupby(
["WF_ID", "WF_NAME", "WF_STATUS", "STATUS", "CREATED_DATE"]
).agg({"TASKS": "sum", "RETRIES": "sum"})
# pivot wide by task status
tasks = df.pivot_table(
values="TASKS",
index=["WF_ID", "WF_NAME", "WF_STATUS", "CREATED_DATE"],
columns="STATUS",
fill_value=0,
)
for col in _cli_order:
if col not in tasks.columns:
tasks[col] = 0
tasks = tasks[_cli_order]
# aggregate again without status to get the totals by workflow
retries = df.groupby(["WF_ID", "WF_NAME", "WF_STATUS", "CREATED_DATE"]).agg(
{"TASKS": "sum", "RETRIES": "sum"}
)
# combine datasets
df = pd.concat([tasks, retries], axis=1)
# compute pcts and format
for col in _cli_order:
df[col + "_pct"] = (
df[col].astype(float) / df["TASKS"].astype(float)
) * 100
df[col + "_pct"] = df[[col + "_pct"]].round(1)
df[col] = (
df[col].astype(int).astype(str)
+ " ("
+ df[col + "_pct"].astype(str)
+ "%)"
)
# final order
df = df[["TASKS"] + _cli_order + ["RETRIES"]]
df = df.reset_index()
df_json = df.to_json()
else:
df = pd.DataFrame(
{},
columns=[
"WF_ID",
"WF_NAME",
"WF_STATUS",
"CREATED_DATE",
"TASKS",
"PENDING",
"RUNNING",
"DONE",
"FATAL",
"RETRIES",
],
).to_json()
df_json = df
return WorkflowStatusResponse(workflows=df_json)
[docs]
def get_workflow_status_viz(self, workflow_ids: List[int]) -> Dict[int, Any]:
"""Get the status of the workflows for GUI."""
wf_ids = workflow_ids
# return DS
return_dic: Dict[int, Any] = dict()
for wf_id in wf_ids:
attempts_sql = select(
func.coalesce(func.min(Task.num_attempts), 0).label("min"),
func.coalesce(func.max(Task.num_attempts), 0).label("max"),
func.coalesce(func.avg(Task.num_attempts), 0.0).label("mean"),
).where(Task.workflow_id == wf_id)
attempts = self.session.execute(attempts_sql).first()
return_dic[int(wf_id)] = {
"id": int(wf_id),
"tasks": 0,
"PENDING": 0,
"SCHEDULED": 0,
"RUNNING": 0,
"DONE": 0,
"FATAL": 0,
"MAXC": 0,
"num_attempts_avg": float(attempts.mean), # type: ignore
"num_attempts_min": int(attempts.min), # type: ignore
"num_attempts_max": int(attempts.max), # type: ignore
}
query_filter = [
Task.workflow_id.in_(wf_ids),
Task.workflow_id == Workflow.id,
]
status_sql: Select[Tuple[int, str, int]] = select(
Task.workflow_id, Task.status, Workflow.max_concurrently_running
).where(*query_filter)
rows = self.session.execute(status_sql).all()
for row in rows:
return_dic[row[0]]["tasks"] += 1
return_dic[row[0]][_cli_label_mapping[row[1]]] += 1
return_dic[row[0]]["MAXC"] = row[2]
return return_dic
[docs]
def _add_multi_value_filter(
self,
value: Optional[str],
column: str,
param_name: str,
where_clauses: list,
substitution_dict: dict,
) -> None:
"""Add a filter that supports comma-separated values with OR logic."""
if not value:
return
value_list = [v.strip() for v in value.split(",") if v.strip()]
if not value_list:
return
if len(value_list) == 1:
where_clauses.append(f"{column} = :{param_name}")
substitution_dict[param_name] = value_list[0]
else:
placeholders = ",".join(
[f":{param_name}_{i}" for i in range(len(value_list))]
)
where_clauses.append(f"{column} IN ({placeholders})")
for i, v in enumerate(value_list):
substitution_dict[f"{param_name}_{i}"] = v
[docs]
def get_workflow_overview(
self,
user: Optional[str] = None,
tool: Optional[str] = None,
wf_name: Optional[str] = None,
wf_args: Optional[str] = None,
wf_attribute_value: Optional[str] = None,
wf_attribute_key: Optional[str] = None,
wf_id: Optional[str] = None,
date_submitted: Optional[str] = None,
date_submitted_end: Optional[str] = None,
status: Optional[str] = None,
) -> WorkflowOverviewResponse:
"""Fetch associated workflows and workflow runs by username."""
where_clauses: list[str] = []
substitution_dict: dict[str, Any] = {}
self._add_multi_value_filter(
user, "workflow_run.user", "user", where_clauses, substitution_dict
)
self._add_multi_value_filter(
tool, "tool.name", "tool", where_clauses, substitution_dict
)
self._add_multi_value_filter(
status, "workflow.status", "status", where_clauses, substitution_dict
)
if wf_name:
where_clauses.append("workflow.name = :wf_name")
substitution_dict["wf_name"] = wf_name
if wf_args:
where_clauses.append("workflow.workflow_args = :wf_args")
substitution_dict["wf_args"] = wf_args
if wf_attribute_key:
where_clauses.append("workflow_attribute_type.name = :wf_attribute_key")
substitution_dict["wf_attribute_key"] = wf_attribute_key
if wf_attribute_value:
where_clauses.append("workflow_attribute.value = :wf_attribute_value")
substitution_dict["wf_attribute_value"] = wf_attribute_value
if wf_id:
where_clauses.append("workflow.id = :wf_id")
substitution_dict["wf_id"] = wf_id # type: ignore
if date_submitted:
where_clauses.append("workflow.created_date >= :date_submitted")
substitution_dict["date_submitted"] = date_submitted
if date_submitted_end:
where_clauses.append("workflow.created_date <= :date_submitted_end")
substitution_dict["date_submitted_end"] = date_submitted_end
if where_clauses:
inner_where_clause = " WHERE " + (" AND ".join(where_clauses))
else:
inner_where_clause = ""
query = text(
f"""
SELECT
workflow.id,
workflow.name,
workflow.created_date,
workflow.status_date,
workflow.workflow_args,
count(distinct workflow_run.id) as num_attempts,
workflow_status.label,
tool.name
FROM
workflow
JOIN (
SELECT
distinct queue_id,
workflow_id
FROM
task
JOIN task_resources ON task_resources.id = task.task_resources_id
WHERE
task.workflow_id IN (
SELECT
workflow_run.workflow_id
FROM
workflow
JOIN tool_version ON
workflow.tool_version_id = tool_version.id
JOIN tool ON tool.id = tool_version.tool_id
JOIN workflow_run ON workflow.id = workflow_run.workflow_id
LEFT JOIN workflow_attribute
ON workflow.id = workflow_attribute.workflow_id
LEFT JOIN workflow_attribute_type
ON workflow_attribute.workflow_attribute_type_id =
workflow_attribute_type.id
{inner_where_clause}
)
GROUP BY
workflow_id, queue_id
) workflow_queue ON workflow.id = workflow_queue.workflow_id
JOIN queue ON queue.id = workflow_queue.queue_id
JOIN workflow_run ON workflow.id = workflow_run.workflow_id
JOIN tool_version ON workflow.tool_version_id = tool_version.id
JOIN tool ON tool.id = tool_version.tool_id
JOIN workflow_status ON workflow.status = workflow_status.id
WHERE
cluster_id != 1
GROUP BY
workflow.id
ORDER BY
workflow.id DESC
"""
)
rows = self.session.execute(query, substitution_dict).all()
def serialize_datetime(obj: Union[datetime, str]) -> str:
"""Serialize datetime objects into string format."""
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, str):
# Handle case where database returns datetime as string (e.g., SQLite)
return obj
raise TypeError(f"Type {obj.__class__.__name__} not serializable")
column_names = (
"wf_id",
"wf_name",
"wf_submitted_date",
"wf_status_date",
"wf_args",
"wfr_count",
"wf_status",
"wf_tool",
)
# Initialize all possible states as 0.
# No need to return data since it will be refreshed
# on demand anyways.
initial_status_counts = {
label_mapping: 0 for label_mapping in set(_cli_label_mapping.values())
}
workflows = []
for row in rows:
workflow_data = dict(zip(column_names, row))
workflow_data.update(initial_status_counts)
workflow_data["wf_submitted_date"] = serialize_datetime(row[2])
workflow_data["wf_status_date"] = serialize_datetime(row[3])
workflows.append(WorkflowOverviewItem(**workflow_data))
return WorkflowOverviewResponse(workflows=workflows)
[docs]
def get_task_details_by_workflow_id(
self, workflow_id: int, tt_name: str
) -> TaskTableResponse:
"""Fetch Task details associated with Workflow ID and TaskTemplate name."""
task_template_name = tt_name
sql = (
select(
Task.id,
Task.name,
Task.status,
Task.command,
Task.num_attempts,
Task.status_date,
Task.max_attempts,
)
.where(
Task.workflow_id == workflow_id,
Task.node_id == Node.id,
Node.task_template_version_id == TaskTemplateVersion.id,
TaskTemplateVersion.task_template_id == TaskTemplate.id,
TaskTemplate.name == task_template_name,
)
.order_by(Task.id.asc())
)
rows = self.session.execute(sql).all()
column_names = (
"task_id",
"task_name",
"task_status",
"task_command",
"task_num_attempts",
"task_status_date",
"task_max_attempts",
)
tasks = []
for row in rows:
task_data = dict(zip(column_names, row))
task_data["task_status"] = _cli_label_mapping[task_data["task_status"]]
task_data["task_status_date"] = str(task_data["task_status_date"])
tasks.append(TaskTableItem(**task_data))
return TaskTableResponse(tasks=tasks)
[docs]
def get_workflow_details_by_id(self, workflow_id: int) -> List[WorkflowDetailsItem]:
"""Fetch name, args, dates, tool for a Workflow provided WF ID."""
latest_workflow_run_subquery = (
self.session.query(
WorkflowRun.workflow_id, func.max(WorkflowRun.heartbeat_date)
)
.group_by(WorkflowRun.workflow_id)
.subquery()
)
sql = (
select(
Workflow.name,
Workflow.workflow_args,
Workflow.created_date,
Workflow.status_date,
Tool.name,
Workflow.status,
WorkflowStatus.description,
WorkflowRun.jobmon_version,
WorkflowRun.heartbeat_date,
WorkflowRun.user,
)
.select_from(Workflow)
.join(ToolVersion, Workflow.tool_version_id == ToolVersion.id)
.join(Tool, ToolVersion.tool_id == Tool.id)
.join(WorkflowStatus, WorkflowStatus.id == Workflow.status)
.join(WorkflowRun, WorkflowRun.workflow_id == Workflow.id)
.join(
latest_workflow_run_subquery,
)
.where(
Workflow.id == workflow_id,
)
)
rows = self.session.execute(sql).all()
column_names = (
"wf_name",
"wf_args",
"wf_created_date",
"wf_status_date",
"tool_name",
"wf_status",
"wf_status_desc",
"wfr_jobmon_version",
"wfr_heartbeat_date",
"wfr_user",
)
result = [dict(zip(column_names, row)) for row in rows]
date_fields = ["wf_status_date", "wf_created_date", "wfr_heartbeat_date"]
for row in result:
for field in date_fields:
if field in row and isinstance(row[field], datetime):
row[field] = row[field].isoformat()
# Convert to Pydantic models
workflow_details = [WorkflowDetailsItem(**row) for row in result]
return workflow_details