如何使用色调创建具有连接点的 Swarmplot,该点与 Boxplot 重合

How to create a Swarmplot with connected dots, that coincides with Boxplot using hues

提问人:Zaida 提问时间:4/22/2023 最后编辑:Trenton McKinneyZaida 更新时间:4/23/2023 访问量:247

问:

由于我的数据性质,我有两个年龄组,他们进行了两次测试。重要的是,我必须找到一种方法来可视化整个样本的行为方式(箱线图)以及每个人在会话之间的变化方式(群图/线图)。

当不使用色调或组时,只需连续使用这三个函数,或者只是跳过线图,就像这里一样(带有连接点的群图);但是由于我使用色调在组之间分开,因此我无法将每个主题的数据点连接起来。

到目前为止,我已经实现了绘制线条,但它们与箱线图不符,而是与具有“Pre”和“Post”条件的刻度线对齐:

下图显示了四个箱线图(pre_young、pre_old 和 post_young,post_old),数据点与每个箱线图对齐,但线与“Pre”和“Post”的刻度对齐,而不是与实际数据点或箱线图的中间对齐。

enter image description here

我通过这段代码得到了它:

fig, ax = plt.subplots(figsize=(7,5))
sns.boxplot(data=test_data, 
            x="Session", 
            y="Pre_Post", 
            hue="Age", 
            palette="pastel", 
            boxprops=boxprops, 
            ax=ax)

sns.swarmplot(data=test_data, 
              x="Session", 
              y="Pre_Post", 
              hue="Age", 
              dodge=True, 
              palette="dark", 
              ax=ax)
    
sns.lineplot(data=test_data, 
                 x="Session", 
                 y="Pre_Post", 
                 hue="Age", 
                 estimator=None, 
                 units="Subject", 
                 style="Age", 
                 markers=True, 
                 palette="dark", 
                 ax=ax)

plt.title("Test")
plt.xlabel("Session")
plt.ylabel("Score")

# Move the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

plt.show()

我还试图通过以下方式获取点的坐标:

points = ax.collections[0]
offsets = points.get_offsets()
x_coords = offsets[:, 0]
y_coords = offsets[:, 1]

但是我无法将每个坐标与它们相关的主题相关联。

如果它能帮助您帮助我,我将添加我的数据集示例。它采用 csv 格式:

'Session,Subject,Age,Pre_Post\nPre,SY01,young,14.0\nPre,SY02,young,14.0\nPre,SY03,young,13.0\nPre,SY04,young,13.0\nPre,SY05,young,13.0\nPre,SY06,young,15.0\nPre,SY07,young,14.0\nPre,SY08,young,14.0\nPre,SA01,old,5.0\nPre,SA02,old,1.0\nPre,SA03,old,10.0\nPre,SA04,old,3.0\nPre,SA05,old,9.0\nPre,SA06,old,5.0\nPre,SA07,old,13.0\nPre,SA08,old,13.0\nPost,SY01,young,14.0\nPost,SY02,young,13.0\nPost,SY03,young,14.0\nPost,SY04,young,13.0\nPost,SY05,young,15.0\nPost,SY06,young,13.0\nPost,SY07,young,15.0\nPost,SY08,young,14.0\nPost,SA01,old,6.0\nPost,SA02,old,2.0\nPost,SA03,old,10.0\nPost,SA04,old,7.0\nPost,SA05,old,8.0\nPost,SA06,old,11.0\nPost,SA07,old,14.0\nPost,SA08,old,11.0\n'
蟒蛇 猫 matplotlib seaborn

评论


答:

3赞 Ken Myers 4/22/2023 #1

这将起作用:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

s = 'Session,Subject,Age,Pre_Post\nPre,SY01,young,14.0\nPre,SY02,young,14.0\nPre,SY03,young,13.0\nPre,SY04,young,13.0\nPre,SY05,young,13.0\nPre,SY06,young,15.0\nPre,SY07,young,14.0\nPre,SY08,young,14.0\nPre,SA01,old,5.0\nPre,SA02,old,1.0\nPre,SA03,old,10.0\nPre,SA04,old,3.0\nPre,SA05,old,9.0\nPre,SA06,old,5.0\nPre,SA07,old,13.0\nPre,SA08,old,13.0\nPost,SY01,young,14.0\nPost,SY02,young,13.0\nPost,SY03,young,14.0\nPost,SY04,young,13.0\nPost,SY05,young,15.0\nPost,SY06,young,13.0\nPost,SY07,young,15.0\nPost,SY08,young,14.0\nPost,SA01,old,6.0\nPost,SA02,old,2.0\nPost,SA03,old,10.0\nPost,SA04,old,7.0\nPost,SA05,old,8.0\nPost,SA06,old,11.0\nPost,SA07,old,14.0\nPost,SA08,old,11.0'

a = np.array([r.split(',') for r in s.split('\n')])

test_data = pd.DataFrame(a[1:, :], columns = a[0])
test_data['Pre_Post'] = test_data['Pre_Post'].apply(float)
def encode_session(x):
  if x=='Pre':
    return 0
  else:
    return 1
test_data['Session'] = test_data['Session'].apply(encode_session)

test_data2 = test_data.copy()
def offset_session(row):
  if row['Age']=='young':
    return row['Session']-0.2
  else:
    return row['Session']+0.2
test_data2['Session'] = test_data2.apply(offset_session, axis=1)

fig, ax = plt.subplots(figsize=(7,5))
sns.boxplot(data=test_data, 
            x="Session", 
            y="Pre_Post", 
            hue="Age", 
            palette="pastel", 
            #boxprops=boxprops, 
            ax=ax)

sns.swarmplot(data=test_data, 
              x="Session", 
              y="Pre_Post", 
              hue="Age", 
              dodge=True, 
              palette="dark", 
              ax=ax)
    
sns.lineplot(data=test_data2, 
                 x="Session", 
                 y="Pre_Post", 
                 hue="Age", 
                 estimator=None, 
                 units="Subject", 
                 style="Age", 
                 markers=True, 
                 palette="dark", 
                 ax=ax)

plt.title("Test")
plt.xlabel("Session")
plt.ylabel("Score")

# Move the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

plt.xticks([0,1],['Pre', 'Post'])

plt.show()

enter image description here

我们可以讨论这个阴谋的优点。它肯定是杂乱无章的,最好在两个单独的轴上拆分,彼此重叠的数据更少。我个人不认为条形图更好。之前/之后的线图可以是一个很好的故事讲述者。例如,在我在谷歌上找到的下面这张图中,我更愿意在条形图中查看 ~40 对柱线图:

enter image description here

0赞 Trenton McKinney 4/23/2023 #2
  • 可视化的目的是使从数据中提取含义变得更加容易。
    • 通常将图放在箱线图上,因为它提供了有关分布的其他信息。
    • 您可以,但不应在分布图上放置趋势线。这是两种类型的图,它们传达了有关数据的不同信息,并且图变得难以解释。
  • 由于重点是显示数据的分布,清楚地显示每个数据的变化,因此条形图更合适。'Score''Subject'
    • 分隔组也是一种更清晰的可视化效果。'Age'
  • 如另一个答案所示:
  • 请求是为每个标记添加一条从 到 的趋势线,这会产生一个难以阅读的图即使数据子集很小'Pre''Post''Age'
    • 当有许多标记时,趋势线将只转到中心标记处,因为线图无法与从 .swarmplot

导入和数据

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# data string
s = 'Session,Subject,Age,Pre_Post\nPre,SY01,young,14.0\nPre,SY02,young,14.0\nPre,SY03,young,13.0\nPre,SY04,young,13.0\nPre,SY05,young,13.0\nPre,SY06,young,15.0\nPre,SY07,young,14.0\nPre,SY08,young,14.0\nPre,SA01,old,5.0\nPre,SA02,old,1.0\nPre,SA03,old,10.0\nPre,SA04,old,3.0\nPre,SA05,old,9.0\nPre,SA06,old,5.0\nPre,SA07,old,13.0\nPre,SA08,old,13.0\nPost,SY01,young,14.0\nPost,SY02,young,13.0\nPost,SY03,young,14.0\nPost,SY04,young,13.0\nPost,SY05,young,15.0\nPost,SY06,young,13.0\nPost,SY07,young,15.0\nPost,SY08,young,14.0\nPost,SA01,old,6.0\nPost,SA02,old,2.0\nPost,SA03,old,10.0\nPost,SA04,old,7.0\nPost,SA05,old,8.0\nPost,SA06,old,11.0\nPost,SA07,old,14.0\nPost,SA08,old,11.0'

# split the data into separate components
data = [v.split(',') for v in s.split('\n')]

# load the list of lists into a dataframe
df = pd.DataFrame(data=data[1:], columns=data[0])

# rename the column
df.rename({'Pre_Post': 'Score'}, axis=1, inplace=True)

# convert the column from a string to a float
df['Score'] = df['Score'].apply(float)

# create separate groups of data for the ages
(_, old), (_, young) = df.groupby('Age')

old

   Session Subject  Age  Score
8      Pre    SA01  old    5.0
9      Pre    SA02  old    1.0
10     Pre    SA03  old   10.0
11     Pre    SA04  old    3.0
12     Pre    SA05  old    9.0
13     Pre    SA06  old    5.0
14     Pre    SA07  old   13.0
15     Pre    SA08  old   13.0
24    Post    SA01  old    6.0
25    Post    SA02  old    2.0
26    Post    SA03  old   10.0
27    Post    SA04  old    7.0
28    Post    SA05  old    8.0
29    Post    SA06  old   11.0
30    Post    SA07  old   14.0
31    Post    SA08  old   11.0

young

   Session Subject    Age  Score
0      Pre    SY01  young   14.0
1      Pre    SY02  young   14.0
2      Pre    SY03  young   13.0
3      Pre    SY04  young   13.0
4      Pre    SY05  young   13.0
5      Pre    SY06  young   15.0
6      Pre    SY07  young   14.0
7      Pre    SY08  young   14.0
16    Post    SY01  young   14.0
17    Post    SY02  young   13.0
18    Post    SY03  young   14.0
19    Post    SY04  young   13.0
20    Post    SY05  young   15.0
21    Post    SY06  young   13.0
22    Post    SY07  young   15.0
23    Post    SY08  young   14.0

标图

  • 实际数据可能具有更多观测值,因此增加元组中的第二个数字以增加绘图长度,并调整第二个数字以使条形图使用更多数字。figsizeheight_ratios
# create the figure using height_ratios to make the bottom subplots larger than the top subplots
fig, axes = plt.subplots(2, 2, figsize=(11, 11), height_ratios=[1, 2])

# flatten the axes for easy access
axes = axes.flat

# plot the boxplots
sns.boxplot(data=young, x="Session", y="Score", ax=axes[0])
sns.boxplot(data=old, x="Session", y="Score", ax=axes[1])

# plot the swarmplots
sns.swarmplot(data=young, x="Session", y="Score", hue='Session', edgecolor='k', linewidth=1, legend=None, ax=axes[0])
sns.swarmplot(data=old, x="Session", y="Score", hue='Session', edgecolor='k', linewidth=1, legend=None, ax=axes[1])

# add a title
axes[0].set_title('Age: Young', fontsize=15)
axes[1].set_title('Age: Old', fontsize=15)

# add the barplots
sns.barplot(data=young, x='Score', y='Subject', hue='Session', ax=axes[2])
sns.barplot(data=old, x='Score', y='Subject', hue='Session', ax=axes[3])

# extract the axes level legend properties
handles, labels = axes[3].get_legend_handles_labels()

# iterate through the bottom axes
for ax in axes[2:]:
    # removed the axes legend
    ax.legend().remove()
    
    # iterate through the containers
    for c in ax.containers:
        
        # annotate the bars
        ax.bar_label(c, label_type='center')
    
# add a figure level legend
_ = fig.legend(handles, labels, title='Session', loc='outside right center', frameon=False)

易于阅读的可视化

enter image description here