While doing some research, I came across the term “Sponge function”. After playing around with them and implementing one inside my kernel, I decided to write this blog post about how to write a simplified version. To keep low-level cryptographic code to a minimum, we will be relying on the hash function MD5. Buckle in, this is going to be a long one.

This article will start from the simple concept of an MD5 hash, and incrementally build on it until we implement a lot of common functionality that seems like black boxes. Every step should be small enough to be digested individually, while still contributing to the overall understanding of the topic. If anything is unclear, feel free to discuss it in the comments. The post is designed so that you can pause and play around with the concepts yourself any time you want, or speed through it if that is what you prefer.

Since we are going to base our project on the MD5 hash function, let’s set aside a small section to go through what it is. We will treat MD5 as a black box and ignore any complicated details for the sake of brevity.

MD5 - Background

MD5 is a cryptographic hash function that maps an arbitrary amount of data into 16 bytes (or 128 bits). In its heyday, MD5 was the go-to choice for hashing passwords, checking files for corruption, and tagging data against tampering. Those days are long gone. It has been considered broken for some time, and it is not recommended using it for anything security-related. But it is well-known and has been implemented for practically any computing device ever created. Luckily for us, Python comes bundled with a collection of hash functions in the hashlib module. Let’s quickly see how it works.

In [3]:
md5(b"Test").hex()
md5(b"Test 123").hex()
Out [3]:
'0cbc6611f5540bd0809a388dc95a615b'
Out [3]:
'f3957228139a2686632e206478ad1c9e'

As we can see, inputs of different lengths map to fixed-size outputs, and small changes in the input lead to completely different output values. Basically what we would expect from a hash function. In this article, we will be (ab)using the MD5 hash function to create a sponge function. But before we can do that, we need to know a sponge function is.

Sponge functions

A sponge function is a cryptographic function that can “absorb” any number of bits and “squeeze” any number of bits, like a sponge. This is different from what we observed about MD5; while MD5 will only produce fixed-size outputs of 16 bytes, a sponge function can output 1 byte, 26 bytes, 5000 bytes or any number that you like. This sounds fun and could be useful for a lot of different tasks, so let’s do some unholy programming and turn MD5 into a sponge function.

Sponge functions are fascinating. You can use a sponge function as a hash function, random number generator, Message Authentication Code or for data encryption. It would be apt to describe one as a cryptographic Swiss army knife.

Theory

In order to create a sponge function, we need an internal state (which is just a buffer), and function to pseudorandomly transform one state into another. We will take advantage of the two properties of the MD5 hash, our state buffer will be the 16 bytes of MD5 output and our transform function will be MD5 itself.

Sponge functions keep most of their internal state hidden. Both the bits absorbed and the bits squeezed only touch a small portion of it, so the output never reveals the full state of the function.

  • The first step is to initialize the state, either to zero or to a sensible default value.
  • For each byte of the input
    • The first byte of the state is XOR’ed with the input byte.
    • State is replaced by MD5(State).

This process absorbs all the input data into the state. After we have absorbed our input, we can squeeze out as many output bytes as we want by following a very similar process. For each byte that we want to produce:

  • Output the first byte of the state.
  • Transform the state using MD5.

Warning! You probably don’t want to use this for anything too sensitive. This is a proof-of-concept implementation using the broken MD5 function as a base. At least pick something better like ChaCha20 or SHA-512. In general, we want a large state and a transform function that mixes the state really well.

Implementation

Now that we have briefly gone over the theory, let’s get to the implementation. We will take this step by step and implement each operation we mentioned above. Our first step is the transform function.

Transform function

According to the theory we outlined above, we need a transform function that will take our state and pseudorandomly map it to another state. In our case, the MD5 hash function will be doing the heavy lifting for us. And by heavy lifting I mean MD5 will be doing pretty much the whole job.

We can transform the current state by passing it through the MD5 function. Here’s a small demonstration.

In [5]:
# Initial state
md5(b"").hex()

# Transform once
md5(md5(b"")).hex()

# Transform again
md5(md5(md5(b""))).hex()

# And so on...
Out [5]:
'd41d8cd98f00b204e9800998ecf8427e'
Out [5]:
'59adb24ef3cdbe0297f05b395827453f'
Out [5]:
'8b8154f03b75f58a6c702235bf643629'

Looks like everything is working. Let’s encapsulate this in a method of the Sponge class. Every time we absorb or squeeze a byte, we will mix the state using this method.

In [6]:
class Sponge(Sponge):
    def transform(self):
        self.state = md5(self.state)

Initialization

As we mentioned before, the state needs to be initialized before we start absorbing and squeezing any bits. Since we are using MD5, we want our state to be 16 bytes. Fortunately, MD5 makes sure that no matter the value we provide, the state will end up being 16 bytes. So we can pick any value, including an empty string. Let’s go with that option.

In [7]:
class Sponge(Sponge):
    def __init__(self):
        self.state = b""
        self.transform()

Let’s see if everything is working. After creating a Sponge instance, we should be greeted with the MD5 of the empty string, d41d8cd98f00b204e9800998ecf8427e.

In [8]:
s = Sponge()
s.state.hex()
Out [8]:
'd41d8cd98f00b204e9800998ecf8427e'

Absorbing a byte

Using the logic from the theory section, we can write the code for absorbing a single byte easily. We will replace the first byte of the state with input XOR first byte, and then transform the state.

In [9]:
class Sponge(Sponge):
    def absorb_byte(self, byte):
        self.state[0] = byte ^ self.state[0]
        self.transform()

We can quickly test that this results in different states after we absorb different data. Let’s try to absorb [1,2] and [2,1] and observe the difference in the states.

In [10]:
s = Sponge()
s.absorb_byte(1)
s.absorb_byte(2)

s.state.hex()
Out [10]:
'29a3a137fccfa18e5cfb5054b13aa412'
In [11]:
s = Sponge()
s.absorb_byte(3)
s.absorb_byte(4)

s.state.hex()
Out [11]:
'0291c72acd7e7da67bedcb15aa4733c6'

Absorbing a buffer

Generalizing this to buffers of arbitrary sizes is trivial. Just iterate over a buffer and absorb the bytes one-by-one. This is a useful abstraction because in real code we will commonly work with buffers instead of individual bytes.

In [12]:
class Sponge(Sponge):
    def absorb(self, buffer):
        for byte in buffer:
            self.absorb_byte(byte)

Here’s a quick sanity check: Our state should be different from the empty state after absorbing bytes. Let’s quickly verify this before moving on.

In [13]:
s = Sponge()
s.absorb(b"Test")
s.state.hex()
Out [13]:
'28a7cbf238c85bad13cc0fc4933a68ae'

Squeezing a byte

Since we don’t need to do any input-mixing, our squeeze logic will be simpler than our absorb logic. Following the theory part, we will output the first byte and transform the state again in order to produce one byte.

In [14]:
class Sponge(Sponge):
    def squeeze_byte(self):
        byte = self.state[0]
        self.transform()
        return byte

Let’s produce some bytes and see if it’s working.

In [15]:
s = Sponge()
s.absorb(b"Test")
[s.squeeze_byte() for _ in range(5)]
Out [15]:
[40, 243, 39, 189, 220]

Squeezing a buffer

Going from extracting single bytes to exctracting buffers is not too difficult. We can use a list comprehension to write this in a concise way.

In [16]:
class Sponge(Sponge):
    def squeeze(self, size):
        buf = [self.squeeze_byte() for _ in range(size)]
        return bytes(buf)
In [17]:
s = Sponge()
s.absorb(b"Test")
s.squeeze(5).hex()
Out [17]:
'28f327bddc'

It might seem like a very small amount of code, but this is all we need. It might be useful to add some convenience funtions later, but for 99% of the use cases these methods will be sufficient. Now we can start playing around with our sponge function.

Use cases

In the beginning, we mentioned that sponge functions have a wide range of cryptographic use cases. In this section I will implement them in simple ways to provide some examples on how useful sponge functions can be.

Hash function

Hashing is the easiest thing to implement with a sponge function. In fact, we already saw a demonstration of this when testing the squeeze function above. To clarify; we can produce a hash by absorbing all the input and squeezing a fixed number of bytes.

In [18]:
def sponge_hash(data):
    s = Sponge()
    s.absorb(data)
    return s.squeeze(10).hex()

sponge_hash(b"123")
sponge_hash(b"Test 123")
sponge_hash(b"Test 113")
Out [18]:
'91e292b50acc3c838a0a'
Out [18]:
'b7a2027b77e56ca5d11f'
Out [18]:
'62eb28a8017c976f7ccc'

This seems to fit our criteria of a hash function; inputs of different sizes map to fixed-size outputs, and small changes in the output result in completely different hashes. You can substitute 10 with any other length in order to change the output size of your hash. In general, longer hashes are less likely to have collisions but take up more space. You can play around and pick the sweet spot for your use case.

Random number generator

Random number generation is also something that can be done with sponge functions. The basic idea is to absorb the RNG seed and then squeeze out bytes for as many random numbers as you need. In the following example, I am using a fixed seed in order to generate 10 unsigned 16-bit integers.

In [19]:
import struct

s = Sponge()
s.absorb(b"Seeding the RNG")

def rng():
    buf = s.squeeze(2)
    return struct.unpack('H', buf)[0]

[rng() for _ in range(10)]
Out [19]:
[29342, 19407, 47040, 9984, 55893, 40500, 56312, 36293, 58610, 10880]

If we use the same seed, we will always get the same output. This might sound counterintuitive for the goal of generating “random” numbers, but it is commonly required to be able to replicate random results. If this is not something you need, you can seed from an actually random source or something that changes regularly like the current time. This depends on what qualities you expect from your random numbers. Below is a demonstration of how to read a random seed from /dev/urandom.

In [20]:
s = Sponge()

with open("/dev/urandom", "rb") as urandom:
    s.absorb(urandom.read(64))

[rng() for _ in range(10)]
Out [20]:
[56437, 39690, 47308, 16515, 29378, 11318, 32523, 18419, 47972, 4874]

Idea! You can absorb values while generating them as well, this allows you to periodically reseed your RNG using external sources.

Message Authentication Code

We can use a sponge function in order to produce a mechanism that can produce and verify signatures using a secret key. This is a very common technique, especially in mobile and web applications where it is used to store the session on the client without letting them tamper with it. If you want to read more about this use-case, check out JSON Web Tokens.

In order to produce a signature; we will absorb the data, and the secret key. After this, we can squeeze any number of bits that can be used as the signature.

In [21]:
def sign(data, key):
    s = Sponge()
    s.absorb(data)
    s.absorb(key)
    return s.squeeze(5)

data = b"Hello world!"
key  = b"password123"
signature = sign(data, key)

signature.hex()
Out [21]:
'480e4c2b9d'

Verification of a signature can be done by generating the signature yourself and comparing the signature you received with the signature you generated. If they match up, the data and the signature have not been tampered with.

In [22]:
def verify(data, sig, key):
    correct = sign(data, key)
    return sig == correct

verify(data, signature, key)
Out [22]:
True

As expected, the signature can be verified successfully. Let’s try to modify the data a little and switch two characters around.

In [23]:
data = b"Hello wordl!"

verify(data, signature, key)
Out [23]:
False

Similarly, we can have the correct data and tamper with the signature instead. The verification fails, showing that both the signature and the data are protected against corruption and tampering.

In [24]:
data = b"Hello world!"
signature = bytes.fromhex("481e4c2b9d")

verify(data, signature, key)
Out [24]:
False

Stream cipher

A stream cipher allows us to encrypt and decrypt a stream of bytes using a single secret key. This can be used to make sure only you, or anyone you entrust with the secret key can decrypt the data.

In [25]:
def stream_cipher(data, key):
    s = Sponge()
    s.absorb(key)
    
    output = bytearray(data)
    
    for i in range(len(data)):
        key = s.squeeze_byte()
        output[i] ^= key
        
    return output

data = b"Hello, world!"
encrypted = stream_cipher(data, b"password123")

encrypted.hex()
Out [25]:
'b571d4065c54547bdf1a002d8e'

Decoding a stream cipher is very simple, in fact it takes no code at all. Simply encrypting the already encrypted value with the correct key will end up decrypt your data. Let’s try to decode our data with the correct and incorrect passwords.

In [26]:
stream_cipher(encrypted, b"password123")
stream_cipher(encrypted, b"password132")
Out [26]:
bytearray(b'Hello, world!')
Out [26]:
bytearray(b'\x12\x88\x98?\x9aESh\x9a\x96\x9d\x17\x1d')

Idea! You can combine the Message Authentication Code and the Stream Cipher in order make a chunk of data that is encrypted and resitant against tampering. This is called Authenticated encryption commonly done in real protocols. Try to implement the AE and AEAD variants.

Warning! It is recommended to also include an IV / nonce with your key in order to make sure the same plaintext encrypts to different ciphertexts.

Time-based one-time password

You might have noticed that a lot of services these days ask you for one-time tokens when trying to authenticate. These tokens are usually displayed as 6 digits and expire in ~30 seconds. Using a sponge function, we can implement our own version pretty easily. Here’s how these one-time tokens work.

  1. The server and the client have a pre-agreed secret key.
  2. When authenticating, the server asks the client to produce a token.
  3. The client absorbs the current time and the secret key in order to produce a token, and sends it to the server.
  4. The server independently produces a token using the same key and following the same rules.
  5. If the tokens match up, the client is granted access.
In [27]:
import time

key = b"Secret key 123"

def get_otp(key, period=10):
    t = time.time()
    value = int(t / period)
    time_left = period - (t % period)
    
    s = Sponge()
    s.absorb(key)
    s.absorb(str(value).encode('ascii'))
    
    otp = [s.squeeze(1).hex() for _ in range(3)]
    otp = ' '.join(otp)
    
    return otp, int(time_left)

otp, time_left = get_otp(key)

f"OTP is '{otp}'."
f"Valid for {time_left} more seconds."
Out [27]:
"OTP is '7c 0b c8'."
Out [27]:
'Valid for 7 more seconds.'

If the code is still valid, meaning that time_left is still not zero, the OTP will be considered valid.

In [28]:
otp == get_otp(key)[0]
Out [28]:
True

If we wait until the timer runs out, our OTP will no longer validate.

In [29]:
time.sleep(time_left + 1)

otp == get_otp(key)[0]
Out [29]:
False

Idea! It is recommended to also accept codes that should have been generated before or after the current time in order to account for clock skew. After all, the current time is the input that determines what the code will be, so authentication won’t be possible if the clocks don’t match up.

Block cipher

Stream ciphers that use cryptographic hashes have a risk of running into cycles. This is when calling the transform function on the state will eventually go back to a previous one. In order to mitigate this, we can use a block cipher instead.

The difference of a block cipher is; instead of constructing the sponge once and squeezing bytes from it for the whole stream, we instead absorb a counter into the sponge along with the key and nonce in order to generate a fixed block of bytes. This is where we get the name “block cipher”.

In [30]:
BLOCKSIZE = 10

def get_block(key, counter):
    s = Sponge()
    s.absorb(key)
    s.absorb(str(counter).encode("ascii"))
    return bytearray(s.squeeze(BLOCKSIZE))

def block_encrypt(data, key):
    size = len(data)
    result = b""
    
    counter = 0
    while data:
        # Chop off BLOCKSIZE bytes from the data
        data_block = data[:BLOCKSIZE]
        data = data[BLOCKSIZE:]
        
        # Generate a block cipher block
        block = get_block(key, counter)
        
        for i, byte in enumerate(data_block):
            block[i] ^= byte
        
        result += block
        counter += 1
        
    return result[:size]

data = b"Hello, world! Don't forget to stay hydrated."
encrypted = block_encrypt(data, b"test")
encrypted.hex()
Out [30]:
'eec587d16686e81d26ed800677e609a6d2fed11b7a27bbb233370cdba1d941cdc01d42c4c3e7ee90a09333c1'

As we did with the stream cipher, let’s try to decrypt our data with correct and incorrect keys.

In [31]:
block_encrypt(encrypted, b"test")
block_encrypt(encrypted, b"TEST")
Out [31]:
b"Hello, world! Don't forget to stay hydrated."
Out [31]:
b'\xd1%\x17\xd9\xe0\x1bh\xaf~2\xc0\x9f\x8da\xb2\xe4\xa4\x05\x99\xc4\x82\xf7\x02\x0c\xed+\xa1\xf4\xefa?\x82l9Q\x05=B>p%\x9e\xa0q'

Closing words

If you made it this far, thank you for reading my article. I’d appreciate any emails or comments. You now have an understanding of how to implement some commonly used cryptographic technologies from scratch, please let me know what projects you end up doing with sponge functions.