Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

数据集按照类的方式写作,采用自定义数据采集器,出现数据无法加载现象 #68532

Open
niushao12 opened this issue Sep 30, 2024 · 1 comment
Assignees

Comments

@niushao12
Copy link

niushao12 commented Sep 30, 2024

bug描述 Describe the Bug

按照paddle提供代码进行数据集前处理图类构建,通过自定义数据采集器,训练测试数据无法load。

paddlepaddle ==2.6.1

以下是整个代码文本

import paddle
import numpy as np
from random import Random
from tqdm import trange, tqdm

class MoleculeDatapoint:
def init(self, smiles, targets, features):
self.smiles = smiles
self.targets = targets
self.features = features

class MoleculeDataset(paddle.io.Dataset):
def init(self, data):
self._data = data
self._batch_graph = None
self._random = Random()

def __len__(self) ->int:
    return len(self._data)

def __getitem__(self, item):
    print(item, self._data[item].smiles)
    return self._data[item]

class MoleculeSampler(paddle.io.Sampler):
def init(self, dataset, class_balance: bool=False,
shuffle: bool=False, seed: int=0):
super(paddle.io.Sampler, self).init()
self.dataset = dataset
self.class_balance = class_balance
self.shuffle = shuffle
self._random = Random(seed)
if self.class_balance:
indices = np.arange(len(dataset))
has_active = np.array([any(target == 1 for target in datapoint.
targets) for datapoint in dataset])
self.positive_indices = indices[has_active].tolist()
self.negative_indices = indices[~has_active].tolist()
self.length = 2 * min(len(self.positive_indices), len(self.
negative_indices))
else:
self.positive_indices = self.negative_indices = None
self.length = len(self.dataset)

def __iter__(self):
    if self.class_balance:
        if self.shuffle:
            self._random.shuffle(self.positive_indices)
            self._random.shuffle(self.negative_indices)
        indices = [index for pair in zip(self.positive_indices, self.
            negative_indices) for index in pair]
    else:
        indices = list(range(len(self.dataset)))
        if self.shuffle:
            self._random.shuffle(indices)
    return iter(indices)

def __len__(self):
    return self.length

class DataLoader(paddle.io.DataLoader):
def init(self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None):

    if isinstance(dataset[0], (tuple, list)):
        return_list = True
    else:
        return_list = False

    super().__init__(
        dataset,
        feed_list=None,
        places=None,
        return_list=return_list,
        batch_sampler=batch_sampler,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        collate_fn=collate_fn,
        num_workers=num_workers,
        use_buffer_reader=True,
        use_shared_memory=False,
        timeout=timeout,
        worker_init_fn=worker_init_fn)
    if sampler is not None:
        self.batch_sampler.sampler = sampler

def construct_molecule_batch(data):
data = MoleculeDataset(data)
return data

class MoleculeDataLoader(DataLoader):
def init(self, dataset, batch_size: int=50,
num_workers: int=8, class_balance: bool=False, shuffle: bool=False,
seed: int=0):
self._dataset = dataset
self._batch_size = batch_size
self._num_workers = num_workers
self._class_balance = class_balance
self._shuffle = shuffle
self._seed = seed
self._context = None
self._timeout = 0
self._sampler = MoleculeSampler(dataset=self._dataset,
class_balance=self._class_balance, shuffle=self._shuffle, seed=
self._seed)
super(MoleculeDataLoader, self).init(dataset=self._dataset,
batch_size=self._batch_size, sampler=self._sampler, num_workers
=self._num_workers, collate_fn=construct_molecule_batch,
multiprocessing_context=self._context, timeout=self._timeout)

def __iter__(self):
    return super(MoleculeDataLoader, self).__iter__()

def chemprop_build_data_loader(
smiles: list[str],
fingerprints: np.ndarray = None,
properties: list[int] = None ,
shuffle: bool = False,
num_workers: int = 0
):

if fingerprints is None:
    fingerprints = [None] * len(smiles)

if properties is None:
    properties = [None] * len(smiles)
else:
    properties = [[float(prop)] for prop in properties]

return MoleculeDataLoader(
    dataset=MoleculeDataset([
        MoleculeDatapoint(
            smiles=[smiles],
            targets=prop,
            features=fingerprint,
        ) for smiles, fingerprint, prop in zip(smiles, fingerprints, properties)
    ]),
    batch_size = 4,
    num_workers=num_workers,
    shuffle=True
)

def main():
train_smiles=[384.27,405.311,386.286,356.26,384.27,463.441,236.109,352.316,379.367,275.106,507.472,292.14,98.145,312.453,284.44,570.896,326.477,654.97,223.747,308.636,195.693,350.48,310.437,298.426,298.426,312.453,284.443,337.463,296.41,310.437,364.529,310.481,369.505,340.463,384.516,368.524,466.456,500.629,282.512,204.357]
train_fingerprints=[2.6355,4.39088,1.4796,2.4663,1.496,4.3989,0.9992,3.9613,4.44214,-0.4514,4.5354,1.7634,0.7806,3.8826,3.0149,6.5826,3.5857,7.7242,2.6044,3.7491,2.1734,3.9192,3.6586,3.6366,3.4925,3.8826,4.3135,4.221,3.6126,3.9156,5.2284,4.8697,4.7145,4.0633,4.4259,5.15,2.63012,2.5925,5.974,5.0354]
train_properties=[384.27,405.311,386.286,356.26,384.27,463.441,236.109,352.316,379.367,275.106,507.472,292.14,98.145,312.453,284.44,570.896,326.477,654.97,223.747,308.636,195.693,350.48,310.437,298.426,298.426,312.453,284.443,337.463,296.41,310.437,364.529,310.481,369.505,340.463,384.516,368.524,466.456,500.629,282.512,204.357]

# Build data loaders
train_data_loader = chemprop_build_data_loader(
    smiles=train_smiles,
    fingerprints=train_fingerprints,
    properties=train_properties,
    shuffle=True,
    num_workers=0
)

print(len(train_data_loader))

for batch in tqdm(train_data_loader, total=len(train_data_loader), leave=False):#:这里出现了bug,无法进行的数据,数据长度没有问题,在torch下完全可以加载数据。
    batch: MoleculeDataset
    smile = batch.smiles()
    print("batch", smile)

if name == 'main':
main()

其他补充信息 Additional Supplementary Information

No response

@liaoxin2
Copy link

liaoxin2 commented Oct 8, 2024

您好,我初步排查了一下,该问题可能是由于construct_molecule_batch函数中的data不是张量造成的,可以参看https://github.com/PaddlePaddle/Paddle/issues/41883,该问题我已反馈,后续版本应该会得到解决

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants