0
我有一个函数dice
Tensorflow - 应用功能在1D张量
def dice(yPred,yTruth,thresh):
smooth = tf.constant(1.0)
threshold = tf.constant(thresh)
yPredThresh = tf.to_float(tf.greater_equal(yPred,threshold))
mul = tf.mul(yPredThresh,yTruth)
intersection = 2*tf.reduce_sum(mul) + smooth
union = tf.reduce_sum(yPredThresh) + tf.reduce_sum(yTruth) + smooth
dice = intersection/union
return dice, yPredThresh
其中工程。一个例子是这里
with tf.Session() as sess:
thresh = 0.5
print("Dice example")
yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3])
yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3])
diceScore, yPredThresh= dice(yPred=yPred,yTruth=yTruth,thresh= thresh)
diceScore_ , yPredThresh_ , yPred_, yTruth_ = sess.run([diceScore,yPredThresh,yPred, yTruth])
print("\nScore = {0}".format(diceScore_))
>>> Score = 0.899999976158
我希望能够遍历骰子的第三arguement给出脱粒。我不知道这样做的最佳方式,以便我可以从图中提取它。因为我不能环路成,显然分割一个TF张量大致如下的东西线...
def diceROC(yPred,yTruth,thresholds=np.linspace(0.1,0.9,20)):
thresholds = thresholds.astype(np.float32)
nThreshs = thresholds.size
diceScores = tf.zeros(shape=nThreshs)
for i in xrange(nThreshs):
score,_ = dice(yPred,yTruth,thresholds[i])
diceScores[i] = score
return diceScores
评估diceScoreROC
产生错误'Tensor' object does not support item assignment
。
真棒,很好的答案谢谢 – mattdns