diff --git a/train.py b/train.py index 5cf6bb4..851c777 100644 --- a/train.py +++ b/train.py @@ -174,12 +174,12 @@ if args.pretrained_disp: print("=> using pre-trained weights for DispResNet") weights = torch.load(args.pretrained_disp) - disp_net.load_state_dict(weights['state_dict'], strict=False) + disp_net.load_state_dict(weights, strict=False) if args.pretrained_pose: print("=> using pre-trained weights for PoseResNet") weights = torch.load(args.pretrained_pose) - pose_net.load_state_dict(weights['state_dict'], strict=False) + pose_net.load_state_dict(weights, strict=False) print('=> setting adam solver') optim_params = [