Handling large-scale data efficiently is a critical skill for any Senior Data Engineer, especially when working with Apache Spark. A common challenge is removing duplicates from massive datasets while ensuring scalability, fault tolerance, and minimal performance overhead.

In this article, we'll explore:

• How to efficiently remove duplicates in PySpark

• Performance pitfalls to avoid

• Optimizations for handling skewed data

• Answers to tricky follow-up questions that could come up in a senior data engineering interview

Let's dive in!

🔥 The Problem: Deduplicating Billion-Row Datasets in PySpark

Imagine you have a billion-rows dataset stored in a distributed file system (e.g., Parquet files on S3 or GCS), and you need to remove duplicates without causing performance bottlenecks.

Constraints:

✅ The dataset is distributed across multiple worker nodes

✅ You must minimize shuffle operations for better performance

✅ The dataset may have data skew, making some partitions disproportionately large

✅ The solution must be fault-tolerant and scalable

❌ The Naïve Approach (What NOT to Do)

Most beginners might try:

df = df.dropDuplicates() # Not efficient for big data!

🚨 Why is this bad?

dropDuplicates() requires a full shuffle, meaning Spark has to rearrange all data across nodes — which is extremely slow for large datasets.

• If the data is skewed, some partitions may become overloaded, leading to longer processing times and potential failures.

So, what's the right approach?

✅ The Efficient Approach: Window Functions

Using Spark's Window Functions, we can partition data efficiently, reduce unnecessary shuffles, and keep operations local to worker nodes.

Optimized Deduplication Code:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, row_number
from pyspark.sql.window import Window

# Initialize Spark session
spark = SparkSession.builder.appName("Deduplication").getOrCreate()

# Sample DataFrame
data = [
  (1, "Alice", "2024–01–01"),
  (2, "Bob", "2024–01–02"),
  (1, "Alice", "2024–01–01"), # Duplicate
  (3, "Charlie", "2024–01–03"),
  (2, "Bob", "2024–01–02"), # Duplicate
]

df = spark.createDataFrame(data, ["id", "name", "date"])

# **Efficient Deduplication Using Window Function**
window_spec = Window.partitionBy("id", "name", "date").orderBy(col("id"))

deduped_df = df.withColumn("row_num", row_number().over(window_spec)) \
  .filter(col("row_num") == 1) \
  .drop("row_num")

deduped_df.show()

🔥 Why This Works Efficiently:

1. Window Partitioning:

• Instead of globally shuffling all data, we partition it by duplicate keys (id, name, date).

• Each partition processes duplicates independently, reducing shuffle costs.

2. Assigning Row Numbers:

row_number() assigns a unique number to each duplicate row within a partition.

• We keep only the first occurrence, filtering out duplicates efficiently.

3. Avoiding distinct() and dropDuplicates():

• Both operations cause expensive shuffles.

• Using row_number() keeps operations within partitions, making it more scalable.

🚀 Optimizing for Data Skew:

Problem: If certain keys (e.g., id) have too many duplicates, Spark may struggle with uneven workload distribution.

Solution: Salting the Data

We can introduce a salt key to distribute records more evenly across partitions:

from pyspark.sql.functions import monotonically_increasing_id

df = df.withColumn("salt", monotonically_increasing_id() % 10)

window_spec = Window.partitionBy("id", "name", "date", "salt").orderBy(col("id"))

deduped_df = df.withColumn("row_num", row_number().over(window_spec)) \
  .filter(col("row_num") == 1) \
  .drop("row_num", "salt")

🔹 Why This Helps:

• The salt column creates multiple smaller partitions for large duplicate groups.

• This reduces hotspots where a single worker has too much data.

Follow-Up Questions:

1️⃣ What Happens If the Dataset Has NULL Values?

By default, partitionBy() treats NULL as a separate group.

This means that:

• All NULL values will be grouped into a single partition, potentially leading to skew.

• If duplicates contain NULLs, Spark will still process them normally in the row_number() logic.

Solution:

• If NULL values exist in critical columns, replace them with a default value:

from pyspark.sql.functions import lit
df = df.fillna({'name': 'UNKNOWN', 'date': '1900–01–01'})

2️⃣ How Would You Implement This in RDDs Instead of DataFrames?

While PySpark DataFrames are optimized, an RDD-based approach is possible:

rdd = df.rdd.map(lambda row: ((row.id, row.name, row.date), row)) # Create key-value pairs
deduped_rdd = rdd.reduceByKey(lambda row1, row2: row1) # Keep first occurrence
final_df = spark.createDataFrame(deduped_rdd.map(lambda x: x[1]))

🚨 Why This Isn't Ideal:

• RDD transformations lack optimizations like Catalyst Optimizer and are slower.

• reduceByKey() requires a full shuffle, which is costly at scale.

• DataFrames handle schema enforcement & optimization better.

🔹 Recommendation: Always prefer DataFrames over RDDs unless absolutely necessary.

3️⃣ How to Handle Duplicates Across Multiple Parquet Files?

If duplicates exist across different Parquet files, use Z-Ordering or Bucketing:

df.write \
    .format("parquet") \
    .partitionBy("date") \
    .bucketBy(100, "id") \
    .saveAsTable("deduped_table")

🔹 Why This Works:

• Partitioning by date ensures smaller files for efficient access.

• Bucketing by id groups similar records in the same file, reducing shuffle.

• Spark can quickly deduplicate each bucket instead of scanning everything.

Removing duplicates in PySpark isn't just about calling distinct() — it's about understanding Spark's execution model. By using window functions, partitioning, salting, and bucketing, you can:

✅ Eliminate duplicates efficiently

✅ Minimize costly shuffle operations

✅ Handle skewed data gracefully

✅ Optimize performance for large-scale datasets

🚀 Interview Tip:

If you're asked about deduplication in a senior data engineering interview, walk through:

• Naïve vs. optimized approaches

• How to handle data skew

• Advanced techniques like bucketing & salting

💡 Mastering these techniques will make you stand out as a true PySpark expert!

💬 What's your favorite PySpark performance trick? Drop a comment below!

If you are aspiring Data Engineer or a Data Engineer trying to add more weight to your skill bag or even if you are interested in topics like this, please do hit the Follow 👉 and Clap 👏 show your support, it might not be much but definitely boosts my confidence to pump more usecase based content on different Data Engineering tools.

Thank You 🖤 for Reading!