Subscribe to mailing list

Get notified when we have new updates or new posts!

Subscribe Unicorn Data Science cover image
jen@unicornds.org profile image jen@unicornds.org

Unit Testing One Day Tour: Testing a Minimal Neural Network in Python

Having a testing mindset can take your data science toolkit to the next level. With proper testing, you can ensure your code is well-designed, maintainable, and working as intended.

Unit Testing One Day Tour: Testing a Minimal Neural Network in Python
Photo by Vedrana Filipović / Unsplash

When I was leaving Google, a colleague asked me, "What was the most important lesson you've learned here?" What a great question, I thought. On the spot, the first answer that came to mind was "unit testing" - which sounded so trivial, and it earned a few nerdy chuckles from my teammates.

Since then, this question has stayed with me. Yet even after years of reflection, I still don't have a better answer. The fact is, learning about the testing mindset was indeed one of the most important things I gained from my time there.

And I had someone specific to thank for that - a code reviewer who tirelessly guided me through a lengthy pull request. My original code had no test. And through a review process spanning multiple iterations, this code reviewer gently introduced me to many core concepts of testing. By the end of the review, my pull request had thorough tests ranging from basic unit tests, to stubbing out external dependencies. The most amazing thing was that this reviewer was actually not on our team, but volunteered to keep code quality high and spread python knowledge across the organization. Reviewing my newbie code probably took her many hours of back-and-forth. To this day, I'm still deeply grateful to this person who I have actually never met.

But enough digression. What I wish to convey is that learning the testing mindset can be a powerful tool for data scientists. So in today's one-day tour, we will use a minimal neural network as an example to try out some fundamentals of testing in Python.

Prerequisites

You should already have used Python before, and can run Python code in your terminal. In this tour, we will use Python's built-in unittest module. That means you don't have to install additional dependencies, and reasonably up-to-date versions of Python should work fine.

Test-Driven Development

As data scientists, we oftentimes write what I call "stream of consciousness" code. This means we have an idea what needs to be done, and we write down the lines like the progressing thoughts in our head.

This is especially common because there's more interactivity in the data science workflow. We have to analyze data, and depending on what we find, make adjustments and proceed with the next step.

As a result, well-known engineering practices like abstraction, encapsulation, and separation of concerns can be new concepts for data scientists. Yet these concepts are not difficult to learn, and they are essential for building maintainable and scalable code.

Test-Driven Development (TDD) is an example like this. Writing tests is such a fundamental part of software development, engineers oftentimes can't submit their code without tests. Yet, in the world of data science (and in academic research as well), the practice of writing tests is comparably rare.

The idea of test-driven development is simple: write tests before you write the implementation code. There are many benefits to this approach. For example, writing tests forces you to define the input and output of the functions you will implement. Another benefit is that it forces you to think about all the possible scenarios, and how your code should handle them. In this way, the tests you write also serves as a documentation for the intended behaviors of your code.

Test-Driven Development compared to common data science code development.

Having to write test first can feel tedious and time-consuming. When I just got started with test-driven development, there are also situations where I had difficulties defining test cases. But over time, it became apparent that difficulty in writing tests is a symptom of a bad implementation design. Once my functions have a more clearly defined scope and behavior, writing tests becomes natural and straightforward.

Example: A Minimal Neural Network

To demonstrate the testing mindset, we will use a minimal neural network. This nerual "net" has just one node, and has one weight (and one bias). We will also only look at forward propagation that uses the ReLU activation function.

A simplistic neuron that we will use as an example for developing unit tests.

For this minimal neural network, what are the feature requirements? We probably want to be able to:

  1. Store the weight and bias
  2. Given an input, x, calculate w x + b
  3. Then apply the ReLU activation function, where negative values are squashed to zero
  4. Get back the output, y, after steps 2. and 3.

Unit Testing in Python

There are many ways to write tests in Python, and there are also many libraries available. But for this one-day tour, we will stick to the built-in unittest module. You can already accomplish quite a bit with unittest!

Let's start with our files and boilerplate code. Create two files: simple_neural_net.py and test_simple_neural_net.py. Leave the simple_neural_net.py blank for now, and in test_simple_neural_net.py, start with the following:

"""test_simple_neural_net.py"""
import unittest
import simple_neural_net

class TestSimpleNeuralNet(unittest.TestCase):

    def test_example(self):
        self.assertEqual(1 + 1, 2)

if __name__ == '__main__':
    unittest.main()

Now if you run python test_simple_neural_net.py in your terminal, you should see the following output:

.
----------------------------------------------------------------------
Ran 1 test in 0.000s

OK

If you intentionally change the test to self.assertEqual(1 + 1, 3), and run python test_simple_neural_net.py again, you should see the following output:

F
======================================================================
FAIL: test_example (__main__.TestSimpleNeuralNet)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/unicorn/unit_testing_one_day_tour/test_simple_neural_net.py", line 7, in test_example
    self.assertEqual(1 + 1, 3)
AssertionError: 2 != 3

----------------------------------------------------------------------
Ran 1 test in 0.001s

FAILED (failures=1)

As you can see, you have control over how you define the tests, and if they suceed or fail!

Assertions

The self.assertEqual method is a built-in method in unittest. There are other assertion methods, which you can see in the Python documentation.

Writing Feature Tests

Going back to our minimal neural network, and looking at the list of feature requirements, we can start writing tests for each of them. Let's begin with the first one: storing the weight and bias. That means we would need a SimpleNeuralNet class that can be initialized with a provided weight and bias, and then we can test that the weight and bias are stored correctly.

"""test_simple_neural_net.py"""
import unittest
import simple_neural_net

class TestSimpleNeuralNet(unittest.TestCase):

    def test_store_params(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        self.assertEqual(snn.weight, 1)
        self.assertEqual(snn.bias, 2)

if __name__ == '__main__':
    unittest.main()


If we run python test_simple_neural_net.py, we should see the test fail, because we haven't implemented the SimpleNeuralNet class yet.

E
======================================================================
ERROR: test_store_params (__main__.TestSimpleNeuralNet)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/unicorn/unit_testing_one_day_tour/test_simple_neural_net.py", line 8, in test_store_params
    snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
AttributeError: module 'simple_neural_net' has no attribute 'SimpleNeuralNet'

----------------------------------------------------------------------
Ran 1 test in 0.000s

FAILED (errors=1)

No big deal. Now that we have a test for the SimpleNeuralNet class, let's implement it!

"""simple_neural_net.py"""
class SimpleNeuralNet:

    def __init__(self, weight, bias):
        self.weight = weight
        self.bias = bias

Running python test_simple_neural_net.py again, and the test should now pass. Congratulations, you have just completed one cycle of test-driven development!

Typically, at this point, I'd add some more tests. For example, below I'm adding a new test_update_params() test to see that we can update the weight and bias of the SimpleNeuralNet class.

"""test_simple_neural_net.py"""
import unittest
import simple_neural_net

class TestSimpleNeuralNet(unittest.TestCase):

    def test_store_params(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        self.assertEqual(snn.weight, 1)
        self.assertEqual(snn.bias, 2)

    def test_update_params(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        snn.weight = 0.1
        snn.bias = 0.2
        self.assertEqual(snn.weight, 0.1)
        self.assertEqual(snn.bias, 0.2)

if __name__ == '__main__':
    unittest.main()

Running python test_simple_neural_net.py again, we should see now two tests are passing:

..
----------------------------------------------------------------------
Ran 2 tests in 0.000s

OK

Very satisfying!

Writing More Feature Tests

Let's now confidently (?) write the tests for the rest of the forward propagation steps.

"""test_simple_neural_net.py"""
import unittest
import simple_neural_net

class TestSimpleNeuralNet(unittest.TestCase):

    def test_store_params(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        self.assertEqual(snn.weight, 1)
        self.assertEqual(snn.bias, 2)

    def test_update_params(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        snn.weight = 0.1
        snn.bias = 0.2
        self.assertEqual(snn.weight, 0.1)
        self.assertEqual(snn.bias, 0.2)

    def test_relu(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        y = snn.relu(x=3)
        self.assertEqual(y, 3)
        y = snn.relu(x=-3)
        self.assertEqual(y, 0)

    def test_forward_propagation(self):
        snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)
        y = snn.forward_propagation(x=3)
        # y = 1 * x + 2 = 3 + 2 = 5 which is still 5 after ReLU
        self.assertEqual(y, 5)

        y = snn.forward_propagation(x=-10)
        # y = 1 * x + 2 = -8 + 2 = -6 which is 0 after ReLU
        self.assertEqual(y, 0)


if __name__ == '__main__':
    unittest.main()

Here we have added two more tests: test_relu() and test_forward_propagation(). The first one assumes we will have a ReLU activation function, which turns any negative values to zero. The second one assumes we will have a forward propagation function, which takes an input x and computes y = w * x + b, then applies the ReLU activation function.

Running python test_simple_neural_net.py will lead to failing tests, as we haven't implemented the relu() and forward_propagation() methods yet. So let's do that!

"""simple_neural_net.py"""
class SimpleNeuralNet:

    def __init__(self, weight, bias):
        self.weight = weight
        self.bias = bias

    def relu(self, x):
        return max(0, x)

    def forward_propagation(self, x):
        return self.relu(x * self.weight + self.bias)

As you can see, the implementation is extremely simple (because this is a minimal neural network!). And notably, it is shorter than the test file! This is actually pretty typical for test-driven development. The test file is comprehensive and exhaustive, whereas implementation code contains clearly defined functions, each is short and focused on a single task.

Running python test_simple_neural_net.py again, we should see all four tests are passing!

....
----------------------------------------------------------------------
Ran 4 tests in 0.000s

OK

Utilizing Unit Testing Framework

Python's built-in unittest module comes with many useful features. For example, you might have noticed that in test_simple_neural_net.py, we are constantly initializing a SimpleNeuralNet class with snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2). If you have many tests that use the same initialization, you can use a setUp() method.

"""test_simple_neural_net.py"""
import unittest
import simple_neural_net

class TestSimpleNeuralNet(unittest.TestCase):
    def setUp(self):
        self.snn = simple_neural_net.SimpleNeuralNet(weight=1, bias=2)

    def test_store_params(self):
        self.assertEqual(self.snn.weight, 1)
        self.assertEqual(self.snn.bias, 2)

    def test_update_params(self):
        self.snn.weight = 0.1
        self.snn.bias = 0.2
        self.assertEqual(self.snn.weight, 0.1)
        self.assertEqual(self.snn.bias, 0.2)

    def test_relu(self):
        y = self.snn.relu(x=3)
        self.assertEqual(y, 3)
        y = self.snn.relu(x=-3)
        self.assertEqual(y, 0)

    def test_forward_propagation(self):
        y = self.snn.forward_propagation(x=3)
        # y = 1 * x + 2 = 3 + 2 = 5 which is still 5 after ReLU
        self.assertEqual(y, 5)

        y = self.snn.forward_propagation(x=-10)
        # y = 1 * x + 2 = -8 + 2 = -6 which is 0 after ReLU
        self.assertEqual(y, 0)


if __name__ == '__main__':
    unittest.main()


Here we have added a setUp() method, which initializes a SimpleNeuralNet class. Note, this initialization happens before each test. So even though we have a test, test_update_params(), that mutates the weight and bias of self.snn, it does not affect other tests. This is an important feature of unit testing, as it allows you to isolate the behavior of each test.

Mocking External Dependencies

I mentioned earlier the code reviewer to whom I owe a big debt of gratitude. One major lesson she taught me was how to stub out external dependencies. So, even though this is a one day tour, I do want to include a section on this topic.

So far our functions have been very simple and self-contained. But in reality, our code rarely lives in a vacuum. Instead, we are constantly fetching data from the outside world (say, via API), or we are calling other modules (say, via import). That means the behavior of our functions will heavily depend on the behavior of external data or modules. This can make testing our own function difficult.

The Python unittest module comes with a built-in mocking framework called mock. It allows you to replace external dependencies with mock objects, which can be used to control the behavior of the external dependencies. For example, if you want to test a function that calls an external API, you can replace the API with a mock object that returns a predefined response.

To illustrate with an example, let's return to our SimpleNeuralNet class. When initializing the class, we require a weight and bias. But model initialization is actually a major research topic - poorly initialized models can be more difficult to train. So let's imagine we have a magical module called param_initializer.py that provides optimized initialization:

"""simple_neural_net.py"""
import param_initializer

class SimpleNeuralNet:

    def __init__(self):
        params = param_initializer.get_params()
        self.weight = params["weight"]
        self.bias = params["bias"]

    def relu(self, x):
        return max(0, x)

    def forward_propagation(self, x):
        return self.relu(x * self.weight + self.bias)

And we will create param_initializer.py, but without actual implementation. This is just to allow python to find the module.

"""param_initializer.py"""

def get_params():
    pass


Then, in the test, we will use the patch decorator to mock the get_params() function. Note, we want to patch the usage of get_params() within the simple_neural_net.py module, so we use @patch('simple_neural_net.get_params') to specify the module path. And patching gives us a mock object, which we can use to control the return value of the function.


Moreover, the patched object has additional assertion methods available. For example, we would expect get_params() function to be called each time we initialize the SimpleNeuralNet class. So we can add a mock_get_params.assert_called_once() test in the setUp() method.

"""test_simple_neural_net.py"""

import unittest
from unittest.mock import patch

import simple_neural_net


class TestSimpleNeuralNet(unittest.TestCase):

    @patch('simple_neural_net.get_params')
    def setUp(self, mock_get_params):
        mock_get_params.return_value = {"weight": 1, "bias": 2}
        self.snn = simple_neural_net.SimpleNeuralNet()
        mock_get_params.assert_called_once()

    def test_store_params(self):
        self.assertEqual(self.snn.weight, 1)
        self.assertEqual(self.snn.bias, 2)

    def test_update_params(self):
        self.snn.weight = 0.1
        self.snn.bias = 0.2
        self.assertEqual(self.snn.weight, 0.1)
        self.assertEqual(self.snn.bias, 0.2)

    def test_relu(self):
        y = self.snn.relu(x=3)
        self.assertEqual(y, 3)
        y = self.snn.relu(x=-3)
        self.assertEqual(y, 0)

    def test_forward_propagation(self):
        y = self.snn.forward_propagation(x=3)
        # y = 1 * x + 2 = 3 + 2 = 5 which is still 5 after ReLU
        self.assertEqual(y, 5)

        y = self.snn.forward_propagation(x=-10)
        # y = 1 * x + 2 = -8 + 2 = -6 which is 0 after ReLU
        self.assertEqual(y, 0)


if __name__ == '__main__':
    unittest.main()


Now you can test the simple neural net even though your dependency, get_params() is not implemented. Such is the power of mocking!

There are more assertion methods available to the mock object. For example, you can check if the get_params() function was called with certain arguments, or if it was called multiple times. All these features allow you to define the precise behavior of external dependencies, so you can craft the test scenarios you need.

Where to Go from Here

This short one-day tour has provided you with a taste of unit testing in Python. With this knowledge, you have just now become a rare breed of data scientists who can build reliable and maintainable software. so where to go from here?

Actually Write Tests

No tutorial is as good as putting your knowledge into practice. Start looking at your projects through the test-driven development lens. Identify tests that are missing, and start implementing them. Think about the scenarios your code should handle, and make it a challenge to yourself to identify edge cases. As your testing skill matures, you might also start thinking about continuous integration (CI; where tests are run automatically on every commit), as well as coverage metrics (where you can see how much of your code is covered by tests).

And, because this is 2024, it's also worth mentioning that pairing up with an LLM assistant can be quite helpful with writing tests. A lot of unit tests are very structured. For example, we started with a boilerplate test file for the Python unittest framework. While you, the data scientist, should fully be in the driver's seat when it comes to designing scenarios to test, an LLM assistant can speed up the process of test implementation.

Other Testing Frameworks

For the sake of simplicity, we have been using the built-in unittest framework. But there are many useful testing tools in the Python ecosystem. For example:

  • pytest - A very popular library known for its simple syntax, powerful fixtures, and rich plugin options.
  • coverage - Commonly paired with pytest to measure how much of your code has corresponding tests. Because what gets measured gets improved.
  • doctest - Write documentation, or write tests? Why not both? With doctest, you can embed tests in docstrings that also serve as code examples.
  • numpy.testing - You have very likely used numpy, a foundational library for numerical calculation. numpy comes with specialized assert functions like assert_array_equal, assert_almost_equal, or assert_allclose. Similarly, pandas also has pandas.testing to check equality between series and also data frames.

Beyond a Simple Neural Network

The example neural "net" we use here is extremely simple, as the focus of the tour is unit testing. However, we can build up on this to construct a more common neural network structure. To do so, you'd probably want to:

  • Vectorize weights and biases with numpy arrays
  • Add a loss calculation function
  • Add a backward propagation function
  • Add a training function which iteratively updates weights and biases

Building a neural network from the ground up can be a great learning experience if you'd like a first principle understanding of how AI works. And there are great resources on this topic, such as this excellent lecture.

That's it for today's one day tour! I hope learning about unit testing gives you lots of ideas on what to test. In fact, writing unit tests is a lot like designing experiments – isolated and reproducible experiments. And as scientists, who wouldn't want to have little experiments to verify our code is working as intended?