matplotlib Legends Single Legend Shared Across Multiple Subplots


Example

Sometimes you will have a grid of subplots, and you want to have a single legend that describes all the lines for each of the subplots as in the following image.

Image of Single Legend Across Multiple Subplots

In order to do this, you will need to create a global legend for the figure instead of creating a legend at the axes level (which will create a separate legend for each subplot). This is achieved by calling fig.legend() as can be seen in the code for the following code.

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10,4))
fig.suptitle('Example of a Single Legend Shared Across Multiple Subplots')

# The data
x =  [1, 2, 3]
y1 = [1, 2, 3]
y2 = [3, 1, 3]
y3 = [1, 3, 1]
y4 = [2, 2, 3]

# Labels to use in the legend for each line
line_labels = ["Line A", "Line B", "Line C", "Line D"]

# Create the sub-plots, assigning a different color for each line.
# Also store the line objects created
l1 = ax1.plot(x, y1, color="red")[0]
l2 = ax2.plot(x, y2, color="green")[0]
l3 = ax3.plot(x, y3, color="blue")[0]
l4 = ax3.plot(x, y4, color="orange")[0] # A second line in the third subplot

# Create the legend
fig.legend([l1, l2, l3, l4],     # The line objects
           labels=line_labels,   # The labels for each line
           loc="center right",   # Position of legend
           borderaxespad=0.1,    # Small spacing around legend box
           title="Legend Title"  # Title for the legend
           )

# Adjust the scaling factor to fit your legend text completely outside the plot
# (smaller value results in more space being made for the legend)
plt.subplots_adjust(right=0.85)

plt.show()

Something to note about the above example is the following:

l1 = ax1.plot(x, y1, color="red")[0]

When plot() is called, it returns a list of line2D objects. In this case it just returns a list with one single line2D object, which is extracted with the [0] indexing, and stored in l1.

A list of all the line2D objects that we are interested in including in the legend need to be passed on as the first argument to fig.legend(). The second argument to fig.legend() is also necessary. It is supposed to be a list of strings to use as the labels for each line in the legend.

The other arguments passed on to fig.legend() are purely optional, and just help with fine-tuning the aesthetics of the legend.