Matplotlib: Python code to create multiple subplots in a figure

  • Post category:Algorithms

Sometimes one needs to put multiple subplots in a single figure. Generating different plots and then combining them using some application/software is not a good idea. Python module ‘Matplotlib’ has function subplots() that creates a figure and a grid of subplots with a single call. This function makes the task of generating subplots very simple. In this post, I will write a simple code to generate 4 subplots in a single figure. You can modify it as per your requirements.

import matplotlib.pyplot as plt
import numpy as np

# data for plot
y1 = [1.58313312147504, 2.151702043136252, 2.5608005252087307, 2.9424727091929768, 3.3129772922261838,
      3.6862908437076696,
      4.045384082850282, 4.386794877471702, 4.744930744894622, 5.103653599598331, 5.41626619992688, 5.75177855846055,
      6.054462369396616, 6.321792380759004, 6.60552357162056, 6.948063864336287, 7.5237405011977785, 7.902742212122294,
      8.583677229923369, 8.798699623732597]
y2 = [0.6320960208668404, 0.4648097436871075, 0.3905281845255269, 0.33987100666513936, 0.3018482093597703,
      0.2712850771688737,
      0.24721116837655902, 0.22795881770408896, 0.2107546716866719, 0.19594054644898817, 0.18463334110502955,
      0.1738652223227884,
      0.1651737072553206, 0.158187446061609, 0.15139036969026198, 0.1439360244804642, 0.13343059148562944,
      0.1269709741638624,
      0.11723902004666707, 0.11448991641692413]
y3 = [0.1200896247696161, 0.18244589449149176, 0.2242941214229461, 0.24581329861577975, 0.2566183603599734,
      0.2606710430402102,
      0.2644363290239076, 0.2645337587776442, 0.26377230895525217, 0.2625894786364123, 0.2609595354056566,
      0.2582966622862364,
      0.25571240885210966, 0.2535724524765861, 0.2509086427940947, 0.24808285583731363, 0.24272653849104733,
      0.23964893311729024,
      0.2341399651062872, 0.23320303742803714]
y4 = [0.04569573347394232, 0.13024486559120693, 0.20371026778817394, 0.24547634001260885, 0.26618036558414054,
      0.27361281867429255,
      0.2768022627421338, 0.27426966120148877, 0.26944247394263277, 0.26336341959826876, 0.2572481425562363,
      0.24999618843208762,
      0.24337338135684394, 0.23765852380490024, 0.23153483152809465, 0.22450268972327098, 0.2133040155152524,
      0.20627426135953833,
      0.1947400499553148, 0.19154333108493005]
x = np.array([i for i in range(len(y1))])
y = [y1, y2, y3, y4]
lbl = ['ippv', 'ppv', 'mcc', 'f1']

# Draw subplots
nr = 2
nc = 2
k = 0
fig, axs = plt.subplots(nrows=nr, ncols=nc, sharex=True)
for i in range(2):
    for j in range(2):
        ax = axs[i][j]
        ax.plot(x, y[k])
        ax.set_xlabel("Scale_pos_weight")
        ax.set_ylabel('Value')
        ax.set_title(lbl[k])
        ax.grid(b=True, which='major', color='b', linestyle='-')
        ax.grid(b=True, which='minor', color='r', linestyle='--')
        ax.minorticks_on()
        k += 1
fig.suptitle('ML classification results')
plt.savefig('ml_results.png')
plt.show()

The above code will generate the following figure containing 4 subplots.

Subplot in python

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.