Generating Large Image Sets Using Apache Spark
Background
On my research team, we often need to create large labeled datasets for machine learning. These datasets typically include images of cardiac electrogram recordings, which are labeled by expert human adjudicators. These labels include things like “presence of atrial fibrillation,” “lead noise,” etc. An example of one of these images is shown below, lightly redacted to protect trade secrets:
This post explains how to quickly create a dataset containing tens of thousands of such images, conveniently compressed into a single file for easy export into whatever system is desired. The basic logic applies whether you are creating images, or any other binary objects that take some time to individually generate.
Prerequisites
This post assumes you have some experience with Apache Spark and/or Databricks, and optionally, the Databricks Unity Catalog. Each of these is a lengthy topic in its own right.
The Scenario
In this example scenario we’ll be processing data from a simple delta table which has this schema:
CREATE TABLE IF NOT EXISTS my_catalog.my_schema.electrograms
(
Model STRING,
Serial LONG,
Timestamp LONG,
Content STRING
) PARTITIONED BY (DeviceModel)
The Content
column contains base64-encoded binary data that we want to generate images from.
We want to generate images for only a subset of the rows in this table. Let’s assume the subset
is defined in a CSV file named subset.csv
with similar column names, for example:
Model | Serial | Timestamp |
---|---|---|
9000 | 123456 | 99936382 |
8000 | 654321 | 89467833 |
7000 | 124816 | 36836302 |
… |
We’ll write the create the image files in a Unity Catalog volume where they can later be downloaded using a tool such as the Azure Storage Explorer:
import databricks
output_path = f"/Volumes/some_volume/path/to/segm_images"
# Delete the output_path folder if it exists from a previous run, and recreate it
dbutils.fs.rm(output_path, recurse=True)
dbutils.fs.mkdirs(output_path)
Creating a Dataframe with the Subset of Electrograms
Let’s read the subset.csv
into a Spark dataframe with the same column datatypes as
the my_catalog.my_schema.electrograms
table, then join it to that table to get a dataframe
which contains the Content column that we’ll be using to generate the images from.
schema = StructType([
StructField("Model", StringType(), False),
StructField("Serial", LongType(), False),
StructField("Timestamp", LongType(), False),
])
subset_df = spark.read.csv("subset.csv", header=True, schema=schema)
subset_df.createOrReplaceTempView("subset_df")
df_with_content = spark.sql("""
select subset_df.Model, subset_df.Serial, subset_df.Timestamp, egms.Content
from subset_df
join my_catalog.my_schema.electrograms egms on
subset_df.Model = egms.DeviceModel AND
subset_df.Serial = egms.Serial AND
subset_df.Timestamp = egms.Timestamp
""")
Note: If the join involves more complex matching, such as finding all electrograms
BETWEEN
a certain time range around the subset timestamps, then this will be very slow because Spark will extract a potentially huge amount of binary data from the Content column and then shuffle it in order to perform a merge join. In that case, dataset narrowing should be done to create an intermediate dataframe which doesn’t include the Content column, which can then be joined to theelectrograms
table like above as the final step. (Perhaps more about this in a future blog post.)
Define a Function to Generate Images
The actual structure of the data in the Content
column, and the logic used to render images from it,
are fairly complex and beyond the scope of this post. For simplicity, let’s define our generate_images()
function like so:
from logging import getLogger
logger = getLogger(__name__)
def generate_plots(row):
plot_title = "NA"
try:
with io.BytesIO() as buffer:
# Create a PNG image and write its binary content to the buffer
# plot_title = ...
# ...
# Write the buffer to a PNG file in the volume.
buffer.seek(0)
with open(f"{output_path}/{plot_title + '.png'}", "wb") as f:
f.write(buffer.getvalue())
except Exception as ex:
# Log the exception
logger.error("Error processing " + plot_title + ". Ex:" + traceback.format_exc(), exc_info=True)
Call the Function on Every Row in the Subset
This is not needed for every use case, but our subsets tend to be rather unbalanced due to some patients having many electrograms, and others very few. Therefore it often is necessary to repartition the data before processing, so that every Spark worker is doing a roughly equal amount of work. This also minimizes library initialization overhead:
# Calls the above function for each partition
def generate_plots_partition(rows):
for row in rows:
generate_plots(row)
# The choice of 256 partitions is largely based on the size of the input data
# and the number of of workers, and may vary depending on the dataset size.
df_with_content.repartition(256).foreachPartition(generate_plots_partition)
Wrapping Up
Our image generation library is able to generate a PNG image at a rate of roughly 1 per second. Assuming the subset contains about 10,000 electrograms, on an 8-node Databricks cluster with 8 cores per node, the job will take roughly
(10,000 / (8 * 8))
seconds, which is…um…whatever. It’s not very long.
After the job finishes, we can run a shell command from the Spark driver node to put all of the images into a
single tar.gz
file for convenient copying or downloading.