diff --git a/.buildkite/dagster-buildkite/dagster_buildkite/steps/trigger.py b/.buildkite/dagster-buildkite/dagster_buildkite/steps/trigger.py index b8636e1ab..bc2c8a3ba 100644 --- a/.buildkite/dagster-buildkite/dagster_buildkite/steps/trigger.py +++ b/.buildkite/dagster-buildkite/dagster_buildkite/steps/trigger.py @@ -1,37 +1,39 @@ import subprocess from typing import Dict, List, Optional def trigger_step( pipeline: str, branches: Optional[List[str]] = None, async_step: bool = False, if_condition: str = None, ) -> Dict: """trigger_step: Trigger a build of another pipeline. See: https://buildkite.com/docs/pipelines/trigger-step Parameters: pipeline (str): The pipeline to trigger branches (List[str]): List of branches to trigger async_step (bool): If set to true the step will immediately continue, regardless of the success of the triggered build. If set to false the step will wait for the triggered build to complete and continue only if the triggered build passed. if_condition (str): A boolean expression that omits the step when false. Cannot be set with "branches" also set. """ - commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() + commit = ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("utf-8").strip() + ) step = { "trigger": pipeline, "label": f":link: {pipeline} from dagster@{commit}", "async": async_step, } if branches: step["branches"] = " ".join(branches) if if_condition: step["if"] = if_condition return step diff --git a/.buildkite/dagster-buildkite/dagster_buildkite/utils.py b/.buildkite/dagster-buildkite/dagster_buildkite/utils.py index f2ae7055e..27d38146a 100644 --- a/.buildkite/dagster-buildkite/dagster_buildkite/utils.py +++ b/.buildkite/dagster-buildkite/dagster_buildkite/utils.py @@ -1,96 +1,96 @@ import os import subprocess import yaml DAGIT_PATH = "js_modules/dagit" def buildkite_yaml_for_steps(steps): return yaml.dump( { "env": { "CI_NAME": "buildkite", "CI_BUILD_NUMBER": "$BUILDKITE_BUILD_NUMBER", "CI_BUILD_URL": "$BUILDKITE_BUILD_URL", "CI_BRANCH": "$BUILDKITE_BRANCH", "CI_PULL_REQUEST": "$BUILDKITE_PULL_REQUEST", }, "steps": steps, }, default_flow_style=False, ) def check_for_release(): try: git_tag = str( subprocess.check_output( ["git", "describe", "--exact-match", "--abbrev=0"], stderr=subprocess.STDOUT ) ).strip("'b\\n") except subprocess.CalledProcessError: return False version = {} with open("python_modules/dagster/dagster/version.py") as fp: exec(fp.read(), version) # pylint: disable=W0122 if git_tag == version["__version__"]: return True return False def is_phab_and_dagit_only(): branch_name = os.getenv("BUILDKITE_BRANCH") if branch_name is None: branch_name = ( subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) .decode("utf-8") .strip() ) if not branch_name.startswith("phabricator"): return False try: base_branch = branch_name.replace("/diff/", "/base/") subprocess.check_call(["git", "fetch", "--tags"]) diff_files = ( subprocess.check_output(["git", "diff", base_branch, branch_name, "--name-only"]) - .decode() + .decode("utf-8") .strip() .split("\n") ) return all(filepath.startswith(DAGIT_PATH) for (filepath) in diff_files) except subprocess.CalledProcessError: return False def network_buildkite_container(network_name): return [ # hold onto your hats, this is docker networking at its best. First, we figure out # the name of the currently running container... "export CONTAINER_ID=`cut -c9- < /proc/1/cpuset`", r'export CONTAINER_NAME=`docker ps --filter "id=\${CONTAINER_ID}" --format "{{.Names}}"`', # then, we dynamically bind this container into the user-defined bridge # network to make the target containers visible... "docker network connect {network_name} \\${{CONTAINER_NAME}}".format( network_name=network_name ), ] def connect_sibling_docker_container(network_name, container_name, env_variable): return [ # Now, we grab the IP address of the target container from within the target # bridge network and export it; this will let the tox tests talk to the target cot. ( "export {env_variable}=`docker inspect --format " "'{{{{ .NetworkSettings.Networks.{network_name}.IPAddress }}}}' " "{container_name}`".format( network_name=network_name, container_name=container_name, env_variable=env_variable ) ) ] diff --git a/docs/test_doc_build.py b/docs/test_doc_build.py index 275a4468c..63b03dcad 100644 --- a/docs/test_doc_build.py +++ b/docs/test_doc_build.py @@ -1,90 +1,89 @@ import json import os import subprocess import pytest -import six from dagster import check from dagster.utils import file_relative_path def git_repo_root(): - return six.ensure_str(subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip()) + return subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip() def assert_documented_exports(module_name, module, whitelist=None): whitelist = check.opt_set_param(whitelist, "whitelist") all_exports = module.__all__ path_to_export_index = os.path.join(git_repo_root(), "docs/next/src/data/exportindex.json") with open(path_to_export_index, "r") as f: export_index = json.load(f) documented_exports = set(export_index[module_name]) undocumented_exports = set(all_exports).difference(documented_exports).difference(whitelist) assert len(undocumented_exports) == 0, "Top level exports {} are not documented".format( undocumented_exports ) def test_documented_exports(): # If this test is failing, you have added a top level export to a module that is undocumented. # Make sure you include documentation for the new export in the appropriate file under # docs/sections/api/apidocs and add a docblock to the definition. import dagster import dagster_gcp import dagster_pandas modules = { "dagster": { "module": dagster, "whitelist": { # NOTE: Do not add any additional entries to this whitelist "ScalarUnion", "DefaultRunLauncher", "build_intermediate_storage_from_object_store", "SolidExecutionContext", "SerializationStrategy", "Materialization", "local_file_manager", "SystemStorageData", }, }, "dagster_gcp": {"module": dagster_gcp}, "dagster_pandas": { "module": dagster_pandas, "whitelist": { # NOTE: Do not add any additional entries to this whitelist "ConstraintWithMetadataException", "all_unique_validator", "ColumnWithMetadataException", "categorical_column_validator_factory", "MultiConstraintWithMetadata", "MultiColumnConstraintWithMetadata", "non_null_validation", "StrictColumnsWithMetadata", "MultiAggregateConstraintWithMetadata", "ConstraintWithMetadata", "dtype_in_set_validation_factory", "nonnull", "create_structured_dataframe_type", "column_range_validation_factory", }, }, } for module_name, value in modules.items(): module = value["module"] whitelist = value.get("whitelist") assert_documented_exports(module_name, module, whitelist) @pytest.mark.docs def test_build_all_docs(): pwd = os.getcwd() try: os.chdir(file_relative_path(__file__, ".")) subprocess.check_output(["make", "clean"]) subprocess.check_output(["make", "html"]) finally: os.chdir(pwd) diff --git a/examples/airline_demo/airline_demo_tests/conftest.py b/examples/airline_demo/airline_demo_tests/conftest.py index d94a16832..e92903c6f 100644 --- a/examples/airline_demo/airline_demo_tests/conftest.py +++ b/examples/airline_demo/airline_demo_tests/conftest.py @@ -1,83 +1,83 @@ import os import subprocess import pytest from dagster.utils import pushd, script_relative_path from dagster_postgres.utils import get_conn_string, wait_for_connection BUILDKITE = bool(os.getenv("BUILDKITE")) def is_postgres_running(): try: output = subprocess.check_output( [ "docker", "container", "ps", "-f", "name=test-postgres-db-airline", "-f", "status=running", ] ) - decoded = output.decode() + decoded = output.decode("utf-8") lines = decoded.split("\n") # header, one line for container, trailing \n return len(lines) == 3 except: # pylint: disable=bare-except return False @pytest.fixture(scope="session") def spark_config(): spark_packages = [ "com.databricks:spark-avro_2.11:3.0.0", "com.databricks:spark-redshift_2.11:2.0.1", "com.databricks:spark-csv_2.11:1.5.0", "org.postgresql:postgresql:42.2.5", "org.apache.hadoop:hadoop-aws:2.6.5", "com.amazonaws:aws-java-sdk:1.7.4", ] return {"spark": {"jars": {"packages": ",".join(spark_packages)}}} @pytest.fixture(scope="session") def pg_hostname(): # In buildkite we get the ip address from this variable (see buildkite code for commentary) # Otherwise assume local development and assume localhost env_name = "POSTGRES_TEST_DB_HOST" if env_name not in os.environ: os.environ[env_name] = "localhost" return os.environ[env_name] @pytest.fixture(scope="function") def postgres(pg_hostname): # pylint: disable=redefined-outer-name if BUILDKITE: yield return script_path = script_relative_path(".") if not is_postgres_running(): with pushd(script_path): try: subprocess.check_output(["docker-compose", "stop", "test-postgres-db-airline"]) subprocess.check_output(["docker-compose", "rm", "-f", "test-postgres-db-airline"]) except Exception: # pylint: disable=broad-except pass subprocess.check_output(["docker-compose", "up", "-d", "test-postgres-db-airline"]) wait_for_connection( get_conn_string(username="test", password="test", hostname=pg_hostname, db_name="test") ) yield @pytest.fixture(scope="session") def s3_bucket(): yield "dagster-scratch-80542c2" diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py index eaa443dde..ecf9a247f 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py @@ -1,102 +1,102 @@ import zipfile import boto3 from airline_demo.unzip_file_handle import unzip_file_handle from dagster import ( LocalFileHandle, ModeDefinition, OutputDefinition, ResourceDefinition, execute_pipeline, local_file_manager, pipeline, solid, ) from dagster.utils.test import get_temp_file_name from dagster_aws.s3 import S3FileHandle, S3FileManager, s3_intermediate_storage from moto import mock_s3 # for dep graphs def write_zip_file_to_disk(zip_file_path, archive_member, data): with zipfile.ZipFile(zip_file_path, mode="w") as archive: archive.writestr(data=data, zinfo_or_arcname=archive_member) def test_unzip_file_handle(): - data = "foo".encode() + data = b"foo" with get_temp_file_name() as zip_file_name: write_zip_file_to_disk(zip_file_name, "some_archive_member", data) @solid def to_zip_file_handle(_): return LocalFileHandle(zip_file_name) @pipeline(mode_defs=[ModeDefinition(resource_defs={"file_manager": local_file_manager})]) def do_test_unzip_file_handle(): return unzip_file_handle(to_zip_file_handle()) result = execute_pipeline( do_test_unzip_file_handle, run_config={ "solids": { "unzip_file_handle": { "inputs": {"archive_member": {"value": "some_archive_member"}} } } }, ) assert result.success @mock_s3 def test_unzip_file_handle_on_fake_s3(): - foo_bytes = "foo".encode() + foo_bytes = b"foo" @solid(required_resource_keys={"file_manager"}, output_defs=[OutputDefinition(S3FileHandle)]) def write_zipped_file_to_s3_store(context): with get_temp_file_name() as zip_file_name: write_zip_file_to_disk(zip_file_name, "an_archive_member", foo_bytes) with open(zip_file_name, "rb") as ff: s3_file_handle = context.resources.file_manager.write_data(ff.read()) return s3_file_handle # Uses mock S3 s3 = boto3.client("s3") s3.create_bucket(Bucket="some-bucket") file_manager = S3FileManager(s3_session=s3, s3_bucket="some-bucket", s3_base_key="dagster") @pipeline( mode_defs=[ ModeDefinition( resource_defs={ "s3": ResourceDefinition.hardcoded_resource(s3), "file_manager": ResourceDefinition.hardcoded_resource(file_manager), }, intermediate_storage_defs=[s3_intermediate_storage], ) ] ) def do_test_unzip_file_handle_s3(): return unzip_file_handle(write_zipped_file_to_s3_store()) result = execute_pipeline( do_test_unzip_file_handle_s3, run_config={ "storage": {"s3": {"config": {"s3_bucket": "some-bucket"}}}, "solids": { "unzip_file_handle": {"inputs": {"archive_member": {"value": "an_archive_member"}}} }, }, ) assert result.success zipped_s3_file = result.result_for_solid("write_zipped_file_to_s3_store").output_value() unzipped_s3_file = result.result_for_solid("unzip_file_handle").output_value() bucket_keys = [obj["Key"] for obj in s3.list_objects(Bucket="some-bucket")["Contents"]] assert zipped_s3_file.s3_key in bucket_keys assert unzipped_s3_file.s3_key in bucket_keys diff --git a/examples/dbt_example/dbt_example_tests/conftest.py b/examples/dbt_example/dbt_example_tests/conftest.py index 726c6a160..cc919c8d9 100644 --- a/examples/dbt_example/dbt_example_tests/conftest.py +++ b/examples/dbt_example/dbt_example_tests/conftest.py @@ -1,82 +1,82 @@ import os import subprocess import pytest from dagster.utils import file_relative_path, pushd from dagster_postgres.utils import get_conn_string, wait_for_connection BUILDKITE = bool(os.getenv("BUILDKITE")) def is_postgres_running(): try: output = subprocess.check_output( [ "docker", "container", "ps", "-f", "name=dbt_example_postgresql", "-f", "status=running", ] ) - decoded = output.decode() + decoded = output.decode("utf-8") lines = decoded.split("\n") # header, one line for container, trailing \n return len(lines) == 3 except: # pylint: disable=bare-except return False @pytest.fixture(scope="session") def pg_hostname(): # In buildkite we get the ip address from this variable (see buildkite code for commentary) # Otherwise assume local development and assume localhost env_name = "DAGSTER_DBT_EXAMPLE_PGHOST" original_value = os.getenv(env_name) try: if original_value is None: os.environ[env_name] = "localhost" yield os.environ[env_name] finally: if original_value is None: del os.environ[env_name] @pytest.fixture(scope="function") def postgres(pg_hostname): # pylint: disable=redefined-outer-name conn_string = get_conn_string( username="dbt_example", password="dbt_example", hostname=pg_hostname, db_name="dbt_example", ) if not BUILDKITE: script_path = file_relative_path(__file__, ".") if not is_postgres_running(): with pushd(script_path): try: subprocess.check_output(["docker-compose", "stop", "dbt_example_postgresql"]) subprocess.check_output( ["docker-compose", "rm", "-f", "dbt_example_postgresql"] ) except Exception: # pylint: disable=broad-except pass subprocess.check_output(["docker-compose", "up", "-d", "dbt_example_postgresql"]) wait_for_connection(conn_string) old_env = None if os.getenv("DBT_EXAMPLE_CONN_STRING") is not None: old_env = os.getenv("DBT_EXAMPLE_CONN_STRING") try: os.environ["DBT_EXAMPLE_CONN_STRING"] = conn_string yield finally: if old_env is not None: os.environ["DBT_EXAMPLE_CONN_STRING"] = old_env else: del os.environ["DBT_EXAMPLE_CONN_STRING"] diff --git a/examples/docs_snippets/docs_snippets/legacy/data_science/download_file.py b/examples/docs_snippets/docs_snippets/legacy/data_science/download_file.py index 9b7d58e4a..15b0621fb 100644 --- a/examples/docs_snippets/docs_snippets/legacy/data_science/download_file.py +++ b/examples/docs_snippets/docs_snippets/legacy/data_science/download_file.py @@ -1,25 +1,26 @@ +from urllib.request import urlretrieve + from dagster import Field, OutputDefinition, String, solid from dagster.utils import script_relative_path -from six.moves.urllib.request import urlretrieve @solid( name="download_file", config_schema={ "url": Field(String, description="The URL from which to download the file"), "path": Field(String, description="The path to which to download the file"), }, output_defs=[ OutputDefinition( String, name="path", description="The path to which the file was downloaded" ) ], description=( "A simple utility solid that downloads a file from a URL to a path using " "urllib.urlretrieve" ), ) def download_file(context): output_path = script_relative_path(context.solid_config["path"]) urlretrieve(context.solid_config["url"], output_path) return output_path diff --git a/examples/legacy_examples/dagster_examples_tests/event_pipeline_demo_tests/conftest.py b/examples/legacy_examples/dagster_examples_tests/event_pipeline_demo_tests/conftest.py index 1f62216a4..092ca51fe 100644 --- a/examples/legacy_examples/dagster_examples_tests/event_pipeline_demo_tests/conftest.py +++ b/examples/legacy_examples/dagster_examples_tests/event_pipeline_demo_tests/conftest.py @@ -1,61 +1,60 @@ import os import subprocess import pytest -import six from dagster.seven import get_system_temp_directory from dagster.utils import mkdir_p @pytest.fixture(scope="session") def events_jar(): - git_repo_root = six.ensure_str( - subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip() + git_repo_root = ( + subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip() ) temp_dir = os.path.join( get_system_temp_directory(), "dagster_examples_tests", "event_pipeline_demo_tests" ) mkdir_p(temp_dir) dst = os.path.join(temp_dir, "events.jar") if os.path.exists(dst): print("events jar already exists, skipping") # pylint: disable=print-call else: subprocess.check_call( ["sbt", "events/assembly"], cwd=os.path.join(git_repo_root, "scala_modules") ) src = os.path.join( git_repo_root, "scala_modules", "events/target/scala-2.11/events-assembly-0.1.0-SNAPSHOT.jar", ) subprocess.check_call(["cp", src, dst]) yield dst @pytest.fixture(scope="session") def spark_home(): spark_home_already_set = os.getenv("SPARK_HOME") is not None try: if not spark_home_already_set: try: import pyspark os.environ["SPARK_HOME"] = os.path.dirname(pyspark.__file__) # We don't have pyspark on this machine, and no spark home set, so there's nothing we # can do. Just give up - this fixture will end up yielding None except ModuleNotFoundError: pass yield os.getenv("SPARK_HOME") finally: # we set it, so clean up after ourselves if not spark_home_already_set and "SPARK_HOME" in os.environ: del os.environ["SPARK_HOME"] diff --git a/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/helm.py b/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/helm.py index bae9d5f45..b7155b848 100644 --- a/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/helm.py +++ b/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/helm.py @@ -1,610 +1,609 @@ # pylint: disable=print-call import base64 import os import subprocess import time from contextlib import contextmanager import kubernetes import pytest -import six import yaml from dagster import check from dagster.utils import git_repository_root from dagster_k8s.utils import wait_for_pod from .integration_utils import IS_BUILDKITE, check_output, get_test_namespace, image_pull_policy TEST_AWS_CONFIGMAP_NAME = "test-aws-env-configmap" TEST_CONFIGMAP_NAME = "test-env-configmap" TEST_SECRET_NAME = "test-env-secret" # By default, dagster.workers.fullname is ReleaseName-celery-workers CELERY_WORKER_NAME_PREFIX = "dagster-celery-workers" @contextmanager def _helm_namespace_helper(docker_image, helm_chart_fn, request): """If an existing Helm chart namespace is specified via pytest CLI with the argument --existing-helm-namespace, we will use that chart. Otherwise, provision a test namespace and install Helm chart into that namespace. Yields the Helm chart namespace. """ existing_helm_namespace = request.config.getoption("--existing-helm-namespace") if existing_helm_namespace: yield existing_helm_namespace else: # Never bother cleaning up on Buildkite if IS_BUILDKITE: should_cleanup = False # Otherwise, always clean up unless --no-cleanup specified else: should_cleanup = not request.config.getoption("--no-cleanup") with get_helm_test_namespace(should_cleanup) as namespace: with helm_test_resources(namespace, should_cleanup): with helm_chart_fn(namespace, docker_image, should_cleanup): print("Helm chart successfully installed in namespace %s" % namespace) yield namespace @pytest.fixture(scope="session") def helm_namespace_for_user_deployments( dagster_docker_image, cluster_provider, request ): # pylint: disable=unused-argument, redefined-outer-name with _helm_namespace_helper( dagster_docker_image, helm_chart_for_user_deployments, request ) as namespace: yield namespace @pytest.fixture(scope="session") def helm_namespace_for_daemon( dagster_docker_image, cluster_provider, request ): # pylint: disable=unused-argument, redefined-outer-name with _helm_namespace_helper(dagster_docker_image, helm_chart_for_daemon, request) as namespace: yield namespace @pytest.fixture(scope="session") def helm_namespace( dagster_docker_image, cluster_provider, request ): # pylint: disable=unused-argument, redefined-outer-name with _helm_namespace_helper(dagster_docker_image, helm_chart, request) as namespace: yield namespace @pytest.fixture(scope="session") def helm_namespace_for_k8s_run_launcher( dagster_docker_image, cluster_provider, request ): # pylint: disable=unused-argument, redefined-outer-name with _helm_namespace_helper( dagster_docker_image, helm_chart_for_k8s_run_launcher, request ) as namespace: yield namespace @contextmanager def get_helm_test_namespace(should_cleanup=True): # Will be something like dagster-test-3fcd70 to avoid ns collisions in shared test environment namespace = get_test_namespace() print("--- \033[32m:k8s: Creating test namespace %s\033[0m" % namespace) kube_api = kubernetes.client.CoreV1Api() try: print("Creating namespace %s" % namespace) kube_namespace = kubernetes.client.V1Namespace( metadata=kubernetes.client.V1ObjectMeta(name=namespace) ) kube_api.create_namespace(kube_namespace) yield namespace finally: # Can skip this step as a time saver when we're going to destroy the cluster anyway, e.g. # w/ a kind cluster if should_cleanup: print("Deleting namespace %s" % namespace) kube_api.delete_namespace(name=namespace) @contextmanager def helm_test_resources(namespace, should_cleanup=True): """Create a couple of resources to test Helm interaction w/ pre-existing resources. """ check.str_param(namespace, "namespace") check.bool_param(should_cleanup, "should_cleanup") try: print( "Creating k8s test objects ConfigMap %s and Secret %s" % (TEST_CONFIGMAP_NAME, TEST_SECRET_NAME) ) kube_api = kubernetes.client.CoreV1Api() configmap = kubernetes.client.V1ConfigMap( api_version="v1", kind="ConfigMap", data={"TEST_ENV_VAR": "foobar"}, metadata=kubernetes.client.V1ObjectMeta(name=TEST_CONFIGMAP_NAME), ) kube_api.create_namespaced_config_map(namespace=namespace, body=configmap) if not IS_BUILDKITE: aws_data = { "AWS_ACCOUNT_ID": os.getenv("AWS_ACCOUNT_ID"), "AWS_ACCESS_KEY_ID": os.getenv("AWS_ACCESS_KEY_ID"), "AWS_SECRET_ACCESS_KEY": os.getenv("AWS_SECRET_ACCESS_KEY"), } if not aws_data["AWS_ACCESS_KEY_ID"] or not aws_data["AWS_SECRET_ACCESS_KEY"]: raise Exception( "Must have AWS credentials set in AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY " "to be able to run Helm tests locally" ) print("Creating ConfigMap %s with AWS credentials" % (TEST_AWS_CONFIGMAP_NAME)) aws_configmap = kubernetes.client.V1ConfigMap( api_version="v1", kind="ConfigMap", data=aws_data, metadata=kubernetes.client.V1ObjectMeta(name=TEST_AWS_CONFIGMAP_NAME), ) kube_api.create_namespaced_config_map(namespace=namespace, body=aws_configmap) # Secret values are expected to be base64 encoded - secret_val = six.ensure_str(base64.b64encode(six.ensure_binary("foobar"))) + secret_val = base64.b64encode(b"foobar").decode("utf-8") secret = kubernetes.client.V1Secret( api_version="v1", kind="Secret", data={"TEST_SECRET_ENV_VAR": secret_val}, metadata=kubernetes.client.V1ObjectMeta(name=TEST_SECRET_NAME), ) kube_api.create_namespaced_secret(namespace=namespace, body=secret) yield finally: # Can skip this step as a time saver when we're going to destroy the cluster anyway, e.g. # w/ a kind cluster if should_cleanup: kube_api.delete_namespaced_config_map(name=TEST_CONFIGMAP_NAME, namespace=namespace) kube_api.delete_namespaced_secret(name=TEST_SECRET_NAME, namespace=namespace) @contextmanager def _helm_chart_helper(namespace, should_cleanup, helm_config, helm_install_name): """Install helm chart. """ check.str_param(namespace, "namespace") check.bool_param(should_cleanup, "should_cleanup") check.str_param(helm_install_name, "helm_install_name") print("--- \033[32m:helm: Installing Helm chart {}\033[0m".format(helm_install_name)) try: helm_config_yaml = yaml.dump(helm_config, default_flow_style=False) helm_cmd = [ "helm", "install", "--namespace", namespace, "-f", "-", "dagster", os.path.join(git_repository_root(), "helm", "dagster"), ] print("Running Helm Install: \n", " ".join(helm_cmd), "\nWith config:\n", helm_config_yaml) p = subprocess.Popen( helm_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - stdout, stderr = p.communicate(six.ensure_binary(helm_config_yaml)) + stdout, stderr = p.communicate(helm_config_yaml.encode("utf-8")) print("Helm install completed with stdout: ", stdout) print("Helm install completed with stderr: ", stderr) assert p.returncode == 0 # Wait for Dagit pod to be ready (won't actually stay up w/out js rebuild) kube_api = kubernetes.client.CoreV1Api() print("Waiting for Dagit pod to be ready...") dagit_pod = None while dagit_pod is None: pods = kube_api.list_namespaced_pod(namespace=namespace) pod_names = [p.metadata.name for p in pods.items if "dagit" in p.metadata.name] if pod_names: dagit_pod = pod_names[0] time.sleep(1) # Wait for Celery worker queues to become ready pods = kubernetes.client.CoreV1Api().list_namespaced_pod(namespace=namespace) pod_names = [ p.metadata.name for p in pods.items if CELERY_WORKER_NAME_PREFIX in p.metadata.name ] if helm_config.get("runLauncher").get("type") == "CeleryK8sRunLauncher": worker_queues = ( helm_config.get("runLauncher") .get("config") .get("celeryK8sRunLauncher") .get("workerQueues", []) ) for queue in worker_queues: num_pods_for_queue = len( [ pod_name for pod_name in pod_names if f"{CELERY_WORKER_NAME_PREFIX}-{queue.get('name')}" in pod_name ] ) assert queue.get("replicaCount") == num_pods_for_queue print("Waiting for celery workers") for pod_name in pod_names: print("Waiting for Celery worker pod %s" % pod_name) wait_for_pod(pod_name, namespace=namespace) rabbitmq_enabled = ("rabbitmq" not in helm_config) or helm_config.get("rabbitmq") if rabbitmq_enabled: print("Waiting for rabbitmq pod to exist...") while True: pods = kube_api.list_namespaced_pod(namespace=namespace) pod_names = [ p.metadata.name for p in pods.items if "rabbitmq" in p.metadata.name ] if pod_names: assert len(pod_names) == 1 print("Waiting for rabbitmq pod to be ready: " + str(pod_names[0])) wait_for_pod(pod_names[0], namespace=namespace) break time.sleep(1) else: assert ( len(pod_names) == 0 ), "celery-worker pods {pod_names} exists when celery is not enabled.".format( pod_names=pod_names ) if helm_config.get("userDeployments") and helm_config.get("userDeployments", {}).get( "enabled" ): # Wait for user code deployments to be ready print("Waiting for user code deployments") pods = kubernetes.client.CoreV1Api().list_namespaced_pod(namespace=namespace) pod_names = [ p.metadata.name for p in pods.items if "user-code-deployment" in p.metadata.name ] for pod_name in pod_names: print("Waiting for user code deployment pod %s" % pod_name) wait_for_pod(pod_name, namespace=namespace) yield finally: # Can skip this step as a time saver when we're going to destroy the cluster anyway, e.g. # w/ a kind cluster if should_cleanup: print("Uninstalling helm chart") check_output( ["helm", "uninstall", "dagster", "--namespace", namespace], cwd=git_repository_root(), ) @contextmanager def helm_chart(namespace, docker_image, should_cleanup=True): check.str_param(namespace, "namespace") check.str_param(docker_image, "docker_image") check.bool_param(should_cleanup, "should_cleanup") repository, tag = docker_image.split(":") pull_policy = image_pull_policy() helm_config = { "dagit": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "env": {"TEST_SET_ENV_VAR": "test_dagit_env_var"}, "envConfigMaps": [TEST_CONFIGMAP_NAME], "envSecrets": [TEST_SECRET_NAME], "livenessProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "periodSeconds": 20, "failureThreshold": 3, }, "startupProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "failureThreshold": 6, "periodSeconds": 10, }, }, "flower": { "enabled": True, "livenessProbe": { "tcpSocket": {"port": "flower"}, "periodSeconds": 20, "failureThreshold": 3, }, "startupProbe": { "tcpSocket": {"port": "flower"}, "failureThreshold": 6, "periodSeconds": 10, }, }, "runLauncher": { "type": "CeleryK8sRunLauncher", "config": { "celeryK8sRunLauncher": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "workerQueues": [ {"name": "dagster", "replicaCount": 2}, {"name": "extra-queue-1", "replicaCount": 1}, ], "livenessProbe": { "initialDelaySeconds": 15, "periodSeconds": 10, "timeoutSeconds": 10, "successThreshold": 1, "failureThreshold": 3, }, }, }, }, "ingress": { "enabled": True, "dagit": {"host": "dagit.example.com"}, "flower": {"flower": "flower.example.com"}, }, "scheduler": { "type": "K8sScheduler", "config": {"k8sScheduler": {"schedulerNamespace": namespace}}, }, "serviceAccount": {"name": "dagit-admin"}, "postgresqlPassword": "test", "postgresqlDatabase": "test", "postgresqlUser": "test", "dagsterDaemon": {"enabled": False}, } with _helm_chart_helper(namespace, should_cleanup, helm_config, helm_install_name="helm_chart"): yield @contextmanager def helm_chart_for_k8s_run_launcher(namespace, docker_image, should_cleanup=True): check.str_param(namespace, "namespace") check.str_param(docker_image, "docker_image") check.bool_param(should_cleanup, "should_cleanup") repository, tag = docker_image.split(":") pull_policy = image_pull_policy() helm_config = { "dagit": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "env": {"TEST_SET_ENV_VAR": "test_dagit_env_var"}, "envConfigMaps": [TEST_CONFIGMAP_NAME], "envSecrets": [TEST_SECRET_NAME], "livenessProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "periodSeconds": 20, "failureThreshold": 3, }, "startupProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "failureThreshold": 6, "periodSeconds": 10, }, }, "runLauncher": { "type": "K8sRunLauncher", "config": {"k8sRunLauncher": {"jobNamespace": namespace}}, }, "scheduler": { "type": "K8sScheduler", "config": {"k8sScheduler": {"schedulerNamespace": namespace}}, }, "serviceAccount": {"name": "dagit-admin"}, "postgresqlPassword": "test", "postgresqlDatabase": "test", "postgresqlUser": "test", "dagsterDaemon": {"enabled": False}, } with _helm_chart_helper( namespace, should_cleanup, helm_config, helm_install_name="helm_chart_for_k8s_run_launcher" ): yield @contextmanager def helm_chart_for_user_deployments(namespace, docker_image, should_cleanup=True): check.str_param(namespace, "namespace") check.str_param(docker_image, "docker_image") check.bool_param(should_cleanup, "should_cleanup") repository, tag = docker_image.split(":") pull_policy = image_pull_policy() helm_config = { "userDeployments": { "enabled": True, "deployments": [ { "name": "user-code-deployment-1", "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "dagsterApiGrpcArgs": [ "-m", "dagster_test.test_project.test_pipelines.repo", "-a", "define_demo_execution_repo", ], "port": 3030, "replicaCount": 1, } ], }, "dagit": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "env": {"TEST_SET_ENV_VAR": "test_dagit_env_var"}, "envConfigMaps": [TEST_CONFIGMAP_NAME], "envSecrets": [TEST_SECRET_NAME], "livenessProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "periodSeconds": 20, "failureThreshold": 3, }, "startupProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "failureThreshold": 6, "periodSeconds": 10, }, }, "flower": { "livenessProbe": { "tcpSocket": {"port": "flower"}, "periodSeconds": 20, "failureThreshold": 3, }, "startupProbe": { "tcpSocket": {"port": "flower"}, "failureThreshold": 6, "periodSeconds": 10, }, }, "runLauncher": { "type": "CeleryK8sRunLauncher", "config": { "celeryK8sRunLauncher": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "workerQueues": [ {"name": "dagster", "replicaCount": 2}, {"name": "extra-queue-1", "replicaCount": 1}, ], "livenessProbe": { "initialDelaySeconds": 15, "periodSeconds": 10, "timeoutSeconds": 10, "successThreshold": 1, "failureThreshold": 3, }, "configSource": { "broker_transport_options": {"priority_steps": [9]}, "worker_concurrency": 1, }, } }, }, "scheduler": { "type": "K8sScheduler", "config": {"k8sScheduler": {"schedulerNamespace": namespace}}, }, "serviceAccount": {"name": "dagit-admin"}, "postgresqlPassword": "test", "postgresqlDatabase": "test", "postgresqlUser": "test", "dagsterDaemon": {"enabled": False}, } with _helm_chart_helper( namespace, should_cleanup, helm_config, helm_install_name="helm_chart_for_user_deployments" ): yield @contextmanager def helm_chart_for_daemon(namespace, docker_image, should_cleanup=True): check.str_param(namespace, "namespace") check.str_param(docker_image, "docker_image") check.bool_param(should_cleanup, "should_cleanup") repository, tag = docker_image.split(":") pull_policy = image_pull_policy() helm_config = { "userDeployments": { "enabled": True, "deployments": [ { "name": "user-code-deployment-1", "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "dagsterApiGrpcArgs": [ "-m", "dagster_test.test_project.test_pipelines.repo", "-a", "define_demo_execution_repo", ], "port": 3030, "env": {"BUILDKITE": os.getenv("BUILDKITE")}, "annotations": {"dagster-integration-tests": "ucd-1-pod-annotation"}, "service": { "annotations": {"dagster-integration-tests": "ucd-1-svc-annotation"} }, "replicaCount": 1, } ], }, "dagit": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "env": {"TEST_SET_ENV_VAR": "test_dagit_env_var"}, "envConfigMaps": [TEST_CONFIGMAP_NAME], "envSecrets": [TEST_SECRET_NAME], "livenessProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "periodSeconds": 20, "failureThreshold": 3, }, "startupProbe": { "httpGet": {"path": "/dagit_info", "port": 80}, "failureThreshold": 6, "periodSeconds": 10, }, "annotations": {"dagster-integration-tests": "dagit-pod-annotation"}, "service": {"annotations": {"dagster-integration-tests": "dagit-svc-annotation"}}, }, "runLauncher": { "type": "CeleryK8sRunLauncher", "config": { "celeryK8sRunLauncher": { "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "workerQueues": [ {"name": "dagster", "replicaCount": 2}, {"name": "extra-queue-1", "replicaCount": 1}, ], "livenessProbe": { "initialDelaySeconds": 15, "periodSeconds": 10, "timeoutSeconds": 10, "successThreshold": 1, "failureThreshold": 3, }, "configSource": { "broker_transport_options": {"priority_steps": [9]}, "worker_concurrency": 1, }, "annotations": {"dagster-integration-tests": "celery-pod-annotation"}, }, }, }, "scheduler": {"type": "DagsterDaemonScheduler", "config": {}}, "serviceAccount": {"name": "dagit-admin"}, "postgresqlPassword": "test", "postgresqlDatabase": "test", "postgresqlUser": "test", "dagsterDaemon": { "enabled": True, "image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}, "queuedRunCoordinator": {"enabled": True}, "env": {"BUILDKITE": os.getenv("BUILDKITE")}, "annotations": {"dagster-integration-tests": "daemon-pod-annotation"}, }, # Used to set the environment variables in dagster.shared_env that determine the run config "pipelineRun": {"image": {"repository": repository, "tag": tag, "pullPolicy": pull_policy}}, } with _helm_chart_helper( namespace, should_cleanup, helm_config, helm_install_name="helm_chart_for_daemon" ): yield diff --git a/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/integration_utils.py b/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/integration_utils.py index 67ec83131..1aae50b45 100644 --- a/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/integration_utils.py +++ b/integration_tests/python_modules/dagster-k8s-test-infra/dagster_k8s_test_infra/integration_utils.py @@ -1,69 +1,67 @@ import os import random import subprocess -import six - IS_BUILDKITE = os.getenv("BUILDKITE") is not None def image_pull_policy(): # This is because when running local tests, we need to load the image into the kind cluster (and # then not attempt to pull it) because we don't want to require credentials for a private # registry / pollute the private registry / set up and network a local registry as a condition # of running tests if IS_BUILDKITE: return "Always" else: return "IfNotPresent" def check_output(*args, **kwargs): try: return subprocess.check_output(*args, **kwargs) except subprocess.CalledProcessError as exc: - output = exc.output.decode() - six.raise_from(Exception(output), exc) + output = exc.output.decode("utf-8") + raise Exception(output) from exc def which_(exe): """Uses distutils to look for an executable, mimicking unix which""" from distutils import spawn # pylint: disable=no-name-in-module # https://github.com/PyCQA/pylint/issues/73 return spawn.find_executable(exe) def get_test_namespace(): namespace_suffix = hex(random.randint(0, 16 ** 6))[2:] return "dagster-test-%s" % namespace_suffix def within_docker(): """detect if we're running inside of a docker container from: https://stackoverflow.com/a/48710609/11295366 """ cgroup_path = "/proc/self/cgroup" return ( os.path.exists("/.dockerenv") or os.path.isfile(cgroup_path) and any("docker" in line for line in open(cgroup_path)) ) def remove_none_recursively(obj): """Remove none values from a dict. This is used here to support comparing provided config vs. config we retrive from kubernetes, which returns all fields, even those which have no value configured. """ if isinstance(obj, (list, tuple, set)): return type(obj)(remove_none_recursively(x) for x in obj if x is not None) elif isinstance(obj, dict): return type(obj)( (remove_none_recursively(k), remove_none_recursively(v)) for k, v in obj.items() if k is not None and v is not None ) else: return obj diff --git a/python_modules/automation/automation/docs/check_library_docs.py b/python_modules/automation/automation/docs/check_library_docs.py index 7791280ac..64a1afba6 100644 --- a/python_modules/automation/automation/docs/check_library_docs.py +++ b/python_modules/automation/automation/docs/check_library_docs.py @@ -1,82 +1,81 @@ """This script is to ensure that we provide docs for every library that we create """ # pylint: disable=print-call import os import sys -import six from automation.git import git_repo_root EXPECTED_LIBRARY_README_CONTENTS = """ # {library} The docs for `{library}` can be found [here](https://docs.dagster.io/_apidocs/libraries/{library_underscore}). """.strip() def get_library_module_directories(): """List library module directories under python_modules/libraries. Returns: List(os.DirEntry): List of core module directories """ library_module_root_dir = os.path.join(git_repo_root(), "python_modules", "libraries") library_directories = [ dir_.name for dir_ in os.scandir(library_module_root_dir) if dir_.is_dir() and not dir_.name.startswith(".") ] return library_directories def check_readme_exists(readme_file, library_name): """Verify that a README.md is provided for a Dagster package """ exists = os.path.exists(readme_file) and os.path.isfile(readme_file) if not exists: print("Missing README.md for library %s!" % library_name) sys.exit(1) def check_readme_contents(readme_file, library_name): """Ensure README.md files have standardized contents. Some files are whitelisted until we have time to migrate their content to the API docs RST. """ expected = EXPECTED_LIBRARY_README_CONTENTS.format( library=library_name, library_underscore=library_name.replace("-", "_") ) with open(readme_file, "rb") as f: - contents = six.ensure_str(f.read()).strip() + contents = f.read().decode("utf-8").strip() if contents != expected: print("=" * 100) print("Readme %s contents do not match!" % readme_file) print("expected:\n%s" % expected) print("\n\nfound:\n%s" % contents) print("=" * 100) sys.exit(1) def check_api_docs(library_name): api_docs_root = os.path.join(git_repo_root(), "docs", "sections", "api", "apidocs", "libraries") underscore_name = library_name.replace("-", "_") if not os.path.exists(os.path.join(api_docs_root, "%s.rst" % underscore_name)): print("API docs not found for library %s!" % library_name) sys.exit(1) def validate_library_readmes(): dirs = get_library_module_directories() for library_name in dirs: library_root = os.path.join(git_repo_root(), "python_modules", "libraries") readme_file = os.path.join(library_root, library_name, "README.md") check_readme_exists(readme_file, library_name) check_readme_contents(readme_file, library_name) check_api_docs(library_name) print(":white_check_mark: All README.md contents exist and content validated!") print(":white_check_mark: All API docs exist!") diff --git a/python_modules/automation/automation/git.py b/python_modules/automation/automation/git.py index 9604a2640..c8a0e805b 100644 --- a/python_modules/automation/automation/git.py +++ b/python_modules/automation/automation/git.py @@ -1,149 +1,150 @@ import os import re import subprocess -import six - from .utils import check_output def git_check_status(): - changes = six.ensure_str(subprocess.check_output(["git", "status", "--porcelain"])) + changes = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip() + if changes != "": raise Exception( "Bailing: Cannot publish with changes present in git repo:\n{changes}".format( changes=changes ) ) def git_user(): - return six.ensure_str( - subprocess.check_output(["git", "config", "--get", "user.name"]).decode("utf-8").strip() - ) + return subprocess.check_output(["git", "config", "--get", "user.name"]).decode("utf-8").strip() def git_repo_root(): - return six.ensure_str(subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip()) + return subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip() def git_push(tag=None, dry_run=True, cwd=None): github_token = os.getenv("GITHUB_TOKEN") github_username = os.getenv("GITHUB_USERNAME") if github_token and github_username: if tag: check_output( [ "git", "push", "https://{github_username}:{github_token}@github.com/dagster-io/dagster.git".format( github_username=github_username, github_token=github_token ), tag, ], dry_run=dry_run, cwd=cwd, ) check_output( [ "git", "push", "https://{github_username}:{github_token}@github.com/dagster-io/dagster.git".format( github_username=github_username, github_token=github_token ), ], dry_run=dry_run, cwd=cwd, ) else: if tag: check_output(["git", "push", "origin", tag], dry_run=dry_run, cwd=cwd) check_output(["git", "push"], dry_run=dry_run, cwd=cwd) def get_git_tag(): try: git_tag = str( subprocess.check_output( ["git", "describe", "--exact-match", "--abbrev=0"], stderr=subprocess.STDOUT ) ).strip("'b\\n") except subprocess.CalledProcessError as exc_info: match = re.search( "fatal: no tag exactly matches '(?P[a-z0-9]+)'", str(exc_info.output) ) if match: raise Exception( "Bailing: there is no git tag for the current commit, {commit}".format( commit=match.group("commit") ) ) raise Exception(str(exc_info.output)) return git_tag def get_most_recent_git_tag(): try: - git_tag = str( + git_tag = ( subprocess.check_output(["git", "describe", "--abbrev=0"], stderr=subprocess.STDOUT) - ).strip("'b\\n") + .decode("utf-8") + .strip() + ) except subprocess.CalledProcessError as exc_info: raise Exception(str(exc_info.output)) return git_tag def get_git_repo_branch(cwd=None): - git_branch = six.ensure_str( + git_branch = ( subprocess.check_output(["git", "branch", "--show-current"], cwd=cwd) - ).strip() + .decode("utf-8") + .strip() + ) return git_branch def set_git_tag(tag, signed=False, dry_run=True): try: if signed: if not dry_run: subprocess.check_output( ["git", "tag", "-s", "-m", tag, tag], stderr=subprocess.STDOUT ) else: if not dry_run: subprocess.check_output( ["git", "tag", "-a", "-m", tag, tag], stderr=subprocess.STDOUT ) except subprocess.CalledProcessError as exc_info: match = re.search("error: gpg failed to sign the data", str(exc_info.output)) if match: raise Exception( "Bailing: cannot sign tag. You may find " "https://stackoverflow.com/q/39494631/324449 helpful. Original error " "output:\n{output}".format(output=str(exc_info.output)) ) match = re.search( r"fatal: tag \'(?P[\.a-z0-9]+)\' already exists", str(exc_info.output) ) if match: raise Exception( "Bailing: cannot release version tag {tag}: already exists".format( tag=match.group("tag") ) ) raise Exception(str(exc_info.output)) return tag def git_commit_updates(repo_dir, message): cmds = [ "git add -A", 'git commit -m "{}"'.format(message), ] print( # pylint: disable=print-call "Committing to {} with message {}".format(repo_dir, message) ) for cmd in cmds: subprocess.call(cmd, cwd=repo_dir, shell=True) diff --git a/python_modules/automation/automation/parse_dataproc_configs.py b/python_modules/automation/automation/parse_dataproc_configs.py index 4e472373d..11bd2c077 100644 --- a/python_modules/automation/automation/parse_dataproc_configs.py +++ b/python_modules/automation/automation/parse_dataproc_configs.py @@ -1,257 +1,257 @@ import os import pprint from collections import namedtuple import requests from .printer import IndentingBufferPrinter SCALAR_TYPES = { "string": "String", "boolean": "Bool", "number": "Int", "enumeration": "String", "integer": "Int", } class List: def __init__(self, inner_type): self.inner_type = inner_type class Enum: def __init__(self, name, enum_names, enum_descriptions): self.name = name self.enum_names = enum_names self.enum_descriptions = enum_descriptions def write(self, printer): printer.line(self.name.title() + " = Enum(") with printer.with_indent(): printer.line("name='{}',".format(self.name.title())) printer.line("enum_values=[") with printer.with_indent(): if self.enum_descriptions: for name, value in zip(self.enum_names, self.enum_descriptions): prefix = "EnumValue('{}', description='''".format(name) printer.block(value + "'''),", initial_indent=prefix) else: for name in self.enum_names: printer.line("EnumValue('{}'),".format(name)) printer.line("],") printer.line(")") class Field: """Field represents a field type that we're going to write out as a dagster config field, once we've pre-processed all custom types """ def __init__(self, fields, is_required, description): self.fields = fields self.is_required = is_required self.description = description def __repr__(self): return "Field(%s, %s, %s)" % ( pprint.pformat(self.fields), str(self.is_required), self.description, ) def _print_fields(self, printer): # Scalars if isinstance(self.fields, str): printer.append(self.fields) # Enums elif isinstance(self.fields, Enum): printer.append(self.fields.name) # Lists elif isinstance(self.fields, List): printer.append("[") self.fields.inner_type.write(printer, field_wrapped=False) printer.append("]") # Dicts else: printer.line("Shape(") with printer.with_indent(): printer.line("fields={") with printer.with_indent(): for (k, v) in self.fields.items(): # We need to skip "output" fields which are API responses, not queries if "Output only" in v.description: continue # This v is a terminal scalar type, print directly if isinstance(v, str): printer.line("'{}': {},".format(k, v)) # Recurse nested fields else: with printer.with_indent(): printer.append("'{}': ".format(k)) v.write(printer) printer.append(",") printer.line("},") printer.line(")") def write(self, printer, field_wrapped=True): """Use field_wrapped=False for Lists that should not be wrapped in Field() """ if not field_wrapped: self._print_fields(printer) return printer.read() printer.append("Field(") printer.line("") with printer.with_indent(): self._print_fields(printer) printer.append(",") # Print description if self.description: printer.block( self.description.replace("'", "\\'") + "''',", initial_indent="description='''" ) # Print is_required=True/False if defined; if not defined, default to True printer.line( "is_required=%s," % str(self.is_required if self.is_required is not None else True) ) printer.line(")") return printer.read() class ParsedConfig(namedtuple("_ParsedConfig", "name configs enums")): def __new__(cls, name, configs, enums): return super(ParsedConfig, cls).__new__(cls, name, configs, enums) def write_configs(self, base_path): configs_filename = "configs_%s.py" % self.name print("Writing", configs_filename) # pylint: disable=print-call with open(os.path.join(base_path, configs_filename), "wb") as f: f.write(self.configs) enums_filename = "types_%s.py" % self.name with open(os.path.join(base_path, enums_filename), "wb") as f: f.write(self.enums) class ConfigParser: def __init__(self, schemas): self.schemas = schemas # Stashing these in a global so that we can write out after we're done constructing configs self.all_enums = {} def extract_config(self, base_field, suffix): with IndentingBufferPrinter() as printer: printer.write_header() printer.line("from dagster import Bool, Field, Int, Permissive, Shape, String") printer.blank_line() # Optionally write enum includes if self.all_enums: printer.line( "from .types_{} import {}".format(suffix, ", ".join(self.all_enums.keys())) ) printer.blank_line() printer.line("def define_%s_config():" % suffix) with printer.with_indent(): printer.append("return ") base_field.write(printer) - return printer.read().strip().encode() + return printer.read().strip().encode("utf-8") def extract_enums(self): if not self.all_enums: return with IndentingBufferPrinter() as printer: printer.write_header() printer.line("from dagster import Enum, EnumValue") printer.blank_line() for enum in self.all_enums: self.all_enums[enum].write(printer) printer.blank_line() - return printer.read().strip().encode() + return printer.read().strip().encode("utf-8") def parse_object(self, obj, name=None, depth=0, enum_descriptions=None): # This is a reference to another object that we should substitute by recursing if "$ref" in obj: name = obj["$ref"] return self.parse_object(self.schemas.get(name), name, depth + 1) # Print type tree prefix = "|" + ("-" * 4 * depth) + " " if depth > 0 else "" print(prefix + (name or obj.get("type"))) # pylint: disable=print-call # Switch on object type obj_type = obj.get("type") # Handle enums if "enum" in obj: # I think this is a bug in the API JSON spec where enum descriptions are a level higher # than they should be for type "Component" and the name isn't there if name is None: name = "Component" enum = Enum(name, obj["enum"], enum_descriptions or obj.get("enumDescriptions")) self.all_enums[name] = enum fields = enum # Handle dicts / objects elif obj_type == "object": # This is a generic k:v map if "additionalProperties" in obj: fields = "Permissive()" else: fields = { k: self.parse_object(v, k, depth + 1) for k, v in obj["properties"].items() } # Handle arrays elif obj_type == "array": fields = List( self.parse_object( obj.get("items"), None, depth + 1, enum_descriptions=obj.get("enumDescriptions") ) ) # Scalars elif obj_type in SCALAR_TYPES: fields = SCALAR_TYPES.get(obj_type) # Should never get here else: raise Exception("unknown type: ", obj) return Field(fields, is_required=None, description=obj.get("description")) def extract_schema_for_object(self, object_name, name): # Reset enums for this object self.all_enums = {} obj = self.parse_object(self.schemas.get(object_name), object_name) return ParsedConfig( name=name, configs=self.extract_config(obj, name), enums=self.extract_enums() ) def main(): api_url = "https://www.googleapis.com/discovery/v1/apis/dataproc/v1/rest" base_path = "../libraries/dagster-gcp/dagster_gcp/dataproc/" json_schema = requests.get(api_url).json().get("schemas") c = ConfigParser(json_schema) parsed = c.extract_schema_for_object("Job", "dataproc_job") parsed.write_configs(base_path) parsed = c.extract_schema_for_object("ClusterConfig", "dataproc_cluster") parsed.write_configs(base_path) if __name__ == "__main__": main() diff --git a/python_modules/automation/automation/parse_spark_configs.py b/python_modules/automation/automation/parse_spark_configs.py index 4715d51ca..b66545251 100644 --- a/python_modules/automation/automation/parse_spark_configs.py +++ b/python_modules/automation/automation/parse_spark_configs.py @@ -1,287 +1,287 @@ """Spark config codegen. This script parses the Spark configuration parameters downloaded from the Spark Github repository, and codegens a file that contains dagster configurations for these parameters. """ import re import sys from collections import namedtuple from enum import Enum import click import requests from automation.printer import IndentingBufferPrinter SPARK_VERSION = "v2.4.0" TABLE_REGEX = r"### (.{,30}?)\n\n(.*?<\/table>)" WHITESPACE_REGEX = r"\s+" class ConfigType(Enum): STRING = "String" INT = "Int" FLOAT = "Float" BOOL = "Bool" MEMORY = "String" # TODO: We should handle memory field types TIME = "String" # TODO: We should handle time field types CONFIG_TYPES = { # # APPLICATION PROPERTIES "spark.app.name": ConfigType.STRING, "spark.driver.cores": ConfigType.INT, "spark.driver.maxResultSize": ConfigType.MEMORY, "spark.driver.memory": ConfigType.MEMORY, "spark.driver.memoryOverhead": ConfigType.MEMORY, "spark.executor.memory": ConfigType.MEMORY, "spark.executor.pyspark.memory": ConfigType.MEMORY, "spark.executor.memoryOverhead": ConfigType.MEMORY, "spark.extraListeners": ConfigType.STRING, "spark.local.dir": ConfigType.STRING, "spark.logConf": ConfigType.BOOL, # TODO: Validate against https://spark.apache.org/docs/latest/submitting-applications.html#master-urls "spark.master": ConfigType.STRING, # TODO: Validate against client/cluster *only*. "spark.submit.deployMode": ConfigType.STRING, "spark.log.callerContext": ConfigType.STRING, "spark.driver.supervise": ConfigType.BOOL, # # RUNTIME ENVIRONMENT "spark.driver.extraClassPath": ConfigType.STRING, "spark.driver.extraJavaOptions": ConfigType.STRING, "spark.driver.extraLibraryPath": ConfigType.STRING, "spark.driver.userClassPathFirst": ConfigType.BOOL, "spark.executor.extraClassPath": ConfigType.STRING, "spark.executor.extraJavaOptions": ConfigType.STRING, "spark.executor.extraLibraryPath": ConfigType.STRING, "spark.executor.logs.rolling.maxRetainedFiles": ConfigType.INT, "spark.executor.logs.rolling.enableCompression": ConfigType.BOOL, "spark.executor.logs.rolling.maxSize": ConfigType.INT, # TODO: Can only be 'time' or 'size' "spark.executor.logs.rolling.strategy": ConfigType.STRING, "spark.executor.logs.rolling.time.interval": ConfigType.STRING, "spark.executor.userClassPathFirst": ConfigType.BOOL, "spark.redaction.regex": ConfigType.STRING, "spark.python.profile": ConfigType.BOOL, # TODO: Should be a path? "spark.python.profile.dump": ConfigType.STRING, "spark.python.worker.memory": ConfigType.MEMORY, "spark.python.worker.reuse": ConfigType.BOOL, "spark.files": ConfigType.STRING, "spark.submit.pyFiles": ConfigType.STRING, "spark.jars": ConfigType.STRING, "spark.jars.packages": ConfigType.STRING, "spark.jars.excludes": ConfigType.STRING, "spark.jars.ivy": ConfigType.STRING, "spark.jars.ivySettings": ConfigType.STRING, "spark.jars.repositories": ConfigType.STRING, "spark.pyspark.driver.python": ConfigType.STRING, "spark.pyspark.python": ConfigType.STRING, # # SHUFFLE BEHAVIOR "spark.reducer.maxSizeInFlight": ConfigType.MEMORY, "spark.reducer.maxReqsInFlight": ConfigType.INT, "spark.reducer.maxBlocksInFlightPerAddress": ConfigType.INT, "spark.maxRemoteBlockSizeFetchToMem": ConfigType.INT, "spark.shuffle.compress": ConfigType.BOOL, "spark.shuffle.file.buffer": ConfigType.MEMORY, "spark.shuffle.io.maxRetries": ConfigType.INT, "spark.shuffle.io.numConnectionsPerPeer": ConfigType.INT, "spark.shuffle.io.preferDirectBufs": ConfigType.BOOL, "spark.shuffle.io.retryWait": ConfigType.TIME, "spark.shuffle.service.enabled": ConfigType.BOOL, "spark.shuffle.service.port": ConfigType.INT, "spark.shuffle.service.index.cache.size": ConfigType.MEMORY, "spark.shuffle.maxChunksBeingTransferred": ConfigType.INT, "spark.shuffle.sort.bypassMergeThreshold": ConfigType.INT, "spark.shuffle.spill.compress": ConfigType.BOOL, "spark.shuffle.accurateBlockThreshold": ConfigType.INT, "spark.shuffle.registration.timeout": ConfigType.INT, "spark.shuffle.registration.maxAttempts": ConfigType.INT, # # SPARK UI ### TODO # # COMPRESSION AND SERIALIZATION ### TODO # # MEMORY MANAGEMENT "spark.memory.fraction": ConfigType.FLOAT, "spark.memory.storageFraction": ConfigType.FLOAT, "spark.memory.offHeap.enabled": ConfigType.BOOL, "spark.memory.offHeap.size": ConfigType.INT, "spark.memory.useLegacyMode": ConfigType.BOOL, "spark.shuffle.memoryFraction": ConfigType.FLOAT, "spark.storage.memoryFraction": ConfigType.FLOAT, "spark.storage.unrollFraction": ConfigType.FLOAT, "spark.storage.replication.proactive": ConfigType.BOOL, "spark.cleaner.periodicGC.interval": ConfigType.TIME, "spark.cleaner.referenceTracking": ConfigType.BOOL, "spark.cleaner.referenceTracking.blocking": ConfigType.BOOL, "spark.cleaner.referenceTracking.blocking.shuffle": ConfigType.BOOL, "spark.cleaner.referenceTracking.cleanCheckpoints": ConfigType.BOOL, # # EXECUTION BEHAVIOR "spark.broadcast.blockSize": ConfigType.MEMORY, "spark.executor.cores": ConfigType.INT, "spark.default.parallelism": ConfigType.INT, "spark.executor.heartbeatInterval": ConfigType.TIME, "spark.files.fetchTimeout": ConfigType.TIME, "spark.files.useFetchCache": ConfigType.BOOL, "spark.files.overwrite": ConfigType.BOOL, "spark.files.maxPartitionBytes": ConfigType.INT, "spark.files.openCostInBytes": ConfigType.INT, "spark.hadoop.cloneConf": ConfigType.BOOL, "spark.hadoop.validateOutputSpecs": ConfigType.BOOL, "spark.storage.memoryMapThreshold": ConfigType.MEMORY, # TODO: Can only be 1 or 2. "spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version": ConfigType.INT, # # NETWORKING ### TODO # # SCHEDULING ### TODO # # DYNAMIC ALLOCATION ### TODO } class SparkConfig(namedtuple("_SparkConfig", "path default meaning")): def __new__(cls, path, default, meaning): # The original documentation strings include extraneous newlines, spaces return super(SparkConfig, cls).__new__( cls, path, re.sub(WHITESPACE_REGEX, " ", str(default)).strip(), re.sub(WHITESPACE_REGEX, " ", meaning).strip(), ) @property def split_path(self): return self.path.split(".") def write(self, printer): config_type = CONFIG_TYPES.get(self.path, ConfigType.STRING).value printer.append("Field(") with printer.with_indent(): printer.line("") printer.line("{config_type},".format(config_type=config_type)) printer.append('description="""') printer.append(self.meaning) printer.line('""",') # printer.line("default_value='{}',".format(self.default)) printer.line("is_required=False,") printer.append(")") class SparkConfigNode: def __init__(self, value=None): self.value = value self.children = {} def write(self, printer): if not self.children: self.value.write(printer) else: if self.value: retdict = {"root": self.value} retdict.update(self.children) else: retdict = self.children printer.append("Field(") printer.line("") with printer.with_indent(): printer.line("Permissive(") with printer.with_indent(): printer.line("fields={") with printer.with_indent(): for (k, v) in retdict.items(): with printer.with_indent(): printer.append('"{}": '.format(k)) v.write(printer) printer.line(",") printer.line("}") printer.line(")") printer.line(")") return printer.read() def extract(spark_docs_markdown_text): import pytablereader as ptr tables = re.findall(TABLE_REGEX, spark_docs_markdown_text, re.DOTALL | re.MULTILINE) spark_configs = [] for name, table in tables: parsed_table = list(ptr.HtmlTableTextLoader(table).load())[0] df = parsed_table.as_dataframe() for _, row in df.iterrows(): s = SparkConfig(row["Property Name"], row["Default"], name + ": " + row["Meaning"]) spark_configs.append(s) result = SparkConfigNode() for spark_config in spark_configs: # TODO: we should handle this thing if spark_config.path == "spark.executorEnv.[EnvironmentVariableName]": continue # Traverse spark.app.name key paths, creating SparkConfigNode at each tree node. # The leaves of the tree (stored in SparkConfigNode.value) are SparkConfig values. print(spark_config.path, file=sys.stderr) # pylint: disable=print-call key_path = spark_config.split_path d = result while key_path: key = key_path.pop(0) if key not in d.children: d.children[key] = SparkConfigNode() d = d.children[key] d.value = spark_config return result def serialize(result): with IndentingBufferPrinter() as printer: printer.write_header() printer.line("from dagster import Bool, Field, Float, Int, Permissive, String") printer.blank_line() printer.blank_line() printer.line("# pylint: disable=line-too-long") printer.line("def spark_config():") with printer.with_indent(): printer.append("return ") result.write(printer) printer.line("# pylint: enable=line-too-long") - return printer.read().strip().encode() + return printer.read().strip().encode("utf-8") @click.command() def run(): r = requests.get( "https://raw.githubusercontent.com/apache/spark/{}/docs/configuration.md".format( SPARK_VERSION ) ) result = extract(r.text) serialized = serialize(result) output_files = [ "python_modules/libraries/dagster-spark/dagster_spark/configs_spark.py", "python_modules/libraries/dagster-aws/dagster_aws/emr/configs_spark.py", ] for output_file in output_files: with open(output_file, "wb") as f: f.write(serialized) if __name__ == "__main__": run() # pylint:disable=E1120 diff --git a/python_modules/automation/automation/printer.py b/python_modules/automation/automation/printer.py index 9d6ca08a7..c954985d7 100644 --- a/python_modules/automation/automation/printer.py +++ b/python_modules/automation/automation/printer.py @@ -1,39 +1,39 @@ import os import sys +from io import StringIO from dagster.utils.indenting_printer import IndentingPrinter -from six import StringIO class IndentingBufferPrinter(IndentingPrinter): """Subclass of IndentingPrinter wrapping a StringIO.""" def __init__(self, indent_level=4, current_indent=0): self.buffer = StringIO() self.printer = lambda x: self.buffer.write(x + "\n") super(IndentingBufferPrinter, self).__init__( indent_level=indent_level, printer=self.printer, current_indent=current_indent ) def __enter__(self): return self def __exit__(self, _exception_type, _exception_value, _traceback): self.buffer.close() def read(self): """Get the value of the backing StringIO.""" return self.buffer.getvalue() def write_header(self): args = [os.path.basename(sys.argv[0])] + sys.argv[1:] self.line("'''NOTE: THIS FILE IS AUTO-GENERATED. DO NOT EDIT") self.blank_line() self.line("@generated") self.blank_line() self.line("Produced via:") self.line("\n\t".join("%s \\" % s for s in args if s != "--snapshot-update")) self.blank_line() self.line("'''") self.blank_line() self.blank_line() diff --git a/python_modules/dagit/dagit/cli.py b/python_modules/dagit/dagit/cli.py index 64749737f..6fece4774 100644 --- a/python_modules/dagit/dagit/cli.py +++ b/python_modules/dagit/dagit/cli.py @@ -1,218 +1,212 @@ import os import sys import tempfile import threading from contextlib import contextmanager import click -import six from dagster import check from dagster.cli.workspace import Workspace, get_workspace_from_kwargs, workspace_target_argument from dagster.cli.workspace.cli_target import WORKSPACE_TARGET_WARNING from dagster.core.instance import DagsterInstance from dagster.core.telemetry import ( START_DAGIT_WEBSERVER, log_action, log_workspace_stats, upload_logs, ) from dagster.utils import DEFAULT_WORKSPACE_YAML_FILENAME from gevent import pywsgi from geventwebsocket.handler import WebSocketHandler from .app import create_app_from_workspace from .version import __version__ def create_dagit_cli(): return ui # pylint: disable=no-value-for-parameter DEFAULT_DAGIT_HOST = "127.0.0.1" DEFAULT_DAGIT_PORT = 3000 DEFAULT_DB_STATEMENT_TIMEOUT = 5000 # 5 sec @click.command( name="ui", help=( "Run dagit. Loads a repository or pipeline.\n\n{warning}".format( warning=WORKSPACE_TARGET_WARNING ) + ( "\n\nExamples:" "\n\n1. dagit (works if .{default_filename} exists)" "\n\n2. dagit -w path/to/{default_filename}" "\n\n3. dagit -f path/to/file.py" "\n\n4. dagit -f path/to/file.py -d path/to/working_directory" "\n\n5. dagit -m some_module" "\n\n6. dagit -f path/to/file.py -a define_repo" "\n\n7. dagit -m some_module -a define_repo" "\n\n8. dagit -p 3333" "\n\nOptions can also provide arguments via environment variables prefixed with DAGIT" "\n\nFor example, DAGIT_PORT=3333 dagit" ).format(default_filename=DEFAULT_WORKSPACE_YAML_FILENAME) ), ) @workspace_target_argument @click.option( "--host", "-h", type=click.STRING, default=DEFAULT_DAGIT_HOST, help="Host to run server on", show_default=True, ) @click.option( "--port", "-p", type=click.INT, help="Port to run server on, default is {default_port}".format(default_port=DEFAULT_DAGIT_PORT), ) @click.option( "--path-prefix", "-l", type=click.STRING, default="", help="The path prefix where Dagit will be hosted (eg: /dagit)", show_default=True, ) @click.option( "--storage-fallback", help="Base directory for dagster storage if $DAGSTER_HOME is not set", default=None, type=click.Path(), ) @click.option( "--db-statement-timeout", help="The timeout in milliseconds to set on database statements sent " "to the DagsterInstance. Not respected in all configurations.", default=DEFAULT_DB_STATEMENT_TIMEOUT, type=click.INT, show_default=True, ) @click.version_option(version=__version__, prog_name="dagit") def ui(host, port, path_prefix, storage_fallback, db_statement_timeout, **kwargs): # add the path for the cwd so imports in dynamically loaded code work correctly sys.path.append(os.getcwd()) if port is None: port_lookup = True port = DEFAULT_DAGIT_PORT else: port_lookup = False if storage_fallback is None: with tempfile.TemporaryDirectory() as storage_fallback: host_dagit_ui( host, port, path_prefix, storage_fallback, db_statement_timeout, port_lookup, **kwargs, ) else: host_dagit_ui( host, port, path_prefix, storage_fallback, db_statement_timeout, port_lookup, **kwargs ) def host_dagit_ui( host, port, path_prefix, storage_fallback, db_statement_timeout, port_lookup=True, **kwargs ): with DagsterInstance.get(storage_fallback) as instance: # Allow the instance components to change behavior in the context of a long running server process instance.optimize_for_dagit(db_statement_timeout) with get_workspace_from_kwargs(kwargs) as workspace: if not workspace: raise Exception("Unable to load workspace with cli_args: {}".format(kwargs)) host_dagit_ui_with_workspace(instance, workspace, host, port, path_prefix, port_lookup) def host_dagit_ui_with_workspace(instance, workspace, host, port, path_prefix, port_lookup=True): check.inst_param(instance, "instance", DagsterInstance) check.inst_param(workspace, "workspace", Workspace) log_workspace_stats(instance, workspace) app = create_app_from_workspace(workspace, instance, path_prefix) start_server(instance, host, port, path_prefix, app, port_lookup) @contextmanager def uploading_logging_thread(): stop_event = threading.Event() logging_thread = threading.Thread( target=upload_logs, args=([stop_event]), name="telemetry-upload" ) try: logging_thread.start() yield finally: stop_event.set() logging_thread.join() def start_server(instance, host, port, path_prefix, app, port_lookup, port_lookup_attempts=0): server = pywsgi.WSGIServer((host, port), app, handler_class=WebSocketHandler) print( # pylint: disable=print-call "Serving on http://{host}:{port}{path_prefix} in process {pid}".format( host=host, port=port, path_prefix=path_prefix, pid=os.getpid() ) ) log_action(instance, START_DAGIT_WEBSERVER) with uploading_logging_thread(): try: server.serve_forever() except OSError as os_error: if "Address already in use" in str(os_error): if port_lookup and ( port_lookup_attempts > 0 or click.confirm( ( "Another process on your machine is already listening on port {port}. " "Would you like to run the app at another port instead?" ).format(port=port) ) ): port_lookup_attempts += 1 start_server( instance, host, port + port_lookup_attempts, path_prefix, app, True, port_lookup_attempts, ) else: - six.raise_from( - Exception( - ( - "Another process on your machine is already listening on port {port}. " - "It is possible that you have another instance of dagit " - "running somewhere using the same port. Or it could be another " - "random process. Either kill that process or use the -p option to " - "select another port." - ).format(port=port) - ), - os_error, - ) + raise Exception( + f"Another process on your machine is already listening on port {port}. " + "It is possible that you have another instance of dagit " + "running somewhere using the same port. Or it could be another " + "random process. Either kill that process or use the -p option to " + "select another port." + ) from os_error else: raise os_error cli = create_dagit_cli() def main(): # click magic cli(auto_envvar_prefix="DAGIT") # pylint:disable=E1120 diff --git a/python_modules/dagit/dagit/debug.py b/python_modules/dagit/dagit/debug.py index 35b2ba2f4..e02b47368 100644 --- a/python_modules/dagit/dagit/debug.py +++ b/python_modules/dagit/dagit/debug.py @@ -1,53 +1,53 @@ from gzip import GzipFile import click from dagster import DagsterInstance, check from dagster.cli.debug import DebugRunPayload from dagster.cli.workspace import Workspace from dagster.serdes import deserialize_json_to_dagster_namedtuple from .cli import DEFAULT_DAGIT_HOST, DEFAULT_DAGIT_PORT, host_dagit_ui_with_workspace @click.command( name="debug", help="Load dagit with an ephemeral instance loaded from a dagster debug export file.", ) @click.argument("input_files", nargs=-1, type=click.Path(exists=True)) @click.option( "--port", "-p", type=click.INT, help="Port to run server on, default is {default_port}".format(default_port=DEFAULT_DAGIT_PORT), default=DEFAULT_DAGIT_PORT, ) def dagit_debug_command(input_files, port): debug_payloads = [] for input_file in input_files: click.echo("Loading {} ...".format(input_file)) with GzipFile(input_file, "rb") as file: - blob = file.read().decode() + blob = file.read().decode("utf-8") debug_payload = deserialize_json_to_dagster_namedtuple(blob) check.invariant(isinstance(debug_payload, DebugRunPayload)) click.echo( "\trun_id: {} \n\tdagster version: {}".format( debug_payload.pipeline_run.run_id, debug_payload.version ) ) debug_payloads.append(debug_payload) instance = DagsterInstance.ephemeral(preload=debug_payloads) host_dagit_ui_with_workspace( workspace=Workspace([]), instance=instance, port=port, port_lookup=True, host=DEFAULT_DAGIT_HOST, path_prefix="", ) def main(): dagit_debug_command() # pylint: disable=no-value-for-parameter diff --git a/python_modules/dagit/dagit/format_error.py b/python_modules/dagit/dagit/format_error.py index 9f6c9677f..4dcdddd3f 100644 --- a/python_modules/dagit/dagit/format_error.py +++ b/python_modules/dagit/dagit/format_error.py @@ -1,32 +1,29 @@ from typing import Any, Dict from dagster.utils.log import get_stack_trace_array from graphql.error.base import GraphQLError -from six import text_type # based on default_format_error copied and pasted from graphql_server 1.1.1 -def format_error_with_stack_trace(error): +def format_error_with_stack_trace(error: Exception) -> Dict[str, Any]: + formatted_error = {"message": str(error)} # type: Dict[str, Any] - # type: (Exception) -> Dict[str, Any] - - formatted_error = {"message": text_type(error)} # type: Dict[str, Any] if isinstance(error, GraphQLError): if error.locations is not None: formatted_error["locations"] = [ {"line": loc.line, "column": loc.column} for loc in error.locations ] if error.path is not None: formatted_error["path"] = error.path # this is what is different about this implementation # we print out stack traces to ease debugging if hasattr(error, "original_error") and error.original_error: formatted_error["stack_trace"] = get_stack_trace_array(error.original_error) else: formatted_error["stack_trace"] = get_stack_trace_array(error) if hasattr(error, "__cause__") and error.__cause__: formatted_error["cause"] = format_error_with_stack_trace(error.__cause__) return formatted_error diff --git a/python_modules/dagster-test/dagster_test/test_project/__init__.py b/python_modules/dagster-test/dagster_test/test_project/__init__.py index 341af0ae2..4ba9263c8 100644 --- a/python_modules/dagster-test/dagster_test/test_project/__init__.py +++ b/python_modules/dagster-test/dagster_test/test_project/__init__.py @@ -1,268 +1,270 @@ import base64 import os import subprocess import sys from dagster import check from dagster.core.code_pointer import FileCodePointer from dagster.core.definitions.reconstructable import ( ReconstructablePipeline, ReconstructableRepository, ) from dagster.core.host_representation import ( ExternalPipeline, ExternalSchedule, InProcessRepositoryLocationOrigin, RepositoryLocation, RepositoryLocationHandle, ) from dagster.core.host_representation.origin import ( ExternalJobOrigin, ExternalPipelineOrigin, ExternalRepositoryOrigin, ) from dagster.core.origin import PipelinePythonOrigin, RepositoryPythonOrigin from dagster.serdes import whitelist_for_serdes from dagster.utils import file_relative_path, git_repository_root IS_BUILDKITE = os.getenv("BUILDKITE") is not None def get_test_repo_path(): return os.path.join( git_repository_root(), "python_modules", "dagster-test", "dagster_test", "test_project" ) def get_test_project_environments_path(): return os.path.join(get_test_repo_path(), "environments") def get_buildkite_registry_config(): import boto3 ecr_client = boto3.client("ecr", region_name="us-west-2") token = ecr_client.get_authorization_token() username, password = ( - base64.b64decode(token["authorizationData"][0]["authorizationToken"]).decode().split(":") + base64.b64decode(token["authorizationData"][0]["authorizationToken"]) + .decode("utf-8") + .split(":") ) registry = token["authorizationData"][0]["proxyEndpoint"] return { "url": registry, "username": username, "password": password, } def find_local_test_image(docker_image): import docker try: client = docker.from_env() client.images.get(docker_image) print( # pylint: disable=print-call "Found existing image tagged {image}, skipping image build. To rebuild, first run: " "docker rmi {image}".format(image=docker_image) ) except docker.errors.ImageNotFound: build_and_tag_test_image(docker_image) def build_and_tag_test_image(tag): check.str_param(tag, "tag") base_python = "3.7.8" # Build and tag local dagster test image return subprocess.check_output(["./build.sh", base_python, tag], cwd=get_test_repo_path()) def get_test_project_recon_pipeline(pipeline_name, container_image=None): return ReOriginatedReconstructablePipelineForTest( ReconstructableRepository.for_file( file_relative_path(__file__, "test_pipelines/repo.py"), "define_demo_execution_repo", container_image=container_image, ).get_reconstructable_pipeline(pipeline_name) ) class ReOriginatedReconstructablePipelineForTest(ReconstructablePipeline): def __new__( # pylint: disable=signature-differs cls, reconstructable_pipeline, ): return super(ReOriginatedReconstructablePipelineForTest, cls).__new__( cls, reconstructable_pipeline.repository, reconstructable_pipeline.pipeline_name, reconstructable_pipeline.solid_selection_str, reconstructable_pipeline.solids_to_execute, ) def get_python_origin(self): """ Hack! Inject origin that the docker-celery images will use. The BK image uses a different directory structure (/workdir/python_modules/dagster-test/dagster_test/test_project) than the test that creates the ReconstructablePipeline. As a result the normal origin won't work, we need to inject this one. """ return PipelinePythonOrigin( self.pipeline_name, RepositoryPythonOrigin( executable_path="python", code_pointer=FileCodePointer( "/dagster_test/test_project/test_pipelines/repo.py", "define_demo_execution_repo", ), container_image=self.repository.container_image, ), ) class ReOriginatedExternalPipelineForTest(ExternalPipeline): def __init__( self, external_pipeline, container_image=None, ): self._container_image = container_image super(ReOriginatedExternalPipelineForTest, self).__init__( external_pipeline.external_pipeline_data, external_pipeline.repository_handle, ) def get_python_origin(self): """ Hack! Inject origin that the k8s images will use. The BK image uses a different directory structure (/workdir/python_modules/dagster-test/dagster_test/test_project) than the images inside the kind cluster (/dagster_test/test_project). As a result the normal origin won't work, we need to inject this one. """ return PipelinePythonOrigin( self._pipeline_index.name, RepositoryPythonOrigin( executable_path="python", code_pointer=FileCodePointer( "/dagster_test/test_project/test_pipelines/repo.py", "define_demo_execution_repo", ), container_image=self._container_image, ), ) def get_external_origin(self): """ Hack! Inject origin that the k8s images will use. The BK image uses a different directory structure (/workdir/python_modules/dagster-test/dagster_test/test_project) than the images inside the kind cluster (/dagster_test/test_project). As a result the normal origin won't work, we need to inject this one. """ return ExternalPipelineOrigin( external_repository_origin=ExternalRepositoryOrigin( repository_location_origin=InProcessRepositoryLocationOrigin( recon_repo=ReconstructableRepository( pointer=FileCodePointer( python_file="/dagster_test/test_project/test_pipelines/repo.py", fn_name="define_demo_execution_repo", ) ) ), repository_name="demo_execution_repo", ), pipeline_name=self._pipeline_index.name, ) class ReOriginatedExternalScheduleForTest(ExternalSchedule): def __init__( self, external_schedule, container_image=None, ): self._container_image = container_image super(ReOriginatedExternalScheduleForTest, self).__init__( external_schedule._external_schedule_data, external_schedule.handle.repository_handle, ) def get_external_origin(self): """ Hack! Inject origin that the k8s images will use. The BK image uses a different directory structure (/workdir/python_modules/dagster-test/dagster_test/test_project) than the images inside the kind cluster (/dagster_test/test_project). As a result the normal origin won't work, we need to inject this one. """ return ExternalJobOrigin( external_repository_origin=ExternalRepositoryOrigin( repository_location_origin=InProcessRepositoryLocationOrigin( recon_repo=ReconstructableRepository( pointer=FileCodePointer( python_file="/dagster_test/test_project/test_pipelines/repo.py", fn_name="define_demo_execution_repo", ) ) ), repository_name="demo_execution_repo", ), job_name=self.name, ) def get_test_project_external_repo(container_image=None): return RepositoryLocation.from_handle( RepositoryLocationHandle.create_from_repository_location_origin( InProcessRepositoryLocationOrigin( ReconstructableRepository.for_file( file_relative_path(__file__, "test_pipelines/repo.py"), "define_demo_execution_repo", container_image=container_image, ) ) ) ).get_repository("demo_execution_repo") def get_test_project_external_pipeline(pipeline_name, container_image=None): return get_test_project_external_repo( container_image=container_image ).get_full_external_pipeline(pipeline_name) def get_test_project_external_schedule(schedule_name, container_image=None): return get_test_project_external_repo(container_image=container_image).get_external_schedule( schedule_name ) def get_test_project_docker_image(): docker_repository = os.getenv("DAGSTER_DOCKER_REPOSITORY") image_name = os.getenv("DAGSTER_DOCKER_IMAGE", "buildkite-test-image") docker_image_tag = os.getenv("DAGSTER_DOCKER_IMAGE_TAG") if IS_BUILDKITE: assert docker_image_tag is not None, ( "This test requires the environment variable DAGSTER_DOCKER_IMAGE_TAG to be set " "to proceed" ) assert docker_repository is not None, ( "This test requires the environment variable DAGSTER_DOCKER_REPOSITORY to be set " "to proceed" ) # This needs to be a domain name to avoid the k8s machinery automatically prefixing it with # `docker.io/` and attempting to pull images from Docker Hub if not docker_repository: docker_repository = "dagster.io.priv" if not docker_image_tag: # Detect the python version we're running on majmin = str(sys.version_info.major) + str(sys.version_info.minor) docker_image_tag = "py{majmin}-{image_version}".format( majmin=majmin, image_version="latest" ) final_docker_image = "{repository}/{image_name}:{tag}".format( repository=docker_repository, image_name=image_name, tag=docker_image_tag ) print("Using Docker image: %s" % final_docker_image) # pylint: disable=print-call return final_docker_image diff --git a/python_modules/dagster-test/dagster_test/toys/error_monster.py b/python_modules/dagster-test/dagster_test/toys/error_monster.py index 5654d6e60..5499b8c4b 100644 --- a/python_modules/dagster-test/dagster_test/toys/error_monster.py +++ b/python_modules/dagster-test/dagster_test/toys/error_monster.py @@ -1,156 +1,152 @@ -import six from dagster import ( EventMetadataEntry, Failure, Field, InputDefinition, Int, ModeDefinition, OutputDefinition, PresetDefinition, ResourceDefinition, RetryRequested, String, execute_pipeline, pipeline, solid, ) from dagster.utils import segfault class ErrorableResource: pass def resource_init(init_context): if init_context.resource_config["throw_on_resource_init"]: raise Exception("throwing from in resource_fn") return ErrorableResource() def define_errorable_resource(): return ResourceDefinition( resource_fn=resource_init, config_schema={ "throw_on_resource_init": Field(bool, is_required=False, default_value=False) }, ) solid_throw_config = { "throw_in_solid": Field(bool, is_required=False, default_value=False), "crash_in_solid": Field(bool, is_required=False, default_value=False), "return_wrong_type": Field(bool, is_required=False, default_value=False), "request_retry": Field(bool, is_required=False, default_value=False), } class ExampleException(Exception): pass def _act_on_config(solid_config): if solid_config["crash_in_solid"]: segfault() if solid_config["throw_in_solid"]: try: raise ExampleException("sample cause exception") except ExampleException as e: - six.raise_from( - Failure( - description="I'm a Failure", - metadata_entries=[ - EventMetadataEntry.text( - label="metadata_label", - text="I am metadata text", - description="metadata_description", - ) - ], - ), - e, - ) + raise Failure( + description="I'm a Failure", + metadata_entries=[ + EventMetadataEntry.text( + label="metadata_label", + text="I am metadata text", + description="metadata_description", + ) + ], + ) from e elif solid_config["request_retry"]: raise RetryRequested() @solid( output_defs=[OutputDefinition(Int)], config_schema=solid_throw_config, required_resource_keys={"errorable_resource"}, ) def emit_num(context): _act_on_config(context.solid_config) if context.solid_config["return_wrong_type"]: return "wow" return 13 @solid( input_defs=[InputDefinition("num", Int)], output_defs=[OutputDefinition(String)], config_schema=solid_throw_config, required_resource_keys={"errorable_resource"}, ) def num_to_str(context, num): _act_on_config(context.solid_config) if context.solid_config["return_wrong_type"]: return num + num return str(num) @solid( input_defs=[InputDefinition("string", String)], output_defs=[OutputDefinition(Int)], config_schema=solid_throw_config, required_resource_keys={"errorable_resource"}, ) def str_to_num(context, string): _act_on_config(context.solid_config) if context.solid_config["return_wrong_type"]: return string + string return int(string) @pipeline( description=( "Demo pipeline that enables configurable types of errors thrown during pipeline execution, " "including solid execution errors, type errors, and resource initialization errors." ), mode_defs=[ ModeDefinition( name="errorable_mode", resource_defs={"errorable_resource": define_errorable_resource()} ) ], preset_defs=[ PresetDefinition.from_pkg_resources( "passing", pkg_resource_defs=[("dagster_test.toys.environments", "error.yaml")], mode="errorable_mode", ) ], ) def error_monster(): start = emit_num.alias("start")() middle = num_to_str.alias("middle")(num=start) str_to_num.alias("end")(string=middle) if __name__ == "__main__": result = execute_pipeline( error_monster, { "solids": { "start": {"config": {"throw_in_solid": False, "return_wrong_type": False}}, "middle": {"config": {"throw_in_solid": False, "return_wrong_type": True}}, "end": {"config": {"throw_in_solid": False, "return_wrong_type": False}}, }, "resources": {"errorable_resource": {"config": {"throw_on_resource_init": False}}}, }, ) print("Pipeline Success: ", result.success) diff --git a/python_modules/dagster/dagster/config/field_utils.py b/python_modules/dagster/dagster/config/field_utils.py index 36bbc167d..f9b8425ed 100644 --- a/python_modules/dagster/dagster/config/field_utils.py +++ b/python_modules/dagster/dagster/config/field_utils.py @@ -1,306 +1,306 @@ # encoding: utf-8 import hashlib from typing import Any, Dict from dagster import check from dagster.core.errors import DagsterInvalidConfigDefinitionError from .config_type import ConfigType, ConfigTypeKind def all_optional_type(config_type): check.inst_param(config_type, "config_type", ConfigType) if ConfigTypeKind.is_shape(config_type.kind): for field in config_type.fields.values(): if field.is_required: return False return True if ConfigTypeKind.is_selector(config_type.kind): if len(config_type.fields) == 1: for field in config_type.fields.values(): if field.is_required: return False return True return False class __FieldValueSentinel: pass class __InferOptionalCompositeFieldSentinel: pass FIELD_NO_DEFAULT_PROVIDED = __FieldValueSentinel INFER_OPTIONAL_COMPOSITE_FIELD = __InferOptionalCompositeFieldSentinel class _ConfigHasFields(ConfigType): def __init__(self, fields, **kwargs): self.fields = expand_fields_dict(fields) super(_ConfigHasFields, self).__init__(**kwargs) FIELD_HASH_CACHE: Dict[str, Any] = {} def _memoize_inst_in_field_cache(passed_cls, defined_cls, key): if key in FIELD_HASH_CACHE: return FIELD_HASH_CACHE[key] defined_cls_inst = super(defined_cls, passed_cls).__new__(defined_cls) FIELD_HASH_CACHE[key] = defined_cls_inst return defined_cls_inst def _add_hash(m, string): - m.update(string.encode()) + m.update(string.encode("utf-8")) def _compute_fields_hash(fields, description): m = hashlib.sha1() # so that hexdigest is 40, not 64 bytes if description: _add_hash(m, ":description: " + description) for field_name in sorted(list(fields.keys())): field = fields[field_name] _add_hash(m, ":fieldname:" + field_name) if field.default_provided: _add_hash(m, ":default_value: " + field.default_value_as_json_str) _add_hash(m, ":is_required: " + str(field.is_required)) _add_hash(m, ":type_key: " + field.config_type.key) if field.description: _add_hash(m, ":description: " + field.description) return m.hexdigest() def _define_shape_key_hash(fields, description): return "Shape." + _compute_fields_hash(fields, description) class Shape(_ConfigHasFields): """Schema for configuration data with string keys and typed values via :py:class:`Field`. Unlike :py:class:`Permissive`, unspecified fields are not allowed and will throw a :py:class:`~dagster.DagsterInvalidConfigError`. Args: fields (Dict[str, Field]): The specification of the config dict. """ def __new__( cls, fields, description=None, ): return _memoize_inst_in_field_cache( cls, Shape, _define_shape_key_hash(expand_fields_dict(fields), description), ) def __init__(self, fields, description=None): fields = expand_fields_dict(fields) super(Shape, self).__init__( kind=ConfigTypeKind.STRICT_SHAPE, key=_define_shape_key_hash(fields, description), description=description, fields=fields, ) def _define_permissive_dict_key(fields, description): return ( "Permissive." + _compute_fields_hash(fields, description=description) if fields else "Permissive" ) class Permissive(_ConfigHasFields): """Defines a config dict with a partially specified schema. A permissive dict allows partial specification of the config schema. Any fields with a specified schema will be type checked. Other fields will be allowed, but will be ignored by the type checker. Args: fields (Dict[str, Field]): The partial specification of the config dict. **Examples:** .. code-block:: python @solid(config_schema=Field(Permissive({'required': Field(String)}))) def partially_specified_config(context) -> List: return sorted(list(context.solid_config.items())) """ def __new__(cls, fields=None, description=None): return _memoize_inst_in_field_cache( cls, Permissive, _define_permissive_dict_key( expand_fields_dict(fields) if fields else None, description ), ) def __init__(self, fields=None, description=None): fields = expand_fields_dict(fields) if fields else None super(Permissive, self).__init__( key=_define_permissive_dict_key(fields, description), kind=ConfigTypeKind.PERMISSIVE_SHAPE, fields=fields or dict(), description=description, ) def _define_selector_key(fields, description): return "Selector." + _compute_fields_hash(fields, description=description) class Selector(_ConfigHasFields): """Define a config field requiring the user to select one option. Selectors are used when you want to be able to present several different options in config but allow only one to be selected. For example, a single input might be read in from either a csv file or a parquet file, but not both at once. Note that in some other type systems this might be called an 'input union'. Functionally, a selector is like a :py:class:`Dict`, except that only one key from the dict can be specified in valid config. Args: fields (Dict[str, Field]): The fields from which the user must select. **Examples:** .. code-block:: python @solid( config_schema=Field( Selector( { 'haw': {'whom': Field(String, default_value='honua', is_required=False)}, 'cn': {'whom': Field(String, default_value='世界', is_required=False)}, 'en': {'whom': Field(String, default_value='world', is_required=False)}, } ), is_required=False, default_value={'en': {'whom': 'world'}}, ) ) def hello_world_with_default(context): if 'haw' in context.solid_config: return 'Aloha {whom}!'.format(whom=context.solid_config['haw']['whom']) if 'cn' in context.solid_config: return '你好,{whom}!'.format(whom=context.solid_config['cn']['whom']) if 'en' in context.solid_config: return 'Hello, {whom}!'.format(whom=context.solid_config['en']['whom']) """ def __new__(cls, fields, description=None): return _memoize_inst_in_field_cache( cls, Selector, _define_selector_key(expand_fields_dict(fields), description), ) def __init__(self, fields, description=None): fields = expand_fields_dict(fields) super(Selector, self).__init__( key=_define_selector_key(fields, description), kind=ConfigTypeKind.SELECTOR, fields=fields, description=description, ) # Config syntax expansion code below def is_potential_field(potential_field): from .field import Field, resolve_to_config_type return isinstance(potential_field, (Field, dict, list)) or resolve_to_config_type( potential_field ) def convert_fields_to_dict_type(fields): return _convert_fields_to_dict_type(fields, fields, []) def _convert_fields_to_dict_type(original_root, fields, stack): return Shape(_expand_fields_dict(original_root, fields, stack)) def expand_fields_dict(fields): return _expand_fields_dict(fields, fields, []) def _expand_fields_dict(original_root, fields, stack): check.dict_param(fields, "fields") return { name: _convert_potential_field(original_root, value, stack + [name]) for name, value in fields.items() } def expand_list(original_root, the_list, stack): from .config_type import Array if len(the_list) != 1: raise DagsterInvalidConfigDefinitionError( original_root, the_list, stack, "List must be of length 1" ) inner_type = _convert_potential_type(original_root, the_list[0], stack) if not inner_type: raise DagsterInvalidConfigDefinitionError( original_root, the_list, stack, "List have a single item and contain a valid type i.e. [int]. Got item {}".format( repr(the_list[0]) ), ) return Array(inner_type) def convert_potential_field(potential_field): return _convert_potential_field(potential_field, potential_field, []) def _convert_potential_type(original_root, potential_type, stack): from .field import resolve_to_config_type if isinstance(potential_type, dict): return Shape(_expand_fields_dict(original_root, potential_type, stack)) if isinstance(potential_type, list): return expand_list(original_root, potential_type, stack) return resolve_to_config_type(potential_type) def _convert_potential_field(original_root, potential_field, stack): from .field import Field if potential_field is None: raise DagsterInvalidConfigDefinitionError( original_root, potential_field, stack, reason="Fields cannot be None" ) if not is_potential_field(potential_field): raise DagsterInvalidConfigDefinitionError(original_root, potential_field, stack) if isinstance(potential_field, Field): return potential_field return Field(_convert_potential_type(original_root, potential_field, stack)) diff --git a/python_modules/dagster/dagster/core/code_pointer.py b/python_modules/dagster/dagster/core/code_pointer.py index 82300e7bd..2ec540f1d 100644 --- a/python_modules/dagster/dagster/core/code_pointer.py +++ b/python_modules/dagster/dagster/core/code_pointer.py @@ -1,416 +1,386 @@ import importlib import inspect import os import sys import warnings from abc import ABC, abstractmethod from collections import namedtuple -import six from dagster import check from dagster.core.errors import DagsterImportError, DagsterInvariantViolationError from dagster.core.types.loadable_target_origin import LoadableTargetOrigin from dagster.serdes import whitelist_for_serdes from dagster.seven import get_import_error_message, import_module_from_path from dagster.utils import alter_sys_path, frozenlist, load_yaml_from_path class CodePointer(ABC): @abstractmethod def load_target(self): pass @abstractmethod def describe(self): pass @abstractmethod def get_loadable_target_origin(self, executable_path): pass @staticmethod def from_module(module_name, definition): check.str_param(module_name, "module_name") check.str_param(definition, "definition") return ModuleCodePointer(module_name, definition) @staticmethod def from_python_package(module_name, attribute): check.str_param(module_name, "module_name") check.str_param(attribute, "attribute") return PackageCodePointer(module_name, attribute) @staticmethod def from_python_file(python_file, definition, working_directory): check.str_param(python_file, "python_file") check.str_param(definition, "definition") check.opt_str_param(working_directory, "working_directory") return FileCodePointer( python_file=python_file, fn_name=definition, working_directory=working_directory ) @staticmethod def from_legacy_repository_yaml(file_path): check.str_param(file_path, "file_path") config = load_yaml_from_path(file_path) repository_config = check.dict_elem(config, "repository") module_name = check.opt_str_elem(repository_config, "module") file_name = check.opt_str_elem(repository_config, "file") fn_name = check.str_elem(repository_config, "fn") return ( CodePointer.from_module(module_name, fn_name) if module_name # rebase file in config off of the path in the config file else CodePointer.from_python_file(rebase_file(file_name, file_path), fn_name, None) ) def rebase_file(relative_path_in_file, file_path_resides_in): """ In config files, you often put file paths that are meant to be relative to the location of that config file. This does that calculation. """ check.str_param(relative_path_in_file, "relative_path_in_file") check.str_param(file_path_resides_in, "file_path_resides_in") return os.path.join( os.path.dirname(os.path.abspath(file_path_resides_in)), relative_path_in_file ) def load_python_file(python_file, working_directory): """ Takes a path to a python file and returns a loaded module """ check.str_param(python_file, "python_file") # First verify that the file exists os.stat(python_file) module_name = os.path.splitext(os.path.basename(python_file))[0] cwd = sys.path[0] if working_directory: try: with alter_sys_path(to_add=[working_directory], to_remove=[cwd]): return import_module_from_path(module_name, python_file) except ImportError as ie: msg = get_import_error_message(ie) + python_file = os.path.abspath(os.path.expanduser(python_file)) + if msg == "attempted relative import with no known parent package": - six.raise_from( - DagsterImportError( - ( - "Encountered ImportError: `{msg}` while importing module {module} from " - "file {python_file}. Consider using the module-based options `-m` for " - "CLI-based targets or the `python_package` workspace.yaml target." - ).format( - msg=msg, - module=module_name, - python_file=os.path.abspath(os.path.expanduser(python_file)), - ) - ), - ie, - ) - six.raise_from( - DagsterImportError( - ( - "Encountered ImportError: `{msg}` while importing module {module} from " - "file {python_file}. Local modules were resolved using the working " - "directory `{working_directory}`. If another working directory should be " - "used, please explicitly specify the appropriate path using the `-d` or " - "`--working-directory` for CLI based targets or the `working_directory` " - "configuration option for `python_file`-based workspace.yaml targets. " - ).format( - msg=msg, - module=module_name, - python_file=os.path.abspath(os.path.expanduser(python_file)), - working_directory=os.path.abspath(os.path.expanduser(working_directory)), - ) - ), - ie, - ) + raise DagsterImportError( + f"Encountered ImportError: `{msg}` while importing module {module_name} from " + f"file {python_file}. Consider using the module-based options `-m` for " + "CLI-based targets or the `python_package` workspace.yaml target." + ) from ie + + working_directory = os.path.abspath(os.path.expanduser(working_directory)) + + raise DagsterImportError( + f"Encountered ImportError: `{msg}` while importing module {module_name} from " + f"file {python_file}. Local modules were resolved using the working " + f"directory `{working_directory}`. If another working directory should be " + "used, please explicitly specify the appropriate path using the `-d` or " + "`--working-directory` for CLI based targets or the `working_directory` " + "configuration option for `python_file`-based workspace.yaml targets. " + ) from ie error = None sys_modules = {k: v for k, v in sys.modules.items()} with alter_sys_path(to_add=[], to_remove=[cwd]): try: module = import_module_from_path(module_name, python_file) except ImportError as ie: # importing alters sys.modules in ways that may interfere with the import below, even # if the import has failed. to work around this, we need to manually clear any modules # that have been cached in sys.modules due to the speculative import call # Also, we are mutating sys.modules instead of straight-up assigning to sys_modules, # because some packages will do similar shenanigans to sys.modules (e.g. numpy) to_delete = set(sys.modules) - set(sys_modules) for key in to_delete: del sys.modules[key] error = ie if not error: return module try: module = import_module_from_path(module_name, python_file) # if here, we were able to resolve the module with the working directory on the # path, but should warn because we may not always invoke from the same directory # (e.g. from cron) + module_name = error.name if hasattr(error, "name") else module_name + warnings.warn( - ( - "Module `{module}` was resolved using the working directory. The ability to " - "implicitly load modules from the working directory is deprecated and " - "will be removed in a future release. Please explicitly specify the " - "`working_directory` config option in your workspace.yaml or install `{module}` to " - "your python environment." - ).format(module=error.name if hasattr(error, "name") else module_name) + f"Module `{module_name}` was resolved using the working directory. The ability to " + "implicitly load modules from the working directory is deprecated and " + "will be removed in a future release. Please explicitly specify the " + "`working_directory` config option in your workspace.yaml or install " + f"`{module_name}` to your python environment." ) return module except RuntimeError: # We might be here because numpy throws run time errors at import time when being imported # multiple times... we should also use the original import error as the root - six.raise_from( - DagsterImportError( - ( - "Encountered ImportError: `{msg}` while importing module {module} from file " - "{python_file}. If relying on the working directory to resolve modules, please " - "explicitly specify the appropriate path using the `-d` or " - "`--working-directory` for CLI based targets or the `working_directory` " - "configuration option for `python_file`-based workspace.yaml targets. " - + error.msg - ).format( - msg=error.msg, - module=module_name, - python_file=os.path.abspath(os.path.expanduser(python_file)), - ) - ), - error, - ) + python_file = os.path.abspath(os.path.expanduser(python_file)) + + raise DagsterImportError( + f"Encountered ImportError: `{error.msg}` while importing module {module_name} from file" + f" {python_file}. If relying on the working directory to resolve modules, please " + "explicitly specify the appropriate path using the `-d` or " + "`--working-directory` for CLI based targets or the `working_directory` " + "configuration option for `python_file`-based workspace.yaml targets. " + error.msg + ) from error except ImportError: # raise the original import error raise error def load_python_module(module_name, warn_only=False, remove_from_path_fn=None): check.str_param(module_name, "module_name") check.bool_param(warn_only, "warn_only") check.opt_callable_param(remove_from_path_fn, "remove_from_path_fn") error = None remove_paths = remove_from_path_fn() if remove_from_path_fn else [] # hook for tests remove_paths.insert(0, sys.path[0]) # remove the working directory with alter_sys_path(to_add=[], to_remove=remove_paths): try: module = importlib.import_module(module_name) except ImportError as ie: error = ie if error: try: module = importlib.import_module(module_name) # if here, we were able to resolve the module with the working directory on the path, # but should error because we may not always invoke from the same directory (e.g. from # cron) if warn_only: warnings.warn( - ( - "Module {module} was resolved using the working directory. The ability to " - "load uninstalled modules from the working directory is deprecated and " - "will be removed in a future release. Please use the python-file based " - "load arguments or install {module} to your python environment." - ).format(module=module_name) + f"Module {module_name} was resolved using the working directory. The ability to" + " load uninstalled modules from the working directory is deprecated and " + "will be removed in a future release. Please use the python-file based " + f"load arguments or install {module_name} to your python environment." ) else: - six.raise_from( - DagsterInvariantViolationError( - ( - "Module {module} not found. Packages must be installed rather than " - "relying on the working directory to resolve module loading." - ).format(module=module_name) - ), - error, - ) + raise DagsterInvariantViolationError( + f"Module {module_name} not found. Packages must be installed rather than " + "relying on the working directory to resolve module loading." + ) from error except RuntimeError: # We might be here because numpy throws run time errors at import time when being # imported multiple times, just raise the original import error raise error except ImportError as ie: raise error return module @whitelist_for_serdes class FileCodePointer( namedtuple("_FileCodePointer", "python_file fn_name working_directory"), CodePointer ): def __new__(cls, python_file, fn_name, working_directory=None): check.opt_str_param(working_directory, "working_directory") return super(FileCodePointer, cls).__new__( cls, check.str_param(python_file, "python_file"), check.str_param(fn_name, "fn_name"), working_directory, ) def load_target(self): module = load_python_file(self.python_file, self.working_directory) if not hasattr(module, self.fn_name): raise DagsterInvariantViolationError( "{name} not found at module scope in file {file}.".format( name=self.fn_name, file=self.python_file ) ) return getattr(module, self.fn_name) def describe(self): if self.working_directory: return "{self.python_file}::{self.fn_name} -- [dir {self.working_directory}]".format( self=self ) else: return "{self.python_file}::{self.fn_name}".format(self=self) def get_cli_args(self): if self.working_directory: return "-f {python_file} -a {fn_name} -d {directory}".format( python_file=os.path.abspath(os.path.expanduser(self.python_file)), fn_name=self.fn_name, directory=os.path.abspath(os.path.expanduser(self.working_directory)), ) else: return "-f {python_file} -a {fn_name}".format( python_file=os.path.abspath(os.path.expanduser(self.python_file)), fn_name=self.fn_name, ) def get_loadable_target_origin(self, executable_path): return LoadableTargetOrigin( executable_path=executable_path, python_file=self.python_file, attribute=self.fn_name, working_directory=self.working_directory, ) @whitelist_for_serdes class ModuleCodePointer(namedtuple("_ModuleCodePointer", "module fn_name"), CodePointer): def __new__(cls, module, fn_name): return super(ModuleCodePointer, cls).__new__( cls, check.str_param(module, "module"), check.str_param(fn_name, "fn_name") ) def load_target(self): module = load_python_module(self.module, warn_only=True) if not hasattr(module, self.fn_name): raise DagsterInvariantViolationError( "{name} not found in module {module}. dir: {dir}".format( name=self.fn_name, module=self.module, dir=dir(module) ) ) return getattr(module, self.fn_name) def describe(self): return "from {self.module} import {self.fn_name}".format(self=self) def get_cli_args(self): return "-m {module} -a {fn_name}".format(module=self.module, fn_name=self.fn_name) def get_loadable_target_origin(self, executable_path): return LoadableTargetOrigin( executable_path=executable_path, module_name=self.module, attribute=self.fn_name, ) @whitelist_for_serdes class PackageCodePointer(namedtuple("_PackageCodePointer", "module attribute"), CodePointer): def __new__(cls, module, attribute): return super(PackageCodePointer, cls).__new__( cls, check.str_param(module, "module"), check.str_param(attribute, "attribute") ) def load_target(self): module = load_python_module(self.module) if not hasattr(module, self.attribute): raise DagsterInvariantViolationError( "{name} not found in module {module}. dir: {dir}".format( name=self.attribute, module=self.module, dir=dir(module) ) ) return getattr(module, self.attribute) def describe(self): return "from {self.module} import {self.attribute}".format(self=self) def get_cli_args(self): return "-m {module} -a {attribute}".format(module=self.module, attribute=self.attribute) def get_loadable_target_origin(self, executable_path): return LoadableTargetOrigin( executable_path=executable_path, module_name=self.module, attribute=self.attribute, ) def get_python_file_from_target(target): module = inspect.getmodule(target) python_file = getattr(module, "__file__", None) if not python_file: return None return os.path.abspath(python_file) @whitelist_for_serdes class CustomPointer( namedtuple( "_CustomPointer", "reconstructor_pointer reconstructable_args reconstructable_kwargs" ), CodePointer, ): def __new__(cls, reconstructor_pointer, reconstructable_args, reconstructable_kwargs): check.inst_param(reconstructor_pointer, "reconstructor_pointer", ModuleCodePointer) # These are lists rather than tuples to circumvent the tuple serdes machinery -- since these # are user-provided, they aren't whitelisted for serdes. check.list_param(reconstructable_args, "reconstructable_args") check.list_param(reconstructable_kwargs, "reconstructable_kwargs") for reconstructable_kwarg in reconstructable_kwargs: check.list_param(reconstructable_kwarg, "reconstructable_kwarg") check.invariant(isinstance(reconstructable_kwarg[0], str), "Bad kwarg key") check.invariant( len(reconstructable_kwarg) == 2, "Bad kwarg of length {length}, should be 2".format( length=len(reconstructable_kwarg) ), ) # These are frozenlists, rather than lists, so that they can be hashed and the pointer # stored in the lru_cache on the repository and pipeline get_definition methods reconstructable_args = frozenlist(reconstructable_args) reconstructable_kwargs = frozenlist( [frozenlist(reconstructable_kwarg) for reconstructable_kwarg in reconstructable_kwargs] ) return super(CustomPointer, cls).__new__( cls, reconstructor_pointer, reconstructable_args, reconstructable_kwargs, ) def load_target(self): reconstructor = self.reconstructor_pointer.load_target() return reconstructor( *self.reconstructable_args, **{key: value for key, value in self.reconstructable_kwargs} ) def describe(self): return "reconstructable using {module}.{fn_name}".format( module=self.reconstructor_pointer.module, fn_name=self.reconstructor_pointer.fn_name ) def get_loadable_target_origin(self, executable_path): raise NotImplementedError() diff --git a/python_modules/dagster/dagster/core/debug.py b/python_modules/dagster/dagster/core/debug.py index c547f9796..24875bf90 100644 --- a/python_modules/dagster/dagster/core/debug.py +++ b/python_modules/dagster/dagster/core/debug.py @@ -1,48 +1,48 @@ from collections import namedtuple from dagster import check from dagster.core.events.log import EventRecord from dagster.core.snap import ExecutionPlanSnapshot, PipelineSnapshot from dagster.core.storage.pipeline_run import PipelineRun from dagster.serdes import serialize_dagster_namedtuple, whitelist_for_serdes @whitelist_for_serdes class DebugRunPayload( namedtuple( "_DebugRunPayload", "version pipeline_run event_list pipeline_snapshot execution_plan_snapshot", ) ): def __new__( cls, version, pipeline_run, event_list, pipeline_snapshot, execution_plan_snapshot, ): return super(DebugRunPayload, cls).__new__( cls, version=check.str_param(version, "version"), pipeline_run=check.inst_param(pipeline_run, "pipeline_run", PipelineRun), event_list=check.list_param(event_list, "event_list", EventRecord), pipeline_snapshot=check.inst_param( pipeline_snapshot, "pipeline_snapshot", PipelineSnapshot ), execution_plan_snapshot=check.inst_param( execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot ), ) @classmethod def build(cls, instance, run): from dagster import __version__ as dagster_version return cls( version=dagster_version, pipeline_run=run, event_list=instance.all_logs(run.run_id), pipeline_snapshot=instance.get_pipeline_snapshot(run.pipeline_snapshot_id), execution_plan_snapshot=instance.get_execution_plan_snapshot( run.execution_plan_snapshot_id ), ) def write(self, output_file): - return output_file.write(serialize_dagster_namedtuple(self).encode()) + return output_file.write(serialize_dagster_namedtuple(self).encode("utf-8")) diff --git a/python_modules/dagster/dagster/core/definitions/graph.py b/python_modules/dagster/dagster/core/definitions/graph.py index 81de015ee..e9a405c4b 100644 --- a/python_modules/dagster/dagster/core/definitions/graph.py +++ b/python_modules/dagster/dagster/core/definitions/graph.py @@ -1,517 +1,514 @@ from collections import OrderedDict -import six from dagster import check from dagster.core.definitions.config import ConfigMapping from dagster.core.errors import DagsterInvalidDefinitionError from dagster.core.types.dagster_type import DagsterTypeKind from toposort import CircularDependencyError, toposort_flatten from .dependency import DependencyStructure, Solid, SolidHandle, SolidInputHandle from .i_solid_definition import NodeDefinition from .input import FanInInputPointer, InputDefinition, InputMapping, InputPointer from .output import OutputDefinition, OutputMapping from .solid_container import create_execution_structure, validate_dependency_dict def _check_node_defs_arg(graph_name, node_defs): if not isinstance(node_defs, list): raise DagsterInvalidDefinitionError( '"solids" arg to "{name}" is not a list. Got {val}.'.format( name=graph_name, val=repr(node_defs) ) ) for node_def in node_defs: if isinstance(node_def, NodeDefinition): continue elif callable(node_def): raise DagsterInvalidDefinitionError( """You have passed a lambda or function {func} into {name} that is not a solid. You have likely forgetten to annotate this function with an @solid or @lambda_solid decorator.' """.format( name=graph_name, func=node_def.__name__ ) ) else: raise DagsterInvalidDefinitionError( "Invalid item in solid list: {item}".format(item=repr(node_def)) ) return node_defs def _create_adjacency_lists(solids, dep_structure): check.list_param(solids, "solids", Solid) check.inst_param(dep_structure, "dep_structure", DependencyStructure) visit_dict = {s.name: False for s in solids} forward_edges = {s.name: set() for s in solids} backward_edges = {s.name: set() for s in solids} def visit(solid_name): if visit_dict[solid_name]: return visit_dict[solid_name] = True for output_handle in dep_structure.all_upstream_outputs_from_solid(solid_name): forward_node = output_handle.solid.name backward_node = solid_name if forward_node in forward_edges: forward_edges[forward_node].add(backward_node) backward_edges[backward_node].add(forward_node) visit(forward_node) for s in solids: visit(s.name) return (forward_edges, backward_edges) class GraphDefinition(NodeDefinition): def __init__( self, name, description, node_defs, dependencies, input_mappings, output_mappings, config_mapping, **kwargs, ): self._node_defs = _check_node_defs_arg(name, node_defs) # TODO: backcompat for now self._solid_defs = self._node_defs self._dependencies = validate_dependency_dict(dependencies) self._dependency_structure, self._solid_dict = create_execution_structure( node_defs, self._dependencies, graph_definition=self ) # List[InputMapping] self._input_mappings, input_defs = _validate_in_mappings( check.opt_list_param(input_mappings, "input_mappings"), self._solid_dict, self._dependency_structure, name, class_name=type(self).__name__, ) # List[OutputMapping] self._output_mappings = _validate_out_mappings( check.opt_list_param(output_mappings, "output_mappings"), self._solid_dict, self._dependency_structure, name, class_name=type(self).__name__, ) self._config_mapping = check.opt_inst_param(config_mapping, "config_mapping", ConfigMapping) super(GraphDefinition, self).__init__( name=name, description=description, input_defs=input_defs, output_defs=[output_mapping.definition for output_mapping in self._output_mappings], **kwargs, ) # must happen after base class construction as properties are assumed to be there # eager computation to detect cycles self.solids_in_topological_order = self._solids_in_topological_order() def _solids_in_topological_order(self): _forward_edges, backward_edges = _create_adjacency_lists( self.solids, self.dependency_structure ) try: order = toposort_flatten(backward_edges) except CircularDependencyError as err: - six.raise_from( - DagsterInvalidDefinitionError(str(err)), err, - ) + raise DagsterInvalidDefinitionError(str(err)) from err return [self.solid_named(solid_name) for solid_name in order] @property def solids(self): """List[Solid]: Top-level solids in the graph. """ return list(set(self._solid_dict.values())) def has_solid_named(self, name): """Return whether or not there is a top level solid with this name in the graph. Args: name (str): Name of solid Returns: bool: True if the solid is in the graph. """ check.str_param(name, "name") return name in self._solid_dict def solid_named(self, name): """Return the top level solid named "name". Throws if it does not exist. Args: name (str): Name of solid Returns: Solid: """ check.str_param(name, "name") check.invariant( name in self._solid_dict, "{graph_name} has no solid named {name}.".format(graph_name=self._name, name=name), ) return self._solid_dict[name] def get_solid(self, handle): """Return the solid contained anywhere within the graph via its handle. Args: handle (SolidHandle): The solid's handle Returns: Solid: """ check.inst_param(handle, "handle", SolidHandle) current = handle lineage = [] while current: lineage.append(current.name) current = current.parent name = lineage.pop() solid = self.solid_named(name) while lineage: name = lineage.pop() solid = solid.definition.solid_named(name) return solid def iterate_node_defs(self): yield self for outer_node_def in self._node_defs: yield from outer_node_def.iterate_node_defs() @property def input_mappings(self): return self._input_mappings @property def output_mappings(self): return self._output_mappings @property def config_mapping(self): return self._config_mapping @property def has_config_mapping(self): return self._config_mapping is not None def get_input_mapping(self, input_name): check.str_param(input_name, "input_name") for mapping in self._input_mappings: if mapping.definition.name == input_name: return mapping return None def input_mapping_for_pointer(self, pointer): check.inst_param(pointer, "pointer", (InputPointer, FanInInputPointer)) for mapping in self._input_mappings: if mapping.maps_to == pointer: return mapping return None def get_output_mapping(self, output_name): check.str_param(output_name, "output_name") for mapping in self._output_mappings: if mapping.definition.name == output_name: return mapping return None def resolve_output_to_origin(self, output_name, handle): check.str_param(output_name, "output_name") check.inst_param(handle, "handle", SolidHandle) mapping = self.get_output_mapping(output_name) check.invariant(mapping, "Can only resolve outputs for valid output names") mapped_solid = self.solid_named(mapping.maps_from.solid_name) return mapped_solid.definition.resolve_output_to_origin( mapping.maps_from.output_name, SolidHandle(mapped_solid.name, handle), ) def default_value_for_input(self, input_name): check.str_param(input_name, "input_name") # base case if self.input_def_named(input_name).has_default_value: return self.input_def_named(input_name).default_value mapping = self.get_input_mapping(input_name) check.invariant(mapping, "Can only resolve inputs for valid input names") mapped_solid = self.solid_named(mapping.maps_to.solid_name) return mapped_solid.definition.default_value_for_input(mapping.maps_to.input_name) def input_has_default(self, input_name): check.str_param(input_name, "input_name") # base case if self.input_def_named(input_name).has_default_value: return True mapping = self.get_input_mapping(input_name) check.invariant(mapping, "Can only resolve inputs for valid input names") mapped_solid = self.solid_named(mapping.maps_to.solid_name) return mapped_solid.definition.input_has_default(mapping.maps_to.input_name) @property def required_resource_keys(self): required_resource_keys = set() for solid in self.solids: required_resource_keys.update(solid.definition.required_resource_keys) return frozenset(required_resource_keys) @property def has_config_entry(self): has_child_solid_config = any([solid.definition.has_config_entry for solid in self.solids]) return ( self.has_config_mapping or has_child_solid_config or self.has_configurable_inputs or self.has_configurable_outputs ) @property def dependencies(self): return self._dependencies @property def dependency_structure(self): return self._dependency_structure @property def config_schema(self): return self.config_mapping.config_schema if self.has_config_mapping else None def input_supports_dynamic_output_dep(self, input_name): mapping = self.get_input_mapping(input_name) internal_dynamic_handle = self.dependency_structure.get_upstream_dynamic_handle_for_solid( mapping.maps_to.solid_name ) if internal_dynamic_handle: return False return True def _validate_in_mappings(input_mappings, solid_dict, dependency_structure, name, class_name): from .composition import MappedInputPlaceholder input_def_dict = OrderedDict() mapping_keys = set() for mapping in input_mappings: # handle incorrect objects passed in as mappings if not isinstance(mapping, InputMapping): if isinstance(mapping, InputDefinition): raise DagsterInvalidDefinitionError( "In {class_name} '{name}' you passed an InputDefinition " "named '{input_name}' directly in to input_mappings. Return " "an InputMapping by calling mapping_to on the InputDefinition.".format( name=name, input_name=mapping.name, class_name=class_name ) ) else: raise DagsterInvalidDefinitionError( "In {class_name} '{name}' received unexpected type '{type}' in input_mappings. " "Provide an OutputMapping using InputDefinition(...).mapping_to(...)".format( type=type(mapping), name=name, class_name=class_name ) ) if input_def_dict.get(mapping.definition.name): if input_def_dict[mapping.definition.name] != mapping.definition: raise DagsterInvalidDefinitionError( "In {class_name} {name} multiple input mappings with same " "definition name but different definitions".format( name=name, class_name=class_name ), ) else: input_def_dict[mapping.definition.name] = mapping.definition target_solid = solid_dict.get(mapping.maps_to.solid_name) if target_solid is None: raise DagsterInvalidDefinitionError( "In {class_name} '{name}' input mapping references solid " "'{solid_name}' which it does not contain.".format( name=name, solid_name=mapping.maps_to.solid_name, class_name=class_name ) ) if not target_solid.has_input(mapping.maps_to.input_name): raise DagsterInvalidDefinitionError( "In {class_name} '{name}' input mapping to solid '{mapping.maps_to.solid_name}' " "which contains no input named '{mapping.maps_to.input_name}'".format( name=name, mapping=mapping, class_name=class_name ) ) target_input = target_solid.input_def_named(mapping.maps_to.input_name) solid_input_handle = SolidInputHandle(target_solid, target_input) if mapping.maps_to_fan_in: if not dependency_structure.has_multi_deps(solid_input_handle): raise DagsterInvalidDefinitionError( 'In {class_name} "{name}" input mapping target ' '"{mapping.maps_to.solid_name}.{mapping.maps_to.input_name}" (index {mapping.maps_to.fan_in_index} of fan-in) ' "is not a MultiDependencyDefinition.".format( name=name, mapping=mapping, class_name=class_name ) ) inner_deps = dependency_structure.get_multi_deps(solid_input_handle) if (mapping.maps_to.fan_in_index >= len(inner_deps)) or ( inner_deps[mapping.maps_to.fan_in_index] is not MappedInputPlaceholder ): raise DagsterInvalidDefinitionError( 'In {class_name} "{name}" input mapping target ' '"{mapping.maps_to.solid_name}.{mapping.maps_to.input_name}" index {mapping.maps_to.fan_in_index} in ' "the MultiDependencyDefinition is not a MappedInputPlaceholder".format( name=name, mapping=mapping, class_name=class_name ) ) mapping_keys.add( "{mapping.maps_to.solid_name}.{mapping.maps_to.input_name}.{mapping.maps_to.fan_in_index}".format( mapping=mapping ) ) target_type = target_input.dagster_type.get_inner_type_for_fan_in() fan_in_msg = " (index {} of fan-in)".format(mapping.maps_to.fan_in_index) else: if dependency_structure.has_deps(solid_input_handle): raise DagsterInvalidDefinitionError( 'In {class_name} "{name}" input mapping target ' '"{mapping.maps_to.solid_name}.{mapping.maps_to.input_name}" ' "is already satisfied by solid output".format( name=name, mapping=mapping, class_name=class_name ) ) mapping_keys.add( "{mapping.maps_to.solid_name}.{mapping.maps_to.input_name}".format(mapping=mapping) ) target_type = target_input.dagster_type fan_in_msg = "" if target_type != mapping.definition.dagster_type: raise DagsterInvalidDefinitionError( "In {class_name} '{name}' input " "'{mapping.definition.name}' of type {mapping.definition.dagster_type.display_name} maps to " "{mapping.maps_to.solid_name}.{mapping.maps_to.input_name}{fan_in_msg} of different type " "{target_type.display_name}. InputMapping source and " "destination must have the same type.".format( mapping=mapping, name=name, target_type=target_type, class_name=class_name, fan_in_msg=fan_in_msg, ) ) for input_handle in dependency_structure.input_handles(): if dependency_structure.has_multi_deps(input_handle): for idx, dep in enumerate(dependency_structure.get_multi_deps(input_handle)): if dep is MappedInputPlaceholder: mapping_str = "{input_handle.solid_name}.{input_handle.input_name}.{idx}".format( input_handle=input_handle, idx=idx ) if mapping_str not in mapping_keys: raise DagsterInvalidDefinitionError( "Unsatisfied MappedInputPlaceholder at index {idx} in " "MultiDependencyDefinition for '{input_handle.solid_name}.{input_handle.input_name}'".format( input_handle=input_handle, idx=idx ) ) return input_mappings, input_def_dict.values() def _validate_out_mappings(output_mappings, solid_dict, dependency_structure, name, class_name): for mapping in output_mappings: if isinstance(mapping, OutputMapping): target_solid = solid_dict.get(mapping.maps_from.solid_name) if target_solid is None: raise DagsterInvalidDefinitionError( "In {class_name} '{name}' output mapping references solid " "'{solid_name}' which it does not contain.".format( name=name, solid_name=mapping.maps_from.solid_name, class_name=class_name ) ) if not target_solid.has_output(mapping.maps_from.output_name): raise DagsterInvalidDefinitionError( "In {class_name} {name} output mapping from solid '{mapping.maps_from.solid_name}' " "which contains no output named '{mapping.maps_from.output_name}'".format( name=name, mapping=mapping, class_name=class_name ) ) target_output = target_solid.output_def_named(mapping.maps_from.output_name) if mapping.definition.dagster_type.kind != DagsterTypeKind.ANY and ( target_output.dagster_type != mapping.definition.dagster_type ): raise DagsterInvalidDefinitionError( "In {class_name} '{name}' output " "'{mapping.definition.name}' of type {mapping.definition.dagster_type.display_name} " "maps from {mapping.maps_from.solid_name}.{mapping.maps_from.output_name} of different type " "{target_output.dagster_type.display_name}. OutputMapping source " "and destination must have the same type.".format( class_name=class_name, mapping=mapping, name=name, target_output=target_output, ) ) if target_output.is_dynamic and not mapping.definition.is_dynamic: raise DagsterInvalidDefinitionError( f'In {class_name} "{name}" can not map from {target_output.__class__.__name__} ' f'"{target_output.name}" to {mapping.definition.__class__.__name__} ' f'"{mapping.definition.name}". Definition types must align.' ) dynamic_handle = dependency_structure.get_upstream_dynamic_handle_for_solid( target_solid.name ) if dynamic_handle and not mapping.definition.is_dynamic: raise DagsterInvalidDefinitionError( f'In {class_name} "{name}" output "{mapping.definition.name}" mapping from ' f'solid "{mapping.maps_from.solid_name}" must be a DynamicOutputDefinition since it is ' f'downstream of dynamic output "{dynamic_handle.describe()}".' ) elif isinstance(mapping, OutputDefinition): raise DagsterInvalidDefinitionError( "You passed an OutputDefinition named '{output_name}' directly " "in to output_mappings. Return an OutputMapping by calling " "mapping_from on the OutputDefinition.".format(output_name=mapping.name) ) else: raise DagsterInvalidDefinitionError( "Received unexpected type '{type}' in output_mappings. " "Provide an OutputMapping using OutputDefinition(...).mapping_from(...)".format( type=type(mapping) ) ) return output_mappings diff --git a/python_modules/dagster/dagster/core/definitions/inference.py b/python_modules/dagster/dagster/core/definitions/inference.py index e65e84fa0..6951214ad 100644 --- a/python_modules/dagster/dagster/core/definitions/inference.py +++ b/python_modules/dagster/dagster/core/definitions/inference.py @@ -1,130 +1,114 @@ import inspect -import six from dagster.check import CheckError from dagster.core.errors import DagsterInvalidDefinitionError from dagster.seven import funcsigs, is_module_available from .input import InputDefinition from .output import OutputDefinition def _infer_input_description_from_docstring(fn): if not is_module_available("docstring_parser"): return {} from docstring_parser import parse docstring = parse(fn.__doc__) return {p.arg_name: p.description for p in docstring.params} def _infer_output_description_from_docstring(fn): if not is_module_available("docstring_parser"): return from docstring_parser import parse docstring = parse(fn.__doc__) if docstring.returns is None: return return docstring.returns.description def infer_output_definitions(decorator_name, solid_name, fn): signature = funcsigs.signature(fn) try: description = _infer_output_description_from_docstring(fn) return [ OutputDefinition() if signature.return_annotation is funcsigs.Signature.empty else OutputDefinition(signature.return_annotation, description=description) ] except CheckError as type_error: - six.raise_from( - DagsterInvalidDefinitionError( - "Error inferring Dagster type for return type " - '"{type_annotation}" from {decorator} "{solid}". ' - "Correct the issue or explicitly pass definitions to {decorator}.".format( - decorator=decorator_name, - solid=solid_name, - type_annotation=signature.return_annotation, - ) - ), - type_error, - ) + raise DagsterInvalidDefinitionError( + "Error inferring Dagster type for return type " + f'"{signature.return_annotation}" from {decorator_name} "{solid_name}". ' + f"Correct the issue or explicitly pass definitions to {decorator_name}." + ) from type_error def has_explicit_return_type(fn): signature = funcsigs.signature(fn) return not signature.return_annotation is funcsigs.Signature.empty def _input_param_type(type_annotation): if type_annotation is not inspect.Parameter.empty: return type_annotation return None def infer_input_definitions_for_lambda_solid(solid_name, fn): signature = funcsigs.signature(fn) params = list(signature.parameters.values()) descriptions = _infer_input_description_from_docstring(fn) defs = _infer_inputs_from_params(params, "@lambda_solid", solid_name, descriptions=descriptions) return defs def _infer_inputs_from_params(params, decorator_name, solid_name, descriptions=None): descriptions = descriptions or {} input_defs = [] for param in params: try: if param.default is not funcsigs.Parameter.empty: input_def = InputDefinition( param.name, _input_param_type(param.annotation), default_value=param.default, description=descriptions.get(param.name), ) else: input_def = InputDefinition( param.name, _input_param_type(param.annotation), description=descriptions.get(param.name), ) input_defs.append(input_def) except CheckError as type_error: - six.raise_from( - DagsterInvalidDefinitionError( - "Error inferring Dagster type for input name {param} typed as " - '"{type_annotation}" from {decorator} "{solid}". ' - "Correct the issue or explicitly pass definitions to {decorator}.".format( - decorator=decorator_name, - solid=solid_name, - param=param.name, - type_annotation=param.annotation, - ) - ), - type_error, - ) + raise DagsterInvalidDefinitionError( + f"Error inferring Dagster type for input name {param.name} typed as " + f'"{param.annotation}" from {decorator_name} "{solid_name}". ' + "Correct the issue or explicitly pass definitions to {decorator_name}." + ) from type_error return input_defs def infer_input_definitions_for_graph(decorator_name, solid_name, fn): signature = funcsigs.signature(fn) params = list(signature.parameters.values()) descriptions = _infer_input_description_from_docstring(fn) defs = _infer_inputs_from_params(params, decorator_name, solid_name, descriptions=descriptions) return defs def infer_input_definitions_for_solid(solid_name, fn): signature = funcsigs.signature(fn) params = list(signature.parameters.values()) descriptions = _infer_input_description_from_docstring(fn) defs = _infer_inputs_from_params(params[1:], "@solid", solid_name, descriptions=descriptions) return defs diff --git a/python_modules/dagster/dagster/core/definitions/pipeline.py b/python_modules/dagster/dagster/core/definitions/pipeline.py index 2cb071c0a..9bea0ef8b 100644 --- a/python_modules/dagster/dagster/core/definitions/pipeline.py +++ b/python_modules/dagster/dagster/core/definitions/pipeline.py @@ -1,822 +1,816 @@ import uuid import warnings -import six from dagster import check from dagster.core.definitions.solid import NodeDefinition from dagster.core.errors import ( DagsterInvalidDefinitionError, DagsterInvalidSubsetError, DagsterInvariantViolationError, ) from dagster.core.storage.output_manager import IOutputManagerDefinition from dagster.core.storage.root_input_manager import IInputManagerDefinition from dagster.core.types.dagster_type import DagsterTypeKind, construct_dagster_type_dictionary from dagster.core.utils import str_format_set from dagster.utils.backcompat import experimental_arg_warning from .config import ConfigMapping from .dependency import ( DependencyDefinition, MultiDependencyDefinition, SolidHandle, SolidInvocation, ) from .graph import GraphDefinition from .hook import HookDefinition from .mode import ModeDefinition from .preset import PresetDefinition from .solid import NodeDefinition from .utils import validate_tags def _anonymous_pipeline_name(): return "__pipeline__" + str(uuid.uuid4()).replace("-", "") class PipelineDefinition(GraphDefinition): """Defines a Dagster pipeline. A pipeline is made up of - Solids, each of which is a single functional unit of data computation. - Dependencies, which determine how the values produced by solids as their outputs flow from one solid to another. This tells Dagster how to arrange solids, and potentially multiple aliased instances of solids, into a directed, acyclic graph (DAG) of compute. - Modes, which can be used to attach resources, custom loggers, custom system storage options, and custom executors to a pipeline, and to switch between them. - Presets, which can be used to ship common combinations of pipeline config options in Python code, and to switch between them. Args: solid_defs (List[SolidDefinition]): The set of solids used in this pipeline. name (Optional[str]): The name of the pipeline. Must be unique within any :py:class:`RepositoryDefinition` containing the pipeline. description (Optional[str]): A human-readable description of the pipeline. dependencies (Optional[Dict[Union[str, SolidInvocation], Dict[str, DependencyDefinition]]]): A structure that declares the dependencies of each solid's inputs on the outputs of other solids in the pipeline. Keys of the top level dict are either the string names of solids in the pipeline or, in the case of aliased solids, :py:class:`SolidInvocations `. Values of the top level dict are themselves dicts, which map input names belonging to the solid or aliased solid to :py:class:`DependencyDefinitions `. mode_defs (Optional[List[ModeDefinition]]): The set of modes in which this pipeline can operate. Modes are used to attach resources, custom loggers, custom system storage options, and custom executors to a pipeline. Modes can be used, e.g., to vary available resource and logging implementations between local test and production runs. preset_defs (Optional[List[PresetDefinition]]): A set of preset collections of configuration options that may be used to execute a pipeline. A preset consists of an environment dict, an optional subset of solids to execute, and a mode selection. Presets can be used to ship common combinations of options to pipeline end users in Python code, and can be selected by tools like Dagit. tags (Optional[Dict[str, Any]]): Arbitrary metadata for any execution run of the pipeline. Values that are not strings will be json encoded and must meet the criteria that `json.loads(json.dumps(value)) == value`. These tag values may be overwritten by tag values provided at invocation time. hook_defs (Optional[Set[HookDefinition]]): A set of hook definitions applied to the pipeline. When a hook is applied to a pipeline, it will be attached to all solid instances within the pipeline. _parent_pipeline_def (INTERNAL ONLY): Used for tracking pipelines created using solid subsets. Examples: .. code-block:: python @lambda_solid def return_one(): return 1 @solid(input_defs=[InputDefinition('num')], required_resource_keys={'op'}) def apply_op(context, num): return context.resources.op(num) @resource(config_schema=Int) def adder_resource(init_context): return lambda x: x + init_context.resource_config add_mode = ModeDefinition( name='add_mode', resource_defs={'op': adder_resource}, description='Mode that adds things', ) add_three_preset = PresetDefinition( name='add_three_preset', run_config={'resources': {'op': {'config': 3}}}, mode='add_mode', ) pipeline_def = PipelineDefinition( name='basic', solid_defs=[return_one, apply_op], dependencies={'apply_op': {'num': DependencyDefinition('return_one')}}, mode_defs=[add_mode], preset_defs=[add_three_preset], ) """ def __init__( self, solid_defs, name=None, description=None, dependencies=None, mode_defs=None, preset_defs=None, tags=None, hook_defs=None, input_mappings=None, output_mappings=None, config_mapping=None, positional_inputs=None, _parent_pipeline_def=None, # https://github.com/dagster-io/dagster/issues/2115 ): if not name: warnings.warn( "Pipeline must have a name. Names will be required starting in 0.10.0 or later." ) name = _anonymous_pipeline_name() # For these warnings they check truthiness because they get changed to [] higher # in the stack for the decorator case if input_mappings: experimental_arg_warning("input_mappings", "PipelineDefinition") if output_mappings: experimental_arg_warning("output_mappings", "PipelineDefinition") if config_mapping is not None: experimental_arg_warning("config_mapping", "PipelineDefinition") if positional_inputs: experimental_arg_warning("positional_inputs", "PipelineDefinition") super(PipelineDefinition, self).__init__( name=name, description=description, dependencies=dependencies, node_defs=solid_defs, tags=check.opt_dict_param(tags, "tags", key_type=str), positional_inputs=positional_inputs, input_mappings=input_mappings, output_mappings=output_mappings, config_mapping=config_mapping, ) self._current_level_node_defs = solid_defs self._tags = validate_tags(tags) mode_definitions = check.opt_list_param(mode_defs, "mode_defs", of_type=ModeDefinition) if not mode_definitions: mode_definitions = [ModeDefinition()] self._mode_definitions = mode_definitions seen_modes = set() for mode_def in mode_definitions: if mode_def.name in seen_modes: raise DagsterInvalidDefinitionError( ( 'Two modes seen with the name "{mode_name}" in "{pipeline_name}". ' "Modes must have unique names." ).format(mode_name=mode_def.name, pipeline_name=self._name) ) seen_modes.add(mode_def.name) self._dagster_type_dict = construct_dagster_type_dictionary(self._current_level_node_defs) self._hook_defs = check.opt_set_param(hook_defs, "hook_defs", of_type=HookDefinition) self._preset_defs = check.opt_list_param(preset_defs, "preset_defs", PresetDefinition) self._preset_dict = {} for preset in self._preset_defs: if preset.name in self._preset_dict: raise DagsterInvalidDefinitionError( ( 'Two PresetDefinitions seen with the name "{name}" in "{pipeline_name}". ' "PresetDefinitions must have unique names." ).format(name=preset.name, pipeline_name=self._name) ) if preset.mode not in seen_modes: raise DagsterInvalidDefinitionError( ( 'PresetDefinition "{name}" in "{pipeline_name}" ' 'references mode "{mode}" which is not defined.' ).format(name=preset.name, pipeline_name=self._name, mode=preset.mode) ) self._preset_dict[preset.name] = preset # Validate solid resource dependencies _validate_resource_dependencies( self._mode_definitions, self._current_level_node_defs, self._dagster_type_dict, self._solid_dict, self._hook_defs, ) # Validate unsatisfied inputs can be materialized from config _validate_inputs(self._dependency_structure, self._solid_dict, self._mode_definitions) # Recursively explore all nodes in the this pipeline self._all_node_defs = _build_all_node_defs(self._current_level_node_defs) self._parent_pipeline_def = check.opt_inst_param( _parent_pipeline_def, "_parent_pipeline_def", PipelineDefinition ) self._cached_run_config_schemas = {} self._cached_external_pipeline = None def copy_for_configured(self, name, description, config_schema, config_or_config_fn): if not self.has_config_mapping: raise DagsterInvalidDefinitionError( "Only pipelines utilizing config mapping can be pre-configured. The pipeline " '"{graph_name}" does not have a config mapping, and thus has nothing to be ' "configured.".format(graph_name=self.name) ) return PipelineDefinition( solid_defs=self._solid_defs, name=self._name_for_configured_node(self.name, name, config_or_config_fn), description=description or self.description, dependencies=self._dependencies, mode_defs=self._mode_definitions, preset_defs=self.preset_defs, hook_defs=self.hook_defs, input_mappings=self._input_mappings, output_mappings=self._output_mappings, config_mapping=ConfigMapping( self._config_mapping.config_fn, config_schema=config_schema ), positional_inputs=self.positional_inputs, _parent_pipeline_def=self._parent_pipeline_def, ) def get_run_config_schema(self, mode=None): check.str_param(mode, "mode") mode_def = self.get_mode_definition(mode) if mode_def.name in self._cached_run_config_schemas: return self._cached_run_config_schemas[mode_def.name] self._cached_run_config_schemas[mode_def.name] = _create_run_config_schema(self, mode_def) return self._cached_run_config_schemas[mode_def.name] @property def mode_definitions(self): return self._mode_definitions @property def preset_defs(self): return self._preset_defs def _get_mode_definition(self, mode): check.str_param(mode, "mode") for mode_definition in self._mode_definitions: if mode_definition.name == mode: return mode_definition return None def get_default_mode(self): return self._mode_definitions[0] @property def is_single_mode(self): return len(self._mode_definitions) == 1 @property def is_multi_mode(self): return len(self._mode_definitions) > 1 def has_mode_definition(self, mode): check.str_param(mode, "mode") return bool(self._get_mode_definition(mode)) def get_default_mode_name(self): return self._mode_definitions[0].name def get_mode_definition(self, mode=None): check.opt_str_param(mode, "mode") if mode is None: check.invariant(self.is_single_mode) return self.get_default_mode() mode_def = self._get_mode_definition(mode) check.invariant( mode_def is not None, "Could not find mode {mode} in pipeline {name}".format(mode=mode, name=self._name), ) return mode_def @property def available_modes(self): return [mode_def.name for mode_def in self._mode_definitions] @property def display_name(self): """str: Display name of pipeline. Name suitable for exception messages, logging etc. If pipeline is unnamed the method will return "<>". """ return self._name if self._name else "<>" @property def tags(self): return self._tags def has_dagster_type(self, name): check.str_param(name, "name") return name in self._dagster_type_dict def dagster_type_named(self, name): check.str_param(name, "name") return self._dagster_type_dict[name] def all_dagster_types(self): return self._dagster_type_dict.values() @property def all_solid_defs(self): return list(self._all_node_defs.values()) @property def top_level_solid_defs(self): return self._current_level_node_defs def solid_def_named(self, name): check.str_param(name, "name") check.invariant(name in self._all_node_defs, "{} not found".format(name)) return self._all_node_defs[name] def has_solid_def(self, name): check.str_param(name, "name") return name in self._all_node_defs def get_pipeline_subset_def(self, solids_to_execute): return ( self if solids_to_execute is None else _get_pipeline_subset_def(self, solids_to_execute) ) def get_presets(self): return list(self._preset_dict.values()) def has_preset(self, name): check.str_param(name, "name") return name in self._preset_dict def get_preset(self, name): check.str_param(name, "name") if name not in self._preset_dict: raise DagsterInvariantViolationError( ( 'Could not find preset for "{name}". Available presets ' 'for pipeline "{pipeline_name}" are {preset_names}.' ).format( name=name, preset_names=list(self._preset_dict.keys()), pipeline_name=self._name ) ) return self._preset_dict[name] def get_pipeline_snapshot(self): return self.get_pipeline_index().pipeline_snapshot def get_pipeline_snapshot_id(self): return self.get_pipeline_index().pipeline_snapshot_id def get_pipeline_index(self): from dagster.core.snap import PipelineSnapshot from dagster.core.host_representation import PipelineIndex return PipelineIndex( PipelineSnapshot.from_pipeline_def(self), self.get_parent_pipeline_snapshot() ) def get_config_schema_snapshot(self): return self.get_pipeline_snapshot().config_schema_snapshot @property def is_subset_pipeline(self): return False @property def parent_pipeline_def(self): return None def get_parent_pipeline_snapshot(self): return None @property def solids_to_execute(self): return None @property def hook_defs(self): return self._hook_defs def get_all_hooks_for_handle(self, handle): """Gather all the hooks for the given solid from all places possibly attached with a hook. A hook can be attached to any of the following objects * Solid (solid invocation) * PipelineDefinition Args: handle (SolidHandle): The solid's handle Returns: FrozeSet[HookDefinition] """ check.inst_param(handle, "handle", SolidHandle) hook_defs = set() current = handle lineage = [] while current: lineage.append(current.name) current = current.parent # hooks on top-level solid name = lineage.pop() solid = self.solid_named(name) hook_defs = hook_defs.union(solid.hook_defs) # hooks on non-top-level solids while lineage: name = lineage.pop() solid = solid.definition.solid_named(name) hook_defs = hook_defs.union(solid.hook_defs) # hooks applied to a pipeline definition will run on every solid hook_defs = hook_defs.union(self.hook_defs) return frozenset(hook_defs) def with_hooks(self, hook_defs): """Apply a set of hooks to all solid instances within the pipeline.""" hook_defs = check.set_param(hook_defs, "hook_defs", of_type=HookDefinition) return PipelineDefinition( solid_defs=self.top_level_solid_defs, name=self.name, description=self.description, dependencies=self.dependencies, mode_defs=self.mode_definitions, preset_defs=self.preset_defs, tags=self.tags, hook_defs=hook_defs.union(self.hook_defs), _parent_pipeline_def=self._parent_pipeline_def, ) class PipelineSubsetDefinition(PipelineDefinition): @property def solids_to_execute(self): return frozenset(self._solid_dict.keys()) @property def solid_selection(self): # we currently don't pass the real solid_selection (the solid query list) down here. # so in the short-term, to make the call sites cleaner, we will convert the solids to execute # to a list return list(self._solid_dict.keys()) @property def parent_pipeline_def(self): return self._parent_pipeline_def def get_parent_pipeline_snapshot(self): return self._parent_pipeline_def.get_pipeline_snapshot() @property def is_subset_pipeline(self): return True def get_pipeline_subset_def(self, solids_to_execute): raise DagsterInvariantViolationError("Pipeline subsets may not be subset again.") def _dep_key_of(solid): return SolidInvocation(solid.definition.name, solid.name) def _get_pipeline_subset_def(pipeline_def, solids_to_execute): """ Build a pipeline which is a subset of another pipeline. Only includes the solids which are in solids_to_execute. """ check.inst_param(pipeline_def, "pipeline_def", PipelineDefinition) check.set_param(solids_to_execute, "solids_to_execute", of_type=str) for solid_name in solids_to_execute: if not pipeline_def.has_solid_named(solid_name): raise DagsterInvalidSubsetError( "Pipeline {pipeline_name} has no solid named {name}.".format( pipeline_name=pipeline_def.name, name=solid_name ), ) solids = list(map(pipeline_def.solid_named, solids_to_execute)) deps = {_dep_key_of(solid): {} for solid in solids} for solid in solids: for input_handle in solid.input_handles(): if pipeline_def.dependency_structure.has_singular_dep(input_handle): output_handle = pipeline_def.dependency_structure.get_singular_dep(input_handle) if output_handle.solid.name in solids_to_execute: deps[_dep_key_of(solid)][input_handle.input_def.name] = DependencyDefinition( solid=output_handle.solid.name, output=output_handle.output_def.name ) elif pipeline_def.dependency_structure.has_multi_deps(input_handle): output_handles = pipeline_def.dependency_structure.get_multi_deps(input_handle) deps[_dep_key_of(solid)][input_handle.input_def.name] = MultiDependencyDefinition( [ DependencyDefinition( solid=output_handle.solid.name, output=output_handle.output_def.name ) for output_handle in output_handles if output_handle.solid.name in solids_to_execute ] ) try: sub_pipeline_def = PipelineSubsetDefinition( name=pipeline_def.name, # should we change the name for subsetted pipeline? solid_defs=list({solid.definition for solid in solids}), mode_defs=pipeline_def.mode_definitions, dependencies=deps, _parent_pipeline_def=pipeline_def, tags=pipeline_def.tags, hook_defs=pipeline_def.hook_defs, ) return sub_pipeline_def except DagsterInvalidDefinitionError as exc: # This handles the case when you construct a subset such that an unsatisfied # input cannot be loaded from config. Instead of throwing a DagsterInvalidDefinitionError, # we re-raise a DagsterInvalidSubsetError. - six.raise_from( - DagsterInvalidSubsetError( - "The attempted subset {solids_to_execute} for pipeline {pipeline_name} results in an invalid pipeline".format( - solids_to_execute=str_format_set(solids_to_execute), - pipeline_name=pipeline_def.name, - ) - ), - exc, - ) + raise DagsterInvalidSubsetError( + f"The attempted subset {str_format_set(solids_to_execute)} for pipeline " + f"{pipeline_def.name} results in an invalid pipeline" + ) from exc def _validate_resource_dependencies( mode_definitions, node_defs, dagster_type_dict, solid_dict, pipeline_hook_defs ): """This validation ensures that each pipeline context provides the resources that are required by each solid. """ check.list_param(mode_definitions, "mode_definitions", of_type=ModeDefinition) check.list_param(node_defs, "node_defs", of_type=NodeDefinition) check.dict_param(dagster_type_dict, "dagster_type_dict") check.dict_param(solid_dict, "solid_dict") check.set_param(pipeline_hook_defs, "pipeline_hook_defs", of_type=HookDefinition) for mode_def in mode_definitions: mode_resources = set(mode_def.resource_defs.keys()) for node_def in node_defs: for required_resource in node_def.required_resource_keys: if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by solid def {node_def_name}, but is not ' 'provided by mode "{mode_name}".' ).format( resource=required_resource, node_def_name=node_def.name, mode_name=mode_def.name, ) ) _validate_type_resource_deps_for_mode(mode_def, mode_resources, dagster_type_dict) for intermediate_storage in mode_def.intermediate_storage_defs or []: for required_resource in intermediate_storage.required_resource_keys: if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( "Resource '{resource}' is required by intermediate storage " "'{storage_name}', but is not provided by mode '{mode_name}'." ).format( resource=required_resource, storage_name=intermediate_storage.name, mode_name=mode_def.name, ) ) for solid in solid_dict.values(): for hook_def in solid.hook_defs: for required_resource in hook_def.required_resource_keys: if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by hook "{hook_name}", but is not ' 'provided by mode "{mode_name}".' ).format( resource=required_resource, hook_name=hook_def.name, mode_name=mode_def.name, ) ) for hook_def in pipeline_hook_defs: for required_resource in hook_def.required_resource_keys: if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by hook "{hook_name}", but is not ' 'provided by mode "{mode_name}".' ).format( resource=required_resource, hook_name=hook_def.name, mode_name=mode_def.name, ) ) def _validate_type_resource_deps_for_mode(mode_def, mode_resources, dagster_type_dict): for dagster_type in dagster_type_dict.values(): for required_resource in dagster_type.required_resource_keys: if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by type "{type_name}", but is not ' 'provided by mode "{mode_name}".' ).format( resource=required_resource, type_name=dagster_type.display_name, mode_name=mode_def.name, ) ) if dagster_type.loader: for required_resource in dagster_type.loader.required_resource_keys(): if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by the loader on type ' '"{type_name}", but is not provided by mode "{mode_name}".' ).format( resource=required_resource, type_name=dagster_type.display_name, mode_name=mode_def.name, ) ) if dagster_type.materializer: for required_resource in dagster_type.materializer.required_resource_keys(): if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by the materializer on type ' '"{type_name}", but is not provided by mode "{mode_name}".' ).format( resource=required_resource, type_name=dagster_type.display_name, mode_name=mode_def.name, ) ) for plugin in dagster_type.auto_plugins: used_by_storage = set( [ intermediate_storage_def.name for intermediate_storage_def in mode_def.intermediate_storage_defs if plugin.compatible_with_storage_def(intermediate_storage_def) ] ) if used_by_storage: for required_resource in plugin.required_resource_keys(): if required_resource not in mode_resources: raise DagsterInvalidDefinitionError( ( 'Resource "{resource}" is required by the plugin "{plugin_name}"' ' on type "{type_name}" (used with storages {storages}), ' 'but is not provided by mode "{mode_name}".' ).format( resource=required_resource, type_name=dagster_type.display_name, plugin_name=plugin.__name__, mode_name=mode_def.name, storages=used_by_storage, ) ) def _validate_inputs(dependency_structure, solid_dict, mode_definitions): for solid in solid_dict.values(): for handle in solid.input_handles(): if dependency_structure.has_deps(handle): for mode_def in mode_definitions: for source_output_handle in dependency_structure.get_deps_list(handle): output_manager_key = source_output_handle.output_def.manager_key output_manager_def = mode_def.resource_defs[output_manager_key] # TODO: remove the IOutputManagerDefinition check when asset store # API is removed. if isinstance( output_manager_def, IOutputManagerDefinition ) and not isinstance(output_manager_def, IInputManagerDefinition): raise DagsterInvalidDefinitionError( f'Input "{handle.input_def.name}" of solid "{solid.name}" is ' f'connected to output "{source_output_handle.output_def.name}" ' f'of solid "{source_output_handle.solid.name}". In mode ' f'"{mode_def.name}", that output does not have an output ' f"manager that knows how to load inputs, so we don't know how " f"to load the input. To address this, assign an IOManager to " f"the upstream output." ) else: if ( not handle.input_def.dagster_type.loader and not handle.input_def.dagster_type.kind == DagsterTypeKind.NOTHING and not handle.input_def.root_manager_key ): raise DagsterInvalidDefinitionError( 'Input "{input_name}" in solid "{solid_name}" is not connected to ' "the output of a previous solid and can not be loaded from configuration, " "creating an impossible to execute pipeline. " "Possible solutions are:\n" ' * add a dagster_type_loader for the type "{dagster_type}"\n' ' * connect "{input_name}" to the output of another solid\n'.format( solid_name=solid.name, input_name=handle.input_def.name, dagster_type=handle.input_def.dagster_type.display_name, ) ) def _build_all_node_defs(node_defs): all_defs = {} for current_level_node_def in node_defs: for node_def in current_level_node_def.iterate_node_defs(): if node_def.name in all_defs: if all_defs[node_def.name] != node_def: raise DagsterInvalidDefinitionError( 'Detected conflicting solid definitions with the same name "{name}"'.format( name=node_def.name ) ) else: all_defs[node_def.name] = node_def return all_defs def _create_run_config_schema(pipeline_def, mode_definition): from .environment_configs import ( EnvironmentClassCreationData, construct_config_type_dictionary, define_environment_cls, ) from .run_config_schema import RunConfigSchema # When executing with a subset pipeline, include the missing solids # from the original pipeline as ignored to allow execution with # run config that is valid for the original if pipeline_def.is_subset_pipeline: ignored_solids = [ solid for solid in pipeline_def.parent_pipeline_def.solids if not pipeline_def.has_solid_named(solid.name) ] else: ignored_solids = [] environment_type = define_environment_cls( EnvironmentClassCreationData( pipeline_name=pipeline_def.name, solids=pipeline_def.solids, dependency_structure=pipeline_def.dependency_structure, mode_definition=mode_definition, logger_defs=mode_definition.loggers, ignored_solids=ignored_solids, ) ) config_type_dict_by_name, config_type_dict_by_key = construct_config_type_dictionary( pipeline_def.all_solid_defs, environment_type ) return RunConfigSchema( environment_type=environment_type, config_type_dict_by_name=config_type_dict_by_name, config_type_dict_by_key=config_type_dict_by_key, ) diff --git a/python_modules/dagster/dagster/core/definitions/preset.py b/python_modules/dagster/dagster/core/definitions/preset.py index 2c2015659..59afaa050 100644 --- a/python_modules/dagster/dagster/core/definitions/preset.py +++ b/python_modules/dagster/dagster/core/definitions/preset.py @@ -1,207 +1,203 @@ from collections import namedtuple import pkg_resources -import six import yaml from dagster import check from dagster.core.definitions.utils import config_from_files, config_from_yaml_strings from dagster.core.errors import DagsterInvariantViolationError from dagster.utils.merger import deep_merge_dicts from .mode import DEFAULT_MODE_NAME from .utils import check_valid_name class PresetDefinition( namedtuple("_PresetDefinition", "name run_config solid_selection mode tags") ): """Defines a preset configuration in which a pipeline can execute. Presets can be used in Dagit to load predefined configurations into the tool. Presets may also be used from the Python API (in a script, or in test) as follows: .. code-block:: python execute_pipeline(pipeline_def, preset='example_preset') Presets may also be used with the command line tools: .. code-block:: shell $ dagster pipeline execute example_pipeline --preset example_preset Args: name (str): The name of this preset. Must be unique in the presets defined on a given pipeline. run_config (Optional[dict]): A dict representing the config to set with the preset. This is equivalent to the ``run_config`` argument to :py:func:`execute_pipeline`. solid_selection (Optional[List[str]]): A list of solid subselection (including single solid names) to execute with the preset. e.g. ``['*some_solid+', 'other_solid']`` mode (Optional[str]): The mode to apply when executing this preset. (default: 'default') tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset. """ def __new__( cls, name, run_config=None, solid_selection=None, mode=None, tags=None, ): return super(PresetDefinition, cls).__new__( cls, name=check_valid_name(name), run_config=run_config, solid_selection=check.opt_nullable_list_param( solid_selection, "solid_selection", of_type=str ), mode=check.opt_str_param(mode, "mode", DEFAULT_MODE_NAME), tags=check.opt_dict_param(tags, "tags", key_type=str), ) @staticmethod def from_files(name, config_files=None, solid_selection=None, mode=None, tags=None): """Static constructor for presets from YAML files. Args: name (str): The name of this preset. Must be unique in the presets defined on a given pipeline. config_files (Optional[List[str]]): List of paths or glob patterns for yaml files to load and parse as the environment config for this preset. solid_selection (Optional[List[str]]): A list of solid subselection (including single solid names) to execute with the preset. e.g. ``['*some_solid+', 'other_solid']`` mode (Optional[str]): The mode to apply when executing this preset. (default: 'default') tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset. Returns: PresetDefinition: A PresetDefinition constructed from the provided YAML files. Raises: DagsterInvariantViolationError: When one of the YAML files is invalid and has a parse error. """ check.str_param(name, "name") config_files = check.opt_list_param(config_files, "config_files") solid_selection = check.opt_nullable_list_param( solid_selection, "solid_selection", of_type=str ) mode = check.opt_str_param(mode, "mode", DEFAULT_MODE_NAME) merged = config_from_files(config_files) return PresetDefinition(name, merged, solid_selection, mode, tags) @staticmethod def from_yaml_strings(name, yaml_strings=None, solid_selection=None, mode=None, tags=None): """Static constructor for presets from YAML strings. Args: name (str): The name of this preset. Must be unique in the presets defined on a given pipeline. yaml_strings (Optional[List[str]]): List of yaml strings to parse as the environment config for this preset. solid_selection (Optional[List[str]]): A list of solid subselection (including single solid names) to execute with the preset. e.g. ``['*some_solid+', 'other_solid']`` mode (Optional[str]): The mode to apply when executing this preset. (default: 'default') tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset. Returns: PresetDefinition: A PresetDefinition constructed from the provided YAML strings Raises: DagsterInvariantViolationError: When one of the YAML documents is invalid and has a parse error. """ check.str_param(name, "name") yaml_strings = check.opt_list_param(yaml_strings, "yaml_strings", of_type=str) solid_selection = check.opt_nullable_list_param( solid_selection, "solid_selection", of_type=str ) mode = check.opt_str_param(mode, "mode", DEFAULT_MODE_NAME) merged = config_from_yaml_strings(yaml_strings) return PresetDefinition(name, merged, solid_selection, mode, tags) @staticmethod def from_pkg_resources( name, pkg_resource_defs=None, solid_selection=None, mode=None, tags=None ): """Load a preset from a package resource, using :py:func:`pkg_resources.resource_string`. Example: .. code-block:: python PresetDefinition.from_pkg_resources( name='local', mode='local', pkg_resource_defs=[ ('dagster_examples.airline_demo.environments', 'local_base.yaml'), ('dagster_examples.airline_demo.environments', 'local_warehouse.yaml'), ], ) Args: name (str): The name of this preset. Must be unique in the presets defined on a given pipeline. pkg_resource_defs (Optional[List[(str, str)]]): List of pkg_resource modules/files to load as environment config for this preset. solid_selection (Optional[List[str]]): A list of solid subselection (including single solid names) to execute with this partition. e.g. ``['*some_solid+', 'other_solid']`` mode (Optional[str]): The mode to apply when executing this preset. (default: 'default') tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset. Returns: PresetDefinition: A PresetDefinition constructed from the provided YAML strings Raises: DagsterInvariantViolationError: When one of the YAML documents is invalid and has a parse error. """ pkg_resource_defs = check.opt_list_param( pkg_resource_defs, "pkg_resource_defs", of_type=tuple ) try: yaml_strings = [ - six.ensure_str(pkg_resources.resource_string(*pkg_resource_def)) + pkg_resources.resource_string(*pkg_resource_def).decode("utf-8") for pkg_resource_def in pkg_resource_defs ] except (ModuleNotFoundError, FileNotFoundError, UnicodeDecodeError) as err: - six.raise_from( - DagsterInvariantViolationError( - "Encountered error attempting to parse yaml. Loading YAMLs from " - "package resources {pkg_resource_defs} " - 'on preset "{name}".'.format(pkg_resource_defs=pkg_resource_defs, name=name) - ), - err, - ) + raise DagsterInvariantViolationError( + "Encountered error attempting to parse yaml. Loading YAMLs from " + f"package resources {pkg_resource_defs} " + f'on preset "{name}".' + ) from err return PresetDefinition.from_yaml_strings(name, yaml_strings, solid_selection, mode, tags) def get_environment_yaml(self): """Get the environment dict set on a preset as YAML. Returns: str: The environment dict as YAML. """ return yaml.dump(self.run_config or {}, default_flow_style=False) def with_additional_config(self, run_config): """Return a new PresetDefinition with additional config merged in to the existing config.""" check.opt_nullable_dict_param(run_config, "run_config") if run_config is None: return self else: return PresetDefinition( name=self.name, solid_selection=self.solid_selection, mode=self.mode, tags=self.tags, run_config=deep_merge_dicts(self.run_config, run_config), ) diff --git a/python_modules/dagster/dagster/core/definitions/utils.py b/python_modules/dagster/dagster/core/definitions/utils.py index 8ba66605d..e3839c71c 100644 --- a/python_modules/dagster/dagster/core/definitions/utils.py +++ b/python_modules/dagster/dagster/core/definitions/utils.py @@ -1,220 +1,208 @@ import keyword import os import re from glob import glob import pkg_resources -import six import yaml from dagster import check, seven from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError from dagster.utils import frozentags from dagster.utils.yaml_utils import merge_yaml_strings, merge_yamls DEFAULT_OUTPUT = "result" DISALLOWED_NAMES = set( [ "context", "conf", "config", "meta", "arg_dict", "dict", "input_arg_dict", "output_arg_dict", "int", "str", "float", "bool", "input", "output", "type", ] + list(keyword.kwlist) # just disallow all python keywords ) VALID_NAME_REGEX_STR = r"^[A-Za-z0-9_]+$" VALID_NAME_REGEX = re.compile(VALID_NAME_REGEX_STR) def has_valid_name_chars(name): return bool(VALID_NAME_REGEX.match(name)) def check_valid_name(name): check.str_param(name, "name") if name in DISALLOWED_NAMES: raise DagsterInvalidDefinitionError("{name} is not allowed.".format(name=name)) if not has_valid_name_chars(name): raise DagsterInvalidDefinitionError( "{name} must be in regex {regex}".format(name=name, regex=VALID_NAME_REGEX_STR) ) check.invariant(is_valid_name(name)) return name def is_valid_name(name): check.str_param(name, "name") return name not in DISALLOWED_NAMES and has_valid_name_chars(name) def _kv_str(key, value): return '{key}="{value}"'.format(key=key, value=repr(value)) def struct_to_string(name, **kwargs): # Sort the kwargs to ensure consistent representations across Python versions props_str = ", ".join([_kv_str(key, value) for key, value in sorted(kwargs.items())]) return "{name}({props_str})".format(name=name, props_str=props_str) def validate_tags(tags): valid_tags = {} for key, value in check.opt_dict_param(tags, "tags", key_type=str).items(): if not isinstance(value, str): valid = False err_reason = 'Could not JSON encode value "{}"'.format(value) try: str_val = seven.json.dumps(value) err_reason = 'JSON encoding "{json}" of value "{val}" is not equivalent to original value'.format( json=str_val, val=value ) valid = seven.json.loads(str_val) == value except Exception: # pylint: disable=broad-except pass if not valid: raise DagsterInvalidDefinitionError( 'Invalid value for tag "{key}", {err_reason}. Tag values must be strings ' "or meet the constraint that json.loads(json.dumps(value)) == value.".format( key=key, err_reason=err_reason ) ) valid_tags[key] = str_val else: valid_tags[key] = value return frozentags(valid_tags) def config_from_files(config_files): """Constructs run config from YAML files. Args: config_files (List[str]): List of paths or glob patterns for yaml files to load and parse as the run config. Returns: Dict[Str, Any]: A run config dictionary constructed from provided YAML files. Raises: FileNotFoundError: When a config file produces no results DagsterInvariantViolationError: When one of the YAML files is invalid and has a parse error. """ config_files = check.opt_list_param(config_files, "config_files") filenames = [] for file_glob in config_files or []: globbed_files = glob(file_glob) if not globbed_files: raise DagsterInvariantViolationError( 'File or glob pattern "{file_glob}" for "config_files"' "produced no results.".format(file_glob=file_glob) ) filenames += [os.path.realpath(globbed_file) for globbed_file in globbed_files] try: run_config = merge_yamls(filenames) except yaml.YAMLError as err: - six.raise_from( - DagsterInvariantViolationError( - "Encountered error attempting to parse yaml. Parsing files {file_set} " - "loaded by file/patterns {files}.".format(file_set=filenames, files=config_files) - ), - err, - ) + raise DagsterInvariantViolationError( + f"Encountered error attempting to parse yaml. Parsing files {filenames} " + f"loaded by file/patterns {config_files}." + ) from err return run_config def config_from_yaml_strings(yaml_strings): """Static constructor for run configs from YAML strings. Args: yaml_strings (List[str]): List of yaml strings to parse as the run config. Returns: Dict[Str, Any]: A run config dictionary constructed from the provided yaml strings Raises: DagsterInvariantViolationError: When one of the YAML documents is invalid and has a parse error. """ yaml_strings = check.opt_list_param(yaml_strings, "yaml_strings", of_type=str) try: run_config = merge_yaml_strings(yaml_strings) except yaml.YAMLError as err: - six.raise_from( - DagsterInvariantViolationError( - "Encountered error attempting to parse yaml. Parsing YAMLs {yaml_strings} ".format( - yaml_strings=yaml_strings - ) - ), - err, - ) + raise DagsterInvariantViolationError( + f"Encountered error attempting to parse yaml. Parsing YAMLs {yaml_strings} " + ) from err return run_config def config_from_pkg_resources(pkg_resource_defs): """Load a run config from a package resource, using :py:func:`pkg_resources.resource_string`. Example: .. code-block:: python config_from_pkg_resources( pkg_resource_defs=[ ('dagster_examples.airline_demo.environments', 'local_base.yaml'), ('dagster_examples.airline_demo.environments', 'local_warehouse.yaml'), ], ) Args: pkg_resource_defs (List[(str, str)]): List of pkg_resource modules/files to load as the run config. Returns: Dict[Str, Any]: A run config dictionary constructed from the provided yaml strings Raises: DagsterInvariantViolationError: When one of the YAML documents is invalid and has a parse error. """ pkg_resource_defs = check.opt_list_param(pkg_resource_defs, "pkg_resource_defs", of_type=tuple) try: yaml_strings = [ - six.ensure_str(pkg_resources.resource_string(*pkg_resource_def)) + pkg_resources.resource_string(*pkg_resource_def).decode("utf-8") for pkg_resource_def in pkg_resource_defs ] except (ModuleNotFoundError, FileNotFoundError, UnicodeDecodeError) as err: - six.raise_from( - DagsterInvariantViolationError( - "Encountered error attempting to parse yaml. Loading YAMLs from " - "package resources {pkg_resource_defs}.".format(pkg_resource_defs=pkg_resource_defs) - ), - err, - ) + raise DagsterInvariantViolationError( + "Encountered error attempting to parse yaml. Loading YAMLs from " + f"package resources {pkg_resource_defs}." + ) from err return config_from_yaml_strings(yaml_strings=yaml_strings) diff --git a/python_modules/dagster/dagster/core/errors.py b/python_modules/dagster/dagster/core/errors.py index 444464a8a..5b46aa3c5 100644 --- a/python_modules/dagster/dagster/core/errors.py +++ b/python_modules/dagster/dagster/core/errors.py @@ -1,504 +1,503 @@ """Core Dagster error classes. All errors thrown by the Dagster framework inherit from :py:class:`~dagster.DagsterError`. Users should not subclass this base class for their own exceptions. There is another exception base class, :py:class:`~dagster.DagsterUserCodeExecutionError`, which is used by the framework in concert with the :py:func:`~dagster.core.errors.user_code_error_boundary`. Dagster uses this construct to wrap user code into which it calls. User code can perform arbitrary computations and may itself throw exceptions. The error boundary catches these user code-generated exceptions, and then reraises them wrapped in a subclass of :py:class:`~dagster.DagsterUserCodeExecutionError`. The wrapped exceptions include additional context for the original exceptions, injected by the Dagster runtime. """ import sys import traceback from contextlib import contextmanager from dagster import check from dagster.utils.interrupts import raise_interrupts_as -from future.utils import raise_from class DagsterError(Exception): """Base class for all errors thrown by the Dagster framework. Users should not subclass this base class for their own exceptions.""" @property def is_user_code_error(self): """Returns true if this error is attributable to user code.""" return False class DagsterInvalidDefinitionError(DagsterError): """Indicates that the rules for a definition have been violated by the user.""" class DagsterInvalidSubsetError(DagsterError): """Indicates that a subset of a pipeline is invalid because either: - One or more solids in the specified subset do not exist on the pipeline.' - The subset produces an invalid pipeline. """ CONFIG_ERROR_VERBIAGE = """ This value can be a: - Field - Python primitive types that resolve to dagster config types - int, float, bool, str, list. - A dagster config type: Int, Float, Bool, List, Optional, Selector, Shape, Permissive - A bare python dictionary, which is wrapped in Field(Shape(...)). Any values in the dictionary get resolved by the same rules, recursively. - A python list with a single entry that can resolve to a type, e.g. [int] """ class DagsterInvalidConfigDefinitionError(DagsterError): """Indicates that you have attempted to construct a config with an invalid value Acceptable values for config types are any of: 1. A Python primitive type that resolves to a Dagster config type (:py:class:`~python:int`, :py:class:`~python:float`, :py:class:`~python:bool`, :py:class:`~python:str`, or :py:class:`~python:list`). 2. A Dagster config type: :py:data:`~dagster.Int`, :py:data:`~dagster.Float`, :py:data:`~dagster.Bool`, :py:data:`~dagster.String`, :py:data:`~dagster.StringSource`, :py:data:`~dagster.Any`, :py:class:`~dagster.Array`, :py:data:`~dagster.Noneable`, :py:data:`~dagster.Enum`, :py:class:`~dagster.Selector`, :py:class:`~dagster.Shape`, or :py:class:`~dagster.Permissive`. 3. A bare python dictionary, which will be automatically wrapped in :py:class:`~dagster.Shape`. Values of the dictionary are resolved recursively according to the same rules. 4. A bare python list of length one which itself is config type. Becomes :py:class:`Array` with list element as an argument. 5. An instance of :py:class:`~dagster.Field`. """ def __init__(self, original_root, current_value, stack, reason=None, **kwargs): self.original_root = original_root self.current_value = current_value self.stack = stack super(DagsterInvalidConfigDefinitionError, self).__init__( ( "Error defining config. Original value passed: {original_root}. " "{stack_str}{current_value} " "cannot be resolved.{reason_str}" + CONFIG_ERROR_VERBIAGE ).format( original_root=repr(original_root), stack_str="Error at stack path :" + ":".join(stack) + ". " if stack else "", current_value=repr(current_value), reason_str=" Reason: {reason}.".format(reason=reason) if reason else "", ), **kwargs, ) class DagsterInvariantViolationError(DagsterError): """Indicates the user has violated a well-defined invariant that can only be enforced at runtime.""" class DagsterExecutionStepNotFoundError(DagsterError): """Thrown when the user specifies execution step keys that do not exist.""" def __init__(self, *args, **kwargs): self.step_keys = check.list_param(kwargs.pop("step_keys"), "step_keys", str) super(DagsterExecutionStepNotFoundError, self).__init__(*args, **kwargs) class DagsterRunNotFoundError(DagsterError): """Thrown when a run cannot be found in run storage.""" def __init__(self, *args, **kwargs): self.invalid_run_id = check.str_param(kwargs.pop("invalid_run_id"), "invalid_run_id") super(DagsterRunNotFoundError, self).__init__(*args, **kwargs) class DagsterStepOutputNotFoundError(DagsterError): """Indicates that previous step outputs required for an execution step to proceed are not available.""" def __init__(self, *args, **kwargs): self.step_key = check.str_param(kwargs.pop("step_key"), "step_key") self.output_name = check.str_param(kwargs.pop("output_name"), "output_name") super(DagsterStepOutputNotFoundError, self).__init__(*args, **kwargs) @contextmanager def raise_execution_interrupts(): with raise_interrupts_as(DagsterExecutionInterruptedError): yield @contextmanager def user_code_error_boundary(error_cls, msg_fn, control_flow_exceptions=None, **kwargs): """ Wraps the execution of user-space code in an error boundary. This places a uniform policy around an user code invoked by the framework. This ensures that all user errors are wrapped in an exception derived from DagsterUserCodeExecutionError, and that the original stack trace of the user error is preserved, so that it can be reported without confusing framework code in the stack trace, if a tool author wishes to do so. Examples: .. code-block:: python with user_code_error_boundary( # Pass a class that inherits from DagsterUserCodeExecutionError DagsterExecutionStepExecutionError, # Pass a function that produces a message "Error occurred during step execution" ): call_user_provided_function() """ check.callable_param(msg_fn, "msg_fn") check.subclass_param(error_cls, "error_cls", DagsterUserCodeExecutionError) control_flow_exceptions = tuple( check.opt_list_param(control_flow_exceptions, "control_flow_exceptions") ) with raise_execution_interrupts(): try: yield except control_flow_exceptions as cf: # A control flow exception has occurred and should be propagated raise cf except DagsterError as de: # The system has thrown an error that is part of the user-framework contract raise de except Exception as e: # pylint: disable=W0703 # An exception has been thrown by user code and computation should cease # with the error reported further up the stack - raise_from( - error_cls(msg_fn(), user_exception=e, original_exc_info=sys.exc_info(), **kwargs), e - ) + raise error_cls( + msg_fn(), user_exception=e, original_exc_info=sys.exc_info(), **kwargs + ) from e class DagsterUserCodeExecutionError(DagsterError): """ This is the base class for any exception that is meant to wrap an :py:class:`~python:Exception` thrown by user code. It wraps that existing user code. The ``original_exc_info`` argument to the constructor is meant to be a tuple of the type returned by :py:func:`sys.exc_info ` at the call site of the constructor. Users should not subclass this base class for their own exceptions and should instead throw freely from user code. User exceptions will be automatically wrapped and rethrown. """ def __init__(self, *args, **kwargs): # original_exc_info should be gotten from a sys.exc_info() call at the # callsite inside of the exception handler. this will allow consuming # code to *re-raise* the user error in it's original format # for cleaner error reporting that does not have framework code in it user_exception = check.inst_param(kwargs.pop("user_exception"), "user_exception", Exception) original_exc_info = check.tuple_param(kwargs.pop("original_exc_info"), "original_exc_info") check.invariant(original_exc_info[0] is not None) super(DagsterUserCodeExecutionError, self).__init__(args[0], *args[1:], **kwargs) self.user_exception = check.opt_inst_param(user_exception, "user_exception", Exception) self.original_exc_info = original_exc_info @property def is_user_code_error(self): return True class DagsterTypeCheckError(DagsterUserCodeExecutionError): """Indicates an error in the solid type system at runtime. E.g. a solid receives an unexpected input, or produces an output that does not match the type of the output definition. """ class DagsterExecutionLoadInputError(DagsterUserCodeExecutionError): """Indicates an error occurred while loading an input for a step.""" def __init__(self, *args, **kwargs): self.step_key = check.str_param(kwargs.pop("step_key"), "step_key") self.input_name = check.str_param(kwargs.pop("input_name"), "input_name") super(DagsterExecutionLoadInputError, self).__init__(*args, **kwargs) class DagsterExecutionHandleOutputError(DagsterUserCodeExecutionError): """Indicates an error occurred while loading an input for a step.""" def __init__(self, *args, **kwargs): self.step_key = check.str_param(kwargs.pop("step_key"), "step_key") self.output_name = check.str_param(kwargs.pop("output_name"), "output_name") super(DagsterExecutionHandleOutputError, self).__init__(*args, **kwargs) class DagsterExecutionStepExecutionError(DagsterUserCodeExecutionError): """Indicates an error occurred while executing the body of an execution step.""" def __init__(self, *args, **kwargs): self.step_key = check.str_param(kwargs.pop("step_key"), "step_key") self.solid_name = check.str_param(kwargs.pop("solid_name"), "solid_name") self.solid_def_name = check.str_param(kwargs.pop("solid_def_name"), "solid_def_name") super(DagsterExecutionStepExecutionError, self).__init__(*args, **kwargs) class DagsterResourceFunctionError(DagsterUserCodeExecutionError): """ Indicates an error occurred while executing the body of the ``resource_fn`` in a :py:class:`~dagster.ResourceDefinition` during resource initialization. """ class DagsterConfigMappingFunctionError(DagsterUserCodeExecutionError): """ Indicates that an unexpected error occurred while executing the body of a config mapping function defined in a :py:class:`~dagster.CompositeSolidDefinition` during config parsing. """ class DagsterTypeLoadingError(DagsterUserCodeExecutionError): """ Indicates that an unexpected error occurred while executing the body of an type load function defined in a :py:class:`~dagster.DagsterTypeLoader` during loading of a custom type. """ class DagsterTypeMaterializationError(DagsterUserCodeExecutionError): """ Indicates that an unexpected error occurred while executing the body of an output materialization function defined in a :py:class:`~dagster.DagsterTypeMaterializer` during materialization of a custom type. """ class DagsterUnknownResourceError(DagsterError, AttributeError): # inherits from AttributeError as it is raised within a __getattr__ call... used to support # object hasattr method """ Indicates that an unknown resource was accessed in the body of an execution step. May often happen by accessing a resource in the compute function of a solid without first supplying the solid with the correct `required_resource_keys` argument. """ def __init__(self, resource_name, *args, **kwargs): self.resource_name = check.str_param(resource_name, "resource_name") msg = ( "Unknown resource `{resource_name}`. Specify `{resource_name}` as a required resource " "on the compute / config function that accessed it." ).format(resource_name=resource_name) super(DagsterUnknownResourceError, self).__init__(msg, *args, **kwargs) class DagsterInvalidConfigError(DagsterError): """Thrown when provided config is invalid (does not type check against the relevant config schema).""" def __init__(self, preamble, errors, config_value, *args, **kwargs): from dagster.config.errors import EvaluationError check.str_param(preamble, "preamble") self.errors = check.list_param(errors, "errors", of_type=EvaluationError) self.config_value = config_value error_msg = preamble error_messages = [] for i_error, error in enumerate(self.errors): error_messages.append(error.message) error_msg += "\n Error {i_error}: {error_message}".format( i_error=i_error + 1, error_message=error.message ) self.message = error_msg self.error_messages = error_messages super(DagsterInvalidConfigError, self).__init__(error_msg, *args, **kwargs) class DagsterUnmetExecutorRequirementsError(DagsterError): """Indicates the resolved executor is incompatible with the state of other systems such as the :py:class:`~dagster.core.instance.DagsterInstance` or system storage configuration. """ class DagsterSubprocessError(DagsterError): """An exception has occurred in one or more of the child processes dagster manages. This error forwards the message and stack trace for all of the collected errors. """ def __init__(self, *args, **kwargs): from dagster.utils.error import SerializableErrorInfo self.subprocess_error_infos = check.list_param( kwargs.pop("subprocess_error_infos"), "subprocess_error_infos", SerializableErrorInfo ) super(DagsterSubprocessError, self).__init__(*args, **kwargs) class DagsterUserCodeProcessError(DagsterError): """An exception has occurred in a user code process that the host process raising this error was communicating with.""" def __init__(self, *args, **kwargs): from dagster.utils.error import SerializableErrorInfo self.user_code_process_error_infos = check.list_param( kwargs.pop("user_code_process_error_infos"), "user_code_process_error_infos", SerializableErrorInfo, ) super(DagsterUserCodeProcessError, self).__init__(*args, **kwargs) class DagsterLaunchFailedError(DagsterError): """Indicates an error while attempting to launch a pipeline run. """ def __init__(self, *args, **kwargs): from dagster.utils.error import SerializableErrorInfo self.serializable_error_info = check.opt_inst_param( kwargs.pop("serializable_error_info", None), "serializable_error_info", SerializableErrorInfo, ) super(DagsterLaunchFailedError, self).__init__(*args, **kwargs) class DagsterBackfillFailedError(DagsterError): """Indicates an error while attempting to launch a backfill. """ def __init__(self, *args, **kwargs): from dagster.utils.error import SerializableErrorInfo self.serializable_error_info = check.opt_inst_param( kwargs.pop("serializable_error_info", None), "serializable_error_info", SerializableErrorInfo, ) super(DagsterBackfillFailedError, self).__init__(*args, **kwargs) class DagsterScheduleWipeRequired(DagsterError): """Indicates that the user must wipe their stored schedule state.""" class DagsterInstanceMigrationRequired(DagsterError): """Indicates that the dagster instance must be migrated.""" def __init__(self, msg=None, db_revision=None, head_revision=None, original_exc_info=None): super(DagsterInstanceMigrationRequired, self).__init__( "Instance is out of date and must be migrated{additional_msg}." "{revision_clause} Please run `dagster instance migrate`.{original_exception_clause}".format( additional_msg=" ({msg})".format(msg=msg) if msg else "", revision_clause=( " Database is at revision {db_revision}, head is " "{head_revision}.".format(db_revision=db_revision, head_revision=head_revision) if db_revision or head_revision else "" ), original_exception_clause=( "\n\nOriginal exception:\n\n{original_exception}".format( original_exception="".join(traceback.format_exception(*original_exc_info)) ) if original_exc_info else "" ), ) ) class DagsterRunAlreadyExists(DagsterError): """Indicates that a pipeline run already exists in a run storage.""" class DagsterSnapshotDoesNotExist(DagsterError): """Indicates you attempted to create a pipeline run with a nonexistent snapshot id""" class DagsterRunConflict(DagsterError): """Indicates that a conflicting pipeline run exists in a run storage.""" class DagsterTypeCheckDidNotPass(DagsterError): """Indicates that a type check failed. This is raised when ``raise_on_error`` is ``True`` in calls to the synchronous pipeline and solid execution APIs (:py:func:`~dagster.execute_pipeline`, :py:func:`~dagster.execute_solid`, etc.), that is, typically in test, and a :py:class:`~dagster.DagsterType`'s type check fails by returning either ``False`` or an instance of :py:class:`~dagster.TypeCheck` whose ``success`` member is ``False``. """ def __init__(self, description=None, metadata_entries=None, dagster_type=None): from dagster import EventMetadataEntry, DagsterType super(DagsterTypeCheckDidNotPass, self).__init__(description) self.description = check.opt_str_param(description, "description") self.metadata_entries = check.opt_list_param( metadata_entries, "metadata_entries", of_type=EventMetadataEntry ) self.dagster_type = check.opt_inst_param(dagster_type, "dagster_type", DagsterType) class DagsterEventLogInvalidForRun(DagsterError): """Raised when the event logs for a historical run are malformed or invalid.""" def __init__(self, run_id): self.run_id = check.str_param(run_id, "run_id") super(DagsterEventLogInvalidForRun, self).__init__( "Event logs invalid for run id {}".format(run_id) ) class ScheduleExecutionError(DagsterUserCodeExecutionError): """Errors raised in a user process during the execution of schedule.""" class SensorExecutionError(DagsterUserCodeExecutionError): """Errors raised in a user process during the execution of a sensor (or its job).""" class PartitionExecutionError(DagsterUserCodeExecutionError): """Errors raised during the execution of user-provided functions of a partition set schedule.""" class DagsterInvalidAssetKey(DagsterError): """ Error raised by invalid asset key """ class HookExecutionError(DagsterUserCodeExecutionError): """ Error raised during the execution of a user-defined hook. """ class DagsterImportError(DagsterError): """ Import error raised while importing user-code. """ class JobError(DagsterUserCodeExecutionError): """Errors raised during the execution of user-provided functions for a defined Job.""" class DagsterUnknownStepStateError(DagsterError): """When pipeline execution complete with steps in an unknown state""" class DagsterExecutionInterruptedError(DagsterError): """Pipeline execution was interrupted during the execution process.""" class DagsterObjectStoreError(DagsterError): """Errors during an object store operation.""" class DagsterInvalidPropertyError(DagsterError): """Indicates that an invalid property was accessed. May often happen by accessing a property that no longer exists after breaking changes.""" diff --git a/python_modules/dagster/dagster/core/execution/stats.py b/python_modules/dagster/dagster/core/execution/stats.py index 03de46ac7..a9e38eb03 100644 --- a/python_modules/dagster/dagster/core/execution/stats.py +++ b/python_modules/dagster/dagster/core/execution/stats.py @@ -1,151 +1,147 @@ from collections import defaultdict, namedtuple from enum import Enum -import six from dagster import check from dagster.core.definitions import AssetMaterialization, ExpectationResult, Materialization from dagster.core.events import DagsterEventType, StepExpectationResultData, StepMaterializationData from dagster.core.events.log import EventRecord from dagster.core.storage.pipeline_run import PipelineRunStatsSnapshot from dagster.serdes import whitelist_for_serdes from dagster.utils import datetime_as_float def build_run_stats_from_events(run_id, records): try: iter(records) except TypeError as exc: - six.raise_from( - check.ParameterCheckError( - "Invariant violation for parameter 'records'. Description: Expected iterable." - ), - from_value=exc, - ) + raise check.ParameterCheckError( + "Invariant violation for parameter 'records'. Description: Expected iterable." + ) from exc for i, record in enumerate(records): - check.inst_param(record, "records[{i}]".format(i=i), EventRecord) + check.inst_param(record, f"records[{i}]", EventRecord) steps_succeeded = 0 steps_failed = 0 materializations = 0 expectations = 0 start_time = None end_time = None for event in records: if not event.is_dagster_event: continue if event.dagster_event.event_type == DagsterEventType.PIPELINE_START: start_time = ( event.timestamp if isinstance(event.timestamp, float) else datetime_as_float(event.timestamp) ) if event.dagster_event.event_type == DagsterEventType.STEP_FAILURE: steps_failed += 1 if event.dagster_event.event_type == DagsterEventType.STEP_SUCCESS: steps_succeeded += 1 if event.dagster_event.event_type == DagsterEventType.STEP_MATERIALIZATION: materializations += 1 if event.dagster_event.event_type == DagsterEventType.STEP_EXPECTATION_RESULT: expectations += 1 if ( event.dagster_event.event_type == DagsterEventType.PIPELINE_SUCCESS or event.dagster_event.event_type == DagsterEventType.PIPELINE_FAILURE or event.dagster_event.event_type == DagsterEventType.PIPELINE_CANCELED ): end_time = ( event.timestamp if isinstance(event.timestamp, float) else datetime_as_float(event.timestamp) ) return PipelineRunStatsSnapshot( run_id, steps_succeeded, steps_failed, materializations, expectations, start_time, end_time ) class StepEventStatus(Enum): SKIPPED = "SKIPPED" SUCCESS = "SUCCESS" FAILURE = "FAILURE" def build_run_step_stats_from_events(run_id, records): by_step_key = defaultdict(dict) for event in records: if not event.is_dagster_event: continue step_key = event.dagster_event.step_key if not step_key: continue if event.dagster_event.event_type == DagsterEventType.STEP_START: by_step_key[step_key]["start_time"] = event.timestamp by_step_key[step_key]["attempts"] = 1 if event.dagster_event.event_type == DagsterEventType.STEP_FAILURE: by_step_key[step_key]["end_time"] = event.timestamp by_step_key[step_key]["status"] = StepEventStatus.FAILURE if event.dagster_event.event_type == DagsterEventType.STEP_RESTARTED: by_step_key[step_key]["attempts"] = by_step_key[step_key].get("attempts") + 1 if event.dagster_event.event_type == DagsterEventType.STEP_SUCCESS: by_step_key[step_key]["end_time"] = event.timestamp by_step_key[step_key]["status"] = StepEventStatus.SUCCESS if event.dagster_event.event_type == DagsterEventType.STEP_SKIPPED: by_step_key[step_key]["end_time"] = event.timestamp by_step_key[step_key]["status"] = StepEventStatus.SKIPPED if event.dagster_event.event_type == DagsterEventType.STEP_MATERIALIZATION: check.inst(event.dagster_event.event_specific_data, StepMaterializationData) materialization = event.dagster_event.event_specific_data.materialization step_materializations = by_step_key[step_key].get("materializations", []) step_materializations.append(materialization) by_step_key[step_key]["materializations"] = step_materializations if event.dagster_event.event_type == DagsterEventType.STEP_EXPECTATION_RESULT: check.inst(event.dagster_event.event_specific_data, StepExpectationResultData) expectation_result = event.dagster_event.event_specific_data.expectation_result step_expectation_results = by_step_key[step_key].get("expectation_results", []) step_expectation_results.append(expectation_result) by_step_key[step_key]["expectation_results"] = step_expectation_results return [ RunStepKeyStatsSnapshot(run_id=run_id, step_key=step_key, **value) for step_key, value in by_step_key.items() ] @whitelist_for_serdes class RunStepKeyStatsSnapshot( namedtuple( "_RunStepKeyStatsSnapshot", ( "run_id step_key status start_time end_time materializations expectation_results attempts" ), ) ): def __new__( cls, run_id, step_key, status=None, start_time=None, end_time=None, materializations=None, expectation_results=None, attempts=None, ): return super(RunStepKeyStatsSnapshot, cls).__new__( cls, run_id=check.str_param(run_id, "run_id"), step_key=check.str_param(step_key, "step_key"), status=check.opt_inst_param(status, "status", StepEventStatus), start_time=check.opt_float_param(start_time, "start_time"), end_time=check.opt_float_param(end_time, "end_time"), materializations=check.opt_list_param( materializations, "materializations", (AssetMaterialization, Materialization) ), expectation_results=check.opt_list_param( expectation_results, "expectation_results", ExpectationResult ), attempts=check.opt_int_param(attempts, "attempts"), ) diff --git a/python_modules/dagster/dagster/core/storage/event_log/sql_event_log.py b/python_modules/dagster/dagster/core/storage/event_log/sql_event_log.py index a8de85a49..e7e546e52 100644 --- a/python_modules/dagster/dagster/core/storage/event_log/sql_event_log.py +++ b/python_modules/dagster/dagster/core/storage/event_log/sql_event_log.py @@ -1,719 +1,718 @@ import logging from abc import abstractmethod from collections import defaultdict from datetime import datetime -import six import sqlalchemy as db from dagster import check, seven from dagster.core.definitions.events import AssetKey, Materialization from dagster.core.errors import DagsterEventLogInvalidForRun from dagster.core.events import DagsterEventType from dagster.core.events.log import EventRecord from dagster.core.execution.stats import RunStepKeyStatsSnapshot, StepEventStatus from dagster.serdes import deserialize_json_to_dagster_namedtuple, serialize_dagster_namedtuple from dagster.utils import datetime_as_float, utc_datetime_from_timestamp from ..pipeline_run import PipelineRunStatsSnapshot from .base import AssetAwareEventLogStorage, EventLogStorage from .migration import REINDEX_DATA_MIGRATIONS, SECONDARY_INDEX_ASSET_KEY from .schema import AssetKeyTable, SecondaryIndexMigrationTable, SqlEventLogStorageTable class SqlEventLogStorage(EventLogStorage): """Base class for SQL backed event log storages. """ @abstractmethod def connect(self, run_id=None): """Context manager yielding a connection. Args: run_id (Optional[str]): Enables those storages which shard based on run_id, e.g., SqliteEventLogStorage, to connect appropriately. """ @abstractmethod def upgrade(self): """This method should perform any schema migrations necessary to bring an out-of-date instance of the storage up to date. """ def reindex(self, print_fn=lambda _: None, force=False): """Call this method to run any data migrations, reindexing to build summary tables.""" for migration_name, migration_fn in REINDEX_DATA_MIGRATIONS.items(): if self.has_secondary_index(migration_name): if not force: print_fn("Skipping already reindexed summary: {}".format(migration_name)) continue print_fn("Starting reindex: {}".format(migration_name)) migration_fn()(self, print_fn) self.enable_secondary_index(migration_name) print_fn("Finished reindexing: {}".format(migration_name)) def prepare_insert_event(self, event): """ Helper method for preparing the event log SQL insertion statement. Abstracted away to have a single place for the logical table representation of the event, while having a way for SQL backends to implement different execution implementations for `store_event`. See the `dagster-postgres` implementation which overrides the generic SQL implementation of `store_event`. """ dagster_event_type = None asset_key_str = None partition = None step_key = event.step_key if event.is_dagster_event: dagster_event_type = event.dagster_event.event_type_value step_key = event.dagster_event.step_key if event.dagster_event.asset_key: check.inst_param(event.dagster_event.asset_key, "asset_key", AssetKey) asset_key_str = event.dagster_event.asset_key.to_string() if event.dagster_event.partition: partition = event.dagster_event.partition # https://stackoverflow.com/a/54386260/324449 return SqlEventLogStorageTable.insert().values( # pylint: disable=no-value-for-parameter run_id=event.run_id, event=serialize_dagster_namedtuple(event), dagster_event_type=dagster_event_type, timestamp=utc_datetime_from_timestamp(event.timestamp), step_key=step_key, asset_key=asset_key_str, partition=partition, ) def store_asset_key(self, conn, event): check.inst_param(event, "event", EventRecord) if not event.is_dagster_event or not event.dagster_event.asset_key: return try: conn.execute( AssetKeyTable.insert().values( # pylint: disable=no-value-for-parameter asset_key=event.dagster_event.asset_key.to_string() ) ) except db.exc.IntegrityError: pass def store_event(self, event): """Store an event corresponding to a pipeline run. Args: event (EventRecord): The event to store. """ check.inst_param(event, "event", EventRecord) insert_event_statement = self.prepare_insert_event(event) run_id = event.run_id with self.connect(run_id) as conn: conn.execute(insert_event_statement) if event.is_dagster_event and event.dagster_event.asset_key: self.store_asset_key(conn, event) def get_logs_for_run_by_log_id(self, run_id, cursor=-1): check.str_param(run_id, "run_id") check.int_param(cursor, "cursor") check.invariant( cursor >= -1, "Don't know what to do with negative cursor {cursor}".format(cursor=cursor), ) # cursor starts at 0 & auto-increment column starts at 1 so adjust cursor = cursor + 1 query = ( db.select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event]) .where(SqlEventLogStorageTable.c.run_id == run_id) .where(SqlEventLogStorageTable.c.id > cursor) .order_by(SqlEventLogStorageTable.c.id.asc()) ) with self.connect(run_id) as conn: results = conn.execute(query).fetchall() events = {} try: for (record_id, json_str,) in results: events[record_id] = check.inst_param( deserialize_json_to_dagster_namedtuple(json_str), "event", EventRecord ) except (seven.JSONDecodeError, check.CheckError) as err: - six.raise_from(DagsterEventLogInvalidForRun(run_id=run_id), err) + raise DagsterEventLogInvalidForRun(run_id=run_id) from err return events def get_logs_for_run(self, run_id, cursor=-1): """Get all of the logs corresponding to a run. Args: run_id (str): The id of the run for which to fetch logs. cursor (Optional[int]): Zero-indexed logs will be returned starting from cursor + 1, i.e., if cursor is -1, all logs will be returned. (default: -1) """ check.str_param(run_id, "run_id") check.int_param(cursor, "cursor") check.invariant( cursor >= -1, "Don't know what to do with negative cursor {cursor}".format(cursor=cursor), ) events_by_id = self.get_logs_for_run_by_log_id(run_id, cursor) return [event for id, event in sorted(events_by_id.items(), key=lambda x: x[0])] def get_stats_for_run(self, run_id): check.str_param(run_id, "run_id") query = ( db.select( [ SqlEventLogStorageTable.c.dagster_event_type, db.func.count().label("n_events_of_type"), db.func.max(SqlEventLogStorageTable.c.timestamp).label("last_event_timestamp"), ] ) .where(SqlEventLogStorageTable.c.run_id == run_id) .group_by("dagster_event_type") ) with self.connect(run_id) as conn: results = conn.execute(query).fetchall() try: counts = {} times = {} for result in results: (dagster_event_type, n_events_of_type, last_event_timestamp) = result if dagster_event_type: counts[dagster_event_type] = n_events_of_type times[dagster_event_type] = last_event_timestamp start_time = times.get(DagsterEventType.PIPELINE_START.value, None) end_time = times.get( DagsterEventType.PIPELINE_SUCCESS.value, times.get( DagsterEventType.PIPELINE_FAILURE.value, times.get(DagsterEventType.PIPELINE_CANCELED.value, None), ), ) return PipelineRunStatsSnapshot( run_id=run_id, steps_succeeded=counts.get(DagsterEventType.STEP_SUCCESS.value, 0), steps_failed=counts.get(DagsterEventType.STEP_FAILURE.value, 0), materializations=counts.get(DagsterEventType.STEP_MATERIALIZATION.value, 0), expectations=counts.get(DagsterEventType.STEP_EXPECTATION_RESULT.value, 0), start_time=datetime_as_float(start_time) if start_time else None, end_time=datetime_as_float(end_time) if end_time else None, ) except (seven.JSONDecodeError, check.CheckError) as err: - six.raise_from(DagsterEventLogInvalidForRun(run_id=run_id), err) + raise DagsterEventLogInvalidForRun(run_id=run_id) from err def get_step_stats_for_run(self, run_id, step_keys=None): check.str_param(run_id, "run_id") check.opt_list_param(step_keys, "step_keys", of_type=str) STEP_STATS_EVENT_TYPES = [ DagsterEventType.STEP_START.value, DagsterEventType.STEP_SUCCESS.value, DagsterEventType.STEP_SKIPPED.value, DagsterEventType.STEP_FAILURE.value, DagsterEventType.STEP_RESTARTED.value, ] by_step_query = ( db.select( [ SqlEventLogStorageTable.c.step_key, SqlEventLogStorageTable.c.dagster_event_type, db.func.max(SqlEventLogStorageTable.c.timestamp).label("timestamp"), db.func.count(SqlEventLogStorageTable.c.id).label("count"), ] ) .where(SqlEventLogStorageTable.c.run_id == run_id) .where(SqlEventLogStorageTable.c.step_key != None) .where(SqlEventLogStorageTable.c.dagster_event_type.in_(STEP_STATS_EVENT_TYPES)) ) if step_keys: by_step_query = by_step_query.where(SqlEventLogStorageTable.c.step_key.in_(step_keys)) by_step_query = by_step_query.group_by( SqlEventLogStorageTable.c.step_key, SqlEventLogStorageTable.c.dagster_event_type, ) with self.connect(run_id) as conn: results = conn.execute(by_step_query).fetchall() by_step_key = defaultdict(dict) for result in results: step_key = result.step_key if result.dagster_event_type == DagsterEventType.STEP_START.value: by_step_key[step_key]["start_time"] = ( datetime_as_float(result.timestamp) if result.timestamp else None ) by_step_key[step_key]["attempts"] = by_step_key[step_key].get("attempts", 0) + 1 if result.dagster_event_type == DagsterEventType.STEP_RESTARTED.value: by_step_key[step_key]["attempts"] = ( # In case we see step retarted events but not a step started event, we want to # only count the restarted events, since the attempt count represents # the number of times we have successfully started runnning the step by_step_key[step_key].get("attempts", 0) + result.count ) if result.dagster_event_type == DagsterEventType.STEP_FAILURE.value: by_step_key[step_key]["end_time"] = ( datetime_as_float(result.timestamp) if result.timestamp else None ) by_step_key[step_key]["status"] = StepEventStatus.FAILURE if result.dagster_event_type == DagsterEventType.STEP_SUCCESS.value: by_step_key[step_key]["end_time"] = ( datetime_as_float(result.timestamp) if result.timestamp else None ) by_step_key[step_key]["status"] = StepEventStatus.SUCCESS if result.dagster_event_type == DagsterEventType.STEP_SKIPPED.value: by_step_key[step_key]["end_time"] = ( datetime_as_float(result.timestamp) if result.timestamp else None ) by_step_key[step_key]["status"] = StepEventStatus.SKIPPED materializations = defaultdict(list) expectation_results = defaultdict(list) raw_event_query = ( db.select([SqlEventLogStorageTable.c.event]) .where(SqlEventLogStorageTable.c.run_id == run_id) .where(SqlEventLogStorageTable.c.step_key != None) .where( SqlEventLogStorageTable.c.dagster_event_type.in_( [ DagsterEventType.STEP_MATERIALIZATION.value, DagsterEventType.STEP_EXPECTATION_RESULT.value, ] ) ) .order_by(SqlEventLogStorageTable.c.id.asc()) ) if step_keys: raw_event_query = raw_event_query.where( SqlEventLogStorageTable.c.step_key.in_(step_keys) ) with self.connect(run_id) as conn: results = conn.execute(raw_event_query).fetchall() try: for (json_str,) in results: event = check.inst_param( deserialize_json_to_dagster_namedtuple(json_str), "event", EventRecord ) if event.dagster_event.event_type == DagsterEventType.STEP_MATERIALIZATION: materializations[event.step_key].append( event.dagster_event.event_specific_data.materialization ) elif event.dagster_event.event_type == DagsterEventType.STEP_EXPECTATION_RESULT: expectation_results[event.step_key].append( event.dagster_event.event_specific_data.expectation_result ) except (seven.JSONDecodeError, check.CheckError) as err: - six.raise_from(DagsterEventLogInvalidForRun(run_id=run_id), err) + raise DagsterEventLogInvalidForRun(run_id=run_id) from err return [ RunStepKeyStatsSnapshot( run_id=run_id, step_key=step_key, status=value.get("status"), start_time=value.get("start_time"), end_time=value.get("end_time"), materializations=materializations.get(step_key), expectation_results=expectation_results.get(step_key), attempts=value.get("attempts"), ) for step_key, value in by_step_key.items() ] def wipe(self): """Clears the event log storage.""" # Should be overridden by SqliteEventLogStorage and other storages that shard based on # run_id # https://stackoverflow.com/a/54386260/324449 with self.connect() as conn: conn.execute(SqlEventLogStorageTable.delete()) # pylint: disable=no-value-for-parameter conn.execute(AssetKeyTable.delete()) # pylint: disable=no-value-for-parameter def delete_events(self, run_id): check.str_param(run_id, "run_id") delete_statement = SqlEventLogStorageTable.delete().where( # pylint: disable=no-value-for-parameter SqlEventLogStorageTable.c.run_id == run_id ) removed_asset_key_query = ( db.select([SqlEventLogStorageTable.c.asset_key]) .where(SqlEventLogStorageTable.c.run_id == run_id) .where(SqlEventLogStorageTable.c.asset_key != None) .group_by(SqlEventLogStorageTable.c.asset_key) ) with self.connect(run_id) as conn: removed_asset_keys = [ AssetKey.from_db_string(row[0]) for row in conn.execute(removed_asset_key_query).fetchall() ] conn.execute(delete_statement) if len(removed_asset_keys) > 0: keys_to_check = [] keys_to_check.extend([key.to_string() for key in removed_asset_keys]) keys_to_check.extend([key.to_string(legacy=True) for key in removed_asset_keys]) remaining_asset_keys = [ AssetKey.from_db_string(row[0]) for row in conn.execute( db.select([SqlEventLogStorageTable.c.asset_key]) .where(SqlEventLogStorageTable.c.asset_key.in_(keys_to_check)) .group_by(SqlEventLogStorageTable.c.asset_key) ) ] to_remove = set(removed_asset_keys) - set(remaining_asset_keys) if to_remove: keys_to_remove = [] keys_to_remove.extend([key.to_string() for key in to_remove]) keys_to_remove.extend([key.to_string(legacy=True) for key in to_remove]) conn.execute( AssetKeyTable.delete().where( # pylint: disable=no-value-for-parameter AssetKeyTable.c.asset_key.in_(keys_to_remove) ) ) @property def is_persistent(self): return True def update_event_log_record(self, record_id, event): """ Utility method for migration scripts to update SQL representation of event records. """ check.int_param(record_id, "record_id") check.inst_param(event, "event", EventRecord) dagster_event_type = None asset_key_str = None if event.is_dagster_event: dagster_event_type = event.dagster_event.event_type_value if event.dagster_event.asset_key: check.inst_param(event.dagster_event.asset_key, "asset_key", AssetKey) asset_key_str = event.dagster_event.asset_key.to_string() with self.connect(run_id=event.run_id) as conn: conn.execute( SqlEventLogStorageTable.update() # pylint: disable=no-value-for-parameter .where(SqlEventLogStorageTable.c.id == record_id) .values( event=serialize_dagster_namedtuple(event), dagster_event_type=dagster_event_type, timestamp=utc_datetime_from_timestamp(event.timestamp), step_key=event.step_key, asset_key=asset_key_str, ) ) def get_event_log_table_data(self, run_id, record_id): """ Utility method to test representation of the record in the SQL table. Returns all of the columns stored in the event log storage (as opposed to the deserialized `EventRecord`). This allows checking that certain fields are extracted to support performant lookups (e.g. extracting `step_key` for fast filtering)""" with self.connect(run_id=run_id) as conn: query = ( db.select([SqlEventLogStorageTable]) .where(SqlEventLogStorageTable.c.id == record_id) .order_by(SqlEventLogStorageTable.c.id.asc()) ) return conn.execute(query).fetchone() def has_secondary_index(self, name, run_id=None): """This method uses a checkpoint migration table to see if summary data has been constructed in a secondary index table. Can be used to checkpoint event_log data migrations. """ query = ( db.select([1]) .where(SecondaryIndexMigrationTable.c.name == name) .where(SecondaryIndexMigrationTable.c.migration_completed != None) .limit(1) ) with self.connect(run_id) as conn: results = conn.execute(query).fetchall() return len(results) > 0 def enable_secondary_index(self, name, run_id=None): """This method marks an event_log data migration as complete, to indicate that a summary data migration is complete. """ query = SecondaryIndexMigrationTable.insert().values( # pylint: disable=no-value-for-parameter name=name, migration_completed=datetime.now(), ) try: with self.connect(run_id) as conn: conn.execute(query) except db.exc.IntegrityError: with self.connect(run_id) as conn: conn.execute( SecondaryIndexMigrationTable.update() # pylint: disable=no-value-for-parameter .where(SecondaryIndexMigrationTable.c.name == name) .values(migration_completed=datetime.now()) ) class AssetAwareSqlEventLogStorage(AssetAwareEventLogStorage, SqlEventLogStorage): @abstractmethod def connect(self, run_id=None): pass @abstractmethod def upgrade(self): pass def _add_cursor_limit_to_query(self, query, cursor, limit, ascending=False): """ Helper function to deal with cursor/limit pagination args """ try: cursor = int(cursor) if cursor else None except ValueError: cursor = None if cursor: cursor_query = db.select([SqlEventLogStorageTable.c.id]).where( SqlEventLogStorageTable.c.id == cursor ) if ascending: query = query.where(SqlEventLogStorageTable.c.id > cursor_query) else: query = query.where(SqlEventLogStorageTable.c.id < cursor_query) if limit: query = query.limit(limit) if ascending: query = query.order_by(SqlEventLogStorageTable.c.timestamp.asc()) else: query = query.order_by(SqlEventLogStorageTable.c.timestamp.desc()) return query def has_asset_key(self, asset_key): check.inst_param(asset_key, "asset_key", AssetKey) if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY): query = ( db.select([1]) .where( db.or_( AssetKeyTable.c.asset_key == asset_key.to_string(), AssetKeyTable.c.asset_key == asset_key.to_string(legacy=True), ) ) .limit(1) ) else: query = ( db.select([1]) .where( db.or_( SqlEventLogStorageTable.c.asset_key == asset_key.to_string(), SqlEventLogStorageTable.c.asset_key == asset_key.to_string(legacy=True), ) ) .limit(1) ) with self.connect() as conn: results = conn.execute(query).fetchall() return len(results) > 0 def get_all_asset_keys(self, prefix_path=None): lazy_migrate = False if not prefix_path: if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY): query = db.select([AssetKeyTable.c.asset_key]) else: query = ( db.select([SqlEventLogStorageTable.c.asset_key]) .where(SqlEventLogStorageTable.c.asset_key != None) .distinct() ) # This is in place to migrate everyone to using the secondary index table for asset # keys. Performing this migration should result in a big performance boost for # any asset-catalog reads. # After a sufficient amount of time (>= 0.11.0?), we can remove the checks # for has_secondary_index(SECONDARY_INDEX_ASSET_KEY) and always read from the # AssetKeyTable, since we are already writing to the table. Tracking the conditional # check removal here: https://github.com/dagster-io/dagster/issues/3507 lazy_migrate = True else: if self.has_secondary_index(SECONDARY_INDEX_ASSET_KEY): query = db.select([AssetKeyTable.c.asset_key]).where( db.or_( AssetKeyTable.c.asset_key.startswith(AssetKey.get_db_prefix(prefix_path)), AssetKeyTable.c.asset_key.startswith( AssetKey.get_db_prefix(prefix_path, legacy=True) ), ) ) else: query = ( db.select([SqlEventLogStorageTable.c.asset_key]) .where(SqlEventLogStorageTable.c.asset_key != None) .where( db.or_( SqlEventLogStorageTable.c.asset_key.startswith( AssetKey.get_db_prefix(prefix_path) ), SqlEventLogStorageTable.c.asset_key.startswith( AssetKey.get_db_prefix(prefix_path, legacy=True) ), ) ) .distinct() ) with self.connect() as conn: results = conn.execute(query).fetchall() if lazy_migrate: # This is in place to migrate everyone to using the secondary index table for asset # keys. Performing this migration should result in a big performance boost for # any subsequent asset-catalog reads. self._lazy_migrate_secondary_index_asset_key( conn, [asset_key for (asset_key,) in results if asset_key] ) return list( set([AssetKey.from_db_string(asset_key) for (asset_key,) in results if asset_key]) ) def _lazy_migrate_secondary_index_asset_key(self, conn, asset_keys): results = conn.execute(db.select([AssetKeyTable.c.asset_key])).fetchall() existing = [asset_key for (asset_key,) in results if asset_key] to_migrate = set(asset_keys) - set(existing) for asset_key in to_migrate: try: conn.execute( AssetKeyTable.insert().values( # pylint: disable=no-value-for-parameter asset_key=AssetKey.from_db_string(asset_key).to_string() ) ) except db.exc.IntegrityError: # asset key already present pass self.enable_secondary_index(SECONDARY_INDEX_ASSET_KEY) def get_asset_events( self, asset_key, partitions=None, cursor=None, limit=None, ascending=False, include_cursor=False, ): check.inst_param(asset_key, "asset_key", AssetKey) check.opt_list_param(partitions, "partitions", of_type=str) query = db.select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event]).where( db.or_( SqlEventLogStorageTable.c.asset_key == asset_key.to_string(), SqlEventLogStorageTable.c.asset_key == asset_key.to_string(legacy=True), ) ) if partitions: query = query.where(SqlEventLogStorageTable.c.partition.in_(partitions)) query = self._add_cursor_limit_to_query(query, cursor, limit, ascending=ascending) with self.connect() as conn: results = conn.execute(query).fetchall() events = [] for row_id, json_str in results: try: event_record = deserialize_json_to_dagster_namedtuple(json_str) if not isinstance(event_record, EventRecord): logging.warning( "Could not resolve asset event record as EventRecord for id `{}`.".format( row_id ) ) continue if include_cursor: events.append(tuple([row_id, event_record])) else: events.append(event_record) except seven.JSONDecodeError: logging.warning("Could not parse asset event record id `{}`.".format(row_id)) return events def get_asset_run_ids(self, asset_key): check.inst_param(asset_key, "asset_key", AssetKey) query = ( db.select( [SqlEventLogStorageTable.c.run_id, db.func.max(SqlEventLogStorageTable.c.timestamp)] ) .where( db.or_( SqlEventLogStorageTable.c.asset_key == asset_key.to_string(), SqlEventLogStorageTable.c.asset_key == asset_key.to_string(legacy=True), ) ) .group_by(SqlEventLogStorageTable.c.run_id,) .order_by(db.func.max(SqlEventLogStorageTable.c.timestamp).desc()) ) with self.connect() as conn: results = conn.execute(query).fetchall() return [run_id for (run_id, _timestamp) in results] def wipe_asset(self, asset_key): check.inst_param(asset_key, "asset_key", AssetKey) event_query = db.select( [SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event] ).where( db.or_( SqlEventLogStorageTable.c.asset_key == asset_key.to_string(), SqlEventLogStorageTable.c.asset_key == asset_key.to_string(legacy=True), ) ) asset_key_delete = AssetKeyTable.delete().where( # pylint: disable=no-value-for-parameter db.or_( AssetKeyTable.c.asset_key == asset_key.to_string(), AssetKeyTable.c.asset_key == asset_key.to_string(legacy=True), ) ) with self.connect() as conn: conn.execute(asset_key_delete) results = conn.execute(event_query).fetchall() for row_id, json_str in results: try: event_record = deserialize_json_to_dagster_namedtuple(json_str) if not isinstance(event_record, EventRecord): continue assert event_record.dagster_event.event_specific_data.materialization.asset_key dagster_event = event_record.dagster_event event_specific_data = dagster_event.event_specific_data materialization = event_specific_data.materialization updated_materialization = Materialization( label=materialization.label, description=materialization.description, metadata_entries=materialization.metadata_entries, asset_key=None, skip_deprecation_warning=True, ) updated_event_specific_data = event_specific_data._replace( materialization=updated_materialization ) updated_dagster_event = dagster_event._replace( event_specific_data=updated_event_specific_data ) updated_record = event_record._replace(dagster_event=updated_dagster_event) # update the event_record here self.update_event_log_record(row_id, updated_record) except seven.JSONDecodeError: logging.warning("Could not parse asset event record id `{}`.".format(row_id)) diff --git a/python_modules/dagster/dagster/core/storage/intermediate_storage.py b/python_modules/dagster/dagster/core/storage/intermediate_storage.py index caeca7c09..4a6a2e2bc 100644 --- a/python_modules/dagster/dagster/core/storage/intermediate_storage.py +++ b/python_modules/dagster/dagster/core/storage/intermediate_storage.py @@ -1,378 +1,371 @@ import warnings from abc import ABC, abstractmethod, abstractproperty -import six from dagster import check from dagster.core.definitions.events import ObjectStoreOperation, ObjectStoreOperationType from dagster.core.errors import DagsterObjectStoreError, DagsterStepOutputNotFoundError from dagster.core.execution.context.system import SystemExecutionContext from dagster.core.execution.plan.outputs import StepOutputHandle from dagster.core.storage.io_manager import IOManager from dagster.core.types.dagster_type import DagsterType, resolve_dagster_type from .object_store import FilesystemObjectStore, InMemoryObjectStore, ObjectStore from .type_storage import TypeStoragePluginRegistry class IntermediateStorage(ABC): # pylint: disable=no-init @abstractmethod def get_intermediate(self, context, dagster_type=None, step_output_handle=None): pass @abstractmethod def set_intermediate( self, context, dagster_type=None, step_output_handle=None, value=None, version=None ): pass @abstractmethod def has_intermediate(self, context, step_output_handle): pass @abstractmethod def copy_intermediate_from_run(self, context, run_id, step_output_handle): pass @abstractproperty def is_persistent(self): pass class IntermediateStorageAdapter(IOManager): def __init__(self, intermediate_storage): self.intermediate_storage = check.inst_param( intermediate_storage, "intermediate_storage", IntermediateStorage ) warnings.warn( "Intermediate Storages are deprecated in 0.10.0 and will be removed in 0.11.0. " "Use IO Managers instead, which gives you better control over how inputs and " "outputs are handled and loaded." ) def handle_output(self, context, obj): res = self.intermediate_storage.set_intermediate( context=context.step_context, dagster_type=context.dagster_type, step_output_handle=StepOutputHandle( context.step_key, context.name, context.mapping_key ), value=obj, version=context.version, ) # Stopgap https://github.com/dagster-io/dagster/issues/3368 if isinstance(res, ObjectStoreOperation): context.log.debug( ( 'Stored output "{output_name}" in {object_store_name}object store{serialization_strategy_modifier} ' "at {address}" ).format( output_name=context.name, object_store_name=res.object_store_name, serialization_strategy_modifier=( " using {serialization_strategy_name}".format( serialization_strategy_name=res.serialization_strategy_name ) if res.serialization_strategy_name else "" ), address=res.key, ) ) def load_input(self, context): step_context = context.step_context source_handle = StepOutputHandle( context.upstream_output.step_key, context.upstream_output.name, context.upstream_output.mapping_key, ) if not self.intermediate_storage.has_intermediate(step_context, source_handle): raise DagsterStepOutputNotFoundError( ( "When executing {step}, discovered required output missing " "from previous step: {previous_step}" ).format(previous_step=source_handle.step_key, step=step_context.step.key), step_key=source_handle.step_key, output_name=source_handle.output_name, ) res = self.intermediate_storage.get_intermediate( context=step_context, dagster_type=context.dagster_type, step_output_handle=source_handle, ) # Stopgap https://github.com/dagster-io/dagster/issues/3368 if isinstance(res, ObjectStoreOperation): context.log.debug( ( "Loaded input {input_name} in {object_store_name}object store{serialization_strategy_modifier} " "from {address}" ).format( input_name=context.name, object_store_name=res.object_store_name, serialization_strategy_modifier=( " using {serialization_strategy_name}".format( serialization_strategy_name=res.serialization_strategy_name ) if res.serialization_strategy_name else "" ), address=res.key, ) ) return res.obj else: return res class ObjectStoreIntermediateStorage(IntermediateStorage): def __init__(self, object_store, root_for_run_id, run_id, type_storage_plugin_registry): self.root_for_run_id = check.callable_param(root_for_run_id, "root_for_run_id") self.run_id = check.str_param(run_id, "run_id") self.object_store = check.inst_param(object_store, "object_store", ObjectStore) self.type_storage_plugin_registry = check.inst_param( type_storage_plugin_registry, "type_storage_plugin_registry", TypeStoragePluginRegistry ) def _get_paths(self, step_output_handle): if step_output_handle.mapping_key: return [ "intermediates", step_output_handle.step_key, step_output_handle.output_name, step_output_handle.mapping_key, ] return ["intermediates", step_output_handle.step_key, step_output_handle.output_name] def get_intermediate_object(self, dagster_type, step_output_handle): check.inst_param(dagster_type, "dagster_type", DagsterType) check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) paths = self._get_paths(step_output_handle) check.param_invariant(len(paths) > 0, "paths") key = self.object_store.key_for_paths([self.root] + paths) try: obj, uri = self.object_store.get_object( key, serialization_strategy=dagster_type.serialization_strategy ) except Exception as error: # pylint: disable=broad-except - six.raise_from( - DagsterObjectStoreError( - _object_store_operation_error_message( - step_output_handle=step_output_handle, - op=ObjectStoreOperationType.GET_OBJECT, - object_store_name=self.object_store.name, - serialization_strategy_name=dagster_type.serialization_strategy.name, - ) - ), - error, - ) + raise DagsterObjectStoreError( + _object_store_operation_error_message( + step_output_handle=step_output_handle, + op=ObjectStoreOperationType.GET_OBJECT, + object_store_name=self.object_store.name, + serialization_strategy_name=dagster_type.serialization_strategy.name, + ) + ) from error return ObjectStoreOperation( op=ObjectStoreOperationType.GET_OBJECT, key=uri, dest_key=None, obj=obj, serialization_strategy_name=dagster_type.serialization_strategy.name, object_store_name=self.object_store.name, ) def get_intermediate( self, context, dagster_type=None, step_output_handle=None, ): dagster_type = resolve_dagster_type(dagster_type) check.opt_inst_param(context, "context", SystemExecutionContext) check.inst_param(dagster_type, "dagster_type", DagsterType) check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) check.invariant(self.has_intermediate(context, step_output_handle)) if self.type_storage_plugin_registry.is_registered(dagster_type): return self.type_storage_plugin_registry.get( dagster_type.unique_name ).get_intermediate_object(self, context, dagster_type, step_output_handle) elif not dagster_type.has_unique_name: self.type_storage_plugin_registry.check_for_unsupported_composite_overrides( dagster_type ) return self.get_intermediate_object(dagster_type, step_output_handle) def set_intermediate_object(self, dagster_type, step_output_handle, value, version=None): check.inst_param(dagster_type, "dagster_type", DagsterType) check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) paths = self._get_paths(step_output_handle) check.param_invariant(len(paths) > 0, "paths") key = self.object_store.key_for_paths([self.root] + paths) try: uri = self.object_store.set_object( key, value, serialization_strategy=dagster_type.serialization_strategy ) except Exception as error: # pylint: disable=broad-except - six.raise_from( - DagsterObjectStoreError( - _object_store_operation_error_message( - step_output_handle=step_output_handle, - op=ObjectStoreOperationType.SET_OBJECT, - object_store_name=self.object_store.name, - serialization_strategy_name=dagster_type.serialization_strategy.name, - ) - ), - error, - ) + raise DagsterObjectStoreError( + _object_store_operation_error_message( + step_output_handle=step_output_handle, + op=ObjectStoreOperationType.SET_OBJECT, + object_store_name=self.object_store.name, + serialization_strategy_name=dagster_type.serialization_strategy.name, + ) + ) from error return ObjectStoreOperation( op=ObjectStoreOperationType.SET_OBJECT, key=uri, dest_key=None, obj=value, serialization_strategy_name=dagster_type.serialization_strategy.name, object_store_name=self.object_store.name, version=version, ) def set_intermediate( self, context, dagster_type=None, step_output_handle=None, value=None, version=None, ): dagster_type = resolve_dagster_type(dagster_type) check.opt_inst_param(context, "context", SystemExecutionContext) check.inst_param(dagster_type, "dagster_type", DagsterType) check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) check.opt_str_param(version, "version") if self.has_intermediate(context, step_output_handle): context.log.warning( "Replacing existing intermediate for %s.%s" % (step_output_handle.step_key, step_output_handle.output_name) ) if self.type_storage_plugin_registry.is_registered(dagster_type): return self.type_storage_plugin_registry.get( dagster_type.unique_name ).set_intermediate_object(self, context, dagster_type, step_output_handle, value) elif not dagster_type.has_unique_name: self.type_storage_plugin_registry.check_for_unsupported_composite_overrides( dagster_type ) return self.set_intermediate_object(dagster_type, step_output_handle, value, version) def has_intermediate(self, context, step_output_handle): check.opt_inst_param(context, "context", SystemExecutionContext) check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) paths = self._get_paths(step_output_handle) check.list_param(paths, "paths", of_type=str) check.param_invariant(len(paths) > 0, "paths") key = self.object_store.key_for_paths([self.root] + paths) return self.object_store.has_object(key) def rm_intermediate(self, context, step_output_handle): check.opt_inst_param(context, "context", SystemExecutionContext) check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) paths = self._get_paths(step_output_handle) check.param_invariant(len(paths) > 0, "paths") key = self.object_store.key_for_paths([self.root] + paths) uri = self.object_store.rm_object(key) return ObjectStoreOperation( op=ObjectStoreOperationType.RM_OBJECT, key=uri, dest_key=None, obj=None, serialization_strategy_name=None, object_store_name=self.object_store.name, ) def copy_intermediate_from_run(self, context, run_id, step_output_handle): check.opt_inst_param(context, "context", SystemExecutionContext) check.str_param(run_id, "run_id") check.inst_param(step_output_handle, "step_output_handle", StepOutputHandle) paths = self._get_paths(step_output_handle) src = self.object_store.key_for_paths([self.root_for_run_id(run_id)] + paths) dst = self.object_store.key_for_paths([self.root] + paths) src_uri, dst_uri = self.object_store.cp_object(src, dst) return ObjectStoreOperation( op=ObjectStoreOperationType.CP_OBJECT, key=src_uri, dest_key=dst_uri, object_store_name=self.object_store.name, ) def uri_for_paths(self, paths, protocol=None): check.list_param(paths, "paths", of_type=str) check.param_invariant(len(paths) > 0, "paths") key = self.key_for_paths(paths) return self.object_store.uri_for_key(key, protocol) def key_for_paths(self, paths): return self.object_store.key_for_paths([self.root] + paths) @property def is_persistent(self): if isinstance(self.object_store, InMemoryObjectStore): return False return True @property def root(self): return self.root_for_run_id(self.run_id) def _object_store_operation_error_message( op, step_output_handle, object_store_name, serialization_strategy_name ): if ObjectStoreOperationType(op) == ObjectStoreOperationType.GET_OBJECT: op_name = "retriving" elif ObjectStoreOperationType(op) == ObjectStoreOperationType.SET_OBJECT: op_name = "storing" else: op_name = "" return ( 'Error occurred during {op_name} output "{output_name}" for step "{step_key}" in ' "{object_store_modifier}object store{serialization_strategy_modifier}." ).format( op_name=op_name, step_key=step_output_handle.step_key, output_name=step_output_handle.output_name, object_store_modifier=( '"{object_store_name}" '.format(object_store_name=object_store_name) if object_store_name else "" ), serialization_strategy_modifier=( ' using "{serialization_strategy_name}"'.format( serialization_strategy_name=serialization_strategy_name ) if serialization_strategy_name else "" ), ) def build_in_mem_intermediates_storage(run_id, type_storage_plugin_registry=None): return ObjectStoreIntermediateStorage( InMemoryObjectStore(), lambda _: "", run_id, type_storage_plugin_registry if type_storage_plugin_registry else TypeStoragePluginRegistry(types_to_register=[]), ) def build_fs_intermediate_storage(root_for_run_id, run_id, type_storage_plugin_registry=None): return ObjectStoreIntermediateStorage( FilesystemObjectStore(), root_for_run_id, run_id, type_storage_plugin_registry if type_storage_plugin_registry else TypeStoragePluginRegistry(types_to_register=[]), ) diff --git a/python_modules/dagster/dagster/core/storage/runs/sql_run_storage.py b/python_modules/dagster/dagster/core/storage/runs/sql_run_storage.py index 221805f03..7cbd111d8 100644 --- a/python_modules/dagster/dagster/core/storage/runs/sql_run_storage.py +++ b/python_modules/dagster/dagster/core/storage/runs/sql_run_storage.py @@ -1,693 +1,694 @@ import logging import zlib from abc import abstractmethod from collections import defaultdict from datetime import datetime from enum import Enum -import six import sqlalchemy as db from dagster import check from dagster.core.errors import DagsterRunAlreadyExists, DagsterSnapshotDoesNotExist from dagster.core.events import DagsterEvent, DagsterEventType from dagster.core.snap import ( ExecutionPlanSnapshot, PipelineSnapshot, create_execution_plan_snapshot_id, create_pipeline_snapshot_id, ) from dagster.core.storage.tags import PARTITION_NAME_TAG, PARTITION_SET_TAG, ROOT_RUN_ID_TAG from dagster.serdes import deserialize_json_to_dagster_namedtuple, serialize_dagster_namedtuple from dagster.seven import JSONDecodeError from dagster.utils import merge_dicts, utc_datetime_from_timestamp from ..pipeline_run import PipelineRun, PipelineRunStatus, PipelineRunsFilter from .base import RunStorage from .migration import RUN_DATA_MIGRATIONS, RUN_PARTITIONS from .schema import ( DaemonHeartbeatsTable, RunTagsTable, RunsTable, SecondaryIndexMigrationTable, SnapshotsTable, ) class SnapshotType(Enum): PIPELINE = "PIPELINE" EXECUTION_PLAN = "EXECUTION_PLAN" class SqlRunStorage(RunStorage): # pylint: disable=no-init """Base class for SQL based run storages """ @abstractmethod def connect(self): """Context manager yielding a sqlalchemy.engine.Connection.""" @abstractmethod def upgrade(self): """This method should perform any schema or data migrations necessary to bring an out-of-date instance of the storage up to date. """ def fetchall(self, query): with self.connect() as conn: result_proxy = conn.execute(query) res = result_proxy.fetchall() result_proxy.close() return res def fetchone(self, query): with self.connect() as conn: result_proxy = conn.execute(query) row = result_proxy.fetchone() result_proxy.close() return row def add_run(self, pipeline_run): check.inst_param(pipeline_run, "pipeline_run", PipelineRun) if pipeline_run.pipeline_snapshot_id and not self.has_pipeline_snapshot( pipeline_run.pipeline_snapshot_id ): raise DagsterSnapshotDoesNotExist( "Snapshot {ss_id} does not exist in run storage".format( ss_id=pipeline_run.pipeline_snapshot_id ) ) has_tags = pipeline_run.tags and len(pipeline_run.tags) > 0 partition = pipeline_run.tags.get(PARTITION_NAME_TAG) if has_tags else None partition_set = pipeline_run.tags.get(PARTITION_SET_TAG) if has_tags else None with self.connect() as conn: try: runs_insert = RunsTable.insert().values( # pylint: disable=no-value-for-parameter run_id=pipeline_run.run_id, pipeline_name=pipeline_run.pipeline_name, status=pipeline_run.status.value, run_body=serialize_dagster_namedtuple(pipeline_run), snapshot_id=pipeline_run.pipeline_snapshot_id, partition=partition, partition_set=partition_set, ) conn.execute(runs_insert) except db.exc.IntegrityError as exc: - six.raise_from(DagsterRunAlreadyExists, exc) + raise DagsterRunAlreadyExists from exc if pipeline_run.tags and len(pipeline_run.tags) > 0: conn.execute( RunTagsTable.insert(), # pylint: disable=no-value-for-parameter [ dict(run_id=pipeline_run.run_id, key=k, value=v) for k, v in pipeline_run.tags.items() ], ) return pipeline_run def handle_run_event(self, run_id, event): check.str_param(run_id, "run_id") check.inst_param(event, "event", DagsterEvent) lookup = { DagsterEventType.PIPELINE_START: PipelineRunStatus.STARTED, DagsterEventType.PIPELINE_SUCCESS: PipelineRunStatus.SUCCESS, DagsterEventType.PIPELINE_FAILURE: PipelineRunStatus.FAILURE, DagsterEventType.PIPELINE_INIT_FAILURE: PipelineRunStatus.FAILURE, DagsterEventType.PIPELINE_ENQUEUED: PipelineRunStatus.QUEUED, DagsterEventType.PIPELINE_STARTING: PipelineRunStatus.STARTING, DagsterEventType.PIPELINE_CANCELING: PipelineRunStatus.CANCELING, DagsterEventType.PIPELINE_CANCELED: PipelineRunStatus.CANCELED, } if event.event_type not in lookup: return run = self.get_run_by_id(run_id) if not run: # TODO log? return new_pipeline_status = lookup[event.event_type] with self.connect() as conn: conn.execute( RunsTable.update() # pylint: disable=no-value-for-parameter .where(RunsTable.c.run_id == run_id) .values( status=new_pipeline_status.value, run_body=serialize_dagster_namedtuple(run.with_status(new_pipeline_status)), update_timestamp=datetime.now(), ) ) def _row_to_run(self, row): return deserialize_json_to_dagster_namedtuple(row[0]) def _rows_to_runs(self, rows): return list(map(self._row_to_run, rows)) def _add_cursor_limit_to_query(self, query, cursor, limit): """ Helper function to deal with cursor/limit pagination args """ if cursor: cursor_query = db.select([RunsTable.c.id]).where(RunsTable.c.run_id == cursor) query = query.where(RunsTable.c.id < cursor_query) if limit: query = query.limit(limit) query = query.order_by(RunsTable.c.id.desc()) return query def _add_filters_to_query(self, query, filters): check.inst_param(filters, "filters", PipelineRunsFilter) if filters.run_ids: query = query.where(RunsTable.c.run_id.in_(filters.run_ids)) if filters.pipeline_name: query = query.where(RunsTable.c.pipeline_name == filters.pipeline_name) if filters.statuses: query = query.where( RunsTable.c.status.in_([status.value for status in filters.statuses]) ) if filters.tags: query = query.where( db.or_( *( db.and_(RunTagsTable.c.key == key, RunTagsTable.c.value == value) for key, value in filters.tags.items() ) ) ).group_by(RunsTable.c.run_body, RunsTable.c.id) if len(filters.tags) > 0: query = query.having(db.func.count(RunsTable.c.run_id) == len(filters.tags)) if filters.snapshot_id: query = query.where(RunsTable.c.snapshot_id == filters.snapshot_id) return query def _runs_query(self, filters=None, cursor=None, limit=None, columns=None): filters = check.opt_inst_param( filters, "filters", PipelineRunsFilter, default=PipelineRunsFilter() ) check.opt_str_param(cursor, "cursor") check.opt_int_param(limit, "limit") check.opt_list_param(columns, "columns") if columns is None: columns = ["run_body"] base_query_columns = [getattr(RunsTable.c, column) for column in columns] # If we have a tags filter, then we need to select from a joined table if filters.tags: base_query = db.select(base_query_columns).select_from( RunsTable.join(RunTagsTable, RunsTable.c.run_id == RunTagsTable.c.run_id) ) else: base_query = db.select(base_query_columns).select_from(RunsTable) query = self._add_filters_to_query(base_query, filters) query = self._add_cursor_limit_to_query(query, cursor, limit) return query def get_runs(self, filters=None, cursor=None, limit=None): query = self._runs_query(filters, cursor, limit) rows = self.fetchall(query) return self._rows_to_runs(rows) def get_runs_count(self, filters=None): subquery = self._runs_query(filters=filters).alias("subquery") # We use an alias here because Postgres requires subqueries to be # aliased. subquery = subquery.alias("subquery") query = db.select([db.func.count()]).select_from(subquery) rows = self.fetchall(query) count = rows[0][0] return count def get_run_by_id(self, run_id): """Get a run by its id. Args: run_id (str): The id of the run Returns: Optional[PipelineRun] """ check.str_param(run_id, "run_id") query = db.select([RunsTable.c.run_body]).where(RunsTable.c.run_id == run_id) rows = self.fetchall(query) return deserialize_json_to_dagster_namedtuple(rows[0][0]) if len(rows) else None def get_run_tags(self): result = defaultdict(set) query = db.select([RunTagsTable.c.key, RunTagsTable.c.value]).distinct( RunTagsTable.c.key, RunTagsTable.c.value ) rows = self.fetchall(query) for r in rows: result[r[0]].add(r[1]) return sorted(list([(k, v) for k, v in result.items()]), key=lambda x: x[0]) def add_run_tags(self, run_id, new_tags): check.str_param(run_id, "run_id") check.dict_param(new_tags, "new_tags", key_type=str, value_type=str) run = self.get_run_by_id(run_id) current_tags = run.tags if run.tags else {} all_tags = merge_dicts(current_tags, new_tags) partition = all_tags.get(PARTITION_NAME_TAG) partition_set = all_tags.get(PARTITION_SET_TAG) with self.connect() as conn: conn.execute( RunsTable.update() # pylint: disable=no-value-for-parameter .where(RunsTable.c.run_id == run_id) .values( run_body=serialize_dagster_namedtuple( run.with_tags(merge_dicts(current_tags, new_tags)) ), partition=partition, partition_set=partition_set, update_timestamp=datetime.now(), ) ) current_tags_set = set(current_tags.keys()) new_tags_set = set(new_tags.keys()) existing_tags = current_tags_set & new_tags_set added_tags = new_tags_set.difference(existing_tags) for tag in existing_tags: conn.execute( RunTagsTable.update() # pylint: disable=no-value-for-parameter .where(db.and_(RunTagsTable.c.run_id == run_id, RunTagsTable.c.key == tag)) .values(value=new_tags[tag]) ) if added_tags: conn.execute( RunTagsTable.insert(), # pylint: disable=no-value-for-parameter [dict(run_id=run_id, key=tag, value=new_tags[tag]) for tag in added_tags], ) def get_run_group(self, run_id): check.str_param(run_id, "run_id") pipeline_run = self.get_run_by_id(run_id) if not pipeline_run: return None # find root_run root_run_id = pipeline_run.root_run_id if pipeline_run.root_run_id else pipeline_run.run_id root_run = self.get_run_by_id(root_run_id) # root_run_id to run_id 1:1 mapping root_to_run = ( db.select( [RunTagsTable.c.value.label("root_run_id"), RunTagsTable.c.run_id.label("run_id")] ) .where( db.and_(RunTagsTable.c.key == ROOT_RUN_ID_TAG, RunTagsTable.c.value == root_run_id) ) .alias("root_to_run") ) # get run group run_group_query = ( db.select([RunsTable.c.run_body]) .select_from( root_to_run.join( RunsTable, root_to_run.c.run_id == RunsTable.c.run_id, isouter=True, ) ) .alias("run_group") ) with self.connect() as conn: res = conn.execute(run_group_query) run_group = self._rows_to_runs(res) return (root_run_id, [root_run] + run_group) def get_run_groups(self, filters=None, cursor=None, limit=None): # The runs that would be returned by calling RunStorage.get_runs with the same arguments runs = self._runs_query( filters=filters, cursor=cursor, limit=limit, columns=["run_body", "run_id"] ).alias("runs") # Gets us the run_id and associated root_run_id for every run in storage that is a # descendant run of some root # # pseudosql: # with all_descendant_runs as ( # select * # from run_tags # where key = @ROOT_RUN_ID_TAG # ) all_descendant_runs = ( db.select([RunTagsTable]) .where(RunTagsTable.c.key == ROOT_RUN_ID_TAG) .alias("all_descendant_runs") ) # Augment the runs in our query, for those runs that are the descendant of some root run, # with the root_run_id # # pseudosql: # # with runs_augmented as ( # select # runs.run_id as run_id, # all_descendant_runs.value as root_run_id # from runs # left outer join all_descendant_runs # on all_descendant_runs.run_id = runs.run_id # ) runs_augmented = ( db.select( [runs.c.run_id.label("run_id"), all_descendant_runs.c.value.label("root_run_id"),] ) .select_from( runs.join( all_descendant_runs, all_descendant_runs.c.run_id == RunsTable.c.run_id, isouter=True, ) ) .alias("runs_augmented") ) # Get all the runs our query will return. This includes runs as well as their root runs. # # pseudosql: # # with runs_and_root_runs as ( # select runs.run_id as run_id # from runs, runs_augmented # where # runs.run_id = runs_augmented.run_id or # runs.run_id = runs_augmented.root_run_id # ) runs_and_root_runs = ( db.select([RunsTable.c.run_id.label("run_id")]) .select_from(runs_augmented) .where( db.or_( RunsTable.c.run_id == runs_augmented.c.run_id, RunsTable.c.run_id == runs_augmented.c.root_run_id, ) ) .distinct(RunsTable.c.run_id) ).alias("runs_and_root_runs") # We count the descendants of all of the runs in our query that are roots so that # we can accurately display when a root run has more descendants than are returned by this # query and afford a drill-down. This might be an unnecessary complication, but the # alternative isn't obvious -- we could go and fetch *all* the runs in any group that we're # going to return in this query, and then append those. # # pseudosql: # # select runs.run_body, count(all_descendant_runs.id) as child_counts # from runs # join runs_and_root_runs on runs.run_id = runs_and_root_runs.run_id # left outer join all_descendant_runs # on all_descendant_runs.value = runs_and_root_runs.run_id # group by runs.run_body # order by child_counts desc runs_and_root_runs_with_descendant_counts = ( db.select( [ RunsTable.c.run_body, db.func.count(all_descendant_runs.c.id).label("child_counts"), ] ) .select_from( RunsTable.join( runs_and_root_runs, RunsTable.c.run_id == runs_and_root_runs.c.run_id ).join( all_descendant_runs, all_descendant_runs.c.value == runs_and_root_runs.c.run_id, isouter=True, ) ) .group_by(RunsTable.c.run_body) .order_by(db.desc(db.column("child_counts"))) ) with self.connect() as conn: res = conn.execute(runs_and_root_runs_with_descendant_counts).fetchall() # Postprocess: descendant runs get aggregated with their roots run_groups = defaultdict(lambda: {"runs": [], "count": 0}) for (run_body, count) in res: row = (run_body,) pipeline_run = self._row_to_run(row) root_run_id = pipeline_run.get_root_run_id() if root_run_id is not None: run_groups[root_run_id]["runs"].append(pipeline_run) else: run_groups[pipeline_run.run_id]["runs"].append(pipeline_run) run_groups[pipeline_run.run_id]["count"] = count + 1 return run_groups def has_run(self, run_id): check.str_param(run_id, "run_id") return bool(self.get_run_by_id(run_id)) def delete_run(self, run_id): check.str_param(run_id, "run_id") query = db.delete(RunsTable).where(RunsTable.c.run_id == run_id) with self.connect() as conn: conn.execute(query) def has_pipeline_snapshot(self, pipeline_snapshot_id): check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id") return self._has_snapshot_id(pipeline_snapshot_id) def add_pipeline_snapshot(self, pipeline_snapshot): check.inst_param(pipeline_snapshot, "pipeline_snapshot", PipelineSnapshot) return self._add_snapshot( snapshot_id=create_pipeline_snapshot_id(pipeline_snapshot), snapshot_obj=pipeline_snapshot, snapshot_type=SnapshotType.PIPELINE, ) def get_pipeline_snapshot(self, pipeline_snapshot_id): check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id") return self._get_snapshot(pipeline_snapshot_id) def has_execution_plan_snapshot(self, execution_plan_snapshot_id): check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id") return bool(self.get_execution_plan_snapshot(execution_plan_snapshot_id)) def add_execution_plan_snapshot(self, execution_plan_snapshot): check.inst_param(execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot) execution_plan_snapshot_id = create_execution_plan_snapshot_id(execution_plan_snapshot) return self._add_snapshot( snapshot_id=execution_plan_snapshot_id, snapshot_obj=execution_plan_snapshot, snapshot_type=SnapshotType.EXECUTION_PLAN, ) def get_execution_plan_snapshot(self, execution_plan_snapshot_id): check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id") return self._get_snapshot(execution_plan_snapshot_id) def _add_snapshot(self, snapshot_id, snapshot_obj, snapshot_type): check.str_param(snapshot_id, "snapshot_id") check.not_none_param(snapshot_obj, "snapshot_obj") check.inst_param(snapshot_type, "snapshot_type", SnapshotType) with self.connect() as conn: snapshot_insert = SnapshotsTable.insert().values( # pylint: disable=no-value-for-parameter snapshot_id=snapshot_id, - snapshot_body=zlib.compress(serialize_dagster_namedtuple(snapshot_obj).encode()), + snapshot_body=zlib.compress( + serialize_dagster_namedtuple(snapshot_obj).encode("utf-8") + ), snapshot_type=snapshot_type.value, ) conn.execute(snapshot_insert) return snapshot_id def _has_snapshot_id(self, snapshot_id): query = db.select([SnapshotsTable.c.snapshot_id]).where( SnapshotsTable.c.snapshot_id == snapshot_id ) row = self.fetchone(query) return bool(row) def _get_snapshot(self, snapshot_id): query = db.select([SnapshotsTable.c.snapshot_body]).where( SnapshotsTable.c.snapshot_id == snapshot_id ) row = self.fetchone(query) return defensively_unpack_pipeline_snapshot_query(logging, row) if row else None def _get_partition_runs(self, partition_set_name, partition_name): # utility method to help test reads off of the partition column if not self.has_built_index(RUN_PARTITIONS): # query by tags return self.get_runs( filters=PipelineRunsFilter( tags={ PARTITION_SET_TAG: partition_set_name, PARTITION_NAME_TAG: partition_name, } ) ) else: query = ( self._runs_query() .where(RunsTable.c.partition == partition_name) .where(RunsTable.c.partition_set == partition_set_name) ) rows = self.fetchall(query) return self._rows_to_runs(rows) # Tracking data migrations over secondary indexes def build_missing_indexes(self, print_fn=lambda _: None, force_rebuild_all=False): for migration_name, migration_fn in RUN_DATA_MIGRATIONS.items(): if self.has_built_index(migration_name): if not force_rebuild_all: continue print_fn(f"Starting data migration: {migration_name}") migration_fn()(self, print_fn) self.mark_index_built(migration_name) print_fn(f"Finished data migration: {migration_name}") def has_built_index(self, migration_name): query = ( db.select([1]) .where(SecondaryIndexMigrationTable.c.name == migration_name) .where(SecondaryIndexMigrationTable.c.migration_completed != None) .limit(1) ) with self.connect() as conn: results = conn.execute(query).fetchall() return len(results) > 0 def mark_index_built(self, migration_name): query = SecondaryIndexMigrationTable.insert().values( # pylint: disable=no-value-for-parameter name=migration_name, migration_completed=datetime.now(), ) try: with self.connect() as conn: conn.execute(query) except db.exc.IntegrityError: with self.connect() as conn: conn.execute( SecondaryIndexMigrationTable.update() # pylint: disable=no-value-for-parameter .where(SecondaryIndexMigrationTable.c.name == migration_name) .values(migration_completed=datetime.now()) ) # Daemon heartbeats def add_daemon_heartbeat(self, daemon_heartbeat): with self.connect() as conn: # insert, or update if already present try: conn.execute( DaemonHeartbeatsTable.insert().values( # pylint: disable=no-value-for-parameter timestamp=utc_datetime_from_timestamp(daemon_heartbeat.timestamp), daemon_type=daemon_heartbeat.daemon_type.value, daemon_id=daemon_heartbeat.daemon_id, body=serialize_dagster_namedtuple(daemon_heartbeat), ) ) except db.exc.IntegrityError: conn.execute( DaemonHeartbeatsTable.update() # pylint: disable=no-value-for-parameter .where( DaemonHeartbeatsTable.c.daemon_type == daemon_heartbeat.daemon_type.value ) .values( # pylint: disable=no-value-for-parameter timestamp=utc_datetime_from_timestamp(daemon_heartbeat.timestamp), daemon_id=daemon_heartbeat.daemon_id, body=serialize_dagster_namedtuple(daemon_heartbeat), ) ) def get_daemon_heartbeats(self): with self.connect() as conn: rows = conn.execute(db.select(DaemonHeartbeatsTable.columns)) heartbeats = [deserialize_json_to_dagster_namedtuple(row.body) for row in rows] return {heartbeat.daemon_type: heartbeat for heartbeat in heartbeats} def wipe(self): """Clears the run storage.""" with self.connect() as conn: # https://stackoverflow.com/a/54386260/324449 conn.execute(RunsTable.delete()) # pylint: disable=no-value-for-parameter conn.execute(RunTagsTable.delete()) # pylint: disable=no-value-for-parameter conn.execute(SnapshotsTable.delete()) # pylint: disable=no-value-for-parameter conn.execute(DaemonHeartbeatsTable.delete()) # pylint: disable=no-value-for-parameter def wipe_daemon_heartbeats(self): with self.connect() as conn: # https://stackoverflow.com/a/54386260/324449 DaemonHeartbeatsTable.drop(conn) # pylint: disable=no-value-for-parameter GET_PIPELINE_SNAPSHOT_QUERY_ID = "get-pipeline-snapshot" def defensively_unpack_pipeline_snapshot_query(logger, row): # no checking here because sqlalchemy returns a special # row proxy and don't want to instance check on an internal # implementation detail def _warn(msg): logger.warning("get-pipeline-snapshot: {msg}".format(msg=msg)) - if not isinstance(row[0], six.binary_type): + if not isinstance(row[0], bytes): _warn("First entry in row is not a binary type.") return None try: uncompressed_bytes = zlib.decompress(row[0]) except zlib.error: _warn("Could not decompress bytes stored in snapshot table.") return None try: - decoded_str = uncompressed_bytes.decode() + decoded_str = uncompressed_bytes.decode("utf-8") except UnicodeDecodeError: _warn("Could not unicode decode decompressed bytes stored in snapshot table.") return None try: return deserialize_json_to_dagster_namedtuple(decoded_str) except JSONDecodeError: _warn("Could not parse json in snapshot table.") return None diff --git a/python_modules/dagster/dagster/core/storage/schedules/sql_schedule_storage.py b/python_modules/dagster/dagster/core/storage/schedules/sql_schedule_storage.py index aaa65d1e3..adea3b7c7 100644 --- a/python_modules/dagster/dagster/core/storage/schedules/sql_schedule_storage.py +++ b/python_modules/dagster/dagster/core/storage/schedules/sql_schedule_storage.py @@ -1,241 +1,232 @@ from abc import abstractmethod from datetime import datetime -import six import sqlalchemy as db from dagster import check from dagster.core.definitions.job import JobType from dagster.core.errors import DagsterInvariantViolationError from dagster.core.scheduler.job import ( JobState, JobTick, JobTickData, JobTickStatsSnapshot, JobTickStatus, ) from dagster.serdes import deserialize_json_to_dagster_namedtuple, serialize_dagster_namedtuple from dagster.utils import utc_datetime_from_timestamp from .base import ScheduleStorage from .schema import JobTable, JobTickTable class SqlScheduleStorage(ScheduleStorage): """Base class for SQL backed schedule storage """ @abstractmethod def connect(self): """Context manager yielding a sqlalchemy.engine.Connection.""" def execute(self, query): with self.connect() as conn: result_proxy = conn.execute(query) res = result_proxy.fetchall() result_proxy.close() return res def _deserialize_rows(self, rows): return list(map(lambda r: deserialize_json_to_dagster_namedtuple(r[0]), rows)) def all_stored_job_state(self, repository_origin_id=None, job_type=None): check.opt_inst_param(job_type, "job_type", JobType) base_query = db.select([JobTable.c.job_body, JobTable.c.job_origin_id]).select_from( JobTable ) if repository_origin_id: query = base_query.where(JobTable.c.repository_origin_id == repository_origin_id) else: query = base_query if job_type: query = query.where(JobTable.c.job_type == job_type.value) rows = self.execute(query) return self._deserialize_rows(rows) def get_job_state(self, job_origin_id): check.str_param(job_origin_id, "job_origin_id") query = ( db.select([JobTable.c.job_body]) .select_from(JobTable) .where(JobTable.c.job_origin_id == job_origin_id) ) rows = self.execute(query) return self._deserialize_rows(rows[:1])[0] if len(rows) else None def add_job_state(self, job): check.inst_param(job, "job", JobState) with self.connect() as conn: try: conn.execute( JobTable.insert().values( # pylint: disable=no-value-for-parameter job_origin_id=job.job_origin_id, repository_origin_id=job.repository_origin_id, status=job.status.value, job_type=job.job_type.value, job_body=serialize_dagster_namedtuple(job), ) ) except db.exc.IntegrityError as exc: - six.raise_from( - DagsterInvariantViolationError( - "JobState {id} is already present in storage".format(id=job.job_origin_id,) - ), - exc, - ) + raise DagsterInvariantViolationError( + f"JobState {job.job_origin_id} is already present in storage" + ) from exc return job def update_job_state(self, job): check.inst_param(job, "job", JobState) if not self.get_job_state(job.job_origin_id): raise DagsterInvariantViolationError( "JobState {id} is not present in storage".format(id=job.job_origin_id) ) with self.connect() as conn: conn.execute( JobTable.update() # pylint: disable=no-value-for-parameter .where(JobTable.c.job_origin_id == job.job_origin_id) .values(status=job.status.value, job_body=serialize_dagster_namedtuple(job),) ) def delete_job_state(self, job_origin_id): check.str_param(job_origin_id, "job_origin_id") if not self.get_job_state(job_origin_id): raise DagsterInvariantViolationError( "JobState {id} is not present in storage".format(id=job_origin_id) ) with self.connect() as conn: conn.execute( JobTable.delete().where( # pylint: disable=no-value-for-parameter JobTable.c.job_origin_id == job_origin_id ) ) def get_latest_job_tick(self, job_origin_id): check.str_param(job_origin_id, "job_origin_id") query = ( db.select([JobTickTable.c.id, JobTickTable.c.tick_body]) .select_from(JobTickTable) .where(JobTickTable.c.job_origin_id == job_origin_id) .order_by(JobTickTable.c.timestamp.desc()) .limit(1) ) rows = self.execute(query) if len(rows) == 0: return None return JobTick(rows[0][0], deserialize_json_to_dagster_namedtuple(rows[0][1])) def get_job_ticks(self, job_origin_id): check.str_param(job_origin_id, "job_origin_id") query = ( db.select([JobTickTable.c.id, JobTickTable.c.tick_body]) .select_from(JobTickTable) .where(JobTickTable.c.job_origin_id == job_origin_id) .order_by(JobTickTable.c.id.desc()) ) rows = self.execute(query) return list( map(lambda r: JobTick(r[0], deserialize_json_to_dagster_namedtuple(r[1])), rows) ) def create_job_tick(self, job_tick_data): check.inst_param(job_tick_data, "job_tick_data", JobTickData) with self.connect() as conn: try: tick_insert = JobTickTable.insert().values( # pylint: disable=no-value-for-parameter job_origin_id=job_tick_data.job_origin_id, status=job_tick_data.status.value, type=job_tick_data.job_type.value, timestamp=utc_datetime_from_timestamp(job_tick_data.timestamp), tick_body=serialize_dagster_namedtuple(job_tick_data), ) result = conn.execute(tick_insert) tick_id = result.inserted_primary_key[0] return JobTick(tick_id, job_tick_data) except db.exc.IntegrityError as exc: - six.raise_from( - DagsterInvariantViolationError( - "Unable to insert JobTick for job {job_name} in storage".format( - job_name=job_tick_data.job_name, - ) - ), - exc, - ) + raise DagsterInvariantViolationError( + f"Unable to insert JobTick for job {job_tick_data.job_name} in storage" + ) from exc def update_job_tick(self, tick): check.inst_param(tick, "tick", JobTick) with self.connect() as conn: conn.execute( JobTickTable.update() # pylint: disable=no-value-for-parameter .where(JobTickTable.c.id == tick.tick_id) .values( status=tick.status.value, type=tick.job_type.value, timestamp=utc_datetime_from_timestamp(tick.timestamp), tick_body=serialize_dagster_namedtuple(tick.job_tick_data), ) ) return tick def purge_job_ticks(self, job_origin_id, tick_status, before): check.str_param(job_origin_id, "job_origin_id") check.inst_param(tick_status, "tick_status", JobTickStatus) check.inst_param(before, "before", datetime) utc_before = utc_datetime_from_timestamp(before.timestamp()) with self.connect() as conn: conn.execute( JobTickTable.delete() # pylint: disable=no-value-for-parameter .where(JobTickTable.c.status == tick_status.value) .where(JobTickTable.c.timestamp < utc_before) .where(JobTickTable.c.job_origin_id == job_origin_id) ) def get_job_tick_stats(self, job_origin_id): check.str_param(job_origin_id, "job_origin_id") query = ( db.select([JobTickTable.c.status, db.func.count()]) .select_from(JobTickTable) .where(JobTickTable.c.job_origin_id == job_origin_id) .group_by(JobTickTable.c.status) ) rows = self.execute(query) counts = {} for status, count in rows: counts[status] = count return JobTickStatsSnapshot( ticks_started=counts.get(JobTickStatus.STARTED.value, 0), ticks_succeeded=counts.get(JobTickStatus.SUCCESS.value, 0), ticks_skipped=counts.get(JobTickStatus.SKIPPED.value, 0), ticks_failed=counts.get(JobTickStatus.FAILURE.value, 0), ) def wipe(self): """Clears the schedule storage.""" with self.connect() as conn: # https://stackoverflow.com/a/54386260/324449 conn.execute(JobTable.delete()) # pylint: disable=no-value-for-parameter conn.execute(JobTickTable.delete()) # pylint: disable=no-value-for-parameter diff --git a/python_modules/dagster/dagster/core/types/dagster_type.py b/python_modules/dagster/dagster/core/types/dagster_type.py index 7bb68c8e1..fa86922e6 100644 --- a/python_modules/dagster/dagster/core/types/dagster_type.py +++ b/python_modules/dagster/dagster/core/types/dagster_type.py @@ -1,884 +1,883 @@ import typing from abc import abstractmethod from enum import Enum as PythonEnum from functools import partial -import six from dagster import check from dagster.builtins import BuiltinEnum from dagster.config.config_type import Array from dagster.config.config_type import Noneable as ConfigNoneable from dagster.core.definitions.events import TypeCheck from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError from dagster.core.storage.type_storage import TypeStoragePlugin from dagster.serdes import whitelist_for_serdes from dagster.utils.backcompat import rename_warning from .builtin_config_schemas import BuiltinSchemas from .config_schema import DagsterTypeLoader, DagsterTypeMaterializer from .marshal import PickleSerializationStrategy, SerializationStrategy @whitelist_for_serdes class DagsterTypeKind(PythonEnum): ANY = "ANY" SCALAR = "SCALAR" LIST = "LIST" NOTHING = "NOTHING" NULLABLE = "NULLABLE" REGULAR = "REGULAR" class DagsterType: """Define a type in dagster. These can be used in the inputs and outputs of solids. Args: type_check_fn (Callable[[TypeCheckContext, Any], [Union[bool, TypeCheck]]]): The function that defines the type check. It takes the value flowing through the input or output of the solid. If it passes, return either ``True`` or a :py:class:`~dagster.TypeCheck` with ``success`` set to ``True``. If it fails, return either ``False`` or a :py:class:`~dagster.TypeCheck` with ``success`` set to ``False``. The first argument must be named ``context`` (or, if unused, ``_``, ``_context``, or ``context_``). Use ``required_resource_keys`` for access to resources. key (Optional[str]): The unique key to identify types programatically. The key property always has a value. If you omit key to the argument to the init function, it instead receives the value of ``name``. If neither ``key`` nor ``name`` is provided, a ``CheckError`` is thrown. In the case of a generic type such as ``List`` or ``Optional``, this is generated programatically based on the type parameters. For most use cases, name should be set and the key argument should not be specified. name (Optional[str]): A unique name given by a user. If ``key`` is ``None``, ``key`` becomes this value. Name is not given in a case where the user does not specify a unique name for this type, such as a generic class. description (Optional[str]): A markdown-formatted string, displayed in tooling. loader (Optional[DagsterTypeLoader]): An instance of a class that inherits from :py:class:`~dagster.DagsterTypeLoader` and can map config data to a value of this type. Specify this argument if you will need to shim values of this type using the config machinery. As a rule, you should use the :py:func:`@dagster_type_loader ` decorator to construct these arguments. materializer (Optional[DagsterTypeMaterializer]): An instance of a class that inherits from :py:class:`~dagster.DagsterTypeMaterializer` and can persist values of this type. As a rule, you should use the :py:func:`@dagster_type_materializer ` decorator to construct these arguments. serialization_strategy (Optional[SerializationStrategy]): An instance of a class that inherits from :py:class:`~dagster.SerializationStrategy`. The default strategy for serializing this value when automatically persisting it between execution steps. You should set this value if the ordinary serialization machinery (e.g., pickle) will not be adequate for this type. auto_plugins (Optional[List[Type[TypeStoragePlugin]]]): If types must be serialized differently depending on the storage being used for intermediates, they should specify this argument. In these cases the serialization_strategy argument is not sufficient because serialization requires specialized API calls, e.g. to call an S3 API directly instead of using a generic file object. See ``dagster_pyspark.DataFrame`` for an example. required_resource_keys (Optional[Set[str]]): Resource keys required by the ``type_check_fn``. is_builtin (bool): Defaults to False. This is used by tools to display or filter built-in types (such as :py:class:`~dagster.String`, :py:class:`~dagster.Int`) to visually distinguish them from user-defined types. Meant for internal use. kind (DagsterTypeKind): Defaults to None. This is used to determine the kind of runtime type for InputDefinition and OutputDefinition type checking. """ def __init__( self, type_check_fn, key=None, name=None, is_builtin=False, description=None, loader=None, materializer=None, serialization_strategy=None, auto_plugins=None, required_resource_keys=None, kind=DagsterTypeKind.REGULAR, ): check.opt_str_param(key, "key") check.opt_str_param(name, "name") check.invariant(not (name is None and key is None), "Must set key or name") if name is None: check.param_invariant( bool(key), "key", "If name is not provided, must provide key.", ) self.key, self._name = key, None elif key is None: check.param_invariant( bool(name), "name", "If key is not provided, must provide name.", ) self.key, self._name = name, name else: check.invariant(key and name) self.key, self._name = key, name self.description = check.opt_str_param(description, "description") self.loader = check.opt_inst_param(loader, "loader", DagsterTypeLoader) self.materializer = check.opt_inst_param( materializer, "materializer", DagsterTypeMaterializer ) self.serialization_strategy = check.opt_inst_param( serialization_strategy, "serialization_strategy", SerializationStrategy, PickleSerializationStrategy(), ) self.required_resource_keys = check.opt_set_param( required_resource_keys, "required_resource_keys", ) self._type_check_fn = check.callable_param(type_check_fn, "type_check_fn") _validate_type_check_fn(self._type_check_fn, self._name) auto_plugins = check.opt_list_param(auto_plugins, "auto_plugins", of_type=type) check.param_invariant( all( issubclass(auto_plugin_type, TypeStoragePlugin) for auto_plugin_type in auto_plugins ), "auto_plugins", ) self.auto_plugins = auto_plugins self.is_builtin = check.bool_param(is_builtin, "is_builtin") check.invariant( self.display_name is not None, "All types must have a valid display name, got None for key {}".format(key), ) self.kind = check.inst_param(kind, "kind", DagsterTypeKind) def type_check(self, context, value): retval = self._type_check_fn(context, value) if not isinstance(retval, (bool, TypeCheck)): raise DagsterInvariantViolationError( ( "You have returned {retval} of type {retval_type} from the type " 'check function of type "{type_key}". Return value must be instance ' "of TypeCheck or a bool." ).format(retval=repr(retval), retval_type=type(retval), type_key=self.key) ) return TypeCheck(success=retval) if isinstance(retval, bool) else retval def __eq__(self, other): return isinstance(other, DagsterType) and self.key == other.key def __ne__(self, other): return not self.__eq__(other) @staticmethod def from_builtin_enum(builtin_enum): check.invariant(BuiltinEnum.contains(builtin_enum), "must be member of BuiltinEnum") return _RUNTIME_MAP[builtin_enum] @property def display_name(self): """Asserted in __init__ to be not None, overridden in many subclasses""" return self._name @property def unique_name(self): """The unique name of this type. Can be None if the type is not unique, such as container types""" check.invariant( self._name is not None, "unique_name requested but is None for type {}".format(self.display_name), ) return self._name @property def has_unique_name(self): return self._name is not None @property def inner_types(self): return [] @property def input_hydration_schema_key(self): rename_warning("loader_schema_key", "input_hydration_schema_key", "0.10.0") return self.loader_schema_key @property def loader_schema_key(self): return self.loader.schema_type.key if self.loader else None @property def output_materialization_schema_key(self): rename_warning("materializer_schema_key", "output_materialization_schema_key", "0.10.0") return self.materializer_schema_key @property def materializer_schema_key(self): return self.materializer.schema_type.key if self.materializer else None @property def type_param_keys(self): return [] @property def is_nothing(self): return self.kind == DagsterTypeKind.NOTHING @property def supports_fan_in(self): return False def get_inner_type_for_fan_in(self): check.invariant( "DagsterType {name} does not support fan-in, should have checked supports_fan_in before calling getter.".format( name=self.display_name ) ) def _validate_type_check_fn(fn, name): from dagster.seven import get_args args = get_args(fn) # py2 doesn't filter out self if len(args) >= 1 and args[0] == "self": args = args[1:] if len(args) == 2: possible_names = { "_", "context", "_context", "context_", } if args[0] not in possible_names: DagsterInvalidDefinitionError( 'type_check function on type "{name}" must have first ' 'argument named "context" (or _, _context, context_).'.format(name=name,) ) return True raise DagsterInvalidDefinitionError( 'type_check_fn argument on type "{name}" must take 2 arguments, ' "received {count}.".format(name=name, count=len(args)) ) class BuiltinScalarDagsterType(DagsterType): def __init__(self, name, type_check_fn, *args, **kwargs): super(BuiltinScalarDagsterType, self).__init__( key=name, name=name, kind=DagsterTypeKind.SCALAR, type_check_fn=type_check_fn, is_builtin=True, *args, **kwargs, ) def type_check_fn(self, _context, value): return self.type_check_scalar_value(value) @abstractmethod def type_check_scalar_value(self, _value): raise NotImplementedError() class _Int(BuiltinScalarDagsterType): def __init__(self): super(_Int, self).__init__( name="Int", loader=BuiltinSchemas.INT_INPUT, materializer=BuiltinSchemas.INT_OUTPUT, type_check_fn=self.type_check_fn, ) def type_check_scalar_value(self, value): - return _fail_if_not_of_type(value, six.integer_types, "int") + return _fail_if_not_of_type(value, int, "int") def _typemismatch_error_str(value, expected_type_desc): return 'Value "{value}" of python type "{python_type}" must be a {type_desc}.'.format( value=value, python_type=type(value).__name__, type_desc=expected_type_desc ) def _fail_if_not_of_type(value, value_type, value_type_desc): if not isinstance(value, value_type): return TypeCheck(success=False, description=_typemismatch_error_str(value, value_type_desc)) return TypeCheck(success=True) class _String(BuiltinScalarDagsterType): def __init__(self): super(_String, self).__init__( name="String", loader=BuiltinSchemas.STRING_INPUT, materializer=BuiltinSchemas.STRING_OUTPUT, type_check_fn=self.type_check_fn, ) def type_check_scalar_value(self, value): return _fail_if_not_of_type(value, str, "string") class _Float(BuiltinScalarDagsterType): def __init__(self): super(_Float, self).__init__( name="Float", loader=BuiltinSchemas.FLOAT_INPUT, materializer=BuiltinSchemas.FLOAT_OUTPUT, type_check_fn=self.type_check_fn, ) def type_check_scalar_value(self, value): return _fail_if_not_of_type(value, float, "float") class _Bool(BuiltinScalarDagsterType): def __init__(self): super(_Bool, self).__init__( name="Bool", loader=BuiltinSchemas.BOOL_INPUT, materializer=BuiltinSchemas.BOOL_OUTPUT, type_check_fn=self.type_check_fn, ) def type_check_scalar_value(self, value): return _fail_if_not_of_type(value, bool, "bool") class Anyish(DagsterType): def __init__( self, key, name, loader=None, materializer=None, serialization_strategy=None, is_builtin=False, description=None, auto_plugins=None, ): super(Anyish, self).__init__( key=key, name=name, kind=DagsterTypeKind.ANY, loader=loader, materializer=materializer, serialization_strategy=serialization_strategy, is_builtin=is_builtin, type_check_fn=self.type_check_method, description=description, auto_plugins=auto_plugins, ) def type_check_method(self, _context, _value): return TypeCheck(success=True) @property def supports_fan_in(self): return True def get_inner_type_for_fan_in(self): # Anyish all the way down return self class _Any(Anyish): def __init__(self): super(_Any, self).__init__( key="Any", name="Any", loader=BuiltinSchemas.ANY_INPUT, materializer=BuiltinSchemas.ANY_OUTPUT, is_builtin=True, ) def create_any_type( name, loader=None, materializer=None, serialization_strategy=None, description=None, auto_plugins=None, ): return Anyish( key=name, name=name, description=description, loader=loader, materializer=materializer, serialization_strategy=serialization_strategy, auto_plugins=auto_plugins, ) class _Nothing(DagsterType): def __init__(self): super(_Nothing, self).__init__( key="Nothing", name="Nothing", kind=DagsterTypeKind.NOTHING, loader=None, materializer=None, type_check_fn=self.type_check_method, is_builtin=True, ) def type_check_method(self, _context, value): if value is not None: return TypeCheck( success=False, description="Value must be None, got a {value_type}".format(value_type=type(value)), ) return TypeCheck(success=True) @property def supports_fan_in(self): return True def get_inner_type_for_fan_in(self): return self class PythonObjectDagsterType(DagsterType): """Define a type in dagster whose typecheck is an isinstance check. Specifically, the type can either be a single python type (e.g. int), or a tuple of types (e.g. (int, float)) which is treated as a union. Examples: .. code-block:: python ntype = PythonObjectDagsterType(python_type=int) assert ntype.name == 'int' assert_success(ntype, 1) assert_failure(ntype, 'a') .. code-block:: python ntype = PythonObjectDagsterType(python_type=(int, float)) assert ntype.name == 'Union[int, float]' assert_success(ntype, 1) assert_success(ntype, 1.5) assert_failure(ntype, 'a') Args: python_type (Union[Type, Tuple[Type, ...]): The dagster typecheck function calls instanceof on this type. name (Optional[str]): Name the type. Defaults to the name of ``python_type``. key (Optional[str]): Key of the type. Defaults to name. description (Optional[str]): A markdown-formatted string, displayed in tooling. loader (Optional[DagsterTypeLoader]): An instance of a class that inherits from :py:class:`~dagster.DagsterTypeLoader` and can map config data to a value of this type. Specify this argument if you will need to shim values of this type using the config machinery. As a rule, you should use the :py:func:`@dagster_type_loader ` decorator to construct these arguments. materializer (Optional[DagsterTypeMaterializer]): An instance of a class that inherits from :py:class:`~dagster.DagsterTypeMaterializer` and can persist values of this type. As a rule, you should use the :py:func:`@dagster_type_mate ` decorator to construct these arguments. serialization_strategy (Optional[SerializationStrategy]): An instance of a class that inherits from :py:class:`SerializationStrategy`. The default strategy for serializing this value when automatically persisting it between execution steps. You should set this value if the ordinary serialization machinery (e.g., pickle) will not be adequate for this type. auto_plugins (Optional[List[Type[TypeStoragePlugin]]]): If types must be serialized differently depending on the storage being used for intermediates, they should specify this argument. In these cases the serialization_strategy argument is not sufficient because serialization requires specialized API calls, e.g. to call an S3 API directly instead of using a generic file object. See ``dagster_pyspark.DataFrame`` for an example. """ def __init__(self, python_type, key=None, name=None, **kwargs): if isinstance(python_type, tuple): self.python_type = check.tuple_param( python_type, "python_type", of_type=tuple(type for item in python_type) ) self.type_str = "Union[{}]".format( ", ".join(python_type.__name__ for python_type in python_type) ) else: self.python_type = check.type_param(python_type, "python_type") self.type_str = python_type.__name__ name = check.opt_str_param(name, "name", self.type_str) key = check.opt_str_param(key, "key", name) super(PythonObjectDagsterType, self).__init__( key=key, name=name, type_check_fn=self.type_check_method, **kwargs ) def type_check_method(self, _context, value): if not isinstance(value, self.python_type): return TypeCheck( success=False, description=( "Value of type {value_type} failed type check for Dagster type {dagster_type}, " "expected value to be of Python type {expected_type}." ).format( value_type=type(value), dagster_type=self._name, expected_type=self.type_str, ), ) return TypeCheck(success=True) class NoneableInputSchema(DagsterTypeLoader): def __init__(self, inner_dagster_type): self._inner_dagster_type = check.inst_param( inner_dagster_type, "inner_dagster_type", DagsterType ) check.param_invariant(inner_dagster_type.loader, "inner_dagster_type") self._schema_type = ConfigNoneable(inner_dagster_type.loader.schema_type) @property def schema_type(self): return self._schema_type def construct_from_config_value(self, context, config_value): if config_value is None: return None return self._inner_dagster_type.loader.construct_from_config_value(context, config_value) def _create_nullable_input_schema(inner_type): if not inner_type.loader: return None return NoneableInputSchema(inner_type) class OptionalType(DagsterType): def __init__(self, inner_type): inner_type = resolve_dagster_type(inner_type) if inner_type is Nothing: raise DagsterInvalidDefinitionError( "Type Nothing can not be wrapped in List or Optional" ) key = "Optional." + inner_type.key self.inner_type = inner_type super(OptionalType, self).__init__( key=key, name=None, kind=DagsterTypeKind.NULLABLE, type_check_fn=self.type_check_method, loader=_create_nullable_input_schema(inner_type), ) @property def display_name(self): return self.inner_type.display_name + "?" def type_check_method(self, context, value): return ( TypeCheck(success=True) if value is None else self.inner_type.type_check(context, value) ) @property def inner_types(self): return [self.inner_type] + self.inner_type.inner_types @property def type_param_keys(self): return [self.inner_type.key] @property def supports_fan_in(self): return self.inner_type.supports_fan_in def get_inner_type_for_fan_in(self): return self.inner_type.get_inner_type_for_fan_in() class ListInputSchema(DagsterTypeLoader): def __init__(self, inner_dagster_type): self._inner_dagster_type = check.inst_param( inner_dagster_type, "inner_dagster_type", DagsterType ) check.param_invariant(inner_dagster_type.loader, "inner_dagster_type") self._schema_type = Array(inner_dagster_type.loader.schema_type) @property def schema_type(self): return self._schema_type def construct_from_config_value(self, context, config_value): convert_item = partial(self._inner_dagster_type.loader.construct_from_config_value, context) return list(map(convert_item, config_value)) def _create_list_input_schema(inner_type): if not inner_type.loader: return None return ListInputSchema(inner_type) class ListType(DagsterType): def __init__(self, inner_type): key = "List." + inner_type.key self.inner_type = inner_type super(ListType, self).__init__( key=key, name=None, kind=DagsterTypeKind.LIST, type_check_fn=self.type_check_method, loader=_create_list_input_schema(inner_type), ) @property def display_name(self): return "[" + self.inner_type.display_name + "]" def type_check_method(self, context, value): value_check = _fail_if_not_of_type(value, list, "list") if not value_check.success: return value_check for item in value: item_check = self.inner_type.type_check(context, item) if not item_check.success: return item_check return TypeCheck(success=True) @property def inner_types(self): return [self.inner_type] + self.inner_type.inner_types @property def type_param_keys(self): return [self.inner_type.key] @property def supports_fan_in(self): return True def get_inner_type_for_fan_in(self): return self.inner_type class DagsterListApi: def __getitem__(self, inner_type): check.not_none_param(inner_type, "inner_type") return _List(resolve_dagster_type(inner_type)) def __call__(self, inner_type): check.not_none_param(inner_type, "inner_type") return _List(inner_type) List = DagsterListApi() def _List(inner_type): check.inst_param(inner_type, "inner_type", DagsterType) if inner_type is Nothing: raise DagsterInvalidDefinitionError("Type Nothing can not be wrapped in List or Optional") return ListType(inner_type) class Stringish(DagsterType): def __init__(self, key=None, name=None, **kwargs): name = check.opt_str_param(name, "name", type(self).__name__) key = check.opt_str_param(key, "key", name) super(Stringish, self).__init__( key=key, name=name, kind=DagsterTypeKind.SCALAR, type_check_fn=self.type_check_method, loader=BuiltinSchemas.STRING_INPUT, materializer=BuiltinSchemas.STRING_OUTPUT, **kwargs, ) def type_check_method(self, _context, value): return _fail_if_not_of_type(value, str, "string") def create_string_type(name, description=None): return Stringish(name=name, key=name, description=description) Any = _Any() Bool = _Bool() Float = _Float() Int = _Int() String = _String() Nothing = _Nothing() _RUNTIME_MAP = { BuiltinEnum.ANY: Any, BuiltinEnum.BOOL: Bool, BuiltinEnum.FLOAT: Float, BuiltinEnum.INT: Int, BuiltinEnum.STRING: String, BuiltinEnum.NOTHING: Nothing, } _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY: typing.Dict[type, DagsterType] = {} """Python types corresponding to user-defined RunTime types created using @map_to_dagster_type or as_dagster_type are registered here so that we can remap the Python types to runtime types.""" def make_python_type_usable_as_dagster_type(python_type, dagster_type): """ Take any existing python type and map it to a dagster type (generally created with :py:class:`DagsterType `) This can only be called once on a given python type. """ check.inst_param(dagster_type, "dagster_type", DagsterType) if ( _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY.get(python_type, dagster_type) is not dagster_type ): # This would be just a great place to insert a short URL pointing to the type system # documentation into the error message # https://github.com/dagster-io/dagster/issues/1831 raise DagsterInvalidDefinitionError( ( "A Dagster type has already been registered for the Python type " "{python_type}. make_python_type_usable_as_dagster_type can only " "be called once on a python type as it is registering a 1:1 mapping " "between that python type and a dagster type." ).format(python_type=python_type) ) _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY[python_type] = dagster_type DAGSTER_INVALID_TYPE_ERROR_MESSAGE = ( "Invalid type: dagster_type must be DagsterType, a python scalar, or a python type " "that has been marked usable as a dagster type via @usable_dagster_type or " "make_python_type_usable_as_dagster_type: got {dagster_type}{additional_msg}" ) def resolve_dagster_type(dagster_type): # circular dep from .python_dict import PythonDict, Dict from .python_set import PythonSet, DagsterSetApi from .python_tuple import PythonTuple, DagsterTupleApi from .transform_typing import transform_typing_type from dagster.config.config_type import ConfigType from dagster.primitive_mapping import ( remap_python_builtin_for_runtime, is_supported_runtime_python_builtin, ) from dagster.utils.typing_api import is_typing_type check.invariant( not (isinstance(dagster_type, type) and issubclass(dagster_type, ConfigType)), "Cannot resolve a config type to a runtime type", ) check.invariant( not (isinstance(dagster_type, type) and issubclass(dagster_type, DagsterType)), "Do not pass runtime type classes. Got {}".format(dagster_type), ) # First check to see if it part of python's typing library if is_typing_type(dagster_type): dagster_type = transform_typing_type(dagster_type) if isinstance(dagster_type, DagsterType): return dagster_type # Test for unhashable objects -- this is if, for instance, someone has passed us an instance of # a dict where they meant to pass dict or Dict, etc. try: hash(dagster_type) except TypeError: raise DagsterInvalidDefinitionError( DAGSTER_INVALID_TYPE_ERROR_MESSAGE.format( additional_msg=( ", which isn't hashable. Did you pass an instance of a type instead of " "the type?" ), dagster_type=str(dagster_type), ) ) if is_supported_runtime_python_builtin(dagster_type): return remap_python_builtin_for_runtime(dagster_type) if dagster_type is None: return Any if dagster_type in _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY: return _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY[dagster_type] if dagster_type is Dict: return PythonDict if isinstance(dagster_type, DagsterTupleApi): return PythonTuple if isinstance(dagster_type, DagsterSetApi): return PythonSet if isinstance(dagster_type, DagsterListApi): return List(Any) if BuiltinEnum.contains(dagster_type): return DagsterType.from_builtin_enum(dagster_type) if not isinstance(dagster_type, type): raise DagsterInvalidDefinitionError( DAGSTER_INVALID_TYPE_ERROR_MESSAGE.format( dagster_type=str(dagster_type), additional_msg="." ) ) raise DagsterInvalidDefinitionError( "{dagster_type} is not a valid dagster type.".format(dagster_type=dagster_type) ) ALL_RUNTIME_BUILTINS = list(_RUNTIME_MAP.values()) def construct_dagster_type_dictionary(solid_defs): type_dict_by_name = {t.unique_name: t for t in ALL_RUNTIME_BUILTINS} type_dict_by_key = {t.key: t for t in ALL_RUNTIME_BUILTINS} for solid_def in solid_defs: for dagster_type in solid_def.all_dagster_types(): # We don't do uniqueness check on key because with classes # like Array, Noneable, etc, those are ephemeral objects # and it is perfectly fine to have many of them. type_dict_by_key[dagster_type.key] = dagster_type if not dagster_type.has_unique_name: continue if dagster_type.unique_name not in type_dict_by_name: type_dict_by_name[dagster_type.unique_name] = dagster_type continue if type_dict_by_name[dagster_type.unique_name] is not dagster_type: raise DagsterInvalidDefinitionError( ( 'You have created two dagster types with the same name "{type_name}". ' "Dagster types have must have unique names." ).format(type_name=dagster_type.display_name) ) return type_dict_by_key class DagsterOptionalApi: def __getitem__(self, inner_type): check.not_none_param(inner_type, "inner_type") return OptionalType(inner_type) Optional = DagsterOptionalApi() diff --git a/python_modules/dagster/dagster/serdes/__init__.py b/python_modules/dagster/dagster/serdes/__init__.py index 76f27dade..f9f8893bf 100644 --- a/python_modules/dagster/dagster/serdes/__init__.py +++ b/python_modules/dagster/dagster/serdes/__init__.py @@ -1,434 +1,434 @@ """ Serialization & deserialization for Dagster objects. Why have custom serialization? * Default json serialization doesn't work well on namedtuples, which we use extensively to create immutable value types. Namedtuples serialize like tuples as flat lists. * Explicit whitelisting should help ensure we are only persisting or communicating across a serialization boundary the types we expect to. Why not pickle? * This isn't meant to replace pickle in the conditions that pickle is reasonable to use (in memory, not human readable, etc) just handle the json case effectively. """ import hashlib import importlib import sys from abc import ABC, abstractmethod, abstractproperty from collections import namedtuple from enum import Enum from inspect import Parameter, signature import yaml from dagster import check, seven from dagster.utils import compose _WHITELIST_MAP = { "types": {"tuple": {}, "enum": {}}, "persistence": {}, } def create_snapshot_id(snapshot): json_rep = serialize_dagster_namedtuple(snapshot) m = hashlib.sha1() # so that hexdigest is 40, not 64 bytes - m.update(json_rep.encode()) + m.update(json_rep.encode("utf-8")) return m.hexdigest() def serialize_pp(value): return serialize_dagster_namedtuple(value, indent=2, separators=(",", ": ")) def register_serdes_tuple_fallbacks(fallback_map): for class_name, klass in fallback_map.items(): _WHITELIST_MAP["types"]["tuple"][class_name] = klass def _get_dunder_new_params_dict(klass): return signature(klass.__new__).parameters def _get_dunder_new_params(klass): return list(_get_dunder_new_params_dict(klass).values()) class SerdesClassUsageError(Exception): pass class Persistable(ABC): def to_storage_value(self): return default_to_storage_value(self, _WHITELIST_MAP) @classmethod def from_storage_dict(cls, storage_dict): return default_from_storage_dict(cls, storage_dict) def _check_serdes_tuple_class_invariants(klass): dunder_new_params = _get_dunder_new_params(klass) cls_param = dunder_new_params[0] def _with_header(msg): return "For namedtuple {class_name}: {msg}".format(class_name=klass.__name__, msg=msg) if cls_param.name not in {"cls", "_cls"}: raise SerdesClassUsageError( _with_header( 'First parameter must be _cls or cls. Got "{name}".'.format(name=cls_param.name) ) ) value_params = dunder_new_params[1:] for index, field in enumerate(klass._fields): if index >= len(value_params): error_msg = ( "Missing parameters to __new__. You have declared fields " "in the named tuple that are not present as parameters to the " "to the __new__ method. In order for " "both serdes serialization and pickling to work, " "these must match. Missing: {missing_fields}" ).format(missing_fields=repr(list(klass._fields[index:]))) raise SerdesClassUsageError(_with_header(error_msg)) value_param = value_params[index] if value_param.name != field: error_msg = ( "Params to __new__ must match the order of field declaration in the namedtuple. " 'Declared field number {one_based_index} in the namedtuple is "{field_name}". ' 'Parameter {one_based_index} in __new__ method is "{param_name}".' ).format(one_based_index=index + 1, field_name=field, param_name=value_param.name) raise SerdesClassUsageError(_with_header(error_msg)) if len(value_params) > len(klass._fields): # Ensure that remaining parameters have default values for extra_param_index in range(len(klass._fields), len(value_params) - 1): if value_params[extra_param_index].default == Parameter.empty: error_msg = ( 'Parameter "{param_name}" is a parameter to the __new__ ' "method but is not a field in this namedtuple. The only " "reason why this should exist is that " "it is a field that used to exist (we refer to this as the graveyard) " "but no longer does. However it might exist in historical storage. This " "parameter existing ensures that serdes continues to work. However these " "must come at the end and have a default value for pickling to work." ).format(param_name=value_params[extra_param_index].name) raise SerdesClassUsageError(_with_header(error_msg)) def _whitelist_for_persistence(whitelist_map): def __whitelist_for_persistence(klass): check.subclass_param(klass, "klass", Persistable) whitelist_map["persistence"][klass.__name__] = klass return klass return __whitelist_for_persistence def _whitelist_for_serdes(whitelist_map): def __whitelist_for_serdes(klass): if issubclass(klass, Enum): whitelist_map["types"]["enum"][klass.__name__] = klass elif issubclass(klass, tuple): _check_serdes_tuple_class_invariants(klass) whitelist_map["types"]["tuple"][klass.__name__] = klass else: check.failed("Can not whitelist class {klass} for serdes".format(klass=klass)) return klass return __whitelist_for_serdes def whitelist_for_serdes(klass): check.class_param(klass, "klass") return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(klass) def whitelist_for_persistence(klass): check.class_param(klass, "klass") return compose( _whitelist_for_persistence(whitelist_map=_WHITELIST_MAP), _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP), )(klass) def pack_value(val): return _pack_value(val, whitelist_map=_WHITELIST_MAP) def _pack_value(val, whitelist_map): if isinstance(val, list): return [_pack_value(i, whitelist_map) for i in val] if isinstance(val, tuple): klass_name = val.__class__.__name__ check.invariant( klass_name in whitelist_map["types"]["tuple"], "Can only serialize whitelisted namedtuples, received tuple {}".format(val), ) if klass_name in whitelist_map["persistence"]: return val.to_storage_value() base_dict = {key: _pack_value(value, whitelist_map) for key, value in val._asdict().items()} base_dict["__class__"] = klass_name return base_dict if isinstance(val, Enum): klass_name = val.__class__.__name__ check.invariant( klass_name in whitelist_map["types"]["enum"], "Can only serialize whitelisted Enums, received {}".format(klass_name), ) return {"__enum__": str(val)} if isinstance(val, set): return {"__set__": [_pack_value(item, whitelist_map) for item in val]} if isinstance(val, frozenset): return {"__frozenset__": [_pack_value(item, whitelist_map) for item in val]} if isinstance(val, dict): return {key: _pack_value(value, whitelist_map) for key, value in val.items()} return val def _serialize_dagster_namedtuple(nt, whitelist_map, **json_kwargs): return seven.json.dumps(_pack_value(nt, whitelist_map), **json_kwargs) def serialize_value(val): return seven.json.dumps(_pack_value(val, whitelist_map=_WHITELIST_MAP)) def deserialize_value(val): return _unpack_value( seven.json.loads(check.str_param(val, "val")), whitelist_map=_WHITELIST_MAP, ) def serialize_dagster_namedtuple(nt, **json_kwargs): return _serialize_dagster_namedtuple( check.tuple_param(nt, "nt"), whitelist_map=_WHITELIST_MAP, **json_kwargs ) def unpack_value(val): return _unpack_value(val, whitelist_map=_WHITELIST_MAP,) def _unpack_value(val, whitelist_map): if isinstance(val, list): return [_unpack_value(i, whitelist_map) for i in val] if isinstance(val, dict) and val.get("__class__"): klass_name = val.pop("__class__") if klass_name not in whitelist_map["types"]["tuple"]: check.failed( 'Attempted to deserialize class "{}" which is not in the serdes whitelist.'.format( klass_name ) ) klass = whitelist_map["types"]["tuple"][klass_name] if klass is None: return None unpacked_val = {key: _unpack_value(value, whitelist_map) for key, value in val.items()} if klass_name in whitelist_map["persistence"]: return klass.from_storage_dict(unpacked_val) # Naively implements backwards compatibility by filtering arguments that aren't present in # the constructor. If a property is present in the serialized object, but doesn't exist in # the version of the class loaded into memory, that property will be completely ignored. # The call to seven.get_args turns out to be pretty expensive -- we should probably turn # to, e.g., manually managing the deprecated keys on the serdes constructor. args_for_class = seven.get_args(klass) filtered_val = {k: v for k, v in unpacked_val.items() if k in args_for_class} return klass(**filtered_val) if isinstance(val, dict) and val.get("__enum__"): name, member = val["__enum__"].split(".") return getattr(whitelist_map["types"]["enum"][name], member) if isinstance(val, dict) and val.get("__set__") is not None: return set([_unpack_value(item, whitelist_map) for item in val["__set__"]]) if isinstance(val, dict) and val.get("__frozenset__") is not None: return frozenset([_unpack_value(item, whitelist_map) for item in val["__frozenset__"]]) if isinstance(val, dict): return {key: _unpack_value(value, whitelist_map) for key, value in val.items()} return val def deserialize_json_to_dagster_namedtuple(json_str): dagster_namedtuple = _deserialize_json_to_dagster_namedtuple( check.str_param(json_str, "json_str"), whitelist_map=_WHITELIST_MAP ) check.invariant( isinstance(dagster_namedtuple, tuple), "Output of deserialized json_str was not a namedtuple. Received type {}.".format( type(dagster_namedtuple) ), ) return dagster_namedtuple def _deserialize_json_to_dagster_namedtuple(json_str, whitelist_map): return _unpack_value(seven.json.loads(json_str), whitelist_map=whitelist_map) def default_to_storage_value(value, whitelist_map): base_dict = {key: _pack_value(value, whitelist_map) for key, value in value._asdict().items()} base_dict["__class__"] = value.__class__.__name__ return base_dict def default_from_storage_dict(cls, storage_dict): return cls.__new__(cls, **storage_dict) @whitelist_for_serdes class ConfigurableClassData( namedtuple("_ConfigurableClassData", "module_name class_name config_yaml") ): """Serializable tuple describing where to find a class and the config fragment that should be used to instantiate it. Users should not instantiate this class directly. Classes intended to be serialized in this way should implement the :py:class:`dagster.serdes.ConfigurableClass` mixin. """ def __new__(cls, module_name, class_name, config_yaml): return super(ConfigurableClassData, cls).__new__( cls, check.str_param(module_name, "module_name"), check.str_param(class_name, "class_name"), check.str_param(config_yaml, "config_yaml"), ) def info_str(self, prefix=""): return ( "{p}module: {module}\n" "{p}class: {cls}\n" "{p}config:\n" "{p} {config}".format( p=prefix, module=self.module_name, cls=self.class_name, config=self.config_yaml ) ) def rehydrate(self): from dagster.core.errors import DagsterInvalidConfigError from dagster.config.field import resolve_to_config_type from dagster.config.validate import process_config try: module = importlib.import_module(self.module_name) except ModuleNotFoundError: check.failed( "Couldn't import module {module_name} when attempting to load the " "configurable class {configurable_class}".format( module_name=self.module_name, configurable_class=self.module_name + "." + self.class_name, ) ) try: klass = getattr(module, self.class_name) except AttributeError: check.failed( "Couldn't find class {class_name} in module when attempting to load the " "configurable class {configurable_class}".format( class_name=self.class_name, configurable_class=self.module_name + "." + self.class_name, ) ) if not issubclass(klass, ConfigurableClass): raise check.CheckError( klass, "class {class_name} in module {module_name}".format( class_name=self.class_name, module_name=self.module_name ), ConfigurableClass, ) config_dict = yaml.safe_load(self.config_yaml) result = process_config(resolve_to_config_type(klass.config_type()), config_dict) if not result.success: raise DagsterInvalidConfigError( "Errors whilst loading configuration for {}.".format(klass.config_type()), result.errors, config_dict, ) return klass.from_config_value(self, result.value) class ConfigurableClass(ABC): """Abstract mixin for classes that can be loaded from config. This supports a powerful plugin pattern which avoids both a) a lengthy, hard-to-synchronize list of conditional imports / optional extras_requires in dagster core and b) a magic directory or file in which third parties can place plugin packages. Instead, the intention is to make, e.g., run storage, pluggable with a config chunk like: .. code-block:: yaml run_storage: module: very_cool_package.run_storage class: SplendidRunStorage config: magic_word: "quux" This same pattern should eventually be viable for other system components, e.g. engines. The ``ConfigurableClass`` mixin provides the necessary hooks for classes to be instantiated from an instance of ``ConfigurableClassData``. Pieces of the Dagster system which we wish to make pluggable in this way should consume a config type such as: .. code-block:: python {'module': str, 'class': str, 'config': Field(Permissive())} """ @abstractproperty def inst_data(self): """ Subclass must be able to return the inst_data as a property if it has been constructed through the from_config_value code path. """ @classmethod @abstractmethod def config_type(cls): """dagster.ConfigType: The config type against which to validate a config yaml fragment serialized in an instance of ``ConfigurableClassData``. """ @staticmethod @abstractmethod def from_config_value(inst_data, config_value): """New up an instance of the ConfigurableClass from a validated config value. Called by ConfigurableClassData.rehydrate. Args: config_value (dict): The validated config value to use. Typically this should be the ``value`` attribute of a :py:class:`~dagster.core.types.evaluator.evaluation.EvaluateValueResult`. A common pattern is for the implementation to align the config_value with the signature of the ConfigurableClass's constructor: .. code-block:: python @staticmethod def from_config_value(inst_data, config_value): return MyConfigurableClass(inst_data=inst_data, **config_value) """ diff --git a/python_modules/dagster/dagster/utils/__init__.py b/python_modules/dagster/dagster/utils/__init__.py index d2ba0eac3..6a0571a1f 100644 --- a/python_modules/dagster/dagster/utils/__init__.py +++ b/python_modules/dagster/dagster/utils/__init__.py @@ -1,528 +1,526 @@ import contextlib import datetime import errno import functools import inspect import os -import pickle import re import signal import socket import subprocess import sys import tempfile import threading from collections import namedtuple from enum import Enum from warnings import warn import _thread as thread -import six import yaml from dagster import check, seven from dagster.core.errors import DagsterExecutionInterruptedError, DagsterInvariantViolationError from dagster.seven import IS_WINDOWS, multiprocessing from dagster.seven.abc import Mapping from .merger import merge_dicts from .yaml_utils import load_yaml_from_glob_list, load_yaml_from_globs, load_yaml_from_path if sys.version_info > (3,): from pathlib import Path # pylint: disable=import-error else: from pathlib2 import Path # pylint: disable=import-error EPOCH = datetime.datetime.utcfromtimestamp(0) PICKLE_PROTOCOL = 4 DEFAULT_WORKSPACE_YAML_FILENAME = "workspace.yaml" def file_relative_path(dunderfile, relative_path): """ This function is useful when one needs to load a file that is relative to the position of the current file. (Such as when you encode a configuration file path in source file and want in runnable in any current working directory) It is meant to be used like the following: file_relative_path(__file__, 'path/relative/to/file') """ check.str_param(dunderfile, "dunderfile") check.str_param(relative_path, "relative_path") return os.path.join(os.path.dirname(dunderfile), relative_path) def script_relative_path(file_path): """ Useful for testing with local files. Use a path relative to where the test resides and this function will return the absolute path of that file. Otherwise it will be relative to script that ran the test Note: this is function is very, very expensive (on the order of 1 millisecond per invocation) so this should only be used in performance insensitive contexts. Prefer file_relative_path for anything with performance constraints. """ # from http://bit.ly/2snyC6s check.str_param(file_path, "file_path") scriptdir = inspect.stack()[1][1] return os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(scriptdir)), file_path)) # Adapted from https://github.com/okunishinishi/python-stringcase/blob/master/stringcase.py def camelcase(string): check.str_param(string, "string") string = re.sub(r"^[\-_\.]", "", str(string)) if not string: return string return str(string[0]).upper() + re.sub( r"[\-_\.\s]([a-z])", lambda matched: str(matched.group(1)).upper(), string[1:] ) def ensure_single_item(ddict): check.dict_param(ddict, "ddict") check.param_invariant(len(ddict) == 1, "ddict", "Expected dict with single item") return list(ddict.items())[0] @contextlib.contextmanager def pushd(path): old_cwd = os.getcwd() os.chdir(path) try: yield path finally: os.chdir(old_cwd) def safe_isfile(path): """"Backport of Python 3.8 os.path.isfile behavior. This is intended to backport https://docs.python.org/dev/whatsnew/3.8.html#os-path. I'm not sure that there are other ways to provoke this behavior on Unix other than the null byte, but there are certainly other ways to do it on Windows. Afaict, we won't mask other ValueErrors, and the behavior in the status quo ante is rough because we risk throwing an unexpected, uncaught ValueError from very deep in our logic. """ try: return os.path.isfile(path) except ValueError: return False def mkdir_p(path): try: os.makedirs(path) return path except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise class frozendict(dict): def __readonly__(self, *args, **kwargs): raise RuntimeError("Cannot modify ReadOnlyDict") # https://docs.python.org/3/library/pickle.html#object.__reduce__ # # For a dict, the default behavior for pickle is to iteratively call __setitem__ (see 5th item # in __reduce__ tuple). Since we want to disable __setitem__ and still inherit dict, we # override this behavior by defining __reduce__. We return the 3rd item in the tuple, which is # passed to __setstate__, allowing us to restore the frozendict. def __reduce__(self): return (frozendict, (), dict(self)) def __setstate__(self, state): self.__init__(state) __setitem__ = __readonly__ __delitem__ = __readonly__ pop = __readonly__ # type: ignore[assignment] popitem = __readonly__ clear = __readonly__ update = __readonly__ # type: ignore[assignment] setdefault = __readonly__ del __readonly__ class frozenlist(list): def __readonly__(self, *args, **kwargs): raise RuntimeError("Cannot modify ReadOnlyList") # https://docs.python.org/3/library/pickle.html#object.__reduce__ # # Like frozendict, implement __reduce__ and __setstate__ to handle pickling. # Otherwise, __setstate__ will be called to restore the frozenlist, causing # a RuntimeError because frozenlist is not mutable. def __reduce__(self): return (frozenlist, (), list(self)) def __setstate__(self, state): self.__init__(state) __setitem__ = __readonly__ # type: ignore[assignment] __delitem__ = __readonly__ append = __readonly__ clear = __readonly__ extend = __readonly__ insert = __readonly__ pop = __readonly__ remove = __readonly__ reverse = __readonly__ sort = __readonly__ # type: ignore[assignment] def __hash__(self): return hash(tuple(self)) def make_readonly_value(value): if isinstance(value, list): return frozenlist(list(map(make_readonly_value, value))) elif isinstance(value, dict): return frozendict({key: make_readonly_value(value) for key, value in value.items()}) else: return value def get_prop_or_key(elem, key): if isinstance(elem, Mapping): return elem.get(key) else: return getattr(elem, key) def list_pull(alist, key): return list(map(lambda elem: get_prop_or_key(elem, key), alist)) def all_none(kwargs): for value in kwargs.values(): if value is not None: return False return True def check_script(path, return_code=0): try: subprocess.check_output([sys.executable, path]) except subprocess.CalledProcessError as exc: if return_code != 0: if exc.returncode == return_code: return raise def check_cli_execute_file_pipeline(path, pipeline_fn_name, env_file=None): from dagster.core.test_utils import instance_for_test with instance_for_test(): cli_cmd = [ sys.executable, "-m", "dagster", "pipeline", "execute", "-f", path, "-a", pipeline_fn_name, ] if env_file: cli_cmd.append("-c") cli_cmd.append(env_file) try: subprocess.check_output(cli_cmd) except subprocess.CalledProcessError as cpe: print(cpe) # pylint: disable=print-call raise cpe def safe_tempfile_path_unmanaged(): # This gets a valid temporary file path in the safest possible way, although there is still no # guarantee that another process will not create a file at this path. The NamedTemporaryFile is # deleted when the context manager exits and the file object is closed. # # This is preferable to using NamedTemporaryFile as a context manager and passing the name # attribute of the file object around because NamedTemporaryFiles cannot be opened a second time # if already open on Windows NT or later: # https://docs.python.org/3.8/library/tempfile.html#tempfile.NamedTemporaryFile # https://github.com/dagster-io/dagster/issues/1582 with tempfile.NamedTemporaryFile() as fd: path = fd.name return Path(path).as_posix() @contextlib.contextmanager def safe_tempfile_path(): try: path = safe_tempfile_path_unmanaged() yield path finally: if os.path.exists(path): os.unlink(path) def ensure_gen(thing_or_gen): if not inspect.isgenerator(thing_or_gen): def _gen_thing(): yield thing_or_gen return _gen_thing() return thing_or_gen def ensure_dir(file_path): try: os.makedirs(file_path) except OSError as e: if e.errno != errno.EEXIST: raise def ensure_file(path): ensure_dir(os.path.dirname(path)) if not os.path.exists(path): touch_file(path) def touch_file(path): ensure_dir(os.path.dirname(path)) with open(path, "a"): os.utime(path, None) def _kill_on_event(termination_event): termination_event.wait() send_interrupt() def send_interrupt(): if IS_WINDOWS: # This will raise a KeyboardInterrupt in python land - meaning this wont be able to # interrupt things like sleep() thread.interrupt_main() else: # If on unix send an os level signal to interrupt any situation we may be stuck in os.kill(os.getpid(), signal.SIGINT) # Function to be invoked by daemon thread in processes which seek to be cancellable. # The motivation for this approach is to be able to exit cleanly on Windows. An alternative # path is to change how the processes are opened and send CTRL_BREAK signals, which at # the time of authoring seemed a more costly approach. # # Reading for the curious: # * https://stackoverflow.com/questions/35772001/how-to-handle-the-signal-in-python-on-windows-machine # * https://stefan.sofa-rockers.org/2013/08/15/handling-sub-process-hierarchies-python-linux-os-x/ def start_termination_thread(termination_event): check.inst_param(termination_event, "termination_event", ttype=type(multiprocessing.Event())) int_thread = threading.Thread( target=_kill_on_event, args=(termination_event,), name="kill-on-event" ) int_thread.daemon = True int_thread.start() # Executes the next() function within an instance of the supplied context manager class # (leaving the context before yielding each result) def iterate_with_context(context, iterator): while True: # Allow interrupts during user code so that we can terminate slow/hanging steps with context(): try: next_output = next(iterator) except StopIteration: return yield next_output def datetime_as_float(dt): check.inst_param(dt, "dt", datetime.datetime) return float((dt - EPOCH).total_seconds()) # hashable frozen string to string dict class frozentags(frozendict): def __init__(self, *args, **kwargs): super(frozentags, self).__init__(*args, **kwargs) check.dict_param(self, "self", key_type=str, value_type=str) def __hash__(self): return hash(tuple(sorted(self.items()))) def updated_with(self, new_tags): check.dict_param(new_tags, "new_tags", key_type=str, value_type=str) updated = dict(self) for key, value in new_tags.items(): updated[key] = value return frozentags(updated) class EventGenerationManager: """ Utility class that wraps an event generator function, that also yields a single instance of a typed object. All events yielded before the typed object are yielded through the method `generate_setup_events` and all events yielded after the typed object are yielded through the method `generate_teardown_events`. This is used to help replace the context managers used in pipeline initialization with generators so that we can begin emitting initialization events AND construct a pipeline context object, while managing explicit setup/teardown. This does require calling `generate_setup_events` AND `generate_teardown_events` in order to get the typed object. """ def __init__(self, generator, object_cls, require_object=True): self.generator = check.generator(generator) self.object_cls = check.type_param(object_cls, "object_cls") self.require_object = check.bool_param(require_object, "require_object") self.object = None self.did_setup = False self.did_teardown = False def generate_setup_events(self): self.did_setup = True try: while self.object is None: obj = next(self.generator) if isinstance(obj, self.object_cls): self.object = obj else: yield obj except StopIteration: if self.require_object: check.inst_param( self.object, "self.object", self.object_cls, "generator never yielded object of type {}".format(self.object_cls.__name__), ) def get_object(self): if not self.did_setup: check.failed("Called `get_object` before `generate_setup_events`") return self.object def generate_teardown_events(self): self.did_teardown = True if self.object: yield from self.generator def utc_datetime_from_timestamp(timestamp): tz = None if sys.version_info.major >= 3 and sys.version_info.minor >= 2: from datetime import timezone tz = timezone.utc else: import pytz tz = pytz.utc return datetime.datetime.fromtimestamp(timestamp, tz=tz) def is_enum_value(value): return False if value is None else issubclass(value.__class__, Enum) def git_repository_root(): - return six.ensure_str(subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip()) + return subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip() def segfault(): """Reliable cross-Python version segfault. https://bugs.python.org/issue1215#msg143236 """ import ctypes ctypes.string_at(0) def find_free_port(): with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] @contextlib.contextmanager def alter_sys_path(to_add, to_remove): to_restore = [path for path in sys.path] # remove paths for path in to_remove: if path in sys.path: sys.path.remove(path) # add paths for path in to_add: sys.path.insert(0, path) try: yield finally: sys.path = to_restore @contextlib.contextmanager def restore_sys_modules(): sys_modules = {k: v for k, v in sys.modules.items()} try: yield finally: to_delete = set(sys.modules) - set(sys_modules) for key in to_delete: del sys.modules[key] def process_is_alive(pid): if IS_WINDOWS: import psutil # pylint: disable=import-error return psutil.pid_exists(pid=pid) else: try: subprocess.check_output(["ps", str(pid)]) except subprocess.CalledProcessError as exc: assert exc.returncode == 1 return False return True def compose(*args): """ Compose python functions args such that compose(f, g)(x) is equivalent to f(g(x)). """ # reduce using functional composition over all the arguments, with the identity function as # initializer return functools.reduce(lambda f, g: lambda x: f(g(x)), args, lambda x: x) def dict_without_keys(ddict, *keys): return {key: value for key, value in ddict.items() if key not in set(keys)} diff --git a/python_modules/dagster/dagster/utils/backoff.py b/python_modules/dagster/dagster/utils/backoff.py index ea36e8846..587028726 100644 --- a/python_modules/dagster/dagster/utils/backoff.py +++ b/python_modules/dagster/dagster/utils/backoff.py @@ -1,65 +1,64 @@ import time -import six from dagster import check def backoff_delay_generator(): i = 0.1 while True: yield i i = i * 2 BACKOFF_MAX_RETRIES = 4 def backoff( fn, retry_on, args=None, kwargs=None, max_retries=BACKOFF_MAX_RETRIES, delay_generator=backoff_delay_generator(), ): """Straightforward backoff implementation. Note that this doesn't implement any jitter on the delays, so probably won't be appropriate for very parallel situations. - + Args: fn (Callable): The function to wrap in a backoff/retry loop. retry_on (Tuple[Exception, ...]): The exception classes on which to retry. Note that we don't (yet) have any support for matching the exception messages. args (Optional[List[Any]]): Positional args to pass to the callable. kwargs (Optional[Dict[str, Any]]): Keyword args to pass to the callable. max_retries (Optional[Int]): The maximum number of times to retry a failed fn call. Set to 0 for no backoff. Default: 4 delay_generator (Generator[float, None, None]): Generates the successive delays between retry attempts. """ check.callable_param(fn, "fn") retry_on = check.tuple_param(retry_on, "retry_on") args = check.opt_list_param(args, "args") kwargs = check.opt_dict_param(kwargs, "kwargs", key_type=str) check.int_param(max_retries, "max_retries") check.generator_param(delay_generator, "delay_generator") retries = 0 to_raise = None try: return fn(*args, **kwargs) except retry_on as exc: to_raise = exc while retries < max_retries: - time.sleep(six.next(delay_generator)) + time.sleep(next(delay_generator)) try: return fn(*args, **kwargs) except retry_on as exc: retries += 1 to_raise = exc continue raise to_raise diff --git a/python_modules/dagster/dagster/utils/indenting_printer.py b/python_modules/dagster/dagster/utils/indenting_printer.py index d7c984882..aca352cec 100644 --- a/python_modules/dagster/dagster/utils/indenting_printer.py +++ b/python_modules/dagster/dagster/utils/indenting_printer.py @@ -1,86 +1,86 @@ from contextlib import contextmanager +from io import StringIO from textwrap import TextWrapper from dagster import check -from six import StringIO LINE_LENGTH = 100 class IndentingPrinter: def __init__(self, indent_level=2, printer=print, current_indent=0, line_length=LINE_LENGTH): self.current_indent = current_indent self.indent_level = check.int_param(indent_level, "indent_level") self.printer = check.callable_param(printer, "printer") self.line_length = line_length self._line_so_far = "" def append(self, text): check.str_param(text, "text") self._line_so_far += text def line(self, text): check.str_param(text, "text") self.printer(self.current_indent_str + self._line_so_far + text) self._line_so_far = "" def block(self, text, prefix="", initial_indent=""): """Automagically wrap a block of text.""" wrapper = TextWrapper( width=self.line_length - len(self.current_indent_str), initial_indent=initial_indent, subsequent_indent=prefix, break_long_words=False, break_on_hyphens=False, ) for line in wrapper.wrap(text): self.line(line) def comment(self, text): self.block(text, prefix="# ", initial_indent="# ") @property def current_indent_str(self): return " " * self.current_indent def blank_line(self): check.invariant( not self._line_so_far, "Cannot throw away appended strings by calling blank_line" ) self.printer("") def increase_indent(self): self.current_indent += self.indent_level def decrease_indent(self): if self.indent_level and self.current_indent <= 0: raise Exception("indent cannot be negative") self.current_indent -= self.indent_level @contextmanager def with_indent(self, text=None): if text is not None: self.line(text) self.increase_indent() yield self.decrease_indent() class IndentingStringIoPrinter(IndentingPrinter): """Subclass of IndentingPrinter wrapping a StringIO.""" def __init__(self, **kwargs): self.buffer = StringIO() self.printer = lambda x: self.buffer.write(x + "\n") super(IndentingStringIoPrinter, self).__init__(printer=self.printer, **kwargs) def __enter__(self): return self def __exit__(self, _exception_type, _exception_value, _traceback): self.buffer.close() def read(self): """Get the value of the backing StringIO.""" return self.buffer.getvalue() diff --git a/python_modules/dagster/dagster/utils/test/postgres_instance.py b/python_modules/dagster/dagster/utils/test/postgres_instance.py index e9498b6dd..e9f223047 100644 --- a/python_modules/dagster/dagster/utils/test/postgres_instance.py +++ b/python_modules/dagster/dagster/utils/test/postgres_instance.py @@ -1,234 +1,234 @@ import os import subprocess import tempfile import warnings from contextlib import contextmanager import pytest from dagster import check, file_relative_path from dagster.core.test_utils import instance_for_test_tempdir from dagster.utils import merge_dicts BUILDKITE = bool(os.getenv("BUILDKITE")) @contextmanager def postgres_instance_for_test(dunder_file, container_name, overrides=None): with tempfile.TemporaryDirectory() as temp_dir: with TestPostgresInstance.docker_service_up_or_skip( file_relative_path(dunder_file, "docker-compose.yml"), container_name, ) as pg_conn_string: TestPostgresInstance.clean_run_storage(pg_conn_string) TestPostgresInstance.clean_event_log_storage(pg_conn_string) TestPostgresInstance.clean_schedule_storage(pg_conn_string) with instance_for_test_tempdir( temp_dir, overrides=merge_dicts( { "run_storage": { "module": "dagster_postgres.run_storage.run_storage", "class": "PostgresRunStorage", "config": {"postgres_url": pg_conn_string}, }, "event_log_storage": { "module": "dagster_postgres.event_log.event_log", "class": "PostgresEventLogStorage", "config": {"postgres_url": pg_conn_string}, }, "schedule_storage": { "module": "dagster_postgres.schedule_storage.schedule_storage", "class": "PostgresScheduleStorage", "config": {"postgres_url": pg_conn_string}, }, }, overrides if overrides else {}, ), ) as instance: yield instance class TestPostgresInstance: @staticmethod def dagster_postgres_installed(): try: import dagster_postgres # pylint: disable=unused-import except ImportError: return False return True @staticmethod def get_hostname(env_name="POSTGRES_TEST_DB_HOST"): # In buildkite we get the ip address from this variable (see buildkite code for commentary) # Otherwise assume local development and assume localhost return os.environ.get(env_name, "localhost") @staticmethod def conn_string(**kwargs): check.invariant( TestPostgresInstance.dagster_postgres_installed(), "dagster_postgres must be installed to test with postgres", ) from dagster_postgres.utils import get_conn_string # pylint: disable=import-error return get_conn_string( **dict( dict( username="test", password="test", hostname=TestPostgresInstance.get_hostname(), db_name="test", ), **kwargs, ) ) @staticmethod def clean_run_storage(conn_string): check.invariant( TestPostgresInstance.dagster_postgres_installed(), "dagster_postgres must be installed to test with postgres", ) from dagster_postgres.run_storage import PostgresRunStorage # pylint: disable=import-error storage = PostgresRunStorage.create_clean_storage(conn_string) assert storage return storage @staticmethod def clean_event_log_storage(conn_string): check.invariant( TestPostgresInstance.dagster_postgres_installed(), "dagster_postgres must be installed to test with postgres", ) from dagster_postgres.event_log import ( # pylint: disable=import-error PostgresEventLogStorage, ) storage = PostgresEventLogStorage.create_clean_storage(conn_string) assert storage return storage @staticmethod def clean_schedule_storage(conn_string): check.invariant( TestPostgresInstance.dagster_postgres_installed(), "dagster_postgres must be installed to test with postgres", ) from dagster_postgres.schedule_storage.schedule_storage import ( # pylint: disable=import-error PostgresScheduleStorage, ) storage = PostgresScheduleStorage.create_clean_storage(conn_string) assert storage return storage @staticmethod @contextmanager def docker_service_up(docker_compose_file, service_name, conn_args=None): check.invariant( TestPostgresInstance.dagster_postgres_installed(), "dagster_postgres must be installed to test with postgres", ) check.str_param(service_name, "service_name") check.str_param(docker_compose_file, "docker_compose_file") check.invariant( os.path.isfile(docker_compose_file), "docker_compose_file must specify a valid file" ) conn_args = check.opt_dict_param(conn_args, "conn_args") if conn_args else {} from dagster_postgres.utils import wait_for_connection # pylint: disable=import-error if BUILDKITE: yield TestPostgresInstance.conn_string( **conn_args ) # buildkite docker is handled in pipeline setup return try: subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "stop", service_name] ) subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "rm", "-f", service_name] ) except subprocess.CalledProcessError: pass try: subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "up", "-d", service_name], stderr=subprocess.STDOUT, # capture STDERR for error handling ) except subprocess.CalledProcessError as ex: - err_text = ex.output.decode() + err_text = ex.output.decode("utf-8") raise PostgresDockerError( "Failed to launch docker container(s) via docker-compose: {}".format(err_text), ex, ) conn_str = TestPostgresInstance.conn_string(**conn_args) wait_for_connection(conn_str, retry_limit=10, retry_wait=3) yield conn_str try: subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "stop", service_name] ) subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "rm", "-f", service_name] ) except subprocess.CalledProcessError: pass @staticmethod @contextmanager def docker_service_up_or_skip(docker_compose_file, service_name, conn_args=None): try: with TestPostgresInstance.docker_service_up( docker_compose_file, service_name, conn_args ) as conn_str: yield conn_str except PostgresDockerError as ex: warnings.warn( "Error launching Dockerized Postgres: {}".format(ex), RuntimeWarning, stacklevel=3 ) pytest.skip("Skipping due to error launching Dockerized Postgres: {}".format(ex)) def is_postgres_running(service_name): check.str_param(service_name, "service_name") try: output = subprocess.check_output( [ "docker", "container", "ps", "-f", "name={}".format(service_name), "-f", "status=running", ], stderr=subprocess.STDOUT, # capture STDERR for error handling ) except subprocess.CalledProcessError as ex: - lines = ex.output.decode().split("\n") + lines = ex.output.decode("utf-8").split("\n") if len(lines) == 2 and "Cannot connect to the Docker daemon" in lines[0]: raise PostgresDockerError("Cannot connect to the Docker daemon", ex) else: raise PostgresDockerError( "Could not verify postgres container was running as expected", ex ) - decoded = output.decode() + decoded = output.decode("utf-8") lines = decoded.split("\n") # header, one line for container, trailing \n # if container is found, service_name should appear at the end of the second line of output return len(lines) == 3 and lines[1].endswith(service_name) class PostgresDockerError(Exception): def __init__(self, message, subprocess_error): super(PostgresDockerError, self).__init__(check.opt_str_param(message, "message")) self.subprocess_error = check.inst_param( subprocess_error, "subprocess_error", subprocess.CalledProcessError ) diff --git a/python_modules/dagster/dagster_tests/cli_tests/test_version.py b/python_modules/dagster/dagster_tests/cli_tests/test_version.py index 65aac85e7..014118ab3 100644 --- a/python_modules/dagster/dagster_tests/cli_tests/test_version.py +++ b/python_modules/dagster/dagster_tests/cli_tests/test_version.py @@ -1,9 +1,9 @@ import subprocess from dagster.version import __version__ def test_version(): - assert subprocess.check_output(["dagster", "--version"]).decode("utf-8").strip("\n").strip( - "\r" - ) == "dagster, version {version}".format(version=__version__) + assert subprocess.check_output(["dagster", "--version"]).decode( + "utf-8" + ).strip() == "dagster, version {version}".format(version=__version__) diff --git a/python_modules/dagster/dagster_tests/conftest.py b/python_modules/dagster/dagster_tests/conftest.py index c6f8655bb..f805ea1d9 100644 --- a/python_modules/dagster/dagster_tests/conftest.py +++ b/python_modules/dagster/dagster_tests/conftest.py @@ -1,138 +1,138 @@ import os import subprocess import sys import time from contextlib import contextmanager import docker import grpc import pytest from dagster import check, seven from dagster.grpc.client import DagsterGrpcClient from dagster.seven import nullcontext from dagster.utils import file_relative_path from dagster_test.dagster_core_docker_buildkite import ( build_and_tag_test_image, get_test_project_docker_image, ) IS_BUILDKITE = os.getenv("BUILDKITE") is not None HARDCODED_PORT = 8090 # Suggested workaround in https://bugs.python.org/issue37380 for subprocesses # failing to open sporadically on windows after other subprocesses were closed. # Fixed in later versions of Python but never back-ported, see the bug for details. if seven.IS_WINDOWS and sys.version_info[0] == 3 and sys.version_info[1] == 6: subprocess._cleanup = lambda: None # type: ignore # pylint: disable=protected-access @pytest.fixture(scope="session") def dagster_docker_image(): docker_image = get_test_project_docker_image() if not IS_BUILDKITE: # Being conservative here when first introducing this. This could fail # if the Docker daemon is not running, so for now we just skip the tests using this # fixture if the build fails, and warn with the output from the build command try: client = docker.from_env() client.images.get(docker_image) print( # pylint: disable=print-call "Found existing image tagged {image}, skipping image build. To rebuild, first run: " "docker rmi {image}".format(image=docker_image) ) except docker.errors.ImageNotFound: try: build_and_tag_test_image(docker_image) except subprocess.CalledProcessError as exc_info: pytest.skip( "Skipped container tests due to a failure when trying to build the image. " "Most likely, the docker deamon is not running.\n" - "Output:\n{}".format(exc_info.output.decode()) + "Output:\n{}".format(exc_info.output.decode("utf-8")) ) return docker_image def wait_for_connection(host, port): retry_limit = 20 while retry_limit: try: if DagsterGrpcClient(host=host, port=port).ping("ready") == "ready": return True except grpc.RpcError: pass time.sleep(0.2) retry_limit -= 1 pytest.skip( "Skipped grpc container tests due to a failure when trying to connect to the GRPC server " "at {host}:{port}".format(host=host, port=port) ) @contextmanager def docker_service_up(docker_compose_file, service_name): check.str_param(service_name, "service_name") check.str_param(docker_compose_file, "docker_compose_file") check.invariant( os.path.isfile(docker_compose_file), "docker_compose_file must specify a valid file" ) if not IS_BUILDKITE: env = os.environ.copy() env["IMAGE_NAME"] = get_test_project_docker_image() try: subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "stop", service_name], env=env, ) subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "rm", "-f", service_name], env=env, ) except Exception: # pylint: disable=broad-except pass subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "up", "-d", service_name], env=env, ) yield try: subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "stop", service_name], env=env, ) subprocess.check_output( ["docker-compose", "-f", docker_compose_file, "rm", "-f", service_name], env=env, ) except Exception: # pylint: disable=broad-except pass @pytest.fixture(scope="session") def grpc_host(): # In buildkite we get the ip address from this variable (see buildkite code for commentary) # Otherwise assume local development and assume localhost env_name = "GRPC_SERVER_HOST" if env_name not in os.environ: os.environ[env_name] = "localhost" return os.environ[env_name] @pytest.fixture(scope="session") def grpc_port(): yield HARDCODED_PORT @pytest.fixture(scope="session") def docker_grpc_client( dagster_docker_image, grpc_host, grpc_port ): # pylint: disable=redefined-outer-name, unused-argument with docker_service_up( file_relative_path(__file__, "docker-compose.yml"), "dagster-grpc-server" ) if not IS_BUILDKITE else nullcontext(): wait_for_connection(grpc_host, grpc_port) yield DagsterGrpcClient(port=grpc_port, host=grpc_host) diff --git a/python_modules/dagster/dagster_tests/core_tests/runtime_types_tests/config_schema_tests/test_config_schema.py b/python_modules/dagster/dagster_tests/core_tests/runtime_types_tests/config_schema_tests/test_config_schema.py index 03ad52928..f7e5faaf9 100644 --- a/python_modules/dagster/dagster_tests/core_tests/runtime_types_tests/config_schema_tests/test_config_schema.py +++ b/python_modules/dagster/dagster_tests/core_tests/runtime_types_tests/config_schema_tests/test_config_schema.py @@ -1,56 +1,53 @@ import hashlib import pytest from dagster import String from dagster.core.errors import DagsterInvalidDefinitionError from dagster.core.types.config_schema import dagster_type_loader def test_dagster_type_loader_one(): @dagster_type_loader(String) def _foo(_, hello): return hello def test_dagster_type_loader_missing_context(): with pytest.raises(DagsterInvalidDefinitionError): @dagster_type_loader(String) def _foo(hello): return hello def test_dagster_type_loader_missing_variable(): with pytest.raises(DagsterInvalidDefinitionError): @dagster_type_loader(String) def _foo(_): return 1 def test_dagster_type_loader_default_version(): @dagster_type_loader(String) def _foo(_, hello): return hello assert _foo.loader_version == None assert _foo.compute_loaded_input_version({}) == None def test_dagster_type_loader_provided_version(): def _get_ext_version(dict_param): return dict_param["version"] @dagster_type_loader(String, loader_version="5", external_version_fn=_get_ext_version) def _foo(_, hello): return hello dict_param = {"version": "42"} assert _foo.loader_version == "5" - assert ( - _foo.compute_loaded_input_version(dict_param) - == hashlib.sha1("542".encode("utf-8")).hexdigest() - ) + assert _foo.compute_loaded_input_version(dict_param) == hashlib.sha1(b"542").hexdigest() diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_defensive_row_unpack.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_defensive_row_unpack.py index a4a751550..e26a782db 100644 --- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_defensive_row_unpack.py +++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_defensive_row_unpack.py @@ -1,99 +1,98 @@ import sys import zlib from dagster import pipeline, solid from dagster.core.storage.runs.sql_run_storage import defensively_unpack_pipeline_snapshot_query from dagster.serdes import serialize_dagster_namedtuple from dagster.seven import mock def test_defensive_pipeline_not_a_string(): mock_logger = mock.MagicMock() assert defensively_unpack_pipeline_snapshot_query(mock_logger, [234]) is None assert mock_logger.warning.call_count == 1 mock_logger.warning.assert_called_with( "get-pipeline-snapshot: First entry in row is not a binary type." ) def test_defensive_pipeline_not_bytes(): mock_logger = mock.MagicMock() assert defensively_unpack_pipeline_snapshot_query(mock_logger, ["notbytes"]) is None assert mock_logger.warning.call_count == 1 if sys.version_info.major == 2: # this error is not detected in python and instead fails on decompress # the joys of the python 2/3 unicode debacle mock_logger.warning.assert_called_with( "get-pipeline-snapshot: Could not decompress bytes stored in snapshot table." ) else: mock_logger.warning.assert_called_with( "get-pipeline-snapshot: First entry in row is not a binary type." ) def test_defensive_pipelines_cannot_decompress(): mock_logger = mock.MagicMock() - assert defensively_unpack_pipeline_snapshot_query(mock_logger, ["notbytes".encode()]) is None + assert defensively_unpack_pipeline_snapshot_query(mock_logger, [b"notbytes"]) is None assert mock_logger.warning.call_count == 1 mock_logger.warning.assert_called_with( "get-pipeline-snapshot: Could not decompress bytes stored in snapshot table." ) def test_defensive_pipelines_cannot_decode_post_decompress(): mock_logger = mock.MagicMock() # guarantee that we cannot decode by double compressing bytes. assert ( defensively_unpack_pipeline_snapshot_query( - mock_logger, [zlib.compress(zlib.compress("notbytes".encode()))] + mock_logger, [zlib.compress(zlib.compress(b"notbytes"))] ) is None ) assert mock_logger.warning.call_count == 1 mock_logger.warning.assert_called_with( "get-pipeline-snapshot: Could not unicode decode decompressed bytes " "stored in snapshot table." ) def test_defensive_pipelines_cannot_parse_json(): mock_logger = mock.MagicMock() assert ( - defensively_unpack_pipeline_snapshot_query(mock_logger, [zlib.compress("notjson".encode())]) - is None + defensively_unpack_pipeline_snapshot_query(mock_logger, [zlib.compress(b"notjson")]) is None ) assert mock_logger.warning.call_count == 1 mock_logger.warning.assert_called_with( "get-pipeline-snapshot: Could not parse json in snapshot table." ) def test_correctly_fetch_decompress_parse_snapshot(): @solid def noop_solid(_): pass @pipeline def noop_pipeline(): noop_solid() noop_pipeline_snapshot = noop_pipeline.get_pipeline_snapshot() mock_logger = mock.MagicMock() assert ( defensively_unpack_pipeline_snapshot_query( mock_logger, - [zlib.compress(serialize_dagster_namedtuple(noop_pipeline_snapshot).encode())], + [zlib.compress(serialize_dagster_namedtuple(noop_pipeline_snapshot).encode("utf-8"))], ) == noop_pipeline_snapshot ) assert mock_logger.warning.call_count == 0 diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_cache.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_cache.py index 84fe3b253..1347b6c6e 100644 --- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_cache.py +++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_cache.py @@ -1,33 +1,33 @@ import io import os from dagster import LocalFileHandle from dagster.core.storage.file_cache import FSFileCache from dagster.utils.temp_file import get_temp_dir def test_fs_file_cache_write_data(): - bytes_object = io.BytesIO("bar".encode()) + bytes_object = io.BytesIO(b"bar") with get_temp_dir() as temp_dir: file_cache = FSFileCache(temp_dir) assert not file_cache.has_file_object("foo") assert file_cache.write_file_object("foo", bytes_object) file_handle = file_cache.get_file_handle("foo") assert isinstance(file_handle, LocalFileHandle) assert file_handle.path_desc == os.path.join(temp_dir, "foo") def test_fs_file_cache_write_binary_data(): with get_temp_dir() as temp_dir: file_store = FSFileCache(temp_dir) assert not file_store.has_file_object("foo") - assert file_store.write_binary_data("foo", "bar".encode()) + assert file_store.write_binary_data("foo", b"bar") file_handle = file_store.get_file_handle("foo") assert isinstance(file_handle, LocalFileHandle) assert file_handle.path_desc == os.path.join(temp_dir, "foo") def test_empty_file_cache(): with get_temp_dir() as temp_dir: file_cache = FSFileCache(temp_dir) assert not file_cache.has_file_object("kjdfkd") diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py index 48e42beac..3b1d20f75 100644 --- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py +++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py @@ -1,68 +1,68 @@ import tempfile from contextlib import contextmanager from dagster import LocalFileHandle, ModeDefinition, execute_pipeline, pipeline, solid from dagster.core.instance import DagsterInstance from dagster.core.storage.file_manager import LocalFileManager, local_file_manager from dagster.utils.temp_file import get_temp_file_handle_with_data @contextmanager def my_local_file_manager(instance, run_id): manager = None try: manager = LocalFileManager.for_instance(instance, run_id) yield manager finally: if manager: manager.delete_local_temp() def test_basic_file_manager_copy_handle_to_local_temp(): instance = DagsterInstance.ephemeral() - foo_data = "foo".encode() + foo_data = b"foo" with get_temp_file_handle_with_data(foo_data) as foo_handle: with my_local_file_manager(instance, "0") as manager: local_temp = manager.copy_handle_to_local_temp(foo_handle) assert local_temp != foo_handle.path with open(local_temp, "rb") as ff: assert ff.read() == foo_data def test_basic_file_manager_execute(): called = {} @solid(required_resource_keys={"file_manager"}) def file_handle(context): - foo_bytes = "foo".encode() + foo_bytes = b"foo" file_handle = context.resources.file_manager.write_data(foo_bytes) assert isinstance(file_handle, LocalFileHandle) with open(file_handle.path, "rb") as handle_obj: assert foo_bytes == handle_obj.read() with context.resources.file_manager.read(file_handle) as handle_obj: assert foo_bytes == handle_obj.read() file_handle = context.resources.file_manager.write_data(foo_bytes, ext="foo") assert isinstance(file_handle, LocalFileHandle) assert file_handle.path[-4:] == ".foo" with open(file_handle.path, "rb") as handle_obj: assert foo_bytes == handle_obj.read() with context.resources.file_manager.read(file_handle) as handle_obj: assert foo_bytes == handle_obj.read() called["yup"] = True @pipeline(mode_defs=[ModeDefinition(resource_defs={"file_manager": local_file_manager})]) def pipe(): return file_handle() with tempfile.TemporaryDirectory() as temp_dir: result = execute_pipeline( pipe, run_config={"resources": {"file_manager": {"config": {"base_dir": temp_dir}}}} ) assert result.success assert called["yup"] diff --git a/python_modules/dagster/dagster_tests/core_tests/test_versioned_execution_plan.py b/python_modules/dagster/dagster_tests/core_tests/test_versioned_execution_plan.py index 914137272..035bfa630 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_versioned_execution_plan.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_versioned_execution_plan.py @@ -1,531 +1,531 @@ import hashlib import pytest from dagster import ( Bool, DagsterInstance, Field, Float, Int, ModeDefinition, Output, String, composite_solid, dagster_type_loader, io_manager, pipeline, resource, solid, usable_as_dagster_type, ) from dagster.core.definitions import InputDefinition from dagster.core.errors import DagsterInvariantViolationError from dagster.core.execution.api import create_execution_plan from dagster.core.execution.plan.outputs import StepOutputHandle from dagster.core.execution.resolve_versions import ( join_and_hash, resolve_config_version, resolve_memoized_execution_plan, resolve_resource_versions, ) from dagster.core.storage.memoizable_io_manager import MemoizableIOManager from dagster.core.storage.tags import MEMOIZED_RUN_TAG class VersionedInMemoryIOManager(MemoizableIOManager): def __init__(self): self.values = {} def _get_keys(self, context): return (context.step_key, context.name, context.version) def handle_output(self, context, obj): keys = self._get_keys(context) self.values[keys] = obj def load_input(self, context): keys = self._get_keys(context.upstream_output) return self.values[keys] def has_output(self, context): keys = self._get_keys(context) return keys in self.values def io_manager_factory(manager): @io_manager def _io_manager_resource(_): return manager return _io_manager_resource def test_join_and_hash(): - assert join_and_hash("foo") == hashlib.sha1("foo".encode("utf-8")).hexdigest() + assert join_and_hash("foo") == hashlib.sha1(b"foo").hexdigest() assert join_and_hash("foo", None, "bar") == None - assert join_and_hash("foo", "bar") == hashlib.sha1("barfoo".encode("utf-8")).hexdigest() + assert join_and_hash("foo", "bar") == hashlib.sha1(b"barfoo").hexdigest() assert join_and_hash("foo", "bar", "zab") == join_and_hash("zab", "bar", "foo") def test_resolve_config_version(): assert resolve_config_version({}) == join_and_hash() assert resolve_config_version({"a": "b", "c": "d"}) == join_and_hash( "a" + join_and_hash("b"), "c" + join_and_hash("d") ) assert resolve_config_version({"a": "b", "c": "d"}) == resolve_config_version( {"c": "d", "a": "b"} ) assert resolve_config_version({"a": {"b": "c"}, "d": "e"}) == join_and_hash( "a" + join_and_hash("b" + join_and_hash("c")), "d" + join_and_hash("e") ) @solid(version="42") def versioned_solid_no_input(_): return 4 @solid(version="5") def versioned_solid_takes_input(_, intput): return 2 * intput def versioned_pipeline_factory(manager=None): @pipeline( mode_defs=[ ModeDefinition( name="main", resource_defs=({"io_manager": io_manager_factory(manager)} if manager else {}), ) ], tags={MEMOIZED_RUN_TAG: "true"}, ) def versioned_pipeline(): versioned_solid_takes_input(versioned_solid_no_input()) return versioned_pipeline @solid def solid_takes_input(_, intput): return 2 * intput def partially_versioned_pipeline_factory(manager=None): @pipeline( mode_defs=[ ModeDefinition( name="main", resource_defs=({"io_manager": io_manager_factory(manager)} if manager else {}), ) ], tags={MEMOIZED_RUN_TAG: "true"}, ) def partially_versioned_pipeline(): solid_takes_input(versioned_solid_no_input()) return partially_versioned_pipeline def versioned_pipeline_expected_step1_version(): solid1_def_version = versioned_solid_no_input.version solid1_config_version = resolve_config_version(None) solid1_resources_version = join_and_hash() solid1_version = join_and_hash( solid1_def_version, solid1_config_version, solid1_resources_version ) return join_and_hash(solid1_version) def versioned_pipeline_expected_step1_output_version(): step1_version = versioned_pipeline_expected_step1_version() return join_and_hash(step1_version, "result") def versioned_pipeline_expected_step2_version(): solid2_def_version = versioned_solid_takes_input.version solid2_config_version = resolve_config_version(None) solid2_resources_version = join_and_hash() solid2_version = join_and_hash( solid2_def_version, solid2_config_version, solid2_resources_version ) step1_outputs_hash = versioned_pipeline_expected_step1_output_version() step2_version = join_and_hash(step1_outputs_hash, solid2_version) return step2_version def versioned_pipeline_expected_step2_output_version(): step2_version = versioned_pipeline_expected_step2_version() return join_and_hash(step2_version + "result") def test_resolve_step_versions_no_external_dependencies(): versioned_pipeline = versioned_pipeline_factory() speculative_execution_plan = create_execution_plan(versioned_pipeline) versions = speculative_execution_plan.resolve_step_versions() assert versions["versioned_solid_no_input"] == versioned_pipeline_expected_step1_version() assert versions["versioned_solid_takes_input"] == versioned_pipeline_expected_step2_version() def test_resolve_step_output_versions_no_external_dependencies(): versioned_pipeline = versioned_pipeline_factory() speculative_execution_plan = create_execution_plan( versioned_pipeline, run_config={}, mode="main" ) versions = speculative_execution_plan.resolve_step_output_versions() assert ( versions[StepOutputHandle("versioned_solid_no_input", "result")] == versioned_pipeline_expected_step1_output_version() ) assert ( versions[StepOutputHandle("versioned_solid_takes_input", "result")] == versioned_pipeline_expected_step2_output_version() ) @solid def basic_solid(_): return 5 @solid def basic_takes_input_solid(_, intpt): return intpt * 4 @pipeline def no_version_pipeline(): basic_takes_input_solid(basic_solid()) def test_resolve_memoized_execution_plan_no_stored_results(): versioned_pipeline = versioned_pipeline_factory(VersionedInMemoryIOManager()) speculative_execution_plan = create_execution_plan(versioned_pipeline) memoized_execution_plan = resolve_memoized_execution_plan(speculative_execution_plan) assert set(memoized_execution_plan.step_keys_to_execute) == { "versioned_solid_no_input", "versioned_solid_takes_input", } def test_resolve_memoized_execution_plan_yes_stored_results(): manager = VersionedInMemoryIOManager() versioned_pipeline = versioned_pipeline_factory(manager) speculative_execution_plan = create_execution_plan(versioned_pipeline) step_output_handle = StepOutputHandle("versioned_solid_no_input", "result") step_output_version = speculative_execution_plan.resolve_step_output_versions()[ step_output_handle ] manager.values[ (step_output_handle.step_key, step_output_handle.output_name, step_output_version) ] = 4 memoized_execution_plan = resolve_memoized_execution_plan(speculative_execution_plan) assert memoized_execution_plan.step_keys_to_execute == ["versioned_solid_takes_input"] expected_handle = StepOutputHandle(step_key="versioned_solid_no_input", output_name="result") assert ( memoized_execution_plan.get_step_by_key("versioned_solid_takes_input") .step_input_dict["intput"] .source.step_output_handle == expected_handle ) def test_resolve_memoized_execution_plan_partial_versioning(): manager = VersionedInMemoryIOManager() partially_versioned_pipeline = partially_versioned_pipeline_factory(manager) speculative_execution_plan = create_execution_plan(partially_versioned_pipeline) step_output_handle = StepOutputHandle("versioned_solid_no_input", "result") step_output_version = speculative_execution_plan.resolve_step_output_versions()[ step_output_handle ] manager.values[ (step_output_handle.step_key, step_output_handle.output_name, step_output_version) ] = 4 assert resolve_memoized_execution_plan(speculative_execution_plan).step_keys_to_execute == [ "solid_takes_input" ] def _get_ext_version(config_value): return join_and_hash(str(config_value)) @dagster_type_loader(String, loader_version="97", external_version_fn=_get_ext_version) def InputHydration(_, _hello): return "Hello" @usable_as_dagster_type(loader=InputHydration) class CustomType(str): pass def test_externally_loaded_inputs(): for type_to_test, loader_version, type_value in [ (String, "String", "foo"), (Int, "Int", int(42)), (Float, "Float", float(5.42)), (Bool, "Bool", False), (CustomType, "97", "bar"), ]: run_test_with_builtin_type(type_to_test, loader_version, type_value) def run_test_with_builtin_type(type_to_test, loader_version, type_value): @solid(version="42", input_defs=[InputDefinition("_builtin_type", type_to_test)]) def versioned_solid_ext_input_builtin_type(_, _builtin_type): pass @pipeline def versioned_pipeline_ext_input_builtin_type(): versioned_solid_takes_input(versioned_solid_ext_input_builtin_type()) run_config = { "solids": { "versioned_solid_ext_input_builtin_type": {"inputs": {"_builtin_type": type_value}} } } speculative_execution_plan = create_execution_plan( versioned_pipeline_ext_input_builtin_type, run_config=run_config, ) versions = speculative_execution_plan.resolve_step_versions() ext_input_version = join_and_hash(str(type_value)) input_version = join_and_hash(loader_version + ext_input_version) solid1_def_version = versioned_solid_ext_input_builtin_type.version solid1_config_version = resolve_config_version(None) solid1_resources_version = join_and_hash() solid1_version = join_and_hash( solid1_def_version, solid1_config_version, solid1_resources_version ) step1_version = join_and_hash(input_version, solid1_version) assert versions["versioned_solid_ext_input_builtin_type"] == step1_version output_version = join_and_hash(step1_version, "result") hashed_input2 = output_version solid2_def_version = versioned_solid_takes_input.version solid2_config_version = resolve_config_version(None) solid2_resources_version = join_and_hash() solid2_version = join_and_hash( solid2_def_version, solid2_config_version, solid2_resources_version ) step2_version = join_and_hash(hashed_input2, solid2_version) assert versions["versioned_solid_takes_input"] == step2_version @solid( version="42", input_defs=[InputDefinition("default_input", String, default_value="DEFAULTVAL")], ) def versioned_solid_default_value(_, default_input): return default_input * 4 @pipeline def versioned_pipeline_default_value(): versioned_solid_default_value() def test_resolve_step_versions_default_value(): speculative_execution_plan = create_execution_plan(versioned_pipeline_default_value) versions = speculative_execution_plan.resolve_step_versions() input_version = join_and_hash(repr("DEFAULTVAL")) solid_def_version = versioned_solid_default_value.version solid_config_version = resolve_config_version(None) solid_resources_version = join_and_hash() solid_version = join_and_hash(solid_def_version, solid_config_version, solid_resources_version) step_version = join_and_hash(input_version, solid_version) assert versions["versioned_solid_default_value"] == step_version def test_step_keys_already_provided(): with pytest.raises( DagsterInvariantViolationError, match="step_keys_to_execute parameter " "cannot be used in conjunction with memoized pipeline runs.", ): instance = DagsterInstance.ephemeral() instance.create_run_for_pipeline( pipeline_def=no_version_pipeline, tags={MEMOIZED_RUN_TAG: "true"}, step_keys_to_execute=["basic_takes_input_solid"], ) @resource(config_schema={"input_str": Field(String)}, version="5") def test_resource(context): return context.resource_config["input_str"] @resource(config_schema={"input_str": Field(String)}) def test_resource_no_version(context): return context.resource_config["input_str"] @resource(version="42") def test_resource_no_config(_): return "Hello" @solid( required_resource_keys={"test_resource", "test_resource_no_version", "test_resource_no_config"}, ) def fake_solid_resources(context): return ( "solidified_" + context.resources.test_resource + context.resources.test_resource_no_version + context.resources.test_resource_no_config ) @pipeline( mode_defs=[ ModeDefinition( name="fakemode", resource_defs={ "test_resource": test_resource, "test_resource_no_version": test_resource_no_version, "test_resource_no_config": test_resource_no_config, }, ), ModeDefinition( name="fakemode2", resource_defs={ "test_resource": test_resource, "test_resource_no_version": test_resource_no_version, "test_resource_no_config": test_resource_no_config, }, ), ] ) def modes_pipeline(): fake_solid_resources() def test_resource_versions(): run_config = { "resources": { "test_resource": {"config": {"input_str": "apple"},}, "test_resource_no_version": {"config": {"input_str": "banana"}}, } } execution_plan = create_execution_plan(modes_pipeline, run_config=run_config, mode="fakemode") resource_versions_by_key = resolve_resource_versions( execution_plan.environment_config, execution_plan.pipeline.get_definition() ) assert resource_versions_by_key["test_resource"] == join_and_hash( resolve_config_version({"config": {"input_str": "apple"}}), test_resource.version ) assert resource_versions_by_key["test_resource_no_version"] == None assert resource_versions_by_key["test_resource_no_config"] == join_and_hash( join_and_hash(), "42" ) @solid(required_resource_keys={"test_resource", "test_resource_no_config"}, version="39") def fake_solid_resources_versioned(context): return ( "solidified_" + context.resources.test_resource + context.resources.test_resource_no_config ) @pipeline( mode_defs=[ ModeDefinition( name="fakemode", resource_defs={ "test_resource": test_resource, "test_resource_no_config": test_resource_no_config, }, ), ] ) def versioned_modes_pipeline(): fake_solid_resources_versioned() def test_step_versions_with_resources(): run_config = {"resources": {"test_resource": {"config": {"input_str": "apple"}}}} speculative_execution_plan = create_execution_plan( versioned_modes_pipeline, run_config=run_config, mode="fakemode" ) versions = speculative_execution_plan.resolve_step_versions() solid_def_version = fake_solid_resources_versioned.version solid_config_version = resolve_config_version(None) resource_versions_by_key = resolve_resource_versions( speculative_execution_plan.environment_config, speculative_execution_plan.pipeline.get_definition(), ) solid_resources_version = join_and_hash( *[ resource_versions_by_key[resource_key] for resource_key in fake_solid_resources_versioned.required_resource_keys ] ) solid_version = join_and_hash(solid_def_version, solid_config_version, solid_resources_version) step_version = join_and_hash(solid_version) assert versions["fake_solid_resources_versioned"] == step_version def test_step_versions_composite_solid(): @solid(config_schema=Field(String, is_required=False)) def scalar_config_solid(context): yield Output(context.solid_config) @composite_solid( config_schema={"override_str": Field(String)}, config_fn=lambda cfg: {"scalar_config_solid": {"config": cfg["override_str"]}}, ) def wrap(): return scalar_config_solid() @pipeline def wrap_pipeline(): wrap.alias("do_stuff")() run_config = { "solids": {"do_stuff": {"config": {"override_str": "override"}}}, "loggers": {"console": {"config": {"log_level": "ERROR"}}}, } speculative_execution_plan = create_execution_plan(wrap_pipeline, run_config=run_config,) versions = speculative_execution_plan.resolve_step_versions() assert versions["do_stuff.scalar_config_solid"] == None diff --git a/python_modules/dagster/dagster_tests/general_tests/compat_tests/test_back_compat.py b/python_modules/dagster/dagster_tests/general_tests/compat_tests/test_back_compat.py index b083e52e6..21e8de763 100644 --- a/python_modules/dagster/dagster_tests/general_tests/compat_tests/test_back_compat.py +++ b/python_modules/dagster/dagster_tests/general_tests/compat_tests/test_back_compat.py @@ -1,357 +1,357 @@ # pylint: disable=protected-access import os import re import sqlite3 from gzip import GzipFile import pytest from dagster import check, execute_pipeline, file_relative_path, pipeline, solid from dagster.cli.debug import DebugRunPayload from dagster.core.errors import DagsterInstanceMigrationRequired from dagster.core.instance import DagsterInstance, InstanceRef from dagster.core.storage.event_log.migration import migrate_event_log_data from dagster.core.storage.event_log.sql_event_log import SqlEventLogStorage from dagster.serdes import deserialize_json_to_dagster_namedtuple from dagster.utils.test import copy_directory def _migration_regex(warning, current_revision, expected_revision=None): instruction = re.escape("Please run `dagster instance migrate`.") if expected_revision: revision = re.escape( "Database is at revision {}, head is {}.".format(current_revision, expected_revision) ) else: revision = "Database is at revision {}, head is [a-z0-9]+.".format(current_revision) return "{} {} {}".format(warning, revision, instruction) def _run_storage_migration_regex(current_revision, expected_revision=None): warning = re.escape( "Instance is out of date and must be migrated (Sqlite run storage requires migration)." ) return _migration_regex(warning, current_revision, expected_revision) def _schedule_storage_migration_regex(current_revision, expected_revision=None): warning = re.escape( "Instance is out of date and must be migrated (Sqlite schedule storage requires migration)." ) return _migration_regex(warning, current_revision, expected_revision) def _event_log_migration_regex(run_id, current_revision, expected_revision=None): warning = re.escape( "Instance is out of date and must be migrated (SqliteEventLogStorage for run {}).".format( run_id ) ) return _migration_regex(warning, current_revision, expected_revision) def test_event_log_step_key_migration(): src_dir = file_relative_path(__file__, "snapshot_0_7_6_pre_event_log_migration/sqlite") with copy_directory(src_dir) as test_dir: instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) # Make sure the schema is migrated instance.upgrade() runs = instance.get_runs() assert len(runs) == 1 run_ids = instance._event_storage.get_all_run_ids() assert run_ids == ["6405c4a0-3ccc-4600-af81-b5ee197f8528"] assert isinstance(instance._event_storage, SqlEventLogStorage) events_by_id = instance._event_storage.get_logs_for_run_by_log_id( "6405c4a0-3ccc-4600-af81-b5ee197f8528" ) assert len(events_by_id) == 40 step_key_records = [] for record_id, _event in events_by_id.items(): row_data = instance._event_storage.get_event_log_table_data( "6405c4a0-3ccc-4600-af81-b5ee197f8528", record_id ) if row_data.step_key is not None: step_key_records.append(row_data) assert len(step_key_records) == 0 # run the event_log backfill migration migrate_event_log_data(instance=instance) step_key_records = [] for record_id, _event in events_by_id.items(): row_data = instance._event_storage.get_event_log_table_data( "6405c4a0-3ccc-4600-af81-b5ee197f8528", record_id ) if row_data.step_key is not None: step_key_records.append(row_data) assert len(step_key_records) > 0 def get_sqlite3_tables(db_path): con = sqlite3.connect(db_path) cursor = con.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return [r[0] for r in cursor.fetchall()] def get_current_alembic_version(db_path): con = sqlite3.connect(db_path) cursor = con.cursor() cursor.execute("SELECT * FROM alembic_version") return cursor.fetchall()[0][0] def get_sqlite3_columns(db_path, table_name): con = sqlite3.connect(db_path) cursor = con.cursor() cursor.execute('PRAGMA table_info("{}");'.format(table_name)) return [r[1] for r in cursor.fetchall()] def test_snapshot_0_7_6_pre_add_pipeline_snapshot(): run_id = "fb0b3905-068b-4444-8f00-76fcbaef7e8b" src_dir = file_relative_path(__file__, "snapshot_0_7_6_pre_add_pipeline_snapshot/sqlite") with copy_directory(src_dir) as test_dir: # invariant check to make sure migration has not been run yet db_path = os.path.join(test_dir, "history", "runs.db") assert get_current_alembic_version(db_path) == "9fe9e746268c" assert "snapshots" not in get_sqlite3_tables(db_path) instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) @solid def noop_solid(_): pass @pipeline def noop_pipeline(): noop_solid() with pytest.raises( DagsterInstanceMigrationRequired, match=_run_storage_migration_regex(current_revision="9fe9e746268c"), ): execute_pipeline(noop_pipeline, instance=instance) assert len(instance.get_runs()) == 1 # Make sure the schema is migrated instance.upgrade() assert "snapshots" in get_sqlite3_tables(db_path) assert {"id", "snapshot_id", "snapshot_body", "snapshot_type"} == set( get_sqlite3_columns(db_path, "snapshots") ) assert len(instance.get_runs()) == 1 run = instance.get_run_by_id(run_id) assert run.run_id == run_id assert run.pipeline_snapshot_id is None result = execute_pipeline(noop_pipeline, instance=instance) assert result.success runs = instance.get_runs() assert len(runs) == 2 new_run_id = result.run_id new_run = instance.get_run_by_id(new_run_id) assert new_run.pipeline_snapshot_id def test_downgrade_and_upgrade(): src_dir = file_relative_path(__file__, "snapshot_0_7_6_pre_add_pipeline_snapshot/sqlite") with copy_directory(src_dir) as test_dir: # invariant check to make sure migration has not been run yet db_path = os.path.join(test_dir, "history", "runs.db") assert get_current_alembic_version(db_path) == "9fe9e746268c" assert "snapshots" not in get_sqlite3_tables(db_path) instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) assert len(instance.get_runs()) == 1 # Make sure the schema is migrated instance.upgrade() assert "snapshots" in get_sqlite3_tables(db_path) assert {"id", "snapshot_id", "snapshot_body", "snapshot_type"} == set( get_sqlite3_columns(db_path, "snapshots") ) assert len(instance.get_runs()) == 1 instance._run_storage._alembic_downgrade(rev="9fe9e746268c") assert get_current_alembic_version(db_path) == "9fe9e746268c" assert "snapshots" not in get_sqlite3_tables(db_path) instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) assert len(instance.get_runs()) == 1 instance.upgrade() assert "snapshots" in get_sqlite3_tables(db_path) assert {"id", "snapshot_id", "snapshot_body", "snapshot_type"} == set( get_sqlite3_columns(db_path, "snapshots") ) assert len(instance.get_runs()) == 1 def test_event_log_asset_key_migration(): src_dir = file_relative_path(__file__, "snapshot_0_7_8_pre_asset_key_migration/sqlite") with copy_directory(src_dir) as test_dir: db_path = os.path.join( test_dir, "history", "runs", "722183e4-119f-4a00-853f-e1257be82ddb.db" ) assert get_current_alembic_version(db_path) == "3b1e175a2be3" assert "asset_key" not in set(get_sqlite3_columns(db_path, "event_logs")) # Make sure the schema is migrated instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) instance.upgrade() assert "asset_key" in set(get_sqlite3_columns(db_path, "event_logs")) def instance_from_debug_payloads(payload_files): debug_payloads = [] for input_file in payload_files: with GzipFile(input_file, "rb") as file: - blob = file.read().decode() + blob = file.read().decode("utf-8") debug_payload = deserialize_json_to_dagster_namedtuple(blob) check.invariant(isinstance(debug_payload, DebugRunPayload)) debug_payloads.append(debug_payload) return DagsterInstance.ephemeral(preload=debug_payloads) def test_object_store_operation_result_data_new_fields(): """We added address and version fields to ObjectStoreOperationResultData. Make sure we can still deserialize old ObjectStoreOperationResultData without those fields.""" instance_from_debug_payloads([file_relative_path(__file__, "0_9_12_nothing_fs_storage.gz")]) def test_event_log_asset_partition_migration(): src_dir = file_relative_path(__file__, "snapshot_0_9_22_pre_asset_partition/sqlite") with copy_directory(src_dir) as test_dir: db_path = os.path.join( test_dir, "history", "runs", "1a1d3c4b-1284-4c74-830c-c8988bd4d779.db" ) assert get_current_alembic_version(db_path) == "c34498c29964" assert "partition" not in set(get_sqlite3_columns(db_path, "event_logs")) # Make sure the schema is migrated instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) instance.upgrade() assert "partition" in set(get_sqlite3_columns(db_path, "event_logs")) def test_run_partition_migration(): src_dir = file_relative_path(__file__, "snapshot_0_9_22_pre_run_partition/sqlite") with copy_directory(src_dir) as test_dir: db_path = os.path.join(test_dir, "history", "runs.db") assert get_current_alembic_version(db_path) == "224640159acf" assert "partition" not in set(get_sqlite3_columns(db_path, "runs")) assert "partition_set" not in set(get_sqlite3_columns(db_path, "runs")) # Make sure the schema is migrated instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) instance.upgrade() assert "partition" in set(get_sqlite3_columns(db_path, "runs")) assert "partition_set" in set(get_sqlite3_columns(db_path, "runs")) instance._run_storage._alembic_downgrade(rev="224640159acf") assert get_current_alembic_version(db_path) == "224640159acf" assert "partition" not in set(get_sqlite3_columns(db_path, "runs")) assert "partition_set" not in set(get_sqlite3_columns(db_path, "runs")) def test_run_partition_data_migration(): src_dir = file_relative_path(__file__, "snapshot_0_9_22_post_schema_pre_data_partition/sqlite") with copy_directory(src_dir) as test_dir: from dagster.core.storage.runs.sql_run_storage import SqlRunStorage from dagster.core.storage.runs.migration import RUN_PARTITIONS # load db that has migrated schema, but not populated data for run partitions db_path = os.path.join(test_dir, "history", "runs.db") assert get_current_alembic_version(db_path) == "375e95bad550" # Make sure the schema is migrated assert "partition" in set(get_sqlite3_columns(db_path, "runs")) assert "partition_set" in set(get_sqlite3_columns(db_path, "runs")) instance = DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) run_storage = instance._run_storage assert isinstance(run_storage, SqlRunStorage) partition_set_name = "ingest_and_train" partition_name = "2020-01-02" # ensure old tag-based reads are working assert not run_storage.has_built_index(RUN_PARTITIONS) assert len(run_storage._get_partition_runs(partition_set_name, partition_name)) == 2 # turn on reads for the partition column, without migrating the data run_storage.mark_index_built(RUN_PARTITIONS) # ensure that no runs are returned because the data has not been migrated assert run_storage.has_built_index(RUN_PARTITIONS) assert len(run_storage._get_partition_runs(partition_set_name, partition_name)) == 0 # actually migrate the data run_storage.build_missing_indexes(force_rebuild_all=True) # ensure that we get the same partitioned runs returned assert run_storage.has_built_index(RUN_PARTITIONS) assert len(run_storage._get_partition_runs(partition_set_name, partition_name)) == 2 def test_0_10_0_schedule_wipe(): src_dir = file_relative_path(__file__, "snapshot_0_10_0_wipe_schedules/sqlite") with copy_directory(src_dir) as test_dir: db_path = os.path.join(test_dir, "schedules", "schedules.db") assert get_current_alembic_version(db_path) == "b22f16781a7c" assert "schedules" in get_sqlite3_tables(db_path) assert "schedule_ticks" in get_sqlite3_tables(db_path) assert "jobs" not in get_sqlite3_tables(db_path) assert "job_ticks" not in get_sqlite3_tables(db_path) with DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) as instance: instance.upgrade() assert get_current_alembic_version(db_path) == "140198fdfe65" assert "schedules" not in get_sqlite3_tables(db_path) assert "schedule_ticks" not in get_sqlite3_tables(db_path) assert "jobs" in get_sqlite3_tables(db_path) assert "job_ticks" in get_sqlite3_tables(db_path) with DagsterInstance.from_ref(InstanceRef.from_dir(test_dir)) as upgraded_instance: assert len(upgraded_instance.all_stored_job_state()) == 0 diff --git a/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_backoff.py b/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_backoff.py index 4233c42ef..ec234b01c 100644 --- a/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_backoff.py +++ b/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_backoff.py @@ -1,75 +1,74 @@ import pytest -import six from dagster.utils.backoff import backoff, backoff_delay_generator class UnretryableException(Exception): pass class RetryableException(Exception): pass class RetryableExceptionB(Exception): pass class Failer: def __init__(self, fails=0, exception=RetryableException): self.fails = fails self.exception = exception self.call_count = 0 self.args = [] self.kwargs = [] def __call__(self, *args, **kwargs): self.call_count += 1 self.args.append(args) self.kwargs.append(kwargs) if self.call_count <= self.fails: raise self.exception return True def test_backoff_delay_generator(): gen = backoff_delay_generator() vals = [] for _ in range(10): - vals.append(six.next(gen)) + vals.append(next(gen)) assert vals == [0.1, 0.2, 0.4, 0.8, 1.6, 3.2, 6.4, 12.8, 25.6, 51.2] def test_backoff(): fn = Failer(fails=100) with pytest.raises(RetryableException): backoff(fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={"foo": "bar"}) assert fn.call_count == 5 assert all([args == (3, 2, 1) for args in fn.args]) assert all([kwargs == {"foo": "bar"} for kwargs in fn.kwargs]) fn = Failer() assert backoff(fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={"foo": "bar"}) assert fn.call_count == 1 fn = Failer(fails=1) assert backoff(fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={"foo": "bar"}) assert fn.call_count == 2 fn = Failer(fails=1) with pytest.raises(RetryableException): backoff( fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={"foo": "bar"}, max_retries=0 ) assert fn.call_count == 1 fn = Failer(fails=2) with pytest.raises(RetryableException): backoff( fn, retry_on=(RetryableException,), args=[3, 2, 1], kwargs={"foo": "bar"}, max_retries=1 ) assert fn.call_count == 2 diff --git a/python_modules/dagster/setup.py b/python_modules/dagster/setup.py index fc3053ebe..e00cfe43f 100644 --- a/python_modules/dagster/setup.py +++ b/python_modules/dagster/setup.py @@ -1,92 +1,91 @@ from setuptools import find_packages, setup def long_description(): return """ ## Dagster Dagster is a data orchestrator for machine learning, analytics, and ETL. Dagster lets you define pipelines in terms of the data flow between reusable, logical components, then test locally and run anywhere. With a unified view of pipelines and the assets they produce, Dagster can schedule and orchestrate Pandas, Spark, SQL, or anything else that Python can invoke. Dagster is designed for data platform engineers, data engineers, and full-stack data scientists. Building a data platform with Dagster makes your stakeholders more independent and your systems more robust. Developing data pipelines with Dagster makes testing easier and deploying faster. """.strip() def get_version(): version = {} with open("dagster/version.py") as fp: exec(fp.read(), version) # pylint: disable=W0122 return version["__version__"] if __name__ == "__main__": setup( name="dagster", version=get_version(), author="Elementl", author_email="hello@elementl.com", license="Apache-2.0", description="A data orchestrator for machine learning, analytics, and ETL.", long_description=long_description(), long_description_content_type="text/markdown", url="https://github.com/dagster-io/dagster", classifiers=[ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], packages=find_packages(exclude=["dagster_tests"]), package_data={ "dagster": [ "dagster/core/storage/event_log/sqlite/alembic/*", "dagster/core/storage/runs/sqlite/alembic/*", "dagster/core/storage/schedules/sqlite/alembic/*", "dagster/grpc/protos/*", ] }, include_package_data=True, install_requires=[ "future", # cli "click>=5.0", "coloredlogs>=6.1, <=14.0", "PyYAML", # core (not explicitly expressed atm) "alembic>=1.2.1", "croniter>=0.3.34", "grpcio>=1.32.0", # ensure version we require is >= that with which we generated the grpc code (set in dev-requirements) "grpcio-health-checking>=1.32.0", "pendulum==1.4.4", # pinned to match airflow, can upgrade to 2.0 once airflow 1.10.13 is released "protobuf>=3.13.0", # ensure version we require is >= that with which we generated the proto code (set in dev-requirements) "pyrsistent>=0.14.8", "python-dateutil", "requests", "rx<=1.6.1", # 3.0 was a breaking change. - "six", "tabulate", "tqdm", "sqlalchemy>=1.0", "toposort>=1.0", "watchdog>=0.8.3", 'psutil >= 1.0; platform_system=="Windows"', # https://github.com/mhammond/pywin32/issues/1439 'pywin32 != 226; platform_system=="Windows"', "pytz", "docstring-parser==0.7.1", ], extras_require={"docker": ["docker"],}, entry_points={ "console_scripts": [ "dagster = dagster.cli:main", "dagster-scheduler = dagster.scheduler.cli:main", "dagster-daemon = dagster.daemon.cli:main", ] }, ) diff --git a/python_modules/libraries/dagster-airflow/dagster_airflow/cli.py b/python_modules/libraries/dagster-airflow/dagster_airflow/cli.py index dbc248aeb..098219a47 100644 --- a/python_modules/libraries/dagster-airflow/dagster_airflow/cli.py +++ b/python_modules/libraries/dagster-airflow/dagster_airflow/cli.py @@ -1,170 +1,169 @@ import os from datetime import datetime, timedelta import click -import six import yaml from dagster import check, seven from dagster.cli.load_handle import recon_repo_for_cli_args from dagster.utils import load_yaml_from_glob_list from dagster.utils.indenting_printer import IndentingStringIoPrinter def construct_environment_yaml(preset_name, config, pipeline_name, module_name): # Load environment dict from either a preset or yaml file globs if preset_name: if config: raise click.UsageError("Can not use --preset with --config.") cli_args = { "fn_name": pipeline_name, "pipeline_name": pipeline_name, "module_name": module_name, } pipeline = recon_repo_for_cli_args(cli_args).get_definition().get_pipeline(pipeline_name) run_config = pipeline.get_preset(preset_name).run_config else: config = list(config) run_config = load_yaml_from_glob_list(config) if config else {} # If not provided by the user, ensure we have storage location defined if "intermediate_storage" not in run_config: system_tmp_path = seven.get_system_temp_directory() dagster_tmp_path = os.path.join(system_tmp_path, "dagster-airflow", pipeline_name) run_config["intermediate_storage"] = { - "filesystem": {"config": {"base_dir": six.ensure_str(dagster_tmp_path)}} + "filesystem": {"config": {"base_dir": dagster_tmp_path}} } return run_config def construct_scaffolded_file_contents(module_name, pipeline_name, run_config): yesterday = datetime.now() - timedelta(1) printer = IndentingStringIoPrinter(indent_level=4) printer.line("'''") printer.line( "The airflow DAG scaffold for {module_name}.{pipeline_name}".format( module_name=module_name, pipeline_name=pipeline_name ) ) printer.blank_line() printer.line('Note that this docstring must contain the strings "airflow" and "DAG" for') printer.line("Airflow to properly detect it as a DAG") printer.line("See: http://bit.ly/307VMum") printer.line("'''") printer.line("import datetime") printer.blank_line() printer.line("import yaml") printer.line("from dagster_airflow.factory import make_airflow_dag") printer.blank_line() printer.line("#" * 80) printer.comment("#") printer.comment("# This environment is auto-generated from your configs and/or presets") printer.comment("#") printer.line("#" * 80) printer.line("ENVIRONMENT = '''") printer.line(yaml.dump(run_config, default_flow_style=False)) printer.line("'''") printer.blank_line() printer.blank_line() printer.line("#" * 80) printer.comment("#") printer.comment("# NOTE: these arguments should be edited for your environment") printer.comment("#") printer.line("#" * 80) printer.line("DEFAULT_ARGS = {") with printer.with_indent(): printer.line("'owner': 'airflow',") printer.line("'depends_on_past': False,") # start date -> yesterday printer.line( "'start_date': datetime.datetime(%s, %d, %d)," % (yesterday.year, yesterday.month, yesterday.day) ) printer.line("'email': ['airflow@example.com'],") printer.line("'email_on_failure': False,") printer.line("'email_on_retry': False,") printer.line("}") printer.blank_line() printer.line("dag, tasks = make_airflow_dag(") with printer.with_indent(): printer.comment( "NOTE: you must ensure that {module_name} is ".format(module_name=module_name) ) printer.comment("installed or available on sys.path, otherwise, this import will fail.") printer.line("module_name='{module_name}',".format(module_name=module_name)) printer.line("pipeline_name='{pipeline_name}',".format(pipeline_name=pipeline_name)) printer.line("run_config=yaml.safe_load(ENVIRONMENT),") printer.line("dag_kwargs={'default_args': DEFAULT_ARGS, 'max_active_runs': 1}") printer.line(")") - return printer.read().encode() + return printer.read().encode("utf-8") @click.group() def main(): pass @main.command() @click.option( "--module-name", "-m", type=click.STRING, help="The name of the source module", required=True ) @click.option("--pipeline-name", type=click.STRING, help="The name of the pipeline", required=True) @click.option( "--output-path", "-o", type=click.Path(), help="Optional. If unset, $AIRFLOW_HOME will be used.", default=os.getenv("AIRFLOW_HOME"), ) @click.option( "-c", "--config", type=click.STRING, multiple=True, help=( "Specify one or more run config files. These can also be file patterns. " "If more than one run config file is captured then those files are merged. " "Files listed first take precendence. They will smash the values of subsequent " "files at the key-level granularity. If the file is a pattern then you must " "enclose it in double quotes" ), ) @click.option( "-p", "--preset", type=click.STRING, help="Specify a preset to use for this pipeline. Presets are defined on pipelines under " "preset_defs.", ) def scaffold(module_name, pipeline_name, output_path, config, preset): """Creates a DAG file for a specified dagster pipeline""" check.tuple_param(config, "config", of_type=str) check.invariant(isinstance(config, tuple)) check.invariant( output_path is not None, "You must specify --output-path or set AIRFLOW_HOME to use this script.", ) run_config = construct_environment_yaml(preset, config, pipeline_name, module_name) file_contents = construct_scaffolded_file_contents(module_name, pipeline_name, run_config) # Ensure output_path/dags exists dags_path = os.path.join(os.path.expanduser(output_path), "dags") if not os.path.isdir(dags_path): os.makedirs(dags_path) dag_file = os.path.join(os.path.expanduser(output_path), "dags", pipeline_name + ".py") click.echo("Wrote DAG scaffold to file: %s" % dag_file) with open(dag_file, "wb") as f: f.write(file_contents) if __name__ == "__main__": main() # pylint:disable=no-value-for-parameter diff --git a/python_modules/libraries/dagster-airflow/setup.py b/python_modules/libraries/dagster-airflow/setup.py index 521db97b9..ee8508115 100644 --- a/python_modules/libraries/dagster-airflow/setup.py +++ b/python_modules/libraries/dagster-airflow/setup.py @@ -1,41 +1,40 @@ from setuptools import find_packages, setup def get_version(): version = {} with open("dagster_airflow/version.py") as fp: exec(fp.read(), version) # pylint: disable=W0122 return version["__version__"] if __name__ == "__main__": ver = get_version() setup( name="dagster-airflow", version=ver, author="Elementl", author_email="hello@elementl.com", license="Apache-2.0", description="Airflow plugin for Dagster", url="https://github.com/dagster-io/dagster", classifiers=[ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], packages=find_packages(exclude=["dagster_airflow_tests"]), install_requires=[ - "six", "dagster=={ver}".format(ver=ver), "docker", "python-dateutil>=2.8.0", "lazy_object_proxy", "pendulum==1.4.4", # https://issues.apache.org/jira/browse/AIRFLOW-6854 'typing_extensions; python_version>="3.8"', ], extras_require={"kubernetes": ["kubernetes>=3.0.0", "cryptography>=2.0.0"]}, entry_points={"console_scripts": ["dagster-airflow = dagster_airflow.cli:main"]}, ) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py b/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py index 9f18c1ebb..20c355d12 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py @@ -1,267 +1,267 @@ import csv import io import os import time import uuid from urllib.parse import urlparse import boto3 from botocore.stub import Stubber from dagster import Field, StringSource, check, resource class AthenaError(Exception): pass class AthenaTimeout(AthenaError): pass class AthenaResource: def __init__(self, client, workgroup="primary", polling_interval=5, max_polls=120): check.invariant( polling_interval >= 0, "polling_interval must be greater than or equal to 0" ) check.invariant(max_polls > 0, "max_polls must be greater than 0") self.client = client self.workgroup = workgroup self.max_polls = max_polls self.polling_interval = polling_interval def execute_query(self, query, fetch_results=False): """Synchronously execute a single query against Athena. If fetch_results is set to true, will return a list of rows, where each row is a tuple of stringified values, e.g. SELECT 1 will return [("1",)]. Args: query (str): The query to execute. fetch_results (Optional[bool]): Whether to return the results of executing the query. Defaults to False, in which case the query will be executed without retrieving the results. Returns: Optional[List[Tuple[Optional[str], ...]]]: Results of the query, as a list of tuples, when fetch_results is set. Otherwise, return None. All items in the tuple are represented as strings except for empty columns which are represented as None. """ check.str_param(query, "query") check.bool_param(fetch_results, "fetch_results") execution_id = self.client.start_query_execution( QueryString=query, WorkGroup=self.workgroup )["QueryExecutionId"] self._poll(execution_id) if fetch_results: return self._results(execution_id) def _poll(self, execution_id): retries = self.max_polls state = "QUEUED" while retries > 0: execution = self.client.get_query_execution(QueryExecutionId=execution_id)[ "QueryExecution" ] state = execution["Status"]["State"] if state not in ["QUEUED", "RUNNING"]: break retries -= 1 time.sleep(self.polling_interval) if retries <= 0: raise AthenaTimeout() if state != "SUCCEEDED": raise AthenaError(execution["Status"]["StateChangeReason"]) def _results(self, execution_id): execution = self.client.get_query_execution(QueryExecutionId=execution_id)["QueryExecution"] s3 = boto3.resource("s3") output_location = execution["ResultConfiguration"]["OutputLocation"] bucket = urlparse(output_location).netloc prefix = urlparse(output_location).path.lstrip("/") results = [] - rows = s3.Bucket(bucket).Object(prefix).get()["Body"].read().decode().splitlines() + rows = s3.Bucket(bucket).Object(prefix).get()["Body"].read().decode("utf-8").splitlines() reader = csv.reader(rows) next(reader) # Skip the CSV's header row for row in reader: results.append(tuple(row)) return results class FakeAthenaResource(AthenaResource): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.polling_interval = 0 self.stubber = Stubber(self.client) s3 = boto3.resource("s3", region_name="us-east-1") self.bucket = s3.Bucket("fake-athena-results-bucket") self.bucket.create() def execute_query( self, query, fetch_results=False, expected_states=None, expected_results=None ): # pylint: disable=arguments-differ """Fake for execute_query; stubs the expected Athena endpoints, polls against the provided expected query execution states, and returns the provided results as a list of tuples. Args: query (str): The query to execute. fetch_results (Optional[bool]): Whether to return the results of executing the query. Defaults to False, in which case the query will be executed without retrieving the results. expected_states (list[str]): The expected query execution states. Defaults to successfully passing through QUEUED, RUNNING, and SUCCEEDED. expected_results ([List[Tuple[Any, ...]]]): The expected results. All non-None items are cast to strings. Defaults to [(1,)]. Returns: Optional[List[Tuple[Optional[str], ...]]]: The expected_resutls when fetch_resutls is set. Otherwise, return None. All items in the tuple are represented as strings except for empty columns which are represented as None. """ if not expected_states: expected_states = ["QUEUED", "RUNNING", "SUCCEEDED"] if not expected_results: expected_results = [("1",)] self.stubber.activate() execution_id = str(uuid.uuid4()) self._stub_start_query_execution(execution_id, query) self._stub_get_query_execution(execution_id, expected_states) if expected_states[-1] == "SUCCEEDED" and fetch_results: self._fake_results(execution_id, expected_results) result = super().execute_query(query, fetch_results=fetch_results) self.stubber.deactivate() self.stubber.assert_no_pending_responses() return result def _stub_start_query_execution(self, execution_id, query): self.stubber.add_response( method="start_query_execution", service_response={"QueryExecutionId": execution_id}, expected_params={"QueryString": query, "WorkGroup": self.workgroup}, ) def _stub_get_query_execution(self, execution_id, states): for state in states: self.stubber.add_response( method="get_query_execution", service_response={ "QueryExecution": { "Status": {"State": state, "StateChangeReason": "state change reason"}, } }, expected_params={"QueryExecutionId": execution_id}, ) def _fake_results(self, execution_id, expected_results): with io.StringIO() as results: writer = csv.writer(results) # Athena adds a header row to its CSV output writer.writerow([]) for row in expected_results: # Athena writes all non-null columns as strings in its CSV output stringified = tuple([str(item) for item in row if item]) writer.writerow(stringified) results.seek(0) self.bucket.Object(execution_id + ".csv").put(Body=results.read()) self.stubber.add_response( method="get_query_execution", service_response={ "QueryExecution": { "ResultConfiguration": { "OutputLocation": os.path.join( "s3://", self.bucket.name, execution_id + ".csv" ) } } }, expected_params={"QueryExecutionId": execution_id}, ) def athena_config(): """Athena configuration.""" return { "workgroup": Field( str, description="The Athena WorkGroup. https://docs.aws.amazon.com/athena/latest/ug/manage-queries-control-costs-with-workgroups.html", is_required=False, default_value="primary", ), "polling_interval": Field( int, description="Time in seconds between checks to see if a query execution is finished. 5 seconds by default. Must be non-negative.", is_required=False, default_value=5, ), "max_polls": Field( int, description="Number of times to poll before timing out. 120 attempts by default. When coupled with the default polling_interval, queries will timeout after 10 minutes (120 * 5 seconds). Must be greater than 0.", is_required=False, default_value=120, ), "aws_access_key_id": Field(StringSource, is_required=False), "aws_secret_access_key": Field(StringSource, is_required=False), } @resource( config_schema=athena_config(), description="Resource for connecting to AWS Athena", ) def athena_resource(context): """This resource enables connecting to AWS Athena and issuing queries against it. Example: .. code-block:: python from dagster import ModeDefinition, execute_solid, solid from dagster_aws.athena import athena_resource @solid(required_resource_keys={"athena"}) def example_athena_solid(context): return context.resources.athena.execute_query("SELECT 1", fetch_results=True) result = execute_solid( example_athena_solid, mode_def=ModeDefinition(resource_defs={"athena": athena_resource}), ) assert result.output_value() == [("1",)] """ client = boto3.client( "athena", aws_access_key_id=context.resource_config.get("aws_access_key_id"), aws_secret_access_key=context.resource_config.get("aws_secret_access_key"), ) return AthenaResource( client=client, workgroup=context.resource_config.get("workgroup"), polling_interval=context.resource_config.get("polling_interval"), max_polls=context.resource_config.get("max_polls"), ) @resource( config_schema=athena_config(), description="Fake resource for connecting to AWS Athena", ) def fake_athena_resource(context): return FakeAthenaResource( client=boto3.client("athena", region_name="us-east-1"), workgroup=context.resource_config.get("workgroup"), polling_interval=context.resource_config.get("polling_interval"), max_polls=context.resource_config.get("max_polls"), ) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py b/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py index 628990547..6a3accf74 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py @@ -1,425 +1,423 @@ # Portions of this file are copied from the Yelp MRJob project: # # https://github.com/Yelp/mrjob # # # Copyright 2009-2013 Yelp, David Marin # Copyright 2015 Yelp # Copyright 2017 Yelp # Copyright 2018 Contributors # Copyright 2019 Yelp and Contributors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gzip import re from io import BytesIO from urllib.parse import urlparse import boto3 import dagster -import six from botocore.exceptions import WaiterError from dagster import check from dagster_aws.utils.mrjob.utils import _boto3_now, _wrap_aws_client, strip_microseconds from .types import EMR_CLUSTER_TERMINATED_STATES, EmrClusterState, EmrStepState # if we can't create or find our own service role, use the one # created by the AWS console and CLI _FALLBACK_SERVICE_ROLE = "EMR_DefaultRole" # if we can't create or find our own instance profile, use the one # created by the AWS console and CLI _FALLBACK_INSTANCE_PROFILE = "EMR_EC2_DefaultRole" class EmrError(Exception): pass class EmrJobRunner: def __init__( self, region, check_cluster_every=30, aws_access_key_id=None, aws_secret_access_key=None, ): """This object encapsulates various utilities for interacting with EMR clusters and invoking steps (jobs) on them. See also :py:class:`~dagster_aws.emr.EmrPySparkResource`, which wraps this job runner in a resource for pyspark workloads. Args: region (str): AWS region to use check_cluster_every (int, optional): How frequently to poll boto3 APIs for updates. Defaults to 30 seconds. aws_access_key_id ([type], optional): AWS access key ID. Defaults to None, which will use the default boto3 credentials chain. aws_secret_access_key ([type], optional): AWS secret access key. Defaults to None, which will use the default boto3 credentials chain. """ self.region = check.str_param(region, "region") # This is in seconds self.check_cluster_every = check.int_param(check_cluster_every, "check_cluster_every") self.aws_access_key_id = check.opt_str_param(aws_access_key_id, "aws_access_key_id") self.aws_secret_access_key = check.opt_str_param( aws_secret_access_key, "aws_secret_access_key" ) def make_emr_client(self): """Creates a boto3 EMR client. Construction is wrapped in retries in case client connection fails transiently. Returns: botocore.client.EMR: An EMR client """ raw_emr_client = boto3.client( "emr", aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, region_name=self.region, ) return _wrap_aws_client(raw_emr_client, min_backoff=self.check_cluster_every) def cluster_id_from_name(self, cluster_name): """Get a cluster ID in the format "j-123ABC123ABC1" given a cluster name "my cool cluster". Args: cluster_name (str): The name of the cluster for which to find an ID Returns: str: The ID of the cluster Raises: EmrError: No cluster with the specified name exists """ check.str_param(cluster_name, "cluster_name") response = self.make_emr_client().list_clusters().get("Clusters", []) for cluster in response: if cluster["Name"] == cluster_name: return cluster["Id"] raise EmrError( "cluster {cluster_name} not found in region {region}".format( cluster_name=cluster_name, region=self.region ) ) @staticmethod def construct_step_dict_for_command(step_name, command, action_on_failure="CONTINUE"): """Construct an EMR step definition which uses command-runner.jar to execute a shell command on the EMR master. Args: step_name (str): The name of the EMR step (will show up in the EMR UI) command (str): The shell command to execute with command-runner.jar action_on_failure (str, optional): Configure action on failure (e.g., continue, or terminate the cluster). Defaults to 'CONTINUE'. Returns: dict: Step definition dict """ check.str_param(step_name, "step_name") check.list_param(command, "command", of_type=str) check.str_param(action_on_failure, "action_on_failure") return { "Name": step_name, "ActionOnFailure": action_on_failure, "HadoopJarStep": {"Jar": "command-runner.jar", "Args": command}, } def add_tags(self, log, tags, cluster_id): """Add tags in the dict tags to cluster cluster_id. Args: log (DagsterLogManager): Log manager, for logging tags (dict): Dictionary of {'key': 'value'} tags cluster_id (str): The ID of the cluster to tag """ check.dict_param(tags, "tags") check.str_param(cluster_id, "cluster_id") tags_items = sorted(tags.items()) self.make_emr_client().add_tags( ResourceId=cluster_id, Tags=[dict(Key=k, Value=v) for k, v in tags_items] ) log.info( "Added EMR tags to cluster %s: %s" % (cluster_id, ", ".join("%s=%s" % (tag, value) for tag, value in tags_items)) ) def run_job_flow(self, log, cluster_config): """Create an empty cluster on EMR, and return the ID of that job flow. Args: log (DagsterLogManager): Log manager, for logging cluster_config (dict): Configuration for this EMR job flow. See: https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html Returns: str: The cluster ID, e.g. "j-ZKIY4CKQRX72" """ check.dict_param(cluster_config, "cluster_config") log.debug("Creating Elastic MapReduce cluster") emr_client = self.make_emr_client() log.debug( "Calling run_job_flow(%s)" % (", ".join("%s=%r" % (k, v) for k, v in sorted(cluster_config.items()))) ) cluster_id = emr_client.run_job_flow(**cluster_config)["JobFlowId"] log.info("Created new cluster %s" % cluster_id) # set EMR tags for the cluster tags = cluster_config.get("Tags", {}) tags["__dagster_version"] = dagster.__version__ self.add_tags(log, tags, cluster_id) return cluster_id def describe_cluster(self, cluster_id): """Thin wrapper over boto3 describe_cluster. Args: cluster_id (str): Cluster to inspect Returns: dict: The cluster info. See: https://docs.aws.amazon.com/emr/latest/APIReference/API_DescribeCluster.html """ check.str_param(cluster_id, "cluster_id") emr_client = self.make_emr_client() return emr_client.describe_cluster(ClusterId=cluster_id) def describe_step(self, cluster_id, step_id): """Thin wrapper over boto3 describe_step. Args: cluster_id (str): Cluster to inspect step_id (str): Step ID to describe Returns: dict: The step info. See: https://docs.aws.amazon.com/emr/latest/APIReference/API_DescribeStep.html """ check.str_param(cluster_id, "cluster_id") check.str_param(step_id, "step_id") emr_client = self.make_emr_client() return emr_client.describe_step(ClusterId=cluster_id, StepId=step_id) def add_job_flow_steps(self, log, cluster_id, step_defs): """Submit the constructed job flow steps to EMR for execution. Args: log (DagsterLogManager): Log manager, for logging cluster_id (str): The ID of the cluster step_defs (List[dict]): List of steps; see also `construct_step_dict_for_command` Returns: List[str]: list of step IDs. """ check.str_param(cluster_id, "cluster_id") check.list_param(step_defs, "step_defs", of_type=dict) emr_client = self.make_emr_client() steps_kwargs = dict(JobFlowId=cluster_id, Steps=step_defs) log.debug( "Calling add_job_flow_steps(%s)" % ",".join(("%s=%r" % (k, v)) for k, v in steps_kwargs.items()) ) return emr_client.add_job_flow_steps(**steps_kwargs)["StepIds"] def is_emr_step_complete(self, log, cluster_id, emr_step_id): step = self.describe_step(cluster_id, emr_step_id)["Step"] step_state = EmrStepState(step["Status"]["State"]) if step_state == EmrStepState.Pending: cluster = self.describe_cluster(cluster_id)["Cluster"] reason = _get_reason(cluster) reason_desc = (": %s" % reason) if reason else "" log.info("PENDING (cluster is %s%s)" % (cluster["Status"]["State"], reason_desc)) return False elif step_state == EmrStepState.Running: time_running_desc = "" start = step["Status"]["Timeline"].get("StartDateTime") if start: time_running_desc = " for %s" % strip_microseconds(_boto3_now() - start) log.info("RUNNING%s" % time_running_desc) return False # we're done, will return at the end of this elif step_state == EmrStepState.Completed: log.info("COMPLETED") return True else: # step has failed somehow. *reason* seems to only be set # when job is cancelled (e.g. 'Job terminated') reason = _get_reason(step) reason_desc = (" (%s)" % reason) if reason else "" log.info("%s%s" % (step_state.value, reason_desc)) # print cluster status; this might give more context # why step didn't succeed cluster = self.describe_cluster(cluster_id)["Cluster"] reason = _get_reason(cluster) reason_desc = (": %s" % reason) if reason else "" log.info( "Cluster %s %s %s%s" % ( cluster["Id"], "was" if "ED" in cluster["Status"]["State"] else "is", cluster["Status"]["State"], reason_desc, ) ) if EmrClusterState(cluster["Status"]["State"]) in EMR_CLUSTER_TERMINATED_STATES: # was it caused by IAM roles? self._check_for_missing_default_iam_roles(log, cluster) # TODO: extract logs here to surface failure reason # See: https://github.com/dagster-io/dagster/issues/1954 if step_state == EmrStepState.Failed: log.error("EMR step %s failed" % emr_step_id) raise EmrError("EMR step %s failed" % emr_step_id) def _check_for_missing_default_iam_roles(self, log, cluster): """If cluster couldn't start due to missing IAM roles, tell user what to do.""" check.dict_param(cluster, "cluster") reason = _get_reason(cluster) if any( reason.endswith("/%s is invalid" % role) for role in (_FALLBACK_INSTANCE_PROFILE, _FALLBACK_SERVICE_ROLE) ): log.warning( "IAM roles are missing. See documentation for IAM roles on EMR here: " "https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-iam-roles.html" ) def log_location_for_cluster(self, cluster_id): """EMR clusters are typically launched with S3 logging configured. This method inspects a cluster using boto3 describe_cluster to retrieve the log URI. Args: cluster_id (str): The cluster to inspect. Raises: EmrError: the log URI was missing (S3 log mirroring not enabled for this cluster) Returns: (str, str): log bucket and key """ check.str_param(cluster_id, "cluster_id") # The S3 log URI is specified per job flow (cluster) log_uri = self.describe_cluster(cluster_id)["Cluster"].get("LogUri", None) # ugh, seriously boto3?! This will come back as string "None" if log_uri == "None" or log_uri is None: raise EmrError("Log URI not specified, cannot retrieve step execution logs") # For some reason the API returns an s3n:// protocol log URI instead of s3:// log_uri = re.sub("^s3n", "s3", log_uri) log_uri_parsed = urlparse(log_uri) log_bucket = log_uri_parsed.netloc log_key_prefix = log_uri_parsed.path.lstrip("/") return log_bucket, log_key_prefix def retrieve_logs_for_step_id(self, log, cluster_id, step_id): """Retrieves stdout and stderr logs for the given step ID. Args: log (DagsterLogManager): Log manager, for logging cluster_id (str): EMR cluster ID step_id (str): EMR step ID for the job that was submitted. Returns (str, str): Tuple of stdout log string contents, and stderr log string contents """ check.str_param(cluster_id, "cluster_id") check.str_param(step_id, "step_id") log_bucket, log_key_prefix = self.log_location_for_cluster(cluster_id) prefix = "{log_key_prefix}{cluster_id}/steps/{step_id}".format( log_key_prefix=log_key_prefix, cluster_id=cluster_id, step_id=step_id ) stdout_log = self.wait_for_log(log, log_bucket, "{prefix}/stdout.gz".format(prefix=prefix)) stderr_log = self.wait_for_log(log, log_bucket, "{prefix}/stderr.gz".format(prefix=prefix)) return stdout_log, stderr_log def wait_for_log(self, log, log_bucket, log_key, waiter_delay=30, waiter_max_attempts=20): """Wait for gzipped EMR logs to appear on S3. Note that EMR syncs logs to S3 every 5 minutes, so this may take a long time. Args: log_bucket (str): S3 bucket where log is expected to appear log_key (str): S3 key for the log file waiter_delay (int): How long to wait between attempts to check S3 for the log file waiter_max_attempts (int): Number of attempts before giving up on waiting Raises: EmrError: Raised if we waited the full duration and the logs did not appear Returns: str: contents of the log file """ check.str_param(log_bucket, "log_bucket") check.str_param(log_key, "log_key") check.int_param(waiter_delay, "waiter_delay") check.int_param(waiter_max_attempts, "waiter_max_attempts") log.info( "Attempting to get log: s3://{log_bucket}/{log_key}".format( log_bucket=log_bucket, log_key=log_key ) ) s3 = _wrap_aws_client(boto3.client("s3"), min_backoff=self.check_cluster_every) waiter = s3.get_waiter("object_exists") try: waiter.wait( Bucket=log_bucket, Key=log_key, WaiterConfig={"Delay": waiter_delay, "MaxAttempts": waiter_max_attempts}, ) except WaiterError as err: - six.raise_from( - EmrError("EMR log file did not appear on S3 after waiting"), err, - ) + raise EmrError("EMR log file did not appear on S3 after waiting") from err + obj = BytesIO(s3.get_object(Bucket=log_bucket, Key=log_key)["Body"].read()) gzip_file = gzip.GzipFile(fileobj=obj) return gzip_file.read().decode("utf-8") def _get_reason(cluster_or_step): """Get state change reason message.""" # StateChangeReason is {} before the first state change return cluster_or_step["Status"]["StateChangeReason"].get("Message", "") diff --git a/python_modules/libraries/dagster-aws/dagster_aws/redshift/resources.py b/python_modules/libraries/dagster-aws/dagster_aws/redshift/resources.py index 6e7d4df45..fe123de2c 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/redshift/resources.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/redshift/resources.py @@ -1,363 +1,358 @@ import abc from contextlib import contextmanager import psycopg2 import psycopg2.extensions -import six from dagster import Field, IntSource, StringSource, check, resource class RedshiftError(Exception): pass class _BaseRedshiftResource(abc.ABC): def __init__(self, context): # pylint: disable=too-many-locals # Extract parameters from resource config self.conn_args = { k: context.resource_config.get(k) for k in ( "host", "port", "user", "password", "database", "schema", "connect_timeout", "sslmode", ) if context.resource_config.get(k) is not None } self.autocommit = context.resource_config.get("autocommit") self.log = context.log_manager @abc.abstractmethod def execute_query(self, query, fetch_results=False, cursor_factory=None, error_callback=None): pass @abc.abstractmethod def execute_queries( self, queries, fetch_results=False, cursor_factory=None, error_callback=None ): pass class RedshiftResource(_BaseRedshiftResource): def execute_query(self, query, fetch_results=False, cursor_factory=None, error_callback=None): """Synchronously execute a single query against Redshift. Will return a list of rows, where each row is a tuple of values, e.g. SELECT 1 will return [(1,)]. Args: query (str): The query to execute. fetch_results (Optional[bool]): Whether to return the results of executing the query. Defaults to False, in which case the query will be executed without retrieving the results. cursor_factory (Optional[:py:class:`psycopg2.extensions.cursor`]): An alternative cursor_factory; defaults to None. Will be used when constructing the cursor. error_callback (Optional[Callable[[Exception, Cursor, DagsterLogManager], None]]): A callback function, invoked when an exception is encountered during query execution; this is intended to support executing additional queries to provide diagnostic information, e.g. by querying ``stl_load_errors`` using ``pg_last_copy_id()``. If no function is provided, exceptions during query execution will be raised directly. Returns: Optional[List[Tuple[Any, ...]]]: Results of the query, as a list of tuples, when fetch_results is set. Otherwise return None. """ check.str_param(query, "query") check.bool_param(fetch_results, "fetch_results") check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor) check.opt_callable_param(error_callback, "error_callback") with self._get_conn() as conn: with self._get_cursor(conn, cursor_factory=cursor_factory) as cursor: try: - six.ensure_str(query) - self.log.info("Executing query '{query}'".format(query=query)) cursor.execute(query) if fetch_results and cursor.rowcount > 0: return cursor.fetchall() else: self.log.info("Empty result from query") except Exception as e: # pylint: disable=broad-except # If autocommit is disabled or not set (it is disabled by default), Redshift # will be in the middle of a transaction at exception time, and because of # the failure the current transaction will not accept any further queries. # # This conn.commit() call closes the open transaction before handing off # control to the error callback, so that the user can issue additional # queries. Notably, for e.g. pg_last_copy_id() to work, it requires you to # use the same conn/cursor, so you have to do this conn.commit() to ensure # things are in a usable state in the error callback. if not self.autocommit: conn.commit() if error_callback is not None: error_callback(e, cursor, self.log) else: raise def execute_queries( self, queries, fetch_results=False, cursor_factory=None, error_callback=None ): """Synchronously execute a list of queries against Redshift. Will return a list of list of rows, where each row is a tuple of values, e.g. ['SELECT 1', 'SELECT 1'] will return [[(1,)], [(1,)]]. Args: queries (List[str]): The queries to execute. fetch_results (Optional[bool]): Whether to return the results of executing the query. Defaults to False, in which case the query will be executed without retrieving the results. cursor_factory (Optional[:py:class:`psycopg2.extensions.cursor`]): An alternative cursor_factory; defaults to None. Will be used when constructing the cursor. error_callback (Optional[Callable[[Exception, Cursor, DagsterLogManager], None]]): A callback function, invoked when an exception is encountered during query execution; this is intended to support executing additional queries to provide diagnostic information, e.g. by querying ``stl_load_errors`` using ``pg_last_copy_id()``. If no function is provided, exceptions during query execution will be raised directly. Returns: Optional[List[List[Tuple[Any, ...]]]]: Results of the query, as a list of list of tuples, when fetch_results is set. Otherwise return None. """ check.list_param(queries, "queries", of_type=str) check.bool_param(fetch_results, "fetch_results") check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor) check.opt_callable_param(error_callback, "error_callback") results = [] with self._get_conn() as conn: with self._get_cursor(conn, cursor_factory=cursor_factory) as cursor: for query in queries: - six.ensure_str(query) - try: self.log.info("Executing query '{query}'".format(query=query)) cursor.execute(query) if fetch_results and cursor.rowcount > 0: results.append(cursor.fetchall()) else: results.append([]) self.log.info("Empty result from query") except Exception as e: # pylint: disable=broad-except # If autocommit is disabled or not set (it is disabled by default), Redshift # will be in the middle of a transaction at exception time, and because of # the failure the current transaction will not accept any further queries. # # This conn.commit() call closes the open transaction before handing off # control to the error callback, so that the user can issue additional # queries. Notably, for e.g. pg_last_copy_id() to work, it requires you to # use the same conn/cursor, so you have to do this conn.commit() to ensure # things are in a usable state in the error callback. if not self.autocommit: conn.commit() if error_callback is not None: error_callback(e, cursor, self.log) else: raise if fetch_results: return results @contextmanager def _get_conn(self): try: conn = psycopg2.connect(**self.conn_args) yield conn finally: conn.close() @contextmanager def _get_cursor(self, conn, cursor_factory=None): check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor) # Could be none, in which case we should respect the connection default. Otherwise # explicitly set to true/false. if self.autocommit is not None: conn.autocommit = self.autocommit with conn: with conn.cursor(cursor_factory=cursor_factory) as cursor: yield cursor # If autocommit is set, we'll commit after each and every query execution. Otherwise, we # want to do a final commit after we're wrapped up executing the full set of one or more # queries. if not self.autocommit: conn.commit() class FakeRedshiftResource(_BaseRedshiftResource): QUERY_RESULT = [(1,)] def execute_query(self, query, fetch_results=False, cursor_factory=None, error_callback=None): """Fake for execute_query; returns [self.QUERY_RESULT] Args: query (str): The query to execute. fetch_results (Optional[bool]): Whether to return the results of executing the query. Defaults to False, in which case the query will be executed without retrieving the results. cursor_factory (Optional[:py:class:`psycopg2.extensions.cursor`]): An alternative cursor_factory; defaults to None. Will be used when constructing the cursor. error_callback (Optional[Callable[[Exception, Cursor, DagsterLogManager], None]]): A callback function, invoked when an exception is encountered during query execution; this is intended to support executing additional queries to provide diagnostic information, e.g. by querying ``stl_load_errors`` using ``pg_last_copy_id()``. If no function is provided, exceptions during query execution will be raised directly. Returns: Optional[List[Tuple[Any, ...]]]: Results of the query, as a list of tuples, when fetch_results is set. Otherwise return None. """ check.str_param(query, "query") check.bool_param(fetch_results, "fetch_results") check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor) check.opt_callable_param(error_callback, "error_callback") self.log.info("Executing query '{query}'".format(query=query)) if fetch_results: return self.QUERY_RESULT def execute_queries( self, queries, fetch_results=False, cursor_factory=None, error_callback=None ): """Fake for execute_queries; returns [self.QUERY_RESULT] * 3 Args: queries (List[str]): The queries to execute. fetch_results (Optional[bool]): Whether to return the results of executing the query. Defaults to False, in which case the query will be executed without retrieving the results. cursor_factory (Optional[:py:class:`psycopg2.extensions.cursor`]): An alternative cursor_factory; defaults to None. Will be used when constructing the cursor. error_callback (Optional[Callable[[Exception, Cursor, DagsterLogManager], None]]): A callback function, invoked when an exception is encountered during query execution; this is intended to support executing additional queries to provide diagnostic information, e.g. by querying ``stl_load_errors`` using ``pg_last_copy_id()``. If no function is provided, exceptions during query execution will be raised directly. Returns: Optional[List[List[Tuple[Any, ...]]]]: Results of the query, as a list of list of tuples, when fetch_results is set. Otherwise return None. """ check.list_param(queries, "queries", of_type=str) check.bool_param(fetch_results, "fetch_results") check.opt_subclass_param(cursor_factory, "cursor_factory", psycopg2.extensions.cursor) check.opt_callable_param(error_callback, "error_callback") for query in queries: self.log.info("Executing query '{query}'".format(query=query)) if fetch_results: return [self.QUERY_RESULT] * 3 def define_redshift_config(): """Redshift configuration. See the Redshift documentation for reference: https://docs.aws.amazon.com/redshift/latest/mgmt/connecting-to-cluster.html """ return { "host": Field(StringSource, description="Redshift host", is_required=True), "port": Field( IntSource, description="Redshift port", is_required=False, default_value=5439 ), "user": Field( StringSource, description="Username for Redshift connection", is_required=False, ), "password": Field( StringSource, description="Password for Redshift connection", is_required=False, ), "database": Field( StringSource, description="Name of the default database to use. After login, you can use USE DATABASE" " to change the database.", is_required=False, ), "schema": Field( StringSource, description="Name of the default schema to use. After login, you can use USE SCHEMA to " "change the schema.", is_required=False, ), "autocommit": Field( bool, description="None by default, which honors the Redshift parameter AUTOCOMMIT. Set to " "True or False to enable or disable autocommit mode in the session, respectively.", is_required=False, ), "connect_timeout": Field( int, description="Connection timeout in seconds. 5 seconds by default", is_required=False, default_value=5, ), "sslmode": Field( str, description="SSL mode to use. See the Redshift documentation for more information on " "usage: https://docs.aws.amazon.com/redshift/latest/mgmt/connecting-ssl-support.html", is_required=False, default_value="require", ), } @resource( config_schema=define_redshift_config(), description="Resource for connecting to the Redshift data warehouse", ) def redshift_resource(context): """This resource enables connecting to a Redshift cluster and issuing queries against that cluster. Example: .. code-block:: python from dagster import ModeDefinition, execute_solid, solid from dagster_aws.redshift import redshift_resource @solid(required_resource_keys={'redshift'}) def example_redshift_solid(context): return context.resources.redshift.execute_query('SELECT 1', fetch_results=True) result = execute_solid( example_redshift_solid, run_config={ 'resources': { 'redshift': { 'config': { 'host': 'my-redshift-cluster.us-east-1.redshift.amazonaws.com', 'port': 5439, 'user': 'dagster', 'password': 'dagster', 'database': 'dev', } } } }, mode_def=ModeDefinition(resource_defs={'redshift': redshift_resource}), ) assert result.output_value() == [(1,)] """ return RedshiftResource(context) @resource( config_schema=define_redshift_config(), description="Fake resource for connecting to the Redshift data warehouse. Usage is identical " "to the real redshift_resource. Will always return [(1,)] for the single query case and " "[[(1,)], [(1,)], [(1,)]] for the multi query case.", ) def fake_redshift_resource(context): return FakeRedshiftResource(context) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/utils/mrjob/log4j.py b/python_modules/libraries/dagster-aws/dagster_aws/utils/mrjob/log4j.py index c5b86d7f5..d4b0da9f5 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/utils/mrjob/log4j.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/utils/mrjob/log4j.py @@ -1,132 +1,130 @@ # -*- coding: utf-8 -*- # Copyright 2015-2016 Yelp # Copyright 2019 Yelp # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parse the log4j syslog format used by Hadoop.""" import re from collections import namedtuple -import six - # log line format output by hadoop jar command _HADOOP_LOG4J_LINE_RE = re.compile( r"^\s*(?P\d\d\/\d\d\/\d\d \d\d\:\d\d\:\d\d)" r"\s+(?P[A-Z]+)" r"\s+(?P\S+)" r"(\s+\((?P.*?)\))?" r"( - ?|: ?)" r"(?P.*?)$" ) # log line format output to Hadoop syslog _HADOOP_LOG4J_LINE_ALTERNATE_RE = re.compile( r"^\s*(?P\d\d\/\d\d\/\d\d \d\d\:\d\d\:\d\d)" r"\s+(?P[A-Z]+)" r"(\s+\[(?P.*?)\])" r"\s+(?P\S+)" r"(\s+\((?P\S+)\))?" r"( - ?|: ?)" r"(?P.*?)$" ) class Log4jRecord( namedtuple( "_Log4jRecord", "caller_location level logger message num_lines start_line thread timestamp" ) ): """Represents a Log4J log record. caller_location -- e.g. 'YarnClientImpl.java:submitApplication(251)' level -- e.g. 'INFO' logger -- e.g. 'amazon.emr.metrics.MetricsSaver' message -- the actual message. If this is a multi-line message (e.g. for counters), the lines will be joined by '\n' num_lines -- how many lines made up the message start_line -- which line the message started on (0-indexed) thread -- e.g. 'main'. Defaults to '' timestamp -- unparsed timestamp, e.g. '15/12/07 20:49:28' """ def __new__( cls, caller_location, level, logger, message, num_lines, start_line, thread, timestamp ): return super(Log4jRecord, cls).__new__( cls, caller_location, level, logger, message, num_lines, start_line, thread, timestamp ) @staticmethod def fake_record(line, line_num): """Used to represent a leading Log4J line that doesn't conform to the regular expressions we expect. """ return Log4jRecord( caller_location="", level="", logger="", message=line, num_lines=1, start_line=line_num, thread="", timestamp="", ) def parse_hadoop_log4j_records(lines): """Parse lines from a hadoop log into log4j records. Yield Log4jRecords. Lines will be converted to unicode, and trailing \r and \n will be stripped from lines. Also yields fake records for leading non-log4j lines (trailing non-log4j lines are assumed to be part of a multiline message if not pre-filtered). """ last_record = None line_num = 0 for line_num, line in enumerate(lines.split("\n")): # convert from bytes to unicode, if needed, and strip trailing newlines - line = six.ensure_str(line).rstrip("\r\n") + line = line.rstrip("\r\n") m = _HADOOP_LOG4J_LINE_RE.match(line) or _HADOOP_LOG4J_LINE_ALTERNATE_RE.match(line) if m: if last_record: last_record = last_record._replace(num_lines=line_num - last_record.start_line) yield last_record matches = m.groupdict() last_record = Log4jRecord( caller_location=matches.get("caller_location", ""), level=matches["level"], logger=matches["logger"], message=matches["message"], num_lines=1, start_line=line_num, thread=matches.get("thread", ""), timestamp=matches["timestamp"], ) else: # add on to previous record if last_record: last_record = last_record._replace(message=last_record.message + "\n" + line) else: yield Log4jRecord.fake_record(line, line_num) if last_record: last_record = last_record._replace(num_lines=line_num + 1 - last_record.start_line) yield last_record diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/emr_tests/test_emr.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/emr_tests/test_emr.py index 40e830240..875939a8b 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/emr_tests/test_emr.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/emr_tests/test_emr.py @@ -1,208 +1,208 @@ import copy import gzip import io import threading import time import pytest from dagster.seven import mock from dagster.utils.test import create_test_pipeline_execution_context from dagster_aws.emr import EmrClusterState, EmrError, EmrJobRunner from dagster_aws.utils.mrjob.utils import _boto3_now from moto import mock_emr REGION = "us-west-1" @mock_emr def test_emr_create_cluster(emr_cluster_config): context = create_test_pipeline_execution_context() cluster = EmrJobRunner(region=REGION) cluster_id = cluster.run_job_flow(context.log, emr_cluster_config) assert cluster_id.startswith("j-") @mock_emr def test_emr_add_tags_and_describe_cluster(emr_cluster_config): context = create_test_pipeline_execution_context() emr = EmrJobRunner(region=REGION) cluster_id = emr.run_job_flow(context.log, emr_cluster_config) emr.add_tags(context.log, {"foobar": "v1", "baz": "123"}, cluster_id) tags = emr.describe_cluster(cluster_id)["Cluster"]["Tags"] assert {"Key": "baz", "Value": "123"} in tags assert {"Key": "foobar", "Value": "v1"} in tags @mock_emr def test_emr_describe_cluster(emr_cluster_config): context = create_test_pipeline_execution_context() cluster = EmrJobRunner(region=REGION) cluster_id = cluster.run_job_flow(context.log, emr_cluster_config) cluster_info = cluster.describe_cluster(cluster_id)["Cluster"] assert cluster_info["Name"] == "test-emr" assert EmrClusterState(cluster_info["Status"]["State"]) == EmrClusterState.Waiting @mock_emr def test_emr_id_from_name(emr_cluster_config): context = create_test_pipeline_execution_context() cluster = EmrJobRunner(region=REGION) cluster_id = cluster.run_job_flow(context.log, emr_cluster_config) assert cluster.cluster_id_from_name("test-emr") == cluster_id with pytest.raises(EmrError) as exc_info: cluster.cluster_id_from_name("cluster-doesnt-exist") assert "cluster cluster-doesnt-exist not found in region us-west-1" in str(exc_info.value) def test_emr_construct_step_dict(): cmd = ["pip", "install", "dagster"] assert EmrJobRunner.construct_step_dict_for_command("test_step", cmd) == { "Name": "test_step", "ActionOnFailure": "CONTINUE", "HadoopJarStep": {"Jar": "command-runner.jar", "Args": cmd}, } assert EmrJobRunner.construct_step_dict_for_command( "test_second_step", cmd, action_on_failure="CANCEL_AND_WAIT" ) == { "Name": "test_second_step", "ActionOnFailure": "CANCEL_AND_WAIT", "HadoopJarStep": {"Jar": "command-runner.jar", "Args": cmd}, } @mock_emr def test_emr_log_location_for_cluster(emr_cluster_config, mock_s3_bucket): context = create_test_pipeline_execution_context() emr = EmrJobRunner(region=REGION) cluster_id = emr.run_job_flow(context.log, emr_cluster_config) assert emr.log_location_for_cluster(cluster_id) == (mock_s3_bucket.name, "elasticmapreduce/") # Should raise when the log URI is missing emr_cluster_config = copy.deepcopy(emr_cluster_config) del emr_cluster_config["LogUri"] cluster_id = emr.run_job_flow(context.log, emr_cluster_config) with pytest.raises(EmrError) as exc_info: emr.log_location_for_cluster(cluster_id) assert "Log URI not specified, cannot retrieve step execution logs" in str(exc_info.value) @mock_emr def test_emr_retrieve_logs(emr_cluster_config, mock_s3_bucket): context = create_test_pipeline_execution_context() emr = EmrJobRunner(region=REGION) cluster_id = emr.run_job_flow(context.log, emr_cluster_config) assert emr.log_location_for_cluster(cluster_id) == (mock_s3_bucket.name, "elasticmapreduce/") def create_log(): time.sleep(0.5) out = io.BytesIO() with gzip.GzipFile(fileobj=out, mode="w") as fo: - fo.write("some log".encode()) + fo.write(b"some log") prefix = "elasticmapreduce/{cluster_id}/steps/{step_id}".format( cluster_id=cluster_id, step_id="s-123456123456" ) for name in ["stdout.gz", "stderr.gz"]: mock_s3_bucket.Object(prefix + "/" + name).put( # pylint: disable=no-member Body=out.getvalue() ) thread = threading.Thread(target=create_log, args=()) thread.daemon = True thread.start() stdout_log, stderr_log = emr.retrieve_logs_for_step_id( context.log, cluster_id, "s-123456123456" ) assert stdout_log == "some log" assert stderr_log == "some log" def test_wait_for_log(mock_s3_bucket): def create_log(): time.sleep(0.5) out = io.BytesIO() with gzip.GzipFile(fileobj=out, mode="w") as fo: - fo.write("foo bar".encode()) + fo.write(b"foo bar") mock_s3_bucket.Object("some_log_file").put(Body=out.getvalue()) # pylint: disable=no-member thread = threading.Thread(target=create_log, args=()) thread.daemon = True thread.start() context = create_test_pipeline_execution_context() emr = EmrJobRunner(region=REGION) res = emr.wait_for_log( context.log, log_bucket=mock_s3_bucket.name, log_key="some_log_file", waiter_delay=1, waiter_max_attempts=2, ) assert res == "foo bar" with pytest.raises(EmrError) as exc_info: emr.wait_for_log( context.log, log_bucket=mock_s3_bucket.name, log_key="does_not_exist", waiter_delay=1, waiter_max_attempts=1, ) assert "EMR log file did not appear on S3 after waiting" in str(exc_info.value) @mock_emr def test_is_emr_step_complete(emr_cluster_config): context = create_test_pipeline_execution_context() emr = EmrJobRunner(region=REGION, check_cluster_every=1) cluster_id = emr.run_job_flow(context.log, emr_cluster_config) step_name = "test_step" step_cmd = ["ls", "/"] step_ids = emr.add_job_flow_steps( context.log, cluster_id, [emr.construct_step_dict_for_command(step_name, step_cmd)] ) def get_step_dict(step_id, step_state): return { "Step": { "Id": step_id, "Name": step_name, "Config": {"Jar": "command-runner.jar", "Properties": {}, "Args": step_cmd}, "ActionOnFailure": "CONTINUE", "Status": { "State": step_state, "StateChangeReason": {"Message": "everything is hosed"}, "Timeline": {"StartDateTime": _boto3_now()}, }, }, } emr_step_id = step_ids[0] describe_step_returns = [ get_step_dict(emr_step_id, "PENDING"), get_step_dict(emr_step_id, "RUNNING"), get_step_dict(emr_step_id, "COMPLETED"), get_step_dict(emr_step_id, "FAILED"), ] with mock.patch.object(EmrJobRunner, "describe_step", side_effect=describe_step_returns): assert not emr.is_emr_step_complete(context.log, cluster_id, emr_step_id) assert not emr.is_emr_step_complete(context.log, cluster_id, emr_step_id) assert emr.is_emr_step_complete(context.log, cluster_id, emr_step_id) with pytest.raises(EmrError) as exc_info: emr.is_emr_step_complete(context.log, cluster_id, emr_step_id) assert "step failed" in str(exc_info.value) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py index caa9325ef..354803f4b 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py @@ -1,114 +1,113 @@ import os import sys import tempfile -import six from dagster import DagsterEventType, execute_pipeline, pipeline, solid from dagster.core.instance import DagsterInstance, InstanceType from dagster.core.launcher import DefaultRunLauncher from dagster.core.run_coordinator import DefaultRunCoordinator from dagster.core.storage.compute_log_manager import ComputeIOType from dagster.core.storage.event_log import SqliteEventLogStorage from dagster.core.storage.root import LocalArtifactStorage from dagster.core.storage.runs import SqliteRunStorage from dagster_aws.s3 import S3ComputeLogManager HELLO_WORLD = "Hello World" SEPARATOR = os.linesep if (os.name == "nt" and sys.version_info < (3,)) else "\n" EXPECTED_LOGS = [ 'STEP_START - Started execution of step "easy".', 'STEP_OUTPUT - Yielded output "result" of type "Any"', 'STEP_SUCCESS - Finished execution of step "easy"', ] def test_compute_log_manager(mock_s3_bucket): @pipeline def simple(): @solid def easy(context): context.log.info("easy") print(HELLO_WORLD) # pylint: disable=print-call return "easy" easy() with tempfile.TemporaryDirectory() as temp_dir: run_store = SqliteRunStorage.from_local(temp_dir) event_store = SqliteEventLogStorage(temp_dir) manager = S3ComputeLogManager( bucket=mock_s3_bucket.name, prefix="my_prefix", local_dir=temp_dir ) instance = DagsterInstance( instance_type=InstanceType.PERSISTENT, local_artifact_storage=LocalArtifactStorage(temp_dir), run_storage=run_store, event_storage=event_store, compute_log_manager=manager, run_coordinator=DefaultRunCoordinator(), run_launcher=DefaultRunLauncher(), ) result = execute_pipeline(simple, instance=instance) compute_steps = [ event.step_key for event in result.step_event_list if event.event_type == DagsterEventType.STEP_START ] assert len(compute_steps) == 1 step_key = compute_steps[0] stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT) assert stdout.data == HELLO_WORLD + SEPARATOR stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR) for expected in EXPECTED_LOGS: assert expected in stderr.data # Check S3 directly s3_object = mock_s3_bucket.Object( key="{prefix}/storage/{run_id}/compute_logs/easy.err".format( prefix="my_prefix", run_id=result.run_id ), ) - stderr_s3 = six.ensure_str(s3_object.get()["Body"].read()) + stderr_s3 = s3_object.get()["Body"].read().decode("utf-8") for expected in EXPECTED_LOGS: assert expected in stderr_s3 # Check download behavior by deleting locally cached logs compute_logs_dir = os.path.join(temp_dir, result.run_id, "compute_logs") for filename in os.listdir(compute_logs_dir): os.unlink(os.path.join(compute_logs_dir, filename)) stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT) assert stdout.data == HELLO_WORLD + SEPARATOR stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR) for expected in EXPECTED_LOGS: assert expected in stderr.data def test_compute_log_manager_from_config(mock_s3_bucket): s3_prefix = "foobar" dagster_yaml = """ compute_logs: module: dagster_aws.s3.compute_log_manager class: S3ComputeLogManager config: bucket: "{s3_bucket}" local_dir: "/tmp/cool" prefix: "{s3_prefix}" """.format( s3_bucket=mock_s3_bucket.name, s3_prefix=s3_prefix ) with tempfile.TemporaryDirectory() as tempdir: with open(os.path.join(tempdir, "dagster.yaml"), "wb") as f: - f.write(six.ensure_binary(dagster_yaml)) + f.write(dagster_yaml.encode("utf-8")) instance = DagsterInstance.from_config(tempdir) assert ( instance.compute_log_manager._s3_bucket # pylint: disable=protected-access == mock_s3_bucket.name ) assert instance.compute_log_manager._s3_prefix == s3_prefix # pylint: disable=protected-access diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_file_handle_to_s3.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_file_handle_to_s3.py index fbeb8864e..b550ba7a2 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_file_handle_to_s3.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_file_handle_to_s3.py @@ -1,47 +1,47 @@ from dagster import ModeDefinition, execute_pipeline, pipeline, solid from dagster_aws.s3 import S3FileHandle, file_handle_to_s3, s3_file_manager, s3_resource def create_file_handle_pipeline(temp_file_handle): @solid def emit_temp_handle(_): return temp_file_handle @pipeline( mode_defs=[ ModeDefinition(resource_defs={"s3": s3_resource, "file_manager": s3_file_manager}) ] ) def test(): return file_handle_to_s3(emit_temp_handle()) return test def test_successful_file_handle_to_s3(mock_s3_bucket): - foo_bytes = "foo".encode() + foo_bytes = b"foo" remote_s3_object = mock_s3_bucket.Object("some-key/foo") remote_s3_object.put(Body=foo_bytes) file_handle = S3FileHandle(mock_s3_bucket.name, "some-key/foo") result = execute_pipeline( create_file_handle_pipeline(file_handle), run_config={ "solids": { "file_handle_to_s3": {"config": {"Bucket": mock_s3_bucket.name, "Key": "some-key"}} }, "resources": {"file_manager": {"config": {"s3_bucket": mock_s3_bucket.name}}}, }, ) assert result.success assert mock_s3_bucket.Object(key="some-key").get()["Body"].read() == foo_bytes materializations = result.result_for_solid("file_handle_to_s3").materializations_during_compute assert len(materializations) == 1 assert len(materializations[0].metadata_entries) == 1 assert materializations[0].metadata_entries[ 0 ].entry_data.path == "s3://{bucket}/some-key".format(bucket=mock_s3_bucket.name) assert materializations[0].metadata_entries[0].label == "some-key" diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_cache.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_cache.py index 82a2f8458..3f2a877c2 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_cache.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_cache.py @@ -1,52 +1,52 @@ import io from dagster_aws.s3 import S3FileCache, S3FileHandle def test_s3_file_cache_file_not_present(mock_s3_resource, mock_s3_bucket): file_store = S3FileCache( s3_bucket=mock_s3_bucket.name, s3_key="some-key", s3_session=mock_s3_resource.meta.client, overwrite=False, ) assert not file_store.has_file_object("foo") def test_s3_file_cache_file_present(mock_s3_resource, mock_s3_bucket): file_store = S3FileCache( s3_bucket=mock_s3_bucket.name, s3_key="some-key", s3_session=mock_s3_resource.meta.client, overwrite=False, ) assert not file_store.has_file_object("foo") - file_store.write_binary_data("foo", "bar".encode()) + file_store.write_binary_data("foo", b"bar") assert file_store.has_file_object("foo") def test_s3_file_cache_correct_handle(mock_s3_resource, mock_s3_bucket): file_store = S3FileCache( s3_bucket=mock_s3_bucket.name, s3_key="some-key", s3_session=mock_s3_resource.meta.client, overwrite=False, ) assert isinstance(file_store.get_file_handle("foo"), S3FileHandle) def test_s3_file_cache_write_file_object(mock_s3_resource, mock_s3_bucket): file_store = S3FileCache( s3_bucket=mock_s3_bucket.name, s3_key="some-key", s3_session=mock_s3_resource.meta.client, overwrite=False, ) - stream = io.BytesIO("content".encode()) + stream = io.BytesIO(b"content") file_store.write_file_object("foo", stream) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_manager.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_manager.py index bbb2c808a..5cab752a7 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_manager.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_s3_file_manager.py @@ -1,209 +1,209 @@ import uuid from dagster import ( InputDefinition, Int, ModeDefinition, OutputDefinition, configured, execute_pipeline, pipeline, solid, ) from dagster.seven import mock from dagster_aws.s3 import ( S3FileHandle, S3FileManager, s3_file_manager, s3_plus_default_intermediate_storage_defs, s3_resource, ) def build_key(run_id, step_key, output_name): return "dagster/storage/{run_id}/intermediates/{step_key}/{output_name}".format( run_id=run_id, step_key=step_key, output_name=output_name ) def test_s3_file_manager_write(mock_s3_resource, mock_s3_bucket): file_manager = S3FileManager(mock_s3_resource.meta.client, mock_s3_bucket.name, "some-key") - body = "foo".encode() + body = b"foo" file_handle = file_manager.write_data(body) assert mock_s3_bucket.Object(file_handle.s3_key).get()["Body"].read() == body file_handle = file_manager.write_data(body, ext="foo") assert file_handle.s3_key.endswith(".foo") assert mock_s3_bucket.Object(file_handle.s3_key).get()["Body"].read() == body def test_s3_file_manager_read(mock_s3_resource, mock_s3_bucket): - body = "bar".encode() + body = b"bar" remote_s3_object = mock_s3_bucket.Object("some-key/foo") remote_s3_object.put(Body=body) file_manager = S3FileManager(mock_s3_resource.meta.client, mock_s3_bucket.name, "some-key") file_handle = S3FileHandle(mock_s3_bucket.name, "some-key/foo") with file_manager.read(file_handle) as file_obj: assert file_obj.read() == body # read again. cached remote_s3_object.delete() with file_manager.read(file_handle) as file_obj: assert file_obj.read() == body def test_depends_on_s3_resource_intermediates(mock_s3_bucket): @solid( input_defs=[InputDefinition("num_one", Int), InputDefinition("num_two", Int)], output_defs=[OutputDefinition(Int)], ) def add_numbers(_, num_one, num_two): return num_one + num_two @pipeline( mode_defs=[ ModeDefinition( intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, resource_defs={"s3": s3_resource}, ) ] ) def s3_internal_pipeline(): return add_numbers() result = execute_pipeline( s3_internal_pipeline, run_config={ "solids": { "add_numbers": {"inputs": {"num_one": {"value": 2}, "num_two": {"value": 4}}} }, "intermediate_storage": {"s3": {"config": {"s3_bucket": mock_s3_bucket.name}}}, }, ) keys_in_bucket = [obj.key for obj in mock_s3_bucket.objects.all()] assert result.success assert result.result_for_solid("add_numbers").output_value() == 6 keys = set() for step_key, output_name in [("add_numbers", "result")]: keys.add(build_key(result.run_id, step_key, output_name)) assert set(keys_in_bucket) == keys def test_depends_on_s3_resource_file_manager(mock_s3_bucket): - bar_bytes = "bar".encode() + bar_bytes = b"bar" @solid(output_defs=[OutputDefinition(S3FileHandle)], required_resource_keys={"file_manager"}) def emit_file(context): return context.resources.file_manager.write_data(bar_bytes) @solid( input_defs=[InputDefinition("file_handle", S3FileHandle)], required_resource_keys={"file_manager"}, ) def accept_file(context, file_handle): local_path = context.resources.file_manager.copy_handle_to_local_temp(file_handle) assert isinstance(local_path, str) assert open(local_path, "rb").read() == bar_bytes @pipeline( mode_defs=[ ModeDefinition( intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, resource_defs={"s3": s3_resource, "file_manager": s3_file_manager}, ) ] ) def s3_file_manager_test(): accept_file(emit_file()) result = execute_pipeline( s3_file_manager_test, run_config={ "resources": { "file_manager": { "config": {"s3_bucket": mock_s3_bucket.name, "s3_prefix": "some-prefix"} } }, "intermediate_storage": {"s3": {"config": {"s3_bucket": mock_s3_bucket.name}}}, }, ) assert result.success keys_in_bucket = [obj.key for obj in mock_s3_bucket.objects.all()] for step_key, output_name in [ ("emit_file", "result"), ("accept_file", "result"), ]: keys_in_bucket.remove(build_key(result.run_id, step_key, output_name)) assert len(keys_in_bucket) == 1 file_key = list(keys_in_bucket)[0] comps = file_key.split("/") assert "/".join(comps[:-1]) == "some-prefix" assert uuid.UUID(comps[-1]) @mock.patch("boto3.resource") @mock.patch("dagster_aws.s3.resources.S3FileManager") def test_s3_file_manager_resource(MockS3FileManager, mock_boto3_resource): did_it_run = dict(it_ran=False) resource_config = { "use_unsigned_session": True, "region_name": "us-west-1", "endpoint_url": "http://alternate-s3-host.io", "s3_bucket": "some-bucket", "s3_prefix": "some-prefix", } mock_s3_session = mock_boto3_resource.return_value.meta.client @solid(required_resource_keys={"file_manager"}) def test_solid(context): # test that we got back a S3FileManager assert context.resources.file_manager == MockS3FileManager.return_value # make sure the file manager was initalized with the config we are supplying MockS3FileManager.assert_called_once_with( s3_session=mock_s3_session, s3_bucket=resource_config["s3_bucket"], s3_base_key=resource_config["s3_prefix"], ) _, call_kwargs = mock_boto3_resource.call_args mock_boto3_resource.assert_called_once_with( "s3", region_name=resource_config["region_name"], endpoint_url=resource_config["endpoint_url"], use_ssl=True, config=call_kwargs["config"], ) assert call_kwargs["config"].retries["max_attempts"] == 5 did_it_run["it_ran"] = True @pipeline( mode_defs=[ ModeDefinition( resource_defs={"file_manager": configured(s3_file_manager)(resource_config)}, ) ] ) def test_pipeline(): test_solid() execute_pipeline(test_pipeline) assert did_it_run["it_ran"] diff --git a/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_cache.py b/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_cache.py index 1a254b8fb..ce809de68 100644 --- a/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_cache.py +++ b/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_cache.py @@ -1,60 +1,60 @@ import io from dagster_azure.adls2 import ADLS2FileCache, ADLS2FileHandle, FakeADLS2ServiceClient def test_adls2_file_cache_file_not_present(storage_account, file_system, credential): fake_client = FakeADLS2ServiceClient(storage_account, credential) file_store = ADLS2FileCache( storage_account=storage_account, file_system=file_system, prefix="some-prefix", client=fake_client, overwrite=False, ) assert not file_store.has_file_object("foo") def test_adls2_file_cache_file_present(storage_account, file_system, credential): fake_client = FakeADLS2ServiceClient(storage_account, credential) file_store = ADLS2FileCache( storage_account=storage_account, file_system=file_system, prefix="some-prefix", client=fake_client, overwrite=False, ) assert not file_store.has_file_object("foo") - file_store.write_binary_data("foo", "bar".encode()) + file_store.write_binary_data("foo", b"bar") assert file_store.has_file_object("foo") def test_adls2_file_cache_correct_handle(storage_account, file_system, credential): fake_client = FakeADLS2ServiceClient(storage_account, credential) file_store = ADLS2FileCache( storage_account=storage_account, file_system=file_system, prefix="some-prefix", client=fake_client, overwrite=False, ) assert isinstance(file_store.get_file_handle("foo"), ADLS2FileHandle) def test_adls2_file_cache_write_file_object(storage_account, file_system, credential): fake_client = FakeADLS2ServiceClient(storage_account, credential) file_store = ADLS2FileCache( storage_account=storage_account, file_system=file_system, prefix="some-prefix", client=fake_client, overwrite=False, ) - stream = io.BytesIO("content".encode()) + stream = io.BytesIO(b"content") file_store.write_file_object("foo", stream) diff --git a/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_manager.py b/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_manager.py index e53242855..6778c5421 100644 --- a/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_manager.py +++ b/python_modules/libraries/dagster-azure/dagster_azure_tests/adls2_tests/test_adls2_file_manager.py @@ -1,247 +1,247 @@ import uuid from dagster import ( InputDefinition, Int, ModeDefinition, OutputDefinition, ResourceDefinition, configured, execute_pipeline, pipeline, solid, ) from dagster.seven import mock from dagster_azure.adls2 import ( ADLS2FileHandle, ADLS2FileManager, FakeADLS2Resource, adls2_file_manager, adls2_plus_default_intermediate_storage_defs, ) # For deps def test_adls2_file_manager_write(storage_account, file_system): file_mock = mock.MagicMock() adls2_mock = mock.MagicMock() adls2_mock.get_file_client.return_value = file_mock adls2_mock.account_name = storage_account file_manager = ADLS2FileManager(adls2_mock, file_system, "some-key") - foo_bytes = "foo".encode() + foo_bytes = b"foo" file_handle = file_manager.write_data(foo_bytes) assert isinstance(file_handle, ADLS2FileHandle) assert file_handle.account == storage_account assert file_handle.file_system == file_system assert file_handle.key.startswith("some-key/") assert file_mock.upload_data.call_count == 1 file_handle = file_manager.write_data(foo_bytes, ext="foo") assert isinstance(file_handle, ADLS2FileHandle) assert file_handle.account == storage_account assert file_handle.file_system == file_system assert file_handle.key.startswith("some-key/") assert file_handle.key[-4:] == ".foo" assert file_mock.upload_data.call_count == 2 def test_adls2_file_manager_read(storage_account, file_system): state = {"called": 0} - bar_bytes = "bar".encode() + bar_bytes = b"bar" class DownloadMock(mock.MagicMock): def readinto(self, fileobj): fileobj.write(bar_bytes) class FileMock(mock.MagicMock): def download_file(self): state["called"] += 1 assert state["called"] == 1 return DownloadMock(file=self) class ADLS2Mock(mock.MagicMock): def get_file_client(self, *_args, **kwargs): state["file_system"] = kwargs["file_system"] file_path = kwargs["file_path"] state["file_path"] = kwargs["file_path"] return FileMock(file_path=file_path) adls2_mock = ADLS2Mock() file_manager = ADLS2FileManager(adls2_mock, file_system, "some-key") file_handle = ADLS2FileHandle(storage_account, file_system, "some-key/kdjfkjdkfjkd") with file_manager.read(file_handle) as file_obj: assert file_obj.read() == bar_bytes assert state["file_system"] == file_handle.file_system assert state["file_path"] == file_handle.key # read again. cached with file_manager.read(file_handle) as file_obj: assert file_obj.read() == bar_bytes file_manager.delete_local_temp() def test_depends_on_adls2_resource_intermediates(storage_account, file_system): @solid( input_defs=[InputDefinition("num_one", Int), InputDefinition("num_two", Int)], output_defs=[OutputDefinition(Int)], ) def add_numbers(_, num_one, num_two): return num_one + num_two adls2_fake_resource = FakeADLS2Resource(storage_account) @pipeline( mode_defs=[ ModeDefinition( intermediate_storage_defs=adls2_plus_default_intermediate_storage_defs, resource_defs={"adls2": ResourceDefinition.hardcoded_resource(adls2_fake_resource)}, ) ] ) def adls2_internal_pipeline(): return add_numbers() result = execute_pipeline( adls2_internal_pipeline, run_config={ "solids": { "add_numbers": {"inputs": {"num_one": {"value": 2}, "num_two": {"value": 4}}} }, "intermediate_storage": {"adls2": {"config": {"adls2_file_system": file_system}}}, }, ) assert result.success assert result.result_for_solid("add_numbers").output_value() == 6 assert file_system in adls2_fake_resource.adls2_client.file_systems keys = set() for step_key, output_name in [("add_numbers", "result")]: keys.add(create_adls2_key(result.run_id, step_key, output_name)) assert set(adls2_fake_resource.adls2_client.file_systems[file_system].keys()) == keys def create_adls2_key(run_id, step_key, output_name): return "dagster/storage/{run_id}/intermediates/{step_key}/{output_name}".format( run_id=run_id, step_key=step_key, output_name=output_name ) def test_depends_on_adls2_resource_file_manager(storage_account, file_system): - bar_bytes = "bar".encode() + bar_bytes = b"bar" @solid(output_defs=[OutputDefinition(ADLS2FileHandle)], required_resource_keys={"file_manager"}) def emit_file(context): return context.resources.file_manager.write_data(bar_bytes) @solid( input_defs=[InputDefinition("file_handle", ADLS2FileHandle)], required_resource_keys={"file_manager"}, ) def accept_file(context, file_handle): local_path = context.resources.file_manager.copy_handle_to_local_temp(file_handle) assert isinstance(local_path, str) assert open(local_path, "rb").read() == bar_bytes adls2_fake_resource = FakeADLS2Resource(storage_account) adls2_fake_file_manager = ADLS2FileManager( adls2_client=adls2_fake_resource.adls2_client, file_system=file_system, prefix="some-prefix", ) @pipeline( mode_defs=[ ModeDefinition( intermediate_storage_defs=adls2_plus_default_intermediate_storage_defs, resource_defs={ "adls2": ResourceDefinition.hardcoded_resource(adls2_fake_resource), "file_manager": ResourceDefinition.hardcoded_resource(adls2_fake_file_manager), }, ) ] ) def adls2_file_manager_test(): accept_file(emit_file()) result = execute_pipeline( adls2_file_manager_test, run_config={ "intermediate_storage": {"adls2": {"config": {"adls2_file_system": file_system}}} }, ) assert result.success keys_in_bucket = set(adls2_fake_resource.adls2_client.file_systems[file_system].keys()) for step_key, output_name in [ ("emit_file", "result"), ("accept_file", "result"), ]: keys_in_bucket.remove(create_adls2_key(result.run_id, step_key, output_name)) assert len(keys_in_bucket) == 1 file_key = list(keys_in_bucket)[0] comps = file_key.split("/") assert "/".join(comps[:-1]) == "some-prefix" assert uuid.UUID(comps[-1]) @mock.patch("dagster_azure.adls2.resources.ADLS2Resource") @mock.patch("dagster_azure.adls2.resources.ADLS2FileManager") def test_adls_file_manager_resource(MockADLS2FileManager, MockADLS2Resource): did_it_run = dict(it_ran=False) resource_config = { "storage_account": "some-storage-account", "credential": {"key": "some-key",}, "adls2_file_system": "some-file-system", "adls2_prefix": "some-prefix", } @solid(required_resource_keys={"file_manager"}) def test_solid(context): # test that we got back a ADLS2FileManager assert context.resources.file_manager == MockADLS2FileManager.return_value # make sure the file manager was initalized with the config we are supplying MockADLS2FileManager.assert_called_once_with( adls2_client=MockADLS2Resource.return_value.adls2_client, file_system=resource_config["adls2_file_system"], prefix=resource_config["adls2_prefix"], ) MockADLS2Resource.assert_called_once_with( resource_config["storage_account"], resource_config["credential"]["key"] ) did_it_run["it_ran"] = True @pipeline( mode_defs=[ ModeDefinition( resource_defs={"file_manager": configured(adls2_file_manager)(resource_config)}, ) ] ) def test_pipeline(): test_solid() execute_pipeline(test_pipeline) assert did_it_run["it_ran"] diff --git a/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py index 0a8f79832..e9fa980cf 100644 --- a/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py +++ b/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py @@ -1,131 +1,130 @@ import os import sys import tempfile -import six from dagster import DagsterEventType, execute_pipeline, pipeline, solid from dagster.core.instance import DagsterInstance, InstanceType from dagster.core.launcher.sync_in_memory_run_launcher import SyncInMemoryRunLauncher from dagster.core.run_coordinator import DefaultRunCoordinator from dagster.core.storage.compute_log_manager import ComputeIOType from dagster.core.storage.event_log import SqliteEventLogStorage from dagster.core.storage.root import LocalArtifactStorage from dagster.core.storage.runs import SqliteRunStorage from dagster.seven import mock from dagster_azure.blob import AzureBlobComputeLogManager, FakeBlobServiceClient HELLO_WORLD = "Hello World" SEPARATOR = os.linesep if (os.name == "nt" and sys.version_info < (3,)) else "\n" EXPECTED_LOGS = [ 'STEP_START - Started execution of step "easy".', 'STEP_OUTPUT - Yielded output "result" of type "Any"', 'STEP_SUCCESS - Finished execution of step "easy"', ] @mock.patch("dagster_azure.blob.compute_log_manager.generate_blob_sas") @mock.patch("dagster_azure.blob.compute_log_manager.create_blob_client") def test_compute_log_manager( mock_create_blob_client, mock_generate_blob_sas, storage_account, container, credential ): mock_generate_blob_sas.return_value = "fake-url" fake_client = FakeBlobServiceClient(storage_account) mock_create_blob_client.return_value = fake_client @pipeline def simple(): @solid def easy(context): context.log.info("easy") print(HELLO_WORLD) # pylint: disable=print-call return "easy" easy() with tempfile.TemporaryDirectory() as temp_dir: run_store = SqliteRunStorage.from_local(temp_dir) event_store = SqliteEventLogStorage(temp_dir) manager = AzureBlobComputeLogManager( storage_account=storage_account, container=container, prefix="my_prefix", local_dir=temp_dir, secret_key=credential, ) instance = DagsterInstance( instance_type=InstanceType.PERSISTENT, local_artifact_storage=LocalArtifactStorage(temp_dir), run_storage=run_store, event_storage=event_store, compute_log_manager=manager, run_coordinator=DefaultRunCoordinator(), run_launcher=SyncInMemoryRunLauncher(), ) result = execute_pipeline(simple, instance=instance) compute_steps = [ event.step_key for event in result.step_event_list if event.event_type == DagsterEventType.STEP_START ] assert len(compute_steps) == 1 step_key = compute_steps[0] stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT) assert stdout.data == HELLO_WORLD + SEPARATOR stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR) for expected in EXPECTED_LOGS: assert expected in stderr.data # Check ADLS2 directly adls2_object = fake_client.get_blob_client( container=container, blob="{prefix}/storage/{run_id}/compute_logs/easy.err".format( prefix="my_prefix", run_id=result.run_id ), ) - adls2_stderr = six.ensure_str(adls2_object.download_blob().readall()) + adls2_stderr = adls2_object.download_blob().readall().decode("utf-8") for expected in EXPECTED_LOGS: assert expected in adls2_stderr # Check download behavior by deleting locally cached logs compute_logs_dir = os.path.join(temp_dir, result.run_id, "compute_logs") for filename in os.listdir(compute_logs_dir): os.unlink(os.path.join(compute_logs_dir, filename)) stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT) assert stdout.data == HELLO_WORLD + SEPARATOR stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR) for expected in EXPECTED_LOGS: assert expected in stderr.data def test_compute_log_manager_from_config(storage_account, container, credential): prefix = "foobar" dagster_yaml = """ compute_logs: module: dagster_azure.blob.compute_log_manager class: AzureBlobComputeLogManager config: storage_account: "{storage_account}" container: {container} secret_key: {credential} local_dir: "/tmp/cool" prefix: "{prefix}" """.format( storage_account=storage_account, container=container, credential=credential, prefix=prefix ) with tempfile.TemporaryDirectory() as tempdir: with open(os.path.join(tempdir, "dagster.yaml"), "wb") as f: - f.write(six.ensure_binary(dagster_yaml)) + f.write(dagster_yaml.encode("utf-8")) instance = DagsterInstance.from_config(tempdir) assert ( instance.compute_log_manager._storage_account # pylint: disable=protected-access == storage_account ) assert instance.compute_log_manager._container == container # pylint: disable=protected-access assert instance.compute_log_manager._blob_prefix == prefix # pylint: disable=protected-access diff --git a/python_modules/libraries/dagster-celery/dagster_celery_tests/test_config.py b/python_modules/libraries/dagster-celery/dagster_celery_tests/test_config.py index 8d874113d..126fe0a3f 100644 --- a/python_modules/libraries/dagster-celery/dagster_celery_tests/test_config.py +++ b/python_modules/libraries/dagster-celery/dagster_celery_tests/test_config.py @@ -1,61 +1,61 @@ import os from dagster.core.test_utils import environ from dagster.seven import tempfile from dagster_celery.cli import get_config_dir CONFIG_YAML = """ execution: celery: broker: "pyampqp://foo@bar:1234//" config_source: foo: "bar" """ ENV_CONFIG_YAML = """ execution: celery: broker: env: BROKER_URL config_source: foo: "bar" """ CONFIG_PY = """broker_url = \'pyampqp://foo@bar:1234//\' result_backend = \'rpc://\' foo = \'bar\' """ CONFIG_PYTHON_FILE = "{config_module_name}.py".format(config_module_name="dagster_celery_config") def test_config_value_from_yaml(): with tempfile.NamedTemporaryFile() as tmp: tmp.write(CONFIG_YAML.encode("utf-8")) tmp.seek(0) python_path = get_config_dir(config_yaml=tmp.name) with open(os.path.join(python_path, CONFIG_PYTHON_FILE), "r") as fd: assert str(fd.read()) == CONFIG_PY def test_config_value_from_empty_yaml(): with tempfile.NamedTemporaryFile() as tmp: - tmp.write("".encode("utf-8")) + tmp.write(b"") tmp.seek(0) python_path = get_config_dir(config_yaml=tmp.name) with open(os.path.join(python_path, CONFIG_PYTHON_FILE), "r") as fd: assert str(fd.read()) == "result_backend = 'rpc://'\n" def test_config_value_from_env_yaml(): with environ({"BROKER_URL": "pyampqp://foo@bar:1234//"}): with tempfile.NamedTemporaryFile() as tmp: tmp.write(CONFIG_YAML.encode("utf-8")) tmp.seek(0) python_path = get_config_dir(config_yaml=tmp.name) with open(os.path.join(python_path, CONFIG_PYTHON_FILE), "r") as fd: assert str(fd.read()) == CONFIG_PY diff --git a/python_modules/libraries/dagster-cron/dagster_cron/cron_scheduler.py b/python_modules/libraries/dagster-cron/dagster_cron/cron_scheduler.py index ddc37e86b..43023fba0 100644 --- a/python_modules/libraries/dagster-cron/dagster_cron/cron_scheduler.py +++ b/python_modules/libraries/dagster-cron/dagster_cron/cron_scheduler.py @@ -1,235 +1,234 @@ import io import os import shutil import stat import sys -import six from crontab import CronTab from dagster import DagsterInstance, check, utils from dagster.core.host_representation import ExternalSchedule from dagster.core.scheduler import DagsterSchedulerError, Scheduler from dagster.serdes import ConfigurableClass class SystemCronScheduler(Scheduler, ConfigurableClass): """Scheduler implementation that uses the local systems cron. Only works on unix systems that have cron. Enable this scheduler by adding it to your ``dagster.yaml`` in ``$DAGSTER_HOME``. """ def __init__( self, inst_data=None, ): self._inst_data = inst_data @property def inst_data(self): return self._inst_data @classmethod def config_type(cls): return {} @staticmethod def from_config_value(inst_data, config_value): return SystemCronScheduler(inst_data=inst_data) def get_cron_tab(self): return CronTab(user=True) def debug_info(self): return "Running Cron Jobs:\n{jobs}\n".format( jobs="\n".join( [str(job) for job in self.get_cron_tab() if "dagster-schedule:" in job.comment] ) ) def start_schedule(self, instance, external_schedule): check.inst_param(instance, "instance", DagsterInstance) check.inst_param(external_schedule, "external_schedule", ExternalSchedule) schedule_origin_id = external_schedule.get_external_origin_id() # If the cron job already exists, remove it. This prevents duplicate entries. # Then, add a new cron job to the cron tab. if self.running_schedule_count(instance, external_schedule.get_external_origin_id()) > 0: self._end_cron_job(instance, schedule_origin_id) self._start_cron_job(instance, external_schedule) # Verify that the cron job is running running_schedule_count = self.running_schedule_count(instance, schedule_origin_id) if running_schedule_count == 0: raise DagsterSchedulerError( "Attempted to write cron job for schedule " "{schedule_name}, but failed. " "The scheduler is not running {schedule_name}.".format( schedule_name=external_schedule.name ) ) elif running_schedule_count > 1: raise DagsterSchedulerError( "Attempted to write cron job for schedule " "{schedule_name}, but duplicate cron jobs were found. " "There are {running_schedule_count} jobs running for the schedule." "To resolve, run `dagster schedule up`, or edit the cron tab to " "remove duplicate schedules".format( schedule_name=external_schedule.name, running_schedule_count=running_schedule_count, ) ) def stop_schedule(self, instance, schedule_origin_id): check.inst_param(instance, "instance", DagsterInstance) check.str_param(schedule_origin_id, "schedule_origin_id") schedule = self._get_schedule_state(instance, schedule_origin_id) self._end_cron_job(instance, schedule_origin_id) # Verify that the cron job has been removed running_schedule_count = self.running_schedule_count(instance, schedule_origin_id) if running_schedule_count > 0: raise DagsterSchedulerError( "Attempted to remove existing cron job for schedule " "{schedule_name}, but failed. " "There are still {running_schedule_count} jobs running for the schedule.".format( schedule_name=schedule.name, running_schedule_count=running_schedule_count ) ) def wipe(self, instance): # Note: This method deletes schedules from ALL repositories check.inst_param(instance, "instance", DagsterInstance) # Delete all script files script_directory = os.path.join(instance.schedules_directory(), "scripts") if os.path.isdir(script_directory): shutil.rmtree(script_directory) # Delete all logs logs_directory = os.path.join(instance.schedules_directory(), "logs") if os.path.isdir(logs_directory): shutil.rmtree(logs_directory) # Remove all cron jobs with self.get_cron_tab() as cron_tab: for job in cron_tab: if "dagster-schedule:" in job.comment: cron_tab.remove_all(comment=job.comment) def _get_bash_script_file_path(self, instance, schedule_origin_id): check.inst_param(instance, "instance", DagsterInstance) check.str_param(schedule_origin_id, "schedule_origin_id") script_directory = os.path.join(instance.schedules_directory(), "scripts") utils.mkdir_p(script_directory) script_file_name = "{}.sh".format(schedule_origin_id) return os.path.join(script_directory, script_file_name) def _cron_tag_for_schedule(self, schedule_origin_id): return "dagster-schedule: {schedule_origin_id}".format( schedule_origin_id=schedule_origin_id ) def _get_command(self, script_file, instance, schedule_origin_id): schedule_log_file_path = self.get_logs_path(instance, schedule_origin_id) command = "{script_file} > {schedule_log_file_path} 2>&1".format( script_file=script_file, schedule_log_file_path=schedule_log_file_path ) return command def _start_cron_job(self, instance, external_schedule): schedule_origin_id = external_schedule.get_external_origin_id() script_file = self._write_bash_script_to_file(instance, external_schedule) command = self._get_command(script_file, instance, schedule_origin_id) with self.get_cron_tab() as cron_tab: job = cron_tab.new( command=command, comment="dagster-schedule: {schedule_origin_id}".format( schedule_origin_id=schedule_origin_id ), ) job.setall(external_schedule.cron_schedule) def _end_cron_job(self, instance, schedule_origin_id): with self.get_cron_tab() as cron_tab: cron_tab.remove_all(comment=self._cron_tag_for_schedule(schedule_origin_id)) script_file = self._get_bash_script_file_path(instance, schedule_origin_id) if os.path.isfile(script_file): os.remove(script_file) def running_schedule_count(self, instance, schedule_origin_id): matching_jobs = self.get_cron_tab().find_comment( self._cron_tag_for_schedule(schedule_origin_id) ) return len(list(matching_jobs)) def _get_or_create_logs_directory(self, instance, schedule_origin_id): check.inst_param(instance, "instance", DagsterInstance) check.str_param(schedule_origin_id, "schedule_origin_id") logs_directory = os.path.join(instance.schedules_directory(), "logs", schedule_origin_id) if not os.path.isdir(logs_directory): utils.mkdir_p(logs_directory) return logs_directory def get_logs_path(self, instance, schedule_origin_id): check.inst_param(instance, "instance", DagsterInstance) check.str_param(schedule_origin_id, "schedule_origin_id") logs_directory = self._get_or_create_logs_directory(instance, schedule_origin_id) return os.path.join(logs_directory, "scheduler.log") def _write_bash_script_to_file(self, instance, external_schedule): # Get path to store bash script schedule_origin_id = external_schedule.get_external_origin_id() script_file = self._get_bash_script_file_path(instance, schedule_origin_id) # Get path to store schedule attempt logs logs_directory = self._get_or_create_logs_directory(instance, schedule_origin_id) schedule_log_file_name = "{}_{}.result".format("${RUN_DATE}", schedule_origin_id) schedule_log_file_path = os.path.join(logs_directory, schedule_log_file_name) local_target = external_schedule.get_external_origin() # Environment information needed for execution dagster_home = os.getenv("DAGSTER_HOME") script_contents = """ #!/bin/bash export DAGSTER_HOME={dagster_home} export LANG=en_US.UTF-8 {env_vars} export RUN_DATE=$(date "+%Y%m%dT%H%M%S") {python_exe} -m dagster api launch_scheduled_execution --schedule_name {schedule_name} {repo_cli_args} "{result_file}" """.format( python_exe=sys.executable, schedule_name=external_schedule.name, repo_cli_args=local_target.get_repo_cli_args(), result_file=schedule_log_file_path, dagster_home=dagster_home, env_vars="\n".join( [ "export {key}={value}".format(key=key, value=value) for key, value in external_schedule.environment_vars.items() ] ), ) with io.open(script_file, "w", encoding="utf-8") as f: - f.write(six.text_type(script_contents)) + f.write(script_contents) st = os.stat(script_file) os.chmod(script_file, st.st_mode | stat.S_IEXEC) return script_file diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt/cli/utils.py b/python_modules/libraries/dagster-dbt/dagster_dbt/cli/utils.py index 8e526c9bc..2ec19a825 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt/cli/utils.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt/cli/utils.py @@ -1,135 +1,135 @@ import json import os import re import subprocess from typing import Any, Dict, List, Tuple from dagster import check from ..errors import ( DagsterDbtCliFatalRuntimeError, DagsterDbtCliHandledRuntimeError, DagsterDbtCliOutputsNotFoundError, ) def execute_cli( executable: str, command: Tuple[str, ...], flags_dict: Dict[str, Any], log: Any, warn_error: bool, ignore_handled_error: bool, ) -> Dict[str, Any]: """Executes a command on the dbt CLI in a subprocess.""" check.str_param(executable, "executable") check.tuple_param(command, "command", of_type=str) check.dict_param(flags_dict, "flags_dict", key_type=str) check.bool_param(warn_error, "warn_error") check.bool_param(ignore_handled_error, "ignore_handled_error") # Format the dbt CLI flags in the command.. warn_error = ["--warn-error"] if warn_error else [] command_list = [executable, "--log-format", "json", *warn_error, *command] for flag, value in flags_dict.items(): if not value: continue command_list.append(f"--{flag}") if isinstance(value, bool): # If a bool flag (and is True), the presence of the flag itself is enough. continue if isinstance(value, list): check.list_param(value, f"config.{flag}", of_type=str) command_list += value continue if isinstance(value, dict): command_list.append(json.dumps(value)) continue command_list.append(str(value)) # Execute the dbt CLI command in a subprocess. command = " ".join(command_list) log.info(f"Executing command: {command}") return_code = 0 process = subprocess.Popen(command_list, stdout=subprocess.PIPE) logs = [] output = [] for raw_line in process.stdout: - line = raw_line.decode() + line = raw_line.decode("utf-8") output.append(line) try: json_line = json.loads(line) except json.JSONDecodeError: log.info(line.rstrip()) else: logs.append(json_line) level = json_line.get("levelname", "").lower() if hasattr(log, level): getattr(log, level)(json_line.get("message", "")) else: log.info(line.rstrip()) process.wait() return_code = process.returncode log.info("dbt exited with return code {return_code}".format(return_code=return_code)) raw_output = "\n".join(output) if return_code == 2: raise DagsterDbtCliFatalRuntimeError(logs=logs, raw_output=raw_output) if return_code == 1 and not ignore_handled_error: raise DagsterDbtCliHandledRuntimeError(logs=logs, raw_output=raw_output) return { "command": command, "return_code": return_code, "logs": logs, "raw_output": raw_output, "summary": extract_summary(logs), } SUMMARY_RE = re.compile(r"PASS=(\d+) WARN=(\d+) ERROR=(\d+) SKIP=(\d+) TOTAL=(\d+)") SUMMARY_LABELS = ("num_pass", "num_warn", "num_error", "num_skip", "num_total") def extract_summary(logs: List[Dict[str, str]]): """Extracts the summary statistics from dbt CLI output.""" check.list_param(logs, "logs", dict) summary = [None] * 5 if len(logs) > 0: # Attempt to extract summary results from the last log's message. last_line = logs[-1] message = last_line["message"].strip() try: summary = next(SUMMARY_RE.finditer(message)).groups() except StopIteration: # Failed to match regex. pass else: summary = map(int, summary) return dict(zip(SUMMARY_LABELS, summary)) def parse_run_results(path: str) -> Dict[str, Any]: """Parses the `target/run_results.json` artifact that is produced by a dbt process.""" run_results_path = os.path.join(path, "target", "run_results.json") try: with open(run_results_path) as file: return json.load(file) except FileNotFoundError: raise DagsterDbtCliOutputsNotFoundError(path=run_results_path) diff --git a/python_modules/libraries/dagster-dbt/dagster_dbt_tests/rpc/conftest.py b/python_modules/libraries/dagster-dbt/dagster_dbt_tests/rpc/conftest.py index 87b1e035f..ab010c3a9 100644 --- a/python_modules/libraries/dagster-dbt/dagster_dbt_tests/rpc/conftest.py +++ b/python_modules/libraries/dagster-dbt/dagster_dbt_tests/rpc/conftest.py @@ -1,260 +1,260 @@ import atexit import json import subprocess import time +from urllib import request +from urllib.error import URLError import pytest import responses from dagster_dbt import DbtRpcClient -from six.moves.urllib import request -from six.moves.urllib.error import URLError TEST_HOSTNAME = "127.0.0.1" TEST_PORT = 8580 RPC_ESTABLISH_RETRIES = 4 RPC_ESTABLISH_RETRY_INTERVAL_S = 1.5 RPC_ENDPOINT = "http://{hostname}:{port}/jsonrpc".format(hostname=TEST_HOSTNAME, port=TEST_PORT) # ======= SOLIDS I ======== def get_rpc_server_status(): status_request_body = b'{"jsonrpc": "2.0", "method": "status", "id": 1}' req = request.Request( RPC_ENDPOINT, data=status_request_body, headers={"Content-type": "application/json"}, ) resp = request.urlopen(req) return json.load(resp) all_subprocs = set() def kill_all_subprocs(): for proc in all_subprocs: proc.kill() atexit.register(kill_all_subprocs) @pytest.fixture(scope="class") def dbt_rpc_server( dbt_seed, dbt_executable, dbt_config_dir ): # pylint: disable=unused-argument, redefined-outer-name proc = subprocess.Popen( [ dbt_executable, "rpc", "--host", TEST_HOSTNAME, "--port", str(TEST_PORT), "--profiles-dir", dbt_config_dir, ], ) # schedule to be killed in case of abort all_subprocs.add(proc) tries_remaining = RPC_ESTABLISH_RETRIES while True: poll_result = proc.poll() # check on the child if poll_result != None: raise Exception("DBT subprocess terminated before test could start.") try: status_json = get_rpc_server_status() if status_json["result"]["state"] == "ready": break except URLError: pass if tries_remaining <= 0: raise Exception("Exceeded max tries waiting for DBT RPC server to be ready.") tries_remaining -= 1 time.sleep(RPC_ESTABLISH_RETRY_INTERVAL_S) yield proc.terminate() # clean up after ourself proc.wait(timeout=0.2) if proc.poll() == None: # still running proc.kill() all_subprocs.remove(proc) # ======= SOLIDS II ======== @pytest.fixture(scope="session") def rpc_endpoint(): return RPC_ENDPOINT @pytest.fixture def rsps(): with responses.RequestsMock() as req_mock: yield req_mock @pytest.fixture def non_terminal_poll_result(rpc_logs): # pylint: disable=redefined-outer-name result = { "result": { "state": "running", "start": "2020-03-10T17:49:39.095678Z", "end": None, "elapsed": 1.471953, "logs": rpc_logs, "tags": {}, }, "id": "846157fe-62f7-11ea-9cd7-acde48001122", "jsonrpc": "2.0", } return result @pytest.fixture def terminal_poll_result(rpc_logs): # pylint: disable=redefined-outer-name result = { "result": { "state": "success", "start": "2020-03-10T17:52:19.254197Z", "end": "2020-03-10T17:53:06.195224Z", "elapsed": 46.941027, "logs": rpc_logs, "tags": {}, "results": [ { "node": { "raw_sql": "\n\n{{\n config(\n unique_key='ds',\n strategy='check',\n check_cols='all'\n )\n}}\n\nselect src.*\nfrom {{ source('dagster', 'daily_fulfillment_forecast') }} src\n\n", "database": "snapshots_david_wallace", "schema": "dagster", "fqn": [ "dataland_dbt", "dagster", "daily_fulfillment_forecast_snapshot", "daily_fulfillment_forecast_snapshot", ], "unique_id": "snapshot.dataland_dbt.daily_fulfillment_forecast_snapshot", "package_name": "dataland_dbt", "root_path": "/Users/dwall/repos/dataland-dbt", "path": "dagster/daily_fulfillment_forecast_snapshot.sql", "original_file_path": "snapshots/dagster/daily_fulfillment_forecast_snapshot.sql", "name": "daily_fulfillment_forecast_snapshot", "resource_type": "snapshot", "alias": "daily_fulfillment_forecast_snapshot", "config": { "enabled": True, "materialized": "snapshot", "persist_docs": {}, "post-hook": [], "pre-hook": [], "vars": {}, "quoting": {}, "column_types": {}, "tags": [], "unique_key": "ds", "target_schema": "dagster", "target_database": "snapshots_david_wallace", "strategy": "check", "check_cols": "all", "transient": False, }, "tags": [], "refs": [], "sources": [["dagster", "daily_fulfillment_forecast"]], "depends_on": { "nodes": ["source.dataland_dbt.dagster.daily_fulfillment_forecast"], "macros": [], }, "docrefs": [], "description": "", "columns": {}, "patch_path": None, "build_path": "target/run/dataland_dbt/dagster/daily_fulfillment_forecast_snapshot.sql", "compiled": True, "compiled_sql": "\n\n\n\nselect src.*\nfrom ingest_dev.dagster.daily_fulfillment_forecast src\n", "extra_ctes_injected": True, "extra_ctes": [], "injected_sql": "\n\n\n\nselect src.*\nfrom ingest_dev.dagster.daily_fulfillment_forecast src\n", "wrapped_sql": "None", }, "error": None, "status": "SUCCESS 0", "execution_time": 14.527844190597534, "thread_id": "Thread-1", "timing": [ { "name": "compile", "started_at": "2020-03-10T17:52:50.519541Z", "completed_at": "2020-03-10T17:52:50.533709Z", }, { "name": "execute", "started_at": "2020-03-10T17:52:50.533986Z", "completed_at": "2020-03-10T17:53:05.046646Z", }, ], "fail": None, "warn": None, "skip": False, } ], "generated_at": "2020-03-10T17:53:06.001341Z", "elapsed_time": 44.305715799331665, }, "id": "016d2822-62f8-11ea-906b-acde48001122", "jsonrpc": "2.0", } return result # ======= CLIENT ======== @pytest.fixture(scope="session") def client(): return DbtRpcClient(host="0.0.0.0", port=8580) # ======= UTILS ======== @pytest.fixture def rpc_logs(): return [ { "timestamp": "2020-03-10T18:19:06.726848Z", "message": "finished collecting timing info", "channel": "dbt", "level": 10, "levelname": "DEBUG", "thread_name": "Thread-1", "process": 18546, "extra": { "timing_info": { "name": "execute", "started_at": "2020-03-10T18:18:54.823894Z", "completed_at": "2020-03-10T18:19:06.726805Z", }, "json_only": True, "unique_id": "snapshot.dataland_dbt.daily_fulfillment_forecast_snapshot", "run_state": "running", "context": "server", }, "exc_info": None, }, { "timestamp": "2020-03-10T18:19:06.727723Z", "message": "11:19:06 | 1 of 1 OK snapshotted snapshots_david_wallace.dagster.daily_fulfillment_forecast_snapshot [\u001b[32mSUCCESS 0\u001b[0m in 11.92s]", "channel": "dbt", "level": 11, "levelname": "INFO", "thread_name": "Thread-1", "process": 18546, "extra": { "unique_id": "snapshot.dataland_dbt.daily_fulfillment_forecast_snapshot", "run_state": "running", "context": "server", }, "exc_info": None, }, ] diff --git a/python_modules/libraries/dagster-gcp/dagster_gcp/bigquery/solids.py b/python_modules/libraries/dagster-gcp/dagster_gcp/bigquery/solids.py index f5744feee..8c926e7bd 100644 --- a/python_modules/libraries/dagster-gcp/dagster_gcp/bigquery/solids.py +++ b/python_modules/libraries/dagster-gcp/dagster_gcp/bigquery/solids.py @@ -1,176 +1,176 @@ import hashlib from dagster import InputDefinition, List, Nothing, OutputDefinition, check, solid from dagster_pandas import DataFrame from google.cloud.bigquery.job import LoadJobConfig, QueryJobConfig from google.cloud.bigquery.table import EncryptionConfiguration, TimePartitioning from .configs import ( define_bigquery_create_dataset_config, define_bigquery_delete_dataset_config, define_bigquery_load_config, define_bigquery_query_config, ) from .types import BigQueryLoadSource _START = "start" def _preprocess_config(cfg): destination_encryption_configuration = cfg.get("destination_encryption_configuration") time_partitioning = cfg.get("time_partitioning") if destination_encryption_configuration is not None: cfg["destination_encryption_configuration"] = EncryptionConfiguration( kms_key_name=destination_encryption_configuration ) if time_partitioning is not None: cfg["time_partitioning"] = TimePartitioning(**time_partitioning) return cfg def bq_solid_for_queries(sql_queries): """ Executes BigQuery SQL queries. Expects a BQ client to be provisioned in resources as context.resources.bigquery. """ sql_queries = check.list_param(sql_queries, "sql queries", of_type=str) m = hashlib.sha1() for query in sql_queries: - m.update(query.encode()) + m.update(query.encode("utf-8")) name = "bq_solid_{hash}".format(hash=m.hexdigest()[:10]) @solid( name=name, input_defs=[InputDefinition(_START, Nothing)], output_defs=[OutputDefinition(List[DataFrame])], config_schema=define_bigquery_query_config(), required_resource_keys={"bigquery"}, tags={"kind": "sql", "sql": "\n".join(sql_queries)}, ) def _solid(context): # pylint: disable=unused-argument query_job_config = _preprocess_config(context.solid_config.get("query_job_config", {})) # Retrieve results as pandas DataFrames results = [] for sql_query in sql_queries: # We need to construct a new QueryJobConfig for each query. # See: https://bit.ly/2VjD6sl cfg = QueryJobConfig(**query_job_config) if query_job_config else None context.log.info( "executing query %s with config: %s" % (sql_query, cfg.to_api_repr() if cfg else "(no config provided)") ) results.append( context.resources.bigquery.query(sql_query, job_config=cfg).to_dataframe() ) return results return _solid BIGQUERY_LOAD_CONFIG = define_bigquery_load_config() @solid( input_defs=[InputDefinition("paths", List[str])], output_defs=[OutputDefinition(Nothing)], config_schema=BIGQUERY_LOAD_CONFIG, required_resource_keys={"bigquery"}, ) def import_gcs_paths_to_bq(context, paths): return _execute_load_in_source(context, paths, BigQueryLoadSource.GCS) @solid( input_defs=[InputDefinition("df", DataFrame)], output_defs=[OutputDefinition(Nothing)], config_schema=BIGQUERY_LOAD_CONFIG, required_resource_keys={"bigquery"}, ) def import_df_to_bq(context, df): return _execute_load_in_source(context, df, BigQueryLoadSource.DataFrame) @solid( input_defs=[InputDefinition("path", str)], output_defs=[OutputDefinition(Nothing)], config_schema=BIGQUERY_LOAD_CONFIG, required_resource_keys={"bigquery"}, ) def import_file_to_bq(context, path): return _execute_load_in_source(context, path, BigQueryLoadSource.File) def _execute_load_in_source(context, source, source_name): destination = context.solid_config.get("destination") load_job_config = _preprocess_config(context.solid_config.get("load_job_config", {})) cfg = LoadJobConfig(**load_job_config) if load_job_config else None context.log.info( "executing BQ load with config: %s for source %s" % (cfg.to_api_repr() if cfg else "(no config provided)", source) ) if source_name == BigQueryLoadSource.DataFrame: context.resources.bigquery.load_table_from_dataframe( source, destination, job_config=cfg ).result() # Load from file. See: https://cloud.google.com/bigquery/docs/loading-data-local elif source_name == BigQueryLoadSource.File: with open(source, "rb") as file_obj: context.resources.bigquery.load_table_from_file( file_obj, destination, job_config=cfg ).result() # Load from GCS. See: https://cloud.google.com/bigquery/docs/loading-data-cloud-storage elif source_name == BigQueryLoadSource.GCS: context.resources.bigquery.load_table_from_uri(source, destination, job_config=cfg).result() @solid( input_defs=[InputDefinition(_START, Nothing)], config_schema=define_bigquery_create_dataset_config(), required_resource_keys={"bigquery"}, ) def bq_create_dataset(context): """BigQuery Create Dataset. This solid encapsulates creating a BigQuery dataset. Expects a BQ client to be provisioned in resources as context.resources.bigquery. """ (dataset, exists_ok) = [context.solid_config.get(k) for k in ("dataset", "exists_ok")] context.log.info("executing BQ create_dataset for dataset %s" % (dataset)) context.resources.bigquery.create_dataset(dataset, exists_ok) @solid( input_defs=[InputDefinition(_START, Nothing)], config_schema=define_bigquery_delete_dataset_config(), required_resource_keys={"bigquery"}, ) def bq_delete_dataset(context): """BigQuery Delete Dataset. This solid encapsulates deleting a BigQuery dataset. Expects a BQ client to be provisioned in resources as context.resources.bigquery. """ (dataset, delete_contents, not_found_ok) = [ context.solid_config.get(k) for k in ("dataset", "delete_contents", "not_found_ok") ] context.log.info("executing BQ delete_dataset for dataset %s" % dataset) context.resources.bigquery.delete_dataset( dataset, delete_contents=delete_contents, not_found_ok=not_found_ok ) diff --git a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/dataproc_tests/test_resources.py b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/dataproc_tests/test_resources.py index 50978db72..539890334 100644 --- a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/dataproc_tests/test_resources.py +++ b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/dataproc_tests/test_resources.py @@ -1,126 +1,129 @@ import os import re import uuid import httplib2 from dagster import ModeDefinition, PipelineDefinition, execute_pipeline, seven from dagster.seven import mock from dagster_gcp import dataproc_resource, dataproc_solid PROJECT_ID = os.getenv("GCP_PROJECT_ID", "default_project") CLUSTER_NAME = "test-%s" % uuid.uuid4().hex REGION = "us-west1" DATAPROC_BASE_URI = "https://dataproc.googleapis.com/v1/projects/{project}/regions/{region}".format( project=PROJECT_ID, region=REGION ) DATAPROC_CLUSTERS_URI = "{base_uri}/clusters".format(base_uri=DATAPROC_BASE_URI) DATAPROC_JOBS_URI = "{base_uri}/jobs".format(base_uri=DATAPROC_BASE_URI) DATAPROC_SCHEMA_URI = "https://www.googleapis.com/discovery/v1/apis/dataproc/v1/rest" EXPECTED_RESULTS = [ # OAuth authorize credentials (re.escape("https://oauth2.googleapis.com/token"), "POST", {"access_token": "foo"}), # Cluster create (re.escape(DATAPROC_CLUSTERS_URI + "?alt=json"), "POST", {}), # Cluster get ( re.escape(DATAPROC_CLUSTERS_URI + "/{}?alt=json".format(CLUSTER_NAME)), "GET", {"status": {"state": "RUNNING"}}, ), # Jobs submit ( re.escape(DATAPROC_JOBS_URI + ":submit?alt=json"), "POST", {"reference": {"jobId": "some job ID"}}, ), # Jobs get (re.escape(DATAPROC_JOBS_URI) + r".*?\?alt=json", "GET", {"status": {"state": "DONE"}}), # Cluster delete (re.escape(DATAPROC_CLUSTERS_URI + "/{}?alt=json".format(CLUSTER_NAME)), "DELETE", {}), ] class HttpSnooper(httplib2.Http): def __init__(self, *args, **kwargs): super(HttpSnooper, self).__init__(*args, **kwargs) def request( self, uri, method="GET", body=None, headers=None, redirections=5, connection_type=None ): for expected_uri, expected_method, result in EXPECTED_RESULTS: if re.match(expected_uri, uri) and method == expected_method: - return (httplib2.Response({"status": "200"}), seven.json.dumps(result).encode()) + return ( + httplib2.Response({"status": "200"}), + seven.json.dumps(result).encode("utf-8"), + ) # Pass this one through since its the entire JSON schema used for dynamic object creation if uri == DATAPROC_SCHEMA_URI: response, content = super(HttpSnooper, self).request( uri, method=method, body=body, headers=headers, redirections=redirections, connection_type=connection_type, ) return response, content def test_dataproc_resource(): """Tests dataproc cluster creation/deletion. Requests are captured by the responses library, so no actual HTTP requests are made here. Note that inspecting the HTTP requests can be useful for debugging, which can be done by adding: import httplib2 httplib2.debuglevel = 4 """ with mock.patch("httplib2.Http", new=HttpSnooper): pipeline = PipelineDefinition( name="test_dataproc_resource", solid_defs=[dataproc_solid], mode_defs=[ModeDefinition(resource_defs={"dataproc": dataproc_resource})], ) result = execute_pipeline( pipeline, { "solids": { "dataproc_solid": { "config": { "job_config": { "projectId": PROJECT_ID, "region": REGION, "job": { "reference": {"projectId": PROJECT_ID}, "placement": {"clusterName": CLUSTER_NAME}, "hiveJob": {"queryList": {"queries": ["SHOW DATABASES"]}}, }, }, "job_scoped_cluster": True, } } }, "resources": { "dataproc": { "config": { "projectId": PROJECT_ID, "clusterName": CLUSTER_NAME, "region": REGION, "cluster_config": { "softwareConfig": { "properties": { # Create a single-node cluster # This needs to be the string "true" when # serialized, not a boolean true "dataproc:dataproc.allow.zero.workers": "true" } } }, } } }, }, ) assert result.success diff --git a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py index 92796348d..a954e725b 100644 --- a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py +++ b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py @@ -1,114 +1,115 @@ import os import sys import tempfile -import six from dagster import DagsterEventType, execute_pipeline, pipeline, solid from dagster.core.instance import DagsterInstance, InstanceType from dagster.core.launcher import DefaultRunLauncher from dagster.core.run_coordinator import DefaultRunCoordinator from dagster.core.storage.compute_log_manager import ComputeIOType from dagster.core.storage.event_log import SqliteEventLogStorage from dagster.core.storage.root import LocalArtifactStorage from dagster.core.storage.runs import SqliteRunStorage from dagster_gcp.gcs import GCSComputeLogManager from google.cloud import storage HELLO_WORLD = "Hello World" SEPARATOR = os.linesep if (os.name == "nt" and sys.version_info < (3,)) else "\n" EXPECTED_LOGS = [ 'STEP_START - Started execution of step "easy".', 'STEP_OUTPUT - Yielded output "result" of type "Any"', 'STEP_SUCCESS - Finished execution of step "easy"', ] def test_compute_log_manager(gcs_bucket): @pipeline def simple(): @solid def easy(context): context.log.info("easy") print(HELLO_WORLD) # pylint: disable=print-call return "easy" easy() with tempfile.TemporaryDirectory() as temp_dir: run_store = SqliteRunStorage.from_local(temp_dir) event_store = SqliteEventLogStorage(temp_dir) manager = GCSComputeLogManager(bucket=gcs_bucket, prefix="my_prefix", local_dir=temp_dir) instance = DagsterInstance( instance_type=InstanceType.PERSISTENT, local_artifact_storage=LocalArtifactStorage(temp_dir), run_storage=run_store, event_storage=event_store, compute_log_manager=manager, run_coordinator=DefaultRunCoordinator(), run_launcher=DefaultRunLauncher(), ) result = execute_pipeline(simple, instance=instance) compute_steps = [ event.step_key for event in result.step_event_list if event.event_type == DagsterEventType.STEP_START ] assert len(compute_steps) == 1 step_key = compute_steps[0] stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT) assert stdout.data == HELLO_WORLD + SEPARATOR stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR) for expected in EXPECTED_LOGS: assert expected in stderr.data # Check GCS directly - stderr_gcs = six.ensure_str( + stderr_gcs = ( storage.Client() .get_bucket(gcs_bucket) .blob( "{prefix}/storage/{run_id}/compute_logs/easy.err".format( prefix="my_prefix", run_id=result.run_id ) ) - .download_as_string() + .download_as_bytes() + .decode("utf-8") ) + for expected in EXPECTED_LOGS: assert expected in stderr_gcs # Check download behavior by deleting locally cached logs compute_logs_dir = os.path.join(temp_dir, result.run_id, "compute_logs") for filename in os.listdir(compute_logs_dir): os.unlink(os.path.join(compute_logs_dir, filename)) stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT) assert stdout.data == HELLO_WORLD + SEPARATOR stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR) for expected in EXPECTED_LOGS: assert expected in stderr.data def test_compute_log_manager_from_config(gcs_bucket): s3_prefix = "foobar" dagster_yaml = """ compute_logs: module: dagster_gcp.gcs.compute_log_manager class: GCSComputeLogManager config: bucket: "{bucket}" local_dir: "/tmp/cool" prefix: "{prefix}" """.format( bucket=gcs_bucket, prefix=s3_prefix ) with tempfile.TemporaryDirectory() as tempdir: with open(os.path.join(tempdir, "dagster.yaml"), "wb") as f: - f.write(six.ensure_binary(dagster_yaml)) + f.write(dagster_yaml.encode("utf-8")) instance = DagsterInstance.from_config(tempdir) assert isinstance(instance.compute_log_manager, GCSComputeLogManager) diff --git a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_gcs_file_manager.py b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_gcs_file_manager.py index 144473f95..5aa83b2c8 100644 --- a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_gcs_file_manager.py +++ b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_gcs_file_manager.py @@ -1,72 +1,72 @@ from dagster import ModeDefinition, configured, execute_pipeline, pipeline, solid from dagster.seven import mock from dagster_gcp.gcs.file_manager import GCSFileHandle, GCSFileManager from dagster_gcp.gcs.resources import gcs_file_manager from google.cloud import storage def test_gcs_file_manager_write(): gcs_mock = mock.MagicMock() file_manager = GCSFileManager(storage.client.Client(), "some-bucket", "some-key") file_manager._client = gcs_mock # pylint:disable=protected-access - foo_bytes = "foo".encode() + foo_bytes = b"foo" file_handle = file_manager.write_data(foo_bytes) assert isinstance(file_handle, GCSFileHandle) assert file_handle.gcs_bucket == "some-bucket" assert file_handle.gcs_key.startswith("some-key/") assert gcs_mock.get_bucket().blob().upload_from_file.call_count == 1 file_handle = file_manager.write_data(foo_bytes, ext="foo") assert isinstance(file_handle, GCSFileHandle) assert file_handle.gcs_bucket == "some-bucket" assert file_handle.gcs_key.startswith("some-key/") assert file_handle.gcs_key[-4:] == ".foo" assert gcs_mock.get_bucket().blob().upload_from_file.call_count == 2 @mock.patch("dagster_gcp.gcs.resources.storage.client.Client") @mock.patch("dagster_gcp.gcs.resources.GCSFileManager") def test_gcs_file_manger_resource(MockGCSFileManager, mock_storage_client_Client): did_it_run = dict(it_ran=False) resource_config = { "project": "some-project", "gcs_bucket": "some-bucket", "gcs_prefix": "some-prefix", } @solid(required_resource_keys={"file_manager"}) def test_solid(context): # test that we got back a GCSFileManager assert context.resources.file_manager == MockGCSFileManager.return_value # make sure the file manager was initalized with the config we are supplying MockGCSFileManager.assert_called_once_with( client=mock_storage_client_Client.return_value, gcs_bucket=resource_config["gcs_bucket"], gcs_base_key=resource_config["gcs_prefix"], ) mock_storage_client_Client.assert_called_once_with(project=resource_config["project"]) did_it_run["it_ran"] = True @pipeline( mode_defs=[ ModeDefinition( resource_defs={"file_manager": configured(gcs_file_manager)(resource_config)}, ) ] ) def test_pipeline(): test_solid() execute_pipeline(test_pipeline) assert did_it_run["it_ran"] diff --git a/python_modules/libraries/dagster-github/dagster_github_tests/test_resources.py b/python_modules/libraries/dagster-github/dagster_github_tests/test_resources.py index a646c3ca4..918bab893 100644 --- a/python_modules/libraries/dagster-github/dagster_github_tests/test_resources.py +++ b/python_modules/libraries/dagster-github/dagster_github_tests/test_resources.py @@ -1,189 +1,189 @@ import time import requests import responses from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from dagster import ModeDefinition, execute_solid, solid from dagster_github import github_resource from dagster_github.resources import GithubResource FAKE_PRIVATE_RSA_KEY = ( rsa.generate_private_key(public_exponent=65537, key_size=1024, backend=default_backend()) .private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) - .decode() + .decode("utf-8") ) @responses.activate def test_github_resource_get_installations(): @solid(required_resource_keys={"github"}) def github_solid(context): assert context.resources.github with responses.RequestsMock() as rsps: rsps.add( rsps.GET, "https://api.github.com/app/installations", status=200, json={}, ) context.resources.github.get_installations() result = execute_solid( github_solid, run_config={ "resources": { "github": { "config": { "github_app_id": 123, "github_app_private_rsa_key": FAKE_PRIVATE_RSA_KEY, "github_installation_id": 123, } } } }, mode_def=ModeDefinition(resource_defs={"github": github_resource}), ) assert result.success @responses.activate def test_github_resource_create_issue(): @solid(required_resource_keys={"github"}) def github_solid(context): assert context.resources.github with responses.RequestsMock() as rsps: rsps.add( rsps.POST, "https://api.github.com/app/installations/123/access_tokens", status=201, json={"token": "fake_token", "expires_at": "2016-07-11T22:14:10Z",}, ) rsps.add( rsps.POST, "https://api.github.com/graphql", status=200, json={"data": {"repository": {"id": 123}},}, ) rsps.add( rsps.POST, "https://api.github.com/graphql", status=200, json={}, ) context.resources.github.create_issue( repo_name="dagster", repo_owner="dagster-io", title="test", body="body", ) result = execute_solid( github_solid, run_config={ "resources": { "github": { "config": { "github_app_id": 123, "github_app_private_rsa_key": FAKE_PRIVATE_RSA_KEY, "github_installation_id": 123, } } } }, mode_def=ModeDefinition(resource_defs={"github": github_resource}), ) assert result.success @responses.activate def test_github_resource_execute(): @solid(required_resource_keys={"github"}) def github_solid(context): assert context.resources.github with responses.RequestsMock() as rsps: rsps.add( rsps.POST, "https://api.github.com/app/installations/123/access_tokens", status=201, json={"token": "fake_token", "expires_at": "2016-07-11T22:14:10Z",}, ) rsps.add( rsps.POST, "https://api.github.com/graphql", status=200, json={"data": {"repository": {"id": 123}},}, ) context.resources.github.execute( query=""" query get_repo_id($repo_name: String!, $repo_owner: String!) { repository(name: $repo_name, owner: $repo_owner) { id } }""", variables={"repo_name": "dagster", "repo_owner": "dagster-io"}, ) result = execute_solid( github_solid, run_config={ "resources": { "github": { "config": { "github_app_id": 123, # Do not be alarmed, this is a fake key "github_app_private_rsa_key": FAKE_PRIVATE_RSA_KEY, "github_installation_id": 123, } } } }, mode_def=ModeDefinition(resource_defs={"github": github_resource}), ) assert result.success @responses.activate def test_github_resource_token_expiration(): class GithubResourceTesting(GithubResource): def __init__(self, client, app_id, app_private_rsa_key, default_installation_id): GithubResource.__init__( self, client=client, app_id=app_id, app_private_rsa_key=app_private_rsa_key, default_installation_id=default_installation_id, ) self.installation_tokens = { "123": {"value": "test", "expires": int(time.time()) - 1000} } self.app_token = { "value": "test", "expires": int(time.time()) - 1000, } resource = GithubResourceTesting( client=requests.Session(), app_id="abc", app_private_rsa_key=FAKE_PRIVATE_RSA_KEY, default_installation_id="123", ) with responses.RequestsMock() as rsps: rsps.add( rsps.POST, "https://api.github.com/app/installations/123/access_tokens", status=201, json={"token": "fake_token", "expires_at": "2016-07-11T22:14:10Z",}, ) rsps.add( rsps.POST, "https://api.github.com/graphql", status=200, json={"data": {"repository": {"id": 123}},}, ) res = resource.execute( query=""" query get_repo_id($repo_name: String!, $repo_owner: String!) { repository(name: $repo_name, owner: $repo_owner) { id } }""", variables={"repo_name": "dagster", "repo_owner": "dagster-io"}, ) assert res["data"]["repository"]["id"] == 123 diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/client.py b/python_modules/libraries/dagster-k8s/dagster_k8s/client.py index 5966b459f..2ff84ab4f 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/client.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/client.py @@ -1,504 +1,494 @@ import logging import sys import time from enum import Enum import kubernetes -import six from dagster import DagsterInstance, check from dagster.core.storage.pipeline_run import PipelineRunStatus -from six import raise_from DEFAULT_WAIT_TIMEOUT = 86400.0 # 1 day DEFAULT_WAIT_BETWEEN_ATTEMPTS = 10.0 # 10 seconds DEFAULT_JOB_POD_COUNT = 1 # expect job:pod to be 1:1 by default class WaitForPodState(Enum): Ready = "READY" Terminated = "TERMINATED" class DagsterK8sError(Exception): pass class DagsterK8sTimeoutError(DagsterK8sError): pass class DagsterK8sAPIRetryLimitExceeded(Exception): def __init__(self, *args, **kwargs): k8s_api_exception = check.inst_param( kwargs.pop("k8s_api_exception"), "k8s_api_exception", Exception ) original_exc_info = check.tuple_param(kwargs.pop("original_exc_info"), "original_exc_info") max_retries = check.int_param(kwargs.pop("max_retries"), "max_retries") check.invariant(original_exc_info[0] is not None) super(DagsterK8sAPIRetryLimitExceeded, self).__init__( f"Retry limit of {max_retries} exceeded: " + args[0], *args[1:], **kwargs, ) self.k8s_api_exception = check.opt_inst_param( k8s_api_exception, "k8s_api_exception", Exception ) self.original_exc_info = original_exc_info class DagsterK8sUnrecoverableAPIError(Exception): def __init__(self, *args, **kwargs): k8s_api_exception = check.inst_param( kwargs.pop("k8s_api_exception"), "k8s_api_exception", Exception ) original_exc_info = check.tuple_param(kwargs.pop("original_exc_info"), "original_exc_info") check.invariant(original_exc_info[0] is not None) super(DagsterK8sUnrecoverableAPIError, self).__init__(args[0], *args[1:], **kwargs) self.k8s_api_exception = check.opt_inst_param( k8s_api_exception, "k8s_api_exception", Exception ) self.original_exc_info = original_exc_info class DagsterK8sPipelineStatusException(Exception): pass WHITELISTED_TRANSIENT_K8S_STATUS_CODES = [ 503, # Service unavailable 504, # Gateway timeout ] def k8s_api_retry( fn, max_retries, timeout, msg_fn=lambda: "Unexpected error encountered in Kubernetes API Client.", ): check.callable_param(fn, "fn") check.int_param(max_retries, "max_retries") check.numeric_param(timeout, "timeout") remaining_attempts = 1 + max_retries while remaining_attempts > 0: remaining_attempts -= 1 try: return fn() except kubernetes.client.rest.ApiException as e: # Only catch whitelisted ApiExceptions status = e.status # Check if the status code is generally whitelisted whitelisted = status in WHITELISTED_TRANSIENT_K8S_STATUS_CODES # If there are remaining attempts, swallow the error if whitelisted and remaining_attempts > 0: time.sleep(timeout) elif whitelisted and remaining_attempts == 0: - raise_from( - DagsterK8sAPIRetryLimitExceeded( - msg_fn(), - k8s_api_exception=e, - max_retries=max_retries, - original_exc_info=sys.exc_info(), - ), - e, - ) + raise DagsterK8sAPIRetryLimitExceeded( + msg_fn(), + k8s_api_exception=e, + max_retries=max_retries, + original_exc_info=sys.exc_info(), + ) from e else: - raise_from( - DagsterK8sUnrecoverableAPIError( - msg_fn(), k8s_api_exception=e, original_exc_info=sys.exc_info(), - ), - e, - ) + raise DagsterK8sUnrecoverableAPIError( + msg_fn(), k8s_api_exception=e, original_exc_info=sys.exc_info(), + ) from e class KubernetesWaitingReasons: PodInitializing = "PodInitializing" ContainerCreating = "ContainerCreating" ErrImagePull = "ErrImagePull" ImagePullBackOff = "ImagePullBackOff" CrashLoopBackOff = "CrashLoopBackOff" RunContainerError = "RunContainerError" class DagsterKubernetesClient: def __init__(self, batch_api, core_api, logger, sleeper, timer): self.batch_api = batch_api self.core_api = core_api self.logger = logger self.sleeper = sleeper self.timer = timer @staticmethod def production_client(): return DagsterKubernetesClient( kubernetes.client.BatchV1Api(), kubernetes.client.CoreV1Api(), logging.info, time.sleep, time.time, ) ### Job operations ### def wait_for_job( self, job_name, namespace, wait_timeout=DEFAULT_WAIT_TIMEOUT, wait_time_between_attempts=DEFAULT_WAIT_BETWEEN_ATTEMPTS, start_time=None, ): """ Wait for a job to launch and be running. Args: job_name (str): Name of the job to wait for. namespace (str): Namespace in which the job is located. wait_timeout (numeric, optional): Timeout after which to give up and raise exception. Defaults to DEFAULT_WAIT_TIMEOUT. wait_time_between_attempts (numeric, optional): Wait time between polling attempts. Defaults to DEFAULT_WAIT_BETWEEN_ATTEMPTS. Raises: DagsterK8sError: Raised when wait_timeout is exceeded or an error is encountered. """ check.str_param(job_name, "job_name") check.str_param(namespace, "namespace") check.numeric_param(wait_timeout, "wait_timeout") check.numeric_param(wait_time_between_attempts, "wait_time_between_attempts") job = None start = start_time or self.timer() while not job: if self.timer() - start > wait_timeout: raise DagsterK8sTimeoutError( "Timed out while waiting for job {job_name}" " to launch".format(job_name=job_name) ) # Get all jobs in the namespace and find the matching job def _get_jobs_for_namespace(): jobs = self.batch_api.list_namespaced_job( namespace=namespace, field_selector="metadata.name={}".format(job_name) ) if jobs.items: check.invariant( len(jobs.items) == 1, 'There should only be one k8s job with name "{}", but got multiple jobs:" {}'.format( job_name, jobs.items ), ) return jobs.items[0] else: return None job = k8s_api_retry( _get_jobs_for_namespace, max_retries=3, timeout=wait_time_between_attempts ) if not job: self.logger('Job "{job_name}" not yet launched, waiting'.format(job_name=job_name)) self.sleeper(wait_time_between_attempts) def wait_for_job_success( self, job_name, namespace, instance=None, run_id=None, wait_timeout=DEFAULT_WAIT_TIMEOUT, wait_time_between_attempts=DEFAULT_WAIT_BETWEEN_ATTEMPTS, num_pods_to_wait_for=DEFAULT_JOB_POD_COUNT, ): """Poll a job for successful completion. Args: job_name (str): Name of the job to wait for. namespace (str): Namespace in which the job is located. wait_timeout (numeric, optional): Timeout after which to give up and raise exception. Defaults to DEFAULT_WAIT_TIMEOUT. wait_time_between_attempts (numeric, optional): Wait time between polling attempts. Defaults to DEFAULT_WAIT_BETWEEN_ATTEMPTS. Raises: DagsterK8sError: Raised when wait_timeout is exceeded or an error is encountered. """ check.str_param(job_name, "job_name") check.str_param(namespace, "namespace") check.opt_inst_param(instance, "instance", DagsterInstance) check.opt_str_param(run_id, "run_id") check.numeric_param(wait_timeout, "wait_timeout") check.numeric_param(wait_time_between_attempts, "wait_time_between_attempts") check.int_param(num_pods_to_wait_for, "num_pods_to_wait_for") start = self.timer() # Wait for job to be running self.wait_for_job( job_name, namespace, wait_timeout=wait_timeout, wait_time_between_attempts=wait_time_between_attempts, start_time=start, ) # Wait for the job status to be completed. We check the status every # wait_time_between_attempts seconds while True: if self.timer() - start > wait_timeout: raise DagsterK8sTimeoutError( "Timed out while waiting for job {job_name}" " to complete".format(job_name=job_name) ) # Reads the status of the specified job. Returns a V1Job object that # we need to read the status off of. status = None def _get_job_status(): job = self.batch_api.read_namespaced_job_status(job_name, namespace=namespace) return job.status status = k8s_api_retry( _get_job_status, max_retries=3, timeout=wait_time_between_attempts ) # status.succeeded represents the number of pods which reached phase Succeeded. if status.succeeded == num_pods_to_wait_for: break # status.failed represents the number of pods which reached phase Failed. if status.failed and status.failed > 0: raise DagsterK8sError( "Encountered failed job pods for job {job_name} with status: {status}, " "in namespace {namespace}".format( job_name=job_name, status=status, namespace=namespace ) ) if instance and run_id: pipeline_run = instance.get_run_by_id(run_id) if not pipeline_run: raise DagsterK8sPipelineStatusException() pipeline_run_status = pipeline_run.status if pipeline_run_status != PipelineRunStatus.STARTED: raise DagsterK8sPipelineStatusException() self.sleeper(wait_time_between_attempts) def delete_job( self, job_name, namespace, ): """Delete Kubernetes Job. We also need to delete corresponding pods due to: https://github.com/kubernetes-client/python/issues/234 Args: job_name (str): Name of the job to wait for. namespace (str): Namespace in which the job is located. """ check.str_param(job_name, "job_name") check.str_param(namespace, "namespace") try: pod_names = self.get_pod_names_in_job(job_name, namespace) # Collect all the errors so that we can post-process before raising pod_names = self.get_pod_names_in_job(job_name, namespace) errors = [] try: self.batch_api.delete_namespaced_job(name=job_name, namespace=namespace) except Exception as e: # pylint: disable=broad-except errors.append(e) for pod_name in pod_names: try: self.core_api.delete_namespaced_pod(name=pod_name, namespace=namespace) except Exception as e: # pylint: disable=broad-except errors.append(e) if len(errors) > 0: # Raise first non-expected error. Else, raise first error. for error in errors: if not ( isinstance(error, kubernetes.client.rest.ApiException) and error.reason == "Not Found" ): raise error raise errors[0] return True except kubernetes.client.rest.ApiException as e: if e.reason == "Not Found": return False raise e ### Pod operations ### def get_pod_names_in_job(self, job_name, namespace): """Get the names of pods launched by the job ``job_name``. Args: job_name (str): Name of the job to inspect. namespace (str): Namespace in which the job is located. Returns: List[str]: List of all pod names that have been launched by the job ``job_name``. """ check.str_param(job_name, "job_name") check.str_param(namespace, "namespace") pods = self.core_api.list_namespaced_pod( namespace=namespace, label_selector="job-name={}".format(job_name) ).items return [p.metadata.name for p in pods] def wait_for_pod( self, pod_name, namespace, wait_for_state=WaitForPodState.Ready, wait_timeout=DEFAULT_WAIT_TIMEOUT, wait_time_between_attempts=DEFAULT_WAIT_BETWEEN_ATTEMPTS, ): """Wait for a pod to launch and be running, or wait for termination (useful for job pods). Args: pod_name (str): Name of the pod to wait for. namespace (str): Namespace in which the pod is located. wait_for_state (WaitForPodState, optional): Whether to wait for pod readiness or termination. Defaults to waiting for readiness. wait_timeout (numeric, optional): Timeout after which to give up and raise exception. Defaults to DEFAULT_WAIT_TIMEOUT. wait_time_between_attempts (numeric, optional): Wait time between polling attempts. Defaults to DEFAULT_WAIT_BETWEEN_ATTEMPTS. Raises: DagsterK8sError: Raised when wait_timeout is exceeded or an error is encountered """ check.str_param(pod_name, "pod_name") check.str_param(namespace, "namespace") check.inst_param(wait_for_state, "wait_for_state", WaitForPodState) check.numeric_param(wait_timeout, "wait_timeout") check.numeric_param(wait_time_between_attempts, "wait_time_between_attempts") self.logger('Waiting for pod "%s"' % pod_name) start = self.timer() while True: pods = self.core_api.list_namespaced_pod( namespace=namespace, field_selector="metadata.name=%s" % pod_name ).items pod = pods[0] if pods else None if self.timer() - start > wait_timeout: raise DagsterK8sError( "Timed out while waiting for pod to become ready with pod info: %s" % str(pod) ) if pod is None: self.logger('Waiting for pod "%s" to launch...' % pod_name) self.sleeper(wait_time_between_attempts) continue if not pod.status.container_statuses: self.logger("Waiting for pod container status to be set by kubernetes...") self.sleeper(wait_time_between_attempts) continue # https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#containerstatus-v1-core container_status = pod.status.container_statuses[0] # State checks below, see: # https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#containerstate-v1-core state = container_status.state if state.running is not None: if wait_for_state == WaitForPodState.Ready: # ready is boolean field of container status ready = container_status.ready if not ready: self.logger('Waiting for pod "%s" to become ready...' % pod_name) self.sleeper(wait_time_between_attempts) continue else: self.logger('Pod "%s" is ready, done waiting' % pod_name) break else: check.invariant( wait_for_state == WaitForPodState.Terminated, "New invalid WaitForPodState" ) self.sleeper(wait_time_between_attempts) continue elif state.waiting is not None: # https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#containerstatewaiting-v1-core if state.waiting.reason == KubernetesWaitingReasons.PodInitializing: self.logger('Waiting for pod "%s" to initialize...' % pod_name) self.sleeper(wait_time_between_attempts) continue elif state.waiting.reason == KubernetesWaitingReasons.ContainerCreating: self.logger("Waiting for container creation...") self.sleeper(wait_time_between_attempts) continue elif state.waiting.reason in [ KubernetesWaitingReasons.ErrImagePull, KubernetesWaitingReasons.ImagePullBackOff, KubernetesWaitingReasons.CrashLoopBackOff, KubernetesWaitingReasons.RunContainerError, ]: raise DagsterK8sError( 'Failed: Reason="{reason}" Message="{message}"'.format( reason=state.waiting.reason, message=state.waiting.message ) ) else: raise DagsterK8sError("Unknown issue: %s" % state.waiting) # https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#containerstateterminated-v1-core elif state.terminated is not None: if not state.terminated.exit_code == 0: raw_logs = self.retrieve_pod_logs(pod_name, namespace) raise DagsterK8sError( 'Pod did not exit successfully. Failed with message: "%s" and pod logs: "%s"' % (state.terminated.message, str(raw_logs)) ) else: self.logger("Pod {pod_name} exitted successfully".format(pod_name=pod_name)) break else: raise DagsterK8sError("Should not get here, unknown pod state") def retrieve_pod_logs(self, pod_name, namespace): """Retrieves the raw pod logs for the pod named `pod_name` from Kubernetes. Args: pod_name (str): The name of the pod from which to retrieve logs. namespace (str): The namespace of the pod. Returns: str: The raw logs retrieved from the pod. """ check.str_param(pod_name, "pod_name") check.str_param(namespace, "namespace") # We set _preload_content to False here to prevent the k8 python api from processing the response. # If the logs happen to be JSON - it will parse in to a dict and then coerce back to a str leaving # us with invalid JSON as the quotes have been switched to ' # # https://github.com/kubernetes-client/python/issues/811 - return six.ensure_str( - self.core_api.read_namespaced_pod_log( - name=pod_name, namespace=namespace, _preload_content=False - ).data - ) + return self.core_api.read_namespaced_pod_log( + name=pod_name, namespace=namespace, _preload_content=False + ).data diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/job.py b/python_modules/libraries/dagster-k8s/dagster_k8s/job.py index 9f2a66f10..80d35af30 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/job.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/job.py @@ -1,522 +1,521 @@ import hashlib import json import random import string from collections import namedtuple import kubernetes -import six from dagster import Array, Field, Noneable, StringSource from dagster import __version__ as dagster_version from dagster import check from dagster.config.field_utils import Permissive, Shape from dagster.config.validate import validate_config from dagster.core.errors import DagsterInvalidConfigError from dagster.serdes import whitelist_for_serdes from dagster.utils import frozentags, merge_dicts # To retry step job, users should raise RetryRequested() so that the dagster system is aware of the # retry. As an example, see retry_pipeline in dagster_test.test_project.test_pipelines.repo # To override this config, user can specify UserDefinedDagsterK8sConfig. K8S_JOB_BACKOFF_LIMIT = 0 K8S_JOB_TTL_SECONDS_AFTER_FINISHED = 24 * 60 * 60 # 1 day DAGSTER_HOME_DEFAULT = "/opt/dagster/dagster_home" # The Kubernetes Secret containing the PG password will be exposed as this env var in the job # container. DAGSTER_PG_PASSWORD_ENV_VAR = "DAGSTER_PG_PASSWORD" # We expect the PG secret to have this key. # # For an example, see: # helm/dagster/templates/secret-postgres.yaml DAGSTER_PG_PASSWORD_SECRET_KEY = "postgresql-password" # Kubernetes Job object names cannot be longer than 63 characters MAX_K8S_NAME_LEN = 63 # TODO: Deprecate this tag K8S_RESOURCE_REQUIREMENTS_KEY = "dagster-k8s/resource_requirements" K8S_RESOURCE_REQUIREMENTS_SCHEMA = Shape({"limits": Permissive(), "requests": Permissive()}) USER_DEFINED_K8S_CONFIG_KEY = "dagster-k8s/config" USER_DEFINED_K8S_CONFIG_SCHEMA = Shape( { "container_config": Permissive(), "pod_template_spec_metadata": Permissive(), "pod_spec_config": Permissive(), "job_config": Permissive(), "job_metadata": Permissive(), "job_spec_config": Permissive(), } ) class UserDefinedDagsterK8sConfig( namedtuple( "_UserDefinedDagsterK8sConfig", "container_config pod_template_spec_metadata pod_spec_config job_config job_metadata job_spec_config", ) ): def __new__( cls, container_config=None, pod_template_spec_metadata=None, pod_spec_config=None, job_config=None, job_metadata=None, job_spec_config=None, ): container_config = check.opt_dict_param(container_config, "container_config", key_type=str) pod_template_spec_metadata = check.opt_dict_param( pod_template_spec_metadata, "pod_template_spec_metadata", key_type=str ) pod_spec_config = check.opt_dict_param(pod_spec_config, "pod_spec_config", key_type=str) job_config = check.opt_dict_param(job_config, "job_config", key_type=str) job_metadata = check.opt_dict_param(job_metadata, "job_metadata", key_type=str) job_spec_config = check.opt_dict_param(job_spec_config, "job_spec_config", key_type=str) return super(UserDefinedDagsterK8sConfig, cls).__new__( cls, container_config=container_config, pod_template_spec_metadata=pod_template_spec_metadata, pod_spec_config=pod_spec_config, job_config=job_config, job_metadata=job_metadata, job_spec_config=job_spec_config, ) def to_dict(self): return { "container_config": self.container_config, "pod_template_spec_metadata": self.pod_template_spec_metadata, "pod_spec_config": self.pod_spec_config, "job_config": self.job_config, "job_metadata": self.job_metadata, "job_spec_config": self.job_spec_config, } @classmethod def from_dict(self, config_dict): return UserDefinedDagsterK8sConfig( container_config=config_dict.get("container_config"), pod_template_spec_metadata=config_dict.get("pod_template_spec_metadata"), pod_spec_config=config_dict.get("pod_spec_config"), job_config=config_dict.get("job_config"), job_metadata=config_dict.get("job_metadata"), job_spec_config=config_dict.get("job_spec_config"), ) def get_k8s_resource_requirements(tags): check.inst_param(tags, "tags", frozentags) check.invariant(K8S_RESOURCE_REQUIREMENTS_KEY in tags) resource_requirements = json.loads(tags[K8S_RESOURCE_REQUIREMENTS_KEY]) result = validate_config(K8S_RESOURCE_REQUIREMENTS_SCHEMA, resource_requirements) if not result.success: raise DagsterInvalidConfigError( "Error in tags for {}".format(K8S_RESOURCE_REQUIREMENTS_KEY), result.errors, result, ) return result.value def get_user_defined_k8s_config(tags): check.inst_param(tags, "tags", frozentags) if not any(key in tags for key in [K8S_RESOURCE_REQUIREMENTS_KEY, USER_DEFINED_K8S_CONFIG_KEY]): return UserDefinedDagsterK8sConfig() user_defined_k8s_config = {} if USER_DEFINED_K8S_CONFIG_KEY in tags: user_defined_k8s_config_value = json.loads(tags[USER_DEFINED_K8S_CONFIG_KEY]) result = validate_config(USER_DEFINED_K8S_CONFIG_SCHEMA, user_defined_k8s_config_value) if not result.success: raise DagsterInvalidConfigError( "Error in tags for {}".format(USER_DEFINED_K8S_CONFIG_KEY), result.errors, result, ) user_defined_k8s_config = result.value container_config = user_defined_k8s_config.get("container_config", {}) # Backcompat for resource requirements key if K8S_RESOURCE_REQUIREMENTS_KEY in tags: resource_requirements_config = get_k8s_resource_requirements(tags) container_config = merge_dicts( container_config, {"resources": resource_requirements_config} ) return UserDefinedDagsterK8sConfig( container_config=container_config, pod_template_spec_metadata=user_defined_k8s_config.get("pod_template_spec_metadata"), pod_spec_config=user_defined_k8s_config.get("pod_spec_config"), job_config=user_defined_k8s_config.get("job_config"), job_spec_config=user_defined_k8s_config.get("job_spec_config"), ) def get_job_name_from_run_id(run_id): return "dagster-run-{}".format(run_id) @whitelist_for_serdes class DagsterK8sJobConfig( namedtuple( "_K8sJobTaskConfig", "job_image dagster_home image_pull_policy image_pull_secrets service_account_name " "instance_config_map postgres_password_secret env_config_maps env_secrets", ) ): """Configuration parameters for launching Dagster Jobs on Kubernetes. Params: job_image (str): The docker image to use. The Job container will be launched with this image. dagster_home (str): The location of DAGSTER_HOME in the Job container; this is where the ``dagster.yaml`` file will be mounted from the instance ConfigMap specified here. image_pull_policy (Optional[str]): Allows the image pull policy to be overridden, e.g. to facilitate local testing with `kind `_. Default: ``"Always"``. See: https://kubernetes.io/docs/concepts/containers/images/#updating-images. image_pull_secrets (Optional[List[Dict[str, str]]]): Optionally, a list of dicts, each of which corresponds to a Kubernetes ``LocalObjectReference`` (e.g., ``{'name': 'myRegistryName'}``). This allows you to specify the ```imagePullSecrets`` on a pod basis. Typically, these will be provided through the service account, when needed, and you will not need to pass this argument. See: https://kubernetes.io/docs/concepts/containers/images/#specifying-imagepullsecrets-on-a-pod and https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.17/#podspec-v1-core service_account_name (Optional[str]): The name of the Kubernetes service account under which to run the Job. Defaults to "default" instance_config_map (str): The ``name`` of an existing Volume to mount into the pod in order to provide a ConfigMap for the Dagster instance. This Volume should contain a ``dagster.yaml`` with appropriate values for run storage, event log storage, etc. postgres_password_secret (str): The name of the Kubernetes Secret where the postgres password can be retrieved. Will be mounted and supplied as an environment variable to the Job Pod. env_config_maps (Optional[List[str]]): A list of custom ConfigMapEnvSource names from which to draw environment variables (using ``envFrom``) for the Job. Default: ``[]``. See: https://kubernetes.io/docs/tasks/inject-data-application/define-environment-variable-container/#define-an-environment-variable-for-a-container env_secrets (Optional[List[str]]): A list of custom Secret names from which to draw environment variables (using ``envFrom``) for the Job. Default: ``[]``. See: https://kubernetes.io/docs/tasks/inject-data-application/distribute-credentials-secure/#configure-all-key-value-pairs-in-a-secret-as-container-environment-variables """ def __new__( cls, job_image=None, dagster_home=None, image_pull_policy=None, image_pull_secrets=None, service_account_name=None, instance_config_map=None, postgres_password_secret=None, env_config_maps=None, env_secrets=None, ): return super(DagsterK8sJobConfig, cls).__new__( cls, job_image=check.opt_str_param(job_image, "job_image"), dagster_home=check.opt_str_param( dagster_home, "dagster_home", default=DAGSTER_HOME_DEFAULT ), image_pull_policy=check.opt_str_param(image_pull_policy, "image_pull_policy"), image_pull_secrets=check.opt_list_param( image_pull_secrets, "image_pull_secrets", of_type=dict ), service_account_name=check.opt_str_param(service_account_name, "service_account_name"), instance_config_map=check.str_param(instance_config_map, "instance_config_map"), postgres_password_secret=check.str_param( postgres_password_secret, "postgres_password_secret" ), env_config_maps=check.opt_list_param(env_config_maps, "env_config_maps", of_type=str), env_secrets=check.opt_list_param(env_secrets, "env_secrets", of_type=str), ) @classmethod def config_type(cls): """Combined config type which includes both run launcher and pipeline run config. """ cfg_run_launcher = DagsterK8sJobConfig.config_type_run_launcher() cfg_pipeline_run = DagsterK8sJobConfig.config_type_pipeline_run() return merge_dicts(cfg_run_launcher, cfg_pipeline_run) @classmethod def config_type_run_launcher(cls): """Configuration intended to be set on the Dagster instance. """ return { "instance_config_map": Field( StringSource, is_required=True, description="The ``name`` of an existing Volume to mount into the pod in order to " "provide a ConfigMap for the Dagster instance. This Volume should contain a " "``dagster.yaml`` with appropriate values for run storage, event log storage, etc.", ), "postgres_password_secret": Field( StringSource, is_required=True, description="The name of the Kubernetes Secret where the postgres password can be " "retrieved. Will be mounted and supplied as an environment variable to the Job Pod." 'Secret must contain the key ``"postgresql-password"`` which will be exposed in ' "the Job environment as the environment variable ``DAGSTER_PG_PASSWORD``.", ), "dagster_home": Field( StringSource, is_required=False, default_value=DAGSTER_HOME_DEFAULT, description="The location of DAGSTER_HOME in the Job container; this is where the " "``dagster.yaml`` file will be mounted from the instance ConfigMap specified here. " "Defaults to /opt/dagster/dagster_home.", ), } @classmethod def config_type_pipeline_run(cls): """Configuration intended to be set at pipeline execution time. """ return { "job_image": Field( Noneable(StringSource), is_required=False, description="Docker image to use for launched task Jobs. If the repository is not " "loaded from a GRPC server, then this field is required. If the repository is " "loaded from a GRPC server, then leave this field empty." '(Ex: "mycompany.com/dagster-k8s-image:latest").', ), "image_pull_policy": Field( StringSource, is_required=False, default_value="IfNotPresent", description="Image pull policy to set on the launched task Job Pods. Defaults to " '"IfNotPresent".', ), "image_pull_secrets": Field( Noneable(Array(Shape({"name": StringSource}))), is_required=False, description="(Advanced) Specifies that Kubernetes should get the credentials from " "the Secrets named in this list.", ), "service_account_name": Field( Noneable(StringSource), is_required=False, description="(Advanced) Override the name of the Kubernetes service account under " "which to run the Job.", ), "env_config_maps": Field( Noneable(Array(StringSource)), is_required=False, description="A list of custom ConfigMapEnvSource names from which to draw " "environment variables (using ``envFrom``) for the Job. Default: ``[]``. See:" "https://kubernetes.io/docs/tasks/inject-data-application/define-environment-variable-container/#define-an-environment-variable-for-a-container", ), "env_secrets": Field( Noneable(Array(StringSource)), is_required=False, description="A list of custom Secret names from which to draw environment " "variables (using ``envFrom``) for the Job. Default: ``[]``. See:" "https://kubernetes.io/docs/tasks/inject-data-application/distribute-credentials-secure/#configure-all-key-value-pairs-in-a-secret-as-container-environment-variables", ), } @property def env_from_sources(self): """This constructs a list of env_from sources. Along with a default base environment config map which we always load, the ConfigMaps and Secrets specified via env_config_maps and env_secrets will be pulled into the job construction here. """ config_maps = [ kubernetes.client.V1EnvFromSource( config_map_ref=kubernetes.client.V1ConfigMapEnvSource(name=config_map) ) for config_map in self.env_config_maps ] secrets = [ kubernetes.client.V1EnvFromSource( secret_ref=kubernetes.client.V1SecretEnvSource(name=secret) ) for secret in self.env_secrets ] return config_maps + secrets def to_dict(self): return self._asdict() @staticmethod def from_dict(config=None): check.opt_dict_param(config, "config") return DagsterK8sJobConfig(**config) def construct_dagster_k8s_job( job_config, args, job_name, user_defined_k8s_config=None, pod_name=None, component=None, env_vars=None, ): """Constructs a Kubernetes Job object for a dagster-graphql invocation. Args: job_config (DagsterK8sJobConfig): Job configuration to use for constructing the Kubernetes Job object. args (List[str]): CLI arguments to use with dagster-graphql in this Job. job_name (str): The name of the Job. Note that this name must be <= 63 characters in length. resources (Dict[str, Dict[str, str]]): The resource requirements for the container pod_name (str, optional): The name of the Pod. Note that this name must be <= 63 characters in length. Defaults to "-pod". component (str, optional): The name of the component, used to provide the Job label app.kubernetes.io/component. Defaults to None. env_vars(Dict[str, str]): Additional environment variables to add to the K8s Container. Returns: kubernetes.client.V1Job: A Kubernetes Job object. """ check.inst_param(job_config, "job_config", DagsterK8sJobConfig) check.list_param(args, "args", of_type=str) check.str_param(job_name, "job_name") user_defined_k8s_config = check.opt_inst_param( user_defined_k8s_config, "user_defined_k8s_config", UserDefinedDagsterK8sConfig, UserDefinedDagsterK8sConfig(), ) pod_name = check.opt_str_param(pod_name, "pod_name", default=job_name + "-pod") check.opt_str_param(component, "component") check.opt_dict_param(env_vars, "env_vars", key_type=str, value_type=str) check.invariant( len(job_name) <= MAX_K8S_NAME_LEN, "job_name is %d in length; Kubernetes Jobs cannot be longer than %d characters." % (len(job_name), MAX_K8S_NAME_LEN), ) check.invariant( len(pod_name) <= MAX_K8S_NAME_LEN, "job_name is %d in length; Kubernetes Pods cannot be longer than %d characters." % (len(pod_name), MAX_K8S_NAME_LEN), ) # See: https://kubernetes.io/docs/concepts/overview/working-with-objects/common-labels/ dagster_labels = { "app.kubernetes.io/name": "dagster", "app.kubernetes.io/instance": "dagster", "app.kubernetes.io/version": dagster_version, "app.kubernetes.io/part-of": "dagster", } if component: dagster_labels["app.kubernetes.io/component"] = component additional_k8s_env_vars = [] if env_vars: for key, value in env_vars.items(): additional_k8s_env_vars.append(kubernetes.client.V1EnvVar(name=key, value=value)) job_container = kubernetes.client.V1Container( name=job_name, image=job_config.job_image, args=args, image_pull_policy=job_config.image_pull_policy, env=[ kubernetes.client.V1EnvVar(name="DAGSTER_HOME", value=job_config.dagster_home), kubernetes.client.V1EnvVar( name=DAGSTER_PG_PASSWORD_ENV_VAR, value_from=kubernetes.client.V1EnvVarSource( secret_key_ref=kubernetes.client.V1SecretKeySelector( name=job_config.postgres_password_secret, key=DAGSTER_PG_PASSWORD_SECRET_KEY ) ), ), ] + additional_k8s_env_vars, env_from=job_config.env_from_sources, volume_mounts=[ kubernetes.client.V1VolumeMount( name="dagster-instance", mount_path="{dagster_home}/dagster.yaml".format( dagster_home=job_config.dagster_home ), sub_path="dagster.yaml", ) ], **user_defined_k8s_config.container_config, ) config_map_volume = kubernetes.client.V1Volume( name="dagster-instance", config_map=kubernetes.client.V1ConfigMapVolumeSource(name=job_config.instance_config_map), ) # If the user has defined custom labels, remove them from the pod_template_spec_metadata # key and merge them with the dagster labels user_defined_pod_template_labels = user_defined_k8s_config.pod_template_spec_metadata.pop( "labels", {} ) template = kubernetes.client.V1PodTemplateSpec( metadata=kubernetes.client.V1ObjectMeta( name=pod_name, labels=merge_dicts(dagster_labels, user_defined_pod_template_labels), **user_defined_k8s_config.pod_template_spec_metadata, ), spec=kubernetes.client.V1PodSpec( image_pull_secrets=[ kubernetes.client.V1LocalObjectReference(name=x["name"]) for x in job_config.image_pull_secrets ], service_account_name=job_config.service_account_name, restart_policy="Never", containers=[job_container], volumes=[config_map_volume], **user_defined_k8s_config.pod_spec_config, ), ) job = kubernetes.client.V1Job( api_version="batch/v1", kind="Job", metadata=kubernetes.client.V1ObjectMeta( name=job_name, labels=dagster_labels, **user_defined_k8s_config.job_metadata ), spec=kubernetes.client.V1JobSpec( template=template, backoff_limit=K8S_JOB_BACKOFF_LIMIT, ttl_seconds_after_finished=K8S_JOB_TTL_SECONDS_AFTER_FINISHED, **user_defined_k8s_config.job_spec_config, ), **user_defined_k8s_config.job_config, ) return job def get_k8s_job_name(input_1, input_2=None): """Creates a unique (short!) identifier to name k8s objects based on run ID and step key(s). K8s Job names are limited to 63 characters, because they are used as labels. For more info, see: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/ """ check.str_param(input_1, "input_1") check.opt_str_param(input_2, "input_2") if not input_2: letters = string.ascii_lowercase input_2 = "".join(random.choice(letters) for i in range(20)) # Creates 32-bit signed int, so could be negative - name_hash = hashlib.md5(six.ensure_binary(input_1 + input_2)) + name_hash = hashlib.md5((input_1 + input_2).encode("utf-8")) return name_hash.hexdigest() diff --git a/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py b/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py index da2b0df53..a2a8e0960 100644 --- a/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py +++ b/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py @@ -1,154 +1,153 @@ import logging import time from contextlib import contextmanager from urllib.parse import quote_plus as urlquote import psycopg2 import psycopg2.errorcodes -import six import sqlalchemy from dagster import Field, IntSource, Selector, StringSource, check from dagster.core.storage.sql import get_alembic_config, handle_schema_errors class DagsterPostgresException(Exception): pass def get_conn(conn_string): conn = psycopg2.connect(conn_string) conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) return conn def pg_config(): return Selector( { "postgres_url": str, "postgres_db": { "username": StringSource, "password": StringSource, "hostname": StringSource, "db_name": StringSource, "port": Field(IntSource, is_required=False, default_value=5432), }, } ) def pg_url_from_config(config_value): if config_value.get("postgres_url"): return config_value["postgres_url"] return get_conn_string(**config_value["postgres_db"]) def get_conn_string(username, password, hostname, db_name, port="5432"): return "postgresql://{username}:{password}@{hostname}:{port}/{db_name}".format( username=username, password=urlquote(password), hostname=hostname, db_name=db_name, port=port, ) def retry_pg_creation_fn(fn, retry_limit=5, retry_wait=0.2): # Retry logic to recover from the case where two processes are creating # tables at the same time using sqlalchemy check.callable_param(fn, "fn") check.int_param(retry_limit, "retry_limit") check.numeric_param(retry_wait, "retry_wait") while True: try: return fn() except ( psycopg2.ProgrammingError, psycopg2.IntegrityError, sqlalchemy.exc.ProgrammingError, sqlalchemy.exc.IntegrityError, ) as exc: # Only programming error we want to retry on is the DuplicateTable error if ( isinstance(exc, sqlalchemy.exc.ProgrammingError) and exc.orig and exc.orig.pgcode != psycopg2.errorcodes.DUPLICATE_TABLE ) or ( isinstance(exc, psycopg2.ProgrammingError) and exc.pgcode != psycopg2.errorcodes.DUPLICATE_TABLE ): raise logging.warning("Retrying failed database creation") if retry_limit == 0: - six.raise_from(DagsterPostgresException("too many retries for DB creation"), exc) + raise DagsterPostgresException("too many retries for DB creation") from exc time.sleep(retry_wait) retry_limit -= 1 def retry_pg_connection_fn(fn, retry_limit=5, retry_wait=0.2): """Reusable retry logic for any psycopg2/sqlalchemy PG connection functions that may fail. Intended to be used anywhere we connect to PG, to gracefully handle transient connection issues. """ check.callable_param(fn, "fn") check.int_param(retry_limit, "retry_limit") check.numeric_param(retry_wait, "retry_wait") while True: try: return fn() except ( # See: https://www.psycopg.org/docs/errors.html # These are broad, we may want to list out specific exceptions to capture psycopg2.DatabaseError, psycopg2.OperationalError, sqlalchemy.exc.DatabaseError, sqlalchemy.exc.OperationalError, ) as exc: logging.warning("Retrying failed database connection") if retry_limit == 0: - six.raise_from(DagsterPostgresException("too many retries for DB connection"), exc) + raise DagsterPostgresException("too many retries for DB connection") from exc time.sleep(retry_wait) retry_limit -= 1 def wait_for_connection(conn_string, retry_limit=5, retry_wait=0.2): retry_pg_connection_fn( lambda: psycopg2.connect(conn_string), retry_limit=retry_limit, retry_wait=retry_wait ) return True @contextmanager def create_pg_connection(engine, dunder_file, storage_type_desc=None): check.inst_param(engine, "engine", sqlalchemy.engine.Engine) check.str_param(dunder_file, "dunder_file") check.opt_str_param(storage_type_desc, "storage_type_desc", "") if storage_type_desc: storage_type_desc += " " else: storage_type_desc = "" conn = None try: # Retry connection to gracefully handle transient connection issues conn = retry_pg_connection_fn(engine.connect) with handle_schema_errors( conn, get_alembic_config(dunder_file), msg="Postgres {}storage requires migration".format(storage_type_desc), ): yield conn finally: if conn: conn.close() def pg_statement_timeout(millis): check.int_param(millis, "millis") return "-c statement_timeout={}".format(millis) diff --git a/python_modules/libraries/dagster-shell/dagster_shell/utils.py b/python_modules/libraries/dagster-shell/dagster_shell/utils.py index 5544c9e4a..2ed3707d1 100644 --- a/python_modules/libraries/dagster-shell/dagster_shell/utils.py +++ b/python_modules/libraries/dagster-shell/dagster_shell/utils.py @@ -1,151 +1,150 @@ # # NOTE: This file is based on the bash operator from Apache Airflow, which can be found here: # https://github.com/apache/airflow/blob/master/airflow/operators/bash.py # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import os import signal from subprocess import PIPE, STDOUT, Popen -import six from dagster import check from dagster.utils import safe_tempfile_path def execute_script_file(shell_script_path, output_logging, log, cwd=None, env=None): """Execute a shell script file specified by the argument ``shell_command``. The script will be invoked via ``subprocess.Popen(['bash', shell_script_path], ...)``. In the Popen invocation, ``stdout=PIPE, stderr=STDOUT`` is used, and the combined stdout/stderr output is retrieved. Args: shell_command (str): The shell command to execute output_logging (str): The logging mode to use. Supports STREAM, BUFFER, and NONE. log (Union[logging.Logger, DagsterLogManager]): Any logger which responds to .info() cwd (str, optional): Working directory for the shell command to use. Defaults to the temporary path where we store the shell command in a script file. env (Dict[str, str], optional): Environment dictionary to pass to ``subprocess.Popen``. Unused by default. Raises: Exception: When an invalid output_logging is selected. Unreachable from solid-based invocation since the config system will check output_logging against the config enum. Returns: str: The combined stdout/stderr output of running the shell script. """ check.str_param(shell_script_path, "shell_script_path") check.str_param(output_logging, "output_logging") check.opt_str_param(cwd, "cwd", default=os.path.dirname(shell_script_path)) env = check.opt_dict_param(env, "env") def pre_exec(): # Restore default signal disposition and invoke setsid for sig in ("SIGPIPE", "SIGXFZ", "SIGXFSZ"): if hasattr(signal, sig): signal.signal(getattr(signal, sig), signal.SIG_DFL) os.setsid() with open(shell_script_path, "rb") as f: - shell_command = six.ensure_str(f.read()) + shell_command = f.read().decode("utf-8") log.info("Running command:\n{command}".format(command=shell_command)) # pylint: disable=subprocess-popen-preexec-fn sub_process = Popen( ["bash", shell_script_path], stdout=PIPE, stderr=STDOUT, cwd=cwd, env=env, preexec_fn=pre_exec, ) # Will return the string result of reading stdout of the shell command output = "" if output_logging not in ["STREAM", "BUFFER", "NONE"]: raise Exception("Unrecognized output_logging %s" % output_logging) # Stream back logs as they are emitted if output_logging == "STREAM": for raw_line in iter(sub_process.stdout.readline, b""): - line = six.ensure_str(raw_line) + line = raw_line.decode("utf-8") log.info(line.rstrip()) output += line sub_process.wait() # Collect and buffer all logs, then emit if output_logging == "BUFFER": output = "".join( - [six.ensure_str(raw_line) for raw_line in iter(sub_process.stdout.readline, b"")] + [raw_line.decode("utf-8") for raw_line in iter(sub_process.stdout.readline, b"")] ) log.info(output) # no logging in this case elif output_logging == "NONE": pass log.info("Command exited with return code {retcode}".format(retcode=sub_process.returncode)) return output, sub_process.returncode def execute(shell_command, output_logging, log, cwd=None, env=None): """Execute a shell script specified by the argument ``shell_command``. The script will be written to a temporary file first and invoked via ``subprocess.Popen(['bash', shell_script_path], ...)``. In the Popen invocation, ``stdout=PIPE, stderr=STDOUT`` is used, and the combined stdout/stderr output is retrieved. Args: shell_command (str): The shell command to execute output_logging (str): The logging mode to use. Supports STREAM, BUFFER, and NONE. log (Union[logging.Logger, DagsterLogManager]): Any logger which responds to .info() cwd (str, optional): Working directory for the shell command to use. Defaults to the temporary path where we store the shell command in a script file. env (Dict[str, str], optional): Environment dictionary to pass to ``subprocess.Popen``. Unused by default. Returns: str: The combined stdout/stderr output of running the shell command. """ check.str_param(shell_command, "shell_command") # other args checked in execute_file with safe_tempfile_path() as tmp_file_path: tmp_path = os.path.dirname(tmp_file_path) log.info("Using temporary directory: %s" % tmp_path) with open(tmp_file_path, "wb") as tmp_file: - tmp_file.write(six.ensure_binary(shell_command)) + tmp_file.write(shell_command.encode("utf-8")) tmp_file.flush() script_location = os.path.abspath(tmp_file.name) log.info("Temporary script location: {location}".format(location=script_location)) return execute_script_file( shell_script_path=tmp_file.name, output_logging=output_logging, log=log, cwd=(cwd or tmp_path), env=env, ) diff --git a/python_modules/libraries/dagster-shell/dagster_shell_tests/conftest.py b/python_modules/libraries/dagster-shell/dagster_shell_tests/conftest.py index 8bd272a52..54eb1e1b0 100644 --- a/python_modules/libraries/dagster-shell/dagster_shell_tests/conftest.py +++ b/python_modules/libraries/dagster-shell/dagster_shell_tests/conftest.py @@ -1,17 +1,16 @@ import contextlib from tempfile import NamedTemporaryFile import pytest -import six @pytest.fixture(scope="function") def tmp_file(tmpdir): @contextlib.contextmanager def _tmp_file_cm(file_contents): with NamedTemporaryFile(dir=str(tmpdir)) as f: - f.write(six.ensure_binary(file_contents)) + f.write(file_contents.encode("utf-8")) f.flush() yield str(tmpdir), f.name return _tmp_file_cm diff --git a/python_modules/libraries/dagster-ssh/dagster_ssh/resources.py b/python_modules/libraries/dagster-ssh/dagster_ssh/resources.py index 92f0467ef..f63decf8e 100644 --- a/python_modules/libraries/dagster-ssh/dagster_ssh/resources.py +++ b/python_modules/libraries/dagster-ssh/dagster_ssh/resources.py @@ -1,256 +1,256 @@ import getpass import os +from io import StringIO import paramiko from dagster import Field, StringSource, check, resource from dagster.utils import merge_dicts, mkdir_p from paramiko.config import SSH_PORT -from six import StringIO from sshtunnel import SSHTunnelForwarder def key_from_str(key_str): """Creates a paramiko SSH key from a string.""" check.str_param(key_str, "key_str") # py2 StringIO doesn't support with key_file = StringIO(key_str) result = paramiko.RSAKey.from_private_key(key_file) key_file.close() return result class SSHResource: """ Resource for ssh remote execution using Paramiko. ref: https://github.com/paramiko/paramiko """ def __init__( self, remote_host, remote_port, username=None, password=None, key_file=None, key_string=None, timeout=10, keepalive_interval=30, compress=True, no_host_key_check=True, allow_host_key_change=False, logger=None, ): self.remote_host = check.str_param(remote_host, "remote_host") self.remote_port = check.opt_int_param(remote_port, "remote_port") self.username = check.opt_str_param(username, "username") self.password = check.opt_str_param(password, "password") self.key_file = check.opt_str_param(key_file, "key_file") self.timeout = check.opt_int_param(timeout, "timeout") self.keepalive_interval = check.opt_int_param(keepalive_interval, "keepalive_interval") self.compress = check.opt_bool_param(compress, "compress") self.no_host_key_check = check.opt_bool_param(no_host_key_check, "no_host_key_check") self.allow_host_key_change = check.opt_bool_param( allow_host_key_change, "allow_host_key_change" ) self.log = logger self.host_proxy = None # Create RSAKey object from private key string self.key_obj = key_from_str(key_string) if key_string is not None else None # Auto detecting username values from system if not self.username: logger.debug( "username to ssh to host: %s is not specified. Using system's default provided by" " getpass.getuser()" % self.remote_host ) self.username = getpass.getuser() user_ssh_config_filename = os.path.expanduser("~/.ssh/config") if os.path.isfile(user_ssh_config_filename): ssh_conf = paramiko.SSHConfig() ssh_conf.parse(open(user_ssh_config_filename)) host_info = ssh_conf.lookup(self.remote_host) if host_info and host_info.get("proxycommand"): self.host_proxy = paramiko.ProxyCommand(host_info.get("proxycommand")) if not (self.password or self.key_file): if host_info and host_info.get("identityfile"): self.key_file = host_info.get("identityfile")[0] def get_connection(self): """ Opens a SSH connection to the remote host. :rtype: paramiko.client.SSHClient """ client = paramiko.SSHClient() if not self.allow_host_key_change: self.log.warning( "Remote Identification Change is not verified. This won't protect against " "Man-In-The-Middle attacks" ) client.load_system_host_keys() if self.no_host_key_check: self.log.warning( "No Host Key Verification. This won't protect against Man-In-The-Middle attacks" ) # Default is RejectPolicy client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if self.password and self.password.strip(): client.connect( hostname=self.remote_host, username=self.username, password=self.password, key_filename=self.key_file, pkey=self.key_obj, timeout=self.timeout, compress=self.compress, port=self.remote_port, sock=self.host_proxy, look_for_keys=False, ) else: client.connect( hostname=self.remote_host, username=self.username, key_filename=self.key_file, pkey=self.key_obj, timeout=self.timeout, compress=self.compress, port=self.remote_port, sock=self.host_proxy, ) if self.keepalive_interval: client.get_transport().set_keepalive(self.keepalive_interval) return client def get_tunnel(self, remote_port, remote_host="localhost", local_port=None): check.int_param(remote_port, "remote_port") check.str_param(remote_host, "remote_host") check.opt_int_param(local_port, "local_port") if local_port is not None: local_bind_address = ("localhost", local_port) else: local_bind_address = ("localhost",) # Will prefer key string if specified, otherwise use the key file pkey = self.key_obj if self.key_obj else self.key_file if self.password and self.password.strip(): client = SSHTunnelForwarder( self.remote_host, ssh_port=self.remote_port, ssh_username=self.username, ssh_password=self.password, ssh_pkey=pkey, ssh_proxy=self.host_proxy, local_bind_address=local_bind_address, remote_bind_address=(remote_host, remote_port), logger=self.log, ) else: client = SSHTunnelForwarder( self.remote_host, ssh_port=self.remote_port, ssh_username=self.username, ssh_pkey=pkey, ssh_proxy=self.host_proxy, local_bind_address=local_bind_address, remote_bind_address=(remote_host, remote_port), host_pkey_directories=[], logger=self.log, ) return client def sftp_get(self, remote_filepath, local_filepath): check.str_param(remote_filepath, "remote_filepath") check.str_param(local_filepath, "local_filepath") conn = self.get_connection() with conn.open_sftp() as sftp_client: local_folder = os.path.dirname(local_filepath) # Create intermediate directories if they don't exist mkdir_p(local_folder) self.log.info( "Starting to transfer from {0} to {1}".format(remote_filepath, local_filepath) ) sftp_client.get(remote_filepath, local_filepath) conn.close() return local_filepath def sftp_put(self, remote_filepath, local_filepath, confirm=True): check.str_param(remote_filepath, "remote_filepath") check.str_param(local_filepath, "local_filepath") conn = self.get_connection() with conn.open_sftp() as sftp_client: self.log.info( "Starting to transfer file from {0} to {1}".format(local_filepath, remote_filepath) ) sftp_client.put(local_filepath, remote_filepath, confirm=confirm) conn.close() return local_filepath @resource( { "remote_host": Field( StringSource, description="remote host to connect to", is_required=True ), "remote_port": Field( int, description="port of remote host to connect (Default is paramiko SSH_PORT)", is_required=False, default_value=SSH_PORT, ), "username": Field( StringSource, description="username to connect to the remote_host", is_required=False ), "password": Field( StringSource, description="password of the username to connect to the remote_host", is_required=False, ), "key_file": Field( StringSource, description="key file to use to connect to the remote_host.", is_required=False, ), "key_string": Field( StringSource, description="key string to use to connect to remote_host", is_required=False, ), "timeout": Field( int, description="timeout for the attempt to connect to the remote_host.", is_required=False, default_value=10, ), "keepalive_interval": Field( int, description="send a keepalive packet to remote host every keepalive_interval seconds", is_required=False, default_value=30, ), "compress": Field(bool, is_required=False, default_value=True), "no_host_key_check": Field(bool, is_required=False, default_value=True), "allow_host_key_change": Field(bool, is_required=False, default_value=False), } ) def ssh_resource(init_context): args = init_context.resource_config args = merge_dicts(init_context.resource_config, {"logger": init_context.log_manager}) return SSHResource(**args) diff --git a/python_modules/libraries/dagster-ssh/dagster_ssh_tests/test_resources.py b/python_modules/libraries/dagster-ssh/dagster_ssh_tests/test_resources.py index 4fce259e8..08261d5f3 100644 --- a/python_modules/libraries/dagster-ssh/dagster_ssh_tests/test_resources.py +++ b/python_modules/libraries/dagster-ssh/dagster_ssh_tests/test_resources.py @@ -1,257 +1,254 @@ import logging import os -import six from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from dagster import Field, ModeDefinition, execute_solid, solid from dagster.seven import get_system_temp_directory, mock from dagster_ssh.resources import SSHResource, key_from_str from dagster_ssh.resources import ssh_resource as sshresource def generate_ssh_key(): # generate private/public key pair key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048) # get private key in PEM container format - return six.ensure_str( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - ) + return key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") @mock.patch("paramiko.SSHClient") def test_ssh_connection_with_password(ssh_mock): ssh_resource = SSHResource( remote_host="remote_host", remote_port=12345, username="username", password="password", key_file="fake.file", timeout=10, keepalive_interval=30, compress=True, no_host_key_check=False, allow_host_key_change=False, logger=logging.root.getChild("test_resources"), ) with ssh_resource.get_connection(): ssh_mock.return_value.connect.assert_called_once_with( hostname="remote_host", username="username", password="password", pkey=None, key_filename="fake.file", timeout=10, compress=True, port=12345, sock=None, look_for_keys=False, ) @mock.patch("paramiko.SSHClient") def test_ssh_connection_without_password(ssh_mock): ssh_resource = SSHResource( remote_host="remote_host", remote_port=12345, username="username", password=None, timeout=10, key_file="fake.file", keepalive_interval=30, compress=True, no_host_key_check=False, allow_host_key_change=False, logger=logging.root.getChild("test_resources"), ) with ssh_resource.get_connection(): ssh_mock.return_value.connect.assert_called_once_with( hostname="remote_host", username="username", pkey=None, key_filename="fake.file", timeout=10, compress=True, port=12345, sock=None, ) @mock.patch("paramiko.SSHClient") def test_ssh_connection_with_key_string(ssh_mock): ssh_key = generate_ssh_key() ssh_resource = SSHResource( remote_host="remote_host", remote_port=12345, username="username", password=None, timeout=10, - key_string=six.ensure_str(ssh_key), + key_string=ssh_key, keepalive_interval=30, compress=True, no_host_key_check=False, allow_host_key_change=False, logger=logging.root.getChild("test_resources"), ) with ssh_resource.get_connection(): ssh_mock.return_value.connect.assert_called_once_with( hostname="remote_host", username="username", key_filename=None, pkey=key_from_str(ssh_key), timeout=10, compress=True, port=12345, sock=None, ) @mock.patch("dagster_ssh.resources.SSHTunnelForwarder") def test_tunnel_with_password(ssh_mock): ssh_resource = SSHResource( remote_host="remote_host", remote_port=12345, username="username", password="password", timeout=10, key_file="fake.file", keepalive_interval=30, compress=True, no_host_key_check=False, allow_host_key_change=False, logger=logging.root.getChild("test_resources"), ) with ssh_resource.get_tunnel(1234): ssh_mock.assert_called_once_with( "remote_host", ssh_port=12345, ssh_username="username", ssh_password="password", ssh_pkey="fake.file", ssh_proxy=None, local_bind_address=("localhost",), remote_bind_address=("localhost", 1234), logger=ssh_resource.log, ) @mock.patch("dagster_ssh.resources.SSHTunnelForwarder") def test_tunnel_without_password(ssh_mock): ssh_resource = SSHResource( remote_host="remote_host", remote_port=12345, username="username", password=None, timeout=10, key_file="fake.file", keepalive_interval=30, compress=True, no_host_key_check=False, allow_host_key_change=False, logger=logging.root.getChild("test_resources"), ) with ssh_resource.get_tunnel(1234): ssh_mock.assert_called_once_with( "remote_host", ssh_port=12345, ssh_username="username", ssh_pkey="fake.file", ssh_proxy=None, local_bind_address=("localhost",), remote_bind_address=("localhost", 1234), host_pkey_directories=[], logger=ssh_resource.log, ) @mock.patch("dagster_ssh.resources.SSHTunnelForwarder") def test_tunnel_with_string_key(ssh_mock): ssh_key = generate_ssh_key() ssh_resource = SSHResource( remote_host="remote_host", remote_port=12345, username="username", password=None, timeout=10, - key_string=six.ensure_str(ssh_key), + key_string=ssh_key, keepalive_interval=30, compress=True, no_host_key_check=False, allow_host_key_change=False, logger=logging.root.getChild("test_resources"), ) with ssh_resource.get_tunnel(1234): ssh_mock.assert_called_once_with( "remote_host", ssh_port=12345, ssh_username="username", ssh_pkey=key_from_str(ssh_key), ssh_proxy=None, local_bind_address=("localhost",), remote_bind_address=("localhost", 1234), host_pkey_directories=[], logger=ssh_resource.log, ) def test_ssh_sftp(sftpserver): tmp_path = get_system_temp_directory() readme_file = os.path.join(tmp_path, "readme.txt") @solid( config_schema={ "local_filepath": Field(str, is_required=True, description="local file path to get"), "remote_filepath": Field(str, is_required=True, description="remote file path to get"), }, required_resource_keys={"ssh_resource"}, ) def sftp_solid_get(context): local_filepath = context.solid_config.get("local_filepath") remote_filepath = context.solid_config.get("remote_filepath") return context.resources.ssh_resource.sftp_get(remote_filepath, local_filepath) with sftpserver.serve_content({"a_dir": {"readme.txt": "hello, world"}}): result = execute_solid( sftp_solid_get, ModeDefinition(resource_defs={"ssh_resource": sshresource}), run_config={ "solids": { "sftp_solid_get": { "config": { "local_filepath": readme_file, "remote_filepath": "a_dir/readme.txt", } } }, "resources": { "ssh_resource": { "config": { "remote_host": sftpserver.host, "remote_port": sftpserver.port, "username": "user", "password": "pw", "no_host_key_check": True, } } }, }, ) assert result.success with open(readme_file, "rb") as f: contents = f.read() assert b"hello, world" in contents diff --git a/python_modules/libraries/dagstermill/dagstermill/manager.py b/python_modules/libraries/dagstermill/dagstermill/manager.py index 2818885cd..594dc20a2 100644 --- a/python_modules/libraries/dagstermill/dagstermill/manager.py +++ b/python_modules/libraries/dagstermill/dagstermill/manager.py @@ -1,319 +1,315 @@ import os import pickle import uuid -import six from dagster import ( AssetMaterialization, ExpectationResult, Failure, Materialization, ModeDefinition, PipelineDefinition, SolidDefinition, TypeCheck, check, seven, ) from dagster.core.definitions.dependency import SolidHandle from dagster.core.definitions.reconstructable import ReconstructablePipeline from dagster.core.definitions.resource import ScopedResourcesBuilder from dagster.core.execution.api import create_execution_plan, scoped_pipeline_context from dagster.core.execution.resources_init import ( get_required_resource_keys_to_init, resource_initialization_event_generator, ) from dagster.core.instance import DagsterInstance from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus from dagster.core.utils import make_new_run_id from dagster.loggers import colored_console_logger from dagster.serdes import unpack_value from dagster.utils import EventGenerationManager from .context import DagstermillExecutionContext, DagstermillRuntimeExecutionContext from .errors import DagstermillError from .serialize import PICKLE_PROTOCOL, read_value, write_value class DagstermillResourceEventGenerationManager(EventGenerationManager): """ Utility class to explicitly manage setup/teardown of resource events. Overrides the default `generate_teardown_events` method so that teardown is deferred until explicitly called by the dagstermill Manager """ def generate_teardown_events(self): return iter(()) def teardown(self): return [ teardown_event for teardown_event in super( DagstermillResourceEventGenerationManager, self ).generate_teardown_events() ] class Manager: def __init__(self): self.pipeline = None self.solid_def = None self.in_pipeline = False self.marshal_dir = None self.context = None self.resource_manager = None def _setup_resources( self, execution_plan, environment_config, pipeline_run, log_manager, resource_keys_to_init, instance, resource_instances_to_override, ): """ Drop-in replacement for `dagster.core.execution.resources_init.resource_initialization_manager`. It uses a `DagstermillResourceEventGenerationManager` and explicitly calls `teardown` on it """ generator = resource_initialization_event_generator( execution_plan, environment_config, pipeline_run, log_manager, resource_keys_to_init, instance, resource_instances_to_override, ) self.resource_manager = DagstermillResourceEventGenerationManager( generator, ScopedResourcesBuilder ) return self.resource_manager def reconstitute_pipeline_context( self, output_log_path=None, marshal_dir=None, run_config=None, executable_dict=None, pipeline_run_dict=None, solid_handle_kwargs=None, instance_ref_dict=None, ): """Reconstitutes a context for dagstermill-managed execution. You'll see this function called to reconstruct a pipeline context within the ``injected parameters`` cell of a dagstermill output notebook. Users should not call this function interactively except when debugging output notebooks. Use :func:`dagstermill.get_context` in the ``parameters`` cell of your notebook to define a context for interactive exploration and development. This call will be replaced by one to :func:`dagstermill.reconstitute_pipeline_context` when the notebook is executed by dagstermill. """ check.opt_str_param(output_log_path, "output_log_path") check.opt_str_param(marshal_dir, "marshal_dir") run_config = check.opt_dict_param(run_config, "run_config", key_type=str) check.dict_param(pipeline_run_dict, "pipeline_run_dict") check.dict_param(executable_dict, "executable_dict") check.dict_param(solid_handle_kwargs, "solid_handle_kwargs") check.dict_param(instance_ref_dict, "instance_ref_dict") pipeline = ReconstructablePipeline.from_dict(executable_dict) pipeline_def = pipeline.get_definition() try: instance_ref = unpack_value(instance_ref_dict) instance = DagsterInstance.from_ref(instance_ref) except Exception as err: # pylint: disable=broad-except - six.raise_from( - DagstermillError( - "Error when attempting to resolve DagsterInstance from serialized InstanceRef" - ), - err, - ) + raise DagstermillError( + "Error when attempting to resolve DagsterInstance from serialized InstanceRef" + ) from err pipeline_run = unpack_value(pipeline_run_dict) solid_handle = SolidHandle.from_dict(solid_handle_kwargs) solid_def = pipeline_def.get_solid(solid_handle).definition self.marshal_dir = marshal_dir self.in_pipeline = True self.solid_def = solid_def self.pipeline = pipeline execution_plan = create_execution_plan( self.pipeline, run_config, mode=pipeline_run.mode, step_keys_to_execute=pipeline_run.step_keys_to_execute, ) with scoped_pipeline_context( execution_plan, run_config, pipeline_run, instance, scoped_resources_builder_cm=self._setup_resources, # Set this flag even though we're not in test for clearer error reporting raise_on_error=True, ) as pipeline_context: self.context = DagstermillRuntimeExecutionContext( pipeline_context=pipeline_context, solid_config=run_config.get("solids", {}).get(solid_def.name, {}).get("config"), resource_keys_to_init=get_required_resource_keys_to_init( execution_plan, pipeline_context.intermediate_storage_def, ), solid_name=solid_def.name, ) return self.context def get_context(self, solid_config=None, mode_def=None, run_config=None): """Get a dagstermill execution context for interactive exploration and development. Args: solid_config (Optional[Any]): If specified, this value will be made available on the context as its ``solid_config`` property. mode_def (Optional[:class:`dagster.ModeDefinition`]): If specified, defines the mode to use to construct the context. Specify this if you would like a context constructed with specific ``resource_defs`` or ``logger_defs``. By default, an ephemeral mode with a console logger will be constructed. run_config(Optional[dict]): The environment config dict with which to construct the context. Returns: :py:class:`~dagstermill.DagstermillExecutionContext` """ check.opt_inst_param(mode_def, "mode_def", ModeDefinition) run_config = check.opt_dict_param(run_config, "run_config", key_type=str) # If we are running non-interactively, and there is already a context reconstituted, return # that context rather than overwriting it. if self.context is not None and isinstance( self.context, DagstermillRuntimeExecutionContext ): return self.context if not mode_def: mode_def = ModeDefinition(logger_defs={"dagstermill": colored_console_logger}) run_config["loggers"] = {"dagstermill": {}} solid_def = SolidDefinition( name="this_solid", input_defs=[], compute_fn=lambda *args, **kwargs: None, output_defs=[], description="Ephemeral solid constructed by dagstermill.get_context()", required_resource_keys=mode_def.resource_key_set, ) pipeline_def = PipelineDefinition( [solid_def], mode_defs=[mode_def], name="ephemeral_dagstermill_pipeline" ) run_id = make_new_run_id() # construct stubbed PipelineRun for notebook exploration... # The actual pipeline run during pipeline execution will be serialized and reconstituted # in the `reconstitute_pipeline_context` call pipeline_run = PipelineRun( pipeline_name=pipeline_def.name, run_id=run_id, run_config=run_config, mode=mode_def.name, step_keys_to_execute=None, status=PipelineRunStatus.NOT_STARTED, tags=None, ) self.in_pipeline = False self.solid_def = solid_def self.pipeline = pipeline_def execution_plan = create_execution_plan(self.pipeline, run_config, mode=mode_def.name) with scoped_pipeline_context( execution_plan, run_config, pipeline_run, DagsterInstance.ephemeral(), scoped_resources_builder_cm=self._setup_resources, ) as pipeline_context: self.context = DagstermillExecutionContext( pipeline_context=pipeline_context, solid_config=solid_config, resource_keys_to_init=get_required_resource_keys_to_init( execution_plan, pipeline_context.intermediate_storage_def, ), solid_name=solid_def.name, ) return self.context def yield_result(self, value, output_name="result"): """Yield a result directly from notebook code. When called interactively or in development, returns its input. Args: value (Any): The value to yield. output_name (Optional[str]): The name of the result to yield (default: ``'result'``). """ if not self.in_pipeline: return value # deferred import for perf import scrapbook if not self.solid_def.has_output(output_name): raise DagstermillError( f"Solid {self.solid_def.name} does not have output named {output_name}." f"Expected one of {[str(output_def.name) for output_def in self.solid_def.output_defs]}" ) dagster_type = self.solid_def.output_def_named(output_name).dagster_type out_file = os.path.join(self.marshal_dir, f"output-{output_name}") scrapbook.glue(output_name, write_value(dagster_type, value, out_file)) def yield_event(self, dagster_event): """Yield a dagster event directly from notebook code. When called interactively or in development, returns its input. Args: dagster_event (Union[:class:`dagster.Materialization`, :class:`dagster.ExpectationResult`, :class:`dagster.TypeCheck`, :class:`dagster.Failure`]): An event to yield back to Dagster. """ check.inst_param( dagster_event, "dagster_event", (AssetMaterialization, Materialization, ExpectationResult, TypeCheck, Failure), ) if not self.in_pipeline: return dagster_event # deferred import for perf import scrapbook event_id = "event-{event_uuid}".format(event_uuid=str(uuid.uuid4())) out_file_path = os.path.join(self.marshal_dir, event_id) with open(out_file_path, "wb") as fd: fd.write(pickle.dumps(dagster_event, PICKLE_PROTOCOL)) scrapbook.glue(event_id, out_file_path) def teardown_resources(self): if self.resource_manager is not None: self.resource_manager.teardown() def load_parameter(self, input_name, input_value): input_def = self.solid_def.input_def_named(input_name) return read_value(input_def.dagster_type, seven.json.loads(input_value)) MANAGER_FOR_NOTEBOOK_INSTANCE = Manager() diff --git a/python_modules/libraries/dagstermill/setup.py b/python_modules/libraries/dagstermill/setup.py index ada2252e4..b79eaec76 100644 --- a/python_modules/libraries/dagstermill/setup.py +++ b/python_modules/libraries/dagstermill/setup.py @@ -1,38 +1,37 @@ from setuptools import find_packages, setup def get_version(): version = {} with open("dagstermill/version.py") as fp: exec(fp.read(), version) # pylint: disable=W0122 return version["__version__"] if __name__ == "__main__": setup( name="dagstermill", version=get_version(), author="Elementl", author_email="hello@elementl.com", license="Apache-2.0", packages=find_packages(exclude=["dagstermill_tests"]), classifiers=[ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], install_requires=[ "dagster", # ipykernel pinned until https://github.com/nteract/papermill/issues/519 is resolved. # See https://github.com/dagster-io/dagster/issues/3401 "ipykernel>=4.9.0,<=5.3.4", "nbconvert>=5.4.0,<6.0.0", "nteract-scrapbook>=0.2.0", "papermill>=1.0.0,<2.0.0", - "six", ], entry_points={"console_scripts": ["dagstermill = dagstermill.cli:main"]}, ) diff --git a/scripts/install_dev_python_modules.py b/scripts/install_dev_python_modules.py index 995ec77a1..10ef26bd3 100644 --- a/scripts/install_dev_python_modules.py +++ b/scripts/install_dev_python_modules.py @@ -1,133 +1,133 @@ import os import subprocess import sys def is_39(): return sys.version_info >= (3, 9) def main(quiet): """ Python 3.9 Notes ================ Especially on macOS, there are still many missing wheels for Python 3.9, which means that some dependencies may have to be built from source. You may find yourself needing to install system packages such as freetype, gfortran, etc.; on macOS, Homebrew should suffice. Tensorflow is still not available for 3.9 (2020-12-10), so we have put conditional logic in place around examples, etc., that make use of it. https://github.com/tensorflow/tensorflow/issues/44485 Pyarrow is still not available for 3.9 (2020-12-10). https://github.com/apache/arrow/pull/8386 As a consequence of pyarrow, the snowflake connector also is not yet avaialble for 3.9 (2020-12-10). https://github.com/snowflakedb/snowflake-connector-python/issues/562 """ # Previously, we did a pip install --upgrade pip here. We have removed that and instead # depend on the user to ensure an up-to-date pip is installed and available. For context, there # is a lengthy discussion here: https://github.com/pypa/pip/issues/5599 # On machines with less memory, pyspark install will fail... see: # https://stackoverflow.com/a/31526029/11295366 cmd = ["pip", "--no-cache-dir", "install", "pyspark>=3.0.1"] if quiet: cmd.append(quiet) p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) print(" ".join(cmd)) # pylint: disable=print-call while True: output = p.stdout.readline() if p.poll() is not None: break if output: - print(output.decode().strip()) # pylint: disable=print-call + print(output.decode("utf-8").strip()) # pylint: disable=print-call install_targets = [] # Need to do this for 3.9 compat # This is to ensure we can build Pandas on 3.9 # See: https://github.com/numpy/numpy/issues/17784, if is_39(): install_targets += ["Cython==0.29.21", "numpy==1.18.5"] install_targets += [ "awscli", "-e python_modules/dagster", "-e python_modules/dagster-graphql", "-e python_modules/dagster-test", "-e python_modules/dagit", "-e python_modules/automation", "-e python_modules/libraries/dagster-pandas", "-e python_modules/libraries/dagster-aws", "-e python_modules/libraries/dagster-celery", "-e python_modules/libraries/dagster-celery-docker", "-e python_modules/libraries/dagster-cron", '-e "python_modules/libraries/dagster-dask[yarn,pbs,kube]"', "-e python_modules/libraries/dagster-datadog", "-e python_modules/libraries/dagster-dbt", "-e python_modules/libraries/dagster-docker", "-e python_modules/libraries/dagster-gcp", "-e python_modules/libraries/dagster-k8s", "-e python_modules/libraries/dagster-celery-k8s", "-e python_modules/libraries/dagster-pagerduty", "-e python_modules/libraries/dagster-papertrail", "-e python_modules/libraries/dagster-postgres", "-e python_modules/libraries/dagster-prometheus", "-e python_modules/libraries/dagster-spark", "-e python_modules/libraries/dagster-pyspark", "-e python_modules/libraries/dagster-databricks", "-e python_modules/libraries/dagster-shell", "-e python_modules/libraries/dagster-slack", "-e python_modules/libraries/dagster-ssh", "-e python_modules/libraries/dagster-twilio", "-e python_modules/libraries/lakehouse", "-r python_modules/dagster/dev-requirements.txt", "-r python_modules/libraries/dagster-aws/dev-requirements.txt", "-e integration_tests/python_modules/dagster-k8s-test-infra", "-r scala_modules/scripts/requirements.txt", # # https://github.com/dagster-io/dagster/issues/3488 # "-e python_modules/libraries/dagster-airflow", # # https://github.com/dagster-io/dagster/pull/2483#issuecomment-635174157 # Uncomment only when snowflake-connector-python can be installed with optional (or # compatible) Azure dependencies. # "-e python_modules/libraries/dagster-azure", ] # dagster-ge depends on a great_expectations version that does not install on Windows # https://github.com/dagster-io/dagster/issues/3319 if not os.name == "nt": install_targets += ["-e python_modules/libraries/dagster-ge"] if not is_39(): install_targets += [ "-e python_modules/libraries/dagster-snowflake", "-e python_modules/libraries/dagstermill", '-e "examples/legacy_examples[full]"', '-e "examples/airline_demo[full]"', "-r docs-requirements.txt", ] # NOTE: These need to be installed as one long pip install command, otherwise pip will install # conflicting dependencies, which will break pip freeze snapshot creation during the integration # image build! cmd = ["pip", "install"] + install_targets if quiet: cmd.append(quiet) p = subprocess.Popen( " ".join(cmd), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True ) print(" ".join(cmd)) # pylint: disable=print-call while True: output = p.stdout.readline() if p.poll() is not None: break if output: - print(output.decode().strip()) # pylint: disable=print-call + print(output.decode("utf-8").strip()) # pylint: disable=print-call if __name__ == "__main__": main(quiet=sys.argv[1] if len(sys.argv) > 1 else "")