search iconsearch icon
Type something to search...

Pyspark Intro

Table of Contents

Open Table of Contents

0. Overview

Apache Spark is a very fast unified analytics engine for big data and machine learning. It relies on Domain LogoMapReduce to split tasks and distribute them into a cluster. This allows spark to work with PetaBytes of data using a cluster with hundreds or thousands of workers.

Apache Spark is a project 100% open source. At this momement databricks is the major actor behind its develepment.

Apache Spark consists of 4 major products:

  • DataFrames + Spark SQL
  • Streaming
  • MLlib (Machine Learning)
  • GraphX (Graph Computation)

This post is an introduction to Spark DataFrames.

The installation of spark is no easy task and is out of the scope of this post.

1. Create dataframe

1.1. Read file

One of the most common ways to create a dataframe is by reading a file. For example to read a csv you should:

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("pyspark_intro").getOrCreate()

sdf = spark.read.format("csv").load("titanic_train.csv")

It is even posible to use wildcards to read multiple files at one.

sdf = spark.read.format("csv").load("titanic_*.csv") # Read both train and test

1.2. RDD

It is posible to read raw files with spark and process them. As an example we are reading a csv and transform it to a datafram.

from pyspark.sql import SparkSession, types

spark = SparkSession.builder.appName("pyspark_intro").getOrCreate()
sc = spark.sparkContext

data = sc.textFile("datasets/iris.csv")
parts = data.map(lambda x: x.split(";"))

iris_data = parts.map(lambda x: types.Row(SL=x[0], SW=x[1], PL=x[2], classification=x[3]))
sdf = spark.createDataFrame(iris_data)

It is only an example is better to read it directly as a dataframe

2. Inspect

2.1. Show data

To show the data simply call sdf.show(N) where N is the number of rows (default=20)

2.2. General info

There are some functions to get general info about the dataframe.

FunctionWhat it does
sdf.count()Number of rows
sdf.schemaDetails about the structure of the dataframe
sdf.columnsColumns of the dataframe as a python list
sdf.dtypesColumns of the dataframe with their data types (dtypes)
sdf.describe()Basic stats of the dataframe

3. Slicing

3.1. First rows

To retrive the first N rows you can either use sdf.head(N) or sdf.take(N).

3.2. Filter columns

To get one or more columns use sdf.select. For example:

sdf.select("Sex")
sdf.select("Sex", "Age")

3.3. Filter rows

Filters in pyspark follow the same syntax as pandas. There are some synonyms for filtering. These 3 lines do exactly the same:

sdf[sdf["Age"] > 24]
sdf.filter(sdf["Age"] > 24)
sdf.where(sdf["Age"] > 24)

It also has some nice functions. Some examples are:

# Rows where age is between 20 and 30
sdf[sdf["Age"].between(20, 30)]

# Rows where Pclass is one of the values of the list [1, 2]
sdf[sdf["Pclass"].isin([1, 2])]

# Rows that inclde 'Miss.' in the Name
sdf[sdf["Name"].like("%Miss.%")]

# Rows with Name starting or ending with a string
sdf[sdf["Name"].startswith("Hei")]
sdf[sdf["Name"].endswith(".")]

3.4. Unique values

sdf.select("Pclass").distinct()

It is also posible to subtract values based on another list.

sdf.select("Pclass").exceptAll(sdf.select("Survived"))

4. Modify the dataframe

To modify a dataframe you need to update the original dataframe. For example you would do sdf = do_something(sdf).

It is posible to only call do_something(sdf) but it won’t update the dataframe.

4.1. Add columns

As an example, to add a column called new_col with the same value as Age do:

sdf = sdf.withColumn("new_col", sdf["Age"])

4.2. Change dtypes

Dtypes are changed using the cast method of a column.

from pyspark.sql.types as T
sdf = sdf.withColumn("Age", sdf["Age"].cast(T.IntegerType()))

You need to specify a type from pyspark.sql.types.

4.3. Modify certain vales

To modify the data itself use sdf.withColumn. You can either create a new column at the same time or update the existing one. For example:

from pyspark.sql.functions import when

sdf.withColumn("sex_code", sdf["Sex"].substr(1, 1).alias("Sex code"))

# This is the equivalent of pandas df.loc
sdf.withColumn("Age", when(sdf["Age"] > 0, sdf["Age"]).otherwise(-1))

You can fill missing values of multiple columns using:

sdf.fillna({"Age": 0, "Cabin": "no cabin"})

And replace values on certain columns:

sdf.replace("male", "m", "Sex")

4.4. Sort

sdf.sort("Age", ascending=False)

4.5. Delete columns

sdf.drop("new_col")

4.6. Drop duplicates

sdf.drop_duplicates(["Pclass"])

4.7. User Defined Functions (UDF)

It is posible to define new functions and apply them to a column. To define it you need to specify the output type (int, str…). The transformation will be apply element wise.

from pyspark.sql.functions import udf

@udf(IntegerType())
def sum_1(x):
    if x is not None:
        return x + 1
    return None

sdf.withColumn("Next_Age", sum_1(sdf["Age"]))

5. Group data

To group data you can use the sdf.groupby method. It is possible to call simple functions like count or use agg to apply custom aggregations at once.

sdf.groupby("Sex").count()

sdf.groupby("Sex").agg({"sex": "count", "Age": "max"})

6. SQL

One really useful feature of spark is that you can use SQL syntax and spark will translate and apply it.

Before using SQL you need to register the table:

sdf.createOrReplaceTempView("titanic")

Then you can perform queries using spark object:

spark.sql("SELECT * FROM titanic WHERE sex = 'female' LIMIT 5")

The output of spark.sql is another dataframe and you can save it with sdf = spark.sql

7. Write

Spark can write files in a lot of different formats. One of the best is parquet. To save a parquet file use:

sdf.write.format("parquet").mode("overwrite").save("titanic_train.parquet")

8. Interacting with python

Some times you need to retrive some data from spark to python.

For example you can always transform a dataframe to pandas with sdf.toPandas().

When transforming to Pandas it will load the data into memory. Be careful to not load a dataframe too big.

To get a column as a python list use:

sdf.select("Name").rdd.flatMap(lambda x: x).collect()

And finally yo can use collect to retrive one value:

sdf.agg({"Age": "max"}).collect()[0][0]