18 Tips for Training your own Tensorflow.js Models in the Browser

Training efficient Image Classifiers and Object Detectors for the Web with Tensorflow.js

Vincent Mühler
ITNEXT

--

After porting existing models for object detection, face detection, face recognition and what not to tensorflow.js, I found some models not to shine with optimal performance, while other models would perform pretty well in the browser. This is actually kind of astonishing if you think about the potential of in-browser machine learning and all the possibilities libraries such as tensorflow.js offer to us web developers.

However, with deep learning models running directly in the browser, we are also facing new challenges and limitations of some of the existing models, which might not have been specifically designed for running client side in a browser, not to mention in a mobile browser. Just take state of the art object detectors as an example: They usually require substantial amount of computing resources to run at reasonable fps, let alone at realtime speed. Furthermore, it is simply not feasible to ship 100MB+ of model weights down to a clients browser in a simple web application.

Training efficient Deep Learning Models for the Web

But hope dies last! Let me tell you, that we are able to build and train fairly decent models, which are optimized for running in a web environment by considering some basic principles. Believe it or not: We can actually train pretty decent image classification- and even object detection models, which end up being just a few mega bytes in size or even just a few kilo bytes:

In this article I want to give you some general tips to get started with training your own convolutional neural network (CNN), but also some tips, which are directly targeted at training a CNN for the web and mobile devices in the browser with tensorflow.js.

Now you might wonder: Why should I train my models with tensorflow.js in the browser, when I could simply train them with tensorflow on my machine? Certainly you could do that of course, provided that your machine is equipped with a NVIDIA card. One huge advantage of in-browser deep learning frameworks, which under the hood utilize WebGL, is: You don’t need a NVIDIA GPU to train a model. After discovering tensorflow.js, it was really the first time ever I was able to train deep learning models on my AMD GPU.

So, if your machine is equipped with a NVIDIA card, you can simply go for the standard tensorflow approach (in this case you can either write your training code in python, or you can also use the tfjs-node wrapper and stick with tfjs) and ignore the browser specific tips. But now, let’s get started!

Network Architecture

Before getting started with training our own image classifier, object detector or whatever, we obviously have to implement a network architecture first. Oftentimes it is being recommended to pick an existing architecture, such as Yolo, SSD, ResNet, MobileNet, etc., which have been proven to work out.

Personally, I think it is valuable to use some of the concepts employed by those architectures in your own architecture. However, as I initially pointed out, simply adopting these architectures won’t make it for the web in my opinion, since we want our models to be small in size, fast at inference (ideally realtime) and as easy to train as possible.

No matter if you want to adapt an exisiting architecture or start completely from scratch, I want to give you the following tips, which helped me a lot in designing efficient CNN architectures for the web:

1. Start off with a small Network Architecture!

Keep in mind, the smaller our network can be while still achieving good accuracy at solving our problem, the faster it will perform at inference time and the easier it will be for a client to download and cache that model. Furthermore, smaller models come with less parameters and thus will converge faster at training time.

If you find your current network architecture to not perform very well, or to not reach the level of accuracy, you would like it to be at, you can still incrementally increase the size of your network, e.g. by increasing the number of convolutional filters at each layer or by simply making your network deeper by stacking more layers.

2. Employ Depthwise Separable Convolutions!

Since we are training a new model anyways, we want to definitely use depthwise separable convolutions over plain 2D convolutions. Depthwise separable convolutions split the regular convolution operation into a depthwise convolution followed by a pointwise (1x1) convolution. Compared to a regular convolution operation, they have less parameters, which result in much less floating point operations and are easier to parallelize, meaning inference will be much faster (I have even seen speed ups of up to 10x for inference by simply replacing regular convolutions with depthwise separable ones) and less resource consuming (which can considerably boost performance on mobile devices). Furthermore, because they have less parameters, it takes less time to train them.

The idea of depthwise separable convolutions is being employed in MobileNet and Xception and you can find them in the tensorflow.js models for MobileNet and PoseNet for example. Whether or whether not depthwise separable convolutions result in less accurate models is probably an open debate, but from my experience they are definitely the way to go for web- (and mobile) models.

Long story short: I would recommend using a regular conv2d operation in your very first layer, which usually has not that much parameters anyways, to preserve the relations between the RGB channels in the features extracted.

For the rest of the convolutions simply go with depthwise separable convolutions. Therefore, instead of having a single kernel, we will end up with a 3 x 3 x channels_in x 1 depthwise filter and a 1 x 1 x channels_in x channels_out pointwise filter.

So instead of using tf.conv2d with a kernel having a shape of [3, 3, 32, 64], we would simply use tf.separableConv2d using a depthwise kernel with a shape of [3, 3, 32, 1] as well as a [1, 1, 32, 64] shaped pointwise kernel.

3. Skip Connections and Densely Connected Blocks

Once I decided to build deeper networks, I was quickly facing one of the most common problems of training a neural network: the vanishing gradient problem. After some epochs, the loss would only decrease in very tiny steps, which would either result in ridiculously long training times, or cause the model to not converge at all.

Skip connections, which are employed in ResNet and DenseNet allow to build deeper architectures, while mitigating the vanishing gradient problem. All we have to do is to add the output of previous layers to the input of layers located deeper in our network, before the activation function is applied:

Skip Connection

Skip connections work, because by connecting layers via shortcuts, we can atleast learn the identity function. The intuition behind this technqiue is, that gradients do not have to be backpropagated solely through convolutional (or fully connected) layers, which cause gradients to diminish once they reach the earlier layers of the network. They can rather “skip” layers through the addition operation of the skip connection.

Obviously, a requirement for that to work is, suppose you want to connect layer A with layer B, the output shape of A has to match the input shape of B. If you want to build residual or densely connected blocks, simply make sure to keep the same number of filters amongst the convolutions in that block and keep a stride of 1 with same padding. Just as a sidenote, there are also different approaches, which either pad the output of A, such that it matches the shape of the input B, or which concatenate feature maps from previous layers, such that the depth of the connected layers match again.

At first, I was fiddleing around with a ResNet like approach, simply introducing a skip connection between every other layer as shown in the above image, but soon figured out, densely connected blocks work even better and immensely decrease the time required for the model to reach convergence:

Sketch of a Denseblock

Here is an example of a dense block implementation, which I used as the basic building block for the 68 point face landmark detector of face-api.js. One of these blocks involves 4 depthwise separable convolutional layers (note, the first convolution of the very first dense block is a regular convolution) and the first convolution operation of each block uses a stride of 2 to scale the input down:

4. Use ReLU Type Activation Functions!

Unless you have a specific reason to use any other type of activation function, I would simply go with tf.relu. The simple reason being, ReLU type activation functions help mitigate the problem of vanishing gradients.

You can also experiment with variations of ReLU, such as leaky ReLU, which is being utilized in the Yolo architecture:

Or ReLU-6 as employed by Mobilenet:

Training

Once we have come up with an initial architecture, we can start training our model.

5. If in doubt, simply use Adam Optimizer!

When I first started training my own models, I was wondering, which optimizer is best? I started off with using plain SGD, which seemed to get stuck in local minima sometimes or even resulted in exploding gradients, causing model weights to infinitely grow, resulting in NaNs eventually.

I am not saying, that Adam is the best option for all problems, but I found it to be the easiest and most robust way to train a new model, by simply starting off using Adam with default parameters and a learning rate of 0.001:

6. Adapting the Learning Rate

Once the loss is not decreasing any further considerably, chances are, our model did converge (or got stuck) and is not able to learn much further. At that point we might as well just stop the training process, to prevent our model from overfitting (or to try a different architecture).

However, it is also possible, that you can squeeze some more numbers out of the training process, by adjusting (decreasing) the learning rate at that point. Especially if the overall loss computed over the training set starts to oscillate (jump up and down), this is an indicator that it might be a good idea to try to decrease the learning rate.

Here is an example showing a plot of the overall error while training the 68 point face landmark model. At epoch 46 the loss value started to oscillate. As you can see, continuing training from the checkpoint of epoch 46 for 10 more epochs with a learning rate of 0.0001 instead of 0.001, I was able to drive the overall error down even further:

7. Weight Initialization

If you have no clue about how to properly initialize your model weights (just like I did not have any idea, when I got started): As a simple rule of thumb, initialize all your biases with zeros (tf.zeros(shape)) and your weights (kernels of convolutions and weights of fully connected layers) with non zero values, drawn from some kind of normal distribution. For example you could simply use tf.randomNormal(shape), but nowadays I prefer to use a glorot normal distribution, which is available in tfjs-layers as follows:

8. Shuffle your Inputs!

A common advice for training a neural network is to randomize the order of occurence of your training samples by shuffling them at the begin of each epoch. Conveniently, we can use tf.utils.shuffle for that purpose, which will shuffle an arbitray array inplace:

9. Saving Model Checkpoints using FileSaver.js

Since we are training our model in the browser, you may now ask yourself: How do we automatically save checkpoints of our model weights while training? We simply use FileSaver.js. The script exposes a function called saveAs, which we can use to store arbitrary types of files, which will end up in our downloads folder.

This way we can save our model weights:

Or even json files, for example to save the accumulated losses for an epoch:

Troubleshooting

Before spending a lot of time at training your model, you want to make sure, that your model is actually learning what it is supposed to and erase any potential source of errors and bugs. If you do not consider the following tips, you might end up wasting your time training complete garbage and you will end up wondering:

10. Check your Input Data, Pre- and Post Processing Logic!

If you pass garbage into your network, it will throw garbage back at you. Thus, make sure your input data is labeled correctly and that your network inputs are what you expect them to be. Especially if you have implemented some preprocessing logic like random cropping, padding, squaring, centering, mean subtraction or what else, make sure to visualize your inputs after preprocessing. Also I would highly recommend unit testing these steps. Same goes for post processing of course!

I know this sounds like a tedious amount of extra work, but it is worth it for sure! You won’t believe, how many hours I was trying to figure out, why the heck my object detector did not learn to detect faces at all, until I eventually discovered my preprocessing logic to turn inputs into trash due to incorrect cropping and distortion.

11. Check your Loss Function!

Now in most cases tensorflow.js luckily provides you with the loss function of your needs. However, in case you need to implement your own loss function, you should definitely unit test it! A while ago, I implemented the Yolo v2 loss function using the tfjs-core API from scratch to train yolo object detectors for the web. Let me tell you that this can get very hairy, unless you break down the problem and make sure, the individual components compute what they are supposed to.

12. Overfit on a small Dataset first!

Generally it’s a good idea, to overfit on a small subset of your training data, to verify, that the loss is converging and that your model is actually learning something useful. Therefore, you should simply pick 10 to 20 images of your training data and train for some epochs. Once the loss converges, run inference on these 10 to 20 images and visualize the results:

This is a very important step, which will help you to eliminate all kinds of sources of bugs in the implementation of your network, pre and post processing logic, as it is unlikely, that your model will learn to make the desired predictions with substantial bugs in your code.

Especially, if you are implementing your own loss function (11.) you definitely want to make sure, your model is able to converge before jumping into training it!

Performance

Finally, I want to give you some advice, which will help you to reduce training time as much as possible and prevent your browser from crashing with memory leaks, by considering some basic principles.

13. Preventing obvious Memory Leaks

Unless you are completely new to tensorflow.js, you probably already know, that we have to dispose unused tensors manually to free up their memory by either calling tensor.dispose() or wrapping our operations in tf.tidy blocks. Ensure, that there are no such memory leaks due to not disposing tensors correctly, otherwise your application will sooner or later run out of memory.

Identifying these kinds of memory leaks is pretty easy. Simply log tf.memory() for a few iterations to verify, that the number of tensors does not inadvertently grow with each iteration:

14. Resize your Canvases and not your Tensors!

Note, the following statements are only valid as of the current state of tfjs-core (I am currently using tfjs-core version 0.12.14) until this will eventually get fixed.

I know this might sound a bit strange: Why not use tf.resizeBilinear, tf.pad and such to reshape your input tensors to the desired network input shape? There is currently an open issue at tfjs, illustrating the problem.

TLDR: Before calling tf.fromPixels, to convert your canvases to tensors, resize your canvases, such that they have the size accepted by your network, otherwise you will run out of GPU memory quickly, depending on the variety of different input sizes of the images in your training data. This will be less of a problem if your training images are all equally sized anyways, but in case you have to resize them explicitly, you can use the following code snippet:

15. Figuring out the optimal Batch Size

Don’t go overboard with batching your inputs! Try out different batch sizes and measure the time required for backpropagation. The optimal batch size obviously depends on your GPU stats, the input size as well as the complexity of your network. In some cases you don’t want to batch your inputs at all.

If in doubt however, I would always go with a batch size of 1. Personally, I figured out that in some cases increasing the batch size doesn’t really help for performance, but in other cases I could see an overall speedup by a factor of somewhere around 1.5–2.0 by creating batches of a size 16 to 24 for an input image size of 112 x 112 pixels at a fairly small network size.

16. Caching, Offline Storage, Indexeddb

Our training images (and labels) might be of considerably large size, maybe up to 1GB or even more, depending on the size as well as the number of your images. Since we can not simply read images from disk in the browser, we would instead use a file proxy, which might be a simple express server, to host our training data and the browser would fetch every single data item.

Apparently, this is very inefficient, but something we have to keep in mind when training in the browser. If your dataset is small enough, you could probably try to keep your entire data in memory, but that’s obviously not very efficient either. Initially, I tried to increase the browser cache size to simply cache the entire data on disk, but that seems to not work anymore in later versions of Chrome and I had no luck with FireFox either.

Finally, I decided to just go for Indexeddb, an in browser database in case you are not familar, which we can utilize to store our entire training and test data sets. Getting started with Indexeddb is quite simple, as we can basically store and query our entire data as key value stores with only a few lines of code. With Indexeddb we can conveniently store our labels as plain json objects and our image data as blobs. Check out this blog post, which nicely explains, how to persist image data and other files in Indexeddb.

Querying Indexeddb is quite fast, atleast I found it to be way faster to query each data item, than fetching files from the proxy over and over again. Plus, after moving your data into Indexeddb, training technically works completely offline now, meaning we might not need the proxy server anymore.

17. Async Loss Reporting

This is a simple, yet pretty effective tip, which helped me a lot reducing iteration times while training. The main idea is, in case we want to retreive the value of our loss tensors returned by optimizer.minimize, which we certainly do, because we want to keep track of our loss while training, we want to avoid awaiting the Promise returned by loss.data() to prevent waiting for CPU and GPU to synchronize at each iteration. Instead we want to do the something like the following for reporting the loss value for an iteration:

We simply have to keep in mind, that our losses are now reported asynchronously, so in case we want to save the overall loss at the end of each epoch to a file, we will have to wait for the last promises to resolve, before doing so. I usually just hack around this issue by using a setTimeout for saving the overall loss value 10 seconds or so after an epoch has finished:

After successfully Training a Model

18. Weight Quantiaztion

Once we are done training our model and we are satisfied with it’s performance, I would recommend to shrink the model size by applying weight quantization. By quantizing our model weights, we can reduce the size of our model to 1/4th of the original size! Reducing the size of our model as much as possible is critical for fast delivery of our model weights down to the client application, especially if we can get it basically for free.

Thus, make sure to check out my guide about weight quantization with tensorflow.js: Shrink your Tensorflow.js Web Model Size with Weight Quantization.

If you liked this article you are invited to leave some claps and follow me on medium and/or twitter :). Also stay tuned for further articles and if you are interested, check out my open source work!

--

--