diff --git a/examples/airline_demo/airline_demo/pipelines.py b/examples/airline_demo/airline_demo/pipelines.py index 7cb42f347..023ce83f0 100644 --- a/examples/airline_demo/airline_demo/pipelines.py +++ b/examples/airline_demo/airline_demo/pipelines.py @@ -1,194 +1,230 @@ """Pipeline definitions for the airline_demo. """ -from dagster import ModeDefinition, PresetDefinition, composite_solid, pipeline +import os + +from dagster import ( + Field, + IOManager, + ModeDefinition, + Noneable, + PresetDefinition, + composite_solid, + fs_io_manager, + io_manager, + pipeline, +) from dagster.core.definitions.no_step_launcher import no_step_launcher from dagster.core.storage.file_cache import fs_file_cache from dagster.core.storage.file_manager import local_file_manager from dagster.core.storage.temp_file_manager import tempfile_resource from dagster_aws.emr import emr_pyspark_step_launcher from dagster_aws.s3 import ( S3FileHandle, file_handle_to_s3, s3_file_cache, s3_file_manager, - s3_plus_default_intermediate_storage_defs, s3_resource, ) from dagster_pyspark import pyspark_resource +from pyspark.sql import SparkSession from .cache_file_from_s3 import cache_file_from_s3 from .resources import postgres_db_info_resource, redshift_db_info_resource from .solids import ( average_sfo_outbound_avg_delays_by_destination, delays_by_geography, delays_vs_fares, delays_vs_fares_nb, eastbound_delays, ingest_csv_file_handle_to_spark, join_q2_data, load_data_to_database_from_spark, process_sfo_weather_data, q2_sfo_outbound_flights, s3_to_df, s3_to_dw_table, sfo_delays_by_destination, tickets_with_destination, westbound_delays, ) + +class ParquetIOManager(IOManager): + def __init__(self, base_dir=None): + self._base_dir = base_dir or os.getcwd() + + def _get_path(self, context): + return os.path.join(self._base_dir, context.run_id, context.step_key, context.name) + + def handle_output(self, context, obj): + obj.write.parquet(self._get_path(context)) + + def load_input(self, context): + spark = SparkSession.builder.getOrCreate() + return spark.read.parquet(self._get_path(context.upstream_output)) + + +@io_manager(config_schema={"base_dir": Field(Noneable(str), is_required=False, default_value=None)}) +def local_parquet_io_manager(init_context): + return ParquetIOManager(base_dir=init_context.resource_config["base_dir"]) + + # start_pipelines_marker_2 test_mode = ModeDefinition( name="test", resource_defs={ "pyspark_step_launcher": no_step_launcher, "pyspark": pyspark_resource, "db_info": redshift_db_info_resource, "tempfile": tempfile_resource, "s3": s3_resource, "file_cache": fs_file_cache, "file_manager": local_file_manager, + "io_manager": fs_io_manager, + "pyspark_io_manager": local_parquet_io_manager, }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, ) local_mode = ModeDefinition( name="local", resource_defs={ "pyspark_step_launcher": no_step_launcher, "pyspark": pyspark_resource, "s3": s3_resource, "db_info": postgres_db_info_resource, "tempfile": tempfile_resource, "file_cache": fs_file_cache, "file_manager": local_file_manager, + "io_manager": fs_io_manager, + "pyspark_io_manager": local_parquet_io_manager, }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, ) prod_mode = ModeDefinition( name="prod", resource_defs={ "pyspark_step_launcher": emr_pyspark_step_launcher, "pyspark": pyspark_resource, "s3": s3_resource, "db_info": redshift_db_info_resource, "tempfile": tempfile_resource, "file_cache": s3_file_cache, "file_manager": s3_file_manager, + "io_manager": fs_io_manager, + "pyspark_io_manager": local_parquet_io_manager, }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, ) # end_pipelines_marker_2 # start_pipelines_marker_0 @pipeline( # ordered so the local is first and therefore the default mode_defs=[local_mode, test_mode, prod_mode], # end_pipelines_marker_0 preset_defs=[ PresetDefinition.from_pkg_resources( name="local_fast", mode="local", pkg_resource_defs=[ ("airline_demo.environments", "local_base.yaml"), ("airline_demo.environments", "local_fast_ingest.yaml"), ], ), PresetDefinition.from_pkg_resources( name="local_full", mode="local", pkg_resource_defs=[ ("airline_demo.environments", "local_base.yaml"), ("airline_demo.environments", "local_full_ingest.yaml"), ], ), PresetDefinition.from_pkg_resources( name="prod_fast", mode="prod", pkg_resource_defs=[ ("airline_demo.environments", "prod_base.yaml"), ("airline_demo.environments", "s3_storage.yaml"), ("airline_demo.environments", "local_fast_ingest.yaml"), ], ), ], ) def airline_demo_ingest_pipeline(): # on time data # start_airline_demo_ingest_pipeline load_data_to_database_from_spark.alias("load_q2_on_time_data")( data_frame=join_q2_data( april_data=s3_to_df.alias("april_on_time_s3_to_df")(), may_data=s3_to_df.alias("may_on_time_s3_to_df")(), june_data=s3_to_df.alias("june_on_time_s3_to_df")(), master_cord_data=s3_to_df.alias("master_cord_s3_to_df")(), ) ) # end_airline_demo_ingest_pipeline # load weather data load_data_to_database_from_spark.alias("load_q2_sfo_weather")( process_sfo_weather_data( ingest_csv_file_handle_to_spark.alias("ingest_q2_sfo_weather")( cache_file_from_s3.alias("download_q2_sfo_weather")() ) ) ) s3_to_dw_table.alias("process_q2_coupon_data")() s3_to_dw_table.alias("process_q2_market_data")() s3_to_dw_table.alias("process_q2_ticket_data")() def define_airline_demo_ingest_pipeline(): return airline_demo_ingest_pipeline @composite_solid def process_delays_by_geo() -> S3FileHandle: return file_handle_to_s3.alias("upload_delays_by_geography_pdf_plots")( delays_by_geography( westbound_delays=westbound_delays(), eastbound_delays=eastbound_delays() ) ) @pipeline( mode_defs=[test_mode, local_mode, prod_mode], preset_defs=[ PresetDefinition.from_pkg_resources( name="local", mode="local", pkg_resource_defs=[ ("airline_demo.environments", "local_base.yaml"), ("airline_demo.environments", "local_warehouse.yaml"), ], ) ], ) def airline_demo_warehouse_pipeline(): process_delays_by_geo() outbound_delays = average_sfo_outbound_avg_delays_by_destination(q2_sfo_outbound_flights()) file_handle_to_s3.alias("upload_delays_vs_fares_pdf_plots")( delays_vs_fares_nb.alias("fares_vs_delays")( delays_vs_fares( tickets_with_destination=tickets_with_destination(), average_sfo_outbound_avg_delays_by_destination=outbound_delays, ) ) ) file_handle_to_s3.alias("upload_outbound_avg_delay_pdf_plots")( sfo_delays_by_destination(outbound_delays) ) def define_airline_demo_warehouse_pipeline(): return airline_demo_warehouse_pipeline diff --git a/examples/airline_demo/airline_demo/solids.py b/examples/airline_demo/airline_demo/solids.py index 2597c86dd..486782def 100644 --- a/examples/airline_demo/airline_demo/solids.py +++ b/examples/airline_demo/airline_demo/solids.py @@ -1,593 +1,604 @@ """A fully fleshed out demo dagster repository with many configurable options.""" import os import re import dagster_pyspark from dagster import ( AssetMaterialization, EventMetadataEntry, ExpectationResult, Field, FileHandle, InputDefinition, Int, Output, OutputDefinition, String, check, composite_solid, make_python_type_usable_as_dagster_type, solid, ) from dagster.core.types.dagster_type import create_string_type from dagster_aws.s3 import S3Coordinate from dagstermill import define_dagstermill_solid from pyspark.sql import DataFrame from sqlalchemy import text from .cache_file_from_s3 import cache_file_from_s3 from .unzip_file_handle import unzip_file_handle SqlTableName = create_string_type("SqlTableName", description="The name of a database table") # Make pyspark.sql.DataFrame map to dagster_pyspark.DataFrame make_python_type_usable_as_dagster_type( python_type=DataFrame, dagster_type=dagster_pyspark.DataFrame ) PARQUET_SPECIAL_CHARACTERS = r"[ ,;{}()\n\t=]" def _notebook_path(name): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "notebooks", name) # start_solids_marker_3 def notebook_solid(name, notebook_path, input_defs, output_defs, required_resource_keys): return define_dagstermill_solid( name, _notebook_path(notebook_path), input_defs, output_defs, required_resource_keys=required_resource_keys, ) # end_solids_marker_3 # need a sql context w a sqlalchemy engine def sql_solid(name, select_statement, materialization_strategy, table_name=None, input_defs=None): """Return a new solid that executes and materializes a SQL select statement. Args: name (str): The name of the new solid. select_statement (str): The select statement to execute. materialization_strategy (str): Must be 'table', the only currently supported materialization strategy. If 'table', the kwarg `table_name` must also be passed. Kwargs: table_name (str): THe name of the new table to create, if the materialization strategy is 'table'. Default: None. input_defs (list[InputDefinition]): Inputs, if any, for the new solid. Default: None. Returns: function: The new SQL solid. """ input_defs = check.opt_list_param(input_defs, "input_defs", InputDefinition) materialization_strategy_output_types = { # pylint:disable=C0103 "table": SqlTableName, # 'view': String, # 'query': SqlAlchemyQueryType, # 'subquery': SqlAlchemySubqueryType, # 'result_proxy': SqlAlchemyResultProxyType, # could also materialize as a Pandas table, as a Spark table, as an intermediate file, etc. } if materialization_strategy not in materialization_strategy_output_types: raise Exception( "Invalid materialization strategy {materialization_strategy}, must " "be one of {materialization_strategies}".format( materialization_strategy=materialization_strategy, materialization_strategies=str(list(materialization_strategy_output_types.keys())), ) ) if materialization_strategy == "table": if table_name is None: raise Exception("Missing table_name: required for materialization strategy 'table'") output_description = ( "The string name of the new table created by the solid" if materialization_strategy == "table" else "The materialized SQL statement. If the materialization_strategy is " "'table', this is the string name of the new table created by the solid." ) description = """This solid executes the following SQL statement: {select_statement}""".format( select_statement=select_statement ) # n.b., we will eventually want to make this resources key configurable sql_statement = ( "drop table if exists {table_name};\n" "create table {table_name} as {select_statement};" ).format(table_name=table_name, select_statement=select_statement) # start_solids_marker_1 @solid( name=name, input_defs=input_defs, output_defs=[ OutputDefinition( materialization_strategy_output_types[materialization_strategy], description=output_description, ) ], description=description, required_resource_keys={"db_info"}, tags={"kind": "sql", "sql": sql_statement}, ) def _sql_solid(context, **input_defs): # pylint: disable=unused-argument """Inner function defining the new solid. Args: context (SolidExecutionContext): Must expose a `db` resource with an `execute` method, like a SQLAlchemy engine, that can execute raw SQL against a database. Returns: str: The table name of the newly materialized SQL select statement. """ context.log.info( "Executing sql statement:\n{sql_statement}".format(sql_statement=sql_statement) ) context.resources.db_info.engine.execute(text(sql_statement)) yield Output(value=table_name, output_name="result") # end_solids_marker_1 return _sql_solid @solid( - required_resource_keys={"pyspark_step_launcher", "pyspark", "file_manager"}, - description="""Take a file handle that contains a csv with headers and load it -into a Spark DataFrame. It infers header names but does *not* infer schema. - -It also ensures that the column names are valid parquet column names by -filtering out any of the following characters from column names: - -Characters (within quotations): "`{chars}`" - -""".format( - chars=PARQUET_SPECIAL_CHARACTERS + required_resource_keys={"pyspark_step_launcher", "pyspark", "file_manager",}, + description=( + "Take a file handle that contains a csv with headers and load it" + "into a Spark DataFrame. It infers header names but does *not* infer schema.\n\n" + "It also ensures that the column names are valid parquet column names by " + "filtering out any of the following characters from column names:\n\n" + f"Characters (within quotations): {PARQUET_SPECIAL_CHARACTERS}" ), + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], ) -def ingest_csv_file_handle_to_spark(context, csv_file_handle: FileHandle) -> DataFrame: +def ingest_csv_file_handle_to_spark(context, csv_file_handle: FileHandle): # fs case: copies from file manager location into system temp # - This is potentially an unnecessary copy. We could potentially specialize # the implementation of copy_handle_to_local_temp to not to do this in the # local fs case. Somewhat more dangerous though. # s3 case: downloads from s3 to local temp directory temp_file_name = context.resources.file_manager.copy_handle_to_local_temp(csv_file_handle) # In fact for a generic component this should really be using # the spark APIs to load directly from whatever object store, rather # than using any interleaving temp files. data_frame = ( context.resources.pyspark.spark_session.read.format("csv") .options( header="true", # inferSchema='true', ) .load(temp_file_name) ) # parquet compat return rename_spark_dataframe_columns( data_frame, lambda x: re.sub(PARQUET_SPECIAL_CHARACTERS, "", x) ) def rename_spark_dataframe_columns(data_frame, fn): return data_frame.toDF(*[fn(c) for c in data_frame.columns]) def do_prefix_column_names(df, prefix): check.inst_param(df, "df", DataFrame) check.str_param(prefix, "prefix") return rename_spark_dataframe_columns(df, lambda c: "{prefix}{c}".format(prefix=prefix, c=c)) -@solid(required_resource_keys={"pyspark_step_launcher"}) -def canonicalize_column_names(_context, data_frame: DataFrame) -> DataFrame: +@solid( + required_resource_keys={"pyspark_step_launcher"}, + input_defs=[InputDefinition(name="data_frame", dagster_type=DataFrame)], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], +) +def canonicalize_column_names(_context, data_frame): return rename_spark_dataframe_columns(data_frame, lambda c: c.lower()) def replace_values_spark(data_frame, old, new): return data_frame.na.replace(old, new) -@solid(required_resource_keys={"pyspark_step_launcher"}) -def process_sfo_weather_data(_context, sfo_weather_data: DataFrame) -> DataFrame: +@solid( + required_resource_keys={"pyspark_step_launcher"}, + input_defs=[InputDefinition(name="sfo_weather_data", dagster_type=DataFrame)], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], +) +def process_sfo_weather_data(_context, sfo_weather_data): normalized_sfo_weather_data = replace_values_spark(sfo_weather_data, "M", None) return rename_spark_dataframe_columns(normalized_sfo_weather_data, lambda c: c.lower()) # start_solids_marker_0 @solid( + input_defs=[InputDefinition(name="data_frame", dagster_type=DataFrame)], output_defs=[OutputDefinition(name="table_name", dagster_type=String)], config_schema={"table_name": String}, required_resource_keys={"db_info", "pyspark_step_launcher"}, ) -def load_data_to_database_from_spark(context, data_frame: DataFrame): +def load_data_to_database_from_spark(context, data_frame): context.resources.db_info.load_table(data_frame, context.solid_config["table_name"]) table_name = context.solid_config["table_name"] yield AssetMaterialization( asset_key="table:{table_name}".format(table_name=table_name), description=( "Persisted table {table_name} in database configured in the db_info resource." ).format(table_name=table_name), metadata_entries=[ EventMetadataEntry.text(label="Host", text=context.resources.db_info.host), EventMetadataEntry.text(label="Db", text=context.resources.db_info.db_name), ], ) yield Output(value=table_name, output_name="table_name") # end_solids_marker_0 @solid( required_resource_keys={"pyspark_step_launcher"}, description="Subsample a spark dataset via the configuration option.", config_schema={ "subsample_pct": Field( Int, description="The integer percentage of rows to sample from the input dataset.", ) }, + input_defs=[InputDefinition(name="data_frame", dagster_type=DataFrame)], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], ) -def subsample_spark_dataset(context, data_frame: DataFrame) -> DataFrame: +def subsample_spark_dataset(context, data_frame): return data_frame.sample( withReplacement=False, fraction=context.solid_config["subsample_pct"] / 100.0 ) @composite_solid( - description="""Ingest a zipped csv file from s3, -stash in a keyed file store (does not download if already -present by default), unzip that file, and load it into a -Spark Dataframe. See documentation in constituent solids for -more detail.""" + description=( + "Ingest a zipped csv file from s3, stash in a keyed file store (does not download if " + "already present by default), unzip that file, and load it into a Spark Dataframe. See " + "documentation in constituent solids for more detail." + ), ) def s3_to_df(s3_coordinate: S3Coordinate, archive_member: String) -> DataFrame: return ingest_csv_file_handle_to_spark( unzip_file_handle(cache_file_from_s3(s3_coordinate), archive_member) ) @composite_solid( config_fn=lambda cfg: { "subsample_spark_dataset": {"config": {"subsample_pct": cfg["subsample_pct"]}}, "load_data_to_database_from_spark": {"config": {"table_name": cfg["table_name"]}}, }, config_schema={"subsample_pct": int, "table_name": str}, - description="""Ingest zipped csv file from s3, load into a Spark -DataFrame, optionally subsample it (via configuring the -subsample_spark_dataset, solid), canonicalize the column names, and then -load it into a data warehouse. -""", + description=( + "Ingest zipped csv file from s3, load into a Spark DataFrame, optionally subsample it " + "(via configuring the subsample_spark_dataset, solid), canonicalize the column names, " + "and then load it into a data warehouse." + ), ) def s3_to_dw_table(s3_coordinate: S3Coordinate, archive_member: String) -> String: return load_data_to_database_from_spark( canonicalize_column_names(subsample_spark_dataset(s3_to_df(s3_coordinate, archive_member))) ) q2_sfo_outbound_flights = sql_solid( "q2_sfo_outbound_flights", """ select * from q2_on_time_data where origin = 'SFO' """, "table", table_name="q2_sfo_outbound_flights", ) average_sfo_outbound_avg_delays_by_destination = sql_solid( "average_sfo_outbound_avg_delays_by_destination", """ select cast(cast(arrdelay as float) as integer) as arrival_delay, cast(cast(depdelay as float) as integer) as departure_delay, origin, dest as destination from q2_sfo_outbound_flights """, "table", table_name="average_sfo_outbound_avg_delays_by_destination", input_defs=[InputDefinition("q2_sfo_outbound_flights", dagster_type=SqlTableName)], ) ticket_prices_with_average_delays = sql_solid( "tickets_with_destination", """ select tickets.*, coupons.dest, coupons.destairportid, coupons.destairportseqid, coupons.destcitymarketid, coupons.destcountry, coupons.deststatefips, coupons.deststate, coupons.deststatename, coupons.destwac from q2_ticket_data as tickets, q2_coupon_data as coupons where tickets.itinid = coupons.itinid; """, "table", table_name="tickets_with_destination", ) tickets_with_destination = sql_solid( "tickets_with_destination", """ select tickets.*, coupons.dest, coupons.destairportid, coupons.destairportseqid, coupons.destcitymarketid, coupons.destcountry, coupons.deststatefips, coupons.deststate, coupons.deststatename, coupons.destwac from q2_ticket_data as tickets, q2_coupon_data as coupons where tickets.itinid = coupons.itinid; """, "table", table_name="tickets_with_destination", ) delays_vs_fares = sql_solid( "delays_vs_fares", """ with avg_fares as ( select tickets.origin, tickets.dest, avg(cast(tickets.itinfare as float)) as avg_fare, avg(cast(tickets.farepermile as float)) as avg_fare_per_mile from tickets_with_destination as tickets where origin = 'SFO' group by (tickets.origin, tickets.dest) ) select avg_fares.*, avg(avg_delays.arrival_delay) as avg_arrival_delay, avg(avg_delays.departure_delay) as avg_departure_delay from avg_fares, average_sfo_outbound_avg_delays_by_destination as avg_delays where avg_fares.origin = avg_delays.origin and avg_fares.dest = avg_delays.destination group by ( avg_fares.avg_fare, avg_fares.avg_fare_per_mile, avg_fares.origin, avg_delays.origin, avg_fares.dest, avg_delays.destination ) """, "table", table_name="delays_vs_fares", input_defs=[ InputDefinition("tickets_with_destination", SqlTableName), InputDefinition("average_sfo_outbound_avg_delays_by_destination", SqlTableName), ], ) eastbound_delays = sql_solid( "eastbound_delays", """ select avg(cast(cast(arrdelay as float) as integer)) as avg_arrival_delay, avg(cast(cast(depdelay as float) as integer)) as avg_departure_delay, origin, dest as destination, count(1) as num_flights, avg(cast(dest_latitude as float)) as dest_latitude, avg(cast(dest_longitude as float)) as dest_longitude, avg(cast(origin_latitude as float)) as origin_latitude, avg(cast(origin_longitude as float)) as origin_longitude from q2_on_time_data where cast(origin_longitude as float) < cast(dest_longitude as float) and originstate != 'HI' and deststate != 'HI' and originstate != 'AK' and deststate != 'AK' group by (origin,destination) order by num_flights desc limit 100; """, "table", table_name="eastbound_delays", ) # start_solids_marker_2 westbound_delays = sql_solid( "westbound_delays", """ select avg(cast(cast(arrdelay as float) as integer)) as avg_arrival_delay, avg(cast(cast(depdelay as float) as integer)) as avg_departure_delay, origin, dest as destination, count(1) as num_flights, avg(cast(dest_latitude as float)) as dest_latitude, avg(cast(dest_longitude as float)) as dest_longitude, avg(cast(origin_latitude as float)) as origin_latitude, avg(cast(origin_longitude as float)) as origin_longitude from q2_on_time_data where cast(origin_longitude as float) > cast(dest_longitude as float) and originstate != 'HI' and deststate != 'HI' and originstate != 'AK' and deststate != 'AK' group by (origin,destination) order by num_flights desc limit 100; """, "table", table_name="westbound_delays", ) # end_solids_marker_2 # start_solids_marker_4 delays_by_geography = notebook_solid( "delays_by_geography", "Delays_by_Geography.ipynb", input_defs=[ InputDefinition( "westbound_delays", SqlTableName, description="The SQL table containing westbound delays.", ), InputDefinition( "eastbound_delays", SqlTableName, description="The SQL table containing eastbound delays.", ), ], output_defs=[ OutputDefinition( dagster_type=FileHandle, # name='plots_pdf_path', description="The saved PDF plots.", ) ], required_resource_keys={"db_info"}, ) # end_solids_marker_4 delays_vs_fares_nb = notebook_solid( "fares_vs_delays", "Fares_vs_Delays.ipynb", input_defs=[ InputDefinition( "table_name", SqlTableName, description="The SQL table to use for calcuations." ) ], output_defs=[ OutputDefinition( dagster_type=FileHandle, # name='plots_pdf_path', description="The path to the saved PDF plots.", ) ], required_resource_keys={"db_info"}, ) sfo_delays_by_destination = notebook_solid( "sfo_delays_by_destination", "SFO_Delays_by_Destination.ipynb", input_defs=[ InputDefinition( "table_name", SqlTableName, description="The SQL table to use for calcuations." ) ], output_defs=[ OutputDefinition( dagster_type=FileHandle, # name='plots_pdf_path', description="The path to the saved PDF plots.", ) ], required_resource_keys={"db_info"}, ) @solid( required_resource_keys={"pyspark_step_launcher", "pyspark"}, config_schema={"subsample_pct": Int}, - description=""" - This solid takes April, May, and June data and coalesces it into a q2 data set. - It then joins the that origin and destination airport with the data in the - master_cord_data. - """, + description=( + "This solid takes April, May, and June data and coalesces it into a Q2 data set. " + "It then joins the that origin and destination airport with the data in the " + "master_cord_data." + ), + input_defs=[ + InputDefinition(name="april_data", dagster_type=DataFrame), + InputDefinition(name="may_data", dagster_type=DataFrame), + InputDefinition(name="june_data", dagster_type=DataFrame), + InputDefinition(name="master_cord_data", dagster_type=DataFrame), + ], + output_defs=[OutputDefinition(DataFrame, io_manager_key="pyspark_io_manager")], ) def join_q2_data( - context, - april_data: DataFrame, - may_data: DataFrame, - june_data: DataFrame, - master_cord_data: DataFrame, -) -> DataFrame: + context, april_data, may_data, june_data, master_cord_data, +): dfs = {"april": april_data, "may": may_data, "june": june_data} missing_things = [] for required_column in ["DestAirportSeqID", "OriginAirportSeqID"]: for month, df in dfs.items(): if required_column not in df.columns: missing_things.append({"month": month, "missing_column": required_column}) yield ExpectationResult( success=not bool(missing_things), label="airport_ids_present", description="Sequence IDs present in incoming monthly flight data.", metadata_entries=[ EventMetadataEntry.json(label="metadata", data={"missing_columns": missing_things}) ], ) yield ExpectationResult( success=set(april_data.columns) == set(may_data.columns) == set(june_data.columns), label="flight_data_same_shape", metadata_entries=[ EventMetadataEntry.json(label="metadata", data={"columns": april_data.columns}) ], ) q2_data = april_data.union(may_data).union(june_data) sampled_q2_data = q2_data.sample( withReplacement=False, fraction=context.solid_config["subsample_pct"] / 100.0 ) sampled_q2_data.createOrReplaceTempView("q2_data") dest_prefixed_master_cord_data = do_prefix_column_names(master_cord_data, "DEST_") dest_prefixed_master_cord_data.createOrReplaceTempView("dest_cord_data") origin_prefixed_master_cord_data = do_prefix_column_names(master_cord_data, "ORIGIN_") origin_prefixed_master_cord_data.createOrReplaceTempView("origin_cord_data") full_data = context.resources.pyspark.spark_session.sql( """ SELECT * FROM origin_cord_data LEFT JOIN ( SELECT * FROM q2_data LEFT JOIN dest_cord_data ON q2_data.DestAirportSeqID = dest_cord_data.DEST_AIRPORT_SEQ_ID ) q2_dest_data ON origin_cord_data.ORIGIN_AIRPORT_SEQ_ID = q2_dest_data.OriginAirportSeqID """ ) yield Output(rename_spark_dataframe_columns(full_data, lambda c: c.lower())) diff --git a/examples/airline_demo/airline_demo_tests/test_pipelines.py b/examples/airline_demo/airline_demo_tests/test_pipelines.py index ee89ca9b0..23d1577b6 100644 --- a/examples/airline_demo/airline_demo_tests/test_pipelines.py +++ b/examples/airline_demo/airline_demo_tests/test_pipelines.py @@ -1,82 +1,72 @@ import os +import tempfile # pylint: disable=unused-argument import pytest from dagster import execute_pipeline, file_relative_path from dagster.core.definitions.reconstructable import ReconstructablePipeline from dagster.core.test_utils import instance_for_test from dagster.utils import load_yaml_from_globs ingest_pipeline = ReconstructablePipeline.for_module( "airline_demo.pipelines", "define_airline_demo_ingest_pipeline", ) warehouse_pipeline = ReconstructablePipeline.for_module( "airline_demo.pipelines", "define_airline_demo_warehouse_pipeline", ) def config_path(relative_path): return file_relative_path( __file__, os.path.join("../airline_demo/environments/", relative_path) ) @pytest.mark.db @pytest.mark.nettest @pytest.mark.py3 @pytest.mark.spark def test_ingest_pipeline_fast(postgres, pg_hostname): - with instance_for_test() as instance: - ingest_config_dict = load_yaml_from_globs( - config_path("test_base.yaml"), config_path("local_fast_ingest.yaml") - ) - result_ingest = execute_pipeline( - pipeline=ingest_pipeline, - mode="local", - run_config=ingest_config_dict, - instance=instance, - ) + with tempfile.TemporaryDirectory() as temp_dir: + with instance_for_test() as instance: + ingest_config_dict = load_yaml_from_globs( + config_path("test_base.yaml"), config_path("local_fast_ingest.yaml"), + ) + ingest_config_dict["resources"]["io_manager"] = {"config": {"base_dir": temp_dir}} + ingest_config_dict["resources"]["pyspark_io_manager"] = { + "config": {"base_dir": temp_dir} + } + result_ingest = execute_pipeline( + pipeline=ingest_pipeline, + mode="local", + run_config=ingest_config_dict, + instance=instance, + ) - assert result_ingest.success - - -@pytest.mark.db -@pytest.mark.nettest -@pytest.mark.py3 -@pytest.mark.spark -def test_ingest_pipeline_fast_filesystem_storage(postgres, pg_hostname): - with instance_for_test() as instance: - ingest_config_dict = load_yaml_from_globs( - config_path("test_base.yaml"), - config_path("local_fast_ingest.yaml"), - config_path("filesystem_storage.yaml"), - ) - result_ingest = execute_pipeline( - pipeline=ingest_pipeline, - mode="local", - run_config=ingest_config_dict, - instance=instance, - ) - - assert result_ingest.success + assert result_ingest.success @pytest.mark.db @pytest.mark.nettest @pytest.mark.py3 @pytest.mark.spark @pytest.mark.skipif('"win" in sys.platform', reason="avoiding the geopandas tests") def test_airline_pipeline_1_warehouse(postgres, pg_hostname): - with instance_for_test() as instance: + with tempfile.TemporaryDirectory() as temp_dir: + with instance_for_test() as instance: - warehouse_config_object = load_yaml_from_globs( - config_path("test_base.yaml"), config_path("local_warehouse.yaml") - ) - result_warehouse = execute_pipeline( - pipeline=warehouse_pipeline, - mode="local", - run_config=warehouse_config_object, - instance=instance, - ) - assert result_warehouse.success + warehouse_config_object = load_yaml_from_globs( + config_path("test_base.yaml"), config_path("local_warehouse.yaml") + ) + warehouse_config_object["resources"]["io_manager"] = {"config": {"base_dir": temp_dir}} + warehouse_config_object["resources"]["pyspark_io_manager"] = { + "config": {"base_dir": temp_dir} + } + result_warehouse = execute_pipeline( + pipeline=warehouse_pipeline, + mode="local", + run_config=warehouse_config_object, + instance=instance, + ) + assert result_warehouse.success diff --git a/examples/airline_demo/airline_demo_tests/test_solids.py b/examples/airline_demo/airline_demo_tests/test_solids.py index f322b6c1b..701f0cbd5 100644 --- a/examples/airline_demo/airline_demo_tests/test_solids.py +++ b/examples/airline_demo/airline_demo_tests/test_solids.py @@ -1,60 +1,32 @@ """Unit and pipeline tests for the airline_demo. As is common in real-world pipelines, we want to test some fairly heavy-weight operations, requiring, e.g., a connection to S3, Spark, and a database. We lever pytest marks to isolate subsets of tests with different requirements. E.g., to run only those tests that don't require Spark, `pytest -m "not spark"`. """ import pytest from airline_demo.solids import sql_solid from dagster import ModeDefinition from dagster.core.storage.temp_file_manager import tempfile_resource tempfile_mode = ModeDefinition(name="tempfile", resource_defs={"tempfile": tempfile_resource}) def test_sql_solid_with_bad_materialization_strategy(): with pytest.raises(Exception) as excinfo: sql_solid("foo", "select * from bar", "view") assert str(excinfo.value) == "Invalid materialization strategy view, must be one of ['table']" def test_sql_solid_without_table_name(): with pytest.raises(Exception) as excinfo: sql_solid("foo", "select * from bar", "table") assert str(excinfo.value) == "Missing table_name: required for materialization strategy 'table'" def test_sql_solid(): - result = sql_solid("foo", "select * from bar", "table", "quux") - assert result - # TODO: test execution? - - -@pytest.mark.postgres -@pytest.mark.skip -@pytest.mark.spark -def test_load_data_to_postgres_from_spark_postgres(): - raise NotImplementedError() - - -@pytest.mark.nettest -@pytest.mark.redshift -@pytest.mark.skip -@pytest.mark.spark -def test_load_data_to_redshift_from_spark(): - raise NotImplementedError() - - -@pytest.mark.skip -@pytest.mark.spark -def test_subsample_spark_dataset(): - raise NotImplementedError() - - -@pytest.mark.skip -@pytest.mark.spark -def test_join_spark_data_frame(): - raise NotImplementedError() + solid = sql_solid("foo", "select * from bar", "table", "quux") + assert solid diff --git a/examples/airline_demo/airline_demo_tests/test_types.py b/examples/airline_demo/airline_demo_tests/test_types.py deleted file mode 100644 index 97048fe0d..000000000 --- a/examples/airline_demo/airline_demo_tests/test_types.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -from tempfile import TemporaryDirectory - -import botocore -import pyspark -from airline_demo.solids import ingest_csv_file_handle_to_spark -from dagster import ( - InputDefinition, - LocalFileHandle, - ModeDefinition, - OutputDefinition, - execute_pipeline, - file_relative_path, - local_file_manager, - pipeline, - solid, -) -from dagster.core.definitions.no_step_launcher import no_step_launcher -from dagster.core.instance import DagsterInstance -from dagster.core.storage.intermediate_storage import build_fs_intermediate_storage -from dagster.core.storage.temp_file_manager import tempfile_resource -from dagster_aws.s3 import s3_file_manager, s3_plus_default_intermediate_storage_defs, s3_resource -from dagster_aws.s3.intermediate_storage import S3IntermediateStorage -from dagster_pyspark import DataFrame, pyspark_resource -from pyspark.sql import Row, SparkSession - -spark_local_fs_mode = ModeDefinition( - name="spark", - resource_defs={ - "pyspark": pyspark_resource, - "tempfile": tempfile_resource, - "s3": s3_resource, - "pyspark_step_launcher": no_step_launcher, - "file_manager": local_file_manager, - }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, -) - -spark_s3_mode = ModeDefinition( - name="spark", - resource_defs={ - "pyspark": pyspark_resource, - "tempfile": tempfile_resource, - "s3": s3_resource, - "pyspark_step_launcher": no_step_launcher, - "file_manager": s3_file_manager, - }, - intermediate_storage_defs=s3_plus_default_intermediate_storage_defs, -) - - -def test_spark_data_frame_serialization_file_system_file_handle(spark_config): - @solid - def nonce(_): - return LocalFileHandle(file_relative_path(__file__, "data/test.csv")) - - @pipeline(mode_defs=[spark_local_fs_mode]) - def spark_df_test_pipeline(): - ingest_csv_file_handle_to_spark(nonce()) - - instance = DagsterInstance.ephemeral() - - result = execute_pipeline( - spark_df_test_pipeline, - mode="spark", - run_config={ - "intermediate_storage": {"filesystem": {}}, - "resources": {"pyspark": {"config": {"spark_conf": spark_config}}}, - }, - instance=instance, - ) - - intermediate_storage = build_fs_intermediate_storage( - instance.intermediates_directory, run_id=result.run_id - ) - - assert result.success - result_dir = os.path.join( - intermediate_storage.root, "intermediates", "ingest_csv_file_handle_to_spark", "result", - ) - - assert "_SUCCESS" in os.listdir(result_dir) - - spark = SparkSession.builder.getOrCreate() - df = spark.read.parquet(result_dir) - assert isinstance(df, pyspark.sql.dataframe.DataFrame) - assert df.head()[0] == "1" - - -def test_spark_data_frame_serialization_s3_file_handle(s3_bucket, spark_config): - @solid(required_resource_keys={"file_manager"}) - def nonce(context): - with open(os.path.join(os.path.dirname(__file__), "data/test.csv"), "rb") as fd: - return context.resources.file_manager.write_data(fd.read()) - - @pipeline(mode_defs=[spark_s3_mode]) - def spark_df_test_pipeline(): - - ingest_csv_file_handle_to_spark(nonce()) - - result = execute_pipeline( - spark_df_test_pipeline, - run_config={ - "intermediate_storage": {"s3": {"config": {"s3_bucket": s3_bucket}}}, - "resources": { - "pyspark": {"config": {"spark_conf": spark_config}}, - "file_manager": {"config": {"s3_bucket": s3_bucket}}, - }, - }, - mode="spark", - ) - - assert result.success - - intermediate_storage = S3IntermediateStorage(s3_bucket=s3_bucket, run_id=result.run_id) - - success_key = "/".join( - [ - intermediate_storage.root.strip("/"), - "intermediates", - "ingest_csv_file_handle_to_spark", - "result", - "_SUCCESS", - ] - ) - try: - assert intermediate_storage.object_store.s3.get_object( - Bucket=intermediate_storage.object_store.bucket, Key=success_key - ) - except botocore.exceptions.ClientError: - raise Exception("Couldn't find object at {success_key}".format(success_key=success_key)) - - -def test_spark_dataframe_output_csv(spark_config): - spark = SparkSession.builder.getOrCreate() - num_df = ( - spark.read.format("csv") - .options(header="true", inferSchema="true") - .load(file_relative_path(__file__, "num.csv")) - ) - - assert num_df.collect() == [Row(num1=1, num2=2)] - - @solid - def emit(_): - return num_df - - @solid(input_defs=[InputDefinition("df", DataFrame)], output_defs=[OutputDefinition(DataFrame)]) - def passthrough_df(_context, df): - return df - - @pipeline(mode_defs=[spark_local_fs_mode]) - def passthrough(): - passthrough_df(emit()) - - with TemporaryDirectory() as tempdir: - file_name = os.path.join(tempdir, "output.csv") - result = execute_pipeline( - passthrough, - run_config={ - "resources": {"pyspark": {"config": {"spark_conf": spark_config}}}, - "solids": { - "passthrough_df": { - "outputs": [{"result": {"csv": {"path": file_name, "header": True}}}] - } - }, - }, - ) - - from_file_df = ( - spark.read.format("csv").options(header="true", inferSchema="true").load(file_name) - ) - - assert ( - result.result_for_solid("passthrough_df").output_value().collect() - == from_file_df.collect() - ) diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py index 00457df29..c7a5dcb3d 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_cache_file_from_s3.py @@ -1,211 +1,211 @@ import os +import tempfile import pytest from airline_demo.cache_file_from_s3 import cache_file_from_s3 from dagster import ( DagsterInvalidDefinitionError, ModeDefinition, ResourceDefinition, execute_pipeline, execute_solid, pipeline, ) from dagster.core.storage.file_cache import LocalFileHandle, fs_file_cache from dagster.seven import mock -from dagster.utils.temp_file import get_temp_dir def execute_solid_with_resources(solid_def, resource_defs, run_config): @pipeline( name="{}_solid_test".format(solid_def.name), mode_defs=[ModeDefinition(resource_defs=resource_defs)], ) def test_pipeline(): return solid_def() return execute_pipeline(test_pipeline, run_config) def test_cache_file_from_s3_basic(): s3_session = mock.MagicMock() - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: solid_result = execute_solid( cache_file_from_s3, ModeDefinition( resource_defs={ "file_cache": fs_file_cache, "s3": ResourceDefinition.hardcoded_resource(s3_session), } ), run_config={ "solids": { "cache_file_from_s3": { "inputs": {"s3_coordinate": {"bucket": "some-bucket", "key": "some-key"}} } }, "resources": {"file_cache": {"config": {"target_folder": temp_dir}}}, }, ) # assert the download occurred assert s3_session.download_file.call_count == 1 assert solid_result.success expectation_results = solid_result.expectation_results_during_compute assert len(expectation_results) == 1 expectation_result = expectation_results[0] assert expectation_result.success assert expectation_result.label == "file_handle_exists" path_in_metadata = expectation_result.metadata_entries[0].entry_data.path assert isinstance(path_in_metadata, str) assert os.path.exists(path_in_metadata) assert isinstance(solid_result.output_value(), LocalFileHandle) assert "some-key" in solid_result.output_value().path_desc def test_cache_file_from_s3_specify_target_key(): s3_session = mock.MagicMock() - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: solid_result = execute_solid( cache_file_from_s3, ModeDefinition( resource_defs={ "file_cache": fs_file_cache, "s3": ResourceDefinition.hardcoded_resource(s3_session), } ), run_config={ "solids": { "cache_file_from_s3": { "inputs": {"s3_coordinate": {"bucket": "some-bucket", "key": "some-key"}}, "config": {"file_key": "specified-file-key"}, } }, "resources": {"file_cache": {"config": {"target_folder": temp_dir}}}, }, ) # assert the download occurred assert s3_session.download_file.call_count == 1 assert solid_result.success assert isinstance(solid_result.output_value(), LocalFileHandle) assert "specified-file-key" in solid_result.output_value().path_desc def test_cache_file_from_s3_skip_download(): - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: s3_session_one = mock.MagicMock() execute_solid( cache_file_from_s3, ModeDefinition( resource_defs={ "file_cache": fs_file_cache, "s3": ResourceDefinition.hardcoded_resource(s3_session_one), } ), run_config={ "solids": { "cache_file_from_s3": { "inputs": {"s3_coordinate": {"bucket": "some-bucket", "key": "some-key"}} } }, "resources": {"file_cache": {"config": {"target_folder": temp_dir}}}, }, ) # assert the download occurred assert s3_session_one.download_file.call_count == 1 s3_session_two = mock.MagicMock() execute_solid( cache_file_from_s3, ModeDefinition( resource_defs={ "file_cache": fs_file_cache, "s3": ResourceDefinition.hardcoded_resource(s3_session_two), } ), run_config={ "solids": { "cache_file_from_s3": { "inputs": {"s3_coordinate": {"bucket": "some-bucket", "key": "some-key"}} } }, "resources": {"file_cache": {"config": {"target_folder": temp_dir}}}, }, ) # assert the download did not occur because file is already there assert s3_session_two.download_file.call_count == 0 def test_cache_file_from_s3_overwrite(): - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: s3_session_one = mock.MagicMock() execute_solid( cache_file_from_s3, ModeDefinition( resource_defs={ "file_cache": fs_file_cache, "s3": ResourceDefinition.hardcoded_resource(s3_session_one), } ), run_config={ "solids": { "cache_file_from_s3": { "inputs": {"s3_coordinate": {"bucket": "some-bucket", "key": "some-key"}} } }, "resources": { "file_cache": {"config": {"target_folder": temp_dir, "overwrite": True}} }, }, ) # assert the download occurred assert s3_session_one.download_file.call_count == 1 s3_session_two = mock.MagicMock() execute_solid( cache_file_from_s3, ModeDefinition( resource_defs={ "file_cache": fs_file_cache, "s3": ResourceDefinition.hardcoded_resource(s3_session_two), } ), run_config={ "solids": { "cache_file_from_s3": { "inputs": {"s3_coordinate": {"bucket": "some-bucket", "key": "some-key"}} } }, "resources": { "file_cache": {"config": {"target_folder": temp_dir, "overwrite": True}} }, }, ) # assert the download did not occur because file is already there assert s3_session_two.download_file.call_count == 0 def test_missing_resources(): with pytest.raises(DagsterInvalidDefinitionError): - with get_temp_dir() as temp_dir: + with tempfile.TemporaryDirectory() as temp_dir: execute_solid( cache_file_from_s3, ModeDefinition(resource_defs={"file_cache": fs_file_cache}), run_config={ "solids": { "cache_file_from_s3": { "inputs": { "s3_coordinate": {"bucket": "some-bucket", "key": "some-key"} } } }, "resources": {"file_cache": {"config": {"target_folder": temp_dir}}}, }, ) diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py index 7926c8535..464023346 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_ingest_csv_file_handle_to_spark.py @@ -1,80 +1,102 @@ +import tempfile + +from airline_demo.pipelines import local_parquet_io_manager from airline_demo.solids import ingest_csv_file_handle_to_spark from dagster import ( LocalFileHandle, ModeDefinition, execute_pipeline, + fs_io_manager, local_file_manager, pipeline, solid, ) from dagster.core.definitions.no_step_launcher import no_step_launcher from dagster.utils import file_relative_path from dagster_pyspark import pyspark_resource from pyspark.sql import Row @solid def collect_df(_, df): """The pyspark Spark context will be stopped on pipeline termination, so we need to collect the pyspark DataFrame before pipeline completion. """ return df.collect() def test_ingest_csv_file_handle_to_spark(spark_config): @solid def emit_num_csv_local_file(_): return LocalFileHandle(file_relative_path(__file__, "../num.csv")) @pipeline( mode_defs=[ ModeDefinition( resource_defs={ "pyspark": pyspark_resource, "pyspark_step_launcher": no_step_launcher, + "pyspark_io_manager": local_parquet_io_manager, "file_manager": local_file_manager, + "io_manager": fs_io_manager, } ) ] ) def ingest_csv_file_test(): return collect_df(ingest_csv_file_handle_to_spark(emit_num_csv_local_file())) - result = execute_pipeline( - ingest_csv_file_test, - run_config={"resources": {"pyspark": {"config": {"spark_conf": spark_config}}}}, - ) - assert result.success + with tempfile.TemporaryDirectory() as temp_dir: + result = execute_pipeline( + ingest_csv_file_test, + run_config={ + "resources": { + "pyspark": {"config": {"spark_conf": spark_config}}, + "pyspark_io_manager": {"config": {"base_dir": temp_dir}}, + "io_manager": {"config": {"base_dir": temp_dir}}, + } + }, + ) + assert result.success - df = result.result_for_solid("collect_df").output_value() - assert df == [Row(num1="1", num2="2")] + df = result.result_for_solid("collect_df").output_value() + assert df == [Row(num1="1", num2="2")] def test_ingest_csv_file_with_special_handle_to_spark(spark_config): @solid def emit_num_special_csv_local_file(_): return LocalFileHandle(file_relative_path(__file__, "../num_with_special_chars.csv")) @pipeline( mode_defs=[ ModeDefinition( resource_defs={ "pyspark": pyspark_resource, "pyspark_step_launcher": no_step_launcher, "file_manager": local_file_manager, + "pyspark_io_manager": local_parquet_io_manager, + "io_manager": fs_io_manager, } ) ] ) def ingest_csv_file_test(): return collect_df(ingest_csv_file_handle_to_spark(emit_num_special_csv_local_file())) - result = execute_pipeline( - ingest_csv_file_test, - run_config={"resources": {"pyspark": {"config": {"spark_conf": spark_config}}}}, - ) - assert result.success + with tempfile.TemporaryDirectory() as temp_dir: + result = execute_pipeline( + ingest_csv_file_test, + run_config={ + "resources": { + "pyspark": {"config": {"spark_conf": spark_config}}, + "pyspark_io_manager": {"config": {"base_dir": temp_dir}}, + "io_manager": {"config": {"base_dir": temp_dir}}, + } + }, + ) + assert result.success - df = result.result_for_solid("collect_df").output_value() + df = result.result_for_solid("collect_df").output_value() - assert df == [Row(num1="1", num2="2")] + assert df == [Row(num1="1", num2="2")] diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py index 7ae84030e..2feb62600 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_load_data_from_spark.py @@ -1,53 +1,76 @@ +import tempfile + +from airline_demo.pipelines import local_parquet_io_manager from airline_demo.resources import DbInfo from airline_demo.solids import load_data_to_database_from_spark -from dagster import ModeDefinition, ResourceDefinition, execute_pipeline, pipeline, solid +from dagster import ( + ModeDefinition, + OutputDefinition, + ResourceDefinition, + execute_pipeline, + fs_io_manager, + pipeline, + solid, +) from dagster.core.definitions.no_step_launcher import no_step_launcher from dagster.seven import mock +from dagster.utils import file_relative_path from dagster_pyspark import pyspark_resource -from pyspark.sql import DataFrame def test_airline_demo_load_df(): db_info_mock = DbInfo( engine=mock.MagicMock(), url="url", jdbc_url="url", dialect="dialect", load_table=mock.MagicMock(), host="host", db_name="db_name", ) - @solid - def emit_mock(_): - return mock.MagicMock(spec=DataFrame) + @solid( + required_resource_keys={"pyspark"}, + output_defs=[OutputDefinition(io_manager_key="pyspark_io_manager")], + ) + def emit_mock(context): + return context.resources.pyspark.spark_session.read.csv( + file_relative_path(__file__, "../data/test.csv") + ) @pipeline( mode_defs=[ ModeDefinition( resource_defs={ "db_info": ResourceDefinition.hardcoded_resource(db_info_mock), "pyspark": pyspark_resource, "pyspark_step_launcher": no_step_launcher, + "pyspark_io_manager": local_parquet_io_manager, + "io_manager": fs_io_manager, } ) ] ) def load_df_test(): load_data_to_database_from_spark(emit_mock()) - solid_result = execute_pipeline( - load_df_test, - run_config={ - "solids": {"load_data_to_database_from_spark": {"config": {"table_name": "foo"}}} - }, - ).result_for_solid("load_data_to_database_from_spark") + with tempfile.TemporaryDirectory() as temp_dir: + solid_result = execute_pipeline( + load_df_test, + run_config={ + "solids": {"load_data_to_database_from_spark": {"config": {"table_name": "foo"}}}, + "resources": { + "io_manager": {"config": {"base_dir": temp_dir}}, + "pyspark_io_manager": {"config": {"base_dir": temp_dir}}, + }, + }, + ).result_for_solid("load_data_to_database_from_spark") - assert solid_result.success - mats = solid_result.materializations_during_compute - assert len(mats) == 1 - mat = mats[0] - assert len(mat.metadata_entries) == 2 - entries = {me.label: me for me in mat.metadata_entries} - assert entries["Host"].entry_data.text == "host" - assert entries["Db"].entry_data.text == "db_name" + assert solid_result.success + mats = solid_result.materializations_during_compute + assert len(mats) == 1 + mat = mats[0] + assert len(mat.metadata_entries) == 2 + entries = {me.label: me for me in mat.metadata_entries} + assert entries["Host"].entry_data.text == "host" + assert entries["Db"].entry_data.text == "db_name" diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py b/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py index ecf9a247f..889fa0be3 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/test_unzip_file_handle.py @@ -1,102 +1,103 @@ import zipfile import boto3 from airline_demo.unzip_file_handle import unzip_file_handle from dagster import ( LocalFileHandle, ModeDefinition, OutputDefinition, ResourceDefinition, execute_pipeline, local_file_manager, pipeline, solid, ) from dagster.utils.test import get_temp_file_name from dagster_aws.s3 import S3FileHandle, S3FileManager, s3_intermediate_storage from moto import mock_s3 # for dep graphs def write_zip_file_to_disk(zip_file_path, archive_member, data): with zipfile.ZipFile(zip_file_path, mode="w") as archive: archive.writestr(data=data, zinfo_or_arcname=archive_member) def test_unzip_file_handle(): data = b"foo" with get_temp_file_name() as zip_file_name: write_zip_file_to_disk(zip_file_name, "some_archive_member", data) @solid def to_zip_file_handle(_): return LocalFileHandle(zip_file_name) @pipeline(mode_defs=[ModeDefinition(resource_defs={"file_manager": local_file_manager})]) def do_test_unzip_file_handle(): return unzip_file_handle(to_zip_file_handle()) result = execute_pipeline( do_test_unzip_file_handle, run_config={ "solids": { "unzip_file_handle": { "inputs": {"archive_member": {"value": "some_archive_member"}} } } }, ) assert result.success @mock_s3 def test_unzip_file_handle_on_fake_s3(): foo_bytes = b"foo" @solid(required_resource_keys={"file_manager"}, output_defs=[OutputDefinition(S3FileHandle)]) def write_zipped_file_to_s3_store(context): with get_temp_file_name() as zip_file_name: write_zip_file_to_disk(zip_file_name, "an_archive_member", foo_bytes) with open(zip_file_name, "rb") as ff: s3_file_handle = context.resources.file_manager.write_data(ff.read()) return s3_file_handle # Uses mock S3 - s3 = boto3.client("s3") + # https://github.com/spulec/moto/issues/3292 + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="some-bucket") file_manager = S3FileManager(s3_session=s3, s3_bucket="some-bucket", s3_base_key="dagster") @pipeline( mode_defs=[ ModeDefinition( resource_defs={ "s3": ResourceDefinition.hardcoded_resource(s3), "file_manager": ResourceDefinition.hardcoded_resource(file_manager), }, intermediate_storage_defs=[s3_intermediate_storage], ) ] ) def do_test_unzip_file_handle_s3(): return unzip_file_handle(write_zipped_file_to_s3_store()) result = execute_pipeline( do_test_unzip_file_handle_s3, run_config={ "storage": {"s3": {"config": {"s3_bucket": "some-bucket"}}}, "solids": { "unzip_file_handle": {"inputs": {"archive_member": {"value": "an_archive_member"}}} }, }, ) assert result.success zipped_s3_file = result.result_for_solid("write_zipped_file_to_s3_store").output_value() unzipped_s3_file = result.result_for_solid("unzip_file_handle").output_value() bucket_keys = [obj["Key"] for obj in s3.list_objects(Bucket="some-bucket")["Contents"]] assert zipped_s3_file.s3_key in bucket_keys assert unzipped_s3_file.s3_key in bucket_keys diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py index 598f2df9a..67c1ba63f 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_four.py @@ -1,63 +1,64 @@ import boto3 from dagster import FileHandle, ModeDefinition, solid from dagster.utils.temp_file import get_temp_file_name from dagster.utils.test import execute_solid from dagster_aws.s3 import S3Coordinate, S3FileCache from moto import mock_s3 @solid(required_resource_keys={"file_cache", "s3"}) def cache_file_from_s3(context, s3_coord: S3Coordinate) -> FileHandle: # we default the target_key to the last component of the s3 key. target_key = s3_coord["key"].split("/")[-1] with get_temp_file_name() as tmp_file: context.resources.s3.download_file( Bucket=s3_coord["bucket"], Key=s3_coord["key"], Filename=tmp_file ) file_cache = context.resources.file_cache with open(tmp_file, "rb") as tmp_file_object: # returns a handle rather than a path file_handle = file_cache.write_file_object(target_key, tmp_file_object) return file_handle def unittest_for_aws_mode_def(s3_session): return ModeDefinition.from_resources( { "file_cache": S3FileCache("file-cache-bucket", "file-cache", s3_session), "s3": s3_session, } ) @mock_s3 def test_cache_file_from_s3_step_four(snapshot): - s3 = boto3.client("s3") + # https://github.com/spulec/moto/issues/3292 + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="source-bucket") s3.create_bucket(Bucket="file-cache-bucket") s3.put_object(Bucket="source-bucket", Key="source-file", Body=b"foo") solid_result = execute_solid( cache_file_from_s3, unittest_for_aws_mode_def(s3), input_values={"s3_coord": {"bucket": "source-bucket", "key": "source-file"}}, ) assert solid_result.output_value().path_desc == "s3://file-cache-bucket/file-cache/source-file" file_cache_obj = s3.get_object(Bucket="file-cache-bucket", Key="file-cache/source-file") assert file_cache_obj["Body"].read() == b"foo" snapshot.assert_match( { "file-cache-bucket": { k: s3.get_object(Bucket="file-cache-bucket", Key=k)["Body"].read() for k in [ obj["Key"] for obj in s3.list_objects(Bucket="file-cache-bucket")["Contents"] ] } } ) diff --git a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py index 30feeb364..d277e614e 100644 --- a/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py +++ b/examples/airline_demo/airline_demo_tests/unit_tests/testing_guide_tests/test_cache_file_from_s3_in_guide_step_three.py @@ -1,74 +1,75 @@ import os import boto3 from dagster import ModeDefinition, solid from dagster.core.storage.file_cache import FSFileCache from dagster.seven import mock from dagster.utils.temp_file import get_temp_dir, get_temp_file_name from dagster.utils.test import execute_solid from dagster_aws.s3 import S3Coordinate from moto import mock_s3 @solid(required_resource_keys={"file_cache", "s3"}) def cache_file_from_s3(context, s3_coord: S3Coordinate) -> str: # we default the target_key to the last component of the s3 key. target_key = s3_coord["key"].split("/")[-1] with get_temp_file_name() as tmp_file: context.resources.s3.download_file( Bucket=s3_coord["bucket"], Key=s3_coord["key"], Filename=tmp_file ) file_cache = context.resources.file_cache with open(tmp_file, "rb") as tmp_file_object: # returns a handle rather than a path file_handle = file_cache.write_file_object(target_key, tmp_file_object) return file_handle.path def unittest_for_local_mode_def(temp_dir, s3_session): return ModeDefinition.from_resources({"file_cache": FSFileCache(temp_dir), "s3": s3_session}) def test_cache_file_from_s3_step_three_mock(): s3_session = mock.MagicMock() with get_temp_dir() as temp_dir: execute_solid( cache_file_from_s3, unittest_for_local_mode_def(temp_dir, s3_session), input_values={"s3_coord": {"bucket": "some-bucket", "key": "some-key"}}, ) assert s3_session.download_file.call_count == 1 assert os.path.exists(os.path.join(temp_dir, "some-key")) @mock_s3 def test_cache_file_from_s3_step_three_fake(snapshot): - s3 = boto3.client("s3") + # https://github.com/spulec/moto/issues/3292 + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="some-bucket") s3.put_object(Bucket="some-bucket", Key="some-key", Body=b"foo") with get_temp_dir() as temp_dir: execute_solid( cache_file_from_s3, unittest_for_local_mode_def(temp_dir, s3), input_values={"s3_coord": {"bucket": "some-bucket", "key": "some-key"}}, ) target_file = os.path.join(temp_dir, "some-key") assert os.path.exists(target_file) with open(target_file, "rb") as ff: assert ff.read() == b"foo" snapshot.assert_match( { "some-bucket": { k: s3.get_object(Bucket="some-bucket", Key=k)["Body"].read() for k in [obj["Key"] for obj in s3.list_objects(Bucket="some-bucket")["Contents"]] } } ) diff --git a/examples/airline_demo/an_archive_member b/examples/airline_demo/an_archive_member deleted file mode 100644 index 191028156..000000000 --- a/examples/airline_demo/an_archive_member +++ /dev/null @@ -1 +0,0 @@ -foo \ No newline at end of file diff --git a/examples/airline_demo/setup.py b/examples/airline_demo/setup.py index a508bba4b..d6d343630 100644 --- a/examples/airline_demo/setup.py +++ b/examples/airline_demo/setup.py @@ -1,53 +1,53 @@ from setuptools import find_packages, setup setup( name="airline_demo", version="dev", author="Elementl", license="Apache-2.0", description="Dagster Examples", url="https://github.com/dagster-io/dagster/tree/master/examples/airline_demo", classifiers=[ "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], packages=find_packages(exclude=["test"]), # default supports basic tutorial & toy examples install_requires=["dagster"], extras_require={ # full is for running the more realistic demos "full": [ "dagstermill", "dagster-aws", "dagster-cron", "dagster-postgres", "dagster-pyspark", "dagster-slack", "dagster-snowflake", # These two packages, descartes and geopandas, are used in the airline demo notebooks "descartes", 'geopandas; "win" not in sys_platform', "google-api-python-client", "google-cloud-storage", "keras", "lakehouse", "matplotlib", "mock", - "moto>=1.3.7", + "moto>=1.3.16", "pandas>=1.0.0", "pytest-mock", # Pyspark 2.x is incompatible with Python 3.8+ 'pyspark>=3.0.0; python_version >= "3.8"', 'pyspark>=2.0.2; python_version < "3.8"', "seaborn", "sqlalchemy-redshift>=0.7.2", "SQLAlchemy-Utils==0.33.8", 'tensorflow; python_version < "3.9"', ], "airflow": ["dagster_airflow", "docker-compose==1.23.2"], }, include_package_data=True, )