GRAM(GRAM可能是一个新提出的模型或方法的缩写,这里我们根据上下文进行解释)受到诸如TorchKAN和ChebyKAN等Kolmogorov-Arnold网络(KAN)替代方案的启发。GRAM引入了一种简化的KAN模型,但同时利用了Gram多项式变换的简单性。它与其他替代方案的不同之处在于其独特的离散性特征。与其他在连续区间上定义的多项式不同,Gram多项式是在一组离散点上定义的。GRAM的这种离散性为处理像图像和文本数据这样的离散数据集提供了一种新颖的方法。
YoloV8改进策略:卷积篇|Kan行天下之GRAM,KAN遇见Gram多项式V2版本-LMLPHP

Kolmogorov-Arnold Networks (KAN): KAN是一种基于Kolmogorov-Arnold表示定理的神经网络模型,该定理表明任何多元连续函数都可以表示为一系列一元函数的复合。KAN通常用于函数逼近和复杂系统的建模。

Gram Polynomials: Gram多项式是一种在离散点上定义的特殊多项式,与在连续区间上定义的传统多项式不同。这种离散性使得Gram多项式在处理离散数据集时具有独特的优势。

离散数据集处理: 在机器学习和数据科学中,处理离散数据集(如图像和文本)是一个重要任务。图像可以视为像素的离散集合,而文本则可以视为字符或单词的离散序列。GRAM通过其离散性特征,为这类数据的处理提供了一种新的视角和方法。

模型简化与效率: GRAM通过简化KAN模型并引入Gram多项式变换,旨在提高模型的效率和实用性。这种简化可能使得GRAM在保持一定性能的同时,具有更低的计算复杂度和更快的训练速度。

应用前景: GRAM的离散性特征使其在图像处理、自然语言处理等领域具有潜在的应用前景。例如,在图像分类任务中,GRAM可能能够更有效地提取图像中的特征;在自然语言处理任务中,GRAM可能能够更准确地捕捉文本中的语义信息。

GRAM通过结合KAN和Gram多项式的思想,为离散数据集的处理提供了一种新颖而有效的解决方案。随着进一步的研究和应用,GRAM有望在多个领域展现出其独特的优势和价值。

代码

代码在原作者的基础上做了修改,代码如下:

# Based on this: https://github.com/Khochawongwat/GRAMKAN/blob/main/model.py
from functools import lru_cache
import torch
import torch.nn as nn
from torch.nn.functional import conv3d, conv2d, conv1d


class KAGNConvNDLayer(nn.Module):
    def __init__(self, conv_class, norm_class, conv_w_fun, input_dim, output_dim, degree, kernel_size,
                 groups=1, padding=0, stride=1, dilation=1, dropout: float = 0.0, ndim: int = 2.,
                 **norm_kwargs):
        super(KAGNConvNDLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.degree = degree
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.base_activation = nn.SiLU()
        self.conv_w_fun = conv_w_fun
        self.ndim = ndim
        self.dropout = None
        self.norm_kwargs = norm_kwargs
        self.p_dropout = dropout
        if dropout > 0:
            if ndim == 1:
                self.dropout = nn.Dropout1d(p=dropout)
            if ndim == 2:
                self.dropout = nn.Dropout2d(p=dropout)
            if ndim == 3:
                self.dropout = nn.Dropout3d(p=dropout)

        if groups <= 0:
            raise ValueError('groups must be a positive integer')
        if input_dim % groups != 0:
            raise ValueError('input_dim must be divisible by groups')
        if output_dim % groups != 0:
            raise ValueError('output_dim must be divisible by groups')

        self.base_conv = nn.ModuleList([conv_class(input_dim // groups,
                                                   output_dim // groups,
                                                   kernel_size,
                                                   stride,
                                                   padding,
                                                   dilation,
                                                   groups=1,
                                                   bias=False) for _ in range(groups)])

        self.layer_norm = nn.ModuleList([norm_class(output_dim // groups, **norm_kwargs) for _ in range(groups)])

        poly_shape = (groups, output_dim // groups, (input_dim // groups) * (degree + 1)) + tuple(
            kernel_size for _ in range(ndim))

        self.poly_weights = nn.Parameter(torch.randn(*poly_shape))
        self.beta_weights = nn.Parameter(torch.zeros(degree + 1, dtype=torch.float32))

        # Initialize weights using Kaiming uniform distribution for better training start
        for conv_layer in self.base_conv:
            nn.init.kaiming_uniform_(conv_layer.weight, nonlinearity='linear')

        nn.init.kaiming_uniform_(self.poly_weights, nonlinearity='linear')
        nn.init.normal_(
            self.beta_weights,
            mean=0.0,
            std=1.0 / ((kernel_size ** ndim) * self.inputdim * (self.degree + 1.0)),
        )

    def beta(self, n, m):
        return (
                       ((m + n) * (m - n) * n ** 2) / (m ** 2 / (4.0 * n ** 2 - 1.0))
               ) * self.beta_weights[n]

    @lru_cache(maxsize=128)  # Cache to avoid recomputation of Gram polynomials
    def gram_poly(self, x, degree):
        p0 = x.new_ones(x.size())

        if degree == 0:
            return p0.unsqueeze(-1)

        p1 = x
        grams_basis = [p0, p1]

        for i in range(2, degree + 1):
            p2 = x * p1 - self.beta(i - 1, i) * p0
            grams_basis.append(p2)
            p0, p1 = p1, p2

        return torch.concatenate(grams_basis, dim=1)

    def forward_kag(self, x, group_index):
        # Apply base activation to input and then linear transform with base weights
        basis = self.base_conv[group_index](self.base_activation(x))

        # Normalize x to the range [-1, 1] for stable Legendre polynomial computation
        x = torch.tanh(x).contiguous()

        if self.dropout is not None:
            x = self.dropout(x)

        grams_basis = self.base_activation(self.gram_poly(x, self.degree))

        y = self.conv_w_fun(grams_basis, self.poly_weights[group_index],
                            stride=self.stride, dilation=self.dilation,
                            padding=self.padding, groups=1)

        y = self.base_activation(self.layer_norm[group_index](y + basis))

        return y

    def forward(self, x):

        split_x = torch.split(x, self.inputdim // self.groups, dim=1)
        output = []
        for group_ind, _x in enumerate(split_x):
            y = self.forward_kag(_x.clone(), group_ind)
            output.append(y.clone())
        y = torch.cat(output, dim=1)
        return y


class KAGNConv3DLayer(KAGNConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
                 dropout: float = 0.0, norm_layer=nn.InstanceNorm3d, **norm_kwargs):
        super(KAGNConv3DLayer, self).__init__(nn.Conv3d, norm_layer, conv3d,
                                              input_dim, output_dim,
                                              degree, kernel_size,
                                              groups=groups, padding=padding, stride=stride, dilation=dilation,
                                              ndim=3, dropout=dropout, **norm_kwargs)


class KAGNConv2DLayer(KAGNConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
                 dropout: float = 0.0, norm_layer=nn.InstanceNorm2d, **norm_kwargs):
        super(KAGNConv2DLayer, self).__init__(nn.Conv2d, norm_layer, conv2d,
                                              input_dim, output_dim,
                                              degree, kernel_size,
                                              groups=groups, padding=padding, stride=stride, dilation=dilation,
                                              ndim=2, dropout=dropout, **norm_kwargs)


class KAGNConv1DLayer(KAGNConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
                 dropout: float = 0.0, norm_layer=nn.InstanceNorm1d, **norm_kwargs):
        super(KAGNConv1DLayer, self).__init__(nn.Conv1d, norm_layer, conv1d,
                                              input_dim, output_dim,
                                              degree, kernel_size,
                                              groups=groups, padding=padding, stride=stride, dilation=dilation,
                                              ndim=1, dropout=dropout, **norm_kwargs)

测试结果

YOLOv8l summary: 388 layers, 91843560 parameters, 0 gradients
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 15/15 [00:02<00:00,  6.02it/s]
                   all        230       1412      0.958      0.977      0.991      0.744
                   c17         40        131      0.996      0.992      0.995      0.835
                    c5         19         68       0.94          1      0.995      0.842
            helicopter         13         43      0.976      0.962      0.982      0.618
                  c130         20         85      0.987          1      0.995      0.678
                   f16         11         57          1      0.958      0.978       0.66
                    b2          2          2      0.861          1      0.995      0.723
                 other         13         86      0.973      0.965      0.977      0.548
                   b52         21         70      0.956      0.971      0.983      0.839
                  kc10         12         62      0.999      0.984      0.989       0.84
               command         12         40      0.986          1      0.995      0.855
                   f15         21        123       0.99          1      0.995      0.705
                 kc135         24         91      0.968      0.989       0.99      0.707
                   a10          4         27          1       0.84      0.967      0.509
                    b1          5         20       0.98          1      0.995      0.731
                   aew          4         25      0.943          1      0.992      0.795
                   f22          3         17      0.974          1      0.995      0.749
                    p3          6        105          1       0.98      0.995      0.813
                    p8          1          1       0.79          1      0.995      0.597
                   f35          5         32          1      0.954      0.995      0.573
                   f18         13        125      0.981      0.992      0.994      0.829
                   v22          5         41      0.991          1      0.995      0.731
                 su-27          5         31      0.989          1      0.995       0.86
                 il-38         10         27      0.985          1      0.995       0.84
                tu-134          1          1      0.768          1      0.995      0.895
                 su-33          1          2          1      0.785      0.995      0.752
                 an-70          1          2      0.849          1      0.995       0.73
                 tu-22          8         98      0.995          1      0.995      0.838
07-12 00:34