提问人:Paulo Pacheco 提问时间:11/19/2019 最后编辑:Paulo Pacheco 更新时间:11/19/2019 访问量:95
递归分配给 Tensorflow 中的变量切片
Recursively Assign to Variable Slices in Tensorflow
问:
我想以递归方式为 Tensorflow (1.15) 变量中的切片赋值。
举例来说,这是有效的:
def test_loss():
m = tf.Variable(1)
n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return 1
test_loss()
Out: 1
然后我试了一下:
def test_loss():
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
for n in range(5):
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return 1
test_loss()
但这会返回一条错误消息:
---> 10 A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
...
ValueError: Sliced assignment is only supported for variables
我知道“assign”返回的不是“变量”,因此在下一个循环中传递“A”将 不再找到“变量”。
然后我试了一下:
def test_loss():
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
for n in range(5):
A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))
return 1
test_loss()
然后我得到了:
InvalidArgumentError: Input 'ref' passed float expected ref type while building NodeDef...
关于我可以递归地为 Tensorflow 变量切片赋值的任何想法?
答:
0赞
thushv89
11/19/2019
#1
以下是使用 和 的一些见解。tf.Variable
assign()
第一个失败的解决方案
for n in range(5):
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
当你这样做时,它实际上返回一个张量(即不是 )。因此,它适用于第一次迭代。从下一次迭代开始,您将尝试将值分配给 ,这是不允许的。A.assign(B)
tf.Variable
tf.Tensor
第二个失败的解决方案
for n in range(5):
A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))
这又是一个非常糟糕的主意,因为你是在循环中创建变量。这样做足够多,你就会耗尽内存。但这甚至不会运行,因为你最终陷入了一个时髦的僵局。您正在尝试创建具有某些张量的变量,这些张量将在 Graph 执行时计算。要执行图形,您需要变量。
正确的方法
我能想到的最好的方法是返回更新操作并创建一个 TensorFlow 占位符。在运行会话的每次迭代中,您都会传递一个值(即当前迭代)。test_loss
n
n
def test_loss(n):
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
update = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return update
with tf.Session() as sess:
tf_n = tf.placeholder(shape=None, dtype=tf.int32, name='n')
update_op = test_loss(tf_n)
print(type(update_op))
tf.global_variables_initializer().run()
for n in range(5):
print(1)
#print(sess.run(update_op, feed_dict={tf_n: n}))
评论
0赞
Paulo Pacheco
12/6/2019
感谢您的见解。在我看来,在这种情况下,使用 PyTorch 是一个更好的范式(更“pythonic”)。
0赞
thushv89
12/6/2019
我同意,如果您有动态操作,则可以使用 Pytorch。但你也可以试试TF2。他们转向了急切执行,这意味着张量会立即执行。
评论