Error while using Tensorflow-Hub and Colab TPU

I am trying to use BERT for text classification using Tensorflow hub. The code runs fine on Colab GPU but when I converted it for Colab TPU it shows up the following 'uninitialized layer' error. Following is the Bert Layer-

class BertLayer(tf.keras.layers.Layer):
          def __init__(self, n_fine_tune_layers, **kwargs):
              self.n_fine_tune_layers = n_fine_tune_layers
              self.trainable = True
              self.output_size = 768
              super(BertLayer, self).__init__(**kwargs)

          def build(self, input_shape):
              self.bert = hub.Module(
                  bert_path,
                  trainable=True,# did this in place of self.trainable
                  name="{}_module".format(self.name)
              )

              trainable_vars = self.bert.variables


              trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]
              #print("--------------------------len=",len(trainable_vars))
              # Select how many layers to fine tune
              trainable_vars = trainable_vars[-self.n_fine_tune_layers:]

              # Add to trainable weights
              for var in trainable_vars:
                  self._trainable_weights.append(var)

              for var in self.bert.variables:
                  if var not in self._trainable_weights:
                      self._non_trainable_weights.append(var)

              super(BertLayer, self).build(input_shape)

          def call(self, inputs):
              inputs = [K.cast(x, dtype="int32") for x in inputs]
              input_ids, input_mask, segment_ids = inputs
              bert_inputs = dict(
                  input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
              )
              result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
                  "pooled_output"
              ]
              return result

          def compute_output_shape(self, input_shape):
              return (input_shape[0], self.output_size)

Following is my model -

print("-----------------------------1")
      from tensorflow.keras.layers import Input,Dense
      #from tensorflow.keras.callbacks import TensorBoard
      #Name="test run-{}".format(1)
      #tensorboard=TensorBoard(log_dir="logs/{}".format(Name))
      in_id=Input(shape=(max_seq_length,),)
      in_mask=Input(shape=(max_seq_length,),)
      in_segment=Input(shape=(max_seq_length,),)
      print("-----------------------------2")
      in_id = tf.keras.layers.Input(shape=(max_seq_length,))
      in_mask = tf.keras.layers.Input(shape=(max_seq_length,))
      in_segment = tf.keras.layers.Input(shape=(max_seq_length,))
      print("-----------------------------3")
      bert_inputs=[in_id,in_mask,in_segment]
      bert_outputs=BertLayer(n_fine_tune_layers=100)(bert_inputs)
      #step=bert_outpu
      step=tf.keras.layers.Dropout(rate=0.1)(bert_outputs)
      step=tf.keras.layers.Dense(512,activation='relu',kernel_initializer='glorot_normal')(step)
      #step=tf.keras.layers.Dense(256,activation='relu')(step)
      step=tf.keras.layers.Dropout(rate=dropout)(step)
      step=tf.keras.layers.Dense(256,activation='relu',kernel_initializer='glorot_normal')(step)
      #dense=tf.keras.layers.Dense(256,activation='relu')(bert_outputs)
      step=tf.keras.layers.Dropout(rate=dropout)(step)
      pred=tf.keras.layers.Dense(1,activation='sigmoid')(step)
      model=tf.keras.Model(inputs=bert_inputs,outputs=pred)
      print("-----------------------------4")
      model.compile(loss='binary_crossentropy',
                    optimizer=tf.train.AdamOptimizer(lr),
                    metrics=[f1,'accuracy'])
      print("-----------------------------5")
      tpu_model = tf.contrib.tpu.keras_to_tpu_model(model,
                                                    strategy=tf.contrib.tpu.TPUDistributionStrategy(tf.contrib.cluster_resolver.TPUClusterResolver(tpu_address)))

      print("-----------------------------6")
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())
      sess.run(tf.tables_initializer())
      K.set_session(sess)
      print("-----------------------------7")
      tpu_model.fit([train_input_ids, train_input_masks, train_segment_ids],
                train_labels,
                epochs=epochs,
                batch_size=64)
                #validation_data=([val_input_ids, val_input_masks, val_segment_ids],val_labels))
      model.evaluate([test_input_ids, test_input_masks, test_segment_ids],test_labels)

And follwing is the Error-

-----------------------------1
-----------------------------2
-----------------------------3
INFO:tensorflow:Saver not created because there are no variables in the graph to restore

I0605 10:36:17.033424 140062933383040 saver.py:1483] Saver not created because there are no variables in the graph to restore

-----------------------------4
-----------------------------5
INFO:tensorflow:Querying Tensorflow master (grpc://10.85.103.202:8470) for TPU system metadata.

I0605 10:36:17.748405 140062933383040 tpu_system_metadata.py:59] Querying Tensorflow master (grpc://10.85.103.202:8470) for TPU system metadata.

INFO:tensorflow:Found TPU system:

I0605 10:36:17.768394 140062933383040 tpu_system_metadata.py:120] Found TPU system:

INFO:tensorflow:*** Num TPU Cores: 8

I0605 10:36:17.770817 140062933383040 tpu_system_metadata.py:121] *** Num TPU Cores: 8

INFO:tensorflow:*** Num TPU Workers: 1

I0605 10:36:17.773086 140062933383040 tpu_system_metadata.py:122] *** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Cores Per Worker: 8

I0605 10:36:17.775260 140062933383040 tpu_system_metadata.py:124] *** Num TPU Cores Per Worker: 8

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 8086897810259541316)

I0605 10:36:17.779561 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 8086897810259541316)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 8309407237260141527)

I0605 10:36:17.782429 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 8309407237260141527)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 7140089854169573112)

I0605 10:36:17.785550 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 7140089854169573112)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 17762152438583970348)

I0605 10:36:17.789351 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 17762152438583970348)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 12631201787268957494)

I0605 10:36:17.793601 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 12631201787268957494)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8708359633115695081)

I0605 10:36:17.796261 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8708359633115695081)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 601478800410838022)

I0605 10:36:17.800481 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 601478800410838022)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 16793071921697081555)

I0605 10:36:17.804739 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 16793071921697081555)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 16730824918382181321)

I0605 10:36:17.807698 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 16730824918382181321)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 11133990522845180639)

I0605 10:36:17.810022 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 11133990522845180639)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 18001585464951191022)

I0605 10:36:17.812952 140062933383040 tpu_system_metadata.py:126] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 18001585464951191022)

WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.

W0605 10:36:17.816158 140062933383040 experimental.py:63] tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.

---------------------------------------------------------------------------

InvalidArgumentError                      Traceback (most recent call last)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1333     try:
-> 1334       return fn(*args)
   1335     except errors.OpError as e:

10 frames

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1318       return self._call_tf_sessionrun(
-> 1319           options, feed_dict, fetch_list, target_list, run_metadata)
   1320 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
   1406         self._session, options, feed_dict, fetch_list, target_list,
-> 1407         run_metadata)
   1408 

InvalidArgumentError: In ReadVariableOp the following variables were found uninitialized: bert_layer_12_module/bert/embeddings/LayerNorm/beta, bert_layer_12_module/bert/embeddings/LayerNorm/gamma, bert_layer_12_module/bert/embeddings/position_embeddings, bert_layer_12_module/bert/embeddings/token_type_embeddings, bert_layer_12_module/bert/embeddings/word_embeddings, bert_layer_12_module/bert/encoder/layer_0/attention/output/LayerNorm/beta, bert_layer_12_module/bert/encoder/layer_0/attention/output/LayerNorm/gamma, bert_layer_12_module/bert/encoder/layer_0/attention/output/dense/bias, bert_layer_12_module/bert/encoder/layer_0/attention/output/dense/kernel, bert_layer_12_module/bert/encoder/layer_0/attention/self/key/bias, bert_layer_12_module/bert/encoder/layer_0/attention/self/key/kernel, bert_layer_12_module/bert/encoder/layer_0/attention/self/query/bias, bert_layer_12_module/bert/encoder/layer_0/attention/self/query/kernel, bert_layer_12_module/bert/encoder/layer_0/attention/self/value/bias, bert_layer_12_module/bert/encoder/layer_0/attention/self/value/kernel, bert_layer_12_module/bert/encoder/layer_0/intermediate/dense/bias, bert_layer_12_module/bert/encoder/layer_0/intermediate/dense/kernel, bert_layer_12_module/bert/encoder/layer_0/output/LayerNorm/beta, bert_layer_12_module/bert/encoder/layer_0/output/LayerNorm/gamma, bert_layer_12_module/bert/encoder/layer_0/output/dense/bias, bert_layer_12_module/bert/encoder/layer_0/output/dense/kernel, bert_layer_12_m...
     [[{{node ReadVariables_14728137872467799544/_1}}]]


During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-40-fdb59f59a0ef> in <module>()
     82       print("-----------------------------5")
     83       tpu_model = tf.contrib.tpu.keras_to_tpu_model(model,
---> 84                                                     strategy=tf.contrib.tpu.TPUDistributionStrategy(tf.contrib.cluster_resolver.TPUClusterResolver(tpu_address)))
     85 
     86       print("-----------------------------6")

/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/framework/python/framework/experimental.py in new_func(*args, **kwargs)
     62         'any time, and without warning.',
     63         decorator_utils.get_qualified_name(func), func.__module__)
---> 64     return func(*args, **kwargs)
     65   new_func.__doc__ = _add_experimental_function_notice_to_docstring(
     66       func.__doc__)

/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in tpu_model(model, strategy)
   2219     else:
   2220       optimizer_config = None
-> 2221     model_weights = model.get_weights()
   2222   else:
   2223     model_weights = None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in get_weights(self)
    390     for layer in self.layers:
    391       weights += layer.weights
--> 392     return backend.batch_get_value(weights)
    393 
    394   def set_weights(self, weights):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/backend.py in batch_get_value(tensors)
   2817     raise RuntimeError('Cannot get value inside Tensorflow graph function.')
   2818   if tensors:
-> 2819     return get_session().run(tensors)
   2820   else:
   2821     return []

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    927     try:
    928       result = self._run(None, fetches, feed_dict, options_ptr,
--> 929                          run_metadata_ptr)
    930       if run_metadata:
    931         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1150     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1151       results = self._do_run(handle, final_targets, final_fetches,
-> 1152                              feed_dict_tensor, options, run_metadata)
   1153     else:
   1154       results = []

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1326     if handle is None:
   1327       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1328                            run_metadata)
   1329     else:
   1330       return self._do_call(_prun_fn, handle, feeds, fetches)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1346           pass
   1347       message = error_interpolation.interpolate(message, self._graph)
-> 1348       raise type(e)(node_def, op, message)
   1349 
   1350   def _extend_graph(self):

InvalidArgumentError: In ReadVariableOp the following variables were found uninitialized: bert_layer_12_module/bert/embeddings/LayerNorm/beta, bert_layer_12_module/bert/embeddings/LayerNorm/gamma, bert_layer_12_module/bert/embeddings/position_embeddings, bert_layer_12_module/bert/embeddings/token_type_embeddings, bert_layer_12_module/bert/embeddings/word_embeddings, bert_layer_12_module/bert/encoder/layer_0/attention/output/LayerNorm/beta, bert_layer_12_module/bert/encoder/layer_0/attention/output/LayerNorm/gamma, bert_layer_12_module/bert/encoder/layer_0/attention/output/dense/bias, bert_layer_12_module/bert/encoder/layer_0/attention/output/dense/kernel, bert_layer_12_module/bert/encoder/layer_0/attention/self/key/bias, bert_layer_12_module/bert/encoder/layer_0/attention/self/key/kernel, bert_layer_12_module/bert/encoder/layer_0/attention/self/query/bias, bert_layer_12_module/bert/encoder/layer_0/attention/self/query/kernel, bert_layer_12_module/bert/encoder/layer_0/attention/self/value/bias, bert_layer_12_module/bert/encoder/layer_0/attention/self/value/kernel, bert_layer_12_module/bert/encoder/layer_0/intermediate/dense/bias, bert_layer_12_module/bert/encoder/layer_0/intermediate/dense/kernel, bert_layer_12_module/bert/encoder/layer_0/output/LayerNorm/beta, bert_layer_12_module/bert/encoder/layer_0/output/LayerNorm/gamma, bert_layer_12_module/bert/encoder/layer_0/output/dense/bias, bert_layer_12_module/bert/encoder/layer_0/output/dense/kernel, bert_layer_12_m...
     [[{{node ReadVariables_14728137872467799544/_1}}]]

Please help me resolve this error.