GRAM(GRAM可能是一个新提出的模型或方法的缩写,这里我们根据上下文进行解释)受到诸如TorchKAN和ChebyKAN等Kolmogorov-Arnold网络(KAN)替代方案的启发。GRAM引入了一种简化的KAN模型,但同时利用了Gram多项式变换的简单性。它与其他替代方案的不同之处在于其独特的离散性特征。与其他在连续区间上定义的多项式不同,Gram多项式是在一组离散点上定义的。GRAM的这种离散性为处理像图像和文本数据这样的离散数据集提供了一种新颖的方法。
Kolmogorov-Arnold Networks (KAN): KAN是一种基于Kolmogorov-Arnold表示定理的神经网络模型,该定理表明任何多元连续函数都可以表示为一系列一元函数的复合。KAN通常用于函数逼近和复杂系统的建模。
Gram Polynomials: Gram多项式是一种在离散点上定义的特殊多项式,与在连续区间上定义的传统多项式不同。这种离散性使得Gram多项式在处理离散数据集时具有独特的优势。
离散数据集处理: 在机器学习和数据科学中,处理离散数据集(如图像和文本)是一个重要任务。图像可以视为像素的离散集合,而文本则可以视为字符或单词的离散序列。GRAM通过其离散性特征,为这类数据的处理提供了一种新的视角和方法。
模型简化与效率: GRAM通过简化KAN模型并引入Gram多项式变换,旨在提高模型的效率和实用性。这种简化可能使得GRAM在保持一定性能的同时,具有更低的计算复杂度和更快的训练速度。
应用前景: GRAM的离散性特征使其在图像处理、自然语言处理等领域具有潜在的应用前景。例如,在图像分类任务中,GRAM可能能够更有效地提取图像中的特征;在自然语言处理任务中,GRAM可能能够更准确地捕捉文本中的语义信息。
GRAM通过结合KAN和Gram多项式的思想,为离散数据集的处理提供了一种新颖而有效的解决方案。随着进一步的研究和应用,GRAM有望在多个领域展现出其独特的优势和价值。
代码
代码在原作者的基础上做了修改,代码如下:
# Based on this: https://github.com/Khochawongwat/GRAMKAN/blob/main/model.py
from functools import lru_cache
import torch
import torch.nn as nn
from torch.nn.functional import conv3d, conv2d, conv1d
class KAGNConvNDLayer(nn.Module):
def __init__(self, conv_class, norm_class, conv_w_fun, input_dim, output_dim, degree, kernel_size,
groups=1, padding=0, stride=1, dilation=1, dropout: float = 0.0, ndim: int = 2.,
**norm_kwargs):
super(KAGNConvNDLayer, self).__init__()
self.inputdim = input_dim
self.outdim = output_dim
self.degree = degree
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
self.base_activation = nn.SiLU()
self.conv_w_fun = conv_w_fun
self.ndim = ndim
self.dropout = None
self.norm_kwargs = norm_kwargs
self.p_dropout = dropout
if dropout > 0:
if ndim == 1:
self.dropout = nn.Dropout1d(p=dropout)
if ndim == 2:
self.dropout = nn.Dropout2d(p=dropout)
if ndim == 3:
self.dropout = nn.Dropout3d(p=dropout)
if groups <= 0:
raise ValueError('groups must be a positive integer')
if input_dim % groups != 0:
raise ValueError('input_dim must be divisible by groups')
if output_dim % groups != 0:
raise ValueError('output_dim must be divisible by groups')
self.base_conv = nn.ModuleList([conv_class(input_dim // groups,
output_dim // groups,
kernel_size,
stride,
padding,
dilation,
groups=1,
bias=False) for _ in range(groups)])
self.layer_norm = nn.ModuleList([norm_class(output_dim // groups, **norm_kwargs) for _ in range(groups)])
poly_shape = (groups, output_dim // groups, (input_dim // groups) * (degree + 1)) + tuple(
kernel_size for _ in range(ndim))
self.poly_weights = nn.Parameter(torch.randn(*poly_shape))
self.beta_weights = nn.Parameter(torch.zeros(degree + 1, dtype=torch.float32))
# Initialize weights using Kaiming uniform distribution for better training start
for conv_layer in self.base_conv:
nn.init.kaiming_uniform_(conv_layer.weight, nonlinearity='linear')
nn.init.kaiming_uniform_(self.poly_weights, nonlinearity='linear')
nn.init.normal_(
self.beta_weights,
mean=0.0,
std=1.0 / ((kernel_size ** ndim) * self.inputdim * (self.degree + 1.0)),
)
def beta(self, n, m):
return (
((m + n) * (m - n) * n ** 2) / (m ** 2 / (4.0 * n ** 2 - 1.0))
) * self.beta_weights[n]
@lru_cache(maxsize=128) # Cache to avoid recomputation of Gram polynomials
def gram_poly(self, x, degree):
p0 = x.new_ones(x.size())
if degree == 0:
return p0.unsqueeze(-1)
p1 = x
grams_basis = [p0, p1]
for i in range(2, degree + 1):
p2 = x * p1 - self.beta(i - 1, i) * p0
grams_basis.append(p2)
p0, p1 = p1, p2
return torch.concatenate(grams_basis, dim=1)
def forward_kag(self, x, group_index):
# Apply base activation to input and then linear transform with base weights
basis = self.base_conv[group_index](self.base_activation(x))
# Normalize x to the range [-1, 1] for stable Legendre polynomial computation
x = torch.tanh(x).contiguous()
if self.dropout is not None:
x = self.dropout(x)
grams_basis = self.base_activation(self.gram_poly(x, self.degree))
y = self.conv_w_fun(grams_basis, self.poly_weights[group_index],
stride=self.stride, dilation=self.dilation,
padding=self.padding, groups=1)
y = self.base_activation(self.layer_norm[group_index](y + basis))
return y
def forward(self, x):
split_x = torch.split(x, self.inputdim // self.groups, dim=1)
output = []
for group_ind, _x in enumerate(split_x):
y = self.forward_kag(_x.clone(), group_ind)
output.append(y.clone())
y = torch.cat(output, dim=1)
return y
class KAGNConv3DLayer(KAGNConvNDLayer):
def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
dropout: float = 0.0, norm_layer=nn.InstanceNorm3d, **norm_kwargs):
super(KAGNConv3DLayer, self).__init__(nn.Conv3d, norm_layer, conv3d,
input_dim, output_dim,
degree, kernel_size,
groups=groups, padding=padding, stride=stride, dilation=dilation,
ndim=3, dropout=dropout, **norm_kwargs)
class KAGNConv2DLayer(KAGNConvNDLayer):
def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
dropout: float = 0.0, norm_layer=nn.InstanceNorm2d, **norm_kwargs):
super(KAGNConv2DLayer, self).__init__(nn.Conv2d, norm_layer, conv2d,
input_dim, output_dim,
degree, kernel_size,
groups=groups, padding=padding, stride=stride, dilation=dilation,
ndim=2, dropout=dropout, **norm_kwargs)
class KAGNConv1DLayer(KAGNConvNDLayer):
def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
dropout: float = 0.0, norm_layer=nn.InstanceNorm1d, **norm_kwargs):
super(KAGNConv1DLayer, self).__init__(nn.Conv1d, norm_layer, conv1d,
input_dim, output_dim,
degree, kernel_size,
groups=groups, padding=padding, stride=stride, dilation=dilation,
ndim=1, dropout=dropout, **norm_kwargs)
测试结果
YOLOv8l summary: 388 layers, 91843560 parameters, 0 gradients
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 15/15 [00:02<00:00, 6.02it/s]
all 230 1412 0.958 0.977 0.991 0.744
c17 40 131 0.996 0.992 0.995 0.835
c5 19 68 0.94 1 0.995 0.842
helicopter 13 43 0.976 0.962 0.982 0.618
c130 20 85 0.987 1 0.995 0.678
f16 11 57 1 0.958 0.978 0.66
b2 2 2 0.861 1 0.995 0.723
other 13 86 0.973 0.965 0.977 0.548
b52 21 70 0.956 0.971 0.983 0.839
kc10 12 62 0.999 0.984 0.989 0.84
command 12 40 0.986 1 0.995 0.855
f15 21 123 0.99 1 0.995 0.705
kc135 24 91 0.968 0.989 0.99 0.707
a10 4 27 1 0.84 0.967 0.509
b1 5 20 0.98 1 0.995 0.731
aew 4 25 0.943 1 0.992 0.795
f22 3 17 0.974 1 0.995 0.749
p3 6 105 1 0.98 0.995 0.813
p8 1 1 0.79 1 0.995 0.597
f35 5 32 1 0.954 0.995 0.573
f18 13 125 0.981 0.992 0.994 0.829
v22 5 41 0.991 1 0.995 0.731
su-27 5 31 0.989 1 0.995 0.86
il-38 10 27 0.985 1 0.995 0.84
tu-134 1 1 0.768 1 0.995 0.895
su-33 1 2 1 0.785 0.995 0.752
an-70 1 2 0.849 1 0.995 0.73
tu-22 8 98 0.995 1 0.995 0.838