diff --git a/docs-requirements.txt b/docs-requirements.txt
index ae87784d1..7ace5a78e 100644
--- a/docs-requirements.txt
+++ b/docs-requirements.txt
@@ -1,18 +1,18 @@
-Sphinx==2.2.2; python_version >= '3.6'
+Sphinx==2.2.2
sphinx-autobuild==0.7.1
-e ./python_modules/dagster
-e ./python_modules/dagster-graphql
-e ./python_modules/dagit
-e ./python_modules/libraries/dagstermill
-e ./python_modules/libraries/dagster-airflow
-e ./python_modules/libraries/dagster-aws
-e ./python_modules/libraries/dagster-celery
-e ./python_modules/libraries/dagster-celery-docker
-e ./python_modules/libraries/dagster-cron
-e ./python_modules/libraries/dagster-dask
-e ./python_modules/libraries/dagster-docker
-e ./python_modules/libraries/dagster-gcp
-e ./python_modules/libraries/dagster-k8s
-r ./python_modules/dagster/dev-requirements.txt
sphinx-click==2.3.1
recommonmark==0.4.0
diff --git a/examples/airline_demo/airline_demo_tests/test_types.py b/examples/airline_demo/airline_demo_tests/test_types.py
index 0eb4fe5c1..97048fe0d 100644
--- a/examples/airline_demo/airline_demo_tests/test_types.py
+++ b/examples/airline_demo/airline_demo_tests/test_types.py
@@ -1,177 +1,177 @@
import os
+from tempfile import TemporaryDirectory
import botocore
import pyspark
from airline_demo.solids import ingest_csv_file_handle_to_spark
from dagster import (
InputDefinition,
LocalFileHandle,
ModeDefinition,
OutputDefinition,
execute_pipeline,
file_relative_path,
local_file_manager,
pipeline,
- seven,
solid,
)
from dagster.core.definitions.no_step_launcher import no_step_launcher
from dagster.core.instance import DagsterInstance
from dagster.core.storage.intermediate_storage import build_fs_intermediate_storage
from dagster.core.storage.temp_file_manager import tempfile_resource
from dagster_aws.s3 import s3_file_manager, s3_plus_default_intermediate_storage_defs, s3_resource
from dagster_aws.s3.intermediate_storage import S3IntermediateStorage
from dagster_pyspark import DataFrame, pyspark_resource
from pyspark.sql import Row, SparkSession
spark_local_fs_mode = ModeDefinition(
name="spark",
resource_defs={
"pyspark": pyspark_resource,
"tempfile": tempfile_resource,
"s3": s3_resource,
"pyspark_step_launcher": no_step_launcher,
"file_manager": local_file_manager,
},
intermediate_storage_defs=s3_plus_default_intermediate_storage_defs,
)
spark_s3_mode = ModeDefinition(
name="spark",
resource_defs={
"pyspark": pyspark_resource,
"tempfile": tempfile_resource,
"s3": s3_resource,
"pyspark_step_launcher": no_step_launcher,
"file_manager": s3_file_manager,
},
intermediate_storage_defs=s3_plus_default_intermediate_storage_defs,
)
def test_spark_data_frame_serialization_file_system_file_handle(spark_config):
@solid
def nonce(_):
return LocalFileHandle(file_relative_path(__file__, "data/test.csv"))
@pipeline(mode_defs=[spark_local_fs_mode])
def spark_df_test_pipeline():
ingest_csv_file_handle_to_spark(nonce())
instance = DagsterInstance.ephemeral()
result = execute_pipeline(
spark_df_test_pipeline,
mode="spark",
run_config={
"intermediate_storage": {"filesystem": {}},
"resources": {"pyspark": {"config": {"spark_conf": spark_config}}},
},
instance=instance,
)
intermediate_storage = build_fs_intermediate_storage(
instance.intermediates_directory, run_id=result.run_id
)
assert result.success
result_dir = os.path.join(
intermediate_storage.root, "intermediates", "ingest_csv_file_handle_to_spark", "result",
)
assert "_SUCCESS" in os.listdir(result_dir)
spark = SparkSession.builder.getOrCreate()
df = spark.read.parquet(result_dir)
assert isinstance(df, pyspark.sql.dataframe.DataFrame)
assert df.head()[0] == "1"
def test_spark_data_frame_serialization_s3_file_handle(s3_bucket, spark_config):
@solid(required_resource_keys={"file_manager"})
def nonce(context):
with open(os.path.join(os.path.dirname(__file__), "data/test.csv"), "rb") as fd:
return context.resources.file_manager.write_data(fd.read())
@pipeline(mode_defs=[spark_s3_mode])
def spark_df_test_pipeline():
ingest_csv_file_handle_to_spark(nonce())
result = execute_pipeline(
spark_df_test_pipeline,
run_config={
"intermediate_storage": {"s3": {"config": {"s3_bucket": s3_bucket}}},
"resources": {
"pyspark": {"config": {"spark_conf": spark_config}},
"file_manager": {"config": {"s3_bucket": s3_bucket}},
},
},
mode="spark",
)
assert result.success
intermediate_storage = S3IntermediateStorage(s3_bucket=s3_bucket, run_id=result.run_id)
success_key = "/".join(
[
intermediate_storage.root.strip("/"),
"intermediates",
"ingest_csv_file_handle_to_spark",
"result",
"_SUCCESS",
]
)
try:
assert intermediate_storage.object_store.s3.get_object(
Bucket=intermediate_storage.object_store.bucket, Key=success_key
)
except botocore.exceptions.ClientError:
raise Exception("Couldn't find object at {success_key}".format(success_key=success_key))
def test_spark_dataframe_output_csv(spark_config):
spark = SparkSession.builder.getOrCreate()
num_df = (
spark.read.format("csv")
.options(header="true", inferSchema="true")
.load(file_relative_path(__file__, "num.csv"))
)
assert num_df.collect() == [Row(num1=1, num2=2)]
@solid
def emit(_):
return num_df
@solid(input_defs=[InputDefinition("df", DataFrame)], output_defs=[OutputDefinition(DataFrame)])
def passthrough_df(_context, df):
return df
@pipeline(mode_defs=[spark_local_fs_mode])
def passthrough():
passthrough_df(emit())
- with seven.TemporaryDirectory() as tempdir:
+ with TemporaryDirectory() as tempdir:
file_name = os.path.join(tempdir, "output.csv")
result = execute_pipeline(
passthrough,
run_config={
"resources": {"pyspark": {"config": {"spark_conf": spark_config}}},
"solids": {
"passthrough_df": {
"outputs": [{"result": {"csv": {"path": file_name, "header": True}}}]
}
},
},
)
from_file_df = (
spark.read.format("csv").options(header="true", inferSchema="true").load(file_name)
)
assert (
result.result_for_solid("passthrough_df").output_value().collect()
== from_file_df.collect()
)
diff --git a/examples/airline_demo/setup.py b/examples/airline_demo/setup.py
index 84b489953..7b9d3f319 100644
--- a/examples/airline_demo/setup.py
+++ b/examples/airline_demo/setup.py
@@ -1,54 +1,53 @@
from setuptools import find_packages, setup
setup(
name="airline_demo",
version="dev",
author="Elementl",
license="Apache-2.0",
description="Dagster Examples",
url="https://github.com/dagster-io/dagster/tree/master/examples/airline_demo",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
packages=find_packages(exclude=["test"]),
# default supports basic tutorial & toy examples
install_requires=["dagster"],
extras_require={
# full is for running the more realistic demos
"full": [
"dagstermill",
"dagster-aws",
"dagster-cron",
"dagster-postgres",
"dagster-pyspark",
- "dagster-slack; python_version >= '3'",
+ "dagster-slack",
"dagster-snowflake",
# These two packages, descartes and geopandas, are used in the airline demo notebooks
"descartes",
'geopandas; "win" not in sys_platform',
"google-api-python-client",
"google-cloud-storage",
"keras",
"lakehouse",
- 'matplotlib==3.0.2; python_version >= "3.5"',
- 'matplotlib==2.2.4; python_version < "3.5"',
+ "matplotlib==3.0.2",
"mock",
"moto>=1.3.7",
"pandas>=1.0.0",
"pytest-mock",
# Pyspark 2.x is incompatible with Python 3.8+
'pyspark>=3.0.0; python_version >= "3.8"',
'pyspark>=2.0.2; python_version < "3.8"',
"seaborn",
"sqlalchemy-redshift>=0.7.2",
"SQLAlchemy-Utils==0.33.8",
'tensorflow; python_version < "3.9"',
],
"airflow": ["dagster_airflow", "docker-compose==1.23.2"],
},
include_package_data=True,
)
diff --git a/examples/asset_store/tests/test_asset_store.py b/examples/asset_store/tests/test_asset_store.py
index 8c7a4042c..b13765e34 100644
--- a/examples/asset_store/tests/test_asset_store.py
+++ b/examples/asset_store/tests/test_asset_store.py
@@ -1,102 +1,103 @@
import os
import pickle
+from tempfile import TemporaryDirectory
-from dagster import DagsterInstance, execute_pipeline, reexecute_pipeline, seven
+from dagster import DagsterInstance, execute_pipeline, reexecute_pipeline
from ..builtin_custom_path import custom_path_pipeline
from ..builtin_default import model_pipeline
from ..builtin_pipeline import asset_store_pipeline
def test_builtin_default():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.ephemeral()
run_config = {
"resources": {"fs_asset_store": {"config": {"base_dir": tmpdir_path}}},
}
result = execute_pipeline(
model_pipeline, run_config=run_config, mode="test", instance=instance
)
assert result.success
filepath_call_api = os.path.join(tmpdir_path, result.run_id, "call_api", "result")
assert os.path.isfile(filepath_call_api)
with open(filepath_call_api, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
filepath_parse_df = os.path.join(tmpdir_path, result.run_id, "parse_df", "result")
assert os.path.isfile(filepath_parse_df)
with open(filepath_parse_df, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4, 5]
assert reexecute_pipeline(
model_pipeline,
result.run_id,
run_config=run_config,
mode="test",
instance=instance,
step_selection=["parse_df"],
).success
def test_custom_path_asset_store():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.ephemeral()
run_config = {
"resources": {"fs_asset_store": {"config": {"base_dir": tmpdir_path}}},
}
result = execute_pipeline(
custom_path_pipeline, run_config=run_config, mode="test", instance=instance
)
assert result.success
filepath_call_api = os.path.join(tmpdir_path, "call_api_output")
assert os.path.isfile(filepath_call_api)
with open(filepath_call_api, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
filepath_parse_df = os.path.join(tmpdir_path, "parse_df_output")
assert os.path.isfile(filepath_parse_df)
with open(filepath_parse_df, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4, 5]
assert reexecute_pipeline(
custom_path_pipeline,
result.run_id,
run_config=run_config,
mode="test",
instance=instance,
step_selection=["parse_df*"],
).success
def test_builtin_pipeline():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.ephemeral()
run_config = {
"resources": {"object_manager": {"config": {"base_dir": tmpdir_path}}},
}
result = execute_pipeline(
asset_store_pipeline, run_config=run_config, mode="test", instance=instance
)
assert result.success
filepath_call_api = os.path.join(tmpdir_path, result.run_id, "call_api", "result")
assert os.path.isfile(filepath_call_api)
with open(filepath_call_api, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
filepath_parse_df = os.path.join(tmpdir_path, result.run_id, "parse_df", "result")
assert os.path.isfile(filepath_parse_df)
with open(filepath_parse_df, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3, 4, 5]
diff --git a/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_default_object_manager.py b/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_default_object_manager.py
index 7a3222666..c2f002d97 100644
--- a/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_default_object_manager.py
+++ b/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_default_object_manager.py
@@ -1,11 +1,12 @@
+from tempfile import TemporaryDirectory
+
from dagster import execute_pipeline
-from dagster.seven import TemporaryDirectory
from docs_snippets.overview.object_managers.default_object_manager import my_pipeline
def test_default_object_manager():
with TemporaryDirectory() as tmpdir:
execute_pipeline(
my_pipeline,
run_config={"resources": {"object_manager": {"config": {"base_dir": tmpdir}}}},
)
diff --git a/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_object_manager_per_output.py b/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_object_manager_per_output.py
index f2ab773bb..520802052 100644
--- a/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_object_manager_per_output.py
+++ b/examples/docs_snippets/docs_snippets_tests/overview_tests/object_manager_tests/test_object_manager_per_output.py
@@ -1,10 +1,11 @@
+from tempfile import TemporaryDirectory
+
from dagster import execute_pipeline
-from dagster.seven import TemporaryDirectory
from docs_snippets.overview.object_managers.object_manager_per_output import my_pipeline
def test_object_manager_per_output():
with TemporaryDirectory() as tmpdir:
execute_pipeline(
my_pipeline, run_config={"resources": {"fs": {"config": {"base_dir": tmpdir}}}},
)
diff --git a/examples/legacy_examples/dagster_examples/bay_bikes/resources.py b/examples/legacy_examples/dagster_examples/bay_bikes/resources.py
index c4f2aeb44..a1f291c91 100644
--- a/examples/legacy_examples/dagster_examples/bay_bikes/resources.py
+++ b/examples/legacy_examples/dagster_examples/bay_bikes/resources.py
@@ -1,103 +1,103 @@
import os
import shutil
import tempfile
-from dagster import check, resource, seven
+from dagster import check, resource
from dagster.utils import mkdir_p
from google.cloud import storage
class CredentialsVault:
def __init__(self, credentials):
self.credentials = credentials
@classmethod
def instantiate_vault_from_environment_variables(cls, environment_variable_names):
"""Will clobber creds that are already in the vault"""
credentials = {}
for environment_variable_name in environment_variable_names:
credential = os.environ.get(environment_variable_name)
if not credential:
raise ValueError(
"Global Variable {} Not Set in Environment".format(environment_variable_name)
)
credentials[environment_variable_name] = credential
return cls(credentials)
@resource(config_schema={"environment_variable_names": [str]})
def credentials_vault(context):
return CredentialsVault.instantiate_vault_from_environment_variables(
context.resource_config["environment_variable_names"]
)
@resource
def temporary_directory_mount(_):
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
yield tmpdir_path
@resource(config_schema={"mount_location": str})
def mount(context):
mount_location = context.resource_config["mount_location"]
if os.path.exists(mount_location):
return context.resource_config["mount_location"]
raise NotADirectoryError("Cant mount files on this resource. Make sure it exists!")
@resource
def gcs_client(_):
return storage.Client()
class LocalBlob:
def __init__(self, key, bucket_location):
self.key = check.str_param(key, "key")
self.location = os.path.join(check.str_param(bucket_location, "bucket_location"), key)
def upload_from_file(self, file_buffer):
if os.path.exists(self.location):
os.remove(self.location)
with open(self.location, "w+b") as fdest:
shutil.copyfileobj(file_buffer, fdest)
class LocalBucket:
def __init__(self, bucket_name, volume):
self.bucket_name = check.str_param(bucket_name, "bucket_name")
# Setup bucket
self.volume = os.path.join(tempfile.gettempdir(), check.str_param(volume, "volume"))
bucket_location = os.path.join(self.volume, self.bucket_name)
if not os.path.exists(bucket_location):
mkdir_p(bucket_location)
self.location = bucket_location
self.blobs = {}
def blob(self, key):
check.str_param(key, "key")
if key not in self.blobs:
self.blobs[key] = LocalBlob(key, self.location)
return self.blobs[key]
class LocalClient:
def __init__(self, volume="storage"):
self.buckets = {}
self.volume = check.str_param(volume, "volume")
def get_bucket(self, bucket_name):
check.str_param(bucket_name, "bucket_name")
if bucket_name not in self.buckets:
self.buckets[bucket_name] = LocalBucket(bucket_name, self.volume)
return self.buckets[bucket_name]
@resource
def local_client(_):
return LocalClient()
@resource
def testing_client(_):
return LocalClient(volume="testing-storage")
diff --git a/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_resource.py b/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_resource.py
index 72a0c7fc2..30af7a9a1 100644
--- a/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_resource.py
+++ b/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_resource.py
@@ -1,42 +1,41 @@
import os
import shutil
import tempfile
-from dagster import seven
from dagster_examples.bay_bikes.resources import LocalBlob, LocalBucket, LocalClient
def test_local_blob_upload():
- with seven.TemporaryDirectory() as bucket_dir:
+ with tempfile.TemporaryDirectory() as bucket_dir:
target_key = os.path.join(bucket_dir, "foo.txt")
# TODO: Make this windows safe
with tempfile.NamedTemporaryFile() as local_fp:
blob = LocalBlob("foo.txt", bucket_dir)
local_fp.write(b"hello")
local_fp.seek(0)
blob.upload_from_file(local_fp)
assert os.path.exists(target_key)
with open(target_key) as key_fp:
assert key_fp.read() == "hello"
def test_local_bucket():
bucket = LocalBucket("foo", "mountain")
assert os.path.exists(os.path.join(tempfile.gettempdir(), "mountain", "foo"))
bucket.blob("bar.txt")
assert isinstance(bucket.blobs["bar.txt"], LocalBlob)
assert bucket.blobs["bar.txt"].key == "bar.txt"
assert bucket.blobs["bar.txt"].location == os.path.join(
tempfile.gettempdir(), "mountain", "foo", "bar.txt"
)
shutil.rmtree(os.path.join(tempfile.gettempdir(), "mountain"), ignore_errors=True)
def test_local_client():
client = LocalClient(volume="mountain")
client.get_bucket("foo")
assert isinstance(client.buckets["foo"], LocalBucket)
assert client.buckets["foo"].bucket_name == "foo"
assert client.buckets["foo"].volume == os.path.join(tempfile.gettempdir(), "mountain")
shutil.rmtree(os.path.join(tempfile.gettempdir(), "mountain"), ignore_errors=True)
diff --git a/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_solids.py b/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_solids.py
index a91bfbfeb..9ef8f680a 100644
--- a/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_solids.py
+++ b/examples/legacy_examples/dagster_examples_tests/bay_bikes_tests/test_solids.py
@@ -1,453 +1,453 @@
# pylint: disable=redefined-outer-name, W0613
import os
import shutil
import tempfile
from datetime import date
from functools import partial
import pytest
-from dagster import ModeDefinition, execute_pipeline, execute_solid, pipeline, seven
+from dagster import ModeDefinition, execute_pipeline, execute_solid, pipeline
from dagster_examples.bay_bikes.resources import credentials_vault, mount, testing_client
from dagster_examples.bay_bikes.solids import (
MultivariateTimeseries,
Timeseries,
download_weather_report_from_weather_api,
produce_training_set,
produce_trip_dataset,
produce_weather_dataset,
transform_into_traffic_dataset,
trip_etl,
upload_pickled_object_to_gcs_bucket,
)
from dagster_examples.common.resources import postgres_db_info_resource
from dagster_examples_tests.bay_bikes_tests.test_data import FAKE_TRIP_DATA, FAKE_WEATHER_DATA
from numpy import array, array_equal
from numpy.testing import assert_array_equal
from pandas import DataFrame, Timestamp
from requests import HTTPError
START_TIME = 1514793600
VOLUME_TARGET_DIRECTORY = "/tmp/bar"
FAKE_ZIPFILE_NAME = "data.csv.zip"
BUCKET_NAME = "dagster-scratch-ccdfe1e"
TRAINING_FILE_NAME = "training_data"
class MockResponse:
def __init__(self, return_status_code, json_data):
self.status_code = return_status_code
self.json_data = json_data
def raise_for_status(self):
if self.status_code != 200:
raise HTTPError("BAD")
return self.status_code
def json(self):
return self.json_data
@pytest.fixture
def mock_response_ok():
return MockResponse(200, {"daily": {"data": [{"time": START_TIME}]}})
@pytest.fixture
def simple_timeseries():
return Timeseries([1, 2, 3, 4, 5])
@pytest.mark.parametrize(
"timeseries, memory_length, expected_snapshot_sequence",
[
(Timeseries([1, 2, 3, 4, 5]), 1, [[1, 2], [2, 3], [3, 4], [4, 5]]),
(Timeseries([1, 2, 3, 4, 5]), 2, [[1, 2, 3], [2, 3, 4], [3, 4, 5]]),
(Timeseries([1, 2, 3, 4, 5]), 4, [[1, 2, 3, 4, 5]]),
],
)
def test_timeseries_conversion_ok(timeseries, memory_length, expected_snapshot_sequence):
assert timeseries.convert_to_snapshot_sequence(memory_length) == expected_snapshot_sequence
def test_timeseries_conversion_no_sequence():
with pytest.raises(ValueError):
empty_timeseries = Timeseries([])
empty_timeseries.convert_to_snapshot_sequence(3)
@pytest.mark.parametrize("memory_length", [0, -1, 5, 6])
def test_timeseries_bad_memory_lengths(simple_timeseries, memory_length):
with pytest.raises(ValueError):
simple_timeseries.convert_to_snapshot_sequence(memory_length)
def test_multivariate_timeseries_transformation_ok():
mv_timeseries = MultivariateTimeseries(
[[1, 2, 3, 4, 5], [0, 1, 2, 3, 4]], [6, 7, 8, 9, 10], ["foo", "bar"], "baz"
)
matrix, output = mv_timeseries.convert_to_snapshot_matrix(2)
assert_array_equal(
matrix,
array([[[1, 0], [2, 1], [3, 2]], [[2, 1], [3, 2], [4, 3]], [[3, 2], [4, 3], [5, 4]]]),
)
assert_array_equal(output, array([8, 9, 10]))
def test_mutlivariate_timeseries_transformation_from_dataframe_ok():
mv_timeseries_df = DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6], "baz": [0, 0, 0]})
mv_timeseries = MultivariateTimeseries.from_dataframe(mv_timeseries_df, ["foo", "bar"], "baz")
assert mv_timeseries
assert mv_timeseries.input_timeseries_collection[0].sequence == [1, 2, 3]
assert mv_timeseries.input_timeseries_collection[1].sequence == [4, 5, 6]
assert mv_timeseries.output_timeseries.sequence == [0, 0, 0]
@pytest.fixture
def setup_dark_sky():
old_api_key = os.environ.get("DARK_SKY_API_KEY", "")
yield
# pytest swallows errors and throws them in the request.node object. This is akin to a finally.
# https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
os.environ["DAKR_SKY_API_KEY"] = old_api_key
def test_download_weather_report_from_weather_api_200(mocker, setup_dark_sky, mock_response_ok):
mock_get = mocker.patch(
"dagster_examples.bay_bikes.solids.requests.get", return_value=mock_response_ok
)
# Clobber api key for test so we don't expose creds
os.environ["DARK_SKY_API_KEY"] = "uuids-will-never-collide"
solid_result = execute_solid(
download_weather_report_from_weather_api,
ModeDefinition(resource_defs={"credentials_vault": credentials_vault,}),
run_config={
"resources": {
"credentials_vault": {
"config": {"environment_variable_names": ["DARK_SKY_API_KEY"]}
},
},
"solids": {
"download_weather_report_from_weather_api": {
"inputs": {"epoch_date": {"value": START_TIME}}
}
},
},
).output_value()
mock_get.assert_called_with(
"https://api.darksky.net/forecast/uuids-will-never-collide/37.8267,-122.4233,1514793600?exclude=currently,minutely,hourly,alerts,flags"
)
assert isinstance(solid_result, DataFrame)
assert solid_result["time"][0] == START_TIME
# pylint: disable=unused-argument
def mock_read_sql(table_name, _engine, index_col=None):
if table_name == "weather":
return DataFrame(FAKE_WEATHER_DATA)
elif table_name == "trips":
return DataFrame(FAKE_TRIP_DATA)
return DataFrame()
def compose_training_data_env_dict():
return {
"resources": {
"postgres_db": {
"config": {
"db_name": "test",
"hostname": "localhost",
"password": "test",
"username": "test",
}
},
"volume": {"config": {"mount_location": "/tmp"}},
},
"solids": {
"produce_training_set": {"config": {"memory_length": 1}},
"produce_trip_dataset": {"inputs": {"trip_table_name": "trips"},},
"produce_weather_dataset": {
"solids": {"load_entire_weather_table": {"config": {"subsets": ["time"]},}},
"inputs": {"weather_table_name": "weather"},
},
"upload_training_set_to_gcs": {
"inputs": {"bucket_name": BUCKET_NAME, "file_name": TRAINING_FILE_NAME,}
},
},
}
@pipeline(
mode_defs=[
ModeDefinition(
name="testing",
resource_defs={
"postgres_db": postgres_db_info_resource,
"gcs_client": testing_client,
"volume": mount,
},
description="Mode to be used during testing. Allows us to clean up test artifacts without interfearing with local artifacts.",
),
],
)
def generate_test_training_set_pipeline():
upload_training_set_to_gcs = upload_pickled_object_to_gcs_bucket.alias(
"upload_training_set_to_gcs"
)
return upload_training_set_to_gcs(
produce_training_set(
transform_into_traffic_dataset(produce_trip_dataset()), produce_weather_dataset(),
)
)
def test_generate_training_set(mocker):
mocker.patch("dagster_examples.bay_bikes.solids.read_sql_table", side_effect=mock_read_sql)
# Execute Pipeline
test_pipeline_result = execute_pipeline(
pipeline=generate_test_training_set_pipeline,
mode="testing",
run_config=compose_training_data_env_dict(),
)
assert test_pipeline_result.success
# Check solids
EXPECTED_TRAFFIC_RECORDS = [
{
"interval_date": date(2019, 7, 31),
"peak_traffic_load": 1,
"time": Timestamp("2019-07-31 00:00:00"),
},
{
"interval_date": date(2019, 8, 31),
"peak_traffic_load": 1,
"time": Timestamp("2019-08-31 00:00:00"),
},
]
traffic_dataset = test_pipeline_result.output_for_solid(
"transform_into_traffic_dataset", output_name="traffic_dataframe"
).to_dict("records")
assert all(record in EXPECTED_TRAFFIC_RECORDS for record in traffic_dataset)
EXPECTED_WEATHER_RECORDS = [
{
"time": Timestamp("2019-08-31 00:00:00"),
"summary": "Clear throughout the day.",
"icon": "clear-day",
"sunriseTime": 1546269960,
"sunsetTime": 1546304520,
"precipIntensity": 0.0007,
"precipIntensityMax": 0.0019,
"precipProbability": 0.05,
"precipType": "rain",
"temperatureHigh": 56.71,
"temperatureHighTime": 1546294020,
"temperatureLow": 44.75,
"temperatureLowTime": 1546358040,
"dewPoint": 28.34,
"humidity": 0.43,
"pressure": 1017.7,
"windSpeed": 12.46,
"windGust": 26.85,
"windGustTime": 1546289220,
"windBearing": 0,
"cloudCover": 0.11,
"uvIndex": 2,
"uvIndexTime": 1546287180,
"visibility": 10,
"ozone": 314.4,
},
{
"time": Timestamp("2019-07-31 00:00:00"),
"summary": "Clear throughout the day.",
"icon": "clear-day",
"sunriseTime": 1546356420,
"sunsetTime": 1546390920,
"precipIntensity": 0.0005,
"precipIntensityMax": 0.0016,
"precipProbability": 0.02,
"precipType": "sunny",
"temperatureHigh": 55.91,
"temperatureHighTime": 1546382040,
"temperatureLow": 41.18,
"temperatureLowTime": 1546437660,
"dewPoint": 20.95,
"humidity": 0.33,
"pressure": 1023.3,
"windSpeed": 6.77,
"windGust": 22.08,
"windGustTime": 1546343340,
"windBearing": 22,
"cloudCover": 0.1,
"uvIndex": 2,
"uvIndexTime": 1546373580,
"visibility": 10,
"ozone": 305.3,
},
]
weather_dataset = test_pipeline_result.output_for_solid(
"produce_weather_dataset", output_name="weather_dataframe"
).to_dict("records")
assert all(record in EXPECTED_WEATHER_RECORDS for record in weather_dataset)
# Ensure we are generating the expected training set
training_set, labels = test_pipeline_result.output_for_solid("produce_training_set")
assert len(labels) == 1 and labels[0] == 1
assert array_equal(
training_set,
[
[
[
1546356420.0,
1546390920.0,
0.0005,
0.0016,
0.02,
55.91,
1546382040.0,
41.18,
1546437660.0,
20.95,
0.33,
1023.3,
6.77,
22.08,
1546343340.0,
22.0,
0.1,
2.0,
1546373580.0,
10.0,
305.3,
],
[
1546269960.0,
1546304520.0,
0.0007,
0.0019,
0.05,
56.71,
1546294020.0,
44.75,
1546358040.0,
28.34,
0.43,
1017.7,
12.46,
26.85,
1546289220.0,
0.0,
0.11,
2.0,
1546287180.0,
10.0,
314.4,
],
]
],
)
materialization_events = [
event
for event in test_pipeline_result.step_event_list
if event.solid_name == "upload_training_set_to_gcs"
and event.event_type_value == "STEP_MATERIALIZATION"
]
assert len(materialization_events) == 1
materialization = materialization_events[0].event_specific_data.materialization
assert materialization.asset_key.path[0:5] == [
"gs",
"dagster",
"scratch",
"ccdfe1e",
"training_data",
]
materialization_event_metadata = materialization.metadata_entries
assert len(materialization_event_metadata) == 1
assert materialization_event_metadata[0].label == "google cloud storage URI"
assert materialization_event_metadata[0].entry_data.text.startswith(
"gs://dagster-scratch-ccdfe1e/training_data"
)
# Clean up
shutil.rmtree(os.path.join(tempfile.gettempdir(), "testing-storage"), ignore_errors=True)
def run_config_dict():
return {
"resources": {
"postgres_db": {
"config": {
"db_name": "test",
"hostname": "localhost",
"password": "test",
"username": "test",
}
},
"volume": {"config": {"mount_location": ""}},
},
"solids": {
"trip_etl": {
"solids": {
"download_baybike_zipfile_from_url": {
"inputs": {
"file_name": {"value": FAKE_ZIPFILE_NAME},
"base_url": {"value": "https://foo.com"},
}
},
"load_baybike_data_into_dataframe": {
"inputs": {"target_csv_file_in_archive": {"value": "",}}
},
"insert_trip_data_into_table": {"inputs": {"table_name": "test_trips"},},
}
}
},
}
def mock_download_zipfile(tmp_dir, fake_trip_data, _url, _target, _chunk_size):
data_zip_file_path = os.path.join(tmp_dir, FAKE_ZIPFILE_NAME)
DataFrame(fake_trip_data).to_csv(data_zip_file_path, compression="zip")
def test_monthly_trip_pipeline(mocker):
env_dictionary = run_config_dict()
- with seven.TemporaryDirectory() as tmp_dir:
+ with tempfile.TemporaryDirectory() as tmp_dir:
# Run pipeline
download_zipfile = mocker.patch(
"dagster_examples.bay_bikes.solids._download_zipfile_from_url",
side_effect=partial(mock_download_zipfile, tmp_dir, FAKE_TRIP_DATA),
)
mocker.patch("dagster_examples.bay_bikes.solids._create_and_load_staging_table")
env_dictionary["resources"]["volume"]["config"]["mount_location"] = tmp_dir
# Done because we are zipping the file in the tmpdir
env_dictionary["solids"]["trip_etl"]["solids"]["load_baybike_data_into_dataframe"][
"inputs"
]["target_csv_file_in_archive"]["value"] = os.path.join(tmp_dir, FAKE_ZIPFILE_NAME)
result = execute_solid(
trip_etl,
run_config=env_dictionary,
mode_def=ModeDefinition(
name="trip_testing",
resource_defs={"postgres_db": postgres_db_info_resource, "volume": mount},
description="Mode to be used during local demo.",
),
)
assert result.success
download_zipfile.assert_called_with(
"https://foo.com/data.csv.zip", os.path.join(tmp_dir, FAKE_ZIPFILE_NAME), 8192
)
materialization_events = result.result_for_solid(
"insert_trip_data_into_table"
).materialization_events_during_compute
assert len(materialization_events) == 1
assert materialization_events[0].event_specific_data.materialization.label == "test_trips"
assert (
materialization_events[0].event_specific_data.materialization.description
== "Table test_trips created in database test"
)
metadata_entries = materialization_events[
0
].event_specific_data.materialization.metadata_entries
assert len(metadata_entries) == 1
assert metadata_entries[0].label == "num rows inserted"
assert metadata_entries[0].entry_data.text == "2"
diff --git a/examples/legacy_examples/dagster_examples_tests/data_science_guide_tests/test_jupyter_notebooks.py b/examples/legacy_examples/dagster_examples_tests/data_science_guide_tests/test_jupyter_notebooks.py
index 472f3c12f..d66e268fb 100644
--- a/examples/legacy_examples/dagster_examples_tests/data_science_guide_tests/test_jupyter_notebooks.py
+++ b/examples/legacy_examples/dagster_examples_tests/data_science_guide_tests/test_jupyter_notebooks.py
@@ -1,159 +1,158 @@
import re
import nbformat
import psycopg2
import pytest
-from dagster import seven
from dagster.utils import script_relative_path
from nbconvert.preprocessors import ExecutePreprocessor
from nbconvert.preprocessors.execute import CellExecutionError
valid_notebook_paths = [
(
"../../../python_modules/libraries/dagster-pandas/dagster_pandas/examples/notebooks/papermill_pandas_hello_world.ipynb"
),
("../../../python_modules/dagit/dagit_tests/render_uuid_notebook.ipynb"),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_world_output.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_world.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/tutorial_LR.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/tutorial_RF.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/clean_data.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/add_two_numbers.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/mult_two_numbers.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_logging.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_world_explicit_yield.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/bad_kernel.ipynb"
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_world_resource.ipynb"
),
("../../../python_modules/libraries/dagstermill/dagstermill_tests/notebooks/retroactive.ipynb"),
("../../../docs/sections/learn/guides/data_science/iris-kmeans.ipynb"),
("../../../docs/sections/learn/guides/data_science/iris-kmeans_2.ipynb"),
("../../../docs/sections/learn/guides/data_science/iris-kmeans_3.ipynb"),
]
invalid_notebook_paths = [
(
"../../../python_modules/libraries/dagster-pandas/dagster_pandas/examples/pandas_hello_world/scratch.ipynb",
["cells", 0, "outputs", 0],
- seven.ModuleNotFoundError.__name__,
+ ModuleNotFoundError.__name__,
"No module named 'dagster_contrib'",
"error",
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/error_notebook.ipynb",
["cells", 1, "outputs", 0],
Exception.__name__,
"Someone set up us the bomb",
"error",
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_world_config.ipynb",
["cells", 1, "outputs", 0],
TypeError.__name__,
"got an unexpected keyword argument 'solid_name'",
"error",
),
(
"../../../python_modules/libraries/dagstermill/dagstermill/examples/notebooks/hello_world_resource_with_exception.ipynb",
["cells", 4, "outputs", 0],
Exception.__name__,
"",
"error",
),
(
"../../../airline_demo/airline_demo/notebooks/Delays_by_Geography.ipynb",
["cells", 6, "outputs", 0],
psycopg2.OperationalError.__name__,
"could not connect to server:",
"error",
),
(
"../../../airline_demo/airline_demo/notebooks/SFO_Delays_by_Destination.ipynb",
["cells", 6, "outputs", 0],
psycopg2.OperationalError.__name__,
"could not connect to server:",
"error",
),
(
"../../../airline_demo/airline_demo/notebooks/Fares_vs_Delays.ipynb",
["cells", 6, "outputs", 0],
psycopg2.OperationalError.__name__,
"could not connect to server:",
"error",
),
]
def get_dict_value(cell_dict, keys):
if not keys:
return cell_dict
return get_dict_value(cell_dict[keys[0]], keys[1:])
@pytest.mark.skip
@pytest.mark.parametrize("valid_notebook_path", valid_notebook_paths)
def test_valid_notebooks(valid_notebook_path):
notebook_filename = script_relative_path(valid_notebook_path)
with open(notebook_filename) as f:
nb = nbformat.read(f, as_version=4)
ep = ExecutePreprocessor(timeout=600, kernel_name="python3")
ep.preprocess(
nb,
{
"metadata": {
"path": script_relative_path(notebook_filename[: notebook_filename.rfind("/")])
}
},
)
@pytest.mark.skip
@pytest.mark.parametrize(
"invalid_notebook_path, cell_location, error_name, error_value, error_output_type",
invalid_notebook_paths,
)
def test_invalid_notebooks(
invalid_notebook_path, cell_location, error_name, error_value, error_output_type
):
notebook_filename = script_relative_path(invalid_notebook_path)
with open(notebook_filename) as f:
nb = nbformat.read(f, as_version=4)
ep = ExecutePreprocessor(timeout=600, kernel_name="python3")
try:
ep.preprocess(
nb,
{
"metadata": {
"path": script_relative_path(
notebook_filename[: notebook_filename.rfind("/")]
)
}
},
)
except CellExecutionError:
error_message = get_dict_value(nb, cell_location)
assert error_message.ename == error_name
assert bool(re.search(error_value, error_message.evalue))
assert error_message.output_type == error_output_type
diff --git a/examples/legacy_examples/setup.py b/examples/legacy_examples/setup.py
index c8839245a..d7b1316c4 100644
--- a/examples/legacy_examples/setup.py
+++ b/examples/legacy_examples/setup.py
@@ -1,56 +1,55 @@
from setuptools import find_packages, setup
setup(
name="dagster_examples",
version="dev",
author="Elementl",
author_email="hello@elementl.com",
license="Apache-2.0",
description="Dagster Examples",
url="https://github.com/dagster-io/dagster",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
packages=find_packages(exclude=["test"]),
# default supports basic tutorial & toy examples
install_requires=["dagster"],
extras_require={
# full is for running the more realistic demos
"full": [
"dagstermill",
"dagster-aws",
"dagster-cron",
"dagster-postgres",
"dagster-pyspark",
- "dagster-slack; python_version >= '3'",
+ "dagster-slack",
"dagster-snowflake",
# These two packages, descartes and geopandas, are used in the airline demo notebooks
"descartes",
'geopandas; "win" not in sys_platform',
"google-api-python-client",
"google-cloud-storage",
"keras; python_version < '3.9'",
"lakehouse",
- 'matplotlib==3.0.2; python_version >= "3.5"',
- 'matplotlib==2.2.4; python_version < "3.5"',
+ "matplotlib==3.0.2",
"mock",
"moto>=1.3.7",
"pandas>=1.0.0",
"pytest-mock",
# Pyspark 2.x is incompatible with Python 3.8+
'pyspark>=3.0.0; python_version >= "3.8"',
'pyspark>=2.0.2; python_version < "3.8"',
"sqlalchemy-redshift>=0.7.2",
"SQLAlchemy-Utils==0.33.8",
'tensorflow; python_version < "3.9"',
"dagster-gcp",
],
"dbt": ["dbt-postgres"],
"airflow": ["dagster_airflow", "docker-compose==1.23.2"],
},
include_package_data=True,
)
diff --git a/integration_tests/test_suites/k8s-integration-test-suite/conftest.py b/integration_tests/test_suites/k8s-integration-test-suite/conftest.py
index 16e37fd76..88e1e08ce 100644
--- a/integration_tests/test_suites/k8s-integration-test-suite/conftest.py
+++ b/integration_tests/test_suites/k8s-integration-test-suite/conftest.py
@@ -1,143 +1,143 @@
# pylint: disable=unused-import
import os
+import tempfile
import docker
import kubernetes
import pytest
-from dagster import seven
from dagster.core.instance import DagsterInstance
from dagster_k8s.launcher import K8sRunLauncher
from dagster_k8s.scheduler import K8sScheduler
from dagster_k8s_test_infra.cluster import (
dagster_instance_for_k8s_run_launcher,
dagster_instance_with_k8s_scheduler,
define_cluster_provider_fixture,
)
from dagster_k8s_test_infra.helm import TEST_AWS_CONFIGMAP_NAME, helm_namespace_for_k8s_run_launcher
from dagster_k8s_test_infra.integration_utils import image_pull_policy
from dagster_test.test_project import build_and_tag_test_image, get_test_project_docker_image
IS_BUILDKITE = os.getenv("BUILDKITE") is not None
@pytest.fixture(scope="session", autouse=True)
def dagster_home():
old_env = os.getenv("DAGSTER_HOME")
os.environ["DAGSTER_HOME"] = "/opt/dagster/dagster_home"
yield
if old_env is not None:
os.environ["DAGSTER_HOME"] = old_env
cluster_provider = define_cluster_provider_fixture(
additional_kind_images=["docker.io/bitnami/rabbitmq", "docker.io/bitnami/postgresql"]
)
@pytest.yield_fixture
def schedule_tempdir():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
yield tempdir
@pytest.fixture
def k8s_scheduler(
cluster_provider, helm_namespace_for_k8s_run_launcher
): # pylint: disable=redefined-outer-name,unused-argument
return K8sScheduler(
scheduler_namespace=helm_namespace_for_k8s_run_launcher,
image_pull_secrets=[{"name": "element-dev-key"}],
service_account_name="dagit-admin",
instance_config_map="dagster-instance",
postgres_password_secret="dagster-postgresql-secret",
dagster_home="/opt/dagster/dagster_home",
job_image=get_test_project_docker_image(),
load_incluster_config=False,
kubeconfig_file=cluster_provider.kubeconfig_file,
image_pull_policy=image_pull_policy(),
env_config_maps=["dagster-pipeline-env", "test-env-configmap"],
env_secrets=["test-env-secret"],
)
@pytest.fixture(scope="function")
def restore_k8s_cron_tab(
helm_namespace_for_k8s_run_launcher,
): # pylint: disable=redefined-outer-name
kube_api = kubernetes.client.BatchV1beta1Api()
# Doubly make sure CronJobs are deleted pre-test and post-test
kube_api.delete_collection_namespaced_cron_job(namespace=helm_namespace_for_k8s_run_launcher)
yield
kube_api.delete_collection_namespaced_cron_job(namespace=helm_namespace_for_k8s_run_launcher)
@pytest.fixture(scope="session")
def run_launcher(
cluster_provider, helm_namespace_for_k8s_run_launcher
): # pylint: disable=redefined-outer-name,unused-argument
return K8sRunLauncher(
image_pull_secrets=[{"name": "element-dev-key"}],
service_account_name="dagit-admin",
instance_config_map="dagster-instance",
postgres_password_secret="dagster-postgresql-secret",
dagster_home="/opt/dagster/dagster_home",
job_image=get_test_project_docker_image(),
load_incluster_config=False,
kubeconfig_file=cluster_provider.kubeconfig_file,
image_pull_policy=image_pull_policy(),
job_namespace=helm_namespace_for_k8s_run_launcher,
env_config_maps=["dagster-pipeline-env", "test-env-configmap"]
+ ([TEST_AWS_CONFIGMAP_NAME] if not IS_BUILDKITE else []),
env_secrets=["test-env-secret"],
)
@pytest.fixture(scope="session")
def dagster_docker_image():
docker_image = get_test_project_docker_image()
if not IS_BUILDKITE:
try:
client = docker.from_env()
client.images.get(docker_image)
print( # pylint: disable=print-call
"Found existing image tagged {image}, skipping image build. To rebuild, first run: "
"docker rmi {image}".format(image=docker_image)
)
except docker.errors.ImageNotFound:
build_and_tag_test_image(docker_image)
return docker_image
# See: https://stackoverflow.com/a/31526934/324449
def pytest_addoption(parser):
# We catch the ValueError to support cases where we are loading multiple test suites, e.g., in
# the VSCode test explorer. When pytest tries to add an option twice, we get, e.g.
#
# ValueError: option names {'--cluster-provider'} already added
# Use kind or some other cluster provider?
try:
parser.addoption("--cluster-provider", action="store", default="kind")
except ValueError:
pass
# Specify an existing kind cluster name to use
try:
parser.addoption("--kind-cluster", action="store")
except ValueError:
pass
# Keep resources around after tests are done
try:
parser.addoption("--no-cleanup", action="store_true", default=False)
except ValueError:
pass
# Use existing Helm chart/namespace
try:
parser.addoption("--existing-helm-namespace", action="store")
except ValueError:
pass
diff --git a/integration_tests/test_suites/k8s-integration-test-suite/test_scheduler.py b/integration_tests/test_suites/k8s-integration-test-suite/test_scheduler.py
index 602dcc465..726b4aa85 100644
--- a/integration_tests/test_suites/k8s-integration-test-suite/test_scheduler.py
+++ b/integration_tests/test_suites/k8s-integration-test-suite/test_scheduler.py
@@ -1,720 +1,721 @@
import os
import subprocess
import sys
+import tempfile
from contextlib import contextmanager
import kubernetes
import pytest
-from dagster import DagsterInstance, ScheduleDefinition, seven
+from dagster import DagsterInstance, ScheduleDefinition
from dagster.core.definitions import lambda_solid, pipeline, repository
from dagster.core.host_representation import (
ManagedGrpcPythonEnvRepositoryLocationOrigin,
RepositoryLocation,
RepositoryLocationHandle,
)
from dagster.core.scheduler.job import JobStatus, JobType
from dagster.core.scheduler.scheduler import (
DagsterScheduleDoesNotExist,
DagsterScheduleReconciliationError,
DagsterSchedulerError,
)
from dagster.core.storage.pipeline_run import PipelineRunStatus
from dagster.core.test_utils import environ
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from marks import mark_scheduler
@pytest.fixture(scope="function")
def unset_dagster_home():
old_env = os.getenv("DAGSTER_HOME")
if old_env is not None:
del os.environ["DAGSTER_HOME"]
yield
if old_env is not None:
os.environ["DAGSTER_HOME"] = old_env
@pipeline
def no_config_pipeline():
@lambda_solid
def return_hello():
return "Hello"
return return_hello()
schedules_dict = {
"no_config_pipeline_daily_schedule": ScheduleDefinition(
name="no_config_pipeline_daily_schedule",
cron_schedule="0 0 * * *",
pipeline_name="no_config_pipeline",
run_config={"storage": {"filesystem": None}},
),
"no_config_pipeline_every_min_schedule": ScheduleDefinition(
name="no_config_pipeline_every_min_schedule",
cron_schedule="* * * * *",
pipeline_name="no_config_pipeline",
run_config={"storage": {"filesystem": None}},
),
"default_config_pipeline_every_min_schedule": ScheduleDefinition(
name="default_config_pipeline_every_min_schedule",
cron_schedule="* * * * *",
pipeline_name="no_config_pipeline",
),
}
def define_schedules():
return list(schedules_dict.values())
@repository
def test_repository():
if os.getenv("DAGSTER_TEST_SMALL_REPO"):
return [no_config_pipeline] + list(
filter(
lambda x: not x.name == "default_config_pipeline_every_min_schedule",
define_schedules(),
)
)
return [no_config_pipeline] + define_schedules()
@contextmanager
def get_test_external_repo():
with RepositoryLocationHandle.create_from_repository_location_origin(
ManagedGrpcPythonEnvRepositoryLocationOrigin(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable, python_file=__file__, attribute="test_repository",
),
location_name="test_location",
),
) as handle:
yield RepositoryLocation.from_handle(handle).get_repository("test_repository")
@contextmanager
def get_smaller_external_repo():
with environ({"DAGSTER_TEST_SMALL_REPO": "1"}):
with get_test_external_repo() as repo:
yield repo
@mark_scheduler
def test_init(
dagster_instance_with_k8s_scheduler,
schedule_tempdir,
helm_namespace_for_k8s_run_launcher,
restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repository:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repository)
# Check schedules are saved to disk
assert "schedules" in os.listdir(schedule_tempdir)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
schedules = instance.all_stored_job_state(job_type=JobType.SCHEDULE)
for schedule in schedules:
assert schedule.status == JobStatus.STOPPED
@mark_scheduler
def test_re_init(
dagster_instance_with_k8s_scheduler,
schedule_tempdir,
helm_namespace_for_k8s_run_launcher,
restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
# Start schedule
schedule_state = instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# Re-initialize scheduler
instance.reconcile_scheduler_state(external_repo)
# Check schedules are saved to disk
assert "schedules" in os.listdir(schedule_tempdir)
schedule_states = instance.all_stored_job_state(job_type=JobType.SCHEDULE)
for state in schedule_states:
if state.name == "no_config_pipeline_every_min_schedule":
assert state == schedule_state
@mark_scheduler
def test_start_and_stop_schedule(
dagster_instance_with_k8s_scheduler,
schedule_tempdir,
helm_namespace_for_k8s_run_launcher,
restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
schedule = external_repo.get_external_schedule(
schedule_name="no_config_pipeline_every_min_schedule"
)
schedule_origin_id = schedule.get_external_origin_id()
instance.start_schedule_and_update_storage_state(external_schedule=schedule)
assert "schedules" in os.listdir(schedule_tempdir)
assert instance.scheduler.get_cron_job(schedule_origin_id=schedule_origin_id)
instance.stop_schedule_and_update_storage_state(schedule_origin_id=schedule_origin_id)
assert not instance.scheduler.get_cron_job(schedule_origin_id=schedule_origin_id)
@mark_scheduler
def test_start_non_existent_schedule(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with pytest.raises(DagsterScheduleDoesNotExist):
# Initialize scheduler
instance.stop_schedule_and_update_storage_state("asdf")
@mark_scheduler
def test_start_schedule_cron_job(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("default_config_pipeline_every_min_schedule")
)
# Inspect the cron tab
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 3
external_schedules_dict = {
external_repo.get_external_schedule(name).get_external_origin_id(): schedule_def
for name, schedule_def in schedules_dict.items()
}
for cron_job in cron_jobs:
cron_schedule = cron_job.spec.schedule
command = cron_job.spec.job_template.spec.template.spec.containers[0].command
args = cron_job.spec.job_template.spec.template.spec.containers[0].args
schedule_origin_id = cron_job.metadata.name
schedule_def = external_schedules_dict[schedule_origin_id]
assert cron_schedule == schedule_def.cron_schedule
assert command == None
assert args[:5] == [
"dagster",
"api",
"launch_scheduled_execution",
"/tmp/launch_scheduled_execution_output",
"--schedule_name",
]
@mark_scheduler
def test_remove_schedule_def(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
with get_smaller_external_repo() as smaller_repo:
instance.reconcile_scheduler_state(smaller_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 2
@mark_scheduler
def test_add_schedule_def(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
with get_smaller_external_repo() as smaller_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(smaller_repo)
# Start all schedule and verify cron tab, schedule storage, and errors
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 2
assert len(instance.scheduler.get_all_cron_jobs()) == 2
assert len(instance.scheduler_debug_info().errors) == 0
# Reconcile with an additional schedule added
instance.reconcile_scheduler_state(external_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
assert len(instance.scheduler.get_all_cron_jobs()) == 2
assert len(instance.scheduler_debug_info().errors) == 0
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("default_config_pipeline_every_min_schedule")
)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
assert len(instance.scheduler.get_all_cron_jobs()) == 3
assert len(instance.scheduler_debug_info().errors) == 0
@mark_scheduler
def test_start_and_stop_schedule_cron_tab(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
# Start schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 1
# Try starting it again
with pytest.raises(DagsterSchedulerError):
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 1
# Start another schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 2
# Stop second schedule
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
).get_external_origin_id()
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 1
# Try stopping second schedule again
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
).get_external_origin_id()
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 1
# Start second schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 2
# Reconcile schedule state, should be in the same state
instance.reconcile_scheduler_state(external_repo)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 2
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("default_config_pipeline_every_min_schedule")
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 3
# Reconcile schedule state, should be in the same state
instance.reconcile_scheduler_state(external_repo)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 3
# Stop all schedules
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
).get_external_origin_id()
)
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"default_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 0
# Reconcile schedule state, should be in the same state
instance.reconcile_scheduler_state(external_repo)
cron_jobs = instance.scheduler.get_all_cron_jobs()
assert len(cron_jobs) == 0
@mark_scheduler
def test_script_execution(
dagster_instance_with_k8s_scheduler,
unset_dagster_home,
helm_namespace_for_k8s_run_launcher,
restore_k8s_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with environ({"DAGSTER_HOME": tempdir}):
local_instance = DagsterInstance.get()
with get_test_external_repo() as external_repo:
# Initialize scheduler
dagster_instance_with_k8s_scheduler.reconcile_scheduler_state(external_repo)
dagster_instance_with_k8s_scheduler.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
local_runs = local_instance.get_runs()
assert len(local_runs) == 0
cron_job_name = external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
batch_v1beta1_api = kubernetes.client.BatchV1beta1Api()
cron_job = batch_v1beta1_api.read_namespaced_cron_job(
cron_job_name, helm_namespace_for_k8s_run_launcher
)
container = cron_job.spec.job_template.spec.template.spec.containers[0]
args = container.args
cli_cmd = [sys.executable, "-m"] + args
p = subprocess.Popen(
cli_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env={
"DAGSTER_HOME": tempdir,
"LC_ALL": "C.UTF-8",
"LANG": "C.UTF-8",
}, # https://stackoverflow.com/questions/36651680/click-will-abort-further-execution-because-python-3-was-configured-to-use-ascii
)
stdout, stderr = p.communicate()
print("Command completed with stdout: ", stdout) # pylint: disable=print-call
print("Command completed with stderr: ", stderr) # pylint: disable=print-call
assert p.returncode == 0
local_runs = local_instance.get_runs()
assert len(local_runs) == 1
run_id = local_runs[0].run_id
pipeline_run = local_instance.get_run_by_id(run_id)
assert pipeline_run
assert pipeline_run.status == PipelineRunStatus.SUCCESS
@mark_scheduler
def test_start_schedule_fails(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
def raises(*args, **kwargs):
raise Exception("Patch")
instance._scheduler._api.create_namespaced_cron_job = ( # pylint: disable=protected-access
raises
)
with pytest.raises(Exception, match="Patch"):
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
schedule = instance.get_job_state(
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
assert schedule.status == JobStatus.STOPPED
@mark_scheduler
def test_start_schedule_unsuccessful(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
def do_nothing(**_):
pass
instance._scheduler._api.create_namespaced_cron_job = ( # pylint: disable=protected-access
do_nothing
)
# Start schedule
with pytest.raises(
DagsterSchedulerError,
match="Attempted to add K8s CronJob for schedule no_config_pipeline_every_min_schedule, "
"but failed. The schedule no_config_pipeline_every_min_schedule is not running.",
):
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
@mark_scheduler
def test_start_schedule_manual_delete_debug(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
instance.scheduler.get_all_cron_jobs()
# Manually delete the schedule
instance.scheduler._end_cron_job( # pylint: disable=protected-access
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id(),
)
# Check debug command
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 1
# Reconcile should fix error
instance.reconcile_scheduler_state(external_repo)
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 0
@mark_scheduler
def test_start_schedule_manual_add_debug(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
# Manually add the schedule from to the crontab
instance.scheduler._start_cron_job( # pylint: disable=protected-access
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# Check debug command
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 1
# Reconcile should fix error
instance.reconcile_scheduler_state(external_repo)
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 0
@mark_scheduler
def test_stop_schedule_fails(
dagster_instance_with_k8s_scheduler,
schedule_tempdir,
helm_namespace_for_k8s_run_launcher,
restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
external_schedule = external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
)
schedule_origin_id = external_schedule.get_external_origin_id()
def raises(*args, **kwargs):
raise Exception("Patch")
instance._scheduler._end_cron_job = raises # pylint: disable=protected-access
instance.start_schedule_and_update_storage_state(external_schedule)
assert "schedules" in os.listdir(schedule_tempdir)
# End schedule
with pytest.raises(Exception, match="Patch"):
instance.stop_schedule_and_update_storage_state(schedule_origin_id)
schedule = instance.get_job_state(schedule_origin_id)
assert schedule.status == JobStatus.RUNNING
@mark_scheduler
def test_stop_schedule_unsuccessful(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
def do_nothing(**_):
pass
instance._scheduler._end_cron_job = do_nothing # pylint: disable=protected-access
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# End schedule
with pytest.raises(
DagsterSchedulerError,
match="Attempted to remove existing K8s CronJob for schedule "
"no_config_pipeline_every_min_schedule, but failed. Schedule is still running.",
):
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
@mark_scheduler
def test_wipe(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
# Start schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# Wipe scheduler
instance.wipe_all_schedules()
# Check schedules are wiped
assert instance.all_stored_job_state(job_type=JobType.SCHEDULE) == []
@mark_scheduler
def test_reconcile_failure(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
instance.reconcile_scheduler_state(external_repo)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
def failed_start_job(*_):
raise DagsterSchedulerError("Failed to start")
def failed_refresh_job(*_):
raise DagsterSchedulerError("Failed to refresh")
def failed_end_job(*_):
raise DagsterSchedulerError("Failed to stop")
instance._scheduler.start_schedule = failed_start_job # pylint: disable=protected-access
instance._scheduler.refresh_schedule = ( # pylint: disable=protected-access
failed_refresh_job
)
instance._scheduler.stop_schedule = failed_end_job # pylint: disable=protected-access
with pytest.raises(
DagsterScheduleReconciliationError,
match="Error 1: Failed to stop\n Error 2: Failed to stop\n Error 3: Failed to refresh",
):
instance.reconcile_scheduler_state(external_repo)
@mark_scheduler
def test_reconcile_failure_when_deleting_schedule_def(
dagster_instance_with_k8s_scheduler, helm_namespace_for_k8s_run_launcher, restore_k8s_cron_tab,
): # pylint:disable=unused-argument
instance = dagster_instance_with_k8s_scheduler
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
def failed_end_job(*_):
raise DagsterSchedulerError("Failed to stop")
instance._scheduler.stop_schedule_and_delete_from_storage = ( # pylint: disable=protected-access
failed_end_job
)
with pytest.raises(
DagsterScheduleReconciliationError, match="Error 1: Failed to stop",
):
with get_smaller_external_repo() as smaller_repo:
instance.reconcile_scheduler_state(smaller_repo)
diff --git a/python_modules/automation/setup.py b/python_modules/automation/setup.py
index 45242b67e..3ff01553b 100644
--- a/python_modules/automation/setup.py
+++ b/python_modules/automation/setup.py
@@ -1,43 +1,43 @@
from setuptools import find_packages, setup
setup(
name="automation",
version="0.0.1",
author="Elementl",
author_email="hello@elementl.com",
license="Apache-2.0",
description="Tools for infrastructure automation",
url="https://github.com/dagster-io/dagster/tree/master/python_modules/automation",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
packages=find_packages(exclude=["test"]),
install_requires=[
"autoflake",
"boto3",
"click>=6.7",
"dagster",
"packaging==18.0",
"pandas",
- 'pydantic; python_version >="3"',
- 'pytablereader; python_version >="3"',
+ "pydantic",
+ "pytablereader",
"requests",
- 'slackclient>=2,<3; python_version >="3"',
+ "slackclient>=2,<3",
"twine==1.15.0",
"virtualenv==16.5.0",
"wheel==0.33.6",
"urllib3==1.25.9",
],
entry_points={
"console_scripts": [
"dagster-docs = automation.docs.cli:main",
"dagster-image = automation.docker.cli:main",
"dagster-release = automation.release.cli:main",
"dagster-scaffold = automation.scaffold.cli:main",
"dagster-helm = automation.helm.cli:main",
]
},
)
diff --git a/python_modules/dagit/dagit/app.py b/python_modules/dagit/dagit/app.py
index 1803e23c0..4b86c09ca 100644
--- a/python_modules/dagit/dagit/app.py
+++ b/python_modules/dagit/dagit/app.py
@@ -1,242 +1,242 @@
from __future__ import absolute_import
import gzip
import io
import os
import uuid
import nbformat
from dagster import __version__ as dagster_version
-from dagster import check, seven
+from dagster import check
from dagster.cli.workspace import Workspace
from dagster.core.debug import DebugRunPayload
from dagster.core.execution.compute_logs import warn_if_compute_logs_disabled
from dagster.core.instance import DagsterInstance
from dagster.core.storage.compute_log_manager import ComputeIOType
from dagster_graphql.implementation.context import DagsterGraphQLContext
from dagster_graphql.schema import create_schema
from dagster_graphql.version import __version__ as dagster_graphql_version
from flask import Blueprint, Flask, jsonify, redirect, request, send_file
from flask_cors import CORS
from flask_graphql import GraphQLView
from flask_sockets import Sockets
from graphql.execution.executors.gevent import GeventExecutor as Executor
from nbconvert import HTMLExporter
from .format_error import format_error_with_stack_trace
from .subscription_server import DagsterSubscriptionServer
from .templates.playground import TEMPLATE as PLAYGROUND_TEMPLATE
from .version import __version__
MISSING_SCHEDULER_WARNING = (
"You have defined ScheduleDefinitions for this repository, but have "
"not defined a scheduler on the instance"
)
class DagsterGraphQLView(GraphQLView):
def __init__(self, context, **kwargs):
super(DagsterGraphQLView, self).__init__(**kwargs)
self.context = check.inst_param(context, "context", DagsterGraphQLContext)
def get_context(self):
return self.context
format_error = staticmethod(format_error_with_stack_trace)
def dagster_graphql_subscription_view(subscription_server, context):
context = check.inst_param(context, "context", DagsterGraphQLContext)
def view(ws):
subscription_server.handle(ws, request_context=context)
return []
return view
def info_view():
return (
jsonify(
dagit_version=__version__,
dagster_graphql_version=dagster_graphql_version,
dagster_version=dagster_version,
),
200,
)
def notebook_view(request_args):
check.dict_param(request_args, "request_args")
# This currently provides open access to your file system - the very least we can
# do is limit it to notebook files until we create a more permanent solution.
path = request_args["path"]
if not path.endswith(".ipynb"):
return "Invalid Path", 400
with open(os.path.abspath(path)) as f:
read_data = f.read()
notebook = nbformat.reads(read_data, as_version=4)
html_exporter = HTMLExporter()
html_exporter.template_file = "basic"
(body, resources) = html_exporter.from_notebook_node(notebook)
return "" + body, 200
def download_log_view(context):
context = check.inst_param(context, "context", DagsterGraphQLContext)
def view(run_id, step_key, file_type):
run_id = str(uuid.UUID(run_id)) # raises if not valid run_id
step_key = step_key.split("/")[-1] # make sure we're not diving deep into
out_name = "{}_{}.{}".format(run_id, step_key, file_type)
manager = context.instance.compute_log_manager
try:
io_type = ComputeIOType(file_type)
result = manager.get_local_path(run_id, step_key, io_type)
if not os.path.exists(result):
result = io.BytesIO()
timeout = None if manager.is_watch_completed(run_id, step_key) else 0
except ValueError:
result = io.BytesIO()
timeout = 0
if not result:
result = io.BytesIO()
return send_file(
result, as_attachment=True, attachment_filename=out_name, cache_timeout=timeout
)
return view
def download_dump_view(context):
context = check.inst_param(context, "context", DagsterGraphQLContext)
def view(run_id):
run = context.instance.get_run_by_id(run_id)
debug_payload = DebugRunPayload.build(context.instance, run)
check.invariant(run is not None)
out_name = "{}.gzip".format(run_id)
result = io.BytesIO()
with gzip.GzipFile(fileobj=result, mode="wb") as file:
debug_payload.write(file)
result.seek(0) # be kind, please rewind
return send_file(result, as_attachment=True, attachment_filename=out_name)
return view
def instantiate_app_with_views(context, app_path_prefix):
app = Flask(
"dagster-ui",
static_url_path=app_path_prefix,
static_folder=os.path.join(os.path.dirname(__file__), "./webapp/build"),
)
schema = create_schema()
subscription_server = DagsterSubscriptionServer(schema=schema)
# Websocket routes
sockets = Sockets(app)
sockets.add_url_rule(
"{}/graphql".format(app_path_prefix),
"graphql",
dagster_graphql_subscription_view(subscription_server, context),
)
# HTTP routes
bp = Blueprint("routes", __name__, url_prefix=app_path_prefix)
bp.add_url_rule(
"/graphiql", "graphiql", lambda: redirect("{}/graphql".format(app_path_prefix), 301)
)
bp.add_url_rule(
"/graphql",
"graphql",
DagsterGraphQLView.as_view(
"graphql",
schema=schema,
graphiql=True,
graphiql_template=PLAYGROUND_TEMPLATE.replace("APP_PATH_PREFIX", app_path_prefix),
executor=Executor(),
context=context,
),
)
bp.add_url_rule(
# should match the `build_local_download_url`
"/download///",
"download_view",
download_log_view(context),
)
bp.add_url_rule(
"/download_debug/", "download_dump_view", download_dump_view(context),
)
# these routes are specifically for the Dagit UI and are not part of the graphql
# API that we want other people to consume, so they're separate for now.
# Also grabbing the magic global request args dict so that notebook_view is testable
bp.add_url_rule("/dagit/notebook", "notebook", lambda: notebook_view(request.args))
bp.add_url_rule("/dagit_info", "sanity_view", info_view)
index_path = os.path.join(os.path.dirname(__file__), "./webapp/build/index.html")
def index_view(_path):
try:
with open(index_path) as f:
return (
f.read()
.replace('href="/', 'href="{}/'.format(app_path_prefix))
.replace('src="/', 'src="{}/'.format(app_path_prefix))
.replace(
' 0
or click.confirm(
(
"Another process on your machine is already listening on port {port}. "
"Would you like to run the app at another port instead?"
).format(port=port)
)
):
port_lookup_attempts += 1
start_server(
instance,
host,
port + port_lookup_attempts,
path_prefix,
app,
True,
port_lookup_attempts,
)
else:
six.raise_from(
Exception(
(
"Another process on your machine is already listening on port {port}. "
"It is possible that you have another instance of dagit "
"running somewhere using the same port. Or it could be another "
"random process. Either kill that process or use the -p option to "
"select another port."
).format(port=port)
),
os_error,
)
else:
raise os_error
cli = create_dagit_cli()
def main():
# click magic
cli(auto_envvar_prefix="DAGIT") # pylint:disable=E1120
diff --git a/python_modules/dagit/dagit_tests/test_app.py b/python_modules/dagit/dagit_tests/test_app.py
index 6dce035ea..a210431dc 100644
--- a/python_modules/dagit/dagit_tests/test_app.py
+++ b/python_modules/dagit/dagit_tests/test_app.py
@@ -1,308 +1,308 @@
import json
+import tempfile
import pytest
from click.testing import CliRunner
from dagit.app import create_app_from_workspace
from dagit.cli import host_dagit_ui_with_workspace, ui
-from dagster import seven
from dagster.cli.workspace.load import load_workspace_from_yaml_paths
from dagster.core.instance import DagsterInstance
from dagster.core.telemetry import START_DAGIT_WEBSERVER, UPDATE_REPO_STATS, hash_name
from dagster.core.test_utils import instance_for_test_tempdir
from dagster.seven import mock
from dagster.utils import file_relative_path
def test_create_app_with_workspace():
with load_workspace_from_yaml_paths(
[file_relative_path(__file__, "./workspace.yaml")],
) as workspace:
assert create_app_from_workspace(workspace, DagsterInstance.ephemeral())
def test_create_app_with_multiple_workspace_files():
with load_workspace_from_yaml_paths(
[
file_relative_path(__file__, "./workspace.yaml"),
file_relative_path(__file__, "./override.yaml"),
],
) as workspace:
assert create_app_from_workspace(workspace, DagsterInstance.ephemeral())
def test_create_app_with_workspace_and_scheduler():
with load_workspace_from_yaml_paths(
[file_relative_path(__file__, "./workspace.yaml")]
) as workspace:
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
}
},
) as instance:
assert create_app_from_workspace(workspace, instance)
def test_notebook_view():
notebook_path = file_relative_path(__file__, "render_uuid_notebook.ipynb")
with load_workspace_from_yaml_paths(
[file_relative_path(__file__, "./workspace.yaml")]
) as workspace:
with create_app_from_workspace(
workspace, DagsterInstance.ephemeral(),
).test_client() as client:
res = client.get("/dagit/notebook?path={}".format(notebook_path))
assert res.status_code == 200
# This magic guid is hardcoded in the notebook
assert b"6cac0c38-2c97-49ca-887c-4ac43f141213" in res.data
def test_index_view():
with load_workspace_from_yaml_paths(
[file_relative_path(__file__, "./workspace.yaml")]
) as workspace:
with create_app_from_workspace(
workspace, DagsterInstance.ephemeral(),
).test_client() as client:
res = client.get("/")
assert res.status_code == 200, res.data
assert b"You need to enable JavaScript to run this app" in res.data
def test_index_view_at_path_prefix():
with load_workspace_from_yaml_paths(
[file_relative_path(__file__, "./workspace.yaml")]
) as workspace:
with create_app_from_workspace(
workspace, DagsterInstance.ephemeral(), "/dagster-path"
).test_client() as client:
# / redirects to prefixed path
res = client.get("/")
assert res.status_code == 301
# index contains the path meta tag
res = client.get("/dagster-path")
assert res.status_code == 200
assert b"You need to enable JavaScript to run this app" in res.data
assert b'=7.0",
"dagster=={ver}".format(ver=ver),
"dagster-graphql=={ver}".format(ver=ver),
# graphql
"graphql-core>=2.1,<3",
# server
"flask-cors>=3.0.6",
"Flask-GraphQL>=2.0.0",
"Flask-Sockets>=0.2.1",
"flask>=0.12.4",
"gevent-websocket>=0.10.1",
"gevent",
"graphql-ws>=0.3.0",
# watchdog
"watchdog>=0.8.3",
# notebooks support
"nbconvert>=5.4.0,<6.0.0",
],
entry_points={
"console_scripts": ["dagit = dagit.cli:main", "dagit-debug = dagit.debug:main"]
},
)
diff --git a/python_modules/dagster-graphql/dagster_graphql/cli.py b/python_modules/dagster-graphql/dagster_graphql/cli.py
index d5bebe9d7..51ffded06 100644
--- a/python_modules/dagster-graphql/dagster_graphql/cli.py
+++ b/python_modules/dagster-graphql/dagster_graphql/cli.py
@@ -1,228 +1,228 @@
from future.standard_library import install_aliases # isort:skip
install_aliases() # isort:skip
import signal
import threading
import warnings
+from urllib.parse import urljoin, urlparse
import click
import requests
from dagster import check, seven
from dagster.cli.workspace import workspace_target_argument
from dagster.cli.workspace.cli_target import WORKSPACE_TARGET_WARNING, get_workspace_from_kwargs
from dagster.cli.workspace.workspace import Workspace
from dagster.core.instance import DagsterInstance
-from dagster.seven import urljoin, urlparse
from dagster.utils import DEFAULT_REPOSITORY_YAML_FILENAME
from dagster.utils.log import get_stack_trace_array
from graphql import graphql
from graphql.execution.executors.gevent import GeventExecutor
from graphql.execution.executors.sync import SyncExecutor
from .client.query import LAUNCH_PIPELINE_EXECUTION_MUTATION
from .implementation.context import DagsterGraphQLContext
from .schema import create_schema
from .version import __version__
def create_dagster_graphql_cli():
return ui
def execute_query(workspace, query, variables=None, use_sync_executor=False, instance=None):
check.inst_param(workspace, "workspace", Workspace)
check.str_param(query, "query")
check.opt_dict_param(variables, "variables")
instance = (
check.inst_param(instance, "instance", DagsterInstance)
if instance
else DagsterInstance.get()
)
check.bool_param(use_sync_executor, "use_sync_executor")
query = query.strip("'\" \n\t")
context = DagsterGraphQLContext(workspace=workspace, instance=instance, version=__version__,)
executor = SyncExecutor() if use_sync_executor else GeventExecutor()
result = graphql(
request_string=query,
schema=create_schema(),
context_value=context,
variable_values=variables,
executor=executor,
)
result_dict = result.to_dict()
# Here we detect if this is in fact an error response
# If so, we iterate over the result_dict and the original result
# which contains a GraphQLError. If that GraphQL error contains
# an original_error property (which is the exception the resolver
# has thrown, typically) we serialize the stack trace of that exception
# in the 'stack_trace' property of each error to ease debugging
if "errors" in result_dict:
check.invariant(len(result_dict["errors"]) == len(result.errors))
for python_error, error_dict in zip(result.errors, result_dict["errors"]):
if hasattr(python_error, "original_error") and python_error.original_error:
error_dict["stack_trace"] = get_stack_trace_array(python_error.original_error)
return result_dict
def execute_query_from_cli(workspace, query, instance, variables=None, output=None):
check.inst_param(workspace, "workspace", Workspace)
check.str_param(query, "query")
check.inst_param(instance, "instance", DagsterInstance)
check.opt_str_param(variables, "variables")
check.opt_str_param(output, "output")
query = query.strip("'\" \n\t")
result_dict = execute_query(
workspace,
query,
instance=instance,
variables=seven.json.loads(variables) if variables else None,
)
str_res = seven.json.dumps(result_dict)
# Since this the entry point for CLI execution, some tests depend on us putting the result on
# stdout
if output:
check.str_param(output, "output")
with open(output, "w") as f:
f.write(str_res + "\n")
else:
print(str_res) # pylint: disable=print-call
return str_res
def execute_query_against_remote(host, query, variables):
parsed_url = urlparse(host)
if not (parsed_url.scheme and parsed_url.netloc):
raise click.UsageError(
"Host {host} is not a valid URL. Host URL should include scheme ie http://localhost".format(
host=host
)
)
sanity_check = requests.get(urljoin(host, "/dagit_info"))
sanity_check.raise_for_status()
if "dagit" not in sanity_check.text:
raise click.UsageError(
"Host {host} failed sanity check. It is not a dagit server.".format(host=host)
)
response = requests.post(
urljoin(host, "/graphql"), params={"query": query, "variables": variables}
)
response.raise_for_status()
str_res = response.json()
return str_res
PREDEFINED_QUERIES = {
"launchPipelineExecution": LAUNCH_PIPELINE_EXECUTION_MUTATION,
}
@workspace_target_argument
@click.command(
name="ui",
help=(
"Run a GraphQL query against the dagster interface to a specified repository or pipeline."
"\n\n{warning}".format(warning=WORKSPACE_TARGET_WARNING)
)
+ (
"\n\nExamples:"
"\n\n1. dagster-graphql"
"\n\n2. dagster-graphql -y path/to/{default_filename}"
"\n\n3. dagster-graphql -f path/to/file.py -a define_repo"
"\n\n4. dagster-graphql -m some_module -a define_repo"
"\n\n5. dagster-graphql -f path/to/file.py -a define_pipeline"
"\n\n6. dagster-graphql -m some_module -a define_pipeline"
).format(default_filename=DEFAULT_REPOSITORY_YAML_FILENAME),
)
@click.version_option(version=__version__)
@click.option(
"--text", "-t", type=click.STRING, help="GraphQL document to execute passed as a string"
)
@click.option(
"--file", "-f", type=click.File(), help="GraphQL document to execute passed as a file"
)
@click.option(
"--predefined",
"-p",
type=click.Choice(PREDEFINED_QUERIES.keys()),
help="GraphQL document to execute, from a predefined set provided by dagster-graphql.",
)
@click.option(
"--variables",
"-v",
type=click.STRING,
help="A JSON encoded string containing the variables for GraphQL execution.",
)
@click.option(
"--remote",
"-r",
type=click.STRING,
help="A URL for a remote instance running dagit server to send the GraphQL request to.",
)
@click.option(
"--output",
"-o",
type=click.STRING,
help="A file path to store the GraphQL response to. This flag is useful when making pipeline "
"execution queries, since pipeline execution causes logs to print to stdout and stderr.",
)
@click.option(
"--remap-sigterm", is_flag=True, default=False, help="Remap SIGTERM signal to SIGINT handler",
)
def ui(text, file, predefined, variables, remote, output, remap_sigterm, **kwargs):
query = None
if text is not None and file is None and predefined is None:
query = text.strip("'\" \n\t")
elif file is not None and text is None and predefined is None:
query = file.read()
elif predefined is not None and text is None and file is None:
query = PREDEFINED_QUERIES[predefined]
else:
raise click.UsageError(
"Must select one and only one of text (-t), file (-f), or predefined (-p) "
"to select GraphQL document to execute."
)
if remap_sigterm:
try:
signal.signal(signal.SIGTERM, signal.getsignal(signal.SIGINT))
except ValueError:
warnings.warn(
(
"Unexpected error attempting to manage signal handling on thread {thread_name}. "
"You should not invoke this API (ui) from threads "
"other than the main thread."
).format(thread_name=threading.current_thread().name)
)
if remote:
res = execute_query_against_remote(remote, query, variables)
print(res) # pylint: disable=print-call
else:
instance = DagsterInstance.get()
with get_workspace_from_kwargs(kwargs) as workspace:
execute_query_from_cli(
workspace, query, instance, variables, output,
)
cli = create_dagster_graphql_cli()
def main():
# click magic
cli(obj={}) # pylint:disable=E1120
diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/pipelines.py b/python_modules/dagster-graphql/dagster_graphql/schema/pipelines.py
index 59c6a2fa1..281c7badd 100644
--- a/python_modules/dagster-graphql/dagster_graphql/schema/pipelines.py
+++ b/python_modules/dagster-graphql/dagster_graphql/schema/pipelines.py
@@ -1,412 +1,411 @@
-from __future__ import absolute_import
+from functools import lru_cache
import yaml
from dagster import check
from dagster.core.host_representation import (
ExternalPipeline,
ExternalPresetData,
RepresentedPipeline,
)
from dagster.core.snap import ConfigSchemaSnapshot, LoggerDefSnap, ModeDefSnap, ResourceDefSnap
from dagster.core.storage.pipeline_run import PipelineRunsFilter
from dagster.core.storage.tags import TagType, get_tag_type
-from dagster.seven import lru_cache
from dagster_graphql import dauphin
from dagster_graphql.implementation.fetch_runs import get_runs
from dagster_graphql.implementation.fetch_schedules import get_schedules_for_pipeline
from dagster_graphql.implementation.utils import UserFacingGraphQLError, capture_dauphin_error
from .config_types import DauphinConfigTypeField
from .dagster_types import to_dauphin_dagster_type
from .solids import DauphinSolidContainer, build_dauphin_solid_handles, build_dauphin_solids
class DauphinPipelineReference(dauphin.Interface):
"""This interface supports the case where we can look up a pipeline successfully in the
repository available to the DagsterInstance/graphql context, as well as the case where we know
that a pipeline exists/existed thanks to materialized data such as logs and run metadata, but
where we can't look the concrete pipeline up."""
class Meta:
name = "PipelineReference"
name = dauphin.NonNull(dauphin.String)
solidSelection = dauphin.List(dauphin.NonNull(dauphin.String))
class DauphinUnknownPipeline(dauphin.ObjectType):
class Meta:
name = "UnknownPipeline"
interfaces = (DauphinPipelineReference,)
name = dauphin.NonNull(dauphin.String)
solidSelection = dauphin.List(dauphin.NonNull(dauphin.String))
class DauphinIPipelineSnapshotMixin:
# Mixin this class to implement IPipelineSnapshot
#
# Graphene has some strange properties that make it so that you cannot
# implement ABCs nor use properties in an overridable way. So the way
# the mixin works is that the target classes have to have a method
# get_represented_pipeline()
#
def get_represented_pipeline(self):
raise NotImplementedError()
name = dauphin.NonNull(dauphin.String)
description = dauphin.String()
id = dauphin.NonNull(dauphin.ID)
pipeline_snapshot_id = dauphin.NonNull(dauphin.String)
dagster_types = dauphin.non_null_list("DagsterType")
dagster_type_or_error = dauphin.Field(
dauphin.NonNull("DagsterTypeOrError"),
dagsterTypeName=dauphin.Argument(dauphin.NonNull(dauphin.String)),
)
solids = dauphin.non_null_list("Solid")
modes = dauphin.non_null_list("Mode")
solid_handles = dauphin.Field(
dauphin.non_null_list("SolidHandle"), parentHandleID=dauphin.String()
)
solid_handle = dauphin.Field(
"SolidHandle", handleID=dauphin.Argument(dauphin.NonNull(dauphin.String)),
)
tags = dauphin.non_null_list("PipelineTag")
runs = dauphin.Field(
dauphin.non_null_list("PipelineRun"), cursor=dauphin.String(), limit=dauphin.Int(),
)
schedules = dauphin.non_null_list("Schedule")
parent_snapshot_id = dauphin.String()
def resolve_pipeline_snapshot_id(self, _):
return self.get_represented_pipeline().identifying_pipeline_snapshot_id
def resolve_id(self, _):
return self.get_represented_pipeline().identifying_pipeline_snapshot_id
def resolve_name(self, _):
return self.get_represented_pipeline().name
def resolve_description(self, _):
return self.get_represented_pipeline().description
def resolve_dagster_types(self, _graphene_info):
represented_pipeline = self.get_represented_pipeline()
return sorted(
list(
map(
lambda dt: to_dauphin_dagster_type(
represented_pipeline.pipeline_snapshot, dt.key
),
[t for t in represented_pipeline.dagster_type_snaps if t.name],
)
),
key=lambda dagster_type: dagster_type.name,
)
@capture_dauphin_error
def resolve_dagster_type_or_error(self, _, **kwargs):
type_name = kwargs["dagsterTypeName"]
represented_pipeline = self.get_represented_pipeline()
if not represented_pipeline.has_dagster_type_named(type_name):
from .errors import DauphinDagsterTypeNotFoundError
raise UserFacingGraphQLError(
DauphinDagsterTypeNotFoundError(dagster_type_name=type_name)
)
return to_dauphin_dagster_type(
represented_pipeline.pipeline_snapshot,
represented_pipeline.get_dagster_type_by_name(type_name).key,
)
def resolve_solids(self, _graphene_info):
represented_pipeline = self.get_represented_pipeline()
return build_dauphin_solids(represented_pipeline, represented_pipeline.dep_structure_index,)
def resolve_modes(self, _):
represented_pipeline = self.get_represented_pipeline()
return [
DauphinMode(represented_pipeline.config_schema_snapshot, mode_def_snap)
for mode_def_snap in sorted(
represented_pipeline.mode_def_snaps, key=lambda item: item.name
)
]
def resolve_solid_handle(self, _graphene_info, handleID):
return _get_solid_handles(self.get_represented_pipeline()).get(handleID)
def resolve_solid_handles(self, _graphene_info, **kwargs):
handles = _get_solid_handles(self.get_represented_pipeline())
parentHandleID = kwargs.get("parentHandleID")
if parentHandleID == "":
handles = {key: handle for key, handle in handles.items() if not handle.parent}
elif parentHandleID is not None:
handles = {
key: handle
for key, handle in handles.items()
if handle.parent and handle.parent.handleID.to_string() == parentHandleID
}
return [handles[key] for key in sorted(handles)]
def resolve_tags(self, graphene_info):
represented_pipeline = self.get_represented_pipeline()
return [
graphene_info.schema.type_named("PipelineTag")(key=key, value=value)
for key, value in represented_pipeline.pipeline_snapshot.tags.items()
]
def resolve_solidSelection(self, _graphene_info):
return self.get_represented_pipeline().solid_selection
def resolve_runs(self, graphene_info, **kwargs):
runs_filter = PipelineRunsFilter(pipeline_name=self.get_represented_pipeline().name)
return get_runs(graphene_info, runs_filter, kwargs.get("cursor"), kwargs.get("limit"))
def resolve_schedules(self, graphene_info):
represented_pipeline = self.get_represented_pipeline()
if not isinstance(represented_pipeline, ExternalPipeline):
# this is an historical pipeline snapshot, so there are not any associated running
# schedules
return []
pipeline_selector = represented_pipeline.handle.to_selector()
schedules = get_schedules_for_pipeline(graphene_info, pipeline_selector)
return schedules
def resolve_parent_snapshot_id(self, _graphene_info):
lineage_snapshot = self.get_represented_pipeline().pipeline_snapshot.lineage_snapshot
if lineage_snapshot:
return lineage_snapshot.parent_snapshot_id
else:
return None
class DauphinIPipelineSnapshot(dauphin.Interface):
class Meta:
name = "IPipelineSnapshot"
name = dauphin.NonNull(dauphin.String)
description = dauphin.String()
pipeline_snapshot_id = dauphin.NonNull(dauphin.String)
dagster_types = dauphin.non_null_list("DagsterType")
dagster_type_or_error = dauphin.Field(
dauphin.NonNull("DagsterTypeOrError"),
dagsterTypeName=dauphin.Argument(dauphin.NonNull(dauphin.String)),
)
solids = dauphin.non_null_list("Solid")
modes = dauphin.non_null_list("Mode")
solid_handles = dauphin.Field(
dauphin.non_null_list("SolidHandle"), parentHandleID=dauphin.String()
)
solid_handle = dauphin.Field(
"SolidHandle", handleID=dauphin.Argument(dauphin.NonNull(dauphin.String)),
)
tags = dauphin.non_null_list("PipelineTag")
class DauphinPipeline(DauphinIPipelineSnapshotMixin, dauphin.ObjectType):
class Meta:
name = "Pipeline"
interfaces = (DauphinSolidContainer, DauphinIPipelineSnapshot)
id = dauphin.NonNull(dauphin.ID)
presets = dauphin.non_null_list("PipelinePreset")
runs = dauphin.Field(
dauphin.non_null_list("PipelineRun"), cursor=dauphin.String(), limit=dauphin.Int(),
)
def __init__(self, external_pipeline):
self._external_pipeline = check.inst_param(
external_pipeline, "external_pipeline", ExternalPipeline
)
def resolve_id(self, _graphene_info):
return self._external_pipeline.get_external_origin_id()
def get_represented_pipeline(self):
return self._external_pipeline
def resolve_presets(self, _graphene_info):
return [
DauphinPipelinePreset(preset, self._external_pipeline.name)
for preset in sorted(self._external_pipeline.active_presets, key=lambda item: item.name)
]
@lru_cache(maxsize=32)
def _get_solid_handles(represented_pipeline):
check.inst_param(represented_pipeline, "represented_pipeline", RepresentedPipeline)
return {
str(item.handleID): item
for item in build_dauphin_solid_handles(
represented_pipeline, represented_pipeline.dep_structure_index
)
}
class DauphinResource(dauphin.ObjectType):
class Meta:
name = "Resource"
def __init__(self, config_schema_snapshot, resource_def_snap):
self._config_schema_snapshot = check.inst_param(
config_schema_snapshot, "config_schema_snapshot", ConfigSchemaSnapshot
)
self._resource_dep_snap = check.inst_param(
resource_def_snap, "resource_def_snap", ResourceDefSnap
)
self.name = resource_def_snap.name
self.description = resource_def_snap.description
name = dauphin.NonNull(dauphin.String)
description = dauphin.String()
configField = dauphin.Field("ConfigTypeField")
def resolve_configField(self, _):
return (
DauphinConfigTypeField(
config_schema_snapshot=self._config_schema_snapshot,
field_snap=self._resource_dep_snap.config_field_snap,
)
if self._resource_dep_snap.config_field_snap
else None
)
class DauphinLogger(dauphin.ObjectType):
class Meta:
name = "Logger"
def __init__(self, config_schema_snapshot, logger_def_snap):
self._config_schema_snapshot = check.inst_param(
config_schema_snapshot, "config_schema_snapshot", ConfigSchemaSnapshot
)
self._logger_def_snap = check.inst_param(logger_def_snap, "logger_def_snap", LoggerDefSnap)
self.name = logger_def_snap.name
self.description = logger_def_snap.description
name = dauphin.NonNull(dauphin.String)
description = dauphin.String()
configField = dauphin.Field("ConfigTypeField")
def resolve_configField(self, _):
return (
DauphinConfigTypeField(
config_schema_snapshot=self._config_schema_snapshot,
field_snap=self._logger_def_snap.config_field_snap,
)
if self._logger_def_snap.config_field_snap
else None
)
class DauphinMode(dauphin.ObjectType):
def __init__(self, config_schema_snapshot, mode_def_snap):
self._mode_def_snap = check.inst_param(mode_def_snap, "mode_def_snap", ModeDefSnap)
self._config_schema_snapshot = check.inst_param(
config_schema_snapshot, "config_schema_snapshot", ConfigSchemaSnapshot
)
class Meta:
name = "Mode"
name = dauphin.NonNull(dauphin.String)
description = dauphin.String()
resources = dauphin.non_null_list("Resource")
loggers = dauphin.non_null_list("Logger")
def resolve_name(self, _graphene_info):
return self._mode_def_snap.name
def resolve_description(self, _graphene_info):
return self._mode_def_snap.description
def resolve_resources(self, _graphene_info):
return [
DauphinResource(self._config_schema_snapshot, resource_def_snap)
for resource_def_snap in sorted(self._mode_def_snap.resource_def_snaps)
]
def resolve_loggers(self, _graphene_info):
return [
DauphinLogger(self._config_schema_snapshot, logger_def_snap)
for logger_def_snap in sorted(self._mode_def_snap.logger_def_snaps)
]
class DauphinMetadataItemDefinition(dauphin.ObjectType):
class Meta:
name = "MetadataItemDefinition"
key = dauphin.NonNull(dauphin.String)
value = dauphin.NonNull(dauphin.String)
class DauphinPipelinePreset(dauphin.ObjectType):
class Meta:
name = "PipelinePreset"
name = dauphin.NonNull(dauphin.String)
solidSelection = dauphin.List(dauphin.NonNull(dauphin.String))
runConfigYaml = dauphin.NonNull(dauphin.String)
mode = dauphin.NonNull(dauphin.String)
tags = dauphin.non_null_list("PipelineTag")
def __init__(self, active_preset_data, pipeline_name):
self._active_preset_data = check.inst_param(
active_preset_data, "active_preset_data", ExternalPresetData
)
self._pipeline_name = check.str_param(pipeline_name, "pipeline_name")
def resolve_name(self, _graphene_info):
return self._active_preset_data.name
def resolve_solidSelection(self, _graphene_info):
return self._active_preset_data.solid_selection
def resolve_runConfigYaml(self, _graphene_info):
yaml_str = yaml.safe_dump(
self._active_preset_data.run_config, default_flow_style=False, allow_unicode=True
)
return yaml_str if yaml_str else ""
def resolve_mode(self, _graphene_info):
return self._active_preset_data.mode
def resolve_tags(self, graphene_info):
return [
graphene_info.schema.type_named("PipelineTag")(key=key, value=value)
for key, value in self._active_preset_data.tags.items()
if get_tag_type(key) != TagType.HIDDEN
]
class DauphinPipelineSnapshot(DauphinIPipelineSnapshotMixin, dauphin.ObjectType):
def __init__(self, represented_pipeline):
self._represented_pipeline = check.inst_param(
represented_pipeline, "represented_pipeline", RepresentedPipeline
)
class Meta:
name = "PipelineSnapshot"
interfaces = (DauphinIPipelineSnapshot, DauphinPipelineReference)
def get_represented_pipeline(self):
return self._represented_pipeline
class DauphinPipelineSnapshotOrError(dauphin.Union):
class Meta:
name = "PipelineSnapshotOrError"
types = (
"PipelineSnapshot",
"PipelineSnapshotNotFoundError",
"PipelineNotFoundError",
"PythonError",
)
diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/conftest.py b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/conftest.py
index e9b30f2b8..71a93f210 100644
--- a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/conftest.py
+++ b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/conftest.py
@@ -1,38 +1,39 @@
+import tempfile
+
import pytest
-from dagster import seven
from dagster.core.test_utils import instance_for_test_tempdir
from .setup import define_test_in_process_context, define_test_out_of_process_context
@pytest.yield_fixture(scope="function")
def graphql_context():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
}
},
) as instance:
with define_test_out_of_process_context(instance) as context:
yield context
@pytest.yield_fixture(scope="function")
def graphql_in_process_context():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
}
},
) as instance:
yield define_test_in_process_context(instance)
diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py
index 064306f27..04882a937 100644
--- a/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py
+++ b/python_modules/dagster-graphql/dagster_graphql_tests/graphql/graphql_context_test_suite.py
@@ -1,828 +1,829 @@
import sys
+import tempfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
import pytest
-from dagster import check, file_relative_path, seven
+from dagster import check, file_relative_path
from dagster.cli.workspace import Workspace
from dagster.core.definitions.reconstructable import ReconstructableRepository
from dagster.core.host_representation import (
GrpcServerRepositoryLocationOrigin,
InProcessRepositoryLocationOrigin,
ManagedGrpcPythonEnvRepositoryLocationOrigin,
)
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher.sync_in_memory_run_launcher import SyncInMemoryRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.event_log import InMemoryEventLogStorage
from dagster.core.storage.event_log.sqlite import ConsolidatedSqliteEventLogStorage
from dagster.core.storage.local_compute_log_manager import LocalComputeLogManager
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import InMemoryRunStorage
from dagster.core.storage.schedules.sqlite.sqlite_schedule_storage import SqliteScheduleStorage
from dagster.core.test_utils import ExplodingRunLauncher, instance_for_test_tempdir
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.grpc.server import GrpcServerProcess
from dagster.utils import merge_dicts
from dagster.utils.test.postgres_instance import TestPostgresInstance
from dagster_graphql.implementation.context import DagsterGraphQLContext
def get_main_recon_repo():
return ReconstructableRepository.for_file(file_relative_path(__file__, "setup.py"), "test_repo")
@contextmanager
def graphql_postgres_instance(overrides):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with TestPostgresInstance.docker_service_up_or_skip(
file_relative_path(__file__, "docker-compose.yml"), "test-postgres-db-graphql",
) as pg_conn_string:
TestPostgresInstance.clean_run_storage(pg_conn_string)
TestPostgresInstance.clean_event_log_storage(pg_conn_string)
TestPostgresInstance.clean_schedule_storage(pg_conn_string)
with instance_for_test_tempdir(
temp_dir,
overrides=merge_dicts(
{
"run_storage": {
"module": "dagster_postgres.run_storage.run_storage",
"class": "PostgresRunStorage",
"config": {"postgres_url": pg_conn_string},
},
"event_log_storage": {
"module": "dagster_postgres.event_log.event_log",
"class": "PostgresEventLogStorage",
"config": {"postgres_url": pg_conn_string},
},
"schedule_storage": {
"module": "dagster_postgres.schedule_storage.schedule_storage",
"class": "PostgresScheduleStorage",
"config": {"postgres_url": pg_conn_string},
},
},
overrides if overrides else {},
),
) as instance:
yield instance
class MarkedManager:
"""
MarkedManagers are passed to GraphQLContextVariants. They contain
a contextmanager function "manager_fn" that yield the relevant
instace, and it includes marks that will be applied to any
context-variant-driven test case that includes this MarkedManager.
See InstanceManagers for an example construction.
See GraphQLContextVariant for further information
"""
def __init__(self, manager_fn, marks):
self.manager_fn = check.callable_param(manager_fn, "manager_fn")
self.marks = check.list_param(marks, "marks")
class InstanceManagers:
@staticmethod
def in_memory_instance():
@contextmanager
def _in_memory_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
yield DagsterInstance(
instance_type=InstanceType.EPHEMERAL,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=InMemoryRunStorage(),
event_storage=InMemoryEventLogStorage(),
compute_log_manager=LocalComputeLogManager(temp_dir),
run_launcher=SyncInMemoryRunLauncher(),
run_coordinator=DefaultRunCoordinator(),
schedule_storage=SqliteScheduleStorage.from_local(temp_dir),
)
return MarkedManager(_in_memory_instance, [Marks.in_memory_instance])
@staticmethod
def readonly_in_memory_instance():
@contextmanager
def _readonly_in_memory_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
yield DagsterInstance(
instance_type=InstanceType.EPHEMERAL,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=InMemoryRunStorage(),
event_storage=InMemoryEventLogStorage(),
compute_log_manager=LocalComputeLogManager(temp_dir),
run_launcher=ExplodingRunLauncher(),
run_coordinator=DefaultRunCoordinator(),
schedule_storage=SqliteScheduleStorage.from_local(temp_dir),
)
return MarkedManager(
_readonly_in_memory_instance, [Marks.in_memory_instance, Marks.readonly],
)
@staticmethod
def readonly_sqlite_instance():
@contextmanager
def _readonly_sqlite_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
},
"run_launcher": {
"module": "dagster.core.test_utils",
"class": "ExplodingRunLauncher",
},
},
) as instance:
yield instance
return MarkedManager(_readonly_sqlite_instance, [Marks.sqlite_instance, Marks.readonly])
@staticmethod
def readonly_postgres_instance():
@contextmanager
def _readonly_postgres_instance():
with graphql_postgres_instance(
overrides={
"run_launcher": {
"module": "dagster.core.test_utils",
"class": "ExplodingRunLauncher",
}
}
) as instance:
yield instance
return MarkedManager(
_readonly_postgres_instance, [Marks.postgres_instance, Marks.readonly],
)
@staticmethod
def sqlite_instance_with_sync_run_launcher():
@contextmanager
def _sqlite_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
},
"run_launcher": {
"module": "dagster.core.launcher.sync_in_memory_run_launcher",
"class": "SyncInMemoryRunLauncher",
},
},
) as instance:
yield instance
return MarkedManager(_sqlite_instance, [Marks.sqlite_instance, Marks.sync_run_launcher])
# Runs launched with this instance won't actually execute since the graphql test suite
# doesn't run the daemon process that launches queued runs
@staticmethod
def sqlite_instance_with_queued_run_coordinator():
@contextmanager
def _sqlite_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"run_coordinator": {
"module": "dagster.core.run_coordinator.queued_run_coordinator",
"class": "QueuedRunCoordinator",
},
},
) as instance:
yield instance
return MarkedManager(
_sqlite_instance, [Marks.sqlite_instance, Marks.queued_run_coordinator]
)
@staticmethod
def sqlite_instance_with_default_run_launcher():
@contextmanager
def _sqlite_instance_with_default_hijack():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
},
"run_launcher": {"module": "dagster", "class": "DefaultRunLauncher",},
},
) as instance:
yield instance
return MarkedManager(
_sqlite_instance_with_default_hijack,
[Marks.sqlite_instance, Marks.default_run_launcher],
)
@staticmethod
def postgres_instance_with_sync_run_launcher():
@contextmanager
def _postgres_instance():
with graphql_postgres_instance(
overrides={
"run_launcher": {
"module": "dagster.core.launcher.sync_in_memory_run_launcher",
"class": "SyncInMemoryRunLauncher",
}
}
) as instance:
yield instance
return MarkedManager(
_postgres_instance, [Marks.postgres_instance, Marks.sync_run_launcher],
)
@staticmethod
def postgres_instance_with_default_run_launcher():
@contextmanager
def _postgres_instance_with_default_hijack():
with graphql_postgres_instance(
overrides={"run_launcher": {"module": "dagster", "class": "DefaultRunLauncher",},}
) as instance:
yield instance
return MarkedManager(
_postgres_instance_with_default_hijack,
[Marks.postgres_instance, Marks.default_run_launcher],
)
@staticmethod
def asset_aware_sqlite_instance():
@contextmanager
def _sqlite_asset_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
instance = DagsterInstance(
instance_type=InstanceType.EPHEMERAL,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=InMemoryRunStorage(),
event_storage=ConsolidatedSqliteEventLogStorage(temp_dir),
compute_log_manager=LocalComputeLogManager(temp_dir),
run_coordinator=DefaultRunCoordinator(),
run_launcher=SyncInMemoryRunLauncher(),
)
yield instance
return MarkedManager(_sqlite_asset_instance, [Marks.asset_aware_instance])
class EnvironmentManagers:
@staticmethod
def user_code_in_host_process():
@contextmanager
def _mgr_fn(recon_repo):
check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
with Workspace([InProcessRepositoryLocationOrigin(recon_repo)]) as workspace:
yield workspace
return MarkedManager(_mgr_fn, [Marks.hosted_user_process_env])
@staticmethod
def managed_grpc():
@contextmanager
def _mgr_fn(recon_repo):
"""Goes out of process via grpc"""
check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
loadable_target_origin = recon_repo.get_python_origin().loadable_target_origin
with Workspace(
[
ManagedGrpcPythonEnvRepositoryLocationOrigin(
loadable_target_origin=loadable_target_origin, location_name="test",
)
]
) as workspace:
yield workspace
return MarkedManager(_mgr_fn, [Marks.managed_grpc_env])
@staticmethod
def deployed_grpc():
@contextmanager
def _mgr_fn(recon_repo):
check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
loadable_target_origin = recon_repo.get_python_origin().loadable_target_origin
server_process = GrpcServerProcess(loadable_target_origin=loadable_target_origin)
try:
with server_process.create_ephemeral_client() as api_client:
with Workspace(
[
GrpcServerRepositoryLocationOrigin(
port=api_client.port,
socket=api_client.socket,
host=api_client.host,
location_name="test",
)
]
) as workspace:
yield workspace
finally:
server_process.wait()
return MarkedManager(_mgr_fn, [Marks.deployed_grpc_env])
@staticmethod
def multi_location():
@contextmanager
def _mgr_fn(recon_repo):
"""Goes out of process but same process as host process"""
check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
with Workspace(
[
ManagedGrpcPythonEnvRepositoryLocationOrigin(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable,
python_file=file_relative_path(__file__, "setup.py"),
attribute="test_repo",
),
location_name="test",
),
ManagedGrpcPythonEnvRepositoryLocationOrigin(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable,
python_file=file_relative_path(__file__, "setup.py"),
attribute="empty_repo",
),
location_name="empty_repo",
),
]
) as workspace:
yield workspace
return MarkedManager(_mgr_fn, [Marks.multi_location])
class Marks:
# Instance type makes
in_memory_instance = pytest.mark.in_memory_instance
sqlite_instance = pytest.mark.sqlite_instance
postgres_instance = pytest.mark.postgres_instance
# Run launcher variants
sync_run_launcher = pytest.mark.sync_run_launcher
default_run_launcher = pytest.mark.default_run_launcher
queued_run_coordinator = pytest.mark.queued_run_coordinator
readonly = pytest.mark.readonly
# Repository Location marks
hosted_user_process_env = pytest.mark.hosted_user_process_env
multi_location = pytest.mark.multi_location
managed_grpc_env = pytest.mark.managed_grpc_env
deployed_grpc_env = pytest.mark.deployed_grpc_env
# Asset-aware sqlite variants
asset_aware_instance = pytest.mark.asset_aware_instance
# Common mark to all test suite tests
graphql_context_test_suite = pytest.mark.graphql_context_test_suite
def none_manager():
@contextmanager
def _yield_none(*_args, **_kwargs):
yield None
return MarkedManager(_yield_none, [])
class GraphQLContextVariant:
"""
An instance of this class represents a context variant that will be run
against *every* method in the test class, defined as a class
created by inheriting from make_graphql_context_test_suite.
It comes with a number of static methods with prebuilt context variants.
e.g. in_memory_in_process_start
One can also make bespoke context variants, provided you configure it properly
with MarkedMembers that produce its members.
Args:
marked_instance_mgr (MarkedManager): The manager_fn
within it must be a contextmanager that takes zero arguments and yields
a DagsterInstance
See InstanceManagers for examples
marked_environment_mgr (MarkedManager): The manager_fn with in
must be a contextmanager takes a default ReconstructableRepo and
yields a list of RepositoryLocation.
See EnvironmentManagers for examples
test_id [Optional] (str): This assigns a test_id to test parameterized with this
variant. This is highly convenient for running a particular variant across
the entire test suite, without running all the other variants.
e.g.
pytest python_modules/dagster-graphql/dagster_graphql_tests/ -s -k in_memory_in_process_start
Will run all tests that use the in_memory_in_process_start, which will get a lot
of code coverage while being very fast to run.
All tests managed by this system are marked with "graphql_context_test_suite".
"""
def __init__(self, marked_instance_mgr, marked_environment_mgr, test_id=None):
self.marked_instance_mgr = check.inst_param(
marked_instance_mgr, "marked_instance_mgr", MarkedManager
)
self.marked_environment_mgr = check.inst_param(
marked_environment_mgr, "marked_environment_mgr", MarkedManager
)
self.test_id = check.opt_str_param(test_id, "test_id")
self.marks = marked_instance_mgr.marks + marked_environment_mgr.marks
@property
def instance_mgr(self):
return self.marked_instance_mgr.manager_fn
@property
def environment_mgr(self):
return self.marked_environment_mgr.manager_fn
@staticmethod
def in_memory_instance_in_process_env():
"""
Good for tests with read-only metadata queries. Does not work
if you have to go through the run launcher.
"""
return GraphQLContextVariant(
InstanceManagers.in_memory_instance(),
EnvironmentManagers.user_code_in_host_process(),
test_id="in_memory_instance_in_process_env",
)
@staticmethod
def in_memory_instance_managed_grpc_env():
"""
Good for tests with read-only metadata queries. Does not work
if you have to go through the run launcher.
"""
return GraphQLContextVariant(
InstanceManagers.in_memory_instance(),
EnvironmentManagers.managed_grpc(),
test_id="in_memory_instance_managed_grpc_env",
)
@staticmethod
def sqlite_with_sync_run_launcher_in_process_env():
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_sync_run_launcher(),
EnvironmentManagers.user_code_in_host_process(),
test_id="sqlite_with_sync_run_launcher_in_process_env",
)
@staticmethod
def sqlite_with_default_run_launcher_in_process_env():
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_default_run_launcher(),
EnvironmentManagers.user_code_in_host_process(),
test_id="sqlite_with_default_run_launcher_in_process_env",
)
@staticmethod
def sqlite_with_queued_run_coordinator_managed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_queued_run_coordinator(),
EnvironmentManagers.managed_grpc(),
test_id="sqlite_with_queued_run_coordinator_managed_grpc_env",
)
@staticmethod
def sqlite_with_default_run_launcher_managed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_default_run_launcher(),
EnvironmentManagers.managed_grpc(),
test_id="sqlite_with_default_run_launcher_managed_grpc_env",
)
@staticmethod
def sqlite_with_default_run_launcher_deployed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.sqlite_instance_with_default_run_launcher(),
EnvironmentManagers.deployed_grpc(),
test_id="sqlite_with_default_run_launcher_deployed_grpc_env",
)
@staticmethod
def postgres_with_sync_run_launcher_in_process_env():
return GraphQLContextVariant(
InstanceManagers.postgres_instance_with_sync_run_launcher(),
EnvironmentManagers.user_code_in_host_process(),
test_id="postgres_with_sync_run_launcher_in_process_env",
)
@staticmethod
def postgres_with_default_run_launcher_in_process_env():
return GraphQLContextVariant(
InstanceManagers.postgres_instance_with_default_run_launcher(),
EnvironmentManagers.user_code_in_host_process(),
test_id="postgres_with_default_run_launcher_in_process_env",
)
@staticmethod
def postgres_with_default_run_launcher_managed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.postgres_instance_with_default_run_launcher(),
EnvironmentManagers.managed_grpc(),
test_id="postgres_with_default_run_launcher_managed_grpc_env",
)
@staticmethod
def postgres_with_default_run_launcher_deployed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.postgres_instance_with_default_run_launcher(),
EnvironmentManagers.deployed_grpc(),
test_id="postgres_with_default_run_launcher_deployed_grpc_env",
)
@staticmethod
def readonly_sqlite_instance_in_process_env():
return GraphQLContextVariant(
InstanceManagers.readonly_sqlite_instance(),
EnvironmentManagers.user_code_in_host_process(),
test_id="readonly_sqlite_instance_in_process_env",
)
@staticmethod
def readonly_sqlite_instance_multi_location():
return GraphQLContextVariant(
InstanceManagers.readonly_sqlite_instance(),
EnvironmentManagers.multi_location(),
test_id="readonly_sqlite_instance_multi_location",
)
@staticmethod
def readonly_sqlite_instance_managed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.readonly_sqlite_instance(),
EnvironmentManagers.managed_grpc(),
test_id="readonly_sqlite_instance_managed_grpc_env",
)
@staticmethod
def readonly_sqlite_instance_deployed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.readonly_sqlite_instance(),
EnvironmentManagers.deployed_grpc(),
test_id="readonly_sqlite_instance_deployed_grpc_env",
)
@staticmethod
def readonly_postgres_instance_in_process_env():
return GraphQLContextVariant(
InstanceManagers.readonly_postgres_instance(),
EnvironmentManagers.user_code_in_host_process(),
test_id="readonly_postgres_instance_in_process_env",
)
@staticmethod
def readonly_postgres_instance_multi_location():
return GraphQLContextVariant(
InstanceManagers.readonly_postgres_instance(),
EnvironmentManagers.multi_location(),
test_id="readonly_postgres_instance_multi_location",
)
@staticmethod
def readonly_postgres_instance_managed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.readonly_postgres_instance(),
EnvironmentManagers.managed_grpc(),
test_id="readonly_postgres_instance_managed_grpc_env",
)
@staticmethod
def readonly_in_memory_instance_in_process_env():
return GraphQLContextVariant(
InstanceManagers.readonly_in_memory_instance(),
EnvironmentManagers.user_code_in_host_process(),
test_id="readonly_in_memory_instance_in_process_env",
)
@staticmethod
def readonly_in_memory_instance_multi_location():
return GraphQLContextVariant(
InstanceManagers.readonly_in_memory_instance(),
EnvironmentManagers.multi_location(),
test_id="readonly_in_memory_instance_multi_location",
)
@staticmethod
def readonly_in_memory_instance_managed_grpc_env():
return GraphQLContextVariant(
InstanceManagers.readonly_in_memory_instance(),
EnvironmentManagers.managed_grpc(),
test_id="readonly_in_memory_instance_managed_grpc_env",
)
@staticmethod
def asset_aware_sqlite_instance_in_process_env():
return GraphQLContextVariant(
InstanceManagers.asset_aware_sqlite_instance(),
EnvironmentManagers.user_code_in_host_process(),
test_id="asset_aware_instance_in_process_env",
)
@staticmethod
def all_variants():
"""
There is a test case that keeps this up-to-date. If you add a static
method that returns a GraphQLContextVariant you have to add it to this
list in order for tests to pass.
"""
return [
GraphQLContextVariant.in_memory_instance_in_process_env(),
GraphQLContextVariant.in_memory_instance_managed_grpc_env(),
GraphQLContextVariant.sqlite_with_sync_run_launcher_in_process_env(),
GraphQLContextVariant.sqlite_with_default_run_launcher_in_process_env(),
GraphQLContextVariant.sqlite_with_default_run_launcher_managed_grpc_env(),
GraphQLContextVariant.sqlite_with_default_run_launcher_deployed_grpc_env(),
GraphQLContextVariant.sqlite_with_queued_run_coordinator_managed_grpc_env(),
GraphQLContextVariant.postgres_with_sync_run_launcher_in_process_env(),
GraphQLContextVariant.postgres_with_default_run_launcher_in_process_env(),
GraphQLContextVariant.postgres_with_default_run_launcher_managed_grpc_env(),
GraphQLContextVariant.postgres_with_default_run_launcher_deployed_grpc_env(),
GraphQLContextVariant.readonly_in_memory_instance_in_process_env(),
GraphQLContextVariant.readonly_in_memory_instance_multi_location(),
GraphQLContextVariant.readonly_in_memory_instance_managed_grpc_env(),
GraphQLContextVariant.readonly_sqlite_instance_in_process_env(),
GraphQLContextVariant.readonly_sqlite_instance_multi_location(),
GraphQLContextVariant.readonly_sqlite_instance_managed_grpc_env(),
GraphQLContextVariant.readonly_sqlite_instance_deployed_grpc_env(),
GraphQLContextVariant.readonly_postgres_instance_in_process_env(),
GraphQLContextVariant.readonly_postgres_instance_multi_location(),
GraphQLContextVariant.readonly_postgres_instance_managed_grpc_env(),
GraphQLContextVariant.asset_aware_sqlite_instance_in_process_env(),
]
@staticmethod
def all_executing_variants():
return [
GraphQLContextVariant.in_memory_instance_in_process_env(),
GraphQLContextVariant.sqlite_with_sync_run_launcher_in_process_env(),
] + GraphQLContextVariant.all_out_of_process_executing_variants()
@staticmethod
def all_out_of_process_executing_variants():
return [
GraphQLContextVariant.sqlite_with_default_run_launcher_managed_grpc_env(),
GraphQLContextVariant.sqlite_with_default_run_launcher_deployed_grpc_env(),
GraphQLContextVariant.postgres_with_default_run_launcher_managed_grpc_env(),
GraphQLContextVariant.postgres_with_default_run_launcher_deployed_grpc_env(),
]
@staticmethod
def all_readonly_variants():
"""
Return all readonly variants. If you try to start or launch these will error
"""
return _variants_with_mark(GraphQLContextVariant.all_variants(), pytest.mark.readonly)
def _variants_with_mark(variants, mark):
def _yield_all():
for variant in variants:
if mark in variant.marks:
yield variant
return list(_yield_all())
def _variants_without_marks(variants, marks):
def _yield_all():
for variant in variants:
if all(mark not in variant.marks for mark in marks):
yield variant
return list(_yield_all())
@contextmanager
def manage_graphql_context(context_variant, recon_repo=None):
recon_repo = recon_repo if recon_repo else get_main_recon_repo()
with context_variant.instance_mgr() as instance:
with context_variant.environment_mgr(recon_repo) as workspace:
yield DagsterGraphQLContext(instance=instance, workspace=workspace)
class _GraphQLContextTestSuite(ABC):
@abstractmethod
def yield_graphql_context(self, request):
pass
@abstractmethod
def recon_repo(self):
pass
@contextmanager
def graphql_context_for_request(self, request):
check.param_invariant(
isinstance(request.param, GraphQLContextVariant),
"request",
"params in fixture must be List[GraphQLContextVariant]",
)
with manage_graphql_context(request.param, self.recon_repo()) as graphql_context:
yield graphql_context
def graphql_context_variants_fixture(context_variants):
check.list_param(context_variants, "context_variants", of_type=GraphQLContextVariant)
def _wrap(fn):
return pytest.fixture(
name="graphql_context",
params=[
pytest.param(
context_variant,
id=context_variant.test_id,
marks=context_variant.marks + [Marks.graphql_context_test_suite],
)
for context_variant in context_variants
],
)(fn)
return _wrap
def make_graphql_context_test_suite(context_variants, recon_repo=None):
"""
Arguments:
runs (List[GraphQLContextVariant]): List of runs to run per test in this class.
recon_repo (ReconstructableRepository): Repository to run against. Defaults
to "define_repository" in setup.py
This is the base class factory for test suites in the dagster-graphql test.
The goal of this suite is to make it straightforward to run tests
against multiple graphql_contexts, have a coherent lifecycle for those
contexts.
GraphQLContextVariant has a number of static methods to provide common run configurations
as well as common groups of run configuration
One can also make bespoke GraphQLContextVariants which specific implementations
of DagsterInstance, RepositoryLocation, and so forth. See that class
for more details.
Example:
class TestAThing(
make_graphql_context_test_suite(
context_variants=[GraphQLContextVariant.in_memory_in_process_start()]
)
):
def test_graphql_context_exists(self, graphql_context):
assert graphql_context
"""
check.list_param(context_variants, "context_variants", of_type=GraphQLContextVariant)
recon_repo = check.inst_param(
recon_repo if recon_repo else get_main_recon_repo(), "recon_repo", ReconstructableRepository
)
class _SpecificTestSuiteBase(_GraphQLContextTestSuite):
@graphql_context_variants_fixture(context_variants=context_variants)
def yield_graphql_context(self, request):
with self.graphql_context_for_request(request) as graphql_context:
yield graphql_context
def recon_repo(self):
return recon_repo
return _SpecificTestSuiteBase
ReadonlyGraphQLContextTestMatrix = make_graphql_context_test_suite(
context_variants=GraphQLContextVariant.all_readonly_variants()
)
ExecutingGraphQLContextTestMatrix = make_graphql_context_test_suite(
context_variants=GraphQLContextVariant.all_executing_variants()
)
OutOfProcessExecutingGraphQLContextTestMatrix = make_graphql_context_test_suite(
context_variants=GraphQLContextVariant.all_out_of_process_executing_variants()
)
diff --git a/python_modules/dagster-graphql/dagster_graphql_tests/test_cli.py b/python_modules/dagster-graphql/dagster_graphql_tests/test_cli.py
index 524d3cfbb..b1e202fde 100644
--- a/python_modules/dagster-graphql/dagster_graphql_tests/test_cli.py
+++ b/python_modules/dagster-graphql/dagster_graphql_tests/test_cli.py
@@ -1,328 +1,329 @@
import json
import os
+import tempfile
import time
from contextlib import contextmanager
from click.testing import CliRunner
from dagster import seven
from dagster.core.storage.pipeline_run import PipelineRunStatus
from dagster.core.test_utils import instance_for_test_tempdir
from dagster.utils import file_relative_path
from dagster_graphql.cli import ui
@contextmanager
def dagster_cli_runner():
- with seven.TemporaryDirectory() as dagster_home_temp:
+ with tempfile.TemporaryDirectory() as dagster_home_temp:
with instance_for_test_tempdir(
dagster_home_temp,
overrides={
"run_launcher": {
"module": "dagster.core.launcher.sync_in_memory_run_launcher",
"class": "SyncInMemoryRunLauncher",
}
},
):
yield CliRunner(env={"DAGSTER_HOME": dagster_home_temp})
def test_basic_introspection():
query = "{ __schema { types { name } } }"
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(ui, ["-w", workspace_path, "-t", query])
assert result.exit_code == 0
result_data = json.loads(result.output)
assert result_data["data"]
def test_basic_repositories():
query = "{ repositoriesOrError { ... on RepositoryConnection { nodes { name } } } }"
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(ui, ["-w", workspace_path, "-t", query])
assert result.exit_code == 0
result_data = json.loads(result.output)
assert result_data["data"]["repositoriesOrError"]["nodes"]
def test_basic_repository_locations():
query = "{ repositoryLocationsOrError { ... on RepositoryLocationConnection { nodes { ... on RepositoryLocation { __typename, name } ... on RepositoryLocationLoadFailure { __typename, name, error { message } } } } } }"
workspace_path = file_relative_path(__file__, "./cli_test_error_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(ui, ["-w", workspace_path, "-t", query])
assert result.exit_code == 0
result_data = json.loads(result.output)
nodes = result_data["data"]["repositoryLocationsOrError"]["nodes"]
assert len(nodes) == 2
assert nodes[0]["__typename"] == "RepositoryLocation"
assert nodes[0]["name"] == "test_cli_location"
assert nodes[1]["__typename"] == "RepositoryLocationLoadFailure"
assert nodes[1]["name"] == "test_cli_location_error"
assert "No module named" in nodes[1]["error"]["message"]
def test_basic_variables():
query = """
query FooBar($pipelineName: String! $repositoryName: String! $repositoryLocationName: String!){
pipelineOrError(params:{pipelineName: $pipelineName repositoryName: $repositoryName repositoryLocationName: $repositoryLocationName})
{ ... on Pipeline { name } }
}
"""
variables = '{"pipelineName": "math", "repositoryName": "test", "repositoryLocationName": "test_cli_location"}'
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(ui, ["-w", workspace_path, "-v", variables, "-t", query])
assert result.exit_code == 0
result_data = json.loads(result.output)
assert result_data["data"]["pipelineOrError"]["name"] == "math"
LAUNCH_PIPELINE_EXECUTION_QUERY = """
mutation ($executionParams: ExecutionParams!) {
launchPipelineExecution(executionParams: $executionParams) {
__typename
... on LaunchPipelineRunSuccess {
run {
runId
pipeline { ...on PipelineReference { name } }
}
}
... on PipelineConfigValidationInvalid {
pipelineName
errors { message }
}
... on PipelineNotFoundError {
pipelineName
}
... on PythonError {
message
stack
}
}
}
"""
def test_start_execution_text():
variables = seven.json.dumps(
{
"executionParams": {
"selector": {
"repositoryLocationName": "test_cli_location",
"repositoryName": "test",
"pipelineName": "math",
},
"runConfigData": {"solids": {"add_one": {"inputs": {"num": {"value": 123}}}}},
"mode": "default",
}
}
)
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(
ui, ["-w", workspace_path, "-v", variables, "-t", LAUNCH_PIPELINE_EXECUTION_QUERY]
)
assert result.exit_code == 0
try:
result_data = json.loads(result.output.strip("\n").split("\n")[-1])
assert (
result_data["data"]["launchPipelineExecution"]["__typename"]
== "LaunchPipelineRunSuccess"
)
except Exception as e:
raise Exception("Failed with {} Exception: {}".format(result.output, e))
def test_start_execution_file():
variables = seven.json.dumps(
{
"executionParams": {
"selector": {
"pipelineName": "math",
"repositoryLocationName": "test_cli_location",
"repositoryName": "test",
},
"runConfigData": {"solids": {"add_one": {"inputs": {"num": {"value": 123}}}}},
"mode": "default",
}
}
)
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(
ui,
[
"-w",
workspace_path,
"-v",
variables,
"--file",
file_relative_path(__file__, "./execute.graphql"),
],
)
assert result.exit_code == 0
result_data = json.loads(result.output.strip("\n").split("\n")[-1])
assert (
result_data["data"]["launchPipelineExecution"]["__typename"]
== "LaunchPipelineRunSuccess"
)
def test_start_execution_save_output():
"""
Test that the --output flag saves the GraphQL response to the specified file
"""
variables = seven.json.dumps(
{
"executionParams": {
"selector": {
"repositoryLocationName": "test_cli_location",
"repositoryName": "test",
"pipelineName": "math",
},
"runConfigData": {"solids": {"add_one": {"inputs": {"num": {"value": 123}}}}},
"mode": "default",
}
}
)
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
file_name = os.path.join(temp_dir, "output_file")
result = runner.invoke(
ui,
[
"-w",
workspace_path,
"-v",
variables,
"--file",
file_relative_path(__file__, "./execute.graphql"),
"--output",
file_name,
],
)
assert result.exit_code == 0
assert os.path.isfile(file_name)
with open(file_name, "r") as f:
lines = f.readlines()
result_data = json.loads(lines[-1])
assert (
result_data["data"]["launchPipelineExecution"]["__typename"]
== "LaunchPipelineRunSuccess"
)
def test_start_execution_predefined():
variables = seven.json.dumps(
{
"executionParams": {
"selector": {
"repositoryLocationName": "test_cli_location",
"repositoryName": "test",
"pipelineName": "math",
},
"runConfigData": {"solids": {"add_one": {"inputs": {"num": {"value": 123}}}}},
"mode": "default",
}
}
)
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
with dagster_cli_runner() as runner:
result = runner.invoke(
ui, ["-w", workspace_path, "-v", variables, "-p", "launchPipelineExecution"]
)
assert result.exit_code == 0
result_data = json.loads(result.output.strip("\n").split("\n")[-1])
if not result_data.get("data"):
raise Exception(result_data)
assert (
result_data["data"]["launchPipelineExecution"]["__typename"]
== "LaunchPipelineRunSuccess"
)
def test_logs_in_start_execution_predefined():
variables = seven.json.dumps(
{
"executionParams": {
"selector": {
"repositoryLocationName": "test_cli_location",
"repositoryName": "test",
"pipelineName": "math",
},
"runConfigData": {"solids": {"add_one": {"inputs": {"num": {"value": 123}}}}},
"mode": "default",
}
}
)
workspace_path = file_relative_path(__file__, "./cli_test_workspace.yaml")
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(
temp_dir,
overrides={
"run_launcher": {
"module": "dagster.core.launcher.sync_in_memory_run_launcher",
"class": "SyncInMemoryRunLauncher",
}
},
) as instance:
runner = CliRunner(env={"DAGSTER_HOME": temp_dir})
result = runner.invoke(
ui, ["-w", workspace_path, "-v", variables, "-p", "launchPipelineExecution"]
)
assert result.exit_code == 0
result_data = json.loads(result.output.strip("\n").split("\n")[-1])
assert (
result_data["data"]["launchPipelineExecution"]["__typename"]
== "LaunchPipelineRunSuccess"
)
run_id = result_data["data"]["launchPipelineExecution"]["run"]["runId"]
# allow FS events to flush
retries = 5
while retries != 0 and not _is_done(instance, run_id):
time.sleep(0.333)
retries -= 1
# assert that the watching run storage captured the run correctly from the other process
run = instance.get_run_by_id(run_id)
assert run.status == PipelineRunStatus.SUCCESS
def _is_done(instance, run_id):
return instance.has_run(run_id) and instance.get_run_by_id(run_id).is_finished
diff --git a/python_modules/dagster/dagster/builtins.py b/python_modules/dagster/dagster/builtins.py
index c5956a02a..c74a74209 100644
--- a/python_modules/dagster/dagster/builtins.py
+++ b/python_modules/dagster/dagster/builtins.py
@@ -1,27 +1,28 @@
import typing
class BuiltinEnum:
ANY = typing.Any
- BOOL = typing.NewType("Bool", bool)
- FLOAT = typing.NewType("Float", float)
- INT = typing.NewType("Int", int)
- STRING = typing.NewType("String", str)
- NOTHING = typing.NewType("Nothing", None)
+ # mypy doesn't like the mismatch between BOOL and "Bool"
+ BOOL = typing.NewType("Bool", bool) # type: ignore[misc]
+ FLOAT = typing.NewType("Float", float) # type: ignore[misc]
+ INT = typing.NewType("Int", int) # type: ignore[misc]
+ STRING = typing.NewType("String", str) # type: ignore[misc]
+ NOTHING = typing.NewType("Nothing", None) # type: ignore[misc]
@classmethod
def contains(cls, value):
for ttype in [cls.ANY, cls.BOOL, cls.FLOAT, cls.INT, cls.STRING, cls.NOTHING]:
if value == ttype:
return True
return False
Any = BuiltinEnum.ANY
String = BuiltinEnum.STRING
Int = BuiltinEnum.INT
Bool = BuiltinEnum.BOOL
Float = BuiltinEnum.FLOAT
Nothing = BuiltinEnum.NOTHING
diff --git a/python_modules/dagster/dagster/check/__init__.py b/python_modules/dagster/dagster/check/__init__.py
index 9ce242af8..d209ee6f0 100644
--- a/python_modules/dagster/dagster/check/__init__.py
+++ b/python_modules/dagster/dagster/check/__init__.py
@@ -1,847 +1,839 @@
import inspect
import sys
+from inspect import Parameter
from future.utils import raise_with_traceback
-from six import integer_types
-
-if sys.version_info[0] >= 3:
- type_types = type
-else:
- # These shenanigans are to support old-style classes in py27
- import new # pylint: disable=import-error
-
- type_types = (type, new.classobj) # pylint: disable=undefined-variable
class CheckError(Exception):
pass
class ParameterCheckError(CheckError):
pass
class ElementCheckError(CheckError):
pass
class NotImplementedCheckError(CheckError):
pass
def _param_type_mismatch_exception(obj, ttype, param_name, additional_message=None):
if isinstance(ttype, tuple):
type_names = sorted([t.__name__ for t in ttype])
return ParameterCheckError(
'Param "{name}" is not one of {type_names}. Got {obj} which is type {obj_type}.'
"{additional_message}".format(
name=param_name,
obj=repr(obj),
type_names=type_names,
obj_type=type(obj),
additional_message=" " + additional_message if additional_message else "",
)
)
else:
return ParameterCheckError(
'Param "{name}" is not a {type}. Got {obj} which is type {obj_type}.'
"{additional_message}".format(
name=param_name,
obj=repr(obj),
type=ttype.__name__,
obj_type=type(obj),
additional_message=" " + additional_message if additional_message else "",
)
)
def _not_type_param_subclass_mismatch_exception(obj, param_name):
return ParameterCheckError(
'Param "{name}" was supposed to be a type. Got {obj} of type {obj_type}'.format(
name=param_name, obj=repr(obj), obj_type=type(obj)
)
)
def _param_subclass_mismatch_exception(obj, superclass, param_name):
return ParameterCheckError(
'Param "{name}" is a type but not a subclass of {superclass}. Got {obj} instead'.format(
name=param_name, superclass=superclass, obj=obj
)
)
def _type_mismatch_error(obj, ttype, desc=None):
type_message = (
f"not one of {sorted([t.__name__ for t in ttype])}"
if isinstance(ttype, tuple)
else f"not a {ttype.__name__}"
)
repr_obj = repr(obj)
desc_str = f" Desc: {desc}" if desc else ""
return CheckError(
f"Object {repr_obj} is {type_message}. Got {repr_obj} with type {type(obj)}.{desc_str}"
)
def _not_callable_exception(obj, param_name):
return ParameterCheckError(
'Param "{name}" is not callable. Got {obj} with type {obj_type}.'.format(
name=param_name, obj=repr(obj), obj_type=type(obj)
)
)
def _param_invariant_exception(param_name, desc):
return ParameterCheckError(
"Invariant violation for parameter {param_name}. Description: {desc}".format(
param_name=param_name, desc=desc
)
)
def failed(desc):
if not isinstance(desc, str):
raise_with_traceback(CheckError("desc argument must be a string"))
raise_with_traceback(CheckError("Failure condition: {desc}".format(desc=desc)))
def not_implemented(desc):
if not isinstance(desc, str):
raise_with_traceback(CheckError("desc argument must be a string"))
raise_with_traceback(NotImplementedCheckError("Not implemented: {desc}".format(desc=desc)))
def inst(obj, ttype, desc=None):
if not isinstance(obj, ttype):
raise_with_traceback(_type_mismatch_error(obj, ttype, desc))
return obj
def subclass(obj, superclass, desc=None):
if not issubclass(obj, superclass):
raise_with_traceback(_type_mismatch_error(obj, superclass, desc))
return obj
def is_callable(obj, desc=None):
if not callable(obj):
if desc:
raise_with_traceback(
CheckError(
"Must be callable. Got {obj}. Description: {desc}".format(
obj=repr(obj), desc=desc
)
)
)
else:
raise_with_traceback(
CheckError(
"Must be callable. Got {obj}. Description: {desc}".format(obj=obj, desc=desc)
)
)
return obj
def not_none_param(obj, param_name):
if obj is None:
raise_with_traceback(
_param_invariant_exception(
param_name, "Param {param_name} cannot be none".format(param_name=param_name)
)
)
return obj
def invariant(condition, desc=None):
if not condition:
if desc:
raise_with_traceback(
CheckError("Invariant failed. Description: {desc}".format(desc=desc))
)
else:
raise_with_traceback(CheckError("Invariant failed."))
return True
def param_invariant(condition, param_name, desc=None):
if not condition:
raise_with_traceback(_param_invariant_exception(param_name, desc))
def inst_param(obj, param_name, ttype, additional_message=None):
if not isinstance(obj, ttype):
raise_with_traceback(
_param_type_mismatch_exception(
obj, ttype, param_name, additional_message=additional_message
)
)
return obj
def opt_inst_param(obj, param_name, ttype, default=None):
if obj is not None and not isinstance(obj, ttype):
raise_with_traceback(_param_type_mismatch_exception(obj, ttype, param_name))
return default if obj is None else obj
def callable_param(obj, param_name):
if not callable(obj):
raise_with_traceback(_not_callable_exception(obj, param_name))
return obj
def opt_callable_param(obj, param_name, default=None):
if obj is not None and not callable(obj):
raise_with_traceback(_not_callable_exception(obj, param_name))
return default if obj is None else obj
def int_param(obj, param_name):
- if not isinstance(obj, integer_types):
+ if not isinstance(obj, int):
raise_with_traceback(_param_type_mismatch_exception(obj, int, param_name))
return obj
def int_value_param(obj, value, param_name):
- if not isinstance(obj, integer_types):
+ if not isinstance(obj, int):
raise_with_traceback(_param_type_mismatch_exception(obj, int, param_name))
if obj != value:
raise_with_traceback(
_param_invariant_exception(param_name, "Should be equal to {value}".format(value=value))
)
return obj
def opt_int_param(obj, param_name, default=None):
- if obj is not None and not isinstance(obj, integer_types):
+ if obj is not None and not isinstance(obj, int):
raise_with_traceback(_param_type_mismatch_exception(obj, int, param_name))
return default if obj is None else obj
def float_param(obj, param_name):
if not isinstance(obj, float):
raise_with_traceback(_param_type_mismatch_exception(obj, float, param_name))
return obj
def opt_numeric_param(obj, param_name, default=None):
if obj is not None and not isinstance(obj, (int, float)):
raise_with_traceback(_param_type_mismatch_exception(obj, (int, float), param_name))
return default if obj is None else obj
def numeric_param(obj, param_name):
if not isinstance(obj, (int, float)):
raise_with_traceback(_param_type_mismatch_exception(obj, (int, float), param_name))
return obj
def opt_float_param(obj, param_name, default=None):
if obj is not None and not isinstance(obj, float):
raise_with_traceback(_param_type_mismatch_exception(obj, float, param_name))
return default if obj is None else obj
def str_param(obj, param_name):
if not isinstance(obj, str):
raise_with_traceback(_param_type_mismatch_exception(obj, str, param_name))
return obj
def opt_str_param(obj, param_name, default=None):
if obj is not None and not isinstance(obj, str):
raise_with_traceback(_param_type_mismatch_exception(obj, str, param_name))
return default if obj is None else obj
def opt_nonempty_str_param(obj, param_name, default=None):
if obj is not None and not isinstance(obj, str):
raise_with_traceback(_param_type_mismatch_exception(obj, str, param_name))
return default if obj is None or obj == "" else obj
def bool_param(obj, param_name):
if not isinstance(obj, bool):
raise_with_traceback(_param_type_mismatch_exception(obj, bool, param_name))
return obj
def opt_bool_param(obj, param_name, default=None):
if obj is not None and not isinstance(obj, bool):
raise_with_traceback(_param_type_mismatch_exception(obj, bool, param_name))
return default if obj is None else obj
def is_list(obj_list, of_type=None, desc=None):
if not isinstance(obj_list, list):
raise_with_traceback(_type_mismatch_error(obj_list, list, desc))
if not of_type:
return obj_list
return _check_list_items(obj_list, of_type)
def is_tuple(obj_tuple, of_type=None, desc=None):
if not isinstance(obj_tuple, tuple):
raise_with_traceback(_type_mismatch_error(obj_tuple, tuple, desc))
if not of_type:
return obj_tuple
return _check_tuple_items(obj_tuple, of_type)
def list_param(obj_list, param_name, of_type=None):
from dagster.utils import frozenlist
if not isinstance(obj_list, (frozenlist, list)):
raise_with_traceback(
_param_type_mismatch_exception(obj_list, (frozenlist, list), param_name)
)
if not of_type:
return obj_list
return _check_list_items(obj_list, of_type)
def set_param(obj_set, param_name, of_type=None):
if not isinstance(obj_set, (frozenset, set)):
raise_with_traceback(_param_type_mismatch_exception(obj_set, (frozenset, set), param_name))
if not of_type:
return obj_set
return _check_set_items(obj_set, of_type)
def tuple_param(obj, param_name, of_type=None):
if not isinstance(obj, tuple):
raise_with_traceback(_param_type_mismatch_exception(obj, tuple, param_name))
if of_type is None:
return obj
return _check_tuple_items(obj, of_type)
def matrix_param(matrix, param_name, of_type=None):
matrix = list_param(matrix, param_name, of_type=list)
if not matrix:
raise_with_traceback(CheckError("You must pass a list of lists. Received an empty list."))
for sublist in matrix:
sublist = list_param(sublist, "sublist_{}".format(param_name), of_type=of_type)
if len(sublist) != len(matrix[0]):
raise_with_traceback(CheckError("All sublists in matrix must have the same length"))
return matrix
def opt_tuple_param(obj, param_name, default=None, of_type=None):
if obj is not None and not isinstance(obj, tuple):
raise_with_traceback(_param_type_mismatch_exception(obj, tuple, param_name))
if obj is None:
return default
if of_type is None:
return obj
return _check_tuple_items(obj, of_type)
def _check_list_items(obj_list, of_type):
for obj in obj_list:
if not isinstance(obj, of_type):
if isinstance(obj, type):
additional_message = (
" Did you pass a class where you were expecting an instance of the class?"
)
else:
additional_message = ""
raise_with_traceback(
CheckError(
"Member of list mismatches type. Expected {of_type}. Got {obj_repr} of type "
"{obj_type}.{additional_message}".format(
of_type=of_type,
obj_repr=repr(obj),
obj_type=type(obj),
additional_message=additional_message,
)
)
)
return obj_list
def _check_set_items(obj_set, of_type):
for obj in obj_set:
if not isinstance(obj, of_type):
if isinstance(obj, type):
additional_message = (
" Did you pass a class where you were expecting an instance of the class?"
)
else:
additional_message = ""
raise_with_traceback(
CheckError(
"Member of set mismatches type. Expected {of_type}. Got {obj_repr} of type "
"{obj_type}.{additional_message}".format(
of_type=of_type,
obj_repr=repr(obj),
obj_type=type(obj),
additional_message=additional_message,
)
)
)
return obj_set
def _check_tuple_items(obj_tuple, of_type):
if isinstance(of_type, tuple):
len_tuple = len(obj_tuple)
len_type = len(of_type)
if not len_tuple == len_type:
raise_with_traceback(
CheckError(
"Tuple mismatches type: tuple had {len_tuple} members but type had "
"{len_type}".format(len_tuple=len_tuple, len_type=len_type)
)
)
for (i, obj) in enumerate(obj_tuple):
of_type_i = of_type[i]
if not isinstance(obj, of_type_i):
if isinstance(obj, type):
additional_message = (
" Did you pass a class where you were expecting an instance of the class?"
)
else:
additional_message = ""
raise_with_traceback(
CheckError(
"Member of tuple mismatches type at index {index}. Expected {of_type}. Got "
"{obj_repr} of type {obj_type}.{additional_message}".format(
index=i,
of_type=of_type_i,
obj_repr=repr(obj),
obj_type=type(obj),
additional_message=additional_message,
)
)
)
else:
for (i, obj) in enumerate(obj_tuple):
if not isinstance(obj, of_type):
if isinstance(obj, type):
additional_message = (
" Did you pass a class where you were expecting an instance of the class?"
)
else:
additional_message = ""
raise_with_traceback(
CheckError(
"Member of tuple mismatches type at index {index}. Expected {of_type}. Got "
"{obj_repr} of type {obj_type}.{additional_message}".format(
index=i,
of_type=of_type,
obj_repr=repr(obj),
obj_type=type(obj),
additional_message=additional_message,
)
)
)
return obj_tuple
def opt_list_param(obj_list, param_name, of_type=None):
"""Ensures argument obj_list is a list or None; in the latter case, instantiates an empty list
and returns it.
If the of_type argument is provided, also ensures that list items conform to the type specified
by of_type.
"""
from dagster.utils import frozenlist
if obj_list is not None and not isinstance(obj_list, (frozenlist, list)):
raise_with_traceback(
_param_type_mismatch_exception(obj_list, (frozenlist, list), param_name)
)
if not obj_list:
return []
if not of_type:
return obj_list
return _check_list_items(obj_list, of_type)
def opt_set_param(obj_set, param_name, of_type=None):
"""Ensures argument obj_set is a set or None; in the latter case, instantiates an empty set
and returns it.
If the of_type argument is provided, also ensures that list items conform to the type specified
by of_type.
"""
if obj_set is not None and not isinstance(obj_set, (frozenset, set)):
raise_with_traceback(_param_type_mismatch_exception(obj_set, (frozenset, set), param_name))
if not obj_set:
return set()
if not of_type:
return obj_set
return _check_set_items(obj_set, of_type)
def opt_nullable_list_param(obj_list, param_name, of_type=None):
"""Ensures argument obj_list is a list or None. Returns None if input is None.
If the of_type argument is provided, also ensures that list items conform to the type specified
by of_type.
"""
from dagster.utils import frozenlist
if obj_list is not None and not isinstance(obj_list, (frozenlist, list)):
raise_with_traceback(
_param_type_mismatch_exception(obj_list, (frozenlist, list), param_name)
)
if not obj_list:
return None if obj_list is None else []
if not of_type:
return obj_list
return _check_list_items(obj_list, of_type)
def _check_key_value_types(obj, key_type, value_type, key_check=isinstance, value_check=isinstance):
"""Ensures argument obj is a dictionary, and enforces that the keys/values conform to the types
specified by key_type, value_type.
"""
if not isinstance(obj, dict):
raise_with_traceback(_type_mismatch_error(obj, dict))
for key, value in obj.items():
if key_type and not key_check(key, key_type):
raise_with_traceback(
CheckError(
"Key in dictionary mismatches type. Expected {key_type}. Got {obj_repr}".format(
key_type=repr(key_type), obj_repr=repr(key)
)
)
)
if value_type and not value_check(value, value_type):
raise_with_traceback(
CheckError(
"Value in dictionary mismatches expected type for key {key}. Expected value "
"of type {vtype}. Got value {value} of type {obj_type}.".format(
vtype=repr(value_type), obj_type=type(value), key=key, value=value
)
)
)
return obj
def dict_param(obj, param_name, key_type=None, value_type=None, additional_message=None):
"""Ensures argument obj is a native Python dictionary, raises an exception if not, and otherwise
returns obj.
"""
from dagster.utils import frozendict
if not isinstance(obj, (frozendict, dict)):
raise_with_traceback(
_param_type_mismatch_exception(
obj, (frozendict, dict), param_name, additional_message=additional_message
)
)
if not (key_type or value_type):
return obj
return _check_key_value_types(obj, key_type, value_type)
def opt_dict_param(obj, param_name, key_type=None, value_type=None, value_class=None):
"""Ensures argument obj is either a dictionary or None; if the latter, instantiates an empty
dictionary.
"""
from dagster.utils import frozendict
if obj is not None and not isinstance(obj, (frozendict, dict)):
raise_with_traceback(_param_type_mismatch_exception(obj, (frozendict, dict), param_name))
if not obj:
return {}
if value_class:
return _check_key_value_types(obj, key_type, value_type=value_class, value_check=issubclass)
return _check_key_value_types(obj, key_type, value_type)
def opt_nullable_dict_param(obj, param_name, key_type=None, value_type=None, value_class=None):
"""Ensures argument obj is either a dictionary or None;
"""
from dagster.utils import frozendict
if obj is not None and not isinstance(obj, (frozendict, dict)):
raise_with_traceback(_param_type_mismatch_exception(obj, (frozendict, dict), param_name))
if not obj:
return None if obj is None else {}
if value_class:
return _check_key_value_types(obj, key_type, value_type=value_class, value_check=issubclass)
return _check_key_value_types(obj, key_type, value_type)
def _check_two_dim_key_value_types(obj, key_type, _param_name, value_type):
_check_key_value_types(obj, key_type, dict) # check level one
for level_two_dict in obj.values():
_check_key_value_types(level_two_dict, key_type, value_type) # check level two
return obj
def two_dim_dict_param(obj, param_name, key_type=str, value_type=None):
if not isinstance(obj, dict):
raise_with_traceback(_param_type_mismatch_exception(obj, dict, param_name))
return _check_two_dim_key_value_types(obj, key_type, param_name, value_type)
def opt_two_dim_dict_param(obj, param_name, key_type=str, value_type=None):
if obj is not None and not isinstance(obj, dict):
raise_with_traceback(_param_type_mismatch_exception(obj, dict, param_name))
if not obj:
return {}
return _check_two_dim_key_value_types(obj, key_type, param_name, value_type)
def type_param(obj, param_name):
- if not isinstance(obj, type_types):
+ if not isinstance(obj, type):
raise_with_traceback(_not_type_param_subclass_mismatch_exception(obj, param_name))
return obj
def opt_type_param(obj, param_name, default=None):
if obj is not None and not isinstance(obj, type):
raise_with_traceback(_not_type_param_subclass_mismatch_exception(obj, param_name))
return obj if obj is not None else default
def subclass_param(obj, param_name, superclass):
type_param(obj, param_name)
if not issubclass(obj, superclass):
raise_with_traceback(_param_subclass_mismatch_exception(obj, superclass, param_name))
return obj
def opt_subclass_param(obj, param_name, superclass):
opt_type_param(obj, param_name)
if obj is not None and not issubclass(obj, superclass):
raise_with_traceback(_param_subclass_mismatch_exception(obj, superclass, param_name))
return obj
def _element_check_error(key, value, ddict, ttype):
return ElementCheckError(
"Value {value} from key {key} is not a {ttype}. Dict: {ddict}".format(
key=key, value=repr(value), ddict=repr(ddict), ttype=repr(ttype)
)
)
def generator(obj):
if not inspect.isgenerator(obj):
raise ParameterCheckError(
"Not a generator (return value of function that yields) Got {obj} instead".format(
obj=obj
)
)
return obj
def opt_generator(obj):
if obj is not None and not inspect.isgenerator(obj):
raise ParameterCheckError(
"Not a generator (return value of function that yields) Got {obj} instead".format(
obj=obj
)
)
return obj
def generator_param(obj, param_name):
if not inspect.isgenerator(obj):
raise ParameterCheckError(
(
'Param "{name}" is not a generator (return value of function that yields) Got '
"{obj} instead"
).format(name=param_name, obj=obj)
)
return obj
def opt_generator_param(obj, param_name):
if obj is not None and not inspect.isgenerator(obj):
raise ParameterCheckError(
(
'Param "{name}" is not a generator (return value of function that yields) Got '
"{obj} instead"
).format(name=param_name, obj=obj)
)
return obj
def list_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict.get(key)
if not isinstance(value, list):
raise_with_traceback(_element_check_error(key, value, ddict, list))
return value
def opt_list_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict.get(key)
if value is None:
return []
if not isinstance(value, list):
raise_with_traceback(_element_check_error(key, value, ddict, list))
return value
def dict_elem(ddict, key):
from dagster.utils import frozendict
dict_param(ddict, "ddict")
str_param(key, "key")
if key not in ddict:
raise_with_traceback(
CheckError("{key} not present in dictionary {ddict}".format(key=key, ddict=ddict))
)
value = ddict[key]
if not isinstance(value, (frozendict, dict)):
raise_with_traceback(_element_check_error(key, value, ddict, (frozendict, dict)))
return value
def opt_dict_elem(ddict, key):
from dagster.utils import frozendict
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict.get(key)
if value is None:
return {}
if not isinstance(value, (frozendict, dict)):
raise_with_traceback(_element_check_error(key, value, ddict, list))
return value
def bool_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict[key]
if not isinstance(value, bool):
raise_with_traceback(_element_check_error(key, value, ddict, bool))
return value
def opt_float_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict.get(key)
if value is None:
return None
if not isinstance(value, float):
raise_with_traceback(_element_check_error(key, value, ddict, float))
return value
def float_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict[key]
if not isinstance(value, float):
raise_with_traceback(_element_check_error(key, value, ddict, float))
return value
def opt_int_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict.get(key)
if value is None:
return None
- if not isinstance(value, integer_types):
+ if not isinstance(value, int):
raise_with_traceback(_element_check_error(key, value, ddict, int))
return value
def int_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict[key]
- if not isinstance(value, integer_types):
+ if not isinstance(value, int):
raise_with_traceback(_element_check_error(key, value, ddict, int))
return value
def opt_str_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict.get(key)
if value is None:
return None
if not isinstance(value, str):
raise_with_traceback(_element_check_error(key, value, ddict, str))
return value
def str_elem(ddict, key):
dict_param(ddict, "ddict")
str_param(key, "key")
value = ddict[key]
if not isinstance(value, str):
raise_with_traceback(_element_check_error(key, value, ddict, str))
return value
def class_param(obj, param_name):
if not inspect.isclass(obj):
return ParameterCheckError(
'Param "{name}" is not a class. Got {obj} which is type {obj_type}.'.format(
name=param_name, obj=repr(obj), obj_type=type(obj),
)
)
return obj
diff --git a/python_modules/dagster/dagster/core/definitions/preset.py b/python_modules/dagster/dagster/core/definitions/preset.py
index d87fc09a7..107f204da 100644
--- a/python_modules/dagster/dagster/core/definitions/preset.py
+++ b/python_modules/dagster/dagster/core/definitions/preset.py
@@ -1,214 +1,213 @@
from collections import namedtuple
import pkg_resources
import six
import yaml
from dagster import check
from dagster.core.definitions.utils import config_from_files, config_from_yaml_strings
from dagster.core.errors import DagsterInvariantViolationError
-from dagster.seven import FileNotFoundError, ModuleNotFoundError # pylint:disable=redefined-builtin
from dagster.utils.backcompat import canonicalize_backcompat_args
from dagster.utils.merger import deep_merge_dicts
from .mode import DEFAULT_MODE_NAME
from .utils import check_valid_name
class PresetDefinition(
namedtuple("_PresetDefinition", "name run_config solid_selection mode tags")
):
"""Defines a preset configuration in which a pipeline can execute.
Presets can be used in Dagit to load predefined configurations into the tool.
Presets may also be used from the Python API (in a script, or in test) as follows:
.. code-block:: python
execute_pipeline(pipeline_def, preset='example_preset')
Presets may also be used with the command line tools:
.. code-block:: shell
$ dagster pipeline execute example_pipeline --preset example_preset
Args:
name (str): The name of this preset. Must be unique in the presets defined on a given
pipeline.
run_config (Optional[dict]): A dict representing the config to set with the preset.
This is equivalent to the ``run_config`` argument to :py:func:`execute_pipeline`.
solid_selection (Optional[List[str]]): A list of solid subselection (including single
solid names) to execute with the preset. e.g. ``['*some_solid+', 'other_solid']``
mode (Optional[str]): The mode to apply when executing this preset. (default: 'default')
tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset.
"""
def __new__(
cls, name, run_config=None, solid_selection=None, mode=None, tags=None,
):
return super(PresetDefinition, cls).__new__(
cls,
name=check_valid_name(name),
run_config=run_config,
solid_selection=check.opt_nullable_list_param(
solid_selection, "solid_selection", of_type=str
),
mode=check.opt_str_param(mode, "mode", DEFAULT_MODE_NAME),
tags=check.opt_dict_param(tags, "tags", key_type=str),
)
@staticmethod
def from_files(
name, environment_files=None, config_files=None, solid_selection=None, mode=None, tags=None
):
"""Static constructor for presets from YAML files.
Args:
name (str): The name of this preset. Must be unique in the presets defined on a given
pipeline.
config_files (Optional[List[str]]): List of paths or glob patterns for yaml files
to load and parse as the environment config for this preset.
solid_selection (Optional[List[str]]): A list of solid subselection (including single
solid names) to execute with the preset. e.g. ``['*some_solid+', 'other_solid']``
mode (Optional[str]): The mode to apply when executing this preset. (default:
'default')
tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset.
Returns:
PresetDefinition: A PresetDefinition constructed from the provided YAML files.
Raises:
DagsterInvariantViolationError: When one of the YAML files is invalid and has a parse
error.
"""
check.str_param(name, "name")
config_files = canonicalize_backcompat_args(
config_files, "config_files", environment_files, "environment_files", "0.9.0"
)
config_files = check.opt_list_param(config_files, "config_files")
solid_selection = check.opt_nullable_list_param(
solid_selection, "solid_selection", of_type=str
)
mode = check.opt_str_param(mode, "mode", DEFAULT_MODE_NAME)
merged = config_from_files(config_files)
return PresetDefinition(name, merged, solid_selection, mode, tags)
@staticmethod
def from_yaml_strings(name, yaml_strings=None, solid_selection=None, mode=None, tags=None):
"""Static constructor for presets from YAML strings.
Args:
name (str): The name of this preset. Must be unique in the presets defined on a given
pipeline.
yaml_strings (Optional[List[str]]): List of yaml strings to parse as the environment
config for this preset.
solid_selection (Optional[List[str]]): A list of solid subselection (including single
solid names) to execute with the preset. e.g. ``['*some_solid+', 'other_solid']``
mode (Optional[str]): The mode to apply when executing this preset. (default:
'default')
tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset.
Returns:
PresetDefinition: A PresetDefinition constructed from the provided YAML strings
Raises:
DagsterInvariantViolationError: When one of the YAML documents is invalid and has a
parse error.
"""
check.str_param(name, "name")
yaml_strings = check.opt_list_param(yaml_strings, "yaml_strings", of_type=str)
solid_selection = check.opt_nullable_list_param(
solid_selection, "solid_selection", of_type=str
)
mode = check.opt_str_param(mode, "mode", DEFAULT_MODE_NAME)
merged = config_from_yaml_strings(yaml_strings)
return PresetDefinition(name, merged, solid_selection, mode, tags)
@staticmethod
def from_pkg_resources(
name, pkg_resource_defs=None, solid_selection=None, mode=None, tags=None
):
"""Load a preset from a package resource, using :py:func:`pkg_resources.resource_string`.
Example:
.. code-block:: python
PresetDefinition.from_pkg_resources(
name='local',
mode='local',
pkg_resource_defs=[
('dagster_examples.airline_demo.environments', 'local_base.yaml'),
('dagster_examples.airline_demo.environments', 'local_warehouse.yaml'),
],
)
Args:
name (str): The name of this preset. Must be unique in the presets defined on a given
pipeline.
pkg_resource_defs (Optional[List[(str, str)]]): List of pkg_resource modules/files to
load as environment config for this preset.
solid_selection (Optional[List[str]]): A list of solid subselection (including single
solid names) to execute with this partition. e.g.
``['*some_solid+', 'other_solid']``
mode (Optional[str]): The mode to apply when executing this preset. (default:
'default')
tags (Optional[Dict[str, Any]]): The tags to apply when executing this preset.
Returns:
PresetDefinition: A PresetDefinition constructed from the provided YAML strings
Raises:
DagsterInvariantViolationError: When one of the YAML documents is invalid and has a
parse error.
"""
pkg_resource_defs = check.opt_list_param(
pkg_resource_defs, "pkg_resource_defs", of_type=tuple
)
try:
yaml_strings = [
six.ensure_str(pkg_resources.resource_string(*pkg_resource_def))
for pkg_resource_def in pkg_resource_defs
]
except (ModuleNotFoundError, FileNotFoundError, UnicodeDecodeError) as err:
six.raise_from(
DagsterInvariantViolationError(
"Encountered error attempting to parse yaml. Loading YAMLs from "
"package resources {pkg_resource_defs} "
'on preset "{name}".'.format(pkg_resource_defs=pkg_resource_defs, name=name)
),
err,
)
return PresetDefinition.from_yaml_strings(name, yaml_strings, solid_selection, mode, tags)
def get_environment_yaml(self):
"""Get the environment dict set on a preset as YAML.
Returns:
str: The environment dict as YAML.
"""
return yaml.dump(self.run_config or {}, default_flow_style=False)
def with_additional_config(self, run_config):
"""Return a new PresetDefinition with additional config merged in to the existing config."""
check.opt_nullable_dict_param(run_config, "run_config")
if run_config is None:
return self
else:
return PresetDefinition(
name=self.name,
solid_selection=self.solid_selection,
mode=self.mode,
tags=self.tags,
run_config=deep_merge_dicts(self.run_config, run_config),
)
diff --git a/python_modules/dagster/dagster/core/definitions/reconstructable.py b/python_modules/dagster/dagster/core/definitions/reconstructable.py
index bd6b8ebe6..fbafedd4a 100644
--- a/python_modules/dagster/dagster/core/definitions/reconstructable.py
+++ b/python_modules/dagster/dagster/core/definitions/reconstructable.py
@@ -1,506 +1,506 @@
import inspect
import os
import sys
from collections import namedtuple
+from functools import lru_cache
from dagster import check, seven
from dagster.core.code_pointer import (
CodePointer,
CustomPointer,
FileCodePointer,
ModuleCodePointer,
get_python_file_from_target,
)
from dagster.core.errors import DagsterInvalidSubsetError, DagsterInvariantViolationError
from dagster.core.origin import PipelinePythonOrigin, RepositoryPythonOrigin, SchedulePythonOrigin
from dagster.core.selector import parse_solid_selection
from dagster.serdes import pack_value, unpack_value, whitelist_for_serdes
-from dagster.seven import lru_cache
from dagster.utils.backcompat import experimental
from .pipeline_base import IPipeline
def get_ephemeral_repository_name(pipeline_name):
check.str_param(pipeline_name, "pipeline_name")
return "__repository__{pipeline_name}".format(pipeline_name=pipeline_name)
@whitelist_for_serdes
class ReconstructableRepository(
namedtuple("_ReconstructableRepository", "pointer container_image")
):
def __new__(
cls, pointer, container_image=None,
):
return super(ReconstructableRepository, cls).__new__(
cls,
pointer=check.inst_param(pointer, "pointer", CodePointer),
container_image=check.opt_str_param(container_image, "container_image"),
)
@lru_cache(maxsize=1)
def get_definition(self):
return repository_def_from_pointer(self.pointer)
def get_reconstructable_pipeline(self, name):
return ReconstructablePipeline(self, name)
def get_reconstructable_schedule(self, name):
return ReconstructableSchedule(self, name)
@classmethod
def for_file(cls, file, fn_name, working_directory=None, container_image=None):
if not working_directory:
working_directory = os.getcwd()
return cls(FileCodePointer(file, fn_name, working_directory), container_image)
@classmethod
def for_module(cls, module, fn_name, container_image=None):
return cls(ModuleCodePointer(module, fn_name), container_image)
def get_cli_args(self):
return self.pointer.get_cli_args()
@classmethod
def from_legacy_repository_yaml(cls, file_path):
check.str_param(file_path, "file_path")
absolute_file_path = os.path.abspath(os.path.expanduser(file_path))
return cls(pointer=CodePointer.from_legacy_repository_yaml(absolute_file_path))
def get_python_origin(self):
return RepositoryPythonOrigin(
executable_path=sys.executable,
code_pointer=self.pointer,
container_image=self.container_image,
)
def get_python_origin_id(self):
return self.get_python_origin().get_id()
@whitelist_for_serdes
class ReconstructablePipeline(
namedtuple(
"_ReconstructablePipeline",
"repository pipeline_name solid_selection_str solids_to_execute",
),
IPipeline,
):
def __new__(
cls, repository, pipeline_name, solid_selection_str=None, solids_to_execute=None,
):
check.opt_set_param(solids_to_execute, "solids_to_execute", of_type=str)
return super(ReconstructablePipeline, cls).__new__(
cls,
repository=check.inst_param(repository, "repository", ReconstructableRepository),
pipeline_name=check.str_param(pipeline_name, "pipeline_name"),
solid_selection_str=check.opt_str_param(solid_selection_str, "solid_selection_str"),
solids_to_execute=solids_to_execute,
)
@property
def solid_selection(self):
return seven.json.loads(self.solid_selection_str) if self.solid_selection_str else None
@lru_cache(maxsize=1)
def get_definition(self):
return (
self.repository.get_definition()
.get_pipeline(self.pipeline_name)
.get_pipeline_subset_def(self.solids_to_execute)
)
def _resolve_solid_selection(self, solid_selection):
# resolve a list of solid selection queries to a frozenset of qualified solid names
# e.g. ['foo_solid+'] to {'foo_solid', 'bar_solid'}
check.list_param(solid_selection, "solid_selection", of_type=str)
solids_to_execute = parse_solid_selection(self.get_definition(), solid_selection)
if len(solids_to_execute) == 0:
raise DagsterInvalidSubsetError(
"No qualified solids to execute found for solid_selection={requested}".format(
requested=solid_selection
)
)
return solids_to_execute
def get_reconstructable_repository(self):
return self.repository
def _subset_for_execution(self, solids_to_execute, solid_selection=None):
if solids_to_execute:
pipe = ReconstructablePipeline(
repository=self.repository,
pipeline_name=self.pipeline_name,
solid_selection_str=seven.json.dumps(solid_selection) if solid_selection else None,
solids_to_execute=frozenset(solids_to_execute),
)
else:
pipe = ReconstructablePipeline(
repository=self.repository, pipeline_name=self.pipeline_name,
)
pipe.get_definition() # verify the subset is correct
return pipe
def subset_for_execution(self, solid_selection):
# take a list of solid queries and resolve the queries to names of solids to execute
check.opt_list_param(solid_selection, "solid_selection", of_type=str)
solids_to_execute = (
self._resolve_solid_selection(solid_selection) if solid_selection else None
)
return self._subset_for_execution(solids_to_execute, solid_selection)
def subset_for_execution_from_existing_pipeline(self, solids_to_execute):
# take a frozenset of resolved solid names from an existing pipeline
# so there's no need to parse the selection
check.opt_set_param(solids_to_execute, "solids_to_execute", of_type=str)
return self._subset_for_execution(solids_to_execute)
def describe(self):
return '"{name}" in repository ({repo})'.format(
repo=self.repository.pointer.describe, name=self.pipeline_name
)
@staticmethod
def for_file(python_file, fn_name):
return bootstrap_standalone_recon_pipeline(
FileCodePointer(python_file, fn_name, os.getcwd())
)
@staticmethod
def for_module(module, fn_name):
return bootstrap_standalone_recon_pipeline(ModuleCodePointer(module, fn_name))
def to_dict(self):
return pack_value(self)
@staticmethod
def from_dict(val):
check.dict_param(val, "val")
inst = unpack_value(val)
check.invariant(
isinstance(inst, ReconstructablePipeline),
"Deserialized object is not instance of ReconstructablePipeline, got {type}".format(
type=type(inst)
),
)
return inst
def get_python_origin(self):
return PipelinePythonOrigin(self.pipeline_name, self.repository.get_python_origin())
def get_python_origin_id(self):
return self.get_python_origin().get_id()
@whitelist_for_serdes
class ReconstructableSchedule(namedtuple("_ReconstructableSchedule", "repository schedule_name",)):
def __new__(
cls, repository, schedule_name,
):
return super(ReconstructableSchedule, cls).__new__(
cls,
repository=check.inst_param(repository, "repository", ReconstructableRepository),
schedule_name=check.str_param(schedule_name, "schedule_name"),
)
def get_python_origin(self):
return SchedulePythonOrigin(self.schedule_name, self.repository.get_python_origin())
def get_python_origin_id(self):
return self.get_python_origin().get_id()
@lru_cache(maxsize=1)
def get_definition(self):
return self.repository.get_definition().get_schedule_def(self.schedule_name)
def reconstructable(target):
"""
Create a ReconstructablePipeline from a function that returns a PipelineDefinition, or a
function decorated with :py:func:`@pipeline `
When your pipeline must cross process boundaries, e.g., for execution on multiple nodes or
in different systems (like dagstermill), Dagster must know how to reconstruct the pipeline
on the other side of the process boundary.
This function implements a very conservative strategy for reconstructing pipelines, so that
its behavior is easy to predict, but as a consequence it is not able to reconstruct certain
kinds of pipelines, such as those defined by lambdas, in nested scopes (e.g., dynamically
within a method call), or in interactive environments such as the Python REPL or Jupyter
notebooks.
If you need to reconstruct pipelines constructed in these ways, you should use
:py:func:`build_reconstructable_pipeline` instead, which allows you to specify your own
strategy for reconstructing a pipeline.
Examples:
.. code-block:: python
from dagster import PipelineDefinition, pipeline, reconstructable
@pipeline
def foo_pipeline():
...
reconstructable_foo_pipeline = reconstructable(foo_pipeline)
def make_bar_pipeline():
return PipelineDefinition(...)
reconstructable_bar_pipeline = reconstructable(bar_pipeline)
"""
from dagster.core.definitions import PipelineDefinition
if not seven.is_function_or_decorator_instance_of(target, PipelineDefinition):
raise DagsterInvariantViolationError(
"Reconstructable target should be a function or definition produced "
"by a decorated function, got {type}.".format(type=type(target)),
)
if seven.is_lambda(target):
raise DagsterInvariantViolationError(
"Reconstructable target can not be a lambda. Use a function or "
"decorated function defined at module scope instead, or use "
"build_reconstructable_pipeline."
)
if seven.qualname_differs(target):
raise DagsterInvariantViolationError(
'Reconstructable target "{target.__name__}" has a different '
'__qualname__ "{target.__qualname__}" indicating it is not '
"defined at module scope. Use a function or decorated function "
"defined at module scope instead, or use build_reconstructable_pipeline.".format(
target=target
)
)
try:
if (
hasattr(target, "__module__")
and hasattr(target, "__name__")
and inspect.getmodule(target).__name__ != "__main__"
):
return ReconstructablePipeline.for_module(target.__module__, target.__name__)
except: # pylint: disable=bare-except
pass
python_file = get_python_file_from_target(target)
if not python_file:
raise DagsterInvariantViolationError(
"reconstructable() can not reconstruct pipelines defined in interactive environments "
"like , IPython, or Jupyter notebooks. "
"Use a pipeline defined in a module or file instead, or "
"use build_reconstructable_pipeline."
)
pointer = FileCodePointer(
python_file=python_file, fn_name=target.__name__, working_directory=os.getcwd()
)
return bootstrap_standalone_recon_pipeline(pointer)
@experimental
def build_reconstructable_pipeline(
reconstructor_module_name,
reconstructor_function_name,
reconstructable_args=None,
reconstructable_kwargs=None,
):
"""
Create a ReconstructablePipeline.
When your pipeline must cross process boundaries, e.g., for execution on multiple nodes or
in different systems (like dagstermill), Dagster must know how to reconstruct the pipeline
on the other side of the process boundary.
This function allows you to use the strategy of your choice for reconstructing pipelines, so
that you can reconstruct certain kinds of pipelines that are not supported by
:py:func:`reconstructable`, such as those defined by lambdas, in nested scopes (e.g.,
dynamically within a method call), or in interactive environments such as the Python REPL or
Jupyter notebooks.
If you need to reconstruct pipelines constructed in these ways, use this function instead of
:py:func:`reconstructable`.
Args:
reconstructor_module_name (str): The name of the module containing the function to use to
reconstruct the pipeline.
reconstructor_function_name (str): The name of the function to use to reconstruct the
pipeline.
reconstructable_args (Tuple): Args to the function to use to reconstruct the pipeline.
Values of the tuple must be JSON serializable.
reconstructable_kwargs (Dict[str, Any]): Kwargs to the function to use to reconstruct the
pipeline. Values of the dict must be JSON serializable.
Examples:
.. code-block:: python
# module: mymodule
from dagster import PipelineDefinition, pipeline, build_reconstructable_pipeline
class PipelineFactory:
def make_pipeline(*args, **kwargs):
@pipeline
def _pipeline(...):
...
return _pipeline
def reconstruct_pipeline(*args):
factory = PipelineFactory()
return factory.make_pipeline(*args)
factory = PipelineFactory()
foo_pipeline_args = (...,...)
foo_pipeline_kwargs = {...:...}
foo_pipeline = factory.make_pipeline(*foo_pipeline_args, **foo_pipeline_kwargs)
reconstructable_foo_pipeline = build_reconstructable_pipeline(
'mymodule',
'reconstruct_pipeline',
foo_pipeline_args,
foo_pipeline_kwargs,
)
"""
check.str_param(reconstructor_module_name, "reconstructor_module_name")
check.str_param(reconstructor_function_name, "reconstructor_function_name")
reconstructable_args = list(check.opt_tuple_param(reconstructable_args, "reconstructable_args"))
reconstructable_kwargs = list(
(
[key, value]
for key, value in check.opt_dict_param(
reconstructable_kwargs, "reconstructable_kwargs", key_type=str
).items()
)
)
reconstructor_pointer = ModuleCodePointer(
reconstructor_module_name, reconstructor_function_name
)
pointer = CustomPointer(reconstructor_pointer, reconstructable_args, reconstructable_kwargs)
pipeline_def = pipeline_def_from_pointer(pointer)
return ReconstructablePipeline(
repository=ReconstructableRepository(pointer), # creates ephemeral repo
pipeline_name=pipeline_def.name,
)
def bootstrap_standalone_recon_pipeline(pointer):
# So this actually straps the the pipeline for the sole
# purpose of getting the pipeline name. If we changed ReconstructablePipeline
# to get the pipeline on demand in order to get name, we could avoid this.
pipeline_def = pipeline_def_from_pointer(pointer)
return ReconstructablePipeline(
repository=ReconstructableRepository(pointer), # creates ephemeral repo
pipeline_name=pipeline_def.name,
)
def _check_is_loadable(definition):
from .pipeline import PipelineDefinition
from .repository import RepositoryDefinition
if not isinstance(definition, (PipelineDefinition, RepositoryDefinition)):
raise DagsterInvariantViolationError(
(
"Loadable attributes must be either a PipelineDefinition or a "
"RepositoryDefinition. Got {definition}."
).format(definition=repr(definition))
)
return definition
def load_def_in_module(module_name, attribute):
return def_from_pointer(CodePointer.from_module(module_name, attribute))
def load_def_in_package(package_name, attribute):
return def_from_pointer(CodePointer.from_python_package(package_name, attribute))
def load_def_in_python_file(python_file, attribute, working_directory):
return def_from_pointer(CodePointer.from_python_file(python_file, attribute, working_directory))
def def_from_pointer(pointer):
target = pointer.load_target()
from .pipeline import PipelineDefinition
from .repository import RepositoryDefinition
if isinstance(target, (PipelineDefinition, RepositoryDefinition)) or not callable(target):
return _check_is_loadable(target)
# if its a function invoke it - otherwise we are pointing to a
# artifact in module scope, likely decorator output
if seven.get_args(target):
raise DagsterInvariantViolationError(
"Error invoking function at {target} with no arguments. "
"Reconstructable target must be callable with no arguments".format(
target=pointer.describe()
)
)
return _check_is_loadable(target())
def pipeline_def_from_pointer(pointer):
from .pipeline import PipelineDefinition
target = def_from_pointer(pointer)
if isinstance(target, PipelineDefinition):
return target
raise DagsterInvariantViolationError(
"CodePointer ({str}) must resolve to a PipelineDefinition. "
"Received a {type}".format(str=pointer.describe(), type=type(target))
)
def repository_def_from_target_def(target):
from .pipeline import PipelineDefinition
from .repository import RepositoryData, RepositoryDefinition
# special case - we can wrap a single pipeline in a repository
if isinstance(target, PipelineDefinition):
# consider including pipeline name in generated repo name
return RepositoryDefinition(
name=get_ephemeral_repository_name(target.name),
repository_data=RepositoryData.from_list([target]),
)
elif isinstance(target, RepositoryDefinition):
return target
else:
return None
def repository_def_from_pointer(pointer):
target = def_from_pointer(pointer)
repo_def = repository_def_from_target_def(target)
if not repo_def:
raise DagsterInvariantViolationError(
"CodePointer ({str}) must resolve to a "
"RepositoryDefinition or a PipelineDefinition. "
"Received a {type}".format(str=pointer.describe(), type=type(target))
)
return repo_def
diff --git a/python_modules/dagster/dagster/core/definitions/utils.py b/python_modules/dagster/dagster/core/definitions/utils.py
index 10aeccdcc..05799df05 100644
--- a/python_modules/dagster/dagster/core/definitions/utils.py
+++ b/python_modules/dagster/dagster/core/definitions/utils.py
@@ -1,221 +1,220 @@
import keyword
import os
import re
from glob import glob
import pkg_resources
import six
import yaml
from dagster import check, seven
from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
-from dagster.seven import FileNotFoundError, ModuleNotFoundError # pylint:disable=redefined-builtin
from dagster.utils import frozentags
from dagster.utils.yaml_utils import merge_yaml_strings, merge_yamls
DEFAULT_OUTPUT = "result"
DISALLOWED_NAMES = set(
[
"context",
"conf",
"config",
"meta",
"arg_dict",
"dict",
"input_arg_dict",
"output_arg_dict",
"int",
"str",
"float",
"bool",
"input",
"output",
"type",
]
+ keyword.kwlist # just disallow all python keywords
)
VALID_NAME_REGEX_STR = r"^[A-Za-z0-9_]+$"
VALID_NAME_REGEX = re.compile(VALID_NAME_REGEX_STR)
def has_valid_name_chars(name):
return bool(VALID_NAME_REGEX.match(name))
def check_valid_name(name):
check.str_param(name, "name")
if name in DISALLOWED_NAMES:
raise DagsterInvalidDefinitionError("{name} is not allowed.".format(name=name))
if not has_valid_name_chars(name):
raise DagsterInvalidDefinitionError(
"{name} must be in regex {regex}".format(name=name, regex=VALID_NAME_REGEX_STR)
)
check.invariant(is_valid_name(name))
return name
def is_valid_name(name):
check.str_param(name, "name")
return name not in DISALLOWED_NAMES and has_valid_name_chars(name)
def _kv_str(key, value):
return '{key}="{value}"'.format(key=key, value=repr(value))
def struct_to_string(name, **kwargs):
# Sort the kwargs to ensure consistent representations across Python versions
props_str = ", ".join([_kv_str(key, value) for key, value in sorted(kwargs.items())])
return "{name}({props_str})".format(name=name, props_str=props_str)
def validate_tags(tags):
valid_tags = {}
for key, value in check.opt_dict_param(tags, "tags", key_type=str).items():
if not isinstance(value, str):
valid = False
err_reason = 'Could not JSON encode value "{}"'.format(value)
try:
str_val = seven.json.dumps(value)
err_reason = 'JSON encoding "{json}" of value "{val}" is not equivalent to original value'.format(
json=str_val, val=value
)
valid = seven.json.loads(str_val) == value
except Exception: # pylint: disable=broad-except
pass
if not valid:
raise DagsterInvalidDefinitionError(
'Invalid value for tag "{key}", {err_reason}. Tag values must be strings '
"or meet the constraint that json.loads(json.dumps(value)) == value.".format(
key=key, err_reason=err_reason
)
)
valid_tags[key] = str_val
else:
valid_tags[key] = value
return frozentags(valid_tags)
def config_from_files(config_files):
"""Constructs run config from YAML files.
Args:
config_files (List[str]): List of paths or glob patterns for yaml files
to load and parse as the run config.
Returns:
Dict[Str, Any]: A run config dictionary constructed from provided YAML files.
Raises:
FileNotFoundError: When a config file produces no results
DagsterInvariantViolationError: When one of the YAML files is invalid and has a parse
error.
"""
config_files = check.opt_list_param(config_files, "config_files")
filenames = []
for file_glob in config_files or []:
globbed_files = glob(file_glob)
if not globbed_files:
raise DagsterInvariantViolationError(
'File or glob pattern "{file_glob}" for "config_files"'
"produced no results.".format(file_glob=file_glob)
)
filenames += [os.path.realpath(globbed_file) for globbed_file in globbed_files]
try:
run_config = merge_yamls(filenames)
except yaml.YAMLError as err:
six.raise_from(
DagsterInvariantViolationError(
"Encountered error attempting to parse yaml. Parsing files {file_set} "
"loaded by file/patterns {files}.".format(file_set=filenames, files=config_files)
),
err,
)
return run_config
def config_from_yaml_strings(yaml_strings):
"""Static constructor for run configs from YAML strings.
Args:
yaml_strings (List[str]): List of yaml strings to parse as the run config.
Returns:
Dict[Str, Any]: A run config dictionary constructed from the provided yaml strings
Raises:
DagsterInvariantViolationError: When one of the YAML documents is invalid and has a
parse error.
"""
yaml_strings = check.opt_list_param(yaml_strings, "yaml_strings", of_type=str)
try:
run_config = merge_yaml_strings(yaml_strings)
except yaml.YAMLError as err:
six.raise_from(
DagsterInvariantViolationError(
"Encountered error attempting to parse yaml. Parsing YAMLs {yaml_strings} ".format(
yaml_strings=yaml_strings
)
),
err,
)
return run_config
def config_from_pkg_resources(pkg_resource_defs):
"""Load a run config from a package resource, using :py:func:`pkg_resources.resource_string`.
Example:
.. code-block:: python
config_from_pkg_resources(
pkg_resource_defs=[
('dagster_examples.airline_demo.environments', 'local_base.yaml'),
('dagster_examples.airline_demo.environments', 'local_warehouse.yaml'),
],
)
Args:
pkg_resource_defs (List[(str, str)]): List of pkg_resource modules/files to
load as the run config.
Returns:
Dict[Str, Any]: A run config dictionary constructed from the provided yaml strings
Raises:
DagsterInvariantViolationError: When one of the YAML documents is invalid and has a
parse error.
"""
pkg_resource_defs = check.opt_list_param(pkg_resource_defs, "pkg_resource_defs", of_type=tuple)
try:
yaml_strings = [
six.ensure_str(pkg_resources.resource_string(*pkg_resource_def))
for pkg_resource_def in pkg_resource_defs
]
except (ModuleNotFoundError, FileNotFoundError, UnicodeDecodeError) as err:
six.raise_from(
DagsterInvariantViolationError(
"Encountered error attempting to parse yaml. Loading YAMLs from "
"package resources {pkg_resource_defs}.".format(pkg_resource_defs=pkg_resource_defs)
),
err,
)
return config_from_yaml_strings(yaml_strings=yaml_strings)
diff --git a/python_modules/dagster/dagster/core/execution/compute_logs.py b/python_modules/dagster/dagster/core/execution/compute_logs.py
index ae7bdffdc..32224f30e 100644
--- a/python_modules/dagster/dagster/core/execution/compute_logs.py
+++ b/python_modules/dagster/dagster/core/execution/compute_logs.py
@@ -1,168 +1,168 @@
from __future__ import print_function
import io
import os
import subprocess
import sys
+import tempfile
import time
import uuid
import warnings
from contextlib import contextmanager
-from dagster import seven
from dagster.core.execution import poll_compute_logs, watch_orphans
from dagster.serdes.ipc import interrupt_ipc_subprocess, open_ipc_subprocess
from dagster.seven import IS_WINDOWS, wait_for_process
from dagster.utils import ensure_file
WIN_PY36_COMPUTE_LOG_DISABLED_MSG = """\u001b[33mWARNING: Compute log capture is disabled for the current environment. Set the environment variable `PYTHONLEGACYWINDOWSSTDIO` to enable.\n\u001b[0m"""
@contextmanager
def redirect_to_file(stream, filepath):
with open(filepath, "a+", buffering=1) as file_stream:
with redirect_stream(file_stream, stream):
yield
@contextmanager
def mirror_stream_to_file(stream, filepath):
ensure_file(filepath)
with tail_to_stream(filepath, stream) as pids:
with redirect_to_file(stream, filepath):
yield pids
def should_disable_io_stream_redirect():
# See https://stackoverflow.com/a/52377087
# https://www.python.org/dev/peps/pep-0528/
return (
os.name == "nt"
and sys.version_info.major == 3
and sys.version_info.minor >= 6
and not os.environ.get("PYTHONLEGACYWINDOWSSTDIO")
)
def warn_if_compute_logs_disabled():
if should_disable_io_stream_redirect():
warnings.warn(WIN_PY36_COMPUTE_LOG_DISABLED_MSG)
@contextmanager
def redirect_stream(to_stream=os.devnull, from_stream=sys.stdout):
# swap the file descriptors to capture system-level output in the process
# From https://stackoverflow.com/questions/4675728/redirect-stdout-to-a-file-in-python/22434262#22434262
from_fd = _fileno(from_stream)
to_fd = _fileno(to_stream)
if not from_fd or not to_fd or should_disable_io_stream_redirect():
yield
return
with os.fdopen(os.dup(from_fd), "wb") as copied:
from_stream.flush()
try:
os.dup2(_fileno(to_stream), from_fd)
except ValueError:
with open(to_stream, "wb") as to_file:
os.dup2(to_file.fileno(), from_fd)
try:
yield from_stream
finally:
from_stream.flush()
to_stream.flush()
os.dup2(copied.fileno(), from_fd)
@contextmanager
def tail_to_stream(path, stream):
if IS_WINDOWS:
with execute_windows_tail(path, stream) as pids:
yield pids
else:
with execute_posix_tail(path, stream) as pids:
yield pids
@contextmanager
def execute_windows_tail(path, stream):
# Cannot use multiprocessing here because we already may be in a daemonized process
# Instead, invoke a thin script to poll a file and dump output to stdout. We pass the current
# pid so that the poll process kills itself if it becomes orphaned
poll_file = os.path.abspath(poll_compute_logs.__file__)
stream = stream if _fileno(stream) else None
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
ipc_output_file = os.path.join(
temp_dir, "execute-windows-tail-{uuid}".format(uuid=uuid.uuid4().hex)
)
try:
tail_process = open_ipc_subprocess(
[sys.executable, poll_file, path, str(os.getpid()), ipc_output_file], stdout=stream
)
yield (tail_process.pid, None)
finally:
if tail_process:
start_time = time.time()
while not os.path.isfile(ipc_output_file):
if time.time() - start_time > 15:
raise Exception("Timed out waiting for tail process to start")
time.sleep(1)
# Now that we know the tail process has started, tell it to terminate once there is
# nothing more to output
interrupt_ipc_subprocess(tail_process)
wait_for_process(tail_process)
@contextmanager
def execute_posix_tail(path, stream):
# open a subprocess to tail the file and print to stdout
tail_cmd = "tail -F -c +0 {}".format(path).split(" ")
stream = stream if _fileno(stream) else None
try:
tail_process = None
watcher_process = None
tail_process = subprocess.Popen(tail_cmd, stdout=stream)
# open a watcher process to check for the orphaning of the tail process (e.g. when the
# current process is suddenly killed)
watcher_file = os.path.abspath(watch_orphans.__file__)
watcher_process = subprocess.Popen(
[sys.executable, watcher_file, str(os.getpid()), str(tail_process.pid),]
)
yield (tail_process.pid, watcher_process.pid)
finally:
if tail_process:
_clean_up_subprocess(tail_process)
if watcher_process:
_clean_up_subprocess(watcher_process)
def _clean_up_subprocess(subprocess_obj):
try:
if subprocess_obj:
subprocess_obj.terminate()
wait_for_process(subprocess_obj)
except OSError:
pass
def _fileno(stream):
try:
fd = getattr(stream, "fileno", lambda: stream)()
except io.UnsupportedOperation:
# Test CLI runners will stub out stdout to a non-file stream, which will raise an
# UnsupportedOperation if `fileno` is accessed. We need to make sure we do not error out,
# or tests will fail
return None
if isinstance(fd, int):
return fd
return None
diff --git a/python_modules/dagster/dagster/core/instance/__init__.py b/python_modules/dagster/dagster/core/instance/__init__.py
index da7a5804d..6c2e886ad 100644
--- a/python_modules/dagster/dagster/core/instance/__init__.py
+++ b/python_modules/dagster/dagster/core/instance/__init__.py
@@ -1,1382 +1,1383 @@
import logging
import os
import sys
+import tempfile
import time
import warnings
from collections import defaultdict
from datetime import datetime
from enum import Enum
import yaml
-from dagster import check, seven
+from dagster import check
from dagster.core.definitions.events import AssetKey
from dagster.core.definitions.pipeline import PipelineDefinition, PipelineSubsetDefinition
from dagster.core.errors import (
DagsterInvariantViolationError,
DagsterRunAlreadyExists,
DagsterRunConflict,
)
from dagster.core.storage.migration.utils import upgrading_instance
from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus
from dagster.core.storage.tags import MEMOIZED_RUN_TAG
from dagster.core.system_config.objects import EnvironmentConfig
from dagster.core.utils import str_format_list
from dagster.serdes import ConfigurableClass
from dagster.seven import get_current_datetime_in_utc
from dagster.utils.error import serializable_error_info_from_exc_info
from .config import DAGSTER_CONFIG_YAML_FILENAME
from .ref import InstanceRef
# 'airflow_execution_date' and 'is_airflow_ingest_pipeline' are hardcoded tags used in the
# airflow ingestion logic (see: dagster_pipeline_factory.py). 'airflow_execution_date' stores the
# 'execution_date' used in Airflow operator execution and 'is_airflow_ingest_pipeline' determines
# whether 'airflow_execution_date' is needed.
# https://github.com/dagster-io/dagster/issues/2403
AIRFLOW_EXECUTION_DATE_STR = "airflow_execution_date"
IS_AIRFLOW_INGEST_PIPELINE_STR = "is_airflow_ingest_pipeline"
def _is_dagster_home_set():
return bool(os.getenv("DAGSTER_HOME"))
def is_memoized_run(tags):
return tags is not None and MEMOIZED_RUN_TAG in tags and tags.get(MEMOIZED_RUN_TAG) == "true"
def _dagster_home():
dagster_home_path = os.getenv("DAGSTER_HOME")
if not dagster_home_path:
raise DagsterInvariantViolationError(
(
"The environment variable $DAGSTER_HOME is not set. Dagster requires this "
"environment variable to be set to an existing directory in your filesystem "
"that contains your dagster instance configuration file (dagster.yaml).\n"
"You can resolve this error by exporting the environment variable."
"For example, you can run the following command in your shell or "
"include it in your shell configuration file:\n"
'\texport DAGSTER_HOME="~/dagster_home"'
)
)
dagster_home_path = os.path.expanduser(dagster_home_path)
if not os.path.isabs(dagster_home_path):
raise DagsterInvariantViolationError(
(
'$DAGSTER_HOME "{}" must be an absolute path. Dagster requires this '
"environment variable to be set to an existing directory in your filesystem that"
"contains your dagster instance configuration file (dagster.yaml)."
).format(dagster_home_path)
)
if not (os.path.exists(dagster_home_path) and os.path.isdir(dagster_home_path)):
raise DagsterInvariantViolationError(
(
'$DAGSTER_HOME "{}" is not a directory or does not exist. Dagster requires this '
"environment variable to be set to an existing directory in your filesystem that "
"contains your dagster instance configuration file (dagster.yaml)."
).format(dagster_home_path)
)
return dagster_home_path
def _check_run_equality(pipeline_run, candidate_run):
check.inst_param(pipeline_run, "pipeline_run", PipelineRun)
check.inst_param(candidate_run, "candidate_run", PipelineRun)
field_diff = {}
for field in pipeline_run._fields:
expected_value = getattr(pipeline_run, field)
candidate_value = getattr(candidate_run, field)
if expected_value != candidate_value:
field_diff[field] = (expected_value, candidate_value)
return field_diff
def _format_field_diff(field_diff):
return "\n".join(
[
(
" {field_name}:\n"
+ " Expected: {expected_value}\n"
+ " Received: {candidate_value}"
).format(
field_name=field_name,
expected_value=expected_value,
candidate_value=candidate_value,
)
for field_name, (expected_value, candidate_value,) in field_diff.items()
]
)
class _EventListenerLogHandler(logging.Handler):
def __init__(self, instance):
self._instance = instance
super(_EventListenerLogHandler, self).__init__()
def emit(self, record):
from dagster.core.events.log import construct_event_record, StructuredLoggerMessage
try:
event = construct_event_record(
StructuredLoggerMessage(
name=record.name,
message=record.msg,
level=record.levelno,
meta=record.dagster_meta,
record=record,
)
)
self._instance.handle_new_event(event)
except Exception as e: # pylint: disable=W0703
logging.critical("Error during instance event listen")
logging.exception(str(e))
raise
class InstanceType(Enum):
PERSISTENT = "PERSISTENT"
EPHEMERAL = "EPHEMERAL"
class DagsterInstance:
"""Core abstraction for managing Dagster's access to storage and other resources.
Use DagsterInstance.get() to grab the current DagsterInstance which will load based on
the values in the ``dagster.yaml`` file in ``$DAGSTER_HOME`` if set, otherwise fallback
to using an ephemeral in-memory set of components.
Configuration of this class should be done by setting values in ``$DAGSTER_HOME/dagster.yaml``.
For example, to use Postgres for run and event log storage, you can write a ``dagster.yaml``
such as the following:
.. literalinclude:: ../../../../docs/sections/deploying/postgres_dagster.yaml
:caption: dagster.yaml
:language: YAML
Args:
instance_type (InstanceType): Indicates whether the instance is ephemeral or persistent.
Users should not attempt to set this value directly or in their ``dagster.yaml`` files.
local_artifact_storage (LocalArtifactStorage): The local artifact storage is used to
configure storage for any artifacts that require a local disk, such as schedules, or
when using the filesystem system storage to manage files and intermediates. By default,
this will be a :py:class:`dagster.core.storage.root.LocalArtifactStorage`. Configurable
in ``dagster.yaml`` using the :py:class:`~dagster.serdes.ConfigurableClass`
machinery.
run_storage (RunStorage): The run storage is used to store metadata about ongoing and past
pipeline runs. By default, this will be a
:py:class:`dagster.core.storage.runs.SqliteRunStorage`. Configurable in ``dagster.yaml``
using the :py:class:`~dagster.serdes.ConfigurableClass` machinery.
event_storage (EventLogStorage): Used to store the structured event logs generated by
pipeline runs. By default, this will be a
:py:class:`dagster.core.storage.event_log.SqliteEventLogStorage`. Configurable in
``dagster.yaml`` using the :py:class:`~dagster.serdes.ConfigurableClass` machinery.
compute_log_manager (ComputeLogManager): The compute log manager handles stdout and stderr
logging for solid compute functions. By default, this will be a
:py:class:`dagster.core.storage.local_compute_log_manager.LocalComputeLogManager`.
Configurable in ``dagster.yaml`` using the
:py:class:`~dagster.serdes.ConfigurableClass` machinery.
run_coordinator (RunCoordinator): A runs coordinator may be used to manage the execution
of pipeline runs.
run_launcher (Optional[RunLauncher]): Optionally, a run launcher may be used to enable
a Dagster instance to launch pipeline runs, e.g. on a remote Kubernetes cluster, in
addition to running them locally.
settings (Optional[Dict]): Specifies certain per-instance settings,
such as feature flags. These are set in the ``dagster.yaml`` under a set of whitelisted
keys.
ref (Optional[InstanceRef]): Used by internal machinery to pass instances across process
boundaries.
"""
_PROCESS_TEMPDIR = None
def __init__(
self,
instance_type,
local_artifact_storage,
run_storage,
event_storage,
compute_log_manager,
schedule_storage=None,
scheduler=None,
run_coordinator=None,
run_launcher=None,
settings=None,
ref=None,
):
from dagster.core.storage.compute_log_manager import ComputeLogManager
from dagster.core.storage.event_log import EventLogStorage
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import RunStorage
from dagster.core.storage.schedules import ScheduleStorage
from dagster.core.scheduler import Scheduler
from dagster.core.run_coordinator import RunCoordinator
from dagster.core.launcher import RunLauncher
self._instance_type = check.inst_param(instance_type, "instance_type", InstanceType)
self._local_artifact_storage = check.inst_param(
local_artifact_storage, "local_artifact_storage", LocalArtifactStorage
)
self._event_storage = check.inst_param(event_storage, "event_storage", EventLogStorage)
self._run_storage = check.inst_param(run_storage, "run_storage", RunStorage)
self._compute_log_manager = check.inst_param(
compute_log_manager, "compute_log_manager", ComputeLogManager
)
self._schedule_storage = check.opt_inst_param(
schedule_storage, "schedule_storage", ScheduleStorage
)
self._scheduler = check.opt_inst_param(scheduler, "scheduler", Scheduler)
self._run_coordinator = check.inst_param(run_coordinator, "run_coordinator", RunCoordinator)
self._run_coordinator.initialize(self)
self._run_launcher = check.inst_param(run_launcher, "run_launcher", RunLauncher)
self._run_launcher.initialize(self)
self._settings = check.opt_dict_param(settings, "settings")
self._ref = check.opt_inst_param(ref, "ref", InstanceRef)
self._subscribers = defaultdict(list)
# ctors
@staticmethod
def ephemeral(tempdir=None, preload=None):
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.launcher.sync_in_memory_run_launcher import SyncInMemoryRunLauncher
from dagster.core.storage.event_log import InMemoryEventLogStorage
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import InMemoryRunStorage
from dagster.core.storage.noop_compute_log_manager import NoOpComputeLogManager
if tempdir is None:
tempdir = DagsterInstance.temp_storage()
return DagsterInstance(
InstanceType.EPHEMERAL,
local_artifact_storage=LocalArtifactStorage(tempdir),
run_storage=InMemoryRunStorage(preload=preload),
event_storage=InMemoryEventLogStorage(preload=preload),
compute_log_manager=NoOpComputeLogManager(),
run_coordinator=DefaultRunCoordinator(),
run_launcher=SyncInMemoryRunLauncher(),
)
@staticmethod
def get(fallback_storage=None):
# 1. Use $DAGSTER_HOME to determine instance if set.
if _is_dagster_home_set():
return DagsterInstance.from_config(_dagster_home())
# 2. If that is not set use the fallback storage directory if provided.
# This allows us to have a nice out of the box dagit experience where runs are persisted
# across restarts in a tempdir that gets cleaned up when the dagit watchdog process exits.
elif fallback_storage is not None:
return DagsterInstance.from_config(fallback_storage)
# 3. If all else fails create an ephemeral in memory instance.
else:
return DagsterInstance.ephemeral(fallback_storage)
@staticmethod
def local_temp(tempdir=None, overrides=None):
warnings.warn(
"To create a local DagsterInstance for a test, use the instance_for_test "
"context manager instead, which ensures that resoures are cleaned up afterwards"
)
if tempdir is None:
tempdir = DagsterInstance.temp_storage()
return DagsterInstance.from_ref(InstanceRef.from_dir(tempdir, overrides=overrides))
@staticmethod
def from_config(config_dir, config_filename=DAGSTER_CONFIG_YAML_FILENAME):
instance_ref = InstanceRef.from_dir(config_dir, config_filename=config_filename)
return DagsterInstance.from_ref(instance_ref)
@staticmethod
def from_ref(instance_ref):
check.inst_param(instance_ref, "instance_ref", InstanceRef)
return DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=instance_ref.local_artifact_storage,
run_storage=instance_ref.run_storage,
event_storage=instance_ref.event_storage,
compute_log_manager=instance_ref.compute_log_manager,
schedule_storage=instance_ref.schedule_storage,
scheduler=instance_ref.scheduler,
run_coordinator=instance_ref.run_coordinator,
run_launcher=instance_ref.run_launcher,
settings=instance_ref.settings,
ref=instance_ref,
)
# flags
@property
def is_persistent(self):
return self._instance_type == InstanceType.PERSISTENT
@property
def is_ephemeral(self):
return self._instance_type == InstanceType.EPHEMERAL
def get_ref(self):
if self._ref:
return self._ref
check.failed(
"Attempted to prepare an ineligible DagsterInstance ({inst_type}) for cross "
"process communication.{dagster_home_msg}".format(
inst_type=self._instance_type,
dagster_home_msg="\nDAGSTER_HOME environment variable is not set, set it to "
"a directory on the filesystem for dagster to use for storage and cross "
"process coordination."
if os.getenv("DAGSTER_HOME") is None
else "",
)
)
@property
def root_directory(self):
return self._local_artifact_storage.base_dir
@staticmethod
def temp_storage():
if DagsterInstance._PROCESS_TEMPDIR is None:
- DagsterInstance._PROCESS_TEMPDIR = seven.TemporaryDirectory()
+ DagsterInstance._PROCESS_TEMPDIR = tempfile.TemporaryDirectory()
return DagsterInstance._PROCESS_TEMPDIR.name
def _info(self, component):
prefix = " "
# ConfigurableClass may not have inst_data if it's a direct instantiation
# which happens for ephemeral instances
if isinstance(component, ConfigurableClass) and component.inst_data:
return component.inst_data.info_str(prefix)
if type(component) is dict:
return prefix + yaml.dump(component, default_flow_style=False).replace(
"\n", "\n" + prefix
)
return "{}{}\n".format(prefix, component.__class__.__name__)
def info_str_for_component(self, component_name, component):
return "{component_name}:\n{component}\n".format(
component_name=component_name, component=self._info(component)
)
def info_str(self):
settings = self._settings if self._settings else {}
return (
"local_artifact_storage:\n{artifact}\n"
"run_storage:\n{run}\n"
"event_log_storage:\n{event}\n"
"compute_logs:\n{compute}\n"
"schedule_storage:\n{schedule_storage}\n"
"scheduler:\n{scheduler}\n"
"run_coordinator:\n{run_coordinator}\n"
"run_launcher:\n{run_launcher}\n"
"".format(
artifact=self._info(self._local_artifact_storage),
run=self._info(self._run_storage),
event=self._info(self._event_storage),
compute=self._info(self._compute_log_manager),
schedule_storage=self._info(self._schedule_storage),
scheduler=self._info(self._scheduler),
run_coordinator=self._info(self._run_coordinator),
run_launcher=self._info(self._run_launcher),
)
+ "\n".join(
[
"{settings_key}:\n{settings_value}".format(
settings_key=settings_key, settings_value=self._info(settings_value)
)
for settings_key, settings_value in settings.items()
]
)
)
# schedule storage
@property
def schedule_storage(self):
return self._schedule_storage
# schedule storage
@property
def scheduler(self):
return self._scheduler
# run coordinator
@property
def run_coordinator(self):
return self._run_coordinator
# run launcher
@property
def run_launcher(self):
return self._run_launcher
# compute logs
@property
def compute_log_manager(self):
return self._compute_log_manager
def get_settings(self, settings_key):
check.str_param(settings_key, "settings_key")
if self._settings and settings_key in self._settings:
return self._settings.get(settings_key)
return {}
@property
def telemetry_enabled(self):
if self.is_ephemeral:
return False
dagster_telemetry_enabled_default = True
telemetry_settings = self.get_settings("telemetry")
if not telemetry_settings:
return dagster_telemetry_enabled_default
if "enabled" in telemetry_settings:
return telemetry_settings["enabled"]
else:
return dagster_telemetry_enabled_default
def upgrade(self, print_fn=lambda _: None):
with upgrading_instance(self):
print_fn("Updating run storage...")
self._run_storage.upgrade()
print_fn("Updating event storage...")
self._event_storage.upgrade()
print_fn("Updating schedule storage...")
self._schedule_storage.upgrade()
def optimize_for_dagit(self, statement_timeout):
self._run_storage.optimize_for_dagit(statement_timeout=statement_timeout)
self._event_storage.optimize_for_dagit(statement_timeout=statement_timeout)
if self._schedule_storage:
self._schedule_storage.optimize_for_dagit(statement_timeout=statement_timeout)
def reindex(self, print_fn=lambda _: None):
print_fn("Checking for reindexing...")
self._event_storage.reindex(print_fn)
print_fn("Done.")
def dispose(self):
self._run_storage.dispose()
self.run_coordinator.dispose()
self._run_launcher.dispose()
self._event_storage.dispose()
self._compute_log_manager.dispose()
# run storage
def get_run_by_id(self, run_id):
return self._run_storage.get_run_by_id(run_id)
def get_pipeline_snapshot(self, snapshot_id):
return self._run_storage.get_pipeline_snapshot(snapshot_id)
def has_pipeline_snapshot(self, snapshot_id):
return self._run_storage.has_pipeline_snapshot(snapshot_id)
def get_historical_pipeline(self, snapshot_id):
from dagster.core.host_representation import HistoricalPipeline
snapshot = self._run_storage.get_pipeline_snapshot(snapshot_id)
parent_snapshot = (
self._run_storage.get_pipeline_snapshot(snapshot.lineage_snapshot.parent_snapshot_id)
if snapshot.lineage_snapshot
else None
)
return HistoricalPipeline(
self._run_storage.get_pipeline_snapshot(snapshot_id), snapshot_id, parent_snapshot
)
def has_historical_pipeline(self, snapshot_id):
return self._run_storage.has_pipeline_snapshot(snapshot_id)
def get_execution_plan_snapshot(self, snapshot_id):
return self._run_storage.get_execution_plan_snapshot(snapshot_id)
def get_run_stats(self, run_id):
return self._event_storage.get_stats_for_run(run_id)
def get_run_step_stats(self, run_id, step_keys=None):
return self._event_storage.get_step_stats_for_run(run_id, step_keys)
def get_run_tags(self):
return self._run_storage.get_run_tags()
def get_run_group(self, run_id):
return self._run_storage.get_run_group(run_id)
def create_run_for_pipeline(
self,
pipeline_def,
execution_plan=None,
run_id=None,
run_config=None,
mode=None,
solids_to_execute=None,
step_keys_to_execute=None,
status=None,
tags=None,
root_run_id=None,
parent_run_id=None,
solid_selection=None,
):
from dagster.core.execution.api import create_execution_plan
from dagster.core.execution.plan.plan import ExecutionPlan
from dagster.core.snap import snapshot_from_execution_plan
check.inst_param(pipeline_def, "pipeline_def", PipelineDefinition)
check.opt_inst_param(execution_plan, "execution_plan", ExecutionPlan)
# note that solids_to_execute is required to execute the solid subset, which is the
# frozenset version of the previous solid_subset.
# solid_selection is not required and will not be converted to solids_to_execute here.
# i.e. this function doesn't handle solid queries.
# solid_selection is only used to pass the user queries further down.
check.opt_set_param(solids_to_execute, "solids_to_execute", of_type=str)
check.opt_list_param(solid_selection, "solid_selection", of_type=str)
if solids_to_execute:
if isinstance(pipeline_def, PipelineSubsetDefinition):
# for the case when pipeline_def is created by IPipeline or ExternalPipeline
check.invariant(
solids_to_execute == pipeline_def.solids_to_execute,
"Cannot create a PipelineRun from pipeline subset {pipeline_solids_to_execute} "
"that conflicts with solids_to_execute arg {solids_to_execute}".format(
pipeline_solids_to_execute=str_format_list(pipeline_def.solids_to_execute),
solids_to_execute=str_format_list(solids_to_execute),
),
)
else:
# for cases when `create_run_for_pipeline` is directly called
pipeline_def = pipeline_def.get_pipeline_subset_def(
solids_to_execute=solids_to_execute
)
full_execution_plan = execution_plan or create_execution_plan(
pipeline_def, run_config=run_config, mode=mode,
)
check.invariant(
len(full_execution_plan.step_keys_to_execute) == len(full_execution_plan.steps)
)
if is_memoized_run(tags):
from dagster.core.execution.resolve_versions import resolve_memoized_execution_plan
if step_keys_to_execute:
raise DagsterInvariantViolationError(
"step_keys_to_execute parameter cannot be used in conjunction with memoized "
"pipeline runs."
)
subsetted_execution_plan = resolve_memoized_execution_plan(
full_execution_plan
) # TODO: tighter integration with existing step_keys_to_execute functionality
step_keys_to_execute = subsetted_execution_plan.step_keys_to_execute
else:
subsetted_execution_plan = (
full_execution_plan.build_subset_plan(step_keys_to_execute)
if step_keys_to_execute
else full_execution_plan
)
return self.create_run(
pipeline_name=pipeline_def.name,
run_id=run_id,
run_config=run_config,
mode=check.opt_str_param(mode, "mode", default=pipeline_def.get_default_mode_name()),
solid_selection=solid_selection,
solids_to_execute=solids_to_execute,
step_keys_to_execute=step_keys_to_execute,
status=status,
tags=tags,
root_run_id=root_run_id,
parent_run_id=parent_run_id,
pipeline_snapshot=pipeline_def.get_pipeline_snapshot(),
execution_plan_snapshot=snapshot_from_execution_plan(
subsetted_execution_plan, pipeline_def.get_pipeline_snapshot_id()
),
parent_pipeline_snapshot=pipeline_def.get_parent_pipeline_snapshot(),
)
def _construct_run_with_snapshots(
self,
pipeline_name,
run_id,
run_config,
mode,
solids_to_execute,
step_keys_to_execute,
status,
tags,
root_run_id,
parent_run_id,
pipeline_snapshot,
execution_plan_snapshot,
parent_pipeline_snapshot,
solid_selection=None,
external_pipeline_origin=None,
):
# https://github.com/dagster-io/dagster/issues/2403
if tags and IS_AIRFLOW_INGEST_PIPELINE_STR in tags:
if AIRFLOW_EXECUTION_DATE_STR not in tags:
tags[AIRFLOW_EXECUTION_DATE_STR] = get_current_datetime_in_utc().isoformat()
check.invariant(
not (not pipeline_snapshot and execution_plan_snapshot),
"It is illegal to have an execution plan snapshot and not have a pipeline snapshot. "
"It is possible to have no execution plan snapshot since we persist runs "
"that do not successfully compile execution plans in the scheduled case.",
)
pipeline_snapshot_id = (
self._ensure_persisted_pipeline_snapshot(pipeline_snapshot, parent_pipeline_snapshot)
if pipeline_snapshot
else None
)
execution_plan_snapshot_id = (
self._ensure_persisted_execution_plan_snapshot(
execution_plan_snapshot, pipeline_snapshot_id, step_keys_to_execute
)
if execution_plan_snapshot and pipeline_snapshot_id
else None
)
return PipelineRun(
pipeline_name=pipeline_name,
run_id=run_id,
run_config=run_config,
mode=mode,
solid_selection=solid_selection,
solids_to_execute=solids_to_execute,
step_keys_to_execute=step_keys_to_execute,
status=status,
tags=tags,
root_run_id=root_run_id,
parent_run_id=parent_run_id,
pipeline_snapshot_id=pipeline_snapshot_id,
execution_plan_snapshot_id=execution_plan_snapshot_id,
external_pipeline_origin=external_pipeline_origin,
)
def _ensure_persisted_pipeline_snapshot(self, pipeline_snapshot, parent_pipeline_snapshot):
from dagster.core.snap import create_pipeline_snapshot_id, PipelineSnapshot
check.inst_param(pipeline_snapshot, "pipeline_snapshot", PipelineSnapshot)
check.opt_inst_param(parent_pipeline_snapshot, "parent_pipeline_snapshot", PipelineSnapshot)
if pipeline_snapshot.lineage_snapshot:
if not self._run_storage.has_pipeline_snapshot(
pipeline_snapshot.lineage_snapshot.parent_snapshot_id
):
check.invariant(
create_pipeline_snapshot_id(parent_pipeline_snapshot)
== pipeline_snapshot.lineage_snapshot.parent_snapshot_id,
"Parent pipeline snapshot id out of sync with passed parent pipeline snapshot",
)
returned_pipeline_snapshot_id = self._run_storage.add_pipeline_snapshot(
parent_pipeline_snapshot
)
check.invariant(
pipeline_snapshot.lineage_snapshot.parent_snapshot_id
== returned_pipeline_snapshot_id
)
pipeline_snapshot_id = create_pipeline_snapshot_id(pipeline_snapshot)
if not self._run_storage.has_pipeline_snapshot(pipeline_snapshot_id):
returned_pipeline_snapshot_id = self._run_storage.add_pipeline_snapshot(
pipeline_snapshot
)
check.invariant(pipeline_snapshot_id == returned_pipeline_snapshot_id)
return pipeline_snapshot_id
def _ensure_persisted_execution_plan_snapshot(
self, execution_plan_snapshot, pipeline_snapshot_id, step_keys_to_execute
):
from dagster.core.snap.execution_plan_snapshot import (
ExecutionPlanSnapshot,
create_execution_plan_snapshot_id,
)
check.inst_param(execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot)
check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id")
check.opt_list_param(step_keys_to_execute, "step_keys_to_execute", of_type=str)
check.invariant(
execution_plan_snapshot.pipeline_snapshot_id == pipeline_snapshot_id,
(
"Snapshot mismatch: Snapshot ID in execution plan snapshot is "
'"{ep_pipeline_snapshot_id}" and snapshot_id created in memory is '
'"{pipeline_snapshot_id}"'
).format(
ep_pipeline_snapshot_id=execution_plan_snapshot.pipeline_snapshot_id,
pipeline_snapshot_id=pipeline_snapshot_id,
),
)
check.invariant(
set(step_keys_to_execute) == set(execution_plan_snapshot.step_keys_to_execute)
if step_keys_to_execute
else set(execution_plan_snapshot.step_keys_to_execute)
== set([step.key for step in execution_plan_snapshot.steps]),
"We encode step_keys_to_execute twice in our stack, unfortunately. This check "
"ensures that they are consistent. We check that step_keys_to_execute in the plan "
"matches the step_keys_to_execute params if it is set. If it is not, this indicates "
"a full execution plan, and so we verify that.",
)
execution_plan_snapshot_id = create_execution_plan_snapshot_id(execution_plan_snapshot)
if not self._run_storage.has_execution_plan_snapshot(execution_plan_snapshot_id):
returned_execution_plan_snapshot_id = self._run_storage.add_execution_plan_snapshot(
execution_plan_snapshot
)
check.invariant(execution_plan_snapshot_id == returned_execution_plan_snapshot_id)
return execution_plan_snapshot_id
def create_run(
self,
pipeline_name,
run_id,
run_config,
mode,
solids_to_execute,
step_keys_to_execute,
status,
tags,
root_run_id,
parent_run_id,
pipeline_snapshot,
execution_plan_snapshot,
parent_pipeline_snapshot,
solid_selection=None,
external_pipeline_origin=None,
):
pipeline_run = self._construct_run_with_snapshots(
pipeline_name=pipeline_name,
run_id=run_id,
run_config=run_config,
mode=mode,
solid_selection=solid_selection,
solids_to_execute=solids_to_execute,
step_keys_to_execute=step_keys_to_execute,
status=status,
tags=tags,
root_run_id=root_run_id,
parent_run_id=parent_run_id,
pipeline_snapshot=pipeline_snapshot,
execution_plan_snapshot=execution_plan_snapshot,
parent_pipeline_snapshot=parent_pipeline_snapshot,
external_pipeline_origin=external_pipeline_origin,
)
return self._run_storage.add_run(pipeline_run)
def register_managed_run(
self,
pipeline_name,
run_id,
run_config,
mode,
solids_to_execute,
step_keys_to_execute,
tags,
root_run_id,
parent_run_id,
pipeline_snapshot,
execution_plan_snapshot,
parent_pipeline_snapshot,
solid_selection=None,
):
# The usage of this method is limited to dagster-airflow, specifically in Dagster
# Operators that are executed in Airflow. Because a common workflow in Airflow is to
# retry dags from arbitrary tasks, we need any node to be capable of creating a
# PipelineRun.
#
# The try-except DagsterRunAlreadyExists block handles the race when multiple "root" tasks
# simultaneously execute self._run_storage.add_run(pipeline_run). When this happens, only
# one task succeeds in creating the run, while the others get DagsterRunAlreadyExists
# error; at this point, the failed tasks try again to fetch the existing run.
# https://github.com/dagster-io/dagster/issues/2412
pipeline_run = self._construct_run_with_snapshots(
pipeline_name=pipeline_name,
run_id=run_id,
run_config=run_config,
mode=mode,
solid_selection=solid_selection,
solids_to_execute=solids_to_execute,
step_keys_to_execute=step_keys_to_execute,
status=PipelineRunStatus.MANAGED,
tags=tags,
root_run_id=root_run_id,
parent_run_id=parent_run_id,
pipeline_snapshot=pipeline_snapshot,
execution_plan_snapshot=execution_plan_snapshot,
parent_pipeline_snapshot=parent_pipeline_snapshot,
)
def get_run():
candidate_run = self.get_run_by_id(pipeline_run.run_id)
field_diff = _check_run_equality(pipeline_run, candidate_run)
if field_diff:
raise DagsterRunConflict(
"Found conflicting existing run with same id {run_id}. Runs differ in:"
"\n{field_diff}".format(
run_id=pipeline_run.run_id, field_diff=_format_field_diff(field_diff),
),
)
return candidate_run
if self.has_run(pipeline_run.run_id):
return get_run()
try:
return self._run_storage.add_run(pipeline_run)
except DagsterRunAlreadyExists:
return get_run()
def add_run(self, pipeline_run):
return self._run_storage.add_run(pipeline_run)
def handle_run_event(self, run_id, event):
return self._run_storage.handle_run_event(run_id, event)
def add_run_tags(self, run_id, new_tags):
return self._run_storage.add_run_tags(run_id, new_tags)
def has_run(self, run_id):
return self._run_storage.has_run(run_id)
def get_runs(self, filters=None, cursor=None, limit=None):
return self._run_storage.get_runs(filters, cursor, limit)
def get_runs_count(self, filters=None):
return self._run_storage.get_runs_count(filters)
def get_run_groups(self, filters=None, cursor=None, limit=None):
return self._run_storage.get_run_groups(filters=filters, cursor=cursor, limit=limit)
def wipe(self):
self._run_storage.wipe()
self._event_storage.wipe()
def delete_run(self, run_id):
self._run_storage.delete_run(run_id)
self._event_storage.delete_events(run_id)
# event storage
def logs_after(self, run_id, cursor):
return self._event_storage.get_logs_for_run(run_id, cursor=cursor)
def all_logs(self, run_id):
return self._event_storage.get_logs_for_run(run_id)
def watch_event_logs(self, run_id, cursor, cb):
return self._event_storage.watch(run_id, cursor, cb)
# asset storage
@property
def is_asset_aware(self):
return self._event_storage.is_asset_aware
def check_asset_aware(self):
check.invariant(
self.is_asset_aware,
(
"Asset queries can only be performed on instances with asset-aware event log "
"storage. Use `instance.is_asset_aware` to verify that the instance is configured "
"with an EventLogStorage that implements `AssetAwareEventLogStorage`"
),
)
def all_asset_keys(self, prefix_path=None):
self.check_asset_aware()
return self._event_storage.get_all_asset_keys(prefix_path)
def has_asset_key(self, asset_key):
self.check_asset_aware()
return self._event_storage.has_asset_key(asset_key)
def events_for_asset_key(self, asset_key, cursor=None, limit=None):
check.inst_param(asset_key, "asset_key", AssetKey)
self.check_asset_aware()
return self._event_storage.get_asset_events(asset_key, cursor, limit)
def run_ids_for_asset_key(self, asset_key):
check.inst_param(asset_key, "asset_key", AssetKey)
self.check_asset_aware()
return self._event_storage.get_asset_run_ids(asset_key)
def wipe_assets(self, asset_keys):
check.list_param(asset_keys, "asset_keys", of_type=AssetKey)
self.check_asset_aware()
for asset_key in asset_keys:
self._event_storage.wipe_asset(asset_key)
# event subscriptions
def get_logger(self):
logger = logging.Logger("__event_listener")
logger.addHandler(_EventListenerLogHandler(self))
logger.setLevel(10)
return logger
def handle_new_event(self, event):
run_id = event.run_id
self._event_storage.store_event(event)
if event.is_dagster_event and event.dagster_event.is_pipeline_event:
self._run_storage.handle_run_event(run_id, event.dagster_event)
for sub in self._subscribers[run_id]:
sub(event)
def add_event_listener(self, run_id, cb):
self._subscribers[run_id].append(cb)
def report_engine_event(
self, message, pipeline_run, engine_event_data=None, cls=None, step_key=None,
):
"""
Report a EngineEvent that occurred outside of a pipeline execution context.
"""
from dagster.core.events import EngineEventData, DagsterEvent, DagsterEventType
from dagster.core.events.log import DagsterEventRecord
check.class_param(cls, "cls")
check.str_param(message, "message")
check.inst_param(pipeline_run, "pipeline_run", PipelineRun)
engine_event_data = check.opt_inst_param(
engine_event_data, "engine_event_data", EngineEventData, EngineEventData([]),
)
if cls:
message = "[{}] {}".format(cls.__name__, message)
log_level = logging.INFO
if engine_event_data and engine_event_data.error:
log_level = logging.ERROR
dagster_event = DagsterEvent(
event_type_value=DagsterEventType.ENGINE_EVENT.value,
pipeline_name=pipeline_run.pipeline_name,
message=message,
event_specific_data=engine_event_data,
)
event_record = DagsterEventRecord(
message=message,
user_message=message,
level=log_level,
pipeline_name=pipeline_run.pipeline_name,
run_id=pipeline_run.run_id,
error_info=None,
timestamp=time.time(),
step_key=step_key,
dagster_event=dagster_event,
)
self.handle_new_event(event_record)
return dagster_event
def report_run_canceling(self, run, message=None):
from dagster.core.events import DagsterEvent, DagsterEventType
from dagster.core.events.log import DagsterEventRecord
check.inst_param(run, "run", PipelineRun)
message = check.opt_str_param(message, "message", "Sending pipeline termination request.",)
canceling_event = DagsterEvent(
event_type_value=DagsterEventType.PIPELINE_CANCELING.value,
pipeline_name=run.pipeline_name,
message=message,
)
event_record = DagsterEventRecord(
message=message,
user_message="",
level=logging.INFO,
pipeline_name=run.pipeline_name,
run_id=run.run_id,
error_info=None,
timestamp=time.time(),
dagster_event=canceling_event,
)
self.handle_new_event(event_record)
def report_run_canceled(
self, pipeline_run, message=None,
):
from dagster.core.events import DagsterEvent, DagsterEventType
from dagster.core.events.log import DagsterEventRecord
check.inst_param(pipeline_run, "pipeline_run", PipelineRun)
message = check.opt_str_param(
message,
"mesage",
"This pipeline run has been marked as canceled from outside the execution context.",
)
dagster_event = DagsterEvent(
event_type_value=DagsterEventType.PIPELINE_CANCELED.value,
pipeline_name=pipeline_run.pipeline_name,
message=message,
)
event_record = DagsterEventRecord(
message=message,
user_message=message,
level=logging.ERROR,
pipeline_name=pipeline_run.pipeline_name,
run_id=pipeline_run.run_id,
error_info=None,
timestamp=time.time(),
dagster_event=dagster_event,
)
self.handle_new_event(event_record)
return dagster_event
def report_run_failed(self, pipeline_run, message=None):
from dagster.core.events import DagsterEvent, DagsterEventType
from dagster.core.events.log import DagsterEventRecord
check.inst_param(pipeline_run, "pipeline_run", PipelineRun)
message = check.opt_str_param(
message,
"message",
"This pipeline run has been marked as failed from outside the execution context.",
)
dagster_event = DagsterEvent(
event_type_value=DagsterEventType.PIPELINE_FAILURE.value,
pipeline_name=pipeline_run.pipeline_name,
message=message,
)
event_record = DagsterEventRecord(
message=message,
user_message=message,
level=logging.ERROR,
pipeline_name=pipeline_run.pipeline_name,
run_id=pipeline_run.run_id,
error_info=None,
timestamp=time.time(),
dagster_event=dagster_event,
)
self.handle_new_event(event_record)
return dagster_event
# directories
def file_manager_directory(self, run_id):
return self._local_artifact_storage.file_manager_dir(run_id)
def intermediates_directory(self, run_id):
return self._local_artifact_storage.intermediates_dir(run_id)
def schedules_directory(self):
return self._local_artifact_storage.schedules_dir
# Runs coordinator
def submit_run(self, run_id, external_pipeline):
"""Submit a pipeline run to the coordinator.
This method delegates to the ``RunCoordinator``, configured on the instance, and will
call its implementation of ``RunCoordinator.submit_run()`` to send the run to the
coordinator for execution. Runs should be created in the instance (e.g., by calling
``DagsterInstance.create_run()``) *before* this method is called, and
should be in the ``PipelineRunStatus.NOT_STARTED`` state. They also must have a non-null
ExternalPipelineOrigin.
Args:
run_id (str): The id of the run.
"""
from dagster.core.host_representation import ExternalPipelineOrigin
run = self.get_run_by_id(run_id)
check.inst(
run.external_pipeline_origin,
ExternalPipelineOrigin,
"External pipeline origin must be set for submitted runs",
)
try:
submitted_run = self._run_coordinator.submit_run(
run, external_pipeline=external_pipeline
)
except:
from dagster.core.events import EngineEventData
error = serializable_error_info_from_exc_info(sys.exc_info())
self.report_engine_event(
error.message, run, EngineEventData.engine_error(error),
)
self.report_run_failed(run)
raise
return submitted_run
# Run launcher
def launch_run(self, run_id, external_pipeline):
"""Launch a pipeline run.
This method is typically called using `instance.submit_run` rather than being invoked
directly. This method delegates to the ``RunLauncher``, if any, configured on the instance,
and will call its implementation of ``RunLauncher.launch_run()`` to begin the execution of
the specified run. Runs should be created in the instance (e.g., by calling
``DagsterInstance.create_run()``) *before* this method is called, and should be in the
``PipelineRunStatus.NOT_STARTED`` state.
Args:
run_id (str): The id of the run the launch.
"""
run = self.get_run_by_id(run_id)
from dagster.core.events import EngineEventData, DagsterEvent, DagsterEventType
from dagster.core.events.log import DagsterEventRecord
launch_started_event = DagsterEvent(
event_type_value=DagsterEventType.PIPELINE_STARTING.value,
pipeline_name=run.pipeline_name,
)
event_record = DagsterEventRecord(
message="",
user_message="",
level=logging.INFO,
pipeline_name=run.pipeline_name,
run_id=run.run_id,
error_info=None,
timestamp=time.time(),
dagster_event=launch_started_event,
)
self.handle_new_event(event_record)
run = self.get_run_by_id(run_id)
try:
self._run_launcher.launch_run(self, run, external_pipeline=external_pipeline)
except:
error = serializable_error_info_from_exc_info(sys.exc_info())
self.report_engine_event(
error.message, run, EngineEventData.engine_error(error),
)
self.report_run_failed(run)
raise
return run
# Scheduler
def reconcile_scheduler_state(self, external_repository):
return self._scheduler.reconcile_scheduler_state(self, external_repository)
def start_schedule_and_update_storage_state(self, external_schedule):
return self._scheduler.start_schedule_and_update_storage_state(self, external_schedule)
def stop_schedule_and_update_storage_state(self, schedule_origin_id):
return self._scheduler.stop_schedule_and_update_storage_state(self, schedule_origin_id)
def stop_schedule_and_delete_from_storage(self, schedule_origin_id):
return self._scheduler.stop_schedule_and_delete_from_storage(self, schedule_origin_id)
def running_schedule_count(self, schedule_origin_id):
if self._scheduler:
return self._scheduler.running_schedule_count(self, schedule_origin_id)
return 0
def scheduler_debug_info(self):
from dagster.core.scheduler import SchedulerDebugInfo
from dagster.core.definitions.job import JobType
from dagster.core.scheduler.job import JobStatus
errors = []
schedules = []
for schedule_state in self.all_stored_job_state(job_type=JobType.SCHEDULE):
if schedule_state.status == JobStatus.RUNNING and not self.running_schedule_count(
schedule_state.job_origin_id
):
errors.append(
"Schedule {schedule_name} is set to be running, but the scheduler is not "
"running the schedule.".format(schedule_name=schedule_state.job_name)
)
elif schedule_state.status == JobStatus.STOPPED and self.running_schedule_count(
schedule_state.job_origin_id
):
errors.append(
"Schedule {schedule_name} is set to be stopped, but the scheduler is still running "
"the schedule.".format(schedule_name=schedule_state.job_name)
)
if self.running_schedule_count(schedule_state.job_origin_id) > 1:
errors.append(
"Duplicate jobs found: More than one job for schedule {schedule_name} are "
"running on the scheduler.".format(schedule_name=schedule_state.job_name)
)
schedule_info = {
schedule_state.job_name: {
"status": schedule_state.status.value,
"cron_schedule": schedule_state.job_specific_data.cron_schedule,
"repository_pointer": schedule_state.origin.get_repo_cli_args(),
"schedule_origin_id": schedule_state.job_origin_id,
"repository_origin_id": schedule_state.repository_origin_id,
}
}
schedules.append(yaml.safe_dump(schedule_info, default_flow_style=False))
return SchedulerDebugInfo(
scheduler_config_info=self.info_str_for_component("Scheduler", self.scheduler),
scheduler_info=self.scheduler.debug_info(),
schedule_storage=schedules,
errors=errors,
)
# Schedule Storage
def start_sensor(self, external_sensor):
from dagster.core.scheduler.job import JobState, JobStatus, SensorJobData
from dagster.core.definitions.job import JobType
job_state = self.get_job_state(external_sensor.get_external_origin_id())
if not job_state:
self.add_job_state(
JobState(
external_sensor.get_external_origin(),
JobType.SENSOR,
JobStatus.RUNNING,
SensorJobData(datetime.utcnow().timestamp()),
)
)
elif job_state.status != JobStatus.RUNNING:
# set the last completed time to the modified state time
self.update_job_state(
job_state.with_status(JobStatus.RUNNING).with_data(
SensorJobData(datetime.utcnow().timestamp())
)
)
def stop_sensor(self, job_origin_id):
job_state = self.get_job_state(job_origin_id)
if job_state:
self.delete_job_state(job_origin_id)
def all_stored_job_state(self, repository_origin_id=None, job_type=None):
return self._schedule_storage.all_stored_job_state(repository_origin_id, job_type)
def get_job_state(self, job_origin_id):
return self._schedule_storage.get_job_state(job_origin_id)
def add_job_state(self, job_state):
return self._schedule_storage.add_job_state(job_state)
def update_job_state(self, job_state):
return self._schedule_storage.update_job_state(job_state)
def delete_job_state(self, job_origin_id):
return self._schedule_storage.delete_job_state(job_origin_id)
def get_job_ticks(self, job_origin_id):
return self._schedule_storage.get_job_ticks(job_origin_id)
def get_latest_job_tick(self, job_origin_id):
return self._schedule_storage.get_latest_job_tick(job_origin_id)
def create_job_tick(self, job_tick_data):
return self._schedule_storage.create_job_tick(job_tick_data)
def update_job_tick(self, tick):
return self._schedule_storage.update_job_tick(tick)
def get_job_tick_stats(self, job_origin_id):
return self._schedule_storage.get_job_tick_stats(job_origin_id)
def purge_job_ticks(self, job_origin_id, tick_status, before):
self._schedule_storage.purge_job_ticks(job_origin_id, tick_status, before)
def wipe_all_schedules(self):
if self._scheduler:
self._scheduler.wipe(self)
self._schedule_storage.wipe()
def logs_path_for_schedule(self, schedule_origin_id):
return self._scheduler.get_logs_path(self, schedule_origin_id)
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
self.dispose()
def get_addresses_for_step_output_versions(self, step_output_versions):
"""
For each given step output, finds whether an output exists with the given
version, and returns its address if it does.
Args:
step_output_versions (Dict[(str, StepOutputHandle), str]):
(pipeline name, step output handle) -> version.
Returns:
Dict[(str, StepOutputHandle), str]: (pipeline name, step output handle) -> address.
For each step output, an address if there is one and None otherwise.
"""
return self._event_storage.get_addresses_for_step_output_versions(step_output_versions)
# dagster daemon
def add_daemon_heartbeat(self, daemon_heartbeat):
"""Called on a regular interval by the daemon"""
self._run_storage.add_daemon_heartbeat(daemon_heartbeat)
def get_daemon_heartbeats(self):
"""Latest heartbeats of all daemon types"""
return self._run_storage.get_daemon_heartbeats()
def wipe_daemon_heartbeats(self):
self._run_storage.wipe_daemon_heartbeats()
diff --git a/python_modules/dagster/dagster/core/storage/runs/sqlite/sqlite_run_storage.py b/python_modules/dagster/dagster/core/storage/runs/sqlite/sqlite_run_storage.py
index 030b22a78..1b0a50e5f 100644
--- a/python_modules/dagster/dagster/core/storage/runs/sqlite/sqlite_run_storage.py
+++ b/python_modules/dagster/dagster/core/storage/runs/sqlite/sqlite_run_storage.py
@@ -1,132 +1,132 @@
import os
from contextlib import contextmanager
+from urllib.parse import urljoin, urlparse
import sqlalchemy as db
from dagster import StringSource, check
from dagster.core.storage.sql import (
check_alembic_revision,
create_engine,
get_alembic_config,
handle_schema_errors,
run_alembic_downgrade,
run_alembic_upgrade,
stamp_alembic_rev,
)
from dagster.core.storage.sqlite import create_db_conn_string
from dagster.serdes import ConfigurableClass, ConfigurableClassData
-from dagster.seven import urljoin, urlparse
from dagster.utils import mkdir_p
from sqlalchemy.pool import NullPool
from ..schema import RunStorageSqlMetadata, RunTagsTable, RunsTable
from ..sql_run_storage import SqlRunStorage
class SqliteRunStorage(SqlRunStorage, ConfigurableClass):
"""SQLite-backed run storage.
Users should not directly instantiate this class; it is instantiated by internal machinery when
``dagit`` and ``dagster-graphql`` load, based on the values in the ``dagster.yaml`` file in
``$DAGSTER_HOME``. Configuration of this class should be done by setting values in that file.
This is the default run storage when none is specified in the ``dagster.yaml``.
To explicitly specify SQLite for run storage, you can add a block such as the following to your
``dagster.yaml``:
.. code-block:: YAML
run_storage:
module: dagster.core.storage.runs
class: SqliteRunStorage
config:
base_dir: /path/to/dir
The ``base_dir`` param tells the run storage where on disk to store the database.
"""
def __init__(self, conn_string, inst_data=None):
check.str_param(conn_string, "conn_string")
self._conn_string = conn_string
self._inst_data = check.opt_inst_param(inst_data, "inst_data", ConfigurableClassData)
@property
def inst_data(self):
return self._inst_data
@classmethod
def config_type(cls):
return {"base_dir": StringSource}
@staticmethod
def from_config_value(inst_data, config_value):
return SqliteRunStorage.from_local(inst_data=inst_data, **config_value)
@staticmethod
def from_local(base_dir, inst_data=None):
check.str_param(base_dir, "base_dir")
mkdir_p(base_dir)
conn_string = create_db_conn_string(base_dir, "runs")
engine = create_engine(conn_string, poolclass=NullPool)
engine.execute("PRAGMA journal_mode=WAL;")
RunStorageSqlMetadata.create_all(engine)
alembic_config = get_alembic_config(__file__)
connection = engine.connect()
db_revision, head_revision = check_alembic_revision(alembic_config, connection)
if not (db_revision and head_revision):
stamp_alembic_rev(alembic_config, engine)
return SqliteRunStorage(conn_string, inst_data)
@contextmanager
def connect(self):
engine = create_engine(self._conn_string, poolclass=NullPool)
conn = engine.connect()
try:
with handle_schema_errors(
conn, get_alembic_config(__file__), msg="Sqlite run storage requires migration",
):
yield conn
finally:
conn.close()
def _alembic_upgrade(self, rev="head"):
alembic_config = get_alembic_config(__file__)
with self.connect() as conn:
run_alembic_upgrade(alembic_config, conn, rev=rev)
def _alembic_downgrade(self, rev="head"):
alembic_config = get_alembic_config(__file__)
with self.connect() as conn:
run_alembic_downgrade(alembic_config, conn, rev=rev)
def upgrade(self):
self._check_for_version_066_migration_and_perform()
self._alembic_upgrade()
# In version 0.6.6, we changed the layout of the of the sqllite dbs on disk
# to move from the root of DAGSTER_HOME/runs.db to DAGSTER_HOME/history/runs.bd
# This function checks for that condition and does the move
def _check_for_version_066_migration_and_perform(self):
old_conn_string = "sqlite://" + urljoin(urlparse(self._conn_string).path, "../runs.db")
path_to_old_db = urlparse(old_conn_string).path
# sqlite URLs look like `sqlite:///foo/bar/baz on Unix/Mac` but on Windows they look like
# `sqlite:///D:/foo/bar/baz` (or `sqlite:///D:\foo\bar\baz`)
if os.name == "nt":
path_to_old_db = path_to_old_db.lstrip("/")
if os.path.exists(path_to_old_db):
old_storage = SqliteRunStorage(old_conn_string)
old_runs = old_storage.get_runs()
for run in old_runs:
self.add_run(run)
os.unlink(path_to_old_db)
def delete_run(self, run_id):
""" Override the default sql delete run implementation until we can get full
support on cascading deletes """
check.str_param(run_id, "run_id")
remove_tags = db.delete(RunTagsTable).where(RunTagsTable.c.run_id == run_id)
remove_run = db.delete(RunsTable).where(RunsTable.c.run_id == run_id)
with self.connect() as conn:
conn.execute(remove_tags)
conn.execute(remove_run)
diff --git a/python_modules/dagster/dagster/core/storage/sql.py b/python_modules/dagster/dagster/core/storage/sql.py
index 2f99590f9..a233785e1 100644
--- a/python_modules/dagster/dagster/core/storage/sql.py
+++ b/python_modules/dagster/dagster/core/storage/sql.py
@@ -1,128 +1,128 @@
# pylint chokes on the perfectly ok import from alembic.migration
import sys
from contextlib import contextmanager
+from functools import lru_cache
import sqlalchemy as db
from alembic.command import downgrade, stamp, upgrade
from alembic.config import Config
from alembic.migration import MigrationContext # pylint: disable=import-error
from alembic.script import ScriptDirectory
from dagster.core.errors import DagsterInstanceMigrationRequired
-from dagster.seven import lru_cache
from dagster.utils import file_relative_path
from dagster.utils.log import quieten
create_engine = db.create_engine # exported
@lru_cache(maxsize=3) # run, event, and schedule storages
def get_alembic_config(dunder_file, config_path="alembic/alembic.ini", script_path="alembic/"):
alembic_config = Config(file_relative_path(dunder_file, config_path))
alembic_config.set_main_option("script_location", file_relative_path(dunder_file, script_path))
return alembic_config
def run_alembic_upgrade(alembic_config, conn, run_id=None, rev="head"):
alembic_config.attributes["connection"] = conn
alembic_config.attributes["run_id"] = run_id
upgrade(alembic_config, rev)
def run_alembic_downgrade(alembic_config, conn, rev, run_id=None):
alembic_config.attributes["connection"] = conn
alembic_config.attributes["run_id"] = run_id
downgrade(alembic_config, rev)
def stamp_alembic_rev(alembic_config, conn, rev="head", quiet=True):
with quieten(quiet):
alembic_config.attributes["connection"] = conn
stamp(alembic_config, rev)
def check_alembic_revision(alembic_config, conn):
migration_context = MigrationContext.configure(conn)
db_revision = migration_context.get_current_revision()
script = ScriptDirectory.from_config(alembic_config)
head_revision = script.as_revision_number("head")
return (db_revision, head_revision)
@contextmanager
def handle_schema_errors(conn, alembic_config, msg=None):
try:
yield
except (db.exc.OperationalError, db.exc.ProgrammingError, db.exc.StatementError):
db_revision, head_revision = (None, None)
try:
with quieten():
db_revision, head_revision = check_alembic_revision(alembic_config, conn)
# If exceptions were raised during the revision check, we want to swallow them and
# allow the original exception to fall through
except Exception: # pylint: disable=broad-except
pass
if db_revision != head_revision:
raise DagsterInstanceMigrationRequired(
msg=msg,
db_revision=db_revision,
head_revision=head_revision,
original_exc_info=sys.exc_info(),
)
raise
def run_migrations_offline(context, config, target_metadata):
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
connectable = config.attributes.get("connection", None)
if connectable is None:
raise Exception(
"No connection set in alembic config. If you are trying to run this script from the "
"command line, STOP and read the README."
)
context.configure(
url=connectable.url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online(context, config, target_metadata):
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = config.attributes.get("connection", None)
if connectable is None:
raise Exception(
"No connection set in alembic config. If you are trying to run this script from the "
"command line, STOP and read the README."
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
diff --git a/python_modules/dagster/dagster/core/test_utils.py b/python_modules/dagster/dagster/core/test_utils.py
index 0fab3110e..1a760a954 100644
--- a/python_modules/dagster/dagster/core/test_utils.py
+++ b/python_modules/dagster/dagster/core/test_utils.py
@@ -1,375 +1,376 @@
import os
import signal
import sys
+import tempfile
import time
from contextlib import contextmanager
import pendulum
import yaml
-from dagster import Shape, check, composite_solid, pipeline, seven, solid
+from dagster import Shape, check, composite_solid, pipeline, solid
from dagster.core.host_representation import ExternalPipeline
from dagster.core.host_representation.origin import ExternalPipelineOrigin
from dagster.core.instance import DagsterInstance
from dagster.core.launcher import RunLauncher
from dagster.core.launcher.default_run_launcher import DefaultRunLauncher
from dagster.core.run_coordinator import RunCoordinator
from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus
from dagster.core.telemetry import cleanup_telemetry_logger
from dagster.serdes import ConfigurableClass
from dagster.utils.error import serializable_error_info_from_exc_info
def step_output_event_filter(pipe_iterator):
for step_event in pipe_iterator:
if step_event.is_successful_output:
yield step_event
def nesting_composite_pipeline(depth, num_children, *args, **kwargs):
"""Creates a pipeline of nested composite solids up to "depth" layers, with a fan-out of
num_children at each layer.
Total number of solids will be num_children ^ depth
"""
@solid
def leaf_node(_):
return 1
def create_wrap(inner, name):
@composite_solid(name=name)
def wrap():
for i in range(num_children):
solid_alias = "%s_node_%d" % (name, i)
inner.alias(solid_alias)()
return wrap
@pipeline(*args, **kwargs)
def nested_pipeline():
comp_solid = create_wrap(leaf_node, "layer_%d" % depth)
for i in range(depth):
comp_solid = create_wrap(comp_solid, "layer_%d" % (depth - (i + 1)))
comp_solid.alias("outer")()
return nested_pipeline
@contextmanager
def environ(env):
"""Temporarily set environment variables inside the context manager and
fully restore previous environment afterwards
"""
previous_values = {key: os.getenv(key) for key in env}
for key, value in env.items():
if value is None:
if key in os.environ:
del os.environ[key]
else:
os.environ[key] = value
try:
yield
finally:
for key, value in previous_values.items():
if value is None:
if key in os.environ:
del os.environ[key]
else:
os.environ[key] = value
@contextmanager
def instance_for_test(overrides=None):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(temp_dir, overrides) as instance:
yield instance
@contextmanager
def instance_for_test_tempdir(temp_dir, overrides=None):
# Write any overrides to disk and set DAGSTER_HOME so that they will still apply when
# DagsterInstance.get() is called from a different process
with environ({"DAGSTER_HOME": temp_dir}):
with open(os.path.join(temp_dir, "dagster.yaml"), "w") as fd:
yaml.dump(overrides, fd, default_flow_style=False)
with DagsterInstance.get() as instance:
try:
yield instance
except:
sys.stderr.write(
"Test raised an exception, attempting to clean up instance:"
+ serializable_error_info_from_exc_info(sys.exc_info()).to_string()
+ "\n"
)
raise
finally:
cleanup_test_instance(instance)
def cleanup_test_instance(instance):
check.inst_param(instance, "instance", DagsterInstance)
# To avoid filesystem contention when we close the temporary directory, wait for
# all runs to reach a terminal state, and close any subprocesses or threads
# that might be accessing the run history DB.
instance.run_launcher.join()
if isinstance(instance.run_launcher, DefaultRunLauncher):
instance.run_launcher.cleanup_managed_grpc_servers()
cleanup_telemetry_logger()
def create_run_for_test(
instance,
pipeline_name=None,
run_id=None,
run_config=None,
mode=None,
solids_to_execute=None,
step_keys_to_execute=None,
status=None,
tags=None,
root_run_id=None,
parent_run_id=None,
pipeline_snapshot=None,
execution_plan_snapshot=None,
parent_pipeline_snapshot=None,
external_pipeline_origin=None,
):
return instance.create_run(
pipeline_name,
run_id,
run_config,
mode,
solids_to_execute,
step_keys_to_execute,
status,
tags,
root_run_id,
parent_run_id,
pipeline_snapshot,
execution_plan_snapshot,
parent_pipeline_snapshot,
external_pipeline_origin=external_pipeline_origin,
)
def register_managed_run_for_test(
instance,
pipeline_name=None,
run_id=None,
run_config=None,
mode=None,
solids_to_execute=None,
step_keys_to_execute=None,
tags=None,
root_run_id=None,
parent_run_id=None,
pipeline_snapshot=None,
execution_plan_snapshot=None,
parent_pipeline_snapshot=None,
):
return instance.register_managed_run(
pipeline_name,
run_id,
run_config,
mode,
solids_to_execute,
step_keys_to_execute,
tags,
root_run_id,
parent_run_id,
pipeline_snapshot,
execution_plan_snapshot,
parent_pipeline_snapshot,
)
def poll_for_finished_run(instance, run_id, timeout=20):
total_time = 0
interval = 0.01
while True:
run = instance.get_run_by_id(run_id)
if run.is_finished:
return run
else:
time.sleep(interval)
total_time += interval
if total_time > timeout:
raise Exception("Timed out")
def poll_for_step_start(instance, run_id, timeout=30):
poll_for_event(instance, run_id, event_type="STEP_START", message=None, timeout=timeout)
def poll_for_event(instance, run_id, event_type, message, timeout=30):
total_time = 0
backoff = 0.01
while True:
time.sleep(backoff)
logs = instance.all_logs(run_id)
matching_events = [
log_record.dagster_event
for log_record in logs
if log_record.dagster_event.event_type_value == event_type
]
if matching_events:
if message is None:
return
for matching_message in (event.message for event in matching_events):
if message in matching_message:
return
total_time += backoff
backoff = backoff * 2
if total_time > timeout:
raise Exception("Timed out")
@contextmanager
def new_cwd(path):
old = os.getcwd()
try:
os.chdir(path)
yield
finally:
os.chdir(old)
def today_at_midnight(timezone_name=None):
now = pendulum.now(timezone_name)
return pendulum.create(now.year, now.month, now.day, tz=now.timezone.name)
class ExplodingRunLauncher(RunLauncher, ConfigurableClass):
def __init__(self, inst_data=None):
self._inst_data = inst_data
@property
def inst_data(self):
return self._inst_data
@classmethod
def config_type(cls):
return {}
@staticmethod
def from_config_value(inst_data, config_value):
return ExplodingRunLauncher(inst_data=inst_data)
def launch_run(self, instance, run, external_pipeline):
raise NotImplementedError("The entire purpose of this is to throw on launch")
def join(self, timeout=30):
"""Nothing to join on since all executions are synchronous."""
def can_terminate(self, run_id):
return False
def terminate(self, run_id):
check.not_implemented("Termination not supported")
class MockedRunLauncher(RunLauncher, ConfigurableClass):
def __init__(self, inst_data=None):
self._inst_data = inst_data
self._queue = []
def launch_run(self, instance, run, external_pipeline):
check.inst_param(instance, "instance", DagsterInstance)
check.inst_param(run, "run", PipelineRun)
check.inst_param(external_pipeline, "external_pipeline", ExternalPipeline)
check.invariant(run.status == PipelineRunStatus.STARTING)
self._queue.append(run)
return run
def queue(self):
return self._queue
@classmethod
def config_type(cls):
return Shape({})
@classmethod
def from_config_value(cls, inst_data, config_value):
return cls(inst_data=inst_data,)
@property
def inst_data(self):
return self._inst_data
def can_terminate(self, run_id):
return False
def terminate(self, run_id):
check.not_implemented("Termintation not supported")
class MockedRunCoordinator(RunCoordinator, ConfigurableClass):
def __init__(self, inst_data=None):
self._inst_data = inst_data
self._queue = []
def submit_run(self, pipeline_run, external_pipeline):
check.inst_param(pipeline_run, "run", PipelineRun)
check.opt_inst_param(external_pipeline, "external_pipeline", ExternalPipeline)
check.inst(pipeline_run.external_pipeline_origin, ExternalPipelineOrigin)
self._queue.append(pipeline_run)
return pipeline_run
def queue(self):
return self._queue
@classmethod
def config_type(cls):
return Shape({})
@classmethod
def from_config_value(cls, inst_data, config_value):
return cls(inst_data=inst_data,)
@property
def inst_data(self):
return self._inst_data
def can_cancel_run(self, run_id):
check.not_implemented("Cancellation not supported")
def cancel_run(self, run_id):
check.not_implemented("Cancellation not supported")
def get_terminate_signal():
if sys.platform == "win32":
return signal.SIGTERM
return signal.SIGKILL
def get_crash_signals():
if sys.platform == "win32":
return [
get_terminate_signal()
] # Windows keeps resources open after termination in a way that messes up tests
else:
return [get_terminate_signal(), signal.SIGINT]
_mocked_system_timezone = {"timezone": None}
@contextmanager
def mock_system_timezone(override_timezone):
with pendulum.tz.LocalTimezone.test(pendulum.Timezone.load(override_timezone)):
try:
_mocked_system_timezone["timezone"] = override_timezone
yield
finally:
_mocked_system_timezone["timezone"] = None
def get_mocked_system_timezone():
return _mocked_system_timezone["timezone"]
diff --git a/python_modules/dagster/dagster/core/types/dagster_type.py b/python_modules/dagster/dagster/core/types/dagster_type.py
index f3e55a1d9..5f6996952 100644
--- a/python_modules/dagster/dagster/core/types/dagster_type.py
+++ b/python_modules/dagster/dagster/core/types/dagster_type.py
@@ -1,883 +1,883 @@
from abc import abstractmethod
from enum import Enum as PythonEnum
from functools import partial
import six
from dagster import check
from dagster.builtins import BuiltinEnum
from dagster.config.config_type import Array
from dagster.config.config_type import Noneable as ConfigNoneable
from dagster.core.definitions.events import TypeCheck
from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster.core.storage.type_storage import TypeStoragePlugin
from dagster.serdes import whitelist_for_serdes
from dagster.utils.backcompat import rename_warning
from .builtin_config_schemas import BuiltinSchemas
from .config_schema import DagsterTypeLoader, DagsterTypeMaterializer
from .marshal import PickleSerializationStrategy, SerializationStrategy
@whitelist_for_serdes
class DagsterTypeKind(PythonEnum):
ANY = "ANY"
SCALAR = "SCALAR"
LIST = "LIST"
NOTHING = "NOTHING"
NULLABLE = "NULLABLE"
REGULAR = "REGULAR"
class DagsterType:
"""Define a type in dagster. These can be used in the inputs and outputs of solids.
Args:
type_check_fn (Callable[[TypeCheckContext, Any], [Union[bool, TypeCheck]]]):
The function that defines the type check. It takes the value flowing
through the input or output of the solid. If it passes, return either
``True`` or a :py:class:`~dagster.TypeCheck` with ``success`` set to ``True``. If it fails,
return either ``False`` or a :py:class:`~dagster.TypeCheck` with ``success`` set to ``False``.
The first argument must be named ``context`` (or, if unused, ``_``, ``_context``, or ``context_``).
Use ``required_resource_keys`` for access to resources.
key (Optional[str]): The unique key to identify types programatically.
The key property always has a value. If you omit key to the argument
to the init function, it instead receives the value of ``name``. If
neither ``key`` nor ``name`` is provided, a ``CheckError`` is thrown.
In the case of a generic type such as ``List`` or ``Optional``, this is
generated programatically based on the type parameters.
For most use cases, name should be set and the key argument should
not be specified.
name (Optional[str]): A unique name given by a user. If ``key`` is ``None``, ``key``
becomes this value. Name is not given in a case where the user does
not specify a unique name for this type, such as a generic class.
description (Optional[str]): A markdown-formatted string, displayed in tooling.
loader (Optional[DagsterTypeLoader]): An instance of a class that
inherits from :py:class:`~dagster.DagsterTypeLoader` and can map config data to a value of
this type. Specify this argument if you will need to shim values of this type using the
config machinery. As a rule, you should use the
:py:func:`@dagster_type_loader ` decorator to construct
these arguments.
materializer (Optional[DagsterTypeMaterializer]): An instance of a class
that inherits from :py:class:`~dagster.DagsterTypeMaterializer` and can persist values of
this type. As a rule, you should use the
:py:func:`@dagster_type_materializer `
decorator to construct these arguments.
serialization_strategy (Optional[SerializationStrategy]): An instance of a class that
inherits from :py:class:`~dagster.SerializationStrategy`. The default strategy for serializing
this value when automatically persisting it between execution steps. You should set
this value if the ordinary serialization machinery (e.g., pickle) will not be adequate
for this type.
auto_plugins (Optional[List[Type[TypeStoragePlugin]]]): If types must be serialized differently
depending on the storage being used for intermediates, they should specify this
argument. In these cases the serialization_strategy argument is not sufficient because
serialization requires specialized API calls, e.g. to call an S3 API directly instead
of using a generic file object. See ``dagster_pyspark.DataFrame`` for an example.
required_resource_keys (Optional[Set[str]]): Resource keys required by the ``type_check_fn``.
is_builtin (bool): Defaults to False. This is used by tools to display or
filter built-in types (such as :py:class:`~dagster.String`, :py:class:`~dagster.Int`) to visually distinguish
them from user-defined types. Meant for internal use.
kind (DagsterTypeKind): Defaults to None. This is used to determine the kind of runtime type
for InputDefinition and OutputDefinition type checking.
"""
def __init__(
self,
type_check_fn,
key=None,
name=None,
is_builtin=False,
description=None,
loader=None,
materializer=None,
serialization_strategy=None,
auto_plugins=None,
required_resource_keys=None,
kind=DagsterTypeKind.REGULAR,
):
check.opt_str_param(key, "key")
check.opt_str_param(name, "name")
check.invariant(not (name is None and key is None), "Must set key or name")
if name is None:
check.param_invariant(
bool(key), "key", "If name is not provided, must provide key.",
)
self.key, self._name = key, None
elif key is None:
check.param_invariant(
bool(name), "name", "If key is not provided, must provide name.",
)
self.key, self._name = name, name
else:
check.invariant(key and name)
self.key, self._name = key, name
self.description = check.opt_str_param(description, "description")
self.loader = check.opt_inst_param(loader, "loader", DagsterTypeLoader)
self.materializer = check.opt_inst_param(
materializer, "materializer", DagsterTypeMaterializer
)
self.serialization_strategy = check.opt_inst_param(
serialization_strategy,
"serialization_strategy",
SerializationStrategy,
PickleSerializationStrategy(),
)
self.required_resource_keys = check.opt_set_param(
required_resource_keys, "required_resource_keys",
)
self._type_check_fn = check.callable_param(type_check_fn, "type_check_fn")
_validate_type_check_fn(self._type_check_fn, self._name)
auto_plugins = check.opt_list_param(auto_plugins, "auto_plugins", of_type=type)
check.param_invariant(
all(
issubclass(auto_plugin_type, TypeStoragePlugin) for auto_plugin_type in auto_plugins
),
"auto_plugins",
)
self.auto_plugins = auto_plugins
self.is_builtin = check.bool_param(is_builtin, "is_builtin")
check.invariant(
self.display_name is not None,
"All types must have a valid display name, got None for key {}".format(key),
)
self.kind = check.inst_param(kind, "kind", DagsterTypeKind)
def type_check(self, context, value):
retval = self._type_check_fn(context, value)
if not isinstance(retval, (bool, TypeCheck)):
raise DagsterInvariantViolationError(
(
"You have returned {retval} of type {retval_type} from the type "
'check function of type "{type_key}". Return value must be instance '
"of TypeCheck or a bool."
).format(retval=repr(retval), retval_type=type(retval), type_key=self.key)
)
return TypeCheck(success=retval) if isinstance(retval, bool) else retval
def __eq__(self, other):
return isinstance(other, DagsterType) and self.key == other.key
def __ne__(self, other):
return not self.__eq__(other)
@staticmethod
def from_builtin_enum(builtin_enum):
check.invariant(BuiltinEnum.contains(builtin_enum), "must be member of BuiltinEnum")
return _RUNTIME_MAP[builtin_enum]
@property
def display_name(self):
"""Asserted in __init__ to be not None, overridden in many subclasses"""
return self._name
@property
def unique_name(self):
"""The unique name of this type. Can be None if the type is not unique, such as container types"""
check.invariant(
self._name is not None,
"unique_name requested but is None for type {}".format(self.display_name),
)
return self._name
@property
def has_unique_name(self):
return self._name is not None
@property
def inner_types(self):
return []
@property
def input_hydration_schema_key(self):
rename_warning("loader_schema_key", "input_hydration_schema_key", "0.10.0")
return self.loader_schema_key
@property
def loader_schema_key(self):
return self.loader.schema_type.key if self.loader else None
@property
def output_materialization_schema_key(self):
rename_warning("materializer_schema_key", "output_materialization_schema_key", "0.10.0")
return self.materializer_schema_key
@property
def materializer_schema_key(self):
return self.materializer.schema_type.key if self.materializer else None
@property
def type_param_keys(self):
return []
@property
def is_nothing(self):
return self.kind == DagsterTypeKind.NOTHING
@property
def supports_fan_in(self):
return False
def get_inner_type_for_fan_in(self):
check.invariant(
"DagsterType {name} does not support fan-in, should have checked supports_fan_in before calling getter.".format(
name=self.display_name
)
)
def _validate_type_check_fn(fn, name):
from dagster.seven import get_args
args = get_args(fn)
# py2 doesn't filter out self
if len(args) >= 1 and args[0] == "self":
args = args[1:]
if len(args) == 2:
possible_names = {
"_",
"context",
"_context",
"context_",
}
if args[0] not in possible_names:
DagsterInvalidDefinitionError(
'type_check function on type "{name}" must have first '
'argument named "context" (or _, _context, context_).'.format(name=name,)
)
return True
raise DagsterInvalidDefinitionError(
'type_check_fn argument on type "{name}" must take 2 arguments, '
"received {count}.".format(name=name, count=len(args))
)
class BuiltinScalarDagsterType(DagsterType):
def __init__(self, name, type_check_fn, *args, **kwargs):
super(BuiltinScalarDagsterType, self).__init__(
key=name,
name=name,
kind=DagsterTypeKind.SCALAR,
type_check_fn=type_check_fn,
is_builtin=True,
*args,
**kwargs,
)
def type_check_fn(self, _context, value):
return self.type_check_scalar_value(value)
@abstractmethod
def type_check_scalar_value(self, _value):
raise NotImplementedError()
class _Int(BuiltinScalarDagsterType):
def __init__(self):
super(_Int, self).__init__(
name="Int",
loader=BuiltinSchemas.INT_INPUT,
materializer=BuiltinSchemas.INT_OUTPUT,
type_check_fn=self.type_check_fn,
)
def type_check_scalar_value(self, value):
return _fail_if_not_of_type(value, six.integer_types, "int")
def _typemismatch_error_str(value, expected_type_desc):
return 'Value "{value}" of python type "{python_type}" must be a {type_desc}.'.format(
value=value, python_type=type(value).__name__, type_desc=expected_type_desc
)
def _fail_if_not_of_type(value, value_type, value_type_desc):
if not isinstance(value, value_type):
return TypeCheck(success=False, description=_typemismatch_error_str(value, value_type_desc))
return TypeCheck(success=True)
class _String(BuiltinScalarDagsterType):
def __init__(self):
super(_String, self).__init__(
name="String",
loader=BuiltinSchemas.STRING_INPUT,
materializer=BuiltinSchemas.STRING_OUTPUT,
type_check_fn=self.type_check_fn,
)
def type_check_scalar_value(self, value):
return _fail_if_not_of_type(value, str, "string")
class _Float(BuiltinScalarDagsterType):
def __init__(self):
super(_Float, self).__init__(
name="Float",
loader=BuiltinSchemas.FLOAT_INPUT,
materializer=BuiltinSchemas.FLOAT_OUTPUT,
type_check_fn=self.type_check_fn,
)
def type_check_scalar_value(self, value):
return _fail_if_not_of_type(value, float, "float")
class _Bool(BuiltinScalarDagsterType):
def __init__(self):
super(_Bool, self).__init__(
name="Bool",
loader=BuiltinSchemas.BOOL_INPUT,
materializer=BuiltinSchemas.BOOL_OUTPUT,
type_check_fn=self.type_check_fn,
)
def type_check_scalar_value(self, value):
return _fail_if_not_of_type(value, bool, "bool")
class Anyish(DagsterType):
def __init__(
self,
key,
name,
loader=None,
materializer=None,
serialization_strategy=None,
is_builtin=False,
description=None,
auto_plugins=None,
):
super(Anyish, self).__init__(
key=key,
name=name,
kind=DagsterTypeKind.ANY,
loader=loader,
materializer=materializer,
serialization_strategy=serialization_strategy,
is_builtin=is_builtin,
type_check_fn=self.type_check_method,
description=description,
auto_plugins=auto_plugins,
)
def type_check_method(self, _context, _value):
return TypeCheck(success=True)
@property
def supports_fan_in(self):
return True
def get_inner_type_for_fan_in(self):
# Anyish all the way down
return self
class _Any(Anyish):
def __init__(self):
super(_Any, self).__init__(
key="Any",
name="Any",
loader=BuiltinSchemas.ANY_INPUT,
materializer=BuiltinSchemas.ANY_OUTPUT,
is_builtin=True,
)
def create_any_type(
name,
loader=None,
materializer=None,
serialization_strategy=None,
description=None,
auto_plugins=None,
):
return Anyish(
key=name,
name=name,
description=description,
loader=loader,
materializer=materializer,
serialization_strategy=serialization_strategy,
auto_plugins=auto_plugins,
)
class _Nothing(DagsterType):
def __init__(self):
super(_Nothing, self).__init__(
key="Nothing",
name="Nothing",
kind=DagsterTypeKind.NOTHING,
loader=None,
materializer=None,
type_check_fn=self.type_check_method,
is_builtin=True,
)
def type_check_method(self, _context, value):
if value is not None:
return TypeCheck(
success=False,
description="Value must be None, got a {value_type}".format(value_type=type(value)),
)
return TypeCheck(success=True)
@property
def supports_fan_in(self):
return True
def get_inner_type_for_fan_in(self):
return self
class PythonObjectDagsterType(DagsterType):
"""Define a type in dagster whose typecheck is an isinstance check.
Specifically, the type can either be a single python type (e.g. int),
or a tuple of types (e.g. (int, float)) which is treated as a union.
Examples:
.. code-block:: python
ntype = PythonObjectDagsterType(python_type=int)
assert ntype.name == 'int'
assert_success(ntype, 1)
assert_failure(ntype, 'a')
.. code-block:: python
ntype = PythonObjectDagsterType(python_type=(int, float))
assert ntype.name == 'Union[int, float]'
assert_success(ntype, 1)
assert_success(ntype, 1.5)
assert_failure(ntype, 'a')
Args:
python_type (Union[Type, Tuple[Type, ...]): The dagster typecheck function calls instanceof on
this type.
name (Optional[str]): Name the type. Defaults to the name of ``python_type``.
key (Optional[str]): Key of the type. Defaults to name.
description (Optional[str]): A markdown-formatted string, displayed in tooling.
loader (Optional[DagsterTypeLoader]): An instance of a class that
inherits from :py:class:`~dagster.DagsterTypeLoader` and can map config data to a value of
this type. Specify this argument if you will need to shim values of this type using the
config machinery. As a rule, you should use the
:py:func:`@dagster_type_loader ` decorator to construct
these arguments.
materializer (Optional[DagsterTypeMaterializer]): An instance of a class
that inherits from :py:class:`~dagster.DagsterTypeMaterializer` and can persist values of
this type. As a rule, you should use the
:py:func:`@dagster_type_mate `
decorator to construct these arguments.
serialization_strategy (Optional[SerializationStrategy]): An instance of a class that
inherits from :py:class:`SerializationStrategy`. The default strategy for serializing
this value when automatically persisting it between execution steps. You should set
this value if the ordinary serialization machinery (e.g., pickle) will not be adequate
for this type.
auto_plugins (Optional[List[Type[TypeStoragePlugin]]]): If types must be serialized differently
depending on the storage being used for intermediates, they should specify this
argument. In these cases the serialization_strategy argument is not sufficient because
serialization requires specialized API calls, e.g. to call an S3 API directly instead
of using a generic file object. See ``dagster_pyspark.DataFrame`` for an example.
"""
def __init__(self, python_type, key=None, name=None, **kwargs):
if isinstance(python_type, tuple):
self.python_type = check.tuple_param(
- python_type, "python_type", of_type=tuple(check.type_types for item in python_type)
+ python_type, "python_type", of_type=tuple(type for item in python_type)
)
self.type_str = "Union[{}]".format(
", ".join(python_type.__name__ for python_type in python_type)
)
else:
self.python_type = check.type_param(python_type, "python_type")
self.type_str = python_type.__name__
name = check.opt_str_param(name, "name", self.type_str)
key = check.opt_str_param(key, "key", name)
super(PythonObjectDagsterType, self).__init__(
key=key, name=name, type_check_fn=self.type_check_method, **kwargs
)
def type_check_method(self, _context, value):
if not isinstance(value, self.python_type):
return TypeCheck(
success=False,
description=(
"Value of type {value_type} failed type check for Dagster type {dagster_type}, "
"expected value to be of Python type {expected_type}."
).format(
value_type=type(value), dagster_type=self._name, expected_type=self.type_str,
),
)
return TypeCheck(success=True)
class NoneableInputSchema(DagsterTypeLoader):
def __init__(self, inner_dagster_type):
self._inner_dagster_type = check.inst_param(
inner_dagster_type, "inner_dagster_type", DagsterType
)
check.param_invariant(inner_dagster_type.loader, "inner_dagster_type")
self._schema_type = ConfigNoneable(inner_dagster_type.loader.schema_type)
@property
def schema_type(self):
return self._schema_type
def construct_from_config_value(self, context, config_value):
if config_value is None:
return None
return self._inner_dagster_type.loader.construct_from_config_value(context, config_value)
def _create_nullable_input_schema(inner_type):
if not inner_type.loader:
return None
return NoneableInputSchema(inner_type)
class OptionalType(DagsterType):
def __init__(self, inner_type):
inner_type = resolve_dagster_type(inner_type)
if inner_type is Nothing:
raise DagsterInvalidDefinitionError(
"Type Nothing can not be wrapped in List or Optional"
)
key = "Optional." + inner_type.key
self.inner_type = inner_type
super(OptionalType, self).__init__(
key=key,
name=None,
kind=DagsterTypeKind.NULLABLE,
type_check_fn=self.type_check_method,
loader=_create_nullable_input_schema(inner_type),
)
@property
def display_name(self):
return self.inner_type.display_name + "?"
def type_check_method(self, context, value):
return (
TypeCheck(success=True) if value is None else self.inner_type.type_check(context, value)
)
@property
def inner_types(self):
return [self.inner_type] + self.inner_type.inner_types
@property
def type_param_keys(self):
return [self.inner_type.key]
@property
def supports_fan_in(self):
return self.inner_type.supports_fan_in
def get_inner_type_for_fan_in(self):
return self.inner_type.get_inner_type_for_fan_in()
class ListInputSchema(DagsterTypeLoader):
def __init__(self, inner_dagster_type):
self._inner_dagster_type = check.inst_param(
inner_dagster_type, "inner_dagster_type", DagsterType
)
check.param_invariant(inner_dagster_type.loader, "inner_dagster_type")
self._schema_type = Array(inner_dagster_type.loader.schema_type)
@property
def schema_type(self):
return self._schema_type
def construct_from_config_value(self, context, config_value):
convert_item = partial(self._inner_dagster_type.loader.construct_from_config_value, context)
return list(map(convert_item, config_value))
def _create_list_input_schema(inner_type):
if not inner_type.loader:
return None
return ListInputSchema(inner_type)
class ListType(DagsterType):
def __init__(self, inner_type):
key = "List." + inner_type.key
self.inner_type = inner_type
super(ListType, self).__init__(
key=key,
name=None,
kind=DagsterTypeKind.LIST,
type_check_fn=self.type_check_method,
loader=_create_list_input_schema(inner_type),
)
@property
def display_name(self):
return "[" + self.inner_type.display_name + "]"
def type_check_method(self, context, value):
value_check = _fail_if_not_of_type(value, list, "list")
if not value_check.success:
return value_check
for item in value:
item_check = self.inner_type.type_check(context, item)
if not item_check.success:
return item_check
return TypeCheck(success=True)
@property
def inner_types(self):
return [self.inner_type] + self.inner_type.inner_types
@property
def type_param_keys(self):
return [self.inner_type.key]
@property
def supports_fan_in(self):
return True
def get_inner_type_for_fan_in(self):
return self.inner_type
class DagsterListApi:
def __getitem__(self, inner_type):
check.not_none_param(inner_type, "inner_type")
return _List(resolve_dagster_type(inner_type))
def __call__(self, inner_type):
check.not_none_param(inner_type, "inner_type")
return _List(inner_type)
List = DagsterListApi()
def _List(inner_type):
check.inst_param(inner_type, "inner_type", DagsterType)
if inner_type is Nothing:
raise DagsterInvalidDefinitionError("Type Nothing can not be wrapped in List or Optional")
return ListType(inner_type)
class Stringish(DagsterType):
def __init__(self, key=None, name=None, **kwargs):
name = check.opt_str_param(name, "name", type(self).__name__)
key = check.opt_str_param(key, "key", name)
super(Stringish, self).__init__(
key=key,
name=name,
kind=DagsterTypeKind.SCALAR,
type_check_fn=self.type_check_method,
loader=BuiltinSchemas.STRING_INPUT,
materializer=BuiltinSchemas.STRING_OUTPUT,
**kwargs,
)
def type_check_method(self, _context, value):
return _fail_if_not_of_type(value, str, "string")
def create_string_type(name, description=None):
return Stringish(name=name, key=name, description=description)
Any = _Any()
Bool = _Bool()
Float = _Float()
Int = _Int()
String = _String()
Nothing = _Nothing()
_RUNTIME_MAP = {
BuiltinEnum.ANY: Any,
BuiltinEnum.BOOL: Bool,
BuiltinEnum.FLOAT: Float,
BuiltinEnum.INT: Int,
BuiltinEnum.STRING: String,
BuiltinEnum.NOTHING: Nothing,
}
_PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY = {}
"""Python types corresponding to user-defined RunTime types created using @map_to_dagster_type or
as_dagster_type are registered here so that we can remap the Python types to runtime types."""
def make_python_type_usable_as_dagster_type(python_type, dagster_type):
"""
Take any existing python type and map it to a dagster type (generally created with
:py:class:`DagsterType `) This can only be called once
on a given python type.
"""
check.inst_param(dagster_type, "dagster_type", DagsterType)
if (
_PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY.get(python_type, dagster_type)
is not dagster_type
):
# This would be just a great place to insert a short URL pointing to the type system
# documentation into the error message
# https://github.com/dagster-io/dagster/issues/1831
raise DagsterInvalidDefinitionError(
(
"A Dagster type has already been registered for the Python type "
"{python_type}. make_python_type_usable_as_dagster_type can only "
"be called once on a python type as it is registering a 1:1 mapping "
"between that python type and a dagster type."
).format(python_type=python_type)
)
_PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY[python_type] = dagster_type
DAGSTER_INVALID_TYPE_ERROR_MESSAGE = (
"Invalid type: dagster_type must be DagsterType, a python scalar, or a python type "
"that has been marked usable as a dagster type via @usable_dagster_type or "
"make_python_type_usable_as_dagster_type: got {dagster_type}{additional_msg}"
)
def resolve_dagster_type(dagster_type):
# circular dep
from .python_dict import PythonDict, Dict
from .python_set import PythonSet, DagsterSetApi
from .python_tuple import PythonTuple, DagsterTupleApi
from .transform_typing import transform_typing_type
from dagster.config.config_type import ConfigType
from dagster.primitive_mapping import (
remap_python_builtin_for_runtime,
is_supported_runtime_python_builtin,
)
from dagster.utils.typing_api import is_typing_type
check.invariant(
not (isinstance(dagster_type, type) and issubclass(dagster_type, ConfigType)),
"Cannot resolve a config type to a runtime type",
)
check.invariant(
not (isinstance(dagster_type, type) and issubclass(dagster_type, DagsterType)),
"Do not pass runtime type classes. Got {}".format(dagster_type),
)
# First check to see if it part of python's typing library
if is_typing_type(dagster_type):
dagster_type = transform_typing_type(dagster_type)
if isinstance(dagster_type, DagsterType):
return dagster_type
# Test for unhashable objects -- this is if, for instance, someone has passed us an instance of
# a dict where they meant to pass dict or Dict, etc.
try:
hash(dagster_type)
except TypeError:
raise DagsterInvalidDefinitionError(
DAGSTER_INVALID_TYPE_ERROR_MESSAGE.format(
additional_msg=(
", which isn't hashable. Did you pass an instance of a type instead of "
"the type?"
),
dagster_type=str(dagster_type),
)
)
if is_supported_runtime_python_builtin(dagster_type):
return remap_python_builtin_for_runtime(dagster_type)
if dagster_type is None:
return Any
if dagster_type in _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY:
return _PYTHON_TYPE_TO_DAGSTER_TYPE_MAPPING_REGISTRY[dagster_type]
if dagster_type is Dict:
return PythonDict
if isinstance(dagster_type, DagsterTupleApi):
return PythonTuple
if isinstance(dagster_type, DagsterSetApi):
return PythonSet
if isinstance(dagster_type, DagsterListApi):
return List(Any)
if BuiltinEnum.contains(dagster_type):
return DagsterType.from_builtin_enum(dagster_type)
if not isinstance(dagster_type, type):
raise DagsterInvalidDefinitionError(
DAGSTER_INVALID_TYPE_ERROR_MESSAGE.format(
dagster_type=str(dagster_type), additional_msg="."
)
)
raise DagsterInvalidDefinitionError(
"{dagster_type} is not a valid dagster type.".format(dagster_type=dagster_type)
)
ALL_RUNTIME_BUILTINS = list(_RUNTIME_MAP.values())
def construct_dagster_type_dictionary(solid_defs):
type_dict_by_name = {t.unique_name: t for t in ALL_RUNTIME_BUILTINS}
type_dict_by_key = {t.key: t for t in ALL_RUNTIME_BUILTINS}
for solid_def in solid_defs:
for dagster_type in solid_def.all_dagster_types():
# We don't do uniqueness check on key because with classes
# like Array, Noneable, etc, those are ephemeral objects
# and it is perfectly fine to have many of them.
type_dict_by_key[dagster_type.key] = dagster_type
if not dagster_type.has_unique_name:
continue
if dagster_type.unique_name not in type_dict_by_name:
type_dict_by_name[dagster_type.unique_name] = dagster_type
continue
if type_dict_by_name[dagster_type.unique_name] is not dagster_type:
raise DagsterInvalidDefinitionError(
(
'You have created two dagster types with the same name "{type_name}". '
"Dagster types have must have unique names."
).format(type_name=dagster_type.display_name)
)
return type_dict_by_key
class DagsterOptionalApi:
def __getitem__(self, inner_type):
check.not_none_param(inner_type, "inner_type")
return OptionalType(inner_type)
Optional = DagsterOptionalApi()
diff --git a/python_modules/dagster/dagster/grpc/server.py b/python_modules/dagster/dagster/grpc/server.py
index 33f3a005a..7839cb23d 100644
--- a/python_modules/dagster/dagster/grpc/server.py
+++ b/python_modules/dagster/dagster/grpc/server.py
@@ -1,1056 +1,1058 @@
import math
import os
import queue
import sys
+import tempfile
import threading
import time
import uuid
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
+from threading import Event as ThreadingEventType
import grpc
from dagster import check, seven
from dagster.core.code_pointer import CodePointer
from dagster.core.definitions.reconstructable import (
ReconstructableRepository,
repository_def_from_target_def,
)
from dagster.core.host_representation import ExternalPipelineOrigin, ExternalRepositoryOrigin
from dagster.core.host_representation.external_data import external_repository_data_from_def
from dagster.core.instance import DagsterInstance
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.serdes import (
deserialize_json_to_dagster_namedtuple,
serialize_dagster_namedtuple,
whitelist_for_serdes,
)
from dagster.serdes.ipc import (
IPCErrorMessage,
ipc_write_stream,
open_ipc_subprocess,
read_unary_response,
)
from dagster.seven import multiprocessing
from dagster.utils import find_free_port, safe_tempfile_path_unmanaged
from dagster.utils.error import serializable_error_info_from_exc_info
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from .__generated__ import api_pb2
from .__generated__.api_pb2_grpc import DagsterApiServicer, add_DagsterApiServicer_to_server
from .impl import (
RunInSubprocessComplete,
StartRunInSubprocessSuccessful,
get_external_execution_plan_snapshot,
get_external_pipeline_subset_result,
get_external_schedule_execution,
get_external_sensor_execution,
get_partition_config,
get_partition_names,
get_partition_set_execution_param_data,
get_partition_tags,
start_run_in_subprocess,
)
from .types import (
CanCancelExecutionRequest,
CanCancelExecutionResult,
CancelExecutionRequest,
CancelExecutionResult,
ExecuteExternalPipelineArgs,
ExecutionPlanSnapshotArgs,
ExternalScheduleExecutionArgs,
GetCurrentImageResult,
ListRepositoriesResponse,
LoadableRepositorySymbol,
PartitionArgs,
PartitionNamesArgs,
PartitionSetExecutionParamArgs,
PipelineSubsetSnapshotArgs,
SensorExecutionArgs,
ShutdownServerResult,
StartRunResult,
)
from .utils import get_loadable_targets
EVENT_QUEUE_POLL_INTERVAL = 0.1
CLEANUP_TICK = 0.5
STREAMING_EXTERNAL_REPOSITORY_CHUNK_SIZE = 4000000
class CouldNotBindGrpcServerToAddress(Exception):
pass
class LazyRepositorySymbolsAndCodePointers:
"""Enables lazily loading user code at RPC-time so that it doesn't interrupt startup and
we can gracefully handle user code errors."""
def __init__(self, loadable_target_origin):
self._loadable_target_origin = loadable_target_origin
self._loadable_repository_symbols = None
self._code_pointers_by_repo_name = None
def load(self):
self._loadable_repository_symbols = load_loadable_repository_symbols(
self._loadable_target_origin
)
self._code_pointers_by_repo_name = build_code_pointers_by_repo_name(
self._loadable_target_origin, self._loadable_repository_symbols
)
@property
def loadable_repository_symbols(self):
if self._loadable_repository_symbols is None:
self.load()
return self._loadable_repository_symbols
@property
def code_pointers_by_repo_name(self):
if self._code_pointers_by_repo_name is None:
self.load()
return self._code_pointers_by_repo_name
def load_loadable_repository_symbols(loadable_target_origin):
if loadable_target_origin:
loadable_targets = get_loadable_targets(
loadable_target_origin.python_file,
loadable_target_origin.module_name,
loadable_target_origin.package_name,
loadable_target_origin.working_directory,
loadable_target_origin.attribute,
)
return [
LoadableRepositorySymbol(
attribute=loadable_target.attribute,
repository_name=repository_def_from_target_def(
loadable_target.target_definition
).name,
)
for loadable_target in loadable_targets
]
else:
return []
def build_code_pointers_by_repo_name(loadable_target_origin, loadable_repository_symbols):
repository_code_pointer_dict = {}
for loadable_repository_symbol in loadable_repository_symbols:
if loadable_target_origin.python_file:
repository_code_pointer_dict[
loadable_repository_symbol.repository_name
] = CodePointer.from_python_file(
loadable_target_origin.python_file,
loadable_repository_symbol.attribute,
loadable_target_origin.working_directory,
)
elif loadable_target_origin.package_name:
repository_code_pointer_dict[
loadable_repository_symbol.repository_name
] = CodePointer.from_python_package(
loadable_target_origin.package_name, loadable_repository_symbol.attribute,
)
else:
repository_code_pointer_dict[
loadable_repository_symbol.repository_name
] = CodePointer.from_module(
loadable_target_origin.module_name, loadable_repository_symbol.attribute,
)
return repository_code_pointer_dict
class DagsterApiServer(DagsterApiServicer):
# The loadable_target_origin is currently Noneable to support instaniating a server.
# This helps us test the ping methods, and incrementally migrate each method to
# the target passed in here instead of passing in a target in the argument.
def __init__(
self,
server_termination_event,
loadable_target_origin=None,
heartbeat=False,
heartbeat_timeout=30,
lazy_load_user_code=False,
fixed_server_id=None,
):
super(DagsterApiServer, self).__init__()
check.bool_param(heartbeat, "heartbeat")
check.int_param(heartbeat_timeout, "heartbeat_timeout")
check.invariant(heartbeat_timeout > 0, "heartbeat_timeout must be greater than 0")
self._server_termination_event = check.inst_param(
- server_termination_event, "server_termination_event", seven.ThreadingEventType
+ server_termination_event, "server_termination_event", ThreadingEventType
)
self._loadable_target_origin = check.opt_inst_param(
loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin
)
# Each server is initialized with a unique UUID. This UUID is used by clients to track when
# servers are replaced and is used for cache invalidation and reloading.
self._server_id = check.opt_str_param(fixed_server_id, "fixed_server_id", str(uuid.uuid4()))
# Client tells the server to shutdown by calling ShutdownServer (or by failing to send a
# hearbeat, at which point this event is set. The cleanup thread will then set the server
# termination event once all current executions have finished, which will stop the server)
self._shutdown_once_executions_finish_event = threading.Event()
# Dict[str, (multiprocessing.Process, DagsterInstance)]
self._executions = {}
# Dict[str, multiprocessing.Event]
self._termination_events = {}
self._termination_times = {}
self._execution_lock = threading.Lock()
self._repository_symbols_and_code_pointers = LazyRepositorySymbolsAndCodePointers(
loadable_target_origin
)
if not lazy_load_user_code:
self._repository_symbols_and_code_pointers.load()
self.__last_heartbeat_time = time.time()
if heartbeat:
self.__heartbeat_thread = threading.Thread(
target=self._heartbeat_thread,
args=(heartbeat_timeout,),
name="grpc-server-heartbeat",
)
self.__heartbeat_thread.daemon = True
self.__heartbeat_thread.start()
else:
self.__heartbeat_thread = None
self.__cleanup_thread = threading.Thread(
target=self._cleanup_thread, args=(), name="grpc-server-cleanup"
)
self.__cleanup_thread.daemon = True
self.__cleanup_thread.start()
def cleanup(self):
if self.__heartbeat_thread:
self.__heartbeat_thread.join()
self.__cleanup_thread.join()
def _heartbeat_thread(self, heartbeat_timeout):
while True:
self._shutdown_once_executions_finish_event.wait(heartbeat_timeout)
if self._shutdown_once_executions_finish_event.is_set():
break
if self.__last_heartbeat_time < time.time() - heartbeat_timeout:
self._shutdown_once_executions_finish_event.set()
def _cleanup_thread(self):
while True:
self._server_termination_event.wait(CLEANUP_TICK)
if self._server_termination_event.is_set():
break
self._check_for_orphaned_runs()
def _check_for_orphaned_runs(self):
with self._execution_lock:
runs_to_clear = []
for run_id, (process, instance_ref) in self._executions.items():
if not process.is_alive():
with DagsterInstance.from_ref(instance_ref) as instance:
runs_to_clear.append(run_id)
run = instance.get_run_by_id(run_id)
if not run or run.is_finished:
continue
# the process died in an unexpected manner. inform the system
message = "Pipeline execution process for {run_id} unexpectedly exited.".format(
run_id=run.run_id
)
instance.report_engine_event(message, run, cls=self.__class__)
instance.report_run_failed(run)
for run_id in runs_to_clear:
self._clear_run(run_id)
# Once there are no more running executions after we have received a request to
# shut down, terminate the server
if self._shutdown_once_executions_finish_event.is_set():
if len(self._executions) == 0:
self._server_termination_event.set()
# Assumes execution lock is being held
def _clear_run(self, run_id):
del self._executions[run_id]
del self._termination_events[run_id]
if run_id in self._termination_times:
del self._termination_times[run_id]
def _recon_repository_from_origin(self, external_repository_origin):
check.inst_param(
external_repository_origin, "external_repository_origin", ExternalRepositoryOrigin,
)
return ReconstructableRepository(
self._repository_symbols_and_code_pointers.code_pointers_by_repo_name[
external_repository_origin.repository_name
],
self._get_current_image(),
)
def _recon_pipeline_from_origin(self, external_pipeline_origin):
check.inst_param(
external_pipeline_origin, "external_pipeline_origin", ExternalPipelineOrigin
)
recon_repo = self._recon_repository_from_origin(
external_pipeline_origin.external_repository_origin
)
return recon_repo.get_reconstructable_pipeline(external_pipeline_origin.pipeline_name)
def Ping(self, request, _context):
echo = request.echo
return api_pb2.PingReply(echo=echo)
def StreamingPing(self, request, _context):
sequence_length = request.sequence_length
echo = request.echo
for sequence_number in range(sequence_length):
yield api_pb2.StreamingPingEvent(sequence_number=sequence_number, echo=echo)
def Heartbeat(self, request, _context):
self.__last_heartbeat_time = time.time()
echo = request.echo
return api_pb2.PingReply(echo=echo)
def GetServerId(self, _request, _context):
return api_pb2.GetServerIdReply(server_id=self._server_id)
def ExecutionPlanSnapshot(self, request, _context):
execution_plan_args = deserialize_json_to_dagster_namedtuple(
request.serialized_execution_plan_snapshot_args
)
check.inst_param(execution_plan_args, "execution_plan_args", ExecutionPlanSnapshotArgs)
recon_pipeline = self._recon_pipeline_from_origin(execution_plan_args.pipeline_origin)
execution_plan_snapshot_or_error = get_external_execution_plan_snapshot(
recon_pipeline, execution_plan_args
)
return api_pb2.ExecutionPlanSnapshotReply(
serialized_execution_plan_snapshot=serialize_dagster_namedtuple(
execution_plan_snapshot_or_error
)
)
def ListRepositories(self, request, _context):
try:
response = ListRepositoriesResponse(
self._repository_symbols_and_code_pointers.loadable_repository_symbols,
executable_path=self._loadable_target_origin.executable_path
if self._loadable_target_origin
else None,
repository_code_pointer_dict=(
self._repository_symbols_and_code_pointers.code_pointers_by_repo_name
),
)
except Exception: # pylint: disable=broad-except
response = serializable_error_info_from_exc_info(sys.exc_info())
return api_pb2.ListRepositoriesReply(
serialized_list_repositories_response_or_error=serialize_dagster_namedtuple(response)
)
def ExternalPartitionNames(self, request, _context):
partition_names_args = deserialize_json_to_dagster_namedtuple(
request.serialized_partition_names_args
)
check.inst_param(partition_names_args, "partition_names_args", PartitionNamesArgs)
recon_repo = self._recon_repository_from_origin(partition_names_args.repository_origin)
return api_pb2.ExternalPartitionNamesReply(
serialized_external_partition_names_or_external_partition_execution_error=serialize_dagster_namedtuple(
get_partition_names(recon_repo, partition_names_args.partition_set_name,)
)
)
def ExternalPartitionSetExecutionParams(self, request, _context):
args = deserialize_json_to_dagster_namedtuple(
request.serialized_partition_set_execution_param_args
)
check.inst_param(
args, "args", PartitionSetExecutionParamArgs,
)
recon_repo = self._recon_repository_from_origin(args.repository_origin)
return api_pb2.ExternalPartitionSetExecutionParamsReply(
serialized_external_partition_set_execution_param_data_or_external_partition_execution_error=serialize_dagster_namedtuple(
get_partition_set_execution_param_data(
recon_repo=recon_repo,
partition_set_name=args.partition_set_name,
partition_names=args.partition_names,
)
)
)
def ExternalPartitionConfig(self, request, _context):
args = deserialize_json_to_dagster_namedtuple(request.serialized_partition_args)
check.inst_param(args, "args", PartitionArgs)
recon_repo = self._recon_repository_from_origin(args.repository_origin)
return api_pb2.ExternalPartitionConfigReply(
serialized_external_partition_config_or_external_partition_execution_error=serialize_dagster_namedtuple(
get_partition_config(recon_repo, args.partition_set_name, args.partition_name)
)
)
def ExternalPartitionTags(self, request, _context):
partition_args = deserialize_json_to_dagster_namedtuple(request.serialized_partition_args)
check.inst_param(partition_args, "partition_args", PartitionArgs)
recon_repo = self._recon_repository_from_origin(partition_args.repository_origin)
return api_pb2.ExternalPartitionTagsReply(
serialized_external_partition_tags_or_external_partition_execution_error=serialize_dagster_namedtuple(
get_partition_tags(
recon_repo, partition_args.partition_set_name, partition_args.partition_name
)
)
)
def ExternalPipelineSubsetSnapshot(self, request, _context):
pipeline_subset_snapshot_args = deserialize_json_to_dagster_namedtuple(
request.serialized_pipeline_subset_snapshot_args
)
check.inst_param(
pipeline_subset_snapshot_args,
"pipeline_subset_snapshot_args",
PipelineSubsetSnapshotArgs,
)
return api_pb2.ExternalPipelineSubsetSnapshotReply(
serialized_external_pipeline_subset_result=serialize_dagster_namedtuple(
get_external_pipeline_subset_result(
self._recon_pipeline_from_origin(pipeline_subset_snapshot_args.pipeline_origin),
pipeline_subset_snapshot_args.solid_selection,
)
)
)
def _get_serialized_external_repository_data(self, request):
repository_origin = deserialize_json_to_dagster_namedtuple(
request.serialized_repository_python_origin
)
check.inst_param(repository_origin, "repository_origin", ExternalRepositoryOrigin)
recon_repo = self._recon_repository_from_origin(repository_origin)
return serialize_dagster_namedtuple(
external_repository_data_from_def(recon_repo.get_definition())
)
def ExternalRepository(self, request, _context):
serialized_external_repository_data = self._get_serialized_external_repository_data(request)
return api_pb2.ExternalRepositoryReply(
serialized_external_repository_data=serialized_external_repository_data,
)
def StreamingExternalRepository(self, request, _context):
serialized_external_repository_data = self._get_serialized_external_repository_data(request)
num_chunks = int(
math.ceil(
float(len(serialized_external_repository_data))
/ STREAMING_EXTERNAL_REPOSITORY_CHUNK_SIZE
)
)
for i in range(num_chunks):
start_index = i * STREAMING_EXTERNAL_REPOSITORY_CHUNK_SIZE
end_index = min(
(i + 1) * STREAMING_EXTERNAL_REPOSITORY_CHUNK_SIZE,
len(serialized_external_repository_data),
)
yield api_pb2.StreamingExternalRepositoryEvent(
sequence_number=i,
serialized_external_repository_chunk=serialized_external_repository_data[
start_index:end_index
],
)
def ExternalScheduleExecution(self, request, _context):
args = deserialize_json_to_dagster_namedtuple(
request.serialized_external_schedule_execution_args
)
check.inst_param(
args, "args", ExternalScheduleExecutionArgs,
)
recon_repo = self._recon_repository_from_origin(args.repository_origin)
return api_pb2.ExternalScheduleExecutionReply(
serialized_external_schedule_execution_data_or_external_schedule_execution_error=serialize_dagster_namedtuple(
get_external_schedule_execution(
recon_repo,
args.instance_ref,
args.schedule_name,
args.scheduled_execution_timestamp,
args.scheduled_execution_timezone,
)
)
)
def ExternalSensorExecution(self, request, _context):
args = deserialize_json_to_dagster_namedtuple(
request.serialized_external_sensor_execution_args
)
check.inst_param(args, "args", SensorExecutionArgs)
recon_repo = self._recon_repository_from_origin(args.repository_origin)
return api_pb2.ExternalSensorExecutionReply(
serialized_external_sensor_execution_data_or_external_sensor_execution_error=serialize_dagster_namedtuple(
get_external_sensor_execution(
recon_repo,
args.instance_ref,
args.sensor_name,
args.last_completion_time,
args.last_run_key,
)
)
)
def ShutdownServer(self, request, _context):
try:
self._shutdown_once_executions_finish_event.set()
return api_pb2.ShutdownServerReply(
serialized_shutdown_server_result=serialize_dagster_namedtuple(
ShutdownServerResult(success=True, serializable_error_info=None)
)
)
except: # pylint: disable=bare-except
return api_pb2.ShutdownServerReply(
serialized_shutdown_server_result=serialize_dagster_namedtuple(
ShutdownServerResult(
success=False,
serializable_error_info=serializable_error_info_from_exc_info(
sys.exc_info()
),
)
)
)
def CancelExecution(self, request, _context):
success = False
message = None
serializable_error_info = None
try:
cancel_execution_request = check.inst(
deserialize_json_to_dagster_namedtuple(request.serialized_cancel_execution_request),
CancelExecutionRequest,
)
with self._execution_lock:
if cancel_execution_request.run_id in self._executions:
self._termination_events[cancel_execution_request.run_id].set()
self._termination_times[cancel_execution_request.run_id] = time.time()
success = True
except: # pylint: disable=bare-except
serializable_error_info = serializable_error_info_from_exc_info(sys.exc_info())
return api_pb2.CancelExecutionReply(
serialized_cancel_execution_result=serialize_dagster_namedtuple(
CancelExecutionResult(
success=success,
message=message,
serializable_error_info=serializable_error_info,
)
)
)
def CanCancelExecution(self, request, _context):
can_cancel_execution_request = check.inst(
deserialize_json_to_dagster_namedtuple(request.serialized_can_cancel_execution_request),
CanCancelExecutionRequest,
)
with self._execution_lock:
run_id = can_cancel_execution_request.run_id
can_cancel = (
run_id in self._executions and not self._termination_events[run_id].is_set()
)
return api_pb2.CanCancelExecutionReply(
serialized_can_cancel_execution_result=serialize_dagster_namedtuple(
CanCancelExecutionResult(can_cancel=can_cancel)
)
)
def StartRun(self, request, _context):
if self._shutdown_once_executions_finish_event.is_set():
return api_pb2.StartRunReply(
serialized_start_run_result=serialize_dagster_namedtuple(
StartRunResult(
success=False,
message="Tried to start a run on a server after telling it to shut down",
serializable_error_info=None,
)
)
)
try:
execute_run_args = check.inst(
deserialize_json_to_dagster_namedtuple(request.serialized_execute_run_args),
ExecuteExternalPipelineArgs,
)
run_id = execute_run_args.pipeline_run_id
recon_pipeline = self._recon_pipeline_from_origin(execute_run_args.pipeline_origin)
except: # pylint: disable=bare-except
return api_pb2.StartRunReply(
serialized_start_run_result=serialize_dagster_namedtuple(
StartRunResult(
success=False,
message=None,
serializable_error_info=serializable_error_info_from_exc_info(
sys.exc_info()
),
)
)
)
event_queue = multiprocessing.Queue()
termination_event = multiprocessing.Event()
execution_process = multiprocessing.Process(
target=start_run_in_subprocess,
args=[
request.serialized_execute_run_args,
recon_pipeline,
event_queue,
termination_event,
],
)
with self._execution_lock:
execution_process.start()
self._executions[run_id] = (
execution_process,
execute_run_args.instance_ref,
)
self._termination_events[run_id] = termination_event
success = None
message = None
serializable_error_info = None
while success is None:
time.sleep(EVENT_QUEUE_POLL_INTERVAL)
# We use `get_nowait()` instead of `get()` so that we can handle the case where the
# execution process has died unexpectedly -- `get()` would hang forever in that case
try:
dagster_event_or_ipc_error_message_or_done = event_queue.get_nowait()
except queue.Empty:
if not execution_process.is_alive():
# subprocess died unexpectedly
success = False
message = (
"GRPC server: Subprocess for {run_id} terminated unexpectedly with "
"exit code {exit_code}".format(
run_id=run_id, exit_code=execution_process.exitcode,
)
)
serializable_error_info = serializable_error_info_from_exc_info(sys.exc_info())
else:
if isinstance(
dagster_event_or_ipc_error_message_or_done, StartRunInSubprocessSuccessful
):
success = True
elif isinstance(
dagster_event_or_ipc_error_message_or_done, RunInSubprocessComplete
):
continue
if isinstance(dagster_event_or_ipc_error_message_or_done, IPCErrorMessage):
success = False
message = dagster_event_or_ipc_error_message_or_done.message
serializable_error_info = (
dagster_event_or_ipc_error_message_or_done.serializable_error_info
)
# Ensure that if the run failed, we remove it from the executions map before
# returning so that CanCancel will never return True
if not success:
with self._execution_lock:
self._clear_run(run_id)
return api_pb2.StartRunReply(
serialized_start_run_result=serialize_dagster_namedtuple(
StartRunResult(
success=success,
message=message,
serializable_error_info=serializable_error_info,
)
)
)
def _get_current_image(self):
return os.getenv("DAGSTER_CURRENT_IMAGE")
def GetCurrentImage(self, request, _context):
return api_pb2.GetCurrentImageReply(
serialized_current_image=serialize_dagster_namedtuple(
GetCurrentImageResult(
current_image=self._get_current_image(), serializable_error_info=None
)
)
)
@whitelist_for_serdes
class GrpcServerStartedEvent(namedtuple("GrpcServerStartedEvent", "")):
pass
@whitelist_for_serdes
class GrpcServerFailedToBindEvent(namedtuple("GrpcServerStartedEvent", "")):
pass
def server_termination_target(termination_event, server):
termination_event.wait()
# We could make this grace period configurable if we set it in the ShutdownServer handler
server.stop(grace=5)
class DagsterGrpcServer:
def __init__(
self,
host="localhost",
port=None,
socket=None,
max_workers=1,
loadable_target_origin=None,
heartbeat=False,
heartbeat_timeout=30,
lazy_load_user_code=False,
ipc_output_file=None,
fixed_server_id=None,
):
check.opt_str_param(host, "host")
check.opt_int_param(port, "port")
check.opt_str_param(socket, "socket")
check.int_param(max_workers, "max_workers")
check.opt_inst_param(loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin)
check.invariant(
port is not None if seven.IS_WINDOWS else True,
"You must pass a valid `port` on Windows: `socket` not supported.",
)
check.invariant(
(port or socket) and not (port and socket),
"You must pass one and only one of `port` or `socket`.",
)
check.invariant(
host is not None if port else True, "Must provide a host when serving on a port",
)
check.bool_param(heartbeat, "heartbeat")
check.int_param(heartbeat_timeout, "heartbeat_timeout")
self._ipc_output_file = check.opt_str_param(ipc_output_file, "ipc_output_file")
check.opt_str_param(fixed_server_id, "fixed_server_id")
check.invariant(heartbeat_timeout > 0, "heartbeat_timeout must be greater than 0")
check.invariant(
max_workers > 1 if heartbeat else True,
"max_workers must be greater than 1 if heartbeat is True",
)
self.server = grpc.server(ThreadPoolExecutor(max_workers=max_workers))
self._server_termination_event = threading.Event()
self._api_servicer = DagsterApiServer(
server_termination_event=self._server_termination_event,
loadable_target_origin=loadable_target_origin,
heartbeat=heartbeat,
heartbeat_timeout=heartbeat_timeout,
lazy_load_user_code=lazy_load_user_code,
fixed_server_id=fixed_server_id,
)
# Create a health check servicer
self._health_servicer = health.HealthServicer()
health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server)
add_DagsterApiServicer_to_server(self._api_servicer, self.server)
if port:
server_address = host + ":" + str(port)
else:
server_address = "unix:" + os.path.abspath(socket)
# grpc.Server.add_insecure_port returns:
# - 0 on failure
# - port number when a port is successfully bound
# - 1 when a UDS is successfully bound
res = self.server.add_insecure_port(server_address)
if socket and res != 1:
if self._ipc_output_file:
with ipc_write_stream(self._ipc_output_file) as ipc_stream:
ipc_stream.send(GrpcServerFailedToBindEvent())
raise CouldNotBindGrpcServerToAddress(socket)
if port and res != port:
if self._ipc_output_file:
with ipc_write_stream(self._ipc_output_file) as ipc_stream:
ipc_stream.send(GrpcServerFailedToBindEvent())
raise CouldNotBindGrpcServerToAddress(port)
def serve(self):
# Unfortunately it looks like ports bind late (here) and so this can fail with an error
# from C++ like:
#
# E0625 08:46:56.180112000 4697443776 server_chttp2.cc:40]
# {"created":"@1593089216.180085000","description":"Only 1 addresses added out of total
# 2 resolved","file":"src/core/ext/transport/chttp2/server/chttp2_server.cc",
# "file_line":406,"referenced_errors":[{"created":"@1593089216.180083000","description":
# "Unable to configure socket","fd":6,"file":
# "src/core/lib/iomgr/tcp_server_utils_posix_common.cc","file_line":217,
# "referenced_errors":[{"created":"@1593089216.180079000",
# "description":"Address already in use","errno":48,"file":
# "src/core/lib/iomgr/tcp_server_utils_posix_common.cc","file_line":190,"os_error":
# "Address already in use","syscall":"bind"}]}]}
#
# This is printed to stdout and there is no return value from server.start or exception
# raised in Python that we can use to handle this. The standard recipes for hijacking C
# stdout (so we could inspect this output and respond accordingly), e.g.
# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/, don't seem
# to work (at least on Mac OS X) against grpc, and in any case would involve a huge
# cross-version and cross-platform maintenance burden. We have an issue open against grpc,
# https://github.com/grpc/grpc/issues/23315, and our own tracking issue at
self.server.start()
# Note: currently this is hardcoded as serving, since both services are cohosted
# pylint: disable=no-member
self._health_servicer.set("DagsterApi", health_pb2.HealthCheckResponse.SERVING)
if self._ipc_output_file:
with ipc_write_stream(self._ipc_output_file) as ipc_stream:
ipc_stream.send(GrpcServerStartedEvent())
server_termination_thread = threading.Thread(
target=server_termination_target,
args=[self._server_termination_event, self.server],
name="grpc-server-termination",
)
server_termination_thread.daemon = True
server_termination_thread.start()
self.server.wait_for_termination()
server_termination_thread.join()
self._api_servicer.cleanup()
class CouldNotStartServerProcess(Exception):
def __init__(self, port=None, socket=None):
super(CouldNotStartServerProcess, self).__init__(
"Could not start server with "
+ (
"port {port}".format(port=port)
if port is not None
else "socket {socket}".format(socket=socket)
)
)
def wait_for_grpc_server(server_process, ipc_output_file, timeout=15):
event = read_unary_response(ipc_output_file, timeout=timeout, ipc_process=server_process)
if isinstance(event, GrpcServerFailedToBindEvent):
raise CouldNotBindGrpcServerToAddress()
elif isinstance(event, GrpcServerStartedEvent):
return True
else:
raise Exception(
"Received unexpected IPC event from gRPC Server: {event}".format(event=event)
)
def open_server_process(
port,
socket,
loadable_target_origin=None,
max_workers=1,
heartbeat=False,
heartbeat_timeout=30,
lazy_load_user_code=False,
fixed_server_id=None,
):
check.invariant((port or socket) and not (port and socket), "Set only port or socket")
check.opt_inst_param(loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin)
check.int_param(max_workers, "max_workers")
from dagster.core.test_utils import get_mocked_system_timezone
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
output_file = os.path.join(
temp_dir, "grpc-server-startup-{uuid}".format(uuid=uuid.uuid4().hex)
)
mocked_system_timezone = get_mocked_system_timezone()
subprocess_args = (
[
loadable_target_origin.executable_path
if loadable_target_origin and loadable_target_origin.executable_path
else sys.executable,
"-m",
"dagster.grpc",
]
+ (["--port", str(port)] if port else [])
+ (["--socket", socket] if socket else [])
+ ["-n", str(max_workers)]
+ (["--heartbeat"] if heartbeat else [])
+ (["--heartbeat-timeout", str(heartbeat_timeout)] if heartbeat_timeout else [])
+ (["--lazy-load-user-code"] if lazy_load_user_code else [])
+ (["--ipc-output-file", output_file])
+ (["--fixed-server-id", fixed_server_id] if fixed_server_id else [])
+ (
["--override-system-timezone", mocked_system_timezone]
if mocked_system_timezone
else []
)
)
if loadable_target_origin:
subprocess_args += loadable_target_origin.get_cli_args()
server_process = open_ipc_subprocess(subprocess_args)
try:
wait_for_grpc_server(server_process, output_file)
except:
if server_process.poll() is None:
server_process.terminate()
raise
return server_process
def open_server_process_on_dynamic_port(
max_retries=10,
loadable_target_origin=None,
max_workers=1,
heartbeat=False,
heartbeat_timeout=30,
lazy_load_user_code=False,
fixed_server_id=None,
):
server_process = None
retries = 0
while server_process is None and retries < max_retries:
port = find_free_port()
try:
server_process = open_server_process(
port=port,
socket=None,
loadable_target_origin=loadable_target_origin,
max_workers=max_workers,
heartbeat=heartbeat,
heartbeat_timeout=heartbeat_timeout,
lazy_load_user_code=lazy_load_user_code,
fixed_server_id=fixed_server_id,
)
except CouldNotBindGrpcServerToAddress:
pass
retries += 1
return server_process, port
def cleanup_server_process(server_process, timeout=3):
start_time = time.time()
while server_process.poll() is None and (time.time() - start_time) < timeout:
time.sleep(0.05)
if server_process.poll() is None:
server_process.terminate()
server_process.wait()
class GrpcServerProcess:
def __init__(
self,
loadable_target_origin=None,
force_port=False,
max_retries=10,
max_workers=1,
heartbeat=False,
heartbeat_timeout=30,
lazy_load_user_code=False,
fixed_server_id=None,
):
self.port = None
self.socket = None
self.server_process = None
check.opt_inst_param(loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin)
check.bool_param(force_port, "force_port")
check.int_param(max_retries, "max_retries")
check.int_param(max_workers, "max_workers")
check.bool_param(heartbeat, "heartbeat")
check.int_param(heartbeat_timeout, "heartbeat_timeout")
check.invariant(heartbeat_timeout > 0, "heartbeat_timeout must be greater than 0")
check.bool_param(lazy_load_user_code, "lazy_load_user_code")
check.opt_str_param(fixed_server_id, "fixed_server_id")
check.invariant(
max_workers > 1 if heartbeat else True,
"max_workers must be greater than 1 if heartbeat is True",
)
if seven.IS_WINDOWS or force_port:
self.server_process, self.port = open_server_process_on_dynamic_port(
max_retries=max_retries,
loadable_target_origin=loadable_target_origin,
max_workers=max_workers,
heartbeat=heartbeat,
heartbeat_timeout=heartbeat_timeout,
lazy_load_user_code=lazy_load_user_code,
fixed_server_id=fixed_server_id,
)
else:
self.socket = safe_tempfile_path_unmanaged()
self.server_process = open_server_process(
port=None,
socket=self.socket,
loadable_target_origin=loadable_target_origin,
max_workers=max_workers,
heartbeat=heartbeat,
heartbeat_timeout=heartbeat_timeout,
lazy_load_user_code=lazy_load_user_code,
fixed_server_id=fixed_server_id,
)
if self.server_process is None:
raise CouldNotStartServerProcess(port=self.port, socket=self.socket)
def wait(self, timeout=30):
if self.server_process.poll() is None:
seven.wait_for_process(self.server_process, timeout=timeout)
def create_ephemeral_client(self):
from dagster.grpc.client import EphemeralDagsterGrpcClient
return EphemeralDagsterGrpcClient(
port=self.port, socket=self.socket, server_process=self.server_process
)
diff --git a/python_modules/dagster/dagster/serdes/__init__.py b/python_modules/dagster/dagster/serdes/__init__.py
index fc189ebca..39c559bc3 100644
--- a/python_modules/dagster/dagster/serdes/__init__.py
+++ b/python_modules/dagster/dagster/serdes/__init__.py
@@ -1,448 +1,448 @@
"""
Serialization & deserialization for Dagster objects.
Why have custom serialization?
* Default json serialization doesn't work well on namedtuples, which we use extensively to create
immutable value types. Namedtuples serialize like tuples as flat lists.
* Explicit whitelisting should help ensure we are only persisting or communicating across a
serialization boundary the types we expect to.
Why not pickle?
* This isn't meant to replace pickle in the conditions that pickle is reasonable to use
(in memory, not human readable, etc) just handle the json case effectively.
"""
import hashlib
import importlib
import sys
from abc import ABC, abstractmethod, abstractproperty
from collections import namedtuple
from enum import Enum
import six
import yaml
from dagster import check, seven
from dagster.utils import compose
_WHITELIST_MAP = {
"types": {"tuple": {}, "enum": {}},
"persistence": {},
}
def create_snapshot_id(snapshot):
json_rep = serialize_dagster_namedtuple(snapshot)
m = hashlib.sha1() # so that hexdigest is 40, not 64 bytes
m.update(json_rep.encode())
return m.hexdigest()
def serialize_pp(value):
return serialize_dagster_namedtuple(value, indent=2, separators=(",", ": "))
def register_serdes_tuple_fallbacks(fallback_map):
for class_name, klass in fallback_map.items():
_WHITELIST_MAP["types"]["tuple"][class_name] = klass
def _get_dunder_new_params_dict(klass):
check.invariant(sys.version_info.major >= 3, "This function can only be run in python 3")
# only pulled in by python 3
from inspect import signature
return signature(klass.__new__).parameters
def _get_dunder_new_params(klass):
return list(_get_dunder_new_params_dict(klass).values())
class SerdesClassUsageError(Exception):
pass
class Persistable(ABC):
def to_storage_value(self):
return default_to_storage_value(self, _WHITELIST_MAP)
@classmethod
def from_storage_dict(cls, storage_dict):
return default_from_storage_dict(cls, storage_dict)
def _check_serdes_tuple_class_invariants(klass):
check.invariant(sys.version_info.major >= 3, "This function can only be run in python 3")
# pull this in dynamically because this method is only called in python 3 contexts
from inspect import Parameter
dunder_new_params = _get_dunder_new_params(klass)
cls_param = dunder_new_params[0]
def _with_header(msg):
return "For namedtuple {class_name}: {msg}".format(class_name=klass.__name__, msg=msg)
if cls_param.name not in {"cls", "_cls"}:
raise SerdesClassUsageError(
_with_header(
'First parameter must be _cls or cls. Got "{name}".'.format(name=cls_param.name)
)
)
value_params = dunder_new_params[1:]
for index, field in enumerate(klass._fields):
if index >= len(value_params):
error_msg = (
"Missing parameters to __new__. You have declared fields "
"in the named tuple that are not present as parameters to the "
"to the __new__ method. In order for "
"both serdes serialization and pickling to work, "
"these must match. Missing: {missing_fields}"
).format(missing_fields=repr(list(klass._fields[index:])))
raise SerdesClassUsageError(_with_header(error_msg))
value_param = value_params[index]
if value_param.name != field:
error_msg = (
"Params to __new__ must match the order of field declaration in the namedtuple. "
'Declared field number {one_based_index} in the namedtuple is "{field_name}". '
'Parameter {one_based_index} in __new__ method is "{param_name}".'
).format(one_based_index=index + 1, field_name=field, param_name=value_param.name)
raise SerdesClassUsageError(_with_header(error_msg))
if len(value_params) > len(klass._fields):
# Ensure that remaining parameters have default values
for extra_param_index in range(len(klass._fields), len(value_params) - 1):
if value_params[extra_param_index].default == Parameter.empty:
error_msg = (
'Parameter "{param_name}" is a parameter to the __new__ '
"method but is not a field in this namedtuple. The only "
"reason why this should exist is that "
"it is a field that used to exist (we refer to this as the graveyard) "
"but no longer does. However it might exist in historical storage. This "
"parameter existing ensures that serdes continues to work. However these "
"must come at the end and have a default value for pickling to work."
).format(param_name=value_params[extra_param_index].name)
raise SerdesClassUsageError(_with_header(error_msg))
def _whitelist_for_persistence(whitelist_map):
def __whitelist_for_persistence(klass):
check.subclass_param(klass, "klass", Persistable)
whitelist_map["persistence"][klass.__name__] = klass
return klass
return __whitelist_for_persistence
def _whitelist_for_serdes(whitelist_map):
def __whitelist_for_serdes(klass):
if issubclass(klass, Enum):
whitelist_map["types"]["enum"][klass.__name__] = klass
elif issubclass(klass, tuple):
# only catch this in python 3 dev environments
# no need to do backwards compat since this is
# only for development time
if sys.version_info.major >= 3:
_check_serdes_tuple_class_invariants(klass)
whitelist_map["types"]["tuple"][klass.__name__] = klass
else:
check.failed("Can not whitelist class {klass} for serdes".format(klass=klass))
return klass
return __whitelist_for_serdes
def whitelist_for_serdes(klass):
check.class_param(klass, "klass")
return _whitelist_for_serdes(whitelist_map=_WHITELIST_MAP)(klass)
def whitelist_for_persistence(klass):
check.class_param(klass, "klass")
return compose(
_whitelist_for_persistence(whitelist_map=_WHITELIST_MAP),
_whitelist_for_serdes(whitelist_map=_WHITELIST_MAP),
)(klass)
def pack_value(val):
return _pack_value(val, whitelist_map=_WHITELIST_MAP)
def _pack_value(val, whitelist_map):
if isinstance(val, list):
return [_pack_value(i, whitelist_map) for i in val]
if isinstance(val, tuple):
klass_name = val.__class__.__name__
check.invariant(
klass_name in whitelist_map["types"]["tuple"],
"Can only serialize whitelisted namedtuples, received tuple {}".format(val),
)
if klass_name in whitelist_map["persistence"]:
return val.to_storage_value()
base_dict = {key: _pack_value(value, whitelist_map) for key, value in val._asdict().items()}
base_dict["__class__"] = klass_name
return base_dict
if isinstance(val, Enum):
klass_name = val.__class__.__name__
check.invariant(
klass_name in whitelist_map["types"]["enum"],
"Can only serialize whitelisted Enums, received {}".format(klass_name),
)
return {"__enum__": str(val)}
if isinstance(val, set):
return {"__set__": [_pack_value(item, whitelist_map) for item in val]}
if isinstance(val, frozenset):
return {"__frozenset__": [_pack_value(item, whitelist_map) for item in val]}
if isinstance(val, dict):
return {key: _pack_value(value, whitelist_map) for key, value in val.items()}
return val
def _serialize_dagster_namedtuple(nt, whitelist_map, **json_kwargs):
return seven.json.dumps(_pack_value(nt, whitelist_map), **json_kwargs)
def serialize_value(val):
return seven.json.dumps(_pack_value(val, whitelist_map=_WHITELIST_MAP))
def deserialize_value(val):
return _unpack_value(
seven.json.loads(check.str_param(val, "val")), whitelist_map=_WHITELIST_MAP,
)
def serialize_dagster_namedtuple(nt, **json_kwargs):
return _serialize_dagster_namedtuple(
check.tuple_param(nt, "nt"), whitelist_map=_WHITELIST_MAP, **json_kwargs
)
def unpack_value(val):
return _unpack_value(val, whitelist_map=_WHITELIST_MAP,)
def _unpack_value(val, whitelist_map):
if isinstance(val, list):
return [_unpack_value(i, whitelist_map) for i in val]
if isinstance(val, dict) and val.get("__class__"):
klass_name = val.pop("__class__")
if klass_name not in whitelist_map["types"]["tuple"]:
check.failed(
'Attempted to deserialize class "{}" which is not in the serdes whitelist.'.format(
klass_name
)
)
klass = whitelist_map["types"]["tuple"][klass_name]
if klass is None:
return None
unpacked_val = {key: _unpack_value(value, whitelist_map) for key, value in val.items()}
if klass_name in whitelist_map["persistence"]:
return klass.from_storage_dict(unpacked_val)
# Naively implements backwards compatibility by filtering arguments that aren't present in
# the constructor. If a property is present in the serialized object, but doesn't exist in
# the version of the class loaded into memory, that property will be completely ignored.
# The call to seven.get_args turns out to be pretty expensive -- we should probably turn
# to, e.g., manually managing the deprecated keys on the serdes constructor.
args_for_class = seven.get_args(klass)
filtered_val = {k: v for k, v in unpacked_val.items() if k in args_for_class}
return klass(**filtered_val)
if isinstance(val, dict) and val.get("__enum__"):
name, member = val["__enum__"].split(".")
return getattr(whitelist_map["types"]["enum"][name], member)
if isinstance(val, dict) and val.get("__set__") is not None:
return set([_unpack_value(item, whitelist_map) for item in val["__set__"]])
if isinstance(val, dict) and val.get("__frozenset__") is not None:
return frozenset([_unpack_value(item, whitelist_map) for item in val["__frozenset__"]])
if isinstance(val, dict):
return {key: _unpack_value(value, whitelist_map) for key, value in val.items()}
return val
def deserialize_json_to_dagster_namedtuple(json_str):
dagster_namedtuple = _deserialize_json_to_dagster_namedtuple(
check.str_param(json_str, "json_str"), whitelist_map=_WHITELIST_MAP
)
check.invariant(
isinstance(dagster_namedtuple, tuple),
"Output of deserialized json_str was not a namedtuple. Received type {}.".format(
type(dagster_namedtuple)
),
)
return dagster_namedtuple
def _deserialize_json_to_dagster_namedtuple(json_str, whitelist_map):
return _unpack_value(seven.json.loads(json_str), whitelist_map=whitelist_map)
def default_to_storage_value(value, whitelist_map):
base_dict = {key: _pack_value(value, whitelist_map) for key, value in value._asdict().items()}
base_dict["__class__"] = value.__class__.__name__
return base_dict
def default_from_storage_dict(cls, storage_dict):
return cls.__new__(cls, **storage_dict)
@whitelist_for_serdes
class ConfigurableClassData(
namedtuple("_ConfigurableClassData", "module_name class_name config_yaml")
):
"""Serializable tuple describing where to find a class and the config fragment that should
be used to instantiate it.
Users should not instantiate this class directly.
Classes intended to be serialized in this way should implement the
:py:class:`dagster.serdes.ConfigurableClass` mixin.
"""
def __new__(cls, module_name, class_name, config_yaml):
return super(ConfigurableClassData, cls).__new__(
cls,
check.str_param(module_name, "module_name"),
check.str_param(class_name, "class_name"),
check.str_param(config_yaml, "config_yaml"),
)
def info_str(self, prefix=""):
return (
"{p}module: {module}\n"
"{p}class: {cls}\n"
"{p}config:\n"
"{p} {config}".format(
p=prefix, module=self.module_name, cls=self.class_name, config=self.config_yaml
)
)
def rehydrate(self):
from dagster.core.errors import DagsterInvalidConfigError
from dagster.config.field import resolve_to_config_type
from dagster.config.validate import process_config
try:
module = importlib.import_module(self.module_name)
- except seven.ModuleNotFoundError:
+ except ModuleNotFoundError:
check.failed(
"Couldn't import module {module_name} when attempting to load the "
"configurable class {configurable_class}".format(
module_name=self.module_name,
configurable_class=self.module_name + "." + self.class_name,
)
)
try:
klass = getattr(module, self.class_name)
except AttributeError:
check.failed(
"Couldn't find class {class_name} in module when attempting to load the "
"configurable class {configurable_class}".format(
class_name=self.class_name,
configurable_class=self.module_name + "." + self.class_name,
)
)
if not issubclass(klass, ConfigurableClass):
raise check.CheckError(
klass,
"class {class_name} in module {module_name}".format(
class_name=self.class_name, module_name=self.module_name
),
ConfigurableClass,
)
config_dict = yaml.safe_load(self.config_yaml)
result = process_config(resolve_to_config_type(klass.config_type()), config_dict)
if not result.success:
raise DagsterInvalidConfigError(
"Errors whilst loading configuration for {}.".format(klass.config_type()),
result.errors,
config_dict,
)
return klass.from_config_value(self, result.value)
class ConfigurableClass(ABC):
"""Abstract mixin for classes that can be loaded from config.
This supports a powerful plugin pattern which avoids both a) a lengthy, hard-to-synchronize list
of conditional imports / optional extras_requires in dagster core and b) a magic directory or
file in which third parties can place plugin packages. Instead, the intention is to make, e.g.,
run storage, pluggable with a config chunk like:
.. code-block:: yaml
run_storage:
module: very_cool_package.run_storage
class: SplendidRunStorage
config:
magic_word: "quux"
This same pattern should eventually be viable for other system components, e.g. engines.
The ``ConfigurableClass`` mixin provides the necessary hooks for classes to be instantiated from
an instance of ``ConfigurableClassData``.
Pieces of the Dagster system which we wish to make pluggable in this way should consume a config
type such as:
.. code-block:: python
{'module': str, 'class': str, 'config': Field(Permissive())}
"""
@abstractproperty
def inst_data(self):
"""
Subclass must be able to return the inst_data as a property if it has been constructed
through the from_config_value code path.
"""
@classmethod
@abstractmethod
def config_type(cls):
"""dagster.ConfigType: The config type against which to validate a config yaml fragment
serialized in an instance of ``ConfigurableClassData``.
"""
@staticmethod
@abstractmethod
def from_config_value(inst_data, config_value):
"""New up an instance of the ConfigurableClass from a validated config value.
Called by ConfigurableClassData.rehydrate.
Args:
config_value (dict): The validated config value to use. Typically this should be the
``value`` attribute of a
:py:class:`~dagster.core.types.evaluator.evaluation.EvaluateValueResult`.
A common pattern is for the implementation to align the config_value with the signature
of the ConfigurableClass's constructor:
.. code-block:: python
@staticmethod
def from_config_value(inst_data, config_value):
return MyConfigurableClass(inst_data=inst_data, **config_value)
"""
diff --git a/python_modules/dagster/dagster/seven/__init__.py b/python_modules/dagster/dagster/seven/__init__.py
index 9f316b9b6..d3f11510a 100644
--- a/python_modules/dagster/dagster/seven/__init__.py
+++ b/python_modules/dagster/dagster/seven/__init__.py
@@ -1,325 +1,179 @@
"""Internal py2/3 compatibility library. A little more than six."""
import datetime
import inspect
import multiprocessing
import os
import shlex
import signal
import sys
import tempfile
import threading
import time
from contextlib import contextmanager
from types import MethodType
+from unittest import mock
import pendulum
from .json import JSONDecodeError, dump, dumps
from .temp_dir import get_system_temp_directory
IS_WINDOWS = os.name == "nt"
-if hasattr(inspect, "signature"):
- funcsigs = inspect
-else:
- import funcsigs
-
-# pylint: disable=no-name-in-module,import-error,no-member
-if sys.version_info < (3, 0):
- # Python 2 tempfile doesn't have tempfile.TemporaryDirectory
- import backports.tempfile
-
- TemporaryDirectory = backports.tempfile.TemporaryDirectory
-
-else:
- TemporaryDirectory = tempfile.TemporaryDirectory
-
-try:
- # pylint:disable=redefined-builtin,self-assigning-variable
- FileNotFoundError = FileNotFoundError
-except NameError:
- FileNotFoundError = IOError
-
-try:
- from functools import lru_cache
-except ImportError:
- from functools32 import lru_cache
-
-try:
- # pylint:disable=redefined-builtin,self-assigning-variable
- ModuleNotFoundError = ModuleNotFoundError
-except NameError:
- ModuleNotFoundError = ImportError
-
-try:
- import _thread as thread
-except ImportError:
- import thread
-
-try:
- from urllib.parse import urljoin, urlparse, urlunparse, quote_plus
-except ImportError:
- from urlparse import urljoin, urlparse, urlunparse
- from urllib import quote_plus
-
-try:
- from itertools import zip_longest
-except ImportError:
- from itertools import izip_longest as zip_longest
-
-if sys.version_info > (3,):
- from pathlib import Path # pylint: disable=import-error
-else:
- from pathlib2 import Path # pylint: disable=import-error
-
-if sys.version_info > (3,):
- from contextlib import ExitStack # pylint: disable=import-error
-else:
- from contextlib2 import ExitStack # pylint: disable=import-error
-
-if sys.version_info > (3,):
- from threading import Event as ThreadingEventType # pylint: disable=import-error
-else:
- from threading import _Event as ThreadingEventType # pylint: disable=import-error
-
-# Set execution method to spawn, to avoid fork and to have same behavior between platforms.
-# Older versions are stuck with whatever is the default on their platform (fork on
-# Unix-like and spawn on windows)
-#
-# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context
-if hasattr(multiprocessing, "get_context"):
- multiprocessing = multiprocessing.get_context("spawn")
+funcsigs = inspect
+multiprocessing = multiprocessing.get_context("spawn")
IS_WINDOWS = os.name == "nt"
# TODO implement a generic import by name -- see https://stackoverflow.com/questions/301134/how-to-import-a-module-given-its-name
# https://stackoverflow.com/a/67692/324449
def import_module_from_path(module_name, path_to_file):
- version = sys.version_info
- if version.major >= 3 and version.minor >= 5:
- import importlib.util
-
- spec = importlib.util.spec_from_file_location(module_name, path_to_file)
- if spec is None:
- raise Exception(
- "Can not import module {module_name} from path {path_to_file}, unable to load spec.".format(
- module_name=module_name, path_to_file=path_to_file
- )
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location(module_name, path_to_file)
+ if spec is None:
+ raise Exception(
+ "Can not import module {module_name} from path {path_to_file}, unable to load spec.".format(
+ module_name=module_name, path_to_file=path_to_file
)
- if sys.modules.get(spec.name) and sys.modules[spec.name].__file__ == os.path.abspath(
- spec.origin
- ):
- module = sys.modules[spec.name]
- else:
- module = importlib.util.module_from_spec(spec)
- sys.modules[spec.name] = module
- spec.loader.exec_module(module)
- elif version.major >= 3 and version.minor >= 3:
- from importlib.machinery import SourceFileLoader
-
- # pylint:disable=deprecated-method, no-value-for-parameter
- module = SourceFileLoader(module_name, path_to_file).load_module()
+ )
+ if sys.modules.get(spec.name) and sys.modules[spec.name].__file__ == os.path.abspath(
+ spec.origin
+ ):
+ module = sys.modules[spec.name]
else:
- from imp import load_source
-
- module = load_source(module_name, path_to_file)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
return module
def is_ascii(str_):
- if sys.version_info.major < 3:
- try:
- str_.decode("ascii")
- return True
- except UnicodeEncodeError:
- return False
- elif sys.version_info.major == 3 and sys.version_info.minor < 7:
+ if sys.version_info.major == 3 and sys.version_info.minor < 7:
try:
str_.encode("ascii")
return True
except UnicodeEncodeError:
return False
else:
return str_.isascii()
-if sys.version_info.major >= 3 and sys.version_info.minor >= 3:
- time_fn = time.perf_counter
-elif IS_WINDOWS:
- time_fn = time.clock
-else:
- time_fn = time.time
-
-try:
- from unittest import mock
-except ImportError:
- # Because this dependency is not encoded setup.py deliberately
- # (we do not want to override or conflict with our users mocks)
- # we never fail when importing this.
-
- # This will only be used within *our* test environment of which
- # we have total control
- try:
- import mock
- except ImportError:
- pass
+time_fn = time.perf_counter
def get_args(callable_):
- if sys.version_info.major >= 3:
- return [
- parameter.name
- for parameter in inspect.signature(callable_).parameters.values()
- if parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
- ]
- else:
- if inspect.isclass(callable_):
- if issubclass(callable_, tuple):
- arg_spec = inspect.getargspec( # pylint: disable=deprecated-method
- callable_.__new__
- )
- else:
- arg_spec = inspect.getargspec( # pylint: disable=deprecated-method
- callable_.__init__
- )
- else:
- arg_spec = inspect.getargspec(callable_) # pylint: disable=deprecated-method
- return arg_spec.args
+ return [
+ parameter.name
+ for parameter in inspect.signature(callable_).parameters.values()
+ if parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
+ ]
def wait_for_process(process, timeout=30):
# Using Popen.communicate instead of Popen.wait since the latter
# can deadlock, see https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait
if not timeout:
process.communicate()
elif sys.version_info.major >= 3:
process.communicate(timeout=timeout)
else:
timed_out_event = threading.Event()
def _wait_timeout():
timed_out_event.set()
process.kill()
timer = threading.Timer(timeout, _wait_timeout)
try:
timer.start()
process.wait()
finally:
timer.cancel()
if timed_out_event.is_set():
raise Exception("Timed out waiting for process to finish")
def kill_process(process):
if not isinstance(process, multiprocessing.Process):
raise Exception("invalid process argument passed to kill_process")
if sys.version_info >= (3, 7):
# Kill added in 3.7
process.kill()
else:
process.terminate()
# https://stackoverflow.com/a/58437485/324449
def is_module_available(module_name):
- if sys.version_info <= (3, 3):
- # python 3.3 and below
- import pkgutil
-
- loader = pkgutil.find_loader(module_name)
- elif sys.version_info >= (3, 4):
- # python 3.4 and above
- import importlib
+ # python 3.4 and above
+ import importlib
- loader = importlib.util.find_spec(module_name)
+ loader = importlib.util.find_spec(module_name)
return loader is not None
def builtin_print():
- if sys.version_info.major >= 3:
- return "builtins.print"
-
- else:
- return "sys.stdout"
+ return "builtins.print"
def print_single_line_str(single_line_str):
- if sys.version_info.major >= 3:
- return [
- mock.call(single_line_str),
- ]
- else:
- return [
- mock.call.write(single_line_str),
- mock.call.write("\n"),
- ]
+ return [
+ mock.call(single_line_str),
+ ]
def get_utc_timezone():
- if sys.version_info.major >= 3 and sys.version_info.minor >= 2:
- from datetime import timezone
-
- return timezone.utc
- else:
- import pytz
+ from datetime import timezone
- return pytz.utc
+ return timezone.utc
def get_current_datetime_in_utc():
return pendulum.now("UTC")
def get_timestamp_from_utc_datetime(utc_datetime):
if isinstance(utc_datetime, pendulum.Pendulum):
return utc_datetime.timestamp()
if utc_datetime.tzinfo != get_utc_timezone():
raise Exception("Must pass in a UTC timezone to compute UNIX timestamp")
- if sys.version_info.major >= 3 and sys.version_info.minor >= 2:
- return utc_datetime.timestamp()
- else:
- import pytz
-
- return (utc_datetime - datetime.datetime(1970, 1, 1, tzinfo=pytz.utc)).total_seconds()
+ return utc_datetime.timestamp()
def is_lambda(target):
return callable(target) and (hasattr(target, "__name__") and target.__name__ == "")
def is_function_or_decorator_instance_of(target, kls):
return inspect.isfunction(target) or (isinstance(target, kls) and hasattr(target, "__name__"))
def qualname_differs(target):
return hasattr(target, "__qualname__") and (target.__qualname__ != target.__name__)
def xplat_shlex_split(s):
if IS_WINDOWS:
return shlex.split(s, posix=False)
return shlex.split(s)
def get_import_error_message(import_error):
- if sys.version_info.major >= 3:
- return import_error.msg
- else:
- return str(import_error)
+ return import_error.msg
# Stand-in for contextlib.nullcontext, but available in python 3.6
@contextmanager
def nullcontext():
yield
diff --git a/python_modules/dagster/dagster/utils/__init__.py b/python_modules/dagster/dagster/utils/__init__.py
index 6d3349352..20275d525 100644
--- a/python_modules/dagster/dagster/utils/__init__.py
+++ b/python_modules/dagster/dagster/utils/__init__.py
@@ -1,529 +1,530 @@
import contextlib
import datetime
import errno
import functools
import inspect
import os
import pickle
import re
import signal
import socket
import subprocess
import sys
import tempfile
import threading
from collections import namedtuple
from enum import Enum
from warnings import warn
+import _thread as thread
import six
import yaml
from dagster import check, seven
from dagster.core.errors import DagsterExecutionInterruptedError, DagsterInvariantViolationError
-from dagster.seven import IS_WINDOWS, TemporaryDirectory, multiprocessing, thread
+from dagster.seven import IS_WINDOWS, multiprocessing
from dagster.seven.abc import Mapping
from six.moves import configparser
from .merger import merge_dicts
from .yaml_utils import load_yaml_from_glob_list, load_yaml_from_globs, load_yaml_from_path
if sys.version_info > (3,):
from pathlib import Path # pylint: disable=import-error
else:
from pathlib2 import Path # pylint: disable=import-error
EPOCH = datetime.datetime.utcfromtimestamp(0)
PICKLE_PROTOCOL = 4
DEFAULT_REPOSITORY_YAML_FILENAME = "repository.yaml"
DEFAULT_WORKSPACE_YAML_FILENAME = "workspace.yaml"
def file_relative_path(dunderfile, relative_path):
"""
This function is useful when one needs to load a file that is
relative to the position of the current file. (Such as when
you encode a configuration file path in source file and want
in runnable in any current working directory)
It is meant to be used like the following:
file_relative_path(__file__, 'path/relative/to/file')
"""
check.str_param(dunderfile, "dunderfile")
check.str_param(relative_path, "relative_path")
return os.path.join(os.path.dirname(dunderfile), relative_path)
def script_relative_path(file_path):
"""
Useful for testing with local files. Use a path relative to where the
test resides and this function will return the absolute path
of that file. Otherwise it will be relative to script that
ran the test
Note: this is function is very, very expensive (on the order of 1
millisecond per invocation) so this should only be used in performance
insensitive contexts. Prefer file_relative_path for anything with
performance constraints.
"""
# from http://bit.ly/2snyC6s
check.str_param(file_path, "file_path")
scriptdir = inspect.stack()[1][1]
return os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(scriptdir)), file_path))
# Adapted from https://github.com/okunishinishi/python-stringcase/blob/master/stringcase.py
def camelcase(string):
check.str_param(string, "string")
string = re.sub(r"^[\-_\.]", "", str(string))
if not string:
return string
return str(string[0]).upper() + re.sub(
r"[\-_\.\s]([a-z])", lambda matched: str(matched.group(1)).upper(), string[1:]
)
def ensure_single_item(ddict):
check.dict_param(ddict, "ddict")
check.param_invariant(len(ddict) == 1, "ddict", "Expected dict with single item")
return list(ddict.items())[0]
@contextlib.contextmanager
def pushd(path):
old_cwd = os.getcwd()
os.chdir(path)
try:
yield path
finally:
os.chdir(old_cwd)
def safe_isfile(path):
""""Backport of Python 3.8 os.path.isfile behavior.
This is intended to backport https://docs.python.org/dev/whatsnew/3.8.html#os-path. I'm not
sure that there are other ways to provoke this behavior on Unix other than the null byte,
but there are certainly other ways to do it on Windows. Afaict, we won't mask other
ValueErrors, and the behavior in the status quo ante is rough because we risk throwing an
unexpected, uncaught ValueError from very deep in our logic.
"""
try:
return os.path.isfile(path)
except ValueError:
return False
def mkdir_p(path):
try:
os.makedirs(path)
return path
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
class frozendict(dict):
def __readonly__(self, *args, **kwargs):
raise RuntimeError("Cannot modify ReadOnlyDict")
# https://docs.python.org/3/library/pickle.html#object.__reduce__
#
# For a dict, the default behavior for pickle is to iteratively call __setitem__ (see 5th item
# in __reduce__ tuple). Since we want to disable __setitem__ and still inherit dict, we
# override this behavior by defining __reduce__. We return the 3rd item in the tuple, which is
# passed to __setstate__, allowing us to restore the frozendict.
def __reduce__(self):
return (frozendict, (), dict(self))
def __setstate__(self, state):
self.__init__(state)
__setitem__ = __readonly__
__delitem__ = __readonly__
pop = __readonly__
popitem = __readonly__
clear = __readonly__
update = __readonly__
setdefault = __readonly__
del __readonly__
class frozenlist(list):
def __readonly__(self, *args, **kwargs):
raise RuntimeError("Cannot modify ReadOnlyList")
# https://docs.python.org/3/library/pickle.html#object.__reduce__
#
# Like frozendict, implement __reduce__ and __setstate__ to handle pickling.
# Otherwise, __setstate__ will be called to restore the frozenlist, causing
# a RuntimeError because frozenlist is not mutable.
def __reduce__(self):
return (frozenlist, (), list(self))
def __setstate__(self, state):
self.__init__(state)
__setitem__ = __readonly__
__delitem__ = __readonly__
append = __readonly__
clear = __readonly__
extend = __readonly__
insert = __readonly__
pop = __readonly__
remove = __readonly__
reverse = __readonly__
sort = __readonly__
def __hash__(self):
return hash(tuple(self))
def make_readonly_value(value):
if isinstance(value, list):
return frozenlist(list(map(make_readonly_value, value)))
elif isinstance(value, dict):
return frozendict({key: make_readonly_value(value) for key, value in value.items()})
else:
return value
def get_prop_or_key(elem, key):
if isinstance(elem, Mapping):
return elem.get(key)
else:
return getattr(elem, key)
def list_pull(alist, key):
return list(map(lambda elem: get_prop_or_key(elem, key), alist))
def all_none(kwargs):
for value in kwargs.values():
if value is not None:
return False
return True
def check_script(path, return_code=0):
try:
subprocess.check_output([sys.executable, path])
except subprocess.CalledProcessError as exc:
if return_code != 0:
if exc.returncode == return_code:
return
raise
def check_cli_execute_file_pipeline(path, pipeline_fn_name, env_file=None):
from dagster.core.test_utils import instance_for_test
with instance_for_test():
cli_cmd = [
sys.executable,
"-m",
"dagster",
"pipeline",
"execute",
"-f",
path,
"-a",
pipeline_fn_name,
]
if env_file:
cli_cmd.append("-c")
cli_cmd.append(env_file)
try:
subprocess.check_output(cli_cmd)
except subprocess.CalledProcessError as cpe:
print(cpe) # pylint: disable=print-call
raise cpe
def safe_tempfile_path_unmanaged():
# This gets a valid temporary file path in the safest possible way, although there is still no
# guarantee that another process will not create a file at this path. The NamedTemporaryFile is
# deleted when the context manager exits and the file object is closed.
#
# This is preferable to using NamedTemporaryFile as a context manager and passing the name
# attribute of the file object around because NamedTemporaryFiles cannot be opened a second time
# if already open on Windows NT or later:
# https://docs.python.org/3.8/library/tempfile.html#tempfile.NamedTemporaryFile
# https://github.com/dagster-io/dagster/issues/1582
with tempfile.NamedTemporaryFile() as fd:
path = fd.name
return Path(path).as_posix()
@contextlib.contextmanager
def safe_tempfile_path():
try:
path = safe_tempfile_path_unmanaged()
yield path
finally:
if os.path.exists(path):
os.unlink(path)
def ensure_gen(thing_or_gen):
if not inspect.isgenerator(thing_or_gen):
def _gen_thing():
yield thing_or_gen
return _gen_thing()
return thing_or_gen
def ensure_dir(file_path):
try:
os.makedirs(file_path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def ensure_file(path):
ensure_dir(os.path.dirname(path))
if not os.path.exists(path):
touch_file(path)
def touch_file(path):
ensure_dir(os.path.dirname(path))
with open(path, "a"):
os.utime(path, None)
def _kill_on_event(termination_event):
termination_event.wait()
send_interrupt()
def send_interrupt():
if IS_WINDOWS:
# This will raise a KeyboardInterrupt in python land - meaning this wont be able to
# interrupt things like sleep()
thread.interrupt_main()
else:
# If on unix send an os level signal to interrupt any situation we may be stuck in
os.kill(os.getpid(), signal.SIGINT)
# Function to be invoked by daemon thread in processes which seek to be cancellable.
# The motivation for this approach is to be able to exit cleanly on Windows. An alternative
# path is to change how the processes are opened and send CTRL_BREAK signals, which at
# the time of authoring seemed a more costly approach.
#
# Reading for the curious:
# * https://stackoverflow.com/questions/35772001/how-to-handle-the-signal-in-python-on-windows-machine
# * https://stefan.sofa-rockers.org/2013/08/15/handling-sub-process-hierarchies-python-linux-os-x/
def start_termination_thread(termination_event):
check.inst_param(termination_event, "termination_event", ttype=type(multiprocessing.Event()))
int_thread = threading.Thread(
target=_kill_on_event, args=(termination_event,), name="kill-on-event"
)
int_thread.daemon = True
int_thread.start()
# Executes the next() function within an instance of the supplied context manager class
# (leaving the context before yielding each result)
def iterate_with_context(context, iterator):
while True:
# Allow interrupts during user code so that we can terminate slow/hanging steps
with context():
try:
next_output = next(iterator)
except StopIteration:
return
yield next_output
def datetime_as_float(dt):
check.inst_param(dt, "dt", datetime.datetime)
return float((dt - EPOCH).total_seconds())
# hashable frozen string to string dict
class frozentags(frozendict):
def __init__(self, *args, **kwargs):
super(frozentags, self).__init__(*args, **kwargs)
check.dict_param(self, "self", key_type=str, value_type=str)
def __hash__(self):
return hash(tuple(sorted(self.items())))
def updated_with(self, new_tags):
check.dict_param(new_tags, "new_tags", key_type=str, value_type=str)
updated = dict(self)
for key, value in new_tags.items():
updated[key] = value
return frozentags(updated)
class EventGenerationManager:
""" Utility class that wraps an event generator function, that also yields a single instance of
a typed object. All events yielded before the typed object are yielded through the method
`generate_setup_events` and all events yielded after the typed object are yielded through the
method `generate_teardown_events`.
This is used to help replace the context managers used in pipeline initialization with
generators so that we can begin emitting initialization events AND construct a pipeline context
object, while managing explicit setup/teardown.
This does require calling `generate_setup_events` AND `generate_teardown_events` in order to
get the typed object.
"""
def __init__(self, generator, object_cls, require_object=True):
self.generator = check.generator(generator)
self.object_cls = check.type_param(object_cls, "object_cls")
self.require_object = check.bool_param(require_object, "require_object")
self.object = None
self.did_setup = False
self.did_teardown = False
def generate_setup_events(self):
self.did_setup = True
try:
while self.object is None:
obj = next(self.generator)
if isinstance(obj, self.object_cls):
self.object = obj
else:
yield obj
except StopIteration:
if self.require_object:
check.inst_param(
self.object,
"self.object",
self.object_cls,
"generator never yielded object of type {}".format(self.object_cls.__name__),
)
def get_object(self):
if not self.did_setup:
check.failed("Called `get_object` before `generate_setup_events`")
return self.object
def generate_teardown_events(self):
self.did_teardown = True
if self.object:
yield from self.generator
def utc_datetime_from_timestamp(timestamp):
tz = None
if sys.version_info.major >= 3 and sys.version_info.minor >= 2:
from datetime import timezone
tz = timezone.utc
else:
import pytz
tz = pytz.utc
return datetime.datetime.fromtimestamp(timestamp, tz=tz)
def is_enum_value(value):
return False if value is None else issubclass(value.__class__, Enum)
def git_repository_root():
return six.ensure_str(subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).strip())
def segfault():
"""Reliable cross-Python version segfault.
https://bugs.python.org/issue1215#msg143236
"""
import ctypes
ctypes.string_at(0)
def find_free_port():
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
@contextlib.contextmanager
def alter_sys_path(to_add, to_remove):
to_restore = [path for path in sys.path]
# remove paths
for path in to_remove:
if path in sys.path:
sys.path.remove(path)
# add paths
for path in to_add:
sys.path.insert(0, path)
try:
yield
finally:
sys.path = to_restore
@contextlib.contextmanager
def restore_sys_modules():
sys_modules = {k: v for k, v in sys.modules.items()}
try:
yield
finally:
to_delete = set(sys.modules) - set(sys_modules)
for key in to_delete:
del sys.modules[key]
def process_is_alive(pid):
if IS_WINDOWS:
import psutil # pylint: disable=import-error
return psutil.pid_exists(pid=pid)
else:
try:
subprocess.check_output(["ps", str(pid)])
except subprocess.CalledProcessError as exc:
assert exc.returncode == 1
return False
return True
def compose(*args):
"""
Compose python functions args such that compose(f, g)(x) is equivalent to f(g(x)).
"""
# reduce using functional composition over all the arguments, with the identity function as
# initializer
return functools.reduce(lambda f, g: lambda x: f(g(x)), args, lambda x: x)
def dict_without_keys(ddict, *keys):
return {key: value for key, value in ddict.items() if key not in set(keys)}
diff --git a/python_modules/dagster/dagster/utils/net.py b/python_modules/dagster/dagster/utils/net.py
index 9c4e17cbb..99381d914 100644
--- a/python_modules/dagster/dagster/utils/net.py
+++ b/python_modules/dagster/dagster/utils/net.py
@@ -1,46 +1,46 @@
import socket
import struct
+from urllib.parse import urlparse
from dagster import check
-from dagster.seven import urlparse
def is_loopback(host):
addr_info = socket.getaddrinfo(host, None, socket.AF_INET, socket.SOCK_STREAM)[0]
sockaddr = addr_info[4][0]
return struct.unpack("!I", socket.inet_aton(sockaddr))[0] >> (32 - 8) == 127
def is_local_uri(address):
"""Determine if an address (full URI, DNS or IP) is local.
Args:
address (str): The URI or IP address to evaluate
Returns:
bool: Whether the address appears to represent a local interface.
"""
check.str_param(address, "address")
# Handle the simple cases with no protocol specified. Per
# https://docs.python.org/3/library/urllib.parse.html, urlparse recognizes a netloc only if it
# is properly introduced by '//' (e.g. has a scheme specified).
hostname = urlparse(address).hostname if "//" in address else address.split(":")[0]
# Empty protocol only specified as URI, e.g. "rpc://"
if hostname is None:
return True
# Get the IPv4 address from the hostname. Returns a triple (hostname, aliaslist, ipaddrlist), so
# we grab the 0th element of ipaddrlist.
try:
ip_addr_str = socket.gethostbyname_ex(hostname)[-1][0]
except socket.gaierror:
# Invalid hostname, so assume not local host
return False
# Special case this since it isn't technically loopback
if ip_addr_str == "0.0.0.0":
return True
return is_loopback(ip_addr_str)
diff --git a/python_modules/dagster/dagster/utils/test/__init__.py b/python_modules/dagster/dagster/utils/test/__init__.py
index 7ad5a0fe0..c620d31ce 100644
--- a/python_modules/dagster/dagster/utils/test/__init__.py
+++ b/python_modules/dagster/dagster/utils/test/__init__.py
@@ -1,419 +1,419 @@
import os
import shutil
+import tempfile
import uuid
from collections import defaultdict
from contextlib import contextmanager
# top-level include is dangerous in terms of incurring circular deps
from dagster import (
DagsterInvariantViolationError,
DependencyDefinition,
Failure,
ModeDefinition,
PipelineDefinition,
RepositoryDefinition,
SolidInvocation,
TypeCheck,
check,
execute_pipeline,
lambda_solid,
- seven,
)
from dagster.core.definitions.logger import LoggerDefinition
from dagster.core.definitions.pipeline_base import InMemoryPipeline
from dagster.core.definitions.resource import ScopedResourcesBuilder
from dagster.core.definitions.solid import NodeDefinition
from dagster.core.execution.api import create_execution_plan, scoped_pipeline_context
from dagster.core.execution.context_creation_pipeline import (
SystemPipelineExecutionContext,
construct_execution_context_data,
create_context_creation_data,
create_executor,
create_log_manager,
)
from dagster.core.instance import DagsterInstance
from dagster.core.scheduler import Scheduler
from dagster.core.scheduler.scheduler import DagsterScheduleDoesNotExist, DagsterSchedulerError
from dagster.core.snap import snapshot_from_execution_plan
from dagster.core.storage.file_manager import LocalFileManager
from dagster.core.storage.pipeline_run import PipelineRun
from dagster.core.types.dagster_type import resolve_dagster_type
from dagster.core.utility_solids import define_stub_solid
from dagster.core.utils import make_new_run_id
from dagster.serdes import ConfigurableClass
# pylint: disable=unused-import
from ..temp_file import (
get_temp_dir,
get_temp_file_handle,
get_temp_file_handle_with_data,
get_temp_file_name,
get_temp_file_name_with_data,
get_temp_file_names,
)
from ..typing_api import is_typing_type
def create_test_pipeline_execution_context(logger_defs=None):
from dagster.core.storage.intermediate_storage import build_in_mem_intermediates_storage
loggers = check.opt_dict_param(
logger_defs, "logger_defs", key_type=str, value_type=LoggerDefinition
)
mode_def = ModeDefinition(logger_defs=loggers)
pipeline_def = PipelineDefinition(
name="test_legacy_context", solid_defs=[], mode_defs=[mode_def]
)
run_config = {"loggers": {key: {} for key in loggers}}
pipeline_run = PipelineRun(pipeline_name="test_legacy_context", run_config=run_config)
instance = DagsterInstance.ephemeral()
execution_plan = create_execution_plan(pipeline=pipeline_def, run_config=run_config)
creation_data = create_context_creation_data(execution_plan, run_config, pipeline_run, instance)
log_manager = create_log_manager(creation_data)
scoped_resources_builder = ScopedResourcesBuilder()
executor = create_executor(creation_data)
return SystemPipelineExecutionContext(
construct_execution_context_data(
context_creation_data=creation_data,
scoped_resources_builder=scoped_resources_builder,
intermediate_storage=build_in_mem_intermediates_storage(pipeline_run.run_id),
log_manager=log_manager,
retries=executor.retries,
raise_on_error=True,
),
executor=executor,
log_manager=log_manager,
)
def _dep_key_of(solid):
return SolidInvocation(solid.definition.name, solid.name)
def build_pipeline_with_input_stubs(pipeline_def, inputs):
check.inst_param(pipeline_def, "pipeline_def", PipelineDefinition)
check.dict_param(inputs, "inputs", key_type=str, value_type=dict)
deps = defaultdict(dict)
for solid_name, dep_dict in pipeline_def.dependencies.items():
for input_name, dep in dep_dict.items():
deps[solid_name][input_name] = dep
stub_solid_defs = []
for solid_name, input_dict in inputs.items():
if not pipeline_def.has_solid_named(solid_name):
raise DagsterInvariantViolationError(
(
"You are injecting an input value for solid {solid_name} "
"into pipeline {pipeline_name} but that solid was not found"
).format(solid_name=solid_name, pipeline_name=pipeline_def.name)
)
solid = pipeline_def.solid_named(solid_name)
for input_name, input_value in input_dict.items():
stub_solid_def = define_stub_solid(
"__stub_{solid_name}_{input_name}".format(
solid_name=solid_name, input_name=input_name
),
input_value,
)
stub_solid_defs.append(stub_solid_def)
deps[_dep_key_of(solid)][input_name] = DependencyDefinition(stub_solid_def.name)
return PipelineDefinition(
name=pipeline_def.name + "_stubbed",
solid_defs=pipeline_def.top_level_solid_defs + stub_solid_defs,
mode_defs=pipeline_def.mode_definitions,
dependencies=deps,
)
def execute_solids_within_pipeline(
pipeline_def,
solid_names,
inputs=None,
run_config=None,
mode=None,
preset=None,
tags=None,
instance=None,
):
"""Execute a set of solids within an existing pipeline.
Intended to support tests. Input values may be passed directly.
Args:
pipeline_def (PipelineDefinition): The pipeline within which to execute the solid.
solid_names (FrozenSet[str]): A set of the solid names, or the aliased solids, to execute.
inputs (Optional[Dict[str, Dict[str, Any]]]): A dict keyed on solid names, whose values are
dicts of input names to input values, used to pass input values to the solids directly.
You may also use the ``run_config`` to configure any inputs that are configurable.
run_config (Optional[dict]): The environment configuration that parameterized this
execution, as a dict.
mode (Optional[str]): The name of the pipeline mode to use. You may not set both ``mode``
and ``preset``.
preset (Optional[str]): The name of the pipeline preset to use. You may not set both
``mode`` and ``preset``.
tags (Optional[Dict[str, Any]]): Arbitrary key-value pairs that will be added to pipeline
logs.
instance (Optional[DagsterInstance]): The instance to execute against. If this is ``None``,
an ephemeral instance will be used, and no artifacts will be persisted from the run.
Returns:
Dict[str, Union[CompositeSolidExecutionResult, SolidExecutionResult]]: The results of
executing the solids, keyed by solid name.
"""
check.inst_param(pipeline_def, "pipeline_def", PipelineDefinition)
check.set_param(solid_names, "solid_names", of_type=str)
inputs = check.opt_dict_param(inputs, "inputs", key_type=str, value_type=dict)
sub_pipeline = pipeline_def.get_pipeline_subset_def(solid_names)
stubbed_pipeline = build_pipeline_with_input_stubs(sub_pipeline, inputs)
result = execute_pipeline(
stubbed_pipeline,
run_config=run_config,
mode=mode,
preset=preset,
tags=tags,
instance=instance,
)
return {sr.solid.name: sr for sr in result.solid_result_list}
def execute_solid_within_pipeline(
pipeline_def,
solid_name,
inputs=None,
run_config=None,
mode=None,
preset=None,
tags=None,
instance=None,
):
"""Execute a single solid within an existing pipeline.
Intended to support tests. Input values may be passed directly.
Args:
pipeline_def (PipelineDefinition): The pipeline within which to execute the solid.
solid_name (str): The name of the solid, or the aliased solid, to execute.
inputs (Optional[Dict[str, Any]]): A dict of input names to input values, used to
pass input values to the solid directly. You may also use the ``run_config`` to
configure any inputs that are configurable.
run_config (Optional[dict]): The environment configuration that parameterized this
execution, as a dict.
mode (Optional[str]): The name of the pipeline mode to use. You may not set both ``mode``
and ``preset``.
preset (Optional[str]): The name of the pipeline preset to use. You may not set both
``mode`` and ``preset``.
tags (Optional[Dict[str, Any]]): Arbitrary key-value pairs that will be added to pipeline
logs.
instance (Optional[DagsterInstance]): The instance to execute against. If this is ``None``,
an ephemeral instance will be used, and no artifacts will be persisted from the run.
Returns:
Union[CompositeSolidExecutionResult, SolidExecutionResult]: The result of executing the
solid.
"""
return execute_solids_within_pipeline(
pipeline_def,
solid_names={solid_name},
inputs={solid_name: inputs} if inputs else None,
run_config=run_config,
mode=mode,
preset=preset,
tags=tags,
instance=instance,
)[solid_name]
@contextmanager
def yield_empty_pipeline_context(run_id=None, instance=None):
pipeline = InMemoryPipeline(PipelineDefinition([]))
pipeline_def = pipeline.get_definition()
instance = check.opt_inst_param(
instance, "instance", DagsterInstance, default=DagsterInstance.ephemeral()
)
execution_plan = create_execution_plan(pipeline)
pipeline_run = instance.create_run(
pipeline_name="",
run_id=run_id,
run_config=None,
mode=None,
solids_to_execute=None,
step_keys_to_execute=None,
status=None,
tags=None,
root_run_id=None,
parent_run_id=None,
pipeline_snapshot=pipeline_def.get_pipeline_snapshot(),
execution_plan_snapshot=snapshot_from_execution_plan(
execution_plan, pipeline_def.get_pipeline_snapshot_id()
),
parent_pipeline_snapshot=pipeline_def.get_parent_pipeline_snapshot(),
)
with scoped_pipeline_context(execution_plan, {}, pipeline_run, instance) as context:
yield context
def execute_solid(
solid_def, mode_def=None, input_values=None, tags=None, run_config=None, raise_on_error=True,
):
"""Execute a single solid in an ephemeral pipeline.
Intended to support unit tests. Input values may be passed directly, and no pipeline need be
specified -- an ephemeral pipeline will be constructed.
Args:
solid_def (SolidDefinition): The solid to execute.
mode_def (Optional[ModeDefinition]): The mode within which to execute the solid. Use this
if, e.g., custom resources, loggers, or executors are desired.
input_values (Optional[Dict[str, Any]]): A dict of input names to input values, used to
pass inputs to the solid directly. You may also use the ``run_config`` to
configure any inputs that are configurable.
tags (Optional[Dict[str, Any]]): Arbitrary key-value pairs that will be added to pipeline
logs.
run_config (Optional[dict]): The environment configuration that parameterized this
execution, as a dict.
raise_on_error (Optional[bool]): Whether or not to raise exceptions when they occur.
Defaults to ``True``, since this is the most useful behavior in test.
Returns:
Union[CompositeSolidExecutionResult, SolidExecutionResult]: The result of executing the
solid.
"""
check.inst_param(solid_def, "solid_def", NodeDefinition)
check.opt_inst_param(mode_def, "mode_def", ModeDefinition)
input_values = check.opt_dict_param(input_values, "input_values", key_type=str)
solid_defs = [solid_def]
def create_value_solid(input_name, input_value):
@lambda_solid(name=input_name)
def input_solid():
return input_value
return input_solid
dependencies = defaultdict(dict)
for input_name, input_value in input_values.items():
dependencies[solid_def.name][input_name] = DependencyDefinition(input_name)
solid_defs.append(create_value_solid(input_name, input_value))
result = execute_pipeline(
PipelineDefinition(
name="ephemeral_{}_solid_pipeline".format(solid_def.name),
solid_defs=solid_defs,
dependencies=dependencies,
mode_defs=[mode_def] if mode_def else None,
),
run_config=run_config,
mode=mode_def.name if mode_def else None,
tags=tags,
raise_on_error=raise_on_error,
)
return result.result_for_handle(solid_def.name)
def check_dagster_type(dagster_type, value):
"""Test a custom Dagster type.
Args:
dagster_type (Any): The Dagster type to test. Should be one of the
:ref:`built-in types `, a dagster type explicitly constructed with
:py:func:`as_dagster_type`, :py:func:`@usable_as_dagster_type `, or
:py:func:`PythonObjectDagsterType`, or a Python type.
value (Any): The runtime value to test.
Returns:
TypeCheck: The result of the type check.
Examples:
.. code-block:: python
assert check_dagster_type(Dict[Any, Any], {'foo': 'bar'}).success
"""
if is_typing_type(dagster_type):
raise DagsterInvariantViolationError(
(
"Must pass in a type from dagster module. You passed {dagster_type} "
"which is part of python's typing module."
).format(dagster_type=dagster_type)
)
dagster_type = resolve_dagster_type(dagster_type)
with yield_empty_pipeline_context() as pipeline_context:
context = pipeline_context.for_type(dagster_type)
try:
type_check = dagster_type.type_check(context, value)
except Failure as failure:
return TypeCheck(success=False, description=failure.description)
if not isinstance(type_check, TypeCheck):
raise DagsterInvariantViolationError(
"Type checks can only return TypeCheck. Type {type_name} returned {value}.".format(
type_name=dagster_type.display_name, value=repr(type_check)
)
)
return type_check
@contextmanager
def copy_directory(src):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
dst = os.path.join(temp_dir, os.path.basename(src))
shutil.copytree(src, dst)
yield dst
class FilesystemTestScheduler(Scheduler, ConfigurableClass):
"""This class is used in dagster core and dagster_graphql to test the scheduler's interactions
with schedule storage, which are implemented in the methods defined on the base Scheduler class.
Therefore, the following methods used to actually schedule jobs (e.g. create and remove cron jobs
on a cron tab) are left unimplemented.
"""
def __init__(self, artifacts_dir, inst_data=None):
check.str_param(artifacts_dir, "artifacts_dir")
self._artifacts_dir = artifacts_dir
self._inst_data = inst_data
@property
def inst_data(self):
return self._inst_data
@classmethod
def config_type(cls):
return {"base_dir": str}
@staticmethod
def from_config_value(inst_data, config_value):
return FilesystemTestScheduler(artifacts_dir=config_value["base_dir"], inst_data=inst_data)
def debug_info(self):
return ""
def start_schedule(self, instance, external_schedule):
pass
def stop_schedule(self, instance, schedule_origin_id):
pass
def running_schedule_count(self, instance, schedule_origin_id):
return 0
def get_logs_path(self, _instance, schedule_origin_id):
check.str_param(schedule_origin_id, "schedule_origin_id")
return os.path.join(self._artifacts_dir, "logs", schedule_origin_id, "scheduler.log")
def wipe(self, instance):
pass
diff --git a/python_modules/dagster/dagster/utils/test/postgres_instance.py b/python_modules/dagster/dagster/utils/test/postgres_instance.py
index b2dd53d58..e9498b6dd 100644
--- a/python_modules/dagster/dagster/utils/test/postgres_instance.py
+++ b/python_modules/dagster/dagster/utils/test/postgres_instance.py
@@ -1,233 +1,234 @@
import os
import subprocess
+import tempfile
import warnings
from contextlib import contextmanager
import pytest
-from dagster import check, file_relative_path, seven
+from dagster import check, file_relative_path
from dagster.core.test_utils import instance_for_test_tempdir
from dagster.utils import merge_dicts
BUILDKITE = bool(os.getenv("BUILDKITE"))
@contextmanager
def postgres_instance_for_test(dunder_file, container_name, overrides=None):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with TestPostgresInstance.docker_service_up_or_skip(
file_relative_path(dunder_file, "docker-compose.yml"), container_name,
) as pg_conn_string:
TestPostgresInstance.clean_run_storage(pg_conn_string)
TestPostgresInstance.clean_event_log_storage(pg_conn_string)
TestPostgresInstance.clean_schedule_storage(pg_conn_string)
with instance_for_test_tempdir(
temp_dir,
overrides=merge_dicts(
{
"run_storage": {
"module": "dagster_postgres.run_storage.run_storage",
"class": "PostgresRunStorage",
"config": {"postgres_url": pg_conn_string},
},
"event_log_storage": {
"module": "dagster_postgres.event_log.event_log",
"class": "PostgresEventLogStorage",
"config": {"postgres_url": pg_conn_string},
},
"schedule_storage": {
"module": "dagster_postgres.schedule_storage.schedule_storage",
"class": "PostgresScheduleStorage",
"config": {"postgres_url": pg_conn_string},
},
},
overrides if overrides else {},
),
) as instance:
yield instance
class TestPostgresInstance:
@staticmethod
def dagster_postgres_installed():
try:
import dagster_postgres # pylint: disable=unused-import
except ImportError:
return False
return True
@staticmethod
def get_hostname(env_name="POSTGRES_TEST_DB_HOST"):
# In buildkite we get the ip address from this variable (see buildkite code for commentary)
# Otherwise assume local development and assume localhost
return os.environ.get(env_name, "localhost")
@staticmethod
def conn_string(**kwargs):
check.invariant(
TestPostgresInstance.dagster_postgres_installed(),
"dagster_postgres must be installed to test with postgres",
)
from dagster_postgres.utils import get_conn_string # pylint: disable=import-error
return get_conn_string(
**dict(
dict(
username="test",
password="test",
hostname=TestPostgresInstance.get_hostname(),
db_name="test",
),
**kwargs,
)
)
@staticmethod
def clean_run_storage(conn_string):
check.invariant(
TestPostgresInstance.dagster_postgres_installed(),
"dagster_postgres must be installed to test with postgres",
)
from dagster_postgres.run_storage import PostgresRunStorage # pylint: disable=import-error
storage = PostgresRunStorage.create_clean_storage(conn_string)
assert storage
return storage
@staticmethod
def clean_event_log_storage(conn_string):
check.invariant(
TestPostgresInstance.dagster_postgres_installed(),
"dagster_postgres must be installed to test with postgres",
)
from dagster_postgres.event_log import ( # pylint: disable=import-error
PostgresEventLogStorage,
)
storage = PostgresEventLogStorage.create_clean_storage(conn_string)
assert storage
return storage
@staticmethod
def clean_schedule_storage(conn_string):
check.invariant(
TestPostgresInstance.dagster_postgres_installed(),
"dagster_postgres must be installed to test with postgres",
)
from dagster_postgres.schedule_storage.schedule_storage import ( # pylint: disable=import-error
PostgresScheduleStorage,
)
storage = PostgresScheduleStorage.create_clean_storage(conn_string)
assert storage
return storage
@staticmethod
@contextmanager
def docker_service_up(docker_compose_file, service_name, conn_args=None):
check.invariant(
TestPostgresInstance.dagster_postgres_installed(),
"dagster_postgres must be installed to test with postgres",
)
check.str_param(service_name, "service_name")
check.str_param(docker_compose_file, "docker_compose_file")
check.invariant(
os.path.isfile(docker_compose_file), "docker_compose_file must specify a valid file"
)
conn_args = check.opt_dict_param(conn_args, "conn_args") if conn_args else {}
from dagster_postgres.utils import wait_for_connection # pylint: disable=import-error
if BUILDKITE:
yield TestPostgresInstance.conn_string(
**conn_args
) # buildkite docker is handled in pipeline setup
return
try:
subprocess.check_output(
["docker-compose", "-f", docker_compose_file, "stop", service_name]
)
subprocess.check_output(
["docker-compose", "-f", docker_compose_file, "rm", "-f", service_name]
)
except subprocess.CalledProcessError:
pass
try:
subprocess.check_output(
["docker-compose", "-f", docker_compose_file, "up", "-d", service_name],
stderr=subprocess.STDOUT, # capture STDERR for error handling
)
except subprocess.CalledProcessError as ex:
err_text = ex.output.decode()
raise PostgresDockerError(
"Failed to launch docker container(s) via docker-compose: {}".format(err_text), ex,
)
conn_str = TestPostgresInstance.conn_string(**conn_args)
wait_for_connection(conn_str, retry_limit=10, retry_wait=3)
yield conn_str
try:
subprocess.check_output(
["docker-compose", "-f", docker_compose_file, "stop", service_name]
)
subprocess.check_output(
["docker-compose", "-f", docker_compose_file, "rm", "-f", service_name]
)
except subprocess.CalledProcessError:
pass
@staticmethod
@contextmanager
def docker_service_up_or_skip(docker_compose_file, service_name, conn_args=None):
try:
with TestPostgresInstance.docker_service_up(
docker_compose_file, service_name, conn_args
) as conn_str:
yield conn_str
except PostgresDockerError as ex:
warnings.warn(
"Error launching Dockerized Postgres: {}".format(ex), RuntimeWarning, stacklevel=3
)
pytest.skip("Skipping due to error launching Dockerized Postgres: {}".format(ex))
def is_postgres_running(service_name):
check.str_param(service_name, "service_name")
try:
output = subprocess.check_output(
[
"docker",
"container",
"ps",
"-f",
"name={}".format(service_name),
"-f",
"status=running",
],
stderr=subprocess.STDOUT, # capture STDERR for error handling
)
except subprocess.CalledProcessError as ex:
lines = ex.output.decode().split("\n")
if len(lines) == 2 and "Cannot connect to the Docker daemon" in lines[0]:
raise PostgresDockerError("Cannot connect to the Docker daemon", ex)
else:
raise PostgresDockerError(
"Could not verify postgres container was running as expected", ex
)
decoded = output.decode()
lines = decoded.split("\n")
# header, one line for container, trailing \n
# if container is found, service_name should appear at the end of the second line of output
return len(lines) == 3 and lines[1].endswith(service_name)
class PostgresDockerError(Exception):
def __init__(self, message, subprocess_error):
super(PostgresDockerError, self).__init__(check.opt_str_param(message, "message"))
self.subprocess_error = check.inst_param(
subprocess_error, "subprocess_error", subprocess.CalledProcessError
)
diff --git a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_cli_commands.py b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_cli_commands.py
index 184ca0a6c..db189b12b 100644
--- a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_cli_commands.py
+++ b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_cli_commands.py
@@ -1,612 +1,612 @@
from __future__ import print_function
import json
import os
import string
import sys
+import tempfile
from contextlib import contextmanager
import mock
import pytest
from click.testing import CliRunner
from dagster import (
PartitionSetDefinition,
PresetDefinition,
ScheduleDefinition,
lambda_solid,
pipeline,
repository,
- seven,
solid,
)
from dagster.cli import ENV_PREFIX, cli
from dagster.cli.pipeline import pipeline_execute_command
from dagster.cli.run import run_list_command, run_wipe_command
from dagster.core.definitions.decorators.sensor import sensor
from dagster.core.definitions.sensor import RunRequest
from dagster.core.test_utils import instance_for_test, instance_for_test_tempdir
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.grpc.server import GrpcServerProcess
from dagster.utils import file_relative_path, merge_dicts
from dagster.version import __version__
def no_print(_):
return None
@lambda_solid
def do_something():
return 1
@lambda_solid
def do_input(x):
return x
@pipeline(
name="foo", preset_defs=[PresetDefinition(name="test", tags={"foo": "bar"}),],
)
def foo_pipeline():
do_input(do_something())
def define_foo_pipeline():
return foo_pipeline
@pipeline(name="baz", description="Not much tbh")
def baz_pipeline():
do_input()
def not_a_repo_or_pipeline_fn():
return "kdjfkjdf"
not_a_repo_or_pipeline = 123
@pipeline
def partitioned_scheduled_pipeline():
do_something()
def define_bar_schedules():
partition_set = PartitionSetDefinition(
name="scheduled_partitions",
pipeline_name="partitioned_scheduled_pipeline",
partition_fn=lambda: string.digits,
)
return {
"foo_schedule": ScheduleDefinition(
"foo_schedule", cron_schedule="* * * * *", pipeline_name="test_pipeline", run_config={},
),
"partitioned_schedule": partition_set.create_schedule_definition(
schedule_name="partitioned_schedule", cron_schedule="* * * * *"
),
}
def define_bar_partitions():
def error_name():
raise Exception("womp womp")
def error_config(_):
raise Exception("womp womp")
return {
"baz_partitions": PartitionSetDefinition(
name="baz_partitions",
pipeline_name="baz",
partition_fn=lambda: string.digits,
run_config_fn_for_partition=lambda partition: {
"solids": {"do_input": {"inputs": {"x": {"value": partition.value}}}}
},
),
"error_name_partitions": PartitionSetDefinition(
name="error_name_partitions", pipeline_name="baz", partition_fn=error_name,
),
"error_config_partitions": PartitionSetDefinition(
name="error_config_partitions", pipeline_name="baz", partition_fn=error_config,
),
}
def define_bar_sensors():
@sensor(pipeline_name="baz")
def foo_sensor(context):
run_config = {"foo": "FOO"}
if context.last_completion_time:
run_config["since"] = context.last_completion_time
return RunRequest(run_key=None, run_config=run_config)
return {"foo_sensor": foo_sensor}
@repository
def bar():
return {
"pipelines": {
"foo": foo_pipeline,
"baz": baz_pipeline,
"partitioned_scheduled_pipeline": partitioned_scheduled_pipeline,
},
"schedules": define_bar_schedules(),
"partition_sets": define_bar_partitions(),
"jobs": define_bar_sensors(),
}
@solid
def spew(context):
context.log.info("HELLO WORLD")
@solid
def fail(context):
raise Exception("I AM SUPPOSED TO FAIL")
@pipeline
def stdout_pipeline():
spew()
@pipeline
def stderr_pipeline():
fail()
@contextmanager
def _default_cli_test_instance_tempdir(temp_dir, overrides=None):
default_overrides = {
"run_launcher": {"module": "dagster.core.test_utils", "class": "MockedRunLauncher",}
}
with instance_for_test_tempdir(
temp_dir, overrides=merge_dicts(default_overrides, (overrides if overrides else {}))
) as instance:
with mock.patch("dagster.core.instance.DagsterInstance.get") as _instance:
_instance.return_value = instance
yield instance
@contextmanager
def default_cli_test_instance(overrides=None):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with _default_cli_test_instance_tempdir(temp_dir, overrides) as instance:
yield instance
@contextmanager
def args_with_instance(gen_instance, *args):
with gen_instance as instance:
yield args + (instance,)
def args_with_default_cli_test_instance(*args):
return args_with_instance(default_cli_test_instance(), *args)
@contextmanager
def grpc_server_bar_kwargs(pipeline_name=None):
server_process = GrpcServerProcess(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable,
python_file=file_relative_path(__file__, "test_cli_commands.py"),
attribute="bar",
),
)
with server_process.create_ephemeral_client() as client:
args = {"grpc_host": client.host}
if pipeline_name:
args["pipeline"] = "foo"
if client.port:
args["grpc_port"] = client.port
if client.socket:
args["grpc_socket"] = client.socket
yield args
server_process.wait()
@contextmanager
def python_bar_cli_args(pipeline_name=None):
args = [
"-m",
"dagster_tests.cli_tests.command_tests.test_cli_commands",
"-a",
"bar",
]
if pipeline_name:
args.append("-p")
args.append(pipeline_name)
yield args
@contextmanager
def grpc_server_bar_cli_args(pipeline_name=None):
server_process = GrpcServerProcess(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable,
python_file=file_relative_path(__file__, "test_cli_commands.py"),
attribute="bar",
),
)
with server_process.create_ephemeral_client() as client:
args = ["--grpc-host", client.host]
if client.port:
args.append("--grpc-port")
args.append(client.port)
if client.socket:
args.append("--grpc-socket")
args.append(client.socket)
if pipeline_name:
args.append("--pipeline")
args.append(pipeline_name)
yield args
server_process.wait()
@contextmanager
def grpc_server_bar_pipeline_args():
with default_cli_test_instance() as instance:
with grpc_server_bar_kwargs(pipeline_name="foo") as kwargs:
yield kwargs, instance
# This iterates over a list of contextmanagers that can be used to contruct
# (cli_args, instance tuples)
def launch_command_contexts():
for pipeline_target_args in valid_external_pipeline_target_args():
yield args_with_default_cli_test_instance(pipeline_target_args)
yield pytest.param(grpc_server_bar_pipeline_args())
def pipeline_python_origin_contexts():
return [
args_with_default_cli_test_instance(pipeline_target_args)
for pipeline_target_args in valid_pipeline_python_origin_target_args()
]
@contextmanager
def scheduler_instance(overrides=None):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with _default_cli_test_instance_tempdir(
temp_dir,
overrides=merge_dicts(
{
"scheduler": {
"module": "dagster.utils.test",
"class": "FilesystemTestScheduler",
"config": {"base_dir": temp_dir},
}
},
overrides if overrides else {},
),
) as instance:
yield instance
@contextmanager
def grpc_server_scheduler_cli_args(overrides=None):
with scheduler_instance(overrides) as instance:
with grpc_server_bar_cli_args() as args:
yield args, instance
# Returns a list of contextmanagers that can be used to contruct
# (cli_args, instance) tuples for schedule calls
def schedule_command_contexts():
return [
args_with_instance(
scheduler_instance(), ["-w", file_relative_path(__file__, "workspace.yaml")]
),
grpc_server_scheduler_cli_args(),
]
def sensor_command_contexts():
return [
args_with_instance(
scheduler_instance(), ["-w", file_relative_path(__file__, "workspace.yaml")],
),
grpc_server_scheduler_cli_args(),
]
# This iterates over a list of contextmanagers that can be used to contruct
# (cli_args, instance) tuples for backfill calls
def backfill_command_contexts():
repo_args = {
"noprompt": True,
"workspace": (file_relative_path(__file__, "repository_file.yaml"),),
}
return [
args_with_instance(default_cli_test_instance(), repo_args),
grpc_server_backfill_args(),
]
@contextmanager
def grpc_server_backfill_args():
with default_cli_test_instance() as instance:
with grpc_server_bar_kwargs() as args:
yield merge_dicts(args, {"noprompt": True}), instance
def non_existant_python_origin_target_args():
return {
"workspace": None,
"pipeline": "foo",
"python_file": file_relative_path(__file__, "made_up_file.py"),
"module_name": None,
"attribute": "bar",
}
def valid_pipeline_python_origin_target_args():
return [
{
"workspace": None,
"pipeline": "foo",
"python_file": file_relative_path(__file__, "test_cli_commands.py"),
"module_name": None,
"attribute": "bar",
},
{
"workspace": None,
"pipeline": "foo",
"python_file": file_relative_path(__file__, "test_cli_commands.py"),
"module_name": None,
"attribute": "bar",
"working_directory": os.path.dirname(__file__),
},
{
"workspace": None,
"pipeline": "foo",
"python_file": None,
"module_name": "dagster_tests.cli_tests.command_tests.test_cli_commands",
"attribute": "bar",
},
{
"workspace": None,
"pipeline": "foo",
"python_file": None,
"package_name": "dagster_tests.cli_tests.command_tests.test_cli_commands",
"attribute": "bar",
},
{
"workspace": None,
"pipeline": None,
"python_file": None,
"module_name": "dagster_tests.cli_tests.command_tests.test_cli_commands",
"attribute": "foo_pipeline",
},
{
"workspace": None,
"pipeline": None,
"python_file": None,
"package_name": "dagster_tests.cli_tests.command_tests.test_cli_commands",
"attribute": "foo_pipeline",
},
{
"workspace": None,
"pipeline": None,
"python_file": file_relative_path(__file__, "test_cli_commands.py"),
"module_name": None,
"attribute": "define_foo_pipeline",
},
{
"workspace": None,
"pipeline": None,
"python_file": file_relative_path(__file__, "test_cli_commands.py"),
"module_name": None,
"attribute": "define_foo_pipeline",
"working_directory": os.path.dirname(__file__),
},
{
"workspace": None,
"pipeline": None,
"python_file": file_relative_path(__file__, "test_cli_commands.py"),
"module_name": None,
"attribute": "foo_pipeline",
},
]
def valid_external_pipeline_target_args():
return [
{
"workspace": (file_relative_path(__file__, "repository_file.yaml"),),
"pipeline": "foo",
"python_file": None,
"module_name": None,
"attribute": None,
},
{
"workspace": (file_relative_path(__file__, "repository_module.yaml"),),
"pipeline": "foo",
"python_file": None,
"module_name": None,
"attribute": None,
},
] + [args for args in valid_pipeline_python_origin_target_args()]
def valid_pipeline_python_origin_target_cli_args():
return [
["-f", file_relative_path(__file__, "test_cli_commands.py"), "-a", "bar", "-p", "foo"],
[
"-f",
file_relative_path(__file__, "test_cli_commands.py"),
"-d",
os.path.dirname(__file__),
"-a",
"bar",
"-p",
"foo",
],
[
"-m",
"dagster_tests.cli_tests.command_tests.test_cli_commands",
"-a",
"bar",
"-p",
"foo",
],
["-m", "dagster_tests.cli_tests.command_tests.test_cli_commands", "-a", "foo_pipeline"],
["-f", file_relative_path(__file__, "test_cli_commands.py"), "-a", "define_foo_pipeline",],
[
"-f",
file_relative_path(__file__, "test_cli_commands.py"),
"-d",
os.path.dirname(__file__),
"-a",
"define_foo_pipeline",
],
]
def valid_external_pipeline_target_cli_args_no_preset():
return [
["-w", file_relative_path(__file__, "repository_file.yaml"), "-p", "foo"],
["-w", file_relative_path(__file__, "repository_module.yaml"), "-p", "foo"],
["-w", file_relative_path(__file__, "workspace.yaml"), "-p", "foo"],
[
"-w",
file_relative_path(__file__, "override.yaml"),
"-w",
file_relative_path(__file__, "workspace.yaml"),
"-p",
"foo",
],
] + [args for args in valid_pipeline_python_origin_target_cli_args()]
def valid_external_pipeline_target_cli_args_with_preset():
run_config = {"storage": {"filesystem": {"config": {"base_dir": "/tmp"}}}}
return valid_external_pipeline_target_cli_args_no_preset() + [
[
"-f",
file_relative_path(__file__, "test_cli_commands.py"),
"-d",
os.path.dirname(__file__),
"-a",
"define_foo_pipeline",
"--preset",
"test",
],
[
"-f",
file_relative_path(__file__, "test_cli_commands.py"),
"-d",
os.path.dirname(__file__),
"-a",
"define_foo_pipeline",
"--config-json",
json.dumps(run_config),
],
]
def test_run_list():
with instance_for_test():
runner = CliRunner()
result = runner.invoke(run_list_command)
assert result.exit_code == 0
def test_run_wipe_correct_delete_message():
with instance_for_test():
runner = CliRunner()
result = runner.invoke(run_wipe_command, input="DELETE\n")
assert "Deleted all run history and event logs" in result.output
assert result.exit_code == 0
def test_run_wipe_incorrect_delete_message():
with instance_for_test():
runner = CliRunner()
result = runner.invoke(run_wipe_command, input="WRONG\n")
assert "Exiting without deleting all run history and event logs" in result.output
assert result.exit_code == 0
def test_run_list_limit():
with instance_for_test():
runner = CliRunner()
runner_pipeline_execute(
runner,
[
"-f",
file_relative_path(__file__, "../../general_tests/test_repository.py"),
"-a",
"dagster_test_repository",
"--preset",
"add",
"-p",
"multi_mode_with_resources", # pipeline name
],
)
runner_pipeline_execute(
runner,
[
"-f",
file_relative_path(__file__, "../../general_tests/test_repository.py"),
"-a",
"dagster_test_repository",
"--preset",
"add",
"-p",
"multi_mode_with_resources", # pipeline name
],
)
# Should only shows one run because of the limit argument
result = runner.invoke(run_list_command, args="--limit 1")
assert result.exit_code == 0
assert result.output.count("Run: ") == 1
assert result.output.count("Pipeline: multi_mode_with_resources") == 1
# Shows two runs because of the limit argument is now 2
two_results = runner.invoke(run_list_command, args="--limit 2")
assert two_results.exit_code == 0
assert two_results.output.count("Run: ") == 2
assert two_results.output.count("Pipeline: multi_mode_with_resources") == 2
# Should only shows two runs although the limit argument is 3 because there are only 2 runs
shows_two_results = runner.invoke(run_list_command, args="--limit 3")
assert shows_two_results.exit_code == 0
assert shows_two_results.output.count("Run: ") == 2
assert shows_two_results.output.count("Pipeline: multi_mode_with_resources") == 2
def runner_pipeline_execute(runner, cli_args):
result = runner.invoke(pipeline_execute_command, cli_args)
if result.exit_code != 0:
# CliRunner captures stdout so printing it out here
raise Exception(
(
"dagster pipeline execute commands with cli_args {cli_args} "
'returned exit_code {exit_code} with stdout:\n"{stdout}" and '
'\nresult as string: "{result}"'
).format(
cli_args=cli_args, exit_code=result.exit_code, stdout=result.stdout, result=result
)
)
return result
def test_use_env_vars_for_cli_option():
env_key = "{}_VERSION".format(ENV_PREFIX)
runner = CliRunner(env={env_key: "1"})
# use `debug` subcommand to trigger the cli group option flag `--version`
# see issue: https://github.com/pallets/click/issues/1694
result = runner.invoke(cli, ["debug"], auto_envvar_prefix=ENV_PREFIX)
assert __version__ in result.output
assert result.exit_code == 0
diff --git a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_memoized_development_cli.py b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_memoized_development_cli.py
index 21029a1fd..f877f4b09 100644
--- a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_memoized_development_cli.py
+++ b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_memoized_development_cli.py
@@ -1,88 +1,89 @@
import os
import sys
+import tempfile
from io import BytesIO
import yaml
-from dagster import execute_pipeline, seven
+from dagster import execute_pipeline
from dagster.cli.pipeline import execute_list_versions_command
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher import DefaultRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.event_log import ConsolidatedSqliteEventLogStorage
from dagster.core.storage.local_compute_log_manager import LocalComputeLogManager
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import SqliteRunStorage
from dagster.utils import file_relative_path
from ...core_tests.execution_tests.memoized_dev_loop_pipeline import asset_pipeline
class Capturing(list):
def __enter__(self):
self._stdout = sys.stdout # pylint: disable=W0201
self._stringio = BytesIO() # pylint: disable=W0201
sys.stdout = self._stringio
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout
def test_execute_display_command():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
run_store = SqliteRunStorage.from_local(temp_dir)
event_store = ConsolidatedSqliteEventLogStorage(temp_dir)
compute_log_manager = LocalComputeLogManager(temp_dir)
instance = DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=run_store,
event_storage=event_store,
compute_log_manager=compute_log_manager,
run_coordinator=DefaultRunCoordinator(),
run_launcher=DefaultRunLauncher(),
)
run_config = {
"solids": {
"create_string_1_asset": {"config": {"input_str": "apple"}},
"take_string_1_asset": {"config": {"input_str": "apple"}},
},
"resources": {"object_manager": {"config": {"base_dir": temp_dir}}},
}
# write run config to temp file
# file is temp because intermediate storage directory is temporary
with open(os.path.join(temp_dir, "pipeline_config.yaml"), "w") as f:
f.write(yaml.dump(run_config))
kwargs = {
"config": (os.path.join(temp_dir, "pipeline_config.yaml"),),
"pipeline": "asset_pipeline",
"python_file": file_relative_path(
__file__, "../../core_tests/execution_tests/memoized_dev_loop_pipeline.py"
),
"tags": '{"dagster/is_memoized_run": "true"}',
}
with Capturing() as output:
execute_list_versions_command(kwargs=kwargs, instance=instance)
assert output
# execute the pipeline once so that addresses have been populated.
result = execute_pipeline(
asset_pipeline,
run_config=run_config,
mode="only_mode",
tags={"dagster/is_memoized_run": "true"},
instance=instance,
)
assert result.success
with Capturing() as output:
execute_list_versions_command(kwargs=kwargs, instance=instance)
assert output
diff --git a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_telemetry.py b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_telemetry.py
index 572d380b3..9dcb5f726 100644
--- a/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_telemetry.py
+++ b/python_modules/dagster/dagster_tests/cli_tests/command_tests/test_telemetry.py
@@ -1,230 +1,230 @@
import json
import logging
import os
+import tempfile
from difflib import SequenceMatcher
import mock
import pytest
import responses
from click.testing import CliRunner
-from dagster import seven
from dagster.cli.pipeline import pipeline_execute_command
from dagster.cli.workspace.load import load_workspace_from_yaml_paths
from dagster.core.definitions.reconstructable import get_ephemeral_repository_name
from dagster.core.telemetry import (
DAGSTER_TELEMETRY_URL,
UPDATE_REPO_STATS,
cleanup_telemetry_logger,
get_dir_from_dagster_home,
hash_name,
log_workspace_stats,
upload_logs,
)
from dagster.core.test_utils import environ, instance_for_test, instance_for_test_tempdir
from dagster.utils import file_relative_path, pushd, script_relative_path
EXPECTED_KEYS = set(
[
"action",
"client_time",
"elapsed_time",
"event_id",
"instance_id",
"pipeline_name_hash",
"num_pipelines_in_repo",
"repo_hash",
"python_version",
"metadata",
"version",
]
)
def path_to_file(path):
return script_relative_path(os.path.join("./", path))
def test_dagster_telemetry_enabled(caplog):
with instance_for_test(overrides={"telemetry": {"enabled": True}}):
runner = CliRunner()
with pushd(path_to_file("")):
pipeline_attribute = "foo_pipeline"
pipeline_name = "foo"
result = runner.invoke(
pipeline_execute_command,
["-f", path_to_file("test_cli_commands.py"), "-a", pipeline_attribute,],
)
for record in caplog.records:
message = json.loads(record.getMessage())
if message.get("action") == UPDATE_REPO_STATS:
assert message.get("pipeline_name_hash") == hash_name(pipeline_name)
assert message.get("num_pipelines_in_repo") == str(1)
assert message.get("repo_hash") == hash_name(
get_ephemeral_repository_name(pipeline_name)
)
assert set(message.keys()) == EXPECTED_KEYS
assert len(caplog.records) == 5
assert result.exit_code == 0
def test_dagster_telemetry_disabled(caplog):
with instance_for_test(overrides={"telemetry": {"enabled": False}}):
runner = CliRunner()
with pushd(path_to_file("")):
pipeline_name = "foo_pipeline"
result = runner.invoke(
pipeline_execute_command,
["-f", path_to_file("test_cli_commands.py"), "-a", pipeline_name,],
)
assert not os.path.exists(os.path.join(get_dir_from_dagster_home("logs"), "event.log"))
assert len(caplog.records) == 0
assert result.exit_code == 0
def test_dagster_telemetry_unset(caplog):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(temp_dir):
runner = CliRunner(env={"DAGSTER_HOME": temp_dir})
with pushd(path_to_file("")):
pipeline_attribute = "foo_pipeline"
pipeline_name = "foo"
result = runner.invoke(
pipeline_execute_command,
["-f", path_to_file("test_cli_commands.py"), "-a", pipeline_attribute],
)
for record in caplog.records:
message = json.loads(record.getMessage())
if message.get("action") == UPDATE_REPO_STATS:
assert message.get("pipeline_name_hash") == hash_name(pipeline_name)
assert message.get("num_pipelines_in_repo") == str(1)
assert message.get("repo_hash") == hash_name(
get_ephemeral_repository_name(pipeline_name)
)
assert set(message.keys()) == EXPECTED_KEYS
assert len(caplog.records) == 5
assert result.exit_code == 0
def test_repo_stats(caplog):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with instance_for_test_tempdir(temp_dir):
runner = CliRunner(env={"DAGSTER_HOME": temp_dir})
with pushd(path_to_file("")):
pipeline_name = "multi_mode_with_resources"
result = runner.invoke(
pipeline_execute_command,
[
"-f",
file_relative_path(__file__, "../../general_tests/test_repository.py"),
"-a",
"dagster_test_repository",
"-p",
pipeline_name,
"--preset",
"add",
"--tags",
'{ "foo": "bar" }',
],
)
assert result.exit_code == 0, result.stdout
for record in caplog.records:
message = json.loads(record.getMessage())
if message.get("action") == UPDATE_REPO_STATS:
assert message.get("pipeline_name_hash") == hash_name(pipeline_name)
assert message.get("num_pipelines_in_repo") == str(4)
assert message.get("repo_hash") == hash_name("dagster_test_repository")
assert set(message.keys()) == EXPECTED_KEYS
assert len(caplog.records) == 5
assert result.exit_code == 0
def test_log_workspace_stats(caplog):
with instance_for_test() as instance:
with load_workspace_from_yaml_paths(
[file_relative_path(__file__, "./multi_env_telemetry_workspace.yaml")]
) as workspace:
log_workspace_stats(instance, workspace)
for record in caplog.records:
message = json.loads(record.getMessage())
assert message.get("action") == UPDATE_REPO_STATS
assert set(message.keys()) == EXPECTED_KEYS
assert len(caplog.records) == 2
# Note that both environment must be set together. Otherwise, if env={"BUILDKITE": None} ran in the
# azure pipeline, then this test would fail, because TF_BUILD would be set implicitly, resulting in
# no logs being uploaded. The same applies in the reverse way, if only TF_BUILD is set to None.
@pytest.mark.parametrize("env", [{"BUILDKITE": None, "TF_BUILD": None}])
@responses.activate
def test_dagster_telemetry_upload(env):
logger = logging.getLogger("dagster_telemetry_logger")
for handler in logger.handlers:
logger.removeHandler(handler)
responses.add(responses.POST, DAGSTER_TELEMETRY_URL)
with environ(env):
with instance_for_test():
runner = CliRunner()
with pushd(path_to_file("")):
pipeline_attribute = "foo_pipeline"
runner.invoke(
pipeline_execute_command,
["-f", path_to_file("test_cli_commands.py"), "-a", pipeline_attribute],
)
mock_stop_event = mock.MagicMock()
mock_stop_event.is_set.return_value = False
def side_effect(_):
mock_stop_event.is_set.return_value = True
mock_stop_event.wait.side_effect = side_effect
# Needed to avoid file contention issues on windows with the telemetry log file
cleanup_telemetry_logger()
upload_logs(mock_stop_event, raise_errors=True)
assert responses.assert_call_count(DAGSTER_TELEMETRY_URL, 1)
@pytest.mark.parametrize("env", [{"BUILDKITE": "True"}, {"TF_BUILD": "True"}])
@responses.activate
def test_dagster_telemetry_no_test_env_upload(env):
with environ(env):
with instance_for_test():
runner = CliRunner()
with pushd(path_to_file("")):
pipeline_attribute = "foo_pipeline"
runner.invoke(
pipeline_execute_command,
["-f", path_to_file("test_cli_commands.py"), "-a", pipeline_attribute],
)
upload_logs(mock.MagicMock())
assert responses.assert_call_count(DAGSTER_TELEMETRY_URL, 0)
# Sanity check that the hash function maps these similar names to sufficiently dissimilar strings
# From the docs, SequenceMatcher `does not yield minimal edit sequences, but does tend to yield
# matches that "look right" to people. As a rule of thumb, a .ratio() value over 0.6 means the
# sequences are close matches`
# Other than above, 0.4 was picked arbitrarily.
def test_hash_name():
pipelines = ["pipeline_1", "pipeline_2", "pipeline_3"]
hashes = [hash_name(p) for p in pipelines]
for h in hashes:
assert len(h) == 64
assert SequenceMatcher(None, hashes[0], hashes[1]).ratio() < 0.4
assert SequenceMatcher(None, hashes[0], hashes[2]).ratio() < 0.4
assert SequenceMatcher(None, hashes[1], hashes[2]).ratio() < 0.4
diff --git a/python_modules/dagster/dagster_tests/cli_tests/workspace_tests/test_workspace_load.py b/python_modules/dagster/dagster_tests/cli_tests/workspace_tests/test_workspace_load.py
index fe7df9874..b1bec5c9a 100644
--- a/python_modules/dagster/dagster_tests/cli_tests/workspace_tests/test_workspace_load.py
+++ b/python_modules/dagster/dagster_tests/cli_tests/workspace_tests/test_workspace_load.py
@@ -1,22 +1,22 @@
import os
+from tempfile import TemporaryDirectory
import pytest
from dagster.check import CheckError
from dagster.cli.workspace.load import load_workspace_from_yaml_paths
-from dagster.seven import TemporaryDirectory
from dagster.utils import touch_file
def test_bad_workspace_yaml_load():
with TemporaryDirectory() as temp_dir:
touch_file(os.path.join(temp_dir, "foo.yaml"))
with pytest.raises(
CheckError,
match=(
"Invariant failed. Description: Could not parse a workspace config from the "
"yaml file at"
),
):
with load_workspace_from_yaml_paths([os.path.join(temp_dir, "foo.yaml")]):
pass
diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_plan_tests/test_external_step.py b/python_modules/dagster/dagster_tests/core_tests/execution_plan_tests/test_external_step.py
index 73a81c02b..905e8d131 100644
--- a/python_modules/dagster/dagster_tests/core_tests/execution_plan_tests/test_external_step.py
+++ b/python_modules/dagster/dagster_tests/core_tests/execution_plan_tests/test_external_step.py
@@ -1,286 +1,284 @@
import os
+import tempfile
import time
import uuid
from threading import Thread
import pytest
from dagster import (
Field,
ModeDefinition,
RetryRequested,
String,
execute_pipeline,
execute_pipeline_iterator,
pipeline,
reconstructable,
resource,
- seven,
solid,
)
from dagster.core.definitions.no_step_launcher import no_step_launcher
from dagster.core.errors import DagsterExecutionInterruptedError
from dagster.core.events import DagsterEventType
from dagster.core.execution.api import create_execution_plan
from dagster.core.execution.context_creation_pipeline import PipelineExecutionContextManager
from dagster.core.execution.plan.external_step import (
LocalExternalStepLauncher,
local_external_step_launcher,
step_context_to_step_run_ref,
step_run_ref_to_step_context,
)
from dagster.core.instance import DagsterInstance
from dagster.core.storage.pipeline_run import PipelineRun
from dagster.utils import safe_tempfile_path, send_interrupt
from dagster.utils.merger import deep_merge_dicts
RUN_CONFIG_BASE = {"solids": {"return_two": {"config": {"a": "b"}}}}
def make_run_config(scratch_dir, mode):
if mode in ["external", "request_retry"]:
step_launcher_resource_keys = ["first_step_launcher", "second_step_launcher"]
else:
step_launcher_resource_keys = ["second_step_launcher"]
return deep_merge_dicts(
RUN_CONFIG_BASE,
{
"resources": {
step_launcher_resource_key: {"config": {"scratch_dir": scratch_dir}}
for step_launcher_resource_key in step_launcher_resource_keys
},
"intermediate_storage": {"filesystem": {"config": {"base_dir": scratch_dir}}},
},
)
class RequestRetryLocalExternalStepLauncher(LocalExternalStepLauncher):
def launch_step(self, step_context, prior_attempts_count):
if prior_attempts_count == 0:
raise RetryRequested()
else:
return super(RequestRetryLocalExternalStepLauncher, self).launch_step(
step_context, prior_attempts_count
)
@resource(config_schema=local_external_step_launcher.config_schema)
def request_retry_local_external_step_launcher(context):
return RequestRetryLocalExternalStepLauncher(**context.resource_config)
def define_basic_pipeline():
@solid(required_resource_keys=set(["first_step_launcher"]), config_schema={"a": Field(str)})
def return_two(_):
return 2
@solid(required_resource_keys=set(["second_step_launcher"]))
def add_one(_, num):
return num + 1
@pipeline(
mode_defs=[
ModeDefinition(
"external",
resource_defs={
"first_step_launcher": local_external_step_launcher,
"second_step_launcher": local_external_step_launcher,
},
),
ModeDefinition(
"internal_and_external",
resource_defs={
"first_step_launcher": no_step_launcher,
"second_step_launcher": local_external_step_launcher,
},
),
ModeDefinition(
"request_retry",
resource_defs={
"first_step_launcher": request_retry_local_external_step_launcher,
"second_step_launcher": request_retry_local_external_step_launcher,
},
),
]
)
def basic_pipeline():
add_one(return_two())
return basic_pipeline
def define_sleepy_pipeline():
@solid(
config_schema={"tempfile": Field(String)},
required_resource_keys=set(["first_step_launcher"]),
)
def sleepy_solid(context):
with open(context.solid_config["tempfile"], "w") as ff:
ff.write("yup")
start_time = time.time()
while True:
time.sleep(0.1)
if time.time() - start_time > 120:
raise Exception("Timed out")
@pipeline(
mode_defs=[
ModeDefinition(
"external", resource_defs={"first_step_launcher": local_external_step_launcher,},
),
]
)
def sleepy_pipeline():
sleepy_solid()
return sleepy_pipeline
def initialize_step_context(scratch_dir, instance):
pipeline_run = PipelineRun(
pipeline_name="foo_pipeline",
run_id=str(uuid.uuid4()),
run_config=make_run_config(scratch_dir, "external"),
mode="external",
)
plan = create_execution_plan(
reconstructable(define_basic_pipeline), pipeline_run.run_config, mode="external"
)
initialization_manager = PipelineExecutionContextManager(
plan, pipeline_run.run_config, pipeline_run, instance,
)
for _ in initialization_manager.prepare_context():
pass
pipeline_context = initialization_manager.get_context()
step_context = pipeline_context.for_step(plan.get_step_by_key("return_two"))
return step_context
def test_step_context_to_step_run_ref():
with DagsterInstance.ephemeral() as instance:
step_context = initialize_step_context("", instance)
step = step_context.step
step_run_ref = step_context_to_step_run_ref(step_context, 0)
assert step_run_ref.run_config == step_context.pipeline_run.run_config
assert step_run_ref.run_id == step_context.pipeline_run.run_id
rehydrated_step_context = step_run_ref_to_step_context(step_run_ref, instance)
assert rehydrated_step_context.required_resource_keys == step_context.required_resource_keys
rehydrated_step = rehydrated_step_context.step
assert rehydrated_step.pipeline_name == step.pipeline_name
assert rehydrated_step.step_inputs == step.step_inputs
assert rehydrated_step.step_outputs == step.step_outputs
assert rehydrated_step.kind == step.kind
assert rehydrated_step.solid_handle.name == step.solid_handle.name
assert rehydrated_step.logging_tags == step.logging_tags
assert rehydrated_step.tags == step.tags
def test_local_external_step_launcher():
- with seven.TemporaryDirectory() as tmpdir:
+ with tempfile.TemporaryDirectory() as tmpdir:
with DagsterInstance.ephemeral() as instance:
step_context = initialize_step_context(tmpdir, instance)
step_launcher = LocalExternalStepLauncher(tmpdir)
events = list(step_launcher.launch_step(step_context, 0))
event_types = [event.event_type for event in events]
assert DagsterEventType.STEP_START in event_types
assert DagsterEventType.STEP_SUCCESS in event_types
assert DagsterEventType.STEP_FAILURE not in event_types
@pytest.mark.parametrize("mode", ["external", "internal_and_external"])
def test_pipeline(mode):
- with seven.TemporaryDirectory() as tmpdir:
+ with tempfile.TemporaryDirectory() as tmpdir:
result = execute_pipeline(
pipeline=reconstructable(define_basic_pipeline),
mode=mode,
run_config=make_run_config(tmpdir, mode),
)
assert result.result_for_solid("return_two").output_value() == 2
assert result.result_for_solid("add_one").output_value() == 3
def test_launcher_requests_retry():
mode = "request_retry"
- with seven.TemporaryDirectory() as tmpdir:
+ with tempfile.TemporaryDirectory() as tmpdir:
result = execute_pipeline(
pipeline=reconstructable(define_basic_pipeline),
mode=mode,
run_config=make_run_config(tmpdir, mode),
)
assert result.success
assert result.result_for_solid("return_two").output_value() == 2
assert result.result_for_solid("add_one").output_value() == 3
for step_key, events in result.events_by_step_key.items():
if step_key:
event_types = [event.event_type for event in events]
assert DagsterEventType.STEP_UP_FOR_RETRY in event_types
assert DagsterEventType.STEP_RESTARTED in event_types
def _send_interrupt_thread(temp_file):
while not os.path.exists(temp_file):
time.sleep(0.1)
send_interrupt()
@pytest.mark.parametrize("mode", ["external"])
def test_interrupt_step_launcher(mode):
- with seven.TemporaryDirectory() as tmpdir:
-
+ with tempfile.TemporaryDirectory() as tmpdir:
with safe_tempfile_path() as success_tempfile:
-
sleepy_run_config = {
"resources": {"first_step_launcher": {"config": {"scratch_dir": tmpdir}}},
"intermediate_storage": {"filesystem": {"config": {"base_dir": tmpdir}}},
"solids": {"sleepy_solid": {"config": {"tempfile": success_tempfile}}},
}
interrupt_thread = Thread(target=_send_interrupt_thread, args=(success_tempfile,))
interrupt_thread.start()
results = []
received_interrupt = False
try:
for result in execute_pipeline_iterator(
pipeline=reconstructable(define_sleepy_pipeline),
mode=mode,
run_config=sleepy_run_config,
):
results.append(result.event_type)
except DagsterExecutionInterruptedError:
received_interrupt = True
assert received_interrupt
assert DagsterEventType.STEP_FAILURE in results
assert DagsterEventType.PIPELINE_FAILURE in results
interrupt_thread.join()
def test_multiproc_launcher_requests_retry():
mode = "request_retry"
- with seven.TemporaryDirectory() as tmpdir:
+ with tempfile.TemporaryDirectory() as tmpdir:
run_config = make_run_config(tmpdir, mode)
run_config["execution"] = {"multiprocess": {}}
result = execute_pipeline(
instance=DagsterInstance.local_temp(tmpdir),
pipeline=reconstructable(define_basic_pipeline),
mode=mode,
run_config=run_config,
)
assert result.success
assert result.result_for_solid("return_two").output_value() == 2
assert result.result_for_solid("add_one").output_value() == 3
for step_key, events in result.events_by_step_key.items():
if step_key:
event_types = [event.event_type for event in events]
assert DagsterEventType.STEP_UP_FOR_RETRY in event_types
assert DagsterEventType.STEP_RESTARTED in event_types
diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_interrupt.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_interrupt.py
index 704cbe081..683c8af9f 100644
--- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_interrupt.py
+++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_interrupt.py
@@ -1,318 +1,319 @@
import os
import signal
+import tempfile
import time
from threading import Thread
import pytest
from dagster import (
DagsterEventType,
Field,
ModeDefinition,
String,
execute_pipeline_iterator,
pipeline,
reconstructable,
resource,
seven,
solid,
)
from dagster.core.errors import DagsterExecutionInterruptedError, raise_execution_interrupts
from dagster.core.test_utils import instance_for_test_tempdir
from dagster.utils import safe_tempfile_path, send_interrupt
from dagster.utils.interrupts import capture_interrupts, check_captured_interrupt
def _send_kbd_int(temp_files):
while not all([os.path.exists(temp_file) for temp_file in temp_files]):
time.sleep(0.1)
send_interrupt()
@solid(config_schema={"tempfile": Field(String)})
def write_a_file(context):
with open(context.solid_config["tempfile"], "w") as ff:
ff.write("yup")
start_time = time.time()
while (time.time() - start_time) < 30:
time.sleep(0.1)
raise Exception("Timed out")
@solid
def should_not_start(_context):
assert False
@pipeline
def write_files_pipeline():
write_a_file.alias("write_1")()
write_a_file.alias("write_2")()
write_a_file.alias("write_3")()
write_a_file.alias("write_4")()
should_not_start.alias("x_should_not_start")()
should_not_start.alias("y_should_not_start")()
should_not_start.alias("z_should_not_start")()
def test_single_proc_interrupt():
@pipeline
def write_a_file_pipeline():
write_a_file()
with safe_tempfile_path() as success_tempfile:
# launch a thread the waits until the file is written to launch an interrupt
Thread(target=_send_kbd_int, args=([success_tempfile],)).start()
result_types = []
result_messages = []
received_interrupt = False
try:
# launch a pipeline that writes a file and loops infinitely
# next time the launched thread wakes up it will send a keyboard
# interrupt
for result in execute_pipeline_iterator(
write_a_file_pipeline,
run_config={"solids": {"write_a_file": {"config": {"tempfile": success_tempfile}}}},
):
result_types.append(result.event_type)
result_messages.append(result.message)
assert False # should never reach
except DagsterExecutionInterruptedError:
received_interrupt = True
assert received_interrupt
assert DagsterEventType.STEP_FAILURE in result_types
assert DagsterEventType.PIPELINE_FAILURE in result_types
assert any(
[
"Execution was interrupted unexpectedly. "
"No user initiated termination request was found, treating as failure." in message
for message in result_messages
]
)
@pytest.mark.skipif(seven.IS_WINDOWS, reason="Interrupts handled differently on windows")
def test_interrupt_multiproc():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with instance_for_test_tempdir(tempdir) as instance:
file_1 = os.path.join(tempdir, "file_1")
file_2 = os.path.join(tempdir, "file_2")
file_3 = os.path.join(tempdir, "file_3")
file_4 = os.path.join(tempdir, "file_4")
# launch a thread that waits until the file is written to launch an interrupt
Thread(target=_send_kbd_int, args=([file_1, file_2, file_3, file_4],)).start()
results = []
received_interrupt = False
try:
# launch a pipeline that writes a file and loops infinitely
# next time the launched thread wakes up it will send a keyboard
# interrupt
for result in execute_pipeline_iterator(
reconstructable(write_files_pipeline),
run_config={
"solids": {
"write_1": {"config": {"tempfile": file_1}},
"write_2": {"config": {"tempfile": file_2}},
"write_3": {"config": {"tempfile": file_3}},
"write_4": {"config": {"tempfile": file_4}},
},
"execution": {"multiprocess": {"config": {"max_concurrent": 4}}},
"intermediate_storage": {"filesystem": {}},
},
instance=instance,
):
results.append(result)
assert False # should never reach
except DagsterExecutionInterruptedError:
received_interrupt = True
assert received_interrupt
assert [result.event_type for result in results].count(
DagsterEventType.STEP_FAILURE
) == 4
assert DagsterEventType.PIPELINE_FAILURE in [result.event_type for result in results]
def test_interrupt_resource_teardown():
called = []
cleaned = []
@resource
def resource_a(_):
try:
called.append("A")
yield "A"
finally:
cleaned.append("A")
@solid(config_schema={"tempfile": Field(String)}, required_resource_keys={"a"})
def write_a_file_resource_solid(context):
with open(context.solid_config["tempfile"], "w") as ff:
ff.write("yup")
while True:
time.sleep(0.1)
@pipeline(mode_defs=[ModeDefinition(resource_defs={"a": resource_a})])
def write_a_file_pipeline():
write_a_file_resource_solid()
with safe_tempfile_path() as success_tempfile:
# launch a thread the waits until the file is written to launch an interrupt
Thread(target=_send_kbd_int, args=([success_tempfile],)).start()
results = []
received_interrupt = False
try:
# launch a pipeline that writes a file and loops infinitely
# next time the launched thread wakes up it will send an interrupt
for result in execute_pipeline_iterator(
write_a_file_pipeline,
run_config={
"solids": {
"write_a_file_resource_solid": {"config": {"tempfile": success_tempfile}}
}
},
):
results.append(result.event_type)
assert False # should never reach
except DagsterExecutionInterruptedError:
received_interrupt = True
assert received_interrupt
assert DagsterEventType.STEP_FAILURE in results
assert DagsterEventType.PIPELINE_FAILURE in results
assert "A" in cleaned
def _send_interrupt_to_self():
os.kill(os.getpid(), signal.SIGINT)
start_time = time.time()
while not check_captured_interrupt():
time.sleep(1)
if time.time() - start_time > 15:
raise Exception("Timed out waiting for interrupt to be received")
@pytest.mark.skipif(seven.IS_WINDOWS, reason="Interrupts handled differently on windows")
def test_capture_interrupt():
outer_interrupt = False
inner_interrupt = False
with capture_interrupts():
try:
_send_interrupt_to_self()
except: # pylint: disable=bare-except
inner_interrupt = True
assert not inner_interrupt
# Verify standard interrupt handler is restored
standard_interrupt = False
try:
_send_interrupt_to_self()
except KeyboardInterrupt:
standard_interrupt = True
assert standard_interrupt
outer_interrupt = False
inner_interrupt = False
# No exception if no signal thrown
try:
with capture_interrupts():
try:
time.sleep(5)
except: # pylint: disable=bare-except
inner_interrupt = True
except: # pylint: disable=bare-except
outer_interrupt = True
assert not outer_interrupt
assert not inner_interrupt
@pytest.mark.skipif(seven.IS_WINDOWS, reason="Interrupts handled differently on windows")
def test_raise_execution_interrupts():
with raise_execution_interrupts():
try:
_send_interrupt_to_self()
except DagsterExecutionInterruptedError:
standard_interrupt = True
assert standard_interrupt
@pytest.mark.skipif(seven.IS_WINDOWS, reason="Interrupts handled differently on windows")
def test_interrupt_inside_nested_delay_and_raise():
interrupt_inside_nested_raise = False
interrupt_after_delay = False
try:
with capture_interrupts():
with raise_execution_interrupts():
try:
_send_interrupt_to_self()
except DagsterExecutionInterruptedError:
interrupt_inside_nested_raise = True
except: # pylint: disable=bare-except
interrupt_after_delay = True
assert interrupt_inside_nested_raise
assert not interrupt_after_delay
@pytest.mark.skipif(seven.IS_WINDOWS, reason="Interrupts handled differently on windows")
def test_no_interrupt_after_nested_delay_and_raise():
interrupt_inside_nested_raise = False
interrupt_after_delay = False
try:
with capture_interrupts():
with raise_execution_interrupts():
try:
time.sleep(5)
except: # pylint: disable=bare-except
interrupt_inside_nested_raise = True
_send_interrupt_to_self()
except: # pylint: disable=bare-except
interrupt_after_delay = True
assert not interrupt_inside_nested_raise
assert not interrupt_after_delay
@pytest.mark.skipif(seven.IS_WINDOWS, reason="Interrupts handled differently on windows")
def test_calling_raise_execution_interrupts_also_raises_any_captured_interrupts():
interrupt_from_raise_execution_interrupts = False
interrupt_after_delay = False
try:
with capture_interrupts():
_send_interrupt_to_self()
try:
with raise_execution_interrupts():
pass
except DagsterExecutionInterruptedError:
interrupt_from_raise_execution_interrupts = True
except: # pylint: disable=bare-except
interrupt_after_delay = True
assert interrupt_from_raise_execution_interrupts
assert not interrupt_after_delay
diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_memoized_dev_loop.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_memoized_dev_loop.py
index cc6a675b3..3e9a8c73d 100644
--- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_memoized_dev_loop.py
+++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_memoized_dev_loop.py
@@ -1,68 +1,70 @@
-from dagster import execute_pipeline, seven
+import tempfile
+
+from dagster import execute_pipeline
from dagster.core.execution.api import create_execution_plan
from dagster.core.execution.resolve_versions import resolve_memoized_execution_plan
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher import DefaultRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.event_log import ConsolidatedSqliteEventLogStorage
from dagster.core.storage.local_compute_log_manager import LocalComputeLogManager
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import SqliteRunStorage
from .memoized_dev_loop_pipeline import asset_pipeline
def get_step_keys_to_execute(pipeline, run_config, mode):
memoized_execution_plan = resolve_memoized_execution_plan(
create_execution_plan(pipeline, run_config=run_config, mode=mode)
)
return memoized_execution_plan.step_keys_to_execute
def test_dev_loop_changing_versions():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
run_store = SqliteRunStorage.from_local(temp_dir)
event_store = ConsolidatedSqliteEventLogStorage(temp_dir)
compute_log_manager = LocalComputeLogManager(temp_dir)
instance = DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=run_store,
event_storage=event_store,
compute_log_manager=compute_log_manager,
run_launcher=DefaultRunLauncher(),
run_coordinator=DefaultRunCoordinator(),
)
run_config = {
"solids": {
"create_string_1_asset": {"config": {"input_str": "apple"}},
"take_string_1_asset": {"config": {"input_str": "apple"}},
},
"resources": {"object_manager": {"config": {"base_dir": temp_dir}}},
}
result = execute_pipeline(
asset_pipeline,
run_config=run_config,
mode="only_mode",
tags={"dagster/is_memoized_run": "true"},
instance=instance,
)
assert result.success
assert not get_step_keys_to_execute(asset_pipeline, run_config, "only_mode")
run_config["solids"]["take_string_1_asset"]["config"]["input_str"] = "banana"
assert get_step_keys_to_execute(asset_pipeline, run_config, "only_mode") == [
"take_string_1_asset"
]
result = execute_pipeline(
asset_pipeline,
run_config=run_config,
mode="only_mode",
tags={"dagster/is_memoized_run": "true"},
instance=instance,
)
assert result.success
assert not get_step_keys_to_execute(asset_pipeline, run_config, "only_mode")
diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_retries.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_retries.py
index 22f8e8d2a..e05ca4225 100644
--- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_retries.py
+++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_retries.py
@@ -1,246 +1,246 @@
import os
+import tempfile
import time
from collections import defaultdict
import pytest
from dagster import (
DagsterEventType,
Output,
OutputDefinition,
PipelineRun,
RetryRequested,
execute_pipeline,
execute_pipeline_iterator,
lambda_solid,
pipeline,
reconstructable,
reexecute_pipeline,
- seven,
solid,
)
from dagster.core.execution.api import create_execution_plan, execute_plan
from dagster.core.execution.retries import Retries, RetryMode
from dagster.core.test_utils import instance_for_test
executors = pytest.mark.parametrize(
"environment",
[
{"intermediate_storage": {"filesystem": {}}},
{"intermediate_storage": {"filesystem": {}}, "execution": {"multiprocess": {}}},
],
)
def define_run_retry_pipeline():
@solid(config_schema={"fail": bool})
def can_fail(context, _start_fail):
if context.solid_config["fail"]:
raise Exception("blah")
return "okay perfect"
@solid(
output_defs=[
OutputDefinition(bool, "start_fail", is_required=False),
OutputDefinition(bool, "start_skip", is_required=False),
]
)
def two_outputs(_):
yield Output(True, "start_fail")
# won't yield start_skip
@solid
def will_be_skipped(_, _start_skip):
pass # doesn't matter
@solid
def downstream_of_failed(_, input_str):
return input_str
@pipeline
def pipe():
start_fail, start_skip = two_outputs()
downstream_of_failed(can_fail(start_fail))
will_be_skipped(will_be_skipped(start_skip))
return pipe
@executors
def test_retries(environment):
with instance_for_test() as instance:
pipe = reconstructable(define_run_retry_pipeline)
fails = dict(environment)
fails["solids"] = {"can_fail": {"config": {"fail": True}}}
result = execute_pipeline(pipe, run_config=fails, instance=instance, raise_on_error=False,)
assert not result.success
passes = dict(environment)
passes["solids"] = {"can_fail": {"config": {"fail": False}}}
second_result = reexecute_pipeline(
pipe, parent_run_id=result.run_id, run_config=passes, instance=instance,
)
assert second_result.success
downstream_of_failed = second_result.result_for_solid("downstream_of_failed").output_value()
assert downstream_of_failed == "okay perfect"
will_be_skipped = [
e for e in second_result.event_list if "will_be_skipped" in str(e.solid_handle)
]
assert str(will_be_skipped[0].event_type_value) == "STEP_SKIPPED"
assert str(will_be_skipped[1].event_type_value) == "STEP_SKIPPED"
def define_step_retry_pipeline():
@solid(config_schema=str)
def fail_first_time(context):
file = os.path.join(context.solid_config, "i_threw_up")
if os.path.exists(file):
return "okay perfect"
else:
open(file, "a").close()
raise RetryRequested()
@pipeline
def step_retry():
fail_first_time()
return step_retry
@executors
def test_step_retry(environment):
with instance_for_test() as instance:
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
env = dict(environment)
env["solids"] = {"fail_first_time": {"config": tempdir}}
result = execute_pipeline(
reconstructable(define_step_retry_pipeline), run_config=env, instance=instance,
)
assert result.success
events = defaultdict(list)
for ev in result.event_list:
events[ev.event_type].append(ev)
assert len(events[DagsterEventType.STEP_START]) == 1
assert len(events[DagsterEventType.STEP_UP_FOR_RETRY]) == 1
assert len(events[DagsterEventType.STEP_RESTARTED]) == 1
assert len(events[DagsterEventType.STEP_SUCCESS]) == 1
def define_retry_limit_pipeline():
@lambda_solid
def default_max():
raise RetryRequested()
@lambda_solid
def three_max():
raise RetryRequested(max_retries=3)
@pipeline
def retry_limits():
default_max()
three_max()
return retry_limits
@executors
def test_step_retry_limit(environment):
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(define_retry_limit_pipeline),
run_config=environment,
raise_on_error=False,
instance=instance,
)
assert not result.success
events = defaultdict(list)
for ev in result.events_by_step_key["default_max"]:
events[ev.event_type].append(ev)
assert len(events[DagsterEventType.STEP_START]) == 1
assert len(events[DagsterEventType.STEP_UP_FOR_RETRY]) == 1
assert len(events[DagsterEventType.STEP_RESTARTED]) == 1
assert len(events[DagsterEventType.STEP_FAILURE]) == 1
events = defaultdict(list)
for ev in result.events_by_step_key["three_max"]:
events[ev.event_type].append(ev)
assert len(events[DagsterEventType.STEP_START]) == 1
assert len(events[DagsterEventType.STEP_UP_FOR_RETRY]) == 3
assert len(events[DagsterEventType.STEP_RESTARTED]) == 3
assert len(events[DagsterEventType.STEP_FAILURE]) == 1
def test_retry_deferral():
with instance_for_test() as instance:
events = execute_plan(
create_execution_plan(define_retry_limit_pipeline()),
pipeline_run=PipelineRun(pipeline_name="retry_limits", run_id="42"),
retries=Retries(RetryMode.DEFERRED),
instance=instance,
)
events_by_type = defaultdict(list)
for ev in events:
events_by_type[ev.event_type].append(ev)
assert len(events_by_type[DagsterEventType.STEP_START]) == 2
assert len(events_by_type[DagsterEventType.STEP_UP_FOR_RETRY]) == 2
assert DagsterEventType.STEP_RESTARTED not in events
assert DagsterEventType.STEP_SUCCESS not in events
DELAY = 2
def define_retry_wait_fixed_pipeline():
@solid(config_schema=str)
def fail_first_and_wait(context):
file = os.path.join(context.solid_config, "i_threw_up")
if os.path.exists(file):
return "okay perfect"
else:
open(file, "a").close()
raise RetryRequested(seconds_to_wait=DELAY)
@pipeline
def step_retry():
fail_first_and_wait()
return step_retry
@executors
def test_step_retry_fixed_wait(environment):
with instance_for_test() as instance:
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
env = dict(environment)
env["solids"] = {"fail_first_and_wait": {"config": tempdir}}
event_iter = execute_pipeline_iterator(
reconstructable(define_retry_wait_fixed_pipeline),
run_config=env,
instance=instance,
)
start_wait = None
end_wait = None
success = None
for event in event_iter:
if event.is_step_up_for_retry:
start_wait = time.time()
if event.is_step_restarted:
end_wait = time.time()
if event.is_pipeline_success:
success = True
assert success
assert start_wait is not None
assert end_wait is not None
delay = end_wait - start_wait
assert delay > DELAY
diff --git a/python_modules/dagster/dagster_tests/core_tests/launcher_tests/test_default_run_launcher.py b/python_modules/dagster/dagster_tests/core_tests/launcher_tests/test_default_run_launcher.py
index eaaada7b7..7a04ccba7 100644
--- a/python_modules/dagster/dagster_tests/core_tests/launcher_tests/test_default_run_launcher.py
+++ b/python_modules/dagster/dagster_tests/core_tests/launcher_tests/test_default_run_launcher.py
@@ -1,546 +1,547 @@
import os
import re
import sys
+import tempfile
import time
from contextlib import contextmanager
import pytest
from dagster import DefaultRunLauncher, file_relative_path, pipeline, repository, seven, solid
from dagster.core.errors import DagsterLaunchFailedError
from dagster.core.host_representation import (
GrpcServerRepositoryLocationOrigin,
ManagedGrpcPythonEnvRepositoryLocationOrigin,
)
from dagster.core.host_representation.handle import RepositoryLocationHandle
from dagster.core.host_representation.repository_location import GrpcServerRepositoryLocation
from dagster.core.storage.pipeline_run import PipelineRunStatus
from dagster.core.test_utils import (
environ,
instance_for_test,
instance_for_test_tempdir,
poll_for_event,
poll_for_finished_run,
poll_for_step_start,
)
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.grpc.server import GrpcServerProcess
@solid
def noop_solid(_):
pass
@pipeline
def noop_pipeline():
pass
@solid
def crashy_solid(_):
os._exit(1) # pylint: disable=W0212
@pipeline
def crashy_pipeline():
crashy_solid()
@solid
def sleepy_solid(_):
while True:
time.sleep(0.1)
@pipeline
def sleepy_pipeline():
sleepy_solid()
@solid
def slow_solid(_):
time.sleep(4)
@pipeline
def slow_pipeline():
slow_solid()
@solid
def return_one(_):
return 1
@solid
def multiply_by_2(_, num):
return num * 2
@solid
def multiply_by_3(_, num):
return num * 3
@solid
def add(_, num1, num2):
return num1 + num2
@pipeline
def math_diamond():
one = return_one()
add(multiply_by_2(one), multiply_by_3(one))
@repository
def nope():
return [noop_pipeline, crashy_pipeline, sleepy_pipeline, slow_pipeline, math_diamond]
@contextmanager
def get_external_pipeline_from_grpc_server_repository(pipeline_name):
loadable_target_origin = LoadableTargetOrigin(
executable_path=sys.executable,
attribute="nope",
python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
)
server_process = GrpcServerProcess(loadable_target_origin=loadable_target_origin)
try:
with server_process.create_ephemeral_client() as api_client:
repository_location = GrpcServerRepositoryLocation(
RepositoryLocationHandle.create_from_repository_location_origin(
GrpcServerRepositoryLocationOrigin(
location_name="test",
port=api_client.port,
socket=api_client.socket,
host=api_client.host,
)
)
)
yield repository_location.get_repository("nope").get_full_external_pipeline(
pipeline_name
)
finally:
server_process.wait()
@contextmanager
def get_external_pipeline_from_managed_grpc_python_env_repository(pipeline_name):
with RepositoryLocationHandle.create_from_repository_location_origin(
ManagedGrpcPythonEnvRepositoryLocationOrigin(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable,
attribute="nope",
python_file=file_relative_path(__file__, "test_default_run_launcher.py"),
),
location_name="nope",
)
) as repository_location_handle:
repository_location = GrpcServerRepositoryLocation(repository_location_handle)
yield repository_location.get_repository("nope").get_full_external_pipeline(pipeline_name)
def run_configs():
return [
None,
{"execution": {"multiprocess": {}}, "intermediate_storage": {"filesystem": {}}},
]
def _is_multiprocess(run_config):
return run_config and "execution" in run_config and "multiprocess" in run_config["execution"]
def _check_event_log_contains(event_log, expected_type_and_message):
types_and_messages = [(e.dagster_event.event_type_value, e.message) for e in event_log]
for expected_event_type, expected_message_fragment in expected_type_and_message:
assert any(
event_type == expected_event_type and expected_message_fragment in message
for event_type, message in types_and_messages
), "Missing {expected_event_type}:{expected_message_fragment}".format(
expected_event_type=expected_event_type,
expected_message_fragment=expected_message_fragment,
)
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
@pytest.mark.parametrize(
"run_config", run_configs(),
)
def test_successful_run(get_external_pipeline, run_config): # pylint: disable=redefined-outer-name
with instance_for_test() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=noop_pipeline, run_config=run_config
)
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
run_id = pipeline_run.run_id
assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED
instance.launch_run(run_id=pipeline_run.run_id, external_pipeline=external_pipeline)
pipeline_run = instance.get_run_by_id(run_id)
assert pipeline_run
assert pipeline_run.run_id == run_id
pipeline_run = poll_for_finished_run(instance, run_id)
assert pipeline_run.status == PipelineRunStatus.SUCCESS
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
def test_invalid_instance_run(get_external_pipeline):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
correct_run_storage_dir = os.path.join(temp_dir, "history", "")
wrong_run_storage_dir = os.path.join(temp_dir, "wrong", "")
with environ({"RUN_STORAGE_ENV": correct_run_storage_dir}):
with instance_for_test_tempdir(
temp_dir,
overrides={
"run_storage": {
"module": "dagster.core.storage.runs",
"class": "SqliteRunStorage",
"config": {"base_dir": {"env": "RUN_STORAGE_ENV"}},
}
},
) as instance:
pipeline_run = instance.create_run_for_pipeline(pipeline_def=noop_pipeline,)
# Server won't be able to load the run from run storage
with environ({"RUN_STORAGE_ENV": wrong_run_storage_dir}):
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
with pytest.raises(
DagsterLaunchFailedError,
match=re.escape(
"gRPC server could not load run {run_id} in order to execute it".format(
run_id=pipeline_run.run_id
)
),
):
instance.launch_run(
run_id=pipeline_run.run_id, external_pipeline=external_pipeline,
)
failed_run = instance.get_run_by_id(pipeline_run.run_id)
assert failed_run.status == PipelineRunStatus.FAILURE
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
@pytest.mark.parametrize(
"run_config", run_configs(),
)
@pytest.mark.skipif(
seven.IS_WINDOWS,
reason="Crashy pipelines leave resources open on windows, causing filesystem contention",
)
def test_crashy_run(get_external_pipeline, run_config): # pylint: disable=redefined-outer-name
with instance_for_test() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=crashy_pipeline, run_config=run_config,
)
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
run_id = pipeline_run.run_id
assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED
instance.launch_run(pipeline_run.run_id, external_pipeline)
failed_pipeline_run = instance.get_run_by_id(run_id)
assert failed_pipeline_run
assert failed_pipeline_run.run_id == run_id
failed_pipeline_run = poll_for_finished_run(instance, run_id, timeout=5)
assert failed_pipeline_run.status == PipelineRunStatus.FAILURE
event_records = instance.all_logs(run_id)
if _is_multiprocess(run_config):
message = (
"Multiprocess executor: child process for "
"step crashy_solid unexpectedly exited"
)
else:
message = "Pipeline execution process for {run_id} unexpectedly exited.".format(
run_id=run_id
)
assert _message_exists(event_records, message)
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
@pytest.mark.parametrize(
"run_config", run_configs(),
)
def test_terminated_run(get_external_pipeline, run_config): # pylint: disable=redefined-outer-name
with instance_for_test() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=sleepy_pipeline, run_config=run_config,
)
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
run_id = pipeline_run.run_id
assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED
instance.launch_run(pipeline_run.run_id, external_pipeline)
poll_for_step_start(instance, run_id)
launcher = instance.run_launcher
assert launcher.can_terminate(run_id)
assert launcher.terminate(run_id)
terminated_pipeline_run = poll_for_finished_run(instance, run_id, timeout=30)
terminated_pipeline_run = instance.get_run_by_id(run_id)
assert terminated_pipeline_run.status == PipelineRunStatus.CANCELED
poll_for_event(
instance, run_id, event_type="ENGINE_EVENT", message="Process for pipeline exited",
)
run_logs = instance.all_logs(run_id)
if _is_multiprocess(run_config):
_check_event_log_contains(
run_logs,
[
("PIPELINE_CANCELING", "Sending pipeline termination request."),
(
"ENGINE_EVENT",
"Multiprocess executor: received termination signal - forwarding to active child process",
),
(
"ENGINE_EVENT",
"Multiprocess executor: interrupted all active child processes",
),
("STEP_FAILURE", 'Execution of step "sleepy_solid" failed.'),
("PIPELINE_CANCELED", 'Execution of pipeline "sleepy_pipeline" canceled.',),
("ENGINE_EVENT", "Process for pipeline exited"),
],
)
else:
_check_event_log_contains(
run_logs,
[
("PIPELINE_CANCELING", "Sending pipeline termination request."),
("STEP_FAILURE", 'Execution of step "sleepy_solid" failed.'),
("PIPELINE_CANCELED", 'Execution of pipeline "sleepy_pipeline" canceled.',),
("ENGINE_EVENT", "Pipeline execution terminated by interrupt"),
("ENGINE_EVENT", "Process for pipeline exited"),
],
)
def _get_engine_events(event_records):
return [er for er in event_records if er.dagster_event and er.dagster_event.is_engine_event]
def _get_successful_step_keys(event_records):
step_keys = set()
for er in event_records:
if er.dagster_event and er.dagster_event.is_step_success:
step_keys.add(er.dagster_event.step_key)
return step_keys
def _message_exists(event_records, message_text):
for event_record in event_records:
if message_text in event_record.message:
return True
return False
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
@pytest.mark.parametrize(
"run_config", run_configs(),
)
def test_single_solid_selection_execution(
get_external_pipeline, run_config,
): # pylint: disable=redefined-outer-name
with instance_for_test() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=math_diamond, run_config=run_config, solids_to_execute={"return_one"}
)
run_id = pipeline_run.run_id
assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
instance.launch_run(pipeline_run.run_id, external_pipeline)
finished_pipeline_run = poll_for_finished_run(instance, run_id)
event_records = instance.all_logs(run_id)
assert finished_pipeline_run
assert finished_pipeline_run.run_id == run_id
assert finished_pipeline_run.status == PipelineRunStatus.SUCCESS
assert _get_successful_step_keys(event_records) == {"return_one"}
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
@pytest.mark.parametrize(
"run_config", run_configs(),
)
def test_multi_solid_selection_execution(
get_external_pipeline, run_config,
): # pylint: disable=redefined-outer-name
with instance_for_test() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=math_diamond,
run_config=run_config,
solids_to_execute={"return_one", "multiply_by_2"},
)
run_id = pipeline_run.run_id
assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
instance.launch_run(pipeline_run.run_id, external_pipeline)
finished_pipeline_run = poll_for_finished_run(instance, run_id)
event_records = instance.all_logs(run_id)
assert finished_pipeline_run
assert finished_pipeline_run.run_id == run_id
assert finished_pipeline_run.status == PipelineRunStatus.SUCCESS
assert _get_successful_step_keys(event_records) == {
"return_one",
"multiply_by_2",
}
@pytest.mark.parametrize(
"get_external_pipeline",
[
get_external_pipeline_from_grpc_server_repository,
get_external_pipeline_from_managed_grpc_python_env_repository,
],
)
@pytest.mark.parametrize(
"run_config", run_configs(),
)
def test_engine_events(get_external_pipeline, run_config): # pylint: disable=redefined-outer-name
with instance_for_test() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=math_diamond, run_config=run_config
)
run_id = pipeline_run.run_id
assert instance.get_run_by_id(run_id).status == PipelineRunStatus.NOT_STARTED
with get_external_pipeline(pipeline_run.pipeline_name) as external_pipeline:
instance.launch_run(pipeline_run.run_id, external_pipeline)
finished_pipeline_run = poll_for_finished_run(instance, run_id)
assert finished_pipeline_run
assert finished_pipeline_run.run_id == run_id
assert finished_pipeline_run.status == PipelineRunStatus.SUCCESS
poll_for_event(
instance, run_id, event_type="ENGINE_EVENT", message="Process for pipeline exited"
)
event_records = instance.all_logs(run_id)
engine_events = _get_engine_events(event_records)
if _is_multiprocess(run_config):
messages = [
"Started process for pipeline",
"Starting initialization of resources",
"Finished initialization of resources",
"Executing steps using multiprocess executor",
"Launching subprocess for return_one",
"Executing step return_one in subprocess",
"Starting initialization of resources",
"Finished initialization of resources",
# multiply_by_2 and multiply_by_3 launch and execute in non-deterministic order
"",
"",
"",
"",
"",
"",
"",
"",
"Launching subprocess for add",
"Executing step add in subprocess",
"Starting initialization of resources",
"Finished initialization of resources",
"Multiprocess executor: parent process exiting",
"Process for pipeline exited",
]
else:
messages = [
"Started process for pipeline",
"Starting initialization of resources",
"Finished initialization of resources",
"Executing steps in process",
"Finished steps in process",
"Process for pipeline exited",
]
events_iter = iter(engine_events)
assert len(engine_events) == len(messages)
for message in messages:
next_log = next(events_iter)
assert message in next_log.message
def test_not_initialized(): # pylint: disable=redefined-outer-name
run_launcher = DefaultRunLauncher()
run_id = "dummy"
assert run_launcher.join() is None
assert run_launcher.can_terminate(run_id) is False
assert run_launcher.terminate(run_id) is False
diff --git a/python_modules/dagster/dagster_tests/core_tests/serdes_tests/test_ipc.py b/python_modules/dagster/dagster_tests/core_tests/serdes_tests/test_ipc.py
index fb74888a7..36b2e9fa2 100644
--- a/python_modules/dagster/dagster_tests/core_tests/serdes_tests/test_ipc.py
+++ b/python_modules/dagster/dagster_tests/core_tests/serdes_tests/test_ipc.py
@@ -1,293 +1,293 @@
import os
import sys
import time
+from contextlib import ExitStack
import pytest
from dagster.serdes.ipc import (
interrupt_ipc_subprocess,
interrupt_ipc_subprocess_pid,
open_ipc_subprocess,
)
-from dagster.seven import ExitStack
from dagster.utils import file_relative_path, process_is_alive, safe_tempfile_path
def wait_for_file(path, timeout=5):
interval = 0.1
total_time = 0
while not os.path.exists(path) and total_time < timeout:
time.sleep(interval)
total_time += interval
if total_time >= timeout:
raise Exception("wait_for_file: timeout")
time.sleep(interval)
def wait_for_process(pid, timeout=5):
interval = 0.1
total_time = 0
while process_is_alive(pid) and total_time < timeout:
time.sleep(interval)
total_time += interval
if total_time >= timeout:
raise Exception("wait_for_process: timeout")
# The following line can be removed to reliably provoke failures on Windows -- hypothesis
# is that there's a race in psutil.Process which tells us a process is gone before it stops
# contending for files
time.sleep(interval)
@pytest.fixture(scope="function")
def windows_legacy_stdio_env():
old_env_value = os.getenv("PYTHONLEGACYWINDOWSSTDIO")
try:
os.environ["PYTHONLEGACYWINDOWSSTDIO"] = "1"
yield
finally:
if old_env_value is not None:
os.environ["PYTHONLEGACYWINDOWSSTDIO"] = old_env_value
else:
del os.environ["PYTHONLEGACYWINDOWSSTDIO"]
def test_interrupt_ipc_subprocess():
with safe_tempfile_path() as started_sentinel:
with safe_tempfile_path() as interrupt_sentinel:
sleepy_process = open_ipc_subprocess(
[
sys.executable,
file_relative_path(__file__, "subprocess_with_interrupt_support.py"),
started_sentinel,
interrupt_sentinel,
]
)
wait_for_file(started_sentinel)
interrupt_ipc_subprocess(sleepy_process)
wait_for_file(interrupt_sentinel)
with open(interrupt_sentinel, "r") as fd:
assert fd.read().startswith("received_keyboard_interrupt")
def test_interrupt_ipc_subprocess_by_pid():
with safe_tempfile_path() as started_sentinel:
with safe_tempfile_path() as interrupt_sentinel:
sleepy_process = open_ipc_subprocess(
[
sys.executable,
file_relative_path(__file__, "subprocess_with_interrupt_support.py"),
started_sentinel,
interrupt_sentinel,
]
)
wait_for_file(started_sentinel)
interrupt_ipc_subprocess_pid(sleepy_process.pid)
wait_for_file(interrupt_sentinel)
with open(interrupt_sentinel, "r") as fd:
assert fd.read().startswith("received_keyboard_interrupt")
def test_interrupt_ipc_subprocess_grandchild():
with ExitStack() as context_stack:
(
child_opened_sentinel,
parent_interrupt_sentinel,
child_started_sentinel,
child_interrupt_sentinel,
) = [context_stack.enter_context(safe_tempfile_path()) for _ in range(4)]
child_process = open_ipc_subprocess(
[
sys.executable,
file_relative_path(__file__, "parent_subprocess_with_interrupt_support.py"),
child_opened_sentinel,
parent_interrupt_sentinel,
child_started_sentinel,
child_interrupt_sentinel,
]
)
wait_for_file(child_opened_sentinel)
wait_for_file(child_started_sentinel)
interrupt_ipc_subprocess(child_process)
wait_for_file(child_interrupt_sentinel)
with open(child_interrupt_sentinel, "r") as fd:
assert fd.read().startswith("received_keyboard_interrupt")
wait_for_file(parent_interrupt_sentinel)
with open(parent_interrupt_sentinel, "r") as fd:
assert fd.read().startswith("parent_received_keyboard_interrupt")
def test_interrupt_compute_log_tail_child(
windows_legacy_stdio_env, # pylint: disable=redefined-outer-name, unused-argument
):
with ExitStack() as context_stack:
(stdout_pids_file, stderr_pids_file, opened_sentinel, interrupt_sentinel) = [
context_stack.enter_context(safe_tempfile_path()) for _ in range(4)
]
child_process = open_ipc_subprocess(
[
sys.executable,
file_relative_path(__file__, "compute_log_subprocess.py"),
stdout_pids_file,
stderr_pids_file,
opened_sentinel,
interrupt_sentinel,
]
)
wait_for_file(opened_sentinel)
wait_for_file(stdout_pids_file)
wait_for_file(stderr_pids_file)
with open(opened_sentinel, "r") as opened_sentinel_fd:
assert opened_sentinel_fd.read().startswith("opened_compute_log_subprocess")
with open(stdout_pids_file, "r") as stdout_pids_fd:
stdout_pids_str = stdout_pids_fd.read()
assert stdout_pids_str.startswith("stdout pids:")
stdout_pids = list(
map(
lambda x: int(x) if x != "None" else None,
[x.strip("(),") for x in stdout_pids_str.split(" ")[2:]],
)
)
with open(stderr_pids_file, "r") as stderr_pids_fd:
stderr_pids_str = stderr_pids_fd.read()
assert stderr_pids_str.startswith("stderr pids:")
stderr_pids = list(
map(
lambda x: int(x) if x != "None" else None,
[x.strip("(),") for x in stderr_pids_str.split(" ")[2:]],
)
)
interrupt_ipc_subprocess(child_process)
for stdout_pid in stdout_pids:
if stdout_pid is not None:
wait_for_process(stdout_pid)
for stderr_pid in stderr_pids:
if stderr_pid is not None:
wait_for_process(stderr_pid)
wait_for_file(interrupt_sentinel)
with open(interrupt_sentinel, "r") as fd:
assert fd.read().startswith("compute_log_subprocess_interrupt")
def test_segfault_compute_log_tail(
windows_legacy_stdio_env, # pylint: disable=redefined-outer-name, unused-argument
):
with safe_tempfile_path() as stdout_pids_file:
with safe_tempfile_path() as stderr_pids_file:
child_process = open_ipc_subprocess(
[
sys.executable,
file_relative_path(__file__, "compute_log_subprocess_segfault.py"),
stdout_pids_file,
stderr_pids_file,
]
)
child_process.wait()
wait_for_file(stdout_pids_file)
with open(stdout_pids_file, "r") as stdout_pids_fd:
stdout_pids_str = stdout_pids_fd.read()
assert stdout_pids_str.startswith("stdout pids:")
stdout_pids = list(
map(
lambda x: int(x) if x != "None" else None,
stdout_pids_str.split(" ")[-1].strip("()").split(","),
)
)
wait_for_file(stderr_pids_file)
with open(stderr_pids_file, "r") as stderr_pids_fd:
stderr_pids_str = stderr_pids_fd.read()
assert stderr_pids_str.startswith("stderr pids:")
stderr_pids = list(
map(
lambda x: int(x) if x != "None" else None,
stderr_pids_str.split(" ")[-1].strip("()").split(","),
)
)
for stdout_pid in stdout_pids:
if stdout_pid is not None:
wait_for_process(stdout_pid)
for stderr_pid in stderr_pids:
if stderr_pid is not None:
wait_for_process(stderr_pid)
def test_interrupt_compute_log_tail_grandchild(
windows_legacy_stdio_env, # pylint: disable=redefined-outer-name, unused-argument
):
with ExitStack() as context_stack:
(
child_opened_sentinel,
parent_interrupt_sentinel,
child_started_sentinel,
stdout_pids_file,
stderr_pids_file,
child_interrupt_sentinel,
) = [context_stack.enter_context(safe_tempfile_path()) for _ in range(6)]
parent_process = open_ipc_subprocess(
[
sys.executable,
file_relative_path(__file__, "parent_compute_log_subprocess.py"),
child_opened_sentinel,
parent_interrupt_sentinel,
child_started_sentinel,
stdout_pids_file,
stderr_pids_file,
child_interrupt_sentinel,
]
)
wait_for_file(child_opened_sentinel)
wait_for_file(child_started_sentinel)
wait_for_file(stdout_pids_file)
with open(stdout_pids_file, "r") as stdout_pids_fd:
stdout_pids_str = stdout_pids_fd.read()
assert stdout_pids_str.startswith("stdout pids:")
stdout_pids = list(
map(
lambda x: int(x) if x != "None" else None,
[x.strip("(),") for x in stdout_pids_str.split(" ")[2:]],
)
)
wait_for_file(stderr_pids_file)
with open(stderr_pids_file, "r") as stderr_pids_fd:
stderr_pids_str = stderr_pids_fd.read()
assert stderr_pids_str.startswith("stderr pids:")
stderr_pids = list(
map(
lambda x: int(x) if x != "None" else None,
[x.strip("(),") for x in stderr_pids_str.split(" ")[2:]],
)
)
interrupt_ipc_subprocess(parent_process)
wait_for_file(child_interrupt_sentinel)
with open(child_interrupt_sentinel, "r") as fd:
assert fd.read().startswith("compute_log_subprocess_interrupt")
wait_for_file(parent_interrupt_sentinel)
with open(parent_interrupt_sentinel, "r") as fd:
assert fd.read().startswith("parent_received_keyboard_interrupt")
for stdout_pid in stdout_pids:
if stdout_pid is not None:
wait_for_process(stdout_pid)
for stderr_pid in stderr_pids:
if stderr_pid is not None:
wait_for_process(stderr_pid)
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_asset_store.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_asset_store.py
index ea9947b60..857db931a 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_asset_store.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_asset_store.py
@@ -1,352 +1,352 @@
import os
import pickle
+import tempfile
import pytest
from dagster import (
DagsterInstance,
DagsterInvariantViolationError,
ModeDefinition,
Output,
OutputDefinition,
execute_pipeline,
pipeline,
reexecute_pipeline,
resource,
- seven,
solid,
)
from dagster.core.definitions.events import AssetMaterialization, AssetStoreOperationType
from dagster.core.execution.api import create_execution_plan, execute_plan
from dagster.core.storage.asset_store import (
AssetStore,
custom_path_fs_asset_store,
fs_asset_store,
mem_asset_store,
)
def define_asset_pipeline(asset_store, asset_metadata_dict):
@solid(output_defs=[OutputDefinition(asset_metadata=asset_metadata_dict.get("solid_a"),)],)
def solid_a(_context):
return [1, 2, 3]
@solid(output_defs=[OutputDefinition(asset_metadata=asset_metadata_dict.get("solid_b"),)],)
def solid_b(_context, _df):
return 1
@pipeline(mode_defs=[ModeDefinition("local", resource_defs={"object_manager": asset_store})])
def asset_pipeline():
solid_b(solid_a())
return asset_pipeline
def test_result_output():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = fs_asset_store.configured({"base_dir": tmpdir_path})
pipeline_def = define_asset_pipeline(asset_store, {})
result = execute_pipeline(pipeline_def)
assert result.success
# test output_value
assert result.result_for_solid("solid_a").output_value() == [1, 2, 3]
assert result.result_for_solid("solid_b").output_value() == 1
def test_fs_asset_store():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = fs_asset_store.configured({"base_dir": tmpdir_path})
pipeline_def = define_asset_pipeline(asset_store, {})
result = execute_pipeline(pipeline_def)
assert result.success
asset_store_operation_events = list(
filter(lambda evt: evt.is_asset_store_operation, result.event_list)
)
assert len(asset_store_operation_events) == 3
# SET ASSET for step "solid_a" output "result"
assert (
asset_store_operation_events[0].event_specific_data.op
== AssetStoreOperationType.SET_ASSET
)
filepath_a = os.path.join(tmpdir_path, result.run_id, "solid_a", "result")
assert os.path.isfile(filepath_a)
with open(filepath_a, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3]
# GET ASSET for step "solid_b" input "_df"
assert (
asset_store_operation_events[1].event_specific_data.op
== AssetStoreOperationType.GET_ASSET
)
assert "solid_a" == asset_store_operation_events[1].event_specific_data.step_key
# SET ASSET for step "solid_b" output "result"
assert (
asset_store_operation_events[2].event_specific_data.op
== AssetStoreOperationType.SET_ASSET
)
filepath_b = os.path.join(tmpdir_path, result.run_id, "solid_b", "result")
assert os.path.isfile(filepath_b)
with open(filepath_b, "rb") as read_obj:
assert pickle.load(read_obj) == 1
def test_default_asset_store_reexecution():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
default_asset_store = fs_asset_store.configured({"base_dir": tmpdir_path})
pipeline_def = define_asset_pipeline(default_asset_store, {})
instance = DagsterInstance.ephemeral()
result = execute_pipeline(pipeline_def, instance=instance)
assert result.success
re_result = reexecute_pipeline(
pipeline_def, result.run_id, instance=instance, step_selection=["solid_b"],
)
# re-execution should yield asset_store_operation events instead of intermediate events
get_asset_events = list(
filter(
lambda evt: evt.is_asset_store_operation
and AssetStoreOperationType(evt.event_specific_data.op)
== AssetStoreOperationType.GET_ASSET,
re_result.event_list,
)
)
assert len(get_asset_events) == 1
assert get_asset_events[0].event_specific_data.step_key == "solid_a"
def execute_pipeline_with_steps(pipeline_def, step_keys_to_execute=None):
plan = create_execution_plan(pipeline_def, step_keys_to_execute=step_keys_to_execute)
with DagsterInstance.ephemeral() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=pipeline_def, step_keys_to_execute=step_keys_to_execute,
)
return execute_plan(plan, instance, pipeline_run)
def test_step_subset_with_custom_paths():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = custom_path_fs_asset_store
# pass hardcoded file path via asset_metadata
test_asset_metadata_dict = {
"solid_a": {"path": os.path.join(tmpdir_path, "a")},
"solid_b": {"path": os.path.join(tmpdir_path, "b")},
}
pipeline_def = define_asset_pipeline(asset_store, test_asset_metadata_dict)
events = execute_pipeline_with_steps(pipeline_def)
for evt in events:
assert not evt.is_failure
# when a path is provided via asset store, it's able to run step subset using an execution
# plan when the ascendant outputs were not previously created by dagster-controlled
# computations
step_subset_events = execute_pipeline_with_steps(
pipeline_def, step_keys_to_execute=["solid_b"]
)
for evt in step_subset_events:
assert not evt.is_failure
# only the selected step subset was executed
assert set([evt.step_key for evt in step_subset_events]) == {"solid_b"}
# Asset Materialization events
step_materialization_events = list(
filter(lambda evt: evt.is_step_materialization, step_subset_events)
)
assert len(step_materialization_events) == 1
assert test_asset_metadata_dict["solid_b"]["path"] == (
step_materialization_events[0]
.event_specific_data.materialization.metadata_entries[0]
.entry_data.path
)
def test_asset_store_multi_materialization():
class DummyAssetStore(AssetStore):
def __init__(self):
self.values = {}
def set_asset(self, context, obj):
keys = tuple(context.get_run_scoped_output_identifier())
self.values[keys] = obj
yield AssetMaterialization(asset_key="yield_one")
yield AssetMaterialization(asset_key="yield_two")
def get_asset(self, context):
keys = tuple(context.get_run_scoped_output_identifier())
return self.values[keys]
def has_asset(self, context):
keys = tuple(context.get_run_scoped_output_identifier())
return keys in self.values
@resource
def dummy_asset_store(_):
return DummyAssetStore()
@solid(output_defs=[OutputDefinition(manager_key="store")])
def solid_a(_context):
return 1
@solid()
def solid_b(_context, a):
assert a == 1
@pipeline(mode_defs=[ModeDefinition(resource_defs={"store": dummy_asset_store})])
def asset_pipeline():
solid_b(solid_a())
result = execute_pipeline(asset_pipeline)
assert result.success
# Asset Materialization events
step_materialization_events = list(
filter(lambda evt: evt.is_step_materialization, result.event_list)
)
assert len(step_materialization_events) == 2
def test_different_asset_stores():
@solid(output_defs=[OutputDefinition(manager_key="store")],)
def solid_a(_context):
return 1
@solid()
def solid_b(_context, a):
assert a == 1
@pipeline(mode_defs=[ModeDefinition(resource_defs={"store": mem_asset_store})])
def asset_pipeline():
solid_b(solid_a())
assert execute_pipeline(asset_pipeline).success
@resource
def my_asset_store(_):
pass
def test_set_asset_store_and_intermediate_storage():
from dagster import intermediate_storage, fs_intermediate_storage
@intermediate_storage()
def my_intermediate_storage(_):
pass
with pytest.raises(DagsterInvariantViolationError):
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={"object_manager": my_asset_store},
intermediate_storage_defs=[my_intermediate_storage, fs_intermediate_storage],
)
]
)
def my_pipeline():
pass
execute_pipeline(my_pipeline)
def test_set_asset_store_configure_intermediate_storage():
with pytest.raises(DagsterInvariantViolationError):
@pipeline(mode_defs=[ModeDefinition(resource_defs={"object_manager": my_asset_store})])
def my_pipeline():
pass
execute_pipeline(my_pipeline, run_config={"intermediate_storage": {"filesystem": {}}})
def test_fan_in():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = fs_asset_store.configured({"base_dir": tmpdir_path})
@solid
def input_solid1(_):
return 1
@solid
def input_solid2(_):
return 2
@solid
def solid1(_, input1):
assert input1 == [1, 2]
@pipeline(mode_defs=[ModeDefinition(resource_defs={"object_manager": asset_store})])
def my_pipeline():
solid1(input1=[input_solid1(), input_solid2()])
execute_pipeline(my_pipeline)
def get_fake_solid():
@solid
def fake_solid(_):
pass
return fake_solid
def test_asset_store_optional_output():
- with seven.TemporaryDirectory() as tmpdir_dir:
+ with tempfile.TemporaryDirectory() as tmpdir_dir:
asset_store = fs_asset_store.configured({"base_dir": tmpdir_dir})
skip = True
@solid(output_defs=[OutputDefinition(is_required=False)])
def solid_a(_context):
if not skip:
yield Output([1, 2])
@solid
def solid_skipped(_context, array):
return array
@pipeline(mode_defs=[ModeDefinition("local", resource_defs={"asset_store": asset_store})])
def asset_pipeline_optional_output():
solid_skipped(solid_a())
result = execute_pipeline(asset_pipeline_optional_output)
assert result.success
assert result.result_for_solid("solid_skipped").skipped
def test_asset_store_optional_output_path_exists():
- with seven.TemporaryDirectory() as tmpdir_dir:
+ with tempfile.TemporaryDirectory() as tmpdir_dir:
asset_store = custom_path_fs_asset_store.configured({"base_dir": tmpdir_dir})
filepath = os.path.join(tmpdir_dir, "foo")
# file exists already
with open(filepath, "wb") as write_obj:
pickle.dump([1], write_obj)
assert os.path.exists(filepath)
skip = True
@solid(output_defs=[OutputDefinition(is_required=False, asset_metadata={"path": filepath})])
def solid_a(_context):
if not skip:
yield Output([1, 2])
@solid(output_defs=[OutputDefinition(asset_metadata={"path": "bar"})])
def solid_b(_context, array):
return array
@pipeline(mode_defs=[ModeDefinition("local", resource_defs={"asset_store": asset_store})])
def asset_pipeline_optional_output_path_exists():
solid_b(solid_a())
result = execute_pipeline(asset_pipeline_optional_output_path_exists)
assert result.success
# won't skip solid_b because filepath exists
assert result.result_for_solid("solid_b").skipped
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_assets.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_assets.py
index 4d1ebe1ab..866c8df65 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_assets.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_assets.py
@@ -1,277 +1,277 @@
+import tempfile
import time
from contextlib import contextmanager
import pytest
from dagster import (
AssetKey,
AssetMaterialization,
DagsterEventType,
Output,
execute_pipeline,
file_relative_path,
pipeline,
- seven,
solid,
)
from dagster.core.definitions.events import parse_asset_key_string, validate_asset_key_string
from dagster.core.errors import DagsterInvalidAssetKey
from dagster.core.events import DagsterEvent, StepMaterializationData
from dagster.core.events.log import DagsterEventRecord, EventRecord
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher.sync_in_memory_run_launcher import SyncInMemoryRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.event_log import (
ConsolidatedSqliteEventLogStorage,
InMemoryEventLogStorage,
)
from dagster.core.storage.event_log.migration import migrate_asset_key_data
from dagster.core.storage.noop_compute_log_manager import NoOpComputeLogManager
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import InMemoryRunStorage
from dagster.utils.test import copy_directory
def get_instance(temp_dir, event_log_storage):
return DagsterInstance(
instance_type=InstanceType.EPHEMERAL,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=InMemoryRunStorage(),
event_storage=event_log_storage,
compute_log_manager=NoOpComputeLogManager(),
run_coordinator=DefaultRunCoordinator(),
run_launcher=SyncInMemoryRunLauncher(),
)
@contextmanager
def create_in_memory_event_log_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
asset_storage = InMemoryEventLogStorage()
instance = get_instance(temp_dir, asset_storage)
yield [instance, asset_storage]
@contextmanager
def create_consolidated_sqlite_event_log_instance():
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
asset_storage = ConsolidatedSqliteEventLogStorage(temp_dir)
instance = get_instance(temp_dir, asset_storage)
yield [instance, asset_storage]
asset_test = pytest.mark.parametrize(
"asset_aware_context",
[create_in_memory_event_log_instance, create_consolidated_sqlite_event_log_instance,],
)
@solid
def solid_one(_):
yield AssetMaterialization(asset_key=AssetKey("asset_1"))
yield Output(1)
@solid
def solid_two(_):
yield AssetMaterialization(asset_key=AssetKey("asset_2"))
yield AssetMaterialization(asset_key=AssetKey(["path", "to", "asset_3"]))
yield Output(1)
@solid
def solid_normalization(_):
yield AssetMaterialization(asset_key="path/to-asset_4")
yield Output(1)
@pipeline
def pipeline_one():
solid_one()
@pipeline
def pipeline_two():
solid_one()
solid_two()
@pipeline
def pipeline_normalization():
solid_normalization()
def test_validate_asset_key_string():
assert validate_asset_key_string("H3_lL0.h-1") == "H3_lL0.h-1"
with pytest.raises(DagsterInvalidAssetKey):
validate_asset_key_string("(Hello)")
def test_structured_asset_key():
asset_parsed = AssetKey(parse_asset_key_string("(Hello)"))
assert len(asset_parsed.path) == 1
assert asset_parsed.path[0] == "Hello"
asset_structured = AssetKey(["(Hello)"])
assert len(asset_structured.path) == 1
assert asset_structured.path[0] == "(Hello)"
def test_parse_asset_key_string():
assert parse_asset_key_string("foo.bar_b-az") == ["foo", "bar_b", "az"]
@asset_test
def test_asset_keys(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
execute_pipeline(pipeline_one, instance=instance)
execute_pipeline(pipeline_two, instance=instance)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 3
assert set([asset_key.to_string() for asset_key in asset_keys]) == set(
['["asset_1"]', '["asset_2"]', '["path", "to", "asset_3"]']
)
prefixed_keys = event_log_storage.get_all_asset_keys(prefix_path=["asset"])
assert len(prefixed_keys) == 2
@asset_test
def test_has_asset_key(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
execute_pipeline(pipeline_one, instance=instance)
execute_pipeline(pipeline_two, instance=instance)
assert event_log_storage.has_asset_key(AssetKey(["path", "to", "asset_3"]))
assert not event_log_storage.has_asset_key(AssetKey(["path", "to", "bogus", "asset"]))
@asset_test
def test_asset_events(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
execute_pipeline(pipeline_one, instance=instance)
execute_pipeline(pipeline_two, instance=instance)
asset_events = event_log_storage.get_asset_events(AssetKey("asset_1"))
assert len(asset_events) == 2
for event in asset_events:
assert isinstance(event, EventRecord)
assert event.is_dagster_event
assert event.dagster_event.event_type == DagsterEventType.STEP_MATERIALIZATION
assert event.dagster_event.asset_key
asset_events = event_log_storage.get_asset_events(AssetKey(["path", "to", "asset_3"]))
assert len(asset_events) == 1
@asset_test
def test_asset_run_ids(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
one = execute_pipeline(pipeline_one, instance=instance)
two = execute_pipeline(pipeline_two, instance=instance)
run_ids = event_log_storage.get_asset_run_ids(AssetKey("asset_1"))
assert set(run_ids) == set([one.run_id, two.run_id])
@asset_test
def test_asset_normalization(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
execute_pipeline(pipeline_normalization, instance=instance)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 1
asset_key = asset_keys[0]
assert asset_key.to_string() == '["path", "to", "asset_4"]'
assert asset_key.path == ["path", "to", "asset_4"]
@asset_test
def test_asset_wipe(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
one = execute_pipeline(pipeline_one, instance=instance)
execute_pipeline(pipeline_two, instance=instance)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 3
log_count = len(event_log_storage.get_logs_for_run(one.run_id))
instance.wipe_assets(asset_keys)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 0
assert log_count == len(event_log_storage.get_logs_for_run(one.run_id))
execute_pipeline(pipeline_one, instance=instance)
execute_pipeline(pipeline_two, instance=instance)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 3
instance.wipe_assets([AssetKey(["path", "to", "asset_3"])])
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 2
@asset_test
def test_asset_secondary_index(asset_aware_context):
with asset_aware_context() as ctx:
instance, event_log_storage = ctx
execute_pipeline(pipeline_one, instance=instance)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 1
migrate_asset_key_data(event_log_storage)
two = execute_pipeline(pipeline_two, instance=instance)
two_two = execute_pipeline(pipeline_two, instance=instance)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 3
event_log_storage.delete_events(two.run_id)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 3
event_log_storage.delete_events(two_two.run_id)
asset_keys = event_log_storage.get_all_asset_keys()
assert len(asset_keys) == 1
def test_asset_key_structure():
src_dir = file_relative_path(__file__, "compat_tests/snapshot_0_9_16_asset_key_structure")
with copy_directory(src_dir) as test_dir:
asset_storage = ConsolidatedSqliteEventLogStorage(test_dir)
asset_keys = asset_storage.get_all_asset_keys()
assert len(asset_keys) == 5
# get a structured asset key
asset_key = AssetKey(["dashboards", "cost_dashboard"])
# check that backcompat events are read
assert asset_storage.has_asset_key(asset_key)
events = asset_storage.get_asset_events(asset_key)
assert len(events) == 1
run_ids = asset_storage.get_asset_run_ids(asset_key)
assert len(run_ids) == 1
# check that backcompat events are merged with newly stored events
run_id = "fake_run_id"
asset_storage.store_event(_materialization_event_record(run_id, asset_key))
assert asset_storage.has_asset_key(asset_key)
events = asset_storage.get_asset_events(asset_key)
assert len(events) == 2
run_ids = asset_storage.get_asset_run_ids(asset_key)
assert len(run_ids) == 2
def _materialization_event_record(run_id, asset_key):
return DagsterEventRecord(
None,
"",
"debug",
"",
run_id,
time.time() - 25,
step_key="my_step_key",
pipeline_name="my_pipeline",
dagster_event=DagsterEvent(
DagsterEventType.STEP_MATERIALIZATION.value,
"my_pipeline",
step_key="my_step_key",
event_specific_data=StepMaterializationData(AssetMaterialization(asset_key=asset_key)),
),
)
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_event_log.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_event_log.py
index a44b508a0..6407bb23d 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_event_log.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_event_log.py
@@ -1,454 +1,454 @@
import os
import sys
+import tempfile
import time
import traceback
from contextlib import contextmanager
import pytest
import sqlalchemy
-from dagster import seven
from dagster.core.definitions import AssetMaterialization, ExpectationResult
from dagster.core.errors import DagsterEventLogInvalidForRun
from dagster.core.events import (
DagsterEvent,
DagsterEventType,
EngineEventData,
StepExpectationResultData,
StepMaterializationData,
)
from dagster.core.events.log import DagsterEventRecord
from dagster.core.execution.plan.objects import StepFailureData, StepSuccessData
from dagster.core.storage.event_log import (
ConsolidatedSqliteEventLogStorage,
InMemoryEventLogStorage,
SqlEventLogStorageMetadata,
SqlEventLogStorageTable,
SqliteEventLogStorage,
)
from dagster.core.storage.sql import create_engine
from dagster.seven import multiprocessing
@contextmanager
def create_in_memory_event_log_storage():
yield InMemoryEventLogStorage()
@contextmanager
def create_sqlite_run_event_logstorage():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
yield SqliteEventLogStorage(tmpdir_path)
@contextmanager
def create_consolidated_sqlite_run_event_log_storage():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
yield ConsolidatedSqliteEventLogStorage(tmpdir_path)
event_storage_test = pytest.mark.parametrize(
"event_storage_factory_cm_fn",
[
create_in_memory_event_log_storage,
create_sqlite_run_event_logstorage,
create_consolidated_sqlite_run_event_log_storage,
],
)
@event_storage_test
def test_init_log_storage(event_storage_factory_cm_fn):
with event_storage_factory_cm_fn() as storage:
if isinstance(storage, InMemoryEventLogStorage):
assert not storage.is_persistent
elif isinstance(storage, (SqliteEventLogStorage, ConsolidatedSqliteEventLogStorage)):
assert storage.is_persistent
else:
raise Exception("Invalid event storage type")
@event_storage_test
def test_log_storage_run_not_found(event_storage_factory_cm_fn):
with event_storage_factory_cm_fn() as storage:
assert storage.get_logs_for_run("bar") == []
@event_storage_test
def test_event_log_storage_store_events_and_wipe(event_storage_factory_cm_fn):
with event_storage_factory_cm_fn() as storage:
assert len(storage.get_logs_for_run("foo")) == 0
storage.store_event(
DagsterEventRecord(
None,
"Message2",
"debug",
"",
"foo",
time.time(),
dagster_event=DagsterEvent(
DagsterEventType.ENGINE_EVENT.value,
"nonce",
event_specific_data=EngineEventData.in_process(999),
),
)
)
assert len(storage.get_logs_for_run("foo")) == 1
assert storage.get_stats_for_run("foo")
storage.wipe()
assert len(storage.get_logs_for_run("foo")) == 0
@event_storage_test
def test_event_log_storage_store_with_multiple_runs(event_storage_factory_cm_fn):
with event_storage_factory_cm_fn() as storage:
runs = ["foo", "bar", "baz"]
for run_id in runs:
assert len(storage.get_logs_for_run(run_id)) == 0
storage.store_event(
DagsterEventRecord(
None,
"Message2",
"debug",
"",
run_id,
time.time(),
dagster_event=DagsterEvent(
DagsterEventType.STEP_SUCCESS.value,
"nonce",
event_specific_data=StepSuccessData(duration_ms=100.0),
),
)
)
for run_id in runs:
assert len(storage.get_logs_for_run(run_id)) == 1
assert storage.get_stats_for_run(run_id).steps_succeeded == 1
storage.wipe()
for run_id in runs:
assert len(storage.get_logs_for_run(run_id)) == 0
@event_storage_test
def test_event_log_storage_watch(event_storage_factory_cm_fn):
def evt(name):
return DagsterEventRecord(
None,
name,
"debug",
"",
"foo",
time.time(),
dagster_event=DagsterEvent(
DagsterEventType.ENGINE_EVENT.value,
"nonce",
event_specific_data=EngineEventData.in_process(999),
),
)
with event_storage_factory_cm_fn() as storage:
watched = []
watcher = lambda x: watched.append(x) # pylint: disable=unnecessary-lambda
assert len(storage.get_logs_for_run("foo")) == 0
storage.store_event(evt("Message1"))
assert len(storage.get_logs_for_run("foo")) == 1
assert len(watched) == 0
storage.watch("foo", 0, watcher)
storage.store_event(evt("Message2"))
storage.store_event(evt("Message3"))
storage.store_event(evt("Message4"))
attempts = 10
while len(watched) < 3 and attempts > 0:
time.sleep(0.1)
attempts -= 1
storage.end_watch("foo", watcher)
time.sleep(0.3) # this value scientifically selected from a range of attractive values
storage.store_event(evt("Message5"))
assert len(storage.get_logs_for_run("foo")) == 5
assert len(watched) == 3
storage.delete_events("foo")
assert len(storage.get_logs_for_run("foo")) == 0
assert len(watched) == 3
@event_storage_test
def test_event_log_storage_pagination(event_storage_factory_cm_fn):
def evt(name):
return DagsterEventRecord(
None,
name,
"debug",
"",
"foo",
time.time(),
dagster_event=DagsterEvent(
DagsterEventType.ENGINE_EVENT.value,
"nonce",
event_specific_data=EngineEventData.in_process(999),
),
)
with event_storage_factory_cm_fn() as storage:
storage.store_event(evt("Message_0"))
storage.store_event(evt("Message_1"))
storage.store_event(evt("Message_2"))
assert len(storage.get_logs_for_run("foo")) == 3
assert len(storage.get_logs_for_run("foo", -1)) == 3
assert len(storage.get_logs_for_run("foo", 0)) == 2
assert len(storage.get_logs_for_run("foo", 1)) == 1
assert len(storage.get_logs_for_run("foo", 2)) == 0
@event_storage_test
def test_event_log_delete(event_storage_factory_cm_fn):
with event_storage_factory_cm_fn() as storage:
assert len(storage.get_logs_for_run("foo")) == 0
storage.store_event(
DagsterEventRecord(
None,
"Message2",
"debug",
"",
"foo",
time.time(),
dagster_event=DagsterEvent(
DagsterEventType.ENGINE_EVENT.value,
"nonce",
event_specific_data=EngineEventData.in_process(999),
),
)
)
assert len(storage.get_logs_for_run("foo")) == 1
assert storage.get_stats_for_run("foo")
storage.delete_events("foo")
assert len(storage.get_logs_for_run("foo")) == 0
@event_storage_test
def test_event_log_get_stats_without_start_and_success(event_storage_factory_cm_fn):
# When an event log doesn't have a PIPELINE_START or PIPELINE_SUCCESS | PIPELINE_FAILURE event,
# we want to ensure storage.get_stats_for_run(...) doesn't throw an error.
with event_storage_factory_cm_fn() as storage:
assert len(storage.get_logs_for_run("foo")) == 0
assert storage.get_stats_for_run("foo")
def test_filesystem_event_log_storage_run_corrupted():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
storage = SqliteEventLogStorage(tmpdir_path)
# URL begins sqlite:///
# pylint: disable=protected-access
with open(os.path.abspath(storage.conn_string_for_run_id("foo")[10:]), "w") as fd:
fd.write("some nonsense")
with pytest.raises(sqlalchemy.exc.DatabaseError):
storage.get_logs_for_run("foo")
def test_filesystem_event_log_storage_run_corrupted_bad_data():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
storage = SqliteEventLogStorage(tmpdir_path)
SqlEventLogStorageMetadata.create_all(create_engine(storage.conn_string_for_run_id("foo")))
with storage.connect("foo") as conn:
event_insert = SqlEventLogStorageTable.insert().values( # pylint: disable=no-value-for-parameter
run_id="foo", event="{bar}", dagster_event_type=None, timestamp=None
)
conn.execute(event_insert)
with pytest.raises(DagsterEventLogInvalidForRun):
storage.get_logs_for_run("foo")
SqlEventLogStorageMetadata.create_all(create_engine(storage.conn_string_for_run_id("bar")))
with storage.connect("bar") as conn: # pylint: disable=protected-access
event_insert = SqlEventLogStorageTable.insert().values( # pylint: disable=no-value-for-parameter
run_id="bar", event="3", dagster_event_type=None, timestamp=None
)
conn.execute(event_insert)
with pytest.raises(DagsterEventLogInvalidForRun):
storage.get_logs_for_run("bar")
def cmd(exceptions, tmpdir_path):
storage = SqliteEventLogStorage(tmpdir_path)
try:
storage.get_logs_for_run_by_log_id("foo")
except Exception as exc: # pylint: disable=broad-except
exceptions.put(exc)
exc_info = sys.exc_info()
traceback.print_tb(exc_info[2])
def test_concurrent_sqlite_event_log_connections():
exceptions = multiprocessing.Queue()
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
ps = []
for _ in range(5):
ps.append(multiprocessing.Process(target=cmd, args=(exceptions, tmpdir_path)))
for p in ps:
p.start()
j = 0
for p in ps:
p.join()
j += 1
assert j == 5
excs = []
while not exceptions.empty():
excs.append(exceptions.get())
assert not excs, excs
@event_storage_test
def test_event_log_step_stats(event_storage_factory_cm_fn):
# When an event log doesn't have a PIPELINE_START or PIPELINE_SUCCESS | PIPELINE_FAILURE event,
# we want to ensure storage.get_stats_for_run(...) doesn't throw an error.
run_id = "foo"
with event_storage_factory_cm_fn() as storage:
for record in _stats_records(run_id=run_id):
storage.store_event(record)
step_stats = storage.get_step_stats_for_run(run_id)
assert len(step_stats) == 4
a_stats = [stats for stats in step_stats if stats.step_key == "A"][0]
assert a_stats.step_key == "A"
assert a_stats.status.value == "SUCCESS"
assert a_stats.end_time - a_stats.start_time == 100
b_stats = [stats for stats in step_stats if stats.step_key == "B"][0]
assert b_stats.step_key == "B"
assert b_stats.status.value == "FAILURE"
assert b_stats.end_time - b_stats.start_time == 50
c_stats = [stats for stats in step_stats if stats.step_key == "C"][0]
assert c_stats.step_key == "C"
assert c_stats.status.value == "SKIPPED"
assert c_stats.end_time - c_stats.start_time == 25
d_stats = [stats for stats in step_stats if stats.step_key == "D"][0]
assert d_stats.step_key == "D"
assert d_stats.status.value == "SUCCESS"
assert d_stats.end_time - d_stats.start_time == 150
assert len(d_stats.materializations) == 3
assert len(d_stats.expectation_results) == 2
def _stats_records(run_id):
now = time.time()
return [
_event_record(run_id, "A", now - 325, DagsterEventType.STEP_START),
_event_record(
run_id,
"A",
now - 225,
DagsterEventType.STEP_SUCCESS,
StepSuccessData(duration_ms=100000.0),
),
_event_record(run_id, "B", now - 225, DagsterEventType.STEP_START),
_event_record(
run_id,
"B",
now - 175,
DagsterEventType.STEP_FAILURE,
StepFailureData(error=None, user_failure_data=None),
),
_event_record(run_id, "C", now - 175, DagsterEventType.STEP_START),
_event_record(run_id, "C", now - 150, DagsterEventType.STEP_SKIPPED),
_event_record(run_id, "D", now - 150, DagsterEventType.STEP_START),
_event_record(
run_id,
"D",
now - 125,
DagsterEventType.STEP_MATERIALIZATION,
StepMaterializationData(AssetMaterialization(asset_key="mat_1")),
),
_event_record(
run_id,
"D",
now - 100,
DagsterEventType.STEP_EXPECTATION_RESULT,
StepExpectationResultData(ExpectationResult(success=True, label="exp 1")),
),
_event_record(
run_id,
"D",
now - 75,
DagsterEventType.STEP_MATERIALIZATION,
StepMaterializationData(AssetMaterialization(asset_key="mat_2")),
),
_event_record(
run_id,
"D",
now - 50,
DagsterEventType.STEP_EXPECTATION_RESULT,
StepExpectationResultData(ExpectationResult(success=False, label="exp 2")),
),
_event_record(
run_id,
"D",
now - 25,
DagsterEventType.STEP_MATERIALIZATION,
StepMaterializationData(AssetMaterialization(asset_key="mat_3")),
),
_event_record(
run_id, "D", now, DagsterEventType.STEP_SUCCESS, StepSuccessData(duration_ms=150000.0)
),
]
def _event_record(run_id, step_key, timestamp, event_type, event_specific_data=None):
pipeline_name = "pipeline_name"
return DagsterEventRecord(
None,
"",
"debug",
"",
run_id,
timestamp,
step_key=step_key,
pipeline_name=pipeline_name,
dagster_event=DagsterEvent(
event_type.value,
pipeline_name,
step_key=step_key,
event_specific_data=event_specific_data,
),
)
def test_secondary_index():
with create_consolidated_sqlite_run_event_log_storage() as storage:
# Only consolidated_sqlite, postgres storage support secondary indexes
assert not storage.has_secondary_index("A")
assert not storage.has_secondary_index("B")
assert "A" in storage._secondary_index_cache # pylint: disable=protected-access
assert "B" in storage._secondary_index_cache # pylint: disable=protected-access
storage.enable_secondary_index("A")
assert "A" not in storage._secondary_index_cache # pylint: disable=protected-access
assert "B" in storage._secondary_index_cache # pylint: disable=protected-access
assert storage.has_secondary_index("A")
assert "A" in storage._secondary_index_cache # pylint: disable=protected-access
assert "B" in storage._secondary_index_cache # pylint: disable=protected-access
assert not storage.has_secondary_index("B")
storage.enable_secondary_index("B")
assert "A" in storage._secondary_index_cache # pylint: disable=protected-access
assert "B" not in storage._secondary_index_cache # pylint: disable=protected-access
assert storage.has_secondary_index("A")
assert storage.has_secondary_index("B")
assert "A" in storage._secondary_index_cache # pylint: disable=protected-access
assert "B" in storage._secondary_index_cache # pylint: disable=protected-access
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_fs_object_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_fs_object_manager.py
index a0ee40ac5..e9fef4190 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_fs_object_manager.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_fs_object_manager.py
@@ -1,63 +1,64 @@
import os
import pickle
+import tempfile
-from dagster import ModeDefinition, execute_pipeline, pipeline, seven, solid
+from dagster import ModeDefinition, execute_pipeline, pipeline, solid
from dagster.core.definitions.events import AssetStoreOperationType
from dagster.core.storage.fs_object_manager import fs_object_manager
def define_pipeline(object_manager):
@solid
def solid_a(_context):
return [1, 2, 3]
@solid
def solid_b(_context, _df):
return 1
@pipeline(mode_defs=[ModeDefinition("local", resource_defs={"object_manager": object_manager})])
def asset_pipeline():
solid_b(solid_a())
return asset_pipeline
def test_fs_object_manager():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = fs_object_manager.configured({"base_dir": tmpdir_path})
pipeline_def = define_pipeline(asset_store)
result = execute_pipeline(pipeline_def)
assert result.success
asset_store_operation_events = list(
filter(lambda evt: evt.is_asset_store_operation, result.event_list)
)
assert len(asset_store_operation_events) == 3
# SET ASSET for step "solid_a" output "result"
assert (
asset_store_operation_events[0].event_specific_data.op
== AssetStoreOperationType.SET_ASSET
)
filepath_a = os.path.join(tmpdir_path, result.run_id, "solid_a", "result")
assert os.path.isfile(filepath_a)
with open(filepath_a, "rb") as read_obj:
assert pickle.load(read_obj) == [1, 2, 3]
# GET ASSET for step "solid_b" input "_df"
assert (
asset_store_operation_events[1].event_specific_data.op
== AssetStoreOperationType.GET_ASSET
)
assert "solid_a" == asset_store_operation_events[1].event_specific_data.step_key
# SET ASSET for step "solid_b" output "result"
assert (
asset_store_operation_events[2].event_specific_data.op
== AssetStoreOperationType.SET_ASSET
)
filepath_b = os.path.join(tmpdir_path, result.run_id, "solid_b", "result")
assert os.path.isfile(filepath_b)
with open(filepath_b, "rb") as read_obj:
assert pickle.load(read_obj) == 1
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_input_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_input_manager.py
index 35a090625..f46ab658a 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_input_manager.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_input_manager.py
@@ -1,272 +1,273 @@
+import tempfile
+
from dagster import (
DagsterInstance,
EventMetadataEntry,
InputDefinition,
InputManagerDefinition,
ModeDefinition,
ObjectManager,
OutputDefinition,
PythonObjectDagsterType,
execute_pipeline,
pipeline,
resource,
- seven,
solid,
)
from dagster.core.definitions.events import Failure, RetryRequested
from dagster.core.instance import InstanceRef
from dagster.core.storage.input_manager import input_manager
def test_validate_inputs():
@input_manager
def my_loader(_, _resource_config):
return 5
@solid(
input_defs=[
InputDefinition(
"input1", dagster_type=PythonObjectDagsterType(int), manager_key="my_loader"
)
]
)
def my_solid(_, input1):
return input1
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_loader": my_loader})])
def my_pipeline():
my_solid()
execute_pipeline(my_pipeline)
def test_root_input_manager():
@input_manager
def my_hardcoded_csv_loader(_context, _resource_config):
return 5
@solid(input_defs=[InputDefinition("input1", manager_key="my_loader")])
def solid1(_, input1):
assert input1 == 5
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_loader": my_hardcoded_csv_loader})])
def my_pipeline():
solid1()
execute_pipeline(my_pipeline)
def test_configurable_root_input_manager():
@input_manager(config_schema={"base_dir": str}, input_config_schema={"value": int})
def my_configurable_csv_loader(context, resource_config):
assert resource_config["base_dir"] == "abc"
return context.config["value"]
@solid(input_defs=[InputDefinition("input1", manager_key="my_loader")])
def solid1(_, input1):
assert input1 == 5
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_loader": my_configurable_csv_loader})])
def my_configurable_pipeline():
solid1()
execute_pipeline(
my_configurable_pipeline,
run_config={
"solids": {"solid1": {"inputs": {"input1": {"value": 5}}}},
"resources": {"my_loader": {"config": {"base_dir": "abc"}}},
},
)
def test_override_object_manager():
metadata = {"name": 5}
class MyObjectManager(ObjectManager):
def handle_output(self, context, obj):
pass
def load_input(self, context):
assert False, "should not be called"
@resource
def my_object_manager(_):
return MyObjectManager()
@solid(
output_defs=[
OutputDefinition(name="my_output", manager_key="my_object_manager", metadata=metadata)
]
)
def solid1(_):
return 1
@solid(input_defs=[InputDefinition("input1", manager_key="spark_loader")])
def solid2(_, input1):
assert input1 == 5
@input_manager
def spark_table_loader(context, _resource_config):
output = context.upstream_output
assert output.metadata == metadata
assert output.name == "my_output"
assert output.step_key == "solid1"
assert context.pipeline_name == "my_pipeline"
assert context.solid_def.name == solid2.name
return 5
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={
"my_object_manager": my_object_manager,
"spark_loader": spark_table_loader,
}
)
]
)
def my_pipeline():
solid2(solid1())
execute_pipeline(my_pipeline)
def test_configured():
@input_manager(
config_schema={"base_dir": str},
description="abc",
input_config_schema={"format": str},
required_resource_keys={"r1", "r2"},
version="123",
)
def my_input_manager(_):
pass
configured_input_manager = my_input_manager.configured({"base_dir": "/a/b/c"})
assert isinstance(configured_input_manager, InputManagerDefinition)
assert configured_input_manager.description == my_input_manager.description
assert configured_input_manager.input_config_schema == my_input_manager.input_config_schema
assert (
configured_input_manager.required_resource_keys == my_input_manager.required_resource_keys
)
assert configured_input_manager.version is None
def test_input_manager_with_failure():
_called = False
@input_manager
def should_fail(_, _resource_config):
raise Failure(
description="Foolure",
metadata_entries=[
EventMetadataEntry.text(label="label", text="text", description="description")
],
)
@solid
def emit_str(_):
return "emit"
@solid(input_defs=[InputDefinition("_fail_input", manager_key="should_fail")])
def fail_on_input(_, _fail_input):
_called = True
@pipeline(mode_defs=[ModeDefinition(resource_defs={"should_fail": should_fail})])
def simple():
fail_on_input(emit_str())
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
result = execute_pipeline(simple, instance=instance, raise_on_error=False)
assert not result.success
failure_data = result.result_for_solid("fail_on_input").failure_data
assert failure_data.error.cls_name == "Failure"
assert failure_data.user_failure_data.description == "Foolure"
assert failure_data.user_failure_data.metadata_entries[0].label == "label"
assert failure_data.user_failure_data.metadata_entries[0].entry_data.text == "text"
assert failure_data.user_failure_data.metadata_entries[0].description == "description"
assert not _called
def test_input_manager_with_retries():
_called = False
_count = {"total": 0}
@input_manager
def should_succeed(_, _resource_config):
if _count["total"] < 2:
_count["total"] += 1
raise RetryRequested(max_retries=3)
return "foo"
@input_manager
def should_retry(_, _resource_config):
raise RetryRequested(max_retries=3)
@input_manager
def should_not_execute(_, _resource_config):
_called = True
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={
"should_succeed": should_succeed,
"should_not_execute": should_not_execute,
"should_retry": should_retry,
}
)
]
)
def simple():
@solid
def source_solid(_):
return "foo"
@solid(input_defs=[InputDefinition("solid_input", manager_key="should_succeed")])
def take_input_1(_, solid_input):
return solid_input
@solid(input_defs=[InputDefinition("solid_input", manager_key="should_retry")])
def take_input_2(_, solid_input):
return solid_input
@solid(input_defs=[InputDefinition("solid_input", manager_key="should_not_execute")])
def take_input_3(_, solid_input):
return solid_input
take_input_3(take_input_2(take_input_1(source_solid())))
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
result = execute_pipeline(simple, instance=instance, raise_on_error=False)
step_stats = instance.get_run_step_stats(result.run_id)
assert len(step_stats) == 3
step_stats_1 = instance.get_run_step_stats(result.run_id, step_keys=["take_input_1"])
assert len(step_stats_1) == 1
step_stat_1 = step_stats_1[0]
assert step_stat_1.status.value == "SUCCESS"
assert step_stat_1.attempts == 3
step_stats_2 = instance.get_run_step_stats(result.run_id, step_keys=["take_input_2"])
assert len(step_stats_2) == 1
step_stat_2 = step_stats_2[0]
assert step_stat_2.status.value == "FAILURE"
assert step_stat_2.attempts == 4
step_stats_3 = instance.get_run_step_stats(result.run_id, step_keys=["take_input_3"])
assert len(step_stats_3) == 0
assert _called == False
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py
index 70d6ca4b7..48e42beac 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_file_manager.py
@@ -1,67 +1,68 @@
+import tempfile
from contextlib import contextmanager
-from dagster import LocalFileHandle, ModeDefinition, execute_pipeline, pipeline, seven, solid
+from dagster import LocalFileHandle, ModeDefinition, execute_pipeline, pipeline, solid
from dagster.core.instance import DagsterInstance
from dagster.core.storage.file_manager import LocalFileManager, local_file_manager
from dagster.utils.temp_file import get_temp_file_handle_with_data
@contextmanager
def my_local_file_manager(instance, run_id):
manager = None
try:
manager = LocalFileManager.for_instance(instance, run_id)
yield manager
finally:
if manager:
manager.delete_local_temp()
def test_basic_file_manager_copy_handle_to_local_temp():
instance = DagsterInstance.ephemeral()
foo_data = "foo".encode()
with get_temp_file_handle_with_data(foo_data) as foo_handle:
with my_local_file_manager(instance, "0") as manager:
local_temp = manager.copy_handle_to_local_temp(foo_handle)
assert local_temp != foo_handle.path
with open(local_temp, "rb") as ff:
assert ff.read() == foo_data
def test_basic_file_manager_execute():
called = {}
@solid(required_resource_keys={"file_manager"})
def file_handle(context):
foo_bytes = "foo".encode()
file_handle = context.resources.file_manager.write_data(foo_bytes)
assert isinstance(file_handle, LocalFileHandle)
with open(file_handle.path, "rb") as handle_obj:
assert foo_bytes == handle_obj.read()
with context.resources.file_manager.read(file_handle) as handle_obj:
assert foo_bytes == handle_obj.read()
file_handle = context.resources.file_manager.write_data(foo_bytes, ext="foo")
assert isinstance(file_handle, LocalFileHandle)
assert file_handle.path[-4:] == ".foo"
with open(file_handle.path, "rb") as handle_obj:
assert foo_bytes == handle_obj.read()
with context.resources.file_manager.read(file_handle) as handle_obj:
assert foo_bytes == handle_obj.read()
called["yup"] = True
@pipeline(mode_defs=[ModeDefinition(resource_defs={"file_manager": local_file_manager})])
def pipe():
return file_handle()
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
result = execute_pipeline(
pipe, run_config={"resources": {"file_manager": {"config": {"base_dir": temp_dir}}}}
)
assert result.success
assert called["yup"]
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_instance.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_instance.py
index 624fdda06..a0e4553b0 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_instance.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_local_instance.py
@@ -1,235 +1,235 @@
import os
+import tempfile
import types
import pytest
import yaml
from dagster import (
DagsterEventType,
DagsterInvalidConfigError,
InputDefinition,
Output,
OutputDefinition,
PipelineRun,
check,
execute_pipeline,
pipeline,
- seven,
solid,
)
from dagster.core.definitions.events import RetryRequested
from dagster.core.execution.stats import StepEventStatus
from dagster.core.instance import DagsterInstance, InstanceRef, InstanceType
from dagster.core.launcher import DefaultRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.event_log import SqliteEventLogStorage
from dagster.core.storage.local_compute_log_manager import LocalComputeLogManager
from dagster.core.storage.pipeline_run import PipelineRunStatus
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import SqliteRunStorage
def test_fs_stores():
@pipeline
def simple():
@solid
def easy(context):
context.log.info("easy")
return "easy"
easy()
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
run_store = SqliteRunStorage.from_local(temp_dir)
event_store = SqliteEventLogStorage(temp_dir)
compute_log_manager = LocalComputeLogManager(temp_dir)
instance = DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=run_store,
event_storage=event_store,
compute_log_manager=compute_log_manager,
run_coordinator=DefaultRunCoordinator(),
run_launcher=DefaultRunLauncher(),
)
result = execute_pipeline(simple, instance=instance)
assert run_store.has_run(result.run_id)
assert run_store.get_run_by_id(result.run_id).status == PipelineRunStatus.SUCCESS
assert DagsterEventType.PIPELINE_SUCCESS in [
event.dagster_event.event_type
for event in event_store.get_logs_for_run(result.run_id)
if event.is_dagster_event
]
stats = event_store.get_stats_for_run(result.run_id)
assert stats.steps_succeeded == 1
assert stats.end_time is not None
def test_init_compute_log_with_bad_config():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
with open(os.path.join(tmpdir_path, "dagster.yaml"), "w") as fd:
yaml.dump({"compute_logs": {"garbage": "flargh"}}, fd, default_flow_style=False)
with pytest.raises(
DagsterInvalidConfigError, match='Received unexpected config entry "garbage"'
):
DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
def test_init_compute_log_with_bad_config_override():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
with pytest.raises(
DagsterInvalidConfigError, match='Received unexpected config entry "garbage"'
):
DagsterInstance.from_ref(
InstanceRef.from_dir(tmpdir_path, overrides={"compute_logs": {"garbage": "flargh"}})
)
def test_init_compute_log_with_bad_config_module():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
with open(os.path.join(tmpdir_path, "dagster.yaml"), "w") as fd:
yaml.dump(
{"compute_logs": {"module": "flargh", "class": "Woble", "config": {}}},
fd,
default_flow_style=False,
)
with pytest.raises(check.CheckError, match="Couldn't import module"):
DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
MOCK_HAS_RUN_CALLED = False
def test_get_run_by_id():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
assert instance.get_runs() == []
pipeline_run = PipelineRun("foo_pipeline", "new_run")
assert instance.get_run_by_id(pipeline_run.run_id) is None
instance._run_storage.add_run(pipeline_run) # pylint: disable=protected-access
assert instance.get_runs() == [pipeline_run]
assert instance.get_run_by_id(pipeline_run.run_id) == pipeline_run
# Run is created after we check whether it exists
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
run = PipelineRun(pipeline_name="foo_pipeline", run_id="bar_run")
def _has_run(self, run_id):
# This is uglier than we would like because there is no nonlocal keyword in py2
global MOCK_HAS_RUN_CALLED # pylint: disable=global-statement
# pylint: disable=protected-access
if not self._run_storage.has_run(run_id) and not MOCK_HAS_RUN_CALLED:
self._run_storage.add_run(PipelineRun(pipeline_name="foo_pipeline", run_id=run_id))
return False
else:
return self._run_storage.has_run(run_id)
instance.has_run = types.MethodType(_has_run, instance)
assert instance.get_run_by_id(run.run_id) is None
# Run is created after we check whether it exists, but deleted before we can get it
global MOCK_HAS_RUN_CALLED # pylint:disable=global-statement
MOCK_HAS_RUN_CALLED = False
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
run = PipelineRun(pipeline_name="foo_pipeline", run_id="bar_run")
def _has_run(self, run_id):
global MOCK_HAS_RUN_CALLED # pylint: disable=global-statement
# pylint: disable=protected-access
if not self._run_storage.has_run(run_id) and not MOCK_HAS_RUN_CALLED:
self._run_storage.add_run(PipelineRun(pipeline_name="foo_pipeline", run_id=run_id))
MOCK_HAS_RUN_CALLED = True
return False
elif self._run_storage.has_run(run_id) and MOCK_HAS_RUN_CALLED:
MOCK_HAS_RUN_CALLED = False
return True
else:
return False
instance.has_run = types.MethodType(_has_run, instance)
assert instance.get_run_by_id(run.run_id) is None
def test_run_step_stats():
_called = None
@pipeline
def simple():
@solid
def should_succeed(context):
context.log.info("succeed")
return "yay"
@solid(input_defs=[InputDefinition("_input", str)], output_defs=[OutputDefinition(str)])
def should_fail(context, _input):
context.log.info("fail")
raise Exception("booo")
@solid
def should_not_execute(_, x):
_called = True
return x
should_not_execute(should_fail(should_succeed()))
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
result = execute_pipeline(simple, instance=instance, raise_on_error=False)
step_stats = sorted(instance.get_run_step_stats(result.run_id), key=lambda x: x.end_time)
assert len(step_stats) == 2
assert step_stats[0].step_key == "should_succeed"
assert step_stats[0].status == StepEventStatus.SUCCESS
assert step_stats[0].end_time > step_stats[0].start_time
assert step_stats[0].attempts == 1
assert step_stats[1].step_key == "should_fail"
assert step_stats[1].status == StepEventStatus.FAILURE
assert step_stats[1].end_time > step_stats[0].start_time
assert step_stats[1].attempts == 1
assert not _called
def test_run_step_stats_with_retries():
_called = None
_count = {"total": 0}
@pipeline
def simple():
@solid
def should_succeed(_):
# This is to have at least one other step that retried to properly test
# the step key filter on `get_run_step_stats`
if _count["total"] < 2:
_count["total"] += 1
raise RetryRequested(max_retries=3)
yield Output("yay")
@solid(input_defs=[InputDefinition("_input", str)], output_defs=[OutputDefinition(str)])
def should_retry(context, _input):
raise RetryRequested(max_retries=3)
@solid
def should_not_execute(_, x):
_called = True
return x
should_not_execute(should_retry(should_succeed()))
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
result = execute_pipeline(simple, instance=instance, raise_on_error=False)
step_stats = instance.get_run_step_stats(result.run_id, step_keys=["should_retry"])
assert len(step_stats) == 1
assert step_stats[0].step_key == "should_retry"
assert step_stats[0].status == StepEventStatus.FAILURE
assert step_stats[0].end_time > step_stats[0].start_time
assert step_stats[0].attempts == 4
assert not _called
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_memoizable_object_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_memoizable_object_manager.py
index 542de6989..c8be1ff84 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_memoizable_object_manager.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_memoizable_object_manager.py
@@ -1,36 +1,38 @@
+from tempfile import TemporaryDirectory
+
from dagster import Any, seven
from dagster.core.execution.context.system import InputContext, OutputContext
from dagster.core.storage.memoizable_object_manager import (
VersionedPickledObjectFilesystemObjectManager,
)
def test_versioned_pickled_object_filesystem_object_manager():
- with seven.TemporaryDirectory() as temp_dir:
+ with TemporaryDirectory() as temp_dir:
store = VersionedPickledObjectFilesystemObjectManager(temp_dir)
context = OutputContext(
step_key="foo",
name="bar",
mapping_key=None,
metadata={},
pipeline_name="fake",
solid_def=None,
dagster_type=Any,
run_id=None,
version="version1",
)
store.handle_output(context, "cat")
assert store.has_output(context)
assert store.load_input(InputContext(upstream_output=context, pipeline_name="abc")) == "cat"
context_diff_version = OutputContext(
step_key="foo",
name="bar",
mapping_key=None,
metadata={},
pipeline_name="fake",
solid_def=None,
dagster_type=Any,
run_id=None,
version="version2",
)
assert not store.has_output(context_diff_version)
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager.py
index dd4720409..4044bdc4d 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager.py
@@ -1,343 +1,343 @@
import os
+import tempfile
import pytest
from dagster import (
AssetMaterialization,
DagsterInstance,
DagsterInvariantViolationError,
ModeDefinition,
ObjectManagerDefinition,
OutputDefinition,
execute_pipeline,
pipeline,
reexecute_pipeline,
resource,
- seven,
solid,
)
from dagster.core.definitions.events import AssetStoreOperationType
from dagster.core.execution.api import create_execution_plan, execute_plan
from dagster.core.storage.asset_store import mem_asset_store
from dagster.core.storage.fs_object_manager import custom_path_fs_object_manager, fs_object_manager
from dagster.core.storage.object_manager import ObjectManager, object_manager
def test_object_manager_with_config():
@solid
def my_solid(_):
pass
class MyObjectManager(ObjectManager):
def load_input(self, context):
assert context.upstream_output.config["some_config"] == "some_value"
return 1
def handle_output(self, context, obj):
assert context.config["some_config"] == "some_value"
@object_manager(output_config_schema={"some_config": str})
def configurable_object_manager(_):
return MyObjectManager()
@pipeline(
mode_defs=[ModeDefinition(resource_defs={"object_manager": configurable_object_manager})]
)
def my_pipeline():
my_solid()
run_config = {"solids": {"my_solid": {"outputs": {"result": {"some_config": "some_value"}}}}}
result = execute_pipeline(my_pipeline, run_config=run_config)
assert result.output_for_solid("my_solid") == 1
def test_object_manager_with_required_resource_keys():
@solid
def my_solid(_):
pass
class MyObjectManager(ObjectManager):
def __init__(self, prefix):
self._prefix = prefix
def load_input(self, _context):
return self._prefix + "bar"
def handle_output(self, _context, obj):
pass
@object_manager(required_resource_keys={"foo_resource"})
def object_manager_requires_resource(init_context):
return MyObjectManager(init_context.resources.foo_resource)
@resource
def foo_resource(_):
return "foo"
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={
"object_manager": object_manager_requires_resource,
"foo_resource": foo_resource,
}
)
]
)
def my_pipeline():
my_solid()
result = execute_pipeline(my_pipeline)
assert result.output_for_solid("my_solid") == "foobar"
def define_pipeline(manager, metadata_dict):
@solid(output_defs=[OutputDefinition(metadata=metadata_dict.get("solid_a"),)],)
def solid_a(_context):
return [1, 2, 3]
@solid(output_defs=[OutputDefinition(metadata=metadata_dict.get("solid_b"),)],)
def solid_b(_context, _df):
return 1
@pipeline(mode_defs=[ModeDefinition("local", resource_defs={"object_manager": manager})])
def my_pipeline():
solid_b(solid_a())
return my_pipeline
def test_result_output():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = fs_object_manager.configured({"base_dir": tmpdir_path})
pipeline_def = define_pipeline(asset_store, {})
result = execute_pipeline(pipeline_def)
assert result.success
# test output_value
assert result.result_for_solid("solid_a").output_value() == [1, 2, 3]
assert result.result_for_solid("solid_b").output_value() == 1
def test_fs_object_manager_reexecution():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
default_asset_store = fs_object_manager.configured({"base_dir": tmpdir_path})
pipeline_def = define_pipeline(default_asset_store, {})
instance = DagsterInstance.ephemeral()
result = execute_pipeline(pipeline_def, instance=instance)
assert result.success
re_result = reexecute_pipeline(
pipeline_def, result.run_id, instance=instance, step_selection=["solid_b"],
)
# re-execution should yield asset_store_operation events instead of intermediate events
get_asset_events = list(
filter(
lambda evt: evt.is_asset_store_operation
and AssetStoreOperationType(evt.event_specific_data.op)
== AssetStoreOperationType.GET_ASSET,
re_result.event_list,
)
)
assert len(get_asset_events) == 1
assert get_asset_events[0].event_specific_data.step_key == "solid_a"
def test_can_reexecute():
pipeline_def = define_pipeline(fs_object_manager, {})
plan = create_execution_plan(pipeline_def)
assert plan.artifacts_persisted
def execute_pipeline_with_steps(pipeline_def, step_keys_to_execute=None):
plan = create_execution_plan(pipeline_def, step_keys_to_execute=step_keys_to_execute)
with DagsterInstance.ephemeral() as instance:
pipeline_run = instance.create_run_for_pipeline(
pipeline_def=pipeline_def, step_keys_to_execute=step_keys_to_execute,
)
return execute_plan(plan, instance, pipeline_run)
def test_step_subset_with_custom_paths():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = custom_path_fs_object_manager
# pass hardcoded file path via asset_metadata
test_asset_metadata_dict = {
"solid_a": {"path": os.path.join(tmpdir_path, "a")},
"solid_b": {"path": os.path.join(tmpdir_path, "b")},
}
pipeline_def = define_pipeline(asset_store, test_asset_metadata_dict)
events = execute_pipeline_with_steps(pipeline_def)
for evt in events:
assert not evt.is_failure
# when a path is provided via asset store, it's able to run step subset using an execution
# plan when the ascendant outputs were not previously created by dagster-controlled
# computations
step_subset_events = execute_pipeline_with_steps(
pipeline_def, step_keys_to_execute=["solid_b"]
)
for evt in step_subset_events:
assert not evt.is_failure
# only the selected step subset was executed
assert set([evt.step_key for evt in step_subset_events]) == {"solid_b"}
# Asset Materialization events
step_materialization_events = list(
filter(lambda evt: evt.is_step_materialization, step_subset_events)
)
assert len(step_materialization_events) == 1
assert test_asset_metadata_dict["solid_b"]["path"] == (
step_materialization_events[0]
.event_specific_data.materialization.metadata_entries[0]
.entry_data.path
)
def test_multi_materialization():
class DummyObjectManager(ObjectManager):
def __init__(self):
self.values = {}
def handle_output(self, context, obj):
keys = tuple(context.get_run_scoped_output_identifier())
self.values[keys] = obj
yield AssetMaterialization(asset_key="yield_one")
yield AssetMaterialization(asset_key="yield_two")
def load_input(self, context):
keys = tuple(context.upstream_output.get_run_scoped_output_identifier())
return self.values[keys]
def has_asset(self, context):
keys = tuple(context.get_run_scoped_output_identifier())
return keys in self.values
@object_manager
def dummy_object_manager(_):
return DummyObjectManager()
@solid(output_defs=[OutputDefinition(manager_key="my_object_manager")])
def solid_a(_context):
return 1
@solid()
def solid_b(_context, a):
assert a == 1
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_object_manager": dummy_object_manager})])
def my_pipeline():
solid_b(solid_a())
result = execute_pipeline(my_pipeline)
assert result.success
# Asset Materialization events
step_materialization_events = list(
filter(lambda evt: evt.is_step_materialization, result.event_list)
)
assert len(step_materialization_events) == 2
def test_different_object_managers():
@solid(output_defs=[OutputDefinition(manager_key="my_object_manager")],)
def solid_a(_context):
return 1
@solid()
def solid_b(_context, a):
assert a == 1
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_object_manager": mem_asset_store})])
def my_pipeline():
solid_b(solid_a())
assert execute_pipeline(my_pipeline).success
@object_manager
def my_object_manager(_):
pass
def test_set_object_manager_and_intermediate_storage():
from dagster import intermediate_storage, fs_intermediate_storage
@intermediate_storage()
def my_intermediate_storage(_):
pass
with pytest.raises(DagsterInvariantViolationError):
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={"object_manager": my_object_manager},
intermediate_storage_defs=[my_intermediate_storage, fs_intermediate_storage],
)
]
)
def my_pipeline():
pass
execute_pipeline(my_pipeline)
def test_set_asset_store_configure_intermediate_storage():
with pytest.raises(DagsterInvariantViolationError):
@pipeline(mode_defs=[ModeDefinition(resource_defs={"object_manager": my_object_manager})])
def my_pipeline():
pass
execute_pipeline(my_pipeline, run_config={"intermediate_storage": {"filesystem": {}}})
def test_fan_in():
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
asset_store = fs_object_manager.configured({"base_dir": tmpdir_path})
@solid
def input_solid1(_):
return 1
@solid
def input_solid2(_):
return 2
@solid
def solid1(_, input1):
assert input1 == [1, 2]
@pipeline(mode_defs=[ModeDefinition(resource_defs={"object_manager": asset_store})])
def my_pipeline():
solid1(input1=[input_solid1(), input_solid2()])
execute_pipeline(my_pipeline)
def test_configured():
@object_manager(
config_schema={"base_dir": str},
description="abc",
output_config_schema={"path": str},
input_config_schema={"format": str},
required_resource_keys={"r1", "r2"},
version="123",
)
def an_object_manager(_):
pass
configured_object_manager = an_object_manager.configured({"base_dir": "/a/b/c"})
assert isinstance(configured_object_manager, ObjectManagerDefinition)
assert configured_object_manager.description == an_object_manager.description
assert configured_object_manager.output_config_schema == an_object_manager.output_config_schema
assert configured_object_manager.input_config_schema == an_object_manager.input_config_schema
assert (
configured_object_manager.required_resource_keys == an_object_manager.required_resource_keys
)
assert configured_object_manager.version is None
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager_multiprocess.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager_multiprocess.py
index 66066ad78..6b35e5255 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager_multiprocess.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_object_manager_multiprocess.py
@@ -1,39 +1,40 @@
+import tempfile
+
from dagster import (
ModeDefinition,
execute_pipeline,
fs_object_manager,
pipeline,
reconstructable,
- seven,
solid,
)
from dagster.core.test_utils import instance_for_test
@solid
def solid_a(_context):
return [1, 2, 3]
@solid
def solid_b(_context, _df):
return 1
@pipeline(mode_defs=[ModeDefinition("local", resource_defs={"object_manager": fs_object_manager})])
def my_pipeline():
solid_b(solid_a())
def test_object_manager_with_multi_process_executor():
with instance_for_test() as instance:
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
result = execute_pipeline(
reconstructable(my_pipeline),
run_config={
"execution": {"multiprocess": {}},
"resources": {"object_manager": {"config": {"base_dir": tmpdir_path}}},
},
instance=instance,
)
assert result.success
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_output_manager.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_output_manager.py
index 053f590d9..a11a68cc6 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_output_manager.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_output_manager.py
@@ -1,370 +1,371 @@
+import tempfile
+
import pytest
from dagster import (
AssetMaterialization,
DagsterInstance,
DagsterInvalidDefinitionError,
DagsterType,
EventMetadataEntry,
Failure,
InputDefinition,
ModeDefinition,
ObjectManager,
Output,
OutputDefinition,
OutputManagerDefinition,
RetryRequested,
dagster_type_materializer,
execute_pipeline,
object_manager,
pipeline,
- seven,
solid,
)
from dagster.core.instance import InstanceRef
from dagster.core.storage.input_manager import input_manager
from dagster.core.storage.output_manager import output_manager
def test_output_manager():
adict = {}
@output_manager
def my_output_manager(_context, _resource_config, obj):
adict["result"] = obj
@solid(output_defs=[OutputDefinition(manager_key="my_output_manager")])
def my_solid(_):
return 5
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_output_manager": my_output_manager})])
def my_pipeline():
my_solid()
execute_pipeline(my_pipeline)
assert adict["result"] == 5
def test_configurable_output_manager():
adict = {}
@output_manager(output_config_schema=str)
def my_output_manager(context, _resource_config, obj):
adict["result"] = (context.config, obj)
@solid(output_defs=[OutputDefinition(name="my_output", manager_key="my_output_manager")])
def my_solid(_):
return 5
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_output_manager": my_output_manager})])
def my_pipeline():
my_solid()
execute_pipeline(
my_pipeline, run_config={"solids": {"my_solid": {"outputs": {"my_output": "a"}}}}
)
assert adict["result"] == ("a", 5)
def test_separate_output_manager_input_manager():
adict = {}
@output_manager
def my_output_manager(_context, _resource_config, obj):
adict["result"] = obj
@input_manager
def my_input_manager(_context, _resource_config):
return adict["result"]
@solid(output_defs=[OutputDefinition(manager_key="my_output_manager")])
def my_solid(_):
return 5
@solid(input_defs=[InputDefinition("input1", manager_key="my_input_manager")])
def my_downstream_solid(_, input1):
return input1 + 1
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={
"my_input_manager": my_input_manager,
"my_output_manager": my_output_manager,
}
)
]
)
def my_pipeline():
my_downstream_solid(my_solid())
execute_pipeline(my_pipeline)
assert adict["result"] == 5
def test_type_materializer_and_configurable_output_manager():
@dagster_type_materializer(config_schema={"type_materializer_path": str})
def my_materializer(_, _config, _value):
assert False, "shouldn't get here"
adict = {}
@output_manager(output_config_schema={"output_manager_path": str})
def my_output_manager(_context, _resource_config, obj):
adict["result"] = obj
my_type = DagsterType(lambda _, _val: True, name="my_type", materializer=my_materializer)
@solid(
output_defs=[
OutputDefinition(name="output1", manager_key="my_output_manager", dagster_type=my_type),
OutputDefinition(name="output2", dagster_type=my_type),
]
)
def my_solid(_):
yield Output(5, "output1")
yield Output(7, "output2")
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_output_manager": my_output_manager})])
def my_pipeline():
my_solid()
execute_pipeline(
my_pipeline,
run_config={"solids": {"my_solid": {"outputs": {"output1": {"output_manager_path": "a"}}}}},
)
assert adict["result"] == 5
def test_type_materializer_and_nonconfigurable_output_manager():
adict = {}
@dagster_type_materializer(config_schema={"type_materializer_path": str})
def my_materializer(_, _config, _value):
adict["materialized"] = True
return AssetMaterialization(asset_key="a")
@output_manager
def my_output_manager(_context, _resource_config, obj):
adict["result"] = obj
my_type = DagsterType(lambda _, _val: True, name="my_type", materializer=my_materializer)
@solid(
output_defs=[
OutputDefinition(name="output1", manager_key="my_output_manager", dagster_type=my_type),
OutputDefinition(name="output2", dagster_type=my_type),
]
)
def my_solid(_):
yield Output(5, "output1")
yield Output(7, "output2")
@pipeline(mode_defs=[ModeDefinition(resource_defs={"my_output_manager": my_output_manager})])
def my_pipeline():
my_solid()
execute_pipeline(
my_pipeline,
run_config={
"solids": {"my_solid": {"outputs": [{"output1": {"type_materializer_path": "a"}}]}}
},
)
assert adict["result"] == 5
assert adict["materialized"]
def test_configured():
@output_manager(
config_schema={"base_dir": str},
description="abc",
output_config_schema={"format": str},
required_resource_keys={"r1", "r2"},
version="123",
)
def my_output_manager(_):
pass
configured_output_manager = my_output_manager.configured({"base_dir": "/a/b/c"})
assert isinstance(configured_output_manager, OutputManagerDefinition)
assert configured_output_manager.description == my_output_manager.description
assert configured_output_manager.output_config_schema == my_output_manager.output_config_schema
assert (
configured_output_manager.required_resource_keys == my_output_manager.required_resource_keys
)
assert configured_output_manager.version is None
def test_output_manager_with_failure():
_called_input_manager = False
_called_solid = False
@output_manager
def should_fail(_, _resource_config, _obj):
raise Failure(
description="Foolure",
metadata_entries=[
EventMetadataEntry.text(label="label", text="text", description="description")
],
)
@input_manager
def should_not_enter(_):
_called_input_manager = True
@solid(output_defs=[OutputDefinition(manager_key="should_fail")])
def emit_str(_):
return "emit"
@solid(
input_defs=[
InputDefinition(name="_input_str", dagster_type=str, manager_key="should_not_enter")
]
)
def should_not_call(_, _input_str):
_called_solid = True
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={"should_fail": should_fail, "should_not_enter": should_not_enter}
)
]
)
def simple():
should_not_call(emit_str())
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
result = execute_pipeline(simple, instance=instance, raise_on_error=False)
assert not result.success
failure_data = result.result_for_solid("emit_str").failure_data
assert failure_data.error.cls_name == "Failure"
assert failure_data.user_failure_data.description == "Foolure"
assert failure_data.user_failure_data.metadata_entries[0].label == "label"
assert failure_data.user_failure_data.metadata_entries[0].entry_data.text == "text"
assert failure_data.user_failure_data.metadata_entries[0].description == "description"
assert not _called_input_manager and not _called_solid
def test_output_manager_with_retries():
_called = False
_count = {"total": 0}
@object_manager
def should_succeed(_):
class FakeObjectManager(ObjectManager):
def load_input(self, _context):
return "foo"
def handle_output(self, _context, _obj):
if _count["total"] < 2:
_count["total"] += 1
raise RetryRequested(max_retries=3)
return FakeObjectManager()
@object_manager
def should_retry(_):
class FakeObjectManager(ObjectManager):
def load_input(self, _context):
return "foo"
def handle_output(self, _context, _obj):
raise RetryRequested(max_retries=3)
return FakeObjectManager()
@pipeline(
mode_defs=[
ModeDefinition(
resource_defs={"should_succeed": should_succeed, "should_retry": should_retry,}
)
]
)
def simple():
@solid(output_defs=[OutputDefinition(manager_key="should_succeed")],)
def source_solid(_):
return "foo"
@solid(
input_defs=[InputDefinition("solid_input")],
output_defs=[OutputDefinition(manager_key="should_retry")],
)
def take_input(_, solid_input):
return solid_input
@solid(input_defs=[InputDefinition("_solid_input")])
def should_not_execute(_, _solid_input):
_called = True
should_not_execute(take_input(source_solid()))
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
instance = DagsterInstance.from_ref(InstanceRef.from_dir(tmpdir_path))
result = execute_pipeline(simple, instance=instance, raise_on_error=False)
step_stats = instance.get_run_step_stats(result.run_id)
assert len(step_stats) == 2
step_stats_1 = instance.get_run_step_stats(result.run_id, step_keys=["source_solid"])
assert len(step_stats_1) == 1
step_stat_1 = step_stats_1[0]
assert step_stat_1.status.value == "SUCCESS"
assert step_stat_1.attempts == 3
step_stats_2 = instance.get_run_step_stats(result.run_id, step_keys=["take_input"])
assert len(step_stats_2) == 1
step_stat_2 = step_stats_2[0]
assert step_stat_2.status.value == "FAILURE"
assert step_stat_2.attempts == 4
step_stats_3 = instance.get_run_step_stats(result.run_id, step_keys=["should_not_execute"])
assert len(step_stats_3) == 0
assert _called == False
def test_output_manager_no_input_manager():
@output_manager
def output_manager_alone(_):
raise NotImplementedError()
@solid(output_defs=[OutputDefinition(name="output_alone", manager_key="output_manager_alone")])
def emit_str(_):
raise NotImplementedError()
@solid(input_defs=[InputDefinition("_str_input")])
def ingest_str(_, _str_input):
raise NotImplementedError()
with pytest.raises(
DagsterInvalidDefinitionError,
match='Input "_str_input" of solid "ingest_str" is connected to output "output_alone" of '
'solid "emit_str". In mode "default", that output does not have an output manager that '
"knows how to load inputs, so we don't know how to load the input. To address this, "
"assign an InputManager to this input or assign an ObjectManager to the upstream output.",
):
@pipeline(
mode_defs=[
ModeDefinition(
"default", resource_defs={"output_manager_alone": output_manager_alone}
)
]
)
def _should_fail():
ingest_str(emit_str())
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_run_storage.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_run_storage.py
index 7d139ccf4..bd4d1b01c 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_run_storage.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_run_storage.py
@@ -1,38 +1,38 @@
+import tempfile
from contextlib import contextmanager
import pytest
-from dagster import seven
from dagster.core.storage.runs import InMemoryRunStorage, SqliteRunStorage
from dagster_tests.core_tests.storage_tests.utils.run_storage import TestRunStorage
@contextmanager
def create_sqlite_run_storage():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
yield SqliteRunStorage.from_local(tempdir)
@contextmanager
def create_in_memory_storage():
yield InMemoryRunStorage()
TestRunStorage.__test__ = False
class TestSqliteImplementation(TestRunStorage):
__test__ = True
@pytest.fixture(name="storage", params=[create_sqlite_run_storage])
def run_storage(self, request):
with request.param() as s:
yield s
class TestInMemoryImplementation(TestRunStorage):
__test__ = True
@pytest.fixture(name="storage", params=[create_in_memory_storage])
def run_storage(self, request):
with request.param() as s:
yield s
diff --git a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_schedule_storage.py b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_schedule_storage.py
index 89269624d..e0b8a4026 100644
--- a/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_schedule_storage.py
+++ b/python_modules/dagster/dagster_tests/core_tests/storage_tests/test_schedule_storage.py
@@ -1,24 +1,24 @@
+import tempfile
from contextlib import contextmanager
import pytest
-from dagster import seven
from dagster.core.storage.schedules import SqliteScheduleStorage
from dagster.utils.test.schedule_storage import TestScheduleStorage
@contextmanager
def create_sqlite_schedule_storage():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
yield SqliteScheduleStorage.from_local(tempdir)
TestScheduleStorage.__test__ = False
class TestSqliteScheduleStorage(TestScheduleStorage):
__test__ = True
@pytest.fixture(name="storage", params=[create_sqlite_schedule_storage])
def schedule_storage(self, request):
with request.param() as s:
yield s
diff --git a/python_modules/dagster/dagster_tests/general_tests/py3_tests/test_type_examples_py3.py b/python_modules/dagster/dagster_tests/general_tests/py3_tests/test_type_examples_py3.py
index 001137a1d..a634574a1 100644
--- a/python_modules/dagster/dagster_tests/general_tests/py3_tests/test_type_examples_py3.py
+++ b/python_modules/dagster/dagster_tests/general_tests/py3_tests/test_type_examples_py3.py
@@ -1,484 +1,484 @@
import os
import pickle
+import tempfile
import time
import pytest
from dagster import (
Any,
Bool,
DagsterInvalidConfigError,
Dict,
Field,
Float,
InputDefinition,
Int,
List,
Nothing,
Optional,
Permissive,
Selector,
Set,
String,
Tuple,
check,
execute_pipeline,
execute_solid,
pipeline,
- seven,
solid,
)
@solid
def identity(_, x: Any) -> Any:
return x
@solid
def identity_imp(_, x):
return x
@solid
def boolean(_, x: Bool) -> String:
return "true" if x else "false"
@solid
def empty_string(_, x: String) -> bool:
return len(x) == 0
@solid
def add_3(_, x: Int) -> int:
return x + 3
@solid
def div_2(_, x: Float) -> float:
return x / 2
@solid
def concat(_, x: String, y: str) -> str:
return x + y
@solid
def wait(_) -> Nothing:
time.sleep(0.2)
return
@solid(input_defs=[InputDefinition("ready", dagster_type=Nothing)])
def done(_) -> str:
return "done"
@pipeline
def nothing_pipeline():
done(wait())
@solid
def wait_int(_) -> Int:
time.sleep(0.2)
return 1
@pipeline
def nothing_int_pipeline():
done(wait_int())
@solid
def nullable_concat(_, x: String, y: Optional[String]) -> String:
return x + (y or "")
@solid
def concat_list(_, xs: List[String]) -> String:
return "".join(xs)
@solid
def emit_1(_) -> int:
return 1
@solid
def emit_2(_) -> int:
return 2
@solid
def emit_3(_) -> int:
return 3
@solid
def sum_solid(_, xs: List[int]) -> int:
return sum(xs)
@pipeline
def sum_pipeline():
sum_solid([emit_1(), emit_2(), emit_3()])
@solid
def repeat(_, spec: Dict) -> str:
return spec["word"] * spec["times"]
@solid
def set_solid(_, set_input: Set[String]) -> List[String]:
return sorted([x for x in set_input])
@solid
def tuple_solid(_, tuple_input: Tuple[String, Int, Float]) -> List:
return [x for x in tuple_input]
@solid
def dict_return_solid(_) -> Dict[str, str]:
return {"foo": "bar"}
def test_identity():
res = execute_solid(identity, input_values={"x": "foo"})
assert res.output_value() == "foo"
def test_identity_imp():
res = execute_solid(identity_imp, input_values={"x": "foo"})
assert res.output_value() == "foo"
def test_boolean():
res = execute_solid(boolean, input_values={"x": True})
assert res.output_value() == "true"
res = execute_solid(boolean, input_values={"x": False})
assert res.output_value() == "false"
def test_empty_string():
res = execute_solid(empty_string, input_values={"x": ""})
assert res.output_value() is True
res = execute_solid(empty_string, input_values={"x": "foo"})
assert res.output_value() is False
def test_add_3():
res = execute_solid(add_3, input_values={"x": 3})
assert res.output_value() == 6
def test_div_2():
res = execute_solid(div_2, input_values={"x": 7.0})
assert res.output_value() == 3.5
def test_concat():
res = execute_solid(concat, input_values={"x": "foo", "y": "bar"})
assert res.output_value() == "foobar"
def test_nothing_pipeline():
res = execute_pipeline(nothing_pipeline)
assert res.result_for_solid("wait").output_value() is None
assert res.result_for_solid("done").output_value() == "done"
def test_nothing_int_pipeline():
res = execute_pipeline(nothing_int_pipeline)
assert res.result_for_solid("wait_int").output_value() == 1
assert res.result_for_solid("done").output_value() == "done"
def test_nullable_concat():
res = execute_solid(nullable_concat, input_values={"x": "foo", "y": None})
assert res.output_value() == "foo"
def test_concat_list():
res = execute_solid(concat_list, input_values={"xs": ["foo", "bar", "baz"]})
assert res.output_value() == "foobarbaz"
def test_sum_pipeline():
res = execute_pipeline(sum_pipeline)
assert res.result_for_solid("sum_solid").output_value() == 6
def test_repeat():
res = execute_solid(repeat, input_values={"spec": {"word": "foo", "times": 3}})
assert res.output_value() == "foofoofoo"
def test_set_solid():
res = execute_solid(set_solid, input_values={"set_input": {"foo", "bar", "baz"}})
assert res.output_value() == sorted(["foo", "bar", "baz"])
def test_set_solid_configable_input():
res = execute_solid(
set_solid,
run_config={
"solids": {
"set_solid": {
"inputs": {"set_input": [{"value": "foo"}, {"value": "bar"}, {"value": "baz"}]}
}
}
},
)
assert res.output_value() == sorted(["foo", "bar", "baz"])
def test_set_solid_configable_input_bad():
with pytest.raises(DagsterInvalidConfigError,) as exc_info:
execute_solid(
set_solid,
run_config={"solids": {"set_solid": {"inputs": {"set_input": {"foo", "bar", "baz"}}}}},
)
expected = "Value at path root:solids:set_solid:inputs:set_input must be list."
assert expected in str(exc_info.value)
def test_tuple_solid():
res = execute_solid(tuple_solid, input_values={"tuple_input": ("foo", 1, 3.1)})
assert res.output_value() == ["foo", 1, 3.1]
def test_tuple_solid_configable_input():
res = execute_solid(
tuple_solid,
run_config={
"solids": {
"tuple_solid": {
"inputs": {"tuple_input": [{"value": "foo"}, {"value": 1}, {"value": 3.1}]}
}
}
},
)
assert res.output_value() == ["foo", 1, 3.1]
def test_dict_return_solid():
res = execute_solid(dict_return_solid)
assert res.output_value() == {"foo": "bar"}
######
@solid(config_schema=Field(Any))
def any_config(context):
return context.solid_config
@solid(config_schema=Field(Bool))
def bool_config(context):
return "true" if context.solid_config else "false"
@solid(config_schema=Int)
def add_n(context, x: Int) -> int:
return x + context.solid_config
@solid(config_schema=Field(Float))
def div_y(context, x: Float) -> float:
return x / context.solid_config
@solid(config_schema=Field(float))
def div_y_var(context, x: Float) -> float:
return x / context.solid_config
@solid(config_schema=Field(String))
def hello(context) -> str:
return "Hello, {friend}!".format(friend=context.solid_config)
@solid(config_schema=Field(String))
def unpickle(context) -> Any:
with open(context.solid_config, "rb") as fd:
return pickle.load(fd)
@solid(config_schema=Field(list))
def concat_typeless_list_config(context) -> String:
return "".join(context.solid_config)
@solid(config_schema=Field([str]))
def concat_config(context) -> String:
return "".join(context.solid_config)
@solid(config_schema={"word": String, "times": Int})
def repeat_config(context) -> str:
return context.solid_config["word"] * context.solid_config["times"]
@solid(config_schema=Field(Selector({"haw": {}, "cn": {}, "en": {}})))
def hello_world(context) -> str:
if "haw" in context.solid_config:
return "Aloha honua!"
if "cn" in context.solid_config:
return "ä½ å¥½ï¼Œä¸–ç•Œ!"
return "Hello, world!"
@solid(
config_schema=Field(
Selector(
{
"haw": {"whom": Field(String, default_value="honua", is_required=False)},
"cn": {"whom": Field(String, default_value="世界", is_required=False)},
"en": {"whom": Field(String, default_value="world", is_required=False)},
}
),
is_required=False,
default_value={"en": {"whom": "world"}},
)
)
def hello_world_default(context) -> str:
if "haw" in context.solid_config:
return "Aloha {whom}!".format(whom=context.solid_config["haw"]["whom"])
if "cn" in context.solid_config:
return "ä½ å¥½ï¼Œ{whom}!".format(whom=context.solid_config["cn"]["whom"])
if "en" in context.solid_config:
return "Hello, {whom}!".format(whom=context.solid_config["en"]["whom"])
@solid(config_schema=Field(Permissive({"required": Field(String)})))
def partially_specified_config(context) -> List:
return sorted(list(context.solid_config.items()))
def test_any_config():
res = execute_solid(any_config, run_config={"solids": {"any_config": {"config": "foo"}}})
assert res.output_value() == "foo"
res = execute_solid(
any_config, run_config={"solids": {"any_config": {"config": {"zip": "zowie"}}}}
)
assert res.output_value() == {"zip": "zowie"}
def test_bool_config():
res = execute_solid(bool_config, run_config={"solids": {"bool_config": {"config": True}}})
assert res.output_value() == "true"
res = execute_solid(bool_config, run_config={"solids": {"bool_config": {"config": False}}})
assert res.output_value() == "false"
def test_add_n():
res = execute_solid(
add_n, input_values={"x": 3}, run_config={"solids": {"add_n": {"config": 7}}}
)
assert res.output_value() == 10
def test_div_y():
res = execute_solid(
div_y, input_values={"x": 3.0}, run_config={"solids": {"div_y": {"config": 2.0}}}
)
assert res.output_value() == 1.5
def test_div_y_var():
res = execute_solid(
div_y_var, input_values={"x": 3.0}, run_config={"solids": {"div_y_var": {"config": 2.0}}},
)
assert res.output_value() == 1.5
def test_hello():
res = execute_solid(hello, run_config={"solids": {"hello": {"config": "Max"}}})
assert res.output_value() == "Hello, Max!"
def test_unpickle():
- with seven.TemporaryDirectory() as tmpdir:
+ with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, "foo.pickle")
with open(filename, "wb") as f:
pickle.dump("foo", f)
res = execute_solid(unpickle, run_config={"solids": {"unpickle": {"config": filename}}})
assert res.output_value() == "foo"
def test_concat_config():
res = execute_solid(
concat_config, run_config={"solids": {"concat_config": {"config": ["foo", "bar", "baz"]}}},
)
assert res.output_value() == "foobarbaz"
def test_concat_typeless_config():
res = execute_solid(
concat_typeless_list_config,
run_config={"solids": {"concat_typeless_list_config": {"config": ["foo", "bar", "baz"]}}},
)
assert res.output_value() == "foobarbaz"
def test_repeat_config():
res = execute_solid(
repeat_config,
run_config={"solids": {"repeat_config": {"config": {"word": "foo", "times": 3}}}},
)
assert res.output_value() == "foofoofoo"
def test_tuple_none_config():
with pytest.raises(check.CheckError, match="Param tuple_types cannot be none"):
@solid(config_schema=Field(Tuple[None]))
def _tuple_none_config(context) -> str:
return ":".join([str(x) for x in context.solid_config])
def test_selector_config():
res = execute_solid(
hello_world, run_config={"solids": {"hello_world": {"config": {"haw": {}}}}}
)
assert res.output_value() == "Aloha honua!"
def test_selector_config_default():
res = execute_solid(hello_world_default)
assert res.output_value() == "Hello, world!"
res = execute_solid(
hello_world_default,
run_config={"solids": {"hello_world_default": {"config": {"haw": {}}}}},
)
assert res.output_value() == "Aloha honua!"
res = execute_solid(
hello_world_default,
run_config={"solids": {"hello_world_default": {"config": {"haw": {"whom": "Max"}}}}},
)
assert res.output_value() == "Aloha Max!"
def test_permissive_config():
res = execute_solid(
partially_specified_config,
run_config={
"solids": {
"partially_specified_config": {"config": {"required": "yes", "also": "this"}}
}
},
)
assert res.output_value() == sorted([("required", "yes"), ("also", "this")])
diff --git a/python_modules/dagster/dev-requirements.txt b/python_modules/dagster/dev-requirements.txt
index 8d82f4612..1ab221f37 100644
--- a/python_modules/dagster/dev-requirements.txt
+++ b/python_modules/dagster/dev-requirements.txt
@@ -1,28 +1,25 @@
-astroid>=2.3.3; python_version >= '3.6'
-black==19.10b0; python_version >= '3.6'
+astroid>=2.3.3
+black==19.10b0
coverage==5.3
docker
flake8>=3.7.8
freezegun>=0.3.15
grpcio-tools==1.32.0
isort<5,>=4.3.21
mock==3.0.5
nbsphinx==0.4.2
protobuf==3.13.0 # without this, pip will install the most up-to-date protobuf
-pylint==2.6.0; python_version >= '3.6'
+pylint==2.6.0
pytest-cov==2.10.1
pytest-dependency==0.5.1
-pytest-mock==2.0.0; python_version < '3.6'
-pytest-mock==3.3.1; python_version >= '3.6'
+pytest-mock==3.3.1
pytest-runner==5.2
-pytest-xdist==1.34.0; python_version < '3.6'
-pytest-xdist==2.1.0; python_version >= '3.6'
-pytest==4.6.11; python_version < '3.6'
-pytest==6.1.1; python_version >= '3.6'
+pytest-xdist==2.1.0
+pytest==6.1.1
recommonmark==0.4.0
responses==0.10.*
snapshottest==0.6.0
tox==3.14.2
tox-pip-version==0.0.7
tqdm==4.48.0 # pylint crash 48.1+
yamllint
diff --git a/python_modules/dagster/setup.py b/python_modules/dagster/setup.py
index aca9396cb..f734d8142 100644
--- a/python_modules/dagster/setup.py
+++ b/python_modules/dagster/setup.py
@@ -1,102 +1,92 @@
from setuptools import find_packages, setup
def long_description():
return """
## Dagster
Dagster is a data orchestrator for machine learning, analytics, and ETL.
Dagster lets you define pipelines in terms of the data flow between reusable, logical components,
then test locally and run anywhere. With a unified view of pipelines and the assets they produce,
Dagster can schedule and orchestrate Pandas, Spark, SQL, or anything else that Python can invoke.
Dagster is designed for data platform engineers, data engineers, and full-stack data scientists.
Building a data platform with Dagster makes your stakeholders more independent and your systems
more robust. Developing data pipelines with Dagster makes testing easier and deploying faster.
""".strip()
def get_version():
version = {}
with open("dagster/version.py") as fp:
exec(fp.read(), version) # pylint: disable=W0122
return version["__version__"]
if __name__ == "__main__":
setup(
name="dagster",
version=get_version(),
author="Elementl",
author_email="hello@elementl.com",
license="Apache-2.0",
description="A data orchestrator for machine learning, analytics, and ETL.",
long_description=long_description(),
long_description_content_type="text/markdown",
url="https://github.com/dagster-io/dagster",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
packages=find_packages(exclude=["dagster_tests"]),
package_data={
"dagster": [
"dagster/core/storage/event_log/sqlite/alembic/*",
"dagster/core/storage/runs/sqlite/alembic/*",
"dagster/core/storage/schedules/sqlite/alembic/*",
"dagster/grpc/protos/*",
]
},
include_package_data=True,
install_requires=[
- # standard python 2/3 compatability things
- 'enum34; python_version < "3.4"',
"future",
- "funcsigs",
- 'functools32; python_version<"3"',
- "contextlib2>=0.5.4",
- 'pathlib2>=2.3.4; python_version<"3"',
# cli
"click>=5.0",
"coloredlogs>=6.1, <=14.0",
"PyYAML",
# core (not explicitly expressed atm)
"alembic>=1.2.1",
"croniter>=0.3.34",
"grpcio>=1.32.0", # ensure version we require is >= that with which we generated the grpc code (set in dev-requirements)
"grpcio-health-checking>=1.32.0",
"pendulum==1.4.4", # pinned to match airflow, can upgrade to 2.0 once airflow 1.10.13 is released
"protobuf>=3.13.0", # ensure version we require is >= that with which we generated the proto code (set in dev-requirements)
- "pyrsistent>=0.14.8,<=0.16.0; python_version < '3'", # 0.17.0 breaks py2 support
- "pyrsistent>=0.14.8; python_version >='3'",
+ "pyrsistent>=0.14.8",
"python-dateutil",
"requests",
"rx<=1.6.1", # 3.0 was a breaking change. No py2 compatability as well.
- 'futures; python_version < "3"',
"six",
"tabulate",
"tqdm",
"sqlalchemy>=1.0",
- 'typing; python_version<"3"',
- 'backports.tempfile; python_version<"3"',
"toposort>=1.0",
"watchdog>=0.8.3",
'psutil >= 1.0; platform_system=="Windows"',
# https://github.com/mhammond/pywin32/issues/1439
'pywin32 != 226; platform_system=="Windows"',
"pytz",
- 'docstring-parser==0.7.1; python_version >="3.6"',
+ "docstring-parser==0.7.1",
],
extras_require={"docker": ["docker"],},
entry_points={
"console_scripts": [
"dagster = dagster.cli:main",
"dagster-scheduler = dagster.scheduler.cli:main",
"dagster-daemon = dagster.daemon.cli:main",
]
},
)
diff --git a/python_modules/libraries/dagster-airflow/dagster_airflow/factory.py b/python_modules/libraries/dagster-airflow/dagster_airflow/factory.py
index b03bd6254..a06dabe65 100644
--- a/python_modules/libraries/dagster-airflow/dagster_airflow/factory.py
+++ b/python_modules/libraries/dagster-airflow/dagster_airflow/factory.py
@@ -1,485 +1,485 @@
import datetime
import re
from collections import namedtuple
from airflow import DAG
from airflow.operators import BaseOperator
from dagster import check, seven
from dagster.core.definitions.reconstructable import ReconstructableRepository
from dagster.core.execution.api import create_execution_plan
from dagster.core.instance import DagsterInstance
from dagster.core.instance.ref import InstanceRef
from dagster.core.snap import ExecutionPlanSnapshot, PipelineSnapshot, snapshot_from_execution_plan
from dagster_airflow.operators.util import check_storage_specified
from .compile import coalesce_execution_steps
from .operators.docker_operator import DagsterDockerOperator
from .operators.python_operator import DagsterPythonOperator
DEFAULT_ARGS = {
"depends_on_past": False,
"email": ["airflow@example.com"],
"email_on_failure": False,
"email_on_retry": False,
"owner": "airflow",
"retries": 1,
"retry_delay": datetime.timedelta(0, 300),
"start_date": datetime.datetime(1900, 1, 1, 0, 0),
}
# Airflow DAG names are not allowed to be longer than 250 chars
AIRFLOW_MAX_DAG_NAME_LEN = 250
def _make_dag_description(pipeline_name):
return """Editable scaffolding autogenerated by dagster-airflow from pipeline {pipeline_name}
""".format(
pipeline_name=pipeline_name
)
def _rename_for_airflow(name):
"""Modify pipeline name for Airflow to meet constraints on DAG names:
https://github.com/apache/airflow/blob/1.10.3/airflow/utils/helpers.py#L52-L63
Here, we just substitute underscores for illegal characters to avoid imposing Airflow's
constraints on our naming schemes.
"""
return re.sub(r"[^\w\-\.]", "_", name)[:AIRFLOW_MAX_DAG_NAME_LEN]
class DagsterOperatorInvocationArgs(
namedtuple(
"DagsterOperatorInvocationArgs",
"recon_repo pipeline_name run_config mode step_keys instance_ref pipeline_snapshot "
"execution_plan_snapshot parent_pipeline_snapshot",
)
):
def __new__(
cls,
recon_repo,
pipeline_name,
run_config,
mode,
step_keys,
instance_ref,
pipeline_snapshot,
execution_plan_snapshot,
parent_pipeline_snapshot,
):
return super(DagsterOperatorInvocationArgs, cls).__new__(
cls,
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
step_keys=step_keys,
instance_ref=instance_ref,
pipeline_snapshot=pipeline_snapshot,
execution_plan_snapshot=execution_plan_snapshot,
parent_pipeline_snapshot=parent_pipeline_snapshot,
)
class DagsterOperatorParameters(
namedtuple(
"_DagsterOperatorParameters",
(
"recon_repo pipeline_name run_config "
"mode task_id step_keys dag instance_ref op_kwargs pipeline_snapshot "
"execution_plan_snapshot parent_pipeline_snapshot"
),
)
):
def __new__(
cls,
pipeline_name,
task_id,
recon_repo=None,
run_config=None,
mode=None,
step_keys=None,
dag=None,
instance_ref=None,
op_kwargs=None,
pipeline_snapshot=None,
execution_plan_snapshot=None,
parent_pipeline_snapshot=None,
):
check_storage_specified(run_config)
return super(DagsterOperatorParameters, cls).__new__(
cls,
recon_repo=check.opt_inst_param(recon_repo, "recon_repo", ReconstructableRepository),
pipeline_name=check.str_param(pipeline_name, "pipeline_name"),
run_config=check.opt_dict_param(run_config, "run_config", key_type=str),
mode=check.opt_str_param(mode, "mode"),
task_id=check.str_param(task_id, "task_id"),
step_keys=check.opt_list_param(step_keys, "step_keys", of_type=str),
dag=check.opt_inst_param(dag, "dag", DAG),
instance_ref=check.opt_inst_param(instance_ref, "instance_ref", InstanceRef),
op_kwargs=check.opt_dict_param(op_kwargs.copy(), "op_kwargs", key_type=str),
pipeline_snapshot=check.inst_param(
pipeline_snapshot, "pipeline_snapshot", PipelineSnapshot
),
execution_plan_snapshot=check.inst_param(
execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot
),
parent_pipeline_snapshot=check.opt_inst_param(
parent_pipeline_snapshot, "parent_pipeline_snapshot", PipelineSnapshot
),
)
@property
def invocation_args(self):
return DagsterOperatorInvocationArgs(
recon_repo=self.recon_repo,
pipeline_name=self.pipeline_name,
run_config=self.run_config,
mode=self.mode,
step_keys=self.step_keys,
instance_ref=self.instance_ref,
pipeline_snapshot=self.pipeline_snapshot,
execution_plan_snapshot=self.execution_plan_snapshot,
parent_pipeline_snapshot=self.parent_pipeline_snapshot,
)
def _make_airflow_dag(
recon_repo,
pipeline_name,
run_config=None,
mode=None,
instance=None,
dag_id=None,
dag_description=None,
dag_kwargs=None,
op_kwargs=None,
operator=DagsterPythonOperator,
):
check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
check.str_param(pipeline_name, "pipeline_name")
run_config = check.opt_dict_param(run_config, "run_config", key_type=str)
mode = check.opt_str_param(mode, "mode")
- # Default to use the (persistent) system temp directory rather than a seven.TemporaryDirectory,
+ # Default to use the (persistent) system temp directory rather than a TemporaryDirectory,
# which would not be consistent between Airflow task invocations.
instance = (
check.inst_param(instance, "instance", DagsterInstance)
if instance
else DagsterInstance.get(fallback_storage=seven.get_system_temp_directory())
)
# Only used for Airflow; internally we continue to use pipeline.name
dag_id = check.opt_str_param(dag_id, "dag_id", _rename_for_airflow(pipeline_name))
dag_description = check.opt_str_param(
dag_description, "dag_description", _make_dag_description(pipeline_name)
)
check.subclass_param(operator, "operator", BaseOperator)
dag_kwargs = dict(
{"default_args": DEFAULT_ARGS},
**check.opt_dict_param(dag_kwargs, "dag_kwargs", key_type=str),
)
op_kwargs = check.opt_dict_param(op_kwargs, "op_kwargs", key_type=str)
dag = DAG(dag_id=dag_id, description=dag_description, **dag_kwargs)
pipeline = recon_repo.get_definition().get_pipeline(pipeline_name)
if mode is None:
mode = pipeline.get_default_mode_name()
execution_plan = create_execution_plan(pipeline, run_config, mode=mode)
tasks = {}
coalesced_plan = coalesce_execution_steps(execution_plan)
for solid_handle, solid_steps in coalesced_plan.items():
step_keys = [step.key for step in solid_steps]
operator_parameters = DagsterOperatorParameters(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
task_id=solid_handle,
step_keys=step_keys,
dag=dag,
instance_ref=instance.get_ref(),
op_kwargs=op_kwargs,
pipeline_snapshot=pipeline.get_pipeline_snapshot(),
execution_plan_snapshot=snapshot_from_execution_plan(
execution_plan, pipeline_snapshot_id=pipeline.get_pipeline_snapshot_id()
),
)
task = operator(operator_parameters)
tasks[solid_handle] = task
for solid_step in solid_steps:
for step_input in solid_step.step_inputs:
for key in step_input.dependency_keys:
prev_solid_handle = execution_plan.get_step_by_key(key).solid_handle.to_string()
if solid_handle != prev_solid_handle:
tasks[prev_solid_handle].set_downstream(task)
return (dag, [tasks[solid_handle] for solid_handle in coalesced_plan.keys()])
def make_airflow_dag(
module_name,
pipeline_name,
run_config=None,
mode=None,
instance=None,
dag_id=None,
dag_description=None,
dag_kwargs=None,
op_kwargs=None,
):
"""Construct an Airflow DAG corresponding to a given Dagster pipeline.
Tasks in the resulting DAG will execute the Dagster logic they encapsulate as a Python
callable, run by an underlying :py:class:`PythonOperator `. As a
consequence, both dagster, any Python dependencies required by your solid logic, and the module
containing your pipeline definition must be available in the Python environment within which
your Airflow tasks execute. If you cannot install requirements into this environment, or you
are looking for a containerized solution to provide better isolation, see instead
:py:func:`make_airflow_dag_containerized`.
This function should be invoked in an Airflow DAG definition file, such as that created by an
invocation of the dagster-airflow scaffold CLI tool.
Args:
module_name (str): The name of the importable module in which the pipeline definition can be
found.
pipeline_name (str): The name of the pipeline definition.
run_config (Optional[dict]): The environment config, if any, with which to compile
the pipeline to an execution plan, as a Python dict.
mode (Optional[str]): The mode in which to execute the pipeline.
instance (Optional[DagsterInstance]): The Dagster instance to use to execute the pipeline.
dag_id (Optional[str]): The id to use for the compiled Airflow DAG (passed through to
:py:class:`DAG `).
dag_description (Optional[str]): The description to use for the compiled Airflow DAG
(passed through to :py:class:`DAG `)
dag_kwargs (Optional[dict]): Any additional kwargs to pass to the Airflow
:py:class:`DAG ` constructor, including ``default_args``.
op_kwargs (Optional[dict]): Any additional kwargs to pass to the underlying Airflow
operator (a subclass of
:py:class:`PythonOperator `).
Returns:
(airflow.models.DAG, List[airflow.models.BaseOperator]): The generated Airflow DAG, and a
list of its constituent tasks.
"""
check.str_param(module_name, "module_name")
recon_repo = ReconstructableRepository.for_module(module_name, pipeline_name)
return _make_airflow_dag(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
instance=instance,
dag_id=dag_id,
dag_description=dag_description,
dag_kwargs=dag_kwargs,
op_kwargs=op_kwargs,
)
def make_airflow_dag_for_operator(
recon_repo,
pipeline_name,
operator,
run_config=None,
mode=None,
dag_id=None,
dag_description=None,
dag_kwargs=None,
op_kwargs=None,
):
"""Construct an Airflow DAG corresponding to a given Dagster pipeline and custom operator.
`Custom operator template `_
Tasks in the resulting DAG will execute the Dagster logic they encapsulate run by the given
Operator :py:class:`BaseOperator `. If you
are looking for a containerized solution to provide better isolation, see instead
:py:func:`make_airflow_dag_containerized`.
This function should be invoked in an Airflow DAG definition file, such as that created by an
invocation of the dagster-airflow scaffold CLI tool.
Args:
recon_repo (:class:`dagster.ReconstructableRepository`): reference to a Dagster RepositoryDefinition
that can be reconstructed in another process
pipeline_name (str): The name of the pipeline definition.
operator (type): The operator to use. Must be a class that inherits from
:py:class:`BaseOperator `
run_config (Optional[dict]): The environment config, if any, with which to compile
the pipeline to an execution plan, as a Python dict.
mode (Optional[str]): The mode in which to execute the pipeline.
instance (Optional[DagsterInstance]): The Dagster instance to use to execute the pipeline.
dag_id (Optional[str]): The id to use for the compiled Airflow DAG (passed through to
:py:class:`DAG `).
dag_description (Optional[str]): The description to use for the compiled Airflow DAG
(passed through to :py:class:`DAG `)
dag_kwargs (Optional[dict]): Any additional kwargs to pass to the Airflow
:py:class:`DAG ` constructor, including ``default_args``.
op_kwargs (Optional[dict]): Any additional kwargs to pass to the underlying Airflow
operator.
Returns:
(airflow.models.DAG, List[airflow.models.BaseOperator]): The generated Airflow DAG, and a
list of its constituent tasks.
"""
check.subclass_param(operator, "operator", BaseOperator)
return _make_airflow_dag(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
dag_id=dag_id,
dag_description=dag_description,
dag_kwargs=dag_kwargs,
op_kwargs=op_kwargs,
operator=operator,
)
def make_airflow_dag_for_recon_repo(
recon_repo,
pipeline_name,
run_config=None,
mode=None,
dag_id=None,
dag_description=None,
dag_kwargs=None,
op_kwargs=None,
):
return _make_airflow_dag(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
dag_id=dag_id,
dag_description=dag_description,
dag_kwargs=dag_kwargs,
op_kwargs=op_kwargs,
)
def make_airflow_dag_containerized(
module_name,
pipeline_name,
image,
run_config=None,
mode=None,
dag_id=None,
dag_description=None,
dag_kwargs=None,
op_kwargs=None,
):
"""Construct a containerized Airflow DAG corresponding to a given Dagster pipeline.
Tasks in the resulting DAG will execute the Dagster logic they encapsulate by calling the
dagster-graphql API exposed by a container run using a subclass of
:py:class:`DockerOperator `. As a
consequence, both dagster, any Python dependencies required by your solid logic, and the module
containing your pipeline definition must be available in the container spun up by this operator.
Typically you'll want to install these requirements onto the image you're using.
This function should be invoked in an Airflow DAG definition file, such as that created by an
invocation of the dagster-airflow scaffold CLI tool.
Args:
module_name (str): The name of the importable module in which the pipeline definition can be
found.
pipeline_name (str): The name of the pipeline definition.
image (str): The name of the Docker image to use for execution (passed through to
:py:class:`DockerOperator `).
run_config (Optional[dict]): The environment config, if any, with which to compile
the pipeline to an execution plan, as a Python dict.
mode (Optional[str]): The mode in which to execute the pipeline.
dag_id (Optional[str]): The id to use for the compiled Airflow DAG (passed through to
:py:class:`DAG `).
dag_description (Optional[str]): The description to use for the compiled Airflow DAG
(passed through to :py:class:`DAG `)
dag_kwargs (Optional[dict]): Any additional kwargs to pass to the Airflow
:py:class:`DAG ` constructor, including ``default_args``.
op_kwargs (Optional[dict]): Any additional kwargs to pass to the underlying Airflow
operator (a subclass of
:py:class:`DockerOperator `).
Returns:
(airflow.models.DAG, List[airflow.models.BaseOperator]): The generated Airflow DAG, and a
list of its constituent tasks.
"""
check.str_param(module_name, "module_name")
check.str_param(pipeline_name, "pipeline_name")
check.str_param(image, "image")
check.opt_dict_param(run_config, "run_config")
check.opt_str_param(mode, "mode")
check.opt_str_param(dag_id, "dag_id")
check.opt_str_param(dag_description, "dag_description")
check.opt_dict_param(dag_kwargs, "dag_kwargs")
check.opt_dict_param(op_kwargs, "op_kwargs")
recon_repo = ReconstructableRepository.for_module(module_name, pipeline_name)
op_kwargs = check.opt_dict_param(op_kwargs, "op_kwargs", key_type=str)
op_kwargs["image"] = image
return _make_airflow_dag(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
dag_id=dag_id,
dag_description=dag_description,
dag_kwargs=dag_kwargs,
op_kwargs=op_kwargs,
operator=DagsterDockerOperator,
)
def make_airflow_dag_containerized_for_recon_repo(
recon_repo,
pipeline_name,
image,
run_config=None,
mode=None,
dag_id=None,
dag_description=None,
dag_kwargs=None,
op_kwargs=None,
instance=None,
):
check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
check.str_param(pipeline_name, "pipeline_name")
check.str_param(image, "image")
check.opt_dict_param(run_config, "run_config")
check.opt_str_param(mode, "mode")
check.opt_str_param(dag_id, "dag_id")
check.opt_str_param(dag_description, "dag_description")
check.opt_dict_param(dag_kwargs, "dag_kwargs")
op_kwargs = check.opt_dict_param(op_kwargs, "op_kwargs", key_type=str)
op_kwargs["image"] = image
return _make_airflow_dag(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
run_config=run_config,
mode=mode,
dag_id=dag_id,
dag_description=dag_description,
dag_kwargs=dag_kwargs,
op_kwargs=op_kwargs,
operator=DagsterDockerOperator,
instance=instance,
)
diff --git a/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_dagster_pipeline_factory/test_load_dag_bag.py b/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_dagster_pipeline_factory/test_load_dag_bag.py
index a64583886..56322fae0 100644
--- a/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_dagster_pipeline_factory/test_load_dag_bag.py
+++ b/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_dagster_pipeline_factory/test_load_dag_bag.py
@@ -1,438 +1,439 @@
import os
+import tempfile
import pytest
-from dagster import execute_pipeline, seven
+from dagster import execute_pipeline
from dagster_airflow.dagster_pipeline_factory import (
make_dagster_repo_from_airflow_dags_path,
make_dagster_repo_from_airflow_example_dags,
)
from dagster_airflow_tests.marks import requires_airflow_db
COMPLEX_DAG_FILE_CONTENTS = '''#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG that shows the complex DAG structure.
"""
import sys
from airflow import models
from airflow.utils.dates import days_ago
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.helpers import chain
default_args = {"start_date": days_ago(1)}
with models.DAG(
dag_id="example_complex", default_args=default_args, schedule_interval=None, tags=['example'],
) as complex_dag:
# Create
create_entry_group = BashOperator(
task_id="create_entry_group", bash_command="echo create_entry_group"
)
create_entry_group_result = BashOperator(
task_id="create_entry_group_result", bash_command="echo create_entry_group_result"
)
create_entry_group_result2 = BashOperator(
task_id="create_entry_group_result2", bash_command="echo create_entry_group_result2"
)
create_entry_gcs = BashOperator(
task_id="create_entry_gcs", bash_command="echo create_entry_gcs"
)
create_entry_gcs_result = BashOperator(
task_id="create_entry_gcs_result", bash_command="echo create_entry_gcs_result"
)
create_entry_gcs_result2 = BashOperator(
task_id="create_entry_gcs_result2", bash_command="echo create_entry_gcs_result2"
)
create_tag = BashOperator(task_id="create_tag", bash_command="echo create_tag")
create_tag_result = BashOperator(
task_id="create_tag_result", bash_command="echo create_tag_result"
)
create_tag_result2 = BashOperator(
task_id="create_tag_result2", bash_command="echo create_tag_result2"
)
create_tag_template = BashOperator(
task_id="create_tag_template", bash_command="echo create_tag_template"
)
create_tag_template_result = BashOperator(
task_id="create_tag_template_result", bash_command="echo create_tag_template_result"
)
create_tag_template_result2 = BashOperator(
task_id="create_tag_template_result2", bash_command="echo create_tag_template_result2"
)
create_tag_template_field = BashOperator(
task_id="create_tag_template_field", bash_command="echo create_tag_template_field"
)
create_tag_template_field_result = BashOperator(
task_id="create_tag_template_field_result",
bash_command="echo create_tag_template_field_result",
)
create_tag_template_field_result2 = BashOperator(
task_id="create_tag_template_field_result",
bash_command="echo create_tag_template_field_result",
)
# Delete
delete_entry = BashOperator(task_id="delete_entry", bash_command="echo delete_entry")
create_entry_gcs >> delete_entry
delete_entry_group = BashOperator(
task_id="delete_entry_group", bash_command="echo delete_entry_group"
)
create_entry_group >> delete_entry_group
delete_tag = BashOperator(task_id="delete_tag", bash_command="echo delete_tag")
create_tag >> delete_tag
delete_tag_template_field = BashOperator(
task_id="delete_tag_template_field", bash_command="echo delete_tag_template_field"
)
delete_tag_template = BashOperator(
task_id="delete_tag_template", bash_command="echo delete_tag_template"
)
# Get
get_entry_group = BashOperator(task_id="get_entry_group", bash_command="echo get_entry_group")
get_entry_group_result = BashOperator(
task_id="get_entry_group_result", bash_command="echo get_entry_group_result"
)
get_entry = BashOperator(task_id="get_entry", bash_command="echo get_entry")
get_entry_result = BashOperator(
task_id="get_entry_result", bash_command="echo get_entry_result"
)
get_tag_template = BashOperator(
task_id="get_tag_template", bash_command="echo get_tag_template"
)
get_tag_template_result = BashOperator(
task_id="get_tag_template_result", bash_command="echo get_tag_template_result"
)
# List
list_tags = BashOperator(task_id="list_tags", bash_command="echo list_tags")
list_tags_result = BashOperator(
task_id="list_tags_result", bash_command="echo list_tags_result"
)
# Lookup
lookup_entry = BashOperator(task_id="lookup_entry", bash_command="echo lookup_entry")
lookup_entry_result = BashOperator(
task_id="lookup_entry_result", bash_command="echo lookup_entry_result"
)
# Rename
rename_tag_template_field = BashOperator(
task_id="rename_tag_template_field", bash_command="echo rename_tag_template_field"
)
# Search
search_catalog = PythonOperator(
task_id="search_catalog", python_callable=lambda: sys.stdout.write("search_catalog\\n")
)
search_catalog_result = BashOperator(
task_id="search_catalog_result", bash_command="echo search_catalog_result"
)
# Update
update_entry = BashOperator(task_id="update_entry", bash_command="echo update_entry")
update_tag = BashOperator(task_id="update_tag", bash_command="echo update_tag")
update_tag_template = BashOperator(
task_id="update_tag_template", bash_command="echo update_tag_template"
)
update_tag_template_field = BashOperator(
task_id="update_tag_template_field", bash_command="echo update_tag_template_field"
)
# Create
create_tasks = [
create_entry_group,
create_entry_gcs,
create_tag_template,
create_tag_template_field,
create_tag,
]
chain(*create_tasks)
create_entry_group >> delete_entry_group
create_entry_group >> create_entry_group_result
create_entry_group >> create_entry_group_result2
create_entry_gcs >> delete_entry
create_entry_gcs >> create_entry_gcs_result
create_entry_gcs >> create_entry_gcs_result2
create_tag_template >> delete_tag_template_field
create_tag_template >> create_tag_template_result
create_tag_template >> create_tag_template_result2
create_tag_template_field >> delete_tag_template_field
create_tag_template_field >> create_tag_template_field_result
create_tag_template_field >> create_tag_template_field_result2
create_tag >> delete_tag
create_tag >> create_tag_result
create_tag >> create_tag_result2
# Delete
delete_tasks = [
delete_tag,
delete_tag_template_field,
delete_tag_template,
delete_entry_group,
delete_entry,
]
chain(*delete_tasks)
# Get
create_tag_template >> get_tag_template >> delete_tag_template
get_tag_template >> get_tag_template_result
create_entry_gcs >> get_entry >> delete_entry
get_entry >> get_entry_result
create_entry_group >> get_entry_group >> delete_entry_group
get_entry_group >> get_entry_group_result
# List
create_tag >> list_tags >> delete_tag
list_tags >> list_tags_result
# Lookup
create_entry_gcs >> lookup_entry >> delete_entry
lookup_entry >> lookup_entry_result
# Rename
create_tag_template_field >> rename_tag_template_field >> delete_tag_template_field
# Search
chain(create_tasks, search_catalog, delete_tasks)
search_catalog >> search_catalog_result
# Update
create_entry_gcs >> update_entry >> delete_entry
create_tag >> update_tag >> delete_tag
create_tag_template >> update_tag_template >> delete_tag_template
create_tag_template_field >> update_tag_template_field >> rename_tag_template_field
'''
BASH_DAG_FILE_CONTENTS = '''#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example DAG demonstrating the usage of the BashOperator."""
# DAG
# airflow
from datetime import timedelta
from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
args = {
'owner': 'airflow',
'start_date': days_ago(2),
}
bash_dag = DAG(
dag_id='example_bash_operator',
default_args=args,
schedule_interval='0 0 * * *',
dagrun_timeout=timedelta(minutes=60),
tags=['example'],
)
run_this_last = DummyOperator(task_id='run_this_last', dag=bash_dag,)
# [START howto_operator_bash]
run_this = BashOperator(task_id='run_after_loop', bash_command='echo 1', dag=bash_dag,)
# [END howto_operator_bash]
run_this >> run_this_last
for i in range(3):
task = BashOperator(
task_id='runme_' + str(i),
bash_command='echo "{{ task_instance_key_str }}" && sleep 1',
dag=bash_dag,
)
task >> run_this
# [START howto_operator_bash_template]
also_run_this = BashOperator(
task_id='also_run_this',
bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"',
dag=bash_dag,
)
# [END howto_operator_bash_template]
also_run_this >> run_this_last
if __name__ == "__main__":
bash_dag.cli()
'''
COMBINED_FILE_CONTENTS = COMPLEX_DAG_FILE_CONTENTS + BASH_DAG_FILE_CONTENTS
test_make_repo_inputs = [
([("complex.py", COMPLEX_DAG_FILE_CONTENTS)], None, ["airflow_example_complex"]),
([("bash.py", BASH_DAG_FILE_CONTENTS)], None, ["airflow_example_bash_operator"]),
(
[("complex.py", COMPLEX_DAG_FILE_CONTENTS), ("bash.py", BASH_DAG_FILE_CONTENTS)],
None,
["airflow_example_complex", "airflow_example_bash_operator"],
),
([("complex.py", COMPLEX_DAG_FILE_CONTENTS)], "complex.py", ["airflow_example_complex"],),
([("bash.py", BASH_DAG_FILE_CONTENTS)], "bash.py", ["airflow_example_bash_operator"],),
(
[("combined.py", COMBINED_FILE_CONTENTS)],
None,
["airflow_example_complex", "airflow_example_bash_operator"],
),
]
@pytest.mark.parametrize(
"path_and_content_tuples, fn_arg_path, expected_pipeline_names", test_make_repo_inputs,
)
def test_make_repo(
path_and_content_tuples, fn_arg_path, expected_pipeline_names,
):
repo_name = "my_repo_name"
- with seven.TemporaryDirectory() as tmpdir_path:
+ with tempfile.TemporaryDirectory() as tmpdir_path:
for (path, content) in path_and_content_tuples:
with open(os.path.join(tmpdir_path, path), "wb") as f:
f.write(bytes(content.encode("utf-8")))
repo = (
make_dagster_repo_from_airflow_dags_path(tmpdir_path, repo_name,)
if fn_arg_path is None
else make_dagster_repo_from_airflow_dags_path(
os.path.join(tmpdir_path, fn_arg_path), repo_name
)
)
for pipeline_name in expected_pipeline_names:
assert repo.name == repo_name
assert repo.has_pipeline(pipeline_name)
pipeline = repo.get_pipeline(pipeline_name)
result = execute_pipeline(pipeline)
assert result.success
assert set(repo.pipeline_names) == set(expected_pipeline_names)
test_airflow_example_dags_inputs = [
(
[
"airflow_example_bash_operator",
"airflow_example_branch_dop_operator_v3",
"airflow_example_branch_operator",
"airflow_example_complex",
"airflow_example_external_task_marker_child",
"airflow_example_external_task_marker_parent",
"airflow_example_http_operator",
"airflow_example_nested_branch_dag", # only exists in airflow v1.10.10
"airflow_example_passing_params_via_test_command",
"airflow_example_pig_operator",
"airflow_example_python_operator",
"airflow_example_short_circuit_operator",
"airflow_example_skip_dag",
"airflow_example_subdag_operator",
"airflow_example_subdag_operator_section_1",
"airflow_example_subdag_operator_section_2",
"airflow_example_trigger_controller_dag",
"airflow_example_trigger_target_dag",
"airflow_example_xcom",
"airflow_latest_only",
"airflow_latest_only_with_trigger",
"airflow_test_utils",
"airflow_tutorial",
],
[
"airflow_example_external_task_marker_child",
"airflow_example_pig_operator",
"airflow_example_skip_dag",
"airflow_example_trigger_target_dag",
"airflow_example_xcom",
"airflow_test_utils",
],
),
]
@pytest.mark.parametrize(
"expected_pipeline_names, exclude_from_execution_tests", test_airflow_example_dags_inputs,
)
@requires_airflow_db
def test_airflow_example_dags(
expected_pipeline_names, exclude_from_execution_tests,
):
repo = make_dagster_repo_from_airflow_example_dags()
for pipeline_name in expected_pipeline_names:
assert repo.name == "airflow_example_dags_repo"
assert repo.has_pipeline(pipeline_name)
pipeline = repo.get_pipeline(pipeline_name)
if pipeline_name not in exclude_from_execution_tests:
result = execute_pipeline(pipeline)
assert result.success
assert set(repo.pipeline_names) == set(expected_pipeline_names)
diff --git a/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_fixtures.py b/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_fixtures.py
index fbf39044f..4618db296 100644
--- a/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_fixtures.py
+++ b/python_modules/libraries/dagster-airflow/dagster_airflow_tests/test_fixtures.py
@@ -1,226 +1,227 @@
import logging
import sys
+import tempfile
from contextlib import contextmanager
import pytest
from airflow import DAG
from airflow.exceptions import AirflowSkipException
from airflow.models import TaskInstance
from airflow.settings import LOG_FORMAT
from airflow.utils import timezone
-from dagster import file_relative_path, seven
+from dagster import file_relative_path
from dagster.core.test_utils import instance_for_test_tempdir
from dagster.core.utils import make_new_run_id
from dagster.utils import load_yaml_from_glob_list, merge_dicts
from dagster.utils.test.postgres_instance import TestPostgresInstance
@contextmanager
def postgres_instance(overrides=None):
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
with TestPostgresInstance.docker_service_up_or_skip(
file_relative_path(__file__, "docker-compose.yml"), "test-postgres-db-airflow",
) as pg_conn_string:
TestPostgresInstance.clean_run_storage(pg_conn_string)
TestPostgresInstance.clean_event_log_storage(pg_conn_string)
TestPostgresInstance.clean_schedule_storage(pg_conn_string)
with instance_for_test_tempdir(
temp_dir,
overrides=merge_dicts(
{
"run_storage": {
"module": "dagster_postgres.run_storage.run_storage",
"class": "PostgresRunStorage",
"config": {"postgres_url": pg_conn_string},
},
"event_log_storage": {
"module": "dagster_postgres.event_log.event_log",
"class": "PostgresEventLogStorage",
"config": {"postgres_url": pg_conn_string},
},
"schedule_storage": {
"module": "dagster_postgres.schedule_storage.schedule_storage",
"class": "PostgresScheduleStorage",
"config": {"postgres_url": pg_conn_string},
},
},
overrides if overrides else {},
),
) as instance:
yield instance
def execute_tasks_in_dag(dag, tasks, run_id, execution_date):
assert isinstance(dag, DAG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter(LOG_FORMAT))
root = logging.getLogger("airflow.task.operators")
root.setLevel(logging.DEBUG)
root.addHandler(handler)
dag_run = dag.create_dagrun(run_id=run_id, state="success", execution_date=execution_date)
results = {}
for task in tasks:
ti = TaskInstance(task=task, execution_date=execution_date)
context = ti.get_template_context()
context["dag_run"] = dag_run
try:
results[ti] = task.execute(context)
except AirflowSkipException as exc:
results[ti] = exc
return results
@pytest.fixture(scope="function")
def dagster_airflow_python_operator_pipeline():
"""This is a test fixture for running Dagster pipelines as Airflow DAGs.
Usage:
from dagster_airflow_tests.test_fixtures import dagster_airflow_python_operator_pipeline
def test_airflow(dagster_airflow_python_operator_pipeline):
results = dagster_airflow_python_operator_pipeline(
pipeline_name='test_pipeline',
recon_repo=reconstructable(define_pipeline),
environment_yaml=['environments/test_*.yaml']
)
assert len(results) == 3
"""
from dagster_airflow.factory import make_airflow_dag_for_recon_repo
from dagster_airflow.vendor.python_operator import PythonOperator
def _pipeline_fn(
recon_repo,
pipeline_name,
run_config=None,
environment_yaml=None,
op_kwargs=None,
mode=None,
execution_date=timezone.utcnow(),
):
if run_config is None and environment_yaml is not None:
run_config = load_yaml_from_glob_list(environment_yaml)
dag, tasks = make_airflow_dag_for_recon_repo(
recon_repo, pipeline_name, run_config, mode=mode, op_kwargs=op_kwargs
)
assert isinstance(dag, DAG)
for task in tasks:
assert isinstance(task, PythonOperator)
return execute_tasks_in_dag(
dag, tasks, run_id=make_new_run_id(), execution_date=execution_date
)
return _pipeline_fn
@pytest.fixture(scope="function")
def dagster_airflow_custom_operator_pipeline():
"""This is a test fixture for running Dagster pipelines with custom operators as Airflow DAGs.
Usage:
from dagster_airflow_tests.test_fixtures import dagster_airflow_custom_operator_pipeline
def test_airflow(dagster_airflow_python_operator_pipeline):
results = dagster_airflow_custom_operator_pipeline(
pipeline_name='test_pipeline',
recon_repo=reconstructable(define_pipeline),
operator=MyCustomOperator,
environment_yaml=['environments/test_*.yaml']
)
assert len(results) == 3
"""
from dagster_airflow.factory import make_airflow_dag_for_operator
from dagster_airflow.vendor.python_operator import PythonOperator
def _pipeline_fn(
recon_repo,
pipeline_name,
operator,
run_config=None,
environment_yaml=None,
op_kwargs=None,
mode=None,
execution_date=timezone.utcnow(),
):
if run_config is None and environment_yaml is not None:
run_config = load_yaml_from_glob_list(environment_yaml)
dag, tasks = make_airflow_dag_for_operator(
recon_repo, pipeline_name, operator, run_config, mode=mode, op_kwargs=op_kwargs
)
assert isinstance(dag, DAG)
for task in tasks:
assert isinstance(task, PythonOperator)
return execute_tasks_in_dag(
dag, tasks, run_id=make_new_run_id(), execution_date=execution_date
)
return _pipeline_fn
@pytest.fixture(scope="function")
def dagster_airflow_docker_operator_pipeline():
"""This is a test fixture for running Dagster pipelines as containerized Airflow DAGs.
Usage:
from dagster_airflow_tests.test_fixtures import dagster_airflow_docker_operator_pipeline
def test_airflow(dagster_airflow_docker_operator_pipeline):
results = dagster_airflow_docker_operator_pipeline(
pipeline_name='test_pipeline',
recon_repo=reconstructable(define_pipeline),
environment_yaml=['environments/test_*.yaml'],
image='myimage:latest'
)
assert len(results) == 3
"""
from dagster_airflow.factory import make_airflow_dag_containerized_for_recon_repo
from dagster_airflow.operators.docker_operator import DagsterDockerOperator
def _pipeline_fn(
recon_repo,
pipeline_name,
image,
run_config=None,
environment_yaml=None,
op_kwargs=None,
mode=None,
execution_date=timezone.utcnow(),
):
if run_config is None and environment_yaml is not None:
run_config = load_yaml_from_glob_list(environment_yaml)
op_kwargs = op_kwargs or {}
op_kwargs["network_mode"] = "container:test-postgres-db-airflow"
with postgres_instance() as instance:
dag, tasks = make_airflow_dag_containerized_for_recon_repo(
recon_repo=recon_repo,
pipeline_name=pipeline_name,
image=image,
mode=mode,
run_config=run_config,
op_kwargs=op_kwargs,
instance=instance,
)
assert isinstance(dag, DAG)
for task in tasks:
assert isinstance(task, DagsterDockerOperator)
return execute_tasks_in_dag(
dag, tasks, run_id=make_new_run_id(), execution_date=execution_date
)
return _pipeline_fn
diff --git a/python_modules/libraries/dagster-airflow/setup.py b/python_modules/libraries/dagster-airflow/setup.py
index 2aa081aaf..521db97b9 100644
--- a/python_modules/libraries/dagster-airflow/setup.py
+++ b/python_modules/libraries/dagster-airflow/setup.py
@@ -1,43 +1,41 @@
from setuptools import find_packages, setup
def get_version():
version = {}
with open("dagster_airflow/version.py") as fp:
exec(fp.read(), version) # pylint: disable=W0122
return version["__version__"]
if __name__ == "__main__":
ver = get_version()
setup(
name="dagster-airflow",
version=ver,
author="Elementl",
author_email="hello@elementl.com",
license="Apache-2.0",
description="Airflow plugin for Dagster",
url="https://github.com/dagster-io/dagster",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
packages=find_packages(exclude=["dagster_airflow_tests"]),
install_requires=[
"six",
"dagster=={ver}".format(ver=ver),
"docker",
"python-dateutil>=2.8.0",
"lazy_object_proxy",
"pendulum==1.4.4",
- # RSA 4.1+ is incompatible with py2.7
- 'rsa<=4.0; python_version<"3"',
# https://issues.apache.org/jira/browse/AIRFLOW-6854
'typing_extensions; python_version>="3.8"',
],
extras_require={"kubernetes": ["kubernetes>=3.0.0", "cryptography>=2.0.0"]},
entry_points={"console_scripts": ["dagster-airflow = dagster_airflow.cli:main"]},
)
diff --git a/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py b/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py
index afcdc99ce..9f18c1ebb 100644
--- a/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py
+++ b/python_modules/libraries/dagster-aws/dagster_aws/athena/resources.py
@@ -1,267 +1,267 @@
import csv
import io
import os
import time
import uuid
+from urllib.parse import urlparse
import boto3
from botocore.stub import Stubber
from dagster import Field, StringSource, check, resource
-from dagster.seven import urlparse
class AthenaError(Exception):
pass
class AthenaTimeout(AthenaError):
pass
class AthenaResource:
def __init__(self, client, workgroup="primary", polling_interval=5, max_polls=120):
check.invariant(
polling_interval >= 0, "polling_interval must be greater than or equal to 0"
)
check.invariant(max_polls > 0, "max_polls must be greater than 0")
self.client = client
self.workgroup = workgroup
self.max_polls = max_polls
self.polling_interval = polling_interval
def execute_query(self, query, fetch_results=False):
"""Synchronously execute a single query against Athena. If fetch_results is set to true,
will return a list of rows, where each row is a tuple of stringified values,
e.g. SELECT 1 will return [("1",)].
Args:
query (str): The query to execute.
fetch_results (Optional[bool]): Whether to return the results of executing the query.
Defaults to False, in which case the query will be executed without retrieving the
results.
Returns:
Optional[List[Tuple[Optional[str], ...]]]: Results of the query, as a list of tuples,
when fetch_results is set. Otherwise, return None. All items in the tuple are
represented as strings except for empty columns which are represented as None.
"""
check.str_param(query, "query")
check.bool_param(fetch_results, "fetch_results")
execution_id = self.client.start_query_execution(
QueryString=query, WorkGroup=self.workgroup
)["QueryExecutionId"]
self._poll(execution_id)
if fetch_results:
return self._results(execution_id)
def _poll(self, execution_id):
retries = self.max_polls
state = "QUEUED"
while retries > 0:
execution = self.client.get_query_execution(QueryExecutionId=execution_id)[
"QueryExecution"
]
state = execution["Status"]["State"]
if state not in ["QUEUED", "RUNNING"]:
break
retries -= 1
time.sleep(self.polling_interval)
if retries <= 0:
raise AthenaTimeout()
if state != "SUCCEEDED":
raise AthenaError(execution["Status"]["StateChangeReason"])
def _results(self, execution_id):
execution = self.client.get_query_execution(QueryExecutionId=execution_id)["QueryExecution"]
s3 = boto3.resource("s3")
output_location = execution["ResultConfiguration"]["OutputLocation"]
bucket = urlparse(output_location).netloc
prefix = urlparse(output_location).path.lstrip("/")
results = []
rows = s3.Bucket(bucket).Object(prefix).get()["Body"].read().decode().splitlines()
reader = csv.reader(rows)
next(reader) # Skip the CSV's header row
for row in reader:
results.append(tuple(row))
return results
class FakeAthenaResource(AthenaResource):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.polling_interval = 0
self.stubber = Stubber(self.client)
s3 = boto3.resource("s3", region_name="us-east-1")
self.bucket = s3.Bucket("fake-athena-results-bucket")
self.bucket.create()
def execute_query(
self, query, fetch_results=False, expected_states=None, expected_results=None
): # pylint: disable=arguments-differ
"""Fake for execute_query; stubs the expected Athena endpoints, polls against the provided
expected query execution states, and returns the provided results as a list of tuples.
Args:
query (str): The query to execute.
fetch_results (Optional[bool]): Whether to return the results of executing the query.
Defaults to False, in which case the query will be executed without retrieving the
results.
expected_states (list[str]): The expected query execution states.
Defaults to successfully passing through QUEUED, RUNNING, and SUCCEEDED.
expected_results ([List[Tuple[Any, ...]]]): The expected results. All non-None items
are cast to strings.
Defaults to [(1,)].
Returns:
Optional[List[Tuple[Optional[str], ...]]]: The expected_resutls when fetch_resutls is
set. Otherwise, return None. All items in the tuple are represented as strings except
for empty columns which are represented as None.
"""
if not expected_states:
expected_states = ["QUEUED", "RUNNING", "SUCCEEDED"]
if not expected_results:
expected_results = [("1",)]
self.stubber.activate()
execution_id = str(uuid.uuid4())
self._stub_start_query_execution(execution_id, query)
self._stub_get_query_execution(execution_id, expected_states)
if expected_states[-1] == "SUCCEEDED" and fetch_results:
self._fake_results(execution_id, expected_results)
result = super().execute_query(query, fetch_results=fetch_results)
self.stubber.deactivate()
self.stubber.assert_no_pending_responses()
return result
def _stub_start_query_execution(self, execution_id, query):
self.stubber.add_response(
method="start_query_execution",
service_response={"QueryExecutionId": execution_id},
expected_params={"QueryString": query, "WorkGroup": self.workgroup},
)
def _stub_get_query_execution(self, execution_id, states):
for state in states:
self.stubber.add_response(
method="get_query_execution",
service_response={
"QueryExecution": {
"Status": {"State": state, "StateChangeReason": "state change reason"},
}
},
expected_params={"QueryExecutionId": execution_id},
)
def _fake_results(self, execution_id, expected_results):
with io.StringIO() as results:
writer = csv.writer(results)
# Athena adds a header row to its CSV output
writer.writerow([])
for row in expected_results:
# Athena writes all non-null columns as strings in its CSV output
stringified = tuple([str(item) for item in row if item])
writer.writerow(stringified)
results.seek(0)
self.bucket.Object(execution_id + ".csv").put(Body=results.read())
self.stubber.add_response(
method="get_query_execution",
service_response={
"QueryExecution": {
"ResultConfiguration": {
"OutputLocation": os.path.join(
"s3://", self.bucket.name, execution_id + ".csv"
)
}
}
},
expected_params={"QueryExecutionId": execution_id},
)
def athena_config():
"""Athena configuration."""
return {
"workgroup": Field(
str,
description="The Athena WorkGroup. https://docs.aws.amazon.com/athena/latest/ug/manage-queries-control-costs-with-workgroups.html",
is_required=False,
default_value="primary",
),
"polling_interval": Field(
int,
description="Time in seconds between checks to see if a query execution is finished. 5 seconds by default. Must be non-negative.",
is_required=False,
default_value=5,
),
"max_polls": Field(
int,
description="Number of times to poll before timing out. 120 attempts by default. When coupled with the default polling_interval, queries will timeout after 10 minutes (120 * 5 seconds). Must be greater than 0.",
is_required=False,
default_value=120,
),
"aws_access_key_id": Field(StringSource, is_required=False),
"aws_secret_access_key": Field(StringSource, is_required=False),
}
@resource(
config_schema=athena_config(), description="Resource for connecting to AWS Athena",
)
def athena_resource(context):
"""This resource enables connecting to AWS Athena and issuing queries against it.
Example:
.. code-block:: python
from dagster import ModeDefinition, execute_solid, solid
from dagster_aws.athena import athena_resource
@solid(required_resource_keys={"athena"})
def example_athena_solid(context):
return context.resources.athena.execute_query("SELECT 1", fetch_results=True)
result = execute_solid(
example_athena_solid, mode_def=ModeDefinition(resource_defs={"athena": athena_resource}),
)
assert result.output_value() == [("1",)]
"""
client = boto3.client(
"athena",
aws_access_key_id=context.resource_config.get("aws_access_key_id"),
aws_secret_access_key=context.resource_config.get("aws_secret_access_key"),
)
return AthenaResource(
client=client,
workgroup=context.resource_config.get("workgroup"),
polling_interval=context.resource_config.get("polling_interval"),
max_polls=context.resource_config.get("max_polls"),
)
@resource(
config_schema=athena_config(), description="Fake resource for connecting to AWS Athena",
)
def fake_athena_resource(context):
return FakeAthenaResource(
client=boto3.client("athena", region_name="us-east-1"),
workgroup=context.resource_config.get("workgroup"),
polling_interval=context.resource_config.get("polling_interval"),
max_polls=context.resource_config.get("max_polls"),
)
diff --git a/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py b/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py
index cf350dc7d..628990547 100644
--- a/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py
+++ b/python_modules/libraries/dagster-aws/dagster_aws/emr/emr.py
@@ -1,425 +1,425 @@
# Portions of this file are copied from the Yelp MRJob project:
#
# https://github.com/Yelp/mrjob
#
#
# Copyright 2009-2013 Yelp, David Marin
# Copyright 2015 Yelp
# Copyright 2017 Yelp
# Copyright 2018 Contributors
# Copyright 2019 Yelp and Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gzip
import re
from io import BytesIO
+from urllib.parse import urlparse
import boto3
import dagster
import six
from botocore.exceptions import WaiterError
from dagster import check
-from dagster.seven import urlparse
from dagster_aws.utils.mrjob.utils import _boto3_now, _wrap_aws_client, strip_microseconds
from .types import EMR_CLUSTER_TERMINATED_STATES, EmrClusterState, EmrStepState
# if we can't create or find our own service role, use the one
# created by the AWS console and CLI
_FALLBACK_SERVICE_ROLE = "EMR_DefaultRole"
# if we can't create or find our own instance profile, use the one
# created by the AWS console and CLI
_FALLBACK_INSTANCE_PROFILE = "EMR_EC2_DefaultRole"
class EmrError(Exception):
pass
class EmrJobRunner:
def __init__(
self, region, check_cluster_every=30, aws_access_key_id=None, aws_secret_access_key=None,
):
"""This object encapsulates various utilities for interacting with EMR clusters and invoking
steps (jobs) on them.
See also :py:class:`~dagster_aws.emr.EmrPySparkResource`, which wraps this job runner in a
resource for pyspark workloads.
Args:
region (str): AWS region to use
check_cluster_every (int, optional): How frequently to poll boto3 APIs for updates.
Defaults to 30 seconds.
aws_access_key_id ([type], optional): AWS access key ID. Defaults to None, which will
use the default boto3 credentials chain.
aws_secret_access_key ([type], optional): AWS secret access key. Defaults to None, which
will use the default boto3 credentials chain.
"""
self.region = check.str_param(region, "region")
# This is in seconds
self.check_cluster_every = check.int_param(check_cluster_every, "check_cluster_every")
self.aws_access_key_id = check.opt_str_param(aws_access_key_id, "aws_access_key_id")
self.aws_secret_access_key = check.opt_str_param(
aws_secret_access_key, "aws_secret_access_key"
)
def make_emr_client(self):
"""Creates a boto3 EMR client. Construction is wrapped in retries in case client connection
fails transiently.
Returns:
botocore.client.EMR: An EMR client
"""
raw_emr_client = boto3.client(
"emr",
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
region_name=self.region,
)
return _wrap_aws_client(raw_emr_client, min_backoff=self.check_cluster_every)
def cluster_id_from_name(self, cluster_name):
"""Get a cluster ID in the format "j-123ABC123ABC1" given a cluster name "my cool cluster".
Args:
cluster_name (str): The name of the cluster for which to find an ID
Returns:
str: The ID of the cluster
Raises:
EmrError: No cluster with the specified name exists
"""
check.str_param(cluster_name, "cluster_name")
response = self.make_emr_client().list_clusters().get("Clusters", [])
for cluster in response:
if cluster["Name"] == cluster_name:
return cluster["Id"]
raise EmrError(
"cluster {cluster_name} not found in region {region}".format(
cluster_name=cluster_name, region=self.region
)
)
@staticmethod
def construct_step_dict_for_command(step_name, command, action_on_failure="CONTINUE"):
"""Construct an EMR step definition which uses command-runner.jar to execute a shell command
on the EMR master.
Args:
step_name (str): The name of the EMR step (will show up in the EMR UI)
command (str): The shell command to execute with command-runner.jar
action_on_failure (str, optional): Configure action on failure (e.g., continue, or
terminate the cluster). Defaults to 'CONTINUE'.
Returns:
dict: Step definition dict
"""
check.str_param(step_name, "step_name")
check.list_param(command, "command", of_type=str)
check.str_param(action_on_failure, "action_on_failure")
return {
"Name": step_name,
"ActionOnFailure": action_on_failure,
"HadoopJarStep": {"Jar": "command-runner.jar", "Args": command},
}
def add_tags(self, log, tags, cluster_id):
"""Add tags in the dict tags to cluster cluster_id.
Args:
log (DagsterLogManager): Log manager, for logging
tags (dict): Dictionary of {'key': 'value'} tags
cluster_id (str): The ID of the cluster to tag
"""
check.dict_param(tags, "tags")
check.str_param(cluster_id, "cluster_id")
tags_items = sorted(tags.items())
self.make_emr_client().add_tags(
ResourceId=cluster_id, Tags=[dict(Key=k, Value=v) for k, v in tags_items]
)
log.info(
"Added EMR tags to cluster %s: %s"
% (cluster_id, ", ".join("%s=%s" % (tag, value) for tag, value in tags_items))
)
def run_job_flow(self, log, cluster_config):
"""Create an empty cluster on EMR, and return the ID of that job flow.
Args:
log (DagsterLogManager): Log manager, for logging
cluster_config (dict): Configuration for this EMR job flow. See:
https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html
Returns:
str: The cluster ID, e.g. "j-ZKIY4CKQRX72"
"""
check.dict_param(cluster_config, "cluster_config")
log.debug("Creating Elastic MapReduce cluster")
emr_client = self.make_emr_client()
log.debug(
"Calling run_job_flow(%s)"
% (", ".join("%s=%r" % (k, v) for k, v in sorted(cluster_config.items())))
)
cluster_id = emr_client.run_job_flow(**cluster_config)["JobFlowId"]
log.info("Created new cluster %s" % cluster_id)
# set EMR tags for the cluster
tags = cluster_config.get("Tags", {})
tags["__dagster_version"] = dagster.__version__
self.add_tags(log, tags, cluster_id)
return cluster_id
def describe_cluster(self, cluster_id):
"""Thin wrapper over boto3 describe_cluster.
Args:
cluster_id (str): Cluster to inspect
Returns:
dict: The cluster info. See:
https://docs.aws.amazon.com/emr/latest/APIReference/API_DescribeCluster.html
"""
check.str_param(cluster_id, "cluster_id")
emr_client = self.make_emr_client()
return emr_client.describe_cluster(ClusterId=cluster_id)
def describe_step(self, cluster_id, step_id):
"""Thin wrapper over boto3 describe_step.
Args:
cluster_id (str): Cluster to inspect
step_id (str): Step ID to describe
Returns:
dict: The step info. See:
https://docs.aws.amazon.com/emr/latest/APIReference/API_DescribeStep.html
"""
check.str_param(cluster_id, "cluster_id")
check.str_param(step_id, "step_id")
emr_client = self.make_emr_client()
return emr_client.describe_step(ClusterId=cluster_id, StepId=step_id)
def add_job_flow_steps(self, log, cluster_id, step_defs):
"""Submit the constructed job flow steps to EMR for execution.
Args:
log (DagsterLogManager): Log manager, for logging
cluster_id (str): The ID of the cluster
step_defs (List[dict]): List of steps; see also `construct_step_dict_for_command`
Returns:
List[str]: list of step IDs.
"""
check.str_param(cluster_id, "cluster_id")
check.list_param(step_defs, "step_defs", of_type=dict)
emr_client = self.make_emr_client()
steps_kwargs = dict(JobFlowId=cluster_id, Steps=step_defs)
log.debug(
"Calling add_job_flow_steps(%s)"
% ",".join(("%s=%r" % (k, v)) for k, v in steps_kwargs.items())
)
return emr_client.add_job_flow_steps(**steps_kwargs)["StepIds"]
def is_emr_step_complete(self, log, cluster_id, emr_step_id):
step = self.describe_step(cluster_id, emr_step_id)["Step"]
step_state = EmrStepState(step["Status"]["State"])
if step_state == EmrStepState.Pending:
cluster = self.describe_cluster(cluster_id)["Cluster"]
reason = _get_reason(cluster)
reason_desc = (": %s" % reason) if reason else ""
log.info("PENDING (cluster is %s%s)" % (cluster["Status"]["State"], reason_desc))
return False
elif step_state == EmrStepState.Running:
time_running_desc = ""
start = step["Status"]["Timeline"].get("StartDateTime")
if start:
time_running_desc = " for %s" % strip_microseconds(_boto3_now() - start)
log.info("RUNNING%s" % time_running_desc)
return False
# we're done, will return at the end of this
elif step_state == EmrStepState.Completed:
log.info("COMPLETED")
return True
else:
# step has failed somehow. *reason* seems to only be set
# when job is cancelled (e.g. 'Job terminated')
reason = _get_reason(step)
reason_desc = (" (%s)" % reason) if reason else ""
log.info("%s%s" % (step_state.value, reason_desc))
# print cluster status; this might give more context
# why step didn't succeed
cluster = self.describe_cluster(cluster_id)["Cluster"]
reason = _get_reason(cluster)
reason_desc = (": %s" % reason) if reason else ""
log.info(
"Cluster %s %s %s%s"
% (
cluster["Id"],
"was" if "ED" in cluster["Status"]["State"] else "is",
cluster["Status"]["State"],
reason_desc,
)
)
if EmrClusterState(cluster["Status"]["State"]) in EMR_CLUSTER_TERMINATED_STATES:
# was it caused by IAM roles?
self._check_for_missing_default_iam_roles(log, cluster)
# TODO: extract logs here to surface failure reason
# See: https://github.com/dagster-io/dagster/issues/1954
if step_state == EmrStepState.Failed:
log.error("EMR step %s failed" % emr_step_id)
raise EmrError("EMR step %s failed" % emr_step_id)
def _check_for_missing_default_iam_roles(self, log, cluster):
"""If cluster couldn't start due to missing IAM roles, tell user what to do."""
check.dict_param(cluster, "cluster")
reason = _get_reason(cluster)
if any(
reason.endswith("/%s is invalid" % role)
for role in (_FALLBACK_INSTANCE_PROFILE, _FALLBACK_SERVICE_ROLE)
):
log.warning(
"IAM roles are missing. See documentation for IAM roles on EMR here: "
"https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-iam-roles.html"
)
def log_location_for_cluster(self, cluster_id):
"""EMR clusters are typically launched with S3 logging configured. This method inspects a
cluster using boto3 describe_cluster to retrieve the log URI.
Args:
cluster_id (str): The cluster to inspect.
Raises:
EmrError: the log URI was missing (S3 log mirroring not enabled for this cluster)
Returns:
(str, str): log bucket and key
"""
check.str_param(cluster_id, "cluster_id")
# The S3 log URI is specified per job flow (cluster)
log_uri = self.describe_cluster(cluster_id)["Cluster"].get("LogUri", None)
# ugh, seriously boto3?! This will come back as string "None"
if log_uri == "None" or log_uri is None:
raise EmrError("Log URI not specified, cannot retrieve step execution logs")
# For some reason the API returns an s3n:// protocol log URI instead of s3://
log_uri = re.sub("^s3n", "s3", log_uri)
log_uri_parsed = urlparse(log_uri)
log_bucket = log_uri_parsed.netloc
log_key_prefix = log_uri_parsed.path.lstrip("/")
return log_bucket, log_key_prefix
def retrieve_logs_for_step_id(self, log, cluster_id, step_id):
"""Retrieves stdout and stderr logs for the given step ID.
Args:
log (DagsterLogManager): Log manager, for logging
cluster_id (str): EMR cluster ID
step_id (str): EMR step ID for the job that was submitted.
Returns
(str, str): Tuple of stdout log string contents, and stderr log string contents
"""
check.str_param(cluster_id, "cluster_id")
check.str_param(step_id, "step_id")
log_bucket, log_key_prefix = self.log_location_for_cluster(cluster_id)
prefix = "{log_key_prefix}{cluster_id}/steps/{step_id}".format(
log_key_prefix=log_key_prefix, cluster_id=cluster_id, step_id=step_id
)
stdout_log = self.wait_for_log(log, log_bucket, "{prefix}/stdout.gz".format(prefix=prefix))
stderr_log = self.wait_for_log(log, log_bucket, "{prefix}/stderr.gz".format(prefix=prefix))
return stdout_log, stderr_log
def wait_for_log(self, log, log_bucket, log_key, waiter_delay=30, waiter_max_attempts=20):
"""Wait for gzipped EMR logs to appear on S3. Note that EMR syncs logs to S3 every 5
minutes, so this may take a long time.
Args:
log_bucket (str): S3 bucket where log is expected to appear
log_key (str): S3 key for the log file
waiter_delay (int): How long to wait between attempts to check S3 for the log file
waiter_max_attempts (int): Number of attempts before giving up on waiting
Raises:
EmrError: Raised if we waited the full duration and the logs did not appear
Returns:
str: contents of the log file
"""
check.str_param(log_bucket, "log_bucket")
check.str_param(log_key, "log_key")
check.int_param(waiter_delay, "waiter_delay")
check.int_param(waiter_max_attempts, "waiter_max_attempts")
log.info(
"Attempting to get log: s3://{log_bucket}/{log_key}".format(
log_bucket=log_bucket, log_key=log_key
)
)
s3 = _wrap_aws_client(boto3.client("s3"), min_backoff=self.check_cluster_every)
waiter = s3.get_waiter("object_exists")
try:
waiter.wait(
Bucket=log_bucket,
Key=log_key,
WaiterConfig={"Delay": waiter_delay, "MaxAttempts": waiter_max_attempts},
)
except WaiterError as err:
six.raise_from(
EmrError("EMR log file did not appear on S3 after waiting"), err,
)
obj = BytesIO(s3.get_object(Bucket=log_bucket, Key=log_key)["Body"].read())
gzip_file = gzip.GzipFile(fileobj=obj)
return gzip_file.read().decode("utf-8")
def _get_reason(cluster_or_step):
"""Get state change reason message."""
# StateChangeReason is {} before the first state change
return cluster_or_step["Status"]["StateChangeReason"].get("Message", "")
diff --git a/python_modules/libraries/dagster-aws/dagster_aws/emr/pyspark_step_launcher.py b/python_modules/libraries/dagster-aws/dagster_aws/emr/pyspark_step_launcher.py
index 068433558..146c7c2ff 100644
--- a/python_modules/libraries/dagster-aws/dagster_aws/emr/pyspark_step_launcher.py
+++ b/python_modules/libraries/dagster-aws/dagster_aws/emr/pyspark_step_launcher.py
@@ -1,342 +1,343 @@
import os
import pickle
+import tempfile
import time
import boto3
from botocore.exceptions import ClientError
-from dagster import Field, StringSource, check, resource, seven
+from dagster import Field, StringSource, check, resource
from dagster.core.definitions.step_launcher import StepLauncher
from dagster.core.errors import raise_execution_interrupts
from dagster.core.events import log_step_event
from dagster.core.execution.plan.external_step import (
PICKLED_EVENTS_FILE_NAME,
PICKLED_STEP_RUN_REF_FILE_NAME,
step_context_to_step_run_ref,
)
from dagster_aws.emr import EmrError, EmrJobRunner, emr_step_main
from dagster_aws.emr.configs_spark import spark_config as get_spark_config
from dagster_aws.utils.mrjob.log4j import parse_hadoop_log4j_records
# On EMR, Spark is installed here
EMR_SPARK_HOME = "/usr/lib/spark/"
CODE_ZIP_NAME = "code.zip"
@resource(
{
"spark_config": get_spark_config(),
"cluster_id": Field(
StringSource, description="Name of the job flow (cluster) on which to execute."
),
"region_name": Field(StringSource, description="The AWS region that the cluster is in."),
"action_on_failure": Field(
str,
is_required=False,
default_value="CANCEL_AND_WAIT",
description="The EMR action to take when the cluster step fails: "
"https://docs.aws.amazon.com/emr/latest/APIReference/API_StepConfig.html",
),
"staging_bucket": Field(
StringSource,
is_required=True,
description="S3 bucket to use for passing files between the plan process and EMR "
"process.",
),
"staging_prefix": Field(
StringSource,
is_required=False,
default_value="emr_staging",
description="S3 key prefix inside the staging_bucket to use for files passed the plan "
"process and EMR process",
),
"wait_for_logs": Field(
bool,
is_required=False,
default_value=False,
description="If set, the system will wait for EMR logs to appear on S3. Note that logs "
"are copied every 5 minutes, so enabling this will add several minutes to the job "
"runtime.",
),
"local_pipeline_package_path": Field(
StringSource,
is_required=True,
description="Absolute path to the package that contains the pipeline definition(s) "
"whose steps will execute remotely on EMR. This is a path on the local fileystem of "
"the process executing the pipeline. The expectation is that this package will also be "
"available on the python path of the launched process running the Spark step on EMR, "
"either deployed on step launch via the deploy_pipeline_package option, referenced on "
"s3 via the s3_pipeline_package_path option, or installed on the cluster via bootstrap "
"actions.",
),
"deploy_local_pipeline_package": Field(
bool,
default_value=False,
is_required=False,
description="If set, before every step run, the launcher will zip up all the code in "
"local_pipeline_package_path, upload it to s3, and pass it to spark-submit's "
"--py-files option. This gives the remote process access to up-to-date user code. "
"If not set, the assumption is that some other mechanism is used for distributing code "
"to the EMR cluster. If this option is set to True, s3_pipeline_package_path should "
"not also be set.",
),
"s3_pipeline_package_path": Field(
StringSource,
is_required=False,
description="If set, this path will be passed to the --py-files option of spark-submit. "
"This should usually be a path to a zip file. If this option is set, "
"deploy_local_pipeline_package should not be set to True.",
),
}
)
def emr_pyspark_step_launcher(context):
return EmrPySparkStepLauncher(**context.resource_config)
emr_pyspark_step_launcher.__doc__ = "\n".join(
"- **" + option + "**: " + (field.description or "")
for option, field in emr_pyspark_step_launcher.config_schema.config_type.fields.items()
)
class EmrPySparkStepLauncher(StepLauncher):
def __init__(
self,
region_name,
staging_bucket,
staging_prefix,
wait_for_logs,
action_on_failure,
cluster_id,
spark_config,
local_pipeline_package_path,
deploy_local_pipeline_package,
s3_pipeline_package_path=None,
):
self.region_name = check.str_param(region_name, "region_name")
self.staging_bucket = check.str_param(staging_bucket, "staging_bucket")
self.staging_prefix = check.str_param(staging_prefix, "staging_prefix")
self.wait_for_logs = check.bool_param(wait_for_logs, "wait_for_logs")
self.action_on_failure = check.str_param(action_on_failure, "action_on_failure")
self.cluster_id = check.str_param(cluster_id, "cluster_id")
self.spark_config = spark_config
check.invariant(
not deploy_local_pipeline_package or not s3_pipeline_package_path,
"If deploy_local_pipeline_package is set to True, s3_pipeline_package_path should not "
"also be set.",
)
self.local_pipeline_package_path = check.str_param(
local_pipeline_package_path, "local_pipeline_package_path"
)
self.deploy_local_pipeline_package = check.bool_param(
deploy_local_pipeline_package, "deploy_local_pipeline_package"
)
self.s3_pipeline_package_path = check.opt_str_param(
s3_pipeline_package_path, "s3_pipeline_package_path"
)
self.emr_job_runner = EmrJobRunner(region=self.region_name)
def _post_artifacts(self, log, step_run_ref, run_id, step_key):
"""
Synchronize the step run ref and pyspark code to an S3 staging bucket for use on EMR.
For the zip file, consider the following toy example:
# Folder: my_pyspark_project/
# a.py
def foo():
print(1)
# b.py
def bar():
print(2)
# main.py
from a import foo
from b import bar
foo()
bar()
This will zip up `my_pyspark_project/` as `my_pyspark_project.zip`. Then, when running
`spark-submit --py-files my_pyspark_project.zip emr_step_main.py` on EMR this will
print 1, 2.
"""
from dagster_pyspark.utils import build_pyspark_zip
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
s3 = boto3.client("s3", region_name=self.region_name)
# Upload step run ref
def _upload_file_to_s3(local_path, s3_filename):
key = self._artifact_s3_key(run_id, step_key, s3_filename)
s3_uri = self._artifact_s3_uri(run_id, step_key, s3_filename)
log.debug(
"Uploading file {local_path} to {s3_uri}".format(
local_path=local_path, s3_uri=s3_uri
)
)
s3.upload_file(Filename=local_path, Bucket=self.staging_bucket, Key=key)
# Upload main file.
# The remote Dagster installation should also have the file, but locating it there
# could be a pain.
main_local_path = self._main_file_local_path()
_upload_file_to_s3(main_local_path, self._main_file_name())
if self.deploy_local_pipeline_package:
# Zip and upload package containing pipeline
zip_local_path = os.path.join(temp_dir, CODE_ZIP_NAME)
build_pyspark_zip(zip_local_path, self.local_pipeline_package_path)
_upload_file_to_s3(zip_local_path, CODE_ZIP_NAME)
# Create step run ref pickle file
step_run_ref_local_path = os.path.join(temp_dir, PICKLED_STEP_RUN_REF_FILE_NAME)
with open(step_run_ref_local_path, "wb") as step_pickle_file:
pickle.dump(step_run_ref, step_pickle_file)
_upload_file_to_s3(step_run_ref_local_path, PICKLED_STEP_RUN_REF_FILE_NAME)
def launch_step(self, step_context, prior_attempts_count):
step_run_ref = step_context_to_step_run_ref(
step_context, prior_attempts_count, self.local_pipeline_package_path
)
run_id = step_context.pipeline_run.run_id
log = step_context.log
step_key = step_run_ref.step_key
self._post_artifacts(log, step_run_ref, run_id, step_key)
emr_step_def = self._get_emr_step_def(run_id, step_key, step_context.solid.name)
emr_step_id = self.emr_job_runner.add_job_flow_steps(log, self.cluster_id, [emr_step_def])[
0
]
return self.wait_for_completion_and_log(log, run_id, step_key, emr_step_id, step_context)
def wait_for_completion_and_log(self, log, run_id, step_key, emr_step_id, step_context):
s3 = boto3.resource("s3", region_name=self.region_name)
try:
for event in self.wait_for_completion(log, s3, run_id, step_key, emr_step_id):
log_step_event(step_context, event)
yield event
except EmrError as emr_error:
if self.wait_for_logs:
self._log_logs_from_s3(log, emr_step_id)
raise emr_error
if self.wait_for_logs:
self._log_logs_from_s3(log, emr_step_id)
def wait_for_completion(self, log, s3, run_id, step_key, emr_step_id, check_interval=15):
""" We want to wait for the EMR steps to complete, and while that's happening, we want to
yield any events that have been written to S3 for us by the remote process.
After the the EMR steps complete, we want a final chance to fetch events before finishing
the step.
"""
done = False
all_events = []
# If this is being called within a `capture_interrupts` context, allow interrupts
# while waiting for the pyspark execution to complete, so that we can terminate slow or
# hanging steps
while not done:
with raise_execution_interrupts():
time.sleep(check_interval) # AWS rate-limits us if we poll it too often
done = self.emr_job_runner.is_emr_step_complete(log, self.cluster_id, emr_step_id)
all_events_new = self.read_events(s3, run_id, step_key)
if len(all_events_new) > len(all_events):
for i in range(len(all_events), len(all_events_new)):
yield all_events_new[i]
all_events = all_events_new
def read_events(self, s3, run_id, step_key):
events_s3_obj = s3.Object( # pylint: disable=no-member
self.staging_bucket, self._artifact_s3_key(run_id, step_key, PICKLED_EVENTS_FILE_NAME)
)
try:
events_data = events_s3_obj.get()["Body"].read()
return pickle.loads(events_data)
except ClientError as ex:
# The file might not be there yet, which is fine
if ex.response["Error"]["Code"] == "NoSuchKey":
return []
else:
raise ex
def _log_logs_from_s3(self, log, emr_step_id):
"""Retrieves the logs from the remote PySpark process that EMR posted to S3 and logs
them to the given log."""
stdout_log, stderr_log = self.emr_job_runner.retrieve_logs_for_step_id(
log, self.cluster_id, emr_step_id
)
# Since stderr is YARN / Hadoop Log4J output, parse and reformat those log lines for
# Dagster's logging system.
records = parse_hadoop_log4j_records(stderr_log)
for record in records:
log._log( # pylint: disable=protected-access
record.level,
"".join(["Spark Driver stderr: ", record.logger, ": ", record.message]),
{},
)
log.info("Spark Driver stdout: " + stdout_log)
def _get_emr_step_def(self, run_id, step_key, solid_name):
"""From the local Dagster instance, construct EMR steps that will kick off execution on a
remote EMR cluster.
"""
from dagster_spark.utils import flatten_dict, format_for_cli
action_on_failure = self.action_on_failure
# Execute Solid via spark-submit
conf = dict(flatten_dict(self.spark_config))
conf["spark.app.name"] = conf.get("spark.app.name", solid_name)
check.invariant(
conf.get("spark.master", "yarn") == "yarn",
desc="spark.master is configured as %s; cannot set Spark master on EMR to anything "
'other than "yarn"' % conf.get("spark.master"),
)
command = (
[
EMR_SPARK_HOME + "bin/spark-submit",
"--master",
"yarn",
"--deploy-mode",
conf.get("spark.submit.deployMode", "client"),
]
+ format_for_cli(list(flatten_dict(conf)))
+ [
"--py-files",
self._artifact_s3_uri(run_id, step_key, CODE_ZIP_NAME),
self._artifact_s3_uri(run_id, step_key, self._main_file_name()),
self.staging_bucket,
self._artifact_s3_key(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME),
]
)
return EmrJobRunner.construct_step_dict_for_command(
"Execute Solid %s" % solid_name, command, action_on_failure=action_on_failure
)
def _main_file_name(self):
return os.path.basename(self._main_file_local_path())
def _main_file_local_path(self):
return emr_step_main.__file__
def _artifact_s3_uri(self, run_id, step_key, filename):
key = self._artifact_s3_key(run_id, step_key, filename)
return "s3://{bucket}/{key}".format(bucket=self.staging_bucket, key=key)
def _artifact_s3_key(self, run_id, step_key, filename):
return "/".join([self.staging_prefix, run_id, step_key, os.path.basename(filename)])
diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py
index fc8d3ac89..caa9325ef 100644
--- a/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py
+++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/s3_tests/test_compute_log_manager.py
@@ -1,113 +1,114 @@
import os
import sys
+import tempfile
import six
-from dagster import DagsterEventType, execute_pipeline, pipeline, seven, solid
+from dagster import DagsterEventType, execute_pipeline, pipeline, solid
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher import DefaultRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.compute_log_manager import ComputeIOType
from dagster.core.storage.event_log import SqliteEventLogStorage
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import SqliteRunStorage
from dagster_aws.s3 import S3ComputeLogManager
HELLO_WORLD = "Hello World"
SEPARATOR = os.linesep if (os.name == "nt" and sys.version_info < (3,)) else "\n"
EXPECTED_LOGS = [
'STEP_START - Started execution of step "easy".',
'STEP_OUTPUT - Yielded output "result" of type "Any"',
'STEP_SUCCESS - Finished execution of step "easy"',
]
def test_compute_log_manager(mock_s3_bucket):
@pipeline
def simple():
@solid
def easy(context):
context.log.info("easy")
print(HELLO_WORLD) # pylint: disable=print-call
return "easy"
easy()
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
run_store = SqliteRunStorage.from_local(temp_dir)
event_store = SqliteEventLogStorage(temp_dir)
manager = S3ComputeLogManager(
bucket=mock_s3_bucket.name, prefix="my_prefix", local_dir=temp_dir
)
instance = DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=run_store,
event_storage=event_store,
compute_log_manager=manager,
run_coordinator=DefaultRunCoordinator(),
run_launcher=DefaultRunLauncher(),
)
result = execute_pipeline(simple, instance=instance)
compute_steps = [
event.step_key
for event in result.step_event_list
if event.event_type == DagsterEventType.STEP_START
]
assert len(compute_steps) == 1
step_key = compute_steps[0]
stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT)
assert stdout.data == HELLO_WORLD + SEPARATOR
stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR)
for expected in EXPECTED_LOGS:
assert expected in stderr.data
# Check S3 directly
s3_object = mock_s3_bucket.Object(
key="{prefix}/storage/{run_id}/compute_logs/easy.err".format(
prefix="my_prefix", run_id=result.run_id
),
)
stderr_s3 = six.ensure_str(s3_object.get()["Body"].read())
for expected in EXPECTED_LOGS:
assert expected in stderr_s3
# Check download behavior by deleting locally cached logs
compute_logs_dir = os.path.join(temp_dir, result.run_id, "compute_logs")
for filename in os.listdir(compute_logs_dir):
os.unlink(os.path.join(compute_logs_dir, filename))
stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT)
assert stdout.data == HELLO_WORLD + SEPARATOR
stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR)
for expected in EXPECTED_LOGS:
assert expected in stderr.data
def test_compute_log_manager_from_config(mock_s3_bucket):
s3_prefix = "foobar"
dagster_yaml = """
compute_logs:
module: dagster_aws.s3.compute_log_manager
class: S3ComputeLogManager
config:
bucket: "{s3_bucket}"
local_dir: "/tmp/cool"
prefix: "{s3_prefix}"
""".format(
s3_bucket=mock_s3_bucket.name, s3_prefix=s3_prefix
)
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with open(os.path.join(tempdir, "dagster.yaml"), "wb") as f:
f.write(six.ensure_binary(dagster_yaml))
instance = DagsterInstance.from_config(tempdir)
assert (
instance.compute_log_manager._s3_bucket # pylint: disable=protected-access
== mock_s3_bucket.name
)
assert instance.compute_log_manager._s3_prefix == s3_prefix # pylint: disable=protected-access
diff --git a/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py
index e2508c28f..0a8f79832 100644
--- a/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py
+++ b/python_modules/libraries/dagster-azure/dagster_azure_tests/blob_tests/test_compute_log_manager.py
@@ -1,130 +1,131 @@
import os
import sys
+import tempfile
import six
-from dagster import DagsterEventType, execute_pipeline, pipeline, seven, solid
+from dagster import DagsterEventType, execute_pipeline, pipeline, solid
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher.sync_in_memory_run_launcher import SyncInMemoryRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.compute_log_manager import ComputeIOType
from dagster.core.storage.event_log import SqliteEventLogStorage
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import SqliteRunStorage
from dagster.seven import mock
from dagster_azure.blob import AzureBlobComputeLogManager, FakeBlobServiceClient
HELLO_WORLD = "Hello World"
SEPARATOR = os.linesep if (os.name == "nt" and sys.version_info < (3,)) else "\n"
EXPECTED_LOGS = [
'STEP_START - Started execution of step "easy".',
'STEP_OUTPUT - Yielded output "result" of type "Any"',
'STEP_SUCCESS - Finished execution of step "easy"',
]
@mock.patch("dagster_azure.blob.compute_log_manager.generate_blob_sas")
@mock.patch("dagster_azure.blob.compute_log_manager.create_blob_client")
def test_compute_log_manager(
mock_create_blob_client, mock_generate_blob_sas, storage_account, container, credential
):
mock_generate_blob_sas.return_value = "fake-url"
fake_client = FakeBlobServiceClient(storage_account)
mock_create_blob_client.return_value = fake_client
@pipeline
def simple():
@solid
def easy(context):
context.log.info("easy")
print(HELLO_WORLD) # pylint: disable=print-call
return "easy"
easy()
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
run_store = SqliteRunStorage.from_local(temp_dir)
event_store = SqliteEventLogStorage(temp_dir)
manager = AzureBlobComputeLogManager(
storage_account=storage_account,
container=container,
prefix="my_prefix",
local_dir=temp_dir,
secret_key=credential,
)
instance = DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=run_store,
event_storage=event_store,
compute_log_manager=manager,
run_coordinator=DefaultRunCoordinator(),
run_launcher=SyncInMemoryRunLauncher(),
)
result = execute_pipeline(simple, instance=instance)
compute_steps = [
event.step_key
for event in result.step_event_list
if event.event_type == DagsterEventType.STEP_START
]
assert len(compute_steps) == 1
step_key = compute_steps[0]
stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT)
assert stdout.data == HELLO_WORLD + SEPARATOR
stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR)
for expected in EXPECTED_LOGS:
assert expected in stderr.data
# Check ADLS2 directly
adls2_object = fake_client.get_blob_client(
container=container,
blob="{prefix}/storage/{run_id}/compute_logs/easy.err".format(
prefix="my_prefix", run_id=result.run_id
),
)
adls2_stderr = six.ensure_str(adls2_object.download_blob().readall())
for expected in EXPECTED_LOGS:
assert expected in adls2_stderr
# Check download behavior by deleting locally cached logs
compute_logs_dir = os.path.join(temp_dir, result.run_id, "compute_logs")
for filename in os.listdir(compute_logs_dir):
os.unlink(os.path.join(compute_logs_dir, filename))
stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT)
assert stdout.data == HELLO_WORLD + SEPARATOR
stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR)
for expected in EXPECTED_LOGS:
assert expected in stderr.data
def test_compute_log_manager_from_config(storage_account, container, credential):
prefix = "foobar"
dagster_yaml = """
compute_logs:
module: dagster_azure.blob.compute_log_manager
class: AzureBlobComputeLogManager
config:
storage_account: "{storage_account}"
container: {container}
secret_key: {credential}
local_dir: "/tmp/cool"
prefix: "{prefix}"
""".format(
storage_account=storage_account, container=container, credential=credential, prefix=prefix
)
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with open(os.path.join(tempdir, "dagster.yaml"), "wb") as f:
f.write(six.ensure_binary(dagster_yaml))
instance = DagsterInstance.from_config(tempdir)
assert (
instance.compute_log_manager._storage_account # pylint: disable=protected-access
== storage_account
)
assert instance.compute_log_manager._container == container # pylint: disable=protected-access
assert instance.compute_log_manager._blob_prefix == prefix # pylint: disable=protected-access
diff --git a/python_modules/libraries/dagster-celery/dagster_celery_tests/test_execute.py b/python_modules/libraries/dagster-celery/dagster_celery_tests/test_execute.py
index 3a12c5fa4..fb1881827 100644
--- a/python_modules/libraries/dagster-celery/dagster_celery_tests/test_execute.py
+++ b/python_modules/libraries/dagster-celery/dagster_celery_tests/test_execute.py
@@ -1,314 +1,315 @@
# pylint doesn't know about pytest fixtures
# pylint: disable=unused-argument
import os
+import tempfile
from threading import Thread
import pytest
from dagster import (
CompositeSolidExecutionResult,
PipelineExecutionResult,
SolidExecutionResult,
execute_pipeline,
execute_pipeline_iterator,
seven,
)
from dagster.core.definitions.reconstructable import ReconstructablePipeline
from dagster.core.errors import DagsterExecutionInterruptedError, DagsterSubprocessError
from dagster.core.events import DagsterEventType
from dagster.core.test_utils import instance_for_test, instance_for_test_tempdir
from dagster.utils import send_interrupt
from dagster_celery_tests.repo import COMPOSITE_DEPTH
from dagster_celery_tests.utils import start_celery_worker
from .utils import ( # isort:skip
execute_eagerly_on_celery,
execute_pipeline_on_celery,
events_of_type,
REPO_FILE,
)
def test_execute_on_celery_default(dagster_celery_worker):
with execute_pipeline_on_celery("test_pipeline") as result:
assert result.result_for_solid("simple").output_value() == 1
assert len(result.step_event_list) == 4
assert len(events_of_type(result, "STEP_START")) == 1
assert len(events_of_type(result, "STEP_OUTPUT")) == 1
assert len(events_of_type(result, "OBJECT_STORE_OPERATION")) == 1
assert len(events_of_type(result, "STEP_SUCCESS")) == 1
def test_execute_serial_on_celery(dagster_celery_worker):
with execute_pipeline_on_celery("test_serial_pipeline") as result:
assert result.result_for_solid("simple").output_value() == 1
assert result.result_for_solid("add_one").output_value() == 2
assert len(result.step_event_list) == 10
assert len(events_of_type(result, "STEP_START")) == 2
assert len(events_of_type(result, "STEP_INPUT")) == 1
assert len(events_of_type(result, "STEP_OUTPUT")) == 2
assert len(events_of_type(result, "OBJECT_STORE_OPERATION")) == 3
assert len(events_of_type(result, "STEP_SUCCESS")) == 2
def test_execute_diamond_pipeline_on_celery(dagster_celery_worker):
with execute_pipeline_on_celery("test_diamond_pipeline") as result:
assert result.result_for_solid("emit_values").output_values == {
"value_one": 1,
"value_two": 2,
}
assert result.result_for_solid("add_one").output_value() == 2
assert result.result_for_solid("renamed").output_value() == 3
assert result.result_for_solid("subtract").output_value() == -1
def test_execute_parallel_pipeline_on_celery(dagster_celery_worker):
with execute_pipeline_on_celery("test_parallel_pipeline") as result:
assert len(result.solid_result_list) == 11
def test_execute_composite_pipeline_on_celery(dagster_celery_worker):
with execute_pipeline_on_celery("composite_pipeline") as result:
assert result.success
assert isinstance(result, PipelineExecutionResult)
assert len(result.solid_result_list) == 1
composite_solid_result = result.solid_result_list[0]
assert len(composite_solid_result.solid_result_list) == 2
for r in composite_solid_result.solid_result_list:
assert isinstance(r, CompositeSolidExecutionResult)
composite_solid_results = composite_solid_result.solid_result_list
for i in range(COMPOSITE_DEPTH):
next_level = []
assert len(composite_solid_results) == pow(2, i + 1)
for res in composite_solid_results:
assert isinstance(res, CompositeSolidExecutionResult)
for r in res.solid_result_list:
next_level.append(r)
composite_solid_results = next_level
assert len(composite_solid_results) == pow(2, COMPOSITE_DEPTH + 1)
assert all(
(isinstance(r, SolidExecutionResult) and r.success for r in composite_solid_results)
)
def test_execute_optional_outputs_pipeline_on_celery(dagster_celery_worker):
with execute_pipeline_on_celery("test_optional_outputs") as result:
assert len(result.solid_result_list) == 4
assert sum([int(x.skipped) for x in result.solid_result_list]) == 2
assert sum([int(x.success) for x in result.solid_result_list]) == 2
def test_execute_fails_pipeline_on_celery(dagster_celery_worker):
with execute_pipeline_on_celery("test_fails") as result:
assert len(result.solid_result_list) == 2 # fail & skip
assert not result.result_for_solid("fails").success
assert (
result.result_for_solid("fails").failure_data.error.message == "Exception: argjhgjh\n"
)
assert result.result_for_solid("should_never_execute").skipped
def test_terminate_pipeline_on_celery(rabbitmq):
with start_celery_worker():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
pipeline_def = ReconstructablePipeline.for_file(REPO_FILE, "interrupt_pipeline")
with instance_for_test_tempdir(tempdir) as instance:
run_config = {
"intermediate_storage": {"filesystem": {"config": {"base_dir": tempdir}}},
"execution": {"celery": {}},
}
results = []
result_types = []
interrupt_thread = None
received_interrupt = False
try:
for result in execute_pipeline_iterator(
pipeline=pipeline_def, run_config=run_config, instance=instance,
):
# Interrupt once the first step starts
if (
result.event_type == DagsterEventType.STEP_START
and not interrupt_thread
):
interrupt_thread = Thread(target=send_interrupt, args=())
interrupt_thread.start()
results.append(result)
result_types.append(result.event_type)
assert False
except DagsterExecutionInterruptedError:
received_interrupt = True
interrupt_thread.join()
assert received_interrupt
# At least one step succeeded (the one that was running when the interrupt fired)
assert DagsterEventType.STEP_SUCCESS in result_types
# At least one step was revoked (and there were no step failure events)
revoke_steps = [
result
for result in results
if result.event_type == DagsterEventType.ENGINE_EVENT
and "was revoked." in result.message
]
assert len(revoke_steps) > 0
# The overall pipeline failed
assert DagsterEventType.PIPELINE_FAILURE in result_types
def test_execute_eagerly_on_celery():
with instance_for_test() as instance:
with execute_eagerly_on_celery("test_pipeline", instance=instance) as result:
assert result.result_for_solid("simple").output_value() == 1
assert len(result.step_event_list) == 4
assert len(events_of_type(result, "STEP_START")) == 1
assert len(events_of_type(result, "STEP_OUTPUT")) == 1
assert len(events_of_type(result, "OBJECT_STORE_OPERATION")) == 1
assert len(events_of_type(result, "STEP_SUCCESS")) == 1
events = instance.all_logs(result.run_id)
start_markers = {}
end_markers = {}
for event in events:
dagster_event = event.dagster_event
if dagster_event.is_engine_event:
if dagster_event.engine_event_data.marker_start:
key = "{step}.{marker}".format(
step=event.step_key, marker=dagster_event.engine_event_data.marker_start
)
start_markers[key] = event.timestamp
if dagster_event.engine_event_data.marker_end:
key = "{step}.{marker}".format(
step=event.step_key, marker=dagster_event.engine_event_data.marker_end
)
end_markers[key] = event.timestamp
seen = set()
assert set(start_markers.keys()) == set(end_markers.keys())
for key in end_markers:
assert end_markers[key] - start_markers[key] > 0
seen.add(key)
def test_execute_eagerly_serial_on_celery():
with execute_eagerly_on_celery("test_serial_pipeline") as result:
assert result.result_for_solid("simple").output_value() == 1
assert result.result_for_solid("add_one").output_value() == 2
assert len(result.step_event_list) == 10
assert len(events_of_type(result, "STEP_START")) == 2
assert len(events_of_type(result, "STEP_INPUT")) == 1
assert len(events_of_type(result, "STEP_OUTPUT")) == 2
assert len(events_of_type(result, "OBJECT_STORE_OPERATION")) == 3
assert len(events_of_type(result, "STEP_SUCCESS")) == 2
def test_execute_eagerly_diamond_pipeline_on_celery():
with execute_eagerly_on_celery("test_diamond_pipeline") as result:
assert result.result_for_solid("emit_values").output_values == {
"value_one": 1,
"value_two": 2,
}
assert result.result_for_solid("add_one").output_value() == 2
assert result.result_for_solid("renamed").output_value() == 3
assert result.result_for_solid("subtract").output_value() == -1
def test_execute_eagerly_diamond_pipeline_subset_on_celery():
with execute_eagerly_on_celery("test_diamond_pipeline", subset=["emit_values"]) as result:
assert result.result_for_solid("emit_values").output_values == {
"value_one": 1,
"value_two": 2,
}
assert len(result.solid_result_list) == 1
def test_execute_eagerly_parallel_pipeline_on_celery():
with execute_eagerly_on_celery("test_parallel_pipeline") as result:
assert len(result.solid_result_list) == 11
def test_execute_eagerly_composite_pipeline_on_celery():
with execute_eagerly_on_celery("composite_pipeline") as result:
assert result.success
assert isinstance(result, PipelineExecutionResult)
assert len(result.solid_result_list) == 1
composite_solid_result = result.solid_result_list[0]
assert len(composite_solid_result.solid_result_list) == 2
for r in composite_solid_result.solid_result_list:
assert isinstance(r, CompositeSolidExecutionResult)
composite_solid_results = composite_solid_result.solid_result_list
for i in range(COMPOSITE_DEPTH):
next_level = []
assert len(composite_solid_results) == pow(2, i + 1)
for res in composite_solid_results:
assert isinstance(res, CompositeSolidExecutionResult)
for r in res.solid_result_list:
next_level.append(r)
composite_solid_results = next_level
assert len(composite_solid_results) == pow(2, COMPOSITE_DEPTH + 1)
assert all(
(isinstance(r, SolidExecutionResult) and r.success for r in composite_solid_results)
)
def test_execute_eagerly_optional_outputs_pipeline_on_celery():
with execute_eagerly_on_celery("test_optional_outputs") as result:
assert len(result.solid_result_list) == 4
assert sum([int(x.skipped) for x in result.solid_result_list]) == 2
assert sum([int(x.success) for x in result.solid_result_list]) == 2
def test_execute_eagerly_resources_limit_pipeline_on_celery():
with execute_eagerly_on_celery("test_resources_limit") as result:
assert result.result_for_solid("resource_req_solid").success
assert result.success
def test_execute_eagerly_fails_pipeline_on_celery():
with execute_eagerly_on_celery("test_fails") as result:
assert len(result.solid_result_list) == 2
assert not result.result_for_solid("fails").success
assert (
result.result_for_solid("fails").failure_data.error.message == "Exception: argjhgjh\n"
)
assert result.result_for_solid("should_never_execute").skipped
def test_execute_eagerly_retries_pipeline_on_celery():
with execute_eagerly_on_celery("test_retries") as result:
assert len(events_of_type(result, "STEP_START")) == 1
assert len(events_of_type(result, "STEP_UP_FOR_RETRY")) == 1
assert len(events_of_type(result, "STEP_RESTARTED")) == 1
assert len(events_of_type(result, "STEP_FAILURE")) == 1
def test_engine_error():
with seven.mock.patch(
"dagster.core.execution.context.system.SystemExecutionContextData.raise_on_error",
return_value=True,
):
with pytest.raises(DagsterSubprocessError):
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with instance_for_test_tempdir(tempdir) as instance:
storage = os.path.join(tempdir, "flakey_storage")
execute_pipeline(
ReconstructablePipeline.for_file(REPO_FILE, "engine_error"),
run_config={
"intermediate_storage": {
"filesystem": {"config": {"base_dir": storage}}
},
"execution": {
"celery": {"config": {"config_source": {"task_always_eager": True}}}
},
"solids": {"destroy": {"config": storage}},
},
instance=instance,
)
diff --git a/python_modules/libraries/dagster-celery/dagster_celery_tests/test_priority.py b/python_modules/libraries/dagster-celery/dagster_celery_tests/test_priority.py
index 2aca61f6e..4fab242c6 100644
--- a/python_modules/libraries/dagster-celery/dagster_celery_tests/test_priority.py
+++ b/python_modules/libraries/dagster-celery/dagster_celery_tests/test_priority.py
@@ -1,83 +1,84 @@
# pylint doesn't know about pytest fixtures
# pylint: disable=unused-argument
+import tempfile
import threading
import time
from collections import OrderedDict
-from dagster import ModeDefinition, default_executors, seven
+from dagster import ModeDefinition, default_executors
from dagster.core.storage.pipeline_run import PipelineRunsFilter
from dagster.core.test_utils import instance_for_test_tempdir
from dagster_celery import celery_executor
from dagster_celery.tags import DAGSTER_CELERY_RUN_PRIORITY_TAG
from .utils import execute_eagerly_on_celery, execute_on_thread, start_celery_worker
celery_mode_defs = [ModeDefinition(executor_defs=default_executors + [celery_executor])]
def test_eager_priority_pipeline():
with execute_eagerly_on_celery("simple_priority_pipeline") as result:
assert result.success
assert list(OrderedDict.fromkeys([evt.step_key for evt in result.step_event_list])) == [
"ten",
"nine",
"eight",
"seven_",
"six",
"five",
"four",
"three",
"two",
"one",
"zero",
]
# If this test is failing locally, it likely means that there is a rogue
# celery worker still running on your machine.
def test_run_priority_pipeline(rabbitmq):
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with instance_for_test_tempdir(tempdir) as instance:
low_done = threading.Event()
hi_done = threading.Event()
# enqueue low-priority tasks
low_thread = threading.Thread(
target=execute_on_thread,
args=("low_pipeline", low_done, instance.get_ref()),
kwargs={"tempdir": tempdir, "tags": {DAGSTER_CELERY_RUN_PRIORITY_TAG: "-3"}},
)
low_thread.daemon = True
low_thread.start()
time.sleep(1) # sleep so that we don't hit any sqlite concurrency issues
# enqueue hi-priority tasks
hi_thread = threading.Thread(
target=execute_on_thread,
args=("hi_pipeline", hi_done, instance.get_ref()),
kwargs={"tempdir": tempdir, "tags": {DAGSTER_CELERY_RUN_PRIORITY_TAG: "3"}},
)
hi_thread.daemon = True
hi_thread.start()
time.sleep(5) # sleep to give queue time to prioritize tasks
with start_celery_worker():
while not low_done.is_set() or not hi_done.is_set():
time.sleep(1)
low_runs = instance.get_runs(
filters=PipelineRunsFilter(pipeline_name="low_pipeline")
)
assert len(low_runs) == 1
low_run = low_runs[0]
lowstats = instance.get_run_stats(low_run.run_id)
hi_runs = instance.get_runs(filters=PipelineRunsFilter(pipeline_name="hi_pipeline"))
assert len(hi_runs) == 1
hi_run = hi_runs[0]
histats = instance.get_run_stats(hi_run.run_id)
assert lowstats.start_time < histats.start_time
assert lowstats.end_time > histats.end_time
diff --git a/python_modules/libraries/dagster-celery/dagster_celery_tests/utils.py b/python_modules/libraries/dagster-celery/dagster_celery_tests/utils.py
index 42a191a96..627262b3f 100644
--- a/python_modules/libraries/dagster-celery/dagster_celery_tests/utils.py
+++ b/python_modules/libraries/dagster-celery/dagster_celery_tests/utils.py
@@ -1,99 +1,100 @@
import os
import signal
import subprocess
+import tempfile
from contextlib import contextmanager
-from dagster import execute_pipeline, seven
+from dagster import execute_pipeline
from dagster.core.definitions.reconstructable import ReconstructablePipeline
from dagster.core.instance import DagsterInstance
from dagster.core.test_utils import instance_for_test
BUILDKITE = os.getenv("BUILDKITE")
REPO_FILE = os.path.join(os.path.dirname(__file__), "repo.py")
@contextmanager
def tempdir_wrapper(tempdir=None):
if tempdir:
yield tempdir
else:
- with seven.TemporaryDirectory() as t:
+ with tempfile.TemporaryDirectory() as t:
yield t
@contextmanager
def _instance_wrapper(instance):
if instance:
yield instance
else:
with instance_for_test() as instance:
yield instance
@contextmanager
def execute_pipeline_on_celery(
pipeline_name, instance=None, run_config=None, tempdir=None, tags=None, subset=None
):
with tempdir_wrapper(tempdir) as tempdir:
pipeline_def = ReconstructablePipeline.for_file(
REPO_FILE, pipeline_name
).subset_for_execution(subset)
with _instance_wrapper(instance) as wrapped_instance:
run_config = run_config or {
"intermediate_storage": {"filesystem": {"config": {"base_dir": tempdir}}},
"execution": {"celery": {}},
}
result = execute_pipeline(
pipeline_def, run_config=run_config, instance=wrapped_instance, tags=tags,
)
yield result
@contextmanager
def execute_eagerly_on_celery(pipeline_name, instance=None, tempdir=None, tags=None, subset=None):
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
run_config = {
"intermediate_storage": {"filesystem": {"config": {"base_dir": tempdir}}},
"execution": {"celery": {"config": {"config_source": {"task_always_eager": True}}}},
}
with execute_pipeline_on_celery(
pipeline_name,
instance=instance,
run_config=run_config,
tempdir=tempdir,
tags=tags,
subset=subset,
) as result:
yield result
def execute_on_thread(pipeline_name, done, instance_ref, tempdir=None, tags=None):
with DagsterInstance.from_ref(instance_ref) as instance:
with execute_pipeline_on_celery(
pipeline_name, tempdir=tempdir, tags=tags, instance=instance
):
done.set()
@contextmanager
def start_celery_worker(queue=None):
process = subprocess.Popen(
["dagster-celery", "worker", "start", "-A", "dagster_celery.app"]
+ (["-q", queue] if queue else [])
+ (["--", "--concurrency", "1"])
)
try:
yield
finally:
os.kill(process.pid, signal.SIGINT)
process.wait()
subprocess.check_output(["dagster-celery", "worker", "terminate"])
def events_of_type(result, event_type):
return [event for event in result.event_list if event.event_type_value == event_type]
diff --git a/python_modules/libraries/dagster-cron/dagster_cron_tests/test_cron_scheduler.py b/python_modules/libraries/dagster-cron/dagster_cron_tests/test_cron_scheduler.py
index 7dab8bb37..d3cf7a553 100644
--- a/python_modules/libraries/dagster-cron/dagster_cron_tests/test_cron_scheduler.py
+++ b/python_modules/libraries/dagster-cron/dagster_cron_tests/test_cron_scheduler.py
@@ -1,799 +1,796 @@
import os
import re
import subprocess
import sys
from contextlib import contextmanager
+from tempfile import TemporaryDirectory
import pytest
import yaml
from dagster import ScheduleDefinition
from dagster.core.definitions import lambda_solid, pipeline, repository
from dagster.core.host_representation import (
ManagedGrpcPythonEnvRepositoryLocationOrigin,
RepositoryLocation,
RepositoryLocationHandle,
)
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher.sync_in_memory_run_launcher import SyncInMemoryRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.scheduler.job import JobState, JobStatus, JobType, ScheduleJobData
from dagster.core.scheduler.scheduler import (
DagsterScheduleDoesNotExist,
DagsterScheduleReconciliationError,
DagsterSchedulerError,
)
from dagster.core.storage.event_log import InMemoryEventLogStorage
from dagster.core.storage.noop_compute_log_manager import NoOpComputeLogManager
from dagster.core.storage.pipeline_run import PipelineRunStatus
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import InMemoryRunStorage
from dagster.core.storage.schedules import SqliteScheduleStorage
from dagster.core.test_utils import environ
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
-from dagster.seven import (
- TemporaryDirectory,
- get_current_datetime_in_utc,
- get_timestamp_from_utc_datetime,
-)
+from dagster.seven import get_current_datetime_in_utc, get_timestamp_from_utc_datetime
from dagster_cron import SystemCronScheduler
from freezegun import freeze_time
@pytest.fixture(scope="function")
def restore_cron_tab():
with TemporaryDirectory() as tempdir:
crontab_backup = os.path.join(tempdir, "crontab_backup.txt")
with open(crontab_backup, "wb+") as f:
try:
output = subprocess.check_output(["crontab", "-l"])
f.write(output)
except subprocess.CalledProcessError:
# If a crontab hasn't been created yet, the command fails with a
# non-zero error code
pass
try:
subprocess.check_output(["crontab", "-r"])
except subprocess.CalledProcessError:
# If a crontab hasn't been created yet, the command fails with a
# non-zero error code
pass
yield
subprocess.check_output(["crontab", crontab_backup])
@pytest.fixture(scope="function")
def unset_dagster_home():
old_env = os.getenv("DAGSTER_HOME")
if old_env is not None:
del os.environ["DAGSTER_HOME"]
yield
if old_env is not None:
os.environ["DAGSTER_HOME"] = old_env
@pipeline
def no_config_pipeline():
@lambda_solid
def return_hello():
return "Hello"
return return_hello()
schedules_dict = {
"no_config_pipeline_daily_schedule": ScheduleDefinition(
name="no_config_pipeline_daily_schedule",
cron_schedule="0 0 * * *",
pipeline_name="no_config_pipeline",
run_config={"intermediate_storage": {"filesystem": None}},
),
"no_config_pipeline_every_min_schedule": ScheduleDefinition(
name="no_config_pipeline_every_min_schedule",
cron_schedule="* * * * *",
pipeline_name="no_config_pipeline",
run_config={"intermediate_storage": {"filesystem": None}},
),
"default_config_pipeline_every_min_schedule": ScheduleDefinition(
name="default_config_pipeline_every_min_schedule",
cron_schedule="* * * * *",
pipeline_name="no_config_pipeline",
),
}
def define_schedules():
return list(schedules_dict.values())
@repository
def test_repository():
if os.getenv("DAGSTER_TEST_SMALL_REPO"):
return [no_config_pipeline] + list(
filter(
lambda x: not x.name == "default_config_pipeline_every_min_schedule",
define_schedules(),
)
)
return [no_config_pipeline] + define_schedules()
@contextmanager
def get_test_external_repo():
with RepositoryLocationHandle.create_from_repository_location_origin(
ManagedGrpcPythonEnvRepositoryLocationOrigin(
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable, python_file=__file__, attribute="test_repository",
),
location_name="test_location",
)
) as handle:
yield RepositoryLocation.from_handle(handle).get_repository("test_repository")
@contextmanager
def get_smaller_external_repo():
with environ({"DAGSTER_TEST_SMALL_REPO": "1"}):
with get_test_external_repo() as repo:
yield repo
def get_cron_jobs():
output = subprocess.check_output(["crontab", "-l"])
return list(filter(None, output.decode("utf-8").strip().split("\n")))
def define_scheduler_instance(tempdir):
return DagsterInstance(
instance_type=InstanceType.EPHEMERAL,
local_artifact_storage=LocalArtifactStorage(tempdir),
run_storage=InMemoryRunStorage(),
event_storage=InMemoryEventLogStorage(),
compute_log_manager=NoOpComputeLogManager(),
schedule_storage=SqliteScheduleStorage.from_local(os.path.join(tempdir, "schedules")),
scheduler=SystemCronScheduler(),
run_coordinator=DefaultRunCoordinator(),
run_launcher=SyncInMemoryRunLauncher(),
)
def test_init(restore_cron_tab): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repository:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repository)
# Check schedules are saved to disk
assert "schedules" in os.listdir(tempdir)
assert instance.all_stored_job_state(job_type=JobType.SCHEDULE)
@freeze_time("2019-02-27")
def test_re_init(restore_cron_tab): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
now = get_current_datetime_in_utc()
# Start schedule
schedule_state = instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
assert (
schedule_state.job_specific_data.start_timestamp
== get_timestamp_from_utc_datetime(now)
)
# Check schedules are saved to disk
assert "schedules" in os.listdir(tempdir)
schedule_states = instance.all_stored_job_state(job_type=JobType.SCHEDULE)
for state in schedule_states:
if state.name == "no_config_pipeline_every_min_schedule":
assert state == schedule_state
@pytest.mark.parametrize("do_initial_reconcile", [True, False])
def test_start_and_stop_schedule(
restore_cron_tab, do_initial_reconcile,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
if do_initial_reconcile:
instance.reconcile_scheduler_state(external_repo)
schedule = external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
schedule_origin_id = schedule.get_external_origin_id()
instance.start_schedule_and_update_storage_state(schedule)
assert "schedules" in os.listdir(tempdir)
assert "{}.sh".format(schedule_origin_id) in os.listdir(
os.path.join(tempdir, "schedules", "scripts")
)
instance.stop_schedule_and_update_storage_state(schedule_origin_id)
assert "{}.sh".format(schedule_origin_id) not in os.listdir(
os.path.join(tempdir, "schedules", "scripts")
)
def test_start_non_existent_schedule(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with pytest.raises(DagsterScheduleDoesNotExist):
instance.stop_schedule_and_update_storage_state("asdf")
@pytest.mark.parametrize("do_initial_reconcile", [True, False])
def test_start_schedule_cron_job(
do_initial_reconcile, restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
if do_initial_reconcile:
instance.reconcile_scheduler_state(external_repo)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("default_config_pipeline_every_min_schedule")
)
# Inspect the cron tab
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 3
external_schedules_dict = {
external_repo.get_external_schedule(name).get_external_origin_id(): schedule_def
for name, schedule_def in schedules_dict.items()
}
for cron_job in cron_jobs:
match = re.findall(r"^(.*?) (/.*) > (.*) 2>&1 # dagster-schedule: (.*)", cron_job)
cron_schedule, command, log_file, schedule_origin_id = match[0]
schedule_def = external_schedules_dict[schedule_origin_id]
# Check cron schedule matches
if schedule_def.cron_schedule == "0 0 * * *":
assert cron_schedule == "@daily"
else:
assert cron_schedule == schedule_def.cron_schedule
# Check bash file exists
assert os.path.isfile(command)
# Check log file is correct
assert log_file.endswith("scheduler.log")
def test_remove_schedule_def(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
instance.reconcile_scheduler_state(external_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
with get_smaller_external_repo() as smaller_repo:
instance.reconcile_scheduler_state(smaller_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 2
def test_add_schedule_def(restore_cron_tab): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_smaller_external_repo() as external_repo:
# Start all schedule and verify cron tab, schedule storage, and errors
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 2
assert len(get_cron_jobs()) == 2
assert len(instance.scheduler_debug_info().errors) == 0
with get_test_external_repo() as external_repo:
# Reconcile with an additional schedule added
instance.reconcile_scheduler_state(external_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
assert len(get_cron_jobs()) == 2
assert len(instance.scheduler_debug_info().errors) == 0
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("default_config_pipeline_every_min_schedule")
)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
assert len(get_cron_jobs()) == 3
assert len(instance.scheduler_debug_info().errors) == 0
def test_start_and_stop_schedule_cron_tab(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
# Start schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 1
# Try starting it again
with pytest.raises(DagsterSchedulerError):
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 1
# Start another schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 2
# Stop second schedule
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
).get_external_origin_id()
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 1
# Try stopping second schedule again
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
).get_external_origin_id()
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 1
# Start second schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_daily_schedule")
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 2
# Reconcile schedule state, should be in the same state
instance.reconcile_scheduler_state(external_repo)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 2
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("default_config_pipeline_every_min_schedule")
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 3
# Reconcile schedule state, should be in the same state
instance.reconcile_scheduler_state(external_repo)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 3
# Stop all schedules
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
).get_external_origin_id()
)
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"default_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 0
# Reconcile schedule state, should be in the same state
instance.reconcile_scheduler_state(external_repo)
cron_jobs = get_cron_jobs()
assert len(cron_jobs) == 0
def test_script_execution(
restore_cron_tab, unset_dagster_home
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
os.environ["DAGSTER_HOME"] = tempdir
config = {
"scheduler": {"module": "dagster_cron", "class": "SystemCronScheduler", "config": {}},
# This needs to synchronously execute to completion when
# the generated bash script is invoked
"run_launcher": {
"module": "dagster.core.launcher.sync_in_memory_run_launcher",
"class": "SyncInMemoryRunLauncher",
},
}
with open(os.path.join(tempdir, "dagster.yaml"), "w+") as f:
f.write(yaml.dump(config))
instance = DagsterInstance.get()
with get_test_external_repo() as external_repo:
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
schedule_origin_id = external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
script = instance.scheduler._get_bash_script_file_path( # pylint: disable=protected-access
instance, schedule_origin_id
)
subprocess.check_output([script], shell=True, env={"DAGSTER_HOME": tempdir})
runs = instance.get_runs()
assert len(runs) == 1
assert runs[0].status == PipelineRunStatus.SUCCESS
def test_start_schedule_fails(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
def raises(*args, **kwargs):
raise Exception("Patch")
instance._scheduler._start_cron_job = raises # pylint: disable=protected-access
with pytest.raises(Exception, match="Patch"):
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
schedule = instance.get_job_state(
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
assert schedule.status == JobStatus.STOPPED
def test_start_schedule_unsuccessful(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
def do_nothing(*_):
pass
instance._scheduler._start_cron_job = do_nothing # pylint: disable=protected-access
# End schedule
with pytest.raises(
DagsterSchedulerError,
match="Attempted to write cron job for schedule no_config_pipeline_every_min_schedule, "
"but failed. The scheduler is not running no_config_pipeline_every_min_schedule.",
):
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
def test_start_schedule_manual_delete_debug(
restore_cron_tab, snapshot # pylint:disable=unused-argument,redefined-outer-name
):
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# Manually delete the schedule from the crontab
instance.scheduler._end_cron_job( # pylint: disable=protected-access
instance,
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id(),
)
# Check debug command
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 1
# Reconcile should fix error
instance.reconcile_scheduler_state(external_repo)
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 0
def test_start_schedule_manual_add_debug(
restore_cron_tab, snapshot # pylint:disable=unused-argument,redefined-outer-name
):
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
# Initialize scheduler
instance.reconcile_scheduler_state(external_repo)
# Manually add the schedule from to the crontab
instance.scheduler._start_cron_job( # pylint: disable=protected-access
instance,
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule"),
)
# Check debug command
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 1
# Reconcile should fix error
instance.reconcile_scheduler_state(external_repo)
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 0
def test_start_schedule_manual_duplicate_schedules_add_debug(
restore_cron_tab, snapshot # pylint:disable=unused-argument,redefined-outer-name
):
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
external_schedule = external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
)
instance.start_schedule_and_update_storage_state(external_schedule)
# Manually add extra cron tabs
instance.scheduler._start_cron_job( # pylint: disable=protected-access
instance, external_schedule,
)
instance.scheduler._start_cron_job( # pylint: disable=protected-access
instance, external_schedule,
)
# Check debug command
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 1
# Reconcile should fix error
instance.reconcile_scheduler_state(external_repo)
debug_info = instance.scheduler_debug_info()
assert len(debug_info.errors) == 0
def test_stop_schedule_fails(
restore_cron_tab, # pylint:disable=unused-argument,redefined-outer-name
):
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
external_schedule = external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
)
schedule_origin_id = external_schedule.get_external_origin_id()
def raises(*args, **kwargs):
raise Exception("Patch")
instance._scheduler._end_cron_job = raises # pylint: disable=protected-access
instance.start_schedule_and_update_storage_state(external_schedule)
assert "schedules" in os.listdir(tempdir)
assert "{}.sh".format(schedule_origin_id) in os.listdir(
os.path.join(tempdir, "schedules", "scripts")
)
# End schedule
with pytest.raises(Exception, match="Patch"):
instance.stop_schedule_and_update_storage_state(schedule_origin_id)
schedule = instance.get_job_state(schedule_origin_id)
assert schedule.status == JobStatus.RUNNING
def test_stop_schedule_unsuccessful(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
def do_nothing(*_):
pass
instance._scheduler._end_cron_job = do_nothing # pylint: disable=protected-access
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# End schedule
with pytest.raises(
DagsterSchedulerError,
match="Attempted to remove existing cron job for schedule "
"no_config_pipeline_every_min_schedule, but failed. There are still 1 jobs running for "
"the schedule.",
):
instance.stop_schedule_and_update_storage_state(
external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
).get_external_origin_id()
)
def test_wipe(restore_cron_tab): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
# Start schedule
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
# Wipe scheduler
instance.wipe_all_schedules()
# Check schedules are wiped
assert instance.all_stored_job_state(job_type=JobType.SCHEDULE) == []
def test_log_directory(restore_cron_tab): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
external_schedule = external_repo.get_external_schedule(
"no_config_pipeline_every_min_schedule"
)
schedule_log_path = instance.logs_path_for_schedule(
external_schedule.get_external_origin_id()
)
assert schedule_log_path.endswith(
"/schedules/logs/{schedule_origin_id}/scheduler.log".format(
schedule_origin_id=external_schedule.get_external_origin_id()
)
)
# Start schedule
instance.start_schedule_and_update_storage_state(external_schedule)
# Wipe scheduler
instance.wipe_all_schedules()
# Check schedules are wiped
assert instance.all_stored_job_state(job_type=JobType.SCHEDULE) == []
def test_reconcile_failure(restore_cron_tab): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
instance.reconcile_scheduler_state(external_repo)
instance.start_schedule_and_update_storage_state(
external_repo.get_external_schedule("no_config_pipeline_every_min_schedule")
)
def failed_start_job(*_):
raise DagsterSchedulerError("Failed to start")
def failed_end_job(*_):
raise DagsterSchedulerError("Failed to stop")
instance._scheduler.start_schedule = ( # pylint: disable=protected-access
failed_start_job
)
instance._scheduler.stop_schedule = failed_end_job # pylint: disable=protected-access
with pytest.raises(
DagsterScheduleReconciliationError,
match="Error 1: Failed to stop\n Error 2: Failed to stop\n Error 3: Failed to stop",
):
instance.reconcile_scheduler_state(external_repo)
@freeze_time("2019-02-27")
def test_reconcile_schedule_without_start_time():
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
external_schedule = external_repo.get_external_schedule(
"no_config_pipeline_daily_schedule"
)
legacy_schedule_state = JobState(
external_schedule.get_external_origin(),
JobType.SCHEDULE,
JobStatus.RUNNING,
ScheduleJobData(external_schedule.cron_schedule, None),
)
instance.add_job_state(legacy_schedule_state)
instance.reconcile_scheduler_state(external_repository=external_repo)
reconciled_schedule_state = instance.get_job_state(
external_schedule.get_external_origin_id()
)
assert reconciled_schedule_state.status == JobStatus.RUNNING
assert (
reconciled_schedule_state.job_specific_data.start_timestamp
== get_timestamp_from_utc_datetime(get_current_datetime_in_utc())
)
def test_reconcile_failure_when_deleting_schedule_def(
restore_cron_tab,
): # pylint:disable=unused-argument,redefined-outer-name
with TemporaryDirectory() as tempdir:
instance = define_scheduler_instance(tempdir)
with get_test_external_repo() as external_repo:
instance.reconcile_scheduler_state(external_repo)
assert len(instance.all_stored_job_state(job_type=JobType.SCHEDULE)) == 3
def failed_end_job(*_):
raise DagsterSchedulerError("Failed to stop")
instance._scheduler.stop_schedule_and_delete_from_storage = ( # pylint: disable=protected-access
failed_end_job
)
with pytest.raises(
DagsterScheduleReconciliationError, match="Error 1: Failed to stop",
):
with get_smaller_external_repo() as smaller_repo:
instance.reconcile_scheduler_state(smaller_repo)
diff --git a/python_modules/libraries/dagster-dask/dagster_dask_tests/test_execute.py b/python_modules/libraries/dagster-dask/dagster_dask_tests/test_execute.py
index b31c02155..21247820b 100644
--- a/python_modules/libraries/dagster-dask/dagster_dask_tests/test_execute.py
+++ b/python_modules/libraries/dagster-dask/dagster_dask_tests/test_execute.py
@@ -1,251 +1,251 @@
import asyncio
+import tempfile
import time
from threading import Thread
import dagster_pandas as dagster_pd
import pytest
from dagster import (
DagsterUnmetExecutorRequirementsError,
InputDefinition,
ModeDefinition,
execute_pipeline,
execute_pipeline_iterator,
file_relative_path,
pipeline,
reconstructable,
- seven,
solid,
)
from dagster.core.definitions.executor import default_executors
from dagster.core.definitions.reconstructable import ReconstructablePipeline
from dagster.core.errors import DagsterExecutionInterruptedError
from dagster.core.events import DagsterEventType
from dagster.core.test_utils import (
instance_for_test,
instance_for_test_tempdir,
nesting_composite_pipeline,
)
from dagster.utils import send_interrupt
from dagster_dask import DataFrame, dask_executor
from dask.distributed import Scheduler, Worker
@solid
def simple(_):
return 1
@pipeline(mode_defs=[ModeDefinition(executor_defs=default_executors + [dask_executor])])
def dask_engine_pipeline():
simple()
def test_execute_on_dask_local():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with instance_for_test_tempdir(tempdir) as instance:
result = execute_pipeline(
reconstructable(dask_engine_pipeline),
run_config={
"intermediate_storage": {"filesystem": {"config": {"base_dir": tempdir}}},
"execution": {"dask": {"config": {"cluster": {"local": {"timeout": 30}}}}},
},
instance=instance,
)
assert result.result_for_solid("simple").output_value() == 1
def dask_composite_pipeline():
return nesting_composite_pipeline(
6, 2, mode_defs=[ModeDefinition(executor_defs=default_executors + [dask_executor])]
)
def test_composite_execute():
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(dask_composite_pipeline),
run_config={
"intermediate_storage": {"filesystem": {}},
"execution": {"dask": {"config": {"cluster": {"local": {"timeout": 30}}}}},
},
instance=instance,
)
assert result.success
@solid(input_defs=[InputDefinition("df", dagster_pd.DataFrame)])
def pandas_solid(_, df): # pylint: disable=unused-argument
pass
@pipeline(mode_defs=[ModeDefinition(executor_defs=default_executors + [dask_executor])])
def pandas_pipeline():
pandas_solid()
def test_pandas_dask():
run_config = {
"solids": {
"pandas_solid": {
"inputs": {"df": {"csv": {"path": file_relative_path(__file__, "ex.csv")}}}
}
}
}
with instance_for_test() as instance:
result = execute_pipeline(
ReconstructablePipeline.for_file(__file__, pandas_pipeline.name),
run_config={
"intermediate_storage": {"filesystem": {}},
"execution": {"dask": {"config": {"cluster": {"local": {"timeout": 30}}}}},
**run_config,
},
instance=instance,
)
assert result.success
@solid(input_defs=[InputDefinition("df", DataFrame)])
def dask_solid(_, df): # pylint: disable=unused-argument
pass
@pipeline(mode_defs=[ModeDefinition(executor_defs=default_executors + [dask_executor])])
def dask_pipeline():
dask_solid()
def test_dask():
run_config = {
"solids": {
"dask_solid": {
"inputs": {"df": {"csv": {"path": file_relative_path(__file__, "ex*.csv")}}}
}
}
}
with instance_for_test() as instance:
result = execute_pipeline(
ReconstructablePipeline.for_file(__file__, dask_pipeline.name),
run_config={
"intermediate_storage": {"filesystem": {}},
"execution": {"dask": {"config": {"cluster": {"local": {"timeout": 30}}}}},
**run_config,
},
instance=instance,
)
assert result.success
def test_execute_on_dask_local_with_intermediate_storage():
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with instance_for_test_tempdir(tempdir) as instance:
result = execute_pipeline(
reconstructable(dask_engine_pipeline),
run_config={
"intermediate_storage": {"filesystem": {"config": {"base_dir": tempdir}}},
"execution": {"dask": {"config": {"cluster": {"local": {"timeout": 30}}}}},
},
instance=instance,
)
assert result.result_for_solid("simple").output_value() == 1
def test_execute_on_dask_local_with_default_storage():
with pytest.raises(DagsterUnmetExecutorRequirementsError):
with instance_for_test() as instance:
result = execute_pipeline(
reconstructable(dask_engine_pipeline),
run_config={
"execution": {"dask": {"config": {"cluster": {"local": {"timeout": 30}}}}},
},
instance=instance,
)
assert result.result_for_solid("simple").output_value() == 1
@solid(input_defs=[InputDefinition("df", DataFrame)])
def sleepy_dask_solid(_, df): # pylint: disable=unused-argument
start_time = time.time()
while True:
time.sleep(0.1)
if time.time() - start_time > 120:
raise Exception("Timed out")
@pipeline(mode_defs=[ModeDefinition(executor_defs=default_executors + [dask_executor])])
def sleepy_dask_pipeline():
sleepy_dask_solid()
def test_dask_terminate():
run_config = {
"solids": {
"sleepy_dask_solid": {
"inputs": {"df": {"csv": {"path": file_relative_path(__file__, "ex*.csv")}}}
}
}
}
interrupt_thread = None
result_types = []
received_interrupt = False
with instance_for_test() as instance:
try:
for result in execute_pipeline_iterator(
pipeline=ReconstructablePipeline.for_file(__file__, sleepy_dask_pipeline.name),
run_config=run_config,
instance=instance,
):
# Interrupt once the first step starts
if result.event_type == DagsterEventType.STEP_START and not interrupt_thread:
interrupt_thread = Thread(target=send_interrupt, args=())
interrupt_thread.start()
if result.event_type == DagsterEventType.STEP_FAILURE:
assert (
"DagsterExecutionInterruptedError"
in result.event_specific_data.error.message
)
result_types.append(result.event_type)
assert False
except DagsterExecutionInterruptedError:
received_interrupt = True
assert received_interrupt
interrupt_thread.join()
assert DagsterEventType.STEP_FAILURE in result_types
assert DagsterEventType.PIPELINE_FAILURE in result_types
def test_existing_scheduler():
def _execute(scheduler_address, instance):
return execute_pipeline(
reconstructable(dask_engine_pipeline),
run_config={
"intermediate_storage": {"filesystem": {}},
"execution": {
"dask": {"config": {"cluster": {"existing": {"address": scheduler_address}}}}
},
},
instance=instance,
)
async def _run_test():
with instance_for_test() as instance:
async with Scheduler() as scheduler:
async with Worker(scheduler.address) as _:
result = await asyncio.get_event_loop().run_in_executor(
None, _execute, scheduler.address, instance
)
assert result.success
assert result.result_for_solid("simple").output_value() == 1
asyncio.get_event_loop().run_until_complete(_run_test())
diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py b/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py
index fdfbc5801..73fe5dc13 100644
--- a/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py
+++ b/python_modules/libraries/dagster-databricks/dagster_databricks/databricks_pyspark_step_launcher.py
@@ -1,319 +1,320 @@
import io
import os.path
import pickle
+import tempfile
-from dagster import Bool, Field, StringSource, check, resource, seven
+from dagster import Bool, Field, StringSource, check, resource
from dagster.core.definitions.step_launcher import StepLauncher
from dagster.core.errors import raise_execution_interrupts
from dagster.core.events import log_step_event
from dagster.core.execution.plan.external_step import (
PICKLED_EVENTS_FILE_NAME,
PICKLED_STEP_RUN_REF_FILE_NAME,
step_context_to_step_run_ref,
)
from dagster.serdes import deserialize_value
from dagster_databricks import DatabricksJobRunner, databricks_step_main
from dagster_pyspark.utils import build_pyspark_zip
from .configs import (
define_databricks_secrets_config,
define_databricks_storage_config,
define_databricks_submit_run_config,
)
CODE_ZIP_NAME = "code.zip"
PICKLED_CONFIG_FILE_NAME = "config.pkl"
@resource(
{
"run_config": define_databricks_submit_run_config(),
"databricks_host": Field(
StringSource,
is_required=True,
description="Databricks host, e.g. uksouth.azuredatabricks.com",
),
"databricks_token": Field(
StringSource, is_required=True, description="Databricks access token",
),
"secrets_to_env_variables": define_databricks_secrets_config(),
"storage": define_databricks_storage_config(),
"local_pipeline_package_path": Field(
StringSource,
is_required=True,
description="Absolute path to the package that contains the pipeline definition(s) "
"whose steps will execute remotely on Databricks. This is a path on the local "
"fileystem of the process executing the pipeline. Before every step run, the "
"launcher will zip up the code in this path, upload it to DBFS, and unzip it "
"into the Python path of the remote Spark process. This gives the remote process "
"access to up-to-date user code.",
),
"staging_prefix": Field(
StringSource,
is_required=False,
default_value="/dagster_staging",
description="Directory in DBFS to use for uploaded job code. Must be absolute.",
),
"wait_for_logs": Field(
Bool,
is_required=False,
default_value=False,
description="If set, and if the specified cluster is configured to export logs, "
"the system will wait after job completion for the logs to appear in the configured "
"location. Note that logs are copied every 5 minutes, so enabling this will add "
"several minutes to the job runtime.",
),
}
)
def databricks_pyspark_step_launcher(context):
"""Resource for running solids as a Databricks Job.
When this resource is used, the solid will be executed in Databricks using the 'Run Submit'
API. Pipeline code will be zipped up and copied to a directory in DBFS along with the solid's
execution context.
Use the 'run_config' configuration to specify the details of the Databricks cluster used, and
the 'storage' key to configure persistent storage on that cluster. Storage is accessed by
setting the credentials in the Spark context, as documented `here for S3`_ and `here for ADLS`_.
.. _`here for S3`: https://docs.databricks.com/data/data-sources/aws/amazon-s3.html#alternative-1-set-aws-keys-in-the-spark-context
.. _`here for ADLS`: https://docs.microsoft.com/en-gb/azure/databricks/data/data-sources/azure/azure-datalake-gen2#--access-directly-using-the-storage-account-access-key
"""
return DatabricksPySparkStepLauncher(**context.resource_config)
class DatabricksPySparkStepLauncher(StepLauncher):
def __init__(
self,
run_config,
databricks_host,
databricks_token,
secrets_to_env_variables,
storage,
local_pipeline_package_path,
staging_prefix,
wait_for_logs,
):
self.run_config = check.dict_param(run_config, "run_config")
self.databricks_host = check.str_param(databricks_host, "databricks_host")
self.databricks_token = check.str_param(databricks_token, "databricks_token")
self.secrets = check.list_param(secrets_to_env_variables, "secrets_to_env_variables", dict)
self.storage = check.dict_param(storage, "storage")
self.local_pipeline_package_path = check.str_param(
local_pipeline_package_path, "local_pipeline_package_path"
)
self.staging_prefix = check.str_param(staging_prefix, "staging_prefix")
check.invariant(staging_prefix.startswith("/"), "staging_prefix must be an absolute path")
self.wait_for_logs = check.bool_param(wait_for_logs, "wait_for_logs")
self.databricks_runner = DatabricksJobRunner(host=databricks_host, token=databricks_token)
def launch_step(self, step_context, prior_attempts_count):
step_run_ref = step_context_to_step_run_ref(
step_context, prior_attempts_count, self.local_pipeline_package_path
)
run_id = step_context.pipeline_run.run_id
log = step_context.log
step_key = step_run_ref.step_key
self._upload_artifacts(log, step_run_ref, run_id, step_key)
task = self._get_databricks_task(run_id, step_key)
databricks_run_id = self.databricks_runner.submit_run(self.run_config, task)
try:
# If this is being called within a `capture_interrupts` context, allow interrupts while
# waiting for the execution to complete, so that we can terminate slow or hanging steps
with raise_execution_interrupts():
self.databricks_runner.wait_for_run_to_complete(log, databricks_run_id)
finally:
if self.wait_for_logs:
self._log_logs_from_cluster(log, databricks_run_id)
for event in self.get_step_events(run_id, step_key):
log_step_event(step_context, event)
yield event
def get_step_events(self, run_id, step_key):
path = self._dbfs_path(run_id, step_key, PICKLED_EVENTS_FILE_NAME)
events_data = self.databricks_runner.client.read_file(path)
return deserialize_value(pickle.loads(events_data))
def _get_databricks_task(self, run_id, step_key):
"""Construct the 'task' parameter to be submitted to the Databricks API.
This will create a 'spark_python_task' dict where `python_file` is a path on DBFS
pointing to the 'databricks_step_main.py' file, and `parameters` is an array with a single
element, a path on DBFS pointing to the picked `step_run_ref` data.
See https://docs.databricks.com/dev-tools/api/latest/jobs.html#jobssparkpythontask.
"""
python_file = self._dbfs_path(run_id, step_key, self._main_file_name())
parameters = [
self._internal_dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME),
self._internal_dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME),
self._internal_dbfs_path(run_id, step_key, CODE_ZIP_NAME),
]
return {"spark_python_task": {"python_file": python_file, "parameters": parameters}}
def _upload_artifacts(self, log, step_run_ref, run_id, step_key):
"""Upload the step run ref and pyspark code to DBFS to run as a job."""
log.info("Uploading main file to DBFS")
main_local_path = self._main_file_local_path()
with open(main_local_path, "rb") as infile:
self.databricks_runner.client.put_file(
infile, self._dbfs_path(run_id, step_key, self._main_file_name())
)
log.info("Uploading pipeline to DBFS")
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
# Zip and upload package containing pipeline
zip_local_path = os.path.join(temp_dir, CODE_ZIP_NAME)
build_pyspark_zip(zip_local_path, self.local_pipeline_package_path)
with open(zip_local_path, "rb") as infile:
self.databricks_runner.client.put_file(
infile, self._dbfs_path(run_id, step_key, CODE_ZIP_NAME)
)
log.info("Uploading step run ref file to DBFS")
step_pickle_file = io.BytesIO()
pickle.dump(step_run_ref, step_pickle_file)
step_pickle_file.seek(0)
self.databricks_runner.client.put_file(
step_pickle_file, self._dbfs_path(run_id, step_key, PICKLED_STEP_RUN_REF_FILE_NAME),
)
databricks_config = DatabricksConfig(storage=self.storage, secrets=self.secrets,)
log.info("Uploading Databricks configuration to DBFS")
databricks_config_file = io.BytesIO()
pickle.dump(databricks_config, databricks_config_file)
databricks_config_file.seek(0)
self.databricks_runner.client.put_file(
databricks_config_file, self._dbfs_path(run_id, step_key, PICKLED_CONFIG_FILE_NAME),
)
def _log_logs_from_cluster(self, log, run_id):
logs = self.databricks_runner.retrieve_logs_for_run_id(log, run_id)
if logs is None:
return
stdout, stderr = logs
if stderr:
log.info(stderr)
if stdout:
log.info(stdout)
def _main_file_name(self):
return os.path.basename(self._main_file_local_path())
def _main_file_local_path(self):
return databricks_step_main.__file__
def _dbfs_path(self, run_id, step_key, filename):
path = "/".join([self.staging_prefix, run_id, step_key, os.path.basename(filename)])
return "dbfs://{}".format(path)
def _internal_dbfs_path(self, run_id, step_key, filename):
"""Scripts running on Databricks should access DBFS at /dbfs/."""
path = "/".join([self.staging_prefix, run_id, step_key, os.path.basename(filename)])
return "/dbfs/{}".format(path)
class DatabricksConfig:
"""Represents configuration required by Databricks to run jobs.
Instances of this class will be created when a Databricks step is launched and will contain
all configuration and secrets required to set up storage and environment variables within
the Databricks environment. The instance will be serialized and uploaded to Databricks
by the step launcher, then deserialized as part of the 'main' script when the job is running
in Databricks.
The `setup` method handles the actual setup prior to solid execution on the Databricks side.
This config is separated out from the regular Dagster run config system because the setup
is done by the 'main' script before entering a Dagster context (i.e. using `run_step_from_ref`).
We use a separate class to avoid coupling the setup to the format of the `step_run_ref` object.
"""
def __init__(self, storage, secrets):
"""Create a new DatabricksConfig object.
`storage` and `secrets` should be of the same shape as the `storage` and
`secrets_to_env_variables` config passed to `databricks_pyspark_step_launcher`.
"""
self.storage = storage
self.secrets = secrets
def setup(self, dbutils, sc):
"""Set up storage and environment variables on Databricks.
The `dbutils` and `sc` arguments must be passed in by the 'main' script, as they
aren't accessible by any other modules.
"""
self.setup_storage(dbutils, sc)
self.setup_environment(dbutils)
def setup_storage(self, dbutils, sc):
"""Set up storage using either S3 or ADLS2."""
if "s3" in self.storage:
self.setup_s3_storage(self.storage["s3"], dbutils, sc)
elif "adls2" in self.storage:
self.setup_adls2_storage(self.storage["adls2"], dbutils, sc)
else:
raise Exception("No valid storage found in Databricks configuration!")
def setup_s3_storage(self, s3_storage, dbutils, sc):
"""Obtain AWS credentials from Databricks secrets and export so both Spark and boto can use them."""
scope = s3_storage["secret_scope"]
access_key = dbutils.secrets.get(scope=scope, key=s3_storage["access_key_key"])
secret_key = dbutils.secrets.get(scope=scope, key=s3_storage["secret_key_key"])
# Spark APIs will use this.
# See https://docs.databricks.com/data/data-sources/aws/amazon-s3.html#alternative-1-set-aws-keys-in-the-spark-context.
sc._jsc.hadoopConfiguration().set( # pylint: disable=protected-access
"fs.s3n.awsAccessKeyId", access_key
)
sc._jsc.hadoopConfiguration().set( # pylint: disable=protected-access
"fs.s3n.awsSecretAccessKey", secret_key
)
# Boto will use these.
os.environ["AWS_ACCESS_KEY_ID"] = access_key
os.environ["AWS_SECRET_ACCESS_KEY"] = secret_key
def setup_adls2_storage(self, adls2_storage, dbutils, sc):
"""Obtain an Azure Storage Account key from Databricks secrets and export so Spark can use it."""
storage_account_key = dbutils.secrets.get(
scope=adls2_storage["secret_scope"], key=adls2_storage["storage_account_key_key"]
)
# Spark APIs will use this.
# See https://docs.microsoft.com/en-gb/azure/databricks/data/data-sources/azure/azure-datalake-gen2#--access-directly-using-the-storage-account-access-key
# sc is globally defined in the Databricks runtime and points to the Spark context
sc._jsc.hadoopConfiguration().set( # pylint: disable=protected-access
"fs.azure.account.key.{}.dfs.core.windows.net".format(
adls2_storage["storage_account_name"]
),
storage_account_key,
)
def setup_environment(self, dbutils):
"""Setup any environment variables required by the run.
Extract any secrets in the run config and export them as environment variables.
This is important for any `StringSource` config since the environment variables
won't ordinarily be available in the Databricks execution environment.
"""
for secret in self.secrets:
name = secret["name"]
key = secret["key"]
scope = secret["scope"]
print( # pylint: disable=print-call
"Exporting {} from Databricks secret {}, scope {}".format(name, key, scope)
)
val = dbutils.secrets.get(scope=scope, key=key)
os.environ[name] = val
diff --git a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py
index 887665541..92796348d 100644
--- a/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py
+++ b/python_modules/libraries/dagster-gcp/dagster_gcp_tests/gcs_tests/test_compute_log_manager.py
@@ -1,113 +1,114 @@
import os
import sys
+import tempfile
import six
-from dagster import DagsterEventType, execute_pipeline, pipeline, seven, solid
+from dagster import DagsterEventType, execute_pipeline, pipeline, solid
from dagster.core.instance import DagsterInstance, InstanceType
from dagster.core.launcher import DefaultRunLauncher
from dagster.core.run_coordinator import DefaultRunCoordinator
from dagster.core.storage.compute_log_manager import ComputeIOType
from dagster.core.storage.event_log import SqliteEventLogStorage
from dagster.core.storage.root import LocalArtifactStorage
from dagster.core.storage.runs import SqliteRunStorage
from dagster_gcp.gcs import GCSComputeLogManager
from google.cloud import storage
HELLO_WORLD = "Hello World"
SEPARATOR = os.linesep if (os.name == "nt" and sys.version_info < (3,)) else "\n"
EXPECTED_LOGS = [
'STEP_START - Started execution of step "easy".',
'STEP_OUTPUT - Yielded output "result" of type "Any"',
'STEP_SUCCESS - Finished execution of step "easy"',
]
def test_compute_log_manager(gcs_bucket):
@pipeline
def simple():
@solid
def easy(context):
context.log.info("easy")
print(HELLO_WORLD) # pylint: disable=print-call
return "easy"
easy()
- with seven.TemporaryDirectory() as temp_dir:
+ with tempfile.TemporaryDirectory() as temp_dir:
run_store = SqliteRunStorage.from_local(temp_dir)
event_store = SqliteEventLogStorage(temp_dir)
manager = GCSComputeLogManager(bucket=gcs_bucket, prefix="my_prefix", local_dir=temp_dir)
instance = DagsterInstance(
instance_type=InstanceType.PERSISTENT,
local_artifact_storage=LocalArtifactStorage(temp_dir),
run_storage=run_store,
event_storage=event_store,
compute_log_manager=manager,
run_coordinator=DefaultRunCoordinator(),
run_launcher=DefaultRunLauncher(),
)
result = execute_pipeline(simple, instance=instance)
compute_steps = [
event.step_key
for event in result.step_event_list
if event.event_type == DagsterEventType.STEP_START
]
assert len(compute_steps) == 1
step_key = compute_steps[0]
stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT)
assert stdout.data == HELLO_WORLD + SEPARATOR
stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR)
for expected in EXPECTED_LOGS:
assert expected in stderr.data
# Check GCS directly
stderr_gcs = six.ensure_str(
storage.Client()
.get_bucket(gcs_bucket)
.blob(
"{prefix}/storage/{run_id}/compute_logs/easy.err".format(
prefix="my_prefix", run_id=result.run_id
)
)
.download_as_string()
)
for expected in EXPECTED_LOGS:
assert expected in stderr_gcs
# Check download behavior by deleting locally cached logs
compute_logs_dir = os.path.join(temp_dir, result.run_id, "compute_logs")
for filename in os.listdir(compute_logs_dir):
os.unlink(os.path.join(compute_logs_dir, filename))
stdout = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDOUT)
assert stdout.data == HELLO_WORLD + SEPARATOR
stderr = manager.read_logs_file(result.run_id, step_key, ComputeIOType.STDERR)
for expected in EXPECTED_LOGS:
assert expected in stderr.data
def test_compute_log_manager_from_config(gcs_bucket):
s3_prefix = "foobar"
dagster_yaml = """
compute_logs:
module: dagster_gcp.gcs.compute_log_manager
class: GCSComputeLogManager
config:
bucket: "{bucket}"
local_dir: "/tmp/cool"
prefix: "{prefix}"
""".format(
bucket=gcs_bucket, prefix=s3_prefix
)
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with open(os.path.join(tempdir, "dagster.yaml"), "wb") as f:
f.write(six.ensure_binary(dagster_yaml))
instance = DagsterInstance.from_config(tempdir)
assert isinstance(instance.compute_log_manager, GCSComputeLogManager)
diff --git a/python_modules/libraries/dagster-gcp/setup.py b/python_modules/libraries/dagster-gcp/setup.py
index ae3c21fe7..eac912f47 100644
--- a/python_modules/libraries/dagster-gcp/setup.py
+++ b/python_modules/libraries/dagster-gcp/setup.py
@@ -1,41 +1,39 @@
from setuptools import find_packages, setup
def get_version():
version = {}
with open("dagster_gcp/version.py") as fp:
exec(fp.read(), version) # pylint: disable=W0122
return version["__version__"]
if __name__ == "__main__":
setup(
name="dagster-gcp",
version=get_version(),
author="Elementl",
author_email="hello@elementl.com",
license="Apache-2.0",
description="Package for GCP-specific Dagster framework solid and resource components.",
# pylint: disable=line-too-long
url="https://github.com/dagster-io/dagster/tree/master/python_modules/libraries/dagster-gcp",
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
packages=find_packages(exclude=["test"]),
install_requires=[
"dagster",
"dagster_pandas",
"google-api-python-client",
"google-cloud-bigquery>=1.19.*",
"google-cloud-storage",
"oauth2client",
- # RSA 4.1+ is incompatible with py2.7
- 'rsa<=4.0; python_version<"3"',
],
extras_require={"pyarrow": ["pyarrow; python_version < '3.9'"]},
zip_safe=False,
)
diff --git a/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py b/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py
index 11fc8b2a1..da2b0df53 100644
--- a/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py
+++ b/python_modules/libraries/dagster-postgres/dagster_postgres/utils.py
@@ -1,154 +1,154 @@
import logging
import time
from contextlib import contextmanager
+from urllib.parse import quote_plus as urlquote
import psycopg2
import psycopg2.errorcodes
import six
import sqlalchemy
from dagster import Field, IntSource, Selector, StringSource, check
from dagster.core.storage.sql import get_alembic_config, handle_schema_errors
-from dagster.seven import quote_plus as urlquote
class DagsterPostgresException(Exception):
pass
def get_conn(conn_string):
conn = psycopg2.connect(conn_string)
conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
return conn
def pg_config():
return Selector(
{
"postgres_url": str,
"postgres_db": {
"username": StringSource,
"password": StringSource,
"hostname": StringSource,
"db_name": StringSource,
"port": Field(IntSource, is_required=False, default_value=5432),
},
}
)
def pg_url_from_config(config_value):
if config_value.get("postgres_url"):
return config_value["postgres_url"]
return get_conn_string(**config_value["postgres_db"])
def get_conn_string(username, password, hostname, db_name, port="5432"):
return "postgresql://{username}:{password}@{hostname}:{port}/{db_name}".format(
username=username,
password=urlquote(password),
hostname=hostname,
db_name=db_name,
port=port,
)
def retry_pg_creation_fn(fn, retry_limit=5, retry_wait=0.2):
# Retry logic to recover from the case where two processes are creating
# tables at the same time using sqlalchemy
check.callable_param(fn, "fn")
check.int_param(retry_limit, "retry_limit")
check.numeric_param(retry_wait, "retry_wait")
while True:
try:
return fn()
except (
psycopg2.ProgrammingError,
psycopg2.IntegrityError,
sqlalchemy.exc.ProgrammingError,
sqlalchemy.exc.IntegrityError,
) as exc:
# Only programming error we want to retry on is the DuplicateTable error
if (
isinstance(exc, sqlalchemy.exc.ProgrammingError)
and exc.orig
and exc.orig.pgcode != psycopg2.errorcodes.DUPLICATE_TABLE
) or (
isinstance(exc, psycopg2.ProgrammingError)
and exc.pgcode != psycopg2.errorcodes.DUPLICATE_TABLE
):
raise
logging.warning("Retrying failed database creation")
if retry_limit == 0:
six.raise_from(DagsterPostgresException("too many retries for DB creation"), exc)
time.sleep(retry_wait)
retry_limit -= 1
def retry_pg_connection_fn(fn, retry_limit=5, retry_wait=0.2):
"""Reusable retry logic for any psycopg2/sqlalchemy PG connection functions that may fail.
Intended to be used anywhere we connect to PG, to gracefully handle transient connection issues.
"""
check.callable_param(fn, "fn")
check.int_param(retry_limit, "retry_limit")
check.numeric_param(retry_wait, "retry_wait")
while True:
try:
return fn()
except (
# See: https://www.psycopg.org/docs/errors.html
# These are broad, we may want to list out specific exceptions to capture
psycopg2.DatabaseError,
psycopg2.OperationalError,
sqlalchemy.exc.DatabaseError,
sqlalchemy.exc.OperationalError,
) as exc:
logging.warning("Retrying failed database connection")
if retry_limit == 0:
six.raise_from(DagsterPostgresException("too many retries for DB connection"), exc)
time.sleep(retry_wait)
retry_limit -= 1
def wait_for_connection(conn_string, retry_limit=5, retry_wait=0.2):
retry_pg_connection_fn(
lambda: psycopg2.connect(conn_string), retry_limit=retry_limit, retry_wait=retry_wait
)
return True
@contextmanager
def create_pg_connection(engine, dunder_file, storage_type_desc=None):
check.inst_param(engine, "engine", sqlalchemy.engine.Engine)
check.str_param(dunder_file, "dunder_file")
check.opt_str_param(storage_type_desc, "storage_type_desc", "")
if storage_type_desc:
storage_type_desc += " "
else:
storage_type_desc = ""
conn = None
try:
# Retry connection to gracefully handle transient connection issues
conn = retry_pg_connection_fn(engine.connect)
with handle_schema_errors(
conn,
get_alembic_config(dunder_file),
msg="Postgres {}storage requires migration".format(storage_type_desc),
):
yield conn
finally:
if conn:
conn.close()
def pg_statement_timeout(millis):
check.int_param(millis, "millis")
return "-c statement_timeout={}".format(millis)
diff --git a/python_modules/libraries/dagster-postgres/dagster_postgres_tests/compat_tests/test_back_compat.py b/python_modules/libraries/dagster-postgres/dagster_postgres_tests/compat_tests/test_back_compat.py
index a09684e17..e31b86159 100644
--- a/python_modules/libraries/dagster-postgres/dagster_postgres_tests/compat_tests/test_back_compat.py
+++ b/python_modules/libraries/dagster-postgres/dagster_postgres_tests/compat_tests/test_back_compat.py
@@ -1,220 +1,221 @@
# pylint: disable=protected-access
import os
import re
import subprocess
+import tempfile
import pytest
-from dagster import execute_pipeline, pipeline, seven, solid
+from dagster import execute_pipeline, pipeline, solid
from dagster.core.errors import DagsterInstanceMigrationRequired
from dagster.core.instance import DagsterInstance
from dagster.core.storage.event_log.migration import migrate_event_log_data
from dagster.core.storage.event_log.sql_event_log import SqlEventLogStorage
from dagster.utils import file_relative_path
from sqlalchemy import create_engine
def test_0_6_6_postgres(hostname, conn_string):
# Init a fresh postgres with a 0.6.6 snapshot
engine = create_engine(conn_string)
engine.execute("drop schema public cascade;")
engine.execute("create schema public;")
env = os.environ.copy()
env["PGPASSWORD"] = "test"
subprocess.check_call(
[
"psql",
"-h",
hostname,
"-p",
"5432",
"-U",
"test",
"-f",
file_relative_path(__file__, "snapshot_0_6_6/postgres/pg_dump.txt"),
],
env=env,
)
run_id = "089287c5-964d-44c0-b727-357eb7ba522e"
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
# Create the dagster.yaml
with open(file_relative_path(__file__, "dagster.yaml"), "r") as template_fd:
with open(os.path.join(tempdir, "dagster.yaml"), "w") as target_fd:
template = template_fd.read().format(hostname=hostname)
target_fd.write(template)
instance = DagsterInstance.from_config(tempdir)
# Runs will appear in DB, but event logs need migration
runs = instance.get_runs()
assert len(runs) == 1
assert instance.get_run_by_id(run_id)
assert instance.all_logs(run_id) == []
# Post migration, event logs appear in DB
instance.upgrade()
runs = instance.get_runs()
assert len(runs) == 1
assert instance.get_run_by_id(run_id)
assert len(instance.all_logs(run_id)) == 89
def test_0_7_6_postgres_pre_event_log_migration(hostname, conn_string):
engine = create_engine(conn_string)
engine.execute("drop schema public cascade;")
engine.execute("create schema public;")
env = os.environ.copy()
env["PGPASSWORD"] = "test"
subprocess.check_call(
[
"psql",
"-h",
hostname,
"-p",
"5432",
"-U",
"test",
"-f",
file_relative_path(
__file__, "snapshot_0_7_6_pre_event_log_migration/postgres/pg_dump.txt"
),
],
env=env,
)
run_id = "ca7f1e33-526d-4f75-9bc5-3e98da41ab97"
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with open(file_relative_path(__file__, "dagster.yaml"), "r") as template_fd:
with open(os.path.join(tempdir, "dagster.yaml"), "w") as target_fd:
template = template_fd.read().format(hostname=hostname)
target_fd.write(template)
instance = DagsterInstance.from_config(tempdir)
# Runs will appear in DB, but event logs need migration
runs = instance.get_runs()
assert len(runs) == 1
assert instance.get_run_by_id(run_id)
# Make sure the schema is migrated
instance.upgrade()
assert isinstance(instance._event_storage, SqlEventLogStorage)
events_by_id = instance._event_storage.get_logs_for_run_by_log_id(run_id)
assert len(events_by_id) == 40
step_key_records = []
for record_id, _event in events_by_id.items():
row_data = instance._event_storage.get_event_log_table_data(run_id, record_id)
if row_data.step_key is not None:
step_key_records.append(row_data)
assert len(step_key_records) == 0
# run the event_log data migration
migrate_event_log_data(instance=instance)
step_key_records = []
for record_id, _event in events_by_id.items():
row_data = instance._event_storage.get_event_log_table_data(run_id, record_id)
if row_data.step_key is not None:
step_key_records.append(row_data)
assert len(step_key_records) > 0
def test_0_7_6_postgres_pre_add_pipeline_snapshot(hostname, conn_string):
engine = create_engine(conn_string)
engine.execute("drop schema public cascade;")
engine.execute("create schema public;")
env = os.environ.copy()
env["PGPASSWORD"] = "test"
subprocess.check_call(
[
"psql",
"-h",
hostname,
"-p",
"5432",
"-U",
"test",
"-f",
file_relative_path(
__file__, "snapshot_0_7_6_pre_add_pipeline_snapshot/postgres/pg_dump.txt"
),
],
env=env,
)
run_id = "d5f89349-7477-4fab-913e-0925cef0a959"
- with seven.TemporaryDirectory() as tempdir:
+ with tempfile.TemporaryDirectory() as tempdir:
with open(file_relative_path(__file__, "dagster.yaml"), "r") as template_fd:
with open(os.path.join(tempdir, "dagster.yaml"), "w") as target_fd:
template = template_fd.read().format(hostname=hostname)
target_fd.write(template)
instance = DagsterInstance.from_config(tempdir)
@solid
def noop_solid(_):
pass
@pipeline
def noop_pipeline():
noop_solid()
with pytest.raises(
DagsterInstanceMigrationRequired, match=_migration_regex("run", current_revision=None)
):
execute_pipeline(noop_pipeline, instance=instance)
# ensure migration is run
instance.upgrade()
runs = instance.get_runs()
assert len(runs) == 1
assert runs[0].run_id == run_id
run = instance.get_run_by_id(run_id)
assert run.run_id == run_id
assert run.pipeline_snapshot_id is None
result = execute_pipeline(noop_pipeline, instance=instance)
assert result.success
runs = instance.get_runs()
assert len(runs) == 2
new_run_id = result.run_id
new_run = instance.get_run_by_id(new_run_id)
assert new_run.pipeline_snapshot_id
def _migration_regex(storage_name, current_revision, expected_revision=None):
warning = re.escape(
"Instance is out of date and must be migrated (Postgres {} storage requires migration).".format(
storage_name
)
)
if expected_revision:
revision = re.escape(
"Database is at revision {}, head is {}.".format(current_revision, expected_revision)
)
else:
revision = "Database is at revision {}, head is [a-z0-9]+.".format(current_revision)
instruction = re.escape("Please run `dagster instance migrate`.")
return "{} {} {}".format(warning, revision, instruction)
diff --git a/python_modules/libraries/dagster-postgres/dagster_postgres_tests/test_instance.py b/python_modules/libraries/dagster-postgres/dagster_postgres_tests/test_instance.py
index 4551fdab2..33fb09764 100644
--- a/python_modules/libraries/dagster-postgres/dagster_postgres_tests/test_instance.py
+++ b/python_modules/libraries/dagster-postgres/dagster_postgres_tests/test_instance.py
@@ -1,92 +1,93 @@
+import tempfile
+
import pytest
import sqlalchemy as db
import yaml
-from dagster import seven
from dagster.core.instance import DagsterInstance, InstanceRef
from dagster.core.test_utils import instance_for_test
from dagster_postgres.utils import get_conn
def full_pg_config(hostname):
return """
run_storage:
module: dagster_postgres.run_storage
class: PostgresRunStorage
config:
postgres_db:
username: test
password: test
hostname: {hostname}
db_name: test
event_log_storage:
module: dagster_postgres.event_log
class: PostgresEventLogStorage
config:
postgres_db:
username: test
password: test
hostname: {hostname}
db_name: test
schedule_storage:
module: dagster_postgres.schedule_storage
class: PostgresScheduleStorage
config:
postgres_db:
username: test
password: test
hostname: {hostname}
db_name: test
""".format(
hostname=hostname
)
def test_connection_leak(hostname, conn_string):
num_instances = 20
- tempdir = seven.TemporaryDirectory()
+ tempdir = tempfile.TemporaryDirectory()
copies = []
for _ in range(num_instances):
copies.append(
DagsterInstance.from_ref(
InstanceRef.from_dir(
tempdir.name, overrides=yaml.safe_load(full_pg_config(hostname))
)
)
)
with get_conn(conn_string).cursor() as curs:
# count open connections
curs.execute("SELECT count(*) from pg_stat_activity")
res = curs.fetchall()
# This includes a number of internal connections, so just ensure it did not scale
# with number of instances
assert res[0][0] < num_instances
for copy in copies:
copy.dispose()
tempdir.cleanup()
def test_statement_timeouts(hostname):
with instance_for_test(overrides=yaml.safe_load(full_pg_config(hostname))) as instance:
instance.optimize_for_dagit(statement_timeout=500) # 500ms
# ensure migration error is not raised by being up to date
instance.upgrade()
with pytest.raises(db.exc.OperationalError, match="QueryCanceled"):
with instance._run_storage.connect() as conn: # pylint: disable=protected-access
conn.execute("select pg_sleep(1)").fetchone()
with pytest.raises(db.exc.OperationalError, match="QueryCanceled"):
with instance._event_storage.connect() as conn: # pylint: disable=protected-access
conn.execute("select pg_sleep(1)").fetchone()
with pytest.raises(db.exc.OperationalError, match="QueryCanceled"):
with instance._schedule_storage.connect() as conn: # pylint: disable=protected-access
conn.execute("select pg_sleep(1)").fetchone()
diff --git a/python_modules/libraries/dagstermill/dagstermill/solids.py b/python_modules/libraries/dagstermill/dagstermill/solids.py
index 435bac065..e51366be1 100644
--- a/python_modules/libraries/dagstermill/dagstermill/solids.py
+++ b/python_modules/libraries/dagstermill/dagstermill/solids.py
@@ -1,329 +1,330 @@
import copy
import os
import pickle
import sys
+import tempfile
import uuid
import nbformat
import papermill
from dagster import (
AssetMaterialization,
EventMetadataEntry,
InputDefinition,
Output,
OutputDefinition,
SolidDefinition,
check,
seven,
)
from dagster.core.definitions.reconstructable import ReconstructablePipeline
from dagster.core.execution.context.compute import SolidExecutionContext
from dagster.core.execution.context.system import SystemComputeExecutionContext
from dagster.core.storage.file_manager import FileHandle
from dagster.serdes import pack_value
from dagster.utils import mkdir_p, safe_tempfile_path
from dagster.utils.error import serializable_error_info_from_exc_info
from papermill.engines import papermill_engines
from papermill.iorw import load_notebook_node, write_ipynb
from papermill.parameterize import _find_first_tagged_cell_index
from .engine import DagstermillNBConvertEngine
from .errors import DagstermillError
from .serialize import read_value, write_value
from .translator import RESERVED_INPUT_NAMES, DagsterTranslator
# This is based on papermill.parameterize.parameterize_notebook
# Typically, papermill injects the injected-parameters cell *below* the parameters cell
# but we want to *replace* the parameters cell, which is what this function does.
def replace_parameters(context, nb, parameters):
"""Assigned parameters into the appropiate place in the input notebook
Args:
nb (NotebookNode): Executable notebook object
parameters (dict): Arbitrary keyword arguments to pass to the notebook parameters.
"""
check.dict_param(parameters, "parameters")
# Copy the nb object to avoid polluting the input
nb = copy.deepcopy(nb)
# papermill method chooses translator based on kernel_name and language, but we just call the
# DagsterTranslator to generate parameter content based on the kernel_name
param_content = DagsterTranslator.codify(parameters)
newcell = nbformat.v4.new_code_cell(source=param_content)
newcell.metadata["tags"] = ["injected-parameters"]
param_cell_index = _find_first_tagged_cell_index(nb, "parameters")
injected_cell_index = _find_first_tagged_cell_index(nb, "injected-parameters")
if injected_cell_index >= 0:
# Replace the injected cell with a new version
before = nb.cells[:injected_cell_index]
after = nb.cells[injected_cell_index + 1 :]
check.int_value_param(param_cell_index, -1, "param_cell_index")
# We should have blown away the parameters cell if there is an injected-parameters cell
elif param_cell_index >= 0:
# Replace the parameter cell with the injected-parameters cell
before = nb.cells[:param_cell_index]
after = nb.cells[param_cell_index + 1 :]
else:
# Inject to the top of the notebook, presumably first cell includes dagstermill import
context.log.debug(
(
"Executing notebook with no tagged parameters cell: injecting boilerplate in first "
"cell."
)
)
before = []
after = nb.cells
nb.cells = before + [newcell] + after
nb.metadata.papermill["parameters"] = seven.json.dumps(parameters)
return nb
def get_papermill_parameters(compute_context, inputs, output_log_path):
check.inst_param(compute_context, "compute_context", SystemComputeExecutionContext)
check.param_invariant(
isinstance(compute_context.run_config, dict),
"compute_context",
"SystemComputeExecutionContext must have valid run_config",
)
check.dict_param(inputs, "inputs", key_type=str)
run_id = compute_context.run_id
marshal_dir = "/tmp/dagstermill/{run_id}/marshal".format(run_id=run_id)
mkdir_p(marshal_dir)
if not isinstance(compute_context.pipeline, ReconstructablePipeline):
raise DagstermillError(
"Can't execute a dagstermill solid from a pipeline that is not reconstructable. "
"Use the reconstructable() function if executing from python"
)
dm_executable_dict = compute_context.pipeline.to_dict()
dm_context_dict = {
"output_log_path": output_log_path,
"marshal_dir": marshal_dir,
"run_config": compute_context.run_config,
}
dm_solid_handle_kwargs = compute_context.solid_handle._asdict()
parameters = {}
input_def_dict = compute_context.solid_def.input_dict
for input_name, input_value in inputs.items():
assert (
input_name not in RESERVED_INPUT_NAMES
), "Dagstermill solids cannot have inputs named {input_name}".format(input_name=input_name)
dagster_type = input_def_dict[input_name].dagster_type
parameter_value = write_value(
dagster_type, input_value, os.path.join(marshal_dir, "input-{}".format(input_name))
)
parameters[input_name] = parameter_value
parameters["__dm_context"] = dm_context_dict
parameters["__dm_executable_dict"] = dm_executable_dict
parameters["__dm_pipeline_run_dict"] = pack_value(compute_context.pipeline_run)
parameters["__dm_solid_handle_kwargs"] = dm_solid_handle_kwargs
parameters["__dm_instance_ref_dict"] = pack_value(compute_context.instance.get_ref())
return parameters
def _dm_solid_compute(name, notebook_path, output_notebook=None, asset_key_prefix=None):
check.str_param(name, "name")
check.str_param(notebook_path, "notebook_path")
check.opt_str_param(output_notebook, "output_notebook")
check.opt_list_param(asset_key_prefix, "asset_key_prefix")
def _t_fn(compute_context, inputs):
check.inst_param(compute_context, "compute_context", SolidExecutionContext)
check.param_invariant(
isinstance(compute_context.run_config, dict),
"context",
"SystemComputeExecutionContext must have valid run_config",
)
system_compute_context = compute_context.get_system_context()
- with seven.TemporaryDirectory() as output_notebook_dir:
+ with tempfile.TemporaryDirectory() as output_notebook_dir:
with safe_tempfile_path() as output_log_path:
parameterized_notebook_path = os.path.join(
output_notebook_dir, "{prefix}-inter.ipynb".format(prefix=str(uuid.uuid4()))
)
executed_notebook_path = os.path.join(
output_notebook_dir, "{prefix}-out.ipynb".format(prefix=str(uuid.uuid4()))
)
# Scaffold the registration here
nb = load_notebook_node(notebook_path)
nb_no_parameters = replace_parameters(
system_compute_context,
nb,
get_papermill_parameters(system_compute_context, inputs, output_log_path),
)
write_ipynb(nb_no_parameters, parameterized_notebook_path)
try:
papermill_engines.register("dagstermill", DagstermillNBConvertEngine)
papermill.execute_notebook(
input_path=parameterized_notebook_path,
output_path=executed_notebook_path,
engine_name="dagstermill",
log_output=True,
)
except Exception: # pylint: disable=broad-except
try:
with open(executed_notebook_path, "rb") as fd:
executed_notebook_file_handle = compute_context.resources.file_manager.write(
fd, mode="wb", ext="ipynb"
)
executed_notebook_materialization_path = (
executed_notebook_file_handle.path_desc
)
except Exception: # pylint: disable=broad-except
compute_context.log.warning(
"Error when attempting to materialize executed notebook using file manager (falling back to local): {exc}".format(
exc=str(serializable_error_info_from_exc_info(sys.exc_info()))
)
)
executed_notebook_materialization_path = executed_notebook_path
yield AssetMaterialization(
asset_key=(asset_key_prefix + [f"{name}_output_notebook"]),
description="Location of output notebook in file manager",
metadata_entries=[
EventMetadataEntry.fspath(
executed_notebook_materialization_path,
label="executed_notebook_path",
)
],
)
raise
system_compute_context.log.debug(
"Notebook execution complete for {name} at {executed_notebook_path}.".format(
name=name, executed_notebook_path=executed_notebook_path,
)
)
executed_notebook_file_handle = None
try:
# use binary mode when when moving the file since certain file_managers such as S3
# may try to hash the contents
with open(executed_notebook_path, "rb") as fd:
executed_notebook_file_handle = compute_context.resources.file_manager.write(
fd, mode="wb", ext="ipynb"
)
executed_notebook_materialization_path = executed_notebook_file_handle.path_desc
except Exception: # pylint: disable=broad-except
compute_context.log.warning(
"Error when attempting to materialize executed notebook using file manager (falling back to local): {exc}".format(
exc=str(serializable_error_info_from_exc_info(sys.exc_info()))
)
)
executed_notebook_materialization_path = executed_notebook_path
yield AssetMaterialization(
asset_key=(asset_key_prefix + [f"{name}_output_notebook"]),
description="Location of output notebook in file manager",
metadata_entries=[
EventMetadataEntry.fspath(executed_notebook_materialization_path)
],
)
if output_notebook is not None:
yield Output(executed_notebook_file_handle, output_notebook)
# deferred import for perf
import scrapbook
output_nb = scrapbook.read_notebook(executed_notebook_path)
for (output_name, output_def) in system_compute_context.solid_def.output_dict.items():
data_dict = output_nb.scraps.data_dict
if output_name in data_dict:
value = read_value(output_def.dagster_type, data_dict[output_name])
yield Output(value, output_name)
for key, value in output_nb.scraps.items():
if key.startswith("event-"):
with open(value.data, "rb") as fd:
yield pickle.loads(fd.read())
return _t_fn
def define_dagstermill_solid(
name,
notebook_path,
input_defs=None,
output_defs=None,
config_schema=None,
required_resource_keys=None,
output_notebook=None,
asset_key_prefix=None,
):
"""Wrap a Jupyter notebook in a solid.
Arguments:
name (str): The name of the solid.
notebook_path (str): Path to the backing notebook.
input_defs (Optional[List[InputDefinition]]): The solid's inputs.
output_defs (Optional[List[OutputDefinition]]): The solid's outputs. Your notebook should
call :py:func:`~dagstermill.yield_result` to yield each of these outputs.
required_resource_keys (Optional[Set[str]]): The string names of any required resources.
output_notebook (Optional[str]): If set, will be used as the name of an injected output of
type :py:class:`~dagster.FileHandle` that will point to the executed notebook (in
addition to the :py:class:`~dagster.AssetMaterialization` that is always created). This
respects the :py:class:`~dagster.core.storage.file_manager.FileManager` configured on
the pipeline resources via the "file_manager" resource key, so, e.g.,
if :py:class:`~dagster_aws.s3.s3_file_manager` is configured, the output will be a :
py:class:`~dagster_aws.s3.S3FileHandle`.
asset_key_prefix (Optional[Union[List[str], str]]): If set, will be used to prefix the
asset keys for materialized notebooks.
Returns:
:py:class:`~dagster.SolidDefinition`
"""
check.str_param(name, "name")
check.str_param(notebook_path, "notebook_path")
input_defs = check.opt_list_param(input_defs, "input_defs", of_type=InputDefinition)
output_defs = check.opt_list_param(output_defs, "output_defs", of_type=OutputDefinition)
required_resource_keys = check.opt_set_param(
required_resource_keys, "required_resource_keys", of_type=str
)
if output_notebook is not None:
required_resource_keys.add("file_manager")
if isinstance(asset_key_prefix, str):
asset_key_prefix = [asset_key_prefix]
asset_key_prefix = check.opt_list_param(asset_key_prefix, "asset_key_prefix", of_type=str)
return SolidDefinition(
name=name,
input_defs=input_defs,
compute_fn=_dm_solid_compute(
name, notebook_path, output_notebook, asset_key_prefix=asset_key_prefix
),
output_defs=output_defs
+ (
[OutputDefinition(dagster_type=FileHandle, name=output_notebook)]
if output_notebook
else []
),
config_schema=config_schema,
required_resource_keys=required_resource_keys,
description="This solid is backed by the notebook at {path}".format(path=notebook_path),
tags={"notebook_path": notebook_path, "kind": "ipynb"},
)
diff --git a/python_modules/libraries/dagstermill/dev-requirements.txt b/python_modules/libraries/dagstermill/dev-requirements.txt
index d950c0441..ab25a047f 100644
--- a/python_modules/libraries/dagstermill/dev-requirements.txt
+++ b/python_modules/libraries/dagstermill/dev-requirements.txt
@@ -1,3 +1,2 @@
matplotlib
-scikit-learn>=0.19.0,<0.22.1; python_version <= '3.5'
-scikit-learn>=0.19.0; python_version > '3.5'
+scikit-learn>=0.19.0
\ No newline at end of file