diff --git a/federated-learning/duet_mnist/MNIST_Syft_Data_Scientist.ipynb b/federated-learning/duet_mnist/MNIST_Syft_Data_Scientist.ipynb index 1fae98a..5e31aad 100644 --- a/federated-learning/duet_mnist/MNIST_Syft_Data_Scientist.ipynb +++ b/federated-learning/duet_mnist/MNIST_Syft_Data_Scientist.ipynb @@ -686,16 +686,31 @@ "import PIL.ImageOps \n", "\n", "import os\n", + "\n", + "transform = torchvision.transforms.ToTensor()\n", + "\n", + "\n", + "def image_to_tensor(im: PIL.Image.Image) -> torch.Tensor:\n", + " \"\"\"Converts given PIL.Image.Image object to torch.Tensor\n", + " \"\"\"\n", + " \n", + " im_transformed = transform(im)\n", + " im_transformed = im_transformed[0].clone().detach()\n", + " \n", + " return im_transformed\n", + "\n", + "\n", "def classify_url_image(image_url):\n", " filename = os.path.basename(image_url)\n", " os.system(f'curl -O {image_url}')\n", " im = Image.open(filename)\n", " im = PIL.ImageOps.invert(im)\n", - "# im = im.resize((28,28), Image.ANTIALIAS)\n", + " im = im.resize((28,28), Image.ANTIALIAS)\n", " im = im.convert('LA')\n", " enhancer = ImageEnhance.Brightness(im)\n", " im = enhancer.enhance(3)\n", - "\n", + " # convert the image to a torch.tensor to send it to the model for prediction \n", + " im_tensor = image_to_tensor(im)\n", "\n", " print(im.size)\n", " fig = plt.figure()\n", @@ -703,7 +718,7 @@ " plt.imshow(im, cmap=\"gray\", interpolation=\"none\")\n", " \n", " # classify local\n", - " class_num, preds = classify_local(image_1, local_model)\n", + " class_num, preds = classify_local(im_tensor, local_model)\n", " print(f\"Prediction: {class_num}\")\n", " print(preds)" ] @@ -714,8 +729,8 @@ "metadata": {}, "outputs": [], "source": [ - "# image_url = \"https://raw.githubusercontent.com/kensanata/numbers/master/0018_CHXX/0/number-100.png\"\n", - "# classify_url_image(image_url)" + "image_url = \"https://raw.githubusercontent.com/kensanata/numbers/master/0018_CHXX/0/number-100.png\"\n", + "classify_url_image(image_url)" ] }, { @@ -728,7 +743,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -742,7 +757,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.9.7" }, "pycharm": { "stem_cell": {