如果我有一个像这样的数据集:
image_datasets['train'] = datasets.ImageFolder(train_dir, transform=train_transforms)
如何以编程方式确定数据集中的类或唯一标签的数量?
最佳答案
如果您的数据类型是张量,则可以使用:import torch n_classes = len(torch.unique(Your_Target_Vector))
关于python - pytorch:获取给定数据集的类数,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/55235594/