DataLoader加载无Label数据
ljjfordownload 发布于2021-04 浏览:1524 回复:1
0
收藏

以下是完整的测试代码

#测试性的DataLoader
import paddle
from paddle.io import Dataset,DataLoader
class TestDataset(Dataset):
    def __init__(self,num):
        self.num_samples = num
    def __getitem__(self,idx):
        random_img = paddle.randn((3,64,64))
        return random_img
    def __len__(self):
        return self.num_samples

dataset = TestDataset(10000)
#从Dataset中获取的单个元素,预期shape是 [3,64,64]
print(dataset[10].shape) #实际输出确实是 [3,64,64]

#使用DataLoader进行加载,BatchSize设置成 256,预期获取到的数据 是一个 Shape 为 [256,3,64,64] 的Tensor (与pytorch行为一致)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=0)
for data in dataloader:
    print(len(data))  #data 不是一个Tensor,而变成了一个List,长度为 3
    print(data[0].shape) #data中每个元素的 shape 为 [256,64,64]
    break
收藏
点赞
0
个赞
共1条回复 最后由ljjfordownload回复于2021-04
#2ljjfordownload回复于2021-04

将 random_img = paddle.randn((3,64,64)) 修改成 random_img = paddle.randn((1,3,64,64)) 可以解决问题,如果是实际图片,可以进行一次reshape(1,3,64,64)来解决问题

0
TOP
切换版块