DeepForest调试记录

DeepForest 树检测

在原本安装好 cudatoolkit 11.3.1 对应版本pytorch以及python version 3.9 环境中,执行:
2022-01-26 22:23

conda install deepforest albumentations -c conda-forge

该命令下载了 deepforest-1.2.0

使用新数据集重新训练模型 [官方给出demo]

#load the modules
import os
import time
import cv2
import numpy as np
from deepforest import main 
from deepforest import get_data
from deepforest import utilities
from deepforest import preprocess
from PIL import Image
import rasterio
import matplotlib.pyplot as plt
# convert hand annotations from xml into retinanet format
# The get_data function is only needed when fetching sample package data
YELL_xml = get_data("2019_YELL_2_528000_4978000_image_crop2.xml")
annotation = utilities.xml_to_annotations(YELL_xml)
annotation.head()

# load the image file corresponding to the annotaion file
YELL_train = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
image_path = os.path.dirname(YELL_train)
# Write converted dataframe to file. Saved alongside the images
annotation.to_csv(os.path.join(image_path,"train_example.csv"), index=False)

# prepare training data and valid data 

#Find annotation path
annotation_path = os.path.join(image_path,"train_example.csv")
# crop images will save in a newly created directory
# os.mkdir(os.getcwd(),'train_data_folder')
crop_dir = os.path.join(os.getcwd(),'train_data_folder')
train_annotations= preprocess.split_raster(path_to_raster=YELL_train,
                                 annotations_file=annotation_path,
                                 base_dir=crop_dir,
                                 patch_size=400,
                                 patch_overlap=0.05)

# Split image crops into training and test. Normally these would be different tiles! Just as an example.
image_paths = train_annotations.image_path.unique()
# split 25% validation annotation
valid_paths = np.random.choice(image_paths, int(len(image_paths)*0.25) )
valid_annotations = train_annotations.loc[train_annotations.image_path.isin(valid_paths)]
train_annotations = train_annotations.loc[~train_annotations.image_path.isin(valid_paths)]

# View output
train_annotations.head()
print("There are {} training crown annotations".format(train_annotations.shape[0]))
print("There are {} test crown annotations".format(valid_annotations.shape[0]))

# save to file and create the file dir
annotations_file= os.path.join(crop_dir,"train.csv")
validation_file= os.path.join(crop_dir,"valid.csv")
# Write window annotations file without a header row, same location as the "base_dir" above.
train_annotations.to_csv(annotations_file,index=False)
valid_annotations.to_csv(validation_file,index=False)

# print(annotations_file)

# initial the model and change the corresponding config file
m = main.deepforest()
m.config['gpus'] = '-1' #move to GPU and use all the GPU resources
m.config["train"]["csv_file"] = annotations_file
m.config["train"]["root_dir"] = os.path.dirname(annotations_file)
m.config["score_thresh"] = 0.4
m.config["train"]['epochs'] = 2
m.config["validation"]["csv_file"] = validation_file
m.config["validation"]["root_dir"] = os.path.dirname(validation_file)
# create a pytorch lighting trainer used to training 
m.create_trainer()
# load the lastest release model 
m.use_release()

print("data ok")

start_time = time.time()
m.trainer.fit(m)
print(f"--- Training on GPU: {(time.time() - start_time):.2f} seconds ---")

print()
print(annotations_file)
print()

# annotations_file='/home/pikapikaq/Desktop/TreeHeight/test/1.png'
# annotations_file='/home/pikapikaq/anaconda3/envs/workspace/lib/python3.9/site-packages/deepforest/data/OSBS_029.csv'

# save the prediction result to a prediction folder
save_dir = os.path.join(os.getcwd(),'pred_result')
try:
  os.mkdir(save_dir)
except FileExistsError:
  pass
results = m.evaluate(annotations_file, os.path.dirname(annotations_file),iou_threshold = 0.4, savedir= save_dir)

# csv_file = '/home/pikapikaq/Desktop/TreeHeight/DeepTree/train_data_folder/2019_YELL_2_528000_4978000_image_crop2_5.png'#'/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg' #'/home/pikapikaq/anaconda3/envs/workspace/lib/python3.9/site-packages/deepforest/data/OSBS_029.tif'
# img=cv2.imread(csv_file)
# img = img.astype(np.float32)/255
# #print(img.dtype)
# # print(img)
# df = m.predict_image(image=img,return_plot=True)#root_dir = os.path.dirname(csv_file))
# print(df)

# csv_file = '/home/pikapikaq/anaconda3/envs/workspace/lib/python3.9/site-packages/deepforest/data/OSBS_029.csv'
# df = m.predict_file(csv_file, root_dir = os.path.dirname(csv_file))
# print(df)

# csv_file = '/home/pikapikaq/Desktop/TreeHeight/DeepTree/pic/OSBS_029.tif' #'/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg'
# img=Image.open(csv_file)

# img_arr = np.array(img)
# img_arr = img_arr.astype(np.float32)/255
# print(img_arr.shape) # uint8
# print(img_arr.dtype)

# df = m.predict_image(image=img_arr,return_plot=True) # 
# print(df)

# raster = '/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg'# get_data("2019_YELL_2_528000_4978000_image_crop2.png")
# # '/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg'
# src = rasterio.open(raster)
# #s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
# print(src.read().shape) # #(3, 2472, 2299)

# predicted_boxes =m.predict_tile(raster_path = raster,
#                                         patch_size = 300,
#                                         patch_overlap = 0.5,
#                                         return_plot = True)
# plt.imshow(predicted_boxes[:,:,::-1])
# plt.show()

# print(predicted_boxes)

运行Demo时 出现报错:

KeyError                                  Traceback (most recent call last)

[<ipython-input-42-0501e4367b85>](https://localhost:8080/#) in ()
      1 start_time = time.time()
      2 
----> 3 m.trainer.fit(m)
      4 
      5 print(f"--- Training on CPU: {(time.time() - start_time):.2f} seconds ---")

16 frames

[/usr/local/lib/python3.7/dist-packages/deepforest/main.py](https://localhost:8080/#) in load_dataset(self, csv_file, root_dir, augment, shuffle, batch_size, train)
    167                                  transforms=self.transforms(augment=augment),
    168                                  label_dict=self.label_dict,
--> 169                                  preload_images=self.config["train"]["preload_images"])
    170 
    171         data_loader = torch.utils.data.DataLoader(

KeyError: 'preload_images'

通过讨论区get到,需要更新到最新的版本,执行:

pip install deepforest --upgrade

The version 1.2.1 has no error.

你可能感兴趣的:(环境配置,python,深度学习,开发语言)