Pytorch有哪些坑?

1 torchvision.transform.ToTensor()中的细节

功能:把一个取值范围是[0,255]PIL.Image或者shape(H,W,C)numpy.ndarray,转换成形状为(C,H,W),取值范围是[0,1.0]torch.FloatTensor

注意:只有当numpy.ndarraydtype=unit8时才会将像素值scale到[0,1.0]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import numpy as np

from PIL import Image
from torchvision import transforms

transform = transforms.Compose([
transforms.ToTensor(),
])
img = Image.open('cat.png') # dtype=uint8

# case1: 正常情况
img1 = np.asarray(img)
img1 = transform(img1)
img1 = img3.permute([1, 2, 0]) # CxHxW -> HxWxC,方便后续对比
img1 = img1.cpu().data.numpy()
print(img1[300: 305, 695: 700, 0]) # 打印部分像素值
"""
[[0.43529412 0.44705883 0.45490196 0.45882353 0.45490196]
[0.41960785 0.43137255 0.4392157 0.44313726 0.4392157 ]
[0.40392157 0.41568628 0.42745098 0.43137255 0.43529412]
[0.3882353 0.40784314 0.41568628 0.41960785 0.42745098]
[0.38431373 0.4 0.4117647 0.41568628 0.42352942]]
"""

# case2: 错误操作,将unit8数据类型转换为float32,导致没有scale操作
img2 = np.asarray(img, dtype=np.float32) # 由于这里使得torch.ByteTensor=False,导致ToTensor()没有执行div(255)操作
img2 = transform(img2)
img2 = img2.permute([1, 2, 0])
img2 = img2.cpu().data.numpy()
print(img2[300: 305, 695: 700, 0])
"""
[[111. 114. 116. 117. 116.]
[107. 110. 112. 113. 112.]
[103. 106. 109. 110. 111.]
[ 99. 104. 106. 107. 109.]
[ 98. 102. 105. 106. 108.]]
"""

# case3: 正常情况,不使用ToTensor()执行scale操作
img3 = np.asarray(img, dtype=np.float32)
img3 = img3 / 255.
print(img3[300: 305, 695: 700, 0])
"""
[[0.43529412 0.44705883 0.45490196 0.45882353 0.45490196]
[0.41960785 0.43137255 0.4392157 0.44313726 0.4392157 ]
[0.40392157 0.41568628 0.42745098 0.43137255 0.43529412]
[0.3882353 0.40784314 0.41568628 0.41960785 0.42745098]
[0.38431373 0.4 0.4117647 0.41568628 0.42352942]]
"""

结论:使用torchvision.transform.ToTensor()时避免将图像的数据类型转为float32