diff --git a/python_modules/dagster-graphql/dagster_graphql/implementation/context.py b/python_modules/dagster-graphql/dagster_graphql/implementation/context.py --- a/python_modules/dagster-graphql/dagster_graphql/implementation/context.py +++ b/python_modules/dagster-graphql/dagster_graphql/implementation/context.py @@ -1,11 +1,16 @@ from dagster import check from dagster.core.host_representation import PipelineSelector, RepositoryLocation from dagster.core.host_representation.external import ExternalPipeline +from dagster.core.host_representation.grpc_server_state_subscriber import ( + LocationStateChangeEventType, + LocationStateSubscriber, +) from dagster.core.instance import DagsterInstance from dagster.grpc.types import ScheduleExecutionDataMode from dagster_graphql.implementation.utils import UserFacingGraphQLError from dagster_graphql.schema.errors import DauphinInvalidSubsetError from dagster_graphql.schema.pipelines import DauphinPipeline +from rx.subjects import Subject class DagsterGraphQLContext: @@ -13,6 +18,12 @@ self._instance = check.inst_param(instance, "instance", DagsterInstance) self._workspace = workspace self._repository_locations = {} + + self._location_state_events = Subject() + self._location_state_subscriber = LocationStateSubscriber( + self._location_state_events_handler + ) + for handle in self._workspace.repository_location_handles: check.invariant( self._repository_locations.get(handle.location_name) is None, @@ -20,9 +31,12 @@ name=handle.location_name, ), ) + + handle.add_state_subscriber(self._location_state_subscriber) self._repository_locations[handle.location_name] = RepositoryLocation.from_handle( handle ) + self.version = version @property @@ -37,6 +51,22 @@ def repository_location_names(self): return self._workspace.repository_location_names + def _location_state_events_handler(self, event): + # If the server was updated or we were not able to reconnect, we immediately reload the + # location handle + + if event.event_type == LocationStateChangeEventType.LOCATION_UPDATED: + # Reload the handle to get updated repository data and re-attach a subscriber + self.reload_repository_location(event.location_name) + new_handle = self._workspace.get_repository_location_handle(event.location_name) + new_handle.add_state_subscriber(self._location_state_subscriber) + elif event.event_type == LocationStateChangeEventType.LOCATION_ERROR: + # Just reload the handle in order to update the workspace with the correct + # error messages + self.reload_repository_location(event.location_name) + + self._location_state_events.on_next(event) + def repository_location_errors(self): return self._workspace.repository_location_errors diff --git a/python_modules/dagster/dagster/core/host_representation/grpc_server_state_subscriber.py b/python_modules/dagster/dagster/core/host_representation/grpc_server_state_subscriber.py new file mode 100644 --- /dev/null +++ b/python_modules/dagster/dagster/core/host_representation/grpc_server_state_subscriber.py @@ -0,0 +1,34 @@ +from collections import namedtuple +from enum import Enum + +from dagster import check + + +class LocationStateChangeEventType(Enum): + LOCATION_UPDATED = "LOCATION_UPDATED" + LOCATION_DISCONNECTED = "LOCATION_DISCONNECTED" + LOCATION_RECONNECTED = "LOCATION_RECONNECTED" + LOCATION_ERROR = "LOCATION_ERROR" + + +class LocationStateChangeEvent( + namedtuple("_LocationStateChangeEvent", "event_type location_name message server_id") +): + def __new__(cls, event_type, location_name, message, server_id=None): + return super(LocationStateChangeEvent, cls).__new__( + cls, + check.inst_param(event_type, "event_type", LocationStateChangeEventType), + check.str_param(location_name, "location_name"), + check.str_param(message, "message"), + check.opt_str_param(server_id, "server_id"), + ) + + +class LocationStateSubscriber(object): + def __init__(self, callback): + check.callable_param(callback, "callback") + self._callback = callback + + def handle_event(self, event): + check.inst_param(event, "event", LocationStateChangeEvent) + self._callback(event) diff --git a/python_modules/dagster/dagster/core/host_representation/handle.py b/python_modules/dagster/dagster/core/host_representation/handle.py --- a/python_modules/dagster/dagster/core/host_representation/handle.py +++ b/python_modules/dagster/dagster/core/host_representation/handle.py @@ -8,6 +8,10 @@ from dagster.api.list_repositories import sync_list_repositories_grpc from dagster.core.definitions.reconstructable import repository_def_from_pointer from dagster.core.errors import DagsterInvariantViolationError +from dagster.core.host_representation.grpc_server_state_subscriber import ( + LocationStateChangeEvent, + LocationStateChangeEventType, +) from dagster.core.host_representation.origin import ( ExternalRepositoryOrigin, GrpcServerRepositoryLocationOrigin, @@ -40,6 +44,9 @@ def cleanup(self): pass + def add_state_subscriber(self, subscriber): + pass + @staticmethod def create_from_repository_location_origin(repo_location_origin): check.inst_param(repo_location_origin, "repo_location_origin", RepositoryLocationOrigin) @@ -90,6 +97,7 @@ def __init__(self, origin): from dagster.grpc.client import DagsterGrpcClient + from dagster.grpc.server_watcher import create_grpc_watch_thread self.origin = check.inst_param(origin, "origin", GrpcServerRepositoryLocationOrigin) @@ -105,9 +113,43 @@ symbol.repository_name for symbol in list_repositories_response.repository_symbols ) + self._state_subscribers = [] + watch_thread_shutdown_event, watch_thread = create_grpc_watch_thread( + self.client, + on_updated=lambda: self._send_state_event_to_subscribers( + LocationStateChangeEvent( + LocationStateChangeEventType.LOCATION_ERROR, + location_name=self.location_name, + message="Server has been updated.", + ) + ), + on_error=lambda: self._send_state_event_to_subscribers( + LocationStateChangeEvent( + LocationStateChangeEventType.LOCATION_ERROR, + location_name=self.location_name, + message="Unable to reconnect to server. In error state", + ) + ), + ) + self._watch_thread_shutdown_event = watch_thread_shutdown_event + self._watch_thread = watch_thread + self._watch_thread.start() + self.executable_path = list_repositories_response.executable_path self.repository_code_pointer_dict = list_repositories_response.repository_code_pointer_dict + def add_state_subscriber(self, subscriber): + self._state_subscribers.append(subscriber) + + def _send_state_event_to_subscribers(self, event): + check.inst_param(event, "event", LocationStateChangeEvent) + for subscriber in self._state_subscribers: + subscriber.handle_event(event) + + def cleanup(self): + self._watch_thread_shutdown_event.set() + self._watch_thread.join() + @property def port(self): return self.origin.port diff --git a/python_modules/dagster/dagster/grpc/server_watcher.py b/python_modules/dagster/dagster/grpc/server_watcher.py --- a/python_modules/dagster/dagster/grpc/server_watcher.py +++ b/python_modules/dagster/dagster/grpc/server_watcher.py @@ -96,7 +96,7 @@ on_updated() set_server_id(new_server_id) return False - except grpc.RpcError: # pylint: disable=protected-access + except grpc._channel._InactiveRpcError: # pylint: disable=protected-access attempts += 1 on_error() @@ -106,7 +106,7 @@ try: watch_for_changes() return - except grpc.RpcError: # pylint: disable=protected-access + except grpc._channel._InactiveRpcError: # pylint: disable=protected-access on_disconnect() reconnected_to_same_server = reconnect_loop() if not reconnected_to_same_server: