Can KANs (Kolmogorov-Arnold Networks) do Computer Vision?
A recent article has caused quite a buzz and attracted a lot of attention — it’s about Kolmogorov-Arnold networks. The article [1] has exploded on the internet and within the community — multiple variations of these networks have been presented, heated debates, experiments are being conducted, and so on. In this article, we will dive into computer vision on KANs and attempt to answer a simple question: Can KANs do CV? For those strongly believing in open-source and open science, advanced investigation with data and code is available here.
Basics
Let’s start with very basic stuff — the math. In short, Multi-Layer Perceptrons are a non-linear function of the weighted sum of inputs, while KANs are the sum of non-linear univariate functions of inputs.
Classical Multi-Layer Perceptrons rely on widely known universal approximation theorems. With some simplification, it states that a given function can be fitted by perceptron with a wide enough hidden layer with non-linear activation:
Kolmogorov-Arnold networks rely on another theorem: Kolmogorov-Arnold representation theorem:
So, from this formula, the authors of “KAN: Kolmogorov-Arnold Networks” derived the new architecture: learnable activations on edges and summation on nodes. MLP in opposite performs fixed non-linearity on nodes and learnable linear projections on edges.
So, okay. We need to learn an activation. How could we do that? In the original paper, using of B-splines was proposed (see about splines here and here). A spline function of order n is a piecewise polynomial function of degree (n-1). The places where the pieces meet are known as knots. The key property of spline functions is that they and their derivatives may be continuous, depending on the multiplicities of the knots.
Pros and cons
The authors claim several advantages of KANs over MLP:
- Better Accuracy: KANs have showcased remarkable accuracy in various tasks compared to MLPs. KANs can represent complex multivariate functions more effectively than MLP (thanks to the Kolmogorov-Arnold Representation theorem), and that leads to better predictions.
- Interpretability: MLPs are the black boxes: it’s hard to say what’s going on inside the models, but KANs can offer us something interesting. One can decompose complex functions into simpler univariate components and derive from this some insights about models and data.
- Flexibility and Generalization: KANs provide more flexibility and better generalization capacities than traditional MLPs. Their adaptive learning of activation functions enables a more effective capture of nonlinear relationships in the data, leading to enhanced performance in generalization (but it’s not that easy).
- Robustness to Noisy Data and Adversarial Attacks: KANs demonstrate increased resilience to noisy data and adversarial attacks compared to MLPs. Their capacity to learn more robust data representations via adaptive activation functions makes them less vulnerable to disruptions and adversarial interferences.
But there are some challenges with KANs (No free lunch, remember?)
- Sensitivity to Hyperparameters: Like all neural network architectures, KANs are sensitive to hyperparameters, including learning rate, regularization strength, and network architecture. Choosing the right hyperparameters can significantly affect KANs’ performance and convergence properties, requiring careful tuning and experimentation. Also, I should mention, that there are a lot of parameters to tune (L1/L2 weights and activations parameters, dropouts, and so on) and there are no go-on recipes to train KANs at the moment. Those who remember the dawn of ConvNets should receive some flashbacks
- Computational Overhead: The computational overhead of KANs, especially during training and inference, can be challenging in resource-constrained environments. The adaptive nature of activation functions and spline parameters may require more computational resources than traditional MLPs, leading to longer training times and higher computational costs.
- Model Complexity and Scalability: KANs are scalable in terms of architecture flexibility, but deeper architectures with multiple layers and complex activation functions can increase model complexity and computational overhead. Scaling KANs to handle large-scale datasets and complex tasks while maintaining computational efficiency and model interpretability is a significant challenge.
Variations
To address some of these problems, several modifications have been proposed.
Firstly, a Fast KAN [3] version was introduced, in which B-splines are replaced by Radial Basis functions. This modification helps to reduce the computational overhead of splines.
There are also several Polynomial KANs (Legendre[2], Chebyshev [4, 2], Jacobi [7], Gram [5], Bernstein [8], etc.) and Wavelet-based KANs [6]. For more detailed information on these variations, please visit the respective repositories, listed in references.
Convolutions
What are convolutions? The most common type of convolution used is the 2D convolution layer, usually abbreviated as Conv2D. In this layer, a filter or a kernel “slides” over the 2D input data, performing an elementwise multiplication. The results are summed up into a single output pixel. The kernel performs the same operation for every location it slides over, transforming a 2D matrix of features into a different one. Although 1D and 3D convolutions share the same concept, they have different filters, input data, and output data dimensions. However, we’ll focus on 2D for simplicity. If you want to dive deeper, here is a good post: Intuitively Understanding Convolutions for Deep Learning
Typically, after a convolutional layer, a normalization layer (like BatchNorm, InstanceNorm, etc.) and non-linear activations (ReLU, LeakyReLU, SiLU, and many more) are applied. Or more formal: suppose we have an input image y, with N x N size. We omit the channel axis for simplicity, it adds another summation across the channel axis. So, first, we need to convolve it with our kernel W with size m x m:
Then, we apply batch norm and non-linearity, for example — ReLU:
Kolmogorov-Arnold Convolutions work differently: the kernel consists of a set of univariate non-linear functions. This kernel “slides” over the 2D input data, performing element-wise application of the kernel’s functions. The results are then summed up into a single output pixel. More formal: suppose we have an input image y (again), with N x N size. We omit the channel axis for simplicity, it adds another summation. So, the KAN-based convolutions are defined as:
Each phi is a univariate non-linear learnable function with its own trainable parameters. In the original paper, the authors propose to use this form of the functions:
The authors propose to choose SiLU as b(x) activation:
Along with this way of learning phi, using RBFs and different polynomials instead of splines were proposed recently. So that’s it — the KAN Convolutions.
To sum up, the “traditional” convolution is a matrix of weights, while Kolmogorov-Arnold convolutions are a matrix of functions.
That’s the primary difference. The key question here is how should we construct these univariate non-linear functions. The answer is the same as for KANs: B-splines, polynomials, RBFs, Wavelets, etc.
Experiments
Let’s dive into experiments.
MNIST
First of all, let's start with MNIST!
Baseline models were chosen to be simple networks with 4 convolutional layers. To reduce dimensionality, convolutions with dilation=2 were used. In the model, the second and third convolutions had dilation=2.
The number of channels in the convolutions was the same for all models: 32, 64, 128, 256. After the convolutions, Global Average Pooling was applied, followed by a linear output layer. In addition, dropout layers have been used: with p = 0.25 in convolutional layers, and with p = 0.5 before the output layer. Implementations of all models you can check out here.
Also, augmentations have been used:
from torchvision.transforms import v2
transform_train = v2.Compose([
v2.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.ColorJitter(brightness=0.2, contrast=0.2),
v2.ToTensor(),
v2.Normalize((0.5,), (0.5,))
])
In the case of classic convolutions, a traditional structure was used: convolution — batch normalization — ReLU.
In addition, we also need to investigate the impact of different normalization layers inside KAN Convolutions and the impact of L1 weights regularization. The Norm Layer column in all tables indicates which normalization layer was used during the experiment, the Affine column indicates whether or not the affine parameter of the normalization layer has been set as True or False.
All experiments were conducted on an NVIDIA RTX 3090 with identical training parameters.
Classical convolutions slightly outperform the KAN-based model, which has 34 times more parameters and requires almost four times more inference time. This performance doesn’t seem as “revolutionizing” as expected. Next, let’s test the FastKAN (splines replaced by RBF’s) version.
While it does work faster, it still underperforms compared to traditional convolutions.
Okay then, let’s choose another basic function: Gram Polynomials — a type of discrete orthogonal polynomials. In theory, they should work pretty well in CV tasks, because images are discrete too.
Gram-based KAN ConvNets outperform traditional ones but are 2.5 times slower and have almost five times more parameters. L1-weights regularization slightly decreases model performance, but this is an area for further improvement.
CIFAR 100
Next, let’s discuss CIFAR 100.
Baseline models were chosen to be simple networks with 8 convolutional layers. To reduce dimensionality, convolutions with dilation=2 were used. In the model, the second, third, and sixth convolutions had dilation=2.
The number of channels in the convolutions was the same for all models: 16, 32, 64, 128, 256, 256, 512, 512. All the rest is the same as in MNIST Case: after the convolutions, Global Average Pooling was applied, followed by a linear output layer. In addition, dropout layers have been used: with p=0.25 in conv layers, and p=0.5 before the output layer. In the case of classic convolutions, a traditional structure was used: convolution — batch normalization — ReLU. Implementations of all models you can check out here.
Also, augmentations have been used:
from torchvision.transforms import v2
from torchvision.transforms.autoaugment import AutoAugmentPolicy
transform_train = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
v2.AutoAugment(AutoAugmentPolicy.IMAGENET),
v2.AutoAugment(AutoAugmentPolicy.SVHN),
v2.TrivialAugmentWide()]),
v2.ToTensor(),
v2.Normalize((0.5,), (0.5,))
])
Only the results of the Gram-based KAN Convs version are shown vs baseline models.
Gram-based KAN Convolutions perform better, albeit with slightly more time overhead and significantly more parameters (over 20 times more). BatchNorm2D appears to be the best option for inner feature normalization within KAN Convolutions. Using Gram polynomials as a basis function for Kolmogorov-Arnold Convolutions seems promising for further experiments on ImageNet1k and other “real” datasets.
For full tables and results using other basic functions (like Wavelets and other types of polynomials) on MNIST, CIFAR10, and CIFAR100, refer to the reports here.
Wanna do your own runs?
In my repository, you can find implementations of 1D, 2D, and 3D Kolmogorov-Arnold Convolutional Layers with different basis functions (B-splines, RBFs, Wavelets, and several polynomials: Legendre, Gram, Chebyshev, Bernstein, Jacobi).
They are pretty easy to use, it’s a drop-in replacement of Pytorch Convolutional layers with extra parameters to control learnable function behavior. Here is a simple KANConvNet that has been used in MNIST experiments:
import torch
import torch.nn as nn
from kan_convs import KANConv2DLayer
class SimpleConvKAN(nn.Module):
def __init__(
self,
layer_sizes,
num_classes: int = 10,
input_channels: int = 1,
spline_order: int = 3,
groups: int = 1):
super(SimpleConvKAN, self).__init__()
self.layers = nn.Sequential(
KANConv2DLayer(input_channels, layer_sizes[0], spline_order, kernel_size=3, groups=1, padding=1, stride=1,
dilation=1),
KANConv2DLayer(layer_sizes[0], layer_sizes[1], spline_order, kernel_size=3, groups=groups, padding=1,
stride=2, dilation=1),
KANConv2DLayer(layer_sizes[1], layer_sizes[2], spline_order, kernel_size=3, groups=groups, padding=1,
stride=2, dilation=1),
KANConv2DLayer(layer_sizes[2], layer_sizes[3], spline_order, kernel_size=3, groups=groups, padding=1,
stride=1, dilation=1),
nn.AdaptiveAvgPool2d((1, 1))
)
self.output = nn.Linear(layer_sizes[3], num_classes)
self.drop = nn.Dropout(p=0.25)
def forward(self, x):
x = self.layers(x)
x = torch.flatten(x, 1)
x = self.drop(x)
x = self.output(x)
return x
In the repository, you could find VGG-like, ResNet-like, DenseNet-like, Unet-like, and U2Net-like models, accelerate-based training scripts, and pre-trained on ImageNet1k weights (VGG11-like for now, but more models are in training right now).
Conclusion
So, can KANs do CV?
It seems that yes, KANs can!
Can it do better than CNNs? Well, we still need to figure it out.
The Multi-Layer Perceptron (MLP) has been in use for years and is overdue for an upgrade. We’ve seen this kind of shift before. For instance, six years ago, Long Short-Term Memory (LSTM) networks, once a staple in sequence modeling, were replaced by transformers as the standard building block for language model architecture. A similar shift for MLPs would be intriguing.
Convolutional networks, which have dominated for many years (and still serve as the workhorse of Computer Vision), were eventually challenged by Vision Transformer (ViT) models. Perhaps it’s time for a new leader in the field? However, before that happens, the community needs to find effective methods to train Kolmogorov-Arnold Networks (KAN), Convolutional Kolmogorov-Arnold Networks (ConvKANs), and ViT-KANs, and address the challenges these models present.
While I’m excited about this new architecture and initial experiments show promising results, I remain somewhat skeptical. Further experiments are necessary. Stay tuned, we are going to dive deeper.
References
- Ziming Liu et al., “KAN: Kolmogorov-Arnold Networks”, 2024, arXiv. https://arxiv.org/abs/2404.19756
- https://github.com/1ssb/torchkan
- https://github.com/ZiyaoLi/fast-kan
- https://github.com/SynodicMonth/ChebyKAN
- https://github.com/Khochawongwat/GRAMKAN
- https://github.com/zavareh1/Wav-KAN
- https://github.com/SpaceLearner/JacobiKAN
- https://github.com/IvanDrokin/torch-conv-kan
- https://github.com/KindXiaoming/pykan
- https://github.com/Blealtan/efficient-kan
Acknowledgments
I want to thank Elena Ericheva and Anton Klochkov for the helpful feedback and reviews of the piece.
Thank you for reading! Before you go, if you want to support this research:
- Be sure to clap and follow the writer ️👏
- Subscribe to my Substack
- Check out the repo, try it yourself, raise an issue with feedback or proposals, star the repository of torch-conv-kans — that’s helping me to keep on working on this research.