Torchvision介紹
Torchvision是基于Pytorch的視覺深度學習遷移學習訓練框架,當前支持的圖像分類、對象檢測、實例分割、語義分割、姿態評估模型的遷移學習訓練與評估。支持對數據集的合成、變換、增強等,此外還支持預訓練模型庫下載相關的模型,直接預測推理。
預訓練模型使用
Torchvision從0.13版本開始預訓練模型支持多源backbone設置,以圖像分類的ResNet網絡模型為例:
支持多個不同的數據集上不同精度的預訓練模型,下載模型,轉化為推理模型
對輸入圖像實現預處理
本地加載模型
Torchvision中支持的預訓練模型當你使用的時候都會加載模型的預訓練模型,然后才可以加載你自己的權重文件,如果你不想加載torchvision的預訓練模型,只想從本地加載pt或者pth文件實現推理或者訓練的時候,一定要通過下面的方式完成,以Faster-RCNN為例:
# Load the model from local host num_classes = len(self.labels) self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=num_classes, pretrained_backbone=False) self.model.load_state_dict(torch.load(self.model_file)) self.model.eval() self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) # 使用GPU train_on_gpu = torch.cuda.is_available() if train_on_gpu: self.model.cuda()
就這樣解鎖了在torchvision框架下如何從本地加載預訓練模型文件或者定義訓練模型文件。
審核編輯:湯梓紅
-
模型
+關注
關注
1文章
3486瀏覽量
49990 -
深度學習
+關注
關注
73文章
5554瀏覽量
122478 -
pytorch
+關注
關注
2文章
809瀏覽量
13763
原文標題:torchvision中怎么加載本地模型實現訓練與推理
文章出處:【微信號:CVSCHOOL,微信公眾號:OpenCV學堂】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
評論