Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to select a range of rows from a dataframe in PySpark?
A PySpark DataFrame is a distributed collection of data organized into rows and columns. Selecting a range of rows means filtering data based on specific conditions. PySpark provides several methods like filter(), where(), and collect() to achieve this.
Setting Up PySpark
First, install PySpark and import the required modules ?
pip install pyspark
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder \
.appName('DataFrame_Range_Selection') \
.getOrCreate()
# Sample data
customer_data = [
("PREM KUMAR", 1281, "AC", 40000, 4000),
("RATAN SINGH", 1289, "HOME THEATER", 35000, 5000),
("DAVID K", 1221, "NIKON CAMERA", 88000, 10000),
("JONATHAN REDDY", 1743, "GEYSER", 15000, 500),
("JASPREET BRAR", 1234, "HP LAPTOP", 78000, 3564),
("NEIL KAMANT", 1222, "WASHING MACHINE", 25000, 2000)
]
columns = ["CUSTOMER_NAME", "PRODUCT_ID", "PRODUCT_NAME", "ACTUAL_PRICE", "EMI_PER_MONTH"]
df = spark.createDataFrame(customer_data, columns)
df.show()
+-------------+----------+---------------+------------+-------------+ | CUSTOMER_NAME|PRODUCT_ID| PRODUCT_NAME|ACTUAL_PRICE|EMI_PER_MONTH| +-------------+----------+---------------+------------+-------------+ | PREM KUMAR| 1281| AC| 40000| 4000| | RATAN SINGH| 1289| HOME THEATER| 35000| 5000| | DAVID K| 1221| NIKON CAMERA| 88000| 10000| |JONATHAN REDDY| 1743| GEYSER| 15000| 500| | JASPREET BRAR| 1234| HP LAPTOP| 78000| 3564| | NEIL KAMANT| 1222|WASHING MACHINE| 25000| 2000| +-------------+----------+---------------+------------+-------------+
Method 1: Using filter()
The filter() method allows you to specify conditions to select rows within a range ?
# Filter rows where ACTUAL_PRICE is between 25000 and 40000 df.filter((df['ACTUAL_PRICE'] >= 25000) & (df['ACTUAL_PRICE'] <= 40000)).show()
+-------------+----------+---------------+------------+-------------+ | CUSTOMER_NAME|PRODUCT_ID| PRODUCT_NAME|ACTUAL_PRICE|EMI_PER_MONTH| +-------------+----------+---------------+------------+-------------+ | PREM KUMAR| 1281| AC| 40000| 4000| | RATAN SINGH| 1289| HOME THEATER| 35000| 5000| | NEIL KAMANT| 1222|WASHING MACHINE| 25000| 2000| +-------------+----------+---------------+------------+-------------+
Method 2: Using where()
The where() method works similarly to filter() and is often used interchangeably ?
# Select rows where EMI_PER_MONTH is between 2000 and 5000 df.where((df['EMI_PER_MONTH'] >= 2000) & (df['EMI_PER_MONTH'] <= 5000)).show()
+-------------+----------+---------------+------------+-------------+ | CUSTOMER_NAME|PRODUCT_ID| PRODUCT_NAME|ACTUAL_PRICE|EMI_PER_MONTH| +-------------+----------+---------------+------------+-------------+ | PREM KUMAR| 1281| AC| 40000| 4000| | RATAN SINGH| 1289| HOME THEATER| 35000| 5000| | JASPREET BRAR| 1234| HP LAPTOP| 78000| 3564| | NEIL KAMANT| 1222|WASHING MACHINE| 25000| 2000| +-------------+----------+---------------+------------+-------------+
Method 3: Using collect() with Python Logic
The collect() method brings all data to the driver node, allowing Python-style filtering ?
# Collect data and filter using Python logic
filtered_rows = []
for row in df.collect():
if 30000 <= row['ACTUAL_PRICE'] <= 50000:
filtered_rows.append(row)
for row in filtered_rows:
print(row)
Row(CUSTOMER_NAME='PREM KUMAR', PRODUCT_ID=1281, PRODUCT_NAME='AC', ACTUAL_PRICE=40000, EMI_PER_MONTH=4000) Row(CUSTOMER_NAME='RATAN SINGH', PRODUCT_ID=1289, PRODUCT_NAME='HOME THEATER', ACTUAL_PRICE=35000, EMI_PER_MONTH=5000)
Comparison of Methods
| Method | Performance | Best For | Data Transfer |
|---|---|---|---|
filter() |
High | Large datasets | Distributed processing |
where() |
High | SQL-like syntax preference | Distributed processing |
collect() |
Low | Small datasets, complex logic | All data to driver |
Conclusion
Use filter() or where() for efficient range selection on large datasets. Avoid collect() for big data as it transfers all data to the driver node, which can cause memory issues.
