Scatter Charts in Matplotlib
We start by importing matplotlib and display all visuals inline, using the ggplot style sheet.
In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
Scatter plots take 2 equal length arrays as input
In [2]:
import numpy as np
x = np.random.rand(100) * 10
y = x * np.random.rand(100) * 5
plt.scatter(x, y)
plt.plot()
Plotting multiple x’s
If we had multiple series we wanted to plot on the same axes, we can pass them all to plt.scatter()
before calling plt.show()
In [3]:
from sklearn import datasets
import pandas as pd
# Load some data
iris = datasets.load_iris()
iris_df = pd.DataFrame(iris['data'], columns=iris['feature_names'])
iris_df['species'] = iris['target']
colours = ['red', 'orange', 'blue']
species = ['I. setosa', 'I. versicolor', 'I. virginica']
for i in range(0, 3):
species_df = iris_df[iris_df['species'] == i]
plt.scatter(
species_df['sepal length (cm)'],
species_df['petal length (cm)'],
color=colours[i],
alpha=0.5,
label=species[i]
)
plt.xlabel('sepal length (cm)')
plt.ylabel('petal length (cm)')
plt.title('Iris dataset: petal length vs sepal length')
plt.legend(loc='lower right')
plt.show()