我在使用PaddleX进行多class的语义分割,paddleX不支持多class 的dice loss于是想自己写一个,但一直不成功
我以为只要把label换成one-hot-shot,循环读每个class的logit就可以计算,但一直不成功。
#自定义 multi-class dice loss
def multi_dice_loss(logit, label, num_classes= 11):
label = fluid.layers.cast(label, 'int64')
label_one_hot = fluid.layers.one_hot(input=label, depth=num_classes)
dice_sum = 0
for i in range(len(num_classes)):
dice_sum += dice_loss(logit[:,i,:,:], label_one_hot[:,:,:,i],ignore_mask, epsilon=0.00001)
return dice_sum
是我对logit的shape(NCHW)理解不对吗?
谢谢高手解答!
收藏
点赞
0
个赞
请登录后评论
TOP
切换版块