PyTorch加載數據

PyTorch包含一個名為torchvision的包,用於加載和準備數據集。它包括兩個基本功能,即DatasetDataLoader,它們有助於數據集的轉換和加載。

數據集

數據集用於從給定數據集讀取和轉換數據點。實現的基本語法如下所述 -

trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
   download = True, transform = transform)

DataLoader用於隨機播放和批量處理數據。它可用於與多處理工作程式並行加載數據。

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
   shuffle = True, num_workers = 2)

示例:加載CSV檔

使用Python包Panda來加載csv檔。原始檔具有以下格式:(圖像名稱,68個標記 - 每個標記具有xy座標)。

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

上一篇: PyTorch術語 下一篇: PyTorch線性回歸