Source code for server.web.routes.v3.fsm.node

"""Routes used by the main jobmon client."""

from http import HTTPStatus as StatusCodes
from time import sleep
from typing import Any, Dict, cast

import structlog
from fastapi import Depends, Request
from sqlalchemy import insert, select
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse

from jobmon.server.web.db import get_dialect_name
from jobmon.server.web.db.deps import get_db
from jobmon.server.web.models.node import Node
from jobmon.server.web.models.node_arg import NodeArg
from jobmon.server.web.routes.v3.fsm import fsm_router as api_v3_router
from jobmon.server.web.server_side_exception import ServerError

[docs] logger = structlog.get_logger(__name__)
[docs] DIALECT = get_dialect_name()
@api_v3_router.post("/nodes")
[docs] async def add_nodes(request: Request, db: Session = Depends(get_db)) -> Any: """Add a chunk of nodes to the database. Args: request: The request object. db: The database session. """ data = cast(Dict, await request.json()) # Extract node and node_args # Bulk insert the nodes and node args with raw SQL, for performance. Ignore duplicate # keys node_keys = [ (n["task_template_version_id"], n["node_args_hash"]) for n in data["nodes"] ] node_insert_stmt = insert(Node).values( [ {"task_template_version_id": ttv, "node_args_hash": arghash} for ttv, arghash in node_keys ] ) if DIALECT == "mysql": node_insert_stmt = node_insert_stmt.prefix_with("IGNORE") elif DIALECT == "sqlite": node_insert_stmt = node_insert_stmt.prefix_with("OR IGNORE") else: raise ServerError(f"Unsupported SQL dialect '{DIALECT}'") # Retry logic for deadlock handling max_retries = 5 for attempt in range(max_retries): try: db.execute(node_insert_stmt) db.flush() break # Success, exit retry loop except OperationalError as e: logger.warning( f"Database error detected for node insert, retrying attempt " f"{attempt + 1}/{max_retries}. {e}" ) db.rollback() # Clear the corrupted session state sleep( 0.001 * (2 ** (attempt + 1)) ) # Exponential backoff: 2ms, 4ms, 8ms, 16ms, 32ms except Exception as e: logger.error(f"Failed to insert nodes: {e}") db.rollback() raise e # Retrieve the node IDs ttvids, node_arg_hashes = zip(*node_keys) select_stmt = select(Node).where( Node.task_template_version_id.in_(ttvids), Node.node_args_hash.in_(node_arg_hashes), ) nodes = db.execute(select_stmt).scalars().all() node_id_dict = {(n.task_template_version_id, n.node_args_hash): n.id for n in nodes} # Add node args. Cast hash to string to match DB schema node_args = { (n["task_template_version_id"], n["node_args_hash"]): n["node_args"] for n in data["nodes"] } node_args_list = [] for node_id_tuple, arg in node_args.items(): node_id = node_id_dict[node_id_tuple] local_logger = logger.bind(node_id=node_id) for arg_id, val in arg.items(): local_logger.debug( "Adding node_arg", node_id=node_id, arg_id=arg_id, val=val ) node_args_list.append({"node_id": node_id, "arg_id": arg_id, "val": val}) # Bulk insert again with raw SQL. Pass the same session. _insert_node_args(node_args_list, db) # return result return_nodes = { ":".join(str(i) for i in key): val for key, val in node_id_dict.items() } resp = JSONResponse(content={"nodes": return_nodes}, status_code=StatusCodes.OK) return resp
[docs] def _insert_node_args(node_args_list: list, db: Session) -> None: if node_args_list: node_arg_insert_stmt = insert(NodeArg).values(node_args_list) if DIALECT == "mysql": node_arg_insert_stmt = node_arg_insert_stmt.prefix_with("IGNORE") elif DIALECT == "sqlite": node_arg_insert_stmt = node_arg_insert_stmt.prefix_with("OR IGNORE") else: raise ServerError(f"Unsupported SQL dialect '{DIALECT}'") db.execute(node_arg_insert_stmt)