最新消息: USBMI致力于为网友们分享Windows、安卓、IOS等主流手机系统相关的资讯以及评测、同时提供相关教程、应用、软件下载等服务。

pytorch 分批训练

IT圈 admin 2浏览 0评论

pytorch 分批训练

#返回的数据集,既有标号又有数据from torch.utils.data import Dataset, DataLoaderclass LoadDataset(Dataset):def __init__(self, data):self.x = datadef __len__(self):return self.x.shape[0]def __getitem__(self, idx):return torch.from_numpy(np.array(self.x[idx])).float(),\torch.from_numpy(np.array(idx))x_train = np.loadtxt("./data_s/CITE/cite.txt")
dataset = LoadDataset(x_train)

#在模型训练过程中train_loader = DataLoader(dataset, batch_size=256, shuffle=True)  #把打好序号的数据送入DataLoader中,打乱顺序并分批,每批256组数据。for epoch in range(args.pretrain_epoch):for batch_idx, (x, _) in enumerate(train_loader):   #遍历每批数据x = x.to(device)                                #将一批数据送入device中 z, x_bar = model(x)loss = F.mse_loss(x_bar, x)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():x = torch.Tensor(dataset.x).to(device)          #此处数据一定要从原始的dataset中取,否则还是一批数据,而标签数据是所有的数据,会造成数据数量不匹配# x = dataset.x.to(device)z, x_bar = model(x)loss = F.mse_loss(x_bar, x)print('{} loss: {}'.format(epoch, loss))kmeans = KMeans(n_clusters=args.n_clusters, n_init=20).fit(z.data.cpu().numpy())eva(y_labels, kmeans.labels_, epoch)torch.save(model.state_dict(), args.pretrain_path)

pytorch 分批训练

#返回的数据集,既有标号又有数据from torch.utils.data import Dataset, DataLoaderclass LoadDataset(Dataset):def __init__(self, data):self.x = datadef __len__(self):return self.x.shape[0]def __getitem__(self, idx):return torch.from_numpy(np.array(self.x[idx])).float(),\torch.from_numpy(np.array(idx))x_train = np.loadtxt("./data_s/CITE/cite.txt")
dataset = LoadDataset(x_train)

#在模型训练过程中train_loader = DataLoader(dataset, batch_size=256, shuffle=True)  #把打好序号的数据送入DataLoader中,打乱顺序并分批,每批256组数据。for epoch in range(args.pretrain_epoch):for batch_idx, (x, _) in enumerate(train_loader):   #遍历每批数据x = x.to(device)                                #将一批数据送入device中 z, x_bar = model(x)loss = F.mse_loss(x_bar, x)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():x = torch.Tensor(dataset.x).to(device)          #此处数据一定要从原始的dataset中取,否则还是一批数据,而标签数据是所有的数据,会造成数据数量不匹配# x = dataset.x.to(device)z, x_bar = model(x)loss = F.mse_loss(x_bar, x)print('{} loss: {}'.format(epoch, loss))kmeans = KMeans(n_clusters=args.n_clusters, n_init=20).fit(z.data.cpu().numpy())eva(y_labels, kmeans.labels_, epoch)torch.save(model.state_dict(), args.pretrain_path)

与本文相关的文章

发布评论

评论列表 (0)

  1. 暂无评论