Difference between Variable and get_variable in TensorFlow

As far as I know, Variable is the default operation for making a variable, and get_variable is mainly used for weight sharing.

On the one hand, there are some people suggesting using get_variable instead of the primitive Variable operation whenever you need a variable. On the other hand, I merely see any use of get_variable in TensorFlow’s official documents and demos.

Thus I want to know some rules of thumb on how to correctly use these two mechanisms. Are there any “standard” principles?

Answers:

Thank you for visiting the Q&A section on Magenaut. Please note that all the answers may not help you solve the issue immediately. So please treat them as advisements. If you found the post helpful (or not), leave a comment & I’ll get back to you as soon as possible.

Method 1

I’d recommend to always use tf.get_variable(...) — it will make it way easier to refactor your code if you need to share variables at any time, e.g. in a multi-gpu setting (see the multi-gpu CIFAR example). There is no downside to it.

Pure tf.Variable is lower-level; at some point tf.get_variable() did not exist so some code still uses the low-level way.

Method 2

tf.Variable is a class, and there are several ways to create tf.Variable including tf.Variable.__init__ and tf.get_variable.

tf.Variable.__init__: Creates a new variable with initial_value.

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Gets an existing variable with these parameters or creates a new one. You can also use initializer.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

It’s very useful to use initializers such as xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

More information here.

Method 3

I can find two main differences between one and the other:

  1. First is that tf.Variable will always create a new variable, whereas tf.get_variable gets an existing variable with specified parameters from the graph, and if it doesn’t exist, creates a new one.
  2. tf.Variable requires that an initial value be specified.

It is important to clarify that the function tf.get_variable prefixes the name with the current variable scope to perform reuse checks. For example:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

The last assertion error is interesting: Two variables with the same name under the same scope are supposed to be the same variable. But if you test the names of variables d and e you will realize that Tensorflow changed the name of variable e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Method 4

Another difference lies in that one is in ('variable_store',) collection but the other is not.

Please see the source code:

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Let me illustrate that:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('t%d: %s' % (k, str(store._vars)))
        except:
            print('t%d: %s' % (k, str(store)))
    print('')

The output:

collection ('__variable_store',): 0: {'word_embeddings_2':
<tf.Variable 'word_embeddings_2:0' shape=(30522, 1024)
dtype=float32_ref>}


All methods was sourced from stackoverflow.com or stackexchange.com, is licensed under cc by-sa 2.5, cc by-sa 3.0 and cc by-sa 4.0

0 0 votes
Article Rating
Subscribe
Notify of
guest

0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments
0
Would love your thoughts, please comment.x
()
x