Few papers released this year have spurred more controversy than the presentation of Kolmogorov-Arnold Networks, or KANs, a new type of neural network that could potentially substitute one of the main pillars of the AI revolution, and essential components of tools ranging from ChatGPT to Stable Diffusion, or Sora.
In fact, KANs could change how we create AI entirely, as they aim to switch the very foundation on which we have created the AI industry.
This may sound like a handful of overly pompous claims, but as you'll see right now, this is exactly what we are looking at.
You are probably sick of AI newsletters talking about how this or that **just** happened. And these newsletters abound because coarsely talking about events and things that already took place is easy, but the value provided is limited and the hype exaggerated.
However, newsletters talking about what will happen are a rare sight. If you're into easy-to-understand insights looking into the future of AI before anyone else does, TheTechOasis newsletter might be perfect for you.
🏝️🏝️ Subscribe today below:
Why Does AI Work?
To understand the huge impact that KAN networks can have on the industry, we first need to understand the main building blocks of current frontier AI models.
Although I really recommend reading my blog post as an introduction to AI, out of the different fields in AI, none is as influential and important today as Deep Learning (DL).
The Blind Belief That Delivered
It is hard to explain the blind reliance of visionaries like Yann LeCun or Geoffrey Hinton on the promises that Deep Learning failed to deliver for decades.
But it certainly panned out.
Today, nothing creates more value (at least in private and open markets) than Deep Learning, and one can confidently claim that AI has already changed the world in several ways.
Just to name a few:
- Millions of biologists worldwide use AlphaFold. With its newer version, we now have a much better understanding of proteins, the building blocks of every living thing.
- AI is already superhuman in specific tasks, like board games such as Go or Chess, enabling us to discover new moves that have been unknown for thousands of years.
- Using semantic space theory, AI is also helping us transform our understanding of emotions and even discover new smells.
And all these achievements can't be explained without Deep Learning. So, what is it?
In short, DL involves using neural networks to parameterize unknown multivariate functions, which is a rather corny way of describing functions that take in multiple inputs and generate at least one output (but it can be many).
In layman's terms, most current AI models are simply a way of learning a complex relationship between a set of input variables and a desired causal output we can observe (and thus, we know it exists) but can't explain.
For instance, ChatGPT is a model that takes a set of inputs (words) in a sequence and generates the next one:
It does so because researchers posited that next-word prediction would be a good proxy for teaching machines to speak our language and make them intelligent (or at least help them reasonably imitate human intelligence).
ChatGPT comes in handy because we know both the sets of inputs and the output (we know how words follow each other and have plenty of training data to use), but we can't claim to know the relationships between words that help us statistically predict the next one every single time.
In other words, although we can observe and execute that exact behavior ourselves (writing or speaking the language), we can't put pen-in-paper over the statistical function that generates that next word.
And this is where neural networks come in to parametrize that precise function.
'Parametrize' means finding the set of parameters that allows us to map the inputs into the outputs, just like 'm' and 'b' parametrize the linear relationship between inputs 'x' and outputs 'y' such as 'y=mx +b.'
Sadly, although we know that 'm' and 'b' account for the slope and the y-intercept, we don't know the full significance of each parameter in a neural network, hence why we say they are 'black boxes.'
But we still haven't answered the real question: why do neural networks work?
The Universal Approximation Theorem
As mentioned, neural networks are humanity's best bet at approximating these unknown functions that govern input/output relationships we know exist.
However, out of all popular neural network architectures today, one has always been present, no matter the industry or use case.
If we take Large Language Models (LLMs) as an example, even though almost all credit is given to the attention mechanism, the unequivocal secret behind their success, no frontier AI model today would exist without MultiLayer Perceptrons, or MLPs.
Initially "proposed" by Frank Rosenblatt and actually invented by Alexey Grigorevich Ivakhnenko, MLPs are considered the main element explaining Deep Learning's triumph.
Below is a standard depiction of a shallow (1 layer) MLP, in which the relationship between the inputs and the outputs (the black box we showed earlier) is defined by scaled linear combinations of hidden units, usually known as 'neurons.'
In particular, each of the 'hidden units' depicted as 'h1' to 'h5', perform the calculation below:
Broadly speaking, each neuron is a linear combination of the inputs in the previous layer (in this case, the inputs to the model), which are weighted by parameters w1-w4 and the neuron's bias.
These 'weights' are the variables we adapt during training so that the model learns to perform the desired prediction.
Then, the entire calculation is driven through an activation function that determines whether the neuron activates or not.
This activation function is crucial to help the model approximate non-linear functions. In other words, it helps the model adapt to complex non-linear relationships between inputs and outputs, which is the common theme in our world.
Also, it explains why they are called 'neurons', as they imitate the firing/non-firing behavior of brain neurons.
For example, the graph below shows the fairly common ReLU activation function.
For each neuron that uses this function, if the result is positive, the neuron fires; if the result is negative, the activation function brings its value to zero, and thus, the neuron doesn't activate for that particular prediction.
But all things understood, why do we use this precise architecture?
The reason is the irresistible promise of the universal approximation theorem.
A Principal to Rule Them All
The UAT is a formal proof "that a neural network with at least one hidden layer can approximate any continuous function to arbitrary precision with enough hidden units."
In layman's terms, a sufficient combination of neurons like the above eventually uncovers or approximates any continuous mapping between the inputs and outputs, be that finding patterns in housing data in order to predict their prices or finding patterns in amino acids to predict protein structures (AlphaFold).
In the particular case of ChatGPT, with enough neurons (which turns out to be billions in some cases), one can effectively define a function that predicts the next word, given a set of previous words, across the entire Internet knowledge.
If you want to genuinely understand why MLPs work in an intuitive and visual way, I recently published an article in my blog addressing why standard neural networks work.
But now, a group of researchers from universities like MIT or Caltech say there's a better way.
And that way is Kolmorogov-Arnold Networks.
KANs, A New Foundation?
As you will see in a minute, I was not overstating the impact that KAN networks can have on the future of AI.
The Current Sacrifices
Despite their unequivocal importance, MLPs are a pain in the butt. This becomes apparent just by looking at Large Language Models (LLMs).
Today, according to research by Meta, 98% of the model's FLOPs (the amount of operations the GPUs perform to run the model) come only from these layers, with just 2% coming from other critical pieces like the attention mechanism, the core operator behind LLMs' success.
The reason is none other than our friend UAT. While UAT promises that you will eventually approximate the mapping you desire, it doesn't specify how many neurons you will need, which usually is a lot.
So… what if there was a better way?
A New Theorem
Although the Kolmogorov-Arnold Representation Theorem was proven in the 1950s, its application to Deep Learning was considered impractical. Thus, it was not considered a way to approximate these complex relationships.
Now, the team behind KANs claims they solved the problem by finding a way to parameterize this theorem as a neural network. In other words, they have found the neural network that applies this theorem.
If we think about neural networks as graphs, where neurons are nodes, and the connections between these neurons are edges, we find the biggest difference between MLPs and KANs.
If we recall the structure of MLPs earlier discussed, each neuron is a combination of learnable weights, represented as the connections to that neuron, which is then passed through a fixed activation function.
Consequently, while weights are located at the edges, the nodes are the actual neurons, where a fixed activation function decides whether that node fires or not.
On the other hand, KANs modify this learning process. While MLPs learn the weights on the edges of the network, KANs put learnable activation functions, represented as splines (more on that later), in the edges, and the nodes are simply a summation of these splines:
Ok, but why?
Because KANs follow the Kolmogorov-Arnold Representation Theorem (KART), which states that a summation of univariate functions can find the mapping to a multivariate function with arbitrary precision.
Visualizing the Key Intuition
This feels extremely hard to understand, but please bear with me; you'll see everything crystal clear in a minute.
Besides giving plenty of proof that they work, the research team also included some very intuitive examples like the one below that will make it very easy for you to understand why KANs work:
If we want to train a KAN network to approximate the equation exp(sin(pi*x)+ y²) above— aka take a form that will output the same results as the equation itself — we can break this function into three steps:
- Squaring 'y'
- calculating the sine function of the product 'π * x'
- Apply the exponential function to the summation of steps 1 and 2
Fascinatingly, after training, the KAN network autonomously reduces itself into three functions:
- a square function applied to 'y',
- a sine function with a frequency proportional to the constant π, applied to x.
- As nodes are sums, the results are added and finally used as input to the exponential function, giving the final result.
Simply put, the KAN network has adapted to simulate the exact functions, which in turn means that running 'x' and 'y' values through that network will output the same results as if we plugged the values into the original symbolic equation.
Of course, this is a pragmatically speaking pointless example because we already know the symbolic equation that rules this input/output relationship.
But what if we didn't know the function? That's precisely the point of AI models like ChatGPT, right?
Therefore, neural networks are extensively used to discover new laws, even to rediscover known ones, like gravity. But you don't have to believe what I tell you blindly; I can prove it.
In 2022, a group of researchers rediscovered Newton's gravity laws with a neural network. They fed it the observations (when 'x' happens, 'y' happens) based on known solar system dynamics, and the model 'discovered' the laws.
In fact, the power of neural networks to uncover nuanced patterns in data allowed it to discover gravity despite not being told crucial attributes like planet mass (it discovered those, too).
That's why neural networks are so valuable; they stumble upon things that humans don't, and eventually adapt to find an approximation to the function (the equation) that governs that mapping.
However, the KAR theorem mentioned the idea of arbitrary precision. In other words, the problem isn't about the architectural principles but knowing how big it needs to be.
In the case of MLPs, that means adding more hidden units or layers. But how does this work for KANs?
From Coarse to Fine-Grained
Using B-splines as the default function wasn't an arbitrary choice.
B-Splines aren't only smooth (and, thus, guarantee continuity), but they are extremely malleable; a spline can change form on one side of the curve without impacting the other parts of the spline (or at least not too much), meaning that you can really tune these functions incredibly well.
For the sake of length, I can't really get into detail as to why B-splines, in particular, are so great. However, this amazing YouTube video does the job better than I could ever explain.
That said, let's visualize this a little bit more.
If MLPs increase granularity by a hefty increase in neurons, in KANs' case, precision can be increased by 'grid extension,' as splines, or univariate functions (one variable as input, one variable as output, or 'f(x) = y') have boundaries.
Thus, you can increase the resolution of the spline indefinitely, as shown to the right, by increasing the number of grid regions in that particular spline, as shown below, where a spline initially divided into 5 regions is divided into ten.
But what are these regions?
Splines have a particular property known as control points. As shown below, these control points determine a spline's shape.
These control points are linked to a set of basis functions (shown in the previous image as little bell-shaped functions under the actual spline) that govern the spline's behavior in that particular section of the curve.
The parameters of these basis functions are the weights that KAN networks actually adapt to learn, aka the equivalent to MLP weights.
Importantly, the effect of each control point is localized, so moving the second point affects a specific part of the curve, while moving the fourth point impacts the right end of the curve, all the while guaranteeing that the curve remains smooth and connected.
Consequently, if we need our splines to take more complex forms, we simply increase the number of basis functions.
However, we still haven't clarified why we KANs take over MLPs.
Interpretability and speed
As we have seen above, KANs are beautifully interpretable (at least at lower dimensions), as we can clearly see how the network has approximated the objective function.
But they are also much, much cheaper to train. Thus, KANs are spurring quite a hot debate, with overly enthusiastic claims that 'Deep Learning as we know it is dead.'
But is it?
Well, hold your horses for one second.
- KANs are unproven at scale. We don't know whether the clear advantages they have shown with small trainings still apply at large-scale training pipelines like the ones we are growing accustomed to.
- They haven't been tested in sequence-to-sequence models. In other words, we don't know if KANs work in LLMs.
- They are not adapted for GPUs. Researchers specifically mentioned they committed to theoretical and low-scale proof that KANs are better than MLPs, but not to prove their efficiency with our hardware. Current Deep Learning architectures like the Transformer have been basically designed specifically for GPUs. Hence, unless KANs achieve such efficiencies, they aren't going to be used, period.
- KANs are also much slower to train and run. That being said, new implementations like FastKAN, are appearing, increasing the inference speed by more than three times… and it's only been a few weeks since the original release.
A New Age for AI?
Whenever something appears to disrupt the foundation of a technology that is already trillion-dollar-level, we must be very careful with our statements.
Indeed, Kolmogorov-Arnold Networks (KANs) have recently sparked significant debate and hype as a potential revolutionary neural network model that could replace the current foundations of deep learning.
Indeed, one could argue that KANs could change it all.
Despite their alleged advantages, however, KANs remain unproven at scale and specifically for LLMs; they aren't optimized for current GPU hardware, and are slower to train than existing models.
But if they deliver on their promise, the entire world of AI will be thrown into chaos, just like our world when we realized that the Earth was not at the center of the universe.
I mean this because a large amount of the current investment in AI is based on the idea that building the next frontier of AI is extremely expensive. If something with three-order-of-magnitude cost decrease potential comes up, companies behind those insane investments might reconsider their position if AI suddenly becomes simpler and cheaper to create.
We don't know the answer, but the excitement is hard to hide.
On a final note, if you have enjoyed this article, I share similar thoughts in a more comprehensive and simplified manner for free on my LinkedIn.
If preferable, you can connect with me through X.
Looking forward to connecting with you.