我在Pytorch中实现自定义激活功能(例如Swish)时遇到问题。我应该如何在Pytorch中实现和使用自定义激活功能?

最佳答案

您可以编写如下的自定义激活功能(例如加权Tanh)。

class weightedTanh(nn.Module):
    def __init__(self, weights = 1):
        super().__init__()
        self.weights = weights

    def forward(self, input):
        ex = torch.exp(2*self.weights*input)
        return (ex-1)/(ex+1)


如果您使用兼容autograd的操作,请不要担心反向传播。

关于python - pytorch自定义激活功能?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/55765234/

10-12 01:28