最近,在学习吴恩达老师的dl课程作业时,遇到了一些问题,代码如下
## START CODE HERE ##
my_image = "my_image.jpg" # change this to the name of your image file
my_label_y = [1] # the true class of your image (1 -> cat, 0 -> non-cat)
## END CODE HERE ##
fname = "images/" + my_image
image = np.array(ndimage.imread(fname, flatten=False))
my_image = scipy.misc.imresize(image, size=(num_px,num_px)).reshape((num_px*num_px*3,1))
my_predicted_image = predict(my_image, my_label_y, parameters)
plt.imshow(image)
print ("y = " + str(np.squeeze(my_predicted_image)) + ", your L-layer model predicts a \"" + classes[int(np.squeeze(my_predicted_image)),].decode("utf-8") + "\" picture.")
我按照源代码运行,出现了以下几种错误
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[63], line 7
4 ## END CODE HERE ##
6 fname = "images/" + my_image
----> 7 image = np.array(ndimage.imread(fname, flatten=False))
8 my_image = scipy.misc.imresize(image, size=(num_px,num_px)).reshape((num_px*num_px*3,1))
9 my_predicted_image = predict(my_image, my_label_y, parameters)
AttributeError: module 'scipy.ndimage' has no attribute 'imread'
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[65], line 9
7 fname = "images/" + my_image
8 image = Image.open(fname)
----> 9 my_image = scipy.misc.imresize(image, size=(num_px,num_px)).reshape((num_px*num_px*3,1))
10 my_predicted_image = predict(my_image, my_label_y, parameters)
12 plt.imshow(image)
AttributeError: module 'scipy.misc' has no attribute 'imresize'
在简单的查询官方文档后,得出是由于版本更新,一些方法被抛弃的原因。
那么接下来就是要用规范的方法将代码改动,以提高代码的实效性,首先分析代码做了什么
分析完毕开始操作
在百度上查看后,有的推荐降级,有的让使用其他的库,明显是比较粗糙的做法,于是查询官方找到目前推荐的打开图片方式,代码如下
from PIL import Image
my_image = "my_image.jpg"
image = Image.open(my_image)
第二部就是缩放图片了,代码如下
image = image.resize((num_px,num_px))
from PIL import Image
## START CODE HERE ##
my_image = "my_image.jpg" # change this to the name of your image file
my_label_y = [1] # the true class of your image (1 -> cat, 0 -> non-cat)
## END CODE HERE ##
fname = "images/" + my_image
num_px = 64
image = Image.open(fname)
image = image.resize((num_px,num_px))
my_image = np.array(image).reshape((num_px*num_px*3,1))
print(my_image.shape)
my_predicted_image = predict(my_image, my_label_y, parameters)
plt.imshow(image)
print ("y = " + str(np.squeeze(my_predicted_image)) + ", your L-layer model predicts a \"" + classes[int(np.squeeze(my_predicted_image)),].decode("utf-8") + "\" picture.")
遇到问题,先分析代码运行过程,然后合理查询文档,得到最优解决方案。既保证了代码的优美,又能掌握最新的知识,还能是代码有时效性,鲁棒性。