diff --git a/transformers_distillation/LICENSE b/transformers_distillation/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/transformers_distillation/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/transformers_distillation/README.md b/transformers_distillation/README.md new file mode 100644 index 0000000..a4a130e --- /dev/null +++ b/transformers_distillation/README.md @@ -0,0 +1,149 @@ +# πŸ§ͺ HF Distiller β€” Knowledge Distillation for Hugging Face Models + +![HF Banner](https://huggingface.co/front/assets/huggingface_logo.svg) + +[![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://www.python.org/) +[![License](https://img.shields.io/badge/license-MIT-green)](LICENSE) +[![Hugging Face](https://img.shields.io/badge/huggingface-Dhiraj309-orange)](https://huggingface.co/Dhiraj309) + +**HF Distiller** is an **open-source toolkit** for performing **knowledge distillation** on Hugging Face Transformers models. It allows developers to **train smaller, faster student models** from large pre-trained teacher models while maintaining high performance. + +--- + +## πŸ“– Overview + +Knowledge Distillation (KD) compresses a large model into a smaller one by transferring the β€œknowledge” learned by the teacher to the student. HF Distiller wraps around Hugging Face’s `Trainer` to make KD **accessible, modular, and intuitive**. + +**Key Features:** + +* βœ… Load any teacher model from Hugging Face Hub +* βœ… Create smaller student models from scratch +* βœ… Supports Hugging Face tokenizers +* βœ… Seamless integration with the `datasets` library +* βœ… Transparent logging and checkpointing +* βœ… Fully compatible with PyTorch and Transformers + +--- + +## πŸ–Ό Architecture + +```text + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Teacher Model β”‚ Pretrained Hugging Face LM + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Knowledge Distillation β”‚ Transfer teacher knowledge + KD loss + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Student Model β”‚ Smaller, efficient model trained from scratch + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +--- + +## ⚑ Installation + +```bash +#Install transformers_distilattion (Recommended) +pip install --no-deps git+https://github.com/Dhiraj309/transformers_distillation.git + +#OR + +# Clone repository +git clone https://github.com/Dhiraj309/transformers_distillation.git +cd transformers_distillation.git + +# Install dependencies +pip install -r requirements.txt +``` + +--- + +## πŸƒ Quick Start + +```python +from transformers_distillation.models import load_teacher, load_student +from transformers_distillation.trainer import DistillTrainer +from transformers import AutoTokenizer, TrainingArguments +from datasets import Dataset + +# Example dataset +dataset = Dataset.from_dict({"text": ["Hello world!", "AI is amazing."]}) + +# Load teacher +teacher = load_teacher("google-bert/bert-base-uncased") +tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + +# Create student model +student = load_student( + model_name_or_path="google-bert/bert-base-uncased", + from_scratch=True, + n_layers=4, + n_heads=4, + n_embd=256, + is_pretrained=False +) + +# Tokenize +def tokenize(batch): + return tokenizer(batch["text"], max_length=128, padding=True, truncation=True) + +tokenized = dataset.map(tokenize, remove_columns=["text"]) + +# Training arguments +training_args = TrainingArguments( + output_dir="./student-llm", + per_device_train_batch_size=1, + num_train_epochs=1, + learning_rate=2e-4, + report_to="none" +) + +# Train student with KD +trainer = DistillTrainer( + teacher_model=teacher, + student_model=student, + train_dataset=tokenized, + tokenizer=tokenizer, + training_args=training_args, + kd_alpha=0.5, + temperature=2.0 +) +trainer.train() +``` + +--- + +## πŸ“‚ Project Status + +| Stage | Status | +| -------------------- | -------------- | +| Core Development | βœ… Complete | +| Documentation | βœ… Complete | +| Community Feedback | 🚧 In Progress | +| Tutorials & Examples | 🚧 In Progress | + +--- + +## 🀝 Collaboration + +We welcome contributions from the community, including: + +* Pull requests for new KD strategies +* Bug reports and feature requests +* Tutorials and example scripts +* Optimization for faster student training + +πŸ”— GitHub: [Dhiraj309](https://github.com/Dhiraj309) +πŸ”— Hugging Face: [dignity045](https://huggingface.co/dignity045) + +--- + +## πŸ“œ License + +Released under the **MIT License** β€” free to use, modify, and distribute. See [LICENSE](LICENSE) for full terms. + diff --git a/transformers_distillation/examples/CausalLM.ipynb b/transformers_distillation/examples/CausalLM.ipynb new file mode 100644 index 0000000..0932adf --- /dev/null +++ b/transformers_distillation/examples/CausalLM.ipynb @@ -0,0 +1,470 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1e6347fb", + "metadata": {}, + "source": [ + "# Knowledge Distillation with hf_distiller\n", + "This notebook demonstrates:\n", + "1. Loading a teacher model from Hugging Face Hub\n", + "2. Creating a smaller student model\n", + "3. Preparing a toy dataset\n", + "4. Training the student using knowledge distillation\n", + "5. Visualizing training loss and logits comparison\n", + "\n", + "You can replace the demo dataset with your own dataset for real training." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d5990fd5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Collecting git+https://github.com/Dhiraj309/transformers_distillation.git\n", + " Cloning https://github.com/Dhiraj309/transformers_distillation.git to c:\\users\\patil\\appdata\\local\\temp\\pip-req-build-_93suvg_\n", + " Resolved https://github.com/Dhiraj309/transformers_distillation.git to commit eec0c657b772f842c30878f7d36fcf69731e3f21\n", + " Installing build dependencies: started\n", + " Installing build dependencies: finished with status 'done'\n", + " Getting requirements to build wheel: started\n", + " Getting requirements to build wheel: finished with status 'done'\n", + " Preparing metadata (pyproject.toml): started\n", + " Preparing metadata (pyproject.toml): finished with status 'done'\n", + "Building wheels for collected packages: transformers_distiller\n", + " Building wheel for transformers_distiller (pyproject.toml): started\n", + " Building wheel for transformers_distiller (pyproject.toml): finished with status 'done'\n", + " Created wheel for transformers_distiller: filename=transformers_distiller-0.1.0-py3-none-any.whl size=11639 sha256=f4c101bf67e2ea8b6d103fd7e6ca93a3c90081ed55cec426a1711928811a6ea0\n", + " Stored in directory: C:\\Users\\patil\\AppData\\Local\\Temp\\pip-ephem-wheel-cache-c7cjlwlf\\wheels\\0d\\22\\7a\\7b6f72d21e3a6e4f60a7b03fda4acb5cfeeb146b6c0ea5c5e8\n", + "Successfully built transformers_distiller\n", + "Installing collected packages: transformers_distiller\n", + "Successfully installed transformers_distiller-0.1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " Running command git clone --filter=blob:none --quiet https://github.com/Dhiraj309/transformers_distillation.git 'C:\\Users\\patil\\AppData\\Local\\Temp\\pip-req-build-_93suvg_'\n" + ] + } + ], + "source": [ + "# Step 0 β€” Install requirements (run only once)\n", + "!pip install --no-deps git+https://github.com/Dhiraj309/transformers_distillation.git" + ] + }, + { + "cell_type": "markdown", + "id": "78a19388", + "metadata": {}, + "source": [ + "## Step 1 β€” Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "41b4f8c7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\patil\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import sys\n", + "import os\n", + "from transformers import AutoTokenizer, TrainingArguments\n", + "from datasets import Dataset\n", + "from transformers_distillation.models import load_teacher, load_student\n", + "from transformers_distillation import DistillTrainer\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "3bc13526", + "metadata": {}, + "source": [ + "## Step 2 β€” Load Teacher Model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b4a7d596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Teacher model loaded: LlamaForCausalLM\n", + "Tokenizer vocab size: 49152\n" + ] + } + ], + "source": [ + "MODEL_NAME = 'HuggingFaceTB/SmolLM2-135M'\n", + "\n", + "# Load teacher and tokenizer\n", + "teacher = load_teacher(model_name_or_path=MODEL_NAME)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "print(\"Teacher model loaded:\", teacher.__class__.__name__)\n", + "print(\"Tokenizer vocab size:\", len(tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "014d882a", + "metadata": {}, + "source": [ + "## Step 3 β€” Create Student Model\n", + "A smaller architecture for faster inference and lower memory usage." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7704b013", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Student model created: LlamaForCausalLM\n" + ] + } + ], + "source": [ + "student = load_student(\n", + " model_name_or_path=MODEL_NAME,\n", + " from_scratch=True,\n", + " n_layers=4,\n", + " n_heads=4,\n", + " n_embd=256,\n", + " is_pretrained=False\n", + ")\n", + "print(\"Student model created:\", student.__class__.__name__)" + ] + }, + { + "cell_type": "markdown", + "id": "d1d46085", + "metadata": {}, + "source": [ + "## Step 4 β€” Prepare Dataset\n", + "Small in-memory dataset for demonstration. Replace with your own data for real training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "456dd4dc", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00, 92.21 examples/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized example: {'input_ids': [19556, 905, 17], 'attention_mask': [1, 1, 1]}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "texts = [\n", + " \"Hello world!\",\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Artificial intelligence is transforming industries.\",\n", + " \"Once upon a time, there was a curious developer.\",\n", + " \"PyTorch makes deep learning both fun and powerful.\"\n", + "]\n", + "dataset = Dataset.from_dict({\"text\": texts})\n", + "\n", + "def tokenize(batch):\n", + " return tokenizer(batch['text'], max_length=128, padding=True, truncation=True)\n", + "\n", + "tokenized_dataset = dataset.map(tokenize, remove_columns=['text'])\n", + "eval_dataset = tokenized_dataset.select(range(1))\n", + "print(\"Tokenized example:\", tokenized_dataset[0])" + ] + }, + { + "cell_type": "markdown", + "id": "c66608ee", + "metadata": {}, + "source": [ + "## Step 5 β€” Define Training Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "448b87d0", + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir='./student-llm',\n", + " per_device_train_batch_size=1,\n", + " num_train_epochs=3,\n", + " learning_rate=2e-4,\n", + " logging_steps=1,\n", + " save_steps=100,\n", + " save_total_limit=5,\n", + " report_to='none',\n", + " lr_scheduler_type='cosine',\n", + " warmup_steps=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c6eccffe", + "metadata": {}, + "source": [ + "## Step 6 β€” Initialize Distillation Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7dd3905f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\patil\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers_distillation\\trainer.py:38: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `DistillationTrainer.__init__`. Use `processing_class` instead.\n", + " super().__init__(\n" + ] + } + ], + "source": [ + "trainer = DistillTrainer(\n", + " teacher_model=teacher,\n", + " student_model=student,\n", + " train_dataset=tokenized_dataset,\n", + " tokenizer=tokenizer,\n", + " training_args=training_args,\n", + " kd_alpha=0.5,\n", + " temperature=2.0,\n", + " is_pretrained=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "eb725dc3", + "metadata": {}, + "source": [ + "## Step 7 β€” Train Student Model\n", + "The student learns from both teacher outputs and ground truth labels." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "77f47a10", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [15/15 00:05, Epoch 3/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
127.726200
246.899300
373.354600
412.300200
549.185500
670.402400
747.146400
844.568800
925.926500
109.427700
1123.904900
1240.299700
1359.467800
1440.667700
157.163500

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Keep track of loss for visualization\n", + "trainer_state = trainer.train()\n", + "losses = trainer_state.training_loss if hasattr(trainer_state, 'training_loss') else []" + ] + }, + { + "cell_type": "markdown", + "id": "f229980c", + "metadata": {}, + "source": [ + "## Step 8 β€” Evaluate Student Model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "33ac356d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1/1 : < :]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation results: {'eval_runtime': 0.0305, 'eval_samples_per_second': 32.834, 'eval_steps_per_second': 32.834, 'epoch': 3.0}\n" + ] + } + ], + "source": [ + "results = trainer.evaluate(eval_dataset = eval_dataset)\n", + "print('Evaluation results:', results)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38ce6f56", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/transformers_distillation/examples/CausalLM.py b/transformers_distillation/examples/CausalLM.py new file mode 100644 index 0000000..a5277a0 --- /dev/null +++ b/transformers_distillation/examples/CausalLM.py @@ -0,0 +1,107 @@ +""" +Knowledge Distillation with hf_distiller (Python Script) + +This script demonstrates: +1. Loading a teacher model from Hugging Face Hub +2. Creating a smaller student model +3. Preparing a toy dataset +4. Training the student using knowledge distillation + +Run: + pip install -r requirements.txt + python distill_demo.py +""" + +import sys +import os +from transformers import AutoTokenizer, TrainingArguments +from datasets import Dataset +from transformers_distillation.models import load_teacher, load_student +from transformers_distillation import DistillTrainer + +# ------------------------------------------------------------------------- +# Step 1 β€” Ensure src/ is in Python path +# ------------------------------------------------------------------------- +# ------------------------------------------------------------------------- +# Step 2 β€” Select teacher model +# ------------------------------------------------------------------------- +MODEL_NAME = "HuggingFaceTB/SmolLM2-135M" + +# ------------------------------------------------------------------------- +# Step 3 β€” Load Teacher & Tokenizer +# ------------------------------------------------------------------------- +teacher = load_teacher(model_name_or_path=MODEL_NAME) +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +# ------------------------------------------------------------------------- +# Step 4 β€” Create Student model (smaller) +# ------------------------------------------------------------------------- +student = load_student( + model_name_or_path=MODEL_NAME, + from_scratch=True, + n_layers=4, + n_heads=4, + n_embd=256, + is_pretrained=False +) + +# ------------------------------------------------------------------------- +# Step 5 β€” Prepare Dataset +# ------------------------------------------------------------------------- +texts = [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + "Artificial intelligence is transforming industries.", + "Once upon a time, there was a curious developer.", + "PyTorch makes deep learning both fun and powerful." +] +dataset = Dataset.from_dict({"text": texts}) + +def tokenize(batch): + return tokenizer(batch["text"], max_length=128, padding=True, truncation=True) + +tokenized_dataset = dataset.map(tokenize, remove_columns=["text"]) +eval_dataset = tokenized_dataset.select(range(1)) + +# ------------------------------------------------------------------------- +# Step 6 β€” Training Arguments +# ------------------------------------------------------------------------- +training_args = TrainingArguments( + output_dir="./student-llm", + per_device_train_batch_size=1, + num_train_epochs=1, + learning_rate=2e-4, + logging_steps=10, + save_steps=100, + save_total_limit=5, + report_to="none", + lr_scheduler_type="cosine", + warmup_steps=500, +) + +# ------------------------------------------------------------------------- +# Step 7 β€” Initialize Distillation Trainer +# ------------------------------------------------------------------------- +trainer = DistillTrainer( + teacher_model=teacher, + student_model=student, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + training_args=training_args, + kd_alpha=0.5, + temperature=2.0, + is_pretrained=False +) + +# ------------------------------------------------------------------------- +# Step 8 β€” Train +# ------------------------------------------------------------------------- +trainer.train() + +# ------------------------------------------------------------------------- +# Optional: Evaluate Student (Requires Eval Dataset) +# ------------------------------------------------------------------------- +results = trainer.evaluate(eval_dataset = eval_dataset) +print("Evaluation results:", results) \ No newline at end of file diff --git a/transformers_distillation/examples/MLM.ipynb b/transformers_distillation/examples/MLM.ipynb new file mode 100644 index 0000000..8bb9ce5 --- /dev/null +++ b/transformers_distillation/examples/MLM.ipynb @@ -0,0 +1,439 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fb561b16", + "metadata": {}, + "source": [ + "# Knowledge Distillation with hf_distiller\n", + "This notebook demonstrates:\n", + "1. Loading a teacher model from Hugging Face Hub\n", + "2. Creating a smaller student model\n", + "3. Preparing a toy dataset\n", + "4. Training the student using knowledge distillation\n", + "5. Visualizing training loss and logits comparison\n", + "\n", + "You can replace the demo dataset with your own dataset for real training." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6d48c507", + "metadata": {}, + "outputs": [], + "source": [ + "# Step 0 β€” Install requirements (run only once)\n", + "# !pip install --no-deps git+https://github.com/Dhiraj309/transformers_distillation.git" + ] + }, + { + "cell_type": "markdown", + "id": "83171e73", + "metadata": {}, + "source": [ + "## Step 1 β€” Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "32d8b9fa", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\patil\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import sys\n", + "import os\n", + "from transformers import AutoTokenizer, TrainingArguments\n", + "from datasets import Dataset\n", + "from transformers_distillation.models import load_teacher, load_student\n", + "from transformers_distillation import DistillTrainer\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "e835a210", + "metadata": {}, + "source": [ + "## Step 2 β€” Load Teacher Model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b95a85a6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Teacher model loaded: BertForMaskedLM\n", + "Tokenizer vocab size: 30522\n" + ] + } + ], + "source": [ + "MODEL_NAME = 'google-bert/bert-base-uncased'\n", + "\n", + "# Load teacher and tokenizer\n", + "teacher = load_teacher(model_name_or_path=MODEL_NAME)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "print(\"Teacher model loaded:\", teacher.__class__.__name__)\n", + "print(\"Tokenizer vocab size:\", len(tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "7527041b", + "metadata": {}, + "source": [ + "## Step 3 β€” Create Student Model\n", + "A smaller architecture for faster inference and lower memory usage." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9bbc0e43", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Student model created: BertForMaskedLM\n" + ] + } + ], + "source": [ + "student = load_student(\n", + " model_name_or_path=MODEL_NAME,\n", + " from_scratch=True,\n", + " n_layers=4,\n", + " n_heads=4,\n", + " n_embd=256,\n", + " is_pretrained=False\n", + ")\n", + "print(\"Student model created:\", student.__class__.__name__)" + ] + }, + { + "cell_type": "markdown", + "id": "ef99b6f9", + "metadata": {}, + "source": [ + "## Step 4 β€” Prepare Dataset\n", + "Small in-memory dataset for demonstration. Replace with your own data for real training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5e10a9e6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00, 104.16 examples/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized example: {'input_ids': [101, 7592, 2088, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "texts = [\n", + " \"Hello world!\",\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"Artificial intelligence is transforming industries.\",\n", + " \"Once upon a time, there was a curious developer.\",\n", + " \"PyTorch makes deep learning both fun and powerful.\"\n", + "]\n", + "dataset = Dataset.from_dict({\"text\": texts})\n", + "\n", + "def tokenize(batch):\n", + " return tokenizer(batch['text'], max_length=128, padding=True, truncation=True)\n", + "\n", + "tokenized_dataset = dataset.map(tokenize, remove_columns=['text'])\n", + "eval_dataset = tokenized_dataset.select(range(1))\n", + "print(\"Tokenized example:\", tokenized_dataset[0])" + ] + }, + { + "cell_type": "markdown", + "id": "3e86cd8e", + "metadata": {}, + "source": [ + "## Step 5 β€” Define Training Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9f1a0060", + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir='./student-llm',\n", + " per_device_train_batch_size=1,\n", + " num_train_epochs=3,\n", + " learning_rate=2e-4,\n", + " logging_steps=1,\n", + " save_steps=100,\n", + " save_total_limit=5,\n", + " report_to='none',\n", + " lr_scheduler_type='cosine',\n", + " warmup_steps=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "824de177", + "metadata": {}, + "source": [ + "## Step 6 β€” Initialize Distillation Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a7793974", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\patil\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers_distillation\\trainer.py:38: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `DistillationTrainer.__init__`. Use `processing_class` instead.\n", + " super().__init__(\n" + ] + } + ], + "source": [ + "trainer = DistillTrainer(\n", + " teacher_model=teacher,\n", + " student_model=student,\n", + " train_dataset=tokenized_dataset,\n", + " tokenizer=tokenizer,\n", + " training_args=training_args,\n", + " kd_alpha=0.5,\n", + " temperature=2.0,\n", + " is_pretrained=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4da02536", + "metadata": {}, + "source": [ + "## Step 7 β€” Train Student Model\n", + "The student learns from both teacher outputs and ground truth labels." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "91af179d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [15/15 00:04, Epoch 3/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
182.397400
2147.293900
3134.957400
448.218500
5127.869000
6118.825800
7116.639000
8128.206400
962.497800
1039.070000
1156.712400
12115.762400
13103.406100
1498.360000
1532.113400

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Keep track of loss for visualization\n", + "trainer_state = trainer.train()\n", + "losses = trainer_state.training_loss if hasattr(trainer_state, 'training_loss') else []" + ] + }, + { + "cell_type": "markdown", + "id": "44059de0", + "metadata": {}, + "source": [ + "## Step 8 β€” Evaluate Student Model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "54c6f5b8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1/1 : < :]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation results: {'eval_runtime': 0.0184, 'eval_samples_per_second': 54.458, 'eval_steps_per_second': 54.458, 'epoch': 3.0}\n" + ] + } + ], + "source": [ + "results = trainer.evaluate(eval_dataset = eval_dataset)\n", + "print('Evaluation results:', results)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/transformers_distillation/examples/MLM.py b/transformers_distillation/examples/MLM.py new file mode 100644 index 0000000..8e8f982 --- /dev/null +++ b/transformers_distillation/examples/MLM.py @@ -0,0 +1,113 @@ +""" +Knowledge Distillation with hf_distiller (Python Script) + +This script demonstrates: +1. Loading a teacher model from Hugging Face Hub +2. Creating a smaller student model +3. Preparing a toy dataset +4. Training the student using knowledge distillation + +Run: + pip install -r requirements.txt + python distill_demo.py +""" + +import sys +import os +from transformers import AutoTokenizer, TrainingArguments +from datasets import Dataset +from transformers_distillation.models import load_teacher, load_student +from transformers_distillation import DistillTrainer + +# ------------------------------------------------------------------------- +# Step 1 β€” Ensure src/ is in Python path +# ------------------------------------------------------------------------- +# ------------------------------------------------------------------------- +# Step 2 β€” Select teacher model +# ------------------------------------------------------------------------- +MODEL_NAME = "google-bert/bert-base-uncased" + +# ------------------------------------------------------------------------- +# Step 3 β€” Load Teacher & Tokenizer +# ------------------------------------------------------------------------- +teacher = load_teacher(model_name_or_path=MODEL_NAME) +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +# ------------------------------------------------------------------------- +# Step 4 β€” Create Student model (smaller) +# ------------------------------------------------------------------------- +student = load_student( + model_name_or_path=MODEL_NAME, + from_scratch=True, + n_layers=4, + n_heads=4, + n_embd=256, + is_pretrained=False +) + +# ------------------------------------------------------------------------- +# Step 5 β€” Prepare Dataset +# ------------------------------------------------------------------------- +texts = [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + "Artificial intelligence is transforming industries.", + "Once upon a time, there was a curious developer.", + "PyTorch makes deep learning both fun and powerful." +] +dataset = Dataset.from_dict({"text": texts}) + +def tokenize(batch): + return tokenizer(batch["text"], max_length=128, padding=True, truncation=True) + +tokenized_dataset = dataset.map(tokenize, remove_columns=["text"]) + +# ------------------------------------------------------------------------- +# Step 6 β€” Training Arguments +# ------------------------------------------------------------------------- +training_args = TrainingArguments( + output_dir="./student-llm", + per_device_train_batch_size=1, + num_train_epochs=1, + learning_rate=2e-4, + logging_steps=10, + save_steps=100, + save_total_limit=5, + report_to="none", + lr_scheduler_type="cosine", + warmup_steps=500, +) + +# ------------------------------------------------------------------------- +# Step 7 β€” Initialize Distillation Trainer +# ------------------------------------------------------------------------- +trainer = DistillTrainer( + teacher_model=teacher, + student_model=student, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + training_args=training_args, + kd_alpha=0.5, + temperature=2.0, + is_pretrained=False +) + +# ------------------------------------------------------------------------- +# Step 8 β€” Train +# ------------------------------------------------------------------------- +trainer.train() + +# ------------------------------------------------------------------------- +<<<<<<< HEAD +# Optional: Evaluate Student (Requires Eval Dataset) +# ------------------------------------------------------------------------- +# results = trainer.evaluate() +# print("Evaluation results:", results) +======= +# Optional: Evaluate Student +# ------------------------------------------------------------------------- +results = trainer.evaluate() +print("Evaluation results:", results) +>>>>>>> origin/main diff --git a/transformers_distillation/examples/Seq2SeqLM.ipynb b/transformers_distillation/examples/Seq2SeqLM.ipynb new file mode 100644 index 0000000..a4a227f --- /dev/null +++ b/transformers_distillation/examples/Seq2SeqLM.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "68675e3b", + "metadata": {}, + "source": [ + "# Knowledge Distillation with hf_distiller\n", + "This notebook demonstrates:\n", + "1. Loading a teacher model from Hugging Face Hub\n", + "2. Creating a smaller student model\n", + "3. Preparing a toy dataset\n", + "4. Training the student using knowledge distillation\n", + "5. Visualizing training loss and logits comparison\n", + "\n", + "You can replace the demo dataset with your own dataset for real training." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "980f2725", + "metadata": {}, + "outputs": [], + "source": [ + "# Step 0 β€” Install requirements (run only once)\n", + "# !pip install --no-deps git+https://github.com/Dhiraj309/transformers_distillation.git" + ] + }, + { + "cell_type": "markdown", + "id": "e22c2484", + "metadata": {}, + "source": [ + "## Step 1 β€” Imports and Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "acab9153", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\patil\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import sys\n", + "import os\n", + "from transformers import AutoTokenizer, TrainingArguments\n", + "from datasets import Dataset\n", + "from transformers_distillation.models import load_teacher, load_student\n", + "from transformers_distillation import DistillTrainer\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "b4223dcc", + "metadata": {}, + "source": [ + "## Step 2 β€” Load Teacher Model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2d061642", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Teacher model loaded: T5ForConditionalGeneration\n", + "Tokenizer vocab size: 32100\n" + ] + } + ], + "source": [ + "MODEL_NAME = 'google/flan-t5-small'\n", + "\n", + "# Load teacher and tokenizer\n", + "teacher = load_teacher(model_name_or_path=MODEL_NAME)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "print(\"Teacher model loaded:\", teacher.__class__.__name__)\n", + "print(\"Tokenizer vocab size:\", len(tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "3967d783", + "metadata": {}, + "source": [ + "## Step 3 β€” Create Student Model\n", + "A smaller architecture for faster inference and lower memory usage." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "84e2af04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Student model created: T5ForConditionalGeneration\n" + ] + } + ], + "source": [ + "student = load_student(\n", + " model_name_or_path=MODEL_NAME,\n", + " from_scratch=True,\n", + " n_layers=4,\n", + " n_heads=4,\n", + " n_embd=256,\n", + " is_pretrained=False\n", + ")\n", + "print(\"Student model created:\", student.__class__.__name__)" + ] + }, + { + "cell_type": "markdown", + "id": "2f7a02c4", + "metadata": {}, + "source": [ + "## Step 4 β€” Prepare Dataset\n", + "Small in-memory dataset for demonstration. Replace with your own data for real training." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "60678685", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map: 0%| | 0/5 [00:00\n", + " \n", + " \n", + " [15/15 00:12, Epoch 3/3]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
110405.671900
210288.678700
310422.496100
410445.543000
510352.596700
610209.575200
710074.432600
810033.449200
910038.668900
1010291.808600
119841.254900
1210065.089800
139404.201200
149463.793000
159407.624000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Keep track of loss for visualization\n", + "trainer_state = trainer.train()\n", + "losses = trainer_state.training_loss if hasattr(trainer_state, 'training_loss') else []" + ] + }, + { + "cell_type": "markdown", + "id": "3320dc18", + "metadata": {}, + "source": [ + "## Step 8 β€” Evaluate Student Model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "db37670f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1/1 : < :]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluation results: {'eval_loss': 8294.9560546875, 'eval_runtime': 0.637, 'eval_samples_per_second': 1.57, 'eval_steps_per_second': 1.57, 'epoch': 3.0}\n" + ] + } + ], + "source": [ + "results = trainer.evaluate(eval_dataset = eval_dataset)\n", + "print('Evaluation results:', results)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/transformers_distillation/examples/Seq2SeqLM.py b/transformers_distillation/examples/Seq2SeqLM.py new file mode 100644 index 0000000..aee4a8b --- /dev/null +++ b/transformers_distillation/examples/Seq2SeqLM.py @@ -0,0 +1,113 @@ +""" +Knowledge Distillation with hf_distiller (Python Script) + +This script demonstrates: +1. Loading a teacher model from Hugging Face Hub +2. Creating a smaller student model +3. Preparing a toy dataset +4. Training the student using knowledge distillation + +Run: + pip install -r requirements.txt + python distill_demo.py +""" + +import sys +import os +from transformers import AutoTokenizer, TrainingArguments +from datasets import Dataset +from transformers_distillation.models import load_teacher, load_student +from transformers_distillation import DistillTrainer + +# ------------------------------------------------------------------------- +# Step 1 β€” Ensure src/ is in Python path +# ------------------------------------------------------------------------- +# ------------------------------------------------------------------------- +# Step 2 β€” Select teacher model +# ------------------------------------------------------------------------- +MODEL_NAME = "google/flan-t5-small" + +# ------------------------------------------------------------------------- +# Step 3 β€” Load Teacher & Tokenizer +# ------------------------------------------------------------------------- +teacher = load_teacher(model_name_or_path=MODEL_NAME) +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +# ------------------------------------------------------------------------- +# Step 4 β€” Create Student model (smaller) +# ------------------------------------------------------------------------- +student = load_student( + model_name_or_path=MODEL_NAME, + from_scratch=True, + n_layers=4, + n_heads=4, + n_embd=256, + is_pretrained=False +) + +# ------------------------------------------------------------------------- +# Step 5 β€” Prepare Dataset +# ------------------------------------------------------------------------- +texts = [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + "Artificial intelligence is transforming industries.", + "Once upon a time, there was a curious developer.", + "PyTorch makes deep learning both fun and powerful." +] +dataset = Dataset.from_dict({"text": texts}) + +def tokenize(batch): + return tokenizer(batch["text"], max_length=128, padding=True, truncation=True) + +tokenized_dataset = dataset.map(tokenize, remove_columns=["text"]) + +# ------------------------------------------------------------------------- +# Step 6 β€” Training Arguments +# ------------------------------------------------------------------------- +training_args = TrainingArguments( + output_dir="./student-llm", + per_device_train_batch_size=1, + num_train_epochs=1, + learning_rate=2e-4, + logging_steps=10, + save_steps=100, + save_total_limit=5, + report_to="none", + lr_scheduler_type="cosine", + warmup_steps=500, +) + +# ------------------------------------------------------------------------- +# Step 7 β€” Initialize Distillation Trainer +# ------------------------------------------------------------------------- +trainer = DistillTrainer( + teacher_model=teacher, + student_model=student, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + training_args=training_args, + kd_alpha=0.5, + temperature=2.0, + is_pretrained=False +) + +# ------------------------------------------------------------------------- +# Step 8 β€” Train +# ------------------------------------------------------------------------- +trainer.train() + +# ------------------------------------------------------------------------- +<<<<<<< HEAD +# Optional: Evaluate Student (Requires Eval Dataset) +# ------------------------------------------------------------------------- +# results = trainer.evaluate() +# print("Evaluation results:", results) +======= +# Optional: Evaluate Student +# ------------------------------------------------------------------------- +results = trainer.evaluate() +print("Evaluation results:", results) +>>>>>>> origin/main diff --git a/transformers_distillation/pyproject.toml b/transformers_distillation/pyproject.toml new file mode 100644 index 0000000..f41fbd8 --- /dev/null +++ b/transformers_distillation/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "transformers_distiller" +version = "0.1.0" +description = "A Hugging Face model distillation trainer" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "Apache-2.0"} +authors = [ + {name = "Dhiraj Patil", email = "patildhiraj1197@gmail.com"} +] +dependencies = [ + "torch>=2.0.0", + "transformers==4.55.2", + "datasets==4.0.0", + "accelerate==1.10.0", + "bitsandbytes==0.47.0", + "huggingface-hub==0.34.4", + "safetensors==0.6.2", + "numpy>=2.1.2", + "pandas>=2.3.1", + "tqdm>=4.67.1" +] + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] + +where = ["src"] diff --git a/transformers_distillation/requirements.txt b/transformers_distillation/requirements.txt new file mode 100644 index 0000000..7037d8e --- /dev/null +++ b/transformers_distillation/requirements.txt @@ -0,0 +1,10 @@ +torch>=2.0.0 +transformers==4.55.2 +datasets==4.0.0 +accelerate==1.10.0 +bitsandbytes==0.47.0 +huggingface-hub==0.34.4 +safetensors==0.6.2 +numpy>=2.1.2 +pandas>=2.3.1 +tqdm>=4.67.1 \ No newline at end of file diff --git a/transformers_distillation/setup.py b/transformers_distillation/setup.py new file mode 100644 index 0000000..686f40d --- /dev/null +++ b/transformers_distillation/setup.py @@ -0,0 +1,35 @@ +from setuptools import setup, find_packages + +setup( + name="transformers_distiller", + version="0.1.0", + description="A Hugging Face model distillation trainer", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + author="Dhiraj Patil", + author_email="patildhiraj1197@gmail.com", + python_requires=">=3.9", + license="Apache-2.0", + packages=find_packages(where="src"), + package_dir={"": "src"}, + install_requires=[ + "torch>=2.0.0", + "transformers==4.55.2", + "datasets==4.0.0", + "accelerate==1.10.0", + "bitsandbytes==0.47.0", + "huggingface-hub==0.34.4", + "safetensors==0.6.2", + "numpy>=2.1.2", + "pandas>=2.3.1", + "tqdm>=4.67.1" + ], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + include_package_data=True, + zip_safe=False, +) \ No newline at end of file diff --git a/transformers_distillation/src/transformers_distillation/__init__.py b/transformers_distillation/src/transformers_distillation/__init__.py new file mode 100644 index 0000000..5e35175 --- /dev/null +++ b/transformers_distillation/src/transformers_distillation/__init__.py @@ -0,0 +1,11 @@ +from .models import load_teacher, load_student +from .trainer import DistillationTrainer, DistillTrainer +from .utils import detect_task_type, TaskType + +__all__ = [ + "load_teacher", + "load_student", + "DistillationTrainer", + "detect_task_type", + "TaskType" +] diff --git a/transformers_distillation/src/transformers_distillation/configs.py b/transformers_distillation/src/transformers_distillation/configs.py new file mode 100644 index 0000000..fa5b967 --- /dev/null +++ b/transformers_distillation/src/transformers_distillation/configs.py @@ -0,0 +1,37 @@ +from typing import Optional +import torch + +try: + from transformers import BitsAndBytesConfig +except Exception: + BitsAndBytesConfig = None + + +def no_quant(): + return None + +def quant_8(): + if BitsAndBytesConfig is None: + raise ImportError("BitsAndBytes Not Available. Install 'BitsAndBytes' To Use 8-bit Quantization") + + return BitsAndBytesConfig(load_in_8bit = True) + +# def quant_16(): +# return BitsAndBytesConfig(load_in_16bit = True) + +def quant_4(): + if BitsAndBytesConfig is None: + raise ImportError("BitsAndBytes Not Available. Install 'BitsAndBytes' To Use 4-bit Quantization") + + return BitsAndBytesConfig( + load_in_4bit = True, + bnb_4bit_quant_type = "nf4", + bnb_4bit_use_double_quant = True, + bnb_4bit_compute_dtype = torch.bfloat16 + ) + +def custom_quant(**kwargs): + if BitsAndBytesConfig is None: + raise ImportError("BitsAndBytes Not Available. Install 'BitsAndBytes' To Use Custom Quantization") + + return BitsAndBytesConfig(**kwargs) \ No newline at end of file diff --git a/transformers_distillation/src/transformers_distillation/models.py b/transformers_distillation/src/transformers_distillation/models.py new file mode 100644 index 0000000..6cc1e70 --- /dev/null +++ b/transformers_distillation/src/transformers_distillation/models.py @@ -0,0 +1,104 @@ +from typing import Optional +import torch +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForMaskedLM, + AutoConfig, + AutoTokenizer, + BitsAndBytesConfig +) +from .utils import detect_task_type, TaskType + + +def _freez_eval(model: torch.nn.Module) -> torch.nn.Module: + model.eval() + for param in model.parameters(): + param.requires_grad = False + + return model + +def load_teacher(model_name_or_path: str, quant_config: Optional[object] = None, device_map: str = "auto"): + + #DETECTING TASK AUTOMATICALLY FOR LM + task = detect_task_type(model_name_or_path) + common_kargs = {} + if quant_config is not None: + common_kargs["quantization_config"] = quant_config + common_kargs["device_map"] = device_map + + #CausalLM Model + if task == TaskType.CAUSAL_LM: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **common_kargs) + + #Seq2SeqLM Model + elif task ==TaskType.SEQ2SEQ_LM: + model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **common_kargs) + + #MLM Model + elif task == TaskType.MLM: + model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, **common_kargs) + + #Fallback For CausalLM + else: + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **common_kargs) + + #Freezing Model To Make The Teacher Model Training False + return _freez_eval(model) + +def load_student( + model_name_or_path: str, + is_pretrained: bool = False, + n_layers: int = None, + n_heads: int = None, + num_key_value_heads: int = None, # If None, will match n_heads + n_embd: int = None, + from_scratch: bool = True, + explicit_task: Optional[TaskType] = None +): + # Detect Task Or Take Explicit Task + task = explicit_task or detect_task_type(model_name_or_path) + + if from_scratch: + cfg = AutoConfig.from_pretrained(model_name_or_path) + + # NUM LAYERS + if hasattr(cfg, "n_layers") and n_layers is not None: + cfg.n_layers = n_layers + if hasattr(cfg, "num_hidden_layers") and n_layers is not None: + cfg.num_hidden_layers = n_layers + + # NUM HEADS + if hasattr(cfg, "n_heads") and n_heads is not None: + cfg.n_heads = n_heads + if hasattr(cfg, "num_attention_heads") and n_heads is not None: + cfg.num_attention_heads = n_heads + + # FIX: Ensure num_key_value_heads matches attention heads if not explicitly set + if hasattr(cfg, "num_key_value_heads"): + cfg.num_key_value_heads = ( + num_key_value_heads if num_key_value_heads is not None + else getattr(cfg, "num_attention_heads", n_heads or cfg.num_key_value_heads) + ) + + # HIDDEN SIZE + if hasattr(cfg, "n_embd") and n_embd is not None: + cfg.n_embd = n_embd + if hasattr(cfg, "hidden_dim") and n_embd is not None: + cfg.hidden_dim = n_embd + + if task == TaskType.CAUSAL_LM: + return AutoModelForCausalLM.from_config(cfg) + if task == TaskType.SEQ2SEQ_LM: + return AutoModelForSeq2SeqLM.from_config(cfg) + if task == TaskType.MLM: + return AutoModelForMaskedLM.from_config(cfg) + return AutoModelForCausalLM.from_config(cfg) + + else: + if task == TaskType.CAUSAL_LM: + return AutoModelForCausalLM.from_pretrained(model_name_or_path) + if task == TaskType.SEQ2SEQ_LM: + return AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + if task == TaskType.MLM: + return AutoModelForMaskedLM.from_pretrained(model_name_or_path) \ No newline at end of file diff --git a/transformers_distillation/src/transformers_distillation/trainer.py b/transformers_distillation/src/transformers_distillation/trainer.py new file mode 100644 index 0000000..7aeba67 --- /dev/null +++ b/transformers_distillation/src/transformers_distillation/trainer.py @@ -0,0 +1,137 @@ +from typing import Optional, Dict, Any +import torch +import torch.nn.functional as F +from transformers import Trainer, TrainingArguments +from .utils import TaskType, detect_task_type + +try: + from transformers.integrations.accelerate import AcceleratorConfig +except ImportError: + AcceleratorConfig = None # Older Transformers versions won't have this + + +class DistillationTrainer(Trainer): + def __init__( + self, + model, + args: TrainingArguments, + train_dataset=None, + eval_dataset=None, + tokenizer=None, + teacher_model=None, + is_pretrained=False, + kd_alpha=0.5, + temperature=2.0, + **kwargs + ): + super().__init__( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + **kwargs + ) + + self.teacher_model = teacher_model + self.kd_alpha = kd_alpha + self.temperature = temperature + + # Detect task type + self.task_type = detect_task_type(model.name_or_path if is_pretrained else model) + + # Setup teacher model + if self.teacher_model is not None: + self.teacher_model.to(self.model.device) + self.teacher_model.eval() + for param in self.teacher_model.parameters(): + param.requires_grad = False + + def shift_tokens_right(self, input_ids, pad_token_id, decoder_start_token_id): + shifted = input_ids.new_zeros(input_ids.shape) + shifted[:, 1:] = input_ids[:, :-1].clone() + shifted[:, 0] = decoder_start_token_id + shifted.masked_fill_(shifted == -100, pad_token_id) + return shifted + + def prepare_labels(self, inputs): + """ + Prepare labels depending on task type. + Ensures causal LM and seq2seq LM have properly shifted labels. + """ + if "labels" not in inputs: + inputs["labels"] = inputs["input_ids"].clone() + + if self.task_type == TaskType.CAUSAL_LM: + # Optionally shift labels for causal LM if model requires it + if getattr(self.model.config, "use_cache", False): + inputs["labels"] = inputs["labels"].clone() + return inputs["labels"] + + elif self.task_type == TaskType.MLM: + # Labels for MLM should already have -100 for masked tokens + return inputs["labels"] + + elif self.task_type == TaskType.SEQ2SEQ_LM: + if "decoder_input_ids" not in inputs: + inputs["decoder_input_ids"] = self.shift_tokens_right( + inputs["labels"], + self.model.config.pad_token_id, + self.model.config.decoder_start_token_id + ) + return inputs["labels"] + + else: + # Fallback + return inputs["labels"] + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + labels = self.prepare_labels(inputs) + student_outputs = model(**inputs) + student_logits = student_outputs.logits + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + student_logits.view(-1, student_logits.size(-1)), + labels.view(-1) + ) + + # Knowledge Distillation loss + if self.teacher_model is not None and model.training: + with torch.no_grad(): + teacher_outputs = self.teacher_model(**inputs) + teacher_logits = teacher_outputs.logits + + kd_loss = F.kl_div( + input=F.log_softmax(student_logits / self.temperature, dim=-1), + target=F.softmax(teacher_logits / self.temperature, dim=-1), + reduction="batchmean" + ) * (self.temperature ** 2) + + loss = self.kd_alpha * kd_loss + (1.0 - self.kd_alpha) * lm_loss + else: + loss = lm_loss + + return (loss, student_outputs) if return_outputs else loss + + +def DistillTrainer( + teacher_model, + student_model, + train_dataset, + tokenizer, + training_args: TrainingArguments, + is_pretrained=False, + kd_alpha=0.5, + temperature=2.0 +): + trainer = DistillationTrainer( + model=student_model, + teacher_model=teacher_model, + args=training_args, + train_dataset=train_dataset, + tokenizer=tokenizer, + kd_alpha=kd_alpha, + temperature=temperature + ) + return trainer diff --git a/transformers_distillation/src/transformers_distillation/utils.py b/transformers_distillation/src/transformers_distillation/utils.py new file mode 100644 index 0000000..db7e476 --- /dev/null +++ b/transformers_distillation/src/transformers_distillation/utils.py @@ -0,0 +1,28 @@ +from enum import Enum +from transformers import AutoConfig, PreTrainedModel + +class TaskType(str, Enum): + CAUSAL_LM = "causal_lm" + SEQ2SEQ_LM = "seq2seq_lm" + MLM = "mlm" + +def detect_task_type(model_or_path) -> TaskType: + # If it's already a model, use its config + if isinstance(model_or_path, PreTrainedModel): + cfg = model_or_path.config + else: + cfg = AutoConfig.from_pretrained(model_or_path) + + archs = (cfg.architectures or []) + model_type = getattr(cfg, "model_type", "").lower() + + if any("ForCausalLM" in a for a in archs) or model_type in {"gpt2", "llama", "mistral", "gpt_neo", "phi"}: + return TaskType.CAUSAL_LM + + if any("ForConditionalGeneration" in a for a in archs) or model_type in {"t5", "flan-t5", "ul", "mt5", "mbart"}: + return TaskType.SEQ2SEQ_LM + + if any("ForMaskedLM" in a for a in archs) or model_type in {"bert", "roberta", "albert", "electra"}: + return TaskType.MLM + + return TaskType.CAUSAL_LM diff --git a/transformers_distillation/tests/test_MLM.py b/transformers_distillation/tests/test_MLM.py new file mode 100644 index 0000000..28c55c1 --- /dev/null +++ b/transformers_distillation/tests/test_MLM.py @@ -0,0 +1,73 @@ +import sys +import os +import pytest +from transformers import AutoTokenizer, TrainingArguments +from datasets import Dataset +from transformers_distillation.models import load_teacher, load_student +from transformers_distillation import DistillTrainer + +# MODEL_NAME = "HuggingFaceTB/SmolLM2-135M" +model_names =[ + "google-bert/bert-base-uncased" +] + +@pytest.mark.parametrize("model_name", model_names) +def test_distillation_runs(model_name): + print(F"\nThe {model_name} Is Currently Being Tested") + teacher = load_teacher(model_name_or_path=model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + student = load_student( + model_name_or_path=model_name, + from_scratch=True, + n_layers=4, + n_heads=4, + n_embd=256, + is_pretrained=False + ) + + texts = [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + "Artificial intelligence is transforming industries.", + "Once upon a time, there was a curious developer.", + "PyTorch makes deep learning both fun and powerful." + ] + dataset = Dataset.from_dict({"text": texts}) + + def tokenize(batch): + return tokenizer(batch["text"], max_length=128, padding=True, truncation=True) + + tokenized_dataset = dataset.map(tokenize, remove_columns=["text"]) + eval_dataset = tokenized_dataset.select(range(1)) + + training_args = TrainingArguments( + output_dir="./student-llm", + per_device_train_batch_size=1, + num_train_epochs=1, + learning_rate=2e-4, + logging_steps=10, + save_steps=100, + save_total_limit=5, + report_to="none", + lr_scheduler_type="cosine", + warmup_steps=500, + ) + + trainer = DistillTrainer( + teacher_model=teacher, + student_model=student, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + training_args=training_args, + kd_alpha=0.5, + temperature=2.0, + is_pretrained=False + ) + + trainer.train() + + results = trainer.evaluate(eval_dataset = eval_dataset) + print("Evaluation results:", results) diff --git a/transformers_distillation/tests/test_Seq2SeqLM.py b/transformers_distillation/tests/test_Seq2SeqLM.py new file mode 100644 index 0000000..d4bd5ef --- /dev/null +++ b/transformers_distillation/tests/test_Seq2SeqLM.py @@ -0,0 +1,102 @@ +import sys +import os +import pytest +from transformers import AutoTokenizer, TrainingArguments +from datasets import Dataset +from transformers_distillation.models import load_teacher, load_student +from transformers_distillation import DistillTrainer + +# MODEL_NAME = "HuggingFaceTB/SmolLM2-135M" +model_names =[ + "google/flan-t5-small", + "google-t5/t5-small" +] + +@pytest.mark.parametrize("model_name", model_names) +def test_distillation_runs(model_name): + print(F"\nThe {model_name} Is Currently Being Tested") + teacher = load_teacher(model_name_or_path=model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + student = load_student( + model_name_or_path=model_name, + from_scratch=True, + n_layers=4, + n_heads=4, + n_embd=256, + is_pretrained=False + ) + + sources = [ + "Translate English to French: Hello world!", + "Translate English to French: The quick brown fox jumps over the lazy dog.", + "Translate English to French: Artificial intelligence is transforming industries.", + "Translate English to French: Once upon a time, there was a curious developer.", + "Translate English to French: PyTorch makes deep learning both fun and powerful." + ] + + targets = [ + "Bonjour le monde!", + "Le renard brun rapide saute par-dessus le chien paresseux.", + "L'intelligence artificielle transforme les industries.", + "Il Γ©tait une fois un dΓ©veloppeur curieux.", + "PyTorch rend l'apprentissage profond Γ  la fois amusant et puissant." + ] + + dataset = Dataset.from_dict({"source": sources, "target": targets}) + + def tokenize(batch): + # Tokenize encoder inputs + model_inputs = tokenizer( + batch["source"], + max_length=128, + truncation=True, + padding="max_length" + ) + + # Tokenize decoder targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer( + batch["target"], + max_length=128, + truncation=True, + padding="max_length" + )["input_ids"] + + model_inputs["labels"] = labels + return model_inputs + + + tokenized_dataset = dataset.map(tokenize, remove_columns=["source", "target"]) + eval_dataset = tokenized_dataset.select(range(1)) + + training_args = TrainingArguments( + output_dir="./student-llm", + per_device_train_batch_size=1, + num_train_epochs=1, + learning_rate=2e-4, + logging_steps=10, + save_steps=100, + save_total_limit=5, + report_to="none", + lr_scheduler_type="cosine", + warmup_steps=500, + ) + + trainer = DistillTrainer( + teacher_model=teacher, + student_model=student, + train_dataset=tokenized_dataset, + tokenizer=tokenizer, + training_args=training_args, + kd_alpha=0.5, + temperature=2.0, + is_pretrained=False + ) + + trainer.train() + + results = trainer.evaluate(eval_dataset = eval_dataset) + print("Evaluation results:", results)