Key Takeaways

  • AI training and cryptographic protocols, despite aiming to extract vs. obscure structure, share a surprising architectural kinship in how they mix and scramble information.
  • "Differential cryptanalysis" is a sophisticated attack used to differentiate ciphers, highlighting that even binary-field operations can be conceptually differentiated.
  • The Feistel network, a cryptographic construction for building invertible functions from non-invertible ones, offers a robust pattern for information scrambling without loss.
  • This cryptographic pattern was adopted by neural networks in 2017, leading to "RevNets" (reversible networks) that make entire models invertible for specific layers.
  • The core benefit of the Feistel Network / Reversible Neural Network (RevNet) Construction is a drastic reduction in memory footprint during training, achieved by rematerializing intermediate activations instead of storing them.

The Feistel Network / Reversible Neural Network (RevNet) Construction

  • Invertible Function Construction: The idea is that you may have some function f which is not invertible, but you like the function because it does interesting things, like it does an MLP, for example. Or it mixes it in an interesting way. You'd like to build something out of this that is invertible. The construction we're going to make is going to be a two-input function rather than a one-input function. We're going to apply f[x]. We need to actually remember what x was, so we're going to stick x over here so that we can work backwards, and then we also can't drop y. We're going to remember y, and we're going to add them together to form this tuple.
  • Inversion Process: The way to invert this, if you think I have this output and I want to recover x and y, I can easily recover x. That's right there, I just read it off. To recover y, if this thing was called z, I can recover y by z minus f[x], because I've already recovered x. That means this construction is invertible.
  • Application to Neural Networks: This paper applied it to some layer, like a transformer layer, for example. We've got this function f, which is our transformer layer. Normally we would have just an input and then a residual connection coming out, and it gets added over here. Now, the variation of this is going to be we've got two inputs, x and y. x goes through the function, gets added to y, and then this becomes the new x, output x. Then this x becomes the output y.

When This Works (and When It Doesn't)

Reiner Pope notes the big win for this construction is during neural network training. For deep models, the memory required to store all intermediate activations for the backward pass can become the single largest footprint on high-bandwidth memory (HBM). RevNets directly address this by allowing these activations to be rematerialized on demand, effectively trading extra computation for significant memory savings. Pope says, “because it's invertible, I don't need to store this at all. I can completely rematerialize it.”

However, this approach isn't a silver bullet. If your training is already compute-bound, the additional calculations for rematerialization could slow things down further. For smaller models or tasks where HBM isn't a bottleneck, the overhead of implementing and managing reversible layers might not justify the memory benefits. It's a strategic trade-off: spend more cycles to save gigabytes.

What to Do With This

Your ML team just submitted a budget request for more expensive GPUs with higher HBM capacity, claiming it's essential for training your latest, larger model. Don't just approve it. Challenge their memory assumptions using the lens of the Feistel Network / RevNet Construction:

1. Invertible Function Construction: Ask your ML lead, "Are there specific parts of our model—like deep transformer layers—where we could embed a non-invertible function within an invertible wrapper, similar to a Feistel network, to reduce our memory footprint?" Push them to identify where intermediate activations are most costly.

2. Inversion Process: Follow up: "If we adopted such a reversible design, how precisely would we recover those intermediate states for the backward pass without storing them? What's the computational cost of this 'rematerialization' in our specific setup?" This makes them articulate the compute-memory trade-off for your use case.

3. Application to Neural Networks: Specifically inquire about existing solutions: "Have we explored applying concepts from reversible networks, like the 2017 RevNets paper, particularly for our deepest layers? How would implementing a two-input x,y structure per layer change our HBM requirements and our need for new hardware?" This prompts a discussion on architectural optimization before escalating to hardware spend.