TensorFlow is evolving fast, and one thing that has made running distributed TensorFlow code much easier is the Estimator class. Here's our hands-on tutorial on how to use it with distributed MNIST.

Note: The code for this tutorial can be found on our GitHub.

TensorFlow's support for distributed training has evolved a little since our last blog post on this topic. At the time, the Experiment class was a new high level class that abstracted away a lot of manual code.

In more recent TensorFlow versions, the Experiment class has been deprecated. Instead, the Estimator class is directly used. This is an extremely useful tool, as it makes building models a lot easier and will save you time from writing and debugging your code.

Another beauty of the Estimator class is that TensorFlow now supports converting any Keras model into an Estimator, speeding up your model development. This is my favorite way of creating Estimators. Unfortunately, there are not many approachable resources to learn how to use this tool, so I will walk you through an MNIST example to write your first Estimator code to run distributed training!

Writing TensorFlow code with Estimators mainly involves three steps:

Here is a simple template involving these three steps:

In this blog, I will cover these three steps and give you an example for each step.

Model & Estimator class

There are many ways to create a model with Estimators. The above code snippet may be most common method, creating your own Estimator through defining model_fn. In this blog, I will discuss three ways you can create an Estimator:

  1. select a pre-made Estimator,

  2. build a custom Estimator from a TensorFlow graph in model_fn, and

  3. convert a Keras model into an Estimator.

Pre-Made Estimators

The simplest method might be to use TensorFlow's premade Estimators. For example, the DNNClassifier Estimator would work well with this example.

However, as of today, the n_classes parameter for DNNClassifier is hard-coded as 1 and does not work for problems involving more than 1 class (or 2 if you count the first class as “positive” class and the second class as “negative” class). Essentially, it can only do true/false classification until this is fixed. As a result, we will not use the DNNClassifier for our MNIST example.

Build a Custom Estimator from a TensorFlow Graph

The second option is to build an Estimator from a custom TensorFlow graph. To create a custom Estimator, we need to write a model_fn function.

As input values, the function takes an input tensor, features, a target tensor, labels, and the mode (training, evaluation, or prediction). model_fn returns a tf.estimator.EstimatorSpec with different parameters depending on what mode the Estimator will be performing in.

For example, for training, we should return an EstimatorSpec with the loss and train op. For the sake of comparison, we provide an example code below (note: there is a slightly easier way to do this if you scroll down):

Phew! If you're not an expert in TensorFlow, that may not have been what you wanted to see. This is actually simpler than the traditional way of writing your own sess.run for all the tensors and ops as we did in our previous blog, but TensorFlow allows for an even simpler way.

Let's break this down slightly before moving on, though. Basically, model_fn is a function that takes three inputs (features, labels, and mode) and returns the behavior of the estimator for three modes (PREDICT, EVAL, and TRAIN). features and labels will be outputs of the input function that we will cover in the next section, and mode is one of tf.estimator.ModeKeys. For example, when you call estimator.train, tf.estimator.ModeKeys.TRAIN will be passed.

The output of model_fn should be an instance of tf.estimator.EstimatorSpec class. As mentioned in the TensorFlow documentation, different modes require different arguments:

  • For mode == ModeKeys.TRAIN: required fields are loss and train_op.

  • For mode == ModeKeys.EVAL: required field is loss (optional: eval_metric_ops).

  • For mode == ModeKeys.PREDICT: required field is predictions.

Converting a Keras Model into an Estimator

The easiest way of creating an estimator, in my opinion, is writing your model in Keras and then converting it to an Estimator.

Tada! Wasn't that much simpler? It's very simple to add simple metrics like accuracy into a Keras model. Keras also conveniently takes care of the different operations that happen for different modes.


You might be wondering what config is in the above code snippets. This is an instance of RunConfig class that describes the training time setting (for example, where and how often to save the checkpoints).

In the above example, checkpoints & Tensorboard summaries will be saved to the directory /logs/ every 100 steps. At most 2 checkpointswill be saved, and global_steps/s will be logged every 10 steps.

Input functions

Now that we have an Estimator, let's talk about the data.

This is where we define our input pipeline. There are many ways to do this. To keep this blog post simple, I will show you how to feed the data from a NumPy array using the numpy_input_fn function. We will dedicate another blog post to describe different input pipelines in the near future.

One important thing to note here: make sure your keys for x are correct. You can have multiple input tensors in your graph, and you need to have an entry in the x dictionary for each of these input tensors. In order for TensorFlow to know which input goes where, make sure the key matches the name you defined in the graph/model. In the above example, I named my input tensor input, so I use input as the key.

Note: Since TensorFlow 1.7, this is not strictly required. If you only have one input tensor, then you can simply pass x=data.train.images instead of passing a dictionary.

Train and evaluate

Now that we have built our estimator and defined our input functions, all that's left is to run our training! We just wrap our input functions into TrainSpec and EvalSpec, then feed the three components into tf.estimator.train_and_evaluate:

Here TrainSpec and EvalSpec defines how the training and evaluation process looks like. In the above example, the Estimator will train using train_input_fn for maximum of 1,000,000 steps. Evaluation will start every 30 seconds starting immediately as training starts (or when a checkpoint is saved after a minimum of 30 seconds have passed from last evaluation). For each evaluation, the eval_input_fn will continue to be called until it raises end-of-input exception (steps=None). If evaluation takes a long time, you may want this throttle_secs to be much longer than how often a checkpoint is saved.

That's it! We did not have to set devices, use MonitoredTrainingSession, set up validation runs, etc. This code will run in both single and distributed mode. By default, this will run an evaluation step at a regular interval (either by number of steps or time) during the training, or you can customize this using Hooks.

Run in Distributed Mode

In order to run this code in distributed mode, we need to feed the cluster information about how to find the machines and what they're supposed to do. The Estimator class makes this pretty simple—all we need is an environment variable called TF_CONFIG that contains the relevant information (about the cluster and the current node's identity) in JSON format. Below is an example:

Last note: setting up a distributed environment is not trivial. With Clusterone, this is as simple as clicking a button, and it's basically free. You only pay for your AWS instance costs. TensorBoard is built into the platform, so all you need to do is click a button to visualize your results.

If you're using an Estimator class, some summaries will be saved for you automatically (such as loss, global steps per second, etc.). Below is an example of this code trained on C4.xlarge instances on AWS (the x-axis is time in hours). You can clearly see the speed-up in distributed learning.

The graph below shows the global_steps/s for three different configuration: a single C4.xlarge instance (blue line), a distributed configuration with 2 C4.xlarge instances (purple line), and a distributed configuration with 4 C4.xlarge instances (green line). As can be seen, there is a linear increase of steps per second.

distributed TensorFlow mnist global step


Below is the code you can run both on your local machine and on the Clusterone platform in asynchronous distributed mode. You can also find this code in our GitHub repository, including instruction on how to run it on Clusterone.