When working with large datasets in Apache Spark, a common performance issue is data skew. This occurs when a few keys dominate the data distribution, leading to uneven partitions and slow queries. It mainly happens during operations that require shuffling, like joins or even regular aggregations.

A practical way to reduce skew is salting, which involves artificially spreading out heavy keys across multiple partitions. In this post, I’ll guide you through this with a practical example.

How Salting Resolves Data Skew Issues

By adding a randomly generated number to the join key and then joining over this combined key, we can distribute large keys more evenly. This makes the data distribution more uniform and spreads the load across more workers, instead of sending most of the data to one worker and leaving the others idle.

Benefits of Salting

When to Use Salting

During joins or aggregations with skewed keys, use salting when you notice long shuffle times or executor failures due to data skew. It's also helpful in real-time streaming applications where partitioning affects data processing efficiency, or when most workers are idle while a few are stuck in a running state.

Salting Example in Scala

Let's generate some data with an unbalanced number of rows. We can assume there are two datasets we need to join: one is a large dataset, and the other is a small dataset.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

// Simulated large dataset with skew
val largeDF = Seq(
  (1, "txn1"), (1, "txn2"), (1, "txn3"), (2, "txn4"), (3, "txn5")
).toDF("customer_id", "transaction")

// Small dataset
val smallDF = Seq(
  (1, "Ahmed"), (2, "Ali"), (3, "Hassan")
).toDF("customer_id", "name")

Let’s add the salting column to the large datasets, which we use randomization to spreed the values of the large key into smaller partitions

// Step 1: create a salting key in the large dataset
val numBuckets = 3
val saltedLargeDF = largeDF.
    withColumn("salt", (rand() * numBuckets).cast("int")).
    withColumn("salted_customer_id", concat($"customer_id", lit("_"), $"salt"))

saltedLargeDF.show()
+-----------+-----------+----+------------------+
|customer_id|transaction|salt|salted_customer_id|
+-----------+-----------+----+------------------+
|          1|       txn1|   1|               1_1|
|          1|       txn2|   1|               1_1|
|          1|       txn3|   2|               1_2|
|          2|       txn4|   2|               2_2|
|          3|       txn5|   0|               3_0|
+-----------+-----------+----+------------------+

To make sure we cover all possible randomized salted keys in the large datasets, we need to explode the small dataset with all possible salted values

// Step 2: Explode rows in smallDF for possible salted keys
val saltedSmallDF = (0 until numBuckets).toDF("salt").
    crossJoin(smallDF).
    withColumn("salted_customer_id", concat($"customer_id", lit("_"), $"salt")) 

saltedSmallDF.show()
+----+-----------+------+------------------+
|salt|customer_id|  name|salted_customer_id|
+----+-----------+------+------------------+
|   0|          1| Ahmed|               1_0|
|   1|          1| Ahmed|               1_1|
|   2|          1| Ahmed|               1_2|
|   0|          2|   Ali|               2_0|
|   1|          2|   Ali|               2_1|
|   2|          2|   Ali|               2_2|
|   0|          3|Hassan|               3_0|
|   1|          3|Hassan|               3_1|
|   2|          3|Hassan|               3_2|
+----+-----------+------+------------------+

Now we can easily join the two datasets

// Step 3: Perform salted join
val joinedDF = saltedLargeDF.
    join(saltedSmallDF, Seq("salted_customer_id", "customer_id"), "inner").
    select("customer_id", "transaction", "name")

joinedDF.show()
+-----------+-----------+------+
|customer_id|transaction|  name|
+-----------+-----------+------+
|          1|       txn2| Ahmed|
|          1|       txn1| Ahmed|
|          1|       txn3| Ahmed|
|          2|       txn4|   Ali|
|          3|       txn5|Hassan|
+-----------+-----------+------+

Salting Example in Python

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, rand, lit, concat
from pyspark.sql.types import IntegerType

# Simulated large dataset with skew
largeDF = spark.createDataFrame([
    (1, "txn1"), (1, "txn2"), (1, "txn3"), (2, "txn4"), (3, "txn5")
], ["customer_id", "transaction"])

# Small dataset
smallDF = spark.createDataFrame([
    (1, "Ahmed"), (2, "Ali"), (3, "Hassan")
], ["customer_id", "name"])

# Step 1: create a salting key in the large dataset
numBuckets = 3
saltedLargeDF = largeDF.withColumn("salt", (rand() * numBuckets).cast(IntegerType())) \
    .withColumn("salted_customer_id", concat(col("customer_id"), lit("_"), col("salt")))

# Step 2: Explode rows in smallDF for possible salted keys
salt_range = spark.range(0, numBuckets).withColumnRenamed("id", "salt")
saltedSmallDF = salt_range.crossJoin(smallDF) \
    .withColumn("salted_customer_id", concat(col("customer_id"), lit("_"), col("salt")))

# Step 3: Perform salted join
joinedDF = saltedLargeDF.join(
    saltedSmallDF,
    on=["salted_customer_id", "customer_id"],
    how="inner"
).select("customer_id", "transaction", "name")

Notes

Tuning Tip: Choosing numBuckets

Rule of Thumb: Start small (e.g., 10-20) and increase gradually based on observed shuffle sizes and task runtime.


Final Thoughts

Salting is an effective and simple method to manage skew in Apache Spark when traditional partitioning or hints (SKEWED JOIN) are insufficient. With the right tuning and monitoring, this technique can significantly decrease job execution times on highly skewed datasets.

Originally published at https://practical-software.com on May 11, 2025.