diff --git a/prediction/pred_pose.py b/prediction/pred_pose.py index f6aa77a..827480e 100644 --- a/prediction/pred_pose.py +++ b/prediction/pred_pose.py @@ -73,7 +73,7 @@ pose_net = models.PoseResNet(18, False).to(device) -pose_net = torch.nn.DataParallel(pose_net) +# pose_net = torch.nn.DataParallel(pose_net) weights = torch.load(args.weight_pth) pose_net.load_state_dict(weights) pose_net.eval()