How to export a TensorFlow model ?

Posted on 06 Aug 2017

Because the documentation on TensorFlow Serving about how to export a saved model is not clear. This article is intended to give a complete code snippet.
Suppose that we have already a trained model which deals with a 2-class prediction problem.

import os
import logging

import tensorflow as tf
from tensorflow.python.saved_model import builder, tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils import build_signature_def
from tensorflow.python.saved_model.utils import build_tensor_info


class MyModel(object):
    def __init__(self, sess):
        self._build_inference_graph()
        sess.run(tf.global_variables_initializer())

    def _build_inference_graph(self):
        x = tf.placeholder(tf.float32, [1, 10])  # Suppose each x has 10 dimensions
        W = tf.Variable(tf.zeros([10, 2]))
        b = tf.Variable(tf.zeros([2]))
        prediction = tf.argmax(tf.nn.softmax(tf.matmul(x, W) + b))
        self._x = x
        self._prediction = prediction

    @property
    def x(self):
        return self._x

    @property
    def prediction(self):
        return self._prediction

SAVED_MODEL_PATH = "Where you saved your trained model"
VERSION = "Latest version"

with tf.Graph().as_default():
g = tf.get_default_graph()
with tf.Session(graph=g) as session, tf.device("/gpu:0"):
    my_model = MyModel(sess=session)
   export_path = os.path.join(SAVED_MODEL_PATH, VERSION)
   logging.info("Exporting Model to %s" % export_path)
   model_builder = builder.SavedModelBuilder(export_path)
   signature_predict = build_signature_def(
       inputs={'x': build_tensor_info(my_model.x)},
       outputs={'prediction': build_tensor_info(my_model.prediction)},
       method_name=signature_constants.PREDICT_METHOD_NAME
    )
   legacy_init_op = tf.group(tf.initialize_all_tables(), name='legacy_init_op')
   model_builder.add_meta_graph_and_variables(
       session, [tag_constants.SERVING],
       signature_def_map={
           'predict_class': signature_predict},
       legacy_init_op=legacy_init_op)
   model_builder.save()

As the code above shows, we have to specify 2 things:

  • SAVEDMODELPATH: Where you saved your trained model
  • VERSION: A number meaning the latest version