提问人:simon 提问时间:3/17/2021 更新时间:3/17/2021 访问量:364
如何访问三元组损失的嵌入
How to access embeddings for triplet loss
问:
我正在尝试创建一个具有三重损失的连体网络,我正在使用一个 github 示例来帮助我。我对此相当陌生,我很难理解如何从模型中提取嵌入。以下是架构:
提取我在几个页面上找到的嵌入的代码是这样的:
def triplet_loss(y_true, y_pred):
anchor, positive, negative = y_pred[:,:emb_size], y_pred[:,emb_size:2*emb_size], y_pred[:,2*emb_size:]
positive_dist = tf.reduce_mean(tf.square(anchor - positive), axis=1)
negative_dist = tf.reduce_mean(tf.square(anchor - negative), axis=1)
return tf.maximum(positive_dist - negative_dist + alpha, 0.)
让我感到困惑的是,我发现很难可视化矩阵,我不明白为什么锚点是 y[:,:emb_size],正是 y_pred[:,emb_size:2emb_size] 和负 y_pred[:,2 emb_size:]。
如果需要更多上下文,请提供完整代码:https://github.com/pranjalg2308/siamese_triplet_loss/blob/master/Siamese_With_Triplet_Loss.ipynb
答:
1赞
FancyXun
3/17/2021
#1
在完整代码片段中
in_anc = Input(shape=(105,105,1))
in_pos = Input(shape=(105,105,1))
in_neg = Input(shape=(105,105,1))
em_anc = embedding_model(in_anc)
em_pos = embedding_model(in_pos)
em_neg = embedding_model(in_neg)
out = concatenate([em_anc, em_pos, em_neg], axis=1)
siamese_net = Model(
[in_anc, in_pos, in_neg],
out
)
锚点、pos 和 neg 连接到一个输出张量,因此锚点是 y_pred[:,:emb_size]...
并会给你的嵌入。embedding_model.predict(np.expand_dims(anchor_image[3], axis=0))
评论
0赞
simon
3/18/2021
是的,但为什么不y_pred[0]作为锚点,而不是y_pred[:,:emb_size]
1赞
FancyXun
3/18/2021
@simon y_pred是二维张量。第一个是batch_size,第二个的大小是emb_size*3。
评论