Zack Proser

Building a Hand-Drawn Digit Recognizer with PyTorch and MNIST

The MNIST mind mapper

Table of contents

Demo of the "MNIST mind-mapper" app in action

You can watch this short demo to see how the app works:

Overview

I've been an application and infrastructure developer for the majority of my career, so I wanted to get hands-on with training and deploying a neural network.

A neural network

I also wanted to wrap that trained neural network in a REST API, so that I could build a frontend that would allow folks to play with it, because interacting with something is more engaging than reading a text description of it.

I knew it would be important to go beyond the working neural net, because issues often arise at the seams, when you're fitting system components together.

This article details the steps I took and the many issues I encountered along the way to building and successfully deploying my original vision.

Open source code

I open source most of my work, and this project is no exception:

App flow diagram

Let's step through how the app works, end to end:

The MNIST mind mapper

The frontend exposes a small drawable canvas, which the user scribbles on.

On a regular interval, the frontend captures what the user drew, using the toDataURL method:

 /**
  * Returns the content of the current canvas as an image that you can use as a source for another canvas or an HTML element.
  * @param type The standard MIME type for the image format to return. If you do not specify this parameter, the default value is a PNG format image.
  *
  * [MDN Reference](https://developer.mozilla.org/docs/Web/API/HTMLCanvasElement/toDataURL)
  */
 toDataURL(type?: string, quality?: any): string;

This image is sent to the backend API, which wraps the trained neural network. The backend runs inference on the image, and returns the predicted digit, which the frontend displays.

Step 1: Model training

The following code is all that's needed with pytorch to:

  1. Define a simple neural network architecture
  2. Load the MNIST dataset
  3. Train the neural network
  4. Save the model to disk so that it can be loaded and reused later
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. Create a simple neural net
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 2. Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 3. Train the neural network
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

# Test the model
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f"Accuracy on test set: {100 * correct / total:.2f}%")

# Save the model
torch.save(model.state_dict(), "mnist_model.pth")

With the training complete, I wanted to quickly sanity check the trained network's performance on the intended task, so I created a simple FAST API server that exposes a /test_images route:

You can view the entire main.py file in the repository, but we'll examine the route itself here, which runs a self-test and shares the output as a PNG image like so:

MNIST test image
app = FastAPI()

@app.get("/test_images")
async def test_images():
    # Load test dataset
    from torchvision import datasets
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=True)

    # Get first batch of images
    images, labels = next(iter(test_loader))

    fig, axs = plt.subplots(2, 5, figsize=(12, 6))
    fig.suptitle('Sample Test Images with Predictions')

    for i, (image, label) in enumerate(zip(images, labels)):
        output = model(image.unsqueeze(0).to(device))
        _, predicted = torch.max(output.data, 1)
        
        ax = axs[i // 5, i % 5]
        ax.imshow(image.squeeze().numpy(), cmap='gray')
        ax.set_title(f'Pred: {predicted.item()}, True: {label.item()}')
        ax.axis('off')

    plt.tight_layout()
    
    # Save the plot to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
        plt.savefig(tmp.name, format='png')
        tmp_path = tmp.name

    plt.close()  # Close the plot to free up memory
    
    return FileResponse(tmp_path, media_type="image/png", filename="test_images.png")

Now I had a trained and working neural network and I was ready to deploy it. Here's where all the fun began.

First challenge: Vercel's Python support is immature

Vercel's beta support for the Python runtime for use in backend functions is very exciting. I believe the ability to deploy a mixed app with a Next.js frontend and a Python backend has huge potential.

The Python ecosystem is rich with popular machine learning libraries, utilities and datasets. Meanwhile, JavaScript provides an excellent way to provide charts, graphs, data visualizations, games and other interactive representations of complex data models.

I'm personally very excited to build a ton of example applications with Python backends and TypeScript frontends on Vercel. But we're not quite there yet.

Unfortunately, the docs Vercel has for using their Python runtime are very sparse, the examples are light and most of the new concepts are not sufficiently explained. You have to read through the Python Vercel templates to understand how everything fits together.

Errors are also opaque. The showstopper for getting my Python backend deployed successfully on Vercel was an unintuitive error message: data too long.

Vercel's Python support is immature

I was pretty sure that pytorch and torchvision were likely blowing out the 4.5MB size limit on serverless functions, but there wasn't a great way to confirm this. My model/weights file was just under 400KB.

Things that you'd expect to be extensively documented, such as how to get around common bundling issues, or how to deploy a Python backend with PyTorch and a custom model, are also sorely needed.

You find extremely important details like this hiding out in 4 year old GitHub issue comment threads:

Converting to ONNX and trying again

ONNX, (which stands for Open Neural Network Exchange) is a fascinating project that defines a common language for machine learning models, to allow Cloud providers and developers to more easily write and deploy models.

You can convert a PyTorch model to ONNX using the torch.onnx.export function. ONNX CAN sometimes reduce the size of the exported model, so I decided to give it a shot.

I also was happy to take any excuse to play around with ONNX after reading about it. After converting the model to the ONNX format, I tested everything again locally to ensure the app still functioned as intended. The potential size savings from ONNX did not end up making a difference in my case, and I got the same error.

Finding a Python backend host in Modal.com

In researching platforms that host your Python code as cloud functions, I found modal.com which was lightweight and fast to set up.

Modal.com dashboard

I ended up converting the local version of my Fast API Python backend to a file I named modal_function.py which I then deployed like so:

modal deploy modal_function.py

Of course, my prefence would be for the backend to be hosted on Vercel, so that I could colocate the frontend and backend code and have a simpler deployment model, but modal.com ended up being exactly what I wanted in the absence of that.

Modal.com lets you sign up quickly, install their CLI and deploy your Python code as a serverless function or a deployed REST API, or an ephemeral testing endpoint. So far, it's been great.

There is the issue of cold starts, because Modal will spin down your container when your service is not receiving traffic, but overall I've been impressed with the developer experience.