pytorch 分批训练
#返回的数据集,既有标号又有数据from torch.utils.data import Dataset, DataLoaderclass LoadDataset(Dataset):def __init__(self, data):self.x = datadef __len__(self):return self.x.shape[0]def __getitem__(self, idx):return torch.from_numpy(np.array(self.x[idx])).float(),\torch.from_numpy(np.array(idx))x_train = np.loadtxt("./data_s/CITE/cite.txt")
dataset = LoadDataset(x_train)
#在模型训练过程中train_loader = DataLoader(dataset, batch_size=256, shuffle=True) #把打好序号的数据送入DataLoader中,打乱顺序并分批,每批256组数据。for epoch in range(args.pretrain_epoch):for batch_idx, (x, _) in enumerate(train_loader): #遍历每批数据x = x.to(device) #将一批数据送入device中 z, x_bar = model(x)loss = F.mse_loss(x_bar, x)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():x = torch.Tensor(dataset.x).to(device) #此处数据一定要从原始的dataset中取,否则还是一批数据,而标签数据是所有的数据,会造成数据数量不匹配# x = dataset.x.to(device)z, x_bar = model(x)loss = F.mse_loss(x_bar, x)print('{} loss: {}'.format(epoch, loss))kmeans = KMeans(n_clusters=args.n_clusters, n_init=20).fit(z.data.cpu().numpy())eva(y_labels, kmeans.labels_, epoch)torch.save(model.state_dict(), args.pretrain_path)
pytorch 分批训练
#返回的数据集,既有标号又有数据from torch.utils.data import Dataset, DataLoaderclass LoadDataset(Dataset):def __init__(self, data):self.x = datadef __len__(self):return self.x.shape[0]def __getitem__(self, idx):return torch.from_numpy(np.array(self.x[idx])).float(),\torch.from_numpy(np.array(idx))x_train = np.loadtxt("./data_s/CITE/cite.txt")
dataset = LoadDataset(x_train)
#在模型训练过程中train_loader = DataLoader(dataset, batch_size=256, shuffle=True) #把打好序号的数据送入DataLoader中,打乱顺序并分批,每批256组数据。for epoch in range(args.pretrain_epoch):for batch_idx, (x, _) in enumerate(train_loader): #遍历每批数据x = x.to(device) #将一批数据送入device中 z, x_bar = model(x)loss = F.mse_loss(x_bar, x)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():x = torch.Tensor(dataset.x).to(device) #此处数据一定要从原始的dataset中取,否则还是一批数据,而标签数据是所有的数据,会造成数据数量不匹配# x = dataset.x.to(device)z, x_bar = model(x)loss = F.mse_loss(x_bar, x)print('{} loss: {}'.format(epoch, loss))kmeans = KMeans(n_clusters=args.n_clusters, n_init=20).fit(z.data.cpu().numpy())eva(y_labels, kmeans.labels_, epoch)torch.save(model.state_dict(), args.pretrain_path)