背景
基于Gumbel-softmax方法EEG通道选择层的PyTorch实现。该层可以放置在任何深度神经网络架构的前面,以共同学习给定任务和网络权重的脑电图通道的最佳子集。这一层由选择神经元组成,每个神经元都使用输入通道上离散分布的连续松弛来学习最佳的单热权重向量来选择输入通道,而不是线性组合它们。
亮点
使用Gumbel-softmax方法对多通道脑电数据进行单通道选择(非多通道线性加权)
使用多尺度滤波卷积网络实现运动想象4分类。
环境配置
PyTorch 0.3.1,
CUDA 9.1
数据
High-Gamma Dataset
方法
多尺度滤波卷积网络主要代码:
class MSFBCNN(nn.Module):
def __init__(self,input_dim,output_dim,FT=10):
super(MSFBCNN, self).__init__()
self.T = input_dim[1]
self.FT = FT
self.D = 1
self.FS = self.FT*self.D
self.C=input_dim[0]
self.output_dim = output_dim
# Parallel temporal convolutions
self.conv1a = nn.Conv2d(1, self.FT, (1, 65), padding = (0,32),bias=False)
self.conv1b = nn.Conv2d(1, self.FT, (1, 41), padding = (0,20),bias=False)
self.conv1c = nn.Conv2d(1, self.FT, (1, 27), padding = (0,13),bias=False)
self.conv1d = nn.Conv2d(1, self.FT, (1, 17), padding = (0,8),bias=False)
self.batchnorm1 = nn.BatchNorm2d(4*self.FT, False)
# Spatial convolution
self.conv2 = nn.Conv2d(4*self.FT, self.FS, (self.C,1),padding=(0,0),groups=1,bias=False)
self.batchnorm2 = nn.BatchNorm2d(self.FS, False)
#Temporal average pooling
self.pooling2 = nn.AvgPool2d(kernel_size=(1, 75),stride=(1,15),padding=(0,0))
self.drop=nn.Dropout(0.5)
#Classification
self.fc1 = nn.Linear(self.FS*math.ceil(1+(self.T-75)/15), self.output_dim)
def forward(self, x):
# Layer 1
x1 = self.conv1a(x);
x2 = self.conv1b(x);
x3 = self.conv1c(x);
x4 = self.conv1d(x);
x = torch.cat([x1,x2,x3,x4],dim=1)
x = self.batchnorm1(x)
# Layer 2
x = torch.pow(self.batchnorm2(self.conv2(x)),2)
x = self.pooling2(x)
x = torch.log(x)
x = self.drop(x)
# FC Layer
x = x.view(-1, self.num_flat_features(x))
x = self.fc1(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
Gumbel-softmax 再参数化主要代码:
class SelectionLayer(nn.Module):
def __init__(self, N,M,temperature=1.0):
super(SelectionLayer, self).__init__()
self.floatTensor = torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor
self.N = N
self.M = M
self.qz_loga = Parameter(torch.randn(N,M)/100)
self.temperature=self.floatTensor([temperature])
self.freeze=False
self.thresh=3.0
def quantile_concrete(self, x):
g = -torch.log(-torch.log(x))
y = (self.qz_loga+g)/self.temperature
y = torch.softmax(y,dim=1)
return y
def regularization(self):
eps = 1e-10
z = torch.clamp(torch.softmax(self.qz_loga,dim=0),eps,1)
H = torch.sum(F.relu(torch.norm(z,1,dim=1)-self.thresh))
return H
def get_eps(self, size):
eps = self.floatTensor(size).uniform_(epsilon, 1-epsilon)
return eps
def sample_z(self, batch_size, training):
if training:
eps = self.get_eps(self.floatTensor(batch_size, self.N, self.M))
z = self.quantile_concrete(eps)
z=z.view(z.size(0),1,z.size(1),z.size(2))
return z
else:
ind = torch.argmax(self.qz_loga,dim=0)
one_hot = self.floatTensor(np.zeros((self.N,self.M)))
for j in range(self.M):
one_hot[ind[j],j]=1
one_hot=one_hot.view(1,1,one_hot.size(0),one_hot.size(1))
one_hot = one_hot.expand(batch_size,1,one_hot.size(2),one_hot.size(3))
return one_hot
def forward(self, x):
z = self.sample_z(x.size(0),training=(self.training and not self.freeze))
z_t = torch.transpose(z,2,3)
out = torch.matmul(z_t,x)
return out
结果
实现从64通道脑电信号中提取出N个重要通道脑电信号,增强后续分类任务的性能
代码获取
https://download.csdn.net/download/YINTENAXIONGNAIER/88946872
参考文献
- Strypsteen, Thomas, and Alexander Bertrand. "End-to-end learnable EEG channel selection for deep neural networks with Gumbel-softmax." Journal of Neural Engineering 18.4 (2021): 0460a9.