if __name__ == '__main__': module = CAM_Module() in_data = torch.randint(0, 255, (2, 3, 7, 7), dtype=torch.float32)