From 4e7bbc88debff27e8cb6a629603b382d7ac12cb7 Mon Sep 17 00:00:00 2001 From: Jeongseok Kang Date: Mon, 24 Jun 2024 09:57:55 +0900 Subject: [PATCH] feat: Add example notebooks for fine-tuning Gemma and Llama3 --- FineTuning_Gemma-Instruct-2B-tf2.ipynb | 593 +++++++++++++++++++++ FineTuning_Llama3-Instruct-8B-tf2.ipynb | 671 ++++++++++++++++++++++++ 2 files changed, 1264 insertions(+) create mode 100644 FineTuning_Gemma-Instruct-2B-tf2.ipynb create mode 100644 FineTuning_Llama3-Instruct-8B-tf2.ipynb diff --git a/FineTuning_Gemma-Instruct-2B-tf2.ipynb b/FineTuning_Gemma-Instruct-2B-tf2.ipynb new file mode 100644 index 0000000..f4def58 --- /dev/null +++ b/FineTuning_Gemma-Instruct-2B-tf2.ipynb @@ -0,0 +1,593 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --upgrade -q \\\n", + " keras-nlp==0.12.1 \\\n", + " keras==3.3.3 \\\n", + " jaxlib==0.4.30 \\\n", + " jax[cuda12]==0.4.30 \\\n", + " git+https://github.com/google-deepmind/gemma.git@a24194737dcb54b7392091e9ba772aea1cb68ffb \\\n", + " \\\n", + " kagglehub==0.2.5\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.6)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Path to model files: /home/work/.cache/kagglehub/models/keras/gemma/keras/gemma_instruct_2b_en/2\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "# TODO: Create Kaggle API token from https://www.kaggle.com/settings\n", + "os.environ[\"KAGGLE_USERNAME\"] = \"[TODO]\"\n", + "os.environ[\"KAGGLE_KEY\"] = \"[TODO]\"\n", + "\n", + "import kagglehub\n", + "\n", + "# Download latest version\n", + "model_path = kagglehub.model_download(\"keras/gemma/keras/gemma_instruct_2b_en\")\n", + "\n", + "print(\"Path to model files:\", model_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-06-24 00:51:50-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n", + "Resolving huggingface.co (huggingface.co)... 13.225.131.35, 13.225.131.94, 13.225.131.6, ...\n", + "Connecting to huggingface.co (huggingface.co)|13.225.131.35|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1719449510&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxOTQ0OTUxMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=n7shMBYyZI3sYv1048Q2%7ErjcP9XZif-Dw9O61kWssl%7EwVSdo%7EaXUbuqXA8PthV-Oj2y6RuGy6BtMs6e-2BfBHViAlHjuJBjjZ0HCzQm-9whQKAyfOHNC9yALXBYmpVCoAo9OsJGGj4j0PRBSCgFP3jaiZD-Jxlol-lundpz1kSYV10zTXt9ZgzS%7EanoVCgpTOTm5Xuu2%7EYM9KE9m1ROdYFmk6J9DmCiQGrc-BJLnAuFYuIHFOkttttoPePR5dTaf1jh1oe55lI3SxRT1JkU%7Ery6SCAjLVwIRKT0EDwXA3LAUYX2lK%7E%7EM3kjNzJvxsBAkamFCtSmEcZZCOJEA7HWf%7EQ__&Key-Pair-Id=K3ESJI6DHPFC7 [following]\n", + "--2024-06-24 00:51:50-- https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1719449510&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxOTQ0OTUxMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=n7shMBYyZI3sYv1048Q2%7ErjcP9XZif-Dw9O61kWssl%7EwVSdo%7EaXUbuqXA8PthV-Oj2y6RuGy6BtMs6e-2BfBHViAlHjuJBjjZ0HCzQm-9whQKAyfOHNC9yALXBYmpVCoAo9OsJGGj4j0PRBSCgFP3jaiZD-Jxlol-lundpz1kSYV10zTXt9ZgzS%7EanoVCgpTOTm5Xuu2%7EYM9KE9m1ROdYFmk6J9DmCiQGrc-BJLnAuFYuIHFOkttttoPePR5dTaf1jh1oe55lI3SxRT1JkU%7Ery6SCAjLVwIRKT0EDwXA3LAUYX2lK%7E%7EM3kjNzJvxsBAkamFCtSmEcZZCOJEA7HWf%7EQ__&Key-Pair-Id=K3ESJI6DHPFC7\n", + "Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.244.61.94, 18.244.61.40, 18.244.61.106, ...\n", + "Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.244.61.94|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 13085339 (12M) [text/plain]\n", + "Saving to: ‘datasets/databricks-dolly-15k.jsonl’\n", + "\n", + "datasets/databricks 100%[===================>] 12.48M --.-KB/s in 0.1s \n", + "\n", + "2024-06-24 00:51:51 (126 MB/s) - ‘datasets/databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n", + "\n" + ] + } + ], + "source": [ + "!mkdir -p datasets\n", + "!wget \\\n", + " https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl \\\n", + " -O datasets/databricks-dolly-15k.jsonl\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-process dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(15011, 4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8377baea0ac44cb49830947569cf34c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import json\n", + "from pathlib import Path\n", + "\n", + "import pandas as pd\n", + "from tqdm.notebook import tqdm\n", + "\n", + "num_samples = 1000\n", + "\n", + "dataset_path = Path().parent / \"datasets\" / \"databricks-dolly-15k.jsonl\"\n", + "data = pd.read_json(dataset_path, lines=True)\n", + "print(data.shape)\n", + "\n", + "prompt_template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", + "\n", + "preprocessed_data = []\n", + "for _, row in tqdm(data.iterrows()):\n", + " preprocessed_data.append(\n", + " prompt_template.format(\n", + " instruction=row[\"instruction\"],\n", + " response=row[\"response\"],\n", + " )\n", + " )\n", + "\n", + "# Only use a limited number of training examples\n", + "preprocessed_data = preprocessed_data[:num_samples]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine-tune" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-24 00:51:55.491401: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n", + "2024-06-24 00:52:06.671942: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\n", + "2024-06-24 00:52:06.674296: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\n", + "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n" + ] + }, + { + "data": { + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │\n",
+       "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", + "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ gemma_backbone                │ (None, None, 2048)        │   2,506,172,416 │ padding_mask[0][0],        │\n",
+       "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_embedding               │ (None, None, 256000)      │     524,288,000 │ gemma_backbone[0][0]       │\n",
+       "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2048\u001b[0m) │ \u001b[38;5;34m2,506,172,416\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m524,288,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 2,506,172,416 (9.34 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 2,506,172,416 (9.34 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Preprocessor: \"gemma_causal_lm_preprocessor\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"gemma_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │\n",
+       "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ gemma_tokenizer (\u001b[38;5;33mGemmaTokenizer\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", + "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"gemma_causal_lm\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"gemma_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ gemma_backbone                │ (None, None, 2048)        │   2,507,536,384 │ padding_mask[0][0],        │\n",
+       "│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_embedding               │ (None, None, 256000)      │     524,288,000 │ gemma_backbone[0][0]       │\n",
+       "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ gemma_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2048\u001b[0m) │ \u001b[38;5;34m2,507,536,384\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mGemmaBackbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256000\u001b[0m) │ \u001b[38;5;34m524,288,000\u001b[0m │ gemma_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 2,507,536,384 (9.34 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m2,507,536,384\u001b[0m (9.34 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 1,363,968 (5.20 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,363,968\u001b[0m (5.20 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 2,506,172,416 (9.34 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m2,506,172,416\u001b[0m (9.34 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-24 00:52:08.298407: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.\n", + "2024-06-24 00:54:14.215545: E external/xla/xla/service/slow_operation_alarm.cc:65] \n", + "********************************\n", + "[Compiling module gemm_fusion_dot.639] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.\n", + "********************************\n", + "2024-06-24 00:54:15.572648: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m1.357171695s\n", + "\n", + "********************************\n", + "[Compiling module gemm_fusion_dot.639] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.\n", + "********************************\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m174s\u001b[0m 42ms/step - loss: 0.5734 - sparse_categorical_accuracy: 0.4935\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-24 00:55:02.705848: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n", + "\n", + "import keras\n", + "import keras_nlp\n", + "\n", + "batch_size = 1\n", + "\n", + "model = keras_nlp.models.GemmaCausalLM.from_preset(str(model_path))\n", + "model.summary()\n", + "\n", + "model.backbone.enable_lora(rank=4)\n", + "model.summary()\n", + "\n", + "model.preprocessor.sequence_length = 512\n", + "optimizer = keras.optimizers.AdamW(\n", + " learning_rate=5e-5,\n", + " weight_decay=0.01,\n", + ")\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", + "\n", + "model.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")\n", + "\n", + "model.fit(preprocessed_data, epochs=1, batch_size=batch_size, verbose=1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instruction:\n", + "What should I do on a trip to Europe?\n", + "\n", + "Response:\n", + "There are two main types of trips to Europe: short trips and long trips. If you are looking to spend a weekend in Europe, there are many cities to choose from such as London, Paris, Rome, and Amsterdam. If you are looking to spend several weeks in Europe, there are many cities and countries to explore such as Barcelona, Berlin, and Prague.\n", + "Instruction:\n", + "Explain the process of photosynthesis in a way that a child could understand.\n", + "\n", + "Response:\n", + "Sure, here's the process of photosynthesis explained in simpler terms.\n", + "Sure, photosynthesis is when plants and other organisms use sunlight to convert water, carbon dioxide and energy to make food, or glucose. It's a process that helps us to get the food that we need to survive.\n", + "It's done by special cells called chloroplasts in plant and algal cells called chloroplasts in plant and algal cells.\n", + "The chloroplasts contain chlorophyll, a green pigment that absorbs the energy from the Sun.\n", + "When the chlorophyll absorbs sunlight, it uses the light energy to split water molecules, which are then used to produce oxygen and glucose.\n", + "The glucose is a type of sugar that the plant uses for energy.\n", + "The oxygen is released from the leaves of the plants as waste products of photosynthesis.\n", + "Photosynthesis also helps to regulate the Earth's atmosphere, as it helps to remove carbon dioxide and water vapor from the air.\n" + ] + } + ], + "source": [ + "prompts = [\n", + " prompt_template.format(\n", + " instruction=\"What should I do on a trip to Europe?\",\n", + " response=\"\",\n", + " ),\n", + " prompt_template.format(\n", + " instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n", + " response=\"\",\n", + " ),\n", + "]\n", + "\n", + "sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", + "model.compile(sampler=sampler)\n", + "\n", + "for prompt in prompts:\n", + " print(model.generate(prompt, max_length=256))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/FineTuning_Llama3-Instruct-8B-tf2.ipynb b/FineTuning_Llama3-Instruct-8B-tf2.ipynb new file mode 100644 index 0000000..6d481cc --- /dev/null +++ b/FineTuning_Llama3-Instruct-8B-tf2.ipynb @@ -0,0 +1,671 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --upgrade -q \\\n", + " keras-nlp==0.12.1 \\\n", + " keras==3.3.3 \\\n", + " \\\n", + " kagglehub==0.2.5\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.6)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Path to model files: /home/work/.cache/kagglehub/models/keras/llama3/keras/llama3_instruct_8b_en/3\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "# TODO: Create Kaggle API token from https://www.kaggle.com/settings\n", + "os.environ[\"KAGGLE_USERNAME\"] = \"[TODO]\"\n", + "os.environ[\"KAGGLE_KEY\"] = \"[TODO]\"\n", + "\n", + "import kagglehub\n", + "\n", + "# Download latest version\n", + "model_path = kagglehub.model_download(\"keras/llama3/keras/llama3_instruct_8b_en\")\n", + "\n", + "print(\"Path to model files:\", model_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-06-24 00:46:24-- https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl\n", + "Resolving huggingface.co (huggingface.co)... 13.225.131.94, 13.225.131.93, 13.225.131.35, ...\n", + "Connecting to huggingface.co (huggingface.co)|13.225.131.94|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1719449184&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxOTQ0OTE4NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=U%7EfbSjEKE4-67ixl%7E8KtfDIisyUzdES%7E%7ETsRBfa2TL2EShC-xCX5d4PBEWzwgM14DBRKU3WkFgVQuk5tH3ULbSWtjpcigVXVYuvGj1pOOBx1OhB958d-dgni3e6P4pO9FnkXT6DpqYwsNl4%7EU%7EapVOHk30jKdbovOxk8w%7EtN8zOHXXSvUS6oeTJY9ECnqD8FfMQ5h2ekoYVv-yAJd21x0oOyHlEKL9x0tvHSORvPqrKd43GqhmPOkeiJ-y2j7CXfKuM0KU7Fvi2vnMvbMbn-%7EJHaXrNE8qhL%7ELYQgPhZRNWr9Xx5-YqUKIMLv7zp0mHfh5C%7EP-sJs9iCPisRSCkitw__&Key-Pair-Id=K3ESJI6DHPFC7 [following]\n", + "--2024-06-24 00:46:24-- https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1719449184&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxOTQ0OTE4NH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=U%7EfbSjEKE4-67ixl%7E8KtfDIisyUzdES%7E%7ETsRBfa2TL2EShC-xCX5d4PBEWzwgM14DBRKU3WkFgVQuk5tH3ULbSWtjpcigVXVYuvGj1pOOBx1OhB958d-dgni3e6P4pO9FnkXT6DpqYwsNl4%7EU%7EapVOHk30jKdbovOxk8w%7EtN8zOHXXSvUS6oeTJY9ECnqD8FfMQ5h2ekoYVv-yAJd21x0oOyHlEKL9x0tvHSORvPqrKd43GqhmPOkeiJ-y2j7CXfKuM0KU7Fvi2vnMvbMbn-%7EJHaXrNE8qhL%7ELYQgPhZRNWr9Xx5-YqUKIMLv7zp0mHfh5C%7EP-sJs9iCPisRSCkitw__&Key-Pair-Id=K3ESJI6DHPFC7\n", + "Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 13.225.131.103, 13.225.131.94, 13.225.131.126, ...\n", + "Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|13.225.131.103|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 13085339 (12M) [text/plain]\n", + "Saving to: ‘datasets/databricks-dolly-15k.jsonl’\n", + "\n", + "datasets/databricks 100%[===================>] 12.48M --.-KB/s in 0.07s \n", + "\n", + "2024-06-24 00:46:25 (176 MB/s) - ‘datasets/databricks-dolly-15k.jsonl’ saved [13085339/13085339]\n", + "\n" + ] + } + ], + "source": [ + "!mkdir -p datasets\n", + "!wget \\\n", + " https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl \\\n", + " -O datasets/databricks-dolly-15k.jsonl\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-process dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(15011, 4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c4d3555c01504aed8a9989381de3d4c9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import json\n", + "from pathlib import Path\n", + "\n", + "import pandas as pd\n", + "from tqdm.notebook import tqdm\n", + "\n", + "num_samples = 1000\n", + "\n", + "dataset_path = Path().parent / \"datasets\" / \"databricks-dolly-15k.jsonl\"\n", + "data = pd.read_json(dataset_path, lines=True)\n", + "print(data.shape)\n", + "\n", + "prompt_template = \"Instruction:\\n{instruction}\\n\\nResponse:\\n{response}\"\n", + "\n", + "preprocessed_data = []\n", + "for _, row in tqdm(data.iterrows()):\n", + " preprocessed_data.append(\n", + " prompt_template.format(\n", + " instruction=row[\"instruction\"],\n", + " response=row[\"response\"],\n", + " )\n", + " )\n", + "\n", + "# Only use a limited number of training examples\n", + "preprocessed_data = preprocessed_data[:num_samples]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine-tune" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-24 00:46:26.391217: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-24 00:46:26.424725: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-06-24 00:46:29.194528: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 79192 MB memory: -> device: 0, name: CUDA GPU, pci bus id: 0000:43:00.0, compute capability: 9.0\n", + "2024-06-24 00:46:29.196389: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 79192 MB memory: -> device: 1, name: CUDA GPU, pci bus id: 0000:52:00.0, compute capability: 9.0\n" + ] + }, + { + "data": { + "text/html": [ + "
Preprocessor: \"llama3_causal_lm_preprocessor\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"llama3_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ llama3_tokenizer (Llama3Tokenizer)                 │                                             128,256 │\n",
+       "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ llama3_tokenizer (\u001b[38;5;33mLlama3Tokenizer\u001b[0m) │ \u001b[38;5;34m128,256\u001b[0m │\n", + "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"llama3_causal_lm\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"llama3_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ llama_backbone                │ (None, None, 4096)        │   8,030,261,248 │ padding_mask[0][0],        │\n",
+       "│ (Llama3Backbone)              │                           │                 │ token_ids[0][0]            │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_embedding               │ (None, None, 128256)      │   1,050,673,152 │ llama_backbone[0][0]       │\n",
+       "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ llama_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4096\u001b[0m) │ \u001b[38;5;34m8,030,261,248\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mLlama3Backbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128256\u001b[0m) │ \u001b[38;5;34m1,050,673,152\u001b[0m │ llama_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 8,030,261,248 (29.92 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m8,030,261,248\u001b[0m (29.92 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 8,030,261,248 (29.92 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m8,030,261,248\u001b[0m (29.92 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Preprocessor: \"llama3_causal_lm_preprocessor\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mPreprocessor: \"llama3_causal_lm_preprocessor\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Tokenizer (type)                                                                                Vocab # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ llama3_tokenizer (Llama3Tokenizer)                 │                                             128,256 │\n",
+       "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ llama3_tokenizer (\u001b[38;5;33mLlama3Tokenizer\u001b[0m) │ \u001b[38;5;34m128,256\u001b[0m │\n", + "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Model: \"llama3_causal_lm\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"llama3_causal_lm\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ llama_backbone                │ (None, None, 4096)        │   8,051,265,536 │ padding_mask[0][0],        │\n",
+       "│ (Llama3Backbone)              │                           │                 │ token_ids[0][0]            │\n",
+       "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
+       "│ token_embedding               │ (None, None, 128256)      │   1,050,673,152 │ llama_backbone[0][0]       │\n",
+       "│ (ReversibleEmbedding)         │                           │                 │                            │\n",
+       "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ llama_backbone │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m4096\u001b[0m) │ \u001b[38;5;34m8,051,265,536\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mLlama3Backbone\u001b[0m) │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n", + "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128256\u001b[0m) │ \u001b[38;5;34m1,050,673,152\u001b[0m │ llama_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n", + "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 8,051,265,536 (29.99 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m8,051,265,536\u001b[0m (29.99 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 21,004,288 (80.12 MB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m21,004,288\u001b[0m (80.12 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 8,030,261,248 (29.92 GB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m8,030,261,248\u001b[0m (29.92 GB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-24 00:47:11.822947: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1719190057.324048 64994 service.cc:145] XLA service 0x7fc2740c9c30 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", + "I0000 00:00:1719190057.324080 64994 service.cc:153] StreamExecutor device (0): CUDA GPU, Compute Capability 9.0\n", + "I0000 00:00:1719190057.324082 64994 service.cc:153] StreamExecutor device (1): CUDA GPU, Compute Capability 9.0\n", + "2024-06-24 00:47:38.671381: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", + "W0000 00:00:1719190060.410455 64994 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert\n", + "2024-06-24 00:47:41.899200: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8905\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "I0000 00:00:1719190064.546372 65262 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_768', 76 bytes spill stores, 76 bytes spill loads\n", + "\n", + "I0000 00:00:1719190065.328185 65265 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_641', 412 bytes spill stores, 380 bytes spill loads\n", + "\n", + "I0000 00:00:1719190065.950308 65263 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_2', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190066.046587 65266 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_4', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190066.427901 65262 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_2', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190066.851865 65263 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_4', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190067.550107 65269 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_7', 4 bytes spill stores, 4 bytes spill loads\n", + "\n", + "I0000 00:00:1719190068.115795 65267 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_2', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190068.692605 65264 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_4', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190068.753953 65268 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_7', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190069.418559 65268 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_765', 88 bytes spill stores, 88 bytes spill loads\n", + "\n", + "I0000 00:00:1719190069.448741 65262 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_3', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190069.754501 65264 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_3', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190071.235632 65266 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_288', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190071.543079 65265 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_7', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190072.085074 65269 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_576', 16 bytes spill stores, 16 bytes spill loads\n", + "\n", + "I0000 00:00:1719190072.112840 65267 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_3', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190072.440780 65267 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_7', 192 bytes spill stores, 192 bytes spill loads\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m 1/1000\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m14:18:08\u001b[0m 52s/step - loss: 0.3467 - sparse_categorical_accuracy: 0.3421" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0000 00:00:1719190082.178672 64994 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1000/1000\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m153s\u001b[0m 102ms/step - loss: 0.4920 - sparse_categorical_accuracy: 0.5289\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import keras\n", + "import keras_nlp\n", + "\n", + "batch_size = 1\n", + "\n", + "model = keras_nlp.models.Llama3CausalLM.from_preset(str(model_path))\n", + "model.summary()\n", + "\n", + "model.backbone.enable_lora(rank=4)\n", + "model.summary()\n", + "\n", + "model.preprocessor.sequence_length = 512\n", + "optimizer = keras.optimizers.AdamW(\n", + " learning_rate=5e-5,\n", + " weight_decay=0.01,\n", + ")\n", + "optimizer.exclude_from_weight_decay(var_names=[\"bias\", \"scale\"])\n", + "\n", + "model.compile(\n", + " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=optimizer,\n", + " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", + ")\n", + "\n", + "model.fit(preprocessed_data, epochs=1, batch_size=batch_size, verbose=1)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0000 00:00:1719190196.230213 65941 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_732', 4 bytes spill stores, 4 bytes spill loads\n", + "\n", + "I0000 00:00:1719190196.857451 65941 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_291', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190196.884180 65934 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_291', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190197.369031 65937 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_357', 4 bytes spill stores, 4 bytes spill loads\n", + "\n", + "I0000 00:00:1719190197.390453 65939 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_357', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190197.690251 65938 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_291', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190198.323155 65941 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_293', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190198.668158 65936 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_357', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.036451 65940 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_292', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.453389 65939 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_292', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.632895 65940 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_573', 468 bytes spill stores, 416 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.639171 65936 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_357', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.867628 65941 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_573', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.920120 65940 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_573', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190199.928212 65936 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_293', 192 bytes spill stores, 192 bytes spill loads\n", + "\n", + "I0000 00:00:1719190200.514925 65934 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_292', 128 bytes spill stores, 128 bytes spill loads\n", + "\n", + "I0000 00:00:1719190200.842650 65938 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_293', 128 bytes spill stores, 128 bytes spill loads\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Instruction:\n", + "What should I do on a trip to Europe?\n", + "\n", + "Response:\n", + "Europe has a lot to offer. First, you should learn to navigate the public transportation system. You can also learn a few basic phrases in the local language. If you are visiting Italy, you might consider trying pasta for the first time. If you are going to the UK, you can visit the Tower of London. If you are going to Germany, you can visit a beer festival. Finally, you should consider staying in hostels and budget hotels to save money.\n", + "\n", + "Second, it would be a good idea to learn about European customs and traditions. For example, you should always say \"please\" and \"thank you\" when asking someone for something, and you should never eat while walking in the street. You should also respect the local culture. For example, you should avoid public displays of affection in some countries.\n", + "\n", + "Third, it would be a good idea to have a budget and to plan your trip in advance. You should know how much you want to spend each day and what you want to do. It would also be a good idea to make a reservation at a restaurant before showing up.\n", + "\n", + "\n", + "Instruction:\n", + "Explain the process of photosynthesis in a way that a child could understand.\n", + "\n", + "Response:\n", + "Photosynthesis is the process plants and some animals use to make food from sunlight. They do this by taking a type of sugar from the air, water and sunlight to make glucose. This process happens in the cells of the plant's leaves. The plant uses this glucose for energy. When animals eat the plant's leaves, they are also eating glucose for energy. Photosynthesis is important for plants and animals because it provides them with energy to grow and survive. It also provides oxygen that humans, animals and other animals breathe in. This is why it is important to have trees and plants in the world.\n", + "\n", + "It is not possible to live on earth without plants because they provide oxygen for humans and animals to breathe. They also provide food for animals to eat. Without plants and photosynthesis, we wouldn't have food to eat or oxygen to breathe. It is a necessary part of life.\n", + "\n", + "In summary plants use sunlight to make glucose, a type of sugar, which they use for energy. When animals eat the plants leaves, they are also eating glucose for energy. The glucose that plants make also turns into oxygen that we breathe in. Photosynthesis\n" + ] + } + ], + "source": [ + "prompts = [\n", + " prompt_template.format(\n", + " instruction=\"What should I do on a trip to Europe?\",\n", + " response=\"\",\n", + " ),\n", + " prompt_template.format(\n", + " instruction=\"Explain the process of photosynthesis in a way that a child could understand.\",\n", + " response=\"\",\n", + " ),\n", + "]\n", + "\n", + "sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)\n", + "model.compile(sampler=sampler)\n", + "\n", + "for prompt in prompts:\n", + " print(model.generate(prompt, max_length=256))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}