多class的dice loss实现
vesper1215 发布于2021-02 浏览:1254 回复:0
0
收藏

我在使用PaddleX进行多class的语义分割,paddleX不支持多class 的dice loss于是想自己写一个,但一直不成功

我以为只要把label换成one-hot-shot,循环读每个class的logit就可以计算,但一直不成功。

#自定义 multi-class dice loss
def multi_dice_loss(logit, label, num_classes= 11):

    label = fluid.layers.cast(label, 'int64')
    label_one_hot = fluid.layers.one_hot(input=label, depth=num_classes)

    dice_sum = 0
    for i in range(len(num_classes)):
        dice_sum += dice_loss(logit[:,i,:,:], label_one_hot[:,:,:,i],ignore_mask, epsilon=0.00001)
    return dice_sum

是我对logit的shape(NCHW)理解不对吗?

 

谢谢高手解答!

收藏
点赞
0
个赞
TOP
切换版块