tf.gather_nd多次使用真的很慢

tf.gather_nd is really slow when used for many times

提问人:ZHANG Juenjie 提问时间:7/2/2018 最后编辑:MichaelZHANG Juenjie 更新时间:8/1/2018 访问量:645

问:

我想要一个 tensorflow 中的损失函数,它是许多元素的复杂组合。例如,以下代码:

import tensorflow as tf
import numpy as np
import time

input_layer = tf.placeholder(tf.float64, shape=[64,4])
output_layer = input_layer + 0.5*tf.tanh(tf.Variable(tf.random_uniform(shape=[64,4],\
                                                       minval=-1,maxval=1,dtype=tf.float64)))

# random_combination is 2-d numpy array of the form:
# [[32, 34, 23, 56],[23,54,33,21],...]
random_combination = np.random.randint(64, size=(210000000, 4))

# a collector to collect the values 
collector=[]

print('start looping')   
print(time.asctime(time.localtime(time.time())))

# loop through random_combination and pick the elements of output_layer
for i in range(len(random_combination)):
    [i,j,k,l] = [random_combination[i][0],random_combination[i][1],\
                 random_combination[i][2],random_combination[i][3]]

    # pick the needed element from output_layer
    f1 = tf.gather_nd(output_layer,[i,0])
    f2 = tf.gather_nd(output_layer,[i,2])
    f3 = tf.gather_nd(output_layer,[i,3])
    f4 = tf.gather_nd(output_layer,[i,4])

    tf1 = f1+1
    tf2 = f2+1
    tf3 = f3+1
    tf4 = f4+1
    collector.append(0.3*tf.abs(f1*f2*tf3*tf4-tf1*tf2*f3*f4))

print('end looping')   
print(time.asctime(time.localtime(time.time())))

# loss function
loss = tf.add_n(collector)

这在我的电脑上大约需要 50 分钟。 我的问题是,这是在 tensorflow 中进行编码的正确方法吗? 或者有一种更省时的方法来索引元素?

TensorFlow 切片

评论

0赞 tucan9389 8/1/2021
我有同样的问题。

答: 暂无答案