PySpark Cheat Sheet

This cheat sheet covers PySpark related code snippets. Code snippets cover common PySpark operations and also some scenario based code. I am regularly adding more code snippets and you can also request for anything specific and I will try to add it quickly as well. Also will request you to add to comment section any code snippet you wish to share, we will add it to main list.

PySpark Code Snippets

PySpark – Create Dataframe from CSV in S3

df_taxi = spark.read.option("header","true").option("delimiter",",").option("inferschema","true").csv("s3://nyc-tlc/trip data/yellow_tripdata_2020-12.csv")

PySpark – Print Dataframe Schema

df_taxi.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)

PySpark – Check rows count in dataframe

df_taxi.count()

1461897

PySpark – Print all column names in dataframe

df_taxi.columns

['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime', 'passenger_count', 'trip_distance', 'RatecodeID', 'store_and_fwd_flag', 'PULocationID', 'DOLocationID', 'payment_type', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge', 'total_amount', 'congestion_surcharge']

PySpark – Check datatype of all columns in dataframe

df_taxi.dtypes

[('VendorID', 'int'), ('tpep_pickup_datetime', 'timestamp'), ('tpep_dropoff_datetime', 'timestamp'), ('passenger_count', 'int'), ('trip_distance', 'double'), ('RatecodeID', 'int'), ('store_and_fwd_flag', 'string'), ('PULocationID', 'int'), ('DOLocationID', 'int'), ('payment_type', 'int'), ('fare_amount', 'double'), ('extra', 'double'), ('mta_tax', 'double'), ('tip_amount', 'double'), ('tolls_amount', 'double'), ('improvement_surcharge', 'double'), ('total_amount', 'double'), ('congestion_surcharge', 'double')]

PySpark – Pick 5 rows from dataframe and return list

df_taxi.take(5)

[Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2020, 12, 1, 0, 7, 13), tpep_dropoff_datetime=datetime.datetime(2020, 12, 1, 0, 18, 12), passenger_count=1, trip_distance=7.6, RatecodeID=1, store_and_fwd_flag='N', PULocationID=138, DOLocationID=263, payment_type=1, fare_amount=21.5, extra=3.0, mta_tax=0.5, tip_amount=2.5, tolls_amount=6.12, improvement_surcharge=0.3, total_amount=33.92, congestion_surcharge=2.5), 
Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2020, 12, 1, 0, 41, 19), tpep_dropoff_datetime=datetime.datetime(2020, 12, 1, 0, 49, 45), passenger_count=1, trip_distance=1.6, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=263, payment_type=1, fare_amount=8.0, extra=3.0, mta_tax=0.5, tip_amount=2.95, tolls_amount=0.0, improvement_surcharge=0.3, total_amount=14.75, congestion_surcharge=2.5), 
Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2020, 12, 1, 0, 33, 40), tpep_dropoff_datetime=datetime.datetime(2020, 12, 1, 1, 0, 35), passenger_count=1, trip_distance=16.74, RatecodeID=2, store_and_fwd_flag='N', PULocationID=132, DOLocationID=164, payment_type=1, fare_amount=52.0, extra=0.0, mta_tax=0.5, tip_amount=2.5, tolls_amount=6.12, improvement_surcharge=0.3, total_amount=63.92, congestion_surcharge=2.5), 
Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2020, 12, 1, 0, 2, 15), tpep_dropoff_datetime=datetime.datetime(2020, 12, 1, 0, 13, 9), passenger_count=1, trip_distance=4.16, RatecodeID=1, store_and_fwd_flag='N', PULocationID=238, DOLocationID=48, payment_type=1, fare_amount=14.0, extra=0.5, mta_tax=0.5, tip_amount=1.0, tolls_amount=0.0, improvement_surcharge=0.3, total_amount=18.8, congestion_surcharge=2.5), 
Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2020, 12, 1, 0, 37, 42), tpep_dropoff_datetime=datetime.datetime(2020, 12, 1, 0, 45, 11), passenger_count=1, trip_distance=2.22, RatecodeID=1, store_and_fwd_flag='N', PULocationID=238, DOLocationID=41, payment_type=2, fare_amount=8.5, extra=0.5, mta_tax=0.5, tip_amount=0.0, tolls_amount=0.0, improvement_surcharge=0.3, total_amount=9.8, congestion_surcharge=0.0)]

PySpark – Pick 5 rows from dataframe and return dataframe

df_taxi.limit(5).show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       1| 2020-12-01 00:07:13|  2020-12-01 00:18:12|              1|          7.6|         1|                 N|         138|         263|           1|       21.5|  3.0|    0.5|       2.5|        6.12|                  0.3|       33.92|                 2.5|
|       1| 2020-12-01 00:41:19|  2020-12-01 00:49:45|              1|          1.6|         1|                 N|         140|         263|           1|        8.0|  3.0|    0.5|      2.95|         0.0|                  0.3|       14.75|                 2.5|
|       2| 2020-12-01 00:33:40|  2020-12-01 01:00:35|              1|        16.74|         2|                 N|         132|         164|           1|       52.0|  0.0|    0.5|       2.5|        6.12|                  0.3|       63.92|                 2.5|
|       2| 2020-12-01 00:02:15|  2020-12-01 00:13:09|              1|         4.16|         1|                 N|         238|          48|           1|       14.0|  0.5|    0.5|       1.0|         0.0|                  0.3|        18.8|                 2.5|
|       2| 2020-12-01 00:37:42|  2020-12-01 00:45:11|              1|         2.22|         1|                 N|         238|          41|           2|        8.5|  0.5|    0.5|       0.0|         0.0|                  0.3|         9.8|                 0.0|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+

PySpark – Show only 5 rows in output

Show columns using “show” in PySpark dataframe.

df_taxi.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       1| 2020-12-01 00:07:13|  2020-12-01 00:18:12|              1|          7.6|         1|                 N|         138|         263|           1|       21.5|  3.0|    0.5|       2.5|        6.12|                  0.3|       33.92|                 2.5|
|       1| 2020-12-01 00:41:19|  2020-12-01 00:49:45|              1|          1.6|         1|                 N|         140|         263|           1|        8.0|  3.0|    0.5|      2.95|         0.0|                  0.3|       14.75|                 2.5|
|       2| 2020-12-01 00:33:40|  2020-12-01 01:00:35|              1|        16.74|         2|                 N|         132|         164|           1|       52.0|  0.0|    0.5|       2.5|        6.12|                  0.3|       63.92|                 2.5|
|       2| 2020-12-01 00:02:15|  2020-12-01 00:13:09|              1|         4.16|         1|                 N|         238|          48|           1|       14.0|  0.5|    0.5|       1.0|         0.0|                  0.3|        18.8|                 2.5|
|       2| 2020-12-01 00:37:42|  2020-12-01 00:45:11|              1|         2.22|         1|                 N|         238|          41|           2|        8.5|  0.5|    0.5|       0.0|         0.0|                  0.3|         9.8|                 0.0|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
only showing top 5 rows

PySpark – Select column from dataframe

SELECT columns using “select” in PySpark dataframe which accepts column list as input argument

df_taxi.select("vendorid","tpep_pickup_datetime","tpep_dropoff_datetime","passenger_count","trip_distance","total_amount").show(5)

+--------+--------------------+---------------------+---------------+-------------+------------+
|vendorid|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|total_amount|
+--------+--------------------+---------------------+---------------+-------------+------------+
|       1| 2020-12-01 00:07:13|  2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1| 2020-12-01 00:41:19|  2020-12-01 00:49:45|              1|          1.6|       14.75|
|       2| 2020-12-01 00:33:40|  2020-12-01 01:00:35|              1|        16.74|       63.92|
|       2| 2020-12-01 00:02:15|  2020-12-01 00:13:09|              1|         4.16|        18.8|
|       2| 2020-12-01 00:37:42|  2020-12-01 00:45:11|              1|         2.22|         9.8|
+--------+--------------------+---------------------+---------------+-------------+------------+
only showing top 5 rows

PySpark – Rename columns in Dataframe

Rename columns using “withColumnRenamed” in PySpark dataframe.

df_taxi.withColumnRenamed("tpep_pickup_datetime","pickup_time").withColumnRenamed("tpep_dropoff_datetime","drop_time").select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").show(5)

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|
+--------+-------------------+-------------------+---------------+-------------+------------+
only showing top 5 rows

PySpark – Rename LIST of columns in Dataframe

Rename list of columns in Dataframe using LIST inside FOR loop. We will use Python ZIP function to merge two list and use that.

existing_cols = ["tpep_pickup_datetime","tpep_dropoff_datetime"]
new_cols = ["pickup_time","drop_time"]
print (set(zip(existing_cols,new_cols)))

{('tpep_pickup_datetime', 'pickup_time'), ('tpep_dropoff_datetime', 'drop_time')}


for i ,j in zip(existing_cols,new_cols):
    df_taxi = df_taxi.withColumnRenamed(i,j)

df_taxi.select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").show(5)

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|
+--------+-------------------+-------------------+---------------+-------------+------------+
only showing top 5 rows

PySpark – Describe Dataframe

Describe returns the statistics of the values for different columns in dataframe

df_taxi.describe().show()

+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+---------------------+------------------+--------------------+
|summary|          VendorID|   passenger_count|    trip_distance|        RatecodeID|store_and_fwd_flag|     PULocationID|      DOLocationID|      payment_type|       fare_amount|             extra|            mta_tax|        tip_amount|       tolls_amount|improvement_surcharge|      total_amount|congestion_surcharge|
+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+---------------------+------------------+--------------------+
|  count|           1362441|           1362441|          1461897|           1362441|           1362441|          1461897|           1461897|           1362441|           1461897|           1461897|            1461897|           1461897|            1461897|              1461897|           1461897|             1461897|
|   mean|1.6853148136322968|1.4193128362989664|6.070390177967453|  1.02891501356756|              null|166.1230250831625|162.04916009814644|1.2890679302810177|12.213858596053779| 0.932831027083303| 0.4924785740719079|1.9610265292287639|0.21780083685779542|  0.29669258504199447|17.598124874775678|  2.1472160829388116|
| stddev|0.4643905448400174|1.0633235169199586|886.5289264178996|0.5479180986778657|              null|67.83917050363657| 71.80259558721558|0.4967290209582704| 329.7733338785015|1.2030748172566303|0.07956615727017227| 2.637018116691093|  1.287813125568131|  0.04391618443743257| 329.8461373300217|  0.9017798324971192|
|    min|                 1|                 0|              0.0|                 1|                 N|                1|                 1|                 1|            -500.0|              -4.5|               -0.5|            -22.55|             -29.62|                 -0.3|            -502.8|                -2.5|
|    max|                 2|                 9|        350914.89|                99|                 Y|              265|               265|                 4|         398464.88|              7.43|                3.3|           1393.56|             102.25|                  0.3|          398467.7|                 2.5|
+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+---------------------+------------------+--------------------+

PySpark – Check number of partitions in Dataframe

We will use getNumPartitions() attribute of rdd object to get number of partitions for dataframe.

df_taxi.rdd.getNumPartitions()

2

PySpark – Write Dataframe to CSV

There are 2 output files as the dataframe had 2 partitions only.

df_taxi.write.format("csv").option("header","true").save("hdfs:///var/data/taxi/")

#check the output file using shell command
hdfs dfs -ls -h hdfs:///var/data/taxi/

Found 3 items
-rw-r--r--   1 hadoop hdfsadmingroup          0 2021-06-30 11:36 hdfs:///var/data/taxi/_SUCCESS
-rw-r--r--   1 hadoop hdfsadmingroup     76.3 M 2021-06-30 11:36 hdfs:///var/data/taxi/part-00000-cbee1ee9-7ef6-43a1-9e2f-8542f2e6a8a1-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     72.8 M 2021-06-30 11:36 hdfs:///var/data/taxi/part-00001-cbee1ee9-7ef6-43a1-9e2f-8542f2e6a8a1-c000.csv

PySpark – Write Dataframe to CSV in single part

We will use coalesce to reduce the number of partitions to 1 before writing it into csv.

df_taxi.coalesce(1).write.format("csv").option("header","true").save("hdfs:///var/data/taxi1/")

#check the output file using shell command
hdfs dfs -ls -h hdfs:///var/data/taxi1/

Found 2 items
-rw-r--r--   1 hadoop hdfsadmingroup          0 2021-06-30 11:43 hdfs:///var/data/taxi1/_SUCCESS
-rw-r--r--   1 hadoop hdfsadmingroup    149.1 M 2021-06-30 11:43 hdfs:///var/data/taxi1/part-00000-ee55afc0-10a6-4242-ba85-70aca302360a-c000.csv

PySpark – Write Dataframe to CSV into multiple parts

We will use “repartition” to increase the number of partitions to 5 before writing into csv.

df_taxi.repartition(5).write.format("csv").option("header","true").save("hdfs:///var/data/taxi2/")

#check the output file using shell command
hdfs dfs -ls -h hdfs:///var/data/taxi2/

Found 6 items
-rw-r--r--   1 hadoop hdfsadmingroup          0 2021-06-30 11:45 hdfs:///var/data/taxi2/_SUCCESS
-rw-r--r--   1 hadoop hdfsadmingroup     29.8 M 2021-06-30 11:45 hdfs:///var/data/taxi2/part-00000-4ba901d0-1cd5-4f24-be59-3b798c566a9b-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     29.8 M 2021-06-30 11:45 hdfs:///var/data/taxi2/part-00001-4ba901d0-1cd5-4f24-be59-3b798c566a9b-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     29.8 M 2021-06-30 11:45 hdfs:///var/data/taxi2/part-00002-4ba901d0-1cd5-4f24-be59-3b798c566a9b-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     29.8 M 2021-06-30 11:45 hdfs:///var/data/taxi2/part-00003-4ba901d0-1cd5-4f24-be59-3b798c566a9b-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     29.8 M 2021-06-30 11:45 hdfs:///var/data/taxi2/part-00004-4ba901d0-1cd5-4f24-be59-3b798c566a9b-c000.csv

PySpark – Write Dataframe to CSV by repartitioning on column basis

We can pass column name on the basis of which dataframe data is repartitioned and will be created into separate file parts. It may create few empty files as default value for shuffle partition is 200. The data has 3 distinct values for “VendorID” column hence 3 files parts with data in the output.

from pyspark.sql.functions import col
df_taxi.repartition(col("VendorID")).write.format("csv").option("header","true").save("hdfs:///var/data/taxi3/")

#check the output file using shell command
hdfs dfs -ls -h hdfs:///var/data/taxi3/
  
Found 5 items
-rw-r--r--   1 hadoop hdfsadmingroup          0 2021-06-30 11:48 hdfs:///var/data/taxi3/_SUCCESS
-rw-r--r--   1 hadoop hdfsadmingroup          0 2021-06-30 11:48 hdfs:///var/data/taxi3/part-00000-6fcbbe85-6cd2-4522-9a44-8c1d87041f8d-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     10.8 M 2021-06-30 11:48 hdfs:///var/data/taxi3/part-00042-6fcbbe85-6cd2-4522-9a44-8c1d87041f8d-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     43.2 M 2021-06-30 11:48 hdfs:///var/data/taxi3/part-00043-6fcbbe85-6cd2-4522-9a44-8c1d87041f8d-c000.csv
-rw-r--r--   1 hadoop hdfsadmingroup     95.1 M 2021-06-30 11:48 hdfs:///var/data/taxi3/part-00174-6fcbbe85-6cd2-4522-9a44-8c1d87041f8d-c000.csv

PySpark – Add new column to the dataframe with static value

We will use “withColumn” to add new column to the existing dataframe. The new column will have a static value using “lit” method.

from pyspark.sql.functions import lit df_taxi.withColumn("bonus",lit(1)).select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount","bonus").show(5)

+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|bonus|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|    1|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|    1|
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|    1|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|    1|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|    1|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
only showing top 5 rows

PySpark – Add new column to the dataframe with derived value

df_taxi.withColumn("bonus",col("total_amount")*0.1).select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount","bonus").show(5)

+--------+-------------------+-------------------+---------------+-------------+------------+------------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|             bonus|
+--------+-------------------+-------------------+---------------+-------------+------------+------------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|3.3920000000000003|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|             1.475|
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|             6.392|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|1.8800000000000001|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|0.9800000000000001|
+--------+-------------------+-------------------+---------------+-------------+------------+------------------+

PySpark – Cast column to different datatype in dataframe

We will convert column to another datatype using “cast” function.

df_taxi.withColumn("bonus",(col("total_amount")*0.1).cast("integer")).select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount","bonus").show(5)

+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|bonus|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|    3|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|    1|
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|    6|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|    1|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|    0|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
only showing top 5 rows

PySpark – Calculate count on aggregate data in dataframe

df_taxi.groupBy("VendorID").count().show()

+--------+------+
|VendorID| count|
+--------+------+
|    null| 99456|
|       1|428740|
|       2|933701|
+--------+------+

PySpark – Calculate average of columns on aggregate data

df_taxi.groupBy("VendorID").avg("trip_distance","total_amount").show()

+--------+------------------+------------------+
|VendorID|avg(trip_distance)| avg(total_amount)|
+--------+------------------+------------------+
|    null| 55.74605383285073| 31.10511090331231|
|       1|  2.37059429957549| 17.32007029899172|
|       2|2.4779207262282674|16.287065248877628|
+--------+------------------+------------------+

PySpark – Calculate sum of columns on aggregate data

df_taxi.groupBy("VendorID").sum("trip_distance","total_amount").show()

+--------+------------------+-------------------+
|VendorID|sum(trip_distance)|  sum(total_amount)|
+--------+------------------+-------------------+
|    null| 5544279.530000002|  3093589.909999829|
|       1|1016368.5999999954|   7425806.93998971|
|       2|2313637.0600000597|1.520724910994229E7|
+--------+------------------+-------------------+

PySpark – Calculate maximum of columns on aggregate data

df_taxi.groupBy("VendorID").max("trip_distance","total_amount").show()

+--------+------------------+-----------------+
|VendorID|max(trip_distance)|max(total_amount)|
+--------+------------------+-----------------+
|    null|         350914.89|           158.41|
|       1|             283.7|         398467.7|
|       2|            407.78|          8361.36|
+--------+------------------+-----------------+

PySpark – Calculate minimum of columns on aggregate data

df_taxi.groupBy("VendorID").min("trip_distance","total_amount").show()

+--------+------------------+-----------------+
|VendorID|min(trip_distance)|min(total_amount)|
+--------+------------------+-----------------+
|    null|               0.0|           -56.32|
|       1|               0.0|              0.0|
|       2|               0.0|           -502.8|
+--------+------------------+-----------------+

PySpark – Merge two Dataframe data into one

We will use UNION ALL to merge data from two dataframe into one.

df_taxi1 = df_taxi.select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").filter("VendorID==1").limit(5)

df_taxi1.show()

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|
|       1|2020-12-01 00:27:47|2020-12-01 00:45:40|              0|          8.4|       40.92|
|       1|2020-12-01 00:08:15|2020-12-01 00:16:04|              2|          2.7|       15.95|
|       1|2020-12-01 00:14:34|2020-12-01 00:31:04|              1|          7.6|       32.15|
+--------+-------------------+-------------------+---------------+-------------+------------+


df_taxi2 = df_taxi.select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").filter("VendorID==2").limit(5)

df_taxi2.show()

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|
|       2|2020-12-01 00:40:47|2020-12-01 00:57:03|              1|         6.44|       24.96|
|       2|2020-12-01 00:01:42|2020-12-01 00:06:06|              1|         0.99|       11.16|
+--------+-------------------+-------------------+---------------+-------------+------------+

# UNION ALL to MERGE 2 dataframe data
df_taxi1.unionAll(df_taxi2).show()

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|
|       1|2020-12-01 00:27:47|2020-12-01 00:45:40|              0|          8.4|       40.92|
|       1|2020-12-01 00:08:15|2020-12-01 00:16:04|              2|          2.7|       15.95|
|       1|2020-12-01 00:14:34|2020-12-01 00:31:04|              1|          7.6|       32.15|
|       2|2020-12-01 00:33:40|2020-12-01 01:00:35|              1|        16.74|       63.92|
|       2|2020-12-01 00:02:15|2020-12-01 00:13:09|              1|         4.16|        18.8|
|       2|2020-12-01 00:37:42|2020-12-01 00:45:11|              1|         2.22|         9.8|
|       2|2020-12-01 00:40:47|2020-12-01 00:57:03|              1|         6.44|       24.96|
|       2|2020-12-01 00:01:42|2020-12-01 00:06:06|              1|         0.99|       11.16|
+--------+-------------------+-------------------+---------------+-------------+------------+

PySpark – Inner join two dataframe

df_taxi1 = df_taxi.select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").filter("VendorID==1").limit(10)
df_taxi1.show()

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|
|       1|2020-12-01 00:27:47|2020-12-01 00:45:40|              0|          8.4|       40.92|
|       1|2020-12-01 00:08:15|2020-12-01 00:16:04|              2|          2.7|       15.95|
|       1|2020-12-01 00:14:34|2020-12-01 00:31:04|              1|          7.6|       32.15|
|       1|2020-12-01 00:11:02|2020-12-01 00:17:34|              1|          1.7|        11.3|
|       1|2020-12-01 00:54:55|2020-12-01 00:57:09|              1|          0.5|         7.8|
|       1|2020-12-01 00:11:22|2020-12-01 00:40:36|              1|         21.0|       71.85|
|       1|2020-12-01 00:53:58|2020-12-01 00:54:06|              1|          0.0|         3.8|
|       1|2020-12-01 00:52:58|2020-12-01 00:54:28|              1|          0.7|         7.8|
+--------+-------------------+-------------------+---------------+-------------+------------+

df_taxi2 = df_taxi.select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").filter("VendorID==1").limit(5)
df_taxi2.show()

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|       1|2020-12-01 00:07:13|2020-12-01 00:18:12|              1|          7.6|       33.92|
|       1|2020-12-01 00:41:19|2020-12-01 00:49:45|              1|          1.6|       14.75|
|       1|2020-12-01 00:27:47|2020-12-01 00:45:40|              0|          8.4|       40.92|
|       1|2020-12-01 00:08:15|2020-12-01 00:16:04|              2|          2.7|       15.95|
|       1|2020-12-01 00:14:34|2020-12-01 00:31:04|              1|          7.6|       32.15|
+--------+-------------------+-------------------+---------------+-------------+------------+

# INNER JOIN
df_taxi1.join(df_taxi2, 'pickup_time', how='inner').show()

+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|        pickup_time|vendorid|          drop_time|passenger_count|trip_distance|total_amount|vendorid|          drop_time|passenger_count|trip_distance|total_amount|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|2020-12-01 00:07:13|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|
|2020-12-01 00:41:19|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|
|2020-12-01 00:27:47|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|
|2020-12-01 00:08:15|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|
|2020-12-01 00:14:34|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+

PySpark – LEFT join two dataframe

>>> df_taxi1.join(df_taxi2, 'pickup_time', how='left').show()

+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|        pickup_time|vendorid|          drop_time|passenger_count|trip_distance|total_amount|vendorid|          drop_time|passenger_count|trip_distance|total_amount|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|2020-12-01 00:07:13|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|
|2020-12-01 00:41:19|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|
|2020-12-01 00:27:47|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|
|2020-12-01 00:08:15|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|
|2020-12-01 00:14:34|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|
|2020-12-01 00:11:02|       1|2020-12-01 00:17:34|              1|          1.7|        11.3|    null|               null|           null|         null|        null|
|2020-12-01 00:54:55|       1|2020-12-01 00:57:09|              1|          0.5|         7.8|    null|               null|           null|         null|        null|
|2020-12-01 00:11:22|       1|2020-12-01 00:40:36|              1|         21.0|       71.85|    null|               null|           null|         null|        null|
|2020-12-01 00:53:58|       1|2020-12-01 00:54:06|              1|          0.0|         3.8|    null|               null|           null|         null|        null|
|2020-12-01 00:52:58|       1|2020-12-01 00:54:28|              1|          0.7|         7.8|    null|               null|           null|         null|        null|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+

PySpark – RIGHT join two dataframe

>>> df_taxi1.join(df_taxi2, 'pickup_time', how='right').show()

+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|        pickup_time|vendorid|          drop_time|passenger_count|trip_distance|total_amount|vendorid|          drop_time|passenger_count|trip_distance|total_amount|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|2020-12-01 00:07:13|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|
|2020-12-01 00:41:19|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|
|2020-12-01 00:27:47|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|
|2020-12-01 00:08:15|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|
|2020-12-01 00:14:34|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+

PySpark – FULL OUTER join two dataframe

>>> df_taxi1.join(df_taxi2, 'pickup_time', how='outer').show()

+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|        pickup_time|vendorid|          drop_time|passenger_count|trip_distance|total_amount|vendorid|          drop_time|passenger_count|trip_distance|total_amount|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+
|2020-12-01 00:07:13|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|       1|2020-12-01 00:18:12|              1|          7.6|       33.92|
|2020-12-01 00:08:15|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|       1|2020-12-01 00:16:04|              2|          2.7|       15.95|
|2020-12-01 00:11:02|       1|2020-12-01 00:17:34|              1|          1.7|        11.3|    null|               null|           null|         null|        null|
|2020-12-01 00:11:22|       1|2020-12-01 00:40:36|              1|         21.0|       71.85|    null|               null|           null|         null|        null|
|2020-12-01 00:14:34|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|       1|2020-12-01 00:31:04|              1|          7.6|       32.15|
|2020-12-01 00:27:47|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|       1|2020-12-01 00:45:40|              0|          8.4|       40.92|
|2020-12-01 00:41:19|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|       1|2020-12-01 00:49:45|              1|          1.6|       14.75|
|2020-12-01 00:52:58|       1|2020-12-01 00:54:28|              1|          0.7|         7.8|    null|               null|           null|         null|        null|
|2020-12-01 00:53:58|       1|2020-12-01 00:54:06|              1|          0.0|         3.8|    null|               null|           null|         null|        null|
|2020-12-01 00:54:55|       1|2020-12-01 00:57:09|              1|          0.5|         7.8|    null|               null|           null|         null|        null|
+-------------------+--------+-------------------+---------------+-------------+------------+--------+-------------------+---------------+-------------+------------+

PySpark – Check NULL values in column in dataframe

df_taxi.filter("VendorID is NULL").select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount").show(5)

+--------+-------------------+-------------------+---------------+-------------+------------+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|
+--------+-------------------+-------------------+---------------+-------------+------------+
|    null|2020-12-01 00:04:00|2020-12-01 00:07:00|           null|         0.51|       59.84|
|    null|2020-12-01 00:24:00|2020-12-01 00:44:00|           null|         5.51|       30.08|
|    null|2020-12-01 00:46:00|2020-12-01 01:02:00|           null|         8.07|        43.0|
|    null|2020-12-01 00:06:00|2020-12-01 00:27:00|           null|         3.83|        25.5|
|    null|2020-12-01 00:27:13|2020-12-01 00:41:18|           null|         8.53|       26.22|
+--------+-------------------+-------------------+---------------+-------------+------------+
only showing top 5 rows

PySpark – DROP any row with value as NULL for any given columns

We will use dropna function to drop any row with value NULL in it for given column.

df_taxi.count()
1461897

df_taxi.dropna(how='any',subset=["VendorID","passenger_count"]).count()
1362441

PySpark – DROP any row with value as NULL for all given columns

df_taxi.count()
1461897

df_taxi.dropna(how='all',subset=["VendorID","passenger_count"]).count()
1362441

PySpark – Replace NULL value with given value for given column

We will use fillna function to replace NULL value for any given column in the dataframe with default value.

df_taxi.filter("VendorID is null").count()
99456

df_taxi.fillna(0,subset=["VendorID"]).filter("VendorID is null").count()
0

PySpark – Window function row number

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

df_taxi.withColumn("rowid",row_number().over(Window.partitionBy("VendorID").orderBy(col("total_amount").desc()))).filter("rowid <4").select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount","rowid").show(10)

+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|rowid|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|    null|2020-12-28 08:16:00|2020-12-28 09:33:00|           null|        60.93|      158.41|    1|
|    null|2020-12-18 09:12:00|2020-12-18 10:32:00|           null|        40.09|      150.75|    2|
|    null|2020-12-22 14:30:00|2020-12-22 15:25:00|           null|        34.28|      128.89|    3|
|       1|2020-12-26 13:39:29|2020-12-26 13:46:24|              2|          0.0|    398467.7|    1|
|       1|2020-12-08 19:01:07|2020-12-09 00:30:36|              0|        265.1|       921.2|    2|
|       1|2020-12-18 19:42:24|2020-12-18 19:42:57|              1|          0.0|      656.15|    3|
|       2|2020-12-20 23:01:34|2020-12-20 23:14:21|              1|         4.51|     8361.36|    1|
|       2|2020-12-11 13:09:01|2020-12-11 13:09:10|              1|          0.0|       617.8|    2|
|       2|2020-12-11 13:10:01|2020-12-11 13:10:10|              1|          0.0|       612.8|    3|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+

PySpark – Window function rank

from pyspark.sql.window import Window
from pyspark.sql.functions import rank
 df_taxi.withColumn("rowid",rank().over(Window.partitionBy("VendorID").orderBy(col("total_amount").desc()))).filter("rowid=2").select("vendorid","pickup_time","drop_time","passenger_count","trip_distance","total_amount","rowid").show(10)
  
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|vendorid|        pickup_time|          drop_time|passenger_count|trip_distance|total_amount|rowid|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+
|    null|2020-12-18 09:12:00|2020-12-18 10:32:00|           null|        40.09|      150.75|    2|
|       1|2020-12-08 19:01:07|2020-12-09 00:30:36|              0|        265.1|       921.2|    2|
|       2|2020-12-11 13:09:01|2020-12-11 13:09:10|              1|          0.0|       617.8|    2|
+--------+-------------------+-------------------+---------------+-------------+------------+-----+

PySpark – Pivot to convert rows into columns

We can convert rows into columns using Pivot function in PySpark. In this example we will convert row value for “passenger_count” column into separate columns and will calculate “total_amount” sum for each column.

df_taxi.select("VendorID","passenger_count","total_amount").show(5)

+--------+---------------+------------+
|VendorID|passenger_count|total_amount|
+--------+---------------+------------+
|       1|              1|       33.92|
|       1|              1|       14.75|
|       2|              1|       63.92|
|       2|              1|        18.8|
|       2|              1|         9.8|
+--------+---------------+------------+
only showing top 5 rows

df_taxi.filter("VendorID in (1,2)").groupBy("VendorID").pivot("passenger_count",values=['1','2','3','4','5','6']).sum("total_amount").select("VendorID",col("1").cast("bigint"),col("2").cast("bigint"),col("3").cast("bigint"),col("4").cast("bigint"),col("5").cast("bigint"),col("6").cast("bigint")).show()

+--------+--------+-------+------+------+------+------+
|VendorID|       1|      2|     3|     4|     5|     6|
+--------+--------+-------+------+------+------+------+
|       1| 5482565|1265902|160975| 49064|  6209|  4495|
|       2|11213768|2122529|626923|263774|546248|431529|
+--------+--------+-------+------+------+------+------+

I will keep updating this list with more PySpark Code Snippets.