You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class MoleculeSampler(paddle.io.Sampler):
def init(self, dataset, class_balance: bool=False,
shuffle: bool=False, seed: int=0):
super(paddle.io.Sampler, self).init()
self.dataset = dataset
self.class_balance = class_balance
self.shuffle = shuffle
self._random = Random(seed)
if self.class_balance:
indices = np.arange(len(dataset))
has_active = np.array([any(target == 1 for target in datapoint.
targets) for datapoint in dataset])
self.positive_indices = indices[has_active].tolist()
self.negative_indices = indices[~has_active].tolist()
self.length = 2 * min(len(self.positive_indices), len(self.
negative_indices))
else:
self.positive_indices = self.negative_indices = None
self.length = len(self.dataset)
def __iter__(self):
if self.class_balance:
if self.shuffle:
self._random.shuffle(self.positive_indices)
self._random.shuffle(self.negative_indices)
indices = [index for pair in zip(self.positive_indices, self.
negative_indices) for index in pair]
else:
indices = list(range(len(self.dataset)))
if self.shuffle:
self._random.shuffle(indices)
return iter(indices)
def __len__(self):
return self.length
bug描述 Describe the Bug
按照paddle提供代码进行数据集前处理图类构建,通过自定义数据采集器,训练测试数据无法load。
paddlepaddle ==2.6.1
以下是整个代码文本
import paddle
import numpy as np
from random import Random
from tqdm import trange, tqdm
class MoleculeDatapoint:
def init(self, smiles, targets, features):
self.smiles = smiles
self.targets = targets
self.features = features
class MoleculeDataset(paddle.io.Dataset):
def init(self, data):
self._data = data
self._batch_graph = None
self._random = Random()
class MoleculeSampler(paddle.io.Sampler):
def init(self, dataset, class_balance: bool=False,
shuffle: bool=False, seed: int=0):
super(paddle.io.Sampler, self).init()
self.dataset = dataset
self.class_balance = class_balance
self.shuffle = shuffle
self._random = Random(seed)
if self.class_balance:
indices = np.arange(len(dataset))
has_active = np.array([any(target == 1 for target in datapoint.
targets) for datapoint in dataset])
self.positive_indices = indices[has_active].tolist()
self.negative_indices = indices[~has_active].tolist()
self.length = 2 * min(len(self.positive_indices), len(self.
negative_indices))
else:
self.positive_indices = self.negative_indices = None
self.length = len(self.dataset)
class DataLoader(paddle.io.DataLoader):
def init(self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None):
def construct_molecule_batch(data):
data = MoleculeDataset(data)
return data
class MoleculeDataLoader(DataLoader):
def init(self, dataset, batch_size: int=50,
num_workers: int=8, class_balance: bool=False, shuffle: bool=False,
seed: int=0):
self._dataset = dataset
self._batch_size = batch_size
self._num_workers = num_workers
self._class_balance = class_balance
self._shuffle = shuffle
self._seed = seed
self._context = None
self._timeout = 0
self._sampler = MoleculeSampler(dataset=self._dataset,
class_balance=self._class_balance, shuffle=self._shuffle, seed=
self._seed)
super(MoleculeDataLoader, self).init(dataset=self._dataset,
batch_size=self._batch_size, sampler=self._sampler, num_workers
=self._num_workers, collate_fn=construct_molecule_batch,
multiprocessing_context=self._context, timeout=self._timeout)
def chemprop_build_data_loader(
smiles: list[str],
fingerprints: np.ndarray = None,
properties: list[int] = None ,
shuffle: bool = False,
num_workers: int = 0
):
def main():
train_smiles=[384.27,405.311,386.286,356.26,384.27,463.441,236.109,352.316,379.367,275.106,507.472,292.14,98.145,312.453,284.44,570.896,326.477,654.97,223.747,308.636,195.693,350.48,310.437,298.426,298.426,312.453,284.443,337.463,296.41,310.437,364.529,310.481,369.505,340.463,384.516,368.524,466.456,500.629,282.512,204.357]
train_fingerprints=[2.6355,4.39088,1.4796,2.4663,1.496,4.3989,0.9992,3.9613,4.44214,-0.4514,4.5354,1.7634,0.7806,3.8826,3.0149,6.5826,3.5857,7.7242,2.6044,3.7491,2.1734,3.9192,3.6586,3.6366,3.4925,3.8826,4.3135,4.221,3.6126,3.9156,5.2284,4.8697,4.7145,4.0633,4.4259,5.15,2.63012,2.5925,5.974,5.0354]
train_properties=[384.27,405.311,386.286,356.26,384.27,463.441,236.109,352.316,379.367,275.106,507.472,292.14,98.145,312.453,284.44,570.896,326.477,654.97,223.747,308.636,195.693,350.48,310.437,298.426,298.426,312.453,284.443,337.463,296.41,310.437,364.529,310.481,369.505,340.463,384.516,368.524,466.456,500.629,282.512,204.357]
if name == 'main':
main()
其他补充信息 Additional Supplementary Information
No response
The text was updated successfully, but these errors were encountered: