How to Use randomSplit in PySpark | Split Your DataFrame into Train and Test Sets | PySpark Tutorial

How to Use randomSplit() in PySpark | Split Your DataFrame into Train and Test Sets

How to Use randomSplit() in PySpark

Split Your DataFrame into Train and Test Sets | PySpark Tutorial

Learn how to efficiently split your DataFrame into training and testing datasets using PySpark's randomSplit() method. A must-have step when preparing data for machine learning workflows.

📘 Introduction

In machine learning and data science, splitting data into training and testing sets is a fundamental step. PySpark’s randomSplit() function allows you to do this easily and reproducibly across distributed data.

🔧 PySpark Code Example

from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName("PySpark randomSplit Example").getOrCreate()

# Sample data
data = [
    ("Aamir Shahzad", 85, "Math"),
    ("Ali Raza", 78, "Science"),
    ("Bob", 92, "History"),
    ("Lisa", 80, "Math"),
    ("John", 88, "Science"),
    ("Emma", 75, "History"),
    ("Sophia", 90, "Math"),
    ("Daniel", 83, "Science"),
    ("David", 95, "History"),
    ("Olivia", 77, "Math")
]

columns = ["Name", "Score", "Subject"]

# Create DataFrame
df = spark.createDataFrame(data, columns)
df.show()

# Split DataFrame (70% train, 30% test)
train_df, test_df = df.randomSplit([0.7, 0.3], seed=42)

# Show results
print("Training Set:")
train_df.show()

print("Testing Set:")
test_df.show()

📊 Original DataFrame Output

+-------------+-----+--------+
| Name        |Score|Subject |
+-------------+-----+--------+
|Aamir Shahzad|   85|    Math|
|Ali Raza     |   78| Science|
|Bob          |   92| History|
|Lisa         |   80|    Math|
|John         |   88| Science|
|Emma         |   75| History|
|Sophia       |   90|    Math|
|Daniel       |   83| Science|
|David        |   95| History|
|Olivia       |   77|    Math|
+-------------+-----+--------+

✅ Training Dataset Output (Approx. 70%)

+-------------+-----+--------+
| Name        |Score|Subject |
+-------------+-----+--------+
|Ali Raza     |   78| Science|
|Daniel       |   83| Science|
|Emma         |   75| History|
|John         |   88| Science|
|Lisa         |   80|    Math|
|Sophia       |   90|    Math|
+-------------+-----+--------+

✅ Testing Dataset Output (Approx. 30%)

+-------------+-----+--------+
| Name        |Score|Subject |
+-------------+-----+--------+
|Aamir Shahzad|   85|    Math|
|Bob          |   92| History|
|David        |   95| History|
|Olivia       |   77|    Math|
+-------------+-----+--------+

💡 Key Notes

  • randomSplit([0.7, 0.3]) divides your data into 70% and 30% splits.
  • You can split into more than two sets, e.g., `[0.6, 0.2, 0.2]` for train/test/val.
  • The seed parameter ensures reproducible results.
  • This function is efficient for large datasets processed in Spark clusters.

🎥 Watch the Video Tutorial

Watch on YouTube

Author: Aamir Shahzad

© 2025 PySpark Tutorials. All rights reserved.