Scatter Charts in Matplotlib

  • Post author:
  • Post category:Python

Scatter Charts in Matplotlib

We start by importing matplotlib and display all visuals inline, using the ggplot style sheet.

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.style.use('ggplot')

Scatter plots take 2 equal length arrays as input

In [2]:
import numpy as np

x = np.random.rand(100) * 10
y = x * np.random.rand(100) * 5
plt.scatter(x, y)
plt.plot()

Plotting multiple x’s

If we had multiple series we wanted to plot on the same axes, we can pass them all to plt.scatter() before calling plt.show()

In [3]:
from sklearn import datasets
import pandas as pd

# Load some data
iris = datasets.load_iris()
iris_df = pd.DataFrame(iris['data'], columns=iris['feature_names'])
iris_df['species'] = iris['target']

colours = ['red', 'orange', 'blue']
species = ['I. setosa', 'I. versicolor', 'I. virginica']

for i in range(0, 3):    
    species_df = iris_df[iris_df['species'] == i]    
    plt.scatter(        
        species_df['sepal length (cm)'],        
        species_df['petal length (cm)'],
        color=colours[i],        
        alpha=0.5,        
        label=species[i]   
    )

plt.xlabel('sepal length (cm)')
plt.ylabel('petal length (cm)')
plt.title('Iris dataset: petal length vs sepal length')
plt.legend(loc='lower right')

plt.show()