Building a Hand-Drawn Digit Recognizer with PyTorch and MNIST
Table of contents
- Demo of the "MNIST mind-mapper" app in action
- Overview
- Open source code
- App flow diagram
- Step 1: Model training
- First challenge: Vercel's Python support is immature
- Converting to ONNX and trying again
- Finding a Python backend host in Modal.com
Demo of the "MNIST mind-mapper" app in action
You can watch this short demo to see how the app works:
Overview
I've been an application and infrastructure developer for the majority of my career, so I wanted to get hands-on with training and deploying a neural network.
I also wanted to wrap that trained neural network in a REST API, so that I could build a frontend that would allow folks to play with it, because interacting with something is more engaging than reading a text description of it.
I knew it would be important to go beyond the working neural net, because issues often arise at the seams, when you're fitting system components together.
This article details the steps I took and the many issues I encountered along the way to building and successfully deploying my original vision.
Open source code
I open source most of my work, and this project is no exception:
App flow diagram
Let's step through how the app works, end to end:
The frontend exposes a small drawable canvas, which the user scribbles on.
On a regular interval, the frontend captures what the user drew, using the toDataURL
method:
/**
* Returns the content of the current canvas as an image that you can use as a source for another canvas or an HTML element.
* @param type The standard MIME type for the image format to return. If you do not specify this parameter, the default value is a PNG format image.
*
* [MDN Reference](https://developer.mozilla.org/docs/Web/API/HTMLCanvasElement/toDataURL)
*/
toDataURL(type?: string, quality?: any): string;
This image is sent to the backend API, which wraps the trained neural network. The backend runs inference on the image, and returns the predicted digit, which the frontend displays.
Step 1: Model training
The following code is all that's needed with pytorch to:
- Define a simple neural network architecture
- Load the MNIST dataset
- Train the neural network
- Save the model to disk so that it can be loaded and reused later
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 1. Create a simple neural net
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 2. Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
# 3. Train the neural network
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5
for epoch in range(num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
# Test the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f"Accuracy on test set: {100 * correct / total:.2f}%")
# Save the model
torch.save(model.state_dict(), "mnist_model.pth")
With the training complete, I wanted to quickly sanity check the trained network's performance on the intended task, so I created a simple FAST API server that exposes a /test_images
route:
You can view the entire main.py
file in the repository, but we'll examine the route itself here, which runs a self-test and shares
the output as a PNG image like so:
app = FastAPI()
@app.get("/test_images")
async def test_images():
# Load test dataset
from torchvision import datasets
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=True)
# Get first batch of images
images, labels = next(iter(test_loader))
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
fig.suptitle('Sample Test Images with Predictions')
for i, (image, label) in enumerate(zip(images, labels)):
output = model(image.unsqueeze(0).to(device))
_, predicted = torch.max(output.data, 1)
ax = axs[i // 5, i % 5]
ax.imshow(image.squeeze().numpy(), cmap='gray')
ax.set_title(f'Pred: {predicted.item()}, True: {label.item()}')
ax.axis('off')
plt.tight_layout()
# Save the plot to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
plt.savefig(tmp.name, format='png')
tmp_path = tmp.name
plt.close() # Close the plot to free up memory
return FileResponse(tmp_path, media_type="image/png", filename="test_images.png")
Now I had a trained and working neural network and I was ready to deploy it. Here's where all the fun began.
First challenge: Vercel's Python support is immature
Vercel's beta support for the Python runtime for use in backend functions is very exciting. I believe the ability to deploy a mixed app with a Next.js frontend and a Python backend has huge potential.
The Python ecosystem is rich with popular machine learning libraries, utilities and datasets. Meanwhile, JavaScript provides an excellent way to provide charts, graphs, data visualizations, games and other interactive representations of complex data models.
I'm personally very excited to build a ton of example applications with Python backends and TypeScript frontends on Vercel. But we're not quite there yet.
Unfortunately, the docs Vercel has for using their Python runtime are very sparse, the examples are light and most of the new concepts are not sufficiently explained. You have to read through the Python Vercel templates to understand how everything fits together.
Errors are also opaque. The showstopper for getting my Python backend deployed successfully on Vercel was an unintuitive error message: data too long
.
I was pretty sure that pytorch and torchvision were likely blowing out the 4.5MB size limit on serverless functions, but there wasn't a great way to confirm this. My model/weights file was just under 400KB.
Things that you'd expect to be extensively documented, such as how to get around common bundling issues, or how to deploy a Python backend with PyTorch and a custom model, are also sorely needed.
You find extremely important details like this hiding out in 4 year old GitHub issue comment threads:
Converting to ONNX and trying again
ONNX, (which stands for Open Neural Network Exchange) is a fascinating project that defines a common language for machine learning models, to allow Cloud providers and developers to more easily write and deploy models.
You can convert a PyTorch model to ONNX using the torch.onnx.export
function. ONNX CAN sometimes reduce the size of the exported model, so I decided to give it a shot.
I also was happy to take any excuse to play around with ONNX after reading about it. After converting the model to the ONNX format, I tested everything again locally to ensure the app still functioned as intended. The potential size savings from ONNX did not end up making a difference in my case, and I got the same error.
Finding a Python backend host in Modal.com
In researching platforms that host your Python code as cloud functions, I found modal.com which was lightweight and fast to set up.
I ended up converting the local version of my Fast API Python backend to a file I named modal_function.py
which I
then deployed like so:
modal deploy modal_function.py
Of course, my prefence would be for the backend to be hosted on Vercel, so that I could colocate the frontend and backend code and have a simpler deployment model, but modal.com ended up being exactly what I wanted in the absence of that.
Modal.com lets you sign up quickly, install their CLI and deploy your Python code as a serverless function or a deployed REST API, or an ephemeral testing endpoint. So far, it's been great.
There is the issue of cold starts, because Modal will spin down your container when your service is not receiving traffic, but overall I've been impressed with the developer experience.