自学内容网 自学内容网

商品推荐场景的triplet_loss

def _get_pairwise_mask(labels, ids):
    """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]
        - id[i] == id[j] == id[k]
    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    """
    # Check that i, j are distinct
    labels = tf.reshape(labels, shape=(-1,))
    ids = tf.reshape(ids, shape=(-1,))

    # check id i,j,k are equal
    id_equal = tf.equal(tf.expand_dims(ids, 0), tf.expand_dims(ids, 1))

    # Check if labels[i] != labels[k]
    # check if labels[i] > labels[j]
    label_less = tf.less(tf.expand_dims(labels, 1), tf.expand_dims(labels, 0))

    # Combine the two masks
    mask = tf.logical_and(id_equal, label_less)

    return mask
  • 构建一个triplet(a,p,n)需要满足三个条件,如上所示。
    • i、j、k are distinct:都代表商品,他们是三个不同的商品。
    • labels[i] == labels[j] and labels[i] != labels[k]:i和j同标签,i和k不同标签。
    • id[i] == id[j] == id[k]:来自统一个用户。
  • label_less:这里用less是为了得到label不同的情况,由于label只有0和1,避免另两个label不同的位置计算两次,所以只取label小于另一个label的情况。
  • 这个方式返回的mask,只满足了条件2和条件3.


原文地址:https://blog.csdn.net/liuhe2296044/article/details/142861231

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!