2017-10-11 133 views
0

当我运行下面的代码时,变量类型变成了torch.LongTensor。我怎样才能创建一个torch.cuda.LongTensor而不是?define torch.cuda.LongTensor而不是火炬。LongTensor

# Turn string into list of longs 
def char_tensor(string): 
    tensor = torch.zeros(len(string)).long() 
    for c in range(len(string)):   
     tensor[c] = all_characters.index(string[c]) 
    return Variable(tensor) 

print(char_tensor('abcDEF')) 

输出:

Variable containing: 
10 
11 
12 
39 
40 
41 
[torch.LongTensor of size 6] 

回答

0

正确答案:

# Turn string into list of longs 
def char_tensor(string): 
    tensor = torch.zeros(len(string)).long() 
    for c in range(len(string)):   
     tensor[c] = all_characters.index(string[c]) 
    return Variable(tensor).cuda() 

print(char_tensor('abcDEF'))