pytorch 常用函数 max ,eq说明

这篇文章主要介绍了pytorch 常用函数 max eq说明,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

max找出tensor 的行或者列最大的值:

找出每行的最大值:

 import torch outputs=torch.FloatTensor([[1],[2],[3]]) print(torch.max(outputs.data,1))

输出:

(tensor([ 1., 2., 3.]), tensor([ 0, 0, 0]))

找出每列的最大值:

 import torch outputs=torch.FloatTensor([[1],[2],[3]]) print(torch.max(outputs.data,0))

输出结果:

(tensor([ 3.]), tensor([ 2]))

Tensor比较eq相等:

 import torch outputs=torch.FloatTensor([[1],[2],[3]]) targets=torch.FloatTensor([[0],[2],[3]]) print(targets.eq(outputs.data)) 

输出结果:

 tensor([[ 0], [ 1], [ 1]], dtype=torch.uint8)

使用sum() 统计相等的个数:

 import torch outputs=torch.FloatTensor([[1],[2],[3]]) targets=torch.FloatTensor([[0],[2],[3]]) print(targets.eq(outputs.data).cpu().sum()) 

输出结果:

tensor(2)

补充知识:PyTorch - torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le

flyfish

torch.eq、torch.ne、torch.gt、torch.lt、torch.ge、torch.le

以上全是简写

参数是input, other, out=None

逐元素比较input和other

返回是torch.BoolTensor

 import torch a=torch.tensor([[1, 2], [3, 4]]) b=torch.tensor([[1, 2], [4, 3]]) print(torch.eq(a,b))#equals # tensor([[ True, True], #     [False, False]]) print(torch.ne(a,b))#not equal to # tensor([[False, False], #     [ True, True]]) print(torch.gt(a,b))#greater than # tensor([[False, False], #     [False, True]]) print(torch.lt(a,b))#less than # tensor([[False, False], #     [ True, False]]) print(torch.ge(a,b))#greater than or equal to # tensor([[ True, True], #     [False, True]]) print(torch.le(a,b))#less than or equal to # tensor([[ True, True], #     [ True, False]]) 

以上就是pytorch 常用函数 max ,eq说明的详细内容,更多请关注0133技术站其它相关文章!

赞(0) 打赏
未经允许不得转载:0133技术站首页 » python