Source code for core.cluster

"""Cluster objects define where a user wants their tasks run. e.g. UGE, Azure, Seq."""

from __future__ import annotations

import logging
from typing import Any, Dict, Optional

from jobmon.core.cluster_protocol import (
    ClusterDistributor,
    ClusterQueue,
    ClusterWorkerNode,
)
from jobmon.core.cluster_type import ClusterType
from jobmon.core.requester import Requester
from jobmon.core.serializers import SerializeCluster, SerializeQueue


[docs] logger = logging.getLogger(__name__)
[docs] class Cluster: """Cluster objects define where a user wants their tasks run. e.g. UGE, Azure, Seq.""" def __init__( self, cluster_name: str, requester: Optional[Requester] = None ) -> None: """Initialization of Cluster."""
[docs] self.cluster_name = cluster_name
if requester is None: requester = Requester.from_defaults()
[docs] self.requester = requester
[docs] self.queues: Dict[str, ClusterQueue] = {}
@classmethod
[docs] def get_cluster( cls: Any, cluster_name: str, requester: Optional[Requester] = None ) -> Cluster: """Get a bound instance of a Cluster. Args: cluster_name: the name of the cluster requester (Requester): requester object to connect to Flask service. """ cluster = cls(cluster_name, requester) cluster.bind() return cluster
[docs] def bind(self) -> None: """Bind Cluster to the database, getting an id back.""" app_route = f"/cluster/{self.cluster_name}" _, response = self.requester.send_request( app_route=app_route, message={}, request_type="get" ) cluster_kwargs = SerializeCluster.kwargs_from_wire(response["cluster"]) self._cluster_id = cluster_kwargs["id"] cluster_type = ClusterType(cluster_kwargs["cluster_type_name"]) self._cluster_type = cluster_type self._connection_parameters = cluster_kwargs["connection_parameters"]
@property
[docs] def connection_parameters(self) -> Dict: """The connection parameters.""" return self._connection_parameters
@property
[docs] def is_bound(self) -> bool: """If the Cluster has been bound to the database.""" return hasattr(self, "_cluster_id")
@property
[docs] def id(self) -> int: """Unique id from database if Cluster has been bound.""" if not self.is_bound: raise AttributeError("Cannot access id until Cluster is bound to database") return self._cluster_id
[docs] def get_worker_node(self) -> ClusterWorkerNode: """Get the cluster specific worker_node interface.""" cluster_worker_node_class = self._cluster_type.cluster_worker_node_class return cluster_worker_node_class()
[docs] def get_distributor(self) -> ClusterDistributor: """Get the cluster specific distributor interface.""" # TODO: read in cluster args from config here? distributor_class = self._cluster_type.cluster_distributor_class return distributor_class(self.cluster_name, **self._connection_parameters)
[docs] def get_queue(self, queue_name: str) -> ClusterQueue: """Get the ClusterQueue object associated with a given queue_name. Checks if queue object is in the cache, if it's not it will query the database and add the queue object to the cache. Args: queue_name: name of the queue you want. """ # this is cached so should be fast try: queue = self.queues[queue_name] except KeyError: queue_class = self._cluster_type.cluster_queue_class app_route = f"/cluster/{self.id}/queue/{queue_name}" _, response = self.requester.send_request( app_route=app_route, message={}, request_type="get" ) queue_kwargs = SerializeQueue.kwargs_from_wire(response["queue"]) queue = queue_class(**queue_kwargs) self.queues[queue_name] = queue return queue