diff --git a/examples/airline_demo/airline_demo/pipelines.py b/examples/airline_demo/airline_demo/pipelines.py --- a/examples/airline_demo/airline_demo/pipelines.py +++ b/examples/airline_demo/airline_demo/pipelines.py @@ -1,7 +1,19 @@ """Pipeline definitions for the airline_demo. """ -from dagster import ModeDefinition, PresetDefinition, composite_solid, pipeline +import os + +from dagster import ( + Field, + IOManager, + ModeDefinition, + Noneable, + PresetDefinition, + composite_solid, + fs_io_manager, + io_manager, + pipeline, +) from dagster.core.definitions.no_step_launcher import no_step_launcher from dagster.core.storage.file_cache import fs_file_cache from dagster.core.storage.file_manager import local_file_manager @@ -12,10 +24,10 @@ file_handle_to_s3, s3_file_cache, s3_file_manager, - s3_plus_default_intermediate_storage_defs, s3_resource, ) from dagster_pyspark import pyspark_resource +from pyspark.sql import SparkSession from .cache_file_from_s3 import cache_file_from_s3 from .resources import postgres_db_info_resource, redshift_db_info_resource @@ -37,6 +49,27 @@ westbound_delays, ) + +class ParquetIOManager(IOManager): + def __init__(self, base_dir=None): + self._base_dir = base_dir or os.getcwd() + + def _get_path(self, context): + return os.path.join(self._base_dir, context.run_id, context.step_key, context.name) + + def handle_output(self, context, obj): + obj.write.parquet(self._get_path(context)) + + def load_input(self, context): + spark = SparkSession.builder.getOrCreate() + return spark.read.parquet(self._get_path(context.upstream_output)) + + +@io_manager(config_schema={"base_dir": Field(Noneable(str), is_required=False, default_value=None)}) +def local_parquet_io_manager(init_context): + return ParquetIOManager(base_dir=init_context.resource_config["base_dir"]) + + # start_pipelines_marker_2 test_mode = ModeDefinition( name="test", @@ -48,8 +81,9 @@ "s3": s3_resource, "file_cache": fs_file_cache, "file_manager": local_file_manager, + "io_manager": fs_io_manager, + "pyspark_io_manager": local_parquet_io_manager, }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, ) @@ -63,8 +97,9 @@ "tempfile": tempfile_resource, "file_cache": fs_file_cache, "file_manager": local_file_manager, + "io_manager": fs_io_manager, + "pyspark_io_manager": local_parquet_io_manager, }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, ) @@ -78,8 +113,9 @@ "tempfile": tempfile_resource, "file_cache": s3_file_cache, "file_manager": s3_file_manager, + "io_manager": fs_io_manager, + "pyspark_io_manager": local_parquet_io_manager, }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, ) # end_pipelines_marker_2 diff --git a/examples/airline_demo/airline_demo/solids.py b/examples/airline_demo/airline_demo/solids.py --- a/examples/airline_demo/airline_demo/solids.py +++ b/examples/airline_demo/airline_demo/solids.py @@ -154,20 +154,17 @@ @solid( - required_resource_keys={"pyspark_step_launcher", "pyspark", "file_manager"}, - description="""Take a file handle that contains a csv with headers and load it -into a Spark DataFrame. It infers header names but does *not* infer schema. - -It also ensures that the column names are valid parquet column names by -filtering out any of the following characters from column names: - -Characters (within quotations): "`{chars}`" - -""".format( - chars=PARQUET_SPECIAL_CHARACTERS + required_resource_keys={"pyspark_step_launcher", "pyspark", "file_manager",}, + description=( + "Take a file handle that contains a csv with headers and load it" + "into a Spark DataFrame. It infers header names but does *not* infer schema.\n\n" + "It also ensures that the column names are valid parquet column names by " + "filtering out any of the following characters from column names:\n\n" + f"Characters (within quotations): {PARQUET_SPECIAL_CHARACTERS}" ), + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], ) -def ingest_csv_file_handle_to_spark(context, csv_file_handle: FileHandle) -> DataFrame: +def ingest_csv_file_handle_to_spark(context, csv_file_handle: FileHandle): # fs case: copies from file manager location into system temp # - This is potentially an unnecessary copy. We could potentially specialize # the implementation of copy_handle_to_local_temp to not to do this in the @@ -203,8 +200,12 @@ return rename_spark_dataframe_columns(df, lambda c: "{prefix}{c}".format(prefix=prefix, c=c)) -@solid(required_resource_keys={"pyspark_step_launcher"}) -def canonicalize_column_names(_context, data_frame: DataFrame) -> DataFrame: +@solid( + required_resource_keys={"pyspark_step_launcher"}, + input_defs=[InputDefinition(name="data_frame", dagster_type=DataFrame)], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], +) +def canonicalize_column_names(_context, data_frame): return rename_spark_dataframe_columns(data_frame, lambda c: c.lower()) @@ -212,19 +213,24 @@ return data_frame.na.replace(old, new) -@solid(required_resource_keys={"pyspark_step_launcher"}) -def process_sfo_weather_data(_context, sfo_weather_data: DataFrame) -> DataFrame: +@solid( + required_resource_keys={"pyspark_step_launcher"}, + input_defs=[InputDefinition(name="sfo_weather_data", dagster_type=DataFrame)], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], +) +def process_sfo_weather_data(_context, sfo_weather_data): normalized_sfo_weather_data = replace_values_spark(sfo_weather_data, "M", None) return rename_spark_dataframe_columns(normalized_sfo_weather_data, lambda c: c.lower()) # start_solids_marker_0 @solid( + input_defs=[InputDefinition(name="data_frame", dagster_type=DataFrame)], output_defs=[OutputDefinition(name="table_name", dagster_type=String)], config_schema={"table_name": String}, required_resource_keys={"db_info", "pyspark_step_launcher"}, ) -def load_data_to_database_from_spark(context, data_frame: DataFrame): +def load_data_to_database_from_spark(context, data_frame): context.resources.db_info.load_table(data_frame, context.solid_config["table_name"]) table_name = context.solid_config["table_name"] @@ -252,19 +258,21 @@ Int, description="The integer percentage of rows to sample from the input dataset.", ) }, + input_defs=[InputDefinition(name="data_frame", dagster_type=DataFrame)], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], ) -def subsample_spark_dataset(context, data_frame: DataFrame) -> DataFrame: +def subsample_spark_dataset(context, data_frame): return data_frame.sample( withReplacement=False, fraction=context.solid_config["subsample_pct"] / 100.0 ) @composite_solid( - description="""Ingest a zipped csv file from s3, -stash in a keyed file store (does not download if already -present by default), unzip that file, and load it into a -Spark Dataframe. See documentation in constituent solids for -more detail.""" + description=( + "Ingest a zipped csv file from s3, stash in a keyed file store (does not download if " + "already present by default), unzip that file, and load it into a Spark Dataframe. See " + "documentation in constituent solids for more detail." + ), ) def s3_to_df(s3_coordinate: S3Coordinate, archive_member: String) -> DataFrame: return ingest_csv_file_handle_to_spark( @@ -278,11 +286,11 @@ "load_data_to_database_from_spark": {"config": {"table_name": cfg["table_name"]}}, }, config_schema={"subsample_pct": int, "table_name": str}, - description="""Ingest zipped csv file from s3, load into a Spark -DataFrame, optionally subsample it (via configuring the -subsample_spark_dataset, solid), canonicalize the column names, and then -load it into a data warehouse. -""", + description=( + "Ingest zipped csv file from s3, load into a Spark DataFrame, optionally subsample it " + "(via configuring the subsample_spark_dataset, solid), canonicalize the column names, " + "and then load it into a data warehouse." + ), ) def s3_to_dw_table(s3_coordinate: S3Coordinate, archive_member: String) -> String: return load_data_to_database_from_spark( @@ -526,19 +534,22 @@ @solid( required_resource_keys={"pyspark_step_launcher", "pyspark"}, config_schema={"subsample_pct": Int}, - description=""" - This solid takes April, May, and June data and coalesces it into a q2 data set. - It then joins the that origin and destination airport with the data in the - master_cord_data. - """, + description=( + "This solid takes April, May, and June data and coalesces it into a Q2 data set. " + "It then joins the that origin and destination airport with the data in the " + "master_cord_data." + ), + input_defs=[ + InputDefinition(name="april_data", dagster_type=DataFrame), + InputDefinition(name="may_data", dagster_type=DataFrame), + InputDefinition(name="june_data", dagster_type=DataFrame), + InputDefinition(name="master_cord_data", dagster_type=DataFrame), + ], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], ) def join_q2_data( - context, - april_data: DataFrame, - may_data: DataFrame, - june_data: DataFrame, - master_cord_data: DataFrame, -) -> DataFrame: + context, april_data, may_data, june_data, master_cord_data, +): dfs = {"april": april_data, "may": may_data, "june": june_data} diff --git a/examples/airline_demo/airline_demo_tests/test_pipelines.py b/examples/airline_demo/airline_demo_tests/test_pipelines.py --- a/examples/airline_demo/airline_demo_tests/test_pipelines.py +++ b/examples/airline_demo/airline_demo_tests/test_pipelines.py @@ -1,4 +1,5 @@ import os +import tempfile # pylint: disable=unused-argument import pytest @@ -27,39 +28,23 @@ @pytest.mark.py3 @pytest.mark.spark def test_ingest_pipeline_fast(postgres, pg_hostname): - with instance_for_test() as instance: - ingest_config_dict = load_yaml_from_globs( - config_path("test_base.yaml"), config_path("local_fast_ingest.yaml") - ) - result_ingest = execute_pipeline( - pipeline=ingest_pipeline, - mode="local", - run_config=ingest_config_dict, - instance=instance, - ) + with tempfile.TemporaryDirectory() as temp_dir: + with instance_for_test() as instance: + ingest_config_dict = load_yaml_from_globs( + config_path("test_base.yaml"), config_path("local_fast_ingest.yaml"), + ) + ingest_config_dict["resources"]["io_manager"] = {"config": {"base_dir": temp_dir}} + ingest_config_dict["resources"]["pyspark_io_manager"] = { + "config": {"base_dir": temp_dir} + } + result_ingest = execute_pipeline( + pipeline=ingest_pipeline, + mode="local", + run_config=ingest_config_dict, + instance=instance, + ) - assert result_ingest.success - - -@pytest.mark.db -@pytest.mark.nettest -@pytest.mark.py3 -@pytest.mark.spark -def test_ingest_pipeline_fast_filesystem_storage(postgres, pg_hostname): - with instance_for_test() as instance: - ingest_config_dict = load_yaml_from_globs( - config_path("test_base.yaml"), - config_path("local_fast_ingest.yaml"), - config_path("filesystem_storage.yaml"), - ) - result_ingest = execute_pipeline( - pipeline=ingest_pipeline, - mode="local", - run_config=ingest_config_dict, - instance=instance, - ) - - assert result_ingest.success + assert result_ingest.success @pytest.mark.db @@ -68,15 +53,20 @@ @pytest.mark.spark @pytest.mark.skipif('"win" in sys.platform', reason="avoiding the geopandas tests") def test_airline_pipeline_1_warehouse(postgres, pg_hostname): - with instance_for_test() as instance: + with tempfile.TemporaryDirectory() as temp_dir: + with instance_for_test() as instance: - warehouse_config_object = load_yaml_from_globs( - config_path("test_base.yaml"), config_path("local_warehouse.yaml") - ) - result_warehouse = execute_pipeline( - pipeline=warehouse_pipeline, - mode="local", - run_config=warehouse_config_object, - instance=instance, - ) - assert result_warehouse.success + warehouse_config_object = load_yaml_from_globs( + config_path("test_base.yaml"), config_path("local_warehouse.yaml") + ) + warehouse_config_object["resources"]["io_manager"] = {"config": {"base_dir": temp_dir}} + warehouse_config_object["resources"]["pyspark_io_manager"] = { + "config": {"base_dir": temp_dir} + } + result_warehouse = execute_pipeline( + pipeline=warehouse_pipeline, + mode="local", + run_config=warehouse_config_object, + instance=instance, + ) + assert result_warehouse.success diff --git a/examples/airline_demo/airline_demo_tests/test_solids.py b/examples/airline_demo/airline_demo_tests/test_solids.py --- a/examples/airline_demo/airline_demo_tests/test_solids.py +++ b/examples/airline_demo/airline_demo_tests/test_solids.py @@ -28,33 +28,5 @@ def test_sql_solid(): - result = sql_solid("foo", "select * from bar", "table", "quux") - assert result - # TODO: test execution? - - -@pytest.mark.postgres -@pytest.mark.skip -@pytest.mark.spark -def test_load_data_to_postgres_from_spark_postgres(): - raise NotImplementedError() - - -@pytest.mark.nettest -@pytest.mark.redshift -@pytest.mark.skip -@pytest.mark.spark -def test_load_data_to_redshift_from_spark(): - raise NotImplementedError() - - -@pytest.mark.skip -@pytest.mark.spark -def test_subsample_spark_dataset(): - raise NotImplementedError() - - -@pytest.mark.skip -@pytest.mark.spark -def test_join_spark_data_frame(): - raise NotImplementedError() + solid = sql_solid("foo", "select * from bar", "table", "quux") + assert solid diff --git a/examples/airline_demo/airline_demo_tests/test_types.py b/examples/airline_demo/airline_demo_tests/test_types.py deleted file mode 100644 --- a/examples/airline_demo/airline_demo_tests/test_types.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -from tempfile import TemporaryDirectory - -import botocore -import pyspark -from airline_demo.solids import ingest_csv_file_handle_to_spark -from dagster import ( - InputDefinition, - LocalFileHandle, - ModeDefinition, - OutputDefinition, - execute_pipeline, - file_relative_path, - local_file_manager, - pipeline, - solid, -) -from dagster.core.definitions.no_step_launcher import no_step_launcher -from dagster.core.instance import DagsterInstance -from dagster.core.storage.intermediate_storage import build_fs_intermediate_storage -from dagster.core.storage.temp_file_manager import tempfile_resource -from dagster_aws.s3 import s3_file_manager, s3_plus_default_intermediate_storage_defs, s3_resource -from dagster_aws.s3.intermediate_storage import S3IntermediateStorage -from dagster_pyspark import DataFrame, pyspark_resource -from pyspark.sql import Row, SparkSession - -spark_local_fs_mode = ModeDefinition( - name="spark", - resource_defs={ - "pyspark": pyspark_resource, - "tempfile": tempfile_resource, - "s3": s3_resource, - "pyspark_step_launcher": no_step_launcher, - "file_manager": local_file_manager, - }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, -) - -spark_s3_mode = ModeDefinition( - name="spark", - resource_defs={ - "pyspark": pyspark_resource, - "tempfile": tempfile_resource, - "s3": s3_resource, - "pyspark_step_launcher": no_step_launcher, - "file_manager": s3_file_manager, - }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, -) - - -def test_spark_data_frame_serialization_file_system_file_handle(spark_config): - @solid - def nonce(_): - return LocalFileHandle(file_relative_path(__file__, "data/test.csv")) - - @pipeline(mode_defs=[spark_local_fs_mode]) - def spark_df_test_pipeline(): - ingest_csv_file_handle_to_spark(nonce()) - - instance = DagsterInstance.ephemeral() - - result = execute_pipeline( - spark_df_test_pipeline, - mode="spark", - run_config={ - "intermediate_storage": {"filesystem": {}}, - "resources": {"pyspark": {"config": {"spark_conf": spark_config}}}, - }, - instance=instance, - ) - - intermediate_storage = build_fs_intermediate_storage( - instance.intermediates_directory, run_id=result.run_id - ) - - assert result.success - result_dir = os.path.join( - intermediate_storage.root, "intermediates", "ingest_csv_file_handle_to_spark", "result", - ) - - assert "_SUCCESS" in os.listdir(result_dir) - - spark = SparkSession.builder.getOrCreate() - df = spark.read.parquet(result_dir) - assert isinstance(df, pyspark.sql.dataframe.DataFrame) - assert df.head()[0] == "1" - - -def test_spark_data_frame_serialization_s3_file_handle(s3_bucket, spark_config): - @solid(required_resource_keys={"file_manager"}) - def nonce(context): - with open(os.path.join(os.path.dirname(__file__), "data/test.csv"), "rb") as fd: - return context.resources.file_manager.write_data(fd.read()) - - @pipeline(mode_defs=[spark_s3_mode]) - def spark_df_test_pipeline(): - - ingest_csv_file_handle_to_spark(nonce()) - - result = execute_pipeline( - spark_df_test_pipeline, - run_config={ - "intermediate_storage": {"s3": {"config": {"s3_bucket": s3_bucket}}}, - "resources": { - "pyspark": {"config": {"spark_conf": spark_config}}, - "file_manager": {"config": {"s3_bucket": s3_bucket}}, - }, - }, - mode="spark", - ) - - assert result.success - - intermediate_storage = S3IntermediateStorage(s3_bucket=s3_bucket, run_id=result.run_id) - - success_key = "/".join( - [ - intermediate_storage.root.strip("/"), - "intermediates", - "ingest_csv_file_handle_to_spark", - "result", - "_SUCCESS", - ] - ) - try: - assert intermediate_storage.object_store.s3.get_object( - Bucket=intermediate_storage.object_store.bucket, Key=success_key - ) - except botocore.exceptions.ClientError: - raise Exception("Couldn't find object at {success_key}".format(success_key=success_key)) - - -def test_spark_dataframe_output_csv(spark_config): - spark = SparkSession.builder.getOrCreate() - num_df = ( - spark.read.format("csv") - .options(header="true", inferSchema="true") - .load(file_relative_path(__file__, "num.csv")) - ) - - assert num_df.collect() == [Row(num1=1, num2=2)] - - @solid - def emit(_): - return num_df - - @solid(input_defs=[InputDefinition("df", DataFrame)], output_defs=[OutputDefinition(DataFrame)]) - def passthrough_df(_context, df): - return df - - @pipeline(mode_defs=[spark_local_fs_mode]) - def passthrough(): - passthrough_df(emit()) - - with TemporaryDirectory() as tempdir: - file_name = os.path.join(tempdir, "output.csv") - result = execute_pipeline( - passthrough, - run_config={ - "resources": {"pyspark": {"config": {"spark_conf": spark_config}}}, - "solids": { - "passthrough_df": { - "outputs": [{"result": {"csv": {"path": file_name, "header": True}}}] - } - }, - }, - ) - - from_file_df = ( - spark.read.format("csv").options(header="true", inferSchema="true").load(file_name) - ) - - assert ( - result.result_for_solid("passthrough_df").output_value().collect() - == from_file_df.collect() - ) diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py @@ -1,4 +1,5 @@ import os +import tempfile import pytest from airline_demo.cache_file_from_s3 import cache_file_from_s3 @@ -12,7 +13,6 @@ ) from dagster.core.storage.file_cache import LocalFileHandle, fs_file_cache from dagster.seven import mock -from dagster.utils.temp_file import get_temp_dir def execute_solid_with_resources(solid_def, resource_defs, run_config): @@ -28,7 +28,7 @@ def test_cache_file_from_s3_basic(): s3_session = mock.MagicMock() - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: solid_result = execute_solid( cache_file_from_s3, ModeDefinition( @@ -67,7 +67,7 @@ def test_cache_file_from_s3_specify_target_key(): s3_session = mock.MagicMock() - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: solid_result = execute_solid( cache_file_from_s3, ModeDefinition( @@ -95,7 +95,7 @@ def test_cache_file_from_s3_skip_download(): - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: s3_session_one = mock.MagicMock() execute_solid( cache_file_from_s3, @@ -142,7 +142,7 @@ def test_cache_file_from_s3_overwrite(): - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: s3_session_one = mock.MagicMock() execute_solid( cache_file_from_s3, @@ -194,7 +194,7 @@ def test_missing_resources(): with pytest.raises(DagsterInvalidDefinitionError): - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: execute_solid( cache_file_from_s3, ModeDefinition(resource_defs={"file_cache": fs_file_cache}), diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py @@ -1,8 +1,12 @@ +import tempfile + +from airline_demo.pipelines import local_parquet_io_manager from airline_demo.solids import ingest_csv_file_handle_to_spark from dagster import ( LocalFileHandle, ModeDefinition, execute_pipeline, + fs_io_manager, local_file_manager, pipeline, solid, @@ -32,7 +36,9 @@ resource_defs={ "pyspark": pyspark_resource, "pyspark_step_launcher": no_step_launcher, + "pyspark_io_manager": local_parquet_io_manager, "file_manager": local_file_manager, + "io_manager": fs_io_manager, } ) ] @@ -40,14 +46,21 @@ def ingest_csv_file_test(): return collect_df(ingest_csv_file_handle_to_spark(emit_num_csv_local_file())) - result = execute_pipeline( - ingest_csv_file_test, - run_config={"resources": {"pyspark": {"config": {"spark_conf": spark_config}}}}, - ) - assert result.success + with tempfile.TemporaryDirectory() as temp_dir: + result = execute_pipeline( + ingest_csv_file_test, + run_config={ + "resources": { + "pyspark": {"config": {"spark_conf": spark_config}}, + "pyspark_io_manager": {"config": {"base_dir": temp_dir}}, + "io_manager": {"config": {"base_dir": temp_dir}}, + } + }, + ) + assert result.success - df = result.result_for_solid("collect_df").output_value() - assert df == [Row(num1="1", num2="2")] + df = result.result_for_solid("collect_df").output_value() + assert df == [Row(num1="1", num2="2")] def test_ingest_csv_file_with_special_handle_to_spark(spark_config): @@ -62,6 +75,8 @@ "pyspark": pyspark_resource, "pyspark_step_launcher": no_step_launcher, "file_manager": local_file_manager, + "pyspark_io_manager": local_parquet_io_manager, + "io_manager": fs_io_manager, } ) ] @@ -69,12 +84,19 @@ def ingest_csv_file_test(): return collect_df(ingest_csv_file_handle_to_spark(emit_num_special_csv_local_file())) - result = execute_pipeline( - ingest_csv_file_test, - run_config={"resources": {"pyspark": {"config": {"spark_conf": spark_config}}}}, - ) - assert result.success + with tempfile.TemporaryDirectory() as temp_dir: + result = execute_pipeline( + ingest_csv_file_test, + run_config={ + "resources": { + "pyspark": {"config": {"spark_conf": spark_config}}, + "pyspark_io_manager": {"config": {"base_dir": temp_dir}}, + "io_manager": {"config": {"base_dir": temp_dir}}, + } + }, + ) + assert result.success - df = result.result_for_solid("collect_df").output_value() + df = result.result_for_solid("collect_df").output_value() - assert df == [Row(num1="1", num2="2")] + assert df == [Row(num1="1", num2="2")] diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py @@ -1,10 +1,21 @@ +import tempfile + +from airline_demo.pipelines import local_parquet_io_manager from airline_demo.resources import DbInfo from airline_demo.solids import load_data_to_database_from_spark -from dagster import ModeDefinition, ResourceDefinition, execute_pipeline, pipeline, solid +from dagster import ( + ModeDefinition, + OutputDefinition, + ResourceDefinition, + execute_pipeline, + fs_io_manager, + pipeline, + solid, +) from dagster.core.definitions.no_step_launcher import no_step_launcher from dagster.seven import mock +from dagster.utils import file_relative_path from dagster_pyspark import pyspark_resource -from pyspark.sql import DataFrame def test_airline_demo_load_df(): @@ -18,9 +29,14 @@ db_name="db_name", ) - @solid - def emit_mock(_): - return mock.MagicMock(spec=DataFrame) + @solid( + required_resource_keys={"pyspark"}, + output_defs=[OutputDefinition(io_manager_key="pyspark_io_manager")], + ) + def emit_mock(context): + return context.resources.pyspark.spark_session.read.csv( + file_relative_path(__file__, "../data/test.csv") + ) @pipeline( mode_defs=[ @@ -29,6 +45,8 @@ "db_info": ResourceDefinition.hardcoded_resource(db_info_mock), "pyspark": pyspark_resource, "pyspark_step_launcher": no_step_launcher, + "pyspark_io_manager": local_parquet_io_manager, + "io_manager": fs_io_manager, } ) ] @@ -36,18 +54,23 @@ def load_df_test(): load_data_to_database_from_spark(emit_mock()) - solid_result = execute_pipeline( - load_df_test, - run_config={ - "solids": {"load_data_to_database_from_spark": {"config": {"table_name": "foo"}}} - }, - ).result_for_solid("load_data_to_database_from_spark") + with tempfile.TemporaryDirectory() as temp_dir: + solid_result = execute_pipeline( + load_df_test, + run_config={ + "solids": {"load_data_to_database_from_spark": {"config": {"table_name": "foo"}}}, + "resources": { + "io_manager": {"config": {"base_dir": temp_dir}}, + "pyspark_io_manager": {"config": {"base_dir": temp_dir}}, + }, + }, + ).result_for_solid("load_data_to_database_from_spark") - assert solid_result.success - mats = solid_result.materializations_during_compute - assert len(mats) == 1 - mat = mats[0] - assert len(mat.metadata_entries) == 2 - entries = {me.label: me for me in mat.metadata_entries} - assert entries["Host"].entry_data.text == "host" - assert entries["Db"].entry_data.text == "db_name" + assert solid_result.success + mats = solid_result.materializations_during_compute + assert len(mats) == 1 + mat = mats[0] + assert len(mat.metadata_entries) == 2 + entries = {me.label: me for me in mat.metadata_entries} + assert entries["Host"].entry_data.text == "host" + assert entries["Db"].entry_data.text == "db_name" diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py @@ -64,7 +64,8 @@ return s3_file_handle # Uses mock S3 - s3 = boto3.client("s3") + # https://github.com/spulec/moto/issues/3292 + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="some-bucket") file_manager = S3FileManager(s3_session=s3, s3_bucket="some-bucket", s3_base_key="dagster") diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py --- a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py @@ -34,7 +34,8 @@ @mock_s3 def test_cache_file_from_s3_step_four(snapshot): - s3 = boto3.client("s3") + # https://github.com/spulec/moto/issues/3292 + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="source-bucket") s3.create_bucket(Bucket="file-cache-bucket") s3.put_object(Bucket="source-bucket", Key="source-file", Body=b"foo") diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py --- a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py @@ -47,7 +47,8 @@ @mock_s3 def test_cache_file_from_s3_step_three_fake(snapshot): - s3 = boto3.client("s3") + # https://github.com/spulec/moto/issues/3292 + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="some-bucket") s3.put_object(Bucket="some-bucket", Key="some-key", Body=b"foo") diff --git a/examples/airline_demo/an_archive_member b/examples/airline_demo/an_archive_member deleted file mode 100644 --- a/examples/airline_demo/an_archive_member +++ /dev/null @@ -1 +0,0 @@ -foo \ No newline at end of file diff --git a/examples/airline_demo/setup.py b/examples/airline_demo/setup.py --- a/examples/airline_demo/setup.py +++ b/examples/airline_demo/setup.py @@ -36,7 +36,7 @@ "lakehouse", "matplotlib", "mock", - "moto>=1.3.7", + "moto>=1.3.16", "pandas>=1.0.0", "pytest-mock", # Pyspark 2.x is incompatible with Python 3.8+