diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ee6f9832b24aaf1330d4fdcc726e05988eecc0f9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,142 @@
+# macOS
+.DS_Store
+.AppleDouble
+.LSOverride
+.Spotlight-V100
+.Trashes
+
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Virtual environments
+.venv/
+venv/
+env/
+.envrc
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache/
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+pytestdebug.log
+coverage/
+
+# Type checkers / linters
+.mypy_cache/
+.dmypy.json
+dmypy.json
+.pyre/
+.pytype/
+.ruff_cache/
+
+# Jupyter Notebook
+.ipynb_checkpoints/
+profile_default/
+
+# IPython
+ipython_config.py
+
+# VSCode
+.vscode/
+
+# IDEs
+.idea/
+*.iml
+*.sublime-project
+*.sublime-workspace
+
+# Logs and temp files
+logs/
+*.log
+log/
+tmp/
+temp/
+
+# TensorBoard
+events.out.tfevents.*
+
+# ML experiment tracking
+wandb/
+mlruns/
+lightning_logs/
+checkpoints/
+runs/
+
+# Data & outputs (uncomment if you keep these out of git)
+# data/
+# datasets/
+# output/
+# outputs/
+# models/
+# results/
+
+# System files (Windows)
+Thumbs.db
+ehthumbs.db
+Desktop.ini
+
+# Secrets and environment files
+.env
+.env.*
+*.env
+*.secret
+*.key
+*.pem
+
+# Node (if present)
+node_modules/
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+pnpm-debug.log*
+
+# Pyenv / Poetry
+.python-version
+poetry.lock
+
+# Editor swap/backup
+*~
+*.swp
+*.swo
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index f9e86bda25c712ddc6a159bef3359e81b7acd276..9ae55a56abba20313ae516fb7b8aa43c07d41a85 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,305 @@
---
-title: ML Starter
-emoji: ๐
-colorFrom: purple
-colorTo: gray
+title: ML Starter MCP Server
+emoji: ๐ง
+colorFrom: blue
+colorTo: green
sdk: gradio
-sdk_version: 6.0.1
+sdk_version: "6.0.0"
app_file: app.py
-pinned: false
license: apache-2.0
-short_description: MCP server that exposes a problem-specific ML codes
+pinned: true
+short_description: Pure-retrieval MCP server that indexes the ML Starter knowledge base with deterministic semantics search.
+tags:
+ - building-mcp-track-enterprise
+ - gradio
+ - mcp
+ - retrieval
+ - embeddings
+ - python
+ - knowledge-base
+ - semantic-search
+ - sentence-transformers
+ - huggingface
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# ML Starter MCP Server
+
+
+
+
+Gradio-powered **remote-only** MCP server that exposes a curated ML knowledge base through deterministic, read-only tooling. Ideal for editors like Claude Desktop, VS Code (Kilo Code), or Cursor that want a trustworthy retrieval endpoint with **no side-effects**.
+
+     
+
+---
+
+## ๐งฉ Overview
+
+The **ML Starter MCP Server** indexes the entire `knowledge_base/` tree (audio, vision, NLP, RL, etc.) and makes it searchable through:
+
+* `list_items` โ enumerate every tutorial/script with metadata.
+* `semantic_search` โ vector search over docstrings and lead context to find the single best code example for a natural-language brief.
+* `get_code` โ return the full Python source for a safe, validated path.
+
+The server is deterministic (seeded numpy/torch), write-protected, and designed to run as a **Gradio MCP SSE endpoint** suitable for Hugging Face Spaces or on-prem deployments.
+
+---
+
+## ๐ ML Starter Knowledge Base
+
+* Root: `knowledge_base/`
+* Domains:
+ * `audio/`
+ * `generative/`
+ * `graph/`
+ * `nlp/`
+ * `rl/`
+ * `structured_data/`
+ * `timeseries/`
+ * `vision/`
+* Each file stores a complete, runnable ML example with docstring summaries leveraged during indexing.
+
+### Features exposed via MCP
+
+* โ
Vector search via `sentence-transformers/all-MiniLM-L6-v2` with cosine similarity.
+* โ๏ธ Safe path resolution ensures only in-repo `.py` files can be fetched.
+* ๐งฎ Metadata-first outputs (category, filename, semantic score) for quick triage.
+* ๐ก๏ธ Read-only contract; zero KB mutations, uploads, or side effects.
+* ๐ Spaces-ready networking with auto `0.0.0.0` binding when environment variables are provided by the platform.
+
+---
+
+## ๐ Quick Start
+
+### Installation
+
+```bash
+pip install -r requirements.txt
+```
+
+### Running the MCP Server
+
+```bash
+python -m mcp_server.server --host 127.0.0.1 --port 7860
+```
+
+* **SSE Endpoint:** `http://127.0.0.1:7860/gradio_api/mcp/sse`
+* Launch with `mcp_server=True` (handled by `mcp_server/server.py`).
+
+### VS Code Kilo Code Settings
+
+```json
+{
+ "mcpServers": {
+ "ml-starter-kb": {
+ "url": "http://127.0.0.1:7860/gradio_api/mcp/sse",
+ "disabled": false,
+ "timeout": 60,
+ "alwaysAllow": [],
+ "disabledTools": []
+ }
+ }
+}
+```
+
+### Environment Variables
+
+```bash
+export TOKENIZERS_PARALLELISM=false
+export PYTORCH_ENABLE_MPS_FALLBACK=1 # optional, improves macOS stability
+```
+
+---
+
+## ๐ง MCP Usage
+
+Any MCP-capable client can connect to the SSE endpoint to:
+
+* Browse the full inventory of ML tutorials.
+* Submit a markdown problem statement and receive the best-matching file path plus relevance score.
+* Fetch the code immediately and render it inline (clients typically syntax-highlight the response).
+
+The Gradio UI mirrors these capabilities via three tabs (List Items, Semantic Search, Get Code) for manual exploration.
+
+---
+
+## ๐ค Supported Embeddings
+
+* `sentence-transformers/all-MiniLM-L6-v2`
+
+### Configuration Example
+
+```yaml
+embedding_model: sentence-transformers/all-MiniLM-L6-v2
+batch_size: 32
+similarity: cosine
+```
+
+---
+
+## ๐ Retrieval Strategy
+
+| Component | Description |
+|----------------------|--------------------------------------------------------------|
+| Index Type | In-memory cosine index backed by numpy vectors |
+| Chunking | File-level (docstring + prefix) |
+| Similarity Function | Dot product on L2-normalized vectors |
+| Results Returned | Top-1 match (deterministic) |
+
+### Configuration Example
+
+```yaml
+retriever: cosine
+max_results: 1
+```
+
+---
+
+## ๐งฉ Folder Structure
+
+```
+ml-starter/
+โโโ app.py # Optional Gradio hook
+โโโ mcp_server/
+โ โโโ server.py # Remote MCP entrypoint & UI builder
+โ โโโ loader.py # KB scanning + safe path resolution
+โ โโโ embeddings.py # MiniLM wrapper + cosine index
+โ โโโ tools/
+โ โโโ list_items.py # list_items()
+โ โโโ semantic_search.py # semantic_search()
+โ โโโ get_code.py # get_code()
+โโโ knowledge_base/ # ML examples grouped by domain
+โโโ requirements.txt
+โโโ README.md
+```
+
+---
+
+## ๐ง MCP Tools (`mcp_server/server.py`)
+
+| MCP Tool | Python Function | Description |
+|----------------|------------------------------------|-----------------------------------------------------------------------------------------|
+| `list_items` | `list_items()` | Enumerates every KB entry with category, filename, absolute path, and summary metadata. |
+| `semantic_search` | `semantic_search(problem_markdown: str)` | Embeds the prompt and returns the single best match plus cosine score. |
+| `get_code` | `get_code(path: str)` | Streams back the full Python source for a validated KB path. |
+
+`server.py` registers these functions with Gradio's MCP adapter, wires docstrings into tool descriptions, and ensures the SSE endpoint stays read-only.
+
+---
+
+## ๐ฌ Demo
+
+* In progress
+
+---
+
+## ๐ฅ Inputs
+
+### 1. `list_items`
+
+No input parameters; returns the entire catalog.
+
+### 2. `semantic_search`
+
+
+Input Model
+
+| Field | Type | Description | Example |
+|------------------|--------|---------------------------------------------------------|-----------------------------------------------------------------|
+| problem_markdown | str | Natural-language description of the ML task or need. | "I need a transformer example for multilingual NER." |
+
+
+### 3. `get_code`
+
+
+Input Model
+
+| Field | Type | Description | Example |
+|-------|------|-----------------------------------------------|------------------------------------------------------|
+| path | str | KB-relative or absolute path to a `.py` file. | "knowledge_base/nlp/text_classification_from_scratch.py" |
+
+
+---
+
+## ๐ค Outputs
+
+### 1. `list_items`
+
+
+Response Example
+
+```json
+[
+ {
+ "id": "nlp/text_classification_with_transformer.py",
+ "category": "nlp",
+ "filename": "text_classification_with_transformer.py",
+ "path": "knowledge_base/nlp/text_classification_with_transformer.py",
+ "summary": "Fine-tune a Transformer for sentiment classification."
+ }
+]
+```
+
+
+### 2. `semantic_search`
+
+
+Response Example
+
+```json
+{
+ "best_match": "knowledge_base/nlp/text_classification_with_transformer.py",
+ "score": 0.89
+}
+```
+
+
+### 3. `get_code`
+
+
+Response Example
+
+```json
+{
+ "path": "knowledge_base/vision/grad_cam.py",
+ "source": ""
+}
+```
+
+
+Each response is deterministic for the same corpus and embeddings, allowing MCP clients to trust caching and diffing workflows.
+
+---
+
+## ๐ฅ Team
+
+**Team Name:** Hepheon
+
+**Team Members:**
+- **Tutkum Akyildiz** - [@Tutkum](https://huggingface.co/Tutkum) - Product
+- **Emre Atilgan** - [@emreatilgan](https://huggingface.co/emreatilgan) - Tech
+
+---
+
+## ๐ ๏ธ Next Steps
+
+Today the knowledge base focuses on curated **Keras** walkthroughs. Upcoming updates will expand coverage to include:
+
+* TensorFlow
+* PyTorch
+* scikit-learn
+* ...
+
+These additions will land in the same deterministic retrieval flow, making mixed-framework discovery as seamless as the current experience.
+
+---
+
+## ๐ License
+
+This project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for full terms.
+
+---
+
+
+ Built with โค๏ธ for the ML Starter knowledge base โข Apache 2.0
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..da0ef98745cd465195afd418c249ab640dc44bff
--- /dev/null
+++ b/app.py
@@ -0,0 +1,21 @@
+from __future__ import annotations
+
+import os
+
+import gradio as gr
+
+from mcp_server.server import create_gradio_blocks
+
+# Expose a demo/app object for Hugging Face Spaces auto-discovery
+demo: gr.Blocks = create_gradio_blocks()
+app: gr.Blocks = demo
+
+if __name__ == "__main__":
+ # Respect common env vars used by Spaces/containers
+ host = os.getenv("GRADIO_SERVER_NAME") or os.getenv("HOST") or "0.0.0.0"
+ port_str = os.getenv("GRADIO_SERVER_PORT") or os.getenv("PORT") or "7860"
+ try:
+ port = int(port_str)
+ except Exception:
+ port = 7860
+ demo.launch(server_name=host, server_port=port, mcp_server=True)
\ No newline at end of file
diff --git a/knowledge_base/audio/ctc_asr.py b/knowledge_base/audio/ctc_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..349b4f13bbdd01fce17b01f0a6f1c088dda5a24a
--- /dev/null
+++ b/knowledge_base/audio/ctc_asr.py
@@ -0,0 +1,464 @@
+"""
+Title: Automatic Speech Recognition using CTC
+Authors: [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
+Date created: 2021/09/26
+Last modified: 2021/09/26
+Description: Training a CTC-based model for automatic speech recognition.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Speech recognition is an interdisciplinary subfield of computer science
+and computational linguistics that develops methodologies and technologies
+that enable the recognition and translation of spoken language into text
+by computers. It is also known as automatic speech recognition (ASR),
+computer speech recognition or speech to text (STT). It incorporates
+knowledge and research in the computer science, linguistics and computer
+engineering fields.
+
+This demonstration shows how to combine a 2D CNN, RNN and a Connectionist
+Temporal Classification (CTC) loss to build an ASR. CTC is an algorithm
+used to train deep neural networks in speech recognition, handwriting
+recognition and other sequence problems. CTC is used when we donโt know
+how the input aligns with the output (how the characters in the transcript
+align to the audio). The model we create is similar to
+[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
+
+We will use the LJSpeech dataset from the
+[LibriVox](https://librivox.org/) project. It consists of short
+audio clips of a single speaker reading passages from 7 non-fiction books.
+
+We will evaluate the quality of the model using
+[Word Error Rate (WER)](https://en.wikipedia.org/wiki/Word_error_rate).
+WER is obtained by adding up
+the substitutions, insertions, and deletions that occur in a sequence of
+recognized words. Divide that number by the total number of words originally
+spoken. The result is the WER. To get the WER score you need to install the
+[jiwer](https://pypi.org/project/jiwer/) package. You can use the following command line:
+
+```
+pip install jiwer
+```
+
+**References:**
+
+- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
+- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition)
+- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
+- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html)
+
+"""
+
+"""
+## Setup
+"""
+
+import pandas as pd
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import matplotlib.pyplot as plt
+from IPython import display
+from jiwer import wer
+
+
+"""
+## Load the LJSpeech Dataset
+
+Let's download the [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/).
+The dataset contains 13,100 audio files as `wav` files in the `/wavs/` folder.
+The label (transcript) for each audio file is a string
+given in the `metadata.csv` file. The fields are:
+
+- **ID**: this is the name of the corresponding .wav file
+- **Transcription**: words spoken by the reader (UTF-8)
+- **Normalized transcription**: transcription with numbers,
+ordinals, and monetary units expanded into full words (UTF-8).
+
+For this demo we will use on the "Normalized transcription" field.
+
+Each audio file is a single-channel 16-bit PCM WAV with a sample rate of 22,050 Hz.
+"""
+
+data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
+data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
+wavs_path = data_path + "/wavs/"
+metadata_path = data_path + "/metadata.csv"
+
+
+# Read metadata file and parse it
+metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
+metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]
+metadata_df = metadata_df[["file_name", "normalized_transcription"]]
+metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
+metadata_df.head(3)
+
+
+"""
+We now split the data into training and validation set.
+"""
+
+split = int(len(metadata_df) * 0.90)
+df_train = metadata_df[:split]
+df_val = metadata_df[split:]
+
+print(f"Size of the training set: {len(df_train)}")
+print(f"Size of the training set: {len(df_val)}")
+
+
+"""
+## Preprocessing
+
+We first prepare the vocabulary to be used.
+"""
+
+# The set of characters accepted in the transcription.
+characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
+# Mapping characters to integers
+char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
+# Mapping integers back to original characters
+num_to_char = keras.layers.StringLookup(
+ vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
+)
+
+print(
+ f"The vocabulary is: {char_to_num.get_vocabulary()} "
+ f"(size ={char_to_num.vocabulary_size()})"
+)
+
+"""
+Next, we create the function that describes the transformation that we apply to each
+element of our dataset.
+"""
+
+# An integer scalar Tensor. The window length in samples.
+frame_length = 256
+# An integer scalar Tensor. The number of samples to step.
+frame_step = 160
+# An integer scalar Tensor. The size of the FFT to apply.
+# If not provided, uses the smallest power of 2 enclosing frame_length.
+fft_length = 384
+
+
+def encode_single_sample(wav_file, label):
+ ###########################################
+ ## Process the Audio
+ ##########################################
+ # 1. Read wav file
+ file = tf.io.read_file(wavs_path + wav_file + ".wav")
+ # 2. Decode the wav file
+ audio, _ = tf.audio.decode_wav(file)
+ audio = tf.squeeze(audio, axis=-1)
+ # 3. Change type to float
+ audio = tf.cast(audio, tf.float32)
+ # 4. Get the spectrogram
+ spectrogram = tf.signal.stft(
+ audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
+ )
+ # 5. We only need the magnitude, which can be derived by applying tf.abs
+ spectrogram = tf.abs(spectrogram)
+ spectrogram = tf.math.pow(spectrogram, 0.5)
+ # 6. normalisation
+ means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
+ stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
+ spectrogram = (spectrogram - means) / (stddevs + 1e-10)
+ ###########################################
+ ## Process the label
+ ##########################################
+ # 7. Convert label to Lower case
+ label = tf.strings.lower(label)
+ # 8. Split the label
+ label = tf.strings.unicode_split(label, input_encoding="UTF-8")
+ # 9. Map the characters in label to numbers
+ label = char_to_num(label)
+ # 10. Return a dict as our model is expecting two inputs
+ return spectrogram, label
+
+
+"""
+## Creating `Dataset` objects
+
+We create a `tf.data.Dataset` object that yields
+the transformed elements, in the same order as they
+appeared in the input.
+"""
+
+batch_size = 32
+# Define the training dataset
+train_dataset = tf.data.Dataset.from_tensor_slices(
+ (list(df_train["file_name"]), list(df_train["normalized_transcription"]))
+)
+train_dataset = (
+ train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
+ .padded_batch(batch_size)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+)
+
+# Define the validation dataset
+validation_dataset = tf.data.Dataset.from_tensor_slices(
+ (list(df_val["file_name"]), list(df_val["normalized_transcription"]))
+)
+validation_dataset = (
+ validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
+ .padded_batch(batch_size)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+)
+
+
+"""
+## Visualize the data
+
+Let's visualize an example in our dataset, including the
+audio clip, the spectrogram and the corresponding label.
+"""
+
+fig = plt.figure(figsize=(8, 5))
+for batch in train_dataset.take(1):
+ spectrogram = batch[0][0].numpy()
+ spectrogram = np.array([np.trim_zeros(x) for x in np.transpose(spectrogram)])
+ label = batch[1][0]
+ # Spectrogram
+ label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
+ ax = plt.subplot(2, 1, 1)
+ ax.imshow(spectrogram, vmax=1)
+ ax.set_title(label)
+ ax.axis("off")
+ # Wav
+ file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav")
+ audio, _ = tf.audio.decode_wav(file)
+ audio = audio.numpy()
+ ax = plt.subplot(2, 1, 2)
+ plt.plot(audio)
+ ax.set_title("Signal Wave")
+ ax.set_xlim(0, len(audio))
+ display.display(display.Audio(np.transpose(audio), rate=16000))
+plt.show()
+
+"""
+## Model
+
+We first define the CTC Loss function.
+"""
+
+
+def CTCLoss(y_true, y_pred):
+ # Compute the training-time loss value
+ batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
+ input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
+ label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
+
+ input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
+ label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
+
+ loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
+ return loss
+
+
+"""
+We now define our model. We will define a model similar to
+[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html).
+"""
+
+
+def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
+ """Model similar to DeepSpeech2."""
+ # Model's input
+ input_spectrogram = layers.Input((None, input_dim), name="input")
+ # Expand the dimension to use 2D CNN.
+ x = layers.Reshape((-1, input_dim, 1), name="expand_dim")(input_spectrogram)
+ # Convolution layer 1
+ x = layers.Conv2D(
+ filters=32,
+ kernel_size=[11, 41],
+ strides=[2, 2],
+ padding="same",
+ use_bias=False,
+ name="conv_1",
+ )(x)
+ x = layers.BatchNormalization(name="conv_1_bn")(x)
+ x = layers.ReLU(name="conv_1_relu")(x)
+ # Convolution layer 2
+ x = layers.Conv2D(
+ filters=32,
+ kernel_size=[11, 21],
+ strides=[1, 2],
+ padding="same",
+ use_bias=False,
+ name="conv_2",
+ )(x)
+ x = layers.BatchNormalization(name="conv_2_bn")(x)
+ x = layers.ReLU(name="conv_2_relu")(x)
+ # Reshape the resulted volume to feed the RNNs layers
+ x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x)
+ # RNN layers
+ for i in range(1, rnn_layers + 1):
+ recurrent = layers.GRU(
+ units=rnn_units,
+ activation="tanh",
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ return_sequences=True,
+ reset_after=True,
+ name=f"gru_{i}",
+ )
+ x = layers.Bidirectional(
+ recurrent, name=f"bidirectional_{i}", merge_mode="concat"
+ )(x)
+ if i < rnn_layers:
+ x = layers.Dropout(rate=0.5)(x)
+ # Dense layer
+ x = layers.Dense(units=rnn_units * 2, name="dense_1")(x)
+ x = layers.ReLU(name="dense_1_relu")(x)
+ x = layers.Dropout(rate=0.5)(x)
+ # Classification layer
+ output = layers.Dense(units=output_dim + 1, activation="softmax")(x)
+ # Model
+ model = keras.Model(input_spectrogram, output, name="DeepSpeech_2")
+ # Optimizer
+ opt = keras.optimizers.Adam(learning_rate=1e-4)
+ # Compile the model and return
+ model.compile(optimizer=opt, loss=CTCLoss)
+ return model
+
+
+# Get the model
+model = build_model(
+ input_dim=fft_length // 2 + 1,
+ output_dim=char_to_num.vocabulary_size(),
+ rnn_units=512,
+)
+model.summary(line_length=110)
+
+"""
+## Training and Evaluating
+"""
+
+
+# A utility function to decode the output of the network
+def decode_batch_predictions(pred):
+ input_len = np.ones(pred.shape[0]) * pred.shape[1]
+ # Use greedy search. For complex tasks, you can use beam search
+ results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
+ # Iterate over the results and get back the text
+ output_text = []
+ for result in results:
+ result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
+ output_text.append(result)
+ return output_text
+
+
+# A callback class to output a few transcriptions during training
+class CallbackEval(keras.callbacks.Callback):
+ """Displays a batch of outputs after every epoch."""
+
+ def __init__(self, dataset):
+ super().__init__()
+ self.dataset = dataset
+
+ def on_epoch_end(self, epoch: int, logs=None):
+ predictions = []
+ targets = []
+ for batch in self.dataset:
+ X, y = batch
+ batch_predictions = model.predict(X)
+ batch_predictions = decode_batch_predictions(batch_predictions)
+ predictions.extend(batch_predictions)
+ for label in y:
+ label = (
+ tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
+ )
+ targets.append(label)
+ wer_score = wer(targets, predictions)
+ print("-" * 100)
+ print(f"Word Error Rate: {wer_score:.4f}")
+ print("-" * 100)
+ for i in np.random.randint(0, len(predictions), 2):
+ print(f"Target : {targets[i]}")
+ print(f"Prediction: {predictions[i]}")
+ print("-" * 100)
+
+
+"""
+Let's start the training process.
+"""
+
+# Define the number of epochs.
+epochs = 1
+# Callback function to check transcription on the val set.
+validation_callback = CallbackEval(validation_dataset)
+# Train the model
+history = model.fit(
+ train_dataset,
+ validation_data=validation_dataset,
+ epochs=epochs,
+ callbacks=[validation_callback],
+)
+
+
+"""
+## Inference
+"""
+
+# Let's check results on more validation samples
+predictions = []
+targets = []
+for batch in validation_dataset:
+ X, y = batch
+ batch_predictions = model.predict(X)
+ batch_predictions = decode_batch_predictions(batch_predictions)
+ predictions.extend(batch_predictions)
+ for label in y:
+ label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
+ targets.append(label)
+wer_score = wer(targets, predictions)
+print("-" * 100)
+print(f"Word Error Rate: {wer_score:.4f}")
+print("-" * 100)
+for i in np.random.randint(0, len(predictions), 5):
+ print(f"Target : {targets[i]}")
+ print(f"Prediction: {predictions[i]}")
+ print("-" * 100)
+
+
+"""
+## Conclusion
+
+In practice, you should train for around 50 epochs or more. Each epoch
+takes approximately 5-6mn using a `GeForce RTX 2080 Ti` GPU.
+The model we trained at 50 epochs has a `Word Error Rate (WER) โ 16% to 17%`.
+
+Some of the transcriptions around epoch 50:
+
+**Audio file: LJ017-0009.wav**
+```
+- Target : sir thomas overbury was undoubtedly poisoned by lord rochester in the reign
+of james the first
+- Prediction: cer thomas overbery was undoubtedly poisoned by lordrochester in the reign
+of james the first
+```
+
+**Audio file: LJ003-0340.wav**
+```
+- Target : the committee does not seem to have yet understood that newgate could be
+only and properly replaced
+- Prediction: the committee does not seem to have yet understood that newgate could be
+only and proberly replace
+```
+
+**Audio file: LJ011-0136.wav**
+```
+- Target : still no sentence of death was carried out for the offense and in eighteen
+thirtytwo
+- Prediction: still no sentence of death was carried out for the offense and in eighteen
+thirtytwo
+```
+
+Example available on HuggingFace.
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/ctc_asr) | [](https://huggingface.co/spaces/keras-io/ctc_asr) |
+
+"""
diff --git a/knowledge_base/audio/melgan_spectrogram_inversion.py b/knowledge_base/audio/melgan_spectrogram_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4f7bf232553489b42cd873e69277b38ac3661ae
--- /dev/null
+++ b/knowledge_base/audio/melgan_spectrogram_inversion.py
@@ -0,0 +1,607 @@
+"""
+Title: MelGAN-based spectrogram inversion using feature matching
+Author: [Darshan Deshpande](https://twitter.com/getdarshan)
+Date created: 02/09/2021
+Last modified: 15/09/2021
+Description: Inversion of audio from mel-spectrograms using the MelGAN architecture and feature matching.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Autoregressive vocoders have been ubiquitous for a majority of the history of speech processing,
+but for most of their existence they have lacked parallelism.
+[MelGAN](https://arxiv.org/abs/1910.06711) is a
+non-autoregressive, fully convolutional vocoder architecture used for purposes ranging
+from spectral inversion and speech enhancement to present-day state-of-the-art
+speech synthesis when used as a decoder
+with models like Tacotron2 or FastSpeech that convert text to mel spectrograms.
+
+In this tutorial, we will have a look at the MelGAN architecture and how it can achieve
+fast spectral inversion, i.e. conversion of spectrograms to audio waves. The MelGAN
+implemented in this tutorial is similar to the original implementation with only the
+difference of method of padding for convolutions where we will use 'same' instead of
+reflect padding.
+"""
+
+"""
+## Importing and Defining Hyperparameters
+"""
+
+"""shell
+pip install -qqq tensorflow_addons
+pip install -qqq tensorflow-io
+"""
+
+import tensorflow as tf
+import tensorflow_io as tfio
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow_addons import layers as addon_layers
+
+# Setting logger level to avoid input shape warnings
+tf.get_logger().setLevel("ERROR")
+
+# Defining hyperparameters
+
+DESIRED_SAMPLES = 8192
+LEARNING_RATE_GEN = 1e-5
+LEARNING_RATE_DISC = 1e-6
+BATCH_SIZE = 16
+
+mse = keras.losses.MeanSquaredError()
+mae = keras.losses.MeanAbsoluteError()
+
+"""
+## Loading the Dataset
+
+This example uses the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/).
+
+The LJSpeech dataset is primarily used for text-to-speech and consists of 13,100 discrete
+speech samples taken from 7 non-fiction books, having a total length of approximately 24
+hours. The MelGAN training is only concerned with the audio waves so we process only the
+WAV files and ignore the audio annotations.
+"""
+
+"""shell
+wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
+tar -xf /content/LJSpeech-1.1.tar.bz2
+"""
+
+"""
+We create a `tf.data.Dataset` to load and process the audio files on the fly.
+The `preprocess()` function takes the file path as input and returns two instances of the
+wave, one for input and one as the ground truth for comparison. The input wave will be
+mapped to a spectrogram using the custom `MelSpec` layer as shown later in this example.
+"""
+
+# Splitting the dataset into training and testing splits
+wavs = tf.io.gfile.glob("LJSpeech-1.1/wavs/*.wav")
+print(f"Number of audio files: {len(wavs)}")
+
+
+# Mapper function for loading the audio. This function returns two instances of the wave
+def preprocess(filename):
+ audio = tf.audio.decode_wav(tf.io.read_file(filename), 1, DESIRED_SAMPLES).audio
+ return audio, audio
+
+
+# Create tf.data.Dataset objects and apply preprocessing
+train_dataset = tf.data.Dataset.from_tensor_slices((wavs,))
+train_dataset = train_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
+
+"""
+## Defining custom layers for MelGAN
+
+The MelGAN architecture consists of 3 main modules:
+
+1. The residual block
+2. Dilated convolutional block
+3. Discriminator block
+
+
+"""
+
+"""
+Since the network takes a mel-spectrogram as input, we will create an additional custom
+layer
+which can convert the raw audio wave to a spectrogram on-the-fly. We use the raw audio
+tensor from `train_dataset` and map it to a mel-spectrogram using the `MelSpec` layer
+below.
+"""
+
+# Custom keras layer for on-the-fly audio to spectrogram conversion
+
+
+class MelSpec(layers.Layer):
+ def __init__(
+ self,
+ frame_length=1024,
+ frame_step=256,
+ fft_length=None,
+ sampling_rate=22050,
+ num_mel_channels=80,
+ freq_min=125,
+ freq_max=7600,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.frame_length = frame_length
+ self.frame_step = frame_step
+ self.fft_length = fft_length
+ self.sampling_rate = sampling_rate
+ self.num_mel_channels = num_mel_channels
+ self.freq_min = freq_min
+ self.freq_max = freq_max
+ # Defining mel filter. This filter will be multiplied with the STFT output
+ self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix(
+ num_mel_bins=self.num_mel_channels,
+ num_spectrogram_bins=self.frame_length // 2 + 1,
+ sample_rate=self.sampling_rate,
+ lower_edge_hertz=self.freq_min,
+ upper_edge_hertz=self.freq_max,
+ )
+
+ def call(self, audio, training=True):
+ # We will only perform the transformation during training.
+ if training:
+ # Taking the Short Time Fourier Transform. Ensure that the audio is padded.
+ # In the paper, the STFT output is padded using the 'REFLECT' strategy.
+ stft = tf.signal.stft(
+ tf.squeeze(audio, -1),
+ self.frame_length,
+ self.frame_step,
+ self.fft_length,
+ pad_end=True,
+ )
+
+ # Taking the magnitude of the STFT output
+ magnitude = tf.abs(stft)
+
+ # Multiplying the Mel-filterbank with the magnitude and scaling it using the db scale
+ mel = tf.matmul(tf.square(magnitude), self.mel_filterbank)
+ log_mel_spec = tfio.audio.dbscale(mel, top_db=80)
+ return log_mel_spec
+ else:
+ return audio
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "frame_length": self.frame_length,
+ "frame_step": self.frame_step,
+ "fft_length": self.fft_length,
+ "sampling_rate": self.sampling_rate,
+ "num_mel_channels": self.num_mel_channels,
+ "freq_min": self.freq_min,
+ "freq_max": self.freq_max,
+ }
+ )
+ return config
+
+
+"""
+The residual convolutional block extensively uses dilations and has a total receptive
+field of 27 timesteps per block. The dilations must grow as a power of the `kernel_size`
+to ensure reduction of hissing noise in the output. The network proposed by the paper is
+as follows:
+
+
+"""
+
+# Creating the residual stack block
+
+
+def residual_stack(input, filters):
+ """Convolutional residual stack with weight normalization.
+
+ Args:
+ filters: int, determines filter size for the residual stack.
+
+ Returns:
+ Residual stack output.
+ """
+ c1 = addon_layers.WeightNormalization(
+ layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
+ )(input)
+ lrelu1 = layers.LeakyReLU()(c1)
+ c2 = addon_layers.WeightNormalization(
+ layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
+ )(lrelu1)
+ add1 = layers.Add()([c2, input])
+
+ lrelu2 = layers.LeakyReLU()(add1)
+ c3 = addon_layers.WeightNormalization(
+ layers.Conv1D(filters, 3, dilation_rate=3, padding="same"), data_init=False
+ )(lrelu2)
+ lrelu3 = layers.LeakyReLU()(c3)
+ c4 = addon_layers.WeightNormalization(
+ layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
+ )(lrelu3)
+ add2 = layers.Add()([add1, c4])
+
+ lrelu4 = layers.LeakyReLU()(add2)
+ c5 = addon_layers.WeightNormalization(
+ layers.Conv1D(filters, 3, dilation_rate=9, padding="same"), data_init=False
+ )(lrelu4)
+ lrelu5 = layers.LeakyReLU()(c5)
+ c6 = addon_layers.WeightNormalization(
+ layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
+ )(lrelu5)
+ add3 = layers.Add()([c6, add2])
+
+ return add3
+
+
+"""
+Each convolutional block uses the dilations offered by the residual stack
+and upsamples the input data by the `upsampling_factor`.
+"""
+
+# Dilated convolutional block consisting of the Residual stack
+
+
+def conv_block(input, conv_dim, upsampling_factor):
+ """Dilated Convolutional Block with weight normalization.
+
+ Args:
+ conv_dim: int, determines filter size for the block.
+ upsampling_factor: int, scale for upsampling.
+
+ Returns:
+ Dilated convolution block.
+ """
+ conv_t = addon_layers.WeightNormalization(
+ layers.Conv1DTranspose(conv_dim, 16, upsampling_factor, padding="same"),
+ data_init=False,
+ )(input)
+ lrelu1 = layers.LeakyReLU()(conv_t)
+ res_stack = residual_stack(lrelu1, conv_dim)
+ lrelu2 = layers.LeakyReLU()(res_stack)
+ return lrelu2
+
+
+"""
+The discriminator block consists of convolutions and downsampling layers. This block is
+essential for the implementation of the feature matching technique.
+
+Each discriminator outputs a list of feature maps that will be compared during training
+to compute the feature matching loss.
+"""
+
+
+def discriminator_block(input):
+ conv1 = addon_layers.WeightNormalization(
+ layers.Conv1D(16, 15, 1, "same"), data_init=False
+ )(input)
+ lrelu1 = layers.LeakyReLU()(conv1)
+ conv2 = addon_layers.WeightNormalization(
+ layers.Conv1D(64, 41, 4, "same", groups=4), data_init=False
+ )(lrelu1)
+ lrelu2 = layers.LeakyReLU()(conv2)
+ conv3 = addon_layers.WeightNormalization(
+ layers.Conv1D(256, 41, 4, "same", groups=16), data_init=False
+ )(lrelu2)
+ lrelu3 = layers.LeakyReLU()(conv3)
+ conv4 = addon_layers.WeightNormalization(
+ layers.Conv1D(1024, 41, 4, "same", groups=64), data_init=False
+ )(lrelu3)
+ lrelu4 = layers.LeakyReLU()(conv4)
+ conv5 = addon_layers.WeightNormalization(
+ layers.Conv1D(1024, 41, 4, "same", groups=256), data_init=False
+ )(lrelu4)
+ lrelu5 = layers.LeakyReLU()(conv5)
+ conv6 = addon_layers.WeightNormalization(
+ layers.Conv1D(1024, 5, 1, "same"), data_init=False
+ )(lrelu5)
+ lrelu6 = layers.LeakyReLU()(conv6)
+ conv7 = addon_layers.WeightNormalization(
+ layers.Conv1D(1, 3, 1, "same"), data_init=False
+ )(lrelu6)
+ return [lrelu1, lrelu2, lrelu3, lrelu4, lrelu5, lrelu6, conv7]
+
+
+"""
+### Create the generator
+"""
+
+
+def create_generator(input_shape):
+ inp = keras.Input(input_shape)
+ x = MelSpec()(inp)
+ x = layers.Conv1D(512, 7, padding="same")(x)
+ x = layers.LeakyReLU()(x)
+ x = conv_block(x, 256, 8)
+ x = conv_block(x, 128, 8)
+ x = conv_block(x, 64, 2)
+ x = conv_block(x, 32, 2)
+ x = addon_layers.WeightNormalization(
+ layers.Conv1D(1, 7, padding="same", activation="tanh")
+ )(x)
+ return keras.Model(inp, x)
+
+
+# We use a dynamic input shape for the generator since the model is fully convolutional
+generator = create_generator((None, 1))
+generator.summary()
+
+"""
+### Create the discriminator
+"""
+
+
+def create_discriminator(input_shape):
+ inp = keras.Input(input_shape)
+ out_map1 = discriminator_block(inp)
+ pool1 = layers.AveragePooling1D()(inp)
+ out_map2 = discriminator_block(pool1)
+ pool2 = layers.AveragePooling1D()(pool1)
+ out_map3 = discriminator_block(pool2)
+ return keras.Model(inp, [out_map1, out_map2, out_map3])
+
+
+# We use a dynamic input shape for the discriminator
+# This is done because the input shape for the generator is unknown
+discriminator = create_discriminator((None, 1))
+
+discriminator.summary()
+
+"""
+## Defining the loss functions
+
+**Generator Loss**
+
+The generator architecture uses a combination of two losses
+
+1. Mean Squared Error:
+
+This is the standard MSE generator loss calculated between ones and the outputs from the
+discriminator with _N_ layers.
+
+
+
+
+
+2. Feature Matching Loss:
+
+This loss involves extracting the outputs of every layer from the discriminator for both
+the generator and ground truth and compare each layer output _k_ using Mean Absolute Error.
+
+
+
+
+
+**Discriminator Loss**
+
+The discriminator uses the Mean Absolute Error and compares the real data predictions
+with ones and generated predictions with zeros.
+
+
+
+
+"""
+
+# Generator loss
+
+
+def generator_loss(real_pred, fake_pred):
+ """Loss function for the generator.
+
+ Args:
+ real_pred: Tensor, output of the ground truth wave passed through the discriminator.
+ fake_pred: Tensor, output of the generator prediction passed through the discriminator.
+
+ Returns:
+ Loss for the generator.
+ """
+ gen_loss = []
+ for i in range(len(fake_pred)):
+ gen_loss.append(mse(tf.ones_like(fake_pred[i][-1]), fake_pred[i][-1]))
+
+ return tf.reduce_mean(gen_loss)
+
+
+def feature_matching_loss(real_pred, fake_pred):
+ """Implements the feature matching loss.
+
+ Args:
+ real_pred: Tensor, output of the ground truth wave passed through the discriminator.
+ fake_pred: Tensor, output of the generator prediction passed through the discriminator.
+
+ Returns:
+ Feature Matching Loss.
+ """
+ fm_loss = []
+ for i in range(len(fake_pred)):
+ for j in range(len(fake_pred[i]) - 1):
+ fm_loss.append(mae(real_pred[i][j], fake_pred[i][j]))
+
+ return tf.reduce_mean(fm_loss)
+
+
+def discriminator_loss(real_pred, fake_pred):
+ """Implements the discriminator loss.
+
+ Args:
+ real_pred: Tensor, output of the ground truth wave passed through the discriminator.
+ fake_pred: Tensor, output of the generator prediction passed through the discriminator.
+
+ Returns:
+ Discriminator Loss.
+ """
+ real_loss, fake_loss = [], []
+ for i in range(len(real_pred)):
+ real_loss.append(mse(tf.ones_like(real_pred[i][-1]), real_pred[i][-1]))
+ fake_loss.append(mse(tf.zeros_like(fake_pred[i][-1]), fake_pred[i][-1]))
+
+ # Calculating the final discriminator loss after scaling
+ disc_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)
+ return disc_loss
+
+
+"""
+Defining the MelGAN model for training.
+This subclass overrides the `train_step()` method to implement the training logic.
+"""
+
+
+class MelGAN(keras.Model):
+ def __init__(self, generator, discriminator, **kwargs):
+ """MelGAN trainer class
+
+ Args:
+ generator: keras.Model, Generator model
+ discriminator: keras.Model, Discriminator model
+ """
+ super().__init__(**kwargs)
+ self.generator = generator
+ self.discriminator = discriminator
+
+ def compile(
+ self,
+ gen_optimizer,
+ disc_optimizer,
+ generator_loss,
+ feature_matching_loss,
+ discriminator_loss,
+ ):
+ """MelGAN compile method.
+
+ Args:
+ gen_optimizer: keras.optimizer, optimizer to be used for training
+ disc_optimizer: keras.optimizer, optimizer to be used for training
+ generator_loss: callable, loss function for generator
+ feature_matching_loss: callable, loss function for feature matching
+ discriminator_loss: callable, loss function for discriminator
+ """
+ super().compile()
+
+ # Optimizers
+ self.gen_optimizer = gen_optimizer
+ self.disc_optimizer = disc_optimizer
+
+ # Losses
+ self.generator_loss = generator_loss
+ self.feature_matching_loss = feature_matching_loss
+ self.discriminator_loss = discriminator_loss
+
+ # Trackers
+ self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")
+ self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
+
+ def train_step(self, batch):
+ x_batch_train, y_batch_train = batch
+
+ with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
+ # Generating the audio wave
+ gen_audio_wave = generator(x_batch_train, training=True)
+
+ # Generating the features using the discriminator
+ real_pred = discriminator(y_batch_train)
+ fake_pred = discriminator(gen_audio_wave)
+
+ # Calculating the generator losses
+ gen_loss = generator_loss(real_pred, fake_pred)
+ fm_loss = feature_matching_loss(real_pred, fake_pred)
+
+ # Calculating final generator loss
+ gen_fm_loss = gen_loss + 10 * fm_loss
+
+ # Calculating the discriminator losses
+ disc_loss = discriminator_loss(real_pred, fake_pred)
+
+ # Calculating and applying the gradients for generator and discriminator
+ grads_gen = gen_tape.gradient(gen_fm_loss, generator.trainable_weights)
+ grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_weights)
+ gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_weights))
+ disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_weights))
+
+ self.gen_loss_tracker.update_state(gen_fm_loss)
+ self.disc_loss_tracker.update_state(disc_loss)
+
+ return {
+ "gen_loss": self.gen_loss_tracker.result(),
+ "disc_loss": self.disc_loss_tracker.result(),
+ }
+
+
+"""
+## Training
+
+The paper suggests that the training with dynamic shapes takes around 400,000 steps (~500
+epochs). For this example, we will run it only for a single epoch (819 steps).
+Longer training time (greater than 300 epochs) will almost certainly provide better results.
+"""
+
+gen_optimizer = keras.optimizers.Adam(
+ LEARNING_RATE_GEN, beta_1=0.5, beta_2=0.9, clipnorm=1
+)
+disc_optimizer = keras.optimizers.Adam(
+ LEARNING_RATE_DISC, beta_1=0.5, beta_2=0.9, clipnorm=1
+)
+
+# Start training
+generator = create_generator((None, 1))
+discriminator = create_discriminator((None, 1))
+
+mel_gan = MelGAN(generator, discriminator)
+mel_gan.compile(
+ gen_optimizer,
+ disc_optimizer,
+ generator_loss,
+ feature_matching_loss,
+ discriminator_loss,
+)
+mel_gan.fit(
+ train_dataset.shuffle(200).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE), epochs=1
+)
+
+"""
+## Testing the model
+
+The trained model can now be used for real time text-to-speech translation tasks.
+To test how fast the MelGAN inference can be, let us take a sample audio mel-spectrogram
+and convert it. Note that the actual model pipeline will not include the `MelSpec` layer
+and hence this layer will be disabled during inference. The inference input will be a
+mel-spectrogram processed similar to the `MelSpec` layer configuration.
+
+For testing this, we will create a randomly uniformly distributed tensor to simulate the
+behavior of the inference pipeline.
+"""
+
+# Sampling a random tensor to mimic a batch of 128 spectrograms of shape [50, 80]
+audio_sample = tf.random.uniform([128, 50, 80])
+
+"""
+Timing the inference speed of a single sample. Running this, you can see that the average
+inference time per spectrogram ranges from 8 milliseconds to 10 milliseconds on a K80 GPU which is
+pretty fast.
+"""
+pred = generator.predict(audio_sample, batch_size=32, verbose=1)
+"""
+## Conclusion
+
+The MelGAN is a highly effective architecture for spectral inversion that has a Mean
+Opinion Score (MOS) of 3.61 that considerably outperforms the Griffin
+Lim algorithm having a MOS of just 1.57. In contrast with this, the MelGAN compares with
+the state-of-the-art WaveGlow and WaveNet architectures on text-to-speech and speech
+enhancement tasks on
+the LJSpeech and VCTK datasets [1].
+
+This tutorial highlights:
+
+1. The advantages of using dilated convolutions that grow with the filter size
+2. Implementation of a custom layer for on-the-fly conversion of audio waves to
+mel-spectrograms
+3. Effectiveness of using the feature matching loss function for training GAN generators.
+
+Further reading
+
+1. [MelGAN paper](https://arxiv.org/abs/1910.06711) (Kundan Kumar et al.) to
+understand the reasoning behind the architecture and training process
+2. For in-depth understanding of the feature matching loss, you can refer to [Improved
+Techniques for Training GANs](https://arxiv.org/abs/1606.03498) (Tim Salimans et
+al.).
+"""
diff --git a/knowledge_base/audio/speaker_recognition_using_cnn.py b/knowledge_base/audio/speaker_recognition_using_cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b95f280846528604d0839a86e9b6aaada0ffe8e
--- /dev/null
+++ b/knowledge_base/audio/speaker_recognition_using_cnn.py
@@ -0,0 +1,489 @@
+"""
+Title: Speaker Recognition
+Author: [Fadi Badine](https://twitter.com/fadibadine)
+Date created: 14/06/2020
+Last modified: 19/07/2023
+Description: Classify speakers using Fast Fourier Transform (FFT) and a 1D Convnet.
+Accelerator: GPU
+Converted to Keras 3 by: [Fadi Badine](https://twitter.com/fadibadine)
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to create a model to classify speakers from the
+frequency domain representation of speech recordings, obtained via Fast Fourier
+Transform (FFT).
+
+It shows the following:
+
+- How to use `tf.data` to load, preprocess and feed audio streams into a model
+- How to create a 1D convolutional network with residual
+connections for audio classification.
+
+Our process:
+
+- We prepare a dataset of speech samples from different speakers, with the speaker as label.
+- We add background noise to these samples to augment our data.
+- We take the FFT of these samples.
+- We train a 1D convnet to predict the correct speaker given a noisy FFT speech sample.
+
+Note:
+
+- This example should be run with TensorFlow 2.3 or higher, or `tf-nightly`.
+- The noise samples in the dataset need to be resampled to a sampling rate of 16000 Hz
+before using the code in this example. In order to do this, you will need to have
+installed `ffmpg`.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import shutil
+import numpy as np
+
+import tensorflow as tf
+import keras
+
+from pathlib import Path
+from IPython.display import display, Audio
+
+# Get the data from https://www.kaggle.com/kongaevans/speaker-recognition-dataset/
+# and save it to ./speaker-recognition-dataset.zip
+# then unzip it to ./16000_pcm_speeches
+"""shell
+kaggle datasets download -d kongaevans/speaker-recognition-dataset
+unzip -qq speaker-recognition-dataset.zip
+"""
+
+DATASET_ROOT = "16000_pcm_speeches"
+
+# The folders in which we will put the audio samples and the noise samples
+AUDIO_SUBFOLDER = "audio"
+NOISE_SUBFOLDER = "noise"
+
+DATASET_AUDIO_PATH = os.path.join(DATASET_ROOT, AUDIO_SUBFOLDER)
+DATASET_NOISE_PATH = os.path.join(DATASET_ROOT, NOISE_SUBFOLDER)
+
+# Percentage of samples to use for validation
+VALID_SPLIT = 0.1
+
+# Seed to use when shuffling the dataset and the noise
+SHUFFLE_SEED = 43
+
+# The sampling rate to use.
+# This is the one used in all the audio samples.
+# We will resample all the noise to this sampling rate.
+# This will also be the output size of the audio wave samples
+# (since all samples are of 1 second long)
+SAMPLING_RATE = 16000
+
+# The factor to multiply the noise with according to:
+# noisy_sample = sample + noise * prop * scale
+# where prop = sample_amplitude / noise_amplitude
+SCALE = 0.5
+
+BATCH_SIZE = 128
+EPOCHS = 1 # For a real training run, use EPOCHS = 100
+
+
+"""
+## Data preparation
+
+The dataset is composed of 7 folders, divided into 2 groups:
+
+- Speech samples, with 5 folders for 5 different speakers. Each folder contains
+1500 audio files, each 1 second long and sampled at 16000 Hz.
+- Background noise samples, with 2 folders and a total of 6 files. These files
+are longer than 1 second (and originally not sampled at 16000 Hz, but we will resample them to 16000 Hz).
+We will use those 6 files to create 354 1-second-long noise samples to be used for training.
+
+Let's sort these 2 categories into 2 folders:
+
+- An `audio` folder which will contain all the per-speaker speech sample folders
+- A `noise` folder which will contain all the noise samples
+"""
+
+"""
+Before sorting the audio and noise categories into 2 folders,
+we have the following directory structure:
+
+```
+main_directory/
+...speaker_a/
+...speaker_b/
+...speaker_c/
+...speaker_d/
+...speaker_e/
+...other/
+..._background_noise_/
+```
+
+After sorting, we end up with the following structure:
+
+```
+main_directory/
+...audio/
+......speaker_a/
+......speaker_b/
+......speaker_c/
+......speaker_d/
+......speaker_e/
+...noise/
+......other/
+......_background_noise_/
+```
+"""
+
+for folder in os.listdir(DATASET_ROOT):
+ if os.path.isdir(os.path.join(DATASET_ROOT, folder)):
+ if folder in [AUDIO_SUBFOLDER, NOISE_SUBFOLDER]:
+ # If folder is `audio` or `noise`, do nothing
+ continue
+ elif folder in ["other", "_background_noise_"]:
+ # If folder is one of the folders that contains noise samples,
+ # move it to the `noise` folder
+ shutil.move(
+ os.path.join(DATASET_ROOT, folder),
+ os.path.join(DATASET_NOISE_PATH, folder),
+ )
+ else:
+ # Otherwise, it should be a speaker folder, then move it to
+ # `audio` folder
+ shutil.move(
+ os.path.join(DATASET_ROOT, folder),
+ os.path.join(DATASET_AUDIO_PATH, folder),
+ )
+
+"""
+## Noise preparation
+
+In this section:
+
+- We load all noise samples (which should have been resampled to 16000)
+- We split those noise samples to chunks of 16000 samples which
+correspond to 1 second duration each
+"""
+
+# Get the list of all noise files
+noise_paths = []
+for subdir in os.listdir(DATASET_NOISE_PATH):
+ subdir_path = Path(DATASET_NOISE_PATH) / subdir
+ if os.path.isdir(subdir_path):
+ noise_paths += [
+ os.path.join(subdir_path, filepath)
+ for filepath in os.listdir(subdir_path)
+ if filepath.endswith(".wav")
+ ]
+if not noise_paths:
+ raise RuntimeError(f"Could not find any files at {DATASET_NOISE_PATH}")
+print(
+ "Found {} files belonging to {} directories".format(
+ len(noise_paths), len(os.listdir(DATASET_NOISE_PATH))
+ )
+)
+
+"""
+Resample all noise samples to 16000 Hz
+"""
+
+command = (
+ "for dir in `ls -1 " + DATASET_NOISE_PATH + "`; do "
+ "for file in `ls -1 " + DATASET_NOISE_PATH + "/$dir/*.wav`; do "
+ "sample_rate=`ffprobe -hide_banner -loglevel panic -show_streams "
+ "$file | grep sample_rate | cut -f2 -d=`; "
+ "if [ $sample_rate -ne 16000 ]; then "
+ "ffmpeg -hide_banner -loglevel panic -y "
+ "-i $file -ar 16000 temp.wav; "
+ "mv temp.wav $file; "
+ "fi; done; done"
+)
+os.system(command)
+
+
+# Split noise into chunks of 16,000 steps each
+def load_noise_sample(path):
+ sample, sampling_rate = tf.audio.decode_wav(
+ tf.io.read_file(path), desired_channels=1
+ )
+ if sampling_rate == SAMPLING_RATE:
+ # Number of slices of 16000 each that can be generated from the noise sample
+ slices = int(sample.shape[0] / SAMPLING_RATE)
+ sample = tf.split(sample[: slices * SAMPLING_RATE], slices)
+ return sample
+ else:
+ print("Sampling rate for {} is incorrect. Ignoring it".format(path))
+ return None
+
+
+noises = []
+for path in noise_paths:
+ sample = load_noise_sample(path)
+ if sample:
+ noises.extend(sample)
+noises = tf.stack(noises)
+
+print(
+ "{} noise files were split into {} noise samples where each is {} sec. long".format(
+ len(noise_paths), noises.shape[0], noises.shape[1] // SAMPLING_RATE
+ )
+)
+
+"""
+## Dataset generation
+"""
+
+
+def paths_and_labels_to_dataset(audio_paths, labels):
+ """Constructs a dataset of audios and labels."""
+ path_ds = tf.data.Dataset.from_tensor_slices(audio_paths)
+ audio_ds = path_ds.map(
+ lambda x: path_to_audio(x), num_parallel_calls=tf.data.AUTOTUNE
+ )
+ label_ds = tf.data.Dataset.from_tensor_slices(labels)
+ return tf.data.Dataset.zip((audio_ds, label_ds))
+
+
+def path_to_audio(path):
+ """Reads and decodes an audio file."""
+ audio = tf.io.read_file(path)
+ audio, _ = tf.audio.decode_wav(audio, 1, SAMPLING_RATE)
+ return audio
+
+
+def add_noise(audio, noises=None, scale=0.5):
+ if noises is not None:
+ # Create a random tensor of the same size as audio ranging from
+ # 0 to the number of noise stream samples that we have.
+ tf_rnd = tf.random.uniform(
+ (tf.shape(audio)[0],), 0, noises.shape[0], dtype=tf.int32
+ )
+ noise = tf.gather(noises, tf_rnd, axis=0)
+
+ # Get the amplitude proportion between the audio and the noise
+ prop = tf.math.reduce_max(audio, axis=1) / tf.math.reduce_max(noise, axis=1)
+ prop = tf.repeat(tf.expand_dims(prop, axis=1), tf.shape(audio)[1], axis=1)
+
+ # Adding the rescaled noise to audio
+ audio = audio + noise * prop * scale
+
+ return audio
+
+
+def audio_to_fft(audio):
+ # Since tf.signal.fft applies FFT on the innermost dimension,
+ # we need to squeeze the dimensions and then expand them again
+ # after FFT
+ audio = tf.squeeze(audio, axis=-1)
+ fft = tf.signal.fft(
+ tf.cast(tf.complex(real=audio, imag=tf.zeros_like(audio)), tf.complex64)
+ )
+ fft = tf.expand_dims(fft, axis=-1)
+
+ # Return the absolute value of the first half of the FFT
+ # which represents the positive frequencies
+ return tf.math.abs(fft[:, : (audio.shape[1] // 2), :])
+
+
+# Get the list of audio file paths along with their corresponding labels
+
+class_names = os.listdir(DATASET_AUDIO_PATH)
+print(
+ "Our class names: {}".format(
+ class_names,
+ )
+)
+
+audio_paths = []
+labels = []
+for label, name in enumerate(class_names):
+ print(
+ "Processing speaker {}".format(
+ name,
+ )
+ )
+ dir_path = Path(DATASET_AUDIO_PATH) / name
+ speaker_sample_paths = [
+ os.path.join(dir_path, filepath)
+ for filepath in os.listdir(dir_path)
+ if filepath.endswith(".wav")
+ ]
+ audio_paths += speaker_sample_paths
+ labels += [label] * len(speaker_sample_paths)
+
+print(
+ "Found {} files belonging to {} classes.".format(len(audio_paths), len(class_names))
+)
+
+# Shuffle
+rng = np.random.RandomState(SHUFFLE_SEED)
+rng.shuffle(audio_paths)
+rng = np.random.RandomState(SHUFFLE_SEED)
+rng.shuffle(labels)
+
+# Split into training and validation
+num_val_samples = int(VALID_SPLIT * len(audio_paths))
+print("Using {} files for training.".format(len(audio_paths) - num_val_samples))
+train_audio_paths = audio_paths[:-num_val_samples]
+train_labels = labels[:-num_val_samples]
+
+print("Using {} files for validation.".format(num_val_samples))
+valid_audio_paths = audio_paths[-num_val_samples:]
+valid_labels = labels[-num_val_samples:]
+
+# Create 2 datasets, one for training and the other for validation
+train_ds = paths_and_labels_to_dataset(train_audio_paths, train_labels)
+train_ds = train_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch(
+ BATCH_SIZE
+)
+
+valid_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels)
+valid_ds = valid_ds.shuffle(buffer_size=32 * 8, seed=SHUFFLE_SEED).batch(32)
+
+
+# Add noise to the training set
+train_ds = train_ds.map(
+ lambda x, y: (add_noise(x, noises, scale=SCALE), y),
+ num_parallel_calls=tf.data.AUTOTUNE,
+)
+
+# Transform audio wave to the frequency domain using `audio_to_fft`
+train_ds = train_ds.map(
+ lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE
+)
+train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
+
+valid_ds = valid_ds.map(
+ lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE
+)
+valid_ds = valid_ds.prefetch(tf.data.AUTOTUNE)
+
+"""
+## Model Definition
+"""
+
+
+def residual_block(x, filters, conv_num=3, activation="relu"):
+ # Shortcut
+ s = keras.layers.Conv1D(filters, 1, padding="same")(x)
+ for i in range(conv_num - 1):
+ x = keras.layers.Conv1D(filters, 3, padding="same")(x)
+ x = keras.layers.Activation(activation)(x)
+ x = keras.layers.Conv1D(filters, 3, padding="same")(x)
+ x = keras.layers.Add()([x, s])
+ x = keras.layers.Activation(activation)(x)
+ return keras.layers.MaxPool1D(pool_size=2, strides=2)(x)
+
+
+def build_model(input_shape, num_classes):
+ inputs = keras.layers.Input(shape=input_shape, name="input")
+
+ x = residual_block(inputs, 16, 2)
+ x = residual_block(x, 32, 2)
+ x = residual_block(x, 64, 3)
+ x = residual_block(x, 128, 3)
+ x = residual_block(x, 128, 3)
+
+ x = keras.layers.AveragePooling1D(pool_size=3, strides=3)(x)
+ x = keras.layers.Flatten()(x)
+ x = keras.layers.Dense(256, activation="relu")(x)
+ x = keras.layers.Dense(128, activation="relu")(x)
+
+ outputs = keras.layers.Dense(num_classes, activation="softmax", name="output")(x)
+
+ return keras.models.Model(inputs=inputs, outputs=outputs)
+
+
+model = build_model((SAMPLING_RATE // 2, 1), len(class_names))
+
+model.summary()
+
+# Compile the model using Adam's default learning rate
+model.compile(
+ optimizer="Adam",
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+)
+
+# Add callbacks:
+# 'EarlyStopping' to stop training when the model is not enhancing anymore
+# 'ModelCheckPoint' to always keep the model that has the best val_accuracy
+model_save_filename = "model.keras"
+
+earlystopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
+mdlcheckpoint_cb = keras.callbacks.ModelCheckpoint(
+ model_save_filename, monitor="val_accuracy", save_best_only=True
+)
+
+"""
+## Training
+"""
+
+history = model.fit(
+ train_ds,
+ epochs=EPOCHS,
+ validation_data=valid_ds,
+ callbacks=[earlystopping_cb, mdlcheckpoint_cb],
+)
+
+"""
+## Evaluation
+"""
+
+print(model.evaluate(valid_ds))
+
+"""
+We get ~ 98% validation accuracy.
+"""
+
+"""
+## Demonstration
+
+Let's take some samples and:
+
+- Predict the speaker
+- Compare the prediction with the real speaker
+- Listen to the audio to see that despite the samples being noisy,
+the model is still pretty accurate
+"""
+
+SAMPLES_TO_DISPLAY = 10
+
+test_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels)
+test_ds = test_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch(
+ BATCH_SIZE
+)
+
+test_ds = test_ds.map(
+ lambda x, y: (add_noise(x, noises, scale=SCALE), y),
+ num_parallel_calls=tf.data.AUTOTUNE,
+)
+
+for audios, labels in test_ds.take(1):
+ # Get the signal FFT
+ ffts = audio_to_fft(audios)
+ # Predict
+ y_pred = model.predict(ffts)
+ # Take random samples
+ rnd = np.random.randint(0, BATCH_SIZE, SAMPLES_TO_DISPLAY)
+ audios = audios.numpy()[rnd, :, :]
+ labels = labels.numpy()[rnd]
+ y_pred = np.argmax(y_pred, axis=-1)[rnd]
+
+ for index in range(SAMPLES_TO_DISPLAY):
+ # For every sample, print the true and predicted label
+ # as well as run the voice with the noise
+ print(
+ "Speaker:\33{} {}\33[0m\tPredicted:\33{} {}\33[0m".format(
+ "[92m" if labels[index] == y_pred[index] else "[91m",
+ class_names[labels[index]],
+ "[92m" if labels[index] == y_pred[index] else "[91m",
+ class_names[y_pred[index]],
+ )
+ )
+ display(Audio(audios[index, :, :].squeeze(), rate=SAMPLING_RATE))
diff --git a/knowledge_base/audio/stft.py b/knowledge_base/audio/stft.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6f74a9ad908dff08ccd34287340c8a4f7900a3e
--- /dev/null
+++ b/knowledge_base/audio/stft.py
@@ -0,0 +1,409 @@
+"""
+Title: Audio Classification with the STFTSpectrogram layer
+Author: [Mostafa M. Amin](https://mostafa-amin.com)
+Date created: 2024/10/04
+Last modified: 2024/10/04
+Description: Introducing the `STFTSpectrogram` layer to extract spectrograms for audio classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Preprocessing audio as spectrograms is an essential step in the vast majority
+of audio-based applications. Spectrograms represent the frequency content of a
+signal over time, are widely used for this purpose. In this tutorial, we'll
+demonstrate how to use the `STFTSpectrogram` layer in Keras to convert raw
+audio waveforms into spectrograms **within the model**. We'll then feed
+these spectrograms into an LSTM network followed by Dense layers to perform
+audio classification on the Speech Commands dataset.
+
+We will:
+
+- Load the ESC-10 dataset.
+- Preprocess the raw audio waveforms and generate spectrograms using
+ `STFTSpectrogram`.
+- Build two models, one using spectrograms as 1D signals and the other is using
+ as images (2D signals) with a pretrained image model.
+- Train and evaluate the models.
+
+## Setup
+
+### Importing the necessary libraries
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+
+import keras
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import scipy.io.wavfile
+from keras import layers
+from scipy.signal import resample
+
+keras.utils.set_random_seed(41)
+
+"""
+### Define some variables
+"""
+
+BASE_DATA_DIR = "./datasets/esc-50_extracted/ESC-50-master/"
+BATCH_SIZE = 16
+NUM_CLASSES = 10
+EPOCHS = 200
+SAMPLE_RATE = 16000
+
+"""
+## Download and Preprocess the ESC-10 Dataset
+
+We'll use the Dataset for Environmental Sound Classification dataset (ESC-10).
+This dataset consists of five-second .wav files of environmental sounds.
+
+### Download and Extract the dataset
+"""
+
+keras.utils.get_file(
+ "esc-50.zip",
+ "https://github.com/karoldvl/ESC-50/archive/master.zip",
+ cache_dir="./",
+ cache_subdir="datasets",
+ extract=True,
+)
+
+"""
+### Read the CSV file
+"""
+
+pd_data = pd.read_csv(os.path.join(BASE_DATA_DIR, "meta", "esc50.csv"))
+# filter ESC-50 to ESC-10 and reassign the targets
+pd_data = pd_data[pd_data["esc10"]]
+targets = sorted(pd_data["target"].unique().tolist())
+assert len(targets) == NUM_CLASSES
+old_target_to_new_target = {old: new for new, old in enumerate(targets)}
+pd_data["target"] = pd_data["target"].map(lambda t: old_target_to_new_target[t])
+pd_data
+
+"""
+### Define functions to read and preprocess the WAV files
+"""
+
+
+def read_wav_file(path, target_sr=SAMPLE_RATE):
+ sr, wav = scipy.io.wavfile.read(os.path.join(BASE_DATA_DIR, "audio", path))
+ wav = wav.astype(np.float32) / 32768.0 # normalize to [-1, 1]
+ num_samples = int(len(wav) * target_sr / sr) # resample to 16 kHz
+ wav = resample(wav, num_samples)
+ return wav[:, None] # Add a channel dimension (of size 1)
+
+
+"""
+Create a function that uses the `STFTSpectrogram` to compute a spectrogram,
+then plots it.
+"""
+
+
+def plot_single_spectrogram(sample_wav_data):
+ spectrogram = layers.STFTSpectrogram(
+ mode="log",
+ frame_length=SAMPLE_RATE * 20 // 1000,
+ frame_step=SAMPLE_RATE * 5 // 1000,
+ fft_length=1024,
+ trainable=False,
+ )(sample_wav_data[None, ...])[0, ...]
+
+ # Plot the spectrogram
+ plt.imshow(spectrogram.T, origin="lower")
+ plt.title("Single Channel Spectrogram")
+ plt.xlabel("Time")
+ plt.ylabel("Frequency")
+ plt.show()
+
+
+"""
+Create a function that uses the `STFTSpectrogram` to compute three
+spectrograms with multiple bandwidths, then aligns them as an image
+with different channels, to get a multi-bandwith spectrogram,
+then plots the spectrogram.
+"""
+
+
+def plot_multi_bandwidth_spectrogram(sample_wav_data):
+ # All spectrograms must use the same `fft_length`, `frame_step`, and
+ # `padding="same"` in order to produce spectrograms with identical shapes,
+ # hence aligning them together. `expand_dims` ensures that the shapes are
+ # compatible with image models.
+
+ spectrograms = np.concatenate(
+ [
+ layers.STFTSpectrogram(
+ mode="log",
+ frame_length=SAMPLE_RATE * x // 1000,
+ frame_step=SAMPLE_RATE * 5 // 1000,
+ fft_length=1024,
+ padding="same",
+ expand_dims=True,
+ )(sample_wav_data[None, ...])[0, ...]
+ for x in [5, 10, 20]
+ ],
+ axis=-1,
+ ).transpose([1, 0, 2])
+
+ # normalize each color channel for better viewing
+ mn = spectrograms.min(axis=(0, 1), keepdims=True)
+ mx = spectrograms.max(axis=(0, 1), keepdims=True)
+ spectrograms = (spectrograms - mn) / (mx - mn)
+
+ plt.imshow(spectrograms, origin="lower")
+ plt.title("Multi-bandwidth Spectrogram")
+ plt.xlabel("Time")
+ plt.ylabel("Frequency")
+ plt.show()
+
+
+"""
+Demonstrate a sample wav file.
+"""
+
+sample_wav_data = read_wav_file(pd_data["filename"].tolist()[52])
+plt.plot(sample_wav_data[:, 0])
+plt.show()
+
+"""
+Plot a Spectrogram
+"""
+
+plot_single_spectrogram(sample_wav_data)
+
+"""
+Plot a multi-bandwidth spectrogram
+"""
+
+plot_multi_bandwidth_spectrogram(sample_wav_data)
+
+"""
+### Define functions to construct a TF Dataset
+"""
+
+
+def read_dataset(df, folds):
+ msk = df["fold"].isin(folds)
+ filenames = df["filename"][msk]
+ targets = df["target"][msk].values
+ waves = np.array([read_wav_file(fil) for fil in filenames], dtype=np.float32)
+ return waves, targets
+
+
+"""
+### Create the datasets
+"""
+
+train_x, train_y = read_dataset(pd_data, [1, 2, 3])
+valid_x, valid_y = read_dataset(pd_data, [4])
+test_x, test_y = read_dataset(pd_data, [5])
+
+"""
+## Training the Models
+
+In this tutorial we demonstrate the different usecases of the `STFTSpectrogram`
+layer.
+
+The first model will use a non-trainable `STFTSpectrogram` layer, so it is
+intended purely for preprocessing. Additionally, the model will use 1D signals,
+hence it make use of Conv1D layers.
+
+The second model will use a trainable `STFTSpectrogram` layer with the
+`expand_dims` option, which expands the shapes to be compatible with image
+models.
+
+### Create the 1D model
+
+1. Create a non-trainable spectrograms, extracting a 1D time signal.
+2. Apply `Conv1D` layers with `LayerNormalization` simialar to the
+ classic VGG design.
+4. Apply global maximum pooling to have fixed set of features.
+5. Add `Dense` layers to make the final predictions based on the features.
+"""
+
+model1d = keras.Sequential(
+ [
+ layers.InputLayer((None, 1)),
+ layers.STFTSpectrogram(
+ mode="log",
+ frame_length=SAMPLE_RATE * 40 // 1000,
+ frame_step=SAMPLE_RATE * 15 // 1000,
+ trainable=False,
+ ),
+ layers.Conv1D(64, 64, activation="relu"),
+ layers.Conv1D(128, 16, activation="relu"),
+ layers.LayerNormalization(),
+ layers.MaxPooling1D(4),
+ layers.Conv1D(128, 8, activation="relu"),
+ layers.Conv1D(256, 8, activation="relu"),
+ layers.Conv1D(512, 4, activation="relu"),
+ layers.LayerNormalization(),
+ layers.Dropout(0.5),
+ layers.GlobalMaxPooling1D(),
+ layers.Dense(256, activation="relu"),
+ layers.Dense(256, activation="relu"),
+ layers.Dropout(0.5),
+ layers.Dense(NUM_CLASSES, activation="softmax"),
+ ],
+ name="model_1d_non_trainble_stft",
+)
+model1d.compile(
+ optimizer=keras.optimizers.Adam(1e-5),
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+)
+model1d.summary()
+
+"""
+Train the model and restore the best weights.
+"""
+
+history_model1d = model1d.fit(
+ train_x,
+ train_y,
+ batch_size=BATCH_SIZE,
+ validation_data=(valid_x, valid_y),
+ epochs=EPOCHS,
+ callbacks=[
+ keras.callbacks.EarlyStopping(
+ monitor="val_loss",
+ patience=EPOCHS,
+ restore_best_weights=True,
+ )
+ ],
+)
+
+"""
+### Create the 2D model
+
+1. Create three spectrograms with multiple band-widths from the raw input.
+2. Concatenate the three spectrograms to have three channels.
+3. Load `MobileNet` and set the weights from the weights trained on `ImageNet`.
+4. Apply global maximum pooling to have fixed set of features.
+5. Add `Dense` layers to make the final predictions based on the features.
+"""
+
+input = layers.Input((None, 1))
+spectrograms = [
+ layers.STFTSpectrogram(
+ mode="log",
+ frame_length=SAMPLE_RATE * frame_size // 1000,
+ frame_step=SAMPLE_RATE * 15 // 1000,
+ fft_length=2048,
+ padding="same",
+ expand_dims=True,
+ # trainable=True, # trainable by default
+ )(input)
+ for frame_size in [30, 40, 50] # frame size in milliseconds
+]
+
+multi_spectrograms = layers.Concatenate(axis=-1)(spectrograms)
+
+img_model = keras.applications.MobileNet(include_top=False, pooling="max")
+output = img_model(multi_spectrograms)
+
+output = layers.Dropout(0.5)(output)
+output = layers.Dense(256, activation="relu")(output)
+output = layers.Dense(256, activation="relu")(output)
+output = layers.Dense(NUM_CLASSES, activation="softmax")(output)
+model2d = keras.Model(input, output, name="model_2d_trainble_stft")
+
+model2d.compile(
+ optimizer=keras.optimizers.Adam(1e-4),
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+)
+model2d.summary()
+
+"""
+Train the model and restore the best weights.
+"""
+
+history_model2d = model2d.fit(
+ train_x,
+ train_y,
+ batch_size=BATCH_SIZE,
+ validation_data=(valid_x, valid_y),
+ epochs=EPOCHS,
+ callbacks=[
+ keras.callbacks.EarlyStopping(
+ monitor="val_loss",
+ patience=EPOCHS,
+ restore_best_weights=True,
+ )
+ ],
+)
+
+"""
+### Plot Training History
+"""
+
+epochs_range = range(EPOCHS)
+
+plt.figure(figsize=(14, 5))
+plt.subplot(1, 2, 1)
+plt.plot(
+ epochs_range,
+ history_model1d.history["accuracy"],
+ label="Training Accuracy,1D model with non-trainable STFT",
+)
+plt.plot(
+ epochs_range,
+ history_model1d.history["val_accuracy"],
+ label="Validation Accuracy, 1D model with non-trainable STFT",
+)
+plt.plot(
+ epochs_range,
+ history_model2d.history["accuracy"],
+ label="Training Accuracy, 2D model with trainable STFT",
+)
+plt.plot(
+ epochs_range,
+ history_model2d.history["val_accuracy"],
+ label="Validation Accuracy, 2D model with trainable STFT",
+)
+plt.legend(loc="lower right")
+plt.title("Training and Validation Accuracy")
+
+plt.subplot(1, 2, 2)
+plt.plot(
+ epochs_range,
+ history_model1d.history["loss"],
+ label="Training Loss,1D model with non-trainable STFT",
+)
+plt.plot(
+ epochs_range,
+ history_model1d.history["val_loss"],
+ label="Validation Loss, 1D model with non-trainable STFT",
+)
+plt.plot(
+ epochs_range,
+ history_model2d.history["loss"],
+ label="Training Loss, 2D model with trainable STFT",
+)
+plt.plot(
+ epochs_range,
+ history_model2d.history["val_loss"],
+ label="Validation Loss, 2D model with trainable STFT",
+)
+plt.legend(loc="upper right")
+plt.title("Training and Validation Loss")
+plt.show()
+
+"""
+### Evaluate on Test Data
+
+Running the models on the test set.
+"""
+
+_, test_acc = model1d.evaluate(test_x, test_y)
+print(f"1D model wit non-trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%")
+
+_, test_acc = model2d.evaluate(test_x, test_y)
+print(f"2D model with trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%")
diff --git a/knowledge_base/audio/transformer_asr.py b/knowledge_base/audio/transformer_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b661c73c83abdf1b0c5d734b7e386405511d315c
--- /dev/null
+++ b/knowledge_base/audio/transformer_asr.py
@@ -0,0 +1,542 @@
+"""
+Title: Automatic Speech Recognition with Transformer
+Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
+Date created: 2021/01/13
+Last modified: 2021/01/13
+Description: Training a sequence-to-sequence Transformer for automatic speech recognition.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Automatic speech recognition (ASR) consists of transcribing audio speech segments into text.
+ASR can be treated as a sequence-to-sequence problem, where the
+audio can be represented as a sequence of feature vectors
+and the text as a sequence of characters, words, or subword tokens.
+
+For this demonstration, we will use the LJSpeech dataset from the
+[LibriVox](https://librivox.org/) project. It consists of short
+audio clips of a single speaker reading passages from 7 non-fiction books.
+Our model will be similar to the original Transformer (both encoder and decoder)
+as proposed in the paper, "Attention is All You Need".
+
+
+**References:**
+
+- [Attention is All You Need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
+- [Very Deep Self-Attention Networks for End-to-End Speech Recognition](https://arxiv.org/abs/1904.13377)
+- [Speech Transformers](https://ieeexplore.ieee.org/document/8462506)
+- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/)
+"""
+
+import re
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+from glob import glob
+import tensorflow as tf
+import keras
+from keras import layers
+
+
+"""
+## Define the Transformer Input Layer
+
+When processing past target tokens for the decoder, we compute the sum of
+position embeddings and token embeddings.
+
+When processing audio features, we apply convolutional layers to downsample
+them (via convolution strides) and process local relationships.
+"""
+
+
+class TokenEmbedding(layers.Layer):
+ def __init__(self, num_vocab=1000, maxlen=100, num_hid=64):
+ super().__init__()
+ self.emb = keras.layers.Embedding(num_vocab, num_hid)
+ self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid)
+
+ def call(self, x):
+ maxlen = tf.shape(x)[-1]
+ x = self.emb(x)
+ positions = tf.range(start=0, limit=maxlen, delta=1)
+ positions = self.pos_emb(positions)
+ return x + positions
+
+
+class SpeechFeatureEmbedding(layers.Layer):
+ def __init__(self, num_hid=64, maxlen=100):
+ super().__init__()
+ self.conv1 = keras.layers.Conv1D(
+ num_hid, 11, strides=2, padding="same", activation="relu"
+ )
+ self.conv2 = keras.layers.Conv1D(
+ num_hid, 11, strides=2, padding="same", activation="relu"
+ )
+ self.conv3 = keras.layers.Conv1D(
+ num_hid, 11, strides=2, padding="same", activation="relu"
+ )
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return self.conv3(x)
+
+
+"""
+## Transformer Encoder Layer
+"""
+
+
+class TransformerEncoder(layers.Layer):
+ def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1):
+ super().__init__()
+ self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(feed_forward_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
+ self.dropout1 = layers.Dropout(rate)
+ self.dropout2 = layers.Dropout(rate)
+
+ def call(self, inputs, training=False):
+ attn_output = self.att(inputs, inputs)
+ attn_output = self.dropout1(attn_output, training=training)
+ out1 = self.layernorm1(inputs + attn_output)
+ ffn_output = self.ffn(out1)
+ ffn_output = self.dropout2(ffn_output, training=training)
+ return self.layernorm2(out1 + ffn_output)
+
+
+"""
+## Transformer Decoder Layer
+"""
+
+
+class TransformerDecoder(layers.Layer):
+ def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1):
+ super().__init__()
+ self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)
+ self.self_att = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.enc_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
+ self.self_dropout = layers.Dropout(0.5)
+ self.enc_dropout = layers.Dropout(0.1)
+ self.ffn_dropout = layers.Dropout(0.1)
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(feed_forward_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+
+ def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
+ """Masks the upper half of the dot product matrix in self attention.
+
+ This prevents flow of information from future tokens to current token.
+ 1's in the lower triangle, counting from the lower right corner.
+ """
+ i = tf.range(n_dest)[:, None]
+ j = tf.range(n_src)
+ m = i >= j - n_src + n_dest
+ mask = tf.cast(m, dtype)
+ mask = tf.reshape(mask, [1, n_dest, n_src])
+ mult = tf.concat(
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
+ )
+ return tf.tile(mask, mult)
+
+ def call(self, enc_out, target):
+ input_shape = tf.shape(target)
+ batch_size = input_shape[0]
+ seq_len = input_shape[1]
+ causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
+ target_att = self.self_att(target, target, attention_mask=causal_mask)
+ target_norm = self.layernorm1(target + self.self_dropout(target_att))
+ enc_out = self.enc_att(target_norm, enc_out)
+ enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm)
+ ffn_out = self.ffn(enc_out_norm)
+ ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out))
+ return ffn_out_norm
+
+
+"""
+## Complete the Transformer model
+
+Our model takes audio spectrograms as inputs and predicts a sequence of characters.
+During training, we give the decoder the target character sequence shifted to the left
+as input. During inference, the decoder uses its own past predictions to predict the
+next token.
+"""
+
+
+class Transformer(keras.Model):
+ def __init__(
+ self,
+ num_hid=64,
+ num_head=2,
+ num_feed_forward=128,
+ source_maxlen=100,
+ target_maxlen=100,
+ num_layers_enc=4,
+ num_layers_dec=1,
+ num_classes=10,
+ ):
+ super().__init__()
+ self.loss_metric = keras.metrics.Mean(name="loss")
+ self.num_layers_enc = num_layers_enc
+ self.num_layers_dec = num_layers_dec
+ self.target_maxlen = target_maxlen
+ self.num_classes = num_classes
+
+ self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen)
+ self.dec_input = TokenEmbedding(
+ num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid
+ )
+
+ self.encoder = keras.Sequential(
+ [self.enc_input]
+ + [
+ TransformerEncoder(num_hid, num_head, num_feed_forward)
+ for _ in range(num_layers_enc)
+ ]
+ )
+
+ for i in range(num_layers_dec):
+ setattr(
+ self,
+ f"dec_layer_{i}",
+ TransformerDecoder(num_hid, num_head, num_feed_forward),
+ )
+
+ self.classifier = layers.Dense(num_classes)
+
+ def decode(self, enc_out, target):
+ y = self.dec_input(target)
+ for i in range(self.num_layers_dec):
+ y = getattr(self, f"dec_layer_{i}")(enc_out, y)
+ return y
+
+ def call(self, inputs):
+ source = inputs[0]
+ target = inputs[1]
+ x = self.encoder(source)
+ y = self.decode(x, target)
+ return self.classifier(y)
+
+ @property
+ def metrics(self):
+ return [self.loss_metric]
+
+ def train_step(self, batch):
+ """Processes one batch inside model.fit()."""
+ source = batch["source"]
+ target = batch["target"]
+ dec_input = target[:, :-1]
+ dec_target = target[:, 1:]
+ with tf.GradientTape() as tape:
+ preds = self([source, dec_input])
+ one_hot = tf.one_hot(dec_target, depth=self.num_classes)
+ mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
+ loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
+ trainable_vars = self.trainable_variables
+ gradients = tape.gradient(loss, trainable_vars)
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+ self.loss_metric.update_state(loss)
+ return {"loss": self.loss_metric.result()}
+
+ def test_step(self, batch):
+ source = batch["source"]
+ target = batch["target"]
+ dec_input = target[:, :-1]
+ dec_target = target[:, 1:]
+ preds = self([source, dec_input])
+ one_hot = tf.one_hot(dec_target, depth=self.num_classes)
+ mask = tf.math.logical_not(tf.math.equal(dec_target, 0))
+ loss = self.compute_loss(None, one_hot, preds, sample_weight=mask)
+ self.loss_metric.update_state(loss)
+ return {"loss": self.loss_metric.result()}
+
+ def generate(self, source, target_start_token_idx):
+ """Performs inference over one batch of inputs using greedy decoding."""
+ bs = tf.shape(source)[0]
+ enc = self.encoder(source)
+ dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx
+ dec_logits = []
+ for i in range(self.target_maxlen - 1):
+ dec_out = self.decode(enc, dec_input)
+ logits = self.classifier(dec_out)
+ logits = tf.argmax(logits, axis=-1, output_type=tf.int32)
+ last_logit = tf.expand_dims(logits[:, -1], axis=-1)
+ dec_logits.append(last_logit)
+ dec_input = tf.concat([dec_input, last_logit], axis=-1)
+ return dec_input
+
+
+"""
+## Download the dataset
+
+Note: This requires ~3.6 GB of disk space and
+takes ~5 minutes for the extraction of files.
+"""
+
+pattern_wav_name = re.compile(r"([^/\\\.]+)")
+
+keras.utils.get_file(
+ os.path.join(os.getcwd(), "data.tar.gz"),
+ "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
+ extract=True,
+ archive_format="tar",
+ cache_dir=".",
+)
+
+
+saveto = "./datasets/LJSpeech-1.1"
+wavs = glob("{}/**/*.wav".format(saveto), recursive=True)
+
+id_to_text = {}
+with open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") as f:
+ for line in f:
+ id = line.strip().split("|")[0]
+ text = line.strip().split("|")[2]
+ id_to_text[id] = text
+
+
+def get_data(wavs, id_to_text, maxlen=50):
+ """returns mapping of audio paths and transcription texts"""
+ data = []
+ for w in wavs:
+ id = pattern_wav_name.split(w)[-4]
+ if len(id_to_text[id]) < maxlen:
+ data.append({"audio": w, "text": id_to_text[id]})
+ return data
+
+
+"""
+## Preprocess the dataset
+"""
+
+
+class VectorizeChar:
+ def __init__(self, max_len=50):
+ self.vocab = (
+ ["-", "#", "<", ">"]
+ + [chr(i + 96) for i in range(1, 27)]
+ + [" ", ".", ",", "?"]
+ )
+ self.max_len = max_len
+ self.char_to_idx = {}
+ for i, ch in enumerate(self.vocab):
+ self.char_to_idx[ch] = i
+
+ def __call__(self, text):
+ text = text.lower()
+ text = text[: self.max_len - 2]
+ text = "<" + text + ">"
+ pad_len = self.max_len - len(text)
+ return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len
+
+ def get_vocabulary(self):
+ return self.vocab
+
+
+max_target_len = 200 # all transcripts in out data are < 200 characters
+data = get_data(wavs, id_to_text, max_target_len)
+vectorizer = VectorizeChar(max_target_len)
+print("vocab size", len(vectorizer.get_vocabulary()))
+
+
+def create_text_ds(data):
+ texts = [_["text"] for _ in data]
+ text_ds = [vectorizer(t) for t in texts]
+ text_ds = tf.data.Dataset.from_tensor_slices(text_ds)
+ return text_ds
+
+
+def path_to_audio(path):
+ # spectrogram using stft
+ audio = tf.io.read_file(path)
+ audio, _ = tf.audio.decode_wav(audio, 1)
+ audio = tf.squeeze(audio, axis=-1)
+ stfts = tf.signal.stft(audio, frame_length=200, frame_step=80, fft_length=256)
+ x = tf.math.pow(tf.abs(stfts), 0.5)
+ # normalisation
+ means = tf.math.reduce_mean(x, 1, keepdims=True)
+ stddevs = tf.math.reduce_std(x, 1, keepdims=True)
+ x = (x - means) / stddevs
+ audio_len = tf.shape(x)[0]
+ # padding to 10 seconds
+ pad_len = 2754
+ paddings = tf.constant([[0, pad_len], [0, 0]])
+ x = tf.pad(x, paddings, "CONSTANT")[:pad_len, :]
+ return x
+
+
+def create_audio_ds(data):
+ flist = [_["audio"] for _ in data]
+ audio_ds = tf.data.Dataset.from_tensor_slices(flist)
+ audio_ds = audio_ds.map(path_to_audio, num_parallel_calls=tf.data.AUTOTUNE)
+ return audio_ds
+
+
+def create_tf_dataset(data, bs=4):
+ audio_ds = create_audio_ds(data)
+ text_ds = create_text_ds(data)
+ ds = tf.data.Dataset.zip((audio_ds, text_ds))
+ ds = ds.map(lambda x, y: {"source": x, "target": y})
+ ds = ds.batch(bs)
+ ds = ds.prefetch(tf.data.AUTOTUNE)
+ return ds
+
+
+split = int(len(data) * 0.99)
+train_data = data[:split]
+test_data = data[split:]
+ds = create_tf_dataset(train_data, bs=64)
+val_ds = create_tf_dataset(test_data, bs=4)
+
+"""
+## Callbacks to display predictions
+"""
+
+
+class DisplayOutputs(keras.callbacks.Callback):
+ def __init__(
+ self, batch, idx_to_token, target_start_token_idx=27, target_end_token_idx=28
+ ):
+ """Displays a batch of outputs after every epoch
+
+ Args:
+ batch: A test batch containing the keys "source" and "target"
+ idx_to_token: A List containing the vocabulary tokens corresponding to their indices
+ target_start_token_idx: A start token index in the target vocabulary
+ target_end_token_idx: An end token index in the target vocabulary
+ """
+ self.batch = batch
+ self.target_start_token_idx = target_start_token_idx
+ self.target_end_token_idx = target_end_token_idx
+ self.idx_to_char = idx_to_token
+
+ def on_epoch_end(self, epoch, logs=None):
+ if epoch % 5 != 0:
+ return
+ source = self.batch["source"]
+ target = self.batch["target"].numpy()
+ bs = tf.shape(source)[0]
+ preds = self.model.generate(source, self.target_start_token_idx)
+ preds = preds.numpy()
+ for i in range(bs):
+ target_text = "".join([self.idx_to_char[_] for _ in target[i, :]])
+ prediction = ""
+ for idx in preds[i, :]:
+ prediction += self.idx_to_char[idx]
+ if idx == self.target_end_token_idx:
+ break
+ print(f"target: {target_text.replace('-','')}")
+ print(f"prediction: {prediction}\n")
+
+
+"""
+## Learning rate schedule
+"""
+
+
+class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule):
+ def __init__(
+ self,
+ init_lr=0.00001,
+ lr_after_warmup=0.001,
+ final_lr=0.00001,
+ warmup_epochs=15,
+ decay_epochs=85,
+ steps_per_epoch=203,
+ ):
+ super().__init__()
+ self.init_lr = init_lr
+ self.lr_after_warmup = lr_after_warmup
+ self.final_lr = final_lr
+ self.warmup_epochs = warmup_epochs
+ self.decay_epochs = decay_epochs
+ self.steps_per_epoch = steps_per_epoch
+
+ def calculate_lr(self, epoch):
+ """linear warm up - linear decay"""
+ warmup_lr = (
+ self.init_lr
+ + ((self.lr_after_warmup - self.init_lr) / (self.warmup_epochs - 1)) * epoch
+ )
+ decay_lr = tf.math.maximum(
+ self.final_lr,
+ self.lr_after_warmup
+ - (epoch - self.warmup_epochs)
+ * (self.lr_after_warmup - self.final_lr)
+ / self.decay_epochs,
+ )
+ return tf.math.minimum(warmup_lr, decay_lr)
+
+ def __call__(self, step):
+ epoch = step // self.steps_per_epoch
+ epoch = tf.cast(epoch, "float32")
+ return self.calculate_lr(epoch)
+
+
+"""
+## Create & train the end-to-end model
+"""
+
+batch = next(iter(val_ds))
+
+# The vocabulary to convert predicted indices into characters
+idx_to_char = vectorizer.get_vocabulary()
+display_cb = DisplayOutputs(
+ batch, idx_to_char, target_start_token_idx=2, target_end_token_idx=3
+) # set the arguments as per vocabulary index for '<' and '>'
+
+model = Transformer(
+ num_hid=200,
+ num_head=2,
+ num_feed_forward=400,
+ target_maxlen=max_target_len,
+ num_layers_enc=4,
+ num_layers_dec=1,
+ num_classes=34,
+)
+loss_fn = keras.losses.CategoricalCrossentropy(
+ from_logits=True,
+ label_smoothing=0.1,
+)
+
+learning_rate = CustomSchedule(
+ init_lr=0.00001,
+ lr_after_warmup=0.001,
+ final_lr=0.00001,
+ warmup_epochs=15,
+ decay_epochs=85,
+ steps_per_epoch=len(ds),
+)
+optimizer = keras.optimizers.Adam(learning_rate)
+model.compile(optimizer=optimizer, loss=loss_fn)
+
+history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1)
+
+"""
+In practice, you should train for around 100 epochs or more.
+
+Some of the predicted text at or around epoch 35 may look as follows:
+```
+target:
+prediction:
+
+target:
+prediction:
+```
+"""
diff --git a/knowledge_base/audio/uk_ireland_accent_recognition.py b/knowledge_base/audio/uk_ireland_accent_recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..283b0a4edc671efd23a8c3eda42e82f3f1c33b1b
--- /dev/null
+++ b/knowledge_base/audio/uk_ireland_accent_recognition.py
@@ -0,0 +1,696 @@
+"""
+Title: English speaker accent recognition using Transfer Learning
+Author: [Fadi Badine](https://twitter.com/fadibadine)
+Date created: 2022/04/16
+Last modified: 2022/04/16
+Description: Training a model to classify UK & Ireland accents using feature extraction from Yamnet.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The following example shows how to use feature extraction in order to
+train a model to classify the English accent spoken in an audio wave.
+
+Instead of training a model from scratch, transfer learning enables us to
+take advantage of existing state-of-the-art deep learning models and use them as feature extractors.
+
+Our process:
+
+* Use a TF Hub pre-trained model (Yamnet) and apply it as part of the tf.data pipeline which transforms
+the audio files into feature vectors.
+* Train a dense model on the feature vectors.
+* Use the trained model for inference on a new audio file.
+
+Note:
+
+* We need to install TensorFlow IO in order to resample audio files to 16 kHz as required by Yamnet model.
+* In the test section, ffmpeg is used to convert the mp3 file to wav.
+
+You can install TensorFlow IO with the following command:
+"""
+
+"""shell
+pip install -U -q tensorflow_io
+"""
+
+"""
+## Configuration
+"""
+
+SEED = 1337
+EPOCHS = 100
+BATCH_SIZE = 64
+VALIDATION_RATIO = 0.1
+MODEL_NAME = "uk_irish_accent_recognition"
+
+# Location where the dataset will be downloaded.
+# By default (None), keras.utils.get_file will use ~/.keras/ as the CACHE_DIR
+CACHE_DIR = None
+
+# The location of the dataset
+URL_PATH = "https://www.openslr.org/resources/83/"
+
+# List of datasets compressed files that contain the audio files
+zip_files = {
+ 0: "irish_english_male.zip",
+ 1: "midlands_english_female.zip",
+ 2: "midlands_english_male.zip",
+ 3: "northern_english_female.zip",
+ 4: "northern_english_male.zip",
+ 5: "scottish_english_female.zip",
+ 6: "scottish_english_male.zip",
+ 7: "southern_english_female.zip",
+ 8: "southern_english_male.zip",
+ 9: "welsh_english_female.zip",
+ 10: "welsh_english_male.zip",
+}
+
+# We see that there are 2 compressed files for each accent (except Irish):
+# - One for male speakers
+# - One for female speakers
+# However, we will be using a gender agnostic dataset.
+
+# List of gender agnostic categories
+gender_agnostic_categories = [
+ "ir", # Irish
+ "mi", # Midlands
+ "no", # Northern
+ "sc", # Scottish
+ "so", # Southern
+ "we", # Welsh
+]
+
+class_names = [
+ "Irish",
+ "Midlands",
+ "Northern",
+ "Scottish",
+ "Southern",
+ "Welsh",
+ "Not a speech",
+]
+
+"""
+## Imports
+"""
+
+import os
+import io
+import csv
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import tensorflow_hub as hub
+import tensorflow_io as tfio
+from tensorflow import keras
+import matplotlib.pyplot as plt
+import seaborn as sns
+from scipy import stats
+from IPython.display import Audio
+
+
+# Set all random seeds in order to get reproducible results
+keras.utils.set_random_seed(SEED)
+
+# Where to download the dataset
+DATASET_DESTINATION = os.path.join(CACHE_DIR if CACHE_DIR else "~/.keras/", "datasets")
+
+"""
+## Yamnet Model
+
+Yamnet is an audio event classifier trained on the AudioSet dataset to predict audio
+events from the AudioSet ontology. It is available on TensorFlow Hub.
+
+Yamnet accepts a 1-D tensor of audio samples with a sample rate of 16 kHz.
+As output, the model returns a 3-tuple:
+
+* Scores of shape `(N, 521)` representing the scores of the 521 classes.
+* Embeddings of shape `(N, 1024)`.
+* The log-mel spectrogram of the entire audio frame.
+
+We will use the embeddings, which are the features extracted from the audio samples, as the input to our dense model.
+
+For more detailed information about Yamnet, please refer to its [TensorFlow Hub](https://tfhub.dev/google/yamnet/1) page.
+"""
+
+yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
+
+"""
+## Dataset
+
+The dataset used is the
+[Crowdsourced high-quality UK and Ireland English Dialect speech data set](https://openslr.org/83/)
+which consists of a total of 17,877 high-quality audio wav files.
+
+This dataset includes over 31 hours of recording from 120 volunteers who self-identify as
+native speakers of Southern England, Midlands, Northern England, Wales, Scotland and Ireland.
+
+For more info, please refer to the above link or to the following paper:
+[Open-source Multi-speaker Corpora of the English Accents in the British Isles](https://aclanthology.org/2020.lrec-1.804.pdf)
+"""
+
+"""
+## Download the data
+"""
+
+# CSV file that contains information about the dataset. For each entry, we have:
+# - ID
+# - wav file name
+# - transcript
+line_index_file = keras.utils.get_file(
+ fname="line_index_file", origin=URL_PATH + "line_index_all.csv"
+)
+
+# Download the list of compressed files that contain the audio wav files
+for i in zip_files:
+ fname = zip_files[i].split(".")[0]
+ url = URL_PATH + zip_files[i]
+
+ zip_file = keras.utils.get_file(fname=fname, origin=url, extract=True)
+ os.remove(zip_file)
+
+"""
+## Load the data in a Dataframe
+
+Of the 3 columns (ID, filename and transcript), we are only interested in the filename column in order to read the audio file.
+We will ignore the other two.
+"""
+
+dataframe = pd.read_csv(
+ line_index_file, names=["id", "filename", "transcript"], usecols=["filename"]
+)
+dataframe.head()
+
+"""
+Let's now preprocess the dataset by:
+
+* Adjusting the filename (removing a leading space & adding ".wav" extension to the
+filename).
+* Creating a label using the first 2 characters of the filename which indicate the
+accent.
+* Shuffling the samples.
+"""
+
+
+# The purpose of this function is to preprocess the dataframe by applying the following:
+# - Cleaning the filename from a leading space
+# - Generating a label column that is gender agnostic i.e.
+# welsh english male and welsh english female for example are both labeled as
+# welsh english
+# - Add extension .wav to the filename
+# - Shuffle samples
+def preprocess_dataframe(dataframe):
+ # Remove leading space in filename column
+ dataframe["filename"] = dataframe.apply(lambda row: row["filename"].strip(), axis=1)
+
+ # Create gender agnostic labels based on the filename first 2 letters
+ dataframe["label"] = dataframe.apply(
+ lambda row: gender_agnostic_categories.index(row["filename"][:2]), axis=1
+ )
+
+ # Add the file path to the name
+ dataframe["filename"] = dataframe.apply(
+ lambda row: os.path.join(DATASET_DESTINATION, row["filename"] + ".wav"), axis=1
+ )
+
+ # Shuffle the samples
+ dataframe = dataframe.sample(frac=1, random_state=SEED).reset_index(drop=True)
+
+ return dataframe
+
+
+dataframe = preprocess_dataframe(dataframe)
+dataframe.head()
+
+"""
+## Prepare training & validation sets
+
+Let's split the samples creating training and validation sets.
+"""
+
+split = int(len(dataframe) * (1 - VALIDATION_RATIO))
+train_df = dataframe[:split]
+valid_df = dataframe[split:]
+
+print(
+ f"We have {train_df.shape[0]} training samples & {valid_df.shape[0]} validation ones"
+)
+
+"""
+## Prepare a TensorFlow Dataset
+
+Next, we need to create a `tf.data.Dataset`.
+This is done by creating a `dataframe_to_dataset` function that does the following:
+
+* Create a dataset using filenames and labels.
+* Get the Yamnet embeddings by calling another function `filepath_to_embeddings`.
+* Apply caching, reshuffling and setting batch size.
+
+The `filepath_to_embeddings` does the following:
+
+* Load audio file.
+* Resample audio to 16 kHz.
+* Generate scores and embeddings from Yamnet model.
+* Since Yamnet generates multiple samples for each audio file,
+this function also duplicates the label for all the generated samples
+that have `score=0` (speech) whereas sets the label for the others as
+'other' indicating that this audio segment is not a speech and we won't label it as one of the accents.
+
+The below `load_16k_audio_file` is copied from the following tutorial
+[Transfer learning with YAMNet for environmental sound classification](https://www.tensorflow.org/tutorials/audio/transfer_learning_audio)
+"""
+
+
+@tf.function
+def load_16k_audio_wav(filename):
+ # Read file content
+ file_content = tf.io.read_file(filename)
+
+ # Decode audio wave
+ audio_wav, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1)
+ audio_wav = tf.squeeze(audio_wav, axis=-1)
+ sample_rate = tf.cast(sample_rate, dtype=tf.int64)
+
+ # Resample to 16k
+ audio_wav = tfio.audio.resample(audio_wav, rate_in=sample_rate, rate_out=16000)
+
+ return audio_wav
+
+
+def filepath_to_embeddings(filename, label):
+ # Load 16k audio wave
+ audio_wav = load_16k_audio_wav(filename)
+
+ # Get audio embeddings & scores.
+ # The embeddings are the audio features extracted using transfer learning
+ # while scores will be used to identify time slots that are not speech
+ # which will then be gathered into a specific new category 'other'
+ scores, embeddings, _ = yamnet_model(audio_wav)
+
+ # Number of embeddings in order to know how many times to repeat the label
+ embeddings_num = tf.shape(embeddings)[0]
+ labels = tf.repeat(label, embeddings_num)
+
+ # Change labels for time-slots that are not speech into a new category 'other'
+ labels = tf.where(tf.argmax(scores, axis=1) == 0, label, len(class_names) - 1)
+
+ # Using one-hot in order to use AUC
+ return (embeddings, tf.one_hot(labels, len(class_names)))
+
+
+def dataframe_to_dataset(dataframe, batch_size=64):
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (dataframe["filename"], dataframe["label"])
+ )
+
+ dataset = dataset.map(
+ lambda x, y: filepath_to_embeddings(x, y),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
+ ).unbatch()
+
+ return dataset.cache().batch(batch_size).prefetch(tf.data.AUTOTUNE)
+
+
+train_ds = dataframe_to_dataset(train_df)
+valid_ds = dataframe_to_dataset(valid_df)
+
+"""
+## Build the model
+
+The model that we use consists of:
+
+* An input layer which is the embedding output of the Yamnet classifier.
+* 4 dense hidden layers and 4 dropout layers.
+* An output dense layer.
+
+The model's hyperparameters were selected using
+[KerasTuner](https://keras.io/keras_tuner/).
+"""
+
+keras.backend.clear_session()
+
+
+def build_and_compile_model():
+ inputs = keras.layers.Input(shape=(1024), name="embedding")
+
+ x = keras.layers.Dense(256, activation="relu", name="dense_1")(inputs)
+ x = keras.layers.Dropout(0.15, name="dropout_1")(x)
+
+ x = keras.layers.Dense(384, activation="relu", name="dense_2")(x)
+ x = keras.layers.Dropout(0.2, name="dropout_2")(x)
+
+ x = keras.layers.Dense(192, activation="relu", name="dense_3")(x)
+ x = keras.layers.Dropout(0.25, name="dropout_3")(x)
+
+ x = keras.layers.Dense(384, activation="relu", name="dense_4")(x)
+ x = keras.layers.Dropout(0.2, name="dropout_4")(x)
+
+ outputs = keras.layers.Dense(len(class_names), activation="softmax", name="ouput")(
+ x
+ )
+
+ model = keras.Model(inputs=inputs, outputs=outputs, name="accent_recognition")
+
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=1.9644e-5),
+ loss=keras.losses.CategoricalCrossentropy(),
+ metrics=["accuracy", keras.metrics.AUC(name="auc")],
+ )
+
+ return model
+
+
+model = build_and_compile_model()
+model.summary()
+
+"""
+## Class weights calculation
+
+Since the dataset is quite unbalanced, we will use `class_weight` argument during training.
+
+Getting the class weights is a little tricky because even though we know the number of
+audio files for each class, it does not represent the number of samples for that class
+since Yamnet transforms each audio file into multiple audio samples of 0.96 seconds each.
+So every audio file will be split into a number of samples that is proportional to its length.
+
+Therefore, to get those weights, we have to calculate the number of samples for each class
+after preprocessing through Yamnet.
+"""
+
+class_counts = tf.zeros(shape=(len(class_names),), dtype=tf.int32)
+
+for x, y in iter(train_ds):
+ class_counts = class_counts + tf.math.bincount(
+ tf.cast(tf.math.argmax(y, axis=1), tf.int32), minlength=len(class_names)
+ )
+
+class_weight = {
+ i: tf.math.reduce_sum(class_counts).numpy() / class_counts[i].numpy()
+ for i in range(len(class_counts))
+}
+
+print(class_weight)
+
+"""
+## Callbacks
+
+We use Keras callbacks in order to:
+
+* Stop whenever the validation AUC stops improving.
+* Save the best model.
+* Call TensorBoard in order to later view the training and validation logs.
+"""
+
+early_stopping_cb = keras.callbacks.EarlyStopping(
+ monitor="val_auc", patience=10, restore_best_weights=True
+)
+
+model_checkpoint_cb = keras.callbacks.ModelCheckpoint(
+ MODEL_NAME + ".h5", monitor="val_auc", save_best_only=True
+)
+
+tensorboard_cb = keras.callbacks.TensorBoard(
+ os.path.join(os.curdir, "logs", model.name)
+)
+
+callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]
+
+"""
+## Training
+"""
+
+history = model.fit(
+ train_ds,
+ epochs=EPOCHS,
+ validation_data=valid_ds,
+ class_weight=class_weight,
+ callbacks=callbacks,
+ verbose=2,
+)
+
+"""
+## Results
+
+Let's plot the training and validation AUC and accuracy.
+"""
+
+fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 5))
+
+axs[0].plot(range(EPOCHS), history.history["accuracy"], label="Training")
+axs[0].plot(range(EPOCHS), history.history["val_accuracy"], label="Validation")
+axs[0].set_xlabel("Epochs")
+axs[0].set_title("Training & Validation Accuracy")
+axs[0].legend()
+axs[0].grid(True)
+
+axs[1].plot(range(EPOCHS), history.history["auc"], label="Training")
+axs[1].plot(range(EPOCHS), history.history["val_auc"], label="Validation")
+axs[1].set_xlabel("Epochs")
+axs[1].set_title("Training & Validation AUC")
+axs[1].legend()
+axs[1].grid(True)
+
+plt.show()
+
+"""
+## Evaluation
+"""
+
+train_loss, train_acc, train_auc = model.evaluate(train_ds)
+valid_loss, valid_acc, valid_auc = model.evaluate(valid_ds)
+
+"""
+Let's try to compare our model's performance to Yamnet's using one of Yamnet metrics (d-prime)
+Yamnet achieved a d-prime value of 2.318.
+Let's check our model's performance.
+"""
+
+
+# The following function calculates the d-prime score from the AUC
+def d_prime(auc):
+ standard_normal = stats.norm()
+ d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
+ return d_prime
+
+
+print(
+ "train d-prime: {0:.3f}, validation d-prime: {1:.3f}".format(
+ d_prime(train_auc), d_prime(valid_auc)
+ )
+)
+
+"""
+We can see that the model achieves the following results:
+
+Results | Training | Validation
+-----------|-----------|------------
+Accuracy | 54% | 51%
+AUC | 0.91 | 0.89
+d-prime | 1.882 | 1.740
+
+"""
+
+"""
+## Confusion Matrix
+
+Let's now plot the confusion matrix for the validation dataset.
+
+The confusion matrix lets us see, for every class, not only how many samples were correctly classified,
+but also which other classes were the samples confused with.
+
+It allows us to calculate the precision and recall for every class.
+"""
+
+# Create x and y tensors
+x_valid = None
+y_valid = None
+
+for x, y in iter(valid_ds):
+ if x_valid is None:
+ x_valid = x.numpy()
+ y_valid = y.numpy()
+ else:
+ x_valid = np.concatenate((x_valid, x.numpy()), axis=0)
+ y_valid = np.concatenate((y_valid, y.numpy()), axis=0)
+
+# Generate predictions
+y_pred = model.predict(x_valid)
+
+# Calculate confusion matrix
+confusion_mtx = tf.math.confusion_matrix(
+ np.argmax(y_valid, axis=1), np.argmax(y_pred, axis=1)
+)
+
+# Plot the confusion matrix
+plt.figure(figsize=(10, 8))
+sns.heatmap(
+ confusion_mtx, xticklabels=class_names, yticklabels=class_names, annot=True, fmt="g"
+)
+plt.xlabel("Prediction")
+plt.ylabel("Label")
+plt.title("Validation Confusion Matrix")
+plt.show()
+
+"""
+## Precision & recall
+
+For every class:
+
+* Recall is the ratio of correctly classified samples i.e. it shows how many samples
+of this specific class, the model is able to detect.
+It is the ratio of diagonal elements to the sum of all elements in the row.
+* Precision shows the accuracy of the classifier. It is the ratio of correctly predicted
+samples among the ones classified as belonging to this class.
+It is the ratio of diagonal elements to the sum of all elements in the column.
+"""
+
+for i, label in enumerate(class_names):
+ precision = confusion_mtx[i, i] / np.sum(confusion_mtx[:, i])
+ recall = confusion_mtx[i, i] / np.sum(confusion_mtx[i, :])
+ print(
+ "{0:15} Precision:{1:.2f}%; Recall:{2:.2f}%".format(
+ label, precision * 100, recall * 100
+ )
+ )
+
+"""
+## Run inference on test data
+
+Let's now run a test on a single audio file.
+Let's check this example from [The Scottish Voice](https://www.thescottishvoice.org.uk/home/)
+
+We will:
+
+* Download the mp3 file.
+* Convert it to a 16k wav file.
+* Run the model on the wav file.
+* Plot the results.
+"""
+
+filename = "audio-sample-Stuart"
+url = "https://www.thescottishvoice.org.uk/files/cm/files/"
+
+if os.path.exists(filename + ".wav") == False:
+ print(f"Downloading {filename}.mp3 from {url}")
+ command = f"wget {url}{filename}.mp3"
+ os.system(command)
+
+ print(f"Converting mp3 to wav and resampling to 16 kHZ")
+ command = (
+ f"ffmpeg -hide_banner -loglevel panic -y -i {filename}.mp3 -acodec "
+ f"pcm_s16le -ac 1 -ar 16000 {filename}.wav"
+ )
+ os.system(command)
+
+filename = filename + ".wav"
+
+
+"""
+The below function `yamnet_class_names_from_csv` was copied and very slightly changed
+from this [Yamnet Notebook](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/yamnet.ipynb).
+"""
+
+
+def yamnet_class_names_from_csv(yamnet_class_map_csv_text):
+ """Returns list of class names corresponding to score vector."""
+ yamnet_class_map_csv = io.StringIO(yamnet_class_map_csv_text)
+ yamnet_class_names = [
+ name for (class_index, mid, name) in csv.reader(yamnet_class_map_csv)
+ ]
+ yamnet_class_names = yamnet_class_names[1:] # Skip CSV header
+ return yamnet_class_names
+
+
+yamnet_class_map_path = yamnet_model.class_map_path().numpy()
+yamnet_class_names = yamnet_class_names_from_csv(
+ tf.io.read_file(yamnet_class_map_path).numpy().decode("utf-8")
+)
+
+
+def calculate_number_of_non_speech(scores):
+ number_of_non_speech = tf.math.reduce_sum(
+ tf.where(tf.math.argmax(scores, axis=1, output_type=tf.int32) != 0, 1, 0)
+ )
+
+ return number_of_non_speech
+
+
+def filename_to_predictions(filename):
+ # Load 16k audio wave
+ audio_wav = load_16k_audio_wav(filename)
+
+ # Get audio embeddings & scores.
+ scores, embeddings, mel_spectrogram = yamnet_model(audio_wav)
+
+ print(
+ "Out of {} samples, {} are not speech".format(
+ scores.shape[0], calculate_number_of_non_speech(scores)
+ )
+ )
+
+ # Predict the output of the accent recognition model with embeddings as input
+ predictions = model.predict(embeddings)
+
+ return audio_wav, predictions, mel_spectrogram
+
+
+"""
+Let's run the model on the audio file:
+"""
+
+audio_wav, predictions, mel_spectrogram = filename_to_predictions(filename)
+
+infered_class = class_names[predictions.mean(axis=0).argmax()]
+print(f"The main accent is: {infered_class} English")
+
+"""
+Listen to the audio
+"""
+
+Audio(audio_wav, rate=16000)
+
+"""
+The below function was copied from this [Yamnet notebook](tinyurl.com/4a8xn7at) and adjusted to our need.
+
+This function plots the following:
+
+* Audio waveform
+* Mel spectrogram
+* Predictions for every time step
+"""
+
+plt.figure(figsize=(10, 6))
+
+# Plot the waveform.
+plt.subplot(3, 1, 1)
+plt.plot(audio_wav)
+plt.xlim([0, len(audio_wav)])
+
+# Plot the log-mel spectrogram (returned by the model).
+plt.subplot(3, 1, 2)
+plt.imshow(
+ mel_spectrogram.numpy().T, aspect="auto", interpolation="nearest", origin="lower"
+)
+
+# Plot and label the model output scores for the top-scoring classes.
+mean_predictions = np.mean(predictions, axis=0)
+
+top_class_indices = np.argsort(mean_predictions)[::-1]
+plt.subplot(3, 1, 3)
+plt.imshow(
+ predictions[:, top_class_indices].T,
+ aspect="auto",
+ interpolation="nearest",
+ cmap="gray_r",
+)
+
+# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS
+# values from the model documentation
+patch_padding = (0.025 / 2) / 0.01
+plt.xlim([-patch_padding - 0.5, predictions.shape[0] + patch_padding - 0.5])
+# Label the top_N classes.
+yticks = range(0, len(class_names), 1)
+plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])
+_ = plt.ylim(-0.5 + np.array([len(class_names), 0]))
diff --git a/knowledge_base/audio/vocal_track_separation.py b/knowledge_base/audio/vocal_track_separation.py
new file mode 100644
index 0000000000000000000000000000000000000000..25574e0ab843f29024cea95439799b0764ee32e0
--- /dev/null
+++ b/knowledge_base/audio/vocal_track_separation.py
@@ -0,0 +1,673 @@
+"""
+Title: Vocal Track Separation with Encoder-Decoder Architecture
+Author: [Joaquin Jimenez](https://github.com/johacks/)
+Date created: 2024/12/10
+Last modified: 2024/12/10
+Description: Train a model to separate vocal tracks from music mixtures.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this tutorial, we build a vocal track separation model using an encoder-decoder
+architecture in Keras 3.
+
+We train the model on the [MUSDB18 dataset](https://doi.org/10.5281/zenodo.1117372),
+which provides music mixtures and isolated tracks for drums, bass, other, and vocals.
+
+Key concepts covered:
+
+- Audio data preprocessing using the Short-Time Fourier Transform (STFT).
+- Audio data augmentation techniques.
+- Implementing custom encoders and decoders specialized for audio data.
+- Defining appropriate loss functions and metrics for audio source separation tasks.
+
+The model architecture is derived from the TFC_TDF_Net model described in:
+
+W. Choi, M. Kim, J. Chung, D. Lee, and S. Jung, โInvestigating U-Nets with various
+intermediate blocks for spectrogram-based singing voice separation,โ in the 21st
+International Society for Music Information Retrieval Conference, 2020.
+
+For reference code, see:
+[GitHub: ws-choi/ISMIR2020_U_Nets_SVS](https://github.com/ws-choi/ISMIR2020_U_Nets_SVS).
+
+The data processing and model training routines are partly derived from:
+[ZFTurbo/Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main).
+"""
+
+"""
+## Setup
+
+Import and install all the required dependencies.
+"""
+
+"""shell
+pip install -qq audiomentations soundfile ffmpeg-binaries
+pip install -qq "keras==3.7.0"
+sudo -n apt-get install -y graphviz >/dev/null 2>&1 # Required for plotting the model
+"""
+
+import glob
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
+
+import random
+import subprocess
+import tempfile
+import typing
+from os import path
+
+import audiomentations as aug
+import ffmpeg
+import keras
+import numpy as np
+import soundfile as sf
+from IPython import display
+from keras import callbacks, layers, ops, saving
+from matplotlib import pyplot as plt
+
+"""
+## Configuration
+
+The following constants define configuration parameters for audio processing
+and model training, including dataset paths, audio chunk sizes, Short-Time Fourier
+Transform (STFT) parameters, and training hyperparameters.
+"""
+
+# MUSDB18 dataset configuration
+MUSDB_STREAMS = {"mixture": 0, "drums": 1, "bass": 2, "other": 3, "vocals": 4}
+TARGET_INSTRUMENTS = {track: MUSDB_STREAMS[track] for track in ("vocals",)}
+N_INSTRUMENTS = len(TARGET_INSTRUMENTS)
+SOURCE_INSTRUMENTS = tuple(k for k in MUSDB_STREAMS if k != "mixture")
+
+# Audio preprocessing parameters for Short-Time Fourier Transform (STFT)
+N_SUBBANDS = 4 # Number of subbands into which frequencies are split
+CHUNK_SIZE = 65024 # Number of amplitude samples per audio chunk (~4 seconds)
+STFT_N_FFT = 2048 # FFT points used in STFT
+STFT_HOP_LENGTH = 512 # Hop length for STFT
+
+# Training hyperparameters
+N_CHANNELS = 64 # Base channel count for the model
+BATCH_SIZE = 3
+ACCUMULATION_STEPS = 2
+EFFECTIVE_BATCH_SIZE = BATCH_SIZE * (ACCUMULATION_STEPS or 1)
+
+# Paths
+TMP_DIR = path.expanduser("~/.keras/tmp")
+DATASET_DIR = path.expanduser("~/.keras/datasets")
+MODEL_PATH = path.join(TMP_DIR, f"model_{keras.backend.backend()}.keras")
+CSV_LOG_PATH = path.join(TMP_DIR, f"training_{keras.backend.backend()}.csv")
+os.makedirs(DATASET_DIR, exist_ok=True)
+os.makedirs(TMP_DIR, exist_ok=True)
+
+# Set random seed for reproducibility
+keras.utils.set_random_seed(21)
+
+"""
+## MUSDB18 Dataset
+
+The MUSDB18 dataset is a standard benchmark for music source separation, containing
+150 full-length music tracks along with isolated drums, bass, other, and vocals.
+The dataset is stored in .mp4 format, and each .mp4 file includes multiple audio
+streams (mixture and individual tracks).
+
+### Download and Conversion
+
+The following utility function downloads MUSDB18 and converts its .mp4 files to
+.wav files for each instrument track, resampled to 16 kHz.
+"""
+
+
+def download_musdb18(out_dir=None):
+ """Download and extract the MUSDB18 dataset, then convert .mp4 files to .wav files.
+
+ MUSDB18 reference:
+ Rafii, Z., Liutkus, A., Stรถter, F.-R., Mimilakis, S. I., & Bittner, R. (2017).
+ MUSDB18 - a corpus for music separation (1.0.0) [Data set]. Zenodo.
+ """
+ ffmpeg.init()
+ from ffmpeg import FFMPEG_PATH
+
+ # Create output directories
+ os.makedirs((base := out_dir or tempfile.mkdtemp()), exist_ok=True)
+ if path.exists((out_dir := path.join(base, "musdb18_wav"))):
+ print("MUSDB18 dataset already downloaded")
+ return out_dir
+
+ # Download and extract the dataset
+ download_dir = keras.utils.get_file(
+ fname="musdb18",
+ origin="https://zenodo.org/records/1117372/files/musdb18.zip",
+ extract=True,
+ )
+
+ # ffmpeg command template: input, stream index, output
+ ffmpeg_args = str(FFMPEG_PATH) + " -v error -i {} -map 0:{} -vn -ar 16000 {}"
+
+ # Convert each mp4 file to multiple .wav files for each track
+ for split in ("train", "test"):
+ songs = os.listdir(path.join(download_dir, split))
+ for i, song in enumerate(songs):
+ if i % 10 == 0:
+ print(f"{split.capitalize()}: {i}/{len(songs)} songs processed")
+
+ mp4_path_orig = path.join(download_dir, split, song)
+ mp4_path = path.join(tempfile.mkdtemp(), split, song.replace(" ", "_"))
+ os.makedirs(path.dirname(mp4_path), exist_ok=True)
+ os.rename(mp4_path_orig, mp4_path)
+
+ wav_dir = path.join(out_dir, split, path.basename(mp4_path).split(".")[0])
+ os.makedirs(wav_dir, exist_ok=True)
+
+ for track in SOURCE_INSTRUMENTS:
+ out_path = path.join(wav_dir, f"{track}.wav")
+ stream_index = MUSDB_STREAMS[track]
+ args = ffmpeg_args.format(mp4_path, stream_index, out_path).split()
+ assert subprocess.run(args).returncode == 0, "ffmpeg conversion failed"
+ return out_dir
+
+
+# Download and prepare the MUSDB18 dataset
+songs = download_musdb18(out_dir=DATASET_DIR)
+
+"""
+### Custom Dataset
+
+We define a custom dataset class to generate random audio chunks and their corresponding
+labels. The dataset does the following:
+
+1. Selects a random chunk from a random song and instrument.
+2. Applies optional data augmentations.
+3. Combines isolated tracks to form new synthetic mixtures.
+4. Prepares features (mixtures) and labels (vocals) for training.
+
+This approach allows creating an effectively infinite variety of training examples
+through randomization and augmentation.
+"""
+
+
+class Dataset(keras.utils.PyDataset):
+ def __init__(
+ self,
+ songs,
+ batch_size=BATCH_SIZE,
+ chunk_size=CHUNK_SIZE,
+ batches_per_epoch=1000 * ACCUMULATION_STEPS,
+ augmentation=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.augmentation = augmentation
+ self.vocals_augmentations = [
+ aug.PitchShift(min_semitones=-5, max_semitones=5, p=0.1),
+ aug.SevenBandParametricEQ(-9, 9, p=0.25),
+ aug.TanhDistortion(0.1, 0.7, p=0.1),
+ ]
+ self.other_augmentations = [
+ aug.PitchShift(p=0.1),
+ aug.AddGaussianNoise(p=0.1),
+ ]
+ self.songs = songs
+ self.sizes = {song: self.get_track_set_size(song) for song in self.songs}
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
+ self.batches_per_epoch = batches_per_epoch
+
+ def get_track_set_size(self, song: str):
+ """Return the smallest track length in the given song directory."""
+ sizes = [len(sf.read(p)[0]) for p in glob.glob(path.join(song, "*.wav"))]
+ if max(sizes) != min(sizes):
+ print(f"Warning: {song} has different track lengths")
+ return min(sizes)
+
+ def random_chunk_of_instrument_type(self, instrument: str):
+ """Extract a random chunk for the specified instrument from a random song."""
+ song, size = random.choice(list(self.sizes.items()))
+ track = path.join(song, f"{instrument}.wav")
+
+ if self.chunk_size <= size:
+ start = np.random.randint(size - self.chunk_size + 1)
+ audio = sf.read(track, self.chunk_size, start, dtype="float32")[0]
+ audio_mono = np.mean(audio, axis=1)
+ else:
+ # If the track is shorter than chunk_size, pad the signal
+ audio_mono = np.mean(sf.read(track, dtype="float32")[0], axis=1)
+ audio_mono = np.pad(audio_mono, ((0, self.chunk_size - size),))
+
+ # If the chunk is almost silent, retry
+ if np.mean(np.abs(audio_mono)) < 0.01:
+ return self.random_chunk_of_instrument_type(instrument)
+
+ return self.data_augmentation(audio_mono, instrument)
+
+ def data_augmentation(self, audio: np.ndarray, instrument: str):
+ """Apply data augmentation to the audio chunk, if enabled."""
+
+ def coin_flip(x, probability: float, fn: typing.Callable):
+ return fn(x) if random.uniform(0, 1) < probability else x
+
+ if self.augmentation:
+ augmentations = (
+ self.vocals_augmentations
+ if instrument == "vocals"
+ else self.other_augmentations
+ )
+ # Loudness augmentation
+ audio *= np.random.uniform(0.5, 1.5, (len(audio),)).astype("float32")
+ # Random reverse
+ audio = coin_flip(audio, 0.1, lambda x: np.flip(x))
+ # Random polarity inversion
+ audio = coin_flip(audio, 0.5, lambda x: -x)
+ # Apply selected augmentations
+ for aug_ in augmentations:
+ aug_.randomize_parameters(audio, sample_rate=16000)
+ audio = aug_(audio, sample_rate=16000)
+ return audio
+
+ def random_mix_of_tracks(self) -> dict:
+ """Create a random mix of instruments by summing their individual chunks."""
+ tracks = {}
+ for instrument in SOURCE_INSTRUMENTS:
+ # Start with a single random chunk
+ mixup = [self.random_chunk_of_instrument_type(instrument)]
+
+ # Randomly add more chunks of the same instrument (mixup augmentation)
+ if self.augmentation:
+ for p in (0.2, 0.02):
+ if random.uniform(0, 1) < p:
+ mixup.append(self.random_chunk_of_instrument_type(instrument))
+
+ tracks[instrument] = np.mean(mixup, axis=0, dtype="float32")
+ return tracks
+
+ def __len__(self):
+ return self.batches_per_epoch
+
+ def __getitem__(self, idx):
+ # Generate a batch of random mixtures
+ batch = [self.random_mix_of_tracks() for _ in range(self.batch_size)]
+
+ # Features: sum of all tracks
+ batch_x = ops.sum(
+ np.array([list(track_set.values()) for track_set in batch]), axis=1
+ )
+
+ # Labels: isolated target instruments (e.g., vocals)
+ batch_y = np.array(
+ [[track_set[t] for t in TARGET_INSTRUMENTS] for track_set in batch]
+ )
+
+ return batch_x, ops.convert_to_tensor(batch_y)
+
+
+# Create train and validation datasets
+train_ds = Dataset(glob.glob(path.join(songs, "train", "*")))
+val_ds = Dataset(
+ glob.glob(path.join(songs, "test", "*")),
+ batches_per_epoch=int(0.1 * train_ds.batches_per_epoch),
+ augmentation=False,
+)
+
+"""
+### Visualize a Sample
+
+Let's visualize a random mixed audio chunk and its corresponding isolated vocals.
+This helps to understand the nature of the preprocessed input data.
+"""
+
+
+def visualize_audio_np(audio: np.ndarray, rate=16000, name="mixup"):
+ """Plot and display an audio waveform and also produce an Audio widget."""
+ plt.figure(figsize=(10, 6))
+ plt.plot(audio)
+ plt.title(f"Waveform: {name}")
+ plt.xlim(0, len(audio))
+ plt.ylabel("Amplitude")
+ plt.show()
+ # plt.savefig(f"tmp/{name}.png")
+
+ # Normalize and display audio
+ audio_norm = (audio - np.min(audio)) / (np.max(audio) - np.min(audio) + 1e-8)
+ audio_norm = (audio_norm * 2 - 1) * 0.6
+ display.display(display.Audio(audio_norm, rate=rate))
+ # sf.write(f"tmp/{name}.wav", audio_norm, rate)
+
+
+sample_batch_x, sample_batch_y = val_ds[None] # Random batch
+visualize_audio_np(ops.convert_to_numpy(sample_batch_x[0]))
+visualize_audio_np(ops.convert_to_numpy(sample_batch_y[0, 0]), name="vocals")
+
+"""
+## Model
+
+### Preprocessing
+
+The model operates on STFT representations rather than raw audio. We define a
+preprocessing model to compute STFT and a corresponding inverse transform (iSTFT).
+"""
+
+
+def stft(inputs, fft_size=STFT_N_FFT, sequence_stride=STFT_HOP_LENGTH):
+ """Compute the STFT for the input audio and return the real and imaginary parts."""
+ real_x, imag_x = ops.stft(inputs, fft_size, sequence_stride, fft_size)
+ real_x, imag_x = ops.expand_dims(real_x, -1), ops.expand_dims(imag_x, -1)
+ x = ops.concatenate((real_x, imag_x), axis=-1)
+
+ # Drop last freq sample for convenience
+ return ops.split(x, [x.shape[2] - 1], axis=2)[0]
+
+
+def inverse_stft(inputs, fft_size=STFT_N_FFT, sequence_stride=STFT_HOP_LENGTH):
+ """Compute the inverse STFT for the given STFT input."""
+ x = inputs
+
+ # Pad back dropped freq sample if using torch backend
+ if keras.backend.backend() == "torch":
+ x = ops.pad(x, ((0, 0), (0, 0), (0, 1), (0, 0)))
+
+ real_x, imag_x = ops.split(x, 2, axis=-1)
+ real_x = ops.squeeze(real_x, axis=-1)
+ imag_x = ops.squeeze(imag_x, axis=-1)
+
+ return ops.istft((real_x, imag_x), fft_size, sequence_stride, fft_size)
+
+
+"""
+### Model Architecture
+
+The model uses a custom encoder-decoder architecture with Time-Frequency Convolution
+(TFC) and Time-Distributed Fully Connected (TDF) blocks. They are grouped into a
+`TimeFrequencyTransformBlock`, i.e. "TFC_TDF" in the original paper by Choi et al.
+
+We then define an encoder-decoder network with multiple scales. Each encoder scale
+applies TFC_TDF blocks followed by downsampling, while decoder scales apply TFC_TDF
+blocks over the concatenation of upsampled features and associated encoder outputs.
+"""
+
+
+@saving.register_keras_serializable()
+class TimeDistributedDenseBlock(layers.Layer):
+ """Time-Distributed Fully Connected layer block.
+
+ Applies frequency-wise dense transformations across time frames with instance
+ normalization and GELU activation.
+ """
+
+ def __init__(self, bottleneck_factor, fft_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.fft_dim = fft_dim
+ self.hidden_dim = fft_dim // bottleneck_factor
+
+ def build(self, *_):
+ self.group_norm_1 = layers.GroupNormalization(groups=-1)
+ self.group_norm_2 = layers.GroupNormalization(groups=-1)
+ self.dense_1 = layers.Dense(self.hidden_dim, use_bias=False)
+ self.dense_2 = layers.Dense(self.fft_dim, use_bias=False)
+
+ def call(self, x):
+ # Apply normalization and dense layers frequency-wise
+ x = ops.gelu(self.group_norm_1(x))
+ x = ops.swapaxes(x, -1, -2)
+ x = self.dense_1(x)
+
+ x = ops.gelu(self.group_norm_2(ops.swapaxes(x, -1, -2)))
+ x = ops.swapaxes(x, -1, -2)
+ x = self.dense_2(x)
+ return ops.swapaxes(x, -1, -2)
+
+
+@saving.register_keras_serializable()
+class TimeFrequencyConvolution(layers.Layer):
+ """Time-Frequency Convolutional layer.
+
+ Applies a 2D convolution over time-frequency representations and applies instance
+ normalization and GELU activation.
+ """
+
+ def __init__(self, channels, **kwargs):
+ super().__init__(**kwargs)
+ self.channels = channels
+
+ def build(self, *_):
+ self.group_norm = layers.GroupNormalization(groups=-1)
+ self.conv = layers.Conv2D(self.channels, 3, padding="same", use_bias=False)
+
+ def call(self, x):
+ return self.conv(ops.gelu(self.group_norm(x)))
+
+
+@saving.register_keras_serializable()
+class TimeFrequencyTransformBlock(layers.Layer):
+ """Implements TFC_TDF block for encoder-decoder architecture.
+
+ Repeatedly apply Time-Frequency Convolution and Time-Distributed Dense blocks as
+ many times as specified by the `length` parameter.
+ """
+
+ def __init__(
+ self, channels, length, fft_dim, bottleneck_factor, in_channels=None, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.channels = channels
+ self.length = length
+ self.fft_dim = fft_dim
+ self.bottleneck_factor = bottleneck_factor
+ self.in_channels = in_channels or channels
+
+ def build(self, *_):
+ self.blocks = []
+ # Add blocks in a flat list to avoid nested structures
+ for i in range(self.length):
+ in_channels = self.channels if i > 0 else self.in_channels
+ self.blocks.append(TimeFrequencyConvolution(in_channels))
+ self.blocks.append(
+ TimeDistributedDenseBlock(self.bottleneck_factor, self.fft_dim)
+ )
+ self.blocks.append(TimeFrequencyConvolution(self.channels))
+ # Residual connection
+ self.blocks.append(layers.Conv2D(self.channels, 1, 1, use_bias=False))
+
+ def call(self, inputs):
+ x = inputs
+ # Each block consists of 4 layers:
+ # 1. Time-Frequency Convolution
+ # 2. Time-Distributed Dense
+ # 3. Time-Frequency Convolution
+ # 4. Residual connection
+ for i in range(0, len(self.blocks), 4):
+ tfc_1 = self.blocks[i](x)
+ tdf = self.blocks[i + 1](x)
+ tfc_2 = self.blocks[i + 2](tfc_1 + tdf)
+ x = tfc_2 + self.blocks[i + 3](x) # Residual connection
+ return x
+
+
+@saving.register_keras_serializable()
+class Downscale(layers.Layer):
+ """Downscale time-frequency dimensions using a convolution."""
+
+ conv_cls = layers.Conv2D
+
+ def __init__(self, channels, scale, **kwargs):
+ super().__init__(**kwargs)
+ self.channels = channels
+ self.scale = scale
+
+ def build(self, *_):
+ self.conv = self.conv_cls(self.channels, self.scale, self.scale, use_bias=False)
+ self.norm = layers.GroupNormalization(groups=-1)
+
+ def call(self, inputs):
+ return self.norm(ops.gelu(self.conv(inputs)))
+
+
+@saving.register_keras_serializable()
+class Upscale(Downscale):
+ """Upscale time-frequency dimensions using a transposed convolution."""
+
+ conv_cls = layers.Conv2DTranspose
+
+
+def build_model(
+ inputs,
+ n_instruments=N_INSTRUMENTS,
+ n_subbands=N_SUBBANDS,
+ channels=N_CHANNELS,
+ fft_dim=(STFT_N_FFT // 2) // N_SUBBANDS,
+ n_scales=4,
+ scale=(2, 2),
+ block_size=2,
+ growth=128,
+ bottleneck_factor=2,
+ **kwargs,
+):
+ """Build the TFC_TDF encoder-decoder model for source separation."""
+ # Compute STFT
+ x = stft(inputs)
+
+ # Split mixture into subbands as separate channels
+ mix = ops.reshape(x, (-1, x.shape[1], x.shape[2] // n_subbands, 2 * n_subbands))
+ first_conv_out = layers.Conv2D(channels, 1, 1, use_bias=False)(mix)
+ x = first_conv_out
+
+ # Encoder path
+ encoder_outs = []
+ for _ in range(n_scales):
+ x = TimeFrequencyTransformBlock(
+ channels, block_size, fft_dim, bottleneck_factor
+ )(x)
+ encoder_outs.append(x)
+ fft_dim, channels = fft_dim // scale[0], channels + growth
+ x = Downscale(channels, scale)(x)
+
+ # Bottleneck
+ x = TimeFrequencyTransformBlock(channels, block_size, fft_dim, bottleneck_factor)(x)
+
+ # Decoder path
+ for _ in range(n_scales):
+ fft_dim, channels = fft_dim * scale[0], channels - growth
+ x = ops.concatenate([Upscale(channels, scale)(x), encoder_outs.pop()], axis=-1)
+ x = TimeFrequencyTransformBlock(
+ channels, block_size, fft_dim, bottleneck_factor, in_channels=x.shape[-1]
+ )(x)
+
+ # Residual connection and final convolutions
+ x = ops.concatenate([mix, x * first_conv_out], axis=-1)
+ x = layers.Conv2D(channels, 1, 1, use_bias=False, activation="gelu")(x)
+ x = layers.Conv2D(n_instruments * n_subbands * 2, 1, 1, use_bias=False)(x)
+
+ # Reshape back to instrument-wise STFT
+ x = ops.reshape(x, (-1, x.shape[1], x.shape[2] * n_subbands, n_instruments, 2))
+ x = ops.transpose(x, (0, 3, 1, 2, 4))
+ x = ops.reshape(x, (-1, n_instruments, x.shape[2], x.shape[3] * 2))
+
+ return keras.Model(inputs=inputs, outputs=x, **kwargs)
+
+
+"""
+## Loss and Metrics
+
+We define:
+
+- `spectral_loss`: Mean absolute error in STFT domain.
+- `sdr`: Signal-to-Distortion Ratio, a common source separation metric.
+"""
+
+
+def prediction_to_wave(x, n_instruments=N_INSTRUMENTS):
+ """Convert STFT predictions back to waveform."""
+ x = ops.reshape(x, (-1, x.shape[2], x.shape[3] // 2, 2))
+ x = inverse_stft(x)
+ return ops.reshape(x, (-1, n_instruments, x.shape[1]))
+
+
+def target_to_stft(y):
+ """Convert target waveforms to their STFT representations."""
+ y = ops.reshape(y, (-1, CHUNK_SIZE))
+ y_real, y_imag = ops.stft(y, STFT_N_FFT, STFT_HOP_LENGTH, STFT_N_FFT)
+ y_real, y_imag = y_real[..., :-1], y_imag[..., :-1]
+ y = ops.stack([y_real, y_imag], axis=-1)
+ return ops.reshape(y, (-1, N_INSTRUMENTS, y.shape[1], y.shape[2] * 2))
+
+
+@saving.register_keras_serializable()
+def sdr(y_true, y_pred):
+ """Signal-to-Distortion Ratio metric."""
+ y_pred = prediction_to_wave(y_pred)
+ # Add epsilon for numerical stability
+ num = ops.sum(ops.square(y_true), axis=-1) + 1e-8
+ den = ops.sum(ops.square(y_true - y_pred), axis=-1) + 1e-8
+ return 10 * ops.log10(num / den)
+
+
+@saving.register_keras_serializable()
+def spectral_loss(y_true, y_pred):
+ """Mean absolute error in the STFT domain."""
+ y_true = target_to_stft(y_true)
+ return ops.mean(ops.absolute(y_true - y_pred))
+
+
+"""
+## Training
+
+### Visualize Model Architecture
+"""
+
+# Load or create the model
+if path.exists(MODEL_PATH):
+ model = saving.load_model(MODEL_PATH)
+else:
+ model = build_model(keras.Input(sample_batch_x.shape[1:]), name="tfc_tdf_net")
+
+# Display the model architecture
+model.summary()
+img = keras.utils.plot_model(model, path.join(TMP_DIR, "model.png"), show_shapes=True)
+display.display(img)
+
+"""
+### Compile and Train the Model
+"""
+
+# Compile the model
+optimizer = keras.optimizers.Adam(5e-05, gradient_accumulation_steps=ACCUMULATION_STEPS)
+model.compile(optimizer=optimizer, loss=spectral_loss, metrics=[sdr])
+
+# Define callbacks
+cbs = [
+ callbacks.ModelCheckpoint(MODEL_PATH, "val_sdr", save_best_only=True, mode="max"),
+ callbacks.ReduceLROnPlateau(factor=0.95, patience=2),
+ callbacks.CSVLogger(CSV_LOG_PATH),
+]
+
+if not path.exists(MODEL_PATH):
+ model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=cbs, shuffle=False)
+else:
+ # Demonstration of a single epoch of training when model already exists
+ model.fit(train_ds, validation_data=val_ds, epochs=1, shuffle=False, verbose=2)
+
+"""
+## Evaluation
+
+Evaluate the model on the validation dataset and visualize predicted vocals.
+"""
+
+model.evaluate(val_ds, verbose=2)
+y_pred = model.predict(sample_batch_x, verbose=2)
+y_pred = prediction_to_wave(y_pred)
+visualize_audio_np(ops.convert_to_numpy(y_pred[0, 0]), name="vocals_pred")
+
+"""
+## Conclusion
+
+We built and trained a vocal track separation model using an encoder-decoder
+architecture with custom blocks applied to the MUSDB18 dataset. We demonstrated
+STFT-based preprocessing, data augmentation, and a source separation metric (SDR).
+
+**Next steps:**
+
+- Train for more epochs and refine hyperparameters.
+- Separate multiple instruments simultaneously.
+- Enhance the model to handle instruments not present in the mixture.
+"""
diff --git a/knowledge_base/generative/adain.py b/knowledge_base/generative/adain.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a0ea5539afcff37f5eae38b5974a7a56515ecf3
--- /dev/null
+++ b/knowledge_base/generative/adain.py
@@ -0,0 +1,677 @@
+"""
+Title: Neural Style Transfer with AdaIN
+Author: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+Date created: 2021/11/08
+Last modified: 2021/11/08
+Description: Neural Style Transfer with Adaptive Instance Normalization.
+Accelerator: GPU
+"""
+
+"""
+# Introduction
+
+[Neural Style Transfer](https://www.tensorflow.org/tutorials/generative/style_transfer)
+is the process of transferring the style of one image onto the content
+of another. This was first introduced in the seminal paper
+["A Neural Algorithm of Artistic Style"](https://arxiv.org/abs/1508.06576)
+by Gatys et al. A major limitation of the technique proposed in this
+work is in its runtime, as the algorithm uses a slow iterative
+optimization process.
+
+Follow-up papers that introduced
+[Batch Normalization](https://arxiv.org/abs/1502.03167),
+[Instance Normalization](https://arxiv.org/abs/1701.02096) and
+[Conditional Instance Normalization](https://arxiv.org/abs/1610.07629)
+allowed Style Transfer to be performed in new ways, no longer
+requiring a slow iterative process.
+
+Following these papers, the authors Xun Huang and Serge
+Belongie propose
+[Adaptive Instance Normalization](https://arxiv.org/abs/1703.06868) (AdaIN),
+which allows arbitrary style transfer in real time.
+
+In this example we implement Adaptive Instance Normalization
+for Neural Style Transfer. We show in the below figure the output
+of our AdaIN model trained for
+only **30 epochs**.
+
+
+
+You can also try out the model with your own images with this
+[Hugging Face demo](https://huggingface.co/spaces/ariG23498/nst).
+"""
+
+"""
+# Setup
+
+We begin with importing the necessary packages. We also set the
+seed for reproducibility. The global variables are hyperparameters
+which we can change as we like.
+"""
+
+import os
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+import matplotlib.pyplot as plt
+import tensorflow_datasets as tfds
+from tensorflow.keras import layers
+
+# Defining the global variables.
+IMAGE_SIZE = (224, 224)
+BATCH_SIZE = 64
+# Training for single epoch for time constraint.
+# Please use atleast 30 epochs to see good results.
+EPOCHS = 1
+AUTOTUNE = tf.data.AUTOTUNE
+
+"""
+## Style transfer sample gallery
+
+For Neural Style Transfer we need style images and content images. In
+this example we will use the
+[Best Artworks of All Time](https://www.kaggle.com/ikarus777/best-artworks-of-all-time)
+as our style dataset and
+[Pascal VOC](https://www.tensorflow.org/datasets/catalog/voc)
+as our content dataset.
+
+This is a deviation from the original paper implementation by the
+authors, where they use
+[WIKI-Art](https://paperswithcode.com/dataset/wikiart) as style and
+[MSCOCO](https://cocodataset.org/#home) as content datasets
+respectively. We do this to create a minimal yet reproducible example.
+
+## Downloading the dataset from Kaggle
+
+The [Best Artworks of All Time](https://www.kaggle.com/ikarus777/best-artworks-of-all-time)
+dataset is hosted on Kaggle and one can easily download it in Colab by
+following these steps:
+
+- Follow the instructions [here](https://github.com/Kaggle/kaggle-api)
+in order to obtain your Kaggle API keys in case you don't have them.
+- Use the following command to upload the Kaggle API keys.
+
+```python
+from google.colab import files
+files.upload()
+```
+
+- Use the following commands to move the API keys to the proper
+directory and download the dataset.
+
+```shell
+$ mkdir ~/.kaggle
+$ cp kaggle.json ~/.kaggle/
+$ chmod 600 ~/.kaggle/kaggle.json
+$ kaggle datasets download ikarus777/best-artworks-of-all-time
+$ unzip -qq best-artworks-of-all-time.zip
+$ rm -rf images
+$ mv resized artwork
+$ rm best-artworks-of-all-time.zip artists.csv
+```
+"""
+
+"""
+## `tf.data` pipeline
+
+In this section, we will build the `tf.data` pipeline for the project.
+For the style dataset, we decode, convert and resize the images from
+the folder. For the content images we are already presented with a
+`tf.data` dataset as we use the `tfds` module.
+
+After we have our style and content data pipeline ready, we zip the
+two together to obtain the data pipeline that our model will consume.
+"""
+
+
+def decode_and_resize(image_path):
+ """Decodes and resizes an image from the image file path.
+
+ Args:
+ image_path: The image file path.
+
+ Returns:
+ A resized image.
+ """
+ image = tf.io.read_file(image_path)
+ image = tf.image.decode_jpeg(image, channels=3)
+ image = tf.image.convert_image_dtype(image, dtype="float32")
+ image = tf.image.resize(image, IMAGE_SIZE)
+ return image
+
+
+def extract_image_from_voc(element):
+ """Extracts image from the PascalVOC dataset.
+
+ Args:
+ element: A dictionary of data.
+
+ Returns:
+ A resized image.
+ """
+ image = element["image"]
+ image = tf.image.convert_image_dtype(image, dtype="float32")
+ image = tf.image.resize(image, IMAGE_SIZE)
+ return image
+
+
+# Get the image file paths for the style images.
+style_images = os.listdir("artwork/resized")
+style_images = [os.path.join("artwork/resized", path) for path in style_images]
+
+# split the style images in train, val and test
+total_style_images = len(style_images)
+train_style = style_images[: int(0.8 * total_style_images)]
+val_style = style_images[int(0.8 * total_style_images) : int(0.9 * total_style_images)]
+test_style = style_images[int(0.9 * total_style_images) :]
+
+# Build the style and content tf.data datasets.
+train_style_ds = (
+ tf.data.Dataset.from_tensor_slices(train_style)
+ .map(decode_and_resize, num_parallel_calls=AUTOTUNE)
+ .repeat()
+)
+train_content_ds = tfds.load("voc", split="train").map(extract_image_from_voc).repeat()
+
+val_style_ds = (
+ tf.data.Dataset.from_tensor_slices(val_style)
+ .map(decode_and_resize, num_parallel_calls=AUTOTUNE)
+ .repeat()
+)
+val_content_ds = (
+ tfds.load("voc", split="validation").map(extract_image_from_voc).repeat()
+)
+
+test_style_ds = (
+ tf.data.Dataset.from_tensor_slices(test_style)
+ .map(decode_and_resize, num_parallel_calls=AUTOTUNE)
+ .repeat()
+)
+test_content_ds = (
+ tfds.load("voc", split="test")
+ .map(extract_image_from_voc, num_parallel_calls=AUTOTUNE)
+ .repeat()
+)
+
+# Zipping the style and content datasets.
+train_ds = (
+ tf.data.Dataset.zip((train_style_ds, train_content_ds))
+ .shuffle(BATCH_SIZE * 2)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTOTUNE)
+)
+
+val_ds = (
+ tf.data.Dataset.zip((val_style_ds, val_content_ds))
+ .shuffle(BATCH_SIZE * 2)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTOTUNE)
+)
+
+test_ds = (
+ tf.data.Dataset.zip((test_style_ds, test_content_ds))
+ .shuffle(BATCH_SIZE * 2)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTOTUNE)
+)
+
+"""
+## Visualizing the data
+
+It is always better to visualize the data before training. To ensure
+the correctness of our preprocessing pipeline, we visualize 10 samples
+from our dataset.
+"""
+
+style, content = next(iter(train_ds))
+fig, axes = plt.subplots(nrows=10, ncols=2, figsize=(5, 30))
+[ax.axis("off") for ax in np.ravel(axes)]
+
+for axis, style_image, content_image in zip(axes, style[0:10], content[0:10]):
+ (ax_style, ax_content) = axis
+ ax_style.imshow(style_image)
+ ax_style.set_title("Style Image")
+
+ ax_content.imshow(content_image)
+ ax_content.set_title("Content Image")
+
+"""
+## Architecture
+
+The style transfer network takes a content image and a style image as
+inputs and outputs the style transferred image. The authors of AdaIN
+propose a simple encoder-decoder structure for achieving this.
+
+
+
+The content image (`C`) and the style image (`S`) are both fed to the
+encoder networks. The output from these encoder networks (feature maps)
+are then fed to the AdaIN layer. The AdaIN layer computes a combined
+feature map. This feature map is then fed into a randomly initialized
+decoder network that serves as the generator for the neural style
+transferred image.
+
+
+
+The style feature map (`fs`) and the content feature map (`fc`) are
+fed to the AdaIN layer. This layer produced the combined feature map
+`t`. The function `g` represents the decoder (generator) network.
+"""
+
+"""
+### Encoder
+
+The encoder is a part of the pretrained (pretrained on
+[imagenet](https://www.image-net.org/)) VGG19 model. We slice the
+model from the `block4-conv1` layer. The output layer is as suggested
+by the authors in their paper.
+"""
+
+
+def get_encoder():
+ vgg19 = keras.applications.VGG19(
+ include_top=False,
+ weights="imagenet",
+ input_shape=(*IMAGE_SIZE, 3),
+ )
+ vgg19.trainable = False
+ mini_vgg19 = keras.Model(vgg19.input, vgg19.get_layer("block4_conv1").output)
+
+ inputs = layers.Input([*IMAGE_SIZE, 3])
+ mini_vgg19_out = mini_vgg19(inputs)
+ return keras.Model(inputs, mini_vgg19_out, name="mini_vgg19")
+
+
+"""
+### Adaptive Instance Normalization
+
+The AdaIN layer takes in the features
+of the content and style image. The layer can be defined via the
+following equation:
+
+
+
+where `sigma` is the standard deviation and `mu` is the mean for the
+concerned variable. In the above equation the mean and variance of the
+content feature map `fc` is aligned with the mean and variance of the
+style feature maps `fs`.
+
+It is important to note that the AdaIN layer proposed by the authors
+uses no other parameters apart from mean and variance. The layer also
+does not have any trainable parameters. This is why we use a
+*Python function* instead of using a *Keras layer*. The function takes
+style and content feature maps, computes the mean and standard deviation
+of the images and returns the adaptive instance normalized feature map.
+"""
+
+
+def get_mean_std(x, epsilon=1e-5):
+ axes = [1, 2]
+
+ # Compute the mean and standard deviation of a tensor.
+ mean, variance = tf.nn.moments(x, axes=axes, keepdims=True)
+ standard_deviation = tf.sqrt(variance + epsilon)
+ return mean, standard_deviation
+
+
+def ada_in(style, content):
+ """Computes the AdaIn feature map.
+
+ Args:
+ style: The style feature map.
+ content: The content feature map.
+
+ Returns:
+ The AdaIN feature map.
+ """
+ content_mean, content_std = get_mean_std(content)
+ style_mean, style_std = get_mean_std(style)
+ t = style_std * (content - content_mean) / content_std + style_mean
+ return t
+
+
+"""
+### Decoder
+
+The authors specify that the decoder network must mirror the encoder
+network. We have symmetrically inverted the encoder to build our
+decoder. We have used `UpSampling2D` layers to increase the spatial
+resolution of the feature maps.
+
+Note that the authors warn against using any normalization layer
+in the decoder network, and do indeed go on to show that including
+batch normalization or instance normalization hurts the performance
+of the overall network.
+
+This is the only portion of the entire architecture that is trainable.
+"""
+
+
+def get_decoder():
+ config = {"kernel_size": 3, "strides": 1, "padding": "same", "activation": "relu"}
+ decoder = keras.Sequential(
+ [
+ layers.InputLayer((None, None, 512)),
+ layers.Conv2D(filters=512, **config),
+ layers.UpSampling2D(),
+ layers.Conv2D(filters=256, **config),
+ layers.Conv2D(filters=256, **config),
+ layers.Conv2D(filters=256, **config),
+ layers.Conv2D(filters=256, **config),
+ layers.UpSampling2D(),
+ layers.Conv2D(filters=128, **config),
+ layers.Conv2D(filters=128, **config),
+ layers.UpSampling2D(),
+ layers.Conv2D(filters=64, **config),
+ layers.Conv2D(
+ filters=3,
+ kernel_size=3,
+ strides=1,
+ padding="same",
+ activation="sigmoid",
+ ),
+ ]
+ )
+ return decoder
+
+
+"""
+### Loss functions
+
+Here we build the loss functions for the neural style transfer model.
+The authors propose to use a pretrained VGG-19 to compute the loss
+function of the network. It is important to keep in mind that this
+will be used for training only the decoder network. The total
+loss (`Lt`) is a weighted combination of content loss (`Lc`) and style
+loss (`Ls`). The `lambda` term is used to vary the amount of style
+transferred.
+
+
+
+### Content Loss
+
+This is the Euclidean distance between the content image features
+and the features of the neural style transferred image.
+
+
+
+Here the authors propose to use the output from the AdaIn layer `t` as
+the content target rather than using features of the original image as
+target. This is done to speed up convergence.
+
+### Style Loss
+
+Rather than using the more commonly used
+[Gram Matrix](https://mathworld.wolfram.com/GramMatrix.html),
+the authors propose to compute the difference between the statistical features
+(mean and variance) which makes it conceptually cleaner. This can be
+easily visualized via the following equation:
+
+
+
+where `theta` denotes the layers in VGG-19 used to compute the loss.
+In this case this corresponds to:
+
+- `block1_conv1`
+- `block1_conv2`
+- `block1_conv3`
+- `block1_conv4`
+
+"""
+
+
+def get_loss_net():
+ vgg19 = keras.applications.VGG19(
+ include_top=False, weights="imagenet", input_shape=(*IMAGE_SIZE, 3)
+ )
+ vgg19.trainable = False
+ layer_names = ["block1_conv1", "block2_conv1", "block3_conv1", "block4_conv1"]
+ outputs = [vgg19.get_layer(name).output for name in layer_names]
+ mini_vgg19 = keras.Model(vgg19.input, outputs)
+
+ inputs = layers.Input([*IMAGE_SIZE, 3])
+ mini_vgg19_out = mini_vgg19(inputs)
+ return keras.Model(inputs, mini_vgg19_out, name="loss_net")
+
+
+"""
+## Neural Style Transfer
+
+This is the trainer module. We wrap the encoder and decoder inside
+a `tf.keras.Model` subclass. This allows us to customize what happens
+in the `model.fit()` loop.
+"""
+
+
+class NeuralStyleTransfer(tf.keras.Model):
+ def __init__(self, encoder, decoder, loss_net, style_weight, **kwargs):
+ super().__init__(**kwargs)
+ self.encoder = encoder
+ self.decoder = decoder
+ self.loss_net = loss_net
+ self.style_weight = style_weight
+
+ def compile(self, optimizer, loss_fn):
+ super().compile()
+ self.optimizer = optimizer
+ self.loss_fn = loss_fn
+ self.style_loss_tracker = keras.metrics.Mean(name="style_loss")
+ self.content_loss_tracker = keras.metrics.Mean(name="content_loss")
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
+
+ def train_step(self, inputs):
+ style, content = inputs
+
+ # Initialize the content and style loss.
+ loss_content = 0.0
+ loss_style = 0.0
+
+ with tf.GradientTape() as tape:
+ # Encode the style and content image.
+ style_encoded = self.encoder(style)
+ content_encoded = self.encoder(content)
+
+ # Compute the AdaIN target feature maps.
+ t = ada_in(style=style_encoded, content=content_encoded)
+
+ # Generate the neural style transferred image.
+ reconstructed_image = self.decoder(t)
+
+ # Compute the losses.
+ reconstructed_vgg_features = self.loss_net(reconstructed_image)
+ style_vgg_features = self.loss_net(style)
+ loss_content = self.loss_fn(t, reconstructed_vgg_features[-1])
+ for inp, out in zip(style_vgg_features, reconstructed_vgg_features):
+ mean_inp, std_inp = get_mean_std(inp)
+ mean_out, std_out = get_mean_std(out)
+ loss_style += self.loss_fn(mean_inp, mean_out) + self.loss_fn(
+ std_inp, std_out
+ )
+ loss_style = self.style_weight * loss_style
+ total_loss = loss_content + loss_style
+
+ # Compute gradients and optimize the decoder.
+ trainable_vars = self.decoder.trainable_variables
+ gradients = tape.gradient(total_loss, trainable_vars)
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+
+ # Update the trackers.
+ self.style_loss_tracker.update_state(loss_style)
+ self.content_loss_tracker.update_state(loss_content)
+ self.total_loss_tracker.update_state(total_loss)
+ return {
+ "style_loss": self.style_loss_tracker.result(),
+ "content_loss": self.content_loss_tracker.result(),
+ "total_loss": self.total_loss_tracker.result(),
+ }
+
+ def test_step(self, inputs):
+ style, content = inputs
+
+ # Initialize the content and style loss.
+ loss_content = 0.0
+ loss_style = 0.0
+
+ # Encode the style and content image.
+ style_encoded = self.encoder(style)
+ content_encoded = self.encoder(content)
+
+ # Compute the AdaIN target feature maps.
+ t = ada_in(style=style_encoded, content=content_encoded)
+
+ # Generate the neural style transferred image.
+ reconstructed_image = self.decoder(t)
+
+ # Compute the losses.
+ recons_vgg_features = self.loss_net(reconstructed_image)
+ style_vgg_features = self.loss_net(style)
+ loss_content = self.loss_fn(t, recons_vgg_features[-1])
+ for inp, out in zip(style_vgg_features, recons_vgg_features):
+ mean_inp, std_inp = get_mean_std(inp)
+ mean_out, std_out = get_mean_std(out)
+ loss_style += self.loss_fn(mean_inp, mean_out) + self.loss_fn(
+ std_inp, std_out
+ )
+ loss_style = self.style_weight * loss_style
+ total_loss = loss_content + loss_style
+
+ # Update the trackers.
+ self.style_loss_tracker.update_state(loss_style)
+ self.content_loss_tracker.update_state(loss_content)
+ self.total_loss_tracker.update_state(total_loss)
+ return {
+ "style_loss": self.style_loss_tracker.result(),
+ "content_loss": self.content_loss_tracker.result(),
+ "total_loss": self.total_loss_tracker.result(),
+ }
+
+ @property
+ def metrics(self):
+ return [
+ self.style_loss_tracker,
+ self.content_loss_tracker,
+ self.total_loss_tracker,
+ ]
+
+
+"""
+## Train Monitor callback
+
+This callback is used to visualize the style transfer output of
+the model at the end of each epoch. The objective of style transfer cannot be
+quantified properly, and is to be subjectively evaluated by an audience.
+For this reason, visualization is a key aspect of evaluating the model.
+"""
+
+test_style, test_content = next(iter(test_ds))
+
+
+class TrainMonitor(tf.keras.callbacks.Callback):
+ def on_epoch_end(self, epoch, logs=None):
+ # Encode the style and content image.
+ test_style_encoded = self.model.encoder(test_style)
+ test_content_encoded = self.model.encoder(test_content)
+
+ # Compute the AdaIN features.
+ test_t = ada_in(style=test_style_encoded, content=test_content_encoded)
+ test_reconstructed_image = self.model.decoder(test_t)
+
+ # Plot the Style, Content and the NST image.
+ fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 5))
+ ax[0].imshow(tf.keras.utils.array_to_img(test_style[0]))
+ ax[0].set_title(f"Style: {epoch:03d}")
+
+ ax[1].imshow(tf.keras.utils.array_to_img(test_content[0]))
+ ax[1].set_title(f"Content: {epoch:03d}")
+
+ ax[2].imshow(tf.keras.utils.array_to_img(test_reconstructed_image[0]))
+ ax[2].set_title(f"NST: {epoch:03d}")
+
+ plt.show()
+ plt.close()
+
+
+"""
+## Train the model
+
+In this section, we define the optimizer, the loss function, and the
+trainer module. We compile the trainer module with the optimizer and
+the loss function and then train it.
+
+*Note*: We train the model for a single epoch for time constraints,
+but we will need to train is for atleast 30 epochs to see good results.
+"""
+
+optimizer = keras.optimizers.Adam(learning_rate=1e-5)
+loss_fn = keras.losses.MeanSquaredError()
+
+encoder = get_encoder()
+loss_net = get_loss_net()
+decoder = get_decoder()
+
+model = NeuralStyleTransfer(
+ encoder=encoder, decoder=decoder, loss_net=loss_net, style_weight=4.0
+)
+
+model.compile(optimizer=optimizer, loss_fn=loss_fn)
+
+history = model.fit(
+ train_ds,
+ epochs=EPOCHS,
+ steps_per_epoch=50,
+ validation_data=val_ds,
+ validation_steps=50,
+ callbacks=[TrainMonitor()],
+)
+
+"""
+## Inference
+
+After we train the model, we now need to run inference with it. We will
+pass arbitrary content and style images from the test dataset and take a look at
+the output images.
+
+*NOTE*: To try out the model on your own images, you can use this
+[Hugging Face demo](https://huggingface.co/spaces/ariG23498/nst).
+"""
+
+for style, content in test_ds.take(1):
+ style_encoded = model.encoder(style)
+ content_encoded = model.encoder(content)
+ t = ada_in(style=style_encoded, content=content_encoded)
+ reconstructed_image = model.decoder(t)
+ fig, axes = plt.subplots(nrows=10, ncols=3, figsize=(10, 30))
+ [ax.axis("off") for ax in np.ravel(axes)]
+
+ for axis, style_image, content_image, reconstructed_image in zip(
+ axes, style[0:10], content[0:10], reconstructed_image[0:10]
+ ):
+ (ax_style, ax_content, ax_reconstructed) = axis
+ ax_style.imshow(style_image)
+ ax_style.set_title("Style Image")
+ ax_content.imshow(content_image)
+ ax_content.set_title("Content Image")
+ ax_reconstructed.imshow(reconstructed_image)
+ ax_reconstructed.set_title("NST Image")
+
+"""
+## Conclusion
+
+Adaptive Instance Normalization allows arbitrary style transfer in
+real time. It is also important to note that the novel proposition of
+the authors is to achieve this only by aligning the statistical
+features (mean and standard deviation) of the style and the content
+images.
+
+*Note*: AdaIN also serves as the base for
+[Style-GANs](https://arxiv.org/abs/1812.04948).
+
+## Reference
+
+- [TF implementation](https://github.com/ftokarev/tf-adain)
+
+## Acknowledgement
+
+We thank [Luke Wood](https://lukewood.xyz) for his
+detailed review.
+"""
diff --git a/knowledge_base/generative/conditional_gan.py b/knowledge_base/generative/conditional_gan.py
new file mode 100644
index 0000000000000000000000000000000000000000..810e28d0fe04513f9636fb4eec2adc5fc66d6f22
--- /dev/null
+++ b/knowledge_base/generative/conditional_gan.py
@@ -0,0 +1,336 @@
+"""
+Title: Conditional GAN
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/07/13
+Last modified: 2024/01/02
+Description: Training a GAN conditioned on class labels to generate handwritten digits.
+Accelerator: GPU
+"""
+
+"""
+Generative Adversarial Networks (GANs) let us generate novel image data, video data,
+or audio data from a random input. Typically, the random input is sampled
+from a normal distribution, before going through a series of transformations that turn
+it into something plausible (image, video, audio, etc.).
+
+However, a simple [DCGAN](https://arxiv.org/abs/1511.06434) doesn't let us control
+the appearance (e.g. class) of the samples we're generating. For instance,
+with a GAN that generates MNIST handwritten digits, a simple DCGAN wouldn't let us
+choose the class of digits we're generating.
+To be able to control what we generate, we need to _condition_ the GAN output
+on a semantic input, such as the class of an image.
+
+In this example, we'll build a **Conditional GAN** that can generate MNIST handwritten
+digits conditioned on a given class. Such a model can have various useful applications:
+
+* let's say you are dealing with an
+[imbalanced image dataset](https://developers.google.com/machine-learning/data-prep/construct/sampling-splitting/imbalanced-data),
+and you'd like to gather more examples for the skewed class to balance the dataset.
+Data collection can be a costly process on its own. You could instead train a Conditional GAN and use
+it to generate novel images for the class that needs balancing.
+* Since the generator learns to associate the generated samples with the class labels,
+its representations can also be used for [other downstream tasks](https://arxiv.org/abs/1809.11096).
+
+Following are the references used for developing this example:
+
+* [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
+* [Lecture on Conditional Generation from Coursera](https://www.coursera.org/lecture/build-basic-generative-adversarial-networks-gans/conditional-generation-inputs-2OPrG)
+
+If you need a refresher on GANs, you can refer to the "Generative adversarial networks"
+section of
+[this resource](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/r-3/232).
+
+This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be
+installed using the following command:
+"""
+
+"""shell
+pip install -q git+https://github.com/tensorflow/docs
+"""
+
+"""
+## Imports
+"""
+
+import keras
+
+from keras import layers
+from keras import ops
+from tensorflow_docs.vis import embed
+import tensorflow as tf
+import numpy as np
+import imageio
+
+"""
+## Constants and hyperparameters
+"""
+
+batch_size = 64
+num_channels = 1
+num_classes = 10
+image_size = 28
+latent_dim = 128
+
+"""
+## Loading the MNIST dataset and preprocessing it
+"""
+
+# We'll use all the available examples from both the training and test
+# sets.
+(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
+all_digits = np.concatenate([x_train, x_test])
+all_labels = np.concatenate([y_train, y_test])
+
+# Scale the pixel values to [0, 1] range, add a channel dimension to
+# the images, and one-hot encode the labels.
+all_digits = all_digits.astype("float32") / 255.0
+all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
+all_labels = keras.utils.to_categorical(all_labels, 10)
+
+# Create tf.data.Dataset.
+dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
+dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
+
+print(f"Shape of training images: {all_digits.shape}")
+print(f"Shape of training labels: {all_labels.shape}")
+
+"""
+## Calculating the number of input channel for the generator and discriminator
+
+In a regular (unconditional) GAN, we start by sampling noise (of some fixed
+dimension) from a normal distribution. In our case, we also need to account
+for the class labels. We will have to add the number of classes to
+the input channels of the generator (noise input) as well as the discriminator
+(generated image input).
+"""
+
+generator_in_channels = latent_dim + num_classes
+discriminator_in_channels = num_channels + num_classes
+print(generator_in_channels, discriminator_in_channels)
+
+"""
+## Creating the discriminator and generator
+
+The model definitions (`discriminator`, `generator`, and `ConditionalGAN`) have been
+adapted from [this example](https://keras.io/guides/customizing_what_happens_in_fit/).
+"""
+
+# Create the discriminator.
+discriminator = keras.Sequential(
+ [
+ keras.layers.InputLayer((28, 28, discriminator_in_channels)),
+ layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.GlobalMaxPooling2D(),
+ layers.Dense(1),
+ ],
+ name="discriminator",
+)
+
+# Create the generator.
+generator = keras.Sequential(
+ [
+ keras.layers.InputLayer((generator_in_channels,)),
+ # We want to generate 128 + num_classes coefficients to reshape into a
+ # 7x7x(128 + num_classes) map.
+ layers.Dense(7 * 7 * generator_in_channels),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Reshape((7, 7, generator_in_channels)),
+ layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
+ ],
+ name="generator",
+)
+
+"""
+## Creating a `ConditionalGAN` model
+"""
+
+
+class ConditionalGAN(keras.Model):
+ def __init__(self, discriminator, generator, latent_dim):
+ super().__init__()
+ self.discriminator = discriminator
+ self.generator = generator
+ self.latent_dim = latent_dim
+ self.seed_generator = keras.random.SeedGenerator(1337)
+ self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
+ self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
+
+ @property
+ def metrics(self):
+ return [self.gen_loss_tracker, self.disc_loss_tracker]
+
+ def compile(self, d_optimizer, g_optimizer, loss_fn):
+ super().compile()
+ self.d_optimizer = d_optimizer
+ self.g_optimizer = g_optimizer
+ self.loss_fn = loss_fn
+
+ def train_step(self, data):
+ # Unpack the data.
+ real_images, one_hot_labels = data
+
+ # Add dummy dimensions to the labels so that they can be concatenated with
+ # the images. This is for the discriminator.
+ image_one_hot_labels = one_hot_labels[:, :, None, None]
+ image_one_hot_labels = ops.repeat(
+ image_one_hot_labels, repeats=[image_size * image_size]
+ )
+ image_one_hot_labels = ops.reshape(
+ image_one_hot_labels, (-1, image_size, image_size, num_classes)
+ )
+
+ # Sample random points in the latent space and concatenate the labels.
+ # This is for the generator.
+ batch_size = ops.shape(real_images)[0]
+ random_latent_vectors = keras.random.normal(
+ shape=(batch_size, self.latent_dim), seed=self.seed_generator
+ )
+ random_vector_labels = ops.concatenate(
+ [random_latent_vectors, one_hot_labels], axis=1
+ )
+
+ # Decode the noise (guided by labels) to fake images.
+ generated_images = self.generator(random_vector_labels)
+
+ # Combine them with real images. Note that we are concatenating the labels
+ # with these images here.
+ fake_image_and_labels = ops.concatenate(
+ [generated_images, image_one_hot_labels], -1
+ )
+ real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
+ combined_images = ops.concatenate(
+ [fake_image_and_labels, real_image_and_labels], axis=0
+ )
+
+ # Assemble labels discriminating real from fake images.
+ labels = ops.concatenate(
+ [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
+ )
+
+ # Train the discriminator.
+ with tf.GradientTape() as tape:
+ predictions = self.discriminator(combined_images)
+ d_loss = self.loss_fn(labels, predictions)
+ grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
+ self.d_optimizer.apply_gradients(
+ zip(grads, self.discriminator.trainable_weights)
+ )
+
+ # Sample random points in the latent space.
+ random_latent_vectors = keras.random.normal(
+ shape=(batch_size, self.latent_dim), seed=self.seed_generator
+ )
+ random_vector_labels = ops.concatenate(
+ [random_latent_vectors, one_hot_labels], axis=1
+ )
+
+ # Assemble labels that say "all real images".
+ misleading_labels = ops.zeros((batch_size, 1))
+
+ # Train the generator (note that we should *not* update the weights
+ # of the discriminator)!
+ with tf.GradientTape() as tape:
+ fake_images = self.generator(random_vector_labels)
+ fake_image_and_labels = ops.concatenate(
+ [fake_images, image_one_hot_labels], -1
+ )
+ predictions = self.discriminator(fake_image_and_labels)
+ g_loss = self.loss_fn(misleading_labels, predictions)
+ grads = tape.gradient(g_loss, self.generator.trainable_weights)
+ self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
+
+ # Monitor loss.
+ self.gen_loss_tracker.update_state(g_loss)
+ self.disc_loss_tracker.update_state(d_loss)
+ return {
+ "g_loss": self.gen_loss_tracker.result(),
+ "d_loss": self.disc_loss_tracker.result(),
+ }
+
+
+"""
+## Training the Conditional GAN
+"""
+
+cond_gan = ConditionalGAN(
+ discriminator=discriminator, generator=generator, latent_dim=latent_dim
+)
+cond_gan.compile(
+ d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
+ g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
+ loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
+)
+
+cond_gan.fit(dataset, epochs=20)
+
+"""
+## Interpolating between classes with the trained generator
+"""
+
+# We first extract the trained generator from our Conditional GAN.
+trained_gen = cond_gan.generator
+
+# Choose the number of intermediate images that would be generated in
+# between the interpolation + 2 (start and last images).
+num_interpolation = 9 # @param {type:"integer"}
+
+# Sample noise for the interpolation.
+interpolation_noise = keras.random.normal(shape=(1, latent_dim))
+interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)
+interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))
+
+
+def interpolate_class(first_number, second_number):
+ # Convert the start and end labels to one-hot encoded vectors.
+ first_label = keras.utils.to_categorical([first_number], num_classes)
+ second_label = keras.utils.to_categorical([second_number], num_classes)
+ first_label = ops.cast(first_label, "float32")
+ second_label = ops.cast(second_label, "float32")
+
+ # Calculate the interpolation vector between the two labels.
+ percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]
+ percent_second_label = ops.cast(percent_second_label, "float32")
+ interpolation_labels = (
+ first_label * (1 - percent_second_label) + second_label * percent_second_label
+ )
+
+ # Combine the noise and the labels and run inference with the generator.
+ noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)
+ fake = trained_gen.predict(noise_and_labels)
+ return fake
+
+
+start_class = 2 # @param {type:"slider", min:0, max:9, step:1}
+end_class = 6 # @param {type:"slider", min:0, max:9, step:1}
+
+fake_images = interpolate_class(start_class, end_class)
+
+"""
+Here, we first sample noise from a normal distribution and then we repeat that for
+`num_interpolation` times and reshape the result accordingly.
+We then distribute it uniformly for `num_interpolation`
+with the label identities being present in some proportion.
+"""
+
+fake_images *= 255.0
+converted_images = fake_images.astype(np.uint8)
+converted_images = ops.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)
+imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)
+embed.embed_file("animation.gif")
+
+"""
+We can further improve the performance of this model with recipes like
+[WGAN-GP](https://keras.io/examples/generative/wgan_gp).
+Conditional generation is also widely used in many modern image generation architectures like
+[VQ-GANs](https://arxiv.org/abs/2012.09841), [DALL-E](https://openai.com/blog/dall-e/),
+etc.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conditional-gan) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conditional-GAN).
+"""
diff --git a/knowledge_base/generative/cyclegan.py b/knowledge_base/generative/cyclegan.py
new file mode 100644
index 0000000000000000000000000000000000000000..bef76f5d9b62973527d458c1ae9de5d210880eac
--- /dev/null
+++ b/knowledge_base/generative/cyclegan.py
@@ -0,0 +1,663 @@
+"""
+Title: CycleGAN
+Author: [A_K_Nain](https://twitter.com/A_K_Nain)
+Date created: 2020/08/12
+Last modified: 2024/09/30
+Description: Implementation of CycleGAN.
+Accelerator: GPU
+"""
+
+"""
+## CycleGAN
+
+CycleGAN is a model that aims to solve the image-to-image translation
+problem. The goal of the image-to-image translation problem is to learn the
+mapping between an input image and an output image using a training set of
+aligned image pairs. However, obtaining paired examples isn't always feasible.
+CycleGAN tries to learn this mapping without requiring paired input-output images,
+using cycle-consistent adversarial networks.
+
+- [Paper](https://arxiv.org/abs/1703.10593)
+- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
+"""
+
+"""
+## Setup
+"""
+
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import keras
+from keras import layers, ops
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+autotune = tf.data.AUTOTUNE
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+"""
+## Prepare the dataset
+
+In this example, we will be using the
+[horse to zebra](https://www.tensorflow.org/datasets/catalog/cycle_gan#cycle_ganhorse2zebra)
+dataset.
+"""
+
+# Load the horse-zebra dataset using tensorflow-datasets.
+dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
+train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
+test_horses, test_zebras = dataset["testA"], dataset["testB"]
+
+# Define the standard image size.
+orig_img_size = (286, 286)
+# Size of the random crops to be used during training.
+input_img_size = (256, 256, 3)
+# Weights initializer for the layers.
+kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
+# Gamma initializer for instance normalization.
+gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
+
+buffer_size = 256
+batch_size = 1
+
+
+def normalize_img(img):
+ img = ops.cast(img, dtype=tf.float32)
+ # Map values in the range [-1, 1]
+ return (img / 127.5) - 1.0
+
+
+def preprocess_train_image(img, label):
+ # Random flip
+ img = tf.image.random_flip_left_right(img)
+ # Resize to the original size first
+ img = ops.image.resize(img, [*orig_img_size])
+ # Random crop to 256X256
+ img = tf.image.random_crop(img, size=[*input_img_size])
+ # Normalize the pixel values in the range [-1, 1]
+ img = normalize_img(img)
+ return img
+
+
+def preprocess_test_image(img, label):
+ # Only resizing and normalization for the test images.
+ img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
+ img = normalize_img(img)
+ return img
+
+
+"""
+## Create `Dataset` objects
+"""
+
+
+# Apply the preprocessing operations to the training data
+train_horses = (
+ train_horses.map(preprocess_train_image, num_parallel_calls=autotune)
+ .cache()
+ .shuffle(buffer_size)
+ .batch(batch_size)
+)
+train_zebras = (
+ train_zebras.map(preprocess_train_image, num_parallel_calls=autotune)
+ .cache()
+ .shuffle(buffer_size)
+ .batch(batch_size)
+)
+
+# Apply the preprocessing operations to the test data
+test_horses = (
+ test_horses.map(preprocess_test_image, num_parallel_calls=autotune)
+ .cache()
+ .shuffle(buffer_size)
+ .batch(batch_size)
+)
+test_zebras = (
+ test_zebras.map(preprocess_test_image, num_parallel_calls=autotune)
+ .cache()
+ .shuffle(buffer_size)
+ .batch(batch_size)
+)
+
+
+"""
+## Visualize some samples
+"""
+
+
+_, ax = plt.subplots(4, 2, figsize=(10, 15))
+for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))):
+ horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
+ zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
+ ax[i, 0].imshow(horse)
+ ax[i, 1].imshow(zebra)
+plt.show()
+
+
+"""
+## Building blocks used in the CycleGAN generators and discriminators
+"""
+
+
+class ReflectionPadding2D(layers.Layer):
+ """Implements Reflection Padding as a layer.
+
+ Args:
+ padding(tuple): Amount of padding for the
+ spatial dimensions.
+
+ Returns:
+ A padded tensor with the same type as the input tensor.
+ """
+
+ def __init__(self, padding=(1, 1), **kwargs):
+ self.padding = tuple(padding)
+ super().__init__(**kwargs)
+
+ def call(self, input_tensor, mask=None):
+ padding_width, padding_height = self.padding
+ padding_tensor = [
+ [0, 0],
+ [padding_height, padding_height],
+ [padding_width, padding_width],
+ [0, 0],
+ ]
+ return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
+
+
+def residual_block(
+ x,
+ activation,
+ kernel_initializer=kernel_init,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding="valid",
+ gamma_initializer=gamma_init,
+ use_bias=False,
+):
+ dim = x.shape[-1]
+ input_tensor = x
+
+ x = ReflectionPadding2D()(input_tensor)
+ x = layers.Conv2D(
+ dim,
+ kernel_size,
+ strides=strides,
+ kernel_initializer=kernel_initializer,
+ padding=padding,
+ use_bias=use_bias,
+ )(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
+ x = activation(x)
+
+ x = ReflectionPadding2D()(x)
+ x = layers.Conv2D(
+ dim,
+ kernel_size,
+ strides=strides,
+ kernel_initializer=kernel_initializer,
+ padding=padding,
+ use_bias=use_bias,
+ )(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
+ x = layers.add([input_tensor, x])
+ return x
+
+
+def downsample(
+ x,
+ filters,
+ activation,
+ kernel_initializer=kernel_init,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding="same",
+ gamma_initializer=gamma_init,
+ use_bias=False,
+):
+ x = layers.Conv2D(
+ filters,
+ kernel_size,
+ strides=strides,
+ kernel_initializer=kernel_initializer,
+ padding=padding,
+ use_bias=use_bias,
+ )(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
+ if activation:
+ x = activation(x)
+ return x
+
+
+def upsample(
+ x,
+ filters,
+ activation,
+ kernel_size=(3, 3),
+ strides=(2, 2),
+ padding="same",
+ kernel_initializer=kernel_init,
+ gamma_initializer=gamma_init,
+ use_bias=False,
+):
+ x = layers.Conv2DTranspose(
+ filters,
+ kernel_size,
+ strides=strides,
+ padding=padding,
+ kernel_initializer=kernel_initializer,
+ use_bias=use_bias,
+ )(x)
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
+ if activation:
+ x = activation(x)
+ return x
+
+
+"""
+## Build the generators
+
+The generator consists of downsampling blocks: nine residual blocks
+and upsampling blocks. The structure of the generator is the following:
+
+```
+c7s1-64 ==> Conv block with `relu` activation, filter size of 7
+d128 ====|
+ |-> 2 downsampling blocks
+d256 ====|
+R256 ====|
+R256 |
+R256 |
+R256 |
+R256 |-> 9 residual blocks
+R256 |
+R256 |
+R256 |
+R256 ====|
+u128 ====|
+ |-> 2 upsampling blocks
+u64 ====|
+c7s1-3 => Last conv block with `tanh` activation, filter size of 7.
+```
+"""
+
+
+def get_resnet_generator(
+ filters=64,
+ num_downsampling_blocks=2,
+ num_residual_blocks=9,
+ num_upsample_blocks=2,
+ gamma_initializer=gamma_init,
+ name=None,
+):
+ img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
+ x = ReflectionPadding2D(padding=(3, 3))(img_input)
+ x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
+ x
+ )
+ x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
+ x
+ )
+ x = layers.Activation("relu")(x)
+
+ # Downsampling
+ for _ in range(num_downsampling_blocks):
+ filters *= 2
+ x = downsample(x, filters=filters, activation=layers.Activation("relu"))
+
+ # Residual blocks
+ for _ in range(num_residual_blocks):
+ x = residual_block(x, activation=layers.Activation("relu"))
+
+ # Upsampling
+ for _ in range(num_upsample_blocks):
+ filters //= 2
+ x = upsample(x, filters, activation=layers.Activation("relu"))
+
+ # Final block
+ x = ReflectionPadding2D(padding=(3, 3))(x)
+ x = layers.Conv2D(3, (7, 7), padding="valid")(x)
+ x = layers.Activation("tanh")(x)
+
+ model = keras.models.Model(img_input, x, name=name)
+ return model
+
+
+"""
+## Build the discriminators
+
+The discriminators implement the following architecture:
+`C64->C128->C256->C512`
+"""
+
+
+def get_discriminator(
+ filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
+):
+ img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
+ x = layers.Conv2D(
+ filters,
+ (4, 4),
+ strides=(2, 2),
+ padding="same",
+ kernel_initializer=kernel_initializer,
+ )(img_input)
+ x = layers.LeakyReLU(0.2)(x)
+
+ num_filters = filters
+ for num_downsample_block in range(3):
+ num_filters *= 2
+ if num_downsample_block < 2:
+ x = downsample(
+ x,
+ filters=num_filters,
+ activation=layers.LeakyReLU(0.2),
+ kernel_size=(4, 4),
+ strides=(2, 2),
+ )
+ else:
+ x = downsample(
+ x,
+ filters=num_filters,
+ activation=layers.LeakyReLU(0.2),
+ kernel_size=(4, 4),
+ strides=(1, 1),
+ )
+
+ x = layers.Conv2D(
+ 1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
+ )(x)
+
+ model = keras.models.Model(inputs=img_input, outputs=x, name=name)
+ return model
+
+
+# Get the generators
+gen_G = get_resnet_generator(name="generator_G")
+gen_F = get_resnet_generator(name="generator_F")
+
+# Get the discriminators
+disc_X = get_discriminator(name="discriminator_X")
+disc_Y = get_discriminator(name="discriminator_Y")
+
+
+"""
+## Build the CycleGAN model
+
+We will override the `train_step()` method of the `Model` class
+for training via `fit()`.
+"""
+
+
+class CycleGan(keras.Model):
+ def __init__(
+ self,
+ generator_G,
+ generator_F,
+ discriminator_X,
+ discriminator_Y,
+ lambda_cycle=10.0,
+ lambda_identity=0.5,
+ ):
+ super().__init__()
+ self.gen_G = generator_G
+ self.gen_F = generator_F
+ self.disc_X = discriminator_X
+ self.disc_Y = discriminator_Y
+ self.lambda_cycle = lambda_cycle
+ self.lambda_identity = lambda_identity
+
+ def call(self, inputs):
+ return (
+ self.disc_X(inputs),
+ self.disc_Y(inputs),
+ self.gen_G(inputs),
+ self.gen_F(inputs),
+ )
+
+ def compile(
+ self,
+ gen_G_optimizer,
+ gen_F_optimizer,
+ disc_X_optimizer,
+ disc_Y_optimizer,
+ gen_loss_fn,
+ disc_loss_fn,
+ ):
+ super().compile()
+ self.gen_G_optimizer = gen_G_optimizer
+ self.gen_F_optimizer = gen_F_optimizer
+ self.disc_X_optimizer = disc_X_optimizer
+ self.disc_Y_optimizer = disc_Y_optimizer
+ self.generator_loss_fn = gen_loss_fn
+ self.discriminator_loss_fn = disc_loss_fn
+ self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
+ self.identity_loss_fn = keras.losses.MeanAbsoluteError()
+
+ def train_step(self, batch_data):
+ # x is Horse and y is zebra
+ real_x, real_y = batch_data
+
+ # For CycleGAN, we need to calculate different
+ # kinds of losses for the generators and discriminators.
+ # We will perform the following steps here:
+ #
+ # 1. Pass real images through the generators and get the generated images
+ # 2. Pass the generated images back to the generators to check if we
+ # can predict the original image from the generated image.
+ # 3. Do an identity mapping of the real images using the generators.
+ # 4. Pass the generated images in 1) to the corresponding discriminators.
+ # 5. Calculate the generators total loss (adversarial + cycle + identity)
+ # 6. Calculate the discriminators loss
+ # 7. Update the weights of the generators
+ # 8. Update the weights of the discriminators
+ # 9. Return the losses in a dictionary
+
+ with tf.GradientTape(persistent=True) as tape:
+ # Horse to fake zebra
+ fake_y = self.gen_G(real_x, training=True)
+ # Zebra to fake horse -> y2x
+ fake_x = self.gen_F(real_y, training=True)
+
+ # Cycle (Horse to fake zebra to fake horse): x -> y -> x
+ cycled_x = self.gen_F(fake_y, training=True)
+ # Cycle (Zebra to fake horse to fake zebra) y -> x -> y
+ cycled_y = self.gen_G(fake_x, training=True)
+
+ # Identity mapping
+ same_x = self.gen_F(real_x, training=True)
+ same_y = self.gen_G(real_y, training=True)
+
+ # Discriminator output
+ disc_real_x = self.disc_X(real_x, training=True)
+ disc_fake_x = self.disc_X(fake_x, training=True)
+
+ disc_real_y = self.disc_Y(real_y, training=True)
+ disc_fake_y = self.disc_Y(fake_y, training=True)
+
+ # Generator adversarial loss
+ gen_G_loss = self.generator_loss_fn(disc_fake_y)
+ gen_F_loss = self.generator_loss_fn(disc_fake_x)
+
+ # Generator cycle loss
+ cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
+ cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle
+
+ # Generator identity loss
+ id_loss_G = (
+ self.identity_loss_fn(real_y, same_y)
+ * self.lambda_cycle
+ * self.lambda_identity
+ )
+ id_loss_F = (
+ self.identity_loss_fn(real_x, same_x)
+ * self.lambda_cycle
+ * self.lambda_identity
+ )
+
+ # Total generator loss
+ total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
+ total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F
+
+ # Discriminator loss
+ disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
+ disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)
+
+ # Get the gradients for the generators
+ grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
+ grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)
+
+ # Get the gradients for the discriminators
+ disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
+ disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)
+
+ # Update the weights of the generators
+ self.gen_G_optimizer.apply_gradients(
+ zip(grads_G, self.gen_G.trainable_variables)
+ )
+ self.gen_F_optimizer.apply_gradients(
+ zip(grads_F, self.gen_F.trainable_variables)
+ )
+
+ # Update the weights of the discriminators
+ self.disc_X_optimizer.apply_gradients(
+ zip(disc_X_grads, self.disc_X.trainable_variables)
+ )
+ self.disc_Y_optimizer.apply_gradients(
+ zip(disc_Y_grads, self.disc_Y.trainable_variables)
+ )
+
+ return {
+ "G_loss": total_loss_G,
+ "F_loss": total_loss_F,
+ "D_X_loss": disc_X_loss,
+ "D_Y_loss": disc_Y_loss,
+ }
+
+
+"""
+## Create a callback that periodically saves generated images
+"""
+
+
+class GANMonitor(keras.callbacks.Callback):
+ """A callback to generate and save images after each epoch"""
+
+ def __init__(self, num_img=4):
+ self.num_img = num_img
+
+ def on_epoch_end(self, epoch, logs=None):
+ _, ax = plt.subplots(4, 2, figsize=(12, 12))
+ for i, img in enumerate(test_horses.take(self.num_img)):
+ prediction = self.model.gen_G(img)[0].numpy()
+ prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
+ img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
+
+ ax[i, 0].imshow(img)
+ ax[i, 1].imshow(prediction)
+ ax[i, 0].set_title("Input image")
+ ax[i, 1].set_title("Translated image")
+ ax[i, 0].axis("off")
+ ax[i, 1].axis("off")
+
+ prediction = keras.utils.array_to_img(prediction)
+ prediction.save(
+ "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
+ )
+ plt.show()
+ plt.close()
+
+
+"""
+## Train the end-to-end model
+"""
+
+
+# Loss function for evaluating adversarial loss
+adv_loss_fn = keras.losses.MeanSquaredError()
+
+# Define the loss function for the generators
+
+
+def generator_loss_fn(fake):
+ fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
+ return fake_loss
+
+
+# Define the loss function for the discriminators
+def discriminator_loss_fn(real, fake):
+ real_loss = adv_loss_fn(ops.ones_like(real), real)
+ fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
+ return (real_loss + fake_loss) * 0.5
+
+
+# Create cycle gan model
+cycle_gan_model = CycleGan(
+ generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
+)
+
+# Compile the model
+cycle_gan_model.compile(
+ gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
+ gen_loss_fn=generator_loss_fn,
+ disc_loss_fn=discriminator_loss_fn,
+)
+# Callbacks
+plotter = GANMonitor()
+checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
+model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_filepath, save_weights_only=True
+)
+
+# Here we will train the model for just one epoch as each epoch takes around
+# 7 minutes on a single P100 backed machine.
+cycle_gan_model.fit(
+ tf.data.Dataset.zip((train_horses, train_zebras)),
+ epochs=90,
+ callbacks=[plotter, model_checkpoint_callback],
+)
+
+"""
+Test the performance of the model.
+"""
+
+
+# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
+
+
+# Load the checkpoints
+cycle_gan_model.load_weights(checkpoint_filepath)
+print("Weights loaded successfully")
+
+_, ax = plt.subplots(4, 2, figsize=(10, 15))
+for i, img in enumerate(test_horses.take(4)):
+ prediction = cycle_gan_model.gen_G(img, training=False)[0].numpy()
+ prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
+ img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
+
+ ax[i, 0].imshow(img)
+ ax[i, 1].imshow(prediction)
+ ax[i, 0].set_title("Input image")
+ ax[i, 0].set_title("Input image")
+ ax[i, 1].set_title("Translated image")
+ ax[i, 0].axis("off")
+ ax[i, 1].axis("off")
+
+ prediction = keras.utils.array_to_img(prediction)
+ prediction.save("predicted_img_{i}.png".format(i=i))
+plt.tight_layout()
+plt.show()
diff --git a/knowledge_base/generative/dcgan_overriding_train_step.py b/knowledge_base/generative/dcgan_overriding_train_step.py
new file mode 100644
index 0000000000000000000000000000000000000000..26779c76cf0fc5ab7e8fd48f9d80541e9cf43029
--- /dev/null
+++ b/knowledge_base/generative/dcgan_overriding_train_step.py
@@ -0,0 +1,233 @@
+"""
+Title: DCGAN to generate face images
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2019/04/29
+Last modified: 2023/12/21
+Description: A simple DCGAN trained using `fit()` by overriding `train_step` on CelebA images.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import keras
+import tensorflow as tf
+
+from keras import layers
+from keras import ops
+import matplotlib.pyplot as plt
+import os
+import gdown
+from zipfile import ZipFile
+
+
+"""
+## Prepare CelebA data
+
+We'll use face images from the CelebA dataset, resized to 64x64.
+"""
+
+os.makedirs("celeba_gan")
+
+url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
+output = "celeba_gan/data.zip"
+gdown.download(url, output, quiet=True)
+
+with ZipFile("celeba_gan/data.zip", "r") as zipobj:
+ zipobj.extractall("celeba_gan")
+
+"""
+Create a dataset from our folder, and rescale the images to the [0-1] range:
+"""
+
+dataset = keras.utils.image_dataset_from_directory(
+ "celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
+)
+dataset = dataset.map(lambda x: x / 255.0)
+
+
+"""
+Let's display a sample image:
+"""
+
+
+for x in dataset:
+ plt.axis("off")
+ plt.imshow((x.numpy() * 255).astype("int32")[0])
+ break
+
+
+"""
+## Create the discriminator
+
+It maps a 64x64 image to a binary classification score.
+"""
+
+discriminator = keras.Sequential(
+ [
+ keras.Input(shape=(64, 64, 3)),
+ layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Flatten(),
+ layers.Dropout(0.2),
+ layers.Dense(1, activation="sigmoid"),
+ ],
+ name="discriminator",
+)
+discriminator.summary()
+
+"""
+## Create the generator
+
+It mirrors the discriminator, replacing `Conv2D` layers with `Conv2DTranspose` layers.
+"""
+
+latent_dim = 128
+
+generator = keras.Sequential(
+ [
+ keras.Input(shape=(latent_dim,)),
+ layers.Dense(8 * 8 * 128),
+ layers.Reshape((8, 8, 128)),
+ layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
+ ],
+ name="generator",
+)
+generator.summary()
+
+"""
+## Override `train_step`
+"""
+
+
+class GAN(keras.Model):
+ def __init__(self, discriminator, generator, latent_dim):
+ super().__init__()
+ self.discriminator = discriminator
+ self.generator = generator
+ self.latent_dim = latent_dim
+ self.seed_generator = keras.random.SeedGenerator(1337)
+
+ def compile(self, d_optimizer, g_optimizer, loss_fn):
+ super().compile()
+ self.d_optimizer = d_optimizer
+ self.g_optimizer = g_optimizer
+ self.loss_fn = loss_fn
+ self.d_loss_metric = keras.metrics.Mean(name="d_loss")
+ self.g_loss_metric = keras.metrics.Mean(name="g_loss")
+
+ @property
+ def metrics(self):
+ return [self.d_loss_metric, self.g_loss_metric]
+
+ def train_step(self, real_images):
+ # Sample random points in the latent space
+ batch_size = ops.shape(real_images)[0]
+ random_latent_vectors = keras.random.normal(
+ shape=(batch_size, self.latent_dim), seed=self.seed_generator
+ )
+
+ # Decode them to fake images
+ generated_images = self.generator(random_latent_vectors)
+
+ # Combine them with real images
+ combined_images = ops.concatenate([generated_images, real_images], axis=0)
+
+ # Assemble labels discriminating real from fake images
+ labels = ops.concatenate(
+ [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
+ )
+ # Add random noise to the labels - important trick!
+ labels += 0.05 * tf.random.uniform(tf.shape(labels))
+
+ # Train the discriminator
+ with tf.GradientTape() as tape:
+ predictions = self.discriminator(combined_images)
+ d_loss = self.loss_fn(labels, predictions)
+ grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
+ self.d_optimizer.apply_gradients(
+ zip(grads, self.discriminator.trainable_weights)
+ )
+
+ # Sample random points in the latent space
+ random_latent_vectors = keras.random.normal(
+ shape=(batch_size, self.latent_dim), seed=self.seed_generator
+ )
+
+ # Assemble labels that say "all real images"
+ misleading_labels = ops.zeros((batch_size, 1))
+
+ # Train the generator (note that we should *not* update the weights
+ # of the discriminator)!
+ with tf.GradientTape() as tape:
+ predictions = self.discriminator(self.generator(random_latent_vectors))
+ g_loss = self.loss_fn(misleading_labels, predictions)
+ grads = tape.gradient(g_loss, self.generator.trainable_weights)
+ self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
+
+ # Update metrics
+ self.d_loss_metric.update_state(d_loss)
+ self.g_loss_metric.update_state(g_loss)
+ return {
+ "d_loss": self.d_loss_metric.result(),
+ "g_loss": self.g_loss_metric.result(),
+ }
+
+
+"""
+## Create a callback that periodically saves generated images
+"""
+
+
+class GANMonitor(keras.callbacks.Callback):
+ def __init__(self, num_img=3, latent_dim=128):
+ self.num_img = num_img
+ self.latent_dim = latent_dim
+ self.seed_generator = keras.random.SeedGenerator(42)
+
+ def on_epoch_end(self, epoch, logs=None):
+ random_latent_vectors = keras.random.normal(
+ shape=(self.num_img, self.latent_dim), seed=self.seed_generator
+ )
+ generated_images = self.model.generator(random_latent_vectors)
+ generated_images *= 255
+ generated_images.numpy()
+ for i in range(self.num_img):
+ img = keras.utils.array_to_img(generated_images[i])
+ img.save("generated_img_%03d_%d.png" % (epoch, i))
+
+
+"""
+## Train the end-to-end model
+"""
+
+epochs = 1 # In practice, use ~100 epochs
+
+gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
+gan.compile(
+ d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
+ g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
+ loss_fn=keras.losses.BinaryCrossentropy(),
+)
+
+gan.fit(
+ dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
+)
+
+"""
+Some of the last generated images around epoch 30
+(results keep improving after that):
+
+
+"""
diff --git a/knowledge_base/generative/ddim.py b/knowledge_base/generative/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd533ee4ff6ec03b588167ac5c82bdf89a4fe392
--- /dev/null
+++ b/knowledge_base/generative/ddim.py
@@ -0,0 +1,868 @@
+"""
+Title: Denoising Diffusion Implicit Models
+Author: [Andrรกs Bรฉres](https://www.linkedin.com/in/andras-beres-789190210)
+Date created: 2022/06/24
+Last modified: 2022/06/24
+Description: Generating images of flowers with denoising diffusion implicit models.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+### What are diffusion models?
+
+Recently, [denoising diffusion models](https://arxiv.org/abs/2006.11239), including
+[score-based generative models](https://arxiv.org/abs/1907.05600), gained popularity as a
+powerful class of generative models, that can [rival](https://arxiv.org/abs/2105.05233)
+even [generative adversarial networks (GANs)](https://arxiv.org/abs/1406.2661) in image
+synthesis quality. They tend to generate more diverse samples, while being stable to
+train and easy to scale. Recent large diffusion models, such as
+[DALL-E 2](https://openai.com/dall-e-2/) and [Imagen](https://imagen.research.google/),
+have shown incredible text-to-image generation capability. One of their drawbacks is
+however, that they are slower to sample from, because they require multiple forward passes
+for generating an image.
+
+Diffusion refers to the process of turning a structured signal (an image) into noise
+step-by-step. By simulating diffusion, we can generate noisy images from our training
+images, and can train a neural network to try to denoise them. Using the trained network
+we can simulate the opposite of diffusion, reverse diffusion, which is the process of an
+image emerging from noise.
+
+
+
+One-sentence summary: **diffusion models are trained to denoise noisy images, and can
+generate images by iteratively denoising pure noise.**
+
+### Goal of this example
+
+This code example intends to be a minimal but feature-complete (with a generation quality
+metric) implementation of diffusion models, with modest compute requirements and
+reasonable performance. My implementation choices and hyperparameter tuning were done
+with these goals in mind.
+
+Since currently the literature of diffusion models is
+[mathematically quite complex](https://arxiv.org/abs/2206.00364)
+with multiple theoretical frameworks
+([score matching](https://arxiv.org/abs/1907.05600),
+[differential equations](https://arxiv.org/abs/2011.13456),
+[Markov chains](https://arxiv.org/abs/2006.11239)) and sometimes even
+[conflicting notations (see Appendix C.2)](https://arxiv.org/abs/2010.02502),
+it can be daunting trying to understand
+them. My view of these models in this example will be that they learn to separate a
+noisy image into its image and Gaussian noise components.
+
+In this example I made effort to break down all long mathematical expressions into
+digestible pieces and gave all variables explanatory names. I also included numerous
+links to relevant literature to help interested readers dive deeper into the topic, in
+the hope that this code example will become a good starting point for practitioners
+learning about diffusion models.
+
+In the following sections, we will implement a continuous time version of
+[Denoising Diffusion Implicit Models (DDIMs)](https://arxiv.org/abs/2010.02502)
+with deterministic sampling.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import math
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+import keras
+from keras import layers
+from keras import ops
+
+"""
+## Hyperparameters
+"""
+
+# data
+dataset_name = "oxford_flowers102"
+dataset_repetitions = 5
+num_epochs = 1 # train for at least 50 epochs for good results
+image_size = 64
+# KID = Kernel Inception Distance, see related section
+kid_image_size = 75
+kid_diffusion_steps = 5
+plot_diffusion_steps = 20
+
+# sampling
+min_signal_rate = 0.02
+max_signal_rate = 0.95
+
+# architecture
+embedding_dims = 32
+embedding_max_frequency = 1000.0
+widths = [32, 64, 96, 128]
+block_depth = 2
+
+# optimization
+batch_size = 64
+ema = 0.999
+learning_rate = 1e-3
+weight_decay = 1e-4
+
+"""
+## Data pipeline
+
+We will use the
+[Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102)
+dataset for
+generating images of flowers, which is a diverse natural dataset containing around 8,000
+images. Unfortunately the official splits are imbalanced, as most of the images are
+contained in the test split. We create new splits (80% train, 20% validation) using the
+[Tensorflow Datasets slicing API](https://www.tensorflow.org/datasets/splits). We apply
+center crops as preprocessing, and repeat the dataset multiple times (reason given in the
+next section).
+"""
+
+
+def preprocess_image(data):
+ # center crop image
+ height = ops.shape(data["image"])[0]
+ width = ops.shape(data["image"])[1]
+ crop_size = ops.minimum(height, width)
+ image = tf.image.crop_to_bounding_box(
+ data["image"],
+ (height - crop_size) // 2,
+ (width - crop_size) // 2,
+ crop_size,
+ crop_size,
+ )
+
+ # resize and clip
+ # for image downsampling it is important to turn on antialiasing
+ image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
+ return ops.clip(image / 255.0, 0.0, 1.0)
+
+
+def prepare_dataset(split):
+ # the validation dataset is shuffled as well, because data order matters
+ # for the KID estimation
+ return (
+ tfds.load(dataset_name, split=split, shuffle_files=True)
+ .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
+ .cache()
+ .repeat(dataset_repetitions)
+ .shuffle(10 * batch_size)
+ .batch(batch_size, drop_remainder=True)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+ )
+
+
+# load dataset
+train_dataset = prepare_dataset("train[:80%]+validation[:80%]+test[:80%]")
+val_dataset = prepare_dataset("train[80%:]+validation[80%:]+test[80%:]")
+
+"""
+## Kernel inception distance
+
+[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) is an image quality
+metric which was proposed as a replacement for the popular
+[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500).
+I prefer KID to FID because it is simpler to
+implement, can be estimated per-batch, and is computationally lighter. More details
+[here](https://keras.io/examples/generative/gan_ada/#kernel-inception-distance).
+
+In this example, the images are evaluated at the minimal possible resolution of the
+Inception network (75x75 instead of 299x299), and the metric is only measured on the
+validation set for computational efficiency. We also limit the number of sampling steps
+at evaluation to 5 for the same reason.
+
+Since the dataset is relatively small, we go over the train and validation splits
+multiple times per epoch, because the KID estimation is noisy and compute-intensive, so
+we want to evaluate only after many iterations, but for many iterations.
+
+"""
+
+
+@keras.saving.register_keras_serializable()
+class KID(keras.metrics.Metric):
+ def __init__(self, name, **kwargs):
+ super().__init__(name=name, **kwargs)
+
+ # KID is estimated per batch and is averaged across batches
+ self.kid_tracker = keras.metrics.Mean(name="kid_tracker")
+
+ # a pretrained InceptionV3 is used without its classification layer
+ # transform the pixel values to the 0-255 range, then use the same
+ # preprocessing as during pretraining
+ self.encoder = keras.Sequential(
+ [
+ keras.Input(shape=(image_size, image_size, 3)),
+ layers.Rescaling(255.0),
+ layers.Resizing(height=kid_image_size, width=kid_image_size),
+ layers.Lambda(keras.applications.inception_v3.preprocess_input),
+ keras.applications.InceptionV3(
+ include_top=False,
+ input_shape=(kid_image_size, kid_image_size, 3),
+ weights="imagenet",
+ ),
+ layers.GlobalAveragePooling2D(),
+ ],
+ name="inception_encoder",
+ )
+
+ def polynomial_kernel(self, features_1, features_2):
+ feature_dimensions = ops.cast(ops.shape(features_1)[1], dtype="float32")
+ return (
+ features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
+ ) ** 3.0
+
+ def update_state(self, real_images, generated_images, sample_weight=None):
+ real_features = self.encoder(real_images, training=False)
+ generated_features = self.encoder(generated_images, training=False)
+
+ # compute polynomial kernels using the two sets of features
+ kernel_real = self.polynomial_kernel(real_features, real_features)
+ kernel_generated = self.polynomial_kernel(
+ generated_features, generated_features
+ )
+ kernel_cross = self.polynomial_kernel(real_features, generated_features)
+
+ # estimate the squared maximum mean discrepancy using the average kernel values
+ batch_size = real_features.shape[0]
+ batch_size_f = ops.cast(batch_size, dtype="float32")
+ mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
+ batch_size_f * (batch_size_f - 1.0)
+ )
+ mean_kernel_generated = ops.sum(
+ kernel_generated * (1.0 - ops.eye(batch_size))
+ ) / (batch_size_f * (batch_size_f - 1.0))
+ mean_kernel_cross = ops.mean(kernel_cross)
+ kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
+
+ # update the average KID estimate
+ self.kid_tracker.update_state(kid)
+
+ def result(self):
+ return self.kid_tracker.result()
+
+ def reset_state(self):
+ self.kid_tracker.reset_state()
+
+
+"""
+## Network architecture
+
+Here we specify the architecture of the neural network that we will use for denoising. We
+build a [U-Net](https://arxiv.org/abs/1505.04597) with identical input and output
+dimensions. U-Net is a popular semantic segmentation architecture, whose main idea is
+that it progressively downsamples and then upsamples its input image, and adds skip
+connections between layers having the same resolution. These help with gradient flow and
+avoid introducing a representation bottleneck, unlike usual
+[autoencoders](https://www.deeplearningbook.org/contents/autoencoders.html). Based on
+this, one can view
+[diffusion models as denoising autoencoders](https://benanne.github.io/2022/01/31/diffusion.html)
+without a bottleneck.
+
+The network takes two inputs, the noisy images and the variances of their noise
+components. The latter is required since denoising a signal requires different operations
+at different levels of noise. We transform the noise variances using sinusoidal
+embeddings, similarly to positional encodings used both in
+[transformers](https://arxiv.org/abs/1706.03762) and
+[NeRF](https://arxiv.org/abs/2003.08934). This helps the network to be
+[highly sensitive](https://arxiv.org/abs/2006.10739) to the noise level, which is
+crucial for good performance. We implement sinusoidal embeddings using a
+[Lambda layer](https://keras.io/api/layers/core_layers/lambda/).
+
+Some other considerations:
+
+* We build the network using the
+[Keras Functional API](https://keras.io/guides/functional_api/), and use
+[closures](https://twitter.com/fchollet/status/1441927912836321280) to build blocks of
+layers in a consistent style.
+* [Diffusion models](https://arxiv.org/abs/2006.11239) embed the index of the timestep of
+the diffusion process instead of the noise variance, while
+[score-based models (Table 1)](https://arxiv.org/abs/2206.00364)
+usually use some function of the noise level. I
+prefer the latter so that we can change the sampling schedule at inference time, without
+retraining the network.
+* [Diffusion models](https://arxiv.org/abs/2006.11239) input the embedding to each
+convolution block separately. We only input it at the start of the network for
+simplicity, which in my experience barely decreases performance, because the skip and
+residual connections help the information propagate through the network properly.
+* In the literature it is common to use
+[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/)
+at lower resolutions for better global coherence. I omitted it for simplicity.
+* We disable the learnable center and scale parameters of the batch normalization layers,
+since the following convolution layers make them redundant.
+* We initialize the last convolution's kernel to all zeros as a good practice, making the
+network predict only zeros after initialization, which is the mean of its targets. This
+will improve behaviour at the start of training and make the mean squared error loss
+start at exactly 1.
+"""
+
+
+@keras.saving.register_keras_serializable()
+def sinusoidal_embedding(x):
+ embedding_min_frequency = 1.0
+ frequencies = ops.exp(
+ ops.linspace(
+ ops.log(embedding_min_frequency),
+ ops.log(embedding_max_frequency),
+ embedding_dims // 2,
+ )
+ )
+ angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32")
+ embeddings = ops.concatenate(
+ [ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=3
+ )
+ return embeddings
+
+
+def ResidualBlock(width):
+ def apply(x):
+ input_width = x.shape[3]
+ if input_width == width:
+ residual = x
+ else:
+ residual = layers.Conv2D(width, kernel_size=1)(x)
+ x = layers.BatchNormalization(center=False, scale=False)(x)
+ x = layers.Conv2D(width, kernel_size=3, padding="same", activation="swish")(x)
+ x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
+ x = layers.Add()([x, residual])
+ return x
+
+ return apply
+
+
+def DownBlock(width, block_depth):
+ def apply(x):
+ x, skips = x
+ for _ in range(block_depth):
+ x = ResidualBlock(width)(x)
+ skips.append(x)
+ x = layers.AveragePooling2D(pool_size=2)(x)
+ return x
+
+ return apply
+
+
+def UpBlock(width, block_depth):
+ def apply(x):
+ x, skips = x
+ x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
+ for _ in range(block_depth):
+ x = layers.Concatenate()([x, skips.pop()])
+ x = ResidualBlock(width)(x)
+ return x
+
+ return apply
+
+
+def get_network(image_size, widths, block_depth):
+ noisy_images = keras.Input(shape=(image_size, image_size, 3))
+ noise_variances = keras.Input(shape=(1, 1, 1))
+
+ e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32))(noise_variances)
+ e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)
+
+ x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
+ x = layers.Concatenate()([x, e])
+
+ skips = []
+ for width in widths[:-1]:
+ x = DownBlock(width, block_depth)([x, skips])
+
+ for _ in range(block_depth):
+ x = ResidualBlock(widths[-1])(x)
+
+ for width in reversed(widths[:-1]):
+ x = UpBlock(width, block_depth)([x, skips])
+
+ x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)
+
+ return keras.Model([noisy_images, noise_variances], x, name="residual_unet")
+
+
+"""
+This showcases the power of the Functional API. Note how we built a relatively complex
+U-Net with skip connections, residual blocks, multiple inputs, and sinusoidal embeddings
+in 80 lines of code!
+"""
+
+"""
+## Diffusion model
+
+### Diffusion schedule
+
+Let us say, that a diffusion process starts at time = 0, and ends at time = 1. This
+variable will be called diffusion time, and can be either discrete (common in diffusion
+models) or continuous (common in score-based models). I choose the latter, so that the
+number of sampling steps can be changed at inference time.
+
+We need to have a function that tells us at each point in the diffusion process the noise
+levels and signal levels of the noisy image corresponding to the actual diffusion time.
+This will be called the diffusion schedule (see `diffusion_schedule()`).
+
+This schedule outputs two quantities: the `noise_rate` and the `signal_rate`
+(corresponding to sqrt(1 - alpha) and sqrt(alpha) in the DDIM paper, respectively). We
+generate the noisy image by weighting the random noise and the training image by their
+corresponding rates and adding them together.
+
+Since the (standard normal) random noises and the (normalized) images both have zero mean
+and unit variance, the noise rate and signal rate can be interpreted as the standard
+deviation of their components in the noisy image, while the squares of their rates can be
+interpreted as their variance (or their power in the signal processing sense). The rates
+will always be set so that their squared sum is 1, meaning that the noisy images will
+always have unit variance, just like its unscaled components.
+
+We will use a simplified, continuous version of the
+[cosine schedule (Section 3.2)](https://arxiv.org/abs/2102.09672),
+that is quite commonly used in the literature.
+This schedule is symmetric, slow towards the start and end of the diffusion process, and
+it also has a nice geometric interpretation, using the
+[trigonometric properties of the unit circle](https://en.wikipedia.org/wiki/Unit_circle#/media/File:Circle-trig6.svg):
+
+
+
+### Training process
+
+The training procedure (see `train_step()` and `denoise()`) of denoising diffusion models
+is the following: we sample random diffusion times uniformly, and mix the training images
+with random gaussian noises at rates corresponding to the diffusion times. Then, we train
+the model to separate the noisy image to its two components.
+
+Usually, the neural network is trained to predict the unscaled noise component, from
+which the predicted image component can be calculated using the signal and noise rates.
+Pixelwise
+[mean squared error](https://keras.io/api/losses/regression_losses/#mean_squared_error-function) should
+be used theoretically, however I recommend using
+[mean absolute error](https://keras.io/api/losses/regression_losses/#mean_absolute_error-function)
+instead (similarly to
+[this](https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L371)
+implementation), which produces better results on this dataset.
+
+### Sampling (reverse diffusion)
+
+When sampling (see `reverse_diffusion()`), at each step we take the previous estimate of
+the noisy image and separate it into image and noise using our network. Then we recombine
+these components using the signal and noise rate of the following step.
+
+Though a similar view is shown in
+[Equation 12 of DDIMs](https://arxiv.org/abs/2010.02502), I believe the above explanation
+of the sampling equation is not widely known.
+
+This example only implements the deterministic sampling procedure from DDIM, which
+corresponds to *eta = 0* in the paper. One can also use stochastic sampling (in which
+case the model becomes a
+[Denoising Diffusion Probabilistic Model (DDPM)](https://arxiv.org/abs/2006.11239)),
+where a part of the predicted noise is
+replaced with the same or larger amount of random noise
+([see Equation 16 and below](https://arxiv.org/abs/2010.02502)).
+
+Stochastic sampling can be used without retraining the network (since both models are
+trained the same way), and it can improve sample quality, while on the other hand
+requiring more sampling steps usually.
+"""
+
+
+@keras.saving.register_keras_serializable()
+class DiffusionModel(keras.Model):
+ def __init__(self, image_size, widths, block_depth):
+ super().__init__()
+
+ self.normalizer = layers.Normalization()
+ self.network = get_network(image_size, widths, block_depth)
+ self.ema_network = keras.models.clone_model(self.network)
+
+ def compile(self, **kwargs):
+ super().compile(**kwargs)
+
+ self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
+ self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
+ self.kid = KID(name="kid")
+
+ @property
+ def metrics(self):
+ return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]
+
+ def denormalize(self, images):
+ # convert the pixel values back to 0-1 range
+ images = self.normalizer.mean + images * self.normalizer.variance**0.5
+ return ops.clip(images, 0.0, 1.0)
+
+ def diffusion_schedule(self, diffusion_times):
+ # diffusion times -> angles
+ start_angle = ops.cast(ops.arccos(max_signal_rate), "float32")
+ end_angle = ops.cast(ops.arccos(min_signal_rate), "float32")
+
+ diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
+
+ # angles -> signal and noise rates
+ signal_rates = ops.cos(diffusion_angles)
+ noise_rates = ops.sin(diffusion_angles)
+ # note that their squared sum is always: sin^2(x) + cos^2(x) = 1
+
+ return noise_rates, signal_rates
+
+ def denoise(self, noisy_images, noise_rates, signal_rates, training):
+ # the exponential moving average weights are used at evaluation
+ if training:
+ network = self.network
+ else:
+ network = self.ema_network
+
+ # predict noise component and calculate the image component using it
+ pred_noises = network([noisy_images, noise_rates**2], training=training)
+ pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
+
+ return pred_noises, pred_images
+
+ def reverse_diffusion(self, initial_noise, diffusion_steps):
+ # reverse diffusion = sampling
+ num_images = initial_noise.shape[0]
+ step_size = 1.0 / diffusion_steps
+
+ # important line:
+ # at the first sampling step, the "noisy image" is pure noise
+ # but its signal rate is assumed to be nonzero (min_signal_rate)
+ next_noisy_images = initial_noise
+ for step in range(diffusion_steps):
+ noisy_images = next_noisy_images
+
+ # separate the current noisy image to its components
+ diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size
+ noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
+ pred_noises, pred_images = self.denoise(
+ noisy_images, noise_rates, signal_rates, training=False
+ )
+ # network used in eval mode
+
+ # remix the predicted components using the next signal and noise rates
+ next_diffusion_times = diffusion_times - step_size
+ next_noise_rates, next_signal_rates = self.diffusion_schedule(
+ next_diffusion_times
+ )
+ next_noisy_images = (
+ next_signal_rates * pred_images + next_noise_rates * pred_noises
+ )
+ # this new noisy image will be used in the next step
+
+ return pred_images
+
+ def generate(self, num_images, diffusion_steps):
+ # noise -> images -> denormalized images
+ initial_noise = keras.random.normal(
+ shape=(num_images, image_size, image_size, 3)
+ )
+ generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
+ generated_images = self.denormalize(generated_images)
+ return generated_images
+
+ def train_step(self, images):
+ # normalize images to have standard deviation of 1, like the noises
+ images = self.normalizer(images, training=True)
+ noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))
+
+ # sample uniform random diffusion times
+ diffusion_times = keras.random.uniform(
+ shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
+ )
+ noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
+ # mix the images with noises accordingly
+ noisy_images = signal_rates * images + noise_rates * noises
+
+ with tf.GradientTape() as tape:
+ # train the network to separate noisy images to their components
+ pred_noises, pred_images = self.denoise(
+ noisy_images, noise_rates, signal_rates, training=True
+ )
+
+ noise_loss = self.loss(noises, pred_noises) # used for training
+ image_loss = self.loss(images, pred_images) # only used as metric
+
+ gradients = tape.gradient(noise_loss, self.network.trainable_weights)
+ self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
+
+ self.noise_loss_tracker.update_state(noise_loss)
+ self.image_loss_tracker.update_state(image_loss)
+
+ # track the exponential moving averages of weights
+ for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
+ ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
+
+ # KID is not measured during the training phase for computational efficiency
+ return {m.name: m.result() for m in self.metrics[:-1]}
+
+ def test_step(self, images):
+ # normalize images to have standard deviation of 1, like the noises
+ images = self.normalizer(images, training=False)
+ noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))
+
+ # sample uniform random diffusion times
+ diffusion_times = keras.random.uniform(
+ shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
+ )
+ noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
+ # mix the images with noises accordingly
+ noisy_images = signal_rates * images + noise_rates * noises
+
+ # use the network to separate noisy images to their components
+ pred_noises, pred_images = self.denoise(
+ noisy_images, noise_rates, signal_rates, training=False
+ )
+
+ noise_loss = self.loss(noises, pred_noises)
+ image_loss = self.loss(images, pred_images)
+
+ self.image_loss_tracker.update_state(image_loss)
+ self.noise_loss_tracker.update_state(noise_loss)
+
+ # measure KID between real and generated images
+ # this is computationally demanding, kid_diffusion_steps has to be small
+ images = self.denormalize(images)
+ generated_images = self.generate(
+ num_images=batch_size, diffusion_steps=kid_diffusion_steps
+ )
+ self.kid.update_state(images, generated_images)
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
+ # plot random generated images for visual evaluation of generation quality
+ generated_images = self.generate(
+ num_images=num_rows * num_cols,
+ diffusion_steps=plot_diffusion_steps,
+ )
+
+ plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
+ for row in range(num_rows):
+ for col in range(num_cols):
+ index = row * num_cols + col
+ plt.subplot(num_rows, num_cols, index + 1)
+ plt.imshow(generated_images[index])
+ plt.axis("off")
+ plt.tight_layout()
+ plt.show()
+ plt.close()
+
+
+"""
+## Training
+"""
+
+# create and compile the model
+model = DiffusionModel(image_size, widths, block_depth)
+# below tensorflow 2.9:
+# pip install tensorflow_addons
+# import tensorflow_addons as tfa
+# optimizer=tfa.optimizers.AdamW
+model.compile(
+ optimizer=keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ ),
+ loss=keras.losses.mean_absolute_error,
+)
+# pixelwise mean absolute error is used as loss
+
+# save the best model based on the validation KID metric
+checkpoint_path = "checkpoints/diffusion_model.weights.h5"
+checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_path,
+ save_weights_only=True,
+ monitor="val_kid",
+ mode="min",
+ save_best_only=True,
+)
+
+# calculate mean and variance of training dataset for normalization
+model.normalizer.adapt(train_dataset)
+
+# run training and plot generated images periodically
+model.fit(
+ train_dataset,
+ epochs=num_epochs,
+ validation_data=val_dataset,
+ callbacks=[
+ keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
+ checkpoint_callback,
+ ],
+)
+
+"""
+## Inference
+"""
+
+# load the best model and generate images
+model.load_weights(checkpoint_path)
+model.plot_images()
+
+"""
+## Results
+
+By running the training for at least 50 epochs (takes 2 hours on a T4 GPU and 30 minutes
+on an A100 GPU), one can get high quality image generations using this code example.
+
+The evolution of a batch of images over a 80 epoch training (color artifacts are due to
+GIF compression):
+
+
+
+Images generated using between 1 and 20 sampling steps from the same initial noise:
+
+
+
+Interpolation (spherical) between initial noise samples:
+
+
+
+Deterministic sampling process (noisy images on top, predicted images on bottom, 40
+steps):
+
+
+
+Stochastic sampling process (noisy images on top, predicted images on bottom, 80 steps):
+
+
+
+"""
+
+"""
+## Lessons learned
+
+During preparation for this code example I have run numerous experiments using
+[this repository](https://github.com/beresandras/clear-diffusion-keras).
+In this section I list
+the lessons learned and my recommendations in my subjective order of importance.
+
+### Algorithmic tips
+
+* **min. and max. signal rates**: I found the min. signal rate to be an important
+hyperparameter. Setting it too low will make the generated images oversaturated, while
+setting it too high will make them undersaturated. I recommend tuning it carefully. Also,
+setting it to 0 will lead to a division by zero error. The max. signal rate can be set to
+1, but I found that setting it lower slightly improves generation quality.
+* **loss function**: While large models tend to use mean squared error (MSE) loss, I
+recommend using mean absolute error (MAE) on this dataset. In my experience MSE loss
+generates more diverse samples (it also seems to hallucinate more
+[Section 3](https://arxiv.org/abs/2111.05826)), while MAE loss leads to smoother images.
+I recommend trying both.
+* **weight decay**: I did occasionally run into diverged trainings when scaling up the
+model, and found that weight decay helps in avoiding instabilities at a low performance
+cost. This is why I use
+[AdamW](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/experimental/AdamW)
+instead of [Adam](https://keras.io/api/optimizers/adam/) in this example.
+* **exponential moving average of weights**: This helps to reduce the variance of the KID
+metric, and helps in averaging out short-term changes during training.
+* **image augmentations**: Though I did not use image augmentations in this example, in
+my experience adding horizontal flips to the training increases generation performance,
+while random crops do not. Since we use a supervised denoising loss, overfitting can be
+an issue, so image augmentations might be important on small datasets. One should also be
+careful not to use
+[leaky augmentations](https://keras.io/examples/generative/gan_ada/#invertible-data-augmentation),
+which can be done following
+[this method (end of Section 5)](https://arxiv.org/abs/2206.00364) for instance.
+* **data normalization**: In the literature the pixel values of images are usually
+converted to the -1 to 1 range. For theoretical correctness, I normalize the images to
+have zero mean and unit variance instead, exactly like the random noises.
+* **noise level input**: I chose to input the noise variance to the network, as it is
+symmetrical under our sampling schedule. One could also input the noise rate (similar
+performance), the signal rate (lower performance), or even the
+[log-signal-to-noise ratio (Appendix B.1)](https://arxiv.org/abs/2107.00630)
+(did not try, as its range is highly
+dependent on the min. and max. signal rates, and would require adjusting the min.
+embedding frequency accordingly).
+* **gradient clipping**: Using global gradient clipping with a value of 1 can help with
+training stability for large models, but decreased performance significantly in my
+experience.
+* **residual connection downscaling**: For
+[deeper models (Appendix B)](https://arxiv.org/abs/2205.11487), scaling the residual
+connections with 1/sqrt(2) can be helpful, but did not help in my case.
+* **learning rate**: For me, [Adam optimizer's](https://keras.io/api/optimizers/adam/)
+default learning rate of 1e-3 worked very well, but lower learning rates are more common
+in the [literature (Tables 11-13)](https://arxiv.org/abs/2105.05233).
+
+### Architectural tips
+
+* **sinusoidal embedding**: Using sinusoidal embeddings on the noise level input of the
+network is crucial for good performance. I recommend setting the min. embedding frequency
+to the reciprocal of the range of this input, and since we use the noise variance in this
+example, it can be left always at 1. The max. embedding frequency controls the smallest
+change in the noise variance that the network will be sensitive to, and the embedding
+dimensions set the number of frequency components in the embedding. In my experience the
+performance is not too sensitive to these values.
+* **skip connections**: Using skip connections in the network architecture is absolutely
+critical, without them the model will fail to learn to denoise at a good performance.
+* **residual connections**: In my experience residual connections also significantly
+improve performance, but this might be due to the fact that we only input the noise
+level embeddings to the first layer of the network instead of to all of them.
+* **normalization**: When scaling up the model, I did occasionally encounter diverged
+trainings, using normalization layers helped to mitigate this issue. In the literature it
+is common to use
+[GroupNormalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization)
+(with 8 groups for example) or
+[LayerNormalization](https://keras.io/api/layers/normalization_layers/layer_normalization/)
+in the network, I however chose to use
+[BatchNormalization](https://keras.io/api/layers/normalization_layers/batch_normalization/),
+as it gave similar benefits in my experiments but was computationally lighter.
+* **activations**: The choice of activation functions had a larger effect on generation
+quality than I expected. In my experiments using non-monotonic activation functions
+outperformed monotonic ones (such as
+[ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/activations/relu)), with
+[Swish](https://www.tensorflow.org/api_docs/python/tf/keras/activations/swish) performing
+the best (this is also what [Imagen uses, page 41](https://arxiv.org/abs/2205.11487)).
+* **attention**: As mentioned earlier, it is common in the literature to use
+[attention layers](https://keras.io/api/layers/attention_layers/multi_head_attention/) at low
+resolutions for better global coherence. I omitted them for simplicity.
+* **upsampling**:
+[Bilinear and nearest neighbour upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/)
+in the network performed similarly, however I did not try
+[transposed convolutions](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/).
+
+For a similar list about GANs check out
+[this Keras tutorial](https://keras.io/examples/generative/gan_ada/#gan-tips-and-tricks).
+"""
+
+"""
+## What to try next?
+
+If you would like to dive in deeper to the topic, I recommend checking out
+[this repository](https://github.com/beresandras/clear-diffusion-keras) that I created in
+preparation for this code example, which implements a wider range of features in a
+similar style, such as:
+
+* stochastic sampling
+* second-order sampling based on the
+[differential equation view of DDIMs (Equation 13)](https://arxiv.org/abs/2010.02502)
+* more diffusion schedules
+* more network output types: predicting image or
+[velocity (Appendix D)](https://arxiv.org/abs/2202.00512) instead of noise
+* more datasets
+"""
+
+"""
+## Related works
+
+* [Score-based generative modeling](https://yang-song.github.io/blog/2021/score/)
+(blogpost)
+* [What are diffusion models?](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/)
+(blogpost)
+* [Annotated diffusion model](https://huggingface.co/blog/annotated-diffusion) (blogpost)
+* [CVPR 2022 tutorial on diffusion models](https://cvpr2022-tutorial-diffusion-models.github.io/)
+(slides available)
+* [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364):
+attempts unifying diffusion methods under a common framework
+* High-level video overviews: [1](https://www.youtube.com/watch?v=yTAMrHVG1ew),
+[2](https://www.youtube.com/watch?v=344w5h24-h8)
+* Detailed technical videos: [1](https://www.youtube.com/watch?v=fbLgFrlTnGU),
+[2](https://www.youtube.com/watch?v=W-O7AZNzbzQ)
+* Score-based generative models: [NCSN](https://arxiv.org/abs/1907.05600),
+[NCSN+](https://arxiv.org/abs/2006.09011), [NCSN++](https://arxiv.org/abs/2011.13456)
+* Denoising diffusion models: [DDPM](https://arxiv.org/abs/2006.11239),
+[DDIM](https://arxiv.org/abs/2010.02502), [DDPM+](https://arxiv.org/abs/2102.09672),
+[DDPM++](https://arxiv.org/abs/2105.05233)
+* Large diffusion models: [GLIDE](https://arxiv.org/abs/2112.10741),
+[DALL-E 2](https://arxiv.org/abs/2204.06125/), [Imagen](https://arxiv.org/abs/2205.11487)
+
+
+"""
diff --git a/knowledge_base/generative/ddpm.py b/knowledge_base/generative/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..baae14ee53cb5913687b3db01ec21c82bda41fe7
--- /dev/null
+++ b/knowledge_base/generative/ddpm.py
@@ -0,0 +1,815 @@
+"""
+Title: Denoising Diffusion Probabilistic Model
+Author: [A_K_Nain](https://twitter.com/A_K_Nain)
+Date created: 2022/11/30
+Last modified: 2022/12/07
+Description: Generating images of flowers with denoising diffusion probabilistic models.
+"""
+
+"""
+## Introduction
+
+Generative modeling experienced tremendous growth in the last five years. Models like
+VAEs, GANs, and flow-based models proved to be a great success in generating
+high-quality content, especially images. Diffusion models are a new type of generative
+model that has proven to be better than previous approaches.
+
+Diffusion models are inspired by non-equilibrium thermodynamics, and they learn to
+generate by denoising. Learning by denoising consists of two processes,
+each of which is a Markov Chain. These are:
+
+1. The forward process: In the forward process, we slowly add random noise to the data
+in a series of time steps `(t1, t2, ..., tn )`. Samples at the current time step are
+drawn from a Gaussian distribution where the mean of the distribution is conditioned
+on the sample at the previous time step, and the variance of the distribution follows
+a fixed schedule. At the end of the forward process, the samples end up with a pure
+noise distribution.
+
+2. The reverse process: During the reverse process, we try to undo the added noise at
+every time step. We start with the pure noise distribution (the last step of the
+forward process) and try to denoise the samples in the backward direction
+`(tn, tn-1, ..., t1)`.
+
+We implement the [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
+paper or DDPMs for short in this code example. It was the first paper demonstrating
+the use of diffusion models for generating high-quality images. The authors proved
+that a certain parameterization of diffusion models reveals an equivalence with
+denoising score matching over multiple noise levels during training and with annealed
+Langevin dynamics during sampling that generates the best quality results.
+
+This paper replicates both the Markov chains (forward process and reverse process)
+involved in the diffusion process but for images. The forward process is fixed and
+gradually adds Gaussian noise to the images according to a fixed variance schedule
+denoted by beta in the paper. This is what the diffusion process looks like in case
+of images: (image -> noise::noise -> image)
+
+
+
+
+The paper describes two algorithms, one for training the model, and the other for
+sampling from the trained model. Training is performed by optimizing the usual
+variational bound on negative log-likelihood. The objective function is further
+simplified, and the network is treated as a noise prediction network. Once optimized,
+we can sample from the network to generate new images from noise samples. Here is an
+overview of both algorithms as presented in the paper:
+
+
+
+
+**Note:** DDPM is just one way of implementing a diffusion model. Also, the sampling
+algorithm in the DDPM replicates the complete Markov chain. Hence, it's slow in
+generating new samples compared to other generative models like GANs. Lots of research
+efforts have been made to address this issue. One such example is Denoising Diffusion
+Implicit Models, or DDIM for short, where the authors replaced the Markov chain with a
+non-Markovian process to sample faster. You can find the code example for DDIM
+[here](https://keras.io/examples/generative/ddim/)
+
+Implementing a DDPM model is simple. We define a model that takes
+two inputs: Images and the randomly sampled time steps. At each training step, we
+perform the following operations to train our model:
+
+1. Sample random noise to be added to the inputs.
+2. Apply the forward process to diffuse the inputs with the sampled noise.
+3. Your model takes these noisy samples as inputs and outputs the noise
+prediction for each time step.
+4. Given true noise and predicted noise, we calculate the loss values
+5. We then calculate the gradients and update the model weights.
+
+Given that our model knows how to denoise a noisy sample at a given time step,
+we can leverage this idea to generate new samples, starting from a pure noise
+distribution.
+"""
+
+"""
+## Setup
+"""
+
+import math
+import numpy as np
+import matplotlib.pyplot as plt
+
+# Requires TensorFlow >=2.11 for the GroupNormalization layer.
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import tensorflow_datasets as tfds
+
+"""
+## Hyperparameters
+"""
+
+batch_size = 32
+num_epochs = 1 # Just for the sake of demonstration
+total_timesteps = 1000
+norm_groups = 8 # Number of groups used in GroupNormalization layer
+learning_rate = 2e-4
+
+img_size = 64
+img_channels = 3
+clip_min = -1.0
+clip_max = 1.0
+
+first_conv_channels = 64
+channel_multiplier = [1, 2, 4, 8]
+widths = [first_conv_channels * mult for mult in channel_multiplier]
+has_attention = [False, False, True, True]
+num_res_blocks = 2 # Number of residual blocks
+
+dataset_name = "oxford_flowers102"
+splits = ["train"]
+
+
+"""
+## Dataset
+
+We use the [Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102)
+dataset for generating images of flowers. In terms of preprocessing, we use center
+cropping for resizing the images to the desired image size, and we rescale the pixel
+values in the range `[-1.0, 1.0]`. This is in line with the range of the pixel values that
+was applied by the authors of the [DDPMs paper](https://arxiv.org/abs/2006.11239). For
+augmenting training data, we randomly flip the images left/right.
+"""
+
+
+# Load the dataset
+(ds,) = tfds.load(dataset_name, split=splits, with_info=False, shuffle_files=True)
+
+
+def augment(img):
+ """Flips an image left/right randomly."""
+ return tf.image.random_flip_left_right(img)
+
+
+def resize_and_rescale(img, size):
+ """Resize the image to the desired size first and then
+ rescale the pixel values in the range [-1.0, 1.0].
+
+ Args:
+ img: Image tensor
+ size: Desired image size for resizing
+ Returns:
+ Resized and rescaled image tensor
+ """
+
+ height = tf.shape(img)[0]
+ width = tf.shape(img)[1]
+ crop_size = tf.minimum(height, width)
+
+ img = tf.image.crop_to_bounding_box(
+ img,
+ (height - crop_size) // 2,
+ (width - crop_size) // 2,
+ crop_size,
+ crop_size,
+ )
+
+ # Resize
+ img = tf.cast(img, dtype=tf.float32)
+ img = tf.image.resize(img, size=size, antialias=True)
+
+ # Rescale the pixel values
+ img = img / 127.5 - 1.0
+ img = tf.clip_by_value(img, clip_min, clip_max)
+ return img
+
+
+def train_preprocessing(x):
+ img = x["image"]
+ img = resize_and_rescale(img, size=(img_size, img_size))
+ img = augment(img)
+ return img
+
+
+train_ds = (
+ ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(batch_size, drop_remainder=True)
+ .shuffle(batch_size * 2)
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+
+"""
+## Gaussian diffusion utilities
+
+We define the forward process and the reverse process
+as a separate utility. Most of the code in this utility has been borrowed
+from the original implementation with some slight modifications.
+"""
+
+
+class GaussianDiffusion:
+ """Gaussian diffusion utility.
+
+ Args:
+ beta_start: Start value of the scheduled variance
+ beta_end: End value of the scheduled variance
+ timesteps: Number of time steps in the forward process
+ """
+
+ def __init__(
+ self,
+ beta_start=1e-4,
+ beta_end=0.02,
+ timesteps=1000,
+ clip_min=-1.0,
+ clip_max=1.0,
+ ):
+ self.beta_start = beta_start
+ self.beta_end = beta_end
+ self.timesteps = timesteps
+ self.clip_min = clip_min
+ self.clip_max = clip_max
+
+ # Define the linear variance schedule
+ self.betas = betas = np.linspace(
+ beta_start,
+ beta_end,
+ timesteps,
+ dtype=np.float64, # Using float64 for better precision
+ )
+ self.num_timesteps = int(timesteps)
+
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ self.betas = tf.constant(betas, dtype=tf.float32)
+ self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
+ self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)
+
+ # Calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = tf.constant(
+ np.sqrt(alphas_cumprod), dtype=tf.float32
+ )
+
+ self.sqrt_one_minus_alphas_cumprod = tf.constant(
+ np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
+ )
+
+ self.log_one_minus_alphas_cumprod = tf.constant(
+ np.log(1.0 - alphas_cumprod), dtype=tf.float32
+ )
+
+ self.sqrt_recip_alphas_cumprod = tf.constant(
+ np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
+ )
+ self.sqrt_recipm1_alphas_cumprod = tf.constant(
+ np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
+ )
+
+ # Calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+ )
+ self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)
+
+ # Log calculation clipped because the posterior variance is 0 at the beginning
+ # of the diffusion chain
+ self.posterior_log_variance_clipped = tf.constant(
+ np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
+ )
+
+ self.posterior_mean_coef1 = tf.constant(
+ betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
+ dtype=tf.float32,
+ )
+
+ self.posterior_mean_coef2 = tf.constant(
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
+ dtype=tf.float32,
+ )
+
+ def _extract(self, a, t, x_shape):
+ """Extract some coefficients at specified timesteps,
+ then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
+
+ Args:
+ a: Tensor to extract from
+ t: Timestep for which the coefficients are to be extracted
+ x_shape: Shape of the current batched samples
+ """
+ batch_size = x_shape[0]
+ out = tf.gather(a, t)
+ return tf.reshape(out, [batch_size, 1, 1, 1])
+
+ def q_mean_variance(self, x_start, t):
+ """Extracts the mean, and the variance at current timestep.
+
+ Args:
+ x_start: Initial sample (before the first diffusion step)
+ t: Current timestep
+ """
+ x_start_shape = tf.shape(x_start)
+ mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
+ variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
+ log_variance = self._extract(
+ self.log_one_minus_alphas_cumprod, t, x_start_shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise):
+ """Diffuse the data.
+
+ Args:
+ x_start: Initial sample (before the first diffusion step)
+ t: Current timestep
+ noise: Gaussian noise to be added at the current timestep
+ Returns:
+ Diffused samples at timestep `t`
+ """
+ x_start_shape = tf.shape(x_start)
+ return (
+ self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
+ + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
+ * noise
+ )
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ x_t_shape = tf.shape(x_t)
+ return (
+ self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
+ - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ """Compute the mean and variance of the diffusion
+ posterior q(x_{t-1} | x_t, x_0).
+
+ Args:
+ x_start: Stating point(sample) for the posterior computation
+ x_t: Sample at timestep `t`
+ t: Current timestep
+ Returns:
+ Posterior mean and variance at current timestep
+ """
+
+ x_t_shape = tf.shape(x_t)
+ posterior_mean = (
+ self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
+ + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
+ )
+ posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
+ posterior_log_variance_clipped = self._extract(
+ self.posterior_log_variance_clipped, t, x_t_shape
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
+ x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
+ if clip_denoised:
+ x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ return model_mean, posterior_variance, posterior_log_variance
+
+ def p_sample(self, pred_noise, x, t, clip_denoised=True):
+ """Sample from the diffusion model.
+
+ Args:
+ pred_noise: Noise predicted by the diffusion model
+ x: Samples at a given timestep for which the noise was predicted
+ t: Current timestep
+ clip_denoised (bool): Whether to clip the predicted noise
+ within the specified range or not.
+ """
+ model_mean, _, model_log_variance = self.p_mean_variance(
+ pred_noise, x=x, t=t, clip_denoised=clip_denoised
+ )
+ noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
+ # No noise when t == 0
+ nonzero_mask = tf.reshape(
+ 1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
+ )
+ return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
+
+
+"""
+## Network architecture
+
+U-Net, originally developed for semantic segmentation, is an architecture that is
+widely used for implementing diffusion models but with some slight modifications:
+
+1. The network accepts two inputs: Image and time step
+2. Self-attention between the convolution blocks once we reach a specific resolution
+(16x16 in the paper)
+3. Group Normalization instead of weight normalization
+
+We implement most of the things as used in the original paper. We use the
+`swish` activation function throughout the network. We use the variance scaling
+kernel initializer.
+
+The only difference here is the number of groups used for the
+`GroupNormalization` layer. For the flowers dataset,
+we found that a value of `groups=8` produces better results
+compared to the default value of `groups=32`. Dropout is optional and should be
+used where chances of over fitting is high. In the paper, the authors used dropout
+only when training on CIFAR10.
+"""
+
+
+# Kernel initializer to use
+def kernel_init(scale):
+ scale = max(scale, 1e-10)
+ return keras.initializers.VarianceScaling(
+ scale, mode="fan_avg", distribution="uniform"
+ )
+
+
+class AttentionBlock(layers.Layer):
+ """Applies self-attention.
+
+ Args:
+ units: Number of units in the dense layers
+ groups: Number of groups to be used for GroupNormalization layer
+ """
+
+ def __init__(self, units, groups=8, **kwargs):
+ self.units = units
+ self.groups = groups
+ super().__init__(**kwargs)
+
+ self.norm = layers.GroupNormalization(groups=groups)
+ self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
+ self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
+ self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
+ self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))
+
+ def call(self, inputs):
+ batch_size = tf.shape(inputs)[0]
+ height = tf.shape(inputs)[1]
+ width = tf.shape(inputs)[2]
+ scale = tf.cast(self.units, tf.float32) ** (-0.5)
+
+ inputs = self.norm(inputs)
+ q = self.query(inputs)
+ k = self.key(inputs)
+ v = self.value(inputs)
+
+ attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
+ attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])
+
+ attn_score = tf.nn.softmax(attn_score, -1)
+ attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])
+
+ proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
+ proj = self.proj(proj)
+ return inputs + proj
+
+
+class TimeEmbedding(layers.Layer):
+ def __init__(self, dim, **kwargs):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.half_dim = dim // 2
+ self.emb = math.log(10000) / (self.half_dim - 1)
+ self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)
+
+ def call(self, inputs):
+ inputs = tf.cast(inputs, dtype=tf.float32)
+ emb = inputs[:, None] * self.emb[None, :]
+ emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
+ return emb
+
+
+def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
+ def apply(inputs):
+ x, t = inputs
+ input_width = x.shape[3]
+
+ if input_width == width:
+ residual = x
+ else:
+ residual = layers.Conv2D(
+ width, kernel_size=1, kernel_initializer=kernel_init(1.0)
+ )(x)
+
+ temb = activation_fn(t)
+ temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
+ :, None, None, :
+ ]
+
+ x = layers.GroupNormalization(groups=groups)(x)
+ x = activation_fn(x)
+ x = layers.Conv2D(
+ width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
+ )(x)
+
+ x = layers.Add()([x, temb])
+ x = layers.GroupNormalization(groups=groups)(x)
+ x = activation_fn(x)
+
+ x = layers.Conv2D(
+ width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
+ )(x)
+ x = layers.Add()([x, residual])
+ return x
+
+ return apply
+
+
+def DownSample(width):
+ def apply(x):
+ x = layers.Conv2D(
+ width,
+ kernel_size=3,
+ strides=2,
+ padding="same",
+ kernel_initializer=kernel_init(1.0),
+ )(x)
+ return x
+
+ return apply
+
+
+def UpSample(width, interpolation="nearest"):
+ def apply(x):
+ x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
+ x = layers.Conv2D(
+ width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
+ )(x)
+ return x
+
+ return apply
+
+
+def TimeMLP(units, activation_fn=keras.activations.swish):
+ def apply(inputs):
+ temb = layers.Dense(
+ units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
+ )(inputs)
+ temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
+ return temb
+
+ return apply
+
+
+def build_model(
+ img_size,
+ img_channels,
+ widths,
+ has_attention,
+ num_res_blocks=2,
+ norm_groups=8,
+ interpolation="nearest",
+ activation_fn=keras.activations.swish,
+):
+ image_input = layers.Input(
+ shape=(img_size, img_size, img_channels), name="image_input"
+ )
+ time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
+
+ x = layers.Conv2D(
+ first_conv_channels,
+ kernel_size=(3, 3),
+ padding="same",
+ kernel_initializer=kernel_init(1.0),
+ )(image_input)
+
+ temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
+ temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)
+
+ skips = [x]
+
+ # DownBlock
+ for i in range(len(widths)):
+ for _ in range(num_res_blocks):
+ x = ResidualBlock(
+ widths[i], groups=norm_groups, activation_fn=activation_fn
+ )([x, temb])
+ if has_attention[i]:
+ x = AttentionBlock(widths[i], groups=norm_groups)(x)
+ skips.append(x)
+
+ if widths[i] != widths[-1]:
+ x = DownSample(widths[i])(x)
+ skips.append(x)
+
+ # MiddleBlock
+ x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
+ [x, temb]
+ )
+ x = AttentionBlock(widths[-1], groups=norm_groups)(x)
+ x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
+ [x, temb]
+ )
+
+ # UpBlock
+ for i in reversed(range(len(widths))):
+ for _ in range(num_res_blocks + 1):
+ x = layers.Concatenate(axis=-1)([x, skips.pop()])
+ x = ResidualBlock(
+ widths[i], groups=norm_groups, activation_fn=activation_fn
+ )([x, temb])
+ if has_attention[i]:
+ x = AttentionBlock(widths[i], groups=norm_groups)(x)
+
+ if i != 0:
+ x = UpSample(widths[i], interpolation=interpolation)(x)
+
+ # End block
+ x = layers.GroupNormalization(groups=norm_groups)(x)
+ x = activation_fn(x)
+ x = layers.Conv2D(3, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
+ return keras.Model([image_input, time_input], x, name="unet")
+
+
+"""
+## Training
+
+We follow the same setup for training the diffusion model as described
+in the paper. We use `Adam` optimizer with a learning rate of `2e-4`.
+We use EMA on model parameters with a decay factor of 0.999. We
+treat our model as noise prediction network i.e. at every training step, we
+input a batch of images and corresponding time steps to our UNet,
+and the network outputs the noise as predictions.
+
+The only difference is that we aren't using the Kernel Inception Distance (KID)
+or Frechet Inception Distance (FID) for evaluating the quality of generated
+samples during training. This is because both these metrics are compute heavy
+and are skipped for the brevity of implementation.
+
+**Note: ** We are using mean squared error as the loss function which is aligned with
+the paper, and theoretically makes sense. In practice, though, it is also common to
+use mean absolute error or Huber loss as the loss function.
+"""
+
+
+class DiffusionModel(keras.Model):
+ def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
+ super().__init__()
+ self.network = network
+ self.ema_network = ema_network
+ self.timesteps = timesteps
+ self.gdf_util = gdf_util
+ self.ema = ema
+
+ def train_step(self, images):
+ # 1. Get the batch size
+ batch_size = tf.shape(images)[0]
+
+ # 2. Sample timesteps uniformly
+ t = tf.random.uniform(
+ minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64
+ )
+
+ with tf.GradientTape() as tape:
+ # 3. Sample random noise to be added to the images in the batch
+ noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
+
+ # 4. Diffuse the images with noise
+ images_t = self.gdf_util.q_sample(images, t, noise)
+
+ # 5. Pass the diffused images and time steps to the network
+ pred_noise = self.network([images_t, t], training=True)
+
+ # 6. Calculate the loss
+ loss = self.loss(noise, pred_noise)
+
+ # 7. Get the gradients
+ gradients = tape.gradient(loss, self.network.trainable_weights)
+
+ # 8. Update the weights of the network
+ self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
+
+ # 9. Updates the weight values for the network with EMA weights
+ for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
+ ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight)
+
+ # 10. Return loss values
+ return {"loss": loss}
+
+ def generate_images(self, num_images=16):
+ # 1. Randomly sample noise (starting point for reverse process)
+ samples = tf.random.normal(
+ shape=(num_images, img_size, img_size, img_channels), dtype=tf.float32
+ )
+ # 2. Sample from the model iteratively
+ for t in reversed(range(0, self.timesteps)):
+ tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
+ pred_noise = self.ema_network.predict(
+ [samples, tt], verbose=0, batch_size=num_images
+ )
+ samples = self.gdf_util.p_sample(
+ pred_noise, samples, tt, clip_denoised=True
+ )
+ # 3. Return generated samples
+ return samples
+
+ def plot_images(
+ self, epoch=None, logs=None, num_rows=2, num_cols=8, figsize=(12, 5)
+ ):
+ """Utility to plot images using the diffusion model during training."""
+ generated_samples = self.generate_images(num_images=num_rows * num_cols)
+ generated_samples = (
+ tf.clip_by_value(generated_samples * 127.5 + 127.5, 0.0, 255.0)
+ .numpy()
+ .astype(np.uint8)
+ )
+
+ _, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
+ for i, image in enumerate(generated_samples):
+ if num_rows == 1:
+ ax[i].imshow(image)
+ ax[i].axis("off")
+ else:
+ ax[i // num_cols, i % num_cols].imshow(image)
+ ax[i // num_cols, i % num_cols].axis("off")
+
+ plt.tight_layout()
+ plt.show()
+
+
+# Build the unet model
+network = build_model(
+ img_size=img_size,
+ img_channels=img_channels,
+ widths=widths,
+ has_attention=has_attention,
+ num_res_blocks=num_res_blocks,
+ norm_groups=norm_groups,
+ activation_fn=keras.activations.swish,
+)
+ema_network = build_model(
+ img_size=img_size,
+ img_channels=img_channels,
+ widths=widths,
+ has_attention=has_attention,
+ num_res_blocks=num_res_blocks,
+ norm_groups=norm_groups,
+ activation_fn=keras.activations.swish,
+)
+ema_network.set_weights(network.get_weights()) # Initially the weights are the same
+
+# Get an instance of the Gaussian Diffusion utilities
+gdf_util = GaussianDiffusion(timesteps=total_timesteps)
+
+# Get the model
+model = DiffusionModel(
+ network=network,
+ ema_network=ema_network,
+ gdf_util=gdf_util,
+ timesteps=total_timesteps,
+)
+
+# Compile the model
+model.compile(
+ loss=keras.losses.MeanSquaredError(),
+ optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
+)
+
+# Train the model
+model.fit(
+ train_ds,
+ epochs=num_epochs,
+ batch_size=batch_size,
+ callbacks=[keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images)],
+)
+
+"""
+## Results
+
+We trained this model for 800 epochs on a V100 GPU,
+and each epoch took almost 8 seconds to finish. We load those weights
+here, and we generate a few samples starting from pure noise.
+"""
+
+"""shell
+curl -LO https://github.com/AakashKumarNain/ddpms/releases/download/v3.0.0/checkpoints.zip
+unzip -qq checkpoints.zip
+"""
+
+# Load the model weights
+model.ema_network.load_weights("checkpoints/diffusion_model_checkpoint")
+
+# Generate and plot some samples
+model.plot_images(num_rows=4, num_cols=8)
+
+
+"""
+## Conclusion
+
+We successfully implemented and trained a diffusion model exactly in the same
+fashion as implemented by the authors of the DDPMs paper. You can find the
+original implementation [here](https://github.com/hojonathanho/diffusion).
+
+There are a few things that you can try to improve the model:
+
+1. Increasing the width of each block. A bigger model can learn to denoise
+in fewer epochs, though you may have to take care of overfitting.
+
+2. We implemented the linear schedule for variance scheduling. You can implement
+other schemes like cosine scheduling and compare the performance.
+"""
+
+"""
+## References
+
+1. [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
+2. [Author's implementation](https://github.com/hojonathanho/diffusion)
+3. [A deep dive into DDPMs](https://magic-with-latents.github.io/latent/posts/ddpms/part3/)
+4. [Denoising Diffusion Implicit Models](https://keras.io/examples/generative/ddim/)
+5. [Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion)
+6. [AIAIART](https://www.youtube.com/watch?v=XTs7M6TSK9I&t=14s)
+"""
diff --git a/knowledge_base/generative/deep_dream.py b/knowledge_base/generative/deep_dream.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2694d0c12d6cfa02dbfca558dc4b283963db6da
--- /dev/null
+++ b/knowledge_base/generative/deep_dream.py
@@ -0,0 +1,209 @@
+"""
+Title: Deep Dream
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2016/01/13
+Last modified: 2020/05/02
+Description: Generating Deep Dreams with Keras.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+"Deep dream" is an image-filtering technique which consists of taking an image
+classification model, and running gradient ascent over an input image to
+try to maximize the activations of specific layers (and sometimes, specific units in
+specific layers) for this input. It produces hallucination-like visuals.
+
+It was first introduced by Alexander Mordvintsev from Google in July 2015.
+
+Process:
+
+- Load the original image.
+- Define a number of processing scales ("octaves"),
+from smallest to largest.
+- Resize the original image to the smallest scale.
+- For every scale, starting with the smallest (i.e. current one):
+ - Run gradient ascent
+ - Upscale image to the next scale
+ - Reinject the detail that was lost at upscaling time
+- Stop when we are back to the original size.
+To obtain the detail lost during upscaling, we simply
+take the original image, shrink it down, upscale it,
+and compare the result to the (resized) original image.
+"""
+
+"""
+## Setup
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+import tensorflow as tf
+import keras
+from keras.applications import inception_v3
+
+base_image_path = keras.utils.get_file("sky.jpg", "https://i.imgur.com/aGBdQyK.jpg")
+result_prefix = "sky_dream"
+
+# These are the names of the layers
+# for which we try to maximize activation,
+# as well as their weight in the final loss
+# we try to maximize.
+# You can tweak these setting to obtain new visual effects.
+layer_settings = {
+ "mixed4": 1.0,
+ "mixed5": 1.5,
+ "mixed6": 2.0,
+ "mixed7": 2.5,
+}
+
+# Playing with these hyperparameters will also allow you to achieve new effects
+step = 0.01 # Gradient ascent step size
+num_octave = 3 # Number of scales at which to run gradient ascent
+octave_scale = 1.4 # Size ratio between scales
+iterations = 20 # Number of ascent steps per scale
+max_loss = 15.0
+
+"""
+This is our base image:
+"""
+
+from IPython.display import Image, display
+
+display(Image(base_image_path))
+
+"""
+Let's set up some image preprocessing/deprocessing utilities:
+"""
+
+
+def preprocess_image(image_path):
+ # Util function to open, resize and format pictures
+ # into appropriate arrays.
+ img = keras.utils.load_img(image_path)
+ img = keras.utils.img_to_array(img)
+ img = np.expand_dims(img, axis=0)
+ img = inception_v3.preprocess_input(img)
+ return img
+
+
+def deprocess_image(x):
+ # Util function to convert a NumPy array into a valid image.
+ x = x.reshape((x.shape[1], x.shape[2], 3))
+ # Undo inception v3 preprocessing
+ x /= 2.0
+ x += 0.5
+ x *= 255.0
+ # Convert to uint8 and clip to the valid range [0, 255]
+ x = np.clip(x, 0, 255).astype("uint8")
+ return x
+
+
+"""
+## Compute the Deep Dream loss
+
+First, build a feature extraction model to retrieve the activations of our target layers
+given an input image.
+"""
+
+# Build an InceptionV3 model loaded with pre-trained ImageNet weights
+model = inception_v3.InceptionV3(weights="imagenet", include_top=False)
+
+# Get the symbolic outputs of each "key" layer (we gave them unique names).
+outputs_dict = dict(
+ [
+ (layer.name, layer.output)
+ for layer in [model.get_layer(name) for name in layer_settings.keys()]
+ ]
+)
+
+# Set up a model that returns the activation values for every target layer
+# (as a dict)
+feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)
+
+"""
+The actual loss computation is very simple:
+"""
+
+
+def compute_loss(input_image):
+ features = feature_extractor(input_image)
+ # Initialize the loss
+ loss = tf.zeros(shape=())
+ for name in features.keys():
+ coeff = layer_settings[name]
+ activation = features[name]
+ # We avoid border artifacts by only involving non-border pixels in the loss.
+ scaling = tf.reduce_prod(tf.cast(tf.shape(activation), "float32"))
+ loss += coeff * tf.reduce_sum(tf.square(activation[:, 2:-2, 2:-2, :])) / scaling
+ return loss
+
+
+"""
+## Set up the gradient ascent loop for one octave
+"""
+
+
+@tf.function
+def gradient_ascent_step(img, learning_rate):
+ with tf.GradientTape() as tape:
+ tape.watch(img)
+ loss = compute_loss(img)
+ # Compute gradients.
+ grads = tape.gradient(loss, img)
+ # Normalize gradients.
+ grads /= tf.maximum(tf.reduce_mean(tf.abs(grads)), 1e-6)
+ img += learning_rate * grads
+ return loss, img
+
+
+def gradient_ascent_loop(img, iterations, learning_rate, max_loss=None):
+ for i in range(iterations):
+ loss, img = gradient_ascent_step(img, learning_rate)
+ if max_loss is not None and loss > max_loss:
+ break
+ print("... Loss value at step %d: %.2f" % (i, loss))
+ return img
+
+
+"""
+## Run the training loop, iterating over different octaves
+"""
+
+original_img = preprocess_image(base_image_path)
+original_shape = original_img.shape[1:3]
+
+successive_shapes = [original_shape]
+for i in range(1, num_octave):
+ shape = tuple([int(dim / (octave_scale**i)) for dim in original_shape])
+ successive_shapes.append(shape)
+successive_shapes = successive_shapes[::-1]
+shrunk_original_img = tf.image.resize(original_img, successive_shapes[0])
+
+img = tf.identity(original_img) # Make a copy
+for i, shape in enumerate(successive_shapes):
+ print("Processing octave %d with shape %s" % (i, shape))
+ img = tf.image.resize(img, shape)
+ img = gradient_ascent_loop(
+ img, iterations=iterations, learning_rate=step, max_loss=max_loss
+ )
+ upscaled_shrunk_original_img = tf.image.resize(shrunk_original_img, shape)
+ same_size_original = tf.image.resize(original_img, shape)
+ lost_detail = same_size_original - upscaled_shrunk_original_img
+
+ img += lost_detail
+ shrunk_original_img = tf.image.resize(original_img, shape)
+
+keras.utils.save_img(result_prefix + ".png", deprocess_image(img.numpy()))
+
+"""
+Display the result.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/deep-dream)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/deep-dream).
+"""
+
+display(Image(result_prefix + ".png"))
diff --git a/knowledge_base/generative/dreambooth.py b/knowledge_base/generative/dreambooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e487a89bd67262e18765c7b3a8c2db34265e2a
--- /dev/null
+++ b/knowledge_base/generative/dreambooth.py
@@ -0,0 +1,657 @@
+"""
+Title: DreamBooth
+Author: [Sayak Paul](https://twitter.com/RisingSayak), [Chansung Park](https://twitter.com/algo_diver)
+Date created: 2023/02/01
+Last modified: 2023/02/05
+Description: Implementing DreamBooth.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we implement DreamBooth, a fine-tuning technique to teach new visual
+concepts to text-conditioned Diffusion models with just 3 - 5 images. DreamBooth was
+proposed in
+[DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation](https://arxiv.org/abs/2208.12242)
+by Ruiz et al.
+
+DreamBooth, in a sense, is similar to the
+[traditional way of fine-tuning a text-conditioned Diffusion model except](https://keras.io/examples/generative/finetune_stable_diffusion/)
+for a few gotchas. This example assumes that you have basic familiarity with
+Diffusion models and how to fine-tune them. Here are some reference examples that might
+help you to get familiarized quickly:
+
+* [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)
+* [Teach StableDiffusion new concepts via Textual Inversion](https://keras.io/examples/generative/fine_tune_via_textual_inversion/)
+* [Fine-tuning Stable Diffusion](https://keras.io/examples/generative/finetune_stable_diffusion/)
+
+First, let's install the latest versions of KerasCV and TensorFlow.
+
+"""
+
+"""shell
+pip install -q -U keras_cv==0.6.0
+pip install -q -U tensorflow
+"""
+
+"""
+If you're running the code, please ensure you're using a GPU with at least 24 GBs of
+VRAM.
+"""
+
+"""
+## Initial imports
+"""
+
+import math
+
+import keras_cv
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+from imutils import paths
+from tensorflow import keras
+
+"""
+## Usage of DreamBooth
+
+... is very versatile. By teaching Stable Diffusion about your favorite visual
+concepts, you can
+
+* Recontextualize objects in interesting ways:
+
+ 
+
+* Generate artistic renderings of the underlying visual concept:
+
+ 
+
+
+And many other applications. We welcome you to check out the original
+[DreamBooth paper](https://arxiv.org/abs/2208.12242) in this regard.
+"""
+
+"""
+## Download the instance and class images
+
+DreamBooth uses a technique called "prior preservation" to meaningfully guide the
+training procedure such that the fine-tuned models can still preserve some of the prior
+semantics of the visual concept you're introducing. To know more about the idea of "prior
+preservation" refer to [this document](https://dreambooth.github.io/).
+
+Here, we need to introduce a few key terms specific to DreamBooth:
+
+* **Unique class**: Examples include "dog", "person", etc. In this example, we use "dog".
+* **Unique identifier**: A unique identifier that is prepended to the unique class while
+forming the "instance prompts". In this example, we use "sks" as this unique identifier.
+* **Instance prompt**: Denotes a prompt that best describes the "instance images". An
+example prompt could be - "f"a photo of {unique_id} {unique_class}". So, for our example,
+this becomes - "a photo of sks dog".
+* **Class prompt**: Denotes a prompt without the unique identifier. This prompt is used
+for generating "class images" for prior preservation. For our example, this prompt is -
+"a photo of dog".
+* **Instance images**: Denote the images that represent the visual concept you're trying
+to teach aka the "instance prompt". This number is typically just 3 - 5. We typically
+gather these images ourselves.
+* **Class images**: Denote the images generated using the "class prompt" for using prior
+preservation in DreamBooth training. We leverage the pre-trained model before fine-tuning
+it to generate these class images. Typically, 200 - 300 class images are enough.
+
+In code, this generation process looks quite simply:
+
+```py
+from tqdm import tqdm
+import numpy as np
+import hashlib
+import keras_cv
+import PIL
+import os
+
+class_images_dir = "class-images"
+os.makedirs(class_images_dir, exist_ok=True)
+
+model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=True)
+
+class_prompt = "a photo of dog"
+num_imgs_to_generate = 200
+for i in tqdm(range(num_imgs_to_generate)):
+ images = model.text_to_image(
+ class_prompt,
+ batch_size=3,
+ )
+ idx = np.random.choice(len(images))
+ selected_image = PIL.Image.fromarray(images[idx])
+ hash_image = hashlib.sha1(selected_image.tobytes()).hexdigest()
+ image_filename = os.path.join(class_images_dir, f"{hash_image}.jpg")
+ selected_image.save(image_filename)
+```
+
+To keep the runtime of this example short, the authors of this example have gone ahead
+and generated some class images using
+[this notebook](https://colab.research.google.com/gist/sayakpaul/6b5de345d29cf5860f84b6d04d958692/generate_class_priors.ipynb).
+
+**Note** that prior preservation is an optional technique used in DreamBooth, but it
+almost always helps in improving the quality of the generated images.
+"""
+
+instance_images_root = tf.keras.utils.get_file(
+ origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz",
+ untar=True,
+)
+class_images_root = tf.keras.utils.get_file(
+ origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz",
+ untar=True,
+)
+
+"""
+## Visualize images
+
+First, let's load the image paths.
+"""
+instance_image_paths = list(paths.list_images(instance_images_root))
+class_image_paths = list(paths.list_images(class_images_root))
+
+"""
+Then we load the images from the paths.
+"""
+
+
+def load_images(image_paths):
+ images = [np.array(keras.utils.load_img(path)) for path in image_paths]
+ return images
+
+
+"""
+And then we make use a utility function to plot the loaded images.
+"""
+
+
+def plot_images(images, title=None):
+ plt.figure(figsize=(20, 20))
+ for i in range(len(images)):
+ ax = plt.subplot(1, len(images), i + 1)
+ if title is not None:
+ plt.title(title)
+ plt.imshow(images[i])
+ plt.axis("off")
+
+
+"""
+**Instance images**:
+"""
+
+plot_images(load_images(instance_image_paths[:5]))
+
+"""
+**Class images**:
+"""
+
+plot_images(load_images(class_image_paths[:5]))
+
+"""
+## Prepare datasets
+
+Dataset preparation includes two stages: (1): preparing the captions, (2) processing the
+images.
+"""
+
+"""
+### Prepare the captions
+"""
+
+# Since we're using prior preservation, we need to match the number
+# of instance images we're using. We just repeat the instance image paths
+# to do so.
+new_instance_image_paths = []
+for index in range(len(class_image_paths)):
+ instance_image = instance_image_paths[index % len(instance_image_paths)]
+ new_instance_image_paths.append(instance_image)
+
+# We just repeat the prompts / captions per images.
+unique_id = "sks"
+class_label = "dog"
+
+instance_prompt = f"a photo of {unique_id} {class_label}"
+instance_prompts = [instance_prompt] * len(new_instance_image_paths)
+
+class_prompt = f"a photo of {class_label}"
+class_prompts = [class_prompt] * len(class_image_paths)
+
+"""
+Next, we embed the prompts to save some compute.
+"""
+
+import itertools
+
+# The padding token and maximum prompt length are specific to the text encoder.
+# If you're using a different text encoder be sure to change them accordingly.
+padding_token = 49407
+max_prompt_length = 77
+
+# Load the tokenizer.
+tokenizer = keras_cv.models.stable_diffusion.SimpleTokenizer()
+
+
+# Method to tokenize and pad the tokens.
+def process_text(caption):
+ tokens = tokenizer.encode(caption)
+ tokens = tokens + [padding_token] * (max_prompt_length - len(tokens))
+ return np.array(tokens)
+
+
+# Collate the tokenized captions into an array.
+tokenized_texts = np.empty(
+ (len(instance_prompts) + len(class_prompts), max_prompt_length)
+)
+
+for i, caption in enumerate(itertools.chain(instance_prompts, class_prompts)):
+ tokenized_texts[i] = process_text(caption)
+
+
+# We also pre-compute the text embeddings to save some memory during training.
+POS_IDS = tf.convert_to_tensor([list(range(max_prompt_length))], dtype=tf.int32)
+text_encoder = keras_cv.models.stable_diffusion.TextEncoder(max_prompt_length)
+
+gpus = tf.config.list_logical_devices("GPU")
+
+# Ensure the computation takes place on a GPU.
+# Note that it's done automatically when there's a GPU present.
+# This example just attempts at showing how you can do it
+# more explicitly.
+with tf.device(gpus[0].name):
+ embedded_text = text_encoder(
+ [tf.convert_to_tensor(tokenized_texts), POS_IDS], training=False
+ ).numpy()
+
+# To ensure text_encoder doesn't occupy any GPU space.
+del text_encoder
+
+"""
+## Prepare the images
+"""
+
+resolution = 512
+auto = tf.data.AUTOTUNE
+
+augmenter = keras.Sequential(
+ layers=[
+ keras_cv.layers.CenterCrop(resolution, resolution),
+ keras_cv.layers.RandomFlip(),
+ keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
+ ]
+)
+
+
+def process_image(image_path, tokenized_text):
+ image = tf.io.read_file(image_path)
+ image = tf.io.decode_png(image, 3)
+ image = tf.image.resize(image, (resolution, resolution))
+ return image, tokenized_text
+
+
+def apply_augmentation(image_batch, embedded_tokens):
+ return augmenter(image_batch), embedded_tokens
+
+
+def prepare_dict(instance_only=True):
+ def fn(image_batch, embedded_tokens):
+ if instance_only:
+ batch_dict = {
+ "instance_images": image_batch,
+ "instance_embedded_texts": embedded_tokens,
+ }
+ return batch_dict
+ else:
+ batch_dict = {
+ "class_images": image_batch,
+ "class_embedded_texts": embedded_tokens,
+ }
+ return batch_dict
+
+ return fn
+
+
+def assemble_dataset(image_paths, embedded_texts, instance_only=True, batch_size=1):
+ dataset = tf.data.Dataset.from_tensor_slices((image_paths, embedded_texts))
+ dataset = dataset.map(process_image, num_parallel_calls=auto)
+ dataset = dataset.shuffle(5, reshuffle_each_iteration=True)
+ dataset = dataset.batch(batch_size)
+ dataset = dataset.map(apply_augmentation, num_parallel_calls=auto)
+
+ prepare_dict_fn = prepare_dict(instance_only=instance_only)
+ dataset = dataset.map(prepare_dict_fn, num_parallel_calls=auto)
+ return dataset
+
+
+"""
+## Assemble dataset
+"""
+instance_dataset = assemble_dataset(
+ new_instance_image_paths,
+ embedded_text[: len(new_instance_image_paths)],
+)
+class_dataset = assemble_dataset(
+ class_image_paths,
+ embedded_text[len(new_instance_image_paths) :],
+ instance_only=False,
+)
+train_dataset = tf.data.Dataset.zip((instance_dataset, class_dataset))
+"""
+## Check shapes
+
+Now that the dataset has been prepared, let's quickly check what's inside it.
+"""
+
+sample_batch = next(iter(train_dataset))
+print(sample_batch[0].keys(), sample_batch[1].keys())
+
+for k in sample_batch[0]:
+ print(k, sample_batch[0][k].shape)
+
+for k in sample_batch[1]:
+ print(k, sample_batch[1][k].shape)
+
+"""
+During training, we make use of these keys to gather the images and text embeddings and
+concat them accordingly.
+"""
+
+"""
+## DreamBooth training loop
+
+Our DreamBooth training loop is very much inspired by
+[this script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)
+provided by the Diffusers team at Hugging Face. However, there is an important
+difference to note. We only fine-tune the UNet (the model responsible for predicting
+noise) and don't fine-tune the text encoder in this example. If you're looking for an
+implementation that also performs the additional fine-tuning of the text encoder, refer
+to [this repository](https://github.com/sayakpaul/dreambooth-keras/).
+"""
+
+import tensorflow.experimental.numpy as tnp
+
+
+class DreamBoothTrainer(tf.keras.Model):
+ # Reference:
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
+
+ def __init__(
+ self,
+ diffusion_model,
+ vae,
+ noise_scheduler,
+ use_mixed_precision=False,
+ prior_loss_weight=1.0,
+ max_grad_norm=1.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.diffusion_model = diffusion_model
+ self.vae = vae
+ self.noise_scheduler = noise_scheduler
+ self.prior_loss_weight = prior_loss_weight
+ self.max_grad_norm = max_grad_norm
+
+ self.use_mixed_precision = use_mixed_precision
+ self.vae.trainable = False
+
+ def train_step(self, inputs):
+ instance_batch = inputs[0]
+ class_batch = inputs[1]
+
+ instance_images = instance_batch["instance_images"]
+ instance_embedded_text = instance_batch["instance_embedded_texts"]
+ class_images = class_batch["class_images"]
+ class_embedded_text = class_batch["class_embedded_texts"]
+
+ images = tf.concat([instance_images, class_images], 0)
+ embedded_texts = tf.concat([instance_embedded_text, class_embedded_text], 0)
+ batch_size = tf.shape(images)[0]
+
+ with tf.GradientTape() as tape:
+ # Project image into the latent space and sample from it.
+ latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
+ # Know more about the magic number here:
+ # https://keras.io/examples/generative/fine_tune_via_textual_inversion/
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents.
+ noise = tf.random.normal(tf.shape(latents))
+
+ # Sample a random timestep for each image.
+ timesteps = tnp.random.randint(
+ 0, self.noise_scheduler.train_timesteps, (batch_size,)
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process).
+ noisy_latents = self.noise_scheduler.add_noise(
+ tf.cast(latents, noise.dtype), noise, timesteps
+ )
+
+ # Get the target for loss depending on the prediction type
+ # just the sampled noise for now.
+ target = noise # noise_schedule.predict_epsilon == True
+
+ # Predict the noise residual and compute loss.
+ timestep_embedding = tf.map_fn(
+ lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
+ )
+ model_pred = self.diffusion_model(
+ [noisy_latents, timestep_embedding, embedded_texts], training=True
+ )
+ loss = self.compute_loss(target, model_pred)
+ if self.use_mixed_precision:
+ loss = self.optimizer.get_scaled_loss(loss)
+
+ # Update parameters of the diffusion model.
+ trainable_vars = self.diffusion_model.trainable_variables
+ gradients = tape.gradient(loss, trainable_vars)
+ if self.use_mixed_precision:
+ gradients = self.optimizer.get_unscaled_gradients(gradients)
+ gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
+ half = dim // 2
+ log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
+ freqs = tf.math.exp(
+ -log_max_period * tf.range(0, half, dtype=tf.float32) / half
+ )
+ args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
+ embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
+ return embedding
+
+ def sample_from_encoder_outputs(self, outputs):
+ mean, logvar = tf.split(outputs, 2, axis=-1)
+ logvar = tf.clip_by_value(logvar, -30.0, 20.0)
+ std = tf.exp(0.5 * logvar)
+ sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
+ return mean + std * sample
+
+ def compute_loss(self, target, model_pred):
+ # Chunk the noise and model_pred into two parts and compute the loss
+ # on each part separately.
+ # Since the first half of the inputs has instance samples and the second half
+ # has class samples, we do the chunking accordingly.
+ model_pred, model_pred_prior = tf.split(
+ model_pred, num_or_size_splits=2, axis=0
+ )
+ target, target_prior = tf.split(target, num_or_size_splits=2, axis=0)
+
+ # Compute instance loss.
+ loss = self.compiled_loss(target, model_pred)
+
+ # Compute prior loss.
+ prior_loss = self.compiled_loss(target_prior, model_pred_prior)
+
+ # Add the prior loss to the instance loss.
+ loss = loss + self.prior_loss_weight * prior_loss
+ return loss
+
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
+ # Overriding this method will allow us to use the `ModelCheckpoint`
+ # callback directly with this trainer class. In this case, it will
+ # only checkpoint the `diffusion_model` since that's what we're training
+ # during fine-tuning.
+ self.diffusion_model.save_weights(
+ filepath=filepath,
+ overwrite=overwrite,
+ save_format=save_format,
+ options=options,
+ )
+
+ def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
+ # Similarly override `load_weights()` so that we can directly call it on
+ # the trainer class object.
+ self.diffusion_model.load_weights(
+ filepath=filepath,
+ by_name=by_name,
+ skip_mismatch=skip_mismatch,
+ options=options,
+ )
+
+
+"""
+## Trainer initialization
+"""
+
+# Comment it if you are not using a GPU having tensor cores.
+tf.keras.mixed_precision.set_global_policy("mixed_float16")
+
+use_mp = True # Set it to False if you're not using a GPU with tensor cores.
+
+image_encoder = keras_cv.models.stable_diffusion.ImageEncoder()
+dreambooth_trainer = DreamBoothTrainer(
+ diffusion_model=keras_cv.models.stable_diffusion.DiffusionModel(
+ resolution, resolution, max_prompt_length
+ ),
+ # Remove the top layer from the encoder, which cuts off the variance and only
+ # returns the mean.
+ vae=tf.keras.Model(
+ image_encoder.input,
+ image_encoder.layers[-2].output,
+ ),
+ noise_scheduler=keras_cv.models.stable_diffusion.NoiseScheduler(),
+ use_mixed_precision=use_mp,
+)
+
+# These hyperparameters come from this tutorial by Hugging Face:
+# https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
+learning_rate = 5e-6
+beta_1, beta_2 = 0.9, 0.999
+weight_decay = (1e-2,)
+epsilon = 1e-08
+
+optimizer = tf.keras.optimizers.experimental.AdamW(
+ learning_rate=learning_rate,
+ weight_decay=weight_decay,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+)
+dreambooth_trainer.compile(optimizer=optimizer, loss="mse")
+
+"""
+## Train!
+
+We first calculate the number of epochs, we need to train for.
+"""
+
+num_update_steps_per_epoch = train_dataset.cardinality()
+max_train_steps = 800
+epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
+print(f"Training for {epochs} epochs.")
+
+"""
+And then we start training!
+"""
+
+ckpt_path = "dreambooth-unet.h5"
+ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
+ ckpt_path,
+ save_weights_only=True,
+ monitor="loss",
+ mode="min",
+)
+dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=[ckpt_callback])
+
+"""
+## Experiments and inference
+
+We ran various experiments with a slightly modified version of this example. Our
+experiments are based on
+[this repository](https://github.com/sayakpaul/dreambooth-keras/) and are inspired by
+[this blog post](https://huggingface.co/blog/dreambooth) from Hugging Face.
+
+First, let's see how we can use the fine-tuned checkpoint for running inference.
+"""
+
+# Initialize a new Stable Diffusion model.
+dreambooth_model = keras_cv.models.StableDiffusion(
+ img_width=resolution, img_height=resolution, jit_compile=True
+)
+dreambooth_model.diffusion_model.load_weights(ckpt_path)
+
+# Note how the unique identifier and the class have been used in the prompt.
+prompt = f"A photo of {unique_id} {class_label} in a bucket"
+num_imgs_to_gen = 3
+
+images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
+plot_images(images_dreamboothed, prompt)
+
+"""
+Now, let's load checkpoints from a different experiment we conducted where we also
+fine-tuned the text encoder along with the UNet:
+"""
+
+unet_weights = tf.keras.utils.get_file(
+ origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-unet.h5"
+)
+text_encoder_weights = tf.keras.utils.get_file(
+ origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-text_encoder.h5"
+)
+
+dreambooth_model.diffusion_model.load_weights(unet_weights)
+dreambooth_model.text_encoder.load_weights(text_encoder_weights)
+
+images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
+plot_images(images_dreamboothed, prompt)
+
+"""
+The default number of steps for generating an image in `text_to_image()`
+[is 50](https://github.com/keras-team/keras-cv/blob/3575bc3b944564fe15b46b917e6555aa6a9d7be0/keras_cv/models/stable_diffusion/stable_diffusion.py#L73).
+Let's increase it to 100.
+"""
+
+images_dreamboothed = dreambooth_model.text_to_image(
+ prompt, batch_size=num_imgs_to_gen, num_steps=100
+)
+plot_images(images_dreamboothed, prompt)
+
+"""
+Feel free to experiment with different prompts (don't forget to add the unique identifier
+and the class label!) to see how the results change. We welcome you to check out our
+codebase and more experimental results
+[here](https://github.com/sayakpaul/dreambooth-keras#results). You can also read
+[this blog post](https://huggingface.co/blog/dreambooth) to get more ideas.
+"""
+
+"""
+## Acknowledgements
+
+* Thanks to the
+[DreamBooth example script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)
+provided by Hugging Face which helped us a lot in getting the initial implementation
+ready quickly.
+* Getting DreamBooth to work on human faces can be challenging. We have compiled some
+general recommendations
+[here](https://github.com/sayakpaul/dreambooth-keras#notes-on-preparing-data-for-dreambooth-training-of-faces).
+Thanks to
+[Abhishek Thakur](https://no.linkedin.com/in/abhi1thakur)
+for helping with these.
+"""
diff --git a/knowledge_base/generative/fine_tune_via_textual_inversion.py b/knowledge_base/generative/fine_tune_via_textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad904d91b72b380ec786c77ee6750c610193c76
--- /dev/null
+++ b/knowledge_base/generative/fine_tune_via_textual_inversion.py
@@ -0,0 +1,729 @@
+"""
+Title: Teach StableDiffusion new concepts via Textual Inversion
+Authors: Ian Stenbit, [lukewood](https://lukewood.xyz)
+Date created: 2022/12/09
+Last modified: 2022/12/09
+Description: Learning new visual concepts with KerasCV's StableDiffusion implementation.
+"""
+
+"""
+## Textual Inversion
+
+Since its release, StableDiffusion has quickly become a favorite amongst
+the generative machine learning community.
+The high volume of traffic has led to open source contributed improvements,
+heavy prompt engineering, and even the invention of novel algorithms.
+
+Perhaps the most impressive new algorithm being used is
+[Textual Inversion](https://github.com/rinongal/textual_inversion), presented in
+[_An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion_](https://textual-inversion.github.io/).
+
+Textual Inversion is the process of teaching an image generator a specific visual concept
+through the use of fine-tuning. In the diagram below, you can see an
+example of this process where the authors teach the model new concepts, calling them
+"S_*".
+
+
+
+Conceptually, textual inversion works by learning a token embedding for a new text
+token, keeping the remaining components of StableDiffusion frozen.
+
+This guide shows you how to fine-tune the StableDiffusion model shipped in KerasCV
+using the Textual-Inversion algorithm. By the end of the guide, you will be able to
+write the "Gandalf the Gray as a <my-funny-cat-token>".
+
+
+
+
+First, let's import the packages we need, and create a
+StableDiffusion instance so we can use some of its subcomponents for fine-tuning.
+"""
+
+"""shell
+pip install -q git+https://github.com/keras-team/keras-cv.git
+pip install -q tensorflow==2.11.0
+"""
+
+import math
+
+import keras_cv
+import numpy as np
+import tensorflow as tf
+from keras_cv import layers as cv_layers
+from keras_cv.models.stable_diffusion import NoiseScheduler
+from tensorflow import keras
+import matplotlib.pyplot as plt
+
+stable_diffusion = keras_cv.models.StableDiffusion()
+
+"""
+Next, let's define a visualization utility to show off the generated images:
+"""
+
+
+def plot_images(images):
+ plt.figure(figsize=(20, 20))
+ for i in range(len(images)):
+ ax = plt.subplot(1, len(images), i + 1)
+ plt.imshow(images[i])
+ plt.axis("off")
+
+
+"""
+## Assembling a text-image pair dataset
+
+In order to train the embedding of our new token, we first must assemble a dataset
+consisting of text-image pairs.
+Each sample from the dataset must contain an image of the concept we are teaching
+StableDiffusion, as well as a caption accurately representing the content of the image.
+In this tutorial, we will teach StableDiffusion the concept of Luke and Ian's GitHub
+avatars:
+
+
+
+First, let's construct an image dataset of cat dolls:
+"""
+
+
+def assemble_image_dataset(urls):
+ # Fetch all remote files
+ files = [tf.keras.utils.get_file(origin=url) for url in urls]
+
+ # Resize images
+ resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)
+ images = [keras.utils.load_img(img) for img in files]
+ images = [keras.utils.img_to_array(img) for img in images]
+ images = np.array([resize(img) for img in images])
+
+ # The StableDiffusion image encoder requires images to be normalized to the
+ # [-1, 1] pixel value range
+ images = images / 127.5 - 1
+
+ # Create the tf.data.Dataset
+ image_dataset = tf.data.Dataset.from_tensor_slices(images)
+
+ # Shuffle and introduce random noise
+ image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)
+ image_dataset = image_dataset.map(
+ cv_layers.RandomCropAndResize(
+ target_size=(512, 512),
+ crop_area_factor=(0.8, 1.0),
+ aspect_ratio_factor=(1.0, 1.0),
+ ),
+ num_parallel_calls=tf.data.AUTOTUNE,
+ )
+ image_dataset = image_dataset.map(
+ cv_layers.RandomFlip(mode="horizontal"),
+ num_parallel_calls=tf.data.AUTOTUNE,
+ )
+ return image_dataset
+
+
+"""
+Next, we assemble a text dataset:
+"""
+
+MAX_PROMPT_LENGTH = 77
+placeholder_token = ""
+
+
+def pad_embedding(embedding):
+ return embedding + (
+ [stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))
+ )
+
+
+stable_diffusion.tokenizer.add_tokens(placeholder_token)
+
+
+def assemble_text_dataset(prompts):
+ prompts = [prompt.format(placeholder_token) for prompt in prompts]
+ embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]
+ embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]
+ text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)
+ text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)
+ return text_dataset
+
+
+"""
+Finally, we zip our datasets together to produce a text-image pair dataset.
+"""
+
+
+def assemble_dataset(urls, prompts):
+ image_dataset = assemble_image_dataset(urls)
+ text_dataset = assemble_text_dataset(prompts)
+ # the image dataset is quite short, so we repeat it to match the length of the
+ # text prompt dataset
+ image_dataset = image_dataset.repeat()
+ # we use the text prompt dataset to determine the length of the dataset. Due to
+ # the fact that there are relatively few prompts we repeat the dataset 5 times.
+ # we have found that this anecdotally improves results.
+ text_dataset = text_dataset.repeat(5)
+ return tf.data.Dataset.zip((image_dataset, text_dataset))
+
+
+"""
+In order to ensure our prompts are descriptive, we use extremely generic prompts.
+
+Let's try this out with some sample images and prompts.
+"""
+
+train_ds = assemble_dataset(
+ urls=[
+ "https://i.imgur.com/VIedH1X.jpg",
+ "https://i.imgur.com/eBw13hE.png",
+ "https://i.imgur.com/oJ3rSg7.png",
+ "https://i.imgur.com/5mCL6Df.jpg",
+ "https://i.imgur.com/4Q6WWyI.jpg",
+ ],
+ prompts=[
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+ ],
+)
+
+"""
+## On the importance of prompt accuracy
+
+During our first attempt at writing this guide we included images of groups of these cat
+dolls in our dataset but continued to use the generic prompts listed above.
+Our results were anecdotally poor. For example, here's cat doll gandalf using this method:
+
+
+
+It's conceptually close, but it isn't as great as it can be.
+
+In order to remedy this, we began experimenting with splitting our images into images of
+singular cat dolls and groups of cat dolls.
+Following this split, we came up with new prompts for the group shots.
+
+Training on text-image pairs that accurately represent the content boosted the quality
+of our results *substantially*. This speaks to the importance of prompt accuracy.
+
+In addition to separating the images into singular and group images, we also remove some
+inaccurate prompts; such as "a dark photo of the {}"
+
+Keeping this in mind, we assemble our final training dataset below:
+"""
+
+single_ds = assemble_dataset(
+ urls=[
+ "https://i.imgur.com/VIedH1X.jpg",
+ "https://i.imgur.com/eBw13hE.png",
+ "https://i.imgur.com/oJ3rSg7.png",
+ "https://i.imgur.com/5mCL6Df.jpg",
+ "https://i.imgur.com/4Q6WWyI.jpg",
+ ],
+ prompts=[
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+ ],
+)
+
+"""
+
+
+Looks great!
+
+Next, we assemble a dataset of groups of our GitHub avatars:
+"""
+
+group_ds = assemble_dataset(
+ urls=[
+ "https://i.imgur.com/yVmZ2Qa.jpg",
+ "https://i.imgur.com/JbyFbZJ.jpg",
+ "https://i.imgur.com/CCubd3q.jpg",
+ ],
+ prompts=[
+ "a photo of a group of {}",
+ "a rendering of a group of {}",
+ "a cropped photo of the group of {}",
+ "the photo of a group of {}",
+ "a photo of a clean group of {}",
+ "a photo of my group of {}",
+ "a photo of a cool group of {}",
+ "a close-up photo of a group of {}",
+ "a bright photo of the group of {}",
+ "a cropped photo of a group of {}",
+ "a photo of the group of {}",
+ "a good photo of the group of {}",
+ "a photo of one group of {}",
+ "a close-up photo of the group of {}",
+ "a rendition of the group of {}",
+ "a photo of the clean group of {}",
+ "a rendition of a group of {}",
+ "a photo of a nice group of {}",
+ "a good photo of a group of {}",
+ "a photo of the nice group of {}",
+ "a photo of the small group of {}",
+ "a photo of the weird group of {}",
+ "a photo of the large group of {}",
+ "a photo of a cool group of {}",
+ "a photo of a small group of {}",
+ ],
+)
+
+"""
+
+
+Finally, we concatenate the two datasets:
+"""
+
+train_ds = single_ds.concatenate(group_ds)
+train_ds = train_ds.batch(1).shuffle(
+ train_ds.cardinality(), reshuffle_each_iteration=True
+)
+
+"""
+## Adding a new token to the text encoder
+
+Next, we create a new text encoder for the StableDiffusion model and add our new
+embedding for '' into the model.
+"""
+tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]
+new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(
+ tf.constant(tokenized_initializer)
+)
+
+# Get len of .vocab instead of tokenizer
+new_vocab_size = len(stable_diffusion.tokenizer.vocab)
+
+# The embedding layer is the 2nd layer in the text encoder
+old_token_weights = stable_diffusion.text_encoder.layers[
+ 2
+].token_embedding.get_weights()
+old_position_weights = stable_diffusion.text_encoder.layers[
+ 2
+].position_embedding.get_weights()
+
+old_token_weights = old_token_weights[0]
+new_weights = np.expand_dims(new_weights, axis=0)
+new_weights = np.concatenate([old_token_weights, new_weights], axis=0)
+
+
+"""
+Let's construct a new TextEncoder and prepare it.
+"""
+
+# Have to set download_weights False so we can init (otherwise tries to load weights)
+new_encoder = keras_cv.models.stable_diffusion.TextEncoder(
+ keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,
+ vocab_size=new_vocab_size,
+ download_weights=False,
+)
+for index, layer in enumerate(stable_diffusion.text_encoder.layers):
+ # Layer 2 is the embedding layer, so we omit it from our weight-copying
+ if index == 2:
+ continue
+ new_encoder.layers[index].set_weights(layer.get_weights())
+
+
+new_encoder.layers[2].token_embedding.set_weights([new_weights])
+new_encoder.layers[2].position_embedding.set_weights(old_position_weights)
+
+stable_diffusion._text_encoder = new_encoder
+stable_diffusion._text_encoder.compile(jit_compile=True)
+
+"""
+## Training
+
+Now we can move on to the exciting part: training!
+
+In TextualInversion, the only piece of the model that is trained is the embedding vector.
+Let's freeze the rest of the model.
+"""
+
+
+stable_diffusion.diffusion_model.trainable = False
+stable_diffusion.decoder.trainable = False
+stable_diffusion.text_encoder.trainable = True
+
+stable_diffusion.text_encoder.layers[2].trainable = True
+
+
+def traverse_layers(layer):
+ if hasattr(layer, "layers"):
+ for layer in layer.layers:
+ yield layer
+ if hasattr(layer, "token_embedding"):
+ yield layer.token_embedding
+ if hasattr(layer, "position_embedding"):
+ yield layer.position_embedding
+
+
+for layer in traverse_layers(stable_diffusion.text_encoder):
+ if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:
+ layer.trainable = True
+ else:
+ layer.trainable = False
+
+new_encoder.layers[2].position_embedding.trainable = False
+
+"""
+Let's confirm the proper weights are set to trainable.
+"""
+
+all_models = [
+ stable_diffusion.text_encoder,
+ stable_diffusion.diffusion_model,
+ stable_diffusion.decoder,
+]
+print([[w.shape for w in model.trainable_weights] for model in all_models])
+
+"""
+## Training the new embedding
+
+In order to train the embedding, we need a couple of utilities.
+We import a NoiseScheduler from KerasCV, and define the following utilities below:
+
+- `sample_from_encoder_outputs` is a wrapper around the base StableDiffusion image
+encoder which samples from the statistical distribution produced by the image
+encoder, rather than taking just the mean (like many other SD applications)
+- `get_timestep_embedding` produces an embedding for a specified timestep for the
+diffusion model
+- `get_position_ids` produces a tensor of position IDs for the text encoder (which is just a
+series from `[1, MAX_PROMPT_LENGTH]`)
+"""
+
+
+# Remove the top layer from the encoder, which cuts off the variance and only returns
+# the mean
+training_image_encoder = keras.Model(
+ stable_diffusion.image_encoder.input,
+ stable_diffusion.image_encoder.layers[-2].output,
+)
+
+
+def sample_from_encoder_outputs(outputs):
+ mean, logvar = tf.split(outputs, 2, axis=-1)
+ logvar = tf.clip_by_value(logvar, -30.0, 20.0)
+ std = tf.exp(0.5 * logvar)
+ sample = tf.random.normal(tf.shape(mean))
+ return mean + std * sample
+
+
+def get_timestep_embedding(timestep, dim=320, max_period=10000):
+ half = dim // 2
+ freqs = tf.math.exp(
+ -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
+ )
+ args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
+ embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
+ return embedding
+
+
+def get_position_ids():
+ return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
+
+
+"""
+Next, we implement a `StableDiffusionFineTuner`, which is a subclass of `keras.Model`
+that overrides `train_step` to train the token embeddings of our text encoder.
+This is the core of the Textual Inversion algorithm.
+
+Abstractly speaking, the train step takes a sample from the output of the frozen SD
+image encoder's latent distribution for a training image, adds noise to that sample, and
+then passes that noisy sample to the frozen diffusion model.
+The hidden state of the diffusion model is the output of the text encoder for the prompt
+corresponding to the image.
+
+Our final goal state is that the diffusion model is able to separate the noise from the
+sample using the text encoding as hidden state, so our loss is the mean-squared error of
+the noise and the output of the diffusion model (which has, ideally, removed the image
+latents from the noise).
+
+We compute gradients for only the token embeddings of the text encoder, and in the
+train step we zero-out the gradients for all tokens other than the token that we're
+learning.
+
+See in-line code comments for more details about the train step.
+"""
+
+
+class StableDiffusionFineTuner(keras.Model):
+ def __init__(self, stable_diffusion, noise_scheduler, **kwargs):
+ super().__init__(**kwargs)
+ self.stable_diffusion = stable_diffusion
+ self.noise_scheduler = noise_scheduler
+
+ def train_step(self, data):
+ images, embeddings = data
+
+ with tf.GradientTape() as tape:
+ # Sample from the predicted distribution for the training image
+ latents = sample_from_encoder_outputs(training_image_encoder(images))
+ # The latents must be downsampled to match the scale of the latents used
+ # in the training of StableDiffusion. This number is truly just a "magic"
+ # constant that they chose when training the model.
+ latents = latents * 0.18215
+
+ # Produce random noise in the same shape as the latent sample
+ noise = tf.random.normal(tf.shape(latents))
+ batch_dim = tf.shape(latents)[0]
+
+ # Pick a random timestep for each sample in the batch
+ timesteps = tf.random.uniform(
+ (batch_dim,),
+ minval=0,
+ maxval=noise_scheduler.train_timesteps,
+ dtype=tf.int64,
+ )
+
+ # Add noise to the latents based on the timestep for each sample
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Encode the text in the training samples to use as hidden state in the
+ # diffusion model
+ encoder_hidden_state = self.stable_diffusion.text_encoder(
+ [embeddings, get_position_ids()]
+ )
+
+ # Compute timestep embeddings for the randomly-selected timesteps for each
+ # sample in the batch
+ timestep_embeddings = tf.map_fn(
+ fn=get_timestep_embedding,
+ elems=timesteps,
+ fn_output_signature=tf.float32,
+ )
+
+ # Call the diffusion model
+ noise_pred = self.stable_diffusion.diffusion_model(
+ [noisy_latents, timestep_embeddings, encoder_hidden_state]
+ )
+
+ # Compute the mean-squared error loss and reduce it.
+ loss = self.compiled_loss(noise_pred, noise)
+ loss = tf.reduce_mean(loss, axis=2)
+ loss = tf.reduce_mean(loss, axis=1)
+ loss = tf.reduce_mean(loss)
+
+ # Load the trainable weights and compute the gradients for them
+ trainable_weights = self.stable_diffusion.text_encoder.trainable_weights
+ grads = tape.gradient(loss, trainable_weights)
+
+ # Gradients are stored in indexed slices, so we have to find the index
+ # of the slice(s) which contain the placeholder token.
+ index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())
+ condition = grads[0].indices == 49408
+ condition = tf.expand_dims(condition, axis=-1)
+
+ # Override the gradients, zeroing out the gradients for all slices that
+ # aren't for the placeholder token, effectively freezing the weights for
+ # all other tokens.
+ grads[0] = tf.IndexedSlices(
+ values=tf.where(condition, grads[0].values, 0),
+ indices=grads[0].indices,
+ dense_shape=grads[0].dense_shape,
+ )
+
+ self.optimizer.apply_gradients(zip(grads, trainable_weights))
+ return {"loss": loss}
+
+
+"""
+Before we start training, let's take a look at what StableDiffusion produces for our
+token.
+"""
+
+generated = stable_diffusion.text_to_image(
+ f"an oil painting of {placeholder_token}", seed=1337, batch_size=3
+)
+plot_images(generated)
+
+"""
+As you can see, the model still thinks of our token as a cat, as this was the seed token
+we used to initialize our custom token.
+
+Now, to get started with training, we can just `compile()` our model like any other
+Keras model. Before doing so, we also instantiate a noise scheduler for training and
+configure our training parameters such as learning rate and optimizer.
+"""
+
+noise_scheduler = NoiseScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ train_timesteps=1000,
+)
+trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")
+EPOCHS = 50
+learning_rate = keras.optimizers.schedules.CosineDecay(
+ initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS
+)
+optimizer = keras.optimizers.Adam(
+ weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10
+)
+
+trainer.compile(
+ optimizer=optimizer,
+ # We are performing reduction manually in our train step, so none is required here.
+ loss=keras.losses.MeanSquaredError(reduction="none"),
+)
+
+"""
+To monitor training, we can produce a `keras.callbacks.Callback` to produce a few images
+every epoch using our custom token.
+
+We create three callbacks with different prompts so that we can see how they progress
+over the course of training. We use a fixed seed so that we can easily see the
+progression of the learned token.
+"""
+
+
+class GenerateImages(keras.callbacks.Callback):
+ def __init__(
+ self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.stable_diffusion = stable_diffusion
+ self.prompt = prompt
+ self.seed = seed
+ self.frequency = frequency
+ self.steps = steps
+
+ def on_epoch_end(self, epoch, logs):
+ if epoch % self.frequency == 0:
+ images = self.stable_diffusion.text_to_image(
+ self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed
+ )
+ plot_images(
+ images,
+ )
+
+
+cbs = [
+ GenerateImages(
+ stable_diffusion, prompt=f"an oil painting of {placeholder_token}", seed=1337
+ ),
+ GenerateImages(
+ stable_diffusion, prompt=f"gandalf the gray as a {placeholder_token}", seed=1337
+ ),
+ GenerateImages(
+ stable_diffusion,
+ prompt=f"two {placeholder_token} getting married, photorealistic, high quality",
+ seed=1337,
+ ),
+]
+
+"""
+Now, all that is left to do is to call `model.fit()`!
+"""
+
+trainer.fit(
+ train_ds,
+ epochs=EPOCHS,
+ callbacks=cbs,
+)
+
+"""
+It's pretty fun to see how the model learns our new token over time. Play around with it
+and see how you can tune training parameters and your training dataset to produce the
+best images!
+"""
+
+"""
+## Taking the Fine Tuned Model for a Spin
+
+Now for the really fun part. We've learned a token embedding for our custom token, so
+now we can generate images with StableDiffusion the same way we would for any other
+token!
+
+Here are some fun example prompts to get you started, with sample outputs from our cat
+doll token!
+"""
+
+generated = stable_diffusion.text_to_image(
+ f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "
+ "golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "
+ "character concepts, digital painting, mystery, adventure",
+ batch_size=3,
+)
+plot_images(generated)
+
+"""
+"""
+
+generated = stable_diffusion.text_to_image(
+ f"A masterpiece of a {placeholder_token} crying out to the heavens. "
+ f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "
+ "life right out of it.",
+ batch_size=3,
+)
+plot_images(generated)
+
+"""
+"""
+
+generated = stable_diffusion.text_to_image(
+ f"An evil {placeholder_token}.", batch_size=3
+)
+plot_images(generated)
+
+"""
+"""
+
+generated = stable_diffusion.text_to_image(
+ f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",
+ batch_size=3,
+)
+plot_images(generated)
+
+"""
+## Conclusions
+
+Using the Textual Inversion algorithm you can teach StableDiffusion new concepts!
+
+Some possible next steps to follow:
+
+- Try out your own prompts
+- Teach the model a style
+- Gather a dataset of your favorite pet cat or dog and teach the model about it
+"""
diff --git a/knowledge_base/generative/finetune_stable_diffusion.py b/knowledge_base/generative/finetune_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..09e7c6fdf324f4fb23dc713104a15fd26378dce6
--- /dev/null
+++ b/knowledge_base/generative/finetune_stable_diffusion.py
@@ -0,0 +1,513 @@
+"""
+Title: Fine-tuning Stable Diffusion
+Author: [Sayak Paul](https://twitter.com/RisingSayak), [Chansung Park](https://twitter.com/algo_diver)
+Date created: 2022/12/28
+Last modified: 2023/01/13
+Description: Fine-tuning Stable Diffusion using a custom image-caption dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This tutorial shows how to fine-tune a
+[Stable Diffusion model](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)
+on a custom dataset of `{image, caption}` pairs. We build on top of the fine-tuning
+script provided by Hugging Face
+[here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py).
+
+We assume that you have a high-level understanding of the Stable Diffusion model.
+The following resources can be helpful if you're looking for more information in that regard:
+
+* [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)
+* [Stable Diffusion with Diffusers](https://huggingface.co/blog/stable_diffusion)
+
+It's highly recommended that you use a GPU with at least 30GB of memory to execute
+the code.
+
+By the end of the guide, you'll be able to generate images of interesting Pokรฉmon:
+
+
+
+The tutorial relies on KerasCV 0.4.0. Additionally, we need
+at least TensorFlow 2.11 in order to use AdamW with mixed precision.
+"""
+
+"""shell
+pip install keras-cv==0.6.0 -q
+pip install -U tensorflow -q
+pip install keras-core -q
+"""
+
+"""
+## What are we fine-tuning?
+
+A Stable Diffusion model can be decomposed into several key models:
+
+* A text encoder that projects the input prompt to a latent space. (The caption
+associated with an image is referred to as the "prompt".)
+* A variational autoencoder (VAE) that projects an input image to a latent space acting
+as an image vector space.
+* A diffusion model that refines a latent vector and produces another latent vector, conditioned
+on the encoded text prompt
+* A decoder that generates images given a latent vector from the diffusion model.
+
+It's worth noting that during the process of generating an image from a text prompt, the
+image encoder is not typically employed.
+
+However, during the process of fine-tuning, the workflow goes like the following:
+
+1. An input text prompt is projected to a latent space by the text encoder.
+2. An input image is projected to a latent space by the image encoder portion of the VAE.
+3. A small amount of noise is added to the image latent vector for a given timestep.
+4. The diffusion model uses latent vectors from these two spaces along with a timestep embedding
+to predict the noise that was added to the image latent.
+5. A reconstruction loss is calculated between the predicted noise and the original noise
+added in step 3.
+6. Finally, the diffusion model parameters are optimized w.r.t this loss using
+gradient descent.
+
+Note that only the diffusion model parameters are updated during fine-tuning, while the
+(pre-trained) text and the image encoders are kept frozen.
+
+Don't worry if this sounds complicated. The code is much simpler than this!
+"""
+
+"""
+## Imports
+"""
+
+from textwrap import wrap
+import os
+
+import keras_cv
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import tensorflow.experimental.numpy as tnp
+from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
+from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
+from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
+from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
+from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
+from tensorflow import keras
+
+"""
+## Data loading
+
+We use the dataset
+[Pokรฉmon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
+However, we'll use a slightly different version which was derived from the original
+dataset to fit better with `tf.data`. Refer to
+[the documentation](https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version)
+for more details.
+"""
+
+data_path = tf.keras.utils.get_file(
+ origin="https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
+ untar=True,
+)
+
+data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))
+
+data_frame["image_path"] = data_frame["image_path"].apply(
+ lambda x: os.path.join(data_path, x)
+)
+data_frame.head()
+
+"""
+Since we have only 833 `{image, caption}` pairs, we can precompute the text embeddings from
+the captions. Moreover, the text encoder will be kept frozen during the course of
+fine-tuning, so we can save some compute by doing this.
+
+Before we use the text encoder, we need to tokenize the captions.
+"""
+
+# The padding token and maximum prompt length are specific to the text encoder.
+# If you're using a different text encoder be sure to change them accordingly.
+PADDING_TOKEN = 49407
+MAX_PROMPT_LENGTH = 77
+
+# Load the tokenizer.
+tokenizer = SimpleTokenizer()
+
+
+# Method to tokenize and pad the tokens.
+def process_text(caption):
+ tokens = tokenizer.encode(caption)
+ tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
+ return np.array(tokens)
+
+
+# Collate the tokenized captions into an array.
+tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))
+
+all_captions = list(data_frame["caption"].values)
+for i, caption in enumerate(all_captions):
+ tokenized_texts[i] = process_text(caption)
+
+"""
+## Prepare a `tf.data.Dataset`
+
+In this section, we'll prepare a `tf.data.Dataset` object from the input image file paths
+and their corresponding caption tokens. The section will include the following:
+
+* Pre-computation of the text embeddings from the tokenized captions.
+* Loading and augmentation of the input images.
+* Shuffling and batching of the dataset.
+"""
+
+RESOLUTION = 256
+AUTO = tf.data.AUTOTUNE
+POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
+
+augmenter = keras.Sequential(
+ layers=[
+ keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
+ keras_cv.layers.RandomFlip(),
+ tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
+ ]
+)
+text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
+
+
+def process_image(image_path, tokenized_text):
+ image = tf.io.read_file(image_path)
+ image = tf.io.decode_png(image, 3)
+ image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
+ return image, tokenized_text
+
+
+def apply_augmentation(image_batch, token_batch):
+ return augmenter(image_batch), token_batch
+
+
+def run_text_encoder(image_batch, token_batch):
+ return (
+ image_batch,
+ token_batch,
+ text_encoder([token_batch, POS_IDS], training=False),
+ )
+
+
+def prepare_dict(image_batch, token_batch, encoded_text_batch):
+ return {
+ "images": image_batch,
+ "tokens": token_batch,
+ "encoded_text": encoded_text_batch,
+ }
+
+
+def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
+ dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
+ dataset = dataset.shuffle(batch_size * 10)
+ dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
+ dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
+ dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
+ dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
+ return dataset.prefetch(AUTO)
+
+
+"""
+The baseline Stable Diffusion model was trained using images with 512x512 resolution. It's
+unlikely for a model that's trained using higher-resolution images to transfer well to
+lower-resolution images. However, the current model will lead to OOM if we keep the
+resolution to 512x512 (without enabling mixed-precision). Therefore, in the interest of
+interactive demonstrations, we kept the input resolution to 256x256.
+"""
+
+# Prepare the dataset.
+training_dataset = prepare_dataset(
+ np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
+)
+
+# Take a sample batch and investigate.
+sample_batch = next(iter(training_dataset))
+
+for k in sample_batch:
+ print(k, sample_batch[k].shape)
+
+"""
+We can also take a look at the training images and their corresponding captions.
+"""
+
+plt.figure(figsize=(20, 10))
+
+for i in range(3):
+ ax = plt.subplot(1, 4, i + 1)
+ plt.imshow((sample_batch["images"][i] + 1) / 2)
+
+ text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
+ text = text.replace("<|startoftext|>", "")
+ text = text.replace("<|endoftext|>", "")
+ text = "\n".join(wrap(text, 12))
+ plt.title(text, fontsize=15)
+
+ plt.axis("off")
+
+"""
+## A trainer class for the fine-tuning loop
+"""
+
+
+class Trainer(tf.keras.Model):
+ # Reference:
+ # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+
+ def __init__(
+ self,
+ diffusion_model,
+ vae,
+ noise_scheduler,
+ use_mixed_precision=False,
+ max_grad_norm=1.0,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.diffusion_model = diffusion_model
+ self.vae = vae
+ self.noise_scheduler = noise_scheduler
+ self.max_grad_norm = max_grad_norm
+
+ self.use_mixed_precision = use_mixed_precision
+ self.vae.trainable = False
+
+ def train_step(self, inputs):
+ images = inputs["images"]
+ encoded_text = inputs["encoded_text"]
+ batch_size = tf.shape(images)[0]
+
+ with tf.GradientTape() as tape:
+ # Project image into the latent space and sample from it.
+ latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
+ # Know more about the magic number here:
+ # https://keras.io/examples/generative/fine_tune_via_textual_inversion/
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents.
+ noise = tf.random.normal(tf.shape(latents))
+
+ # Sample a random timestep for each image.
+ timesteps = tnp.random.randint(
+ 0, self.noise_scheduler.train_timesteps, (batch_size,)
+ )
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process).
+ noisy_latents = self.noise_scheduler.add_noise(
+ tf.cast(latents, noise.dtype), noise, timesteps
+ )
+
+ # Get the target for loss depending on the prediction type
+ # just the sampled noise for now.
+ target = noise # noise_schedule.predict_epsilon == True
+
+ # Predict the noise residual and compute loss.
+ timestep_embedding = tf.map_fn(
+ lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
+ )
+ timestep_embedding = tf.squeeze(timestep_embedding, 1)
+ model_pred = self.diffusion_model(
+ [noisy_latents, timestep_embedding, encoded_text], training=True
+ )
+ loss = self.compiled_loss(target, model_pred)
+ if self.use_mixed_precision:
+ loss = self.optimizer.get_scaled_loss(loss)
+
+ # Update parameters of the diffusion model.
+ trainable_vars = self.diffusion_model.trainable_variables
+ gradients = tape.gradient(loss, trainable_vars)
+ if self.use_mixed_precision:
+ gradients = self.optimizer.get_unscaled_gradients(gradients)
+ gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
+ half = dim // 2
+ log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
+ freqs = tf.math.exp(
+ -log_max_period * tf.range(0, half, dtype=tf.float32) / half
+ )
+ args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
+ embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
+ embedding = tf.reshape(embedding, [1, -1])
+ return embedding
+
+ def sample_from_encoder_outputs(self, outputs):
+ mean, logvar = tf.split(outputs, 2, axis=-1)
+ logvar = tf.clip_by_value(logvar, -30.0, 20.0)
+ std = tf.exp(0.5 * logvar)
+ sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
+ return mean + std * sample
+
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
+ # Overriding this method will allow us to use the `ModelCheckpoint`
+ # callback directly with this trainer class. In this case, it will
+ # only checkpoint the `diffusion_model` since that's what we're training
+ # during fine-tuning.
+ self.diffusion_model.save_weights(
+ filepath=filepath,
+ overwrite=overwrite,
+ save_format=save_format,
+ options=options,
+ )
+
+
+"""
+One important implementation detail to note here: Instead of directly taking
+the latent vector produced by the image encoder (which is a VAE), we sample from the
+mean and log-variance predicted by it. This way, we can achieve better sample
+quality and diversity.
+
+It's common to add support for mixed-precision training along with exponential
+moving averaging of model weights for fine-tuning these models. However, in the interest
+of brevity, we discard those elements. More on this later in the tutorial.
+"""
+
+"""
+## Initialize the trainer and compile it
+"""
+
+# Enable mixed-precision training if the underlying GPU has tensor cores.
+USE_MP = True
+if USE_MP:
+ keras.mixed_precision.set_global_policy("mixed_float16")
+
+image_encoder = ImageEncoder()
+diffusion_ft_trainer = Trainer(
+ diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
+ # Remove the top layer from the encoder, which cuts off the variance and only
+ # returns the mean.
+ vae=tf.keras.Model(
+ image_encoder.input,
+ image_encoder.layers[-2].output,
+ ),
+ noise_scheduler=NoiseScheduler(),
+ use_mixed_precision=USE_MP,
+)
+
+# These hyperparameters come from this tutorial by Hugging Face:
+# https://huggingface.co/docs/diffusers/training/text2image
+lr = 1e-5
+beta_1, beta_2 = 0.9, 0.999
+weight_decay = (1e-2,)
+epsilon = 1e-08
+
+optimizer = tf.keras.optimizers.experimental.AdamW(
+ learning_rate=lr,
+ weight_decay=weight_decay,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+)
+diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")
+
+"""
+## Fine-tuning
+
+To keep the runtime of this tutorial short, we just fine-tune for an epoch.
+"""
+
+epochs = 1
+ckpt_path = "finetuned_stable_diffusion.h5"
+ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
+ ckpt_path,
+ save_weights_only=True,
+ monitor="loss",
+ mode="min",
+)
+diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])
+
+"""
+## Inference
+
+We fine-tuned the model for 60 epochs on an image resolution of 512x512. To allow
+training with this resolution, we incorporated mixed-precision support. You can
+check out
+[this repository](https://github.com/sayakpaul/stabe-diffusion-keras-ft)
+for more details. It additionally provides support for exponential moving averaging of
+the fine-tuned model parameters and model checkpointing.
+
+
+For this section, we'll use the checkpoint derived after 60 epochs of fine-tuning.
+"""
+
+weights_path = tf.keras.utils.get_file(
+ origin="https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
+)
+
+img_height = img_width = 512
+pokemon_model = keras_cv.models.StableDiffusion(
+ img_width=img_width, img_height=img_height
+)
+# We just reload the weights of the fine-tuned diffusion model.
+pokemon_model.diffusion_model.load_weights(weights_path)
+
+"""
+Now, we can take this model for a test-drive.
+"""
+
+prompts = ["Yoda", "Hello Kitty", "A pokemon with red eyes"]
+images_to_generate = 3
+outputs = {}
+
+for prompt in prompts:
+ generated_images = pokemon_model.text_to_image(
+ prompt, batch_size=images_to_generate, unconditional_guidance_scale=40
+ )
+ outputs.update({prompt: generated_images})
+
+"""
+With 60 epochs of fine-tuning (a good number is about 70), the generated images were not
+up to the mark. So, we experimented with the number of steps Stable Diffusion takes
+during the inference time and the `unconditional_guidance_scale` parameter.
+
+We found the best results with this checkpoint with `unconditional_guidance_scale` set to
+40.
+"""
+
+
+def plot_images(images, title):
+ plt.figure(figsize=(20, 20))
+ for i in range(len(images)):
+ ax = plt.subplot(1, len(images), i + 1)
+ plt.imshow(images[i])
+ plt.title(title, fontsize=12)
+ plt.axis("off")
+
+
+for prompt in outputs:
+ plot_images(outputs[prompt], prompt)
+
+"""
+We can notice that the model has started adapting to the style of our dataset. You can
+check the
+[accompanying repository](https://github.com/sayakpaul/stable-diffusion-keras-ft#results)
+for more comparisons and commentary. If you're feeling adventurous to try out a demo,
+you can check out
+[this resource](https://huggingface.co/spaces/sayakpaul/pokemon-sd-kerascv).
+"""
+
+"""
+## Conclusion and acknowledgements
+
+We demonstrated how to fine-tune the Stable Diffusion model on a custom dataset. While
+the results are far from aesthetically pleasing, we believe with more epochs of
+fine-tuning, they will likely improve. To enable that, having support for gradient
+accumulation and distributed training is crucial. This can be thought of as the next step
+in this tutorial.
+
+There is another interesting way in which Stable Diffusion models can be fine-tuned,
+called textual inversion. You can refer to
+[this tutorial](https://keras.io/examples/generative/fine_tune_via_textual_inversion/)
+to know more about it.
+
+We'd like to acknowledge the GCP Credit support from ML Developer Programs' team at
+Google. We'd like to thank the Hugging Face team for providing the
+[fine-tuning script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)
+. It's very readable and easy to understand.
+"""
diff --git a/knowledge_base/generative/gan_ada.py b/knowledge_base/generative/gan_ada.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac79a8209010850f716b7b8c353b435771b80a69
--- /dev/null
+++ b/knowledge_base/generative/gan_ada.py
@@ -0,0 +1,772 @@
+"""
+Title: Data-efficient GANs with Adaptive Discriminator Augmentation
+Author: [Andrรกs Bรฉres](https://www.linkedin.com/in/andras-beres-789190210)
+Date created: 2021/10/28
+Last modified: 2025/01/23
+Description: Generating images from limited data using the Caltech Birds dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+### GANs
+
+[Generative Adversarial Networks (GANs)](https://arxiv.org/abs/1406.2661) are a popular
+class of generative deep learning models, commonly used for image generation. They
+consist of a pair of dueling neural networks, called the discriminator and the generator.
+The discriminator's task is to distinguish real images from generated (fake) ones, while
+the generator network tries to fool the discriminator by generating more and more
+realistic images. If the generator is however too easy or too hard to fool, it might fail
+to provide useful learning signal for the generator, therefore training GANs is usually
+considered a difficult task.
+
+### Data augmentation for GANS
+
+Data augmentation, a popular technique in deep learning, is the process of randomly
+applying semantics-preserving transformations to the input data to generate multiple
+realistic versions of it, thereby effectively multiplying the amount of training data
+available. The simplest example is left-right flipping an image, which preserves its
+contents while generating a second unique training sample. Data augmentation is commonly
+used in supervised learning to prevent overfitting and enhance generalization.
+
+The authors of [StyleGAN2-ADA](https://arxiv.org/abs/2006.06676) show that discriminator
+overfitting can be an issue in GANs, especially when only low amounts of training data is
+available. They propose Adaptive Discriminator Augmentation to mitigate this issue.
+
+Applying data augmentation to GANs however is not straightforward. Since the generator is
+updated using the discriminator's gradients, if the generated images are augmented, the
+augmentation pipeline has to be differentiable and also has to be GPU-compatible for
+computational efficiency. Luckily, the
+[Keras image augmentation layers](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
+fulfill both these requirements, and are therefore very well suited for this task.
+
+### Invertible data augmentation
+
+A possible difficulty when using data augmentation in generative models is the issue of
+["leaky augmentations" (section 2.2)](https://arxiv.org/abs/2006.06676), namely when the
+model generates images that are already augmented. This would mean that it was not able
+to separate the augmentation from the underlying data distribution, which can be caused
+by using non-invertible data transformations. For example, if either 0, 90, 180 or 270
+degree rotations are performed with equal probability, the original orientation of the
+images is impossible to infer, and this information is destroyed.
+
+A simple trick to make data augmentations invertible is to only apply them with some
+probability. That way the original version of the images will be more common, and the
+data distribution can be inferred. By properly choosing this probability, one can
+effectively regularize the discriminator without making the augmentations leaky.
+
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+import keras
+from keras import ops
+from keras import layers
+
+"""
+## Hyperparameterers
+"""
+
+# data
+num_epochs = 10 # train for 400 epochs for good results
+image_size = 64
+# resolution of Kernel Inception Distance measurement, see related section
+kid_image_size = 75
+padding = 0.25
+dataset_name = "caltech_birds2011"
+
+# adaptive discriminator augmentation
+max_translation = 0.125
+max_rotation = 0.125
+max_zoom = 0.25
+target_accuracy = 0.85
+integration_steps = 1000
+
+# architecture
+noise_size = 64
+depth = 4
+width = 128
+leaky_relu_slope = 0.2
+dropout_rate = 0.4
+
+# optimization
+batch_size = 128
+learning_rate = 2e-4
+beta_1 = 0.5 # not using the default value of 0.9 is important
+ema = 0.99
+
+"""
+## Data pipeline
+
+In this example, we will use the
+[Caltech Birds (2011)](https://www.tensorflow.org/datasets/catalog/caltech_birds2011) dataset for
+generating images of birds, which is a diverse natural dataset containing less then 6000
+images for training. When working with such low amounts of data, one has to take extra
+care to retain as high data quality as possible. In this example, we use the provided
+bounding boxes of the birds to cut them out with square crops while preserving their
+aspect ratios when possible.
+"""
+
+
+def round_to_int(float_value):
+ return ops.cast(ops.round(float_value), "int32")
+
+
+def preprocess_image(data):
+ # unnormalize bounding box coordinates
+ height = ops.cast(ops.shape(data["image"])[0], "float32")
+ width = ops.cast(ops.shape(data["image"])[1], "float32")
+ bounding_box = data["bbox"] * ops.stack([height, width, height, width])
+
+ # calculate center and length of longer side, add padding
+ target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
+ target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
+ target_size = ops.maximum(
+ (1.0 + padding) * (bounding_box[2] - bounding_box[0]),
+ (1.0 + padding) * (bounding_box[3] - bounding_box[1]),
+ )
+
+ # modify crop size to fit into image
+ target_height = ops.min(
+ [target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
+ )
+ target_width = ops.min(
+ [target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
+ )
+
+ # crop image, `ops.image.crop_images` only works with non-tensor croppings
+ image = ops.slice(
+ data["image"],
+ start_indices=(
+ round_to_int(target_center_y - 0.5 * target_height),
+ round_to_int(target_center_x - 0.5 * target_width),
+ 0,
+ ),
+ shape=(round_to_int(target_height), round_to_int(target_width), 3),
+ )
+
+ # resize and clip
+ image = ops.cast(image, "float32")
+ image = ops.image.resize(image, [image_size, image_size])
+
+ return ops.clip(image / 255.0, 0.0, 1.0)
+
+
+def prepare_dataset(split):
+ # the validation dataset is shuffled as well, because data order matters
+ # for the KID calculation
+ return (
+ tfds.load(dataset_name, split=split, shuffle_files=True)
+ .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
+ .cache()
+ .shuffle(10 * batch_size)
+ .batch(batch_size, drop_remainder=True)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+ )
+
+
+train_dataset = prepare_dataset("train")
+val_dataset = prepare_dataset("test")
+
+"""
+After preprocessing the training images look like the following:
+
+"""
+
+"""
+## Kernel inception distance
+
+[Kernel Inception Distance (KID)](https://arxiv.org/abs/1801.01401) was proposed as a
+replacement for the popular
+[Frechet Inception Distance (FID)](https://arxiv.org/abs/1706.08500)
+metric for measuring image generation quality.
+Both metrics measure the difference in the generated and training distributions in the
+representation space of an [InceptionV3](https://keras.io/api/applications/inceptionv3/)
+network pretrained on
+[ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012).
+
+According to the paper, KID was proposed because FID has no unbiased estimator, its
+expected value is higher when it is measured on fewer images. KID is more suitable for
+small datasets because its expected value does not depend on the number of samples it is
+measured on. In my experience it is also computationally lighter, numerically more
+stable, and simpler to implement because it can be estimated in a per-batch manner.
+
+In this example, the images are evaluated at the minimal possible resolution of the
+Inception network (75x75 instead of 299x299), and the metric is only measured on the
+validation set for computational efficiency.
+
+
+"""
+
+
+class KID(keras.metrics.Metric):
+ def __init__(self, name="kid", **kwargs):
+ super().__init__(name=name, **kwargs)
+
+ # KID is estimated per batch and is averaged across batches
+ self.kid_tracker = keras.metrics.Mean()
+
+ # a pretrained InceptionV3 is used without its classification layer
+ # transform the pixel values to the 0-255 range, then use the same
+ # preprocessing as during pretraining
+ self.encoder = keras.Sequential(
+ [
+ layers.InputLayer(input_shape=(image_size, image_size, 3)),
+ layers.Rescaling(255.0),
+ layers.Resizing(height=kid_image_size, width=kid_image_size),
+ layers.Lambda(keras.applications.inception_v3.preprocess_input),
+ keras.applications.InceptionV3(
+ include_top=False,
+ input_shape=(kid_image_size, kid_image_size, 3),
+ weights="imagenet",
+ ),
+ layers.GlobalAveragePooling2D(),
+ ],
+ name="inception_encoder",
+ )
+
+ def polynomial_kernel(self, features_1, features_2):
+ feature_dimensions = ops.cast(ops.shape(features_1)[1], "float32")
+ return (
+ features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
+ ) ** 3.0
+
+ def update_state(self, real_images, generated_images, sample_weight=None):
+ real_features = self.encoder(real_images, training=False)
+ generated_features = self.encoder(generated_images, training=False)
+
+ # compute polynomial kernels using the two sets of features
+ kernel_real = self.polynomial_kernel(real_features, real_features)
+ kernel_generated = self.polynomial_kernel(
+ generated_features, generated_features
+ )
+ kernel_cross = self.polynomial_kernel(real_features, generated_features)
+
+ # estimate the squared maximum mean discrepancy using the average kernel values
+ batch_size = ops.shape(real_features)[0]
+ batch_size_f = ops.cast(batch_size, "float32")
+ mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
+ batch_size_f * (batch_size_f - 1.0)
+ )
+ mean_kernel_generated = ops.sum(
+ kernel_generated * (1.0 - ops.eye(batch_size))
+ ) / (batch_size_f * (batch_size_f - 1.0))
+ mean_kernel_cross = ops.mean(kernel_cross)
+ kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
+
+ # update the average KID estimate
+ self.kid_tracker.update_state(kid)
+
+ def result(self):
+ return self.kid_tracker.result()
+
+ def reset_state(self):
+ self.kid_tracker.reset_state()
+
+
+"""
+
+## Adaptive discriminator augmentation
+
+The authors of [StyleGAN2-ADA](https://arxiv.org/abs/2006.06676) propose to change the
+augmentation probability adaptively during training. Though it is explained differently
+in the paper, they use [integral control](https://en.wikipedia.org/wiki/PID_controller#Integral) on the augmentation
+probability to keep the discriminator's accuracy on real images close to a target value.
+Note, that their controlled variable is actually the average sign of the discriminator
+logits (r_t in the paper), which corresponds to 2 * accuracy - 1.
+
+This method requires two hyperparameters:
+
+1. `target_accuracy`: the target value for the discriminator's accuracy on real images. I
+recommend selecting its value from the 80-90% range.
+2. [`integration_steps`](https://en.wikipedia.org/wiki/PID_controller#Mathematical_form):
+the number of update steps required for an accuracy error of 100% to transform into an
+augmentation probability increase of 100%. To give an intuition, this defines how slowly
+the augmentation probability is changed. I recommend setting this to a relatively high
+value (1000 in this case) so that the augmentation strength is only adjusted slowly.
+
+The main motivation for this procedure is that the optimal value of the target accuracy
+is similar across different dataset sizes (see [figure 4 and 5 in the paper](https://arxiv.org/abs/2006.06676)),
+so it does not have to be re-tuned, because the
+process automatically applies stronger data augmentation when it is needed.
+
+"""
+
+
+# "hard sigmoid", useful for binary accuracy calculation from logits
+def step(values):
+ # negative values -> 0.0, positive values -> 1.0
+ return 0.5 * (1.0 + ops.sign(values))
+
+
+# augments images with a probability that is dynamically updated during training
+class AdaptiveAugmenter(keras.Model):
+ def __init__(self):
+ super().__init__()
+
+ # stores the current probability of an image being augmented
+ self.probability = keras.Variable(0.0)
+ self.seed_generator = keras.random.SeedGenerator(42)
+
+ # the corresponding augmentation names from the paper are shown above each layer
+ # the authors show (see figure 4), that the blitting and geometric augmentations
+ # are the most helpful in the low-data regime
+ self.augmenter = keras.Sequential(
+ [
+ layers.InputLayer(input_shape=(image_size, image_size, 3)),
+ # blitting/x-flip:
+ layers.RandomFlip("horizontal"),
+ # blitting/integer translation:
+ layers.RandomTranslation(
+ height_factor=max_translation,
+ width_factor=max_translation,
+ interpolation="nearest",
+ ),
+ # geometric/rotation:
+ layers.RandomRotation(factor=max_rotation),
+ # geometric/isotropic and anisotropic scaling:
+ layers.RandomZoom(
+ height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
+ ),
+ ],
+ name="adaptive_augmenter",
+ )
+
+ def call(self, images, training):
+ if training:
+ augmented_images = self.augmenter(images, training=training)
+
+ # during training either the original or the augmented images are selected
+ # based on self.probability
+ augmentation_values = keras.random.uniform(
+ shape=(batch_size, 1, 1, 1), seed=self.seed_generator
+ )
+ augmentation_bools = ops.less(augmentation_values, self.probability)
+
+ images = ops.where(augmentation_bools, augmented_images, images)
+ return images
+
+ def update(self, real_logits):
+ current_accuracy = ops.mean(step(real_logits))
+
+ # the augmentation probability is updated based on the discriminator's
+ # accuracy on real images
+ accuracy_error = current_accuracy - target_accuracy
+ self.probability.assign(
+ ops.clip(self.probability + accuracy_error / integration_steps, 0.0, 1.0)
+ )
+
+
+"""
+## Network architecture
+
+Here we specify the architecture of the two networks:
+
+* generator: maps a random vector to an image, which should be as realistic as possible
+* discriminator: maps an image to a scalar score, which should be high for real and low
+for generated images
+
+GANs tend to be sensitive to the network architecture, I implemented a DCGAN architecture
+in this example, because it is relatively stable during training while being simple to
+implement. We use a constant number of filters throughout the network, use a sigmoid
+instead of tanh in the last layer of the generator, and use default initialization
+instead of random normal as further simplifications.
+
+As a good practice, we disable the learnable scale parameter in the batch normalization
+layers, because on one hand the following relu + convolutional layers make it redundant
+(as noted in the
+[documentation](https://keras.io/api/layers/normalization_layers/batch_normalization/)).
+But also because it should be disabled based on theory when using [spectral normalization
+(section 4.1)](https://arxiv.org/abs/1802.05957), which is not used here, but is common
+in GANs. We also disable the bias in the fully connected and convolutional layers, because
+the following batch normalization makes it redundant.
+"""
+
+
+# DCGAN generator
+def get_generator():
+ noise_input = keras.Input(shape=(noise_size,))
+ x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
+ x = layers.BatchNormalization(scale=False)(x)
+ x = layers.ReLU()(x)
+ x = layers.Reshape(target_shape=(4, 4, width))(x)
+ for _ in range(depth - 1):
+ x = layers.Conv2DTranspose(
+ width,
+ kernel_size=4,
+ strides=2,
+ padding="same",
+ use_bias=False,
+ )(x)
+ x = layers.BatchNormalization(scale=False)(x)
+ x = layers.ReLU()(x)
+ image_output = layers.Conv2DTranspose(
+ 3,
+ kernel_size=4,
+ strides=2,
+ padding="same",
+ activation="sigmoid",
+ )(x)
+
+ return keras.Model(noise_input, image_output, name="generator")
+
+
+# DCGAN discriminator
+def get_discriminator():
+ image_input = keras.Input(shape=(image_size, image_size, 3))
+ x = image_input
+ for _ in range(depth):
+ x = layers.Conv2D(
+ width,
+ kernel_size=4,
+ strides=2,
+ padding="same",
+ use_bias=False,
+ )(x)
+ x = layers.BatchNormalization(scale=False)(x)
+ x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
+ x = layers.Flatten()(x)
+ x = layers.Dropout(dropout_rate)(x)
+ output_score = layers.Dense(1)(x)
+
+ return keras.Model(image_input, output_score, name="discriminator")
+
+
+"""
+## GAN model
+"""
+
+
+class GAN_ADA(keras.Model):
+ def __init__(self):
+ super().__init__()
+
+ self.seed_generator = keras.random.SeedGenerator(seed=42)
+ self.augmenter = AdaptiveAugmenter()
+ self.generator = get_generator()
+ self.ema_generator = keras.models.clone_model(self.generator)
+ self.discriminator = get_discriminator()
+
+ self.generator.summary()
+ self.discriminator.summary()
+ # we have created all layers at this point, so we can mark the model
+ # as having been built
+ self.built = True
+
+ def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
+ super().compile(**kwargs)
+
+ # separate optimizers for the two networks
+ self.generator_optimizer = generator_optimizer
+ self.discriminator_optimizer = discriminator_optimizer
+
+ self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
+ self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
+ self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
+ self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
+ self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
+ self.kid = KID()
+
+ @property
+ def metrics(self):
+ return [
+ self.generator_loss_tracker,
+ self.discriminator_loss_tracker,
+ self.real_accuracy,
+ self.generated_accuracy,
+ self.augmentation_probability_tracker,
+ self.kid,
+ ]
+
+ def generate(self, batch_size, training):
+ latent_samples = keras.random.normal(
+ shape=(batch_size, noise_size), seed=self.seed_generator
+ )
+ # use ema_generator during inference
+ if training:
+ generated_images = self.generator(latent_samples, training=training)
+ else:
+ generated_images = self.ema_generator(latent_samples, training=training)
+ return generated_images
+
+ def adversarial_loss(self, real_logits, generated_logits):
+ # this is usually called the non-saturating GAN loss
+
+ real_labels = ops.ones(shape=(batch_size, 1))
+ generated_labels = ops.zeros(shape=(batch_size, 1))
+
+ # the generator tries to produce images that the discriminator considers as real
+ generator_loss = keras.losses.binary_crossentropy(
+ real_labels, generated_logits, from_logits=True
+ )
+ # the discriminator tries to determine if images are real or generated
+ discriminator_loss = keras.losses.binary_crossentropy(
+ ops.concatenate([real_labels, generated_labels], axis=0),
+ ops.concatenate([real_logits, generated_logits], axis=0),
+ from_logits=True,
+ )
+
+ return ops.mean(generator_loss), ops.mean(discriminator_loss)
+
+ def train_step(self, real_images):
+ real_images = self.augmenter(real_images, training=True)
+
+ # use persistent gradient tape because gradients will be calculated twice
+ with tf.GradientTape(persistent=True) as tape:
+ generated_images = self.generate(batch_size, training=True)
+ # gradient is calculated through the image augmentation
+ generated_images = self.augmenter(generated_images, training=True)
+
+ # separate forward passes for the real and generated images, meaning
+ # that batch normalization is applied separately
+ real_logits = self.discriminator(real_images, training=True)
+ generated_logits = self.discriminator(generated_images, training=True)
+
+ generator_loss, discriminator_loss = self.adversarial_loss(
+ real_logits, generated_logits
+ )
+
+ # calculate gradients and update weights
+ generator_gradients = tape.gradient(
+ generator_loss, self.generator.trainable_weights
+ )
+ discriminator_gradients = tape.gradient(
+ discriminator_loss, self.discriminator.trainable_weights
+ )
+ self.generator_optimizer.apply_gradients(
+ zip(generator_gradients, self.generator.trainable_weights)
+ )
+ self.discriminator_optimizer.apply_gradients(
+ zip(discriminator_gradients, self.discriminator.trainable_weights)
+ )
+
+ # update the augmentation probability based on the discriminator's performance
+ self.augmenter.update(real_logits)
+
+ self.generator_loss_tracker.update_state(generator_loss)
+ self.discriminator_loss_tracker.update_state(discriminator_loss)
+ self.real_accuracy.update_state(1.0, step(real_logits))
+ self.generated_accuracy.update_state(0.0, step(generated_logits))
+ self.augmentation_probability_tracker.update_state(self.augmenter.probability)
+
+ # track the exponential moving average of the generator's weights to decrease
+ # variance in the generation quality
+ for weight, ema_weight in zip(
+ self.generator.weights, self.ema_generator.weights
+ ):
+ ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
+
+ # KID is not measured during the training phase for computational efficiency
+ return {m.name: m.result() for m in self.metrics[:-1]}
+
+ def test_step(self, real_images):
+ generated_images = self.generate(batch_size, training=False)
+
+ self.kid.update_state(real_images, generated_images)
+
+ # only KID is measured during the evaluation phase for computational efficiency
+ return {self.kid.name: self.kid.result()}
+
+ def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
+ # plot random generated images for visual evaluation of generation quality
+ if epoch is None or (epoch + 1) % interval == 0:
+ num_images = num_rows * num_cols
+ generated_images = self.generate(num_images, training=False)
+
+ plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
+ for row in range(num_rows):
+ for col in range(num_cols):
+ index = row * num_cols + col
+ plt.subplot(num_rows, num_cols, index + 1)
+ plt.imshow(generated_images[index])
+ plt.axis("off")
+ plt.tight_layout()
+ plt.show()
+ plt.close()
+
+
+"""
+## Training
+
+One can should see from the metrics during training, that if the real accuracy
+(discriminator's accuracy on real images) is below the target accuracy, the augmentation
+probability is increased, and vice versa. In my experience, during a healthy GAN
+training, the discriminator accuracy should stay in the 80-95% range. Below that, the
+discriminator is too weak, above that it is too strong.
+
+Note that we track the exponential moving average of the generator's weights, and use that
+for image generation and KID evaluation.
+"""
+
+# create and compile the model
+model = GAN_ADA()
+model.compile(
+ generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
+ discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
+)
+
+# save the best model based on the validation KID metric
+checkpoint_path = "gan_model.weights.h5"
+checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_path,
+ save_weights_only=True,
+ monitor="val_kid",
+ mode="min",
+ save_best_only=True,
+)
+
+# run training and plot generated images periodically
+model.fit(
+ train_dataset,
+ epochs=num_epochs,
+ validation_data=val_dataset,
+ callbacks=[
+ keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
+ checkpoint_callback,
+ ],
+)
+
+"""
+## Inference
+"""
+
+# load the best model and generate images
+model.load_weights(checkpoint_path)
+model.plot_images()
+
+"""
+## Results
+
+By running the training for 400 epochs (which takes 2-3 hours in a Colab notebook), one
+can get high quality image generations using this code example.
+
+The evolution of a random batch of images over a 400 epoch training (ema=0.999 for
+animation smoothness):
+
+
+Latent-space interpolation between a batch of selected images:
+
+
+I also recommend trying out training on other datasets, such as
+[CelebA](https://www.tensorflow.org/datasets/catalog/celeb_a) for example. In my
+experience good results can be achieved without changing any hyperparameters (though
+discriminator augmentation might not be necessary).
+"""
+
+"""
+## GAN tips and tricks
+
+My goal with this example was to find a good tradeoff between ease of implementation and
+generation quality for GANs. During preparation, I have run numerous ablations using
+[this repository](https://github.com/beresandras/gan-flavours-keras).
+
+In this section I list the lessons learned and my recommendations in my subjective order
+of importance.
+
+I recommend checking out the [DCGAN paper](https://arxiv.org/abs/1511.06434), this
+[NeurIPS talk](https://www.youtube.com/watch?v=myGAju4L7O8), and this
+[large scale GAN study](https://arxiv.org/abs/1711.10337) for others' takes on this subject.
+
+### Architectural tips
+
+* **resolution**: Training GANs at higher resolutions tends to get more difficult, I
+recommend experimenting at 32x32 or 64x64 resolutions initially.
+* **initialization**: If you see strong colorful patterns early on in the training, the
+initialization might be the issue. Set the kernel_initializer parameters of layers to
+[random normal](https://keras.io/api/layers/initializers/#randomnormal-class), and
+decrease the standard deviation (recommended value: 0.02, following DCGAN) until the
+issue disappears.
+* **upsampling**: There are two main methods for upsampling in the generator.
+[Transposed convolution](https://keras.io/api/layers/convolution_layers/convolution2d_transpose/)
+is faster, but can lead to
+[checkerboard artifacts](https://distill.pub/2016/deconv-checkerboard/), which can be reduced by using
+a kernel size that is divisible with the stride (recommended kernel size is 4 for a stride of 2).
+[Upsampling](https://keras.io/api/layers/reshaping_layers/up_sampling2d/) +
+[standard convolution](https://keras.io/api/layers/convolution_layers/convolution2d/) can have slightly
+lower quality, but checkerboard artifacts are not an issue. I recommend using nearest-neighbor
+interpolation over bilinear for it.
+* **batch normalization in discriminator**: Sometimes has a high impact, I recommend
+trying out both ways.
+* **[spectral normalization](https://www.tensorflow.org/addons/api_docs/python/tfa/layers/SpectralNormalization)**:
+A popular technique for training GANs, can help with stability. I recommend
+disabling batch normalization's learnable scale parameters along with it.
+* **[residual connections](https://keras.io/guides/functional_api/#a-toy-resnet-model)**:
+While residual discriminators behave similarly, residual generators are more difficult to
+train in my experience. They are however necessary for training large and deep
+architectures. I recommend starting with non-residual architectures.
+* **dropout**: Using dropout before the last layer of the discriminator improves
+generation quality in my experience. Recommended dropout rate is below 0.5.
+* **[leaky ReLU](https://keras.io/api/layers/activation_layers/leaky_relu/)**: Use leaky
+ReLU activations in the discriminator to make its gradients less sparse. Recommended
+slope/alpha is 0.2 following DCGAN.
+
+### Algorithmic tips
+
+* **loss functions**: Numerous losses have been proposed over the years for training
+GANs, promising improved performance and stability. I have implemented 5 of them in
+[this repository](https://github.com/beresandras/gan-flavours-keras), and my experience is in
+line with [this GAN study](https://arxiv.org/abs/1711.10337): no loss seems to
+consistently outperform the default non-saturating GAN loss. I recommend using that as a
+default.
+* **Adam's beta_1 parameter**: The beta_1 parameter in Adam can be interpreted as the
+momentum of mean gradient estimation. Using 0.5 or even 0.0 instead of the default 0.9
+value was proposed in DCGAN and is important. This example would not work using its
+default value.
+* **separate batch normalization for generated and real images**: The forward pass of the
+discriminator should be separate for the generated and real images. Doing otherwise can
+lead to artifacts (45 degree stripes in my case) and decreased performance.
+* **exponential moving average of generator's weights**: This helps to reduce the
+variance of the KID measurement, and helps in averaging out the rapid color palette
+changes during training.
+* **[different learning rate for generator and discriminator](https://arxiv.org/abs/1706.08500)**:
+If one has the resources, it can help
+to tune the learning rates of the two networks separately. A similar idea is to update
+either network's (usually the discriminator's) weights multiple times for each of the
+other network's updates. I recommend using the same learning rate of 2e-4 (Adam),
+following DCGAN for both networks, and only updating both of them once as a default.
+* **label noise**: [One-sided label smoothing](https://arxiv.org/abs/1606.03498) (using
+less than 1.0 for real labels), or adding noise to the labels can regularize the
+discriminator not to get overconfident, however in my case they did not improve
+performance.
+* **adaptive data augmentation**: Since it adds another dynamic component to the training
+process, disable it as a default, and only enable it when the other components already
+work well.
+"""
+
+"""
+## Related works
+
+Other GAN-related Keras code examples:
+
+* [DCGAN + CelebA](https://keras.io/examples/generative/dcgan_overriding_train_step/)
+* [WGAN + FashionMNIST](https://keras.io/examples/generative/wgan_gp/)
+* [WGAN + Molecules](https://keras.io/examples/generative/wgan-graphs/)
+* [ConditionalGAN + MNIST](https://keras.io/examples/generative/conditional_gan/)
+* [CycleGAN + Horse2Zebra](https://keras.io/examples/generative/cyclegan/)
+* [StyleGAN](https://keras.io/examples/generative/stylegan/)
+
+Modern GAN architecture-lines:
+
+* [SAGAN](https://arxiv.org/abs/1805.08318), [BigGAN](https://arxiv.org/abs/1809.11096)
+* [ProgressiveGAN](https://arxiv.org/abs/1710.10196),
+[StyleGAN](https://arxiv.org/abs/1812.04948),
+[StyleGAN2](https://arxiv.org/abs/1912.04958),
+[StyleGAN2-ADA](https://arxiv.org/abs/2006.06676),
+[AliasFreeGAN](https://arxiv.org/abs/2106.12423)
+
+Concurrent papers on discriminator data augmentation:
+[1](https://arxiv.org/abs/2006.02595), [2](https://arxiv.org/abs/2006.05338), [3](https://arxiv.org/abs/2006.10738)
+
+Recent literature overview on GANs: [talk](https://www.youtube.com/watch?v=3ktD752xq5k)
+"""
diff --git a/knowledge_base/generative/gaugan.py b/knowledge_base/generative/gaugan.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c45f6795dfcfa432b73757c178ba913f786b39a
--- /dev/null
+++ b/knowledge_base/generative/gaugan.py
@@ -0,0 +1,845 @@
+"""
+Title: GauGAN for conditional image generation
+Author: [Soumik Rakshit](https://github.com/soumik12345), [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/12/26
+Last modified: 2022/01/03
+Description: Implementing a GauGAN for conditional image generation.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we present an implementation of the GauGAN architecture proposed in
+[Semantic Image Synthesis with Spatially-Adaptive Normalization](https://arxiv.org/abs/1903.07291).
+Briefly, GauGAN uses a Generative Adversarial Network (GAN) to generate realistic images
+that are conditioned on cue images and segmentation maps, as shown below
+([image source](https://nvlabs.github.io/SPADE/)):
+
+
+
+The main components of a GauGAN are:
+
+- **SPADE (aka spatially-adaptive normalization)** : The authors of GauGAN argue that the
+more conventional normalization layers (such as
+[Batch Normalization](https://arxiv.org/abs/1502.03167))
+destroy the semantic information obtained from segmentation maps that
+are provided as inputs. To address this problem, the authors introduce SPADE, a
+normalization layer particularly suitable for learning affine parameters (scale and bias)
+that are spatially adaptive. This is done by learning different sets of scaling and
+bias parameters for each semantic label.
+- **Variational encoder**: Inspired by
+[Variational Autoencoders](https://arxiv.org/abs/1312.6114), GauGAN uses a
+variational formulation wherein an encoder learns the mean and variance of a
+normal (Gaussian) distribution from the cue images. This is where GauGAN gets its name
+from. The generator of GauGAN takes as inputs the latents sampled from the Gaussian
+distribution as well as the one-hot encoded semantic segmentation label maps. The cue
+images act as style images that guide the generator to stylistic generation. This
+variational formulation helps GauGAN achieve image diversity as well as fidelity.
+- **Multi-scale patch discriminator** : Inspired by the
+[PatchGAN](https://paperswithcode.com/method/patchgan) model,
+GauGAN uses a discriminator that assesses a given image on a patch basis
+and produces an averaged score.
+
+As we proceed with the example, we will discuss each of the different
+components in further detail.
+
+For a thorough review of GauGAN, please refer to
+[this article](https://blog.paperspace.com/nvidia-gaugan-introduction/).
+We also encourage you to check out
+[the official GauGAN website](https://nvlabs.github.io/SPADE/), which
+has many creative applications of GauGAN. This example assumes that the reader is already
+familiar with the fundamental concepts of GANs. If you need a refresher, the following
+resources might be useful:
+
+* [Chapter on GANs](https://livebook.manning.com/book/deep-learning-with-python/chapter-8)
+from the Deep Learning with Python book by Franรงois Chollet.
+* GAN implementations on keras.io:
+
+ * [Data efficient GANs](https://keras.io/examples/generative/gan_ada)
+ * [CycleGAN](https://keras.io/examples/generative/cyclegan)
+ * [Conditional GAN](https://keras.io/examples/generative/conditional_gan)
+"""
+
+"""
+## Data collection
+
+We will be using the
+[Facades dataset](https://cmp.felk.cvut.cz/~tylecr1/facade/)
+for training our GauGAN model. Let's first download it.
+"""
+
+"""shell
+wget https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj -O facades_data.zip
+unzip -q facades_data.zip
+"""
+
+"""
+## Imports
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import tensorflow as tf
+import keras
+from keras import ops
+from keras import layers
+
+from glob import glob
+
+"""
+## Data splitting
+"""
+
+PATH = "./facades_data/"
+SPLIT = 0.2
+
+files = glob(PATH + "*.jpg")
+np.random.shuffle(files)
+
+split_index = int(len(files) * (1 - SPLIT))
+train_files = files[:split_index]
+val_files = files[split_index:]
+
+print(f"Total samples: {len(files)}.")
+print(f"Total training samples: {len(train_files)}.")
+print(f"Total validation samples: {len(val_files)}.")
+
+"""
+## Data loader
+"""
+
+BATCH_SIZE = 4
+IMG_HEIGHT = IMG_WIDTH = 256
+NUM_CLASSES = 12
+AUTOTUNE = tf.data.AUTOTUNE
+
+
+def load(image_files, batch_size, is_train=True):
+ def _random_crop(
+ segmentation_map,
+ image,
+ labels,
+ crop_size=(IMG_HEIGHT, IMG_WIDTH),
+ ):
+ crop_size = tf.convert_to_tensor(crop_size)
+ image_shape = tf.shape(image)[:2]
+ margins = image_shape - crop_size
+ y1 = tf.random.uniform(shape=(), maxval=margins[0], dtype=tf.int32)
+ x1 = tf.random.uniform(shape=(), maxval=margins[1], dtype=tf.int32)
+ y2 = y1 + crop_size[0]
+ x2 = x1 + crop_size[1]
+
+ cropped_images = []
+ images = [segmentation_map, image, labels]
+ for img in images:
+ cropped_images.append(img[y1:y2, x1:x2])
+ return cropped_images
+
+ def _load_data_tf(image_file, segmentation_map_file, label_file):
+ image = tf.image.decode_png(tf.io.read_file(image_file), channels=3)
+ segmentation_map = tf.image.decode_png(
+ tf.io.read_file(segmentation_map_file), channels=3
+ )
+ labels = tf.image.decode_bmp(tf.io.read_file(label_file), channels=0)
+ labels = tf.squeeze(labels)
+
+ image = tf.cast(image, tf.float32) / 127.5 - 1
+ segmentation_map = tf.cast(segmentation_map, tf.float32) / 127.5 - 1
+ return segmentation_map, image, labels
+
+ def _one_hot(segmentation_maps, real_images, labels):
+ labels = tf.one_hot(labels, NUM_CLASSES)
+ labels.set_shape((None, None, NUM_CLASSES))
+ return segmentation_maps, real_images, labels
+
+ segmentation_map_files = [
+ image_file.replace("images", "segmentation_map").replace("jpg", "png")
+ for image_file in image_files
+ ]
+ label_files = [
+ image_file.replace("images", "segmentation_labels").replace("jpg", "bmp")
+ for image_file in image_files
+ ]
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (image_files, segmentation_map_files, label_files)
+ )
+
+ dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
+ dataset = dataset.map(_load_data_tf, num_parallel_calls=AUTOTUNE)
+ dataset = dataset.map(_random_crop, num_parallel_calls=AUTOTUNE)
+ dataset = dataset.map(_one_hot, num_parallel_calls=AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ return dataset
+
+
+train_dataset = load(train_files, batch_size=BATCH_SIZE, is_train=True)
+val_dataset = load(val_files, batch_size=BATCH_SIZE, is_train=False)
+
+"""
+Now, let's visualize a few samples from the training set.
+"""
+
+sample_train_batch = next(iter(train_dataset))
+print(f"Segmentation map batch shape: {sample_train_batch[0].shape}.")
+print(f"Image batch shape: {sample_train_batch[1].shape}.")
+print(f"One-hot encoded label map shape: {sample_train_batch[2].shape}.")
+
+# Plot a view samples from the training set.
+for segmentation_map, real_image in zip(sample_train_batch[0], sample_train_batch[1]):
+ fig = plt.figure(figsize=(10, 10))
+ fig.add_subplot(1, 2, 1).set_title("Segmentation Map")
+ plt.imshow((segmentation_map + 1) / 2)
+ fig.add_subplot(1, 2, 2).set_title("Real Image")
+ plt.imshow((real_image + 1) / 2)
+ plt.show()
+
+"""
+Note that in the rest of this example, we use a couple of figures from the
+[original GauGAN paper](https://arxiv.org/abs/1903.07291) for convenience.
+"""
+
+"""
+## Custom layers
+
+In the following section, we implement the following layers:
+
+* SPADE
+* Residual block including SPADE
+* Gaussian sampler
+"""
+
+"""
+### Some more notes on SPADE
+
+
+
+**SPatially-Adaptive (DE) normalization** or **SPADE** is a simple but effective layer
+for synthesizing photorealistic images given an input semantic layout. Previous methods
+for conditional image generation from semantic input such as
+Pix2Pix ([Isola et al.](https://arxiv.org/abs/1611.07004))
+or Pix2PixHD ([Wang et al.](https://arxiv.org/abs/1711.11585))
+directly feed the semantic layout as input to the deep network, which is then processed
+through stacks of convolution, normalization, and nonlinearity layers. This is often
+suboptimal as the normalization layers have a tendency to wash away semantic information.
+
+In SPADE, the segmentation mask is first projected onto an embedding space, and then
+convolved to produce the modulation parameters `ฮณ` and `ฮฒ`. Unlike prior conditional
+normalization methods, `ฮณ` and `ฮฒ` are not vectors, but tensors with spatial dimensions.
+The produced `ฮณ` and `ฮฒ` are multiplied and added to the normalized activation
+element-wise. As the modulation parameters are adaptive to the input segmentation mask,
+SPADE is better suited for semantic image synthesis.
+"""
+
+
+class SPADE(layers.Layer):
+ def __init__(self, filters, epsilon=1e-5, **kwargs):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+ self.conv = layers.Conv2D(128, 3, padding="same", activation="relu")
+ self.conv_gamma = layers.Conv2D(filters, 3, padding="same")
+ self.conv_beta = layers.Conv2D(filters, 3, padding="same")
+
+ def build(self, input_shape):
+ self.resize_shape = input_shape[1:3]
+
+ def call(self, input_tensor, raw_mask):
+ mask = ops.image.resize(raw_mask, self.resize_shape, interpolation="nearest")
+ x = self.conv(mask)
+ gamma = self.conv_gamma(x)
+ beta = self.conv_beta(x)
+ mean, var = ops.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
+ std = ops.sqrt(var + self.epsilon)
+ normalized = (input_tensor - mean) / std
+ output = gamma * normalized + beta
+ return output
+
+
+class ResBlock(layers.Layer):
+ def __init__(self, filters, **kwargs):
+ super().__init__(**kwargs)
+ self.filters = filters
+
+ def build(self, input_shape):
+ input_filter = input_shape[-1]
+ self.spade_1 = SPADE(input_filter)
+ self.spade_2 = SPADE(self.filters)
+ self.conv_1 = layers.Conv2D(self.filters, 3, padding="same")
+ self.conv_2 = layers.Conv2D(self.filters, 3, padding="same")
+ self.learned_skip = False
+
+ if self.filters != input_filter:
+ self.learned_skip = True
+ self.spade_3 = SPADE(input_filter)
+ self.conv_3 = layers.Conv2D(self.filters, 3, padding="same")
+
+ def call(self, input_tensor, mask):
+ x = self.spade_1(input_tensor, mask)
+ x = self.conv_1(keras.activations.leaky_relu(x, 0.2))
+ x = self.spade_2(x, mask)
+ x = self.conv_2(keras.activations.leaky_relu(x, 0.2))
+ skip = (
+ self.conv_3(
+ keras.activations.leaky_relu(self.spade_3(input_tensor, mask), 0.2)
+ )
+ if self.learned_skip
+ else input_tensor
+ )
+ output = skip + x
+ return output
+
+
+class GaussianSampler(layers.Layer):
+ def __init__(self, batch_size, latent_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.batch_size = batch_size
+ self.latent_dim = latent_dim
+ self.seed_generator = keras.random.SeedGenerator(1337)
+
+ def call(self, inputs):
+ means, variance = inputs
+ epsilon = keras.random.normal(
+ shape=(self.batch_size, self.latent_dim),
+ mean=0.0,
+ stddev=1.0,
+ seed=self.seed_generator,
+ )
+ samples = means + ops.exp(0.5 * variance) * epsilon
+ return samples
+
+
+"""
+Next, we implement the downsampling block for the encoder.
+"""
+
+
+def downsample(
+ channels,
+ kernels,
+ strides=2,
+ apply_norm=True,
+ apply_activation=True,
+ apply_dropout=False,
+):
+ block = keras.Sequential()
+ block.add(
+ layers.Conv2D(
+ channels,
+ kernels,
+ strides=strides,
+ padding="same",
+ use_bias=False,
+ kernel_initializer=keras.initializers.GlorotNormal(),
+ )
+ )
+ if apply_norm:
+ block.add(layers.GroupNormalization(groups=-1))
+ if apply_activation:
+ block.add(layers.LeakyReLU(0.2))
+ if apply_dropout:
+ block.add(layers.Dropout(0.5))
+ return block
+
+
+"""
+The GauGAN encoder consists of a few downsampling blocks. It outputs the mean and
+variance of a distribution.
+
+
+
+"""
+
+
+def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):
+ input_image = keras.Input(shape=image_shape)
+ x = downsample(encoder_downsample_factor, 3, apply_norm=False)(input_image)
+ x = downsample(2 * encoder_downsample_factor, 3)(x)
+ x = downsample(4 * encoder_downsample_factor, 3)(x)
+ x = downsample(8 * encoder_downsample_factor, 3)(x)
+ x = downsample(8 * encoder_downsample_factor, 3)(x)
+ x = layers.Flatten()(x)
+ mean = layers.Dense(latent_dim, name="mean")(x)
+ variance = layers.Dense(latent_dim, name="variance")(x)
+ return keras.Model(input_image, [mean, variance], name="encoder")
+
+
+"""
+Next, we implement the generator, which consists of the modified residual blocks and
+upsampling blocks. It takes latent vectors and one-hot encoded segmentation labels, and
+produces new images.
+
+
+
+With SPADE, there is no need to feed the segmentation map to the first layer of the
+generator, since the latent inputs have enough structural information about the style we
+want the generator to emulate. We also discard the encoder part of the generator, which is
+commonly used in prior architectures. This results in a more lightweight
+generator network, which can also take a random vector as input, enabling a simple and
+natural path to multi-modal synthesis.
+"""
+
+
+def build_generator(mask_shape, latent_dim=256):
+ latent = keras.Input(shape=(latent_dim,))
+ mask = keras.Input(shape=mask_shape)
+ x = layers.Dense(16384)(latent)
+ x = layers.Reshape((4, 4, 1024))(x)
+ x = ResBlock(filters=1024)(x, mask)
+ x = layers.UpSampling2D((2, 2))(x)
+ x = ResBlock(filters=1024)(x, mask)
+ x = layers.UpSampling2D((2, 2))(x)
+ x = ResBlock(filters=1024)(x, mask)
+ x = layers.UpSampling2D((2, 2))(x)
+ x = ResBlock(filters=512)(x, mask)
+ x = layers.UpSampling2D((2, 2))(x)
+ x = ResBlock(filters=256)(x, mask)
+ x = layers.UpSampling2D((2, 2))(x)
+ x = ResBlock(filters=128)(x, mask)
+ x = layers.UpSampling2D((2, 2))(x)
+ x = keras.activations.leaky_relu(x, 0.2)
+ output_image = keras.activations.tanh(layers.Conv2D(3, 4, padding="same")(x))
+ return keras.Model([latent, mask], output_image, name="generator")
+
+
+"""
+The discriminator takes a segmentation map and an image and concatenates them. It
+then predicts if patches of the concatenated image are real or fake.
+
+
+"""
+
+
+def build_discriminator(image_shape, downsample_factor=64):
+ input_image_A = keras.Input(shape=image_shape, name="discriminator_image_A")
+ input_image_B = keras.Input(shape=image_shape, name="discriminator_image_B")
+ x = layers.Concatenate()([input_image_A, input_image_B])
+ x1 = downsample(downsample_factor, 4, apply_norm=False)(x)
+ x2 = downsample(2 * downsample_factor, 4)(x1)
+ x3 = downsample(4 * downsample_factor, 4)(x2)
+ x4 = downsample(8 * downsample_factor, 4, strides=1)(x3)
+ x5 = layers.Conv2D(1, 4)(x4)
+ outputs = [x1, x2, x3, x4, x5]
+ return keras.Model([input_image_A, input_image_B], outputs)
+
+
+"""
+## Loss functions
+
+GauGAN uses the following loss functions:
+
+* Generator:
+
+ * Expectation over the discriminator predictions.
+ * [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
+ for learning the mean and variance predicted by the encoder.
+ * Minimization between the discriminator predictions on original and generated
+ images to align the feature space of the generator.
+ * [Perceptual loss](https://arxiv.org/abs/1603.08155) for encouraging the generated
+ images to have perceptual quality.
+
+* Discriminator:
+
+ * [Hinge loss](https://en.wikipedia.org/wiki/Hinge_loss).
+"""
+
+
+def generator_loss(y):
+ return -ops.mean(y)
+
+
+def kl_divergence_loss(mean, variance):
+ return -0.5 * ops.sum(1 + variance - ops.square(mean) - ops.exp(variance))
+
+
+class FeatureMatchingLoss(keras.losses.Loss):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.mae = keras.losses.MeanAbsoluteError()
+
+ def call(self, y_true, y_pred):
+ loss = 0
+ for i in range(len(y_true) - 1):
+ loss += self.mae(y_true[i], y_pred[i])
+ return loss
+
+
+class VGGFeatureMatchingLoss(keras.losses.Loss):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.encoder_layers = [
+ "block1_conv1",
+ "block2_conv1",
+ "block3_conv1",
+ "block4_conv1",
+ "block5_conv1",
+ ]
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
+ vgg = keras.applications.VGG19(include_top=False, weights="imagenet")
+ layer_outputs = [vgg.get_layer(x).output for x in self.encoder_layers]
+ self.vgg_model = keras.Model(vgg.input, layer_outputs, name="VGG")
+ self.mae = keras.losses.MeanAbsoluteError()
+
+ def call(self, y_true, y_pred):
+ y_true = keras.applications.vgg19.preprocess_input(127.5 * (y_true + 1))
+ y_pred = keras.applications.vgg19.preprocess_input(127.5 * (y_pred + 1))
+ real_features = self.vgg_model(y_true)
+ fake_features = self.vgg_model(y_pred)
+ loss = 0
+ for i in range(len(real_features)):
+ loss += self.weights[i] * self.mae(real_features[i], fake_features[i])
+ return loss
+
+
+class DiscriminatorLoss(keras.losses.Loss):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.hinge_loss = keras.losses.Hinge()
+
+ def call(self, y, is_real):
+ return self.hinge_loss(is_real, y)
+
+
+"""
+## GAN monitor callback
+
+Next, we implement a callback to monitor the GauGAN results while it is training.
+"""
+
+
+class GanMonitor(keras.callbacks.Callback):
+ def __init__(self, val_dataset, n_samples, epoch_interval=5):
+ self.val_images = next(iter(val_dataset))
+ self.n_samples = n_samples
+ self.epoch_interval = epoch_interval
+ self.seed_generator = keras.random.SeedGenerator(42)
+
+ def infer(self):
+ latent_vector = keras.random.normal(
+ shape=(self.model.batch_size, self.model.latent_dim),
+ mean=0.0,
+ stddev=2.0,
+ seed=self.seed_generator,
+ )
+ return self.model.predict([latent_vector, self.val_images[2]])
+
+ def on_epoch_end(self, epoch, logs=None):
+ if epoch % self.epoch_interval == 0:
+ generated_images = self.infer()
+ for _ in range(self.n_samples):
+ grid_row = min(generated_images.shape[0], 3)
+ f, axarr = plt.subplots(grid_row, 3, figsize=(18, grid_row * 6))
+ for row in range(grid_row):
+ ax = axarr if grid_row == 1 else axarr[row]
+ ax[0].imshow((self.val_images[0][row] + 1) / 2)
+ ax[0].axis("off")
+ ax[0].set_title("Mask", fontsize=20)
+ ax[1].imshow((self.val_images[1][row] + 1) / 2)
+ ax[1].axis("off")
+ ax[1].set_title("Ground Truth", fontsize=20)
+ ax[2].imshow((generated_images[row] + 1) / 2)
+ ax[2].axis("off")
+ ax[2].set_title("Generated", fontsize=20)
+ plt.show()
+
+
+"""
+## Subclassed GauGAN model
+
+Finally, we put everything together inside a subclassed model (from `tf.keras.Model`)
+overriding its `train_step()` method.
+"""
+
+
+class GauGAN(keras.Model):
+ def __init__(
+ self,
+ image_size,
+ num_classes,
+ batch_size,
+ latent_dim,
+ feature_loss_coeff=10,
+ vgg_feature_loss_coeff=0.1,
+ kl_divergence_loss_coeff=0.1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.latent_dim = latent_dim
+ self.batch_size = batch_size
+ self.num_classes = num_classes
+ self.image_shape = (image_size, image_size, 3)
+ self.mask_shape = (image_size, image_size, num_classes)
+ self.feature_loss_coeff = feature_loss_coeff
+ self.vgg_feature_loss_coeff = vgg_feature_loss_coeff
+ self.kl_divergence_loss_coeff = kl_divergence_loss_coeff
+
+ self.discriminator = build_discriminator(self.image_shape)
+ self.generator = build_generator(self.mask_shape)
+ self.encoder = build_encoder(self.image_shape)
+ self.sampler = GaussianSampler(batch_size, latent_dim)
+ self.patch_size, self.combined_model = self.build_combined_generator()
+
+ self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
+ self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")
+ self.feat_loss_tracker = keras.metrics.Mean(name="feat_loss")
+ self.vgg_loss_tracker = keras.metrics.Mean(name="vgg_loss")
+ self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
+
+ @property
+ def metrics(self):
+ return [
+ self.disc_loss_tracker,
+ self.gen_loss_tracker,
+ self.feat_loss_tracker,
+ self.vgg_loss_tracker,
+ self.kl_loss_tracker,
+ ]
+
+ def build_combined_generator(self):
+ # This method builds a model that takes as inputs the following:
+ # latent vector, one-hot encoded segmentation label map, and
+ # a segmentation map. It then (i) generates an image with the generator,
+ # (ii) passes the generated images and segmentation map to the discriminator.
+ # Finally, the model produces the following outputs: (a) discriminator outputs,
+ # (b) generated image.
+ # We will be using this model to simplify the implementation.
+ self.discriminator.trainable = False
+ mask_input = keras.Input(shape=self.mask_shape, name="mask")
+ image_input = keras.Input(shape=self.image_shape, name="image")
+ latent_input = keras.Input(shape=(self.latent_dim,), name="latent")
+ generated_image = self.generator([latent_input, mask_input])
+ discriminator_output = self.discriminator([image_input, generated_image])
+ combined_outputs = discriminator_output + [generated_image]
+ patch_size = discriminator_output[-1].shape[1]
+ combined_model = keras.Model(
+ [latent_input, mask_input, image_input], combined_outputs
+ )
+ return patch_size, combined_model
+
+ def compile(self, gen_lr=1e-4, disc_lr=4e-4, **kwargs):
+ super().compile(**kwargs)
+ self.generator_optimizer = keras.optimizers.Adam(
+ gen_lr, beta_1=0.0, beta_2=0.999
+ )
+ self.discriminator_optimizer = keras.optimizers.Adam(
+ disc_lr, beta_1=0.0, beta_2=0.999
+ )
+ self.discriminator_loss = DiscriminatorLoss()
+ self.feature_matching_loss = FeatureMatchingLoss()
+ self.vgg_loss = VGGFeatureMatchingLoss()
+
+ def train_discriminator(self, latent_vector, segmentation_map, real_image, labels):
+ fake_images = self.generator([latent_vector, labels])
+ with tf.GradientTape() as gradient_tape:
+ pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
+ pred_real = self.discriminator([segmentation_map, real_image])[-1]
+ loss_fake = self.discriminator_loss(pred_fake, -1.0)
+ loss_real = self.discriminator_loss(pred_real, 1.0)
+ total_loss = 0.5 * (loss_fake + loss_real)
+
+ self.discriminator.trainable = True
+ gradients = gradient_tape.gradient(
+ total_loss, self.discriminator.trainable_variables
+ )
+ self.discriminator_optimizer.apply_gradients(
+ zip(gradients, self.discriminator.trainable_variables)
+ )
+ return total_loss
+
+ def train_generator(
+ self, latent_vector, segmentation_map, labels, image, mean, variance
+ ):
+ # Generator learns through the signal provided by the discriminator. During
+ # backpropagation, we only update the generator parameters.
+ self.discriminator.trainable = False
+ with tf.GradientTape() as tape:
+ real_d_output = self.discriminator([segmentation_map, image])
+ combined_outputs = self.combined_model(
+ [latent_vector, labels, segmentation_map]
+ )
+ fake_d_output, fake_image = combined_outputs[:-1], combined_outputs[-1]
+ pred = fake_d_output[-1]
+
+ # Compute generator losses.
+ g_loss = generator_loss(pred)
+ kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
+ vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
+ feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
+ real_d_output, fake_d_output
+ )
+ total_loss = g_loss + kl_loss + vgg_loss + feature_loss
+
+ all_trainable_variables = (
+ self.combined_model.trainable_variables + self.encoder.trainable_variables
+ )
+
+ gradients = tape.gradient(total_loss, all_trainable_variables)
+ self.generator_optimizer.apply_gradients(
+ zip(gradients, all_trainable_variables)
+ )
+ return total_loss, feature_loss, vgg_loss, kl_loss
+
+ def train_step(self, data):
+ segmentation_map, image, labels = data
+ mean, variance = self.encoder(image)
+ latent_vector = self.sampler([mean, variance])
+ discriminator_loss = self.train_discriminator(
+ latent_vector, segmentation_map, image, labels
+ )
+ (generator_loss, feature_loss, vgg_loss, kl_loss) = self.train_generator(
+ latent_vector, segmentation_map, labels, image, mean, variance
+ )
+
+ # Report progress.
+ self.disc_loss_tracker.update_state(discriminator_loss)
+ self.gen_loss_tracker.update_state(generator_loss)
+ self.feat_loss_tracker.update_state(feature_loss)
+ self.vgg_loss_tracker.update_state(vgg_loss)
+ self.kl_loss_tracker.update_state(kl_loss)
+ results = {m.name: m.result() for m in self.metrics}
+ return results
+
+ def test_step(self, data):
+ segmentation_map, image, labels = data
+ # Obtain the learned moments of the real image distribution.
+ mean, variance = self.encoder(image)
+
+ # Sample a latent from the distribution defined by the learned moments.
+ latent_vector = self.sampler([mean, variance])
+
+ # Generate the fake images.
+ fake_images = self.generator([latent_vector, labels])
+
+ # Calculate the losses.
+ pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
+ pred_real = self.discriminator([segmentation_map, image])[-1]
+ loss_fake = self.discriminator_loss(pred_fake, -1.0)
+ loss_real = self.discriminator_loss(pred_real, 1.0)
+ total_discriminator_loss = 0.5 * (loss_fake + loss_real)
+ real_d_output = self.discriminator([segmentation_map, image])
+ combined_outputs = self.combined_model(
+ [latent_vector, labels, segmentation_map]
+ )
+ fake_d_output, fake_image = combined_outputs[:-1], combined_outputs[-1]
+ pred = fake_d_output[-1]
+ g_loss = generator_loss(pred)
+ kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
+ vgg_loss = self.vgg_feature_loss_coeff * self.vgg_loss(image, fake_image)
+ feature_loss = self.feature_loss_coeff * self.feature_matching_loss(
+ real_d_output, fake_d_output
+ )
+ total_generator_loss = g_loss + kl_loss + vgg_loss + feature_loss
+
+ # Report progress.
+ self.disc_loss_tracker.update_state(total_discriminator_loss)
+ self.gen_loss_tracker.update_state(total_generator_loss)
+ self.feat_loss_tracker.update_state(feature_loss)
+ self.vgg_loss_tracker.update_state(vgg_loss)
+ self.kl_loss_tracker.update_state(kl_loss)
+ results = {m.name: m.result() for m in self.metrics}
+ return results
+
+ def call(self, inputs):
+ latent_vectors, labels = inputs
+ return self.generator([latent_vectors, labels])
+
+
+"""
+## GauGAN training
+"""
+
+gaugan = GauGAN(IMG_HEIGHT, NUM_CLASSES, BATCH_SIZE, latent_dim=256)
+gaugan.compile()
+history = gaugan.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=15,
+ callbacks=[GanMonitor(val_dataset, BATCH_SIZE)],
+)
+
+
+def plot_history(item):
+ plt.plot(history.history[item], label=item)
+ plt.plot(history.history["val_" + item], label="val_" + item)
+ plt.xlabel("Epochs")
+ plt.ylabel(item)
+ plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_history("disc_loss")
+plot_history("gen_loss")
+plot_history("feat_loss")
+plot_history("vgg_loss")
+plot_history("kl_loss")
+
+"""
+## Inference
+"""
+
+val_iterator = iter(val_dataset)
+
+for _ in range(5):
+ val_images = next(val_iterator)
+ # Sample latent from a normal distribution.
+ latent_vector = keras.random.normal(
+ shape=(gaugan.batch_size, gaugan.latent_dim), mean=0.0, stddev=2.0
+ )
+ # Generate fake images.
+ fake_images = gaugan.predict([latent_vector, val_images[2]])
+
+ real_images = val_images
+ grid_row = min(fake_images.shape[0], 3)
+ grid_col = 3
+ f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col * 6, grid_row * 6))
+ for row in range(grid_row):
+ ax = axarr if grid_row == 1 else axarr[row]
+ ax[0].imshow((real_images[0][row] + 1) / 2)
+ ax[0].axis("off")
+ ax[0].set_title("Mask", fontsize=20)
+ ax[1].imshow((real_images[1][row] + 1) / 2)
+ ax[1].axis("off")
+ ax[1].set_title("Ground Truth", fontsize=20)
+ ax[2].imshow((fake_images[row] + 1) / 2)
+ ax[2].axis("off")
+ ax[2].set_title("Generated", fontsize=20)
+ plt.show()
+
+"""
+## Final words
+
+* The dataset we used in this example is a small one. For obtaining even better results
+we recommend to use a bigger dataset. GauGAN results were demonstrated with the
+[COCO-Stuff](https://github.com/nightrome/cocostuff) and
+[CityScapes](https://www.cityscapes-dataset.com/) datasets.
+* This example was inspired the Chapter 6 of
+[Hands-On Image Generation with TensorFlow](https://www.packtpub.com/product/hands-on-image-generation-with-tensorflow/9781838826789)
+by [Soon-Yau Cheong](https://www.linkedin.com/in/soonyau/) and
+[Implementing SPADE using fastai](https://towardsdatascience.com/implementing-spade-using-fastai-6ad86b94030a) by
+[Divyansh Jha](https://medium.com/@divyanshj.16).
+* If you found this example interesting and exciting, you might want to check out
+[our repository](https://github.com/soumik12345/tf2_gans) which we are
+currently building. It will include reimplementations of popular GANs and pretrained
+models. Our focus will be on readability and making the code as accessible as possible.
+Our plain is to first train our implementation of GauGAN (following the code of
+this example) on a bigger dataset and then make the repository public. We welcome
+contributions!
+* Recently GauGAN2 was also released. You can check it out
+[here](https://blogs.nvidia.com/blog/2021/11/22/gaugan2-ai-art-demo/).
+
+"""
+"""
+Example available on HuggingFace.
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/GauGAN-Image-generation) | [](https://huggingface.co/spaces/keras-io/GauGAN_Conditional_Image_Generation) |
+"""
diff --git a/knowledge_base/generative/gpt2_text_generation_with_keras_hub.py b/knowledge_base/generative/gpt2_text_generation_with_keras_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..8076053470ff4c07a2b25a915aa568ec4f68b2e2
--- /dev/null
+++ b/knowledge_base/generative/gpt2_text_generation_with_keras_hub.py
@@ -0,0 +1,377 @@
+"""
+Title: GPT2 Text Generation with KerasHub
+Author: Chen Qian
+Date created: 2023/04/17
+Last modified: 2024/04/12
+Description: Use KerasHub GPT2 model and `samplers` to do text generation.
+Accelerator: GPU
+"""
+
+"""
+In this tutorial, you will learn to use [KerasHub](https://keras.io/keras_hub/) to load a
+pre-trained Large Language Model (LLM) - [GPT-2 model](https://openai.com/research/better-language-models)
+(originally invented by OpenAI), finetune it to a specific text style, and
+generate text based on users' input (also known as prompt). You will also learn
+how GPT2 adapts quickly to non-English languages, such as Chinese.
+"""
+
+"""
+## Before we begin
+
+Colab offers different kinds of runtimes. Make sure to go to **Runtime ->
+Change runtime type** and choose the GPU Hardware Accelerator runtime
+(which should have >12G host RAM and ~15G GPU RAM) since you will finetune the
+GPT-2 model. Running this tutorial on CPU runtime will take hours.
+"""
+
+"""
+## Install KerasHub, Choose Backend and Import Dependencies
+
+This examples uses [Keras 3](https://keras.io/keras_3/) to work in any of
+`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras 3 is baked into
+KerasHub, simply change the `"KERAS_BACKEND"` environment variable to select
+the backend of your choice. We select the JAX backend below.
+"""
+
+"""shell
+pip install git+https://github.com/keras-team/keras-hub.git -q
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch"
+
+import keras_hub
+import keras
+import tensorflow as tf
+import time
+
+keras.mixed_precision.set_global_policy("mixed_float16")
+
+"""
+## Introduction to Generative Large Language Models (LLMs)
+
+Large language models (LLMs) are a type of machine learning models that are
+trained on a large corpus of text data to generate outputs for various natural
+language processing (NLP) tasks, such as text generation, question answering,
+and machine translation.
+
+Generative LLMs are typically based on deep learning neural networks, such as
+the [Transformer architecture](https://arxiv.org/abs/1706.03762) invented by
+Google researchers in 2017, and are trained on massive amounts of text data,
+often involving billions of words. These models, such as Google [LaMDA](https://blog.google/technology/ai/lamda/)
+and [PaLM](https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html),
+are trained with a large dataset from various data sources which allows them to
+generate output for many tasks. The core of Generative LLMs is predicting the
+next word in a sentence, often referred as **Causal LM Pretraining**. In this
+way LLMs can generate coherent text based on user prompts. For a more
+pedagogical discussion on language models, you can refer to the
+[Stanford CS324 LLM class](https://stanford-cs324.github.io/winter2022/lectures/introduction/).
+"""
+
+"""
+## Introduction to KerasHub
+
+Large Language Models are complex to build and expensive to train from scratch.
+Luckily there are pretrained LLMs available for use right away. [KerasHub](https://keras.io/keras_hub/)
+provides a large number of pre-trained checkpoints that allow you to experiment
+with SOTA models without needing to train them yourself.
+
+KerasHub is a natural language processing library that supports users through
+their entire development cycle. KerasHub offers both pretrained models and
+modularized building blocks, so developers could easily reuse pretrained models
+or stack their own LLM.
+
+In a nutshell, for generative LLM, KerasHub offers:
+
+- Pretrained models with `generate()` method, e.g.,
+ `keras_hub.models.GPT2CausalLM` and `keras_hub.models.OPTCausalLM`.
+- Sampler class that implements generation algorithms such as Top-K, Beam and
+ contrastive search. These samplers can be used to generate text with
+ custom models.
+"""
+
+"""
+## Load a pre-trained GPT-2 model and generate some text
+
+KerasHub provides a number of pre-trained models, such as [Google
+Bert](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html)
+and [GPT-2](https://openai.com/research/better-language-models). You can see
+the list of models available in the [KerasHub repository](https://github.com/keras-team/keras-hub/tree/master/keras_hub/models).
+
+It's very easy to load the GPT-2 model as you can see below:
+"""
+
+# To speed up training and generation, we use preprocessor of length 128
+# instead of full length 1024.
+preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
+ "gpt2_base_en",
+ sequence_length=128,
+)
+gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
+ "gpt2_base_en", preprocessor=preprocessor
+)
+
+"""
+Once the model is loaded, you can use it to generate some text right away. Run
+the cells below to give it a try. It's as simple as calling a single function
+*generate()*:
+"""
+
+start = time.time()
+
+output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
+print("\nGPT-2 output:")
+print(output)
+
+end = time.time()
+print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
+
+"""
+Try another one:
+"""
+
+start = time.time()
+
+output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
+print("\nGPT-2 output:")
+print(output)
+
+end = time.time()
+print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
+
+"""
+Notice how much faster the second call is. This is because the computational
+graph is [XLA compiled](https://www.tensorflow.org/xla) in the 1st run and
+re-used in the 2nd behind the scenes.
+
+The quality of the generated text looks OK, but we can improve it via
+fine-tuning.
+"""
+
+"""
+## More on the GPT-2 model from KerasHub
+
+Next up, we will actually fine-tune the model to update its parameters, but
+before we do, let's take a look at the full set of tools we have to for working
+with for GPT2.
+
+The code of GPT2 can be found
+[here](https://github.com/keras-team/keras-hub/blob/master/keras_hub/models/gpt2/).
+Conceptually the `GPT2CausalLM` can be hierarchically broken down into several
+modules in KerasHub, all of which have a *from_preset()* function that loads a
+pretrained model:
+
+- `keras_hub.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a
+ [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt).
+- `keras_hub.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2
+ causal LM training. It does the tokenization along with other preprocessing
+ works such as creating the label and appending the end token.
+- `keras_hub.models.GPT2Backbone`: the GPT2 model, which is a stack of
+ `keras_hub.layers.TransformerDecoder`. This is usually just referred as
+ `GPT2`.
+- `keras_hub.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the
+ output of `GPT2Backbone` by embedding matrix to generate logits over
+ vocab tokens.
+"""
+
+"""
+## Finetune on Reddit dataset
+
+Now you have the knowledge of the GPT-2 model from KerasHub, you can take one
+step further to finetune the model so that it generates text in a specific
+style, short or long, strict or casual. In this tutorial, we will use reddit
+dataset for example.
+"""
+
+import tensorflow_datasets as tfds
+
+reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)
+
+"""
+Let's take a look inside sample data from the reddit TensorFlow Dataset. There
+are two features:
+
+- **__document__**: text of the post.
+- **__title__**: the title.
+
+"""
+
+for document, title in reddit_ds:
+ print(document.numpy())
+ print(title.numpy())
+ break
+
+"""
+In our case, we are performing next word prediction in a language model, so we
+only need the 'document' feature.
+"""
+
+train_ds = (
+ reddit_ds.map(lambda document, _: document)
+ .batch(32)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+"""
+Now you can finetune the model using the familiar *fit()* function. Note that
+`preprocessor` will be automatically called inside `fit` method since
+`GPT2CausalLM` is a `keras_hub.models.Task` instance.
+
+This step takes quite a bit of GPU memory and a long time if we were to train
+it all the way to a fully trained state. Here we just use part of the dataset
+for demo purposes.
+"""
+
+train_ds = train_ds.take(500)
+num_epochs = 1
+
+# Linearly decaying learning rate.
+learning_rate = keras.optimizers.schedules.PolynomialDecay(
+ 5e-5,
+ decay_steps=train_ds.cardinality() * num_epochs,
+ end_learning_rate=0.0,
+)
+loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+gpt2_lm.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=loss,
+ weighted_metrics=["accuracy"],
+)
+
+gpt2_lm.fit(train_ds, epochs=num_epochs)
+
+"""
+After fine-tuning is finished, you can again generate text using the same
+*generate()* function. This time, the text will be closer to Reddit writing
+style, and the generated length will be close to our preset length in the
+training set.
+"""
+
+start = time.time()
+
+output = gpt2_lm.generate("I like basketball", max_length=200)
+print("\nGPT-2 output:")
+print(output)
+
+end = time.time()
+print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
+
+"""
+## Into the Sampling Method
+
+In KerasHub, we offer a few sampling methods, e.g., contrastive search,
+Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but
+you can choose your own sampling method.
+
+Much like optimizer and activations, there are two ways to specify your custom
+sampler:
+
+- Use a string identifier, such as "greedy", you are using the default
+configuration via this way.
+- Pass a `keras_hub.samplers.Sampler` instance, you can use custom configuration
+via this way.
+"""
+
+# Use a string identifier.
+gpt2_lm.compile(sampler="top_k")
+output = gpt2_lm.generate("I like basketball", max_length=200)
+print("\nGPT-2 output:")
+print(output)
+
+# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
+greedy_sampler = keras_hub.samplers.GreedySampler()
+gpt2_lm.compile(sampler=greedy_sampler)
+
+output = gpt2_lm.generate("I like basketball", max_length=200)
+print("\nGPT-2 output:")
+print(output)
+
+"""
+For more details on KerasHub `Sampler` class, you can check the code
+[here](https://github.com/keras-team/keras-hub/tree/master/keras_hub/samplers).
+"""
+
+"""
+## Finetune on Chinese Poem Dataset
+
+We can also finetune GPT2 on non-English datasets. For readers knowing Chinese,
+this part illustrates how to fine-tune GPT2 on Chinese poem dataset to teach our
+model to become a poet!
+
+Because GPT2 uses byte-pair encoder, and the original pretraining dataset
+contains some Chinese characters, we can use the original vocab to finetune on
+Chinese dataset.
+"""
+
+"""shell
+# Load chinese poetry dataset.
+git clone https://github.com/chinese-poetry/chinese-poetry.git
+"""
+
+"""
+Load text from the json file. We only useใๅ
จๅ่ฏใfor demo purposes.
+"""
+
+import os
+import json
+
+poem_collection = []
+for file in os.listdir("chinese-poetry/ๅ
จๅ่ฏ"):
+ if ".json" not in file or "poet" not in file:
+ continue
+ full_filename = "%s/%s" % ("chinese-poetry/ๅ
จๅ่ฏ", file)
+ with open(full_filename, "r") as f:
+ content = json.load(f)
+ poem_collection.extend(content)
+
+paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]
+
+"""
+Let's take a look at sample data.
+"""
+
+print(paragraphs[0])
+
+"""
+Similar as Reddit example, we convert to TF dataset, and only use partial data
+to train.
+"""
+
+train_ds = (
+ tf.data.Dataset.from_tensor_slices(paragraphs)
+ .batch(16)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+# Running through the whole dataset takes long, only take `500` and run 1
+# epochs for demo purposes.
+train_ds = train_ds.take(500)
+num_epochs = 1
+
+learning_rate = keras.optimizers.schedules.PolynomialDecay(
+ 5e-4,
+ decay_steps=train_ds.cardinality() * num_epochs,
+ end_learning_rate=0.0,
+)
+loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+gpt2_lm.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=loss,
+ weighted_metrics=["accuracy"],
+)
+
+gpt2_lm.fit(train_ds, epochs=num_epochs)
+
+"""
+Let's check the result!
+"""
+
+output = gpt2_lm.generate("ๆจๅค้จ็้ฃ้ชค", max_length=200)
+print(output)
+
+"""
+Not bad ๐
+"""
diff --git a/knowledge_base/generative/lstm_character_level_text_generation.py b/knowledge_base/generative/lstm_character_level_text_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..66b6401350f12f726a4cd9f6c8cd3e1379a5c093
--- /dev/null
+++ b/knowledge_base/generative/lstm_character_level_text_generation.py
@@ -0,0 +1,132 @@
+"""
+Title: Character-level text generation with LSTM
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2015/06/15
+Last modified: 2020/04/30
+Description: Generate text from Nietzsche's writings with a character-level LSTM.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to use a LSTM model to generate
+text character-by-character.
+
+At least 20 epochs are required before the generated text
+starts sounding locally coherent.
+
+It is recommended to run this script on GPU, as recurrent
+networks are quite computationally intensive.
+
+If you try this script on new data, make sure your corpus
+has at least ~100k characters. ~1M is better.
+"""
+
+"""
+## Setup
+"""
+import keras
+from keras import layers
+
+import numpy as np
+import random
+import io
+
+"""
+## Prepare the data
+"""
+
+path = keras.utils.get_file(
+ "nietzsche.txt",
+ origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt",
+)
+with io.open(path, encoding="utf-8") as f:
+ text = f.read().lower()
+text = text.replace("\n", " ") # We remove newlines chars for nicer display
+print("Corpus length:", len(text))
+
+chars = sorted(list(set(text)))
+print("Total chars:", len(chars))
+char_indices = dict((c, i) for i, c in enumerate(chars))
+indices_char = dict((i, c) for i, c in enumerate(chars))
+
+# cut the text in semi-redundant sequences of maxlen characters
+maxlen = 40
+step = 3
+sentences = []
+next_chars = []
+for i in range(0, len(text) - maxlen, step):
+ sentences.append(text[i : i + maxlen])
+ next_chars.append(text[i + maxlen])
+print("Number of sequences:", len(sentences))
+
+x = np.zeros((len(sentences), maxlen, len(chars)), dtype="bool")
+y = np.zeros((len(sentences), len(chars)), dtype="bool")
+for i, sentence in enumerate(sentences):
+ for t, char in enumerate(sentence):
+ x[i, t, char_indices[char]] = 1
+ y[i, char_indices[next_chars[i]]] = 1
+
+
+"""
+## Build the model: a single LSTM layer
+"""
+
+model = keras.Sequential(
+ [
+ keras.Input(shape=(maxlen, len(chars))),
+ layers.LSTM(128),
+ layers.Dense(len(chars), activation="softmax"),
+ ]
+)
+optimizer = keras.optimizers.RMSprop(learning_rate=0.01)
+model.compile(loss="categorical_crossentropy", optimizer=optimizer)
+
+"""
+## Prepare the text sampling function
+"""
+
+
+def sample(preds, temperature=1.0):
+ # helper function to sample an index from a probability array
+ preds = np.asarray(preds).astype("float64")
+ preds = np.log(preds) / temperature
+ exp_preds = np.exp(preds)
+ preds = exp_preds / np.sum(exp_preds)
+ probas = np.random.multinomial(1, preds, 1)
+ return np.argmax(probas)
+
+
+"""
+## Train the model
+"""
+
+epochs = 40
+batch_size = 128
+
+for epoch in range(epochs):
+ model.fit(x, y, batch_size=batch_size, epochs=1)
+ print()
+ print("Generating text after epoch: %d" % epoch)
+
+ start_index = random.randint(0, len(text) - maxlen - 1)
+ for diversity in [0.2, 0.5, 1.0, 1.2]:
+ print("...Diversity:", diversity)
+
+ generated = ""
+ sentence = text[start_index : start_index + maxlen]
+ print('...Generating with seed: "' + sentence + '"')
+
+ for i in range(400):
+ x_pred = np.zeros((1, maxlen, len(chars)))
+ for t, char in enumerate(sentence):
+ x_pred[0, t, char_indices[char]] = 1.0
+ preds = model.predict(x_pred, verbose=0)[0]
+ next_index = sample(preds, diversity)
+ next_char = indices_char[next_index]
+ sentence = sentence[1:] + next_char
+ generated += next_char
+
+ print("...Generated: ", generated)
+ print("-")
diff --git a/knowledge_base/generative/midi_generation_with_transformer.py b/knowledge_base/generative/midi_generation_with_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1527dc86da693ca1d7b4bc6ebd6521d477cf4e3c
--- /dev/null
+++ b/knowledge_base/generative/midi_generation_with_transformer.py
@@ -0,0 +1,722 @@
+"""
+Title: Music Generation with Transformer Models
+Author: [Joaquin Jimenez](https://github.com/johacks/)
+Date created: 2024/11/22
+Last modified: 2024/11/26
+Description: Use a Transformer model to train on MIDI data and generate music sequences.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this tutorial, we learn how to build a music generation model using a
+Transformer decode-only architecture.
+The model is trained on the [Maestro dataset](https://magenta.tensorflow.org/datasets/maestro)
+and implemented using keras 3.
+In the process, we explore MIDI tokenization, and relative global attention mechanisms.
+
+This example is based on the paper "Music Transformer" by Huang et al. (2018).
+Check out the original [paper](https://arxiv.org/abs/1809.04281) and
+[code](https://github.com/jason9693/MusicTransformer-tensorflow2.0).
+"""
+
+"""
+## Setup
+
+Before we start, let's import and install all the libraries we need.
+"""
+
+"""shell
+pip install -qq midi_neural_processor
+pip install -qq keras_hub
+pip install -qq "keras>=3.6.0" # Allows use of keras.utils.Config.
+"""
+
+"""
+### Optional dependencies
+
+To hear the audio, install the following additional dependencies:
+"""
+
+"""shell
+sudo apt-get -qq install -y fluidsynth 2> /dev/null
+pip install -qq pyfluidsynth scipy
+"""
+
+import os
+import random
+import tempfile
+
+import keras
+import midi_neural_processor.processor as midi_tokenizer
+import numpy as np
+from keras import callbacks, layers, ops, optimizers, utils
+from keras_hub import layers as hub_layers
+from os import path
+
+"""
+## Configuration
+
+Lets define the configuration for the model and the dataset to be used in this example.
+"""
+event_range = midi_tokenizer.RANGE_NOTE_ON
+event_range += midi_tokenizer.RANGE_NOTE_OFF
+event_range += midi_tokenizer.RANGE_TIME_SHIFT
+event_range += midi_tokenizer.RANGE_VEL
+CONFIG = utils.Config(
+ max_sequence_len=2048,
+ embedding_dim=256,
+ num_transformer_blocks=6,
+ batch_size=6,
+ token_pad=event_range,
+ token_start_of_sentence=event_range + 1,
+ token_end_of_sentence=event_range + 2,
+ vocabulary_size=event_range + 3,
+ model_out="tmp/music_transformer.keras",
+ seed=42,
+)
+utils.set_random_seed(CONFIG.seed)
+
+
+"""
+## Maestro dataset
+
+The Maestro dataset contains MIDI files for piano performances.
+
+### Download the dataset
+
+We now download and extract the dataset, then move the MIDI files to a new directory.
+"""
+
+
+def download_maestro(output_dir=None):
+ """Download the Maestro MIDI dataset.
+ Extracted from: https://magenta.tensorflow.org/datasets/maestro
+ """
+ # Ensure the output directory exists
+ output_dir = tempfile.mkdtemp() if output_dir is None else output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Download and extract zip file
+ dir = utils.get_file(
+ origin="https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip",
+ extract=True,
+ )
+
+ # Gather all MIDI files
+ midi_files, file_paths = set(), list()
+ for root, _, files in os.walk(dir):
+ for file in files:
+ if file.lower().endswith(".midi") or file.lower().endswith(".mid"):
+ midi_files.add(path.join(root, file))
+
+ # Move the files to the output directory
+ for file in sorted(midi_files):
+ file_paths.append(new_path := path.join(output_dir, path.basename(file)))
+ os.rename(file, new_path)
+ return file_paths
+
+
+paths = list(sorted(download_maestro(output_dir="datasets/maestro")))
+output_dir = path.dirname(paths[0])
+
+
+"""
+### Split the dataset
+
+We can now split the dataset into training and validation sets.
+"""
+
+indices = np.random.permutation(len(paths))
+split = int(len(paths) * 0.1)
+train_paths = [paths[i] for i in indices[split:]]
+val_paths = [paths[i] for i in indices[:split]]
+
+"""
+### Hear a MIDI file
+
+We use the pretty_midi library and fluidsynth to convert MIDI files into waveform audio.
+This allows us to listen to the data samples before and after processing.
+
+The following dependencies are required to play the audio:
+- fluidsynth: `sudo apt install -y fluidsynth`
+- pyfluidsynth, scipy: `pip install pyfluidsynth scipy`
+"""
+
+
+def visualize_midi(midi_path, sampling_rate=16000, seconds=15, out_dir=None):
+ import pretty_midi
+ from scipy.io.wavfile import write as write_wav
+ from IPython.display import Audio
+
+ # Create the audio waveform
+ pretty_midi_file = pretty_midi.PrettyMIDI(midi_path)
+ waveform = pretty_midi_file.fluidsynth(fs=sampling_rate)[: seconds * sampling_rate]
+
+ # Display the audio if no path is provided
+ if out_dir is None:
+ # IPython display
+ return Audio(waveform, rate=sampling_rate)
+
+ # Save the audio to a file
+ os.makedirs(out_dir, exist_ok=True)
+ audio_path = path.join(out_dir, path.basename(midi_path).split(".")[0] + ".wav")
+ write_wav(audio_path, sampling_rate, (waveform * 32767).astype(np.int16))
+ return audio_path
+
+
+print(visualize_midi(train_paths[0], out_dir="tmp/")) # Saved audio path
+visualize_midi(train_paths[0]) # Display the audio if in a Jupyter notebook
+
+
+"""
+### Tokenize the data
+
+We now preprocess the MIDI files into a tokenized format for training.
+"""
+
+
+def encode_midi_task(midi_path):
+ """Define a task that tokenizes a MIDI file."""
+ import midi_neural_processor.processor as midi_tokenizer
+
+ return midi_tokenizer.encode_midi(midi_path)
+
+
+def preprocess_midi_files(file_paths, save_dir=None):
+ """Preprocess a list of MIDI files and save the notes to a file."""
+ from multiprocessing import Pool, cpu_count
+
+ # Assume all files are in the same directory and save to the same directory
+ save_dir = path.dirname(file_paths[0]) if save_dir is None else save_dir
+ os.makedirs(save_dir, exist_ok=True)
+
+ # Check if the notes have already been preprocessed
+ output_file = path.join(save_dir, "notes.npz")
+ if path.exists(output_file):
+ npz_file = np.load(output_file)
+ return [npz_file[key] for key in npz_file.keys()]
+
+ # Preprocess the MIDI files in parallel
+ progbar = utils.Progbar(len(file_paths), unit_name="MIDI_file", interval=5)
+ pool = Pool(cpu_count() - 1)
+ all_notes = []
+ for notes in pool.imap_unordered(encode_midi_task, file_paths):
+ progbar.add(1)
+ all_notes.append(np.array(notes))
+
+ # Save the notes to a file
+ np.savez(output_file, *all_notes)
+ return all_notes
+
+
+train_midis = preprocess_midi_files(train_paths, path.join(output_dir, "train"))
+val_midis = preprocess_midi_files(val_paths, path.join(output_dir, "val"))
+
+
+"""
+### Dataset objects
+
+We now define a dataset class that yields batches of input sequences and target sequences.
+"""
+
+
+class MidiDataset(utils.PyDataset):
+ """A dataset for MIDI files that yields batches of input sequences and target sequences."""
+
+ def __init__(
+ self,
+ encoded_midis,
+ batch_size=CONFIG.batch_size,
+ max_sequence_len=CONFIG.max_sequence_len,
+ ):
+ super(MidiDataset, self).__init__()
+ self.batch_size = batch_size
+ self.max_sequence_len = max_sequence_len
+ self.encoded_midis = encoded_midis
+ batches, last_batch_size = divmod(len(encoded_midis), batch_size)
+ self._num_batches = batches + int(last_batch_size > 0)
+
+ def __len__(self):
+ """Get the number of batches."""
+ return self._num_batches
+
+ def __getitem__(self, idx):
+ """Generate random inputs and corresponding targets for the model."""
+ # Same as in the original paper, we always get a random batch.
+ # See: https://github.com/jason9693/MusicTransformer-tensorflow2.0/blob/f7c06c0cb2e9cdddcbf6db779cb39cd650282778/data.py
+ batch = random.sample(self.encoded_midis, k=self.batch_size)
+
+ # Convert the batch to sequences
+ batch_data = [
+ self._get_sequence(midi, self.max_sequence_len + 1) for midi in batch
+ ]
+ batch_data = np.array(batch_data)
+
+ # Split the data into input and target sequences
+ return batch_data[:, :-1], batch_data[:, 1:]
+
+ def _get_sequence(self, data, max_length):
+ """Get a random sequence of notes from a file."""
+ # Truncate or pad the sequence
+ if len(data) > max_length:
+ start = random.randrange(0, len(data) - max_length)
+ data = data[start : start + max_length]
+ elif len(data) < max_length:
+ data = np.append(data, CONFIG.token_end_of_sentence)
+
+ # Pad the sequence if necessary
+ if len(data) < max_length:
+ data = np.concatenate(
+ (data, np.full(max_length - len(data), CONFIG.token_pad))
+ )
+ return np.asanyarray(data, dtype="int32")
+
+
+train_dataset, val_dataset = MidiDataset(train_midis), MidiDataset(val_midis)
+
+
+"""
+## Model definition
+
+It is time to define the model architecture. We use a Transformer decoder
+architecture with a custom attention mechanism, relative global attention.
+
+### Relative Global Attention
+
+The following code implements the Relative Global Attention layer. It is used
+in place of the standard multi-head attention layer in the Transformer decoder.
+The main difference is that it includes a relative positional encoding that
+allows the model to learn relative positional information between tokens.
+"""
+
+
+@keras.utils.register_keras_serializable()
+class RelativeGlobalAttention(layers.Layer):
+ """
+ From Music Transformer (Huang et al., 2018)
+ https://arxiv.org/abs/1809.04281
+ """
+
+ def __init__(self, num_heads, embedding_dim, max_sequence_len, **kwargs):
+ super().__init__(**kwargs)
+ self.key_length = None
+ self.max_sequence_len = max_sequence_len
+ self.relative_embedding = None
+ self.num_heads = num_heads
+ self.embedding_dim = embedding_dim
+ self.head_dim = embedding_dim // num_heads
+ self.query_dense = layers.Dense(int(self.embedding_dim))
+ self.key_dense = layers.Dense(int(self.embedding_dim))
+ self.value_dense = layers.Dense(int(self.embedding_dim))
+ self.output_dense = layers.Dense(embedding_dim, name="output")
+
+ def build(self, input_shape):
+ self.query_length = input_shape[0][1]
+ self.key_length = input_shape[1][1]
+ self.relative_embedding = self.add_weight(
+ (self.max_sequence_len, int(self.head_dim)), name="relative_embedding"
+ )
+
+ def _apply_dense_layer_and_split_heads(self, inputs, dense_layer):
+ # Apply linear transformation
+ inputs = dense_layer(inputs)
+ new_shape = ops.shape(inputs)
+ # Reshape to split by attention heads
+ reshaped = ops.reshape(inputs, (new_shape[0], new_shape[1], self.num_heads, -1))
+ # Transpose for head-first format
+ return ops.transpose(reshaped, (0, 2, 1, 3))
+
+ def call(self, inputs, mask=None):
+ # Compute Q, K, V: Batch, head, sequence, features
+ query = self._apply_dense_layer_and_split_heads(inputs[0], self.query_dense)
+ key = self._apply_dense_layer_and_split_heads(inputs[1], self.key_dense)
+ value = self._apply_dense_layer_and_split_heads(inputs[2], self.value_dense)
+
+ # Compute scaled dot-product attention scores
+ attention_scores = ops.matmul(query, ops.transpose(key, [0, 1, 3, 2]))
+
+ # Compute relative positional encoding and combine with attention scores
+ start_idx = max(0, self.max_sequence_len - ops.shape(query)[2])
+ relative_embedding = self.relative_embedding[start_idx:, :]
+ attention_scores += self._compute_attention_scores(query, relative_embedding)
+ logits = attention_scores / ops.sqrt(self.head_dim)
+
+ # Apply mask if provided
+ if mask is not None:
+ logits += ops.cast(mask, "float32") * -1e9
+
+ # Compute attention weights
+ attention_weights = ops.nn.softmax(logits, axis=-1)
+ attention_output = ops.matmul(attention_weights, value)
+
+ # Merge heads and apply final linear transformation
+ merged_attention = ops.transpose(attention_output, (0, 2, 1, 3))
+ merged_attention = ops.reshape(
+ merged_attention, (ops.shape(merged_attention)[0], -1, self.embedding_dim)
+ )
+ output = self.output_dense(merged_attention)
+
+ return output, attention_weights
+
+ def _compute_attention_scores(self, query, relative_embedding):
+ """
+ Compute relative attention scores using positional encodings.
+ """
+ relative_scores = ops.einsum("bhld, md->bhlm", query, relative_embedding)
+ relative_scores = self._apply_mask_to_relative_scores(relative_scores)
+ return self._skew_attention_scores(relative_scores)
+
+ def _apply_mask_to_relative_scores(self, scores):
+ """
+ Apply masking to relative positional scores to ignore future positions.
+ """
+ mask = ops.flip(
+ ops.tri(scores.shape[-2], scores.shape[-1], dtype="float32"), axis=1
+ )
+ return mask * scores
+
+ def _skew_attention_scores(self, scores):
+ """
+ Perform skewing operation to align relative attention scores with the sequence.
+ """
+ padded_scores = ops.pad(scores, ((0, 0), (0, 0), (0, 0), (1, 0)))
+ padded_shape = ops.shape(padded_scores)
+ reshaped_scores = ops.reshape(
+ padded_scores, (-1, padded_shape[1], padded_shape[-1], padded_shape[-2])
+ )
+ skewed_scores = reshaped_scores[:, :, 1:, :]
+
+ if self.key_length > self.query_length:
+ size_diff = self.key_length - self.query_length
+ return ops.pad(skewed_scores, [[0, 0], [0, 0], [0, 0], [0, size_diff]])
+ else:
+ return skewed_scores[:, :, :, : self.key_length]
+
+
+"""
+### Decoder Layer
+
+Using the RelativeGlobalAttention layer, we can define the DecoderLayer. It is mostly like
+the standard Transformer decoder layer but with the custom attention mechanism.
+"""
+
+
+@keras.utils.register_keras_serializable()
+class DecoderLayer(layers.Layer):
+ def __init__(self, embedding_dim, num_heads, max_sequence_len, dropout=0.1):
+ super(DecoderLayer, self).__init__()
+
+ # Initialize attributes
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.max_sequence_len = max_sequence_len
+
+ # Initialize layers
+ self.relative_global_attention_1 = RelativeGlobalAttention(
+ num_heads, embedding_dim, max_sequence_len
+ )
+
+ self.feed_forward_network_pre = layers.Dense(self.embedding_dim // 2, "relu")
+ self.feed_forward_network_pos = layers.Dense(self.embedding_dim)
+
+ self.layer_normalization_1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layer_normalization_2 = layers.LayerNormalization(epsilon=1e-6)
+
+ self.dropout_1 = layers.Dropout(dropout)
+ self.dropout_2 = layers.Dropout(dropout)
+
+ def call(self, inputs, mask=None, training=False):
+ # Attention block. Inputs are (query, key, value)
+ attention_out, attention_weights = self.relative_global_attention_1(
+ (inputs, inputs, inputs), mask=mask
+ )
+ attention_out = self.dropout_1(attention_out, training=training)
+ attention_out_normalized = self.layer_normalization_1(attention_out + inputs)
+
+ ffn_out = self.feed_forward_network_pre(attention_out)
+ ffn_out = self.feed_forward_network_pos(ffn_out)
+ ffn_out = self.dropout_2(ffn_out, training=training)
+ out = self.layer_normalization_2(attention_out_normalized + ffn_out)
+
+ return out, attention_weights
+
+
+"""
+### Decoder
+
+The Decoder layer is composed of multiple DecoderLayer blocks. It also includes
+an embedding layer that converts our tokenized input into an embedding representation.
+"""
+
+
+@keras.utils.register_keras_serializable()
+class Decoder(layers.Layer):
+ def __init__(
+ self, embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout
+ ):
+ super(Decoder, self).__init__()
+
+ self.embedding_dim = embedding_dim
+ self.num_blocks = num_blocks
+
+ self.embedding = layers.Embedding(vocabulary_size, self.embedding_dim)
+ self.positional_encoding = hub_layers.SinePositionEncoding()
+
+ self.decode_layers = [
+ DecoderLayer(
+ embedding_dim, embedding_dim // 64, max_sequence_len, dropout=dropout
+ )
+ for _ in range(num_blocks)
+ ]
+ self.dropout = layers.Dropout(dropout)
+
+ def call(self, inputs, mask=None, training=False, return_attention_weights=False):
+ weights = []
+
+ # Adding embedding and position encoding.
+ x = self.embedding(inputs)
+ x = x * ops.sqrt(ops.cast(self.embedding_dim, "float32"))
+ x = x + self.positional_encoding(x)
+ x = self.dropout(x, training=training)
+
+ # Passing through the transformer blocks.
+ for i in range(self.num_blocks):
+ x, w = self.decode_layers[i](x, mask=mask, training=training)
+ weights.append(w)
+ if return_attention_weights:
+ return x, weights
+ return x
+
+
+"""
+### Music Transformer Decoder
+
+With the above layers defined, we can now define the MusicTransformerDecoder model. It applies
+a linear transformation to the output of the decoder to get the logits for each token.
+"""
+
+
+@keras.utils.register_keras_serializable()
+class MusicTransformerDecoder(keras.Model):
+ def __init__(
+ self,
+ embedding_dim=CONFIG.embedding_dim,
+ vocabulary_size=CONFIG.vocabulary_size,
+ num_blocks=CONFIG.num_transformer_blocks,
+ max_sequence_len=CONFIG.max_sequence_len,
+ dropout=0.2,
+ ):
+ # Initialize attributes
+ super(MusicTransformerDecoder, self).__init__()
+ self.embedding_dim = embedding_dim
+ self.vocabulary_size = vocabulary_size
+ self.num_blocks = num_blocks
+ self.max_sequence_len = max_sequence_len
+
+ # Initialize layers
+ # Transformer decoder
+ self.decoder = Decoder(
+ embedding_dim, vocabulary_size, max_sequence_len, num_blocks, dropout
+ )
+ # Output layer
+ self.fc = layers.Dense(self.vocabulary_size, activation=None, name="output")
+
+ @staticmethod
+ def get_look_ahead_mask(max_sequence_len, inputs):
+ sequence_length = min(max_sequence_len, inputs.shape[1])
+ sequence_mask = ops.logical_not(
+ ops.tri(sequence_length, sequence_length, dtype="bool")
+ )
+
+ inputs = ops.cast(inputs[:, None, None, :], "int32")
+ output_pad_tensor = ops.ones_like(inputs) * CONFIG.token_pad
+ decoder_output_mask = ops.equal(inputs, output_pad_tensor)
+ return ops.cast(ops.logical_or(decoder_output_mask, sequence_mask), "int32")
+
+ def call(self, inputs, training=False):
+ mask = self.get_look_ahead_mask(self.max_sequence_len, inputs)
+ decoding = self.decoder(
+ inputs, mask=mask, training=training, return_attention_weights=False
+ )
+ return self.fc(decoding)
+
+ # --- Sequence generation methods
+
+ def generate(self, inputs: list, length=CONFIG.max_sequence_len, top_k=5):
+ inputs = ops.convert_to_tensor([inputs])
+
+ # Generate a new token using output distribution at given index
+ def generate_token(inputs, end_idx):
+ distribution = ops.stop_gradient(self.call(inputs)[0, end_idx])
+
+ # Select the top-k tokens and their probabilities
+ top_k_distribution, top_k_indices = ops.top_k(distribution, k=top_k)
+
+ # Sample from the top-k probabilities
+ new_token_idx = keras.random.categorical(top_k_distribution[None, :], 1)
+ return ops.take(top_k_indices, new_token_idx[0])
+
+ # Compute the number of tokens to add
+ added_tokens = min(length, self.max_sequence_len - inputs.shape[1])
+ progbar = utils.Progbar(added_tokens, unit_name="token", interval=5)
+
+ # Pad the input sequence that will be filled with generated tokens
+ out = ops.pad(inputs, ((0, 0), (0, added_tokens)), "constant", CONFIG.token_pad)
+
+ # Generate tokens using top-k sampling
+ for token_idx in range(inputs.shape[1] - 1, inputs.shape[1] - 1 + added_tokens):
+ token = ops.cast(generate_token(out, end_idx=token_idx), out.dtype)
+ out = ops.scatter_update(out, ((0, token_idx + 1),), token)
+ progbar.add(1)
+
+ return ops.convert_to_numpy(out[0])
+
+ # --- Serialization methods
+
+ def get_config(self):
+ atts = ["embedding_dim", "vocabulary_size", "num_blocks", "max_sequence_len"]
+ return {a: getattr(self, a) for a in atts}
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(**config)
+
+
+"""
+### Loss function
+
+We define a custom loss function that computes the categorical cross-entropy
+loss for the model. It is computed only for non-padding tokens and uses
+`from_logits=True` since the model outputs logits.
+"""
+
+
+@keras.utils.register_keras_serializable()
+def train_loss(y_true, y_pred):
+ mask = ops.cast(ops.logical_not(ops.equal(y_true, CONFIG.token_pad)), "float32")
+ y_true = ops.one_hot(ops.cast(y_true, "int32"), CONFIG.vocabulary_size)
+ return ops.categorical_crossentropy(y_true, y_pred, from_logits=True) * mask
+
+
+"""
+### Learning rate schedule
+
+Following the Music Transformer paper, we define an adapted exponential decay
+learning rate schedule that takes into account the embedding dimension.
+"""
+
+
+@keras.utils.register_keras_serializable()
+class CustomSchedule(optimizers.schedules.LearningRateSchedule):
+ def __init__(self, embedding_dim, warmup_steps=4000):
+ super(CustomSchedule, self).__init__()
+
+ self.embedding_dim = embedding_dim
+ self.warmup_steps = warmup_steps
+
+ self._embedding_dim = ops.cast(self.embedding_dim, "float32")
+ # Numerical stability adjustment on torch, which is less precise
+ self._lr_adjust = 0.1 if keras.backend.backend() == "torch" else 1.0
+
+ def get_config(self):
+ return {"embedding_dim": self.embedding_dim, "warmup_steps": self.warmup_steps}
+
+ def __call__(self, step):
+ step_rsqrt = ops.rsqrt(ops.cast(step, "float32"))
+ warmup_adjust = step * (self.warmup_steps**-1.5)
+ output = ops.rsqrt(self._embedding_dim) * ops.minimum(step_rsqrt, warmup_adjust)
+ return self._lr_adjust * output
+
+
+"""
+## Training the model
+
+We can now train the model on the Maestro dataset. First, we define a training
+function. This function compiles the model, trains it, and saves the best model
+checkpoint. This way, we can continue training from the best model checkpoint
+if needed.
+"""
+
+
+def train_model(model, train_ds, val_ds, epochs=15):
+ # Configure optimizer
+ learning_rate = CustomSchedule(CONFIG.embedding_dim)
+ optimizer = optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
+
+ # Compile the model
+ model.compile(optimizer=optimizer, loss=train_loss)
+
+ # Train the model
+ save_cb = callbacks.ModelCheckpoint(CONFIG.model_out, save_best_only=True)
+ model.fit(
+ train_ds, validation_data=val_ds, epochs=epochs, callbacks=[save_cb], verbose=2
+ )
+ return model
+
+
+"""
+We can now train the model on the Maestro dataset. If a model checkpoint exists,
+we can load it and continue training.
+"""
+if path.exists(CONFIG.model_out):
+ model = keras.models.load_model(CONFIG.model_out)
+ # Comment out to continue model training from the checkpoint
+ # train_model(model, train_dataset, val_dataset, epochs=10)
+else:
+ # Train the model
+ model = train_model(MusicTransformerDecoder(), train_dataset, val_dataset)
+
+
+"""
+## Generate music
+
+We can now generate music using the trained model. We use an existing MIDI file
+as a seed and generate a new sequence.
+"""
+
+
+def generate_music(model, seed_path, length=1024, out_dir=None, top_k=None):
+ # Ensure the output directory exists
+ out_dir = out_dir if out_dir is not None else tempfile.mkdtemp()
+ os.makedirs(out_dir, exist_ok=True)
+
+ # Get some tokens from the MIDI file
+ inputs = midi_tokenizer.encode_midi(seed_path)[100:125]
+ print(f"Seed tokens: {inputs}")
+
+ # Generate music that follows the input tokens until the maximum length
+ result = model.generate(inputs, length=length, top_k=top_k)
+
+ output_path = path.join(out_dir, path.basename(seed_path).split(".")[0] + ".mid")
+ midi_tokenizer.decode_midi(result, output_path)
+ return output_path
+
+
+output_file = generate_music(model, val_paths[-1], out_dir="tmp/", top_k=15)
+print(visualize_midi(output_file, out_dir="tmp/")) # Saved audio path
+visualize_midi(output_file) # Display the audio if in a Jupyter notebook
+
+"""
+## Conclusion
+
+In this example, we learned how to build a music generation model using a custom
+Transformer decoder architecture.
+
+We did it following the Music Transformer paper by Huang et al. (2018).
+To do so we had to:
+
+- Define a custom loss function and learning rate schedule.
+- Define a custom attention mechanism.
+- Preprocess MIDI files into a tokenized format.
+
+After training the model on the Maestro dataset, we generated music sequences
+using a seed MIDI file.
+
+### Next steps
+
+We could further improve inference times by caching attention weights during the
+forward pass, in a similar way as `keras_hub` `CausalLM` models, which use the
+`CachedMultiHeadAttention` layer.
+"""
diff --git a/knowledge_base/generative/molecule_generation.py b/knowledge_base/generative/molecule_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ef4b8f285a20df2ac4baac34f913d6aae21cc4
--- /dev/null
+++ b/knowledge_base/generative/molecule_generation.py
@@ -0,0 +1,623 @@
+"""
+Title: Drug Molecule Generation with VAE
+Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
+Date created: 2022/03/10
+Last modified: 2024/12/17
+Description: Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we use a Variational Autoencoder to generate molecules for drug discovery.
+We use the research papers
+[Automatic chemical design using a data-driven continuous representation of molecules](https://arxiv.org/abs/1610.02415)
+and [MolGAN: An implicit generative model for small molecular graphs](https://arxiv.org/abs/1805.11973)
+as a reference.
+
+The model described in the paper **Automatic chemical design using a data-driven
+continuous representation of molecules** generates new molecules via efficient exploration
+of open-ended spaces of chemical compounds. The model consists of
+three components: Encoder, Decoder and Predictor. The Encoder converts the discrete
+representation of a molecule into a real-valued continuous vector, and the Decoder
+converts these continuous vectors back to discrete molecule representations. The
+Predictor estimates chemical properties from the latent continuous vector representation
+of the molecule. Continuous representations allow the use of gradient-based
+optimization to efficiently guide the search for optimized functional compounds.
+
+
+
+**Figure (a)** - A diagram of the autoencoder used for molecule design, including the
+joint property prediction model. Starting from a discrete molecule representation, such
+as a SMILES string, the encoder network converts each molecule into a vector in the
+latent space, which is effectively a continuous molecule representation. Given a point
+in the latent space, the decoder network produces a corresponding SMILES string. A
+multilayer perceptron network estimates the value of target properties associated with
+each molecule.
+
+**Figure (b)** - Gradient-based optimization in continuous latent space. After training a
+surrogate model `f(z)` to predict the properties of molecules based on their latent
+representation `z`, we can optimize `f(z)` with respect to `z` to find new latent
+representations expected to match specific desired properties. These new latent
+representations can then be decoded into SMILES strings, at which point their properties
+can be tested empirically.
+
+For an explanation and implementation of MolGAN, please refer to the Keras Example
+[**WGAN-GP with R-GCN for the generation of small molecular graphs**](https://bit.ly/3pU6zXK) by
+Alexander Kensert. Many of the functions used in the present example are from the above Keras example.
+"""
+
+"""
+## Setup
+
+RDKit is an open source toolkit for cheminformatics and machine learning. This toolkit come in handy
+if one is into drug discovery domain. In this example, RDKit is used to conveniently
+and efficiently transform SMILES to molecule objects, and then from those obtain sets of atoms
+and bonds.
+
+Quoting from
+[WGAN-GP with R-GCN for the generation of small molecular graphs](https://keras.io/examples/generative/wgan-graphs/)):
+
+**"SMILES expresses the structure of a given molecule in the form of an ASCII string.
+The SMILES string is a compact encoding which, for smaller molecules, is relatively human-readable.
+Encoding molecules as a string both alleviates and facilitates database and/or web searching
+of a given molecule. RDKit uses algorithms to accurately transform a given SMILES to
+a molecule object, which can then be used to compute a great number of molecular properties/features."**
+"""
+
+"""shell
+pip -q install rdkit-pypi==2021.9.4
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import ast
+
+import pandas as pd
+import numpy as np
+
+import tensorflow as tf
+import keras
+from keras import layers
+from keras import ops
+
+import matplotlib.pyplot as plt
+from rdkit import Chem, RDLogger
+from rdkit.Chem import BondType
+from rdkit.Chem.Draw import MolsToGridImage
+
+RDLogger.DisableLog("rdApp.*")
+
+"""
+## Dataset
+
+We use the [**ZINC โ A Free Database of Commercially Available Compounds for
+Virtual Screening**](https://bit.ly/3IVBI4x) dataset. The dataset comes with molecule
+formula in SMILE representation along with their respective molecular properties such as
+**logP** (waterโoctanal partition coefficient), **SAS** (synthetic
+accessibility score) and **QED** (Qualitative Estimate of Drug-likeness).
+
+"""
+
+csv_path = keras.utils.get_file(
+ "250k_rndm_zinc_drugs_clean_3.csv",
+ "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv",
+)
+
+df = pd.read_csv(csv_path)
+df["smiles"] = df["smiles"].apply(lambda s: s.replace("\n", ""))
+df.head()
+
+"""
+## Hyperparameters
+"""
+
+SMILE_CHARSET = '["C", "B", "F", "I", "H", "O", "N", "S", "P", "Cl", "Br"]'
+
+bond_mapping = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3}
+bond_mapping.update(
+ {0: BondType.SINGLE, 1: BondType.DOUBLE, 2: BondType.TRIPLE, 3: BondType.AROMATIC}
+)
+SMILE_CHARSET = ast.literal_eval(SMILE_CHARSET)
+
+MAX_MOLSIZE = max(df["smiles"].str.len())
+SMILE_to_index = dict((c, i) for i, c in enumerate(SMILE_CHARSET))
+index_to_SMILE = dict((i, c) for i, c in enumerate(SMILE_CHARSET))
+atom_mapping = dict(SMILE_to_index)
+atom_mapping.update(index_to_SMILE)
+
+BATCH_SIZE = 100
+EPOCHS = 10
+
+VAE_LR = 5e-4
+NUM_ATOMS = 120 # Maximum number of atoms
+
+ATOM_DIM = len(SMILE_CHARSET) # Number of atom types
+BOND_DIM = 4 + 1 # Number of bond types
+LATENT_DIM = 435 # Size of the latent space
+
+
+def smiles_to_graph(smiles):
+ # Converts SMILES to molecule object
+ molecule = Chem.MolFromSmiles(smiles)
+
+ # Initialize adjacency and feature tensor
+ adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")
+ features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")
+
+ # loop over each atom in molecule
+ for atom in molecule.GetAtoms():
+ i = atom.GetIdx()
+ atom_type = atom_mapping[atom.GetSymbol()]
+ features[i] = np.eye(ATOM_DIM)[atom_type]
+ # loop over one-hop neighbors
+ for neighbor in atom.GetNeighbors():
+ j = neighbor.GetIdx()
+ bond = molecule.GetBondBetweenAtoms(i, j)
+ bond_type_idx = bond_mapping[bond.GetBondType().name]
+ adjacency[bond_type_idx, [i, j], [j, i]] = 1
+
+ # Where no bond, add 1 to last channel (indicating "non-bond")
+ # Notice: channels-first
+ adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1
+
+ # Where no atom, add 1 to last column (indicating "non-atom")
+ features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1
+
+ return adjacency, features
+
+
+def graph_to_molecule(graph):
+ # Unpack graph
+ adjacency, features = graph
+
+ # RWMol is a molecule object intended to be edited
+ molecule = Chem.RWMol()
+
+ # Remove "no atoms" & atoms with no bonds
+ keep_idx = np.where(
+ (np.argmax(features, axis=1) != ATOM_DIM - 1)
+ & (np.sum(adjacency[:-1], axis=(0, 1)) != 0)
+ )[0]
+ features = features[keep_idx]
+ adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]
+
+ # Add atoms to molecule
+ for atom_type_idx in np.argmax(features, axis=1):
+ atom = Chem.Atom(atom_mapping[atom_type_idx])
+ _ = molecule.AddAtom(atom)
+
+ # Add bonds between atoms in molecule; based on the upper triangles
+ # of the [symmetric] adjacency tensor
+ (bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)
+ for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):
+ if atom_i == atom_j or bond_ij == BOND_DIM - 1:
+ continue
+ bond_type = bond_mapping[bond_ij]
+ molecule.AddBond(int(atom_i), int(atom_j), bond_type)
+
+ # Sanitize the molecule; for more information on sanitization, see
+ # https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
+ flag = Chem.SanitizeMol(molecule, catchErrors=True)
+ # Let's be strict. If sanitization fails, return None
+ if flag != Chem.SanitizeFlags.SANITIZE_NONE:
+ return None
+
+ return molecule
+
+
+"""
+## Generate training set
+"""
+
+train_df = df.sample(frac=0.75, random_state=42) # random state is a seed value
+train_df.reset_index(drop=True, inplace=True)
+
+adjacency_tensor, feature_tensor, qed_tensor = [], [], []
+for idx in range(8000):
+ adjacency, features = smiles_to_graph(train_df.loc[idx]["smiles"])
+ qed = train_df.loc[idx]["qed"]
+ adjacency_tensor.append(adjacency)
+ feature_tensor.append(features)
+ qed_tensor.append(qed)
+
+adjacency_tensor = np.array(adjacency_tensor)
+feature_tensor = np.array(feature_tensor)
+qed_tensor = np.array(qed_tensor)
+
+
+class RelationalGraphConvLayer(keras.layers.Layer):
+ def __init__(
+ self,
+ units=128,
+ activation="relu",
+ use_bias=False,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.units = units
+ self.activation = keras.activations.get(activation)
+ self.use_bias = use_bias
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
+ self.bias_initializer = keras.initializers.get(bias_initializer)
+ self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
+ self.bias_regularizer = keras.regularizers.get(bias_regularizer)
+
+ def build(self, input_shape):
+ bond_dim = input_shape[0][1]
+ atom_dim = input_shape[1][2]
+
+ self.kernel = self.add_weight(
+ shape=(bond_dim, atom_dim, self.units),
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ trainable=True,
+ name="W",
+ dtype="float32",
+ )
+
+ if self.use_bias:
+ self.bias = self.add_weight(
+ shape=(bond_dim, 1, self.units),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ trainable=True,
+ name="b",
+ dtype="float32",
+ )
+
+ self.built = True
+
+ def call(self, inputs, training=False):
+ adjacency, features = inputs
+ # Aggregate information from neighbors
+ x = ops.matmul(adjacency, features[:, None])
+ # Apply linear transformation
+ x = ops.matmul(x, self.kernel)
+ if self.use_bias:
+ x += self.bias
+ # Reduce bond types dim
+ x_reduced = ops.sum(x, axis=1)
+ # Apply non-linear transformation
+ return self.activation(x_reduced)
+
+
+"""
+## Build the Encoder and Decoder
+
+The Encoder takes as input a molecule's graph adjacency matrix and feature matrix.
+These features are processed via a Graph Convolution layer, then are flattened and
+processed by several Dense layers to derive `z_mean` and `log_var`, the
+latent-space representation of the molecule.
+
+**Graph Convolution layer**: The relational graph convolution layer implements
+non-linearly transformed neighbourhood aggregations. We can define these layers as
+follows:
+
+`H_hat**(l+1) = ฯ(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`
+
+Where `ฯ` denotes the non-linear transformation (commonly a ReLU activation), `A` the
+adjacency tensor, `H_hat**(l)` the feature tensor at the `l-th` layer, `D_hat**(-1)` the
+inverse diagonal degree tensor of `A_hat`, and `W_hat**(l)` the trainable weight tensor
+at the `l-th` layer. Specifically, for each bond type (relation), the degree tensor
+expresses, in the diagonal, the number of bonds attached to each atom.
+
+Source:
+[WGAN-GP with R-GCN for the generation of small molecular graphs](https://keras.io/examples/generative/wgan-graphs/))
+
+The Decoder takes as input the latent-space representation and predicts
+the graph adjacency matrix and feature matrix of the corresponding molecules.
+"""
+
+
+def get_encoder(
+ gconv_units, latent_dim, adjacency_shape, feature_shape, dense_units, dropout_rate
+):
+ adjacency = layers.Input(shape=adjacency_shape)
+ features = layers.Input(shape=feature_shape)
+
+ # Propagate through one or more graph convolutional layers
+ features_transformed = features
+ for units in gconv_units:
+ features_transformed = RelationalGraphConvLayer(units)(
+ [adjacency, features_transformed]
+ )
+ # Reduce 2-D representation of molecule to 1-D
+ x = layers.GlobalAveragePooling1D()(features_transformed)
+
+ # Propagate through one or more densely connected layers
+ for units in dense_units:
+ x = layers.Dense(units, activation="relu")(x)
+ x = layers.Dropout(dropout_rate)(x)
+
+ z_mean = layers.Dense(latent_dim, dtype="float32", name="z_mean")(x)
+ log_var = layers.Dense(latent_dim, dtype="float32", name="log_var")(x)
+
+ encoder = keras.Model([adjacency, features], [z_mean, log_var], name="encoder")
+
+ return encoder
+
+
+def get_decoder(dense_units, dropout_rate, latent_dim, adjacency_shape, feature_shape):
+ latent_inputs = keras.Input(shape=(latent_dim,))
+
+ x = latent_inputs
+ for units in dense_units:
+ x = layers.Dense(units, activation="tanh")(x)
+ x = layers.Dropout(dropout_rate)(x)
+
+ # Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)
+ x_adjacency = layers.Dense(np.prod(adjacency_shape))(x)
+ x_adjacency = layers.Reshape(adjacency_shape)(x_adjacency)
+ # Symmetrify tensors in the last two dimensions
+ x_adjacency = (x_adjacency + ops.transpose(x_adjacency, (0, 1, 3, 2))) / 2
+ x_adjacency = layers.Softmax(axis=1)(x_adjacency)
+
+ # Map outputs of previous layer (x) to [continuous] feature tensors (x_features)
+ x_features = layers.Dense(np.prod(feature_shape))(x)
+ x_features = layers.Reshape(feature_shape)(x_features)
+ x_features = layers.Softmax(axis=2)(x_features)
+
+ decoder = keras.Model(
+ latent_inputs, outputs=[x_adjacency, x_features], name="decoder"
+ )
+
+ return decoder
+
+
+"""
+## Build the Sampling layer
+"""
+
+
+class Sampling(layers.Layer):
+ def __init__(self, seed=None, **kwargs):
+ super().__init__(**kwargs)
+ self.seed_generator = keras.random.SeedGenerator(seed)
+
+ def call(self, inputs):
+ z_mean, z_log_var = inputs
+ batch, dim = ops.shape(z_log_var)
+ epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
+ return z_mean + ops.exp(0.5 * z_log_var) * epsilon
+
+
+"""
+## Build the VAE
+
+This model is trained to optimize four losses:
+
+* Categorical crossentropy
+* KL divergence loss
+* Property prediction loss
+* Graph loss (gradient penalty)
+
+The categorical crossentropy loss function measures the model's
+reconstruction accuracy. The Property prediction loss estimates the mean squared
+error between predicted and actual properties after running the latent representation
+through a property prediction model. The property
+prediction of the model is optimized via binary crossentropy. The gradient
+penalty is further guided by the model's property (QED) prediction.
+
+A gradient penalty is an alternative soft constraint on the
+1-Lipschitz continuity as an improvement upon the gradient clipping scheme from the
+original neural network
+("1-Lipschitz continuity" means that the norm of the gradient is at most 1 at every single
+point of the function).
+It adds a regularization term to the loss function.
+"""
+
+
+class MoleculeGenerator(keras.Model):
+ def __init__(self, encoder, decoder, max_len, seed=None, **kwargs):
+ super().__init__(**kwargs)
+ self.encoder = encoder
+ self.decoder = decoder
+ self.property_prediction_layer = layers.Dense(1)
+ self.max_len = max_len
+ self.seed_generator = keras.random.SeedGenerator(seed)
+ self.sampling_layer = Sampling(seed=seed)
+
+ self.train_total_loss_tracker = keras.metrics.Mean(name="train_total_loss")
+ self.val_total_loss_tracker = keras.metrics.Mean(name="val_total_loss")
+
+ def train_step(self, data):
+ adjacency_tensor, feature_tensor, qed_tensor = data[0]
+ graph_real = [adjacency_tensor, feature_tensor]
+ self.batch_size = ops.shape(qed_tensor)[0]
+ with tf.GradientTape() as tape:
+ z_mean, z_log_var, qed_pred, gen_adjacency, gen_features = self(
+ graph_real, training=True
+ )
+ graph_generated = [gen_adjacency, gen_features]
+ total_loss = self._compute_loss(
+ z_log_var, z_mean, qed_tensor, qed_pred, graph_real, graph_generated
+ )
+
+ grads = tape.gradient(total_loss, self.trainable_weights)
+ self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
+
+ self.train_total_loss_tracker.update_state(total_loss)
+ return {"loss": self.train_total_loss_tracker.result()}
+
+ def _compute_loss(
+ self, z_log_var, z_mean, qed_true, qed_pred, graph_real, graph_generated
+ ):
+ adjacency_real, features_real = graph_real
+ adjacency_gen, features_gen = graph_generated
+
+ adjacency_loss = ops.mean(
+ ops.sum(
+ keras.losses.categorical_crossentropy(
+ adjacency_real, adjacency_gen, axis=1
+ ),
+ axis=(1, 2),
+ )
+ )
+ features_loss = ops.mean(
+ ops.sum(
+ keras.losses.categorical_crossentropy(features_real, features_gen),
+ axis=(1),
+ )
+ )
+ kl_loss = -0.5 * ops.sum(
+ 1 + z_log_var - z_mean**2 - ops.minimum(ops.exp(z_log_var), 1e6), 1
+ )
+ kl_loss = ops.mean(kl_loss)
+
+ property_loss = ops.mean(
+ keras.losses.binary_crossentropy(qed_true, ops.squeeze(qed_pred, axis=1))
+ )
+
+ graph_loss = self._gradient_penalty(graph_real, graph_generated)
+
+ return kl_loss + property_loss + graph_loss + adjacency_loss + features_loss
+
+ def _gradient_penalty(self, graph_real, graph_generated):
+ # Unpack graphs
+ adjacency_real, features_real = graph_real
+ adjacency_generated, features_generated = graph_generated
+
+ # Generate interpolated graphs (adjacency_interp and features_interp)
+ alpha = keras.random.uniform(shape=(self.batch_size,), seed=self.seed_generator)
+ alpha = ops.reshape(alpha, (self.batch_size, 1, 1, 1))
+ adjacency_interp = (adjacency_real * alpha) + (
+ 1.0 - alpha
+ ) * adjacency_generated
+ alpha = ops.reshape(alpha, (self.batch_size, 1, 1))
+ features_interp = (features_real * alpha) + (1.0 - alpha) * features_generated
+
+ # Compute the logits of interpolated graphs
+ with tf.GradientTape() as tape:
+ tape.watch(adjacency_interp)
+ tape.watch(features_interp)
+ _, _, logits, _, _ = self(
+ [adjacency_interp, features_interp], training=True
+ )
+
+ # Compute the gradients with respect to the interpolated graphs
+ grads = tape.gradient(logits, [adjacency_interp, features_interp])
+ # Compute the gradient penalty
+ grads_adjacency_penalty = (1 - ops.norm(grads[0], axis=1)) ** 2
+ grads_features_penalty = (1 - ops.norm(grads[1], axis=2)) ** 2
+ return ops.mean(
+ ops.mean(grads_adjacency_penalty, axis=(-2, -1))
+ + ops.mean(grads_features_penalty, axis=(-1))
+ )
+
+ def inference(self, batch_size):
+ z = keras.random.normal(
+ shape=(batch_size, LATENT_DIM), seed=self.seed_generator
+ )
+ reconstruction_adjacency, reconstruction_features = model.decoder.predict(z)
+ # obtain one-hot encoded adjacency tensor
+ adjacency = ops.argmax(reconstruction_adjacency, axis=1)
+ adjacency = ops.one_hot(adjacency, num_classes=BOND_DIM, axis=1)
+ # Remove potential self-loops from adjacency
+ adjacency = adjacency * (1.0 - ops.eye(NUM_ATOMS, dtype="float32")[None, None])
+ # obtain one-hot encoded feature tensor
+ features = ops.argmax(reconstruction_features, axis=2)
+ features = ops.one_hot(features, num_classes=ATOM_DIM, axis=2)
+ return [
+ graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
+ for i in range(batch_size)
+ ]
+
+ def call(self, inputs):
+ z_mean, log_var = self.encoder(inputs)
+ z = self.sampling_layer([z_mean, log_var])
+
+ gen_adjacency, gen_features = self.decoder(z)
+
+ property_pred = self.property_prediction_layer(z_mean)
+
+ return z_mean, log_var, property_pred, gen_adjacency, gen_features
+
+
+"""
+## Train the model
+"""
+
+vae_optimizer = keras.optimizers.Adam(learning_rate=VAE_LR)
+
+encoder = get_encoder(
+ gconv_units=[9],
+ adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
+ feature_shape=(NUM_ATOMS, ATOM_DIM),
+ latent_dim=LATENT_DIM,
+ dense_units=[512],
+ dropout_rate=0.0,
+)
+decoder = get_decoder(
+ dense_units=[128, 256, 512],
+ dropout_rate=0.2,
+ latent_dim=LATENT_DIM,
+ adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
+ feature_shape=(NUM_ATOMS, ATOM_DIM),
+)
+
+model = MoleculeGenerator(encoder, decoder, MAX_MOLSIZE)
+
+model.compile(vae_optimizer)
+history = model.fit([adjacency_tensor, feature_tensor, qed_tensor], epochs=EPOCHS)
+
+"""
+## Inference
+
+We use our model to generate new valid molecules from different points of the latent space.
+"""
+
+"""
+### Generate unique Molecules with the model
+"""
+
+molecules = model.inference(1000)
+
+MolsToGridImage(
+ [m for m in molecules if m is not None][:1000], molsPerRow=5, subImgSize=(260, 160)
+)
+
+"""
+### Display latent space clusters with respect to molecular properties (QAE)
+"""
+
+
+def plot_latent(vae, data, labels):
+ # display a 2D plot of the property in the latent space
+ z_mean, _ = vae.encoder.predict(data)
+ plt.figure(figsize=(12, 10))
+ plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
+ plt.colorbar()
+ plt.xlabel("z[0]")
+ plt.ylabel("z[1]")
+ plt.show()
+
+
+plot_latent(model, [adjacency_tensor[:8000], feature_tensor[:8000]], qed_tensor[:8000])
+
+"""
+## Conclusion
+
+In this example, we combined model architectures from two papers,
+"Automatic chemical design using a data-driven continuous representation of
+molecules" from 2016 and the "MolGAN" paper from 2018. The former paper
+treats SMILES inputs as strings and seeks to generate molecule strings in SMILES format,
+while the later paper considers SMILES inputs as graphs (a combination of adjacency
+matrices and feature matrices) and seeks to generate molecules as graphs.
+
+This hybrid approach enables a new type of directed gradient-based search through chemical space.
+
+Example available on HuggingFace
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/drug-molecule-generation-with-VAE) | [](https://huggingface.co/spaces/keras-io/generating-drug-molecule-with-VAE) |
+"""
diff --git a/knowledge_base/generative/neural_style_transfer.py b/knowledge_base/generative/neural_style_transfer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1586f8470a40f9ad21ab561fe47658485a906e
--- /dev/null
+++ b/knowledge_base/generative/neural_style_transfer.py
@@ -0,0 +1,273 @@
+"""
+Title: Neural style transfer
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2016/01/11
+Last modified: 2020/05/02
+Description: Transferring the style of a reference image to target image using gradient descent.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Style transfer consists in generating an image
+with the same "content" as a base image, but with the
+"style" of a different picture (typically artistic).
+This is achieved through the optimization of a loss function
+that has 3 components: "style loss", "content loss",
+and "total variation loss":
+
+- The total variation loss imposes local spatial continuity between
+the pixels of the combination image, giving it visual coherence.
+- The style loss is where the deep learning keeps in --that one is defined
+using a deep convolutional neural network. Precisely, it consists in a sum of
+L2 distances between the Gram matrices of the representations of
+the base image and the style reference image, extracted from
+different layers of a convnet (trained on ImageNet). The general idea
+is to capture color/texture information at different spatial
+scales (fairly large scales --defined by the depth of the layer considered).
+- The content loss is a L2 distance between the features of the base
+image (extracted from a deep layer) and the features of the combination image,
+keeping the generated image close enough to the original one.
+
+**Reference:** [A Neural Algorithm of Artistic Style](
+ http://arxiv.org/abs/1508.06576)
+"""
+
+"""
+## Setup
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+import numpy as np
+import tensorflow as tf
+from keras.applications import vgg19
+
+base_image_path = keras.utils.get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
+style_reference_image_path = keras.utils.get_file(
+ "starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
+)
+result_prefix = "paris_generated"
+
+# Weights of the different loss components
+total_variation_weight = 1e-6
+style_weight = 1e-6
+content_weight = 2.5e-8
+
+# Dimensions of the generated picture.
+width, height = keras.utils.load_img(base_image_path).size
+img_nrows = 400
+img_ncols = int(width * img_nrows / height)
+
+"""
+## Let's take a look at our base (content) image and our style reference image
+"""
+
+from IPython.display import Image, display
+
+display(Image(base_image_path))
+display(Image(style_reference_image_path))
+
+"""
+## Image preprocessing / deprocessing utilities
+"""
+
+
+def preprocess_image(image_path):
+ # Util function to open, resize and format pictures into appropriate tensors
+ img = keras.utils.load_img(image_path, target_size=(img_nrows, img_ncols))
+ img = keras.utils.img_to_array(img)
+ img = np.expand_dims(img, axis=0)
+ img = vgg19.preprocess_input(img)
+ return tf.convert_to_tensor(img)
+
+
+def deprocess_image(x):
+ # Util function to convert a tensor into a valid image
+ x = x.reshape((img_nrows, img_ncols, 3))
+ # Remove zero-center by mean pixel
+ x[:, :, 0] += 103.939
+ x[:, :, 1] += 116.779
+ x[:, :, 2] += 123.68
+ # 'BGR'->'RGB'
+ x = x[:, :, ::-1]
+ x = np.clip(x, 0, 255).astype("uint8")
+ return x
+
+
+"""
+## Compute the style transfer loss
+
+First, we need to define 4 utility functions:
+
+- `gram_matrix` (used to compute the style loss)
+- The `style_loss` function, which keeps the generated image close to the local textures
+of the style reference image
+- The `content_loss` function, which keeps the high-level representation of the
+generated image close to that of the base image
+- The `total_variation_loss` function, a regularization loss which keeps the generated
+image locally-coherent
+"""
+
+# The gram matrix of an image tensor (feature-wise outer product)
+
+
+def gram_matrix(x):
+ x = tf.transpose(x, (2, 0, 1))
+ features = tf.reshape(x, (tf.shape(x)[0], -1))
+ gram = tf.matmul(features, tf.transpose(features))
+ return gram
+
+
+# The "style loss" is designed to maintain
+# the style of the reference image in the generated image.
+# It is based on the gram matrices (which capture style) of
+# feature maps from the style reference image
+# and from the generated image
+
+
+def style_loss(style, combination):
+ S = gram_matrix(style)
+ C = gram_matrix(combination)
+ channels = 3
+ size = img_nrows * img_ncols
+ return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels**2) * (size**2))
+
+
+# An auxiliary loss function
+# designed to maintain the "content" of the
+# base image in the generated image
+
+
+def content_loss(base, combination):
+ return tf.reduce_sum(tf.square(combination - base))
+
+
+# The 3rd loss function, total variation loss,
+# designed to keep the generated image locally coherent
+
+
+def total_variation_loss(x):
+ a = tf.square(
+ x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]
+ )
+ b = tf.square(
+ x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]
+ )
+ return tf.reduce_sum(tf.pow(a + b, 1.25))
+
+
+"""
+Next, let's create a feature extraction model that retrieves the intermediate activations
+of VGG19 (as a dict, by name).
+"""
+
+# Build a VGG19 model loaded with pre-trained ImageNet weights
+model = vgg19.VGG19(weights="imagenet", include_top=False)
+
+# Get the symbolic outputs of each "key" layer (we gave them unique names).
+outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])
+
+# Set up a model that returns the activation values for every layer in
+# VGG19 (as a dict).
+feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)
+
+"""
+Finally, here's the code that computes the style transfer loss.
+"""
+
+# List of layers to use for the style loss.
+style_layer_names = [
+ "block1_conv1",
+ "block2_conv1",
+ "block3_conv1",
+ "block4_conv1",
+ "block5_conv1",
+]
+# The layer to use for the content loss.
+content_layer_name = "block5_conv2"
+
+
+def compute_loss(combination_image, base_image, style_reference_image):
+ input_tensor = tf.concat(
+ [base_image, style_reference_image, combination_image], axis=0
+ )
+ features = feature_extractor(input_tensor)
+
+ # Initialize the loss
+ loss = tf.zeros(shape=())
+
+ # Add content loss
+ layer_features = features[content_layer_name]
+ base_image_features = layer_features[0, :, :, :]
+ combination_features = layer_features[2, :, :, :]
+ loss = loss + content_weight * content_loss(
+ base_image_features, combination_features
+ )
+ # Add style loss
+ for layer_name in style_layer_names:
+ layer_features = features[layer_name]
+ style_reference_features = layer_features[1, :, :, :]
+ combination_features = layer_features[2, :, :, :]
+ sl = style_loss(style_reference_features, combination_features)
+ loss += (style_weight / len(style_layer_names)) * sl
+
+ # Add total variation loss
+ loss += total_variation_weight * total_variation_loss(combination_image)
+ return loss
+
+
+"""
+## Add a tf.function decorator to loss & gradient computation
+
+To compile it, and thus make it fast.
+"""
+
+
+@tf.function
+def compute_loss_and_grads(combination_image, base_image, style_reference_image):
+ with tf.GradientTape() as tape:
+ loss = compute_loss(combination_image, base_image, style_reference_image)
+ grads = tape.gradient(loss, combination_image)
+ return loss, grads
+
+
+"""
+## The training loop
+
+Repeatedly run vanilla gradient descent steps to minimize the loss, and save the
+resulting image every 100 iterations.
+
+We decay the learning rate by 0.96 every 100 steps.
+"""
+
+optimizer = keras.optimizers.SGD(
+ keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96
+ )
+)
+
+base_image = preprocess_image(base_image_path)
+style_reference_image = preprocess_image(style_reference_image_path)
+combination_image = tf.Variable(preprocess_image(base_image_path))
+
+iterations = 4000
+for i in range(1, iterations + 1):
+ loss, grads = compute_loss_and_grads(
+ combination_image, base_image, style_reference_image
+ )
+ optimizer.apply_gradients([(grads, combination_image)])
+ if i % 100 == 0:
+ print("Iteration %d: loss=%.2f" % (i, loss))
+ img = deprocess_image(combination_image.numpy())
+ fname = result_prefix + "_at_iteration_%d.png" % i
+ keras.utils.save_img(fname, img)
+
+"""
+After 4000 iterations, you get the following result:
+"""
+
+display(Image(result_prefix + "_at_iteration_4000.png"))
diff --git a/knowledge_base/generative/pixelcnn.py b/knowledge_base/generative/pixelcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c08a09fc246aae181ed384e3fa149ef102b676d
--- /dev/null
+++ b/knowledge_base/generative/pixelcnn.py
@@ -0,0 +1,187 @@
+"""
+Title: PixelCNN
+Author: [ADMoreau](https://github.com/ADMoreau)
+Date created: 2020/05/17
+Last modified: 2020/05/23
+Description: PixelCNN implemented in Keras.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+PixelCNN is a generative model proposed in 2016 by van den Oord et al.
+(reference: [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328)).
+It is designed to generate images (or other data types) iteratively
+from an input vector where the probability distribution of prior elements dictates the
+probability distribution of later elements. In the following example, images are generated
+in this fashion, pixel-by-pixel, via a masked convolution kernel that only looks at data
+from previously generated pixels (origin at the top left) to generate later pixels.
+During inference, the output of the network is used as a probability distribution
+from which new pixel values are sampled to generate a new image
+(here, with MNIST, the pixels values are either black or white).
+"""
+
+import numpy as np
+import keras
+from keras import layers
+from keras import ops
+from tqdm import tqdm
+
+"""
+## Getting the Data
+"""
+
+# Model / data parameters
+num_classes = 10
+input_shape = (28, 28, 1)
+n_residual_blocks = 5
+# The data, split between train and test sets
+(x, _), (y, _) = keras.datasets.mnist.load_data()
+# Concatenate all the images together
+data = np.concatenate((x, y), axis=0)
+# Round all pixel values less than 33% of the max 256 value to 0
+# anything above this value gets rounded up to 1 so that all values are either
+# 0 or 1
+data = np.where(data < (0.33 * 256), 0, 1)
+data = data.astype(np.float32)
+
+"""
+## Create two classes for the requisite Layers for the model
+"""
+
+
+# The first layer is the PixelCNN layer. This layer simply
+# builds on the 2D convolutional layer, but includes masking.
+class PixelConvLayer(layers.Layer):
+ def __init__(self, mask_type, **kwargs):
+ super().__init__()
+ self.mask_type = mask_type
+ self.conv = layers.Conv2D(**kwargs)
+
+ def build(self, input_shape):
+ # Build the conv2d layer to initialize kernel variables
+ self.conv.build(input_shape)
+ # Use the initialized kernel to create the mask
+ kernel_shape = ops.shape(self.conv.kernel)
+ self.mask = np.zeros(shape=kernel_shape)
+ self.mask[: kernel_shape[0] // 2, ...] = 1.0
+ self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
+ if self.mask_type == "B":
+ self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
+
+ def call(self, inputs):
+ self.conv.kernel.assign(self.conv.kernel * self.mask)
+ return self.conv(inputs)
+
+
+# Next, we build our residual block layer.
+# This is just a normal residual block, but based on the PixelConvLayer.
+class ResidualBlock(keras.layers.Layer):
+ def __init__(self, filters, **kwargs):
+ super().__init__(**kwargs)
+ self.conv1 = keras.layers.Conv2D(
+ filters=filters, kernel_size=1, activation="relu"
+ )
+ self.pixel_conv = PixelConvLayer(
+ mask_type="B",
+ filters=filters // 2,
+ kernel_size=3,
+ activation="relu",
+ padding="same",
+ )
+ self.conv2 = keras.layers.Conv2D(
+ filters=filters, kernel_size=1, activation="relu"
+ )
+
+ def call(self, inputs):
+ x = self.conv1(inputs)
+ x = self.pixel_conv(x)
+ x = self.conv2(x)
+ return keras.layers.add([inputs, x])
+
+
+"""
+## Build the model based on the original paper
+"""
+
+inputs = keras.Input(shape=input_shape, batch_size=128)
+x = PixelConvLayer(
+ mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
+)(inputs)
+
+for _ in range(n_residual_blocks):
+ x = ResidualBlock(filters=128)(x)
+
+for _ in range(2):
+ x = PixelConvLayer(
+ mask_type="B",
+ filters=128,
+ kernel_size=1,
+ strides=1,
+ activation="relu",
+ padding="valid",
+ )(x)
+
+out = keras.layers.Conv2D(
+ filters=1, kernel_size=1, strides=1, activation="sigmoid", padding="valid"
+)(x)
+
+pixel_cnn = keras.Model(inputs, out)
+adam = keras.optimizers.Adam(learning_rate=0.0005)
+pixel_cnn.compile(optimizer=adam, loss="binary_crossentropy")
+
+pixel_cnn.summary()
+pixel_cnn.fit(
+ x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2
+)
+
+"""
+## Demonstration
+
+The PixelCNN cannot generate the full image at once. Instead, it must generate each pixel in
+order, append the last generated pixel to the current image, and feed the image back into the
+model to repeat the process.
+"""
+
+from IPython.display import Image, display
+
+# Create an empty array of pixels.
+batch = 4
+pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
+batch, rows, cols, channels = pixels.shape
+
+# Iterate over the pixels because generation has to be done sequentially pixel by pixel.
+for row in tqdm(range(rows)):
+ for col in range(cols):
+ for channel in range(channels):
+ # Feed the whole array and retrieving the pixel value probabilities for the next
+ # pixel.
+ probs = pixel_cnn.predict(pixels, verbose=0)[:, row, col, channel]
+ # Use the probabilities to pick pixel values and append the values to the image
+ # frame.
+ pixels[:, row, col, channel] = ops.ceil(
+ probs - keras.random.uniform(probs.shape)
+ )
+
+
+def deprocess_image(x):
+ # Stack the single channeled black and white image to rgb values.
+ x = np.stack((x, x, x), 2)
+ # Undo preprocessing
+ x *= 255.0
+ # Convert to uint8 and clip to the valid range [0, 255]
+ x = np.clip(x, 0, 255).astype("uint8")
+ return x
+
+
+# Iterate over the generated images and plot them with matplotlib.
+for i, pic in enumerate(pixels):
+ keras.utils.save_img(
+ "generated_image_{}.png".format(i), deprocess_image(np.squeeze(pic, -1))
+ )
+
+display(Image("generated_image_0.png"))
+display(Image("generated_image_1.png"))
+display(Image("generated_image_2.png"))
+display(Image("generated_image_3.png"))
diff --git a/knowledge_base/generative/random_walks_with_stable_diffusion.py b/knowledge_base/generative/random_walks_with_stable_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba3d48603e597762cb918727b86f6b2e728ce335
--- /dev/null
+++ b/knowledge_base/generative/random_walks_with_stable_diffusion.py
@@ -0,0 +1,387 @@
+"""
+Title: A walk through latent space with Stable Diffusion
+Authors: Ian Stenbit, [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml)
+Date created: 2022/09/28
+Last modified: 2022/09/28
+Description: Explore the latent manifold of Stable Diffusion.
+Accelerator: GPU
+"""
+
+"""
+## Overview
+
+Generative image models learn a "latent manifold" of the visual world:
+a low-dimensional vector space where each point maps to an image.
+Going from such a point on the manifold back to a displayable image
+is called "decoding" -- in the Stable Diffusion model, this is handled by
+the "decoder" model.
+
+
+
+This latent manifold of images is continuous and interpolative, meaning that:
+
+1. Moving a little on the manifold only changes the corresponding image a little (continuity).
+2. For any two points A and B on the manifold (i.e. any two images), it is possible
+to move from A to B via a path where each intermediate point is also on the manifold (i.e.
+is also a valid image). Intermediate points would be called "interpolations" between
+the two starting images.
+
+Stable Diffusion isn't just an image model, though, it's also a natural language model.
+It has two latent spaces: the image representation space learned by the
+encoder used during training, and the prompt latent space
+which is learned using a combination of pretraining and training-time
+fine-tuning.
+
+_Latent space walking_, or _latent space exploration_, is the process of
+sampling a point in latent space and incrementally changing the latent
+representation. Its most common application is generating animations
+where each sampled point is fed to the decoder and is stored as a
+frame in the final animation.
+For high-quality latent representations, this produces coherent-looking
+animations. These animations can provide insight into the feature map of the
+latent space, and can ultimately lead to improvements in the training
+process. One such GIF is displayed below:
+
+
+
+In this guide, we will show how to take advantage of the Stable Diffusion API
+in KerasCV to perform prompt interpolation and circular walks through
+Stable Diffusion's visual latent manifold, as well as through
+the text encoder's latent manifold.
+
+This guide assumes the reader has a
+high-level understanding of Stable Diffusion.
+If you haven't already, you should start
+by reading the [Stable Diffusion Tutorial](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/).
+
+To start, we import KerasCV and load up a Stable Diffusion model using the
+optimizations discussed in the tutorial
+[Generate images with Stable Diffusion](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/).
+Note that if you are running with a M1 Mac GPU you should not enable mixed precision.
+"""
+
+"""shell
+pip install keras-cv --upgrade --quiet
+"""
+
+import keras_cv
+import keras
+import matplotlib.pyplot as plt
+from keras import ops
+import numpy as np
+import math
+from PIL import Image
+
+# Enable mixed precision
+# (only do this if you have a recent NVIDIA GPU)
+keras.mixed_precision.set_global_policy("mixed_float16")
+
+# Instantiate the Stable Diffusion model
+model = keras_cv.models.StableDiffusion(jit_compile=True)
+
+"""
+## Interpolating between text prompts
+
+In Stable Diffusion, a text prompt is first encoded into a vector,
+and that encoding is used to guide the diffusion process.
+The latent encoding vector has shape
+77x768 (that's huge!), and when we give Stable Diffusion a text prompt, we're
+generating images from just one such point on the latent manifold.
+
+To explore more of this manifold, we can interpolate between two text encodings
+and generate images at those interpolated points:
+"""
+
+prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
+prompt_2 = "A still life DSLR photo of a bowl of fruit"
+interpolation_steps = 5
+
+encoding_1 = ops.squeeze(model.encode_text(prompt_1))
+encoding_2 = ops.squeeze(model.encode_text(prompt_2))
+
+interpolated_encodings = ops.linspace(encoding_1, encoding_2, interpolation_steps)
+
+# Show the size of the latent manifold
+print(f"Encoding shape: {encoding_1.shape}")
+
+"""
+Once we've interpolated the encodings, we can generate images from each point.
+Note that in order to maintain some stability between the resulting images we
+keep the diffusion noise constant between images.
+"""
+
+seed = 12345
+noise = keras.random.normal((512 // 8, 512 // 8, 4), seed=seed)
+
+images = model.generate_image(
+ interpolated_encodings,
+ batch_size=interpolation_steps,
+ diffusion_noise=noise,
+)
+
+"""
+Now that we've generated some interpolated images, let's take a look at them!
+
+Throughout this tutorial, we're going to export sequences of images as gifs so
+that they can be easily viewed with some temporal context. For sequences of
+images where the first and last images don't match conceptually, we rubber-band
+the gif.
+
+If you're running in Colab, you can view your own GIFs by running:
+
+```
+from IPython.display import Image as IImage
+IImage("doggo-and-fruit-5.gif")
+```
+"""
+
+
+def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
+ if rubber_band:
+ images += images[2:-1][::-1]
+ images[0].save(
+ filename,
+ save_all=True,
+ append_images=images[1:],
+ duration=1000 // frames_per_second,
+ loop=0,
+ )
+
+
+export_as_gif(
+ "doggo-and-fruit-5.gif",
+ [Image.fromarray(img) for img in images],
+ frames_per_second=2,
+ rubber_band=True,
+)
+
+"""
+
+
+The results may seem surprising. Generally, interpolating between prompts
+produces coherent looking images, and often demonstrates a progressive concept
+shift between the contents of the two prompts. This is indicative of a high
+quality representation space, that closely mirrors the natural structure
+of the visual world.
+
+To best visualize this, we should do a much more fine-grained interpolation,
+using hundreds of steps. In order to keep batch size small (so that we don't
+OOM our GPU), this requires manually batching our interpolated
+encodings.
+"""
+
+interpolation_steps = 150
+batch_size = 3
+batches = interpolation_steps // batch_size
+
+interpolated_encodings = ops.linspace(encoding_1, encoding_2, interpolation_steps)
+batched_encodings = ops.split(interpolated_encodings, batches)
+
+images = []
+for batch in range(batches):
+ images += [
+ Image.fromarray(img)
+ for img in model.generate_image(
+ batched_encodings[batch],
+ batch_size=batch_size,
+ num_steps=25,
+ diffusion_noise=noise,
+ )
+ ]
+
+export_as_gif("doggo-and-fruit-150.gif", images, rubber_band=True)
+
+"""
+
+
+The resulting gif shows a much clearer and more coherent shift between the two
+prompts. Try out some prompts of your own and experiment!
+
+We can even extend this concept for more than one image. For example, we can
+interpolate between four prompts:
+"""
+
+prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
+prompt_2 = "A still life DSLR photo of a bowl of fruit"
+prompt_3 = "The eiffel tower in the style of starry night"
+prompt_4 = "An architectural sketch of a skyscraper"
+
+interpolation_steps = 6
+batch_size = 3
+batches = (interpolation_steps**2) // batch_size
+
+encoding_1 = ops.squeeze(model.encode_text(prompt_1))
+encoding_2 = ops.squeeze(model.encode_text(prompt_2))
+encoding_3 = ops.squeeze(model.encode_text(prompt_3))
+encoding_4 = ops.squeeze(model.encode_text(prompt_4))
+
+interpolated_encodings = ops.linspace(
+ ops.linspace(encoding_1, encoding_2, interpolation_steps),
+ ops.linspace(encoding_3, encoding_4, interpolation_steps),
+ interpolation_steps,
+)
+interpolated_encodings = ops.reshape(
+ interpolated_encodings, (interpolation_steps**2, 77, 768)
+)
+batched_encodings = ops.split(interpolated_encodings, batches)
+
+images = []
+for batch in range(batches):
+ images.append(
+ model.generate_image(
+ batched_encodings[batch],
+ batch_size=batch_size,
+ diffusion_noise=noise,
+ )
+ )
+
+
+def plot_grid(images, path, grid_size, scale=2):
+ fig, axs = plt.subplots(
+ grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)
+ )
+ fig.tight_layout()
+ plt.subplots_adjust(wspace=0, hspace=0)
+ plt.axis("off")
+ for ax in axs.flat:
+ ax.axis("off")
+
+ images = images.astype(int)
+ for i in range(min(grid_size * grid_size, len(images))):
+ ax = axs.flat[i]
+ ax.imshow(images[i].astype("uint8"))
+ ax.axis("off")
+
+ for i in range(len(images), grid_size * grid_size):
+ axs.flat[i].axis("off")
+ axs.flat[i].remove()
+
+ plt.savefig(
+ fname=path,
+ pad_inches=0,
+ bbox_inches="tight",
+ transparent=False,
+ dpi=60,
+ )
+
+
+images = np.concatenate(images)
+plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)
+
+"""
+We can also interpolate while allowing diffusion noise to vary by dropping
+the `diffusion_noise` parameter:
+"""
+
+images = []
+for batch in range(batches):
+ images.append(model.generate_image(batched_encodings[batch], batch_size=batch_size))
+
+images = np.concatenate(images)
+plot_grid(images, "4-way-interpolation-varying-noise.jpg", interpolation_steps)
+
+"""
+Next up -- let's go for some walks!
+
+## A walk around a text prompt
+
+Our next experiment will be to go for a walk around the latent manifold
+starting from a point produced by a particular prompt.
+"""
+
+walk_steps = 150
+batch_size = 3
+batches = walk_steps // batch_size
+step_size = 0.005
+
+encoding = ops.squeeze(
+ model.encode_text("The Eiffel Tower in the style of starry night")
+)
+# Note that (77, 768) is the shape of the text encoding.
+delta = ops.ones_like(encoding) * step_size
+
+walked_encodings = []
+for step_index in range(walk_steps):
+ walked_encodings.append(encoding)
+ encoding += delta
+walked_encodings = ops.stack(walked_encodings)
+batched_encodings = ops.split(walked_encodings, batches)
+
+images = []
+for batch in range(batches):
+ images += [
+ Image.fromarray(img)
+ for img in model.generate_image(
+ batched_encodings[batch],
+ batch_size=batch_size,
+ num_steps=25,
+ diffusion_noise=noise,
+ )
+ ]
+
+export_as_gif("eiffel-tower-starry-night.gif", images, rubber_band=True)
+
+"""
+
+
+Perhaps unsurprisingly, walking too far from the encoder's latent manifold
+produces images that look incoherent. Try it for yourself by setting
+your own prompt, and adjusting `step_size` to increase or decrease the magnitude
+of the walk. Note that when the magnitude of the walk gets large, the walk often
+leads into areas which produce extremely noisy images.
+
+## A circular walk through the diffusion noise space for a single prompt
+
+Our final experiment is to stick to one prompt and explore the variety of images
+that the diffusion model can produce from that prompt. We do this by controlling
+the noise that is used to seed the diffusion process.
+
+We create two noise components, `x` and `y`, and do a walk from 0 to 2ฯ, summing
+the cosine of our `x` component and the sin of our `y` component to produce noise.
+Using this approach, the end of our walk arrives at the same noise inputs where
+we began our walk, so we get a "loopable" result!
+"""
+
+prompt = "An oil paintings of cows in a field next to a windmill in Holland"
+encoding = ops.squeeze(model.encode_text(prompt))
+walk_steps = 150
+batch_size = 3
+batches = walk_steps // batch_size
+
+walk_noise_x = keras.random.normal(noise.shape, dtype="float64")
+walk_noise_y = keras.random.normal(noise.shape, dtype="float64")
+
+walk_scale_x = ops.cos(ops.linspace(0, 2, walk_steps) * math.pi)
+walk_scale_y = ops.sin(ops.linspace(0, 2, walk_steps) * math.pi)
+noise_x = ops.tensordot(walk_scale_x, walk_noise_x, axes=0)
+noise_y = ops.tensordot(walk_scale_y, walk_noise_y, axes=0)
+noise = ops.add(noise_x, noise_y)
+batched_noise = ops.split(noise, batches)
+
+images = []
+for batch in range(batches):
+ images += [
+ Image.fromarray(img)
+ for img in model.generate_image(
+ encoding,
+ batch_size=batch_size,
+ num_steps=25,
+ diffusion_noise=batched_noise[batch],
+ )
+ ]
+
+export_as_gif("cows.gif", images)
+
+"""
+
+
+Experiment with your own prompts and with different values of
+`unconditional_guidance_scale`!
+
+## Conclusion
+
+Stable Diffusion offers a lot more than just single text-to-image generation.
+Exploring the latent manifold of the text encoder and the noise space of the
+diffusion model are two fun ways to experience the power of this model, and
+KerasCV makes it easy!
+"""
diff --git a/knowledge_base/generative/random_walks_with_stable_diffusion_3.py b/knowledge_base/generative/random_walks_with_stable_diffusion_3.py
new file mode 100644
index 0000000000000000000000000000000000000000..52307af0a803aa37d9fa75da90e35899106aaa01
--- /dev/null
+++ b/knowledge_base/generative/random_walks_with_stable_diffusion_3.py
@@ -0,0 +1,693 @@
+"""
+Title: A walk through latent space with Stable Diffusion 3
+Authors: [Hongyu Chiu](https://github.com/james77777778), Ian Stenbit, [fchollet](https://twitter.com/fchollet), [lukewood](https://twitter.com/luke_wood_ml)
+Date created: 2024/11/11
+Last modified: 2024/11/11
+Description: Explore the latent manifold of Stable Diffusion 3.
+Accelerator: GPU
+"""
+
+"""
+## Overview
+
+Generative image models learn a "latent manifold" of the visual world: a
+low-dimensional vector space where each point maps to an image. Going from such
+a point on the manifold back to a displayable image is called "decoding" -- in
+the Stable Diffusion model, this is handled by the "decoder" model.
+
+
+
+This latent manifold of images is continuous and interpolative, meaning that:
+
+1. Moving a little on the manifold only changes the corresponding image a
+little (continuity).
+2. For any two points A and B on the manifold (i.e. any two images), it is
+possible to move from A to B via a path where each intermediate point is also on
+the manifold (i.e. is also a valid image). Intermediate points would be called
+"interpolations" between the two starting images.
+
+Stable Diffusion isn't just an image model, though, it's also a natural language
+model. It has two latent spaces: the image representation space learned by the
+encoder used during training, and the prompt latent space which is learned using
+a combination of pretraining and training-time fine-tuning.
+
+_Latent space walking_, or _latent space exploration_, is the process of
+sampling a point in latent space and incrementally changing the latent
+representation. Its most common application is generating animations where each
+sampled point is fed to the decoder and is stored as a frame in the final
+animation.
+For high-quality latent representations, this produces coherent-looking
+animations. These animations can provide insight into the feature map of the
+latent space, and can ultimately lead to improvements in the training process.
+One such GIF is displayed below:
+
+
+
+In this guide, we will show how to take advantage of the TextToImage API in
+KerasHub to perform prompt interpolation and circular walks through Stable
+Diffusion 3's visual latent manifold, as well as through the text encoder's
+latent manifold.
+
+This guide assumes the reader has a high-level understanding of Stable
+Diffusion 3. If you haven't already, you should start by reading the
+[Stable Diffusion 3 in KerasHub](
+https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/).
+
+It is also worth noting that the preset "stable_diffusion_3_medium" excludes the
+T5XXL text encoder, as it requires significantly more GPU memory. The performace
+degradation is negligible in most cases. The weights, including T5XXL, will be
+available on KerasHub soon.
+"""
+
+"""shell
+# Use the latest version of KerasHub
+!pip install -Uq git+https://github.com/keras-team/keras-hub.git
+"""
+
+import math
+
+import keras
+import keras_hub
+import matplotlib.pyplot as plt
+from keras import ops
+from keras import random
+from PIL import Image
+
+height, width = 512, 512
+num_steps = 28
+guidance_scale = 7.0
+dtype = "float16"
+
+# Instantiate the Stable Diffusion 3 model and the preprocessor
+backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
+ "stable_diffusion_3_medium", image_shape=(height, width, 3), dtype=dtype
+)
+preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
+ "stable_diffusion_3_medium"
+)
+
+"""
+Let's define some helper functions for this example.
+"""
+
+
+def get_text_embeddings(prompt):
+ """Get the text embeddings for a given prompt."""
+ token_ids = preprocessor.generate_preprocess([prompt])
+ negative_token_ids = preprocessor.generate_preprocess([""])
+ (
+ positive_embeddings,
+ negative_embeddings,
+ positive_pooled_embeddings,
+ negative_pooled_embeddings,
+ ) = backbone.encode_text_step(token_ids, negative_token_ids)
+ return (
+ positive_embeddings,
+ negative_embeddings,
+ positive_pooled_embeddings,
+ negative_pooled_embeddings,
+ )
+
+
+def decode_to_images(x, height, width):
+ """Concatenate and normalize the images to uint8 dtype."""
+ x = ops.concatenate(x, axis=0)
+ x = ops.reshape(x, (-1, height, width, 3))
+ x = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0)
+ return ops.cast(ops.round(ops.multiply(x, 255.0)), "uint8")
+
+
+def generate_with_latents_and_embeddings(
+ latents, embeddings, num_steps, guidance_scale
+):
+ """Generate images from latents and text embeddings."""
+
+ def body_fun(step, latents):
+ return backbone.denoise_step(
+ latents,
+ embeddings,
+ step,
+ num_steps,
+ guidance_scale,
+ )
+
+ latents = ops.fori_loop(0, num_steps, body_fun, latents)
+ return backbone.decode_step(latents)
+
+
+def export_as_gif(filename, images, frames_per_second=10, no_rubber_band=False):
+ if not no_rubber_band:
+ images += images[2:-1][::-1] # Makes a rubber band: A->B->A
+ images[0].save(
+ filename,
+ save_all=True,
+ append_images=images[1:],
+ duration=1000 // frames_per_second,
+ loop=0,
+ )
+
+
+"""
+We are going to generate images using custom latents and embeddings, so we need
+to implement the `generate_with_latents_and_embeddings` function. Additionally,
+it is important to compile this function to speed up the generation process.
+"""
+
+if keras.config.backend() == "torch":
+ import torch
+
+ @torch.no_grad()
+ def wrapped_function(*args, **kwargs):
+ return generate_with_latents_and_embeddings(*args, **kwargs)
+
+ generate_function = wrapped_function
+elif keras.config.backend() == "tensorflow":
+ import tensorflow as tf
+
+ generate_function = tf.function(
+ generate_with_latents_and_embeddings, jit_compile=True
+ )
+elif keras.config.backend() == "jax":
+ import itertools
+
+ import jax
+
+ @jax.jit
+ def compiled_function(state, *args, **kwargs):
+ (trainable_variables, non_trainable_variables) = state
+ mapping = itertools.chain(
+ zip(backbone.trainable_variables, trainable_variables),
+ zip(backbone.non_trainable_variables, non_trainable_variables),
+ )
+ with keras.StatelessScope(state_mapping=mapping):
+ return generate_with_latents_and_embeddings(*args, **kwargs)
+
+ def wrapped_function(*args, **kwargs):
+ state = (
+ [v.value for v in backbone.trainable_variables],
+ [v.value for v in backbone.non_trainable_variables],
+ )
+ return compiled_function(state, *args, **kwargs)
+
+ generate_function = wrapped_function
+
+
+"""
+## Interpolating between text prompts
+
+In Stable Diffusion 3, a text prompt is encoded into multiple vectors, which are
+then used to guide the diffusion process. These latent encoding vectors have
+shapes of 154x4096 and 2048 for both the positive and negative prompts - quite
+large! When we input a text prompt into Stable Diffusion 3, we generate images
+from a single point on this latent manifold.
+
+To explore more of this manifold, we can interpolate between two text encodings
+and generate images at those interpolated points:
+"""
+
+prompt_1 = "A cute dog in a beautiful field of lavander colorful flowers "
+prompt_1 += "everywhere, perfect lighting, leica summicron 35mm f2.0, kodak "
+prompt_1 += "portra 400, film grain"
+prompt_2 = prompt_1.replace("dog", "cat")
+interpolation_steps = 5
+
+encoding_1 = get_text_embeddings(prompt_1)
+encoding_2 = get_text_embeddings(prompt_2)
+
+
+# Show the size of the latent manifold
+print(f"Positive embeddings shape: {encoding_1[0].shape}")
+print(f"Negative embeddings shape: {encoding_1[1].shape}")
+print(f"Positive pooled embeddings shape: {encoding_1[2].shape}")
+print(f"Negative pooled embeddings shape: {encoding_1[3].shape}")
+
+
+"""
+In this example, we want to use Spherical Linear Interpolation (slerp) instead
+of simple linear interpolation. Slerp is commonly used in computer graphics to
+animate rotations smoothly and can also be applied to interpolate between
+high-dimensional data points, such as latent vectors used in generative models.
+
+The source is from Andrej Karpathy's gist:
+[https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355](https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355).
+
+A more detailed explanation of this method can be found at:
+[https://en.wikipedia.org/wiki/Slerp](https://en.wikipedia.org/wiki/Slerp).
+"""
+
+
+def slerp(v1, v2, num):
+ ori_dtype = v1.dtype
+ # Cast to float32 for numerical stability.
+ v1 = ops.cast(v1, "float32")
+ v2 = ops.cast(v2, "float32")
+
+ def interpolation(t, v1, v2, dot_threshold=0.9995):
+ """helper function to spherically interpolate two arrays."""
+ dot = ops.sum(
+ v1 * v2 / (ops.linalg.norm(ops.ravel(v1)) * ops.linalg.norm(ops.ravel(v2)))
+ )
+ if ops.abs(dot) > dot_threshold:
+ v2 = (1 - t) * v1 + t * v2
+ else:
+ theta_0 = ops.arccos(dot)
+ sin_theta_0 = ops.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = ops.sin(theta_t)
+ s0 = ops.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v1 + s1 * v2
+ return v2
+
+ t = ops.linspace(0, 1, num)
+ interpolated = ops.stack([interpolation(t[i], v1, v2) for i in range(num)], axis=0)
+ return ops.cast(interpolated, ori_dtype)
+
+
+interpolated_positive_embeddings = slerp(
+ encoding_1[0], encoding_2[0], interpolation_steps
+)
+interpolated_positive_pooled_embeddings = slerp(
+ encoding_1[2], encoding_2[2], interpolation_steps
+)
+# We don't use negative prompts in this example, so thereโs no need to
+# interpolate them.
+negative_embeddings = encoding_1[1]
+negative_pooled_embeddings = encoding_1[3]
+
+
+"""
+Once we've interpolated the encodings, we can generate images from each point.
+Note that in order to maintain some stability between the resulting images we
+keep the diffusion latents constant between images.
+"""
+
+latents = random.normal((1, height // 8, width // 8, 16), seed=42)
+
+images = []
+progbar = keras.utils.Progbar(interpolation_steps)
+for i in range(interpolation_steps):
+ images.append(
+ generate_function(
+ latents,
+ (
+ interpolated_positive_embeddings[i],
+ negative_embeddings,
+ interpolated_positive_pooled_embeddings[i],
+ negative_pooled_embeddings,
+ ),
+ ops.convert_to_tensor(num_steps),
+ ops.convert_to_tensor(guidance_scale),
+ )
+ )
+ progbar.update(i + 1, finalize=i == interpolation_steps - 1)
+
+"""
+Now that we've generated some interpolated images, let's take a look at them!
+
+Throughout this tutorial, we're going to export sequences of images as gifs so
+that they can be easily viewed with some temporal context. For sequences of
+images where the first and last images don't match conceptually, we rubber-band
+the gif.
+
+If you're running in Colab, you can view your own GIFs by running:
+
+```
+from IPython.display import Image as IImage
+IImage("dog_to_cat_5.gif")
+```
+"""
+
+images = ops.convert_to_numpy(decode_to_images(images, height, width))
+export_as_gif(
+ "dog_to_cat_5.gif",
+ [Image.fromarray(image) for image in images],
+ frames_per_second=2,
+)
+
+"""
+The results may seem surprising. Generally, interpolating between prompts
+produces coherent looking images, and often demonstrates a progressive concept
+shift between the contents of the two prompts. This is indicative of a high
+quality representation space, that closely mirrors the natural structure of the
+visual world.
+
+To best visualize this, we should do a much more fine-grained interpolation,
+using more steps.
+"""
+
+interpolation_steps = 64
+batch_size = 4
+batches = interpolation_steps // batch_size
+
+interpolated_positive_embeddings = slerp(
+ encoding_1[0], encoding_2[0], interpolation_steps
+)
+interpolated_positive_pooled_embeddings = slerp(
+ encoding_1[2], encoding_2[2], interpolation_steps
+)
+positive_embeddings_shape = ops.shape(encoding_1[0])
+positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
+interpolated_positive_embeddings = ops.reshape(
+ interpolated_positive_embeddings,
+ (
+ batches,
+ batch_size,
+ positive_embeddings_shape[-2],
+ positive_embeddings_shape[-1],
+ ),
+)
+interpolated_positive_pooled_embeddings = ops.reshape(
+ interpolated_positive_pooled_embeddings,
+ (batches, batch_size, positive_pooled_embeddings_shape[-1]),
+)
+negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
+negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))
+
+latents = random.normal((1, height // 8, width // 8, 16), seed=42)
+latents = ops.tile(latents, (batch_size, 1, 1, 1))
+
+images = []
+progbar = keras.utils.Progbar(batches)
+for i in range(batches):
+ images.append(
+ generate_function(
+ latents,
+ (
+ interpolated_positive_embeddings[i],
+ negative_embeddings,
+ interpolated_positive_pooled_embeddings[i],
+ negative_pooled_embeddings,
+ ),
+ ops.convert_to_tensor(num_steps),
+ ops.convert_to_tensor(guidance_scale),
+ )
+ )
+ progbar.update(i + 1, finalize=i == batches - 1)
+
+images = ops.convert_to_numpy(decode_to_images(images, height, width))
+export_as_gif(
+ "dog_to_cat_64.gif",
+ [Image.fromarray(image) for image in images],
+ frames_per_second=2,
+)
+
+"""
+The resulting gif shows a much clearer and more coherent shift between the two
+prompts. Try out some prompts of your own and experiment!
+
+We can even extend this concept for more than one image. For example, we can
+interpolate between four prompts:
+"""
+
+prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
+prompt_2 = "A still life DSLR photo of a bowl of fruit"
+prompt_3 = "The eiffel tower in the style of starry night"
+prompt_4 = "An architectural sketch of a skyscraper"
+
+interpolation_steps = 8
+batch_size = 4
+batches = (interpolation_steps**2) // batch_size
+
+encoding_1 = get_text_embeddings(prompt_1)
+encoding_2 = get_text_embeddings(prompt_2)
+encoding_3 = get_text_embeddings(prompt_3)
+encoding_4 = get_text_embeddings(prompt_4)
+
+positive_embeddings_shape = ops.shape(encoding_1[0])
+positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
+interpolated_positive_embeddings_12 = slerp(
+ encoding_1[0], encoding_2[0], interpolation_steps
+)
+interpolated_positive_embeddings_34 = slerp(
+ encoding_3[0], encoding_4[0], interpolation_steps
+)
+interpolated_positive_embeddings = slerp(
+ interpolated_positive_embeddings_12,
+ interpolated_positive_embeddings_34,
+ interpolation_steps,
+)
+interpolated_positive_embeddings = ops.reshape(
+ interpolated_positive_embeddings,
+ (
+ batches,
+ batch_size,
+ positive_embeddings_shape[-2],
+ positive_embeddings_shape[-1],
+ ),
+)
+interpolated_positive_pooled_embeddings_12 = slerp(
+ encoding_1[2], encoding_2[2], interpolation_steps
+)
+interpolated_positive_pooled_embeddings_34 = slerp(
+ encoding_3[2], encoding_4[2], interpolation_steps
+)
+interpolated_positive_pooled_embeddings = slerp(
+ interpolated_positive_pooled_embeddings_12,
+ interpolated_positive_pooled_embeddings_34,
+ interpolation_steps,
+)
+interpolated_positive_pooled_embeddings = ops.reshape(
+ interpolated_positive_pooled_embeddings,
+ (batches, batch_size, positive_pooled_embeddings_shape[-1]),
+)
+negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
+negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))
+
+latents = random.normal((1, height // 8, width // 8, 16), seed=42)
+latents = ops.tile(latents, (batch_size, 1, 1, 1))
+
+images = []
+progbar = keras.utils.Progbar(batches)
+for i in range(batches):
+ images.append(
+ generate_function(
+ latents,
+ (
+ interpolated_positive_embeddings[i],
+ negative_embeddings,
+ interpolated_positive_pooled_embeddings[i],
+ negative_pooled_embeddings,
+ ),
+ ops.convert_to_tensor(num_steps),
+ ops.convert_to_tensor(guidance_scale),
+ )
+ )
+ progbar.update(i + 1, finalize=i == batches - 1)
+
+
+"""
+Let's display the resulting images in a grid to make them easier to interpret.
+"""
+
+
+def plot_grid(images, path, grid_size, scale=2):
+ fig, axs = plt.subplots(
+ grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)
+ )
+ fig.tight_layout()
+ plt.subplots_adjust(wspace=0, hspace=0)
+ plt.axis("off")
+ for ax in axs.flat:
+ ax.axis("off")
+
+ for i in range(min(grid_size * grid_size, len(images))):
+ ax = axs.flat[i]
+ ax.imshow(images[i])
+ ax.axis("off")
+
+ for i in range(len(images), grid_size * grid_size):
+ axs.flat[i].axis("off")
+ axs.flat[i].remove()
+
+ plt.savefig(
+ fname=path,
+ pad_inches=0,
+ bbox_inches="tight",
+ transparent=False,
+ dpi=60,
+ )
+
+
+images = ops.convert_to_numpy(decode_to_images(images, height, width))
+plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)
+
+"""
+We can also interpolate while allowing diffusion latents to vary by dropping
+the `seed` parameter:
+"""
+
+images = []
+progbar = keras.utils.Progbar(batches)
+for i in range(batches):
+ # Vary diffusion latents for each input.
+ latents = random.normal((batch_size, height // 8, width // 8, 16))
+ images.append(
+ generate_function(
+ latents,
+ (
+ interpolated_positive_embeddings[i],
+ negative_embeddings,
+ interpolated_positive_pooled_embeddings[i],
+ negative_pooled_embeddings,
+ ),
+ ops.convert_to_tensor(num_steps),
+ ops.convert_to_tensor(guidance_scale),
+ )
+ )
+ progbar.update(i + 1, finalize=i == batches - 1)
+
+images = ops.convert_to_numpy(decode_to_images(images, height, width))
+plot_grid(images, "4-way-interpolation-varying-latent.jpg", interpolation_steps)
+
+"""
+Next up -- let's go for some walks!
+
+## A walk around a text prompt
+
+Our next experiment will be to go for a walk around the latent manifold
+starting from a point produced by a particular prompt.
+"""
+
+walk_steps = 64
+batch_size = 4
+batches = walk_steps // batch_size
+step_size = 0.01
+prompt = "The eiffel tower in the style of starry night"
+encoding = get_text_embeddings(prompt)
+
+positive_embeddings = encoding[0]
+positive_pooled_embeddings = encoding[2]
+negative_embeddings = encoding[1]
+negative_pooled_embeddings = encoding[3]
+
+# The shape of `positive_embeddings`: (1, 154, 4096)
+# The shape of `positive_pooled_embeddings`: (1, 2048)
+positive_embeddings_delta = ops.ones_like(positive_embeddings) * step_size
+positive_pooled_embeddings_delta = ops.ones_like(positive_pooled_embeddings) * step_size
+positive_embeddings_shape = ops.shape(positive_embeddings)
+positive_pooled_embeddings_shape = ops.shape(positive_pooled_embeddings)
+
+walked_positive_embeddings = []
+walked_positive_pooled_embeddings = []
+for step_index in range(walk_steps):
+ walked_positive_embeddings.append(positive_embeddings)
+ walked_positive_pooled_embeddings.append(positive_pooled_embeddings)
+ positive_embeddings += positive_embeddings_delta
+ positive_pooled_embeddings += positive_pooled_embeddings_delta
+walked_positive_embeddings = ops.stack(walked_positive_embeddings, axis=0)
+walked_positive_pooled_embeddings = ops.stack(walked_positive_pooled_embeddings, axis=0)
+walked_positive_embeddings = ops.reshape(
+ walked_positive_embeddings,
+ (
+ batches,
+ batch_size,
+ positive_embeddings_shape[-2],
+ positive_embeddings_shape[-1],
+ ),
+)
+walked_positive_pooled_embeddings = ops.reshape(
+ walked_positive_pooled_embeddings,
+ (batches, batch_size, positive_pooled_embeddings_shape[-1]),
+)
+negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
+negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))
+
+latents = random.normal((1, height // 8, width // 8, 16), seed=42)
+latents = ops.tile(latents, (batch_size, 1, 1, 1))
+
+images = []
+progbar = keras.utils.Progbar(batches)
+for i in range(batches):
+ images.append(
+ generate_function(
+ latents,
+ (
+ walked_positive_embeddings[i],
+ negative_embeddings,
+ walked_positive_pooled_embeddings[i],
+ negative_pooled_embeddings,
+ ),
+ ops.convert_to_tensor(num_steps),
+ ops.convert_to_tensor(guidance_scale),
+ )
+ )
+ progbar.update(i + 1, finalize=i == batches - 1)
+
+images = ops.convert_to_numpy(decode_to_images(images, height, width))
+export_as_gif(
+ "eiffel-tower-starry-night.gif",
+ [Image.fromarray(image) for image in images],
+ frames_per_second=2,
+)
+
+"""
+Perhaps unsurprisingly, walking too far from the encoder's latent manifold
+produces images that look incoherent. Try it for yourself by setting your own
+prompt, and adjusting `step_size` to increase or decrease the magnitude
+of the walk. Note that when the magnitude of the walk gets large, the walk often
+leads into areas which produce extremely noisy images.
+
+## A circular walk through the diffusion latent space for a single prompt
+
+Our final experiment is to stick to one prompt and explore the variety of images
+that the diffusion model can produce from that prompt. We do this by controlling
+the noise that is used to seed the diffusion process.
+
+We create two noise components, `x` and `y`, and do a walk from 0 to 2ฯ, summing
+the cosine of our `x` component and the sin of our `y` component to produce
+noise. Using this approach, the end of our walk arrives at the same noise inputs
+where we began our walk, so we get a "loopable" result!
+"""
+
+walk_steps = 64
+batch_size = 4
+batches = walk_steps // batch_size
+prompt = "An oil paintings of cows in a field next to a windmill in Holland"
+encoding = get_text_embeddings(prompt)
+
+walk_latent_x = random.normal((1, height // 8, width // 8, 16))
+walk_latent_y = random.normal((1, height // 8, width // 8, 16))
+walk_scale_x = ops.cos(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
+walk_scale_y = ops.sin(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
+latent_x = ops.tensordot(walk_scale_x, walk_latent_x, axes=0)
+latent_y = ops.tensordot(walk_scale_y, walk_latent_y, axes=0)
+latents = ops.add(latent_x, latent_y)
+latents = ops.reshape(latents, (batches, batch_size, height // 8, width // 8, 16))
+
+images = []
+progbar = keras.utils.Progbar(batches)
+for i in range(batches):
+ images.append(
+ generate_function(
+ latents[i],
+ (
+ ops.tile(encoding[0], (batch_size, 1, 1)),
+ ops.tile(encoding[1], (batch_size, 1, 1)),
+ ops.tile(encoding[2], (batch_size, 1)),
+ ops.tile(encoding[3], (batch_size, 1)),
+ ),
+ ops.convert_to_tensor(num_steps),
+ ops.convert_to_tensor(guidance_scale),
+ )
+ )
+ progbar.update(i + 1, finalize=i == batches - 1)
+
+images = ops.convert_to_numpy(decode_to_images(images, height, width))
+export_as_gif(
+ "cows.gif",
+ [Image.fromarray(image) for image in images],
+ frames_per_second=4,
+ no_rubber_band=True,
+)
+
+"""
+Experiment with your own prompts and with different values of the parameters!
+
+## Conclusion
+
+Stable Diffusion 3 offers a lot more than just single text-to-image generation.
+Exploring the latent manifold of the text encoder and the latent space of the
+diffusion model are two fun ways to experience the power of this model, and
+KerasHub makes it easy!
+"""
diff --git a/knowledge_base/generative/real_nvp.py b/knowledge_base/generative/real_nvp.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaf6769d476dd4a6c0055b84980309400678d2f0
--- /dev/null
+++ b/knowledge_base/generative/real_nvp.py
@@ -0,0 +1,223 @@
+"""
+Title: Density estimation using Real NVP
+Authors: [Mandolini Giorgio Maria](https://www.linkedin.com/in/giorgio-maria-mandolini-a2a1b71b4/), [Sanna Daniele](https://www.linkedin.com/in/daniele-sanna-338629bb/), [Zannini Quirini Giorgio](https://www.linkedin.com/in/giorgio-zannini-quirini-16ab181a0/)
+Date created: 2020/08/10
+Last modified: 2020/08/10
+Description: Estimating the density distribution of the "double moon" dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The aim of this work is to map a simple distribution - which is easy to sample
+and whose density is simple to estimate - to a more complex one learned from the data.
+This kind of generative model is also known as "normalizing flow".
+
+In order to do this, the model is trained via the maximum
+likelihood principle, using the "change of variable" formula.
+
+We will use an affine coupling function. We create it such that its inverse, as well as
+the determinant of the Jacobian, are easy to obtain (more details in the referenced paper).
+
+**Requirements:**
+
+* Tensorflow 2.9.1
+* Tensorflow probability 0.17.0
+
+**Reference:**
+
+[Density estimation using Real NVP](https://arxiv.org/abs/1605.08803)
+"""
+
+"""
+## Setup
+
+"""
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras import regularizers
+from sklearn.datasets import make_moons
+import numpy as np
+import matplotlib.pyplot as plt
+import tensorflow_probability as tfp
+
+"""
+## Load the data
+"""
+
+data = make_moons(3000, noise=0.05)[0].astype("float32")
+norm = layers.Normalization()
+norm.adapt(data)
+normalized_data = norm(data)
+
+"""
+## Affine coupling layer
+"""
+
+# Creating a custom layer with keras API.
+output_dim = 256
+reg = 0.01
+
+
+def Coupling(input_shape):
+ input = keras.layers.Input(shape=input_shape)
+
+ t_layer_1 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(input)
+ t_layer_2 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(t_layer_1)
+ t_layer_3 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(t_layer_2)
+ t_layer_4 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(t_layer_3)
+ t_layer_5 = keras.layers.Dense(
+ input_shape, activation="linear", kernel_regularizer=regularizers.l2(reg)
+ )(t_layer_4)
+
+ s_layer_1 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(input)
+ s_layer_2 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(s_layer_1)
+ s_layer_3 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(s_layer_2)
+ s_layer_4 = keras.layers.Dense(
+ output_dim, activation="relu", kernel_regularizer=regularizers.l2(reg)
+ )(s_layer_3)
+ s_layer_5 = keras.layers.Dense(
+ input_shape, activation="tanh", kernel_regularizer=regularizers.l2(reg)
+ )(s_layer_4)
+
+ return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5])
+
+
+"""
+## Real NVP
+"""
+
+
+class RealNVP(keras.Model):
+ def __init__(self, num_coupling_layers):
+ super().__init__()
+
+ self.num_coupling_layers = num_coupling_layers
+
+ # Distribution of the latent space.
+ self.distribution = tfp.distributions.MultivariateNormalDiag(
+ loc=[0.0, 0.0], scale_diag=[1.0, 1.0]
+ )
+ self.masks = np.array(
+ [[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype="float32"
+ )
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+ self.layers_list = [Coupling(2) for i in range(num_coupling_layers)]
+
+ @property
+ def metrics(self):
+ """List of the model's metrics.
+
+ We make sure the loss tracker is listed as part of `model.metrics`
+ so that `fit()` and `evaluate()` are able to `reset()` the loss tracker
+ at the start of each epoch and at the start of an `evaluate()` call.
+ """
+ return [self.loss_tracker]
+
+ def call(self, x, training=True):
+ log_det_inv = 0
+ direction = 1
+ if training:
+ direction = -1
+ for i in range(self.num_coupling_layers)[::direction]:
+ x_masked = x * self.masks[i]
+ reversed_mask = 1 - self.masks[i]
+ s, t = self.layers_list[i](x_masked)
+ s *= reversed_mask
+ t *= reversed_mask
+ gate = (direction - 1) / 2
+ x = (
+ reversed_mask
+ * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))
+ + x_masked
+ )
+ log_det_inv += gate * tf.reduce_sum(s, [1])
+
+ return x, log_det_inv
+
+ # Log likelihood of the normal distribution plus the log determinant of the jacobian.
+
+ def log_loss(self, x):
+ y, logdet = self(x)
+ log_likelihood = self.distribution.log_prob(y) + logdet
+ return -tf.reduce_mean(log_likelihood)
+
+ def train_step(self, data):
+ with tf.GradientTape() as tape:
+ loss = self.log_loss(data)
+
+ g = tape.gradient(loss, self.trainable_variables)
+ self.optimizer.apply_gradients(zip(g, self.trainable_variables))
+ self.loss_tracker.update_state(loss)
+
+ return {"loss": self.loss_tracker.result()}
+
+ def test_step(self, data):
+ loss = self.log_loss(data)
+ self.loss_tracker.update_state(loss)
+
+ return {"loss": self.loss_tracker.result()}
+
+
+"""
+## Model training
+"""
+
+model = RealNVP(num_coupling_layers=6)
+
+model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))
+
+history = model.fit(
+ normalized_data, batch_size=256, epochs=300, verbose=2, validation_split=0.2
+)
+
+"""
+## Performance evaluation
+"""
+
+plt.figure(figsize=(15, 10))
+plt.plot(history.history["loss"])
+plt.plot(history.history["val_loss"])
+plt.title("model loss")
+plt.legend(["train", "validation"], loc="upper right")
+plt.ylabel("loss")
+plt.xlabel("epoch")
+
+# From data to latent space.
+z, _ = model(normalized_data)
+
+# From latent space to data.
+samples = model.distribution.sample(3000)
+x, _ = model.predict(samples)
+
+f, axes = plt.subplots(2, 2)
+f.set_size_inches(20, 15)
+
+axes[0, 0].scatter(normalized_data[:, 0], normalized_data[:, 1], color="r")
+axes[0, 0].set(title="Inference data space X", xlabel="x", ylabel="y")
+axes[0, 1].scatter(z[:, 0], z[:, 1], color="r")
+axes[0, 1].set(title="Inference latent space Z", xlabel="x", ylabel="y")
+axes[0, 1].set_xlim([-3.5, 4])
+axes[0, 1].set_ylim([-4, 4])
+axes[1, 0].scatter(samples[:, 0], samples[:, 1], color="g")
+axes[1, 0].set(title="Generated latent space Z", xlabel="x", ylabel="y")
+axes[1, 1].scatter(x[:, 0], x[:, 1], color="g")
+axes[1, 1].set(title="Generated data space X", label="x", ylabel="y")
+axes[1, 1].set_xlim([-2, 2])
+axes[1, 1].set_ylim([-2, 2])
diff --git a/knowledge_base/generative/stylegan.py b/knowledge_base/generative/stylegan.py
new file mode 100644
index 0000000000000000000000000000000000000000..11854f2f0b1d4f526b4df3821462f4a231e8df15
--- /dev/null
+++ b/knowledge_base/generative/stylegan.py
@@ -0,0 +1,772 @@
+"""
+Title: Face image generation with StyleGAN
+Author: [Soon-Yau Cheong](https://www.linkedin.com/in/soonyau/)
+Date created: 2021/07/01
+Last modified: 2021/07/01
+Description: Implementation of StyleGAN for image generation.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The key idea of StyleGAN is to progressively increase the resolution of the generated
+images and to incorporate style features in the generative process.This
+[StyleGAN](https://arxiv.org/abs/1812.04948) implementation is based on the book
+[Hands-on Image Generation with TensorFlow](https://www.amazon.com/dp/1838826785).
+The code from the book's
+[GitHub repository](https://github.com/PacktPublishing/Hands-On-Image-Generation-with-TensorFlow-2.0/tree/master/Chapter07)
+was refactored to leverage a custom `train_step()` to enable
+faster training time via compilation and distribution.
+"""
+
+"""
+## Setup
+"""
+
+"""
+### Install latest TFA
+"""
+"""shell
+pip install tensorflow_addons
+"""
+
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+
+from functools import partial
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras.models import Sequential
+from tensorflow_addons.layers import InstanceNormalization
+
+import gdown
+from zipfile import ZipFile
+
+"""
+## Prepare the dataset
+
+In this example, we will train using the CelebA from the project GDrive.
+"""
+
+
+def log2(x):
+ return int(np.log2(x))
+
+
+# we use different batch size for different resolution, so larger image size
+# could fit into GPU memory. The keys is image resolution in log2
+batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
+# We adjust the train step accordingly
+train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
+
+
+os.makedirs("celeba_gan")
+
+url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
+output = "celeba_gan/data.zip"
+gdown.download(url, output, quiet=True)
+
+with ZipFile("celeba_gan/data.zip", "r") as zipobj:
+ zipobj.extractall("celeba_gan")
+
+# Create a dataset from our folder, and rescale the images to the [0-1] range:
+
+ds_train = keras.utils.image_dataset_from_directory(
+ "celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
+)
+
+
+def resize_image(res, image):
+ # only downsampling, so use nearest neighbor that is faster to run
+ image = tf.image.resize(
+ image, (res, res), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
+ )
+ image = tf.cast(image, tf.float32) / 127.5 - 1.0
+ return image
+
+
+def create_dataloader(res):
+ batch_size = batch_sizes[log2(res)]
+ # NOTE: we unbatch the dataset so we can `batch()` it again with the `drop_remainder=True` option
+ # since the model only supports a single batch size
+ dl = ds_train.map(
+ partial(resize_image, res), num_parallel_calls=tf.data.AUTOTUNE
+ ).unbatch()
+ dl = dl.shuffle(200).batch(batch_size, drop_remainder=True).prefetch(1).repeat()
+ return dl
+
+
+"""
+## Utility function to display images after each epoch
+"""
+
+
+def plot_images(images, log2_res, fname=""):
+ scales = {2: 0.5, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8}
+ scale = scales[log2_res]
+
+ grid_col = min(images.shape[0], int(32 // scale))
+ grid_row = 1
+
+ f, axarr = plt.subplots(
+ grid_row, grid_col, figsize=(grid_col * scale, grid_row * scale)
+ )
+
+ for row in range(grid_row):
+ ax = axarr if grid_row == 1 else axarr[row]
+ for col in range(grid_col):
+ ax[col].imshow(images[row * grid_col + col])
+ ax[col].axis("off")
+ plt.show()
+ if fname:
+ f.savefig(fname)
+
+
+"""
+## Custom Layers
+
+The following are building blocks that will be used to construct the generators and
+discriminators of the StyleGAN model.
+"""
+
+
+def fade_in(alpha, a, b):
+ return alpha * a + (1.0 - alpha) * b
+
+
+def wasserstein_loss(y_true, y_pred):
+ return -tf.reduce_mean(y_true * y_pred)
+
+
+def pixel_norm(x, epsilon=1e-8):
+ return x / tf.math.sqrt(tf.reduce_mean(x**2, axis=-1, keepdims=True) + epsilon)
+
+
+def minibatch_std(input_tensor, epsilon=1e-8):
+ n, h, w, c = tf.shape(input_tensor)
+ group_size = tf.minimum(4, n)
+ x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
+ group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
+ group_std = tf.sqrt(group_var + epsilon)
+ avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
+ x = tf.tile(avg_std, [group_size, h, w, 1])
+ return tf.concat([input_tensor, x], axis=-1)
+
+
+class EqualizedConv(layers.Layer):
+ def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
+ super().__init__(**kwargs)
+ self.kernel = kernel
+ self.out_channels = out_channels
+ self.gain = gain
+ self.pad = kernel != 1
+
+ def build(self, input_shape):
+ self.in_channels = input_shape[-1]
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
+ self.w = self.add_weight(
+ shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
+ initializer=initializer,
+ trainable=True,
+ name="kernel",
+ )
+ self.b = self.add_weight(
+ shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
+ )
+ fan_in = self.kernel * self.kernel * self.in_channels
+ self.scale = tf.sqrt(self.gain / fan_in)
+
+ def call(self, inputs):
+ if self.pad:
+ x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
+ else:
+ x = inputs
+ output = (
+ tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
+ )
+ return output
+
+
+class EqualizedDense(layers.Layer):
+ def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
+ super().__init__(**kwargs)
+ self.units = units
+ self.gain = gain
+ self.learning_rate_multiplier = learning_rate_multiplier
+
+ def build(self, input_shape):
+ self.in_channels = input_shape[-1]
+ initializer = keras.initializers.RandomNormal(
+ mean=0.0, stddev=1.0 / self.learning_rate_multiplier
+ )
+ self.w = self.add_weight(
+ shape=[self.in_channels, self.units],
+ initializer=initializer,
+ trainable=True,
+ name="kernel",
+ )
+ self.b = self.add_weight(
+ shape=(self.units,), initializer="zeros", trainable=True, name="bias"
+ )
+ fan_in = self.in_channels
+ self.scale = tf.sqrt(self.gain / fan_in)
+
+ def call(self, inputs):
+ output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
+ return output * self.learning_rate_multiplier
+
+
+class AddNoise(layers.Layer):
+ def build(self, input_shape):
+ n, h, w, c = input_shape[0]
+ initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
+ self.b = self.add_weight(
+ shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
+ )
+
+ def call(self, inputs):
+ x, noise = inputs
+ output = x + self.b * noise
+ return output
+
+
+class AdaIN(layers.Layer):
+ def __init__(self, gain=1, **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+
+ def build(self, input_shapes):
+ x_shape = input_shapes[0]
+ w_shape = input_shapes[1]
+
+ self.w_channels = w_shape[-1]
+ self.x_channels = x_shape[-1]
+
+ self.dense_1 = EqualizedDense(self.x_channels, gain=1)
+ self.dense_2 = EqualizedDense(self.x_channels, gain=1)
+
+ def call(self, inputs):
+ x, w = inputs
+ ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
+ yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
+ return ys * x + yb
+
+
+"""
+Next we build the following:
+
+- A model mapping to map the random noise into style code
+- The generator
+- The discriminator
+
+For the generator, we build generator blocks at multiple resolutions,
+e.g. 4x4, 8x8, ...up to 1024x1024. We only use 4x4 in the beginning
+and we use progressively larger-resolution blocks as the training proceeds.
+Same for the discriminator.
+"""
+
+
+def Mapping(num_stages, input_shape=512):
+ z = layers.Input(shape=(input_shape))
+ w = pixel_norm(z)
+ for i in range(8):
+ w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
+ w = layers.LeakyReLU(0.2)(w)
+ w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
+ return keras.Model(z, w, name="mapping")
+
+
+class Generator:
+ def __init__(self, start_res_log2, target_res_log2):
+ self.start_res_log2 = start_res_log2
+ self.target_res_log2 = target_res_log2
+ self.num_stages = target_res_log2 - start_res_log2 + 1
+ # list of generator blocks at increasing resolution
+ self.g_blocks = []
+ # list of layers to convert g_block activation to RGB
+ self.to_rgb = []
+ # list of noise input of different resolutions into g_blocks
+ self.noise_inputs = []
+ # filter size to use at each stage, keys are log2(resolution)
+ self.filter_nums = {
+ 0: 512,
+ 1: 512,
+ 2: 512, # 4x4
+ 3: 512, # 8x8
+ 4: 512, # 16x16
+ 5: 512, # 32x32
+ 6: 256, # 64x64
+ 7: 128, # 128x128
+ 8: 64, # 256x256
+ 9: 32, # 512x512
+ 10: 16,
+ } # 1024x1024
+
+ start_res = 2**start_res_log2
+ self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
+ self.g_input = layers.Input(self.input_shape, name="generator_input")
+
+ for i in range(start_res_log2, target_res_log2 + 1):
+ filter_num = self.filter_nums[i]
+ res = 2**i
+ self.noise_inputs.append(
+ layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
+ )
+ to_rgb = Sequential(
+ [
+ layers.InputLayer(input_shape=(res, res, filter_num)),
+ EqualizedConv(3, 1, gain=1),
+ ],
+ name=f"to_rgb_{res}x{res}",
+ )
+ self.to_rgb.append(to_rgb)
+ is_base = i == self.start_res_log2
+ if is_base:
+ input_shape = (res, res, self.filter_nums[i - 1])
+ else:
+ input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
+ g_block = self.build_block(
+ filter_num, res=res, input_shape=input_shape, is_base=is_base
+ )
+ self.g_blocks.append(g_block)
+
+ def build_block(self, filter_num, res, input_shape, is_base):
+ input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
+ noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
+ w = layers.Input(shape=512)
+ x = input_tensor
+
+ if not is_base:
+ x = layers.UpSampling2D((2, 2))(x)
+ x = EqualizedConv(filter_num, 3)(x)
+
+ x = AddNoise()([x, noise])
+ x = layers.LeakyReLU(0.2)(x)
+ x = InstanceNormalization()(x)
+ x = AdaIN()([x, w])
+
+ x = EqualizedConv(filter_num, 3)(x)
+ x = AddNoise()([x, noise])
+ x = layers.LeakyReLU(0.2)(x)
+ x = InstanceNormalization()(x)
+ x = AdaIN()([x, w])
+ return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
+
+ def grow(self, res_log2):
+ res = 2**res_log2
+
+ num_stages = res_log2 - self.start_res_log2 + 1
+ w = layers.Input(shape=(self.num_stages, 512), name="w")
+
+ alpha = layers.Input(shape=(1), name="g_alpha")
+ x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
+
+ if num_stages == 1:
+ rgb = self.to_rgb[0](x)
+ else:
+ for i in range(1, num_stages - 1):
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
+
+ old_rgb = self.to_rgb[num_stages - 2](x)
+ old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
+
+ i = num_stages - 1
+ x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
+
+ new_rgb = self.to_rgb[i](x)
+
+ rgb = fade_in(alpha[0], new_rgb, old_rgb)
+
+ return keras.Model(
+ [self.g_input, w, self.noise_inputs, alpha],
+ rgb,
+ name=f"generator_{res}_x_{res}",
+ )
+
+
+class Discriminator:
+ def __init__(self, start_res_log2, target_res_log2):
+ self.start_res_log2 = start_res_log2
+ self.target_res_log2 = target_res_log2
+ self.num_stages = target_res_log2 - start_res_log2 + 1
+ # filter size to use at each stage, keys are log2(resolution)
+ self.filter_nums = {
+ 0: 512,
+ 1: 512,
+ 2: 512, # 4x4
+ 3: 512, # 8x8
+ 4: 512, # 16x16
+ 5: 512, # 32x32
+ 6: 256, # 64x64
+ 7: 128, # 128x128
+ 8: 64, # 256x256
+ 9: 32, # 512x512
+ 10: 16,
+ } # 1024x1024
+ # list of discriminator blocks at increasing resolution
+ self.d_blocks = []
+ # list of layers to convert RGB into activation for d_blocks inputs
+ self.from_rgb = []
+
+ for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
+ res = 2**res_log2
+ filter_num = self.filter_nums[res_log2]
+ from_rgb = Sequential(
+ [
+ layers.InputLayer(
+ input_shape=(res, res, 3), name=f"from_rgb_input_{res}"
+ ),
+ EqualizedConv(filter_num, 1),
+ layers.LeakyReLU(0.2),
+ ],
+ name=f"from_rgb_{res}",
+ )
+
+ self.from_rgb.append(from_rgb)
+
+ input_shape = (res, res, filter_num)
+ if len(self.d_blocks) == 0:
+ d_block = self.build_base(filter_num, res)
+ else:
+ d_block = self.build_block(
+ filter_num, self.filter_nums[res_log2 - 1], res
+ )
+
+ self.d_blocks.append(d_block)
+
+ def build_base(self, filter_num, res):
+ input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
+ x = minibatch_std(input_tensor)
+ x = EqualizedConv(filter_num, 3)(x)
+ x = layers.LeakyReLU(0.2)(x)
+ x = layers.Flatten()(x)
+ x = EqualizedDense(filter_num)(x)
+ x = layers.LeakyReLU(0.2)(x)
+ x = EqualizedDense(1)(x)
+ return keras.Model(input_tensor, x, name=f"d_{res}")
+
+ def build_block(self, filter_num_1, filter_num_2, res):
+ input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
+ x = EqualizedConv(filter_num_1, 3)(input_tensor)
+ x = layers.LeakyReLU(0.2)(x)
+ x = EqualizedConv(filter_num_2)(x)
+ x = layers.LeakyReLU(0.2)(x)
+ x = layers.AveragePooling2D((2, 2))(x)
+ return keras.Model(input_tensor, x, name=f"d_{res}")
+
+ def grow(self, res_log2):
+ res = 2**res_log2
+ idx = res_log2 - self.start_res_log2
+ alpha = layers.Input(shape=(1), name="d_alpha")
+ input_image = layers.Input(shape=(res, res, 3), name="input_image")
+ x = self.from_rgb[idx](input_image)
+ x = self.d_blocks[idx](x)
+ if idx > 0:
+ idx -= 1
+ downsized_image = layers.AveragePooling2D((2, 2))(input_image)
+ y = self.from_rgb[idx](downsized_image)
+ x = fade_in(alpha[0], x, y)
+
+ for i in range(idx, -1, -1):
+ x = self.d_blocks[i](x)
+ return keras.Model([input_image, alpha], x, name=f"discriminator_{res}_x_{res}")
+
+
+"""
+## Build StyleGAN with custom train step
+"""
+
+
+class StyleGAN(tf.keras.Model):
+ def __init__(self, z_dim=512, target_res=64, start_res=4):
+ super().__init__()
+ self.z_dim = z_dim
+
+ self.target_res_log2 = log2(target_res)
+ self.start_res_log2 = log2(start_res)
+ self.current_res_log2 = self.target_res_log2
+ self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
+
+ self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
+
+ self.mapping = Mapping(num_stages=self.num_stages)
+ self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
+ self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
+ self.g_input_shape = self.g_builder.input_shape
+
+ self.phase = None
+ self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
+
+ self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
+
+ def grow_model(self, res):
+ tf.keras.backend.clear_session()
+ res_log2 = log2(res)
+ self.generator = self.g_builder.grow(res_log2)
+ self.discriminator = self.d_builder.grow(res_log2)
+ self.current_res_log2 = res_log2
+ print(f"\nModel resolution:{res}x{res}")
+
+ def compile(
+ self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
+ ):
+ self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
+ self.steps_per_epoch = steps_per_epoch
+ if res != 2**self.current_res_log2:
+ self.grow_model(res)
+ self.d_optimizer = d_optimizer
+ self.g_optimizer = g_optimizer
+
+ self.train_step_counter.assign(0)
+ self.phase = phase
+ self.d_loss_metric = keras.metrics.Mean(name="d_loss")
+ self.g_loss_metric = keras.metrics.Mean(name="g_loss")
+ super().compile(*args, **kwargs)
+
+ @property
+ def metrics(self):
+ return [self.d_loss_metric, self.g_loss_metric]
+
+ def generate_noise(self, batch_size):
+ noise = [
+ tf.random.normal((batch_size, 2**res, 2**res, 1))
+ for res in range(self.start_res_log2, self.target_res_log2 + 1)
+ ]
+ return noise
+
+ def gradient_loss(self, grad):
+ loss = tf.square(grad)
+ loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
+ loss = tf.sqrt(loss)
+ loss = tf.reduce_mean(tf.square(loss - 1))
+ return loss
+
+ def train_step(self, real_images):
+ self.train_step_counter.assign_add(1)
+
+ if self.phase == "TRANSITION":
+ self.alpha.assign(
+ tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
+ )
+ elif self.phase == "STABLE":
+ self.alpha.assign(1.0)
+ else:
+ raise NotImplementedError
+ alpha = tf.expand_dims(self.alpha, 0)
+ batch_size = tf.shape(real_images)[0]
+ real_labels = tf.ones(batch_size)
+ fake_labels = -tf.ones(batch_size)
+
+ z = tf.random.normal((batch_size, self.z_dim))
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
+ noise = self.generate_noise(batch_size)
+
+ # generator
+ with tf.GradientTape() as g_tape:
+ w = self.mapping(z)
+ fake_images = self.generator([const_input, w, noise, alpha])
+ pred_fake = self.discriminator([fake_images, alpha])
+ g_loss = wasserstein_loss(real_labels, pred_fake)
+
+ trainable_weights = (
+ self.mapping.trainable_weights + self.generator.trainable_weights
+ )
+ gradients = g_tape.gradient(g_loss, trainable_weights)
+ self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
+
+ # discriminator
+ with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
+ # forward pass
+ pred_fake = self.discriminator([fake_images, alpha])
+ pred_real = self.discriminator([real_images, alpha])
+
+ epsilon = tf.random.uniform((batch_size, 1, 1, 1))
+ interpolates = epsilon * real_images + (1 - epsilon) * fake_images
+ gradient_tape.watch(interpolates)
+ pred_fake_grad = self.discriminator([interpolates, alpha])
+
+ # calculate losses
+ loss_fake = wasserstein_loss(fake_labels, pred_fake)
+ loss_real = wasserstein_loss(real_labels, pred_real)
+ loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
+
+ # gradient penalty
+ gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
+ gradient_penalty = self.loss_weights[
+ "gradient_penalty"
+ ] * self.gradient_loss(gradients_fake)
+
+ # drift loss
+ all_pred = tf.concat([pred_fake, pred_real], axis=0)
+ drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred**2)
+
+ d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
+
+ gradients = total_tape.gradient(
+ d_loss, self.discriminator.trainable_weights
+ )
+ self.d_optimizer.apply_gradients(
+ zip(gradients, self.discriminator.trainable_weights)
+ )
+
+ # Update metrics
+ self.d_loss_metric.update_state(d_loss)
+ self.g_loss_metric.update_state(g_loss)
+ return {
+ "d_loss": self.d_loss_metric.result(),
+ "g_loss": self.g_loss_metric.result(),
+ }
+
+ def call(self, inputs: dict()):
+ style_code = inputs.get("style_code", None)
+ z = inputs.get("z", None)
+ noise = inputs.get("noise", None)
+ batch_size = inputs.get("batch_size", 1)
+ alpha = inputs.get("alpha", 1.0)
+ alpha = tf.expand_dims(alpha, 0)
+ if style_code is None:
+ if z is None:
+ z = tf.random.normal((batch_size, self.z_dim))
+ style_code = self.mapping(z)
+
+ if noise is None:
+ noise = self.generate_noise(batch_size)
+
+ # self.alpha.assign(alpha)
+
+ const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
+ images = self.generator([const_input, style_code, noise, alpha])
+ images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
+
+ return images
+
+
+"""
+## Training
+
+We first build the StyleGAN at smallest resolution, such as 4x4 or 8x8. Then we
+progressively grow the model to higher resolution by appending new generator and
+discriminator blocks.
+"""
+
+START_RES = 4
+TARGET_RES = 128
+
+style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
+
+"""
+The training for each new resolution happens in two phases - "transition" and "stable".
+In the transition phase, the features from the previous resolution are mixed with the
+current resolution. This allows for a smoother transition when scaling up. We use each
+epoch in `model.fit()` as a phase.
+"""
+
+
+def train(
+ start_res=START_RES,
+ target_res=TARGET_RES,
+ steps_per_epoch=5000,
+ display_images=True,
+):
+ opt_cfg = {"learning_rate": 1e-3, "beta_1": 0.0, "beta_2": 0.99, "epsilon": 1e-8}
+
+ val_batch_size = 16
+ val_z = tf.random.normal((val_batch_size, style_gan.z_dim))
+ val_noise = style_gan.generate_noise(val_batch_size)
+
+ start_res_log2 = int(np.log2(start_res))
+ target_res_log2 = int(np.log2(target_res))
+
+ for res_log2 in range(start_res_log2, target_res_log2 + 1):
+ res = 2**res_log2
+ for phase in ["TRANSITION", "STABLE"]:
+ if res == start_res and phase == "TRANSITION":
+ continue
+
+ train_dl = create_dataloader(res)
+
+ steps = int(train_step_ratio[res_log2] * steps_per_epoch)
+
+ style_gan.compile(
+ d_optimizer=tf.keras.optimizers.legacy.Adam(**opt_cfg),
+ g_optimizer=tf.keras.optimizers.legacy.Adam(**opt_cfg),
+ loss_weights={"gradient_penalty": 10, "drift": 0.001},
+ steps_per_epoch=steps,
+ res=res,
+ phase=phase,
+ run_eagerly=False,
+ )
+
+ prefix = f"res_{res}x{res}_{style_gan.phase}"
+
+ ckpt_cb = keras.callbacks.ModelCheckpoint(
+ f"checkpoints/stylegan_{res}x{res}.ckpt",
+ save_weights_only=True,
+ verbose=0,
+ )
+ print(phase)
+ style_gan.fit(
+ train_dl, epochs=1, steps_per_epoch=steps, callbacks=[ckpt_cb]
+ )
+
+ if display_images:
+ images = style_gan({"z": val_z, "noise": val_noise, "alpha": 1.0})
+ plot_images(images, res_log2)
+
+
+"""
+StyleGAN can take a long time to train, in the code below, a small `steps_per_epoch`
+value of 1 is used to sanity-check the code is working alright. In practice, a larger
+`steps_per_epoch` value (over 10000)
+is required to get decent results.
+"""
+
+train(start_res=4, target_res=16, steps_per_epoch=1, display_images=False)
+
+"""
+## Results
+
+We can now run some inference using pre-trained 64x64 checkpoints. In general, the image
+fidelity increases with the resolution. You can try to train this StyleGAN to resolutions
+above 128x128 with the CelebA HQ dataset.
+"""
+
+url = "https://github.com/soon-yau/stylegan_keras/releases/download/keras_example_v1.0/stylegan_128x128.ckpt.zip"
+
+weights_path = keras.utils.get_file(
+ "stylegan_128x128.ckpt.zip",
+ url,
+ extract=True,
+ cache_dir=os.path.abspath("."),
+ cache_subdir="pretrained",
+)
+
+style_gan.grow_model(128)
+style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
+
+tf.random.set_seed(196)
+batch_size = 2
+z = tf.random.normal((batch_size, style_gan.z_dim))
+w = style_gan.mapping(z)
+noise = style_gan.generate_noise(batch_size=batch_size)
+images = style_gan({"style_code": w, "noise": noise, "alpha": 1.0})
+plot_images(images, 5)
+
+"""
+## Style Mixing
+
+We can also mix styles from two images to create a new image.
+"""
+
+alpha = 0.4
+w_mix = np.expand_dims(alpha * w[0] + (1 - alpha) * w[1], 0)
+noise_a = [np.expand_dims(n[0], 0) for n in noise]
+mix_images = style_gan({"style_code": w_mix, "noise": noise_a})
+image_row = np.hstack([images[0], images[1], mix_images[0]])
+plt.figure(figsize=(9, 3))
+plt.imshow(image_row)
+plt.axis("off")
diff --git a/knowledge_base/generative/text_generation_fnet.py b/knowledge_base/generative/text_generation_fnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eca13c9efc8738771ebcdef6549f2adf52c7bb5
--- /dev/null
+++ b/knowledge_base/generative/text_generation_fnet.py
@@ -0,0 +1,387 @@
+"""
+Title: Text Generation using FNet
+Author: [Darshan Deshpande](https://twitter.com/getdarshan)
+Date created: 2021/10/05
+Last modified: 2021/10/05
+Description: FNet transformer for text generation in Keras.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The original transformer implementation (Vaswani et al., 2017) was one of the major
+breakthroughs in Natural Language Processing, giving rise to important architectures such BERT and GPT.
+However, the drawback of these architectures is
+that the self-attention mechanism they use is computationally expensive. The FNet
+architecture proposes to replace this self-attention attention with a leaner mechanism:
+a Fourier transformation-based linear mixer for input tokens.
+
+The FNet model was able to achieve 92-97% of BERT's accuracy while training 80% faster on
+GPUs and almost 70% faster on TPUs. This type of design provides an efficient and small
+model size, leading to faster inference times.
+
+In this example, we will implement and train this architecture on the Cornell Movie
+Dialog corpus to show the applicability of this model to text generation.
+"""
+
+"""
+## Imports
+"""
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import os
+
+# Defining hyperparameters
+
+VOCAB_SIZE = 8192
+MAX_SAMPLES = 50000
+BUFFER_SIZE = 20000
+MAX_LENGTH = 40
+EMBED_DIM = 256
+LATENT_DIM = 512
+NUM_HEADS = 8
+BATCH_SIZE = 64
+
+"""
+## Loading data
+
+We will be using the Cornell Dialog Corpus. We will parse the movie conversations into
+questions and answers sets.
+"""
+
+path_to_zip = keras.utils.get_file(
+ "cornell_movie_dialogs.zip",
+ origin="http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip",
+ extract=True,
+)
+
+path_to_dataset = os.path.join(
+ os.path.dirname(path_to_zip), "cornell movie-dialogs corpus"
+)
+path_to_movie_lines = os.path.join(path_to_dataset, "movie_lines.txt")
+path_to_movie_conversations = os.path.join(path_to_dataset, "movie_conversations.txt")
+
+
+def load_conversations():
+ # Helper function for loading the conversation splits
+ id2line = {}
+ with open(path_to_movie_lines, errors="ignore") as file:
+ lines = file.readlines()
+ for line in lines:
+ parts = line.replace("\n", "").split(" +++$+++ ")
+ id2line[parts[0]] = parts[4]
+
+ inputs, outputs = [], []
+ with open(path_to_movie_conversations, "r") as file:
+ lines = file.readlines()
+ for line in lines:
+ parts = line.replace("\n", "").split(" +++$+++ ")
+ # get conversation in a list of line ID
+ conversation = [line[1:-1] for line in parts[3][1:-1].split(", ")]
+ for i in range(len(conversation) - 1):
+ inputs.append(id2line[conversation[i]])
+ outputs.append(id2line[conversation[i + 1]])
+ if len(inputs) >= MAX_SAMPLES:
+ return inputs, outputs
+ return inputs, outputs
+
+
+questions, answers = load_conversations()
+
+# Splitting training and validation sets
+
+train_dataset = tf.data.Dataset.from_tensor_slices((questions[:40000], answers[:40000]))
+val_dataset = tf.data.Dataset.from_tensor_slices((questions[40000:], answers[40000:]))
+
+"""
+### Preprocessing and Tokenization
+"""
+
+
+def preprocess_text(sentence):
+ sentence = tf.strings.lower(sentence)
+ # Adding a space between the punctuation and the last word to allow better tokenization
+ sentence = tf.strings.regex_replace(sentence, r"([?.!,])", r" \1 ")
+ # Replacing multiple continuous spaces with a single space
+ sentence = tf.strings.regex_replace(sentence, r"\s\s+", " ")
+ # Replacing non english words with spaces
+ sentence = tf.strings.regex_replace(sentence, r"[^a-z?.!,]+", " ")
+ sentence = tf.strings.strip(sentence)
+ sentence = tf.strings.join(["[start]", sentence, "[end]"], separator=" ")
+ return sentence
+
+
+vectorizer = layers.TextVectorization(
+ VOCAB_SIZE,
+ standardize=preprocess_text,
+ output_mode="int",
+ output_sequence_length=MAX_LENGTH,
+)
+
+# We will adapt the vectorizer to both the questions and answers
+# This dataset is batched to parallelize and speed up the process
+vectorizer.adapt(tf.data.Dataset.from_tensor_slices((questions + answers)).batch(128))
+
+"""
+### Tokenizing and padding sentences using `TextVectorization`
+"""
+
+
+def vectorize_text(inputs, outputs):
+ inputs, outputs = vectorizer(inputs), vectorizer(outputs)
+ # One extra padding token to the right to match the output shape
+ outputs = tf.pad(outputs, [[0, 1]])
+ return (
+ {"encoder_inputs": inputs, "decoder_inputs": outputs[:-1]},
+ {"outputs": outputs[1:]},
+ )
+
+
+train_dataset = train_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
+val_dataset = val_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
+
+train_dataset = (
+ train_dataset.cache()
+ .shuffle(BUFFER_SIZE)
+ .batch(BATCH_SIZE)
+ .prefetch(tf.data.AUTOTUNE)
+)
+val_dataset = val_dataset.cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
+
+"""
+## Creating the FNet Encoder
+
+The FNet paper proposes a replacement for the standard attention mechanism used by the
+Transformer architecture (Vaswani et al., 2017).
+
+
+
+The outputs of the FFT layer are complex numbers. To avoid dealing with complex layers,
+only the real part (the magnitude) is extracted.
+
+The dense layers that follow the Fourier transformation act as convolutions applied on
+the frequency domain.
+"""
+
+
+class FNetEncoder(layers.Layer):
+ def __init__(self, embed_dim, dense_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.dense_dim = dense_dim
+ self.dense_proj = keras.Sequential(
+ [
+ layers.Dense(dense_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+
+ def call(self, inputs):
+ # Casting the inputs to complex64
+ inp_complex = tf.cast(inputs, tf.complex64)
+ # Projecting the inputs to the frequency domain using FFT2D and
+ # extracting the real part of the output
+ fft = tf.math.real(tf.signal.fft2d(inp_complex))
+ proj_input = self.layernorm_1(inputs + fft)
+ proj_output = self.dense_proj(proj_input)
+ return self.layernorm_2(proj_input + proj_output)
+
+
+"""
+## Creating the Decoder
+
+The decoder architecture remains the same as the one proposed by (Vaswani et al., 2017)
+in the original transformer architecture, consisting of an embedding, positional
+encoding, two masked multi-head attention layers and finally the dense output layers.
+The architecture that follows is taken from
+[Deep Learning with Python, second edition, chapter 11](https://www.manning.com/books/deep-learning-with-python-second-edition).
+
+"""
+
+
+class PositionalEmbedding(layers.Layer):
+ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.token_embeddings = layers.Embedding(
+ input_dim=vocab_size, output_dim=embed_dim
+ )
+ self.position_embeddings = layers.Embedding(
+ input_dim=sequence_length, output_dim=embed_dim
+ )
+ self.sequence_length = sequence_length
+ self.vocab_size = vocab_size
+ self.embed_dim = embed_dim
+
+ def call(self, inputs):
+ length = tf.shape(inputs)[-1]
+ positions = tf.range(start=0, limit=length, delta=1)
+ embedded_tokens = self.token_embeddings(inputs)
+ embedded_positions = self.position_embeddings(positions)
+ return embedded_tokens + embedded_positions
+
+ def compute_mask(self, inputs, mask=None):
+ return tf.math.not_equal(inputs, 0)
+
+
+class FNetDecoder(layers.Layer):
+ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.latent_dim = latent_dim
+ self.num_heads = num_heads
+ self.attention_1 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.attention_2 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.dense_proj = keras.Sequential(
+ [
+ layers.Dense(latent_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+ self.layernorm_3 = layers.LayerNormalization()
+ self.supports_masking = True
+
+ def call(self, inputs, encoder_outputs, mask=None):
+ causal_mask = self.get_causal_attention_mask(inputs)
+ if mask is not None:
+ padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")
+ padding_mask = tf.minimum(padding_mask, causal_mask)
+
+ attention_output_1 = self.attention_1(
+ query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
+ )
+ out_1 = self.layernorm_1(inputs + attention_output_1)
+
+ attention_output_2 = self.attention_2(
+ query=out_1,
+ value=encoder_outputs,
+ key=encoder_outputs,
+ attention_mask=padding_mask,
+ )
+ out_2 = self.layernorm_2(out_1 + attention_output_2)
+
+ proj_output = self.dense_proj(out_2)
+ return self.layernorm_3(out_2 + proj_output)
+
+ def get_causal_attention_mask(self, inputs):
+ input_shape = tf.shape(inputs)
+ batch_size, sequence_length = input_shape[0], input_shape[1]
+ i = tf.range(sequence_length)[:, tf.newaxis]
+ j = tf.range(sequence_length)
+ mask = tf.cast(i >= j, dtype="int32")
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
+ mult = tf.concat(
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
+ axis=0,
+ )
+ return tf.tile(mask, mult)
+
+
+def create_model():
+ encoder_inputs = keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
+ x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(encoder_inputs)
+ encoder_outputs = FNetEncoder(EMBED_DIM, LATENT_DIM)(x)
+ encoder = keras.Model(encoder_inputs, encoder_outputs)
+ decoder_inputs = keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")
+ encoded_seq_inputs = keras.Input(
+ shape=(None, EMBED_DIM), name="decoder_state_inputs"
+ )
+ x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(decoder_inputs)
+ x = FNetDecoder(EMBED_DIM, LATENT_DIM, NUM_HEADS)(x, encoded_seq_inputs)
+ x = layers.Dropout(0.5)(x)
+ decoder_outputs = layers.Dense(VOCAB_SIZE, activation="softmax")(x)
+ decoder = keras.Model(
+ [decoder_inputs, encoded_seq_inputs], decoder_outputs, name="outputs"
+ )
+ decoder_outputs = decoder([decoder_inputs, encoder_outputs])
+ fnet = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs, name="fnet")
+ return fnet
+
+
+"""
+## Creating and Training the model
+"""
+
+fnet = create_model()
+fnet.compile("adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
+
+"""
+Here, the `epochs` parameter is set to a single epoch, but in practice the model will take around
+**20-30 epochs** of training to start outputting comprehensible sentences. Although accuracy
+is not a good measure for this task, we will use it just to get a hint of the improvement
+of the network.
+"""
+
+fnet.fit(train_dataset, epochs=1, validation_data=val_dataset)
+
+"""
+## Performing inference
+"""
+
+VOCAB = vectorizer.get_vocabulary()
+
+
+def decode_sentence(input_sentence):
+ # Mapping the input sentence to tokens and adding start and end tokens
+ tokenized_input_sentence = vectorizer(
+ tf.constant("[start] " + preprocess_text(input_sentence) + " [end]")
+ )
+ # Initializing the initial sentence consisting of only the start token.
+ tokenized_target_sentence = tf.expand_dims(VOCAB.index("[start]"), 0)
+ decoded_sentence = ""
+
+ for i in range(MAX_LENGTH):
+ # Get the predictions
+ predictions = fnet.predict(
+ {
+ "encoder_inputs": tf.expand_dims(tokenized_input_sentence, 0),
+ "decoder_inputs": tf.expand_dims(
+ tf.pad(
+ tokenized_target_sentence,
+ [[0, MAX_LENGTH - tf.shape(tokenized_target_sentence)[0]]],
+ ),
+ 0,
+ ),
+ }
+ )
+ # Calculating the token with maximum probability and getting the corresponding word
+ sampled_token_index = tf.argmax(predictions[0, i, :])
+ sampled_token = VOCAB[sampled_token_index.numpy()]
+ # If sampled token is the end token then stop generating and return the sentence
+ if tf.equal(sampled_token_index, VOCAB.index("[end]")):
+ break
+ decoded_sentence += sampled_token + " "
+ tokenized_target_sentence = tf.concat(
+ [tokenized_target_sentence, [sampled_token_index]], 0
+ )
+
+ return decoded_sentence
+
+
+decode_sentence("Where have you been all this time?")
+
+"""
+## Conclusion
+
+This example shows how to train and perform inference using the FNet model.
+For getting insight into the architecture or for further reading, you can refer to:
+
+1. [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824v3)
+(Lee-Thorp et al., 2021)
+2. [Attention Is All You Need](https://arxiv.org/abs/1706.03762v5) (Vaswani et al.,
+2017)
+
+Thanks to Franรงois Chollet for his Keras example on
+[English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)
+from which the decoder implementation was extracted.
+"""
diff --git a/knowledge_base/generative/text_generation_gpt.py b/knowledge_base/generative/text_generation_gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b63ab760c1d5319afd1e40142acb67e938be265
--- /dev/null
+++ b/knowledge_base/generative/text_generation_gpt.py
@@ -0,0 +1,422 @@
+"""
+Title: GPT text generation from scratch with KerasHub
+Author: [Jesse Chan](https://github.com/jessechancy)
+Date created: 2022/07/25
+Last modified: 2022/07/25
+Description: Using KerasHub to train a mini-GPT model for text generation.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we will use KerasHub to build a scaled down Generative
+Pre-Trained (GPT) model. GPT is a Transformer-based model that allows you to generate
+sophisticated text from a prompt.
+
+We will train the model on the [simplebooks-92](https://arxiv.org/abs/1911.12391) corpus,
+which is a dataset made from several novels. It is a good dataset for this example since
+it has a small vocabulary and high word frequency, which is beneficial when training a
+model with few parameters.
+
+This example combines concepts from
+[Text generation with a miniature GPT](https://keras.io/examples/generative/text_generation_with_miniature_gpt/)
+with KerasHub abstractions. We will demonstrate how KerasHub tokenization, layers and
+metrics simplify the training
+process, and then show how to generate output text using the KerasHub sampling utilities.
+
+Note: If you are running this example on a Colab,
+make sure to enable GPU runtime for faster training.
+
+This example requires KerasHub. You can install it via the following command:
+`pip install keras-hub`
+"""
+
+"""
+## Setup
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+import os
+import keras_hub
+import keras
+
+import tensorflow.data as tf_data
+import tensorflow.strings as tf_strings
+
+"""
+## Settings & hyperparameters
+"""
+
+# Data
+BATCH_SIZE = 64
+MIN_STRING_LEN = 512 # Strings shorter than this will be discarded
+SEQ_LEN = 128 # Length of training sequences, in tokens
+
+# Model
+EMBED_DIM = 256
+FEED_FORWARD_DIM = 128
+NUM_HEADS = 3
+NUM_LAYERS = 2
+VOCAB_SIZE = 5000 # Limits parameters in model.
+
+# Training
+EPOCHS = 5
+
+# Inference
+NUM_TOKENS_TO_GENERATE = 80
+
+"""
+## Load the data
+
+Now, let's download the dataset! The SimpleBooks dataset consists of 1,573 Gutenberg books, and has
+one of the smallest vocabulary size to word-level tokens ratio. It has a vocabulary size of ~98k,
+a third of WikiText-103's, with around the same number of tokens (~100M). This makes it easy to fit a small model.
+"""
+
+keras.utils.get_file(
+ origin="https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip",
+ extract=True,
+)
+dir = os.path.expanduser("~/.keras/datasets/simplebooks/")
+
+# Load simplebooks-92 train set and filter out short lines.
+raw_train_ds = (
+ tf_data.TextLineDataset(dir + "simplebooks-92-raw/train.txt")
+ .filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
+ .batch(BATCH_SIZE)
+ .shuffle(buffer_size=256)
+)
+
+# Load simplebooks-92 validation set and filter out short lines.
+raw_val_ds = (
+ tf_data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt")
+ .filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
+ .batch(BATCH_SIZE)
+)
+
+"""
+## Train the tokenizer
+
+We train the tokenizer from the training dataset for a vocabulary size of `VOCAB_SIZE`,
+which is a tuned hyperparameter. We want to limit the vocabulary as much as possible, as
+we will see later on
+that it has a large effect on the number of model parameters. We also don't want to include
+*too few* vocabulary terms, or there would be too many out-of-vocabulary (OOV) sub-words. In
+addition, three tokens are reserved in the vocabulary:
+
+- `"[PAD]"` for padding sequences to `SEQ_LEN`. This token has index 0 in both
+`reserved_tokens` and `vocab`, since `WordPieceTokenizer` (and other layers) consider
+`0`/`vocab[0]` as the default padding.
+- `"[UNK]"` for OOV sub-words, which should match the default `oov_token="[UNK]"` in
+`WordPieceTokenizer`.
+- `"[BOS]"` stands for beginning of sentence, but here technically it is a token
+representing the beginning of each line of training data.
+"""
+
+# Train tokenizer vocabulary
+vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
+ raw_train_ds,
+ vocabulary_size=VOCAB_SIZE,
+ lowercase=True,
+ reserved_tokens=["[PAD]", "[UNK]", "[BOS]"],
+)
+
+"""
+## Load tokenizer
+
+We use the vocabulary data to initialize
+`keras_hub.tokenizers.WordPieceTokenizer`. WordPieceTokenizer is an efficient
+implementation of the WordPiece algorithm used by BERT and other models. It will strip,
+lower-case and do other irreversible preprocessing operations.
+"""
+
+tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
+ vocabulary=vocab,
+ sequence_length=SEQ_LEN,
+ lowercase=True,
+)
+
+"""
+## Tokenize data
+
+We preprocess the dataset by tokenizing and splitting it into `features` and `labels`.
+"""
+
+# packer adds a start token
+start_packer = keras_hub.layers.StartEndPacker(
+ sequence_length=SEQ_LEN,
+ start_value=tokenizer.token_to_id("[BOS]"),
+)
+
+
+def preprocess(inputs):
+ outputs = tokenizer(inputs)
+ features = start_packer(outputs)
+ labels = outputs
+ return features, labels
+
+
+# Tokenize and split into train and label sequences.
+train_ds = raw_train_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(
+ tf_data.AUTOTUNE
+)
+val_ds = raw_val_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(
+ tf_data.AUTOTUNE
+)
+
+"""
+## Build the model
+
+We create our scaled down GPT model with the following layers:
+
+- One `keras_hub.layers.TokenAndPositionEmbedding` layer, which combines the embedding
+for the token and its position.
+- Multiple `keras_hub.layers.TransformerDecoder` layers, with the default causal masking.
+The layer has no cross-attention when run with decoder sequence only.
+- One final dense linear layer
+"""
+
+inputs = keras.layers.Input(shape=(None,), dtype="int32")
+# Embedding.
+embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(
+ vocabulary_size=VOCAB_SIZE,
+ sequence_length=SEQ_LEN,
+ embedding_dim=EMBED_DIM,
+ mask_zero=True,
+)
+x = embedding_layer(inputs)
+# Transformer decoders.
+for _ in range(NUM_LAYERS):
+ decoder_layer = keras_hub.layers.TransformerDecoder(
+ num_heads=NUM_HEADS,
+ intermediate_dim=FEED_FORWARD_DIM,
+ )
+ x = decoder_layer(x) # Giving one argument only skips cross-attention.
+# Output.
+outputs = keras.layers.Dense(VOCAB_SIZE)(x)
+model = keras.Model(inputs=inputs, outputs=outputs)
+loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+perplexity = keras_hub.metrics.Perplexity(from_logits=True, mask_token_id=0)
+model.compile(optimizer="adam", loss=loss_fn, metrics=[perplexity])
+
+"""
+Let's take a look at our model summary - a large majority of the
+parameters are in the `token_and_position_embedding` and the output `dense` layer!
+This means that the vocabulary size (`VOCAB_SIZE`) has a large effect on the size of the model,
+while the number of Transformer decoder layers (`NUM_LAYERS`) doesn't affect it as much.
+"""
+
+model.summary()
+
+"""
+## Training
+
+Now that we have our model, let's train it with the `fit()` method.
+"""
+
+model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
+
+"""
+## Inference
+
+With our trained model, we can test it out to gauge its performance. To do this
+we can seed our model with an input sequence starting with the `"[BOS]"` token,
+and progressively sample the model by making predictions for each subsequent
+token in a loop.
+
+To start lets build a prompt with the same shape as our model inputs, containing
+only the `"[BOS]"` token.
+"""
+
+# The "packer" layers adds the [BOS] token for us.
+prompt_tokens = start_packer(tokenizer([""]))
+prompt_tokens
+
+"""
+We will use the `keras_hub.samplers` module for inference, which requires a
+callback function wrapping the model we just trained. This wrapper calls
+the model and returns the logit predictions for the current token we are
+generating.
+
+Note: There are two pieces of more advanced functionality available when
+defining your callback. The first is the ability to take in a `cache` of states
+computed in previous generation steps, which can be used to speed up generation.
+The second is the ability to output the final dense "hidden state" of each
+generated token. This is used by `keras_hub.samplers.ContrastiveSampler`, which
+avoids repetition by penalizing repeated hidden states. Both are optional, and
+we will ignore them for now.
+"""
+
+
+def next(prompt, cache, index):
+ logits = model(prompt)[:, index - 1, :]
+ # Ignore hidden states for now; only needed for contrastive search.
+ hidden_states = None
+ return logits, hidden_states, cache
+
+
+"""
+Creating the wrapper function is the most complex part of using these functions. Now that
+it's done, let's test out the different utilities, starting with greedy search.
+"""
+
+"""
+### Greedy search
+
+We greedily pick the most probable token at each timestep. In other words, we get the
+argmax of the model output.
+"""
+
+sampler = keras_hub.samplers.GreedySampler()
+output_tokens = sampler(
+ next=next,
+ prompt=prompt_tokens,
+ index=1, # Start sampling immediately after the [BOS] token.
+)
+txt = tokenizer.detokenize(output_tokens)
+print(f"Greedy search generated text: \n{txt}\n")
+
+"""
+As you can see, greedy search starts out making some sense, but quickly starts repeating
+itself. This is a common problem with text generation that can be fixed by some of the
+probabilistic text generation utilities shown later on!
+"""
+
+"""
+### Beam search
+
+At a high-level, beam search keeps track of the `num_beams` most probable sequences at
+each timestep, and predicts the best next token from all sequences. It is an improvement
+over greedy search since it stores more possibilities. However, it is less efficient than
+greedy search since it has to compute and store multiple potential sequences.
+
+**Note:** beam search with `num_beams=1` is identical to greedy search.
+"""
+
+sampler = keras_hub.samplers.BeamSampler(num_beams=10)
+output_tokens = sampler(
+ next=next,
+ prompt=prompt_tokens,
+ index=1,
+)
+txt = tokenizer.detokenize(output_tokens)
+print(f"Beam search generated text: \n{txt}\n")
+
+"""
+Similar to greedy search, beam search quickly starts repeating itself, since it is still
+a deterministic method.
+"""
+
+"""
+### Random search
+
+Random search is our first probabilistic method. At each time step, it samples the next
+token using the softmax probabilities provided by the model.
+"""
+
+sampler = keras_hub.samplers.RandomSampler()
+output_tokens = sampler(
+ next=next,
+ prompt=prompt_tokens,
+ index=1,
+)
+txt = tokenizer.detokenize(output_tokens)
+print(f"Random search generated text: \n{txt}\n")
+
+"""
+Voilร , no repetitions! However, with random search, we may see some nonsensical words
+appearing since any word in the vocabulary has a chance of appearing with this sampling
+method. This is fixed by our next search utility, top-k search.
+"""
+
+"""
+### Top-K search
+
+Similar to random search, we sample the next token from the probability distribution
+provided by the model. The only difference is that here, we select out the top `k` most
+probable tokens, and distribute the probability mass over them before sampling. This way,
+we won't be sampling from low probability tokens, and hence we would have less
+nonsensical words!
+"""
+
+sampler = keras_hub.samplers.TopKSampler(k=10)
+output_tokens = sampler(
+ next=next,
+ prompt=prompt_tokens,
+ index=1,
+)
+txt = tokenizer.detokenize(output_tokens)
+print(f"Top-K search generated text: \n{txt}\n")
+
+"""
+### Top-P search
+
+Even with the top-k search, there is something to improve upon. With top-k search, the
+number `k` is fixed, which means it selects the same number of tokens for any probability
+distribution. Consider two scenarios, one where the probability mass is concentrated over
+2 words and another where the probability mass is evenly concentrated across 10. Should
+we choose `k=2` or `k=10`? There is no one size that fits all `k` here.
+
+This is where top-p search comes in! Instead of choosing a `k`, we choose a probability
+`p` that we want the probabilities of the top tokens to sum up to. This way, we can
+dynamically adjust the `k` based on the probability distribution. By setting `p=0.9`, if
+90% of the probability mass is concentrated on the top 2 tokens, we can filter out the
+top 2 tokens to sample from. If instead the 90% is distributed over 10 tokens, it will
+similarly filter out the top 10 tokens to sample from.
+"""
+
+sampler = keras_hub.samplers.TopPSampler(p=0.5)
+output_tokens = sampler(
+ next=next,
+ prompt=prompt_tokens,
+ index=1,
+)
+txt = tokenizer.detokenize(output_tokens)
+print(f"Top-P search generated text: \n{txt}\n")
+
+"""
+### Using callbacks for text generation
+
+We can also wrap the utilities in a callback, which allows you to print out a prediction
+sequence for every epoch of the model! Here is an example of a callback for top-k search:
+"""
+
+
+class TopKTextGenerator(keras.callbacks.Callback):
+ """A callback to generate text from a trained model using top-k."""
+
+ def __init__(self, k):
+ self.sampler = keras_hub.samplers.TopKSampler(k)
+
+ def on_epoch_end(self, epoch, logs=None):
+ output_tokens = self.sampler(
+ next=next,
+ prompt=prompt_tokens,
+ index=1,
+ )
+ txt = tokenizer.detokenize(output_tokens)
+ print(f"Top-K search generated text: \n{txt}\n")
+
+
+text_generation_callback = TopKTextGenerator(k=10)
+# Dummy training loop to demonstrate callback.
+model.fit(train_ds.take(1), verbose=2, epochs=2, callbacks=[text_generation_callback])
+
+"""
+## Conclusion
+
+To recap, in this example, we use KerasHub layers to train a sub-word vocabulary,
+tokenize training data, create a miniature GPT model, and perform inference with the
+text generation library.
+
+If you would like to understand how Transformers work, or learn more about training the
+full GPT model, here are some further readings:
+
+- Attention Is All You Need [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
+- GPT-3 Paper [Brown et al., 2020](https://arxiv.org/abs/2005.14165)
+"""
diff --git a/knowledge_base/generative/text_generation_with_miniature_gpt.py b/knowledge_base/generative/text_generation_with_miniature_gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca34ea03d412f54f365aab8875442a1b41f8fc6e
--- /dev/null
+++ b/knowledge_base/generative/text_generation_with_miniature_gpt.py
@@ -0,0 +1,316 @@
+"""
+Title: Text generation with a miniature GPT
+Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
+Date created: 2020/05/29
+Last modified: 2020/05/29
+Description: Implement a miniature version of GPT and train it to generate text.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to implement an autoregressive language model
+using a miniature version of the GPT model.
+The model consists of a single Transformer block with causal masking
+in its attention layer.
+We use the text from the IMDB sentiment classification dataset for training
+and generate new movie reviews for a given prompt.
+When using this script with your own dataset, make sure it has at least
+1 million words.
+
+This example should be run with `tf-nightly>=2.3.0-dev20200531` or
+with TensorFlow 2.3 or higher.
+
+**References:**
+
+- [GPT](https://www.semanticscholar.org/paper/Improving-Language-Understanding-by-Generative-Radford/cd18800a0fe0b668a1cc19f2ec95b5003d0a5035)
+- [GPT-2](https://www.semanticscholar.org/paper/Language-Models-are-Unsupervised-Multitask-Learners-Radford-Wu/9405cc0d6169988371b2755e573cc28650d14dfe)
+- [GPT-3](https://arxiv.org/abs/2005.14165)
+"""
+"""
+## Setup
+"""
+# We set the backend to TensorFlow. The code works with
+# both `tensorflow` and `torch`. It does not work with JAX
+# due to the behavior of `jax.numpy.tile` in a jit scope
+# (used in `causal_attention_mask()`: `tile` in JAX does
+# not support a dynamic `reps` argument.
+# You can make the code work in JAX by wrapping the
+# inside of the `causal_attention_mask` function in
+# a decorator to prevent jit compilation:
+# `with jax.ensure_compile_time_eval():`.
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+from keras import ops
+from keras.layers import TextVectorization
+import numpy as np
+import os
+import string
+import random
+import tensorflow
+import tensorflow.data as tf_data
+import tensorflow.strings as tf_strings
+
+
+"""
+## Implement a Transformer block as a layer
+"""
+
+
+def causal_attention_mask(batch_size, n_dest, n_src, dtype):
+ """
+ Mask the upper half of the dot product matrix in self attention.
+ This prevents flow of information from future tokens to current token.
+ 1's in the lower triangle, counting from the lower right corner.
+ """
+ i = ops.arange(n_dest)[:, None]
+ j = ops.arange(n_src)
+ m = i >= j - n_src + n_dest
+ mask = ops.cast(m, dtype)
+ mask = ops.reshape(mask, [1, n_dest, n_src])
+ mult = ops.concatenate(
+ [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
+ )
+ return ops.tile(mask, mult)
+
+
+class TransformerBlock(layers.Layer):
+ def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
+ super().__init__()
+ self.att = layers.MultiHeadAttention(num_heads, embed_dim)
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(ff_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
+ self.dropout1 = layers.Dropout(rate)
+ self.dropout2 = layers.Dropout(rate)
+
+ def call(self, inputs):
+ input_shape = ops.shape(inputs)
+ batch_size = input_shape[0]
+ seq_len = input_shape[1]
+ causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
+ attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
+ attention_output = self.dropout1(attention_output)
+ out1 = self.layernorm1(inputs + attention_output)
+ ffn_output = self.ffn(out1)
+ ffn_output = self.dropout2(ffn_output)
+ return self.layernorm2(out1 + ffn_output)
+
+
+"""
+## Implement an embedding layer
+
+Create two separate embedding layers: one for tokens and one for token index
+(positions).
+"""
+
+
+class TokenAndPositionEmbedding(layers.Layer):
+ def __init__(self, maxlen, vocab_size, embed_dim):
+ super().__init__()
+ self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
+ self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
+
+ def call(self, x):
+ maxlen = ops.shape(x)[-1]
+ positions = ops.arange(0, maxlen, 1)
+ positions = self.pos_emb(positions)
+ x = self.token_emb(x)
+ return x + positions
+
+
+"""
+## Implement the miniature GPT model
+"""
+vocab_size = 20000 # Only consider the top 20k words
+maxlen = 80 # Max sequence size
+embed_dim = 256 # Embedding size for each token
+num_heads = 2 # Number of attention heads
+feed_forward_dim = 256 # Hidden layer size in feed forward network inside transformer
+
+
+def create_model():
+ inputs = layers.Input(shape=(maxlen,), dtype="int32")
+ embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
+ x = embedding_layer(inputs)
+ transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
+ x = transformer_block(x)
+ outputs = layers.Dense(vocab_size)(x)
+ model = keras.Model(inputs=inputs, outputs=[outputs, x])
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ model.compile(
+ "adam",
+ loss=[loss_fn, None],
+ ) # No loss and optimization based on word embeddings from transformer block
+ return model
+
+
+"""
+## Prepare the data for word-level language modelling
+
+Download the IMDB dataset and combine training and validation sets for a text
+generation task.
+"""
+
+"""shell
+curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+tar -xf aclImdb_v1.tar.gz
+"""
+
+
+batch_size = 128
+
+# The dataset contains each review in a separate text file
+# The text files are present in four different folders
+# Create a list all files
+filenames = []
+directories = [
+ "aclImdb/train/pos",
+ "aclImdb/train/neg",
+ "aclImdb/test/pos",
+ "aclImdb/test/neg",
+]
+for dir in directories:
+ for f in os.listdir(dir):
+ filenames.append(os.path.join(dir, f))
+
+print(f"{len(filenames)} files")
+
+# Create a dataset from text files
+random.shuffle(filenames)
+text_ds = tf_data.TextLineDataset(filenames)
+text_ds = text_ds.shuffle(buffer_size=256)
+text_ds = text_ds.batch(batch_size)
+
+
+def custom_standardization(input_string):
+ """Remove html line-break tags and handle punctuation"""
+ lowercased = tf_strings.lower(input_string)
+ stripped_html = tf_strings.regex_replace(lowercased, "
", " ")
+ return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
+
+
+# Create a vectorization layer and adapt it to the text
+vectorize_layer = TextVectorization(
+ standardize=custom_standardization,
+ max_tokens=vocab_size - 1,
+ output_mode="int",
+ output_sequence_length=maxlen + 1,
+)
+vectorize_layer.adapt(text_ds)
+vocab = vectorize_layer.get_vocabulary() # To get words back from token indices
+
+
+def prepare_lm_inputs_labels(text):
+ """
+ Shift word sequences by 1 position so that the target for position (i) is
+ word at position (i+1). The model will use all words up till position (i)
+ to predict the next word.
+ """
+ text = tensorflow.expand_dims(text, -1)
+ tokenized_sentences = vectorize_layer(text)
+ x = tokenized_sentences[:, :-1]
+ y = tokenized_sentences[:, 1:]
+ return x, y
+
+
+text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
+text_ds = text_ds.prefetch(tf_data.AUTOTUNE)
+
+
+"""
+## Implement a Keras callback for generating text
+"""
+
+
+class TextGenerator(keras.callbacks.Callback):
+ """A callback to generate text from a trained model.
+ 1. Feed some starting prompt to the model
+ 2. Predict probabilities for the next token
+ 3. Sample the next token and add it to the next input
+
+ Arguments:
+ max_tokens: Integer, the number of tokens to be generated after prompt.
+ start_tokens: List of integers, the token indices for the starting prompt.
+ index_to_word: List of strings, obtained from the TextVectorization layer.
+ top_k: Integer, sample from the `top_k` token predictions.
+ print_every: Integer, print after this many epochs.
+ """
+
+ def __init__(
+ self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1
+ ):
+ self.max_tokens = max_tokens
+ self.start_tokens = start_tokens
+ self.index_to_word = index_to_word
+ self.print_every = print_every
+ self.k = top_k
+
+ def sample_from(self, logits):
+ logits, indices = ops.top_k(logits, k=self.k, sorted=True)
+ indices = np.asarray(indices).astype("int32")
+ preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
+ preds = np.asarray(preds).astype("float32")
+ return np.random.choice(indices, p=preds)
+
+ def detokenize(self, number):
+ return self.index_to_word[number]
+
+ def on_epoch_end(self, epoch, logs=None):
+ start_tokens = [_ for _ in self.start_tokens]
+ if (epoch + 1) % self.print_every != 0:
+ return
+ num_tokens_generated = 0
+ tokens_generated = []
+ while num_tokens_generated <= self.max_tokens:
+ pad_len = maxlen - len(start_tokens)
+ sample_index = len(start_tokens) - 1
+ if pad_len < 0:
+ x = start_tokens[:maxlen]
+ sample_index = maxlen - 1
+ elif pad_len > 0:
+ x = start_tokens + [0] * pad_len
+ else:
+ x = start_tokens
+ x = np.array([x])
+ y, _ = self.model.predict(x, verbose=0)
+ sample_token = self.sample_from(y[0][sample_index])
+ tokens_generated.append(sample_token)
+ start_tokens.append(sample_token)
+ num_tokens_generated = len(tokens_generated)
+ txt = " ".join(
+ [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
+ )
+ print(f"generated text:\n{txt}\n")
+
+
+# Tokenize starting prompt
+word_to_index = {}
+for index, word in enumerate(vocab):
+ word_to_index[word] = index
+
+start_prompt = "this movie is"
+start_tokens = [word_to_index.get(_, 1) for _ in start_prompt.split()]
+num_tokens_generated = 40
+text_gen_callback = TextGenerator(num_tokens_generated, start_tokens, vocab)
+
+
+"""
+## Train the model
+
+Note: This code should preferably be run on GPU.
+"""
+
+model = create_model()
+
+model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])
diff --git a/knowledge_base/generative/vae.py b/knowledge_base/generative/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..9068026f5568a712ffacc404003c8d3570262483
--- /dev/null
+++ b/knowledge_base/generative/vae.py
@@ -0,0 +1,197 @@
+"""
+Title: Variational AutoEncoder
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/05/03
+Last modified: 2024/04/24
+Description: Convolutional Variational AutoEncoder (VAE) trained on MNIST digits.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+import tensorflow as tf
+import keras
+from keras import ops
+from keras import layers
+
+"""
+## Create a sampling layer
+"""
+
+
+class Sampling(layers.Layer):
+ """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.seed_generator = keras.random.SeedGenerator(1337)
+
+ def call(self, inputs):
+ z_mean, z_log_var = inputs
+ batch = ops.shape(z_mean)[0]
+ dim = ops.shape(z_mean)[1]
+ epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
+ return z_mean + ops.exp(0.5 * z_log_var) * epsilon
+
+
+"""
+## Build the encoder
+"""
+
+latent_dim = 2
+
+encoder_inputs = keras.Input(shape=(28, 28, 1))
+x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
+x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
+x = layers.Flatten()(x)
+x = layers.Dense(16, activation="relu")(x)
+z_mean = layers.Dense(latent_dim, name="z_mean")(x)
+z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
+z = Sampling()([z_mean, z_log_var])
+encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
+encoder.summary()
+
+"""
+## Build the decoder
+"""
+
+latent_inputs = keras.Input(shape=(latent_dim,))
+x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
+x = layers.Reshape((7, 7, 64))(x)
+x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
+x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
+decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
+decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
+decoder.summary()
+
+"""
+## Define the VAE as a `Model` with a custom `train_step`
+"""
+
+
+class VAE(keras.Model):
+ def __init__(self, encoder, decoder, **kwargs):
+ super().__init__(**kwargs)
+ self.encoder = encoder
+ self.decoder = decoder
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
+ self.reconstruction_loss_tracker = keras.metrics.Mean(
+ name="reconstruction_loss"
+ )
+ self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
+
+ @property
+ def metrics(self):
+ return [
+ self.total_loss_tracker,
+ self.reconstruction_loss_tracker,
+ self.kl_loss_tracker,
+ ]
+
+ def train_step(self, data):
+ with tf.GradientTape() as tape:
+ z_mean, z_log_var, z = self.encoder(data)
+ reconstruction = self.decoder(z)
+ reconstruction_loss = ops.mean(
+ ops.sum(
+ keras.losses.binary_crossentropy(data, reconstruction),
+ axis=(1, 2),
+ )
+ )
+ kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
+ kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
+ total_loss = reconstruction_loss + kl_loss
+ grads = tape.gradient(total_loss, self.trainable_weights)
+ self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
+ self.total_loss_tracker.update_state(total_loss)
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
+ self.kl_loss_tracker.update_state(kl_loss)
+ return {
+ "loss": self.total_loss_tracker.result(),
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
+ "kl_loss": self.kl_loss_tracker.result(),
+ }
+
+
+"""
+## Train the VAE
+"""
+
+(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
+mnist_digits = np.concatenate([x_train, x_test], axis=0)
+mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
+
+vae = VAE(encoder, decoder)
+vae.compile(optimizer=keras.optimizers.Adam())
+vae.fit(mnist_digits, epochs=30, batch_size=128)
+
+"""
+## Display a grid of sampled digits
+"""
+
+import matplotlib.pyplot as plt
+
+
+def plot_latent_space(vae, n=30, figsize=15):
+ # display a n*n 2D manifold of digits
+ digit_size = 28
+ scale = 1.0
+ figure = np.zeros((digit_size * n, digit_size * n))
+ # linearly spaced coordinates corresponding to the 2D plot
+ # of digit classes in the latent space
+ grid_x = np.linspace(-scale, scale, n)
+ grid_y = np.linspace(-scale, scale, n)[::-1]
+
+ for i, yi in enumerate(grid_y):
+ for j, xi in enumerate(grid_x):
+ z_sample = np.array([[xi, yi]])
+ x_decoded = vae.decoder.predict(z_sample, verbose=0)
+ digit = x_decoded[0].reshape(digit_size, digit_size)
+ figure[
+ i * digit_size : (i + 1) * digit_size,
+ j * digit_size : (j + 1) * digit_size,
+ ] = digit
+
+ plt.figure(figsize=(figsize, figsize))
+ start_range = digit_size // 2
+ end_range = n * digit_size + start_range
+ pixel_range = np.arange(start_range, end_range, digit_size)
+ sample_range_x = np.round(grid_x, 1)
+ sample_range_y = np.round(grid_y, 1)
+ plt.xticks(pixel_range, sample_range_x)
+ plt.yticks(pixel_range, sample_range_y)
+ plt.xlabel("z[0]")
+ plt.ylabel("z[1]")
+ plt.imshow(figure, cmap="Greys_r")
+ plt.show()
+
+
+plot_latent_space(vae)
+
+"""
+## Display how the latent space clusters different digit classes
+"""
+
+
+def plot_label_clusters(vae, data, labels):
+ # display a 2D plot of the digit classes in the latent space
+ z_mean, _, _ = vae.encoder.predict(data, verbose=0)
+ plt.figure(figsize=(12, 10))
+ plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
+ plt.colorbar()
+ plt.xlabel("z[0]")
+ plt.ylabel("z[1]")
+ plt.show()
+
+
+(x_train, y_train), _ = keras.datasets.mnist.load_data()
+x_train = np.expand_dims(x_train, -1).astype("float32") / 255
+
+plot_label_clusters(vae, x_train, y_train)
diff --git a/knowledge_base/generative/vq_vae.py b/knowledge_base/generative/vq_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..cccd4037c99466e2f80cb6f1809adf8f1e0249a5
--- /dev/null
+++ b/knowledge_base/generative/vq_vae.py
@@ -0,0 +1,596 @@
+"""
+Title: Vector-Quantized Variational Autoencoders
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/07/21
+Last modified: 2022/06/27
+Description: Training a VQ-VAE for image reconstruction and codebook sampling for generation.
+Accelerator: GPU
+"""
+
+"""
+In this example, we develop a Vector Quantized Variational Autoencoder (VQ-VAE).
+VQ-VAE was proposed in
+[Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)
+by van der Oord et al. In standard VAEs, the latent space is continuous and is sampled
+from a Gaussian distribution. It is generally harder to learn such a continuous
+distribution via gradient descent. VQ-VAEs, on the other hand,
+operate on a discrete latent space, making the optimization problem simpler. It does so
+by maintaining a discrete *codebook*. The codebook is developed by
+discretizing the distance between continuous embeddings and the encoded
+outputs. These discrete code words are then fed to the decoder, which is trained
+to generate reconstructed samples.
+
+For an overview of VQ-VAEs, please refer to the original paper and
+[this video explanation](https://www.youtube.com/watch?v=VZFVUrYcig0).
+If you need a refresher on VAEs, you can refer to
+[this book chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/).
+VQ-VAEs are one of the main recipes behind [DALL-E](https://openai.com/blog/dall-e/)
+and the idea of a codebook is used in [VQ-GANs](https://arxiv.org/abs/2012.09841).
+This example uses implementation details from the
+[official VQ-VAE tutorial](https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb)
+from DeepMind.
+
+## Requirements
+
+To run this example, you will need TensorFlow 2.5 or higher, as well as
+TensorFlow Probability, which can be installed using the command below.
+"""
+
+"""shell
+pip install -q tensorflow-probability
+"""
+
+"""
+## Imports
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from tensorflow import keras
+from tensorflow.keras import layers
+import tensorflow_probability as tfp
+import tensorflow as tf
+
+"""
+## `VectorQuantizer` layer
+
+First, we implement a custom layer for the vector quantizer, which is the layer in between
+the encoder and decoder. Consider an output from the encoder, with shape `(batch_size, height, width,
+num_filters)`. The vector quantizer will first flatten this output, only keeping the
+`num_filters` dimension intact. So, the shape would become `(batch_size * height * width,
+num_filters)`. The rationale behind this is to treat the total number of filters as the size for
+the latent embeddings.
+
+An embedding table is then initialized to learn a codebook. We measure the L2-normalized
+distance between the flattened encoder outputs and code words of this codebook. We take the
+code that yields the minimum distance, and we apply one-hot encoding to achieve quantization.
+This way, the code yielding the minimum distance to the corresponding encoder output is
+mapped as one and the remaining codes are mapped as zeros.
+
+Since the quantization process is not differentiable, we apply a
+[straight-through estimator](https://www.hassanaskary.com/python/pytorch/deep%20learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html)
+in between the decoder and the encoder, so that the decoder gradients are directly propagated
+to the encoder. As the encoder and decoder share the same channel space, the decoder gradients are
+still meaningful to the encoder.
+"""
+
+
+class VectorQuantizer(layers.Layer):
+ def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
+ super().__init__(**kwargs)
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+
+ # The `beta` parameter is best kept between [0.25, 2] as per the paper.
+ self.beta = beta
+
+ # Initialize the embeddings which we will quantize.
+ w_init = tf.random_uniform_initializer()
+ self.embeddings = tf.Variable(
+ initial_value=w_init(
+ shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
+ ),
+ trainable=True,
+ name="embeddings_vqvae",
+ )
+
+ def call(self, x):
+ # Calculate the input shape of the inputs and
+ # then flatten the inputs keeping `embedding_dim` intact.
+ input_shape = tf.shape(x)
+ flattened = tf.reshape(x, [-1, self.embedding_dim])
+
+ # Quantization.
+ encoding_indices = self.get_code_indices(flattened)
+ encodings = tf.one_hot(encoding_indices, self.num_embeddings)
+ quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
+
+ # Reshape the quantized values back to the original input shape
+ quantized = tf.reshape(quantized, input_shape)
+
+ # Calculate vector quantization loss and add that to the layer. You can learn more
+ # about adding losses to different layers here:
+ # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
+ # the original paper to get a handle on the formulation of the loss function.
+ commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
+ codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
+ self.add_loss(self.beta * commitment_loss + codebook_loss)
+
+ # Straight-through estimator.
+ quantized = x + tf.stop_gradient(quantized - x)
+ return quantized
+
+ def get_code_indices(self, flattened_inputs):
+ # Calculate L2-normalized distance between the inputs and the codes.
+ similarity = tf.matmul(flattened_inputs, self.embeddings)
+ distances = (
+ tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
+ + tf.reduce_sum(self.embeddings**2, axis=0)
+ - 2 * similarity
+ )
+
+ # Derive the indices for minimum distances.
+ encoding_indices = tf.argmin(distances, axis=1)
+ return encoding_indices
+
+
+"""
+**A note on straight-through estimation**:
+
+This line of code does the straight-through estimation part: `quantized = x +
+tf.stop_gradient(quantized - x)`. During backpropagation, `(quantized - x)` won't be
+included in the computation graph and the gradients obtained for `quantized`
+will be copied for `inputs`. Thanks to [this video](https://youtu.be/VZFVUrYcig0?t=1393)
+for helping me understand this technique.
+"""
+
+"""
+## Encoder and decoder
+
+Now for the encoder and the decoder for the VQ-VAE. We will keep them small so
+that their capacity is a good fit for the MNIST dataset. The implementation of the encoder and
+decoder come from
+[this example](https://keras.io/examples/generative/vae).
+
+Note that activations _other than ReLU_ may not work for the encoder and decoder layers in the
+quantization architecture: Leaky ReLU activated layers, for example, have proven difficult to
+train, resulting in intermittent loss spikes that the model has trouble recovering from.
+"""
+
+
+def get_encoder(latent_dim=16):
+ encoder_inputs = keras.Input(shape=(28, 28, 1))
+ x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(
+ encoder_inputs
+ )
+ x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
+ encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
+ return keras.Model(encoder_inputs, encoder_outputs, name="encoder")
+
+
+def get_decoder(latent_dim=16):
+ latent_inputs = keras.Input(shape=get_encoder(latent_dim).output.shape[1:])
+ x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(
+ latent_inputs
+ )
+ x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
+ decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
+ return keras.Model(latent_inputs, decoder_outputs, name="decoder")
+
+
+"""
+## Standalone VQ-VAE model
+"""
+
+
+def get_vqvae(latent_dim=16, num_embeddings=64):
+ vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
+ encoder = get_encoder(latent_dim)
+ decoder = get_decoder(latent_dim)
+ inputs = keras.Input(shape=(28, 28, 1))
+ encoder_outputs = encoder(inputs)
+ quantized_latents = vq_layer(encoder_outputs)
+ reconstructions = decoder(quantized_latents)
+ return keras.Model(inputs, reconstructions, name="vq_vae")
+
+
+get_vqvae().summary()
+
+"""
+Note that the output channels of the encoder should match the `latent_dim` for the vector
+quantizer.
+"""
+
+"""
+## Wrapping up the training loop inside `VQVAETrainer`
+"""
+
+
+class VQVAETrainer(keras.models.Model):
+ def __init__(self, train_variance, latent_dim=32, num_embeddings=128, **kwargs):
+ super().__init__(**kwargs)
+ self.train_variance = train_variance
+ self.latent_dim = latent_dim
+ self.num_embeddings = num_embeddings
+
+ self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)
+
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
+ self.reconstruction_loss_tracker = keras.metrics.Mean(
+ name="reconstruction_loss"
+ )
+ self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
+
+ @property
+ def metrics(self):
+ return [
+ self.total_loss_tracker,
+ self.reconstruction_loss_tracker,
+ self.vq_loss_tracker,
+ ]
+
+ def train_step(self, x):
+ with tf.GradientTape() as tape:
+ # Outputs from the VQ-VAE.
+ reconstructions = self.vqvae(x)
+
+ # Calculate the losses.
+ reconstruction_loss = (
+ tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
+ )
+ total_loss = reconstruction_loss + sum(self.vqvae.losses)
+
+ # Backpropagation.
+ grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
+ self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))
+
+ # Loss tracking.
+ self.total_loss_tracker.update_state(total_loss)
+ self.reconstruction_loss_tracker.update_state(reconstruction_loss)
+ self.vq_loss_tracker.update_state(sum(self.vqvae.losses))
+
+ # Log results.
+ return {
+ "loss": self.total_loss_tracker.result(),
+ "reconstruction_loss": self.reconstruction_loss_tracker.result(),
+ "vqvae_loss": self.vq_loss_tracker.result(),
+ }
+
+
+"""
+## Load and preprocess the MNIST dataset
+"""
+
+(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
+
+x_train = np.expand_dims(x_train, -1)
+x_test = np.expand_dims(x_test, -1)
+x_train_scaled = (x_train / 255.0) - 0.5
+x_test_scaled = (x_test / 255.0) - 0.5
+
+data_variance = np.var(x_train / 255.0)
+
+"""
+## Train the VQ-VAE model
+"""
+
+vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=128)
+vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
+vqvae_trainer.fit(x_train_scaled, epochs=30, batch_size=128)
+
+"""
+## Reconstruction results on the test set
+"""
+
+
+def show_subplot(original, reconstructed):
+ plt.subplot(1, 2, 1)
+ plt.imshow(original.squeeze() + 0.5)
+ plt.title("Original")
+ plt.axis("off")
+
+ plt.subplot(1, 2, 2)
+ plt.imshow(reconstructed.squeeze() + 0.5)
+ plt.title("Reconstructed")
+ plt.axis("off")
+
+ plt.show()
+
+
+trained_vqvae_model = vqvae_trainer.vqvae
+idx = np.random.choice(len(x_test_scaled), 10)
+test_images = x_test_scaled[idx]
+reconstructions_test = trained_vqvae_model.predict(test_images)
+
+for test_image, reconstructed_image in zip(test_images, reconstructions_test):
+ show_subplot(test_image, reconstructed_image)
+
+"""
+These results look decent. You are encouraged to play with different hyperparameters
+(especially the number of embeddings and the dimensions of the embeddings) and observe how
+they affect the results.
+"""
+
+"""
+## Visualizing the discrete codes
+"""
+
+encoder = vqvae_trainer.vqvae.get_layer("encoder")
+quantizer = vqvae_trainer.vqvae.get_layer("vector_quantizer")
+
+encoded_outputs = encoder.predict(test_images)
+flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
+codebook_indices = quantizer.get_code_indices(flat_enc_outputs)
+codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])
+
+for i in range(len(test_images)):
+ plt.subplot(1, 2, 1)
+ plt.imshow(test_images[i].squeeze() + 0.5)
+ plt.title("Original")
+ plt.axis("off")
+
+ plt.subplot(1, 2, 2)
+ plt.imshow(codebook_indices[i])
+ plt.title("Code")
+ plt.axis("off")
+ plt.show()
+
+"""
+The figure above shows that the discrete codes have been able to capture some
+regularities from the dataset. Now, how do we sample from this codebook to create
+novel images? Since these codes are discrete and we imposed a categorical distribution
+on them, we cannot use them yet to generate anything meaningful until we can generate likely
+sequences of codes that we can give to the decoder.
+
+The authors use a PixelCNN to train these codes so that they can be used as powerful priors to
+generate novel examples. PixelCNN was proposed in
+[Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328)
+by van der Oord et al. We borrow the implementation from
+[this PixelCNN example](https://keras.io/examples/generative/pixelcnn/). It's an autoregressive
+generative model where the outputs are conditional on the prior ones. In other words, a PixelCNN
+generates an image on a pixel-by-pixel basis. For the purpose in this example, however, its task
+is to generate code book indices instead of pixels directly. The trained VQ-VAE decoder is used
+to map the indices generated by the PixelCNN back into the pixel space.
+"""
+
+"""
+## PixelCNN hyperparameters
+"""
+
+num_residual_blocks = 2
+num_pixelcnn_layers = 2
+pixelcnn_input_shape = encoded_outputs.shape[1:-1]
+print(f"Input shape of the PixelCNN: {pixelcnn_input_shape}")
+
+"""
+This input shape represents the reduction in the resolution performed by the encoder. With "same" padding,
+this exactly halves the "resolution" of the output shape for each stride-2 convolution layer. So, with these
+two layers, we end up with an encoder output tensor of 7x7 on axes 2 and 3, with the first axis as the batch
+size and the last axis being the code book embedding size. Since the quantization layer in the autoencoder
+maps these 7x7 tensors to indices of the code book, these output layer axis sizes must be matched by the
+PixelCNN as the input shape. The task of the PixelCNN for this architecture is to generate _likely_ 7x7
+arrangements of codebook indices.
+
+Note that this shape is something to optimize for in larger-sized image domains, along with the code
+book sizes. Since the PixelCNN is autoregressive, it needs to pass over each codebook index sequentially
+in order to generate novel images from the codebook. Each stride-2 (or rather more correctly a
+stride (2, 2)) convolution layer will divide the image generation time by four. Note, however, that there
+is probably a lower bound on this part: when the number of codes for the image to reconstruct is too small,
+it has insufficient information for the decoder to represent the level of detail in the image, so the
+output quality will suffer. This can be amended at least to some extent by using a larger code book.
+Since the autoregressive part of the image generation procedure uses codebook indices, there is far less of
+a performance penalty on using a larger code book as the lookup time for a larger-sized code from a larger
+code book is much smaller in comparison to iterating over a larger sequence of code book indices, although
+the size of the code book does impact on the batch size that can pass through the image generation procedure.
+Finding the sweet spot for this trade-off can require some architecture tweaking and could very well differ
+per dataset.
+"""
+
+"""
+## PixelCNN model
+
+Majority of this comes from
+[this example](https://keras.io/examples/generative/pixelcnn/).
+
+## Notes
+
+Thanks to [Rein van 't Veer](https://github.com/reinvantveer) for improving this example with
+copy-edits and minor code clean-ups.
+"""
+
+
+# The first layer is the PixelCNN layer. This layer simply
+# builds on the 2D convolutional layer, but includes masking.
+class PixelConvLayer(layers.Layer):
+ def __init__(self, mask_type, **kwargs):
+ super().__init__()
+ self.mask_type = mask_type
+ self.conv = layers.Conv2D(**kwargs)
+
+ def build(self, input_shape):
+ # Build the conv2d layer to initialize kernel variables
+ self.conv.build(input_shape)
+ # Use the initialized kernel to create the mask
+ kernel_shape = self.conv.kernel.get_shape()
+ self.mask = np.zeros(shape=kernel_shape)
+ self.mask[: kernel_shape[0] // 2, ...] = 1.0
+ self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
+ if self.mask_type == "B":
+ self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
+
+ def call(self, inputs):
+ self.conv.kernel.assign(self.conv.kernel * self.mask)
+ return self.conv(inputs)
+
+
+# Next, we build our residual block layer.
+# This is just a normal residual block, but based on the PixelConvLayer.
+class ResidualBlock(keras.layers.Layer):
+ def __init__(self, filters, **kwargs):
+ super().__init__(**kwargs)
+ self.conv1 = keras.layers.Conv2D(
+ filters=filters, kernel_size=1, activation="relu"
+ )
+ self.pixel_conv = PixelConvLayer(
+ mask_type="B",
+ filters=filters // 2,
+ kernel_size=3,
+ activation="relu",
+ padding="same",
+ )
+ self.conv2 = keras.layers.Conv2D(
+ filters=filters, kernel_size=1, activation="relu"
+ )
+
+ def call(self, inputs):
+ x = self.conv1(inputs)
+ x = self.pixel_conv(x)
+ x = self.conv2(x)
+ return keras.layers.add([inputs, x])
+
+
+pixelcnn_inputs = keras.Input(shape=pixelcnn_input_shape, dtype=tf.int32)
+ohe = tf.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings)
+x = PixelConvLayer(
+ mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
+)(ohe)
+
+for _ in range(num_residual_blocks):
+ x = ResidualBlock(filters=128)(x)
+
+for _ in range(num_pixelcnn_layers):
+ x = PixelConvLayer(
+ mask_type="B",
+ filters=128,
+ kernel_size=1,
+ strides=1,
+ activation="relu",
+ padding="valid",
+ )(x)
+
+out = keras.layers.Conv2D(
+ filters=vqvae_trainer.num_embeddings, kernel_size=1, strides=1, padding="valid"
+)(x)
+
+pixel_cnn = keras.Model(pixelcnn_inputs, out, name="pixel_cnn")
+pixel_cnn.summary()
+
+"""
+## Prepare data to train the PixelCNN
+
+We will train the PixelCNN to learn a categorical distribution of the discrete codes.
+First, we will generate code indices using the encoder and vector quantizer we just
+trained. Our training objective will be to minimize the crossentropy loss between these
+indices and the PixelCNN outputs. Here, the number of categories is equal to the number
+of embeddings present in our codebook (128 in our case). The PixelCNN model is
+trained to learn a distribution (as opposed to minimizing the L1/L2 loss), which is where
+it gets its generative capabilities from.
+"""
+
+# Generate the codebook indices.
+encoded_outputs = encoder.predict(x_train_scaled)
+flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
+codebook_indices = quantizer.get_code_indices(flat_enc_outputs)
+
+codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])
+print(f"Shape of the training data for PixelCNN: {codebook_indices.shape}")
+
+"""
+## PixelCNN training
+"""
+
+pixel_cnn.compile(
+ optimizer=keras.optimizers.Adam(3e-4),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=["accuracy"],
+)
+pixel_cnn.fit(
+ x=codebook_indices,
+ y=codebook_indices,
+ batch_size=128,
+ epochs=30,
+ validation_split=0.1,
+)
+
+"""
+We can improve these scores with more training and hyperparameter tuning.
+"""
+
+"""
+## Codebook sampling
+
+Now that our PixelCNN is trained, we can sample distinct codes from its outputs and pass
+them to our decoder to generate novel images.
+"""
+
+# Create a mini sampler model.
+inputs = layers.Input(shape=pixel_cnn.input_shape[1:])
+outputs = pixel_cnn(inputs, training=False)
+categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
+outputs = categorical_layer(outputs)
+sampler = keras.Model(inputs, outputs)
+
+"""
+We now construct a prior to generate images. Here, we will generate 10 images.
+"""
+
+# Create an empty array of priors.
+batch = 10
+priors = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
+batch, rows, cols = priors.shape
+
+# Iterate over the priors because generation has to be done sequentially pixel by pixel.
+for row in range(rows):
+ for col in range(cols):
+ # Feed the whole array and retrieving the pixel value probabilities for the next
+ # pixel.
+ probs = sampler.predict(priors)
+ # Use the probabilities to pick pixel values and append the values to the priors.
+ priors[:, row, col] = probs[:, row, col]
+
+print(f"Prior shape: {priors.shape}")
+
+"""
+We can now use our decoder to generate the images.
+"""
+
+# Perform an embedding lookup.
+pretrained_embeddings = quantizer.embeddings
+priors_ohe = tf.one_hot(priors.astype("int32"), vqvae_trainer.num_embeddings).numpy()
+quantized = tf.matmul(
+ priors_ohe.astype("float32"), pretrained_embeddings, transpose_b=True
+)
+quantized = tf.reshape(quantized, (-1, *(encoded_outputs.shape[1:])))
+
+# Generate novel images.
+decoder = vqvae_trainer.vqvae.get_layer("decoder")
+generated_samples = decoder.predict(quantized)
+
+for i in range(batch):
+ plt.subplot(1, 2, 1)
+ plt.imshow(priors[i])
+ plt.title("Code")
+ plt.axis("off")
+
+ plt.subplot(1, 2, 2)
+ plt.imshow(generated_samples[i].squeeze() + 0.5)
+ plt.title("Generated Sample")
+ plt.axis("off")
+ plt.show()
+
+"""
+We can enhance the quality of these generated samples by tweaking the PixelCNN.
+"""
+
+"""
+## Additional notes
+
+* After the VQ-VAE paper was initially released, the authors developed an exponential
+moving averaging scheme to update the embeddings inside the quantizer. If you're
+interested you can check out
+[this snippet](https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/nets/vqvae.py#L124).
+* To further enhance the quality of the generated samples,
+[VQ-VAE-2](https://arxiv.org/abs/1906.00446) was proposed that follows a cascaded
+approach to learn the codebook and to generate the images.
+"""
diff --git a/knowledge_base/generative/wgan-graphs.py b/knowledge_base/generative/wgan-graphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e470393af062b5b4946bb1fd72802fbda6d6692
--- /dev/null
+++ b/knowledge_base/generative/wgan-graphs.py
@@ -0,0 +1,611 @@
+"""
+Title: WGAN-GP with R-GCN for the generation of small molecular graphs
+Author: [akensert](https://github.com/akensert)
+Date created: 2021/06/30
+Last modified: 2021/06/30
+Description: Complete implementation of WGAN-GP with R-GCN to generate novel molecules.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this tutorial, we implement a generative model for graphs and use it to generate
+novel molecules.
+
+Motivation: The [development of new drugs](https://en.wikipedia.org/wiki/Drug_development)
+(molecules) can be extremely time-consuming and costly. The use of deep learning models
+can alleviate the search for good candidate drugs, by predicting properties of known molecules
+(e.g., solubility, toxicity, affinity to target protein, etc.). As the number of
+possible molecules is astronomical, the space in which we search for/explore molecules is
+just a fraction of the entire space. Therefore, it's arguably desirable to implement
+generative models that can learn to generate novel molecules (which would otherwise have never been explored).
+
+### References (implementation)
+
+The implementation in this tutorial is based on/inspired by the
+[MolGAN paper](https://arxiv.org/abs/1805.11973) and DeepChem's
+[Basic MolGAN](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#basicmolganmod
+el).
+
+### Further reading (generative models)
+Recent implementations of generative models for molecular graphs also include
+[Mol-CycleGAN](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-019-0404-1),
+[GraphVAE](https://arxiv.org/abs/1802.03480) and
+[JT-VAE](https://arxiv.org/abs/1802.04364). For more information on generative
+adverserial networks, see [GAN](https://arxiv.org/abs/1406.2661),
+[WGAN](https://arxiv.org/abs/1701.07875) and [WGAN-GP](https://arxiv.org/abs/1704.00028).
+
+"""
+
+"""
+## Setup
+
+### Install RDKit
+
+[RDKit](https://www.rdkit.org/) is a collection of cheminformatics and machine-learning
+software written in C++ and Python. In this tutorial, RDKit is used to conveniently and
+efficiently transform
+[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) to
+molecule objects, and then from those obtain sets of atoms and bonds.
+
+SMILES expresses the structure of a given molecule in the form of an ASCII string.
+The SMILES string is a compact encoding which, for smaller molecules, is relatively
+human-readable. Encoding molecules as a string both alleviates and facilitates database
+and/or web searching of a given molecule. RDKit uses algorithms to
+accurately transform a given SMILES to a molecule object, which can then
+be used to compute a great number of molecular properties/features.
+
+Notice, RDKit is commonly installed via [Conda](https://www.rdkit.org/docs/Install.html).
+However, thanks to
+[rdkit_platform_wheels](https://github.com/kuelumbus/rdkit_platform_wheels), rdkit
+can now (for the sake of this tutorial) be installed easily via pip, as follows:
+```
+pip -q install rdkit-pypi
+```
+And to allow easy visualization of a molecule objects, Pillow needs to be installed:
+```
+pip -q install Pillow
+```
+
+"""
+
+"""
+### Import packages
+
+"""
+
+from rdkit import Chem, RDLogger
+from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+
+RDLogger.DisableLog("rdApp.*")
+
+"""
+## Dataset
+
+The dataset used in this tutorial is a
+[quantum mechanics dataset](http://quantum-machine.org/datasets/) (QM9), obtained from
+[MoleculeNet](http://moleculenet.ai/datasets-1). Although many feature and label columns
+come with the dataset, we'll only focus on the
+[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system)
+column. The QM9 dataset is a good first dataset to work with for generating
+graphs, as the maximum number of heavy (non-hydrogen) atoms found in a molecule is only nine.
+"""
+
+csv_path = tf.keras.utils.get_file(
+ "qm9.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv"
+)
+
+data = []
+with open(csv_path, "r") as f:
+ for line in f.readlines()[1:]:
+ data.append(line.split(",")[1])
+
+# Let's look at a molecule of the dataset
+smiles = data[1000]
+print("SMILES:", smiles)
+molecule = Chem.MolFromSmiles(smiles)
+print("Num heavy atoms:", molecule.GetNumHeavyAtoms())
+molecule
+
+"""
+### Define helper functions
+These helper functions will help convert SMILES to graphs and graphs to molecule objects.
+
+**Representing a molecular graph**. Molecules can naturally be expressed as undirected
+graphs `G = (V, E)`, where `V` is a set of vertices (atoms), and `E` a set of edges
+(bonds). As for this implementation, each graph (molecule) will be represented as an
+adjacency tensor `A`, which encodes existence/non-existence of atom-pairs with their
+one-hot encoded bond types stretching an extra dimension, and a feature tensor `H`, which
+for each atom, one-hot encodes its atom type. Notice, as hydrogen atoms can be inferred by
+RDKit, hydrogen atoms are excluded from `A` and `H` for easier modeling.
+
+"""
+
+atom_mapping = {
+ "C": 0,
+ 0: "C",
+ "N": 1,
+ 1: "N",
+ "O": 2,
+ 2: "O",
+ "F": 3,
+ 3: "F",
+}
+
+bond_mapping = {
+ "SINGLE": 0,
+ 0: Chem.BondType.SINGLE,
+ "DOUBLE": 1,
+ 1: Chem.BondType.DOUBLE,
+ "TRIPLE": 2,
+ 2: Chem.BondType.TRIPLE,
+ "AROMATIC": 3,
+ 3: Chem.BondType.AROMATIC,
+}
+
+NUM_ATOMS = 9 # Maximum number of atoms
+ATOM_DIM = 4 + 1 # Number of atom types
+BOND_DIM = 4 + 1 # Number of bond types
+LATENT_DIM = 64 # Size of the latent space
+
+
+def smiles_to_graph(smiles):
+ # Converts SMILES to molecule object
+ molecule = Chem.MolFromSmiles(smiles)
+
+ # Initialize adjacency and feature tensor
+ adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")
+ features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")
+
+ # loop over each atom in molecule
+ for atom in molecule.GetAtoms():
+ i = atom.GetIdx()
+ atom_type = atom_mapping[atom.GetSymbol()]
+ features[i] = np.eye(ATOM_DIM)[atom_type]
+ # loop over one-hop neighbors
+ for neighbor in atom.GetNeighbors():
+ j = neighbor.GetIdx()
+ bond = molecule.GetBondBetweenAtoms(i, j)
+ bond_type_idx = bond_mapping[bond.GetBondType().name]
+ adjacency[bond_type_idx, [i, j], [j, i]] = 1
+
+ # Where no bond, add 1 to last channel (indicating "non-bond")
+ # Notice: channels-first
+ adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1
+
+ # Where no atom, add 1 to last column (indicating "non-atom")
+ features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1
+
+ return adjacency, features
+
+
+def graph_to_molecule(graph):
+ # Unpack graph
+ adjacency, features = graph
+
+ # RWMol is a molecule object intended to be edited
+ molecule = Chem.RWMol()
+
+ # Remove "no atoms" & atoms with no bonds
+ keep_idx = np.where(
+ (np.argmax(features, axis=1) != ATOM_DIM - 1)
+ & (np.sum(adjacency[:-1], axis=(0, 1)) != 0)
+ )[0]
+ features = features[keep_idx]
+ adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]
+
+ # Add atoms to molecule
+ for atom_type_idx in np.argmax(features, axis=1):
+ atom = Chem.Atom(atom_mapping[atom_type_idx])
+ _ = molecule.AddAtom(atom)
+
+ # Add bonds between atoms in molecule; based on the upper triangles
+ # of the [symmetric] adjacency tensor
+ (bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)
+ for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):
+ if atom_i == atom_j or bond_ij == BOND_DIM - 1:
+ continue
+ bond_type = bond_mapping[bond_ij]
+ molecule.AddBond(int(atom_i), int(atom_j), bond_type)
+
+ # Sanitize the molecule; for more information on sanitization, see
+ # https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
+ flag = Chem.SanitizeMol(molecule, catchErrors=True)
+ # Let's be strict. If sanitization fails, return None
+ if flag != Chem.SanitizeFlags.SANITIZE_NONE:
+ return None
+
+ return molecule
+
+
+# Test helper functions
+graph_to_molecule(smiles_to_graph(smiles))
+
+"""
+### Generate training set
+
+To save training time, we'll only use a tenth of the QM9 dataset.
+"""
+
+adjacency_tensor, feature_tensor = [], []
+for smiles in data[::10]:
+ adjacency, features = smiles_to_graph(smiles)
+ adjacency_tensor.append(adjacency)
+ feature_tensor.append(features)
+
+adjacency_tensor = np.array(adjacency_tensor)
+feature_tensor = np.array(feature_tensor)
+
+print("adjacency_tensor.shape =", adjacency_tensor.shape)
+print("feature_tensor.shape =", feature_tensor.shape)
+
+"""
+## Model
+
+The idea is to implement a generator network and a discriminator network via WGAN-GP,
+that will result in a generator network that can generate small novel molecules
+(small graphs).
+
+The generator network needs to be able to map (for each example in the batch) a vector `z`
+to a 3-D adjacency tensor (`A`) and 2-D feature tensor (`H`). For this, `z` will first be
+passed through a fully-connected network, for which the output will be further passed
+through two separate fully-connected networks. Each of these two fully-connected
+networks will then output (for each example in the batch) a tanh-activated vector
+followed by a reshape and softmax to match that of a multi-dimensional adjacency/feature
+tensor.
+
+As the discriminator network will receives as input a graph (`A`, `H`) from either the
+generator or from the training set, we'll need to implement graph convolutional layers,
+which allows us to operate on graphs. This means that input to the discriminator network
+will first pass through graph convolutional layers, then an average-pooling layer,
+and finally a few fully-connected layers. The final output should be a scalar (for each
+example in the batch) which indicates the "realness" of the associated input
+(in this case a "fake" or "real" molecule).
+
+
+### Graph generator
+"""
+
+
+def GraphGenerator(
+ dense_units,
+ dropout_rate,
+ latent_dim,
+ adjacency_shape,
+ feature_shape,
+):
+ z = keras.layers.Input(shape=(LATENT_DIM,))
+ # Propagate through one or more densely connected layers
+ x = z
+ for units in dense_units:
+ x = keras.layers.Dense(units, activation="tanh")(x)
+ x = keras.layers.Dropout(dropout_rate)(x)
+
+ # Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)
+ x_adjacency = keras.layers.Dense(tf.math.reduce_prod(adjacency_shape))(x)
+ x_adjacency = keras.layers.Reshape(adjacency_shape)(x_adjacency)
+ # Symmetrify tensors in the last two dimensions
+ x_adjacency = (x_adjacency + tf.transpose(x_adjacency, (0, 1, 3, 2))) / 2
+ x_adjacency = keras.layers.Softmax(axis=1)(x_adjacency)
+
+ # Map outputs of previous layer (x) to [continuous] feature tensors (x_features)
+ x_features = keras.layers.Dense(tf.math.reduce_prod(feature_shape))(x)
+ x_features = keras.layers.Reshape(feature_shape)(x_features)
+ x_features = keras.layers.Softmax(axis=2)(x_features)
+
+ return keras.Model(inputs=z, outputs=[x_adjacency, x_features], name="Generator")
+
+
+generator = GraphGenerator(
+ dense_units=[128, 256, 512],
+ dropout_rate=0.2,
+ latent_dim=LATENT_DIM,
+ adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
+ feature_shape=(NUM_ATOMS, ATOM_DIM),
+)
+generator.summary()
+
+"""
+### Graph discriminator
+
+
+**Graph convolutional layer**. The
+[relational graph convolutional layers](https://arxiv.org/abs/1703.06103) implements non-linearly transformed
+neighborhood aggregations. We can define these layers as follows:
+
+`H^{l+1} = ฯ(D^{-1} @ A @ H^{l+1} @ W^{l})`
+
+
+Where `ฯ` denotes the non-linear transformation (commonly a ReLU activation), `A` the
+adjacency tensor, `H^{l}` the feature tensor at the `l:th` layer, `D^{-1}` the inverse
+diagonal degree tensor of `A`, and `W^{l}` the trainable weight tensor at the `l:th`
+layer. Specifically, for each bond type (relation), the degree tensor expresses, in the
+diagonal, the number of bonds attached to each atom. Notice, in this tutorial `D^{-1}` is
+omitted, for two reasons: (1) it's not obvious how to apply this normalization on the
+continuous adjacency tensors (generated by the generator), and (2) the performance of the
+WGAN without normalization seems to work just fine. Furthermore, in contrast to the
+[original paper](https://arxiv.org/abs/1703.06103), no self-loop is defined, as we don't
+want to train the generator to predict "self-bonding".
+
+
+
+"""
+
+
+class RelationalGraphConvLayer(keras.layers.Layer):
+ def __init__(
+ self,
+ units=128,
+ activation="relu",
+ use_bias=False,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.units = units
+ self.activation = keras.activations.get(activation)
+ self.use_bias = use_bias
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
+ self.bias_initializer = keras.initializers.get(bias_initializer)
+ self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
+ self.bias_regularizer = keras.regularizers.get(bias_regularizer)
+
+ def build(self, input_shape):
+ bond_dim = input_shape[0][1]
+ atom_dim = input_shape[1][2]
+
+ self.kernel = self.add_weight(
+ shape=(bond_dim, atom_dim, self.units),
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ trainable=True,
+ name="W",
+ dtype=tf.float32,
+ )
+
+ if self.use_bias:
+ self.bias = self.add_weight(
+ shape=(bond_dim, 1, self.units),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ trainable=True,
+ name="b",
+ dtype=tf.float32,
+ )
+
+ self.built = True
+
+ def call(self, inputs, training=False):
+ adjacency, features = inputs
+ # Aggregate information from neighbors
+ x = tf.matmul(adjacency, features[:, None, :, :])
+ # Apply linear transformation
+ x = tf.matmul(x, self.kernel)
+ if self.use_bias:
+ x += self.bias
+ # Reduce bond types dim
+ x_reduced = tf.reduce_sum(x, axis=1)
+ # Apply non-linear transformation
+ return self.activation(x_reduced)
+
+
+def GraphDiscriminator(
+ gconv_units, dense_units, dropout_rate, adjacency_shape, feature_shape
+):
+ adjacency = keras.layers.Input(shape=adjacency_shape)
+ features = keras.layers.Input(shape=feature_shape)
+
+ # Propagate through one or more graph convolutional layers
+ features_transformed = features
+ for units in gconv_units:
+ features_transformed = RelationalGraphConvLayer(units)(
+ [adjacency, features_transformed]
+ )
+
+ # Reduce 2-D representation of molecule to 1-D
+ x = keras.layers.GlobalAveragePooling1D()(features_transformed)
+
+ # Propagate through one or more densely connected layers
+ for units in dense_units:
+ x = keras.layers.Dense(units, activation="relu")(x)
+ x = keras.layers.Dropout(dropout_rate)(x)
+
+ # For each molecule, output a single scalar value expressing the
+ # "realness" of the inputted molecule
+ x_out = keras.layers.Dense(1, dtype="float32")(x)
+
+ return keras.Model(inputs=[adjacency, features], outputs=x_out)
+
+
+discriminator = GraphDiscriminator(
+ gconv_units=[128, 128, 128, 128],
+ dense_units=[512, 512],
+ dropout_rate=0.2,
+ adjacency_shape=(BOND_DIM, NUM_ATOMS, NUM_ATOMS),
+ feature_shape=(NUM_ATOMS, ATOM_DIM),
+)
+discriminator.summary()
+
+"""
+### WGAN-GP
+"""
+
+
+class GraphWGAN(keras.Model):
+ def __init__(
+ self,
+ generator,
+ discriminator,
+ discriminator_steps=1,
+ generator_steps=1,
+ gp_weight=10,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.generator = generator
+ self.discriminator = discriminator
+ self.discriminator_steps = discriminator_steps
+ self.generator_steps = generator_steps
+ self.gp_weight = gp_weight
+ self.latent_dim = self.generator.input_shape[-1]
+
+ def compile(self, optimizer_generator, optimizer_discriminator, **kwargs):
+ super().compile(**kwargs)
+ self.optimizer_generator = optimizer_generator
+ self.optimizer_discriminator = optimizer_discriminator
+ self.metric_generator = keras.metrics.Mean(name="loss_gen")
+ self.metric_discriminator = keras.metrics.Mean(name="loss_dis")
+
+ def train_step(self, inputs):
+ if isinstance(inputs[0], tuple):
+ inputs = inputs[0]
+
+ graph_real = inputs
+
+ self.batch_size = tf.shape(inputs[0])[0]
+
+ # Train the discriminator for one or more steps
+ for _ in range(self.discriminator_steps):
+ z = tf.random.normal((self.batch_size, self.latent_dim))
+
+ with tf.GradientTape() as tape:
+ graph_generated = self.generator(z, training=True)
+ loss = self._loss_discriminator(graph_real, graph_generated)
+
+ grads = tape.gradient(loss, self.discriminator.trainable_weights)
+ self.optimizer_discriminator.apply_gradients(
+ zip(grads, self.discriminator.trainable_weights)
+ )
+ self.metric_discriminator.update_state(loss)
+
+ # Train the generator for one or more steps
+ for _ in range(self.generator_steps):
+ z = tf.random.normal((self.batch_size, self.latent_dim))
+
+ with tf.GradientTape() as tape:
+ graph_generated = self.generator(z, training=True)
+ loss = self._loss_generator(graph_generated)
+
+ grads = tape.gradient(loss, self.generator.trainable_weights)
+ self.optimizer_generator.apply_gradients(
+ zip(grads, self.generator.trainable_weights)
+ )
+ self.metric_generator.update_state(loss)
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def _loss_discriminator(self, graph_real, graph_generated):
+ logits_real = self.discriminator(graph_real, training=True)
+ logits_generated = self.discriminator(graph_generated, training=True)
+ loss = tf.reduce_mean(logits_generated) - tf.reduce_mean(logits_real)
+ loss_gp = self._gradient_penalty(graph_real, graph_generated)
+ return loss + loss_gp * self.gp_weight
+
+ def _loss_generator(self, graph_generated):
+ logits_generated = self.discriminator(graph_generated, training=True)
+ return -tf.reduce_mean(logits_generated)
+
+ def _gradient_penalty(self, graph_real, graph_generated):
+ # Unpack graphs
+ adjacency_real, features_real = graph_real
+ adjacency_generated, features_generated = graph_generated
+
+ # Generate interpolated graphs (adjacency_interp and features_interp)
+ alpha = tf.random.uniform([self.batch_size])
+ alpha = tf.reshape(alpha, (self.batch_size, 1, 1, 1))
+ adjacency_interp = (adjacency_real * alpha) + (1 - alpha) * adjacency_generated
+ alpha = tf.reshape(alpha, (self.batch_size, 1, 1))
+ features_interp = (features_real * alpha) + (1 - alpha) * features_generated
+
+ # Compute the logits of interpolated graphs
+ with tf.GradientTape() as tape:
+ tape.watch(adjacency_interp)
+ tape.watch(features_interp)
+ logits = self.discriminator(
+ [adjacency_interp, features_interp], training=True
+ )
+
+ # Compute the gradients with respect to the interpolated graphs
+ grads = tape.gradient(logits, [adjacency_interp, features_interp])
+ # Compute the gradient penalty
+ grads_adjacency_penalty = (1 - tf.norm(grads[0], axis=1)) ** 2
+ grads_features_penalty = (1 - tf.norm(grads[1], axis=2)) ** 2
+ return tf.reduce_mean(
+ tf.reduce_mean(grads_adjacency_penalty, axis=(-2, -1))
+ + tf.reduce_mean(grads_features_penalty, axis=(-1))
+ )
+
+
+"""
+## Train the model
+
+To save time (if run on a CPU), we'll only train the model for 10 epochs.
+"""
+
+wgan = GraphWGAN(generator, discriminator, discriminator_steps=1)
+
+wgan.compile(
+ optimizer_generator=keras.optimizers.Adam(5e-4),
+ optimizer_discriminator=keras.optimizers.Adam(5e-4),
+)
+
+wgan.fit([adjacency_tensor, feature_tensor], epochs=10, batch_size=16)
+
+"""
+## Sample novel molecules with the generator
+"""
+
+
+def sample(generator, batch_size):
+ z = tf.random.normal((batch_size, LATENT_DIM))
+ graph = generator.predict(z)
+ # obtain one-hot encoded adjacency tensor
+ adjacency = tf.argmax(graph[0], axis=1)
+ adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)
+ # Remove potential self-loops from adjacency
+ adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))
+ # obtain one-hot encoded feature tensor
+ features = tf.argmax(graph[1], axis=2)
+ features = tf.one_hot(features, depth=ATOM_DIM, axis=2)
+ return [
+ graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])
+ for i in range(batch_size)
+ ]
+
+
+molecules = sample(wgan.generator, batch_size=48)
+
+MolsToGridImage(
+ [m for m in molecules if m is not None][:25], molsPerRow=5, subImgSize=(150, 150)
+)
+
+"""
+## Concluding thoughts
+
+**Inspecting the results**. Ten epochs of training seemed enough to generate some decent
+looking molecules! Notice, in contrast to the
+[MolGAN paper](https://arxiv.org/abs/1805.11973), the uniqueness of the generated
+molecules in this tutorial seems really high, which is great!
+
+**What we've learned, and prospects**. In this tutorial, a generative model for molecular
+graphs was successfully implemented, which allowed us to generate novel molecules. In the
+future, it would be interesting to implement generative models that can modify existing
+molecules (for instance, to optimize solubility or protein-binding of an existing
+molecule). For that however, a reconstruction loss would likely be needed, which is
+tricky to implement as there's no easy and obvious way to compute similarity between two
+molecular graphs.
+
+Example available on HuggingFace
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/wgan-molecular-graphs) | [](https://huggingface.co/spaces/keras-io/Generating-molecular-graphs-by-WGAN-GP) |
+"""
diff --git a/knowledge_base/generative/wgan_gp.py b/knowledge_base/generative/wgan_gp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d81974b5f51f2c5ce5ce480b740885fc480d84a3
--- /dev/null
+++ b/knowledge_base/generative/wgan_gp.py
@@ -0,0 +1,441 @@
+"""
+Title: WGAN-GP overriding `Model.train_step`
+Author: [A_K_Nain](https://twitter.com/A_K_Nain)
+Date created: 2020/05/9
+Last modified: 2023/08/3
+Description: Implementation of Wasserstein GAN with Gradient Penalty.
+Accelerator: GPU
+"""
+
+"""
+## Wasserstein GAN (WGAN) with Gradient Penalty (GP)
+
+The original [Wasserstein GAN](https://arxiv.org/abs/1701.07875) leverages the
+Wasserstein distance to produce a value function that has better theoretical
+properties than the value function used in the original GAN paper. WGAN requires
+that the discriminator (aka the critic) lie within the space of 1-Lipschitz
+functions. The authors proposed the idea of weight clipping to achieve this
+constraint. Though weight clipping works, it can be a problematic way to enforce
+1-Lipschitz constraint and can cause undesirable behavior, e.g. a very deep WGAN
+discriminator (critic) often fails to converge.
+
+The [WGAN-GP](https://arxiv.org/abs/1704.00028) method proposes an
+alternative to weight clipping to ensure smooth training. Instead of clipping
+the weights, the authors proposed a "gradient penalty" by adding a loss term
+that keeps the L2 norm of the discriminator gradients close to 1.
+"""
+
+"""
+## Setup
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+import tensorflow as tf
+from keras import layers
+
+
+"""
+## Prepare the Fashion-MNIST data
+
+To demonstrate how to train WGAN-GP, we will be using the
+[Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. Each
+sample in this dataset is a 28x28 grayscale image associated with a label from
+10 classes (e.g. trouser, pullover, sneaker, etc.)
+"""
+
+IMG_SHAPE = (28, 28, 1)
+BATCH_SIZE = 512
+
+# Size of the noise vector
+noise_dim = 128
+
+fashion_mnist = keras.datasets.fashion_mnist
+(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
+print(f"Number of examples: {len(train_images)}")
+print(f"Shape of the images in the dataset: {train_images.shape[1:]}")
+
+# Reshape each sample to (28, 28, 1) and normalize the pixel values in the [-1, 1] range
+train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
+train_images = (train_images - 127.5) / 127.5
+
+"""
+## Create the discriminator (the critic in the original WGAN)
+
+The samples in the dataset have a (28, 28, 1) shape. Because we will be
+using strided convolutions, this can result in a shape with odd dimensions.
+For example,
+`(28, 28) -> Conv_s2 -> (14, 14) -> Conv_s2 -> (7, 7) -> Conv_s2 ->(3, 3)`.
+
+While performing upsampling in the generator part of the network, we won't get
+the same input shape as the original images if we aren't careful. To avoid this,
+we will do something much simpler:
+- In the discriminator: "zero pad" the input to change the shape to `(32, 32, 1)`
+for each sample; and
+- Ihe generator: crop the final output to match the shape with input shape.
+"""
+
+
+def conv_block(
+ x,
+ filters,
+ activation,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding="same",
+ use_bias=True,
+ use_bn=False,
+ use_dropout=False,
+ drop_value=0.5,
+):
+ x = layers.Conv2D(
+ filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
+ )(x)
+ if use_bn:
+ x = layers.BatchNormalization()(x)
+ x = activation(x)
+ if use_dropout:
+ x = layers.Dropout(drop_value)(x)
+ return x
+
+
+def get_discriminator_model():
+ img_input = layers.Input(shape=IMG_SHAPE)
+ # Zero pad the input to make the input images size to (32, 32, 1).
+ x = layers.ZeroPadding2D((2, 2))(img_input)
+ x = conv_block(
+ x,
+ 64,
+ kernel_size=(5, 5),
+ strides=(2, 2),
+ use_bn=False,
+ use_bias=True,
+ activation=layers.LeakyReLU(0.2),
+ use_dropout=False,
+ drop_value=0.3,
+ )
+ x = conv_block(
+ x,
+ 128,
+ kernel_size=(5, 5),
+ strides=(2, 2),
+ use_bn=False,
+ activation=layers.LeakyReLU(0.2),
+ use_bias=True,
+ use_dropout=True,
+ drop_value=0.3,
+ )
+ x = conv_block(
+ x,
+ 256,
+ kernel_size=(5, 5),
+ strides=(2, 2),
+ use_bn=False,
+ activation=layers.LeakyReLU(0.2),
+ use_bias=True,
+ use_dropout=True,
+ drop_value=0.3,
+ )
+ x = conv_block(
+ x,
+ 512,
+ kernel_size=(5, 5),
+ strides=(2, 2),
+ use_bn=False,
+ activation=layers.LeakyReLU(0.2),
+ use_bias=True,
+ use_dropout=False,
+ drop_value=0.3,
+ )
+
+ x = layers.Flatten()(x)
+ x = layers.Dropout(0.2)(x)
+ x = layers.Dense(1)(x)
+
+ d_model = keras.models.Model(img_input, x, name="discriminator")
+ return d_model
+
+
+d_model = get_discriminator_model()
+d_model.summary()
+
+"""
+## Create the generator
+"""
+
+
+def upsample_block(
+ x,
+ filters,
+ activation,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ up_size=(2, 2),
+ padding="same",
+ use_bn=False,
+ use_bias=True,
+ use_dropout=False,
+ drop_value=0.3,
+):
+ x = layers.UpSampling2D(up_size)(x)
+ x = layers.Conv2D(
+ filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
+ )(x)
+
+ if use_bn:
+ x = layers.BatchNormalization()(x)
+
+ if activation:
+ x = activation(x)
+ if use_dropout:
+ x = layers.Dropout(drop_value)(x)
+ return x
+
+
+def get_generator_model():
+ noise = layers.Input(shape=(noise_dim,))
+ x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
+ x = layers.BatchNormalization()(x)
+ x = layers.LeakyReLU(0.2)(x)
+
+ x = layers.Reshape((4, 4, 256))(x)
+ x = upsample_block(
+ x,
+ 128,
+ layers.LeakyReLU(0.2),
+ strides=(1, 1),
+ use_bias=False,
+ use_bn=True,
+ padding="same",
+ use_dropout=False,
+ )
+ x = upsample_block(
+ x,
+ 64,
+ layers.LeakyReLU(0.2),
+ strides=(1, 1),
+ use_bias=False,
+ use_bn=True,
+ padding="same",
+ use_dropout=False,
+ )
+ x = upsample_block(
+ x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
+ )
+ # At this point, we have an output which has the same shape as the input, (32, 32, 1).
+ # We will use a Cropping2D layer to make it (28, 28, 1).
+ x = layers.Cropping2D((2, 2))(x)
+
+ g_model = keras.models.Model(noise, x, name="generator")
+ return g_model
+
+
+g_model = get_generator_model()
+g_model.summary()
+
+"""
+## Create the WGAN-GP model
+
+Now that we have defined our generator and discriminator, it's time to implement
+the WGAN-GP model. We will also override the `train_step` for training.
+"""
+
+
+class WGAN(keras.Model):
+ def __init__(
+ self,
+ discriminator,
+ generator,
+ latent_dim,
+ discriminator_extra_steps=3,
+ gp_weight=10.0,
+ ):
+ super().__init__()
+ self.discriminator = discriminator
+ self.generator = generator
+ self.latent_dim = latent_dim
+ self.d_steps = discriminator_extra_steps
+ self.gp_weight = gp_weight
+
+ def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
+ super().compile()
+ self.d_optimizer = d_optimizer
+ self.g_optimizer = g_optimizer
+ self.d_loss_fn = d_loss_fn
+ self.g_loss_fn = g_loss_fn
+
+ def gradient_penalty(self, batch_size, real_images, fake_images):
+ """Calculates the gradient penalty.
+
+ This loss is calculated on an interpolated image
+ and added to the discriminator loss.
+ """
+ # Get the interpolated image
+ alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
+ diff = fake_images - real_images
+ interpolated = real_images + alpha * diff
+
+ with tf.GradientTape() as gp_tape:
+ gp_tape.watch(interpolated)
+ # 1. Get the discriminator output for this interpolated image.
+ pred = self.discriminator(interpolated, training=True)
+
+ # 2. Calculate the gradients w.r.t to this interpolated image.
+ grads = gp_tape.gradient(pred, [interpolated])[0]
+ # 3. Calculate the norm of the gradients.
+ norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
+ gp = tf.reduce_mean((norm - 1.0) ** 2)
+ return gp
+
+ def train_step(self, real_images):
+ if isinstance(real_images, tuple):
+ real_images = real_images[0]
+
+ # Get the batch size
+ batch_size = tf.shape(real_images)[0]
+
+ # For each batch, we are going to perform the
+ # following steps as laid out in the original paper:
+ # 1. Train the generator and get the generator loss
+ # 2. Train the discriminator and get the discriminator loss
+ # 3. Calculate the gradient penalty
+ # 4. Multiply this gradient penalty with a constant weight factor
+ # 5. Add the gradient penalty to the discriminator loss
+ # 6. Return the generator and discriminator losses as a loss dictionary
+
+ # Train the discriminator first. The original paper recommends training
+ # the discriminator for `x` more steps (typically 5) as compared to
+ # one step of the generator. Here we will train it for 3 extra steps
+ # as compared to 5 to reduce the training time.
+ for i in range(self.d_steps):
+ # Get the latent vector
+ random_latent_vectors = tf.random.normal(
+ shape=(batch_size, self.latent_dim)
+ )
+ with tf.GradientTape() as tape:
+ # Generate fake images from the latent vector
+ fake_images = self.generator(random_latent_vectors, training=True)
+ # Get the logits for the fake images
+ fake_logits = self.discriminator(fake_images, training=True)
+ # Get the logits for the real images
+ real_logits = self.discriminator(real_images, training=True)
+
+ # Calculate the discriminator loss using the fake and real image logits
+ d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
+ # Calculate the gradient penalty
+ gp = self.gradient_penalty(batch_size, real_images, fake_images)
+ # Add the gradient penalty to the original discriminator loss
+ d_loss = d_cost + gp * self.gp_weight
+
+ # Get the gradients w.r.t the discriminator loss
+ d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
+ # Update the weights of the discriminator using the discriminator optimizer
+ self.d_optimizer.apply_gradients(
+ zip(d_gradient, self.discriminator.trainable_variables)
+ )
+
+ # Train the generator
+ # Get the latent vector
+ random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
+ with tf.GradientTape() as tape:
+ # Generate fake images using the generator
+ generated_images = self.generator(random_latent_vectors, training=True)
+ # Get the discriminator logits for fake images
+ gen_img_logits = self.discriminator(generated_images, training=True)
+ # Calculate the generator loss
+ g_loss = self.g_loss_fn(gen_img_logits)
+
+ # Get the gradients w.r.t the generator loss
+ gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
+ # Update the weights of the generator using the generator optimizer
+ self.g_optimizer.apply_gradients(
+ zip(gen_gradient, self.generator.trainable_variables)
+ )
+ return {"d_loss": d_loss, "g_loss": g_loss}
+
+
+"""
+## Create a Keras callback that periodically saves generated images
+"""
+
+
+class GANMonitor(keras.callbacks.Callback):
+ def __init__(self, num_img=6, latent_dim=128):
+ self.num_img = num_img
+ self.latent_dim = latent_dim
+
+ def on_epoch_end(self, epoch, logs=None):
+ random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
+ generated_images = self.model.generator(random_latent_vectors)
+ generated_images = (generated_images * 127.5) + 127.5
+
+ for i in range(self.num_img):
+ img = generated_images[i].numpy()
+ img = keras.utils.array_to_img(img)
+ img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))
+
+
+"""
+## Train the end-to-end model
+"""
+
+# Instantiate the optimizer for both networks
+# (learning_rate=0.0002, beta_1=0.5 are recommended)
+generator_optimizer = keras.optimizers.Adam(
+ learning_rate=0.0002, beta_1=0.5, beta_2=0.9
+)
+discriminator_optimizer = keras.optimizers.Adam(
+ learning_rate=0.0002, beta_1=0.5, beta_2=0.9
+)
+
+
+# Define the loss functions for the discriminator,
+# which should be (fake_loss - real_loss).
+# We will add the gradient penalty later to this loss function.
+def discriminator_loss(real_img, fake_img):
+ real_loss = tf.reduce_mean(real_img)
+ fake_loss = tf.reduce_mean(fake_img)
+ return fake_loss - real_loss
+
+
+# Define the loss functions for the generator.
+def generator_loss(fake_img):
+ return -tf.reduce_mean(fake_img)
+
+
+# Set the number of epochs for training.
+epochs = 20
+
+# Instantiate the customer `GANMonitor` Keras callback.
+cbk = GANMonitor(num_img=3, latent_dim=noise_dim)
+
+# Get the wgan model
+wgan = WGAN(
+ discriminator=d_model,
+ generator=g_model,
+ latent_dim=noise_dim,
+ discriminator_extra_steps=3,
+)
+
+# Compile the wgan model
+wgan.compile(
+ d_optimizer=discriminator_optimizer,
+ g_optimizer=generator_optimizer,
+ g_loss_fn=generator_loss,
+ d_loss_fn=discriminator_loss,
+)
+
+# Start training
+wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])
+
+"""
+Display the last generated images:
+"""
+
+from IPython.display import Image, display
+
+display(Image("generated_img_0_19.png"))
+display(Image("generated_img_1_19.png"))
+display(Image("generated_img_2_19.png"))
diff --git a/knowledge_base/graph/gat_node_classification.py b/knowledge_base/graph/gat_node_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..701e52718a5e926683e5e6d4dc7ff4dbc4eb1172
--- /dev/null
+++ b/knowledge_base/graph/gat_node_classification.py
@@ -0,0 +1,398 @@
+"""
+Title: Graph attention network (GAT) for node classification
+Author: [akensert](https://github.com/akensert)
+Date created: 2021/09/13
+Last modified: 2021/12/26
+Description: An implementation of a Graph Attention Network (GAT) for node classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+[Graph neural networks](https://en.wikipedia.org/wiki/Graph_neural_network)
+is the preferred neural network architecture for processing data structured as
+graphs (for example, social networks or molecule structures), yielding
+better results than fully-connected networks or convolutional networks.
+
+In this tutorial, we will implement a specific graph neural network known as a
+[Graph Attention Network](https://arxiv.org/abs/1710.10903) (GAT) to predict labels of
+scientific papers based on what type of papers cite them (using the
+[Cora](https://linqs.soe.ucsc.edu/data) dataset).
+
+### References
+
+For more information on GAT, see the original paper
+[Graph Attention Networks](https://arxiv.org/abs/1710.10903) as well as
+[DGL's Graph Attention Networks](https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html)
+documentation.
+"""
+
+"""
+### Import packages
+"""
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import numpy as np
+import pandas as pd
+import os
+import warnings
+
+warnings.filterwarnings("ignore")
+pd.set_option("display.max_columns", 6)
+pd.set_option("display.max_rows", 6)
+np.random.seed(2)
+
+"""
+## Obtain the dataset
+
+The preparation of the [Cora dataset](https://linqs.soe.ucsc.edu/data) follows that of the
+[Node classification with Graph Neural Networks](https://keras.io/examples/graph/gnn_citations/)
+tutorial. Refer to this tutorial for more details on the dataset and exploratory data analysis.
+In brief, the Cora dataset consists of two files: `cora.cites` which contains *directed links* (citations) between
+papers; and `cora.content` which contains *features* of the corresponding papers and one
+of seven labels (the *subject* of the paper).
+"""
+
+zip_file = keras.utils.get_file(
+ fname="cora.tgz",
+ origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
+ extract=True,
+)
+
+data_dir = os.path.join(os.path.dirname(zip_file), "cora")
+
+citations = pd.read_csv(
+ os.path.join(data_dir, "cora.cites"),
+ sep="\t",
+ header=None,
+ names=["target", "source"],
+)
+
+papers = pd.read_csv(
+ os.path.join(data_dir, "cora.content"),
+ sep="\t",
+ header=None,
+ names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
+)
+
+class_values = sorted(papers["subject"].unique())
+class_idx = {name: id for id, name in enumerate(class_values)}
+paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
+
+papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
+citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
+citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
+papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
+
+print(citations)
+
+print(papers)
+
+"""
+### Split the dataset
+"""
+
+# Obtain random indices
+random_indices = np.random.permutation(range(papers.shape[0]))
+
+# 50/50 split
+train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
+test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]
+
+"""
+### Prepare the graph data
+"""
+
+# Obtain paper indices which will be used to gather node states
+# from the graph later on when training the model
+train_indices = train_data["paper_id"].to_numpy()
+test_indices = test_data["paper_id"].to_numpy()
+
+# Obtain ground truth labels corresponding to each paper_id
+train_labels = train_data["subject"].to_numpy()
+test_labels = test_data["subject"].to_numpy()
+
+# Define graph, namely an edge tensor and a node feature tensor
+edges = tf.convert_to_tensor(citations[["target", "source"]])
+node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])
+
+# Print shapes of the graph
+print("Edges shape:\t\t", edges.shape)
+print("Node features shape:", node_states.shape)
+
+"""
+## Build the model
+
+GAT takes as input a graph (namely an edge tensor and a node feature tensor) and
+outputs \[updated\] node states. The node states are, for each target node, neighborhood
+aggregated information of *N*-hops (where *N* is decided by the number of layers of the
+GAT). Importantly, in contrast to the
+[graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN)
+the GAT makes use of attention mechanisms
+to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply
+averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*),
+GAT first applies normalized attention scores to each source node state and then sums.
+"""
+
+"""
+### (Multi-head) graph attention layer
+
+The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention`
+layer is simply a concatenation (or averaging) of multiple graph attention layers
+(`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer
+does the following:
+
+Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`.
+
+For each target node:
+
+1. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`,
+resulting in `e_{ij}` (for all `j`).
+`||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}`
+corresponds to a given 1-hop neighbor/source node.
+2. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores
+to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1.
+3. Applies attention scores `e_{norm}_{ij}` to `z_{j}`
+and adds it to the new target node state `h^{l+1}_{i}`, for all `j`.
+"""
+
+
+class GraphAttention(layers.Layer):
+ def __init__(
+ self,
+ units,
+ kernel_initializer="glorot_uniform",
+ kernel_regularizer=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.units = units
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
+ self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
+
+ def build(self, input_shape):
+ self.kernel = self.add_weight(
+ shape=(input_shape[0][-1], self.units),
+ trainable=True,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ name="kernel",
+ )
+ self.kernel_attention = self.add_weight(
+ shape=(self.units * 2, 1),
+ trainable=True,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ name="kernel_attention",
+ )
+ self.built = True
+
+ def call(self, inputs):
+ node_states, edges = inputs
+
+ # Linearly transform node states
+ node_states_transformed = tf.matmul(node_states, self.kernel)
+
+ # (1) Compute pair-wise attention scores
+ node_states_expanded = tf.gather(node_states_transformed, edges)
+ node_states_expanded = tf.reshape(
+ node_states_expanded, (tf.shape(edges)[0], -1)
+ )
+ attention_scores = tf.nn.leaky_relu(
+ tf.matmul(node_states_expanded, self.kernel_attention)
+ )
+ attention_scores = tf.squeeze(attention_scores, -1)
+
+ # (2) Normalize attention scores
+ attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
+ attention_scores_sum = tf.math.unsorted_segment_sum(
+ data=attention_scores,
+ segment_ids=edges[:, 0],
+ num_segments=tf.reduce_max(edges[:, 0]) + 1,
+ )
+ attention_scores_sum = tf.repeat(
+ attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
+ )
+ attention_scores_norm = attention_scores / attention_scores_sum
+
+ # (3) Gather node states of neighbors, apply attention scores and aggregate
+ node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
+ out = tf.math.unsorted_segment_sum(
+ data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
+ segment_ids=edges[:, 0],
+ num_segments=tf.shape(node_states)[0],
+ )
+ return out
+
+
+class MultiHeadGraphAttention(layers.Layer):
+ def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
+ super().__init__(**kwargs)
+ self.num_heads = num_heads
+ self.merge_type = merge_type
+ self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]
+
+ def call(self, inputs):
+ atom_features, pair_indices = inputs
+
+ # Obtain outputs from each attention head
+ outputs = [
+ attention_layer([atom_features, pair_indices])
+ for attention_layer in self.attention_layers
+ ]
+ # Concatenate or average the node states from each head
+ if self.merge_type == "concat":
+ outputs = tf.concat(outputs, axis=-1)
+ else:
+ outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
+ # Activate and return node states
+ return tf.nn.relu(outputs)
+
+
+"""
+### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods
+
+Notice, the GAT model operates on the entire graph (namely, `node_states` and
+`edges`) in all phases (training, validation and testing). Hence, `node_states` and
+`edges` are passed to the constructor of the `keras.Model` and used as attributes.
+The difference between the phases are the indices (and labels), which gathers
+certain outputs (`tf.gather(outputs, indices)`).
+
+"""
+
+
+class GraphAttentionNetwork(keras.Model):
+ def __init__(
+ self,
+ node_states,
+ edges,
+ hidden_units,
+ num_heads,
+ num_layers,
+ output_dim,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.node_states = node_states
+ self.edges = edges
+ self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
+ self.attention_layers = [
+ MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
+ ]
+ self.output_layer = layers.Dense(output_dim)
+
+ def call(self, inputs):
+ node_states, edges = inputs
+ x = self.preprocess(node_states)
+ for attention_layer in self.attention_layers:
+ x = attention_layer([x, edges]) + x
+ outputs = self.output_layer(x)
+ return outputs
+
+ def train_step(self, data):
+ indices, labels = data
+
+ with tf.GradientTape() as tape:
+ # Forward pass
+ outputs = self([self.node_states, self.edges])
+ # Compute loss
+ loss = self.compiled_loss(labels, tf.gather(outputs, indices))
+ # Compute gradients
+ grads = tape.gradient(loss, self.trainable_weights)
+ # Apply gradients (update weights)
+ optimizer.apply_gradients(zip(grads, self.trainable_weights))
+ # Update metric(s)
+ self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def predict_step(self, data):
+ indices = data
+ # Forward pass
+ outputs = self([self.node_states, self.edges])
+ # Compute probabilities
+ return tf.nn.softmax(tf.gather(outputs, indices))
+
+ def test_step(self, data):
+ indices, labels = data
+ # Forward pass
+ outputs = self([self.node_states, self.edges])
+ # Compute loss
+ loss = self.compiled_loss(labels, tf.gather(outputs, indices))
+ # Update metric(s)
+ self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
+
+ return {m.name: m.result() for m in self.metrics}
+
+
+"""
+### Train and evaluate
+"""
+
+# Define hyper-parameters
+HIDDEN_UNITS = 100
+NUM_HEADS = 8
+NUM_LAYERS = 3
+OUTPUT_DIM = len(class_values)
+
+NUM_EPOCHS = 100
+BATCH_SIZE = 256
+VALIDATION_SPLIT = 0.1
+LEARNING_RATE = 3e-1
+MOMENTUM = 0.9
+
+loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
+accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
+early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
+)
+
+# Build model
+gat_model = GraphAttentionNetwork(
+ node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
+)
+
+# Compile model
+gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])
+
+gat_model.fit(
+ x=train_indices,
+ y=train_labels,
+ validation_split=VALIDATION_SPLIT,
+ batch_size=BATCH_SIZE,
+ epochs=NUM_EPOCHS,
+ callbacks=[early_stopping],
+ verbose=2,
+)
+
+_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)
+
+print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")
+
+"""
+### Predict (probabilities)
+"""
+test_probs = gat_model.predict(x=test_indices)
+
+mapping = {v: k for (k, v) in class_idx.items()}
+
+for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
+ print(f"Example {i+1}: {mapping[label]}")
+ for j, c in zip(probs, class_idx.keys()):
+ print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
+ print("---" * 20)
+
+"""
+## Conclusions
+
+The results look OK! The GAT model seems to correctly predict the subjects of the papers,
+based on what they cite, about 80% of the time. Further improvements could be
+made by fine-tuning the hyper-parameters of the GAT. For instance, try changing the number of layers,
+the number of hidden units, or the optimizer/learning rate; add regularization (e.g., dropout);
+or modify the preprocessing step. We could also try to implement *self-loops*
+(i.e., paper X cites paper X) and/or make the graph *undirected*.
+"""
diff --git a/knowledge_base/graph/gnn_citations.py b/knowledge_base/graph/gnn_citations.py
new file mode 100644
index 0000000000000000000000000000000000000000..8165cb5a05bd15e2762bd03b121a2416ec2fd9e2
--- /dev/null
+++ b/knowledge_base/graph/gnn_citations.py
@@ -0,0 +1,706 @@
+"""
+Title: Node Classification with Graph Neural Networks
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/05/30
+Last modified: 2021/05/30
+Description: Implementing a graph neural network model for predicting the topic of a paper given its citations.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Many datasets in various machine learning (ML) applications have structural relationships
+between their entities, which can be represented as graphs. Such application includes
+social and communication networks analysis, traffic prediction, and fraud detection.
+[Graph representation Learning](https://www.cs.mcgill.ca/~wlh/grl_book/)
+aims to build and train models for graph datasets to be used for a variety of ML tasks.
+
+This example demonstrate a simple implementation of a [Graph Neural Network](https://arxiv.org/pdf/1901.00596.pdf)
+(GNN) model. The model is used for a node prediction task on the [Cora dataset](https://relational.fit.cvut.cz/dataset/CORA)
+to predict the subject of a paper given its words and citations network.
+
+Note that, **we implement a Graph Convolution Layer from scratch** to provide better
+understanding of how they work. However, there is a number of specialized TensorFlow-based
+libraries that provide rich GNN APIs, such as [Spectral](https://graphneural.network/),
+[StellarGraph](https://stellargraph.readthedocs.io/en/stable/README.html), and
+[GraphNets](https://github.com/deepmind/graph_nets).
+"""
+
+"""
+## Setup
+"""
+
+import os
+import pandas as pd
+import numpy as np
+import networkx as nx
+import matplotlib.pyplot as plt
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+
+"""
+## Prepare the Dataset
+
+The Cora dataset consists of 2,708 scientific papers classified into one of seven classes.
+The citation network consists of 5,429 links. Each paper has a binary word vector of size
+1,433, indicating the presence of a corresponding word.
+
+### Download the dataset
+
+The dataset has two tap-separated files: `cora.cites` and `cora.content`.
+
+1. The `cora.cites` includes the citation records with two columns:
+`cited_paper_id` (target) and `citing_paper_id` (source).
+2. The `cora.content` includes the paper content records with 1,435 columns:
+`paper_id`, `subject`, and 1,433 binary features.
+
+Let's download the dataset.
+"""
+
+zip_file = keras.utils.get_file(
+ fname="cora.tgz",
+ origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
+ extract=True,
+)
+data_dir = os.path.join(os.path.dirname(zip_file), "cora")
+"""
+### Process and visualize the dataset
+
+Then we load the citations data into a Pandas DataFrame.
+"""
+
+citations = pd.read_csv(
+ os.path.join(data_dir, "cora.cites"),
+ sep="\t",
+ header=None,
+ names=["target", "source"],
+)
+print("Citations shape:", citations.shape)
+
+"""
+Now we display a sample of the `citations` DataFrame.
+The `target` column includes the paper ids cited by the paper ids in the `source` column.
+"""
+
+citations.sample(frac=1).head()
+
+"""
+Now let's load the papers data into a Pandas DataFrame.
+"""
+
+column_names = ["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"]
+papers = pd.read_csv(
+ os.path.join(data_dir, "cora.content"),
+ sep="\t",
+ header=None,
+ names=column_names,
+)
+print("Papers shape:", papers.shape)
+
+"""
+Now we display a sample of the `papers` DataFrame. The DataFrame includes the `paper_id`
+and the `subject` columns, as well as 1,433 binary column representing whether a term exists
+in the paper or not.
+"""
+
+print(papers.sample(5).T)
+
+"""
+Let's display the count of the papers in each subject.
+"""
+
+print(papers.subject.value_counts())
+
+"""
+We convert the paper ids and the subjects into zero-based indices.
+"""
+
+class_values = sorted(papers["subject"].unique())
+class_idx = {name: id for id, name in enumerate(class_values)}
+paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
+
+papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
+citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
+citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
+papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
+
+"""
+Now let's visualize the citation graph. Each node in the graph represents a paper,
+and the color of the node corresponds to its subject. Note that we only show a sample of
+the papers in the dataset.
+"""
+
+plt.figure(figsize=(10, 10))
+colors = papers["subject"].tolist()
+cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
+subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
+nx.draw_spring(cora_graph, node_size=15, node_color=subjects)
+
+
+"""
+### Split the dataset into stratified train and test sets
+"""
+
+train_data, test_data = [], []
+
+for _, group_data in papers.groupby("subject"):
+ # Select around 50% of the dataset for training.
+ random_selection = np.random.rand(len(group_data.index)) <= 0.5
+ train_data.append(group_data[random_selection])
+ test_data.append(group_data[~random_selection])
+
+train_data = pd.concat(train_data).sample(frac=1)
+test_data = pd.concat(test_data).sample(frac=1)
+
+print("Train data shape:", train_data.shape)
+print("Test data shape:", test_data.shape)
+
+"""
+## Implement Train and Evaluate Experiment
+"""
+
+hidden_units = [32, 32]
+learning_rate = 0.01
+dropout_rate = 0.5
+num_epochs = 300
+batch_size = 256
+
+"""
+This function compiles and trains an input model using the given training data.
+"""
+
+
+def run_experiment(model, x_train, y_train):
+ # Compile the model.
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
+ )
+ # Create an early stopping callback.
+ early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_acc", patience=50, restore_best_weights=True
+ )
+ # Fit the model.
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ epochs=num_epochs,
+ batch_size=batch_size,
+ validation_split=0.15,
+ callbacks=[early_stopping],
+ )
+
+ return history
+
+
+"""
+This function displays the loss and accuracy curves of the model during training.
+"""
+
+
+def display_learning_curves(history):
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
+
+ ax1.plot(history.history["loss"])
+ ax1.plot(history.history["val_loss"])
+ ax1.legend(["train", "test"], loc="upper right")
+ ax1.set_xlabel("Epochs")
+ ax1.set_ylabel("Loss")
+
+ ax2.plot(history.history["acc"])
+ ax2.plot(history.history["val_acc"])
+ ax2.legend(["train", "test"], loc="upper right")
+ ax2.set_xlabel("Epochs")
+ ax2.set_ylabel("Accuracy")
+ plt.show()
+
+
+"""
+## Implement Feedforward Network (FFN) Module
+
+We will use this module in the baseline and the GNN models.
+"""
+
+
+def create_ffn(hidden_units, dropout_rate, name=None):
+ fnn_layers = []
+
+ for units in hidden_units:
+ fnn_layers.append(layers.BatchNormalization())
+ fnn_layers.append(layers.Dropout(dropout_rate))
+ fnn_layers.append(layers.Dense(units, activation=tf.nn.gelu))
+
+ return keras.Sequential(fnn_layers, name=name)
+
+
+"""
+## Build a Baseline Neural Network Model
+
+### Prepare the data for the baseline model
+"""
+
+feature_names = list(set(papers.columns) - {"paper_id", "subject"})
+num_features = len(feature_names)
+num_classes = len(class_idx)
+
+# Create train and test features as a numpy array.
+x_train = train_data[feature_names].to_numpy()
+x_test = test_data[feature_names].to_numpy()
+# Create train and test targets as a numpy array.
+y_train = train_data["subject"]
+y_test = test_data["subject"]
+
+"""
+### Implement a baseline classifier
+
+We add five FFN blocks with skip connections, so that we generate a baseline model with
+roughly the same number of parameters as the GNN models to be built later.
+"""
+
+
+def create_baseline_model(hidden_units, num_classes, dropout_rate=0.2):
+ inputs = layers.Input(shape=(num_features,), name="input_features")
+ x = create_ffn(hidden_units, dropout_rate, name=f"ffn_block1")(inputs)
+ for block_idx in range(4):
+ # Create an FFN block.
+ x1 = create_ffn(hidden_units, dropout_rate, name=f"ffn_block{block_idx + 2}")(x)
+ # Add skip connection.
+ x = layers.Add(name=f"skip_connection{block_idx + 2}")([x, x1])
+ # Compute logits.
+ logits = layers.Dense(num_classes, name="logits")(x)
+ # Create the model.
+ return keras.Model(inputs=inputs, outputs=logits, name="baseline")
+
+
+baseline_model = create_baseline_model(hidden_units, num_classes, dropout_rate)
+baseline_model.summary()
+
+"""
+### Train the baseline classifier
+"""
+
+history = run_experiment(baseline_model, x_train, y_train)
+
+"""
+Let's plot the learning curves.
+"""
+
+display_learning_curves(history)
+
+"""
+Now we evaluate the baseline model on the test data split.
+"""
+
+_, test_accuracy = baseline_model.evaluate(x=x_test, y=y_test, verbose=0)
+print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
+
+"""
+### Examine the baseline model predictions
+
+Let's create new data instances by randomly generating binary word vectors with respect to
+the word presence probabilities.
+"""
+
+
+def generate_random_instances(num_instances):
+ token_probability = x_train.mean(axis=0)
+ instances = []
+ for _ in range(num_instances):
+ probabilities = np.random.uniform(size=len(token_probability))
+ instance = (probabilities <= token_probability).astype(int)
+ instances.append(instance)
+
+ return np.array(instances)
+
+
+def display_class_probabilities(probabilities):
+ for instance_idx, probs in enumerate(probabilities):
+ print(f"Instance {instance_idx + 1}:")
+ for class_idx, prob in enumerate(probs):
+ print(f"- {class_values[class_idx]}: {round(prob * 100, 2)}%")
+
+
+"""
+Now we show the baseline model predictions given these randomly generated instances.
+"""
+
+new_instances = generate_random_instances(num_classes)
+logits = baseline_model.predict(new_instances)
+probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
+display_class_probabilities(probabilities)
+
+"""
+## Build a Graph Neural Network Model
+
+### Prepare the data for the graph model
+
+Preparing and loading the graphs data into the model for training is the most challenging
+part in GNN models, which is addressed in different ways by the specialised libraries.
+In this example, we show a simple approach for preparing and using graph data that is suitable
+if your dataset consists of a single graph that fits entirely in memory.
+
+The graph data is represented by the `graph_info` tuple, which consists of the following
+three elements:
+
+1. `node_features`: This is a `[num_nodes, num_features]` NumPy array that includes the
+node features. In this dataset, the nodes are the papers, and the `node_features` are the
+word-presence binary vectors of each paper.
+2. `edges`: This is `[num_edges, num_edges]` NumPy array representing a sparse
+[adjacency matrix](https://en.wikipedia.org/wiki/Adjacency_matrix#:~:text=In%20graph%20theory%20and%20computer,with%20zeros%20on%20its%20diagonal.)
+of the links between the nodes. In this example, the links are the citations between the papers.
+3. `edge_weights` (optional): This is a `[num_edges]` NumPy array that includes the edge weights, which *quantify*
+the relationships between nodes in the graph. In this example, there are no weights for the paper citations.
+"""
+
+# Create an edges array (sparse adjacency matrix) of shape [2, num_edges].
+edges = citations[["source", "target"]].to_numpy().T
+# Create an edge weights array of ones.
+edge_weights = tf.ones(shape=edges.shape[1])
+# Create a node features array of shape [num_nodes, num_features].
+node_features = tf.cast(
+ papers.sort_values("paper_id")[feature_names].to_numpy(), dtype=tf.dtypes.float32
+)
+# Create graph info tuple with node_features, edges, and edge_weights.
+graph_info = (node_features, edges, edge_weights)
+
+print("Edges shape:", edges.shape)
+print("Nodes shape:", node_features.shape)
+
+"""
+### Implement a graph convolution layer
+
+We implement a graph convolution module as a [Keras Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer?version=nightly).
+Our `GraphConvLayer` performs the following steps:
+
+1. **Prepare**: The input node representations are processed using a FFN to produce a *message*. You can simplify
+the processing by only applying linear transformation to the representations.
+2. **Aggregate**: The messages of the neighbours of each node are aggregated with
+respect to the `edge_weights` using a *permutation invariant* pooling operation, such as *sum*, *mean*, and *max*,
+to prepare a single aggregated message for each node. See, for example, [tf.math.unsorted_segment_sum](https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum)
+APIs used to aggregate neighbour messages.
+3. **Update**: The `node_repesentations` and `aggregated_messages`โboth of shape `[num_nodes, representation_dim]`โ
+are combined and processed to produce the new state of the node representations (node embeddings).
+If `combination_type` is `gru`, the `node_repesentations` and `aggregated_messages` are stacked to create a sequence,
+then processed by a GRU layer. Otherwise, the `node_repesentations` and `aggregated_messages` are added
+or concatenated, then processed using a FFN.
+
+
+The technique implemented use ideas from [Graph Convolutional Networks](https://arxiv.org/abs/1609.02907),
+[GraphSage](https://arxiv.org/abs/1706.02216), [Graph Isomorphism Network](https://arxiv.org/abs/1810.00826),
+[Simple Graph Networks](https://arxiv.org/abs/1902.07153), and
+[Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493).
+Two other key techniques that are not covered are [Graph Attention Networks](https://arxiv.org/abs/1710.10903)
+and [Message Passing Neural Networks](https://arxiv.org/abs/1704.01212).
+"""
+
+
+def create_gru(hidden_units, dropout_rate):
+ inputs = keras.layers.Input(shape=(2, hidden_units[0]))
+ x = inputs
+ for units in hidden_units:
+ x = layers.GRU(
+ units=units,
+ activation="tanh",
+ recurrent_activation="sigmoid",
+ return_sequences=True,
+ dropout=dropout_rate,
+ return_state=False,
+ recurrent_dropout=dropout_rate,
+ )(x)
+ return keras.Model(inputs=inputs, outputs=x)
+
+
+class GraphConvLayer(layers.Layer):
+ def __init__(
+ self,
+ hidden_units,
+ dropout_rate=0.2,
+ aggregation_type="mean",
+ combination_type="concat",
+ normalize=False,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.aggregation_type = aggregation_type
+ self.combination_type = combination_type
+ self.normalize = normalize
+
+ self.ffn_prepare = create_ffn(hidden_units, dropout_rate)
+ if self.combination_type == "gru":
+ self.update_fn = create_gru(hidden_units, dropout_rate)
+ else:
+ self.update_fn = create_ffn(hidden_units, dropout_rate)
+
+ def prepare(self, node_repesentations, weights=None):
+ # node_repesentations shape is [num_edges, embedding_dim].
+ messages = self.ffn_prepare(node_repesentations)
+ if weights is not None:
+ messages = messages * tf.expand_dims(weights, -1)
+ return messages
+
+ def aggregate(self, node_indices, neighbour_messages, node_repesentations):
+ # node_indices shape is [num_edges].
+ # neighbour_messages shape: [num_edges, representation_dim].
+ # node_repesentations shape is [num_nodes, representation_dim]
+ num_nodes = node_repesentations.shape[0]
+ if self.aggregation_type == "sum":
+ aggregated_message = tf.math.unsorted_segment_sum(
+ neighbour_messages, node_indices, num_segments=num_nodes
+ )
+ elif self.aggregation_type == "mean":
+ aggregated_message = tf.math.unsorted_segment_mean(
+ neighbour_messages, node_indices, num_segments=num_nodes
+ )
+ elif self.aggregation_type == "max":
+ aggregated_message = tf.math.unsorted_segment_max(
+ neighbour_messages, node_indices, num_segments=num_nodes
+ )
+ else:
+ raise ValueError(f"Invalid aggregation type: {self.aggregation_type}.")
+
+ return aggregated_message
+
+ def update(self, node_repesentations, aggregated_messages):
+ # node_repesentations shape is [num_nodes, representation_dim].
+ # aggregated_messages shape is [num_nodes, representation_dim].
+ if self.combination_type == "gru":
+ # Create a sequence of two elements for the GRU layer.
+ h = tf.stack([node_repesentations, aggregated_messages], axis=1)
+ elif self.combination_type == "concat":
+ # Concatenate the node_repesentations and aggregated_messages.
+ h = tf.concat([node_repesentations, aggregated_messages], axis=1)
+ elif self.combination_type == "add":
+ # Add node_repesentations and aggregated_messages.
+ h = node_repesentations + aggregated_messages
+ else:
+ raise ValueError(f"Invalid combination type: {self.combination_type}.")
+
+ # Apply the processing function.
+ node_embeddings = self.update_fn(h)
+ if self.combination_type == "gru":
+ node_embeddings = tf.unstack(node_embeddings, axis=1)[-1]
+
+ if self.normalize:
+ node_embeddings = tf.nn.l2_normalize(node_embeddings, axis=-1)
+ return node_embeddings
+
+ def call(self, inputs):
+ """Process the inputs to produce the node_embeddings.
+
+ inputs: a tuple of three elements: node_repesentations, edges, edge_weights.
+ Returns: node_embeddings of shape [num_nodes, representation_dim].
+ """
+
+ node_repesentations, edges, edge_weights = inputs
+ # Get node_indices (source) and neighbour_indices (target) from edges.
+ node_indices, neighbour_indices = edges[0], edges[1]
+ # neighbour_repesentations shape is [num_edges, representation_dim].
+ neighbour_repesentations = tf.gather(node_repesentations, neighbour_indices)
+
+ # Prepare the messages of the neighbours.
+ neighbour_messages = self.prepare(neighbour_repesentations, edge_weights)
+ # Aggregate the neighbour messages.
+ aggregated_messages = self.aggregate(
+ node_indices, neighbour_messages, node_repesentations
+ )
+ # Update the node embedding with the neighbour messages.
+ return self.update(node_repesentations, aggregated_messages)
+
+
+"""
+### Implement a graph neural network node classifier
+
+The GNN classification model follows the [Design Space for Graph Neural Networks](https://arxiv.org/abs/2011.08843) approach,
+as follows:
+
+1. Apply preprocessing using FFN to the node features to generate initial node representations.
+2. Apply one or more graph convolutional layer, with skip connections, to the node representation
+to produce node embeddings.
+3. Apply post-processing using FFN to the node embeddings to generate the final node embeddings.
+4. Feed the node embeddings in a Softmax layer to predict the node class.
+
+Each graph convolutional layer added captures information from a further level of neighbours.
+However, adding many graph convolutional layer can cause oversmoothing, where the model
+produces similar embeddings for all the nodes.
+
+Note that the `graph_info` passed to the constructor of the Keras model, and used as a *property*
+of the Keras model object, rather than input data for training or prediction.
+The model will accept a **batch** of `node_indices`, which are used to lookup the
+node features and neighbours from the `graph_info`.
+"""
+
+
+class GNNNodeClassifier(tf.keras.Model):
+ def __init__(
+ self,
+ graph_info,
+ num_classes,
+ hidden_units,
+ aggregation_type="sum",
+ combination_type="concat",
+ dropout_rate=0.2,
+ normalize=True,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ # Unpack graph_info to three elements: node_features, edges, and edge_weight.
+ node_features, edges, edge_weights = graph_info
+ self.node_features = node_features
+ self.edges = edges
+ self.edge_weights = edge_weights
+ # Set edge_weights to ones if not provided.
+ if self.edge_weights is None:
+ self.edge_weights = tf.ones(shape=edges.shape[1])
+ # Scale edge_weights to sum to 1.
+ self.edge_weights = self.edge_weights / tf.math.reduce_sum(self.edge_weights)
+
+ # Create a process layer.
+ self.preprocess = create_ffn(hidden_units, dropout_rate, name="preprocess")
+ # Create the first GraphConv layer.
+ self.conv1 = GraphConvLayer(
+ hidden_units,
+ dropout_rate,
+ aggregation_type,
+ combination_type,
+ normalize,
+ name="graph_conv1",
+ )
+ # Create the second GraphConv layer.
+ self.conv2 = GraphConvLayer(
+ hidden_units,
+ dropout_rate,
+ aggregation_type,
+ combination_type,
+ normalize,
+ name="graph_conv2",
+ )
+ # Create a postprocess layer.
+ self.postprocess = create_ffn(hidden_units, dropout_rate, name="postprocess")
+ # Create a compute logits layer.
+ self.compute_logits = layers.Dense(units=num_classes, name="logits")
+
+ def call(self, input_node_indices):
+ # Preprocess the node_features to produce node representations.
+ x = self.preprocess(self.node_features)
+ # Apply the first graph conv layer.
+ x1 = self.conv1((x, self.edges, self.edge_weights))
+ # Skip connection.
+ x = x1 + x
+ # Apply the second graph conv layer.
+ x2 = self.conv2((x, self.edges, self.edge_weights))
+ # Skip connection.
+ x = x2 + x
+ # Postprocess node embedding.
+ x = self.postprocess(x)
+ # Fetch node embeddings for the input node_indices.
+ node_embeddings = tf.gather(x, input_node_indices)
+ # Compute logits
+ return self.compute_logits(node_embeddings)
+
+
+"""
+Let's test instantiating and calling the GNN model.
+Notice that if you provide `N` node indices, the output will be a tensor of shape `[N, num_classes]`,
+regardless of the size of the graph.
+"""
+
+gnn_model = GNNNodeClassifier(
+ graph_info=graph_info,
+ num_classes=num_classes,
+ hidden_units=hidden_units,
+ dropout_rate=dropout_rate,
+ name="gnn_model",
+)
+
+print("GNN output shape:", gnn_model([1, 10, 100]))
+
+gnn_model.summary()
+
+"""
+### Train the GNN model
+
+Note that we use the standard *supervised* cross-entropy loss to train the model.
+However, we can add another *self-supervised* loss term for the generated node embeddings
+that makes sure that neighbouring nodes in graph have similar representations, while faraway
+nodes have dissimilar representations.
+"""
+
+x_train = train_data.paper_id.to_numpy()
+history = run_experiment(gnn_model, x_train, y_train)
+
+"""
+Let's plot the learning curves
+"""
+
+display_learning_curves(history)
+
+"""
+Now we evaluate the GNN model on the test data split.
+The results may vary depending on the training sample, however the GNN model always outperforms
+the baseline model in terms of the test accuracy.
+"""
+
+x_test = test_data.paper_id.to_numpy()
+_, test_accuracy = gnn_model.evaluate(x=x_test, y=y_test, verbose=0)
+print(f"Test accuracy: {round(test_accuracy * 100, 2)}%")
+
+"""
+### Examine the GNN model predictions
+
+Let's add the new instances as nodes to the `node_features`, and generate links
+(citations) to existing nodes.
+"""
+
+# First we add the N new_instances as nodes to the graph
+# by appending the new_instance to node_features.
+num_nodes = node_features.shape[0]
+new_node_features = np.concatenate([node_features, new_instances])
+# Second we add the M edges (citations) from each new node to a set
+# of existing nodes in a particular subject
+new_node_indices = [i + num_nodes for i in range(num_classes)]
+new_citations = []
+for subject_idx, group in papers.groupby("subject"):
+ subject_papers = list(group.paper_id)
+ # Select random x papers specific subject.
+ selected_paper_indices1 = np.random.choice(subject_papers, 5)
+ # Select random y papers from any subject (where y < x).
+ selected_paper_indices2 = np.random.choice(list(papers.paper_id), 2)
+ # Merge the selected paper indices.
+ selected_paper_indices = np.concatenate(
+ [selected_paper_indices1, selected_paper_indices2], axis=0
+ )
+ # Create edges between a citing paper idx and the selected cited papers.
+ citing_paper_indx = new_node_indices[subject_idx]
+ for cited_paper_idx in selected_paper_indices:
+ new_citations.append([citing_paper_indx, cited_paper_idx])
+
+new_citations = np.array(new_citations).T
+new_edges = np.concatenate([edges, new_citations], axis=1)
+
+"""
+Now let's update the `node_features` and the `edges` in the GNN model.
+"""
+
+print("Original node_features shape:", gnn_model.node_features.shape)
+print("Original edges shape:", gnn_model.edges.shape)
+gnn_model.node_features = new_node_features
+gnn_model.edges = new_edges
+gnn_model.edge_weights = tf.ones(shape=new_edges.shape[1])
+print("New node_features shape:", gnn_model.node_features.shape)
+print("New edges shape:", gnn_model.edges.shape)
+
+logits = gnn_model.predict(tf.convert_to_tensor(new_node_indices))
+probabilities = keras.activations.softmax(tf.convert_to_tensor(logits)).numpy()
+display_class_probabilities(probabilities)
+
+"""
+Notice that the probabilities of the expected subjects
+(to which several citations are added) are higher compared to the baseline model.
+"""
diff --git a/knowledge_base/graph/mpnn-molecular-graphs.py b/knowledge_base/graph/mpnn-molecular-graphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..d41e69f4ad238aed3cc77c9bbf25bffe03e90e0c
--- /dev/null
+++ b/knowledge_base/graph/mpnn-molecular-graphs.py
@@ -0,0 +1,675 @@
+"""
+Title: Message-passing neural network (MPNN) for molecular property prediction
+Author: [akensert](http://github.com/akensert)
+Date created: 2021/08/16
+Last modified: 2021/12/27
+Description: Implementation of an MPNN to predict blood-brain barrier permeability.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this tutorial, we will implement a type of graph neural network (GNN) known as
+_ message passing neural network_ (MPNN) to predict graph properties. Specifically, we will
+implement an MPNN to predict a molecular property known as
+_blood-brain barrier permeability_ (BBBP).
+
+Motivation: as molecules are naturally represented as an undirected graph `G = (V, E)`,
+where `V` is a set or vertices (nodes; atoms) and `E` a set of edges (bonds), GNNs (such
+as MPNN) are proving to be a useful method for predicting molecular properties.
+
+Until now, more traditional methods, such as random forests, support vector machines, etc.,
+have been commonly used to predict molecular properties. In contrast to GNNs, these
+traditional approaches often operate on precomputed molecular features such as
+molecular weight, polarity, charge, number of carbon atoms, etc. Although these
+molecular features prove to be good predictors for various molecular properties, it is
+hypothesized that operating on these more "raw", "low-level", features could prove even
+better.
+
+### References
+
+In recent years, a lot of effort has been put into developing neural networks for
+graph data, including molecular graphs. For a summary of graph neural networks, see e.g.,
+[A Comprehensive Survey on Graph Neural Networks](https://arxiv.org/abs/1901.00596) and
+[Graph Neural Networks: A Review of Methods and Applications](https://arxiv.org/abs/1812.08434);
+and for further reading on the specific
+graph neural network implemented in this tutorial see
+[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and
+[DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).
+"""
+
+"""
+## Setup
+
+### Install RDKit and other dependencies
+
+(Text below taken from
+[this tutorial](https://keras.io/examples/generative/wgan-graphs/)).
+
+[RDKit](https://www.rdkit.org/) is a collection of cheminformatics and machine-learning
+software written in C++ and Python. In this tutorial, RDKit is used to conveniently and
+efficiently transform
+[SMILES](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) to
+molecule objects, and then from those obtain sets of atoms and bonds.
+
+SMILES expresses the structure of a given molecule in the form of an ASCII string.
+The SMILES string is a compact encoding which, for smaller molecules, is relatively
+human-readable. Encoding molecules as a string both alleviates and facilitates database
+and/or web searching of a given molecule. RDKit uses algorithms to
+accurately transform a given SMILES to a molecule object, which can then
+be used to compute a great number of molecular properties/features.
+
+Notice, RDKit is commonly installed via [Conda](https://www.rdkit.org/docs/Install.html).
+However, thanks to
+[rdkit_platform_wheels](https://github.com/kuelumbus/rdkit_platform_wheels), rdkit
+can now (for the sake of this tutorial) be installed easily via pip, as follows:
+
+```
+pip -q install rdkit-pypi
+```
+
+And for easy and efficient reading of csv files and visualization, the below needs to be
+installed:
+
+```
+pip -q install pandas
+pip -q install Pillow
+pip -q install matplotlib
+pip -q install pydot
+sudo apt-get -qq install graphviz
+```
+"""
+
+"""
+### Import packages
+"""
+
+import os
+
+# Temporary suppress tf logs
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import warnings
+from rdkit import Chem
+from rdkit import RDLogger
+from rdkit.Chem.Draw import IPythonConsole
+from rdkit.Chem.Draw import MolsToGridImage
+
+# Temporary suppress warnings and RDKit logs
+warnings.filterwarnings("ignore")
+RDLogger.DisableLog("rdApp.*")
+
+np.random.seed(42)
+tf.random.set_seed(42)
+
+"""
+## Dataset
+
+Information about the dataset can be found in
+[A Bayesian Approach to in Silico Blood-Brain Barrier Penetration Modeling](https://pubs.acs.org/doi/10.1021/ci300124c)
+and [MoleculeNet: A Benchmark for Molecular Machine Learning](https://arxiv.org/abs/1703.00564).
+The dataset will be downloaded from [MoleculeNet.org](https://moleculenet.org/datasets-1).
+
+### About
+
+The dataset contains **2,050** molecules. Each molecule come with a **name**, **label**
+and **SMILES** string.
+
+The blood-brain barrier (BBB) is a membrane separating the blood from the brain
+extracellular fluid, hence blocking out most drugs (molecules) from reaching
+the brain. Because of this, the BBBP has been important to study for the development of
+new drugs that aim to target the central nervous system. The labels for this
+data set are binary (1 or 0) and indicate the permeability of the molecules.
+"""
+
+csv_path = keras.utils.get_file(
+ "BBBP.csv", "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
+)
+
+df = pd.read_csv(csv_path, usecols=[1, 2, 3])
+df.iloc[96:104]
+
+"""
+### Define features
+
+To encode features for atoms and bonds (which we will need later),
+we'll define two classes: `AtomFeaturizer` and `BondFeaturizer` respectively.
+
+To reduce the lines of code, i.e., to keep this tutorial short and concise,
+only about a handful of (atom and bond) features will be considered: \[atom features\]
+[symbol (element)](https://en.wikipedia.org/wiki/Chemical_element),
+[number of valence electrons](https://en.wikipedia.org/wiki/Valence_electron),
+[number of hydrogen bonds](https://en.wikipedia.org/wiki/Hydrogen),
+[orbital hybridization](https://en.wikipedia.org/wiki/Orbital_hybridisation),
+\[bond features\]
+[(covalent) bond type](https://en.wikipedia.org/wiki/Covalent_bond), and
+[conjugation](https://en.wikipedia.org/wiki/Conjugated_system).
+"""
+
+
+class Featurizer:
+ def __init__(self, allowable_sets):
+ self.dim = 0
+ self.features_mapping = {}
+ for k, s in allowable_sets.items():
+ s = sorted(list(s))
+ self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim)))
+ self.dim += len(s)
+
+ def encode(self, inputs):
+ output = np.zeros((self.dim,))
+ for name_feature, feature_mapping in self.features_mapping.items():
+ feature = getattr(self, name_feature)(inputs)
+ if feature not in feature_mapping:
+ continue
+ output[feature_mapping[feature]] = 1.0
+ return output
+
+
+class AtomFeaturizer(Featurizer):
+ def __init__(self, allowable_sets):
+ super().__init__(allowable_sets)
+
+ def symbol(self, atom):
+ return atom.GetSymbol()
+
+ def n_valence(self, atom):
+ return atom.GetTotalValence()
+
+ def n_hydrogens(self, atom):
+ return atom.GetTotalNumHs()
+
+ def hybridization(self, atom):
+ return atom.GetHybridization().name.lower()
+
+
+class BondFeaturizer(Featurizer):
+ def __init__(self, allowable_sets):
+ super().__init__(allowable_sets)
+ self.dim += 1
+
+ def encode(self, bond):
+ output = np.zeros((self.dim,))
+ if bond is None:
+ output[-1] = 1.0
+ return output
+ output = super().encode(bond)
+ return output
+
+ def bond_type(self, bond):
+ return bond.GetBondType().name.lower()
+
+ def conjugated(self, bond):
+ return bond.GetIsConjugated()
+
+
+atom_featurizer = AtomFeaturizer(
+ allowable_sets={
+ "symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
+ "n_valence": {0, 1, 2, 3, 4, 5, 6},
+ "n_hydrogens": {0, 1, 2, 3, 4},
+ "hybridization": {"s", "sp", "sp2", "sp3"},
+ }
+)
+
+bond_featurizer = BondFeaturizer(
+ allowable_sets={
+ "bond_type": {"single", "double", "triple", "aromatic"},
+ "conjugated": {True, False},
+ }
+)
+
+
+"""
+### Generate graphs
+
+Before we can generate complete graphs from SMILES, we need to implement the following functions:
+
+1. `molecule_from_smiles`, which takes as input a SMILES and returns a molecule object.
+This is all handled by RDKit.
+
+2. `graph_from_molecule`, which takes as input a molecule object and returns a graph,
+represented as a three-tuple (atom_features, bond_features, pair_indices). For this we
+will make use of the classes defined previously.
+
+Finally, we can now implement the function `graphs_from_smiles`, which applies function (1)
+and subsequently (2) on all SMILES of the training, validation and test datasets.
+
+Notice: although scaffold splitting is recommended for this data set (see
+[here](https://arxiv.org/abs/1703.00564)), for simplicity, simple random splittings were
+performed.
+"""
+
+
+def molecule_from_smiles(smiles):
+ # MolFromSmiles(m, sanitize=True) should be equivalent to
+ # MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...)
+ molecule = Chem.MolFromSmiles(smiles, sanitize=False)
+
+ # If sanitization is unsuccessful, catch the error, and try again without
+ # the sanitization step that caused the error
+ flag = Chem.SanitizeMol(molecule, catchErrors=True)
+ if flag != Chem.SanitizeFlags.SANITIZE_NONE:
+ Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
+
+ Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
+ return molecule
+
+
+def graph_from_molecule(molecule):
+ # Initialize graph
+ atom_features = []
+ bond_features = []
+ pair_indices = []
+
+ for atom in molecule.GetAtoms():
+ atom_features.append(atom_featurizer.encode(atom))
+
+ # Add self-loops
+ pair_indices.append([atom.GetIdx(), atom.GetIdx()])
+ bond_features.append(bond_featurizer.encode(None))
+
+ for neighbor in atom.GetNeighbors():
+ bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
+ pair_indices.append([atom.GetIdx(), neighbor.GetIdx()])
+ bond_features.append(bond_featurizer.encode(bond))
+
+ return np.array(atom_features), np.array(bond_features), np.array(pair_indices)
+
+
+def graphs_from_smiles(smiles_list):
+ # Initialize graphs
+ atom_features_list = []
+ bond_features_list = []
+ pair_indices_list = []
+
+ for smiles in smiles_list:
+ molecule = molecule_from_smiles(smiles)
+ atom_features, bond_features, pair_indices = graph_from_molecule(molecule)
+
+ atom_features_list.append(atom_features)
+ bond_features_list.append(bond_features)
+ pair_indices_list.append(pair_indices)
+
+ # Convert lists to ragged tensors for tf.data.Dataset later on
+ return (
+ tf.ragged.constant(atom_features_list, dtype=tf.float32),
+ tf.ragged.constant(bond_features_list, dtype=tf.float32),
+ tf.ragged.constant(pair_indices_list, dtype=tf.int64),
+ )
+
+
+# Shuffle array of indices ranging from 0 to 2049
+permuted_indices = np.random.permutation(np.arange(df.shape[0]))
+
+# Train set: 80 % of data
+train_index = permuted_indices[: int(df.shape[0] * 0.8)]
+x_train = graphs_from_smiles(df.iloc[train_index].smiles)
+y_train = df.iloc[train_index].p_np
+
+# Valid set: 19 % of data
+valid_index = permuted_indices[int(df.shape[0] * 0.8) : int(df.shape[0] * 0.99)]
+x_valid = graphs_from_smiles(df.iloc[valid_index].smiles)
+y_valid = df.iloc[valid_index].p_np
+
+# Test set: 1 % of data
+test_index = permuted_indices[int(df.shape[0] * 0.99) :]
+x_test = graphs_from_smiles(df.iloc[test_index].smiles)
+y_test = df.iloc[test_index].p_np
+
+"""
+### Test the functions
+"""
+
+print(f"Name:\t{df.name[100]}\nSMILES:\t{df.smiles[100]}\nBBBP:\t{df.p_np[100]}")
+molecule = molecule_from_smiles(df.iloc[100].smiles)
+print("Molecule:")
+molecule
+
+"""
+"""
+
+graph = graph_from_molecule(molecule)
+print("Graph (including self-loops):")
+print("\tatom features\t", graph[0].shape)
+print("\tbond features\t", graph[1].shape)
+print("\tpair indices\t", graph[2].shape)
+
+
+"""
+### Create a `tf.data.Dataset`
+
+In this tutorial, the MPNN implementation will take as input (per iteration) a single graph.
+Therefore, given a batch of (sub)graphs (molecules), we need to merge them into a
+single graph (we'll refer to this graph as *global graph*).
+This global graph is a disconnected graph where each subgraph is
+completely separated from the other subgraphs.
+"""
+
+
+def prepare_batch(x_batch, y_batch):
+ """Merges (sub)graphs of batch into a single global (disconnected) graph"""
+
+ atom_features, bond_features, pair_indices = x_batch
+
+ # Obtain number of atoms and bonds for each graph (molecule)
+ num_atoms = atom_features.row_lengths()
+ num_bonds = bond_features.row_lengths()
+
+ # Obtain partition indices (molecule_indicator), which will be used to
+ # gather (sub)graphs from global graph in model later on
+ molecule_indices = tf.range(len(num_atoms))
+ molecule_indicator = tf.repeat(molecule_indices, num_atoms)
+
+ # Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
+ # 'pair_indices' (and merging ragged tensors) actualizes the global graph
+ gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
+ increment = tf.cumsum(num_atoms[:-1])
+ increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
+ pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
+ pair_indices = pair_indices + increment[:, tf.newaxis]
+ atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
+ bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
+
+ return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch
+
+
+def MPNNDataset(X, y, batch_size=32, shuffle=False):
+ dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
+ if shuffle:
+ dataset = dataset.shuffle(1024)
+ return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1)
+
+
+"""
+## Model
+
+The MPNN model can take on various shapes and forms. In this tutorial, we will implement an
+MPNN based on the original paper
+[Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) and
+[DeepChem's MPNNModel](https://deepchem.readthedocs.io/en/latest/api_reference/models.html#mpnnmodel).
+The MPNN of this tutorial consists of three stages: message passing, readout and
+classification.
+
+
+### Message passing
+
+The message passing step itself consists of two parts:
+
+1. The *edge network*, which passes messages from 1-hop neighbors `w_{i}` of `v`
+to `v`, based on the edge features between them (`e_{vw_{i}}`),
+resulting in an updated node (state) `v'`. `w_{i}` denotes the `i:th` neighbor of
+`v`.
+
+2. The *gated recurrent unit* (GRU), which takes as input the most recent node state
+and updates it based on previous node states. In
+other words, the most recent node state serves as the input to the GRU, while the previous
+node states are incorporated within the memory state of the GRU. This allows information
+to travel from one node state (e.g., `v`) to another (e.g., `v''`).
+
+Importantly, step (1) and (2) are repeated for `k steps`, and where at each step `1...k`,
+the radius (or number of hops) of aggregated information from `v` increases by 1.
+"""
+
+
+class EdgeNetwork(layers.Layer):
+ def build(self, input_shape):
+ self.atom_dim = input_shape[0][-1]
+ self.bond_dim = input_shape[1][-1]
+ self.kernel = self.add_weight(
+ shape=(self.bond_dim, self.atom_dim * self.atom_dim),
+ initializer="glorot_uniform",
+ name="kernel",
+ )
+ self.bias = self.add_weight(
+ shape=(self.atom_dim * self.atom_dim),
+ initializer="zeros",
+ name="bias",
+ )
+ self.built = True
+
+ def call(self, inputs):
+ atom_features, bond_features, pair_indices = inputs
+
+ # Apply linear transformation to bond features
+ bond_features = tf.matmul(bond_features, self.kernel) + self.bias
+
+ # Reshape for neighborhood aggregation later
+ bond_features = tf.reshape(bond_features, (-1, self.atom_dim, self.atom_dim))
+
+ # Obtain atom features of neighbors
+ atom_features_neighbors = tf.gather(atom_features, pair_indices[:, 1])
+ atom_features_neighbors = tf.expand_dims(atom_features_neighbors, axis=-1)
+
+ # Apply neighborhood aggregation
+ transformed_features = tf.matmul(bond_features, atom_features_neighbors)
+ transformed_features = tf.squeeze(transformed_features, axis=-1)
+ aggregated_features = tf.math.unsorted_segment_sum(
+ transformed_features,
+ pair_indices[:, 0],
+ num_segments=tf.shape(atom_features)[0],
+ )
+ return aggregated_features
+
+
+class MessagePassing(layers.Layer):
+ def __init__(self, units, steps=4, **kwargs):
+ super().__init__(**kwargs)
+ self.units = units
+ self.steps = steps
+
+ def build(self, input_shape):
+ self.atom_dim = input_shape[0][-1]
+ self.message_step = EdgeNetwork()
+ self.pad_length = max(0, self.units - self.atom_dim)
+ self.update_step = layers.GRUCell(self.atom_dim + self.pad_length)
+ self.built = True
+
+ def call(self, inputs):
+ atom_features, bond_features, pair_indices = inputs
+
+ # Pad atom features if number of desired units exceeds atom_features dim.
+ # Alternatively, a dense layer could be used here.
+ atom_features_updated = tf.pad(atom_features, [(0, 0), (0, self.pad_length)])
+
+ # Perform a number of steps of message passing
+ for i in range(self.steps):
+ # Aggregate information from neighbors
+ atom_features_aggregated = self.message_step(
+ [atom_features_updated, bond_features, pair_indices]
+ )
+
+ # Update node state via a step of GRU
+ atom_features_updated, _ = self.update_step(
+ atom_features_aggregated, atom_features_updated
+ )
+ return atom_features_updated
+
+
+"""
+### Readout
+
+When the message passing procedure ends, the k-step-aggregated node states are to be partitioned
+into subgraphs (corresponding to each molecule in the batch) and subsequently
+reduced to graph-level embeddings. In the
+[original paper](https://arxiv.org/abs/1704.01212), a
+[set-to-set layer](https://arxiv.org/abs/1511.06391) was used for this purpose.
+In this tutorial however, a transformer encoder + average pooling will be used. Specifically:
+
+* the k-step-aggregated node states will be partitioned into the subgraphs
+(corresponding to each molecule in the batch);
+* each subgraph will then be padded to match the subgraph with the greatest number of nodes, followed
+by a `tf.stack(...)`;
+* the (stacked padded) tensor, encoding subgraphs (each subgraph containing a set of node states), are
+masked to make sure the paddings don't interfere with training;
+* finally, the tensor is passed to the transformer followed by average pooling.
+"""
+
+
+class PartitionPadding(layers.Layer):
+ def __init__(self, batch_size, **kwargs):
+ super().__init__(**kwargs)
+ self.batch_size = batch_size
+
+ def call(self, inputs):
+ atom_features, molecule_indicator = inputs
+
+ # Obtain subgraphs
+ atom_features_partitioned = tf.dynamic_partition(
+ atom_features, molecule_indicator, self.batch_size
+ )
+
+ # Pad and stack subgraphs
+ num_atoms = [tf.shape(f)[0] for f in atom_features_partitioned]
+ max_num_atoms = tf.reduce_max(num_atoms)
+ atom_features_stacked = tf.stack(
+ [
+ tf.pad(f, [(0, max_num_atoms - n), (0, 0)])
+ for f, n in zip(atom_features_partitioned, num_atoms)
+ ],
+ axis=0,
+ )
+
+ # Remove empty subgraphs (usually for last batch in dataset)
+ gather_indices = tf.where(tf.reduce_sum(atom_features_stacked, (1, 2)) != 0)
+ gather_indices = tf.squeeze(gather_indices, axis=-1)
+ return tf.gather(atom_features_stacked, gather_indices, axis=0)
+
+
+class TransformerEncoderReadout(layers.Layer):
+ def __init__(
+ self, num_heads=8, embed_dim=64, dense_dim=512, batch_size=32, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.partition_padding = PartitionPadding(batch_size)
+ self.attention = layers.MultiHeadAttention(num_heads, embed_dim)
+ self.dense_proj = keras.Sequential(
+ [
+ layers.Dense(dense_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+ self.average_pooling = layers.GlobalAveragePooling1D()
+
+ def call(self, inputs):
+ x = self.partition_padding(inputs)
+ padding_mask = tf.reduce_any(tf.not_equal(x, 0.0), axis=-1)
+ padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
+ attention_output = self.attention(x, x, attention_mask=padding_mask)
+ proj_input = self.layernorm_1(x + attention_output)
+ proj_output = self.layernorm_2(proj_input + self.dense_proj(proj_input))
+ return self.average_pooling(proj_output)
+
+
+"""
+### Message Passing Neural Network (MPNN)
+
+It is now time to complete the MPNN model. In addition to the message passing
+and readout, a two-layer classification network will be implemented to make
+predictions of BBBP.
+"""
+
+
+def MPNNModel(
+ atom_dim,
+ bond_dim,
+ batch_size=32,
+ message_units=64,
+ message_steps=4,
+ num_attention_heads=8,
+ dense_units=512,
+):
+ atom_features = layers.Input((atom_dim), dtype="float32", name="atom_features")
+ bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features")
+ pair_indices = layers.Input((2), dtype="int32", name="pair_indices")
+ molecule_indicator = layers.Input((), dtype="int32", name="molecule_indicator")
+
+ x = MessagePassing(message_units, message_steps)(
+ [atom_features, bond_features, pair_indices]
+ )
+
+ x = TransformerEncoderReadout(
+ num_attention_heads, message_units, dense_units, batch_size
+ )([x, molecule_indicator])
+
+ x = layers.Dense(dense_units, activation="relu")(x)
+ x = layers.Dense(1, activation="sigmoid")(x)
+
+ model = keras.Model(
+ inputs=[atom_features, bond_features, pair_indices, molecule_indicator],
+ outputs=[x],
+ )
+ return model
+
+
+mpnn = MPNNModel(
+ atom_dim=x_train[0][0][0].shape[0],
+ bond_dim=x_train[1][0][0].shape[0],
+)
+
+mpnn.compile(
+ loss=keras.losses.BinaryCrossentropy(),
+ optimizer=keras.optimizers.Adam(learning_rate=5e-4),
+ metrics=[keras.metrics.AUC(name="AUC")],
+)
+
+keras.utils.plot_model(mpnn, show_dtype=True, show_shapes=True)
+
+"""
+### Training
+"""
+
+train_dataset = MPNNDataset(x_train, y_train)
+valid_dataset = MPNNDataset(x_valid, y_valid)
+test_dataset = MPNNDataset(x_test, y_test)
+
+history = mpnn.fit(
+ train_dataset,
+ validation_data=valid_dataset,
+ epochs=40,
+ verbose=2,
+ class_weight={0: 2.0, 1: 0.5},
+)
+
+plt.figure(figsize=(10, 6))
+plt.plot(history.history["AUC"], label="train AUC")
+plt.plot(history.history["val_AUC"], label="valid AUC")
+plt.xlabel("Epochs", fontsize=16)
+plt.ylabel("AUC", fontsize=16)
+plt.legend(fontsize=16)
+
+"""
+### Predicting
+"""
+
+molecules = [molecule_from_smiles(df.smiles.values[index]) for index in test_index]
+y_true = [df.p_np.values[index] for index in test_index]
+y_pred = tf.squeeze(mpnn.predict(test_dataset), axis=1)
+
+legends = [f"y_true/y_pred = {y_true[i]}/{y_pred[i]:.2f}" for i in range(len(y_true))]
+MolsToGridImage(molecules, molsPerRow=4, legends=legends)
+
+"""
+## Conclusions
+
+In this tutorial, we demonstrated a message passing neural network (MPNN) to
+predict blood-brain barrier permeability (BBBP) for a number of different molecules. We
+first had to construct graphs from SMILES, then build a Keras model that could
+operate on these graphs, and finally train the model to make the predictions.
+
+Example available on HuggingFace
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/MPNN-for-molecular-property-prediction) | [](https://huggingface.co/spaces/keras-io/molecular-property-prediction) |
+"""
diff --git a/knowledge_base/graph/node2vec_movielens.py b/knowledge_base/graph/node2vec_movielens.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e9972f92ae9769f2e0f9aec6d3f9b8cfd2384e5
--- /dev/null
+++ b/knowledge_base/graph/node2vec_movielens.py
@@ -0,0 +1,599 @@
+"""
+Title: Graph representation learning with node2vec
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/05/15
+Last modified: 2021/05/15
+Description: Implementing the node2vec model to generate embeddings for movies from the MovieLens dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Learning useful representations from objects structured as graphs is useful for
+a variety of machine learning (ML) applicationsโsuch as social and communication networks analysis,
+biomedicine studies, and recommendation systems.
+[Graph representation Learning](https://www.cs.mcgill.ca/~wlh/grl_book/) aims to
+learn embeddings for the graph nodes, which can be used for a variety of ML tasks
+such as node label prediction (e.g. categorizing an article based on its citations)
+and link prediction (e.g. recommending an interest group to a user in a social network).
+
+[node2vec](https://arxiv.org/abs/1607.00653) is a simple, yet scalable and effective
+technique for learning low-dimensional embeddings for nodes in a graph by optimizing
+a neighborhood-preserving objective. The aim is to learn similar embeddings for
+neighboring nodes, with respect to the graph structure.
+
+Given your data items structured as a graph (where the items are represented as
+nodes and the relationship between items are represented as edges),
+node2vec works as follows:
+
+1. Generate item sequences using (biased) random walk.
+2. Create positive and negative training examples from these sequences.
+3. Train a [word2vec](https://www.tensorflow.org/tutorials/text/word2vec) model
+(skip-gram) to learn embeddings for the items.
+
+In this example, we demonstrate the node2vec technique on the
+[small version of the Movielens dataset](https://files.grouplens.org/datasets/movielens/ml-latest-small-README.html)
+to learn movie embeddings. Such a dataset can be represented as a graph by treating
+the movies as nodes, and creating edges between movies that have similar ratings
+by the users. The learnt movie embeddings can be used for tasks such as movie recommendation,
+or movie genres prediction.
+
+This example requires `networkx` package, which can be installed using the following command:
+
+```shell
+pip install networkx
+```
+"""
+
+"""
+## Setup
+"""
+
+import os
+from collections import defaultdict
+import math
+import networkx as nx
+import random
+from tqdm import tqdm
+from zipfile import ZipFile
+from urllib.request import urlretrieve
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import matplotlib.pyplot as plt
+
+"""
+## Download the MovieLens dataset and prepare the data
+
+The small version of the MovieLens dataset includes around 100k ratings
+from 610 users on 9,742 movies.
+
+First, let's download the dataset. The downloaded folder will contain
+three data files: `users.csv`, `movies.csv`, and `ratings.csv`. In this example,
+we will only need the `movies.dat`, and `ratings.dat` data files.
+"""
+
+urlretrieve(
+ "http://files.grouplens.org/datasets/movielens/ml-latest-small.zip", "movielens.zip"
+)
+ZipFile("movielens.zip", "r").extractall()
+
+"""
+Then, we load the data into a Pandas DataFrame and perform some basic preprocessing.
+"""
+
+# Load movies to a DataFrame.
+movies = pd.read_csv("ml-latest-small/movies.csv")
+# Create a `movieId` string.
+movies["movieId"] = movies["movieId"].apply(lambda x: f"movie_{x}")
+
+# Load ratings to a DataFrame.
+ratings = pd.read_csv("ml-latest-small/ratings.csv")
+# Convert the `ratings` to floating point
+ratings["rating"] = ratings["rating"].apply(lambda x: float(x))
+# Create the `movie_id` string.
+ratings["movieId"] = ratings["movieId"].apply(lambda x: f"movie_{x}")
+
+print("Movies data shape:", movies.shape)
+print("Ratings data shape:", ratings.shape)
+
+"""
+Let's inspect a sample instance of the `ratings` DataFrame.
+"""
+
+ratings.head()
+
+"""
+Next, let's check a sample instance of the `movies` DataFrame.
+"""
+
+movies.head()
+
+"""
+Implement two utility functions for the `movies` DataFrame.
+"""
+
+
+def get_movie_title_by_id(movieId):
+ return list(movies[movies.movieId == movieId].title)[0]
+
+
+def get_movie_id_by_title(title):
+ return list(movies[movies.title == title].movieId)[0]
+
+
+"""
+## Construct the Movies graph
+
+We create an edge between two movie nodes in the graph if both movies are rated
+by the same user >= `min_rating`. The weight of the edge will be based on the
+[pointwise mutual information](https://en.wikipedia.org/wiki/Pointwise_mutual_information)
+between the two movies, which is computed as: `log(xy) - log(x) - log(y) + log(D)`, where:
+
+* `xy` is how many users rated both movie `x` and movie `y` with >= `min_rating`.
+* `x` is how many users rated movie `x` >= `min_rating`.
+* `y` is how many users rated movie `y` >= `min_rating`.
+* `D` total number of movie ratings >= `min_rating`.
+"""
+
+"""
+### Step 1: create the weighted edges between movies.
+"""
+
+min_rating = 5
+pair_frequency = defaultdict(int)
+item_frequency = defaultdict(int)
+
+# Filter instances where rating is greater than or equal to min_rating.
+rated_movies = ratings[ratings.rating >= min_rating]
+# Group instances by user.
+movies_grouped_by_users = list(rated_movies.groupby("userId"))
+for group in tqdm(
+ movies_grouped_by_users,
+ position=0,
+ leave=True,
+ desc="Compute movie rating frequencies",
+):
+ # Get a list of movies rated by the user.
+ current_movies = list(group[1]["movieId"])
+
+ for i in range(len(current_movies)):
+ item_frequency[current_movies[i]] += 1
+ for j in range(i + 1, len(current_movies)):
+ x = min(current_movies[i], current_movies[j])
+ y = max(current_movies[i], current_movies[j])
+ pair_frequency[(x, y)] += 1
+
+"""
+### Step 2: create the graph with the nodes and the edges
+
+To reduce the number of edges between nodes, we only add an edge between movies
+if the weight of the edge is greater than `min_weight`.
+"""
+
+min_weight = 10
+D = math.log(sum(item_frequency.values()))
+
+# Create the movies undirected graph.
+movies_graph = nx.Graph()
+# Add weighted edges between movies.
+# This automatically adds the movie nodes to the graph.
+for pair in tqdm(
+ pair_frequency, position=0, leave=True, desc="Creating the movie graph"
+):
+ x, y = pair
+ xy_frequency = pair_frequency[pair]
+ x_frequency = item_frequency[x]
+ y_frequency = item_frequency[y]
+ pmi = math.log(xy_frequency) - math.log(x_frequency) - math.log(y_frequency) + D
+ weight = pmi * xy_frequency
+ # Only include edges with weight >= min_weight.
+ if weight >= min_weight:
+ movies_graph.add_edge(x, y, weight=weight)
+
+"""
+Let's display the total number of nodes and edges in the graph.
+Note that the number of nodes is less than the total number of movies,
+since only the movies that have edges to other movies are added.
+"""
+
+print("Total number of graph nodes:", movies_graph.number_of_nodes())
+print("Total number of graph edges:", movies_graph.number_of_edges())
+
+"""
+Let's display the average node degree (number of neighbours) in the graph.
+"""
+
+degrees = []
+for node in movies_graph.nodes:
+ degrees.append(movies_graph.degree[node])
+
+print("Average node degree:", round(sum(degrees) / len(degrees), 2))
+
+"""
+### Step 3: Create vocabulary and a mapping from tokens to integer indices
+
+The vocabulary is the nodes (movie IDs) in the graph.
+"""
+
+vocabulary = ["NA"] + list(movies_graph.nodes)
+vocabulary_lookup = {token: idx for idx, token in enumerate(vocabulary)}
+
+"""
+## Implement the biased random walk
+
+A random walk starts from a given node, and randomly picks a neighbour node to move to.
+If the edges are weighted, the neighbour is selected *probabilistically* with
+respect to weights of the edges between the current node and its neighbours.
+This procedure is repeated for `num_steps` to generate a sequence of *related* nodes.
+
+The [*biased* random walk](https://en.wikipedia.org/wiki/Biased_random_walk_on_a_graph) balances between **breadth-first sampling**
+(where only local neighbours are visited) and **depth-first sampling**
+(where distant neighbours are visited) by introducing the following two parameters:
+
+1. **Return parameter** (`p`): Controls the likelihood of immediately revisiting
+a node in the walk. Setting it to a high value encourages moderate exploration,
+while setting it to a low value would keep the walk local.
+2. **In-out parameter** (`q`): Allows the search to differentiate
+between *inward* and *outward* nodes. Setting it to a high value biases the
+random walk towards local nodes, while setting it to a low value biases the walk
+to visit nodes which are further away.
+
+"""
+
+
+def next_step(graph, previous, current, p, q):
+ neighbors = list(graph.neighbors(current))
+
+ weights = []
+ # Adjust the weights of the edges to the neighbors with respect to p and q.
+ for neighbor in neighbors:
+ if neighbor == previous:
+ # Control the probability to return to the previous node.
+ weights.append(graph[current][neighbor]["weight"] / p)
+ elif graph.has_edge(neighbor, previous):
+ # The probability of visiting a local node.
+ weights.append(graph[current][neighbor]["weight"])
+ else:
+ # Control the probability to move forward.
+ weights.append(graph[current][neighbor]["weight"] / q)
+
+ # Compute the probabilities of visiting each neighbor.
+ weight_sum = sum(weights)
+ probabilities = [weight / weight_sum for weight in weights]
+ # Probabilistically select a neighbor to visit.
+ next = np.random.choice(neighbors, size=1, p=probabilities)[0]
+ return next
+
+
+def random_walk(graph, num_walks, num_steps, p, q):
+ walks = []
+ nodes = list(graph.nodes())
+ # Perform multiple iterations of the random walk.
+ for walk_iteration in range(num_walks):
+ random.shuffle(nodes)
+
+ for node in tqdm(
+ nodes,
+ position=0,
+ leave=True,
+ desc=f"Random walks iteration {walk_iteration + 1} of {num_walks}",
+ ):
+ # Start the walk with a random node from the graph.
+ walk = [node]
+ # Randomly walk for num_steps.
+ while len(walk) < num_steps:
+ current = walk[-1]
+ previous = walk[-2] if len(walk) > 1 else None
+ # Compute the next node to visit.
+ next = next_step(graph, previous, current, p, q)
+ walk.append(next)
+ # Replace node ids (movie ids) in the walk with token ids.
+ walk = [vocabulary_lookup[token] for token in walk]
+ # Add the walk to the generated sequence.
+ walks.append(walk)
+
+ return walks
+
+
+"""
+## Generate training data using the biased random walk
+
+You can explore different configurations of `p` and `q` to different results of
+related movies.
+"""
+# Random walk return parameter.
+p = 1
+# Random walk in-out parameter.
+q = 1
+# Number of iterations of random walks.
+num_walks = 5
+# Number of steps of each random walk.
+num_steps = 10
+walks = random_walk(movies_graph, num_walks, num_steps, p, q)
+
+print("Number of walks generated:", len(walks))
+
+"""
+## Generate positive and negative examples
+
+To train a skip-gram model, we use the generated walks to create positive and
+negative training examples. Each example includes the following features:
+
+1. `target`: A movie in a walk sequence.
+2. `context`: Another movie in a walk sequence.
+3. `weight`: How many times these two movies occurred in walk sequences.
+4. `label`: The label is 1 if these two movies are samples from the walk sequences,
+otherwise (i.e., if randomly sampled) the label is 0.
+"""
+
+"""
+### Generate examples
+"""
+
+
+def generate_examples(sequences, window_size, num_negative_samples, vocabulary_size):
+ example_weights = defaultdict(int)
+ # Iterate over all sequences (walks).
+ for sequence in tqdm(
+ sequences,
+ position=0,
+ leave=True,
+ desc=f"Generating positive and negative examples",
+ ):
+ # Generate positive and negative skip-gram pairs for a sequence (walk).
+ pairs, labels = keras.preprocessing.sequence.skipgrams(
+ sequence,
+ vocabulary_size=vocabulary_size,
+ window_size=window_size,
+ negative_samples=num_negative_samples,
+ )
+ for idx in range(len(pairs)):
+ pair = pairs[idx]
+ label = labels[idx]
+ target, context = min(pair[0], pair[1]), max(pair[0], pair[1])
+ if target == context:
+ continue
+ entry = (target, context, label)
+ example_weights[entry] += 1
+
+ targets, contexts, labels, weights = [], [], [], []
+ for entry in example_weights:
+ weight = example_weights[entry]
+ target, context, label = entry
+ targets.append(target)
+ contexts.append(context)
+ labels.append(label)
+ weights.append(weight)
+
+ return np.array(targets), np.array(contexts), np.array(labels), np.array(weights)
+
+
+num_negative_samples = 4
+targets, contexts, labels, weights = generate_examples(
+ sequences=walks,
+ window_size=num_steps,
+ num_negative_samples=num_negative_samples,
+ vocabulary_size=len(vocabulary),
+)
+
+"""
+Let's display the shapes of the outputs
+"""
+
+print(f"Targets shape: {targets.shape}")
+print(f"Contexts shape: {contexts.shape}")
+print(f"Labels shape: {labels.shape}")
+print(f"Weights shape: {weights.shape}")
+
+"""
+### Convert the data into `tf.data.Dataset` objects
+"""
+
+batch_size = 1024
+
+
+def create_dataset(targets, contexts, labels, weights, batch_size):
+ inputs = {
+ "target": targets,
+ "context": contexts,
+ }
+ dataset = tf.data.Dataset.from_tensor_slices((inputs, labels, weights))
+ dataset = dataset.shuffle(buffer_size=batch_size * 2)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+dataset = create_dataset(
+ targets=targets,
+ contexts=contexts,
+ labels=labels,
+ weights=weights,
+ batch_size=batch_size,
+)
+
+"""
+## Train the skip-gram model
+
+Our skip-gram is a simple binary classification model that works as follows:
+
+1. An embedding is looked up for the `target` movie.
+2. An embedding is looked up for the `context` movie.
+3. The dot product is computed between these two embeddings.
+4. The result (after a sigmoid activation) is compared to the label.
+5. A binary crossentropy loss is used.
+"""
+
+learning_rate = 0.001
+embedding_dim = 50
+num_epochs = 10
+
+"""
+### Implement the model
+"""
+
+
+def create_model(vocabulary_size, embedding_dim):
+ inputs = {
+ "target": layers.Input(name="target", shape=(), dtype="int32"),
+ "context": layers.Input(name="context", shape=(), dtype="int32"),
+ }
+ # Initialize item embeddings.
+ embed_item = layers.Embedding(
+ input_dim=vocabulary_size,
+ output_dim=embedding_dim,
+ embeddings_initializer="he_normal",
+ embeddings_regularizer=keras.regularizers.l2(1e-6),
+ name="item_embeddings",
+ )
+ # Lookup embeddings for target.
+ target_embeddings = embed_item(inputs["target"])
+ # Lookup embeddings for context.
+ context_embeddings = embed_item(inputs["context"])
+ # Compute dot similarity between target and context embeddings.
+ logits = layers.Dot(axes=1, normalize=False, name="dot_similarity")(
+ [target_embeddings, context_embeddings]
+ )
+ # Create the model.
+ model = keras.Model(inputs=inputs, outputs=logits)
+ return model
+
+
+"""
+### Train the model
+"""
+
+"""
+We instantiate the model and compile it.
+"""
+
+model = create_model(len(vocabulary), embedding_dim)
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=keras.losses.BinaryCrossentropy(from_logits=True),
+)
+
+"""
+Let's plot the model.
+"""
+
+keras.utils.plot_model(
+ model,
+ show_shapes=True,
+ show_dtype=True,
+ show_layer_names=True,
+)
+
+"""
+Now we train the model on the `dataset`.
+"""
+
+history = model.fit(dataset, epochs=num_epochs)
+
+"""
+Finally we plot the learning history.
+"""
+
+plt.plot(history.history["loss"])
+plt.ylabel("loss")
+plt.xlabel("epoch")
+plt.show()
+
+"""
+## Analyze the learnt embeddings.
+"""
+
+movie_embeddings = model.get_layer("item_embeddings").get_weights()[0]
+print("Embeddings shape:", movie_embeddings.shape)
+
+"""
+### Find related movies
+
+Define a list with some movies called `query_movies`.
+"""
+
+query_movies = [
+ "Matrix, The (1999)",
+ "Star Wars: Episode IV - A New Hope (1977)",
+ "Lion King, The (1994)",
+ "Terminator 2: Judgment Day (1991)",
+ "Godfather, The (1972)",
+]
+
+"""
+Get the embeddings of the movies in `query_movies`.
+"""
+
+query_embeddings = []
+
+for movie_title in query_movies:
+ movieId = get_movie_id_by_title(movie_title)
+ token_id = vocabulary_lookup[movieId]
+ movie_embedding = movie_embeddings[token_id]
+ query_embeddings.append(movie_embedding)
+
+query_embeddings = np.array(query_embeddings)
+
+"""
+Compute the [consine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) between the embeddings of `query_movies`
+and all the other movies, then pick the top k for each.
+"""
+
+similarities = tf.linalg.matmul(
+ tf.math.l2_normalize(query_embeddings),
+ tf.math.l2_normalize(movie_embeddings),
+ transpose_b=True,
+)
+
+_, indices = tf.math.top_k(similarities, k=5)
+indices = indices.numpy().tolist()
+
+"""
+Display the top related movies in `query_movies`.
+"""
+
+for idx, title in enumerate(query_movies):
+ print(title)
+ print("".rjust(len(title), "-"))
+ similar_tokens = indices[idx]
+ for token in similar_tokens:
+ similar_movieId = vocabulary[token]
+ similar_title = get_movie_title_by_id(similar_movieId)
+ print(f"- {similar_title}")
+ print()
+
+"""
+### Visualize the embeddings using the Embedding Projector
+"""
+
+import io
+
+out_v = io.open("embeddings.tsv", "w", encoding="utf-8")
+out_m = io.open("metadata.tsv", "w", encoding="utf-8")
+
+for idx, movie_id in enumerate(vocabulary[1:]):
+ movie_title = list(movies[movies.movieId == movie_id].title)[0]
+ vector = movie_embeddings[idx]
+ out_v.write("\t".join([str(x) for x in vector]) + "\n")
+ out_m.write(movie_title + "\n")
+
+out_v.close()
+out_m.close()
+
+"""
+Download the `embeddings.tsv` and `metadata.tsv` to analyze the obtained embeddings
+in the [Embedding Projector](https://projector.tensorflow.org/).
+"""
+
+"""
+
+**Example available on HuggingFace**
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/Node2Vec_MovieLens) | [](https://huggingface.co/spaces/keras-io/Node2Vec_MovieLens) |
+"""
diff --git a/knowledge_base/nlp/abstractive_summarization_with_bart.py b/knowledge_base/nlp/abstractive_summarization_with_bart.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1c2d495a62acaa7b6cd0ee0510ff007bd89a64f
--- /dev/null
+++ b/knowledge_base/nlp/abstractive_summarization_with_bart.py
@@ -0,0 +1,240 @@
+"""
+Title: Abstractive Text Summarization with BART
+Author: [Abheesht Sharma](https://github.com/abheesht17/)
+Date created: 2023/07/08
+Last modified: 2024/03/20
+Description: Use KerasHub to fine-tune BART on the abstractive summarization task.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+In the era of information overload, it has become crucial to extract the crux
+of a long document or a conversation and express it in a few sentences. Owing
+to the fact that summarization has widespread applications in different domains,
+it has become a key, well-studied NLP task in recent years.
+
+[Bidirectional Autoregressive Transformer (BART)](https://arxiv.org/abs/1910.13461)
+is a Transformer-based encoder-decoder model, often used for
+sequence-to-sequence tasks like summarization and neural machine translation.
+BART is pre-trained in a self-supervised fashion on a large text corpus. During
+pre-training, the text is corrupted and BART is trained to reconstruct the
+original text (hence called a "denoising autoencoder"). Some pre-training tasks
+include token masking, token deletion, sentence permutation (shuffle sentences
+and train BART to fix the order), etc.
+
+In this example, we will demonstrate how to fine-tune BART on the abstractive
+summarization task (on conversations!) using KerasHub, and generate summaries
+using the fine-tuned model.
+"""
+
+"""
+## Setup
+
+Before we start implementing the pipeline, let's install and import all the
+libraries we need. We'll be using the KerasHub library. We will also need a
+couple of utility libraries.
+"""
+
+"""shell
+pip install git+https://github.com/keras-team/keras-hub.git py7zr -q
+"""
+
+"""
+This examples uses [Keras 3](https://keras.io/keras_3/) to work in any of
+`"tensorflow"`, `"jax"` or `"torch"`. Support for Keras 3 is baked into
+KerasHub, simply change the `"KERAS_BACKEND"` environment variable to select
+the backend of your choice. We select the JAX backend below.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+
+"""
+Import all necessary libraries.
+"""
+
+import py7zr
+import time
+
+import keras_hub
+import keras
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+"""
+Let's also define our hyperparameters.
+"""
+
+BATCH_SIZE = 8
+NUM_BATCHES = 600
+EPOCHS = 1 # Can be set to a higher value for better results
+MAX_ENCODER_SEQUENCE_LENGTH = 512
+MAX_DECODER_SEQUENCE_LENGTH = 128
+MAX_GENERATION_LENGTH = 40
+
+"""
+## Dataset
+
+Let's load the [SAMSum dataset](https://arxiv.org/abs/1911.12237). This dataset
+contains around 15,000 pairs of conversations/dialogues and summaries.
+"""
+
+# Download the dataset.
+filename = keras.utils.get_file(
+ "corpus.7z",
+ origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
+)
+
+# Extract the `.7z` file.
+with py7zr.SevenZipFile(filename, mode="r") as z:
+ z.extractall(path="/root/tensorflow_datasets/downloads/manual")
+
+# Load data using TFDS.
+samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
+
+"""
+The dataset has two fields: `dialogue` and `summary`. Let's see a sample.
+"""
+for dialogue, summary in samsum_ds:
+ print(dialogue.numpy())
+ print(summary.numpy())
+ break
+
+"""
+We'll now batch the dataset and retain only a subset of the dataset for the
+purpose of this example. The dialogue is fed to the encoder, and the
+corresponding summary serves as input to the decoder. We will, therefore, change
+the format of the dataset to a dictionary having two keys: `"encoder_text"` and
+`"decoder_text"`.This is how `keras_hub.models.BartSeq2SeqLMPreprocessor`
+expects the input format to be.
+"""
+
+train_ds = (
+ samsum_ds.map(
+ lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
+ )
+ .batch(BATCH_SIZE)
+ .cache()
+)
+train_ds = train_ds.take(NUM_BATCHES)
+
+"""
+## Fine-tune BART
+
+Let's load the model and preprocessor first. We use sequence lengths of 512
+and 128 for the encoder and decoder, respectively, instead of 1024 (which is the
+default sequence length). This will allow us to run this example quickly
+on Colab.
+
+If you observe carefully, the preprocessor is attached to the model. What this
+means is that we don't have to worry about preprocessing the text inputs;
+everything will be done internally. The preprocessor tokenizes the encoder text
+and the decoder text, adds special tokens and pads them. To generate labels
+for auto-regressive training, the preprocessor shifts the decoder text one
+position to the right. This is done because at every timestep, the model is
+trained to predict the next token.
+"""
+
+preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
+ "bart_base_en",
+ encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
+ decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
+)
+bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
+ "bart_base_en", preprocessor=preprocessor
+)
+
+bart_lm.summary()
+
+"""
+Define the optimizer and loss. We use the Adam optimizer with a linearly
+decaying learning rate. Compile the model.
+"""
+
+optimizer = keras.optimizers.AdamW(
+ learning_rate=5e-5,
+ weight_decay=0.01,
+ epsilon=1e-6,
+ global_clipnorm=1.0, # Gradient clipping.
+)
+# Exclude layernorm and bias terms from weight decay.
+optimizer.exclude_from_weight_decay(var_names=["bias"])
+optimizer.exclude_from_weight_decay(var_names=["gamma"])
+optimizer.exclude_from_weight_decay(var_names=["beta"])
+
+loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+
+bart_lm.compile(
+ optimizer=optimizer,
+ loss=loss,
+ weighted_metrics=["accuracy"],
+)
+
+"""
+Let's train the model!
+"""
+
+bart_lm.fit(train_ds, epochs=EPOCHS)
+
+"""
+## Generate summaries and evaluate them!
+
+Now that the model has been trained, let's get to the fun part - actually
+generating summaries! Let's pick the first 100 samples from the validation set
+and generate summaries for them. We will use the default decoding strategy, i.e.,
+greedy search.
+
+Generation in KerasHub is highly optimized. It is backed by the power of XLA.
+Secondly, key/value tensors in the self-attention layer and cross-attention layer
+in the decoder are cached to avoid recomputation at every timestep.
+"""
+
+
+def generate_text(model, input_text, max_length=200, print_time_taken=False):
+ start = time.time()
+ output = model.generate(input_text, max_length=max_length)
+ end = time.time()
+ print(f"Total Time Elapsed: {end - start:.2f}s")
+ return output
+
+
+# Load the dataset.
+val_ds = tfds.load("samsum", split="validation", as_supervised=True)
+val_ds = val_ds.take(100)
+
+dialogues = []
+ground_truth_summaries = []
+for dialogue, summary in val_ds:
+ dialogues.append(dialogue.numpy())
+ ground_truth_summaries.append(summary.numpy())
+
+# Let's make a dummy call - the first call to XLA generally takes a bit longer.
+_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
+
+# Generate summaries.
+generated_summaries = generate_text(
+ bart_lm,
+ val_ds.map(lambda dialogue, _: dialogue).batch(8),
+ max_length=MAX_GENERATION_LENGTH,
+ print_time_taken=True,
+)
+
+"""
+Let's see some of the summaries.
+"""
+for dialogue, generated_summary, ground_truth_summary in zip(
+ dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
+):
+ print("Dialogue:", dialogue)
+ print("Generated Summary:", generated_summary)
+ print("Ground Truth Summary:", ground_truth_summary)
+ print("=============================")
+
+"""
+The generated summaries look awesome! Not bad for a model trained only for 1
+epoch and on 5000 examples :)
+"""
diff --git a/knowledge_base/nlp/active_learning_review_classification.py b/knowledge_base/nlp/active_learning_review_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6ec81780ec02c24e31cd92ffca4437f984aaf4a
--- /dev/null
+++ b/knowledge_base/nlp/active_learning_review_classification.py
@@ -0,0 +1,524 @@
+"""
+Title: Review Classification using Active Learning
+Author: [Darshan Deshpande](https://twitter.com/getdarshan)
+Date created: 2021/10/29
+Last modified: 2024/05/08
+Description: Demonstrating the advantages of active learning through review classification.
+Accelerator: GPU
+Converted to Keras 3 by: [Sachin Prasad](https://github.com/sachinprasadhs)
+"""
+
+"""
+## Introduction
+
+With the growth of data-centric Machine Learning, Active Learning has grown in popularity
+amongst businesses and researchers. Active Learning seeks to progressively
+train ML models so that the resultant model requires lesser amount of training data to
+achieve competitive scores.
+
+The structure of an Active Learning pipeline involves a classifier and an oracle. The
+oracle is an annotator that cleans, selects, labels the data, and feeds it to the model
+when required. The oracle is a trained individual or a group of individuals that
+ensure consistency in labeling of new data.
+
+The process starts with annotating a small subset of the full dataset and training an
+initial model. The best model checkpoint is saved and then tested onย a balanced test
+set. The test set must be carefully sampled because the full training process will be
+dependent on it. Once we have the initial evaluation scores, the oracle is tasked with
+labeling more samples; the number of data points to be sampled is usually determined by
+the business requirements. After that, the newly sampled data is added to the training
+set, and the training procedure repeats. This cycle continues until either an
+acceptable score is reached or some other business metric is met.
+
+This tutorial provides a basic demonstration of how Active Learning works by
+demonstrating a ratio-based (least confidence) sampling strategy that results in lower
+overall false positive and negative rates when compared to a model trained on the entire
+dataset. This sampling falls under the domain of *uncertainty sampling*, in which new
+datasets are sampled based on the uncertainty that the model outputs for the
+corresponding label. In our example, we compare our model's false positive and false
+negative rates and annotate the new data based on their ratio.
+
+Some other sampling techniques include:
+
+1. [Committee sampling](https://www.researchgate.net/publication/51909346_Committee-Based_Sample_Selection_for_Probabilistic_Classifiers):
+Using multiple models to vote for the best data points to be sampled
+2. [Entropy reduction](https://www.researchgate.net/publication/51909346_Committee-Based_Sample_Selection_for_Probabilistic_Classifiers):
+Sampling according to an entropy threshold, selecting more of the samples that produce the highest entropy score.
+3. [Minimum margin based sampling](https://arxiv.org/abs/1906.00025v1):
+Selects data points closest to the decision boundary
+"""
+
+"""
+## Importing required libraries
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow" # @param ["tensorflow", "jax", "torch"]
+import keras
+from keras import ops
+from keras import layers
+import tensorflow_datasets as tfds
+import tensorflow as tf
+import matplotlib.pyplot as plt
+import re
+import string
+
+tfds.disable_progress_bar()
+
+"""
+## Loading and preprocessing the data
+
+We will be using the IMDB reviews dataset for our experiments. This dataset has 50,000
+reviews in total, including training and testing splits. We will merge these splits and
+sample our own, balanced training, validation and testing sets.
+"""
+
+dataset = tfds.load(
+ "imdb_reviews",
+ split="train + test",
+ as_supervised=True,
+ batch_size=-1,
+ shuffle_files=False,
+)
+reviews, labels = tfds.as_numpy(dataset)
+
+print("Total examples:", reviews.shape[0])
+
+"""
+Active learning starts with labeling a subset of data.
+For the ratio sampling technique that we will be using, we will need well-balanced training,
+validation and testing splits.
+"""
+
+val_split = 2500
+test_split = 2500
+train_split = 7500
+
+# Separating the negative and positive samples for manual stratification
+x_positives, y_positives = reviews[labels == 1], labels[labels == 1]
+x_negatives, y_negatives = reviews[labels == 0], labels[labels == 0]
+
+# Creating training, validation and testing splits
+x_val, y_val = (
+ tf.concat((x_positives[:val_split], x_negatives[:val_split]), 0),
+ tf.concat((y_positives[:val_split], y_negatives[:val_split]), 0),
+)
+x_test, y_test = (
+ tf.concat(
+ (
+ x_positives[val_split : val_split + test_split],
+ x_negatives[val_split : val_split + test_split],
+ ),
+ 0,
+ ),
+ tf.concat(
+ (
+ y_positives[val_split : val_split + test_split],
+ y_negatives[val_split : val_split + test_split],
+ ),
+ 0,
+ ),
+)
+x_train, y_train = (
+ tf.concat(
+ (
+ x_positives[val_split + test_split : val_split + test_split + train_split],
+ x_negatives[val_split + test_split : val_split + test_split + train_split],
+ ),
+ 0,
+ ),
+ tf.concat(
+ (
+ y_positives[val_split + test_split : val_split + test_split + train_split],
+ y_negatives[val_split + test_split : val_split + test_split + train_split],
+ ),
+ 0,
+ ),
+)
+
+# Remaining pool of samples are stored separately. These are only labeled as and when required
+x_pool_positives, y_pool_positives = (
+ x_positives[val_split + test_split + train_split :],
+ y_positives[val_split + test_split + train_split :],
+)
+x_pool_negatives, y_pool_negatives = (
+ x_negatives[val_split + test_split + train_split :],
+ y_negatives[val_split + test_split + train_split :],
+)
+
+# Creating TF Datasets for faster prefetching and parallelization
+train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+
+pool_negatives = tf.data.Dataset.from_tensor_slices(
+ (x_pool_negatives, y_pool_negatives)
+)
+pool_positives = tf.data.Dataset.from_tensor_slices(
+ (x_pool_positives, y_pool_positives)
+)
+
+print(f"Initial training set size: {len(train_dataset)}")
+print(f"Validation set size: {len(val_dataset)}")
+print(f"Testing set size: {len(test_dataset)}")
+print(f"Unlabeled negative pool: {len(pool_negatives)}")
+print(f"Unlabeled positive pool: {len(pool_positives)}")
+
+"""
+### Fitting the `TextVectorization` layer
+
+Since we are working with text data, we will need to encode the text strings as vectors which
+would then be passed through an `Embedding` layer. To make this tokenization process
+faster, we use the `map()` function with its parallelization functionality.
+"""
+
+
+vectorizer = layers.TextVectorization(
+ 3000, standardize="lower_and_strip_punctuation", output_sequence_length=150
+)
+# Adapting the dataset
+vectorizer.adapt(
+ train_dataset.map(lambda x, y: x, num_parallel_calls=tf.data.AUTOTUNE).batch(256)
+)
+
+
+def vectorize_text(text, label):
+ text = vectorizer(text)
+ return text, label
+
+
+train_dataset = train_dataset.map(
+ vectorize_text, num_parallel_calls=tf.data.AUTOTUNE
+).prefetch(tf.data.AUTOTUNE)
+pool_negatives = pool_negatives.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
+pool_positives = pool_positives.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
+
+val_dataset = val_dataset.batch(256).map(
+ vectorize_text, num_parallel_calls=tf.data.AUTOTUNE
+)
+test_dataset = test_dataset.batch(256).map(
+ vectorize_text, num_parallel_calls=tf.data.AUTOTUNE
+)
+
+"""
+## Creating Helper Functions
+"""
+
+
+# Helper function for merging new history objects with older ones
+def append_history(losses, val_losses, accuracy, val_accuracy, history):
+ losses = losses + history.history["loss"]
+ val_losses = val_losses + history.history["val_loss"]
+ accuracy = accuracy + history.history["binary_accuracy"]
+ val_accuracy = val_accuracy + history.history["val_binary_accuracy"]
+ return losses, val_losses, accuracy, val_accuracy
+
+
+# Plotter function
+def plot_history(losses, val_losses, accuracies, val_accuracies):
+ plt.plot(losses)
+ plt.plot(val_losses)
+ plt.legend(["train_loss", "val_loss"])
+ plt.xlabel("Epochs")
+ plt.ylabel("Loss")
+ plt.show()
+
+ plt.plot(accuracies)
+ plt.plot(val_accuracies)
+ plt.legend(["train_accuracy", "val_accuracy"])
+ plt.xlabel("Epochs")
+ plt.ylabel("Accuracy")
+ plt.show()
+
+
+"""
+## Creating the Model
+
+We create a small bidirectional LSTM model. When using Active Learning, you should make sure
+that the model architecture is capable of overfitting to the initial data.
+Overfitting gives a strong hint that the model will have enough capacity for
+future, unseen data.
+"""
+
+
+def create_model():
+ model = keras.models.Sequential(
+ [
+ layers.Input(shape=(150,)),
+ layers.Embedding(input_dim=3000, output_dim=128),
+ layers.Bidirectional(layers.LSTM(32, return_sequences=True)),
+ layers.GlobalMaxPool1D(),
+ layers.Dense(20, activation="relu"),
+ layers.Dropout(0.5),
+ layers.Dense(1, activation="sigmoid"),
+ ]
+ )
+ model.summary()
+ return model
+
+
+"""
+## Training on the entire dataset
+
+To show the effectiveness of Active Learning, we will first train the model on the entire
+dataset containing 40,000 labeled samples. This model will be used for comparison later.
+"""
+
+
+def train_full_model(full_train_dataset, val_dataset, test_dataset):
+ model = create_model()
+ model.compile(
+ loss="binary_crossentropy",
+ optimizer="rmsprop",
+ metrics=[
+ keras.metrics.BinaryAccuracy(),
+ keras.metrics.FalseNegatives(),
+ keras.metrics.FalsePositives(),
+ ],
+ )
+
+ # We will save the best model at every epoch and load the best one for evaluation on the test set
+ history = model.fit(
+ full_train_dataset.batch(256),
+ epochs=20,
+ validation_data=val_dataset,
+ callbacks=[
+ keras.callbacks.EarlyStopping(patience=4, verbose=1),
+ keras.callbacks.ModelCheckpoint(
+ "FullModelCheckpoint.keras", verbose=1, save_best_only=True
+ ),
+ ],
+ )
+
+ # Plot history
+ plot_history(
+ history.history["loss"],
+ history.history["val_loss"],
+ history.history["binary_accuracy"],
+ history.history["val_binary_accuracy"],
+ )
+
+ # Loading the best checkpoint
+ model = keras.models.load_model("FullModelCheckpoint.keras")
+
+ print("-" * 100)
+ print(
+ "Test set evaluation: ",
+ model.evaluate(test_dataset, verbose=0, return_dict=True),
+ )
+ print("-" * 100)
+ return model
+
+
+# Sampling the full train dataset to train on
+full_train_dataset = (
+ train_dataset.concatenate(pool_positives)
+ .concatenate(pool_negatives)
+ .cache()
+ .shuffle(20000)
+)
+
+# Training the full model
+full_dataset_model = train_full_model(full_train_dataset, val_dataset, test_dataset)
+
+"""
+## Training via Active Learning
+
+The general process we follow when performing Active Learning is demonstrated below:
+
+
+
+The pipeline can be summarized in five parts:
+
+1. Sample and annotate a small, balanced training dataset
+2. Train the model on this small subset
+3. Evaluate the model on a balanced testing set
+4. If the model satisfies the business criteria, deploy it in a real time setting
+5. If it doesn't pass the criteria, sample a few more samples according to the ratio of
+false positives and negatives, add them to the training set and repeat from step 2 till
+the model passes the tests or till all available data is exhausted.
+
+For the code below, we will perform sampling using the following formula:
+
+
+
+Active Learning techniques use callbacks extensively for progress tracking. We will be
+using model checkpointing and early stopping for this example. The `patience` parameter
+for Early Stopping can help minimize overfitting and the time required. We have set it
+`patience=4` for now but since the model is robust, we can increase the patience level if
+desired.
+
+Note: We are not loading the checkpoint after the first training iteration. In my
+experience working on Active Learning techniques, this helps the model probe the
+newly formed loss landscape. Even if the model fails to improve in the second iteration,
+we will still gain insight about the possible future false positive and negative rates.
+This will help us sample a better set in the next iteration where the model will have a
+greater chance to improve.
+"""
+
+
+def train_active_learning_models(
+ train_dataset,
+ pool_negatives,
+ pool_positives,
+ val_dataset,
+ test_dataset,
+ num_iterations=3,
+ sampling_size=5000,
+):
+
+ # Creating lists for storing metrics
+ losses, val_losses, accuracies, val_accuracies = [], [], [], []
+
+ model = create_model()
+ # We will monitor the false positives and false negatives predicted by our model
+ # These will decide the subsequent sampling ratio for every Active Learning loop
+ model.compile(
+ loss="binary_crossentropy",
+ optimizer="rmsprop",
+ metrics=[
+ keras.metrics.BinaryAccuracy(),
+ keras.metrics.FalseNegatives(),
+ keras.metrics.FalsePositives(),
+ ],
+ )
+
+ # Defining checkpoints.
+ # The checkpoint callback is reused throughout the training since it only saves the best overall model.
+ checkpoint = keras.callbacks.ModelCheckpoint(
+ "AL_Model.keras", save_best_only=True, verbose=1
+ )
+ # Here, patience is set to 4. This can be set higher if desired.
+ early_stopping = keras.callbacks.EarlyStopping(patience=4, verbose=1)
+
+ print(f"Starting to train with {len(train_dataset)} samples")
+ # Initial fit with a small subset of the training set
+ history = model.fit(
+ train_dataset.cache().shuffle(20000).batch(256),
+ epochs=20,
+ validation_data=val_dataset,
+ callbacks=[checkpoint, early_stopping],
+ )
+
+ # Appending history
+ losses, val_losses, accuracies, val_accuracies = append_history(
+ losses, val_losses, accuracies, val_accuracies, history
+ )
+
+ for iteration in range(num_iterations):
+ # Getting predictions from previously trained model
+ predictions = model.predict(test_dataset)
+
+ # Generating labels from the output probabilities
+ rounded = ops.where(ops.greater(predictions, 0.5), 1, 0)
+
+ # Evaluating the number of zeros and ones incorrrectly classified
+ _, _, false_negatives, false_positives = model.evaluate(test_dataset, verbose=0)
+
+ print("-" * 100)
+ print(
+ f"Number of zeros incorrectly classified: {false_negatives}, Number of ones incorrectly classified: {false_positives}"
+ )
+
+ # This technique of Active Learning demonstrates ratio based sampling where
+ # Number of ones/zeros to sample = Number of ones/zeros incorrectly classified / Total incorrectly classified
+ if false_negatives != 0 and false_positives != 0:
+ total = false_negatives + false_positives
+ sample_ratio_ones, sample_ratio_zeros = (
+ false_positives / total,
+ false_negatives / total,
+ )
+ # In the case where all samples are correctly predicted, we can sample both classes equally
+ else:
+ sample_ratio_ones, sample_ratio_zeros = 0.5, 0.5
+
+ print(
+ f"Sample ratio for positives: {sample_ratio_ones}, Sample ratio for negatives:{sample_ratio_zeros}"
+ )
+
+ # Sample the required number of ones and zeros
+ sampled_dataset = pool_negatives.take(
+ int(sample_ratio_zeros * sampling_size)
+ ).concatenate(pool_positives.take(int(sample_ratio_ones * sampling_size)))
+
+ # Skip the sampled data points to avoid repetition of sample
+ pool_negatives = pool_negatives.skip(int(sample_ratio_zeros * sampling_size))
+ pool_positives = pool_positives.skip(int(sample_ratio_ones * sampling_size))
+
+ # Concatenating the train_dataset with the sampled_dataset
+ train_dataset = train_dataset.concatenate(sampled_dataset).prefetch(
+ tf.data.AUTOTUNE
+ )
+
+ print(f"Starting training with {len(train_dataset)} samples")
+ print("-" * 100)
+
+ # We recompile the model to reset the optimizer states and retrain the model
+ model.compile(
+ loss="binary_crossentropy",
+ optimizer="rmsprop",
+ metrics=[
+ keras.metrics.BinaryAccuracy(),
+ keras.metrics.FalseNegatives(),
+ keras.metrics.FalsePositives(),
+ ],
+ )
+ history = model.fit(
+ train_dataset.cache().shuffle(20000).batch(256),
+ validation_data=val_dataset,
+ epochs=20,
+ callbacks=[
+ checkpoint,
+ keras.callbacks.EarlyStopping(patience=4, verbose=1),
+ ],
+ )
+
+ # Appending the history
+ losses, val_losses, accuracies, val_accuracies = append_history(
+ losses, val_losses, accuracies, val_accuracies, history
+ )
+
+ # Loading the best model from this training loop
+ model = keras.models.load_model("AL_Model.keras")
+
+ # Plotting the overall history and evaluating the final model
+ plot_history(losses, val_losses, accuracies, val_accuracies)
+ print("-" * 100)
+ print(
+ "Test set evaluation: ",
+ model.evaluate(test_dataset, verbose=0, return_dict=True),
+ )
+ print("-" * 100)
+
+ return model
+
+
+active_learning_model = train_active_learning_models(
+ train_dataset, pool_negatives, pool_positives, val_dataset, test_dataset
+)
+
+"""
+## Conclusion
+
+Active Learning is a growing area of research. This example demonstrates the cost-efficiency
+benefits of using Active Learning, as it eliminates the need to annotate large amounts of
+data, saving resources.
+
+The following are some noteworthy observations from this example:
+
+1. We only require 30,000 samples to reach the same (if not better) scores as the model
+trained on the full dataset. This means that in a real life setting, we save the effort
+required for annotating 10,000 images!
+2. The number of false negatives and false positives are well balanced at the end of the
+training as compared to the skewed ratio obtained from the full training. This makes the
+model slightly more useful in real life scenarios where both the labels hold equal
+importance.
+
+For further reading about the types of sampling ratios, training techniques or available
+open source libraries/implementations, you can refer to the resources below:
+
+1. [Active Learning Literature Survey](http://burrsettles.com/pub/settles.activelearning.pdf) (Burr Settles, 2010).
+2. [modAL](https://github.com/modAL-python/modAL): A Modular Active Learning framework.
+3. Google's unofficial [Active Learning playground](https://github.com/google/active-learning).
+"""
diff --git a/knowledge_base/nlp/addition_rnn.py b/knowledge_base/nlp/addition_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..409d64568e139149824a080a835aab1af88cb0e8
--- /dev/null
+++ b/knowledge_base/nlp/addition_rnn.py
@@ -0,0 +1,251 @@
+"""
+Title: Sequence to sequence learning for performing number addition
+Author: [Smerity](https://twitter.com/Smerity) and others
+Date created: 2015/08/17
+Last modified: 2024/02/13
+Description: A model that learns to add strings of numbers, e.g. "535+61" -> "596".
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we train a model to learn to add two numbers, provided as strings.
+
+**Example:**
+
+- Input: "535+61"
+- Output: "596"
+
+Input may optionally be reversed, which was shown to increase performance in many tasks
+ in: [Learning to Execute](http://arxiv.org/abs/1410.4615) and
+[Sequence to Sequence Learning with Neural Networks](http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf).
+
+Theoretically, sequence order inversion introduces shorter term dependencies between
+ source and target for this problem.
+
+**Results:**
+
+For two digits (reversed):
+
++ One layer LSTM (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
+
+Three digits (reversed):
+
++ One layer LSTM (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
+
+Four digits (reversed):
+
++ One layer LSTM (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
+
+Five digits (reversed):
+
++ One layer LSTM (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import layers
+import numpy as np
+
+# Parameters for the model and dataset.
+TRAINING_SIZE = 50000
+DIGITS = 3
+REVERSE = True
+
+# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
+# int is DIGITS.
+MAXLEN = DIGITS + 1 + DIGITS
+
+"""
+## Generate the data
+"""
+
+
+class CharacterTable:
+ """Given a set of characters:
+ + Encode them to a one-hot integer representation
+ + Decode the one-hot or integer representation to their character output
+ + Decode a vector of probabilities to their character output
+ """
+
+ def __init__(self, chars):
+ """Initialize character table.
+ # Arguments
+ chars: Characters that can appear in the input.
+ """
+ self.chars = sorted(set(chars))
+ self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
+ self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
+
+ def encode(self, C, num_rows):
+ """One-hot encode given string C.
+ # Arguments
+ C: string, to be encoded.
+ num_rows: Number of rows in the returned one-hot encoding. This is
+ used to keep the # of rows for each data the same.
+ """
+ x = np.zeros((num_rows, len(self.chars)))
+ for i, c in enumerate(C):
+ x[i, self.char_indices[c]] = 1
+ return x
+
+ def decode(self, x, calc_argmax=True):
+ """Decode the given vector or 2D array to their character output.
+ # Arguments
+ x: A vector or a 2D array of probabilities or one-hot representations;
+ or a vector of character indices (used with `calc_argmax=False`).
+ calc_argmax: Whether to find the character index with maximum
+ probability, defaults to `True`.
+ """
+ if calc_argmax:
+ x = x.argmax(axis=-1)
+ return "".join(self.indices_char[x] for x in x)
+
+
+# All the numbers, plus sign and space for padding.
+chars = "0123456789+ "
+ctable = CharacterTable(chars)
+
+questions = []
+expected = []
+seen = set()
+print("Generating data...")
+while len(questions) < TRAINING_SIZE:
+ f = lambda: int(
+ "".join(
+ np.random.choice(list("0123456789"))
+ for i in range(np.random.randint(1, DIGITS + 1))
+ )
+ )
+ a, b = f(), f()
+ # Skip any addition questions we've already seen
+ # Also skip any such that x+Y == Y+x (hence the sorting).
+ key = tuple(sorted((a, b)))
+ if key in seen:
+ continue
+ seen.add(key)
+ # Pad the data with spaces such that it is always MAXLEN.
+ q = "{}+{}".format(a, b)
+ query = q + " " * (MAXLEN - len(q))
+ ans = str(a + b)
+ # Answers can be of maximum size DIGITS + 1.
+ ans += " " * (DIGITS + 1 - len(ans))
+ if REVERSE:
+ # Reverse the query, e.g., '12+345 ' becomes ' 543+21'. (Note the
+ # space used for padding.)
+ query = query[::-1]
+ questions.append(query)
+ expected.append(ans)
+print("Total questions:", len(questions))
+
+"""
+## Vectorize the data
+"""
+
+print("Vectorization...")
+x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=bool)
+y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=bool)
+for i, sentence in enumerate(questions):
+ x[i] = ctable.encode(sentence, MAXLEN)
+for i, sentence in enumerate(expected):
+ y[i] = ctable.encode(sentence, DIGITS + 1)
+
+# Shuffle (x, y) in unison as the later parts of x will almost all be larger
+# digits.
+indices = np.arange(len(y))
+np.random.shuffle(indices)
+x = x[indices]
+y = y[indices]
+
+# Explicitly set apart 10% for validation data that we never train over.
+split_at = len(x) - len(x) // 10
+(x_train, x_val) = x[:split_at], x[split_at:]
+(y_train, y_val) = y[:split_at], y[split_at:]
+
+print("Training Data:")
+print(x_train.shape)
+print(y_train.shape)
+
+print("Validation Data:")
+print(x_val.shape)
+print(y_val.shape)
+
+"""
+## Build the model
+"""
+
+print("Build model...")
+num_layers = 1 # Try to add more LSTM layers!
+
+model = keras.Sequential()
+# "Encode" the input sequence using a LSTM, producing an output of size 128.
+# Note: In a situation where your input sequences have a variable length,
+# use input_shape=(None, num_feature).
+model.add(layers.Input((MAXLEN, len(chars))))
+model.add(layers.LSTM(128))
+# As the decoder RNN's input, repeatedly provide with the last output of
+# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
+# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
+model.add(layers.RepeatVector(DIGITS + 1))
+# The decoder RNN could be multiple layers stacked or a single layer.
+for _ in range(num_layers):
+ # By setting return_sequences to True, return not only the last output but
+ # all the outputs so far in the form of (num_samples, timesteps,
+ # output_dim). This is necessary as TimeDistributed in the below expects
+ # the first dimension to be the timesteps.
+ model.add(layers.LSTM(128, return_sequences=True))
+
+# Apply a dense layer to the every temporal slice of an input. For each of step
+# of the output sequence, decide which character should be chosen.
+model.add(layers.Dense(len(chars), activation="softmax"))
+model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+model.summary()
+
+"""
+## Train the model
+"""
+
+# Training parameters.
+epochs = 30
+batch_size = 32
+
+# Formatting characters for results display.
+green_color = "\033[92m"
+red_color = "\033[91m"
+end_char = "\033[0m"
+
+# Train the model each generation and show predictions against the validation
+# dataset.
+for epoch in range(1, epochs):
+ print()
+ print("Iteration", epoch)
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=1,
+ validation_data=(x_val, y_val),
+ )
+ # Select 10 samples from the validation set at random so we can visualize
+ # errors.
+ for i in range(10):
+ ind = np.random.randint(0, len(x_val))
+ rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
+ preds = np.argmax(model.predict(rowx, verbose=0), axis=-1)
+ q = ctable.decode(rowx[0])
+ correct = ctable.decode(rowy[0])
+ guess = ctable.decode(preds[0], calc_argmax=False)
+ print("Q", q[::-1] if REVERSE else q, end=" ")
+ print("T", correct, end=" ")
+ if correct == guess:
+ print(f"{green_color}โ {guess}{end_char}")
+ else:
+ print(f"{red_color}โ {guess}{end_char}")
+
+"""
+You'll get to 99+% validation accuracy after ~30 epochs.
+"""
diff --git a/knowledge_base/nlp/bidirectional_lstm_imdb.py b/knowledge_base/nlp/bidirectional_lstm_imdb.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52814532e9207c51925e76e04b555c9224bbbda
--- /dev/null
+++ b/knowledge_base/nlp/bidirectional_lstm_imdb.py
@@ -0,0 +1,59 @@
+"""
+Title: Bidirectional LSTM on IMDB
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/05/03
+Last modified: 2020/05/03
+Description: Train a 2-layer bidirectional LSTM on the IMDB movie review sentiment classification dataset.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import keras
+from keras import layers
+
+max_features = 20000 # Only consider the top 20k words
+maxlen = 200 # Only consider the first 200 words of each movie review
+
+"""
+## Build the model
+"""
+
+# Input for variable-length sequences of integers
+inputs = keras.Input(shape=(None,), dtype="int32")
+# Embed each integer in a 128-dimensional vector
+x = layers.Embedding(max_features, 128)(inputs)
+# Add 2 bidirectional LSTMs
+x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)
+x = layers.Bidirectional(layers.LSTM(64))(x)
+# Add a classifier
+outputs = layers.Dense(1, activation="sigmoid")(x)
+model = keras.Model(inputs, outputs)
+model.summary()
+
+"""
+## Load the IMDB movie review sentiment data
+"""
+
+(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(
+ num_words=max_features
+)
+print(len(x_train), "Training sequences")
+print(len(x_val), "Validation sequences")
+# Use pad_sequence to standardize sequence length:
+# this will truncate sequences longer than 200 words and zero-pad sequences shorter than 200 words.
+x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
+x_val = keras.utils.pad_sequences(x_val, maxlen=maxlen)
+
+"""
+## Train and evaluate the model
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/bidirectional-lstm-imdb)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/bidirectional_lstm_imdb).
+"""
+
+model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
+model.fit(x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val))
diff --git a/knowledge_base/nlp/data_parallel_training_with_keras_hub.py b/knowledge_base/nlp/data_parallel_training_with_keras_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..88912968e4175f6ab62027c8807b5e728b5ff7ef
--- /dev/null
+++ b/knowledge_base/nlp/data_parallel_training_with_keras_hub.py
@@ -0,0 +1,242 @@
+"""
+Title: Data Parallel Training with KerasHub and tf.distribute
+Author: Anshuman Mishra
+Date created: 2023/07/07
+Last modified: 2023/07/07
+Description: Data Parallel training with KerasHub and tf.distribute.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Distributed training is a technique used to train deep learning models on multiple devices
+or machines simultaneously. It helps to reduce training time and allows for training larger
+models with more data. KerasHub is a library that provides tools and utilities for natural
+language processing tasks, including distributed training.
+
+In this tutorial, we will use KerasHub to train a BERT-based masked language model (MLM)
+on the wikitext-2 dataset (a 2 million word dataset of wikipedia articles). The MLM task
+involves predicting the masked words in a sentence, which helps the model learn contextual
+representations of words.
+
+This guide focuses on data parallelism, in particular synchronous data parallelism, where
+each accelerator (a GPU or TPU) holds a complete replica of the model, and sees a
+different partial batch of the input data. Partial gradients are computed on each device,
+aggregated, and used to compute a global gradient update.
+
+Specifically, this guide teaches you how to use the `tf.distribute` API to train Keras
+models on multiple GPUs, with minimal changes to your code, in the following two setups:
+
+- On multiple GPUs (typically 2 to 8) installed on a single machine (single host,
+multi-device training). This is the most common setup for researchers and small-scale
+industry workflows.
+- On a cluster of many machines, each hosting one or multiple GPUs (multi-worker
+distributed training). This is a good setup for large-scale industry workflows, e.g.
+training high-resolution text summarization models on billion word datasets on 20-100 GPUs.
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+"""
+## Imports
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import tensorflow as tf
+import keras
+import keras_hub
+
+"""
+Before we start any training, let's configure our single GPU to show up as two logical
+devices.
+
+When you are training with two or more physical GPUs, this is totally uncessary. This
+is just a trick to show real distributed training on the default colab GPU runtime,
+which has only one GPU available.
+"""
+
+"""shell
+nvidia-smi --query-gpu=memory.total --format=csv,noheader
+"""
+
+physical_devices = tf.config.list_physical_devices("GPU")
+tf.config.set_logical_device_configuration(
+ physical_devices[0],
+ [
+ tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),
+ tf.config.LogicalDeviceConfiguration(memory_limit=15360 // 2),
+ ],
+)
+
+logical_devices = tf.config.list_logical_devices("GPU")
+logical_devices
+
+EPOCHS = 3
+
+
+"""
+To do single-host, multi-device synchronous training with a Keras model, you would use
+the `tf.distribute.MirroredStrategy` API. Here's how it works:
+
+- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you
+want to use (by default the strategy will use all GPUs available).
+- Use the strategy object to open a scope, and within this scope, create all the Keras
+objects you need that contain variables. Typically, that means **creating & compiling the
+model** inside the distribution scope.
+- Train the model via `fit()` as usual.
+"""
+strategy = tf.distribute.MirroredStrategy()
+print(f"Number of devices: {strategy.num_replicas_in_sync}")
+
+"""
+Base batch size and learning rate
+"""
+base_batch_size = 32
+base_learning_rate = 1e-4
+
+"""
+Calculate scaled batch size and learning rate
+
+"""
+scaled_batch_size = base_batch_size * strategy.num_replicas_in_sync
+scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync
+
+"""
+Now, we need to download and preprocess the wikitext-2 dataset. This dataset will be
+used for pretraining the BERT model. We will filter out short lines to ensure that the
+data has enough context for training.
+"""
+
+keras.utils.get_file(
+ origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",
+ extract=True,
+)
+wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/")
+
+# Load wikitext-103 and filter out short lines.
+wiki_train_ds = (
+ tf.data.TextLineDataset(
+ wiki_dir + "wiki.train.tokens",
+ )
+ .filter(lambda x: tf.strings.length(x) > 100)
+ .shuffle(buffer_size=500)
+ .batch(scaled_batch_size)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+wiki_val_ds = (
+ tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens")
+ .filter(lambda x: tf.strings.length(x) > 100)
+ .shuffle(buffer_size=500)
+ .batch(scaled_batch_size)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+wiki_test_ds = (
+ tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens")
+ .filter(lambda x: tf.strings.length(x) > 100)
+ .shuffle(buffer_size=500)
+ .batch(scaled_batch_size)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+"""
+In the above code, we download the wikitext-2 dataset and extract it. Then, we define
+three datasets: wiki_train_ds, wiki_val_ds, and wiki_test_ds. These datasets are
+filtered to remove short lines and are batched for efficient training.
+"""
+
+"""
+It's a common practice to use a decayed learning rate in NLP training/tuning. We'll
+use `PolynomialDecay` schedule here.
+
+"""
+
+total_training_steps = sum(1 for _ in wiki_train_ds.as_numpy_iterator()) * EPOCHS
+lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
+ initial_learning_rate=scaled_learning_rate,
+ decay_steps=total_training_steps,
+ end_learning_rate=0.0,
+)
+
+
+class PrintLR(tf.keras.callbacks.Callback):
+ def on_epoch_end(self, epoch, logs=None):
+ print(
+ f"\nLearning rate for epoch {epoch + 1} is {model_dist.optimizer.learning_rate.numpy()}"
+ )
+
+
+"""
+Let's also make a callback to TensorBoard, this will enable visualization of different
+metrics while we train the model in later part of this tutorial. We put all the callbacks
+together as follows:
+"""
+callbacks = [
+ tf.keras.callbacks.TensorBoard(log_dir="./logs"),
+ PrintLR(),
+]
+
+
+print(tf.config.list_physical_devices("GPU"))
+
+
+"""
+With the datasets prepared, we now initialize and compile our model and optimizer within
+the `strategy.scope()`:
+"""
+
+with strategy.scope():
+ # Everything that creates variables should be under the strategy scope.
+ # In general this is only model construction & `compile()`.
+ model_dist = keras_hub.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")
+
+ # This line just sets pooled_dense layer as non-trainiable, we do this to avoid
+ # warnings of this layer being unused
+ model_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = False
+
+ model_dist.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),
+ weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
+ jit_compile=False,
+ )
+
+ model_dist.fit(
+ wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks
+ )
+
+"""
+After fitting our model under the scope, we evaluate it normally!
+"""
+
+model_dist.evaluate(wiki_test_ds)
+
+"""
+For distributed training across multiple machines (as opposed to training that only leverages
+multiple devices on a single machine), there are two distribution strategies you
+could use: `MultiWorkerMirroredStrategy` and `ParameterServerStrategy`:
+
+- `tf.distribute.MultiWorkerMirroredStrategy` implements a synchronous CPU/GPU
+multi-worker solution to work with Keras-style model building and training loop,
+using synchronous reduction of gradients across the replicas.
+- `tf.distribute.experimental.ParameterServerStrategy` implements an asynchronous CPU/GPU
+multi-worker solution, where the parameters are stored on parameter servers, and
+workers update the gradients to parameter servers asynchronously.
+
+### Further reading
+
+1. [TensorFlow distributed training guide](https://www.tensorflow.org/guide/distributed_training)
+2. [Tutorial on multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
+3. [MirroredStrategy docs](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
+4. [MultiWorkerMirroredStrategy docs](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
+5. [Distributed training in tf.keras with Weights & Biases](https://towardsdatascience.com/distributed-training-in-tf-keras-with-w-b-ccf021f9322e)
+"""
diff --git a/knowledge_base/nlp/fnet_classification_with_keras_hub.py b/knowledge_base/nlp/fnet_classification_with_keras_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ca026cdb1ad16748d932abc91d25f8021e44819
--- /dev/null
+++ b/knowledge_base/nlp/fnet_classification_with_keras_hub.py
@@ -0,0 +1,366 @@
+"""
+Title: Text Classification using FNet
+Author: [Abheesht Sharma](https://github.com/abheesht17/)
+Date created: 2022/06/01
+Last modified: 2022/12/21
+Description: Text Classification on the IMDb Dataset using `keras_hub.layers.FNetEncoder` layer.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we will demonstrate the ability of FNet to achieve comparable
+results with a vanilla Transformer model on the text classification task.
+We will be using the IMDb dataset, which is a
+collection of movie reviews labelled either positive or negative (sentiment
+analysis).
+
+To build the tokenizer, model, etc., we will use components from
+[KerasHub](https://github.com/keras-team/keras-hub). KerasHub makes life easier
+for people who want to build NLP pipelines! :)
+
+### Model
+
+Transformer-based language models (LMs) such as BERT, RoBERTa, XLNet, etc. have
+demonstrated the effectiveness of the self-attention mechanism for computing
+rich embeddings for input text. However, the self-attention mechanism is an
+expensive operation, with a time complexity of `O(n^2)`, where `n` is the number
+of tokens in the input. Hence, there has been an effort to reduce the time
+complexity of the self-attention mechanism and improve performance without
+sacrificing the quality of results.
+
+In 2020, a paper titled
+[FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824)
+replaced the self-attention layer in BERT with a simple Fourier Transform layer
+for "token mixing". This resulted in comparable accuracy and a speed-up during
+training. In particular, a couple of points from the paper stand out:
+
+* The authors claim that FNet is 80% faster than BERT on GPUs and 70% faster on
+TPUs. The reason for this speed-up is two-fold: a) the Fourier Transform layer
+is unparametrized, it does not have any parameters, and b) the authors use Fast
+Fourier Transform (FFT); this reduces the time complexity from `O(n^2)`
+(in the case of self-attention) to `O(n log n)`.
+* FNet manages to achieve 92-97% of the accuracy of BERT on the GLUE benchmark.
+"""
+
+"""
+## Setup
+
+Before we start with the implementation, let's import all the necessary packages.
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+import keras_hub
+import keras
+import tensorflow as tf
+import os
+
+keras.utils.set_random_seed(42)
+
+"""
+Let's also define our hyperparameters.
+"""
+BATCH_SIZE = 64
+EPOCHS = 3
+MAX_SEQUENCE_LENGTH = 512
+VOCAB_SIZE = 15000
+
+EMBED_DIM = 128
+INTERMEDIATE_DIM = 512
+
+"""
+## Loading the dataset
+
+First, let's download the IMDB dataset and extract it.
+"""
+
+"""shell
+wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+tar -xzf aclImdb_v1.tar.gz
+"""
+
+"""
+Samples are present in the form of text files. Let's inspect the structure of
+the directory.
+"""
+
+print(os.listdir("./aclImdb"))
+print(os.listdir("./aclImdb/train"))
+print(os.listdir("./aclImdb/test"))
+
+"""
+The directory contains two sub-directories: `train` and `test`. Each subdirectory
+in turn contains two folders: `pos` and `neg` for positive and negative reviews,
+respectively. Before we load the dataset, let's delete the `./aclImdb/train/unsup`
+folder since it has unlabelled samples.
+"""
+
+"""shell
+rm -rf aclImdb/train/unsup
+"""
+
+"""
+We'll use the `keras.utils.text_dataset_from_directory` utility to generate
+our labelled `tf.data.Dataset` dataset from text files.
+"""
+
+train_ds = keras.utils.text_dataset_from_directory(
+ "aclImdb/train",
+ batch_size=BATCH_SIZE,
+ validation_split=0.2,
+ subset="training",
+ seed=42,
+)
+val_ds = keras.utils.text_dataset_from_directory(
+ "aclImdb/train",
+ batch_size=BATCH_SIZE,
+ validation_split=0.2,
+ subset="validation",
+ seed=42,
+)
+test_ds = keras.utils.text_dataset_from_directory("aclImdb/test", batch_size=BATCH_SIZE)
+
+"""
+We will now convert the text to lowercase.
+"""
+train_ds = train_ds.map(lambda x, y: (tf.strings.lower(x), y))
+val_ds = val_ds.map(lambda x, y: (tf.strings.lower(x), y))
+test_ds = test_ds.map(lambda x, y: (tf.strings.lower(x), y))
+
+"""
+Let's print a few samples.
+"""
+for text_batch, label_batch in train_ds.take(1):
+ for i in range(3):
+ print(text_batch.numpy()[i])
+ print(label_batch.numpy()[i])
+
+
+"""
+### Tokenizing the data
+
+We'll be using the `keras_hub.tokenizers.WordPieceTokenizer` layer to tokenize
+the text. `keras_hub.tokenizers.WordPieceTokenizer` takes a WordPiece vocabulary
+and has functions for tokenizing the text, and detokenizing sequences of tokens.
+
+Before we define the tokenizer, we first need to train it on the dataset
+we have. The WordPiece tokenization algorithm is a subword tokenization algorithm;
+training it on a corpus gives us a vocabulary of subwords. A subword tokenizer
+is a compromise between word tokenizers (word tokenizers need very large
+vocabularies for good coverage of input words), and character tokenizers
+(characters don't really encode meaning like words do). Luckily, KerasHub
+makes it very simple to train WordPiece on a corpus with the
+`keras_hub.tokenizers.compute_word_piece_vocabulary` utility.
+
+Note: The official implementation of FNet uses the SentencePiece Tokenizer.
+"""
+
+
+def train_word_piece(ds, vocab_size, reserved_tokens):
+ word_piece_ds = ds.unbatch().map(lambda x, y: x)
+ vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
+ word_piece_ds.batch(1000).prefetch(2),
+ vocabulary_size=vocab_size,
+ reserved_tokens=reserved_tokens,
+ )
+ return vocab
+
+
+"""
+Every vocabulary has a few special, reserved tokens. We have two such tokens:
+
+- `"[PAD]"` - Padding token. Padding tokens are appended to the input sequence length
+when the input sequence length is shorter than the maximum sequence length.
+- `"[UNK]"` - Unknown token.
+"""
+reserved_tokens = ["[PAD]", "[UNK]"]
+train_sentences = [element[0] for element in train_ds]
+vocab = train_word_piece(train_ds, VOCAB_SIZE, reserved_tokens)
+
+"""
+Let's see some tokens!
+"""
+print("Tokens: ", vocab[100:110])
+
+"""
+Now, let's define the tokenizer. We will configure the tokenizer with the
+the vocabularies trained above. We will define a maximum sequence length so that
+all sequences are padded to the same length, if the length of the sequence is
+less than the specified sequence length. Otherwise, the sequence is truncated.
+"""
+tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
+ vocabulary=vocab,
+ lowercase=False,
+ sequence_length=MAX_SEQUENCE_LENGTH,
+)
+
+"""
+Let's try and tokenize a sample from our dataset! To verify whether the text has
+been tokenized correctly, we can also detokenize the list of tokens back to the
+original text.
+"""
+input_sentence_ex = train_ds.take(1).get_single_element()[0][0]
+input_tokens_ex = tokenizer(input_sentence_ex)
+
+print("Sentence: ", input_sentence_ex)
+print("Tokens: ", input_tokens_ex)
+print("Recovered text after detokenizing: ", tokenizer.detokenize(input_tokens_ex))
+
+
+"""
+## Formatting the dataset
+
+Next, we'll format our datasets in the form that will be fed to the models. We
+need to tokenize the text.
+"""
+
+
+def format_dataset(sentence, label):
+ sentence = tokenizer(sentence)
+ return ({"input_ids": sentence}, label)
+
+
+def make_dataset(dataset):
+ dataset = dataset.map(format_dataset, num_parallel_calls=tf.data.AUTOTUNE)
+ return dataset.shuffle(512).prefetch(16).cache()
+
+
+train_ds = make_dataset(train_ds)
+val_ds = make_dataset(val_ds)
+test_ds = make_dataset(test_ds)
+
+"""
+## Building the model
+
+Now, let's move on to the exciting part - defining our model!
+We first need an embedding layer, i.e., a layer that maps every token in the input
+sequence to a vector. This embedding layer can be initialised randomly. We also
+need a positional embedding layer which encodes the word order in the sequence.
+The convention is to add, i.e., sum, these two embeddings. KerasHub has a
+`keras_hub.layers.TokenAndPositionEmbedding ` layer which does all of the above
+steps for us.
+
+Our FNet classification model consists of three `keras_hub.layers.FNetEncoder`
+layers with a `keras.layers.Dense` layer on top.
+
+Note: For FNet, masking the padding tokens has a minimal effect on results. In the
+official implementation, the padding tokens are not masked.
+"""
+
+input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")
+
+x = keras_hub.layers.TokenAndPositionEmbedding(
+ vocabulary_size=VOCAB_SIZE,
+ sequence_length=MAX_SEQUENCE_LENGTH,
+ embedding_dim=EMBED_DIM,
+ mask_zero=True,
+)(input_ids)
+
+x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
+x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
+x = keras_hub.layers.FNetEncoder(intermediate_dim=INTERMEDIATE_DIM)(inputs=x)
+
+
+x = keras.layers.GlobalAveragePooling1D()(x)
+x = keras.layers.Dropout(0.1)(x)
+outputs = keras.layers.Dense(1, activation="sigmoid")(x)
+
+fnet_classifier = keras.Model(input_ids, outputs, name="fnet_classifier")
+
+"""
+## Training our model
+
+We'll use accuracy to monitor training progress on the validation data. Let's
+train our model for 3 epochs.
+"""
+fnet_classifier.summary()
+fnet_classifier.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+ loss="binary_crossentropy",
+ metrics=["accuracy"],
+)
+fnet_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
+
+"""
+We obtain a train accuracy of around 92% and a validation accuracy of around
+85%. Moreover, for 3 epochs, it takes around 86 seconds to train the model
+(on Colab with a 16 GB Tesla T4 GPU).
+
+Let's calculate the test accuracy.
+"""
+fnet_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)
+
+
+"""
+## Comparison with Transformer model
+
+Let's compare our FNet Classifier model with a Transformer Classifier model. We
+keep all the parameters/hyperparameters the same. For example, we use three
+`TransformerEncoder` layers.
+
+We set the number of heads to 2.
+"""
+NUM_HEADS = 2
+input_ids = keras.Input(shape=(None,), dtype="int64", name="input_ids")
+
+
+x = keras_hub.layers.TokenAndPositionEmbedding(
+ vocabulary_size=VOCAB_SIZE,
+ sequence_length=MAX_SEQUENCE_LENGTH,
+ embedding_dim=EMBED_DIM,
+ mask_zero=True,
+)(input_ids)
+
+x = keras_hub.layers.TransformerEncoder(
+ intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
+)(inputs=x)
+x = keras_hub.layers.TransformerEncoder(
+ intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
+)(inputs=x)
+x = keras_hub.layers.TransformerEncoder(
+ intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
+)(inputs=x)
+
+
+x = keras.layers.GlobalAveragePooling1D()(x)
+x = keras.layers.Dropout(0.1)(x)
+outputs = keras.layers.Dense(1, activation="sigmoid")(x)
+
+transformer_classifier = keras.Model(input_ids, outputs, name="transformer_classifier")
+
+
+transformer_classifier.summary()
+transformer_classifier.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+ loss="binary_crossentropy",
+ metrics=["accuracy"],
+)
+transformer_classifier.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
+
+"""
+We obtain a train accuracy of around 94% and a validation accuracy of around
+86.5%. It takes around 146 seconds to train the model (on Colab with a 16 GB Tesla
+T4 GPU).
+
+Let's calculate the test accuracy.
+"""
+transformer_classifier.evaluate(test_ds, batch_size=BATCH_SIZE)
+
+"""
+Let's make a table and compare the two models. We can see that FNet
+significantly speeds up our run time (1.7x), with only a small sacrifice in
+overall accuracy (drop of 0.75%).
+
+| | **FNet Classifier** | **Transformer Classifier** |
+|:-----------------------:|:-------------------:|:--------------------------:|
+| **Training Time** | 86 seconds | 146 seconds |
+| **Train Accuracy** | 92.34% | 93.85% |
+| **Validation Accuracy** | 85.21% | 86.42% |
+| **Test Accuracy** | 83.94% | 84.69% |
+| **#Params** | 2,321,921 | 2,520,065 |
+"""
diff --git a/knowledge_base/nlp/lstm_seq2seq.py b/knowledge_base/nlp/lstm_seq2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..01f7fadd4c95ab3e3682728717ee1a69b97b1873
--- /dev/null
+++ b/knowledge_base/nlp/lstm_seq2seq.py
@@ -0,0 +1,274 @@
+"""
+Title: Character-level recurrent sequence-to-sequence model
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2017/09/29
+Last modified: 2023/11/22
+Description: Character-level recurrent sequence-to-sequence model.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to implement a basic character-level
+recurrent sequence-to-sequence model. We apply it to translating
+short English sentences into short French sentences,
+character-by-character. Note that it is fairly unusual to
+do character-level machine translation, as word-level
+models are more common in this domain.
+
+**Summary of the algorithm**
+
+- We start with input sequences from a domain (e.g. English sentences)
+ and corresponding target sequences from another domain
+ (e.g. French sentences).
+- An encoder LSTM turns input sequences to 2 state vectors
+ (we keep the last LSTM state and discard the outputs).
+- A decoder LSTM is trained to turn the target sequences into
+ the same sequence but offset by one timestep in the future,
+ a training process called "teacher forcing" in this context.
+ It uses as initial state the state vectors from the encoder.
+ Effectively, the decoder learns to generate `targets[t+1...]`
+ given `targets[...t]`, conditioned on the input sequence.
+- In inference mode, when we want to decode unknown input sequences, we:
+ - Encode the input sequence into state vectors
+ - Start with a target sequence of size 1
+ (just the start-of-sequence character)
+ - Feed the state vectors and 1-char target sequence
+ to the decoder to produce predictions for the next character
+ - Sample the next character using these predictions
+ (we simply use argmax).
+ - Append the sampled character to the target sequence
+ - Repeat until we generate the end-of-sequence character or we
+ hit the character limit.
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import keras
+import os
+from pathlib import Path
+
+"""
+## Download the data
+"""
+
+fpath = keras.utils.get_file(origin="http://www.manythings.org/anki/fra-eng.zip")
+dirpath = Path(fpath).parent.absolute()
+os.system(f"unzip -q {fpath} -d {dirpath}")
+
+"""
+## Configuration
+"""
+
+batch_size = 64 # Batch size for training.
+epochs = 100 # Number of epochs to train for.
+latent_dim = 256 # Latent dimensionality of the encoding space.
+num_samples = 10000 # Number of samples to train on.
+# Path to the data txt file on disk.
+data_path = os.path.join(dirpath, "fra.txt")
+
+"""
+## Prepare the data
+"""
+
+# Vectorize the data.
+input_texts = []
+target_texts = []
+input_characters = set()
+target_characters = set()
+with open(data_path, "r", encoding="utf-8") as f:
+ lines = f.read().split("\n")
+for line in lines[: min(num_samples, len(lines) - 1)]:
+ input_text, target_text, _ = line.split("\t")
+ # We use "tab" as the "start sequence" character
+ # for the targets, and "\n" as "end sequence" character.
+ target_text = "\t" + target_text + "\n"
+ input_texts.append(input_text)
+ target_texts.append(target_text)
+ for char in input_text:
+ if char not in input_characters:
+ input_characters.add(char)
+ for char in target_text:
+ if char not in target_characters:
+ target_characters.add(char)
+
+input_characters = sorted(list(input_characters))
+target_characters = sorted(list(target_characters))
+num_encoder_tokens = len(input_characters)
+num_decoder_tokens = len(target_characters)
+max_encoder_seq_length = max([len(txt) for txt in input_texts])
+max_decoder_seq_length = max([len(txt) for txt in target_texts])
+
+print("Number of samples:", len(input_texts))
+print("Number of unique input tokens:", num_encoder_tokens)
+print("Number of unique output tokens:", num_decoder_tokens)
+print("Max sequence length for inputs:", max_encoder_seq_length)
+print("Max sequence length for outputs:", max_decoder_seq_length)
+
+input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
+target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
+
+encoder_input_data = np.zeros(
+ (len(input_texts), max_encoder_seq_length, num_encoder_tokens),
+ dtype="float32",
+)
+decoder_input_data = np.zeros(
+ (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
+ dtype="float32",
+)
+decoder_target_data = np.zeros(
+ (len(input_texts), max_decoder_seq_length, num_decoder_tokens),
+ dtype="float32",
+)
+
+for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
+ for t, char in enumerate(input_text):
+ encoder_input_data[i, t, input_token_index[char]] = 1.0
+ encoder_input_data[i, t + 1 :, input_token_index[" "]] = 1.0
+ for t, char in enumerate(target_text):
+ # decoder_target_data is ahead of decoder_input_data by one timestep
+ decoder_input_data[i, t, target_token_index[char]] = 1.0
+ if t > 0:
+ # decoder_target_data will be ahead by one timestep
+ # and will not include the start character.
+ decoder_target_data[i, t - 1, target_token_index[char]] = 1.0
+ decoder_input_data[i, t + 1 :, target_token_index[" "]] = 1.0
+ decoder_target_data[i, t:, target_token_index[" "]] = 1.0
+
+"""
+## Build the model
+"""
+
+# Define an input sequence and process it.
+encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
+encoder = keras.layers.LSTM(latent_dim, return_state=True)
+encoder_outputs, state_h, state_c = encoder(encoder_inputs)
+
+# We discard `encoder_outputs` and only keep the states.
+encoder_states = [state_h, state_c]
+
+# Set up the decoder, using `encoder_states` as initial state.
+decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
+
+# We set up our decoder to return full output sequences,
+# and to return internal states as well. We don't use the
+# return states in the training model, but we will use them in inference.
+decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
+decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
+decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
+decoder_outputs = decoder_dense(decoder_outputs)
+
+# Define the model that will turn
+# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
+model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
+
+"""
+## Train the model
+"""
+
+model.compile(
+ optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"]
+)
+model.fit(
+ [encoder_input_data, decoder_input_data],
+ decoder_target_data,
+ batch_size=batch_size,
+ epochs=epochs,
+ validation_split=0.2,
+)
+# Save model
+model.save("s2s_model.keras")
+
+"""
+## Run inference (sampling)
+
+1. encode input and retrieve initial decoder state
+2. run one step of decoder with this initial state
+and a "start of sequence" token as target.
+Output will be the next target token.
+3. Repeat with the current target token and current states
+"""
+
+# Define sampling models
+# Restore the model and construct the encoder and decoder.
+model = keras.models.load_model("s2s_model.keras")
+
+encoder_inputs = model.input[0] # input_1
+encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1
+encoder_states = [state_h_enc, state_c_enc]
+encoder_model = keras.Model(encoder_inputs, encoder_states)
+
+decoder_inputs = model.input[1] # input_2
+decoder_state_input_h = keras.Input(shape=(latent_dim,))
+decoder_state_input_c = keras.Input(shape=(latent_dim,))
+decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
+decoder_lstm = model.layers[3]
+decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
+ decoder_inputs, initial_state=decoder_states_inputs
+)
+decoder_states = [state_h_dec, state_c_dec]
+decoder_dense = model.layers[4]
+decoder_outputs = decoder_dense(decoder_outputs)
+decoder_model = keras.Model(
+ [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
+)
+
+# Reverse-lookup token index to decode sequences back to
+# something readable.
+reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
+reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
+
+
+def decode_sequence(input_seq):
+ # Encode the input as state vectors.
+ states_value = encoder_model.predict(input_seq, verbose=0)
+
+ # Generate empty target sequence of length 1.
+ target_seq = np.zeros((1, 1, num_decoder_tokens))
+ # Populate the first character of target sequence with the start character.
+ target_seq[0, 0, target_token_index["\t"]] = 1.0
+
+ # Sampling loop for a batch of sequences
+ # (to simplify, here we assume a batch of size 1).
+ stop_condition = False
+ decoded_sentence = ""
+ while not stop_condition:
+ output_tokens, h, c = decoder_model.predict(
+ [target_seq] + states_value, verbose=0
+ )
+
+ # Sample a token
+ sampled_token_index = np.argmax(output_tokens[0, -1, :])
+ sampled_char = reverse_target_char_index[sampled_token_index]
+ decoded_sentence += sampled_char
+
+ # Exit condition: either hit max length
+ # or find stop character.
+ if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
+ stop_condition = True
+
+ # Update the target sequence (of length 1).
+ target_seq = np.zeros((1, 1, num_decoder_tokens))
+ target_seq[0, 0, sampled_token_index] = 1.0
+
+ # Update states
+ states_value = [h, c]
+ return decoded_sentence
+
+
+"""
+You can now generate decoded sentences as such:
+"""
+
+for seq_index in range(20):
+ # Take one sequence (part of the training set)
+ # for trying out decoding.
+ input_seq = encoder_input_data[seq_index : seq_index + 1]
+ decoded_sentence = decode_sequence(input_seq)
+ print("-")
+ print("Input sentence:", input_texts[seq_index])
+ print("Decoded sentence:", decoded_sentence)
diff --git a/knowledge_base/nlp/masked_language_modeling.py b/knowledge_base/nlp/masked_language_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..01262642690a542b469082c7d732df70cf700473
--- /dev/null
+++ b/knowledge_base/nlp/masked_language_modeling.py
@@ -0,0 +1,505 @@
+"""
+Title: End-to-end Masked Language Modeling with BERT
+Author: [Ankur Singh](https://twitter.com/ankur310794)
+Date created: 2020/09/18
+Last modified: 2024/03/15
+Description: Implement a Masked Language Model (MLM) with BERT and fine-tune it on the IMDB Reviews dataset.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
+"""
+
+"""
+## Introduction
+
+Masked Language Modeling is a fill-in-the-blank task,
+where a model uses the context words surrounding a mask token to try to predict what the
+masked word should be.
+
+For an input that contains one or more mask tokens,
+the model will generate the most likely substitution for each.
+
+Example:
+
+- Input: "I have watched this [MASK] and it was awesome."
+- Output: "I have watched this movie and it was awesome."
+
+Masked language modeling is a great way to train a language
+model in a self-supervised setting (without human-annotated labels).
+Such a model can then be fine-tuned to accomplish various supervised
+NLP tasks.
+
+This example teaches you how to build a BERT model from scratch,
+train it with the masked language modeling task,
+and then fine-tune this model on a sentiment classification task.
+
+We will use the Keras `TextVectorization` and `MultiHeadAttention` layers
+to create a BERT Transformer-Encoder network architecture.
+
+Note: This example should be run with `tf-nightly`.
+"""
+
+"""
+## Setup
+
+Install `tf-nightly` via `pip install tf-nightly`.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow
+
+import keras_hub
+
+import keras
+from keras import layers
+from keras.layers import TextVectorization
+
+from dataclasses import dataclass
+import pandas as pd
+import numpy as np
+import glob
+import re
+from pprint import pprint
+
+"""
+## Set-up Configuration
+"""
+
+
+@dataclass
+class Config:
+ MAX_LEN = 256
+ BATCH_SIZE = 32
+ LR = 0.001
+ VOCAB_SIZE = 30000
+ EMBED_DIM = 128
+ NUM_HEAD = 8 # used in bert model
+ FF_DIM = 128 # used in bert model
+ NUM_LAYERS = 1
+
+
+config = Config()
+
+"""
+## Load the data
+
+We will first download the IMDB data and load into a Pandas dataframe.
+"""
+
+"""shell
+curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+tar -xf aclImdb_v1.tar.gz
+"""
+
+
+def get_text_list_from_files(files):
+ text_list = []
+ for name in files:
+ with open(name) as f:
+ for line in f:
+ text_list.append(line)
+ return text_list
+
+
+def get_data_from_text_files(folder_name):
+ pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
+ pos_texts = get_text_list_from_files(pos_files)
+ neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
+ neg_texts = get_text_list_from_files(neg_files)
+ df = pd.DataFrame(
+ {
+ "review": pos_texts + neg_texts,
+ "sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),
+ }
+ )
+ df = df.sample(len(df)).reset_index(drop=True)
+ return df
+
+
+train_df = get_data_from_text_files("train")
+test_df = get_data_from_text_files("test")
+
+all_data = pd.concat([train_df, test_df], ignore_index=True)
+
+"""
+## Dataset preparation
+
+We will use the `TextVectorization` layer to vectorize the text into integer token ids.
+It transforms a batch of strings into either
+a sequence of token indices (one sample = 1D array of integer token indices, in order)
+or a dense representation (one sample = 1D array of float values encoding an unordered set of tokens).
+
+Below, we define 3 preprocessing functions.
+
+1. The `get_vectorize_layer` function builds the `TextVectorization` layer.
+2. The `encode` function encodes raw text into integer token ids.
+3. The `get_masked_input_and_labels` function will mask input token ids.
+It masks 15% of all input tokens in each sequence at random.
+"""
+
+# For data pre-processing and tf.data.Dataset
+import tensorflow as tf
+
+
+def custom_standardization(input_data):
+ lowercase = tf.strings.lower(input_data)
+ stripped_html = tf.strings.regex_replace(lowercase, "
", " ")
+ return tf.strings.regex_replace(
+ stripped_html, "[%s]" % re.escape("!#$%&'()*+,-./:;<=>?@\^_`{|}~"), ""
+ )
+
+
+def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=["[MASK]"]):
+ """Build Text vectorization layer
+
+ Args:
+ texts (list): List of string i.e input texts
+ vocab_size (int): vocab size
+ max_seq (int): Maximum sequence length.
+ special_tokens (list, optional): List of special tokens. Defaults to ['[MASK]'].
+
+ Returns:
+ layers.Layer: Return TextVectorization Keras Layer
+ """
+ vectorize_layer = TextVectorization(
+ max_tokens=vocab_size,
+ output_mode="int",
+ standardize=custom_standardization,
+ output_sequence_length=max_seq,
+ )
+ vectorize_layer.adapt(texts)
+
+ # Insert mask token in vocabulary
+ vocab = vectorize_layer.get_vocabulary()
+ vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[mask]"]
+ vectorize_layer.set_vocabulary(vocab)
+ return vectorize_layer
+
+
+vectorize_layer = get_vectorize_layer(
+ all_data.review.values.tolist(),
+ config.VOCAB_SIZE,
+ config.MAX_LEN,
+ special_tokens=["[mask]"],
+)
+
+# Get mask token id for masked language model
+mask_token_id = vectorize_layer(["[mask]"]).numpy()[0][0]
+
+
+def encode(texts):
+ encoded_texts = vectorize_layer(texts)
+ return encoded_texts.numpy()
+
+
+def get_masked_input_and_labels(encoded_texts):
+ # 15% BERT masking
+ inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
+ # Do not mask special tokens
+ inp_mask[encoded_texts <= 2] = False
+ # Set targets to -1 by default, it means ignore
+ labels = -1 * np.ones(encoded_texts.shape, dtype=int)
+ # Set labels for masked tokens
+ labels[inp_mask] = encoded_texts[inp_mask]
+
+ # Prepare input
+ encoded_texts_masked = np.copy(encoded_texts)
+ # Set input to [MASK] which is the last token for the 90% of tokens
+ # This means leaving 10% unchanged
+ inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)
+ encoded_texts_masked[inp_mask_2mask] = (
+ mask_token_id # mask token is the last in the dict
+ )
+
+ # Set 10% to a random token
+ inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)
+ encoded_texts_masked[inp_mask_2random] = np.random.randint(
+ 3, mask_token_id, inp_mask_2random.sum()
+ )
+
+ # Prepare sample_weights to pass to .fit() method
+ sample_weights = np.ones(labels.shape)
+ sample_weights[labels == -1] = 0
+
+ # y_labels would be same as encoded_texts i.e input tokens
+ y_labels = np.copy(encoded_texts)
+
+ return encoded_texts_masked, y_labels, sample_weights
+
+
+# We have 25000 examples for training
+x_train = encode(train_df.review.values) # encode reviews with vectorizer
+y_train = train_df.sentiment.values
+train_classifier_ds = (
+ tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ .shuffle(1000)
+ .batch(config.BATCH_SIZE)
+)
+
+# We have 25000 examples for testing
+x_test = encode(test_df.review.values)
+y_test = test_df.sentiment.values
+test_classifier_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(
+ config.BATCH_SIZE
+)
+
+# Dataset for end to end model input (will be used at the end)
+test_raw_classifier_ds = test_df
+
+# Prepare data for masked language model
+x_all_review = encode(all_data.review.values)
+x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(
+ x_all_review
+)
+
+mlm_ds = tf.data.Dataset.from_tensor_slices(
+ (x_masked_train, y_masked_labels, sample_weights)
+)
+mlm_ds = mlm_ds.shuffle(1000).batch(config.BATCH_SIZE)
+
+"""
+## Create BERT model (Pretraining Model) for masked language modeling
+
+We will create a BERT-like pretraining model architecture
+using the `MultiHeadAttention` layer.
+It will take token ids as inputs (including masked tokens)
+and it will predict the correct ids for the masked input tokens.
+"""
+
+
+def bert_module(query, key, value, i):
+ # Multi headed self-attention
+ attention_output = layers.MultiHeadAttention(
+ num_heads=config.NUM_HEAD,
+ key_dim=config.EMBED_DIM // config.NUM_HEAD,
+ name="encoder_{}_multiheadattention".format(i),
+ )(query, key, value)
+ attention_output = layers.Dropout(0.1, name="encoder_{}_att_dropout".format(i))(
+ attention_output
+ )
+ attention_output = layers.LayerNormalization(
+ epsilon=1e-6, name="encoder_{}_att_layernormalization".format(i)
+ )(query + attention_output)
+
+ # Feed-forward layer
+ ffn = keras.Sequential(
+ [
+ layers.Dense(config.FF_DIM, activation="relu"),
+ layers.Dense(config.EMBED_DIM),
+ ],
+ name="encoder_{}_ffn".format(i),
+ )
+ ffn_output = ffn(attention_output)
+ ffn_output = layers.Dropout(0.1, name="encoder_{}_ffn_dropout".format(i))(
+ ffn_output
+ )
+ sequence_output = layers.LayerNormalization(
+ epsilon=1e-6, name="encoder_{}_ffn_layernormalization".format(i)
+ )(attention_output + ffn_output)
+ return sequence_output
+
+
+loss_fn = keras.losses.SparseCategoricalCrossentropy(reduction=None)
+loss_tracker = keras.metrics.Mean(name="loss")
+
+
+class MaskedLanguageModel(keras.Model):
+
+ def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
+
+ loss = loss_fn(y, y_pred, sample_weight)
+ loss_tracker.update_state(loss, sample_weight=sample_weight)
+ return keras.ops.sum(loss)
+
+ def compute_metrics(self, x, y, y_pred, sample_weight):
+
+ # Return a dict mapping metric names to current value
+ return {"loss": loss_tracker.result()}
+
+ @property
+ def metrics(self):
+ # We list our `Metric` objects here so that `reset_states()` can be
+ # called automatically at the start of each epoch
+ # or at the start of `evaluate()`.
+ # If you don't implement this property, you have to call
+ # `reset_states()` yourself at the time of your choosing.
+ return [loss_tracker]
+
+
+def create_masked_language_bert_model():
+ inputs = layers.Input((config.MAX_LEN,), dtype="int64")
+
+ word_embeddings = layers.Embedding(
+ config.VOCAB_SIZE, config.EMBED_DIM, name="word_embedding"
+ )(inputs)
+ position_embeddings = keras_hub.layers.PositionEmbedding(
+ sequence_length=config.MAX_LEN
+ )(word_embeddings)
+ embeddings = word_embeddings + position_embeddings
+
+ encoder_output = embeddings
+ for i in range(config.NUM_LAYERS):
+ encoder_output = bert_module(encoder_output, encoder_output, encoder_output, i)
+
+ mlm_output = layers.Dense(config.VOCAB_SIZE, name="mlm_cls", activation="softmax")(
+ encoder_output
+ )
+ mlm_model = MaskedLanguageModel(inputs, mlm_output, name="masked_bert_model")
+
+ optimizer = keras.optimizers.Adam(learning_rate=config.LR)
+ mlm_model.compile(optimizer=optimizer)
+ return mlm_model
+
+
+id2token = dict(enumerate(vectorize_layer.get_vocabulary()))
+token2id = {y: x for x, y in id2token.items()}
+
+
+class MaskedTextGenerator(keras.callbacks.Callback):
+ def __init__(self, sample_tokens, top_k=5):
+ self.sample_tokens = sample_tokens
+ self.k = top_k
+
+ def decode(self, tokens):
+ return " ".join([id2token[t] for t in tokens if t != 0])
+
+ def convert_ids_to_tokens(self, id):
+ return id2token[id]
+
+ def on_epoch_end(self, epoch, logs=None):
+ prediction = self.model.predict(self.sample_tokens)
+
+ masked_index = np.where(self.sample_tokens == mask_token_id)
+ masked_index = masked_index[1]
+ mask_prediction = prediction[0][masked_index]
+
+ top_indices = mask_prediction[0].argsort()[-self.k :][::-1]
+ values = mask_prediction[0][top_indices]
+
+ for i in range(len(top_indices)):
+ p = top_indices[i]
+ v = values[i]
+ tokens = np.copy(sample_tokens[0])
+ tokens[masked_index[0]] = p
+ result = {
+ "input_text": self.decode(sample_tokens[0].numpy()),
+ "prediction": self.decode(tokens),
+ "probability": v,
+ "predicted mask token": self.convert_ids_to_tokens(p),
+ }
+ pprint(result)
+
+
+sample_tokens = vectorize_layer(["I have watched this [mask] and it was awesome"])
+generator_callback = MaskedTextGenerator(sample_tokens.numpy())
+
+bert_masked_model = create_masked_language_bert_model()
+bert_masked_model.summary()
+
+"""
+## Train and Save
+"""
+
+bert_masked_model.fit(mlm_ds, epochs=5, callbacks=[generator_callback])
+bert_masked_model.save("bert_mlm_imdb.keras")
+
+"""
+## Fine-tune a sentiment classification model
+
+We will fine-tune our self-supervised model on a downstream task of sentiment classification.
+To do this, let's create a classifier by adding a pooling layer and a `Dense` layer on top of the
+pretrained BERT features.
+
+"""
+
+# Load pretrained bert model
+mlm_model = keras.models.load_model(
+ "bert_mlm_imdb.keras", custom_objects={"MaskedLanguageModel": MaskedLanguageModel}
+)
+pretrained_bert_model = keras.Model(
+ mlm_model.input, mlm_model.get_layer("encoder_0_ffn_layernormalization").output
+)
+
+# Freeze it
+pretrained_bert_model.trainable = False
+
+
+def create_classifier_bert_model():
+ inputs = layers.Input((config.MAX_LEN,), dtype="int64")
+ sequence_output = pretrained_bert_model(inputs)
+ pooled_output = layers.GlobalMaxPooling1D()(sequence_output)
+ hidden_layer = layers.Dense(64, activation="relu")(pooled_output)
+ outputs = layers.Dense(1, activation="sigmoid")(hidden_layer)
+ classifer_model = keras.Model(inputs, outputs, name="classification")
+ optimizer = keras.optimizers.Adam()
+ classifer_model.compile(
+ optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
+ )
+ return classifer_model
+
+
+classifer_model = create_classifier_bert_model()
+classifer_model.summary()
+
+# Train the classifier with frozen BERT stage
+classifer_model.fit(
+ train_classifier_ds,
+ epochs=5,
+ validation_data=test_classifier_ds,
+)
+
+# Unfreeze the BERT model for fine-tuning
+pretrained_bert_model.trainable = True
+optimizer = keras.optimizers.Adam()
+classifer_model.compile(
+ optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
+)
+classifer_model.fit(
+ train_classifier_ds,
+ epochs=5,
+ validation_data=test_classifier_ds,
+)
+
+"""
+## Create an end-to-end model and evaluate it
+
+When you want to deploy a model, it's best if it already includes its preprocessing
+pipeline, so that you don't have to reimplement the preprocessing logic in your
+production environment. Let's create an end-to-end model that incorporates
+the `TextVectorization` layer inside evaluate method, and let's evaluate. We will pass raw strings as input.
+"""
+
+
+# We create a custom Model to override the evaluate method so
+# that it first pre-process text data
+class ModelEndtoEnd(keras.Model):
+
+ def evaluate(self, inputs):
+ features = encode(inputs.review.values)
+ labels = inputs.sentiment.values
+ test_classifier_ds = (
+ tf.data.Dataset.from_tensor_slices((features, labels))
+ .shuffle(1000)
+ .batch(config.BATCH_SIZE)
+ )
+ return super().evaluate(test_classifier_ds)
+
+ # Build the model
+ def build(self, input_shape):
+ self.built = True
+
+
+def get_end_to_end(model):
+ inputs = classifer_model.inputs[0]
+ outputs = classifer_model.outputs
+ end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
+ optimizer = keras.optimizers.Adam(learning_rate=config.LR)
+ end_to_end_model.compile(
+ optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
+ )
+ return end_to_end_model
+
+
+end_to_end_classification_model = get_end_to_end(classifer_model)
+# Pass raw text dataframe to the model
+end_to_end_classification_model.evaluate(test_raw_classifier_ds)
diff --git a/knowledge_base/nlp/multi_label_classification.py b/knowledge_base/nlp/multi_label_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..60a9936f254aa0fb6a0e4dba63dd178daff6e272
--- /dev/null
+++ b/knowledge_base/nlp/multi_label_classification.py
@@ -0,0 +1,451 @@
+"""
+Title: Large-scale multi-label text classification
+Author: [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345)
+Date created: 2020/09/25
+Last modified: 2025/02/27
+Description: Implementing a large-scale multi-label text classification model.
+Accelerator: GPU
+Converted to keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
+"""
+
+"""
+## Introduction
+
+In this example, we will build a multi-label text classifier to predict the subject areas
+of arXiv papers from their abstract bodies. This type of classifier can be useful for
+conference submission portals like [OpenReview](https://openreview.net/). Given a paper
+abstract, the portal could provide suggestions for which areas the paper would
+best belong to.
+
+The dataset was collected using the
+[`arXiv` Python library](https://github.com/lukasschwab/arxiv.py)
+that provides a wrapper around the
+[original arXiv API](http://arxiv.org/help/api/index).
+To learn more about the data collection process, please refer to
+[this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb).
+Additionally, you can also find the dataset on
+[Kaggle](https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts).
+"""
+
+"""
+## Imports
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
+
+import keras
+from keras import layers, ops
+
+from sklearn.model_selection import train_test_split
+
+from ast import literal_eval
+import matplotlib.pyplot as plt
+import pandas as pd
+import numpy as np
+
+"""
+## Perform exploratory data analysis
+
+In this section, we first load the dataset into a `pandas` dataframe and then perform
+some basic exploratory data analysis (EDA).
+"""
+
+arxiv_data = pd.read_csv(
+ "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
+)
+arxiv_data.head()
+
+"""
+Our text features are present in the `summaries` column and their corresponding labels
+are in `terms`. As you can notice, there are multiple categories associated with a
+particular entry.
+"""
+
+print(f"There are {len(arxiv_data)} rows in the dataset.")
+
+"""
+Real-world data is noisy. One of the most commonly observed source of noise is data
+duplication. Here we notice that our initial dataset has got about 13k duplicate entries.
+"""
+
+total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
+print(f"There are {total_duplicate_titles} duplicate titles.")
+
+"""
+Before proceeding further, we drop these entries.
+"""
+
+arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
+print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")
+
+# There are some terms with occurrence as low as 1.
+print(sum(arxiv_data["terms"].value_counts() == 1))
+
+# How many unique terms?
+print(arxiv_data["terms"].nunique())
+
+"""
+As observed above, out of 3,157 unique combinations of `terms`, 2,321 entries have the
+lowest occurrence. To prepare our train, validation, and test sets with
+[stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop
+these terms.
+"""
+
+# Filtering the rare terms.
+arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
+arxiv_data_filtered.shape
+
+"""
+## Convert the string labels to lists of strings
+
+The initial labels are represented as raw strings. Here we make them `List[str]` for a
+more compact representation.
+"""
+
+arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
+ lambda x: literal_eval(x)
+)
+arxiv_data_filtered["terms"].values[:5]
+
+"""
+## Use stratified splits because of class imbalance
+
+The dataset has a
+[class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset).
+So, to have a fair evaluation result, we need to ensure the datasets are sampled with
+stratification. To know more about different strategies to deal with the class imbalance
+problem, you can follow
+[this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data).
+For an end-to-end demonstration of classification with imbablanced data, refer to
+[Imbalanced classification: credit card fraud detection](https://keras.io/examples/structured_data/imbalanced_classification/).
+"""
+
+test_split = 0.1
+
+# Initial train and test split.
+train_df, test_df = train_test_split(
+ arxiv_data_filtered,
+ test_size=test_split,
+ stratify=arxiv_data_filtered["terms"].values,
+)
+
+# Splitting the test set further into validation
+# and new test sets.
+val_df = test_df.sample(frac=0.5)
+test_df.drop(val_df.index, inplace=True)
+
+print(f"Number of rows in training set: {len(train_df)}")
+print(f"Number of rows in validation set: {len(val_df)}")
+print(f"Number of rows in test set: {len(test_df)}")
+
+"""
+## Multi-label binarization
+
+Now we preprocess our labels using the
+[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup)
+layer.
+"""
+
+# For RaggedTensor
+import tensorflow as tf
+
+terms = tf.ragged.constant(train_df["terms"].values)
+lookup = layers.StringLookup(output_mode="multi_hot")
+lookup.adapt(terms)
+vocab = lookup.get_vocabulary()
+
+
+def invert_multi_hot(encoded_labels):
+ """Reverse a single multi-hot encoded label to a tuple of vocab terms."""
+ hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
+ return np.take(vocab, hot_indices)
+
+
+print("Vocabulary:\n")
+print(vocab)
+
+
+"""
+Here we are separating the individual unique classes available from the label
+pool and then using this information to represent a given label set with 0's and 1's.
+Below is an example.
+"""
+
+sample_label = train_df["terms"].iloc[0]
+print(f"Original label: {sample_label}")
+
+label_binarized = lookup([sample_label])
+print(f"Label-binarized representation: {label_binarized}")
+
+"""
+## Data preprocessing and `tf.data.Dataset` objects
+
+We first get percentile estimates of the sequence lengths. The purpose will be clear in a
+moment.
+"""
+
+train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()
+
+"""
+Notice that 50% of the abstracts have a length of 154 (you may get a different number
+based on the split). So, any number close to that value is a good enough approximate for the
+maximum sequence length.
+
+Now, we implement utilities to prepare our datasets.
+"""
+
+max_seqlen = 150
+batch_size = 128
+padding_token = ""
+auto = tf.data.AUTOTUNE
+
+
+def make_dataset(dataframe, is_train=True):
+ labels = tf.ragged.constant(dataframe["terms"].values)
+ label_binarized = lookup(labels).numpy()
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (dataframe["summaries"].values, label_binarized)
+ )
+ dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
+ return dataset.batch(batch_size)
+
+
+"""
+Now we can prepare the `tf.data.Dataset` objects.
+"""
+
+train_dataset = make_dataset(train_df, is_train=True)
+validation_dataset = make_dataset(val_df, is_train=False)
+test_dataset = make_dataset(test_df, is_train=False)
+
+"""
+## Dataset preview
+"""
+
+text_batch, label_batch = next(iter(train_dataset))
+
+for i, text in enumerate(text_batch[:5]):
+ label = label_batch[i].numpy()[None, ...]
+ print(f"Abstract: {text}")
+ print(f"Label(s): {invert_multi_hot(label[0])}")
+ print(" ")
+
+"""
+## Vectorization
+
+Before we feed the data to our model, we need to vectorize it (represent it in a numerical form).
+For that purpose, we will use the
+[`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).
+It can operate as a part of your main model so that the model is excluded from the core
+preprocessing logic. This greatly reduces the chances of training / serving skew during inference.
+
+We first calculate the number of unique words present in the abstracts.
+"""
+
+# Source: https://stackoverflow.com/a/18937309/7636462
+vocabulary = set()
+train_df["summaries"].str.lower().str.split().apply(vocabulary.update)
+vocabulary_size = len(vocabulary)
+print(vocabulary_size)
+
+
+"""
+We now create our vectorization layer and `map()` to the `tf.data.Dataset`s created
+earlier.
+"""
+
+text_vectorizer = layers.TextVectorization(
+ max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
+)
+
+# `TextVectorization` layer needs to be adapted as per the vocabulary from our
+# training set.
+with tf.device("/CPU:0"):
+ text_vectorizer.adapt(train_dataset.map(lambda text, label: text))
+
+train_dataset = train_dataset.map(
+ lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
+).prefetch(auto)
+validation_dataset = validation_dataset.map(
+ lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
+).prefetch(auto)
+test_dataset = test_dataset.map(
+ lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
+).prefetch(auto)
+
+
+"""
+A batch of raw text will first go through the `TextVectorization` layer and it will
+generate their integer representations. Internally, the `TextVectorization` layer will
+first create bi-grams out of the sequences and then represent them using
+[TF-IDF](https://wikipedia.org/wiki/Tf%E2%80%93idf). The output representations will then
+be passed to the shallow model responsible for text classification.
+
+To learn more about other possible configurations with `TextVectorizer`, please consult
+the
+[official documentation](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).
+
+**Note**: Setting the `max_tokens` argument to a pre-calculated vocabulary size is
+not a requirement.
+"""
+
+"""
+## Create a text classification model
+
+We will keep our model simple -- it will be a small stack of fully-connected layers with
+ReLU as the non-linearity.
+
+"""
+
+
+def make_model():
+ shallow_mlp_model = keras.Sequential(
+ [
+ layers.Dense(512, activation="relu"),
+ layers.Dense(256, activation="relu"),
+ layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
+ ] # More on why "sigmoid" has been used here in a moment.
+ )
+ return shallow_mlp_model
+
+
+"""
+## Train the model
+
+We will train our model using the binary crossentropy loss. This is because the labels
+are not disjoint. For a given abstract, we may have multiple categories. So, we will
+divide the prediction task into a series of multiple binary classification problems. This
+is also why we kept the activation function of the classification layer in our model to
+sigmoid. Researchers have used other combinations of loss function and activation
+function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932),
+Mahajan et al. used the softmax activation function and cross-entropy loss to train
+their models.
+
+There are several options of metrics that can be used in multi-label classification.
+To keep this code example narrow we decided to use the
+[binary accuracy metric](https://keras.io/api/metrics/accuracy_metrics/#binaryaccuracy-class).
+To see the explanation why this metric is used we refer to this
+[pull-request](https://github.com/keras-team/keras-io/pull/1133#issuecomment-1322736860).
+There are also other suitable metrics for multi-label classification, like
+[F1 Score](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score) or
+[Hamming loss](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss).
+"""
+
+epochs = 20
+
+shallow_mlp_model = make_model()
+shallow_mlp_model.compile(
+ loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]
+)
+
+history = shallow_mlp_model.fit(
+ train_dataset, validation_data=validation_dataset, epochs=epochs
+)
+
+
+def plot_result(item):
+ plt.plot(history.history[item], label=item)
+ plt.plot(history.history["val_" + item], label="val_" + item)
+ plt.xlabel("Epochs")
+ plt.ylabel(item)
+ plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_result("loss")
+plot_result("binary_accuracy")
+
+"""
+While training, we notice an initial sharp fall in the loss followed by a gradual decay.
+"""
+
+"""
+### Evaluate the model
+"""
+
+_, binary_acc = shallow_mlp_model.evaluate(test_dataset)
+print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")
+
+"""
+The trained model gives us an evaluation accuracy of ~99%.
+"""
+
+"""
+## Inference
+
+An important feature of the
+[preprocessing layers provided by Keras](https://keras.io/api/layers/preprocessing_layers/)
+is that they can be included inside a `tf.keras.Model`. We will export an inference model
+by including the `text_vectorization` layer on top of `shallow_mlp_model`. This will
+allow our inference model to directly operate on raw strings.
+
+**Note** that during training it is always preferable to use these preprocessing
+layers as a part of the data input pipeline rather than the model to avoid
+surfacing bottlenecks for the hardware accelerators. This also allows for
+asynchronous data processing.
+"""
+
+
+# We create a custom Model to override the predict method so
+# that it first vectorizes text data
+class ModelEndtoEnd(keras.Model):
+
+ def predict(self, inputs):
+ indices = text_vectorizer(inputs)
+ return super().predict(indices)
+
+
+def get_inference_model(model):
+ inputs = shallow_mlp_model.inputs
+ outputs = shallow_mlp_model.outputs
+ end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
+ end_to_end_model.compile(
+ optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
+ )
+ return end_to_end_model
+
+
+model_for_inference = get_inference_model(shallow_mlp_model)
+
+# Create a small dataset just for demonstrating inference.
+inference_dataset = make_dataset(test_df.sample(2), is_train=False)
+text_batch, label_batch = next(iter(inference_dataset))
+predicted_probabilities = model_for_inference.predict(text_batch)
+
+
+# Perform inference.
+for i, text in enumerate(text_batch[:5]):
+ label = label_batch[i].numpy()[None, ...]
+ print(f"Abstract: {text}")
+ print(f"Label(s): {invert_multi_hot(label[0])}")
+ predicted_proba = [proba for proba in predicted_probabilities[i]]
+ top_3_labels = [
+ x
+ for _, x in sorted(
+ zip(predicted_probabilities[i], lookup.get_vocabulary()),
+ key=lambda pair: pair[0],
+ reverse=True,
+ )
+ ][:3]
+ print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
+ print(" ")
+
+"""
+The prediction results are not that great but not below the par for a simple model like
+ours. We can improve this performance with models that consider word order like LSTM or
+even those that use Transformers ([Vaswani et al.](https://arxiv.org/abs/1706.03762)).
+"""
+
+"""
+## Acknowledgements
+
+We would like to thank [Matt Watson](https://github.com/mattdangerw) for helping us
+tackle the multi-label binarization part and inverse-transforming the processed labels
+to the original form.
+
+Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric.
+"""
diff --git a/knowledge_base/nlp/multimodal_entailment.py b/knowledge_base/nlp/multimodal_entailment.py
new file mode 100644
index 0000000000000000000000000000000000000000..dab6ca1063a01744ea9344ee6408f25970a4cd72
--- /dev/null
+++ b/knowledge_base/nlp/multimodal_entailment.py
@@ -0,0 +1,713 @@
+"""
+Title: Multimodal entailment
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/08/08
+Last modified: 2025/01/03
+Description: Training a multimodal model for predicting entailment.
+Accelerator: GPU
+Converted to Keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
+"""
+
+"""
+## Introduction
+
+In this example, we will build and train a model for predicting multimodal entailment. We will be
+using the
+[multimodal entailment dataset](https://github.com/google-research-datasets/recognizing-multimodal-entailment)
+recently introduced by Google Research.
+
+### What is multimodal entailment?
+
+On social media platforms, to audit and moderate content
+we may want to find answers to the
+following questions in near real-time:
+
+* Does a given piece of information contradict the other?
+* Does a given piece of information imply the other?
+
+In NLP, this task is called analyzing _textual entailment_. However, that's only
+when the information comes from text content.
+In practice, it's often the case the information available comes not just
+from text content, but from a multimodal combination of text, images, audio, video, etc.
+_Multimodal entailment_ is simply the extension of textual entailment to a variety
+of new input modalities.
+
+### Requirements
+
+This example requires TensorFlow 2.5 or higher. In addition, TensorFlow Hub and
+TensorFlow Text are required for the BERT model
+([Devlin et al.](https://arxiv.org/abs/1810.04805)). These libraries can be installed
+using the following command:
+"""
+
+"""shell
+pip install -q tensorflow_text
+"""
+
+"""
+## Imports
+"""
+
+from sklearn.model_selection import train_test_split
+import matplotlib.pyplot as plt
+import pandas as pd
+import numpy as np
+import random
+import math
+from skimage.io import imread
+from skimage.transform import resize
+from PIL import Image
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
+
+import keras
+import keras_hub
+from keras.utils import PyDataset
+
+"""
+## Define a label map
+"""
+
+label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}
+
+"""
+## Collect the dataset
+
+The original dataset is available
+[here](https://github.com/google-research-datasets/recognizing-multimodal-entailment).
+It comes with URLs of images which are hosted on Twitter's photo storage system called
+the
+[Photo Blob Storage (PBS for short)](https://blog.twitter.com/engineering/en_us/a/2012/blobstore-twitter-s-in-house-photo-storage-system).
+We will be working with the downloaded images along with additional data that comes with
+the original dataset. Thanks to
+[Nilabhra Roy Chowdhury](https://de.linkedin.com/in/nilabhraroychowdhury) who worked on
+preparing the image data.
+"""
+
+image_base_path = keras.utils.get_file(
+ "tweet_images",
+ "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
+ untar=True,
+)
+
+"""
+## Read the dataset and apply basic preprocessing
+"""
+
+df = pd.read_csv(
+ "https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
+).iloc[
+ 0:1000
+] # Resources conservation since these are examples and not SOTA
+df.sample(10)
+
+"""
+The columns we are interested in are the following:
+
+* `text_1`
+* `image_1`
+* `text_2`
+* `image_2`
+* `label`
+
+The entailment task is formulated as the following:
+
+***Given the pairs of (`text_1`, `image_1`) and (`text_2`, `image_2`) do they entail (or
+not entail or contradict) each other?***
+
+We have the images already downloaded. `image_1` is downloaded as `id1` as its filename
+and `image2` is downloaded as `id2` as its filename. In the next step, we will add two
+more columns to `df` - filepaths of `image_1`s and `image_2`s.
+"""
+
+images_one_paths = []
+images_two_paths = []
+
+for idx in range(len(df)):
+ current_row = df.iloc[idx]
+ id_1 = current_row["id_1"]
+ id_2 = current_row["id_2"]
+ extentsion_one = current_row["image_1"].split(".")[-1]
+ extentsion_two = current_row["image_2"].split(".")[-1]
+
+ image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
+ image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")
+
+ images_one_paths.append(image_one_path)
+ images_two_paths.append(image_two_path)
+
+df["image_1_path"] = images_one_paths
+df["image_2_path"] = images_two_paths
+
+# Create another column containing the integer ids of
+# the string labels.
+df["label_idx"] = df["label"].apply(lambda x: label_map[x])
+
+"""
+## Dataset visualization
+"""
+
+
+def visualize(idx):
+ current_row = df.iloc[idx]
+ image_1 = plt.imread(current_row["image_1_path"])
+ image_2 = plt.imread(current_row["image_2_path"])
+ text_1 = current_row["text_1"]
+ text_2 = current_row["text_2"]
+ label = current_row["label"]
+
+ plt.subplot(1, 2, 1)
+ plt.imshow(image_1)
+ plt.axis("off")
+ plt.title("Image One")
+ plt.subplot(1, 2, 2)
+ plt.imshow(image_1)
+ plt.axis("off")
+ plt.title("Image Two")
+ plt.show()
+
+ print(f"Text one: {text_1}")
+ print(f"Text two: {text_2}")
+ print(f"Label: {label}")
+
+
+random_idx = random.choice(range(len(df)))
+visualize(random_idx)
+
+random_idx = random.choice(range(len(df)))
+visualize(random_idx)
+
+"""
+## Train/test split
+
+The dataset suffers from
+[class imbalance problem](https://developers.google.com/machine-learning/glossary#class-imbalanced-dataset).
+We can confirm that in the following cell.
+"""
+
+df["label"].value_counts()
+
+"""
+To account for that we will go for a stratified split.
+"""
+
+# 10% for test
+train_df, test_df = train_test_split(
+ df, test_size=0.1, stratify=df["label"].values, random_state=42
+)
+# 5% for validation
+train_df, val_df = train_test_split(
+ train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
+)
+
+print(f"Total training examples: {len(train_df)}")
+print(f"Total validation examples: {len(val_df)}")
+print(f"Total test examples: {len(test_df)}")
+
+"""
+## Data input pipeline
+
+Keras Hub provides
+[variety of BERT family of models](https://keras.io/keras_hub/presets/).
+Each of those models comes with a
+corresponding preprocessing layer. You can learn more about these models and their
+preprocessing layers from
+[this resource](https://www.kaggle.com/models/keras/bert/keras/bert_base_en_uncased/2).
+
+To keep the runtime of this example relatively short, we will use a base_unacased variant of
+the original BERT model.
+"""
+
+"""
+text preprocessing using KerasHub
+"""
+
+text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
+ "bert_base_en_uncased",
+ sequence_length=128,
+)
+
+"""
+### Run the preprocessor on a sample input
+"""
+
+idx = random.choice(range(len(train_df)))
+row = train_df.iloc[idx]
+sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
+print(f"Text 1: {sample_text_1}")
+print(f"Text 2: {sample_text_2}")
+
+test_text = [sample_text_1, sample_text_2]
+text_preprocessed = text_preprocessor(test_text)
+
+print("Keys : ", list(text_preprocessed.keys()))
+print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
+print("Token Ids : ", text_preprocessed["token_ids"][0, :16])
+print(" Shape Padding Mask : ", text_preprocessed["padding_mask"].shape)
+print("Padding Mask : ", text_preprocessed["padding_mask"][0, :16])
+print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
+print("Segment Ids : ", text_preprocessed["segment_ids"][0, :16])
+
+
+"""
+We will now create `tf.data.Dataset` objects from the dataframes.
+
+Note that the text inputs will be preprocessed as a part of the data input pipeline. But
+the preprocessing modules can also be a part of their corresponding BERT models. This
+helps reduce the training/serving skew and lets our models operate with raw text inputs.
+Follow [this tutorial](https://www.tensorflow.org/text/tutorials/classify_text_with_bert)
+to learn more about how to incorporate the preprocessing modules directly inside the
+models.
+"""
+
+
+def dataframe_to_dataset(dataframe):
+ columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
+ ds = UnifiedPyDataset(
+ dataframe,
+ batch_size=32,
+ workers=4,
+ )
+ return ds
+
+
+"""
+### Preprocessing utilities
+"""
+
+bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
+
+
+def preprocess_text(text_1, text_2):
+ output = text_preprocessor([text_1, text_2])
+ output = {
+ feature: keras.ops.reshape(output[feature], [-1])
+ for feature in bert_input_features
+ }
+ return output
+
+
+"""
+### Create the final datasets, method adapted from PyDataset doc string.
+"""
+
+
+class UnifiedPyDataset(PyDataset):
+ """A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""
+
+ def __init__(
+ self,
+ df,
+ batch_size=32,
+ workers=4,
+ use_multiprocessing=False,
+ max_queue_size=10,
+ **kwargs,
+ ):
+ """
+ Args:
+ df: pandas DataFrame with data
+ batch_size: Batch size for dataset
+ workers: Number of workers to use for parallel loading (Keras)
+ use_multiprocessing: Whether to use multiprocessing
+ max_queue_size: Maximum size of the data queue for parallel loading
+ """
+ super().__init__(**kwargs)
+ self.dataframe = df
+ columns = ["image_1_path", "image_2_path", "text_1", "text_2"]
+ # image files
+ self.image_x_1 = self.dataframe["image_1_path"]
+ self.image_x_2 = self.dataframe["image_1_path"]
+ self.image_y = self.dataframe["label_idx"]
+ # text files
+ self.text_x_1 = self.dataframe["text_1"]
+ self.text_x_2 = self.dataframe["text_2"]
+ self.text_y = self.dataframe["label_idx"]
+ # general
+ self.batch_size = batch_size
+ self.workers = workers
+ self.use_multiprocessing = use_multiprocessing
+ self.max_queue_size = max_queue_size
+
+ def __getitem__(self, index):
+ """
+ Fetches a batch of data from the dataset at the given index.
+ """
+
+ # Return x, y for batch idx.
+ low = index * self.batch_size
+ # Cap upper bound at array length; the last batch may be smaller
+ # if the total number of items is not a multiple of batch size.
+ # image files
+ high_image_1 = min(low + self.batch_size, len(self.image_x_1))
+ high_image_2 = min(low + self.batch_size, len(self.image_x_2))
+ # text
+ high_text_1 = min(low + self.batch_size, len(self.text_x_1))
+ high_text_2 = min(low + self.batch_size, len(self.text_x_1))
+ # images files
+ batch_image_x_1 = self.image_x_1[low:high_image_1]
+ batch_image_y_1 = self.image_y[low:high_image_1]
+ batch_image_x_2 = self.image_x_2[low:high_image_2]
+ batch_image_y_2 = self.image_y[low:high_image_2]
+ # text files
+ batch_text_x_1 = self.text_x_1[low:high_text_1]
+ batch_text_y_1 = self.text_y[low:high_text_1]
+ batch_text_x_2 = self.text_x_2[low:high_text_2]
+ batch_text_y_2 = self.text_y[low:high_text_2]
+ # image number 1 inputs
+ image_1 = [
+ resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
+ ]
+ image_1 = [
+ ( # exeperienced some shapes which were different from others.
+ np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
+ if img.shape[2] == 4
+ else img
+ )
+ for img in image_1
+ ]
+ image_1 = np.array(image_1)
+ # Both text inputs to the model, return a dict for inputs to BertBackbone
+ text = {
+ key: np.array(
+ [
+ d[key]
+ for d in [
+ preprocess_text(file_path1, file_path2)
+ for file_path1, file_path2 in zip(
+ batch_text_x_1, batch_text_x_2
+ )
+ ]
+ ]
+ )
+ for key in ["padding_mask", "token_ids", "segment_ids"]
+ }
+ # Image number 2 model inputs
+ image_2 = [
+ resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
+ ]
+ image_2 = [
+ ( # exeperienced some shapes which were different from others
+ np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
+ if img.shape[2] == 4
+ else img
+ )
+ for img in image_2
+ ]
+ # Stack the list comprehension to an nd.array
+ image_2 = np.array(image_2)
+ return (
+ {
+ "image_1": image_1,
+ "image_2": image_2,
+ "padding_mask": text["padding_mask"],
+ "segment_ids": text["segment_ids"],
+ "token_ids": text["token_ids"],
+ },
+ # Target lables
+ np.array(batch_image_y_1),
+ )
+
+ def __len__(self):
+ """
+ Returns the number of batches in the dataset.
+ """
+ return math.ceil(len(self.dataframe) / self.batch_size)
+
+
+"""
+Create train, validation and test datasets
+"""
+
+
+def prepare_dataset(dataframe):
+ ds = dataframe_to_dataset(dataframe)
+ return ds
+
+
+train_ds = prepare_dataset(train_df)
+validation_ds = prepare_dataset(val_df)
+test_ds = prepare_dataset(test_df)
+
+"""
+## Model building utilities
+
+Our final model will accept two images along with their text counterparts. While the
+images will be directly fed to the model the text inputs will first be preprocessed and
+then will make it into the model. Below is a visual illustration of this approach:
+
+
+
+The model consists of the following elements:
+
+* A standalone encoder for the images. We will use a
+[ResNet50V2](https://arxiv.org/abs/1603.05027) pre-trained on the ImageNet-1k dataset for
+this.
+* A standalone encoder for the images. A pre-trained BERT will be used for this.
+
+After extracting the individual embeddings, they will be projected in an identical space.
+Finally, their projections will be concatenated and be fed to the final classification
+layer.
+
+This is a multi-class classification problem involving the following classes:
+
+* NoEntailment
+* Implies
+* Contradictory
+
+`project_embeddings()`, `create_vision_encoder()`, and `create_text_encoder()` utilities
+are referred from [this example](https://keras.io/examples/nlp/nl_image_search/).
+"""
+
+"""
+Projection utilities
+"""
+
+
+def project_embeddings(
+ embeddings, num_projection_layers, projection_dims, dropout_rate
+):
+ projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
+ for _ in range(num_projection_layers):
+ x = keras.ops.nn.gelu(projected_embeddings)
+ x = keras.layers.Dense(projection_dims)(x)
+ x = keras.layers.Dropout(dropout_rate)(x)
+ x = keras.layers.Add()([projected_embeddings, x])
+ projected_embeddings = keras.layers.LayerNormalization()(x)
+ return projected_embeddings
+
+
+"""
+Vision encoder utilities
+"""
+
+
+def create_vision_encoder(
+ num_projection_layers, projection_dims, dropout_rate, trainable=False
+):
+ # Load the pre-trained ResNet50V2 model to be used as the base encoder.
+ resnet_v2 = keras.applications.ResNet50V2(
+ include_top=False, weights="imagenet", pooling="avg"
+ )
+ # Set the trainability of the base encoder.
+ for layer in resnet_v2.layers:
+ layer.trainable = trainable
+
+ # Receive the images as inputs.
+ image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
+ image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
+
+ # Preprocess the input image.
+ preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
+ preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)
+
+ # Generate the embeddings for the images using the resnet_v2 model
+ # concatenate them.
+ embeddings_1 = resnet_v2(preprocessed_1)
+ embeddings_2 = resnet_v2(preprocessed_2)
+ embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])
+
+ # Project the embeddings produced by the model.
+ outputs = project_embeddings(
+ embeddings, num_projection_layers, projection_dims, dropout_rate
+ )
+ # Create the vision encoder model.
+ return keras.Model([image_1, image_2], outputs, name="vision_encoder")
+
+
+"""
+Text encoder utilities
+"""
+
+
+def create_text_encoder(
+ num_projection_layers, projection_dims, dropout_rate, trainable=False
+):
+ # Load the pre-trained BERT BackBone using KerasHub.
+ bert = keras_hub.models.BertBackbone.from_preset(
+ "bert_base_en_uncased", num_classes=3
+ )
+
+ # Set the trainability of the base encoder.
+ bert.trainable = trainable
+
+ # Receive the text as inputs.
+ bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
+ inputs = {
+ feature: keras.Input(shape=(256,), dtype="int32", name=feature)
+ for feature in bert_input_features
+ }
+
+ # Generate embeddings for the preprocessed text using the BERT model.
+ embeddings = bert(inputs)["pooled_output"]
+
+ # Project the embeddings produced by the model.
+ outputs = project_embeddings(
+ embeddings, num_projection_layers, projection_dims, dropout_rate
+ )
+ # Create the text encoder model.
+ return keras.Model(inputs, outputs, name="text_encoder")
+
+
+"""
+Multimodal model utilities
+"""
+
+
+def create_multimodal_model(
+ num_projection_layers=1,
+ projection_dims=256,
+ dropout_rate=0.1,
+ vision_trainable=False,
+ text_trainable=False,
+):
+ # Receive the images as inputs.
+ image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
+ image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
+
+ # Receive the text as inputs.
+ bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
+ text_inputs = {
+ feature: keras.Input(shape=(256,), dtype="int32", name=feature)
+ for feature in bert_input_features
+ }
+ text_inputs = list(text_inputs.values())
+ # Create the encoders.
+ vision_encoder = create_vision_encoder(
+ num_projection_layers, projection_dims, dropout_rate, vision_trainable
+ )
+ text_encoder = create_text_encoder(
+ num_projection_layers, projection_dims, dropout_rate, text_trainable
+ )
+
+ # Fetch the embedding projections.
+ vision_projections = vision_encoder([image_1, image_2])
+ text_projections = text_encoder(text_inputs)
+
+ # Concatenate the projections and pass through the classification layer.
+ concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
+ outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
+ return keras.Model([image_1, image_2, *text_inputs], outputs)
+
+
+multimodal_model = create_multimodal_model()
+keras.utils.plot_model(multimodal_model, show_shapes=True)
+
+"""
+You can inspect the structure of the individual encoders as well by setting the
+`expand_nested` argument of `plot_model()` to `True`. You are encouraged
+to play with the different hyperparameters involved in building this model and
+observe how the final performance is affected.
+"""
+
+"""
+## Compile and train the model
+"""
+
+multimodal_model.compile(
+ optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
+)
+
+history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
+
+"""
+## Evaluate the model
+"""
+
+_, acc = multimodal_model.evaluate(test_ds)
+print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
+
+"""
+## Additional notes regarding training
+
+**Incorporating regularization**:
+
+The training logs suggest that the model is starting to overfit and may have benefitted
+from regularization. Dropout ([Srivastava et al.](https://jmlr.org/papers/v15/srivastava14a.html))
+is a simple yet powerful regularization technique that we can use in our model.
+But how should we apply it here?
+
+We could always introduce Dropout (`keras.layers.Dropout`) in between different layers of the model.
+But here is another recipe. Our model expects inputs from two different data modalities.
+What if either of the modalities is not present during inference? To account for this,
+we can introduce Dropout to the individual projections just before they get concatenated:
+
+```python
+vision_projections = keras.layers.Dropout(rate)(vision_projections)
+text_projections = keras.layers.Dropout(rate)(text_projections)
+concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
+```
+
+**Attending to what matters**:
+
+Do all parts of the images correspond equally to their textual counterparts? It's likely
+not the case. To make our model only focus on the most important bits of the images that relate
+well to their corresponding textual parts we can use "cross-attention":
+
+```python
+# Embeddings.
+vision_projections = vision_encoder([image_1, image_2])
+text_projections = text_encoder(text_inputs)
+
+# Cross-attention (Luong-style).
+query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
+ [vision_projections, text_projections]
+)
+# Concatenate.
+concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
+contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])
+```
+
+To see this in action, refer to
+[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment_attn.ipynb).
+
+**Handling class imbalance**:
+
+The dataset suffers from class imbalance. Investigating the confusion matrix of the
+above model reveals that it performs poorly on the minority classes. If we had used a
+weighted loss then the training would have been more guided. You can check out
+[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment.ipynb)
+that takes class-imbalance into account during model training.
+
+**Using only text inputs**:
+
+Also, what if we had only incorporated text inputs for the entailment task? Because of
+the nature of the text inputs encountered on social media platforms, text inputs alone
+would have hurt the final performance. Under a similar training setup, by only using
+text inputs we get to 67.14% top-1 accuracy on the same test set. Refer to
+[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/text_entailment.ipynb)
+for details.
+
+Finally, here is a table comparing different approaches taken for the entailment task:
+
+| Type | Standard
Cross-entropy | Loss-weighted
Cross-entropy | Focal Loss |
+|:---: |:---: |:---: |:---: |
+| Multimodal | 77.86% | 67.86% | 86.43% |
+| Only text | 67.14% | 11.43% | 37.86% |
+
+You can check out [this repository](https://git.io/JR0HU) to learn more about how the
+experiments were conducted to obtain these numbers.
+"""
+
+"""
+## Final remarks
+
+* The architecture we used in this example is too large for the number of data points
+available for training. It's going to benefit from more data.
+* We used a smaller variant of the original BERT model. Chances are high that with a
+larger variant, this performance will be improved. TensorFlow Hub
+[provides](https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub)
+a number of different BERT models that you can experiment with.
+* We kept the pre-trained models frozen. Fine-tuning them on the multimodal entailment
+task would could resulted in better performance.
+* We built a simple baseline model for the multimodal entailment task. There are various
+approaches that have been proposed to tackle the entailment problem.
+[This presentation deck](https://docs.google.com/presentation/d/1mAB31BCmqzfedreNZYn4hsKPFmgHA9Kxz219DzyRY3c/edit?usp=sharing)
+from the
+[Recognizing Multimodal Entailment](https://multimodal-entailment.github.io/)
+tutorial provides a comprehensive overview.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/multimodal-entailment)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/multimodal_entailment)
+"""
diff --git a/knowledge_base/nlp/multiple_choice_task_with_transfer_learning.py b/knowledge_base/nlp/multiple_choice_task_with_transfer_learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..757b1abe6798ed929c7c99f91e71f4cd25139748
--- /dev/null
+++ b/knowledge_base/nlp/multiple_choice_task_with_transfer_learning.py
@@ -0,0 +1,560 @@
+"""
+Title: MultipleChoice Task with Transfer Learning
+Author: Md Awsafur Rahman
+Date created: 2023/09/14
+Last modified: 2025/06/16
+Description: Use pre-trained nlp models for multiplechoice task.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we will demonstrate how to perform the **MultipleChoice** task by
+finetuning pre-trained DebertaV3 model. In this task, several candidate answers are
+provided along with a context and the model is trained to select the correct answer
+unlike question answering. We will use SWAG dataset to demonstrate this example.
+"""
+
+"""
+## Setup
+"""
+
+"""shell
+"""
+
+import keras_hub
+import keras
+import tensorflow as tf # For tf.data only.
+
+import numpy as np
+import pandas as pd
+
+import matplotlib.pyplot as plt
+
+"""
+## Dataset
+In this example we'll use **SWAG** dataset for multiplechoice task.
+"""
+
+"""shell
+wget "https://github.com/rowanz/swagaf/archive/refs/heads/master.zip" -O swag.zip
+unzip -q swag.zip
+"""
+
+"""shell
+ls swagaf-master/data
+"""
+
+"""
+## Configuration
+"""
+
+
+class CFG:
+ preset = "deberta_v3_extra_small_en" # Name of pretrained models
+ sequence_length = 200 # Input sequence length
+ seed = 42 # Random seed
+ epochs = 5 # Training epochs
+ batch_size = 8 # Batch size
+ augment = True # Augmentation (Shuffle Options)
+
+
+"""
+## Reproducibility
+Sets value for random seed to produce similar result in each run.
+"""
+
+keras.utils.set_random_seed(CFG.seed)
+
+
+"""
+## Meta Data
+* **train.csv** - will be used for training.
+* `sent1` and `sent2`: these fields show how a sentence starts, and if you put the two
+together, you get the `startphrase` field.
+* `ending_`: suggests a possible ending for how a sentence can end, but only one of
+them is correct.
+ * `label`: identifies the correct sentence ending.
+
+* **val.csv** - similar to `train.csv` but will be used for validation.
+"""
+
+# Train data
+train_df = pd.read_csv(
+ "swagaf-master/data/train.csv", index_col=0
+) # Read CSV file into a DataFrame
+train_df = train_df.sample(frac=0.02)
+print("# Train Data: {:,}".format(len(train_df)))
+
+# Valid data
+valid_df = pd.read_csv(
+ "swagaf-master/data/val.csv", index_col=0
+) # Read CSV file into a DataFrame
+valid_df = valid_df.sample(frac=0.02)
+print("# Valid Data: {:,}".format(len(valid_df)))
+
+"""
+## Contextualize Options
+
+Our approach entails furnishing the model with question and answer pairs, as opposed to
+employing a single question for all five options. In practice, this signifies that for
+the five options, we will supply the model with the same set of five questions combined
+with each respective answer choice (e.g., `(Q + A)`, `(Q + B)`, and so on). This analogy
+draws parallels to the practice of revisiting a question multiple times during an exam to
+promote a deeper understanding of the problem at hand.
+
+> Notably, in the context of SWAG dataset, question is the start of a sentence and
+options are possible ending of that sentence.
+"""
+
+
+# Define a function to create options based on the prompt and choices
+def make_options(row):
+ row["options"] = [
+ f"{row.startphrase}\n{row.ending0}", # Option 0
+ f"{row.startphrase}\n{row.ending1}", # Option 1
+ f"{row.startphrase}\n{row.ending2}", # Option 2
+ f"{row.startphrase}\n{row.ending3}",
+ ] # Option 3
+ return row
+
+
+"""
+Apply the `make_options` function to each row of the dataframe
+"""
+
+train_df = train_df.apply(make_options, axis=1)
+valid_df = valid_df.apply(make_options, axis=1)
+
+"""
+## Preprocessing
+
+**What it does:** The preprocessor takes input strings and transforms them into a
+dictionary (`token_ids`, `padding_mask`) containing preprocessed tensors. This process
+starts with tokenization, where input strings are converted into sequences of token IDs.
+
+**Why it's important:** Initially, raw text data is complex and challenging for modeling
+due to its high dimensionality. By converting text into a compact set of tokens, such as
+transforming `"The quick brown fox"` into `["the", "qu", "##ick", "br", "##own", "fox"]`,
+we simplify the data. Many models rely on special tokens and additional tensors to
+understand input. These tokens help divide input and identify padding, among other tasks.
+Making all sequences the same length through padding boosts computational efficiency,
+making subsequent steps smoother.
+
+Explore the following pages to access the available preprocessing and tokenizer layers in
+**KerasHub**:
+- [Preprocessing](https://keras.io/api/keras_hub/preprocessing_layers/)
+- [Tokenizers](https://keras.io/api/keras_hub/tokenizers/)
+"""
+
+preprocessor = keras_hub.models.DebertaV3Preprocessor.from_preset(
+ preset=CFG.preset, # Name of the model
+ sequence_length=CFG.sequence_length, # Max sequence length, will be padded if shorter
+)
+
+"""
+Now, let's examine what the output shape of the preprocessing layer looks like. The
+output shape of the layer can be represented as $(num\_choices, sequence\_length)$.
+"""
+
+outs = preprocessor(train_df.options.iloc[0]) # Process options for the first row
+
+# Display the shape of each processed output
+for k, v in outs.items():
+ print(k, ":", v.shape)
+
+"""
+We'll use the `preprocessing_fn` function to transform each text option using the
+`dataset.map(preprocessing_fn)` method.
+"""
+
+
+def preprocess_fn(text, label=None):
+ text = preprocessor(text) # Preprocess text
+ return (
+ (text, label) if label is not None else text
+ ) # Return processed text and label if available
+
+
+"""
+## Augmentation
+
+In this notebook, we'll experiment with an interesting augmentation technique,
+`option_shuffle`. Since we're providing the model with one option at a time, we can
+introduce a shuffle to the order of options. For instance, options `[A, C, E, D, B]`
+would be rearranged as `[D, B, A, E, C]`. This practice will help the model focus on the
+content of the options themselves, rather than being influenced by their positions.
+
+**Note:** Even though `option_shuffle` function is written in pure
+tensorflow, it can be used with any backend (e.g. JAX, PyTorch) as it is only used
+in `tf.data.Dataset` pipeline which is compatible with Keras 3 routines.
+"""
+
+
+def option_shuffle(options, labels, prob=0.50, seed=None):
+ if tf.random.uniform([]) > prob: # Shuffle probability check
+ return options, labels
+ # Shuffle indices of options and labels in the same order
+ indices = tf.random.shuffle(tf.range(tf.shape(options)[0]), seed=seed)
+ # Shuffle options and labels
+ options = tf.gather(options, indices)
+ labels = tf.gather(labels, indices)
+ return options, labels
+
+
+"""
+In the following function, we'll merge all augmentation functions to apply to the text.
+These augmentations will be applied to the data using the `dataset.map(augment_fn)`
+approach.
+"""
+
+
+def augment_fn(text, label=None):
+ text, label = option_shuffle(text, label, prob=0.5) # Shuffle the options
+ return (text, label) if label is not None else text
+
+
+"""
+## DataLoader
+
+The code below sets up a robust data flow pipeline using `tf.data.Dataset` for data
+processing. Notable aspects of `tf.data` include its ability to simplify pipeline
+construction and represent components in sequences.
+
+To learn more about `tf.data`, refer to this
+[documentation](https://www.tensorflow.org/guide/data).
+"""
+
+
+def build_dataset(
+ texts,
+ labels=None,
+ batch_size=32,
+ cache=False,
+ augment=False,
+ repeat=False,
+ shuffle=1024,
+):
+ AUTO = tf.data.AUTOTUNE # AUTOTUNE option
+ slices = (
+ (texts,)
+ if labels is None
+ else (texts, keras.utils.to_categorical(labels, num_classes=4))
+ ) # Create slices
+ ds = tf.data.Dataset.from_tensor_slices(slices) # Create dataset from slices
+ ds = ds.cache() if cache else ds # Cache dataset if enabled
+ if augment: # Apply augmentation if enabled
+ ds = ds.map(augment_fn, num_parallel_calls=AUTO)
+ ds = ds.map(preprocess_fn, num_parallel_calls=AUTO) # Map preprocessing function
+ ds = ds.repeat() if repeat else ds # Repeat dataset if enabled
+ opt = tf.data.Options() # Create dataset options
+ if shuffle:
+ ds = ds.shuffle(shuffle, seed=CFG.seed) # Shuffle dataset if enabled
+ opt.experimental_deterministic = False
+ ds = ds.with_options(opt) # Set dataset options
+ ds = ds.batch(batch_size, drop_remainder=True) # Batch dataset
+ ds = ds.prefetch(AUTO) # Prefetch next batch
+ return ds # Return the built dataset
+
+
+"""
+Now let's create train and valid dataloader using above function.
+"""
+
+# Build train dataloader
+train_texts = train_df.options.tolist() # Extract training texts
+train_labels = train_df.label.tolist() # Extract training labels
+train_ds = build_dataset(
+ train_texts,
+ train_labels,
+ batch_size=CFG.batch_size,
+ cache=True,
+ shuffle=True,
+ repeat=True,
+ augment=CFG.augment,
+)
+
+# Build valid dataloader
+valid_texts = valid_df.options.tolist() # Extract validation texts
+valid_labels = valid_df.label.tolist() # Extract validation labels
+valid_ds = build_dataset(
+ valid_texts,
+ valid_labels,
+ batch_size=CFG.batch_size,
+ cache=True,
+ shuffle=False,
+ repeat=False,
+ augment=False,
+)
+
+
+"""
+## LR Schedule
+
+Implementing a learning rate scheduler is crucial for transfer learning. The learning
+rate initiates at `lr_start` and gradually tapers down to `lr_min` using **cosine**
+curve.
+
+**Importance:** A well-structured learning rate schedule is essential for efficient model
+training, ensuring optimal convergence and avoiding issues such as overshooting or
+stagnation.
+"""
+
+import math
+
+
+def get_lr_callback(batch_size=8, mode="cos", epochs=10, plot=False):
+ lr_start, lr_max, lr_min = 1.0e-6, 0.6e-6 * batch_size, 1e-6
+ lr_ramp_ep, lr_sus_ep = 2, 0
+
+ def lrfn(epoch): # Learning rate update function
+ if epoch < lr_ramp_ep:
+ lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
+ elif epoch < lr_ramp_ep + lr_sus_ep:
+ lr = lr_max
+ else:
+ decay_total_epochs, decay_epoch_index = (
+ epochs - lr_ramp_ep - lr_sus_ep + 3,
+ epoch - lr_ramp_ep - lr_sus_ep,
+ )
+ phase = math.pi * decay_epoch_index / decay_total_epochs
+ lr = (lr_max - lr_min) * 0.5 * (1 + math.cos(phase)) + lr_min
+ return lr
+
+ if plot: # Plot lr curve if plot is True
+ plt.figure(figsize=(10, 5))
+ plt.plot(
+ np.arange(epochs),
+ [lrfn(epoch) for epoch in np.arange(epochs)],
+ marker="o",
+ )
+ plt.xlabel("epoch")
+ plt.ylabel("lr")
+ plt.title("LR Scheduler")
+ plt.show()
+
+ return keras.callbacks.LearningRateScheduler(
+ lrfn, verbose=False
+ ) # Create lr callback
+
+
+_ = get_lr_callback(CFG.batch_size, plot=True)
+
+"""
+## Callbacks
+
+The function below will gather all the training callbacks, such as `lr_scheduler`,
+`model_checkpoint`.
+"""
+
+
+def get_callbacks():
+ callbacks = []
+ lr_cb = get_lr_callback(CFG.batch_size) # Get lr callback
+ ckpt_cb = keras.callbacks.ModelCheckpoint(
+ f"best.keras",
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=False,
+ mode="max",
+ ) # Get Model checkpoint callback
+ callbacks.extend([lr_cb, ckpt_cb]) # Add lr and checkpoint callbacks
+ return callbacks # Return the list of callbacks
+
+
+callbacks = get_callbacks()
+
+"""
+## MultipleChoice Model
+
+
+
+
+
+"""
+
+"""
+
+### Pre-trained Models
+
+The `KerasHub` library provides comprehensive, ready-to-use implementations of popular
+NLP model architectures. It features a variety of pre-trained models including `Bert`,
+`Roberta`, `DebertaV3`, and more. In this notebook, we'll showcase the usage of
+`DistillBert`. However, feel free to explore all available models in the [KerasHub
+documentation](https://keras.io/api/keras_hub/models/). Also for a deeper understanding
+of `KerasHub`, refer to the informative [getting started
+guide](https://keras.io/guides/keras_hub/getting_started/).
+
+Our approach involves using `keras_hub.models.XXClassifier` to process each question and
+option pari (e.g. (Q+A), (Q+B), etc.), generating logits. These logits are then combined
+and passed through a softmax function to produce the final output.
+"""
+
+"""
+
+### Classifier for Multiple-Choice Tasks
+
+When dealing with multiple-choice questions, instead of giving the model the question and
+all options together `(Q + A + B + C ...)`, we provide the model with one option at a
+time along with the question. For instance, `(Q + A)`, `(Q + B)`, and so on. Once we have
+the prediction scores (logits) for all options, we combine them using the `Softmax`
+function to get the ultimate result. If we had given all options at once to the model,
+the text's length would increase, making it harder for the model to handle. The picture
+below illustrates this idea:
+
+
+
+
+
+From a coding perspective, remember that we use the same model for all five options, with
+shared weights. Despite the figure suggesting five separate models, they are, in fact,
+one model with shared weights. Another point to consider is the the input shapes of
+Classifier and MultipleChoice.
+
+* Input shape for **Multiple Choice**: $(batch\_size, num\_choices, seq\_length)$
+* Input shape for **Classifier**: $(batch\_size, seq\_length)$
+
+Certainly, it's clear that we can't directly give the data for the multiple-choice task
+to the model because the input shapes don't match. To handle this, we'll use **slicing**.
+This means we'll separate the features of each option, like $feature_{(Q + A)}$ and
+$feature_{(Q + B)}$, and give them one by one to the NLP classifier. After we get the
+prediction scores $logits_{(Q + A)}$ and $logits_{(Q + B)}$ for all the options, we'll
+use the Softmax function, like $\operatorname{Softmax}([logits_{(Q + A)}, logits_{(Q +
+B)}])$, to combine them. This final step helps us make the ultimate decision or choice.
+
+> Note that in the classifier, we set `num_classes=1` instead of `5`. This is because the
+classifier produces a single output for each option. When dealing with five options,
+these individual outputs are joined together and then processed through a softmax
+function to generate the final result, which has a dimension of `5`.
+"""
+
+
+# Selects one option from five
+class SelectOption(keras.layers.Layer):
+ def __init__(self, index, **kwargs):
+ super().__init__(**kwargs)
+ self.index = index
+
+ def call(self, inputs):
+ # Selects a specific slice from the inputs tensor
+ return inputs[:, self.index, :]
+
+ def get_config(self):
+ # For serialize the model
+ base_config = super().get_config()
+ config = {
+ "index": self.index,
+ }
+ return {**base_config, **config}
+
+
+def build_model():
+ # Define input layers
+ inputs = {
+ "token_ids": keras.Input(shape=(4, None), dtype="int32", name="token_ids"),
+ "padding_mask": keras.Input(
+ shape=(4, None), dtype="int32", name="padding_mask"
+ ),
+ }
+ # Create a DebertaV3Classifier model
+ classifier = keras_hub.models.DebertaV3Classifier.from_preset(
+ CFG.preset,
+ preprocessor=None,
+ num_classes=1, # one output per one option, for five options total 5 outputs
+ )
+ logits = []
+ # Loop through each option (Q+A), (Q+B) etc and compute associated logits
+ for option_idx in range(4):
+ option = {
+ k: SelectOption(option_idx, name=f"{k}_{option_idx}")(v)
+ for k, v in inputs.items()
+ }
+ logit = classifier(option)
+ logits.append(logit)
+
+ # Compute final output
+ logits = keras.layers.Concatenate(axis=-1)(logits)
+ outputs = keras.layers.Softmax(axis=-1)(logits)
+ model = keras.Model(inputs, outputs)
+
+ # Compile the model with optimizer, loss, and metrics
+ model.compile(
+ optimizer=keras.optimizers.AdamW(5e-6),
+ loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.02),
+ metrics=[
+ keras.metrics.CategoricalAccuracy(name="accuracy"),
+ ],
+ jit_compile=True,
+ )
+ return model
+
+
+# Build the Build
+model = build_model()
+
+"""
+Let's checkout the model summary to have a better insight on the model.
+"""
+
+model.summary()
+
+"""
+Finally, let's check the model structure visually if everything is in place.
+"""
+
+keras.utils.plot_model(model, show_shapes=True)
+
+"""
+## Training
+"""
+
+# Start training the model
+history = model.fit(
+ train_ds,
+ epochs=CFG.epochs,
+ validation_data=valid_ds,
+ callbacks=callbacks,
+ steps_per_epoch=int(len(train_df) / CFG.batch_size),
+ verbose=1,
+)
+
+"""
+## Inference
+"""
+
+# Make predictions using the trained model on last validation data
+predictions = model.predict(
+ valid_ds,
+ batch_size=CFG.batch_size, # max batch size = valid size
+ verbose=1,
+)
+
+# Format predictions and true answers
+pred_answers = np.arange(4)[np.argsort(-predictions)][:, 0]
+true_answers = valid_df.label.values
+
+# Check 5 Predictions
+print("# Predictions\n")
+for i in range(0, 50, 10):
+ row = valid_df.iloc[i]
+ question = row.startphrase
+ pred_answer = f"ending{pred_answers[i]}"
+ true_answer = f"ending{true_answers[i]}"
+ print(f"โ Sentence {i+1}:\n{question}\n")
+ print(f"โ
True Ending: {true_answer}\n >> {row[true_answer]}\n")
+ print(f"๐ค Predicted Ending: {pred_answer}\n >> {row[pred_answer]}\n")
+ print("-" * 90, "\n")
+
+"""
+## Reference
+* [Multiple Choice with
+HF](https://twitter.com/johnowhitaker/status/1689790373454041089?s=20)
+* [Keras NLP](https://keras.io/api/keras_hub/)
+* [BirdCLEF23: Pretraining is All you Need
+[Train]](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train)
+[Train]](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train)
+* [Triple Stratified KFold with
+TFRecords](https://www.kaggle.com/code/cdeotte/triple-stratified-kfold-with-tfrecords)
+"""
diff --git a/knowledge_base/nlp/ner_transformers.py b/knowledge_base/nlp/ner_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d8c89afb47ca3f1bf13a7234e4e445452ae6bd
--- /dev/null
+++ b/knowledge_base/nlp/ner_transformers.py
@@ -0,0 +1,355 @@
+"""
+Title: Named Entity Recognition using Transformers
+Author: [Varun Singh](https://www.linkedin.com/in/varunsingh2/)
+Date created: 2021/06/23
+Last modified: 2024/04/05
+Description: NER using the Transformers and data from CoNLL 2003 shared task.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+Named Entity Recognition (NER) is the process of identifying named entities in text.
+Example of named entities are: "Person", "Location", "Organization", "Dates" etc. NER is
+essentially a token classification task where every token is classified into one or more
+predetermined categories.
+
+In this exercise, we will train a simple Transformer based model to perform NER. We will
+be using the data from CoNLL 2003 shared task. For more information about the dataset,
+please visit [the dataset website](https://www.clips.uantwerpen.be/conll2003/ner/).
+However, since obtaining this data requires an additional step of getting a free license, we will be using
+HuggingFace's datasets library which contains a processed version of this dataset.
+"""
+
+"""
+## Install the open source datasets library from HuggingFace
+
+We also download the script used to evaluate NER models.
+"""
+
+"""shell
+pip3 install datasets
+wget https://raw.githubusercontent.com/sighsmile/conlleval/master/conlleval.py
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import ops
+import numpy as np
+import tensorflow as tf
+from keras import layers
+from datasets import load_dataset
+from collections import Counter
+from conlleval import evaluate
+
+"""
+We will be using the transformer implementation from this fantastic
+[example](https://keras.io/examples/nlp/text_classification_with_transformer/).
+
+Let's start by defining a `TransformerBlock` layer:
+"""
+
+
+class TransformerBlock(layers.Layer):
+ def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
+ super().__init__()
+ self.att = keras.layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.ffn = keras.Sequential(
+ [
+ keras.layers.Dense(ff_dim, activation="relu"),
+ keras.layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
+ self.dropout1 = keras.layers.Dropout(rate)
+ self.dropout2 = keras.layers.Dropout(rate)
+
+ def call(self, inputs, training=False):
+ attn_output = self.att(inputs, inputs)
+ attn_output = self.dropout1(attn_output, training=training)
+ out1 = self.layernorm1(inputs + attn_output)
+ ffn_output = self.ffn(out1)
+ ffn_output = self.dropout2(ffn_output, training=training)
+ return self.layernorm2(out1 + ffn_output)
+
+
+"""
+Next, let's define a `TokenAndPositionEmbedding` layer:
+"""
+
+
+class TokenAndPositionEmbedding(layers.Layer):
+ def __init__(self, maxlen, vocab_size, embed_dim):
+ super().__init__()
+ self.token_emb = keras.layers.Embedding(
+ input_dim=vocab_size, output_dim=embed_dim
+ )
+ self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
+
+ def call(self, inputs):
+ maxlen = ops.shape(inputs)[-1]
+ positions = ops.arange(start=0, stop=maxlen, step=1)
+ position_embeddings = self.pos_emb(positions)
+ token_embeddings = self.token_emb(inputs)
+ return token_embeddings + position_embeddings
+
+
+"""
+## Build the NER model class as a `keras.Model` subclass
+"""
+
+
+class NERModel(keras.Model):
+ def __init__(
+ self, num_tags, vocab_size, maxlen=128, embed_dim=32, num_heads=2, ff_dim=32
+ ):
+ super().__init__()
+ self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
+ self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
+ self.dropout1 = layers.Dropout(0.1)
+ self.ff = layers.Dense(ff_dim, activation="relu")
+ self.dropout2 = layers.Dropout(0.1)
+ self.ff_final = layers.Dense(num_tags, activation="softmax")
+
+ def call(self, inputs, training=False):
+ x = self.embedding_layer(inputs)
+ x = self.transformer_block(x)
+ x = self.dropout1(x, training=training)
+ x = self.ff(x)
+ x = self.dropout2(x, training=training)
+ x = self.ff_final(x)
+ return x
+
+
+"""
+## Load the CoNLL 2003 dataset from the datasets library and process it
+"""
+
+conll_data = load_dataset("conll2003")
+
+"""
+We will export this data to a tab-separated file format which will be easy to read as a
+`tf.data.Dataset` object.
+"""
+
+
+def export_to_file(export_file_path, data):
+ with open(export_file_path, "w") as f:
+ for record in data:
+ ner_tags = record["ner_tags"]
+ tokens = record["tokens"]
+ if len(tokens) > 0:
+ f.write(
+ str(len(tokens))
+ + "\t"
+ + "\t".join(tokens)
+ + "\t"
+ + "\t".join(map(str, ner_tags))
+ + "\n"
+ )
+
+
+os.mkdir("data")
+export_to_file("./data/conll_train.txt", conll_data["train"])
+export_to_file("./data/conll_val.txt", conll_data["validation"])
+
+"""
+## Make the NER label lookup table
+
+NER labels are usually provided in IOB, IOB2 or IOBES formats. Checkout this link for
+more information:
+[Wikipedia](https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging))
+
+Note that we start our label numbering from 1 since 0 will be reserved for padding. We
+have a total of 10 labels: 9 from the NER dataset and one for padding.
+"""
+
+
+def make_tag_lookup_table():
+ iob_labels = ["B", "I"]
+ ner_labels = ["PER", "ORG", "LOC", "MISC"]
+ all_labels = [(label1, label2) for label2 in ner_labels for label1 in iob_labels]
+ all_labels = ["-".join([a, b]) for a, b in all_labels]
+ all_labels = ["[PAD]", "O"] + all_labels
+ return dict(zip(range(0, len(all_labels) + 1), all_labels))
+
+
+mapping = make_tag_lookup_table()
+print(mapping)
+
+"""
+Get a list of all tokens in the training dataset. This will be used to create the
+vocabulary.
+"""
+
+all_tokens = sum(conll_data["train"]["tokens"], [])
+all_tokens_array = np.array(list(map(str.lower, all_tokens)))
+
+counter = Counter(all_tokens_array)
+print(len(counter))
+
+num_tags = len(mapping)
+vocab_size = 20000
+
+# We only take (vocab_size - 2) most commons words from the training data since
+# the `StringLookup` class uses 2 additional tokens - one denoting an unknown
+# token and another one denoting a masking token
+vocabulary = [token for token, count in counter.most_common(vocab_size - 2)]
+
+# The StringLook class will convert tokens to token IDs
+lookup_layer = keras.layers.StringLookup(vocabulary=vocabulary)
+
+"""
+Create 2 new `Dataset` objects from the training and validation data
+"""
+
+train_data = tf.data.TextLineDataset("./data/conll_train.txt")
+val_data = tf.data.TextLineDataset("./data/conll_val.txt")
+
+"""
+Print out one line to make sure it looks good. The first record in the line is the number of tokens.
+After that we will have all the tokens followed by all the ner tags.
+"""
+
+print(list(train_data.take(1).as_numpy_iterator()))
+
+"""
+We will be using the following map function to transform the data in the dataset:
+"""
+
+
+def map_record_to_training_data(record):
+ record = tf.strings.split(record, sep="\t")
+ length = tf.strings.to_number(record[0], out_type=tf.int32)
+ tokens = record[1 : length + 1]
+ tags = record[length + 1 :]
+ tags = tf.strings.to_number(tags, out_type=tf.int64)
+ tags += 1
+ return tokens, tags
+
+
+def lowercase_and_convert_to_ids(tokens):
+ tokens = tf.strings.lower(tokens)
+ return lookup_layer(tokens)
+
+
+# We use `padded_batch` here because each record in the dataset has a
+# different length.
+batch_size = 32
+train_dataset = (
+ train_data.map(map_record_to_training_data)
+ .map(lambda x, y: (lowercase_and_convert_to_ids(x), y))
+ .padded_batch(batch_size)
+)
+val_dataset = (
+ val_data.map(map_record_to_training_data)
+ .map(lambda x, y: (lowercase_and_convert_to_ids(x), y))
+ .padded_batch(batch_size)
+)
+
+ner_model = NERModel(num_tags, vocab_size, embed_dim=32, num_heads=4, ff_dim=64)
+
+"""
+We will be using a custom loss function that will ignore the loss from padded tokens.
+"""
+
+
+class CustomNonPaddingTokenLoss(keras.losses.Loss):
+ def __init__(self, name="custom_ner_loss"):
+ super().__init__(name=name)
+
+ def call(self, y_true, y_pred):
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(
+ from_logits=False, reduction=None
+ )
+ loss = loss_fn(y_true, y_pred)
+ mask = ops.cast((y_true > 0), dtype="float32")
+ loss = loss * mask
+ return ops.sum(loss) / ops.sum(mask)
+
+
+loss = CustomNonPaddingTokenLoss()
+
+"""
+## Compile and fit the model
+"""
+
+tf.config.run_functions_eagerly(True)
+ner_model.compile(optimizer="adam", loss=loss)
+ner_model.fit(train_dataset, epochs=10)
+
+
+def tokenize_and_convert_to_ids(text):
+ tokens = text.split()
+ return lowercase_and_convert_to_ids(tokens)
+
+
+# Sample inference using the trained model
+sample_input = tokenize_and_convert_to_ids(
+ "eu rejects german call to boycott british lamb"
+)
+sample_input = ops.reshape(sample_input, shape=[1, -1])
+print(sample_input)
+
+output = ner_model.predict(sample_input)
+prediction = np.argmax(output, axis=-1)[0]
+prediction = [mapping[i] for i in prediction]
+
+# eu -> B-ORG, german -> B-MISC, british -> B-MISC
+print(prediction)
+
+"""
+## Metrics calculation
+
+Here is a function to calculate the metrics. The function calculates F1 score for the
+overall NER dataset as well as individual scores for each NER tag.
+"""
+
+
+def calculate_metrics(dataset):
+ all_true_tag_ids, all_predicted_tag_ids = [], []
+
+ for x, y in dataset:
+ output = ner_model.predict(x, verbose=0)
+ predictions = ops.argmax(output, axis=-1)
+ predictions = ops.reshape(predictions, [-1])
+
+ true_tag_ids = ops.reshape(y, [-1])
+
+ mask = (true_tag_ids > 0) & (predictions > 0)
+ true_tag_ids = true_tag_ids[mask]
+ predicted_tag_ids = predictions[mask]
+
+ all_true_tag_ids.append(true_tag_ids)
+ all_predicted_tag_ids.append(predicted_tag_ids)
+
+ all_true_tag_ids = np.concatenate(all_true_tag_ids)
+ all_predicted_tag_ids = np.concatenate(all_predicted_tag_ids)
+
+ predicted_tags = [mapping[tag] for tag in all_predicted_tag_ids]
+ real_tags = [mapping[tag] for tag in all_true_tag_ids]
+
+ evaluate(real_tags, predicted_tags)
+
+
+calculate_metrics(val_dataset)
+
+"""
+## Conclusions
+
+In this exercise, we created a simple transformer based named entity recognition model.
+We trained it on the CoNLL 2003 shared task data and got an overall F1 score of around 70%.
+State of the art NER models fine-tuned on pretrained models such as BERT or ELECTRA can easily
+get much higher F1 score -between 90-95% on this dataset owing to the inherent knowledge
+of words as part of the pretraining process and the usage of subword tokenization.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/ner-with-transformers)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/ner_with_transformers)."""
diff --git a/knowledge_base/nlp/neural_machine_translation_with_keras_hub.py b/knowledge_base/nlp/neural_machine_translation_with_keras_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcd260c5c42183041eaadb93cf57c9f2c797a995
--- /dev/null
+++ b/knowledge_base/nlp/neural_machine_translation_with_keras_hub.py
@@ -0,0 +1,510 @@
+"""
+Title: English-to-Spanish translation with KerasHub
+Author: [Abheesht Sharma](https://github.com/abheesht17/)
+Date created: 2022/05/26
+Last modified: 2024/04/30
+Description: Use KerasHub to train a sequence-to-sequence Transformer model on the machine translation task.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+KerasHub provides building blocks for NLP (model layers, tokenizers, metrics, etc.) and
+makes it convenient to construct NLP pipelines.
+
+In this example, we'll use KerasHub layers to build an encoder-decoder Transformer
+model, and train it on the English-to-Spanish machine translation task.
+
+This example is based on the
+[English-to-Spanish NMT
+example](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)
+by [fchollet](https://twitter.com/fchollet). The original example is more low-level
+and implements layers from scratch, whereas this example uses KerasHub to show
+some more advanced approaches, such as subword tokenization and using metrics
+to compute the quality of generated translations.
+
+You'll learn how to:
+
+- Tokenize text using `keras_hub.tokenizers.WordPieceTokenizer`.
+- Implement a sequence-to-sequence Transformer model using KerasHub's
+`keras_hub.layers.TransformerEncoder`, `keras_hub.layers.TransformerDecoder` and
+`keras_hub.layers.TokenAndPositionEmbedding` layers, and train it.
+- Use `keras_hub.samplers` to generate translations of unseen input sentences
+ using the top-p decoding strategy!
+
+Don't worry if you aren't familiar with KerasHub. This tutorial will start with
+the basics. Let's dive right in!
+"""
+
+"""
+## Setup
+
+Before we start implementing the pipeline, let's import all the libraries we need.
+"""
+
+"""shell
+pip install -q --upgrade rouge-score
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+import keras_hub
+import pathlib
+import random
+
+import keras
+from keras import ops
+
+import tensorflow.data as tf_data
+from tensorflow_text.tools.wordpiece_vocab import (
+ bert_vocab_from_dataset as bert_vocab,
+)
+
+"""
+Let's also define our parameters/hyperparameters.
+"""
+
+BATCH_SIZE = 64
+EPOCHS = 1 # This should be at least 10 for convergence
+MAX_SEQUENCE_LENGTH = 40
+ENG_VOCAB_SIZE = 15000
+SPA_VOCAB_SIZE = 15000
+
+EMBED_DIM = 256
+INTERMEDIATE_DIM = 2048
+NUM_HEADS = 8
+
+"""
+## Downloading the data
+
+We'll be working with an English-to-Spanish translation dataset
+provided by [Anki](https://www.manythings.org/anki/). Let's download it:
+"""
+
+text_file = keras.utils.get_file(
+ fname="spa-eng.zip",
+ origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
+ extract=True,
+)
+text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"
+
+"""
+## Parsing the data
+
+Each line contains an English sentence and its corresponding Spanish sentence.
+The English sentence is the *source sequence* and Spanish one is the *target sequence*.
+Before adding the text to a list, we convert it to lowercase.
+"""
+
+with open(text_file) as f:
+ lines = f.read().split("\n")[:-1]
+text_pairs = []
+for line in lines:
+ eng, spa = line.split("\t")
+ eng = eng.lower()
+ spa = spa.lower()
+ text_pairs.append((eng, spa))
+
+"""
+Here's what our sentence pairs look like:
+"""
+
+for _ in range(5):
+ print(random.choice(text_pairs))
+
+"""
+Now, let's split the sentence pairs into a training set, a validation set,
+and a test set.
+"""
+
+random.shuffle(text_pairs)
+num_val_samples = int(0.15 * len(text_pairs))
+num_train_samples = len(text_pairs) - 2 * num_val_samples
+train_pairs = text_pairs[:num_train_samples]
+val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
+test_pairs = text_pairs[num_train_samples + num_val_samples :]
+
+print(f"{len(text_pairs)} total pairs")
+print(f"{len(train_pairs)} training pairs")
+print(f"{len(val_pairs)} validation pairs")
+print(f"{len(test_pairs)} test pairs")
+
+
+"""
+## Tokenizing the data
+
+We'll define two tokenizers - one for the source language (English), and the other
+for the target language (Spanish). We'll be using
+`keras_hub.tokenizers.WordPieceTokenizer` to tokenize the text.
+`keras_hub.tokenizers.WordPieceTokenizer` takes a WordPiece vocabulary
+and has functions for tokenizing the text, and detokenizing sequences of tokens.
+
+Before we define the two tokenizers, we first need to train them on the dataset
+we have. The WordPiece tokenization algorithm is a subword tokenization algorithm;
+training it on a corpus gives us a vocabulary of subwords. A subword tokenizer
+is a compromise between word tokenizers (word tokenizers need very large
+vocabularies for good coverage of input words), and character tokenizers
+(characters don't really encode meaning like words do). Luckily, KerasHub
+makes it very simple to train WordPiece on a corpus with the
+`keras_hub.tokenizers.compute_word_piece_vocabulary` utility.
+"""
+
+
+def train_word_piece(text_samples, vocab_size, reserved_tokens):
+ word_piece_ds = tf_data.Dataset.from_tensor_slices(text_samples)
+ vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
+ word_piece_ds.batch(1000).prefetch(2),
+ vocabulary_size=vocab_size,
+ reserved_tokens=reserved_tokens,
+ )
+ return vocab
+
+
+"""
+Every vocabulary has a few special, reserved tokens. We have four such tokens:
+
+- `"[PAD]"` - Padding token. Padding tokens are appended to the input sequence
+length when the input sequence length is shorter than the maximum sequence length.
+- `"[UNK]"` - Unknown token.
+- `"[START]"` - Token that marks the start of the input sequence.
+- `"[END]"` - Token that marks the end of the input sequence.
+"""
+
+reserved_tokens = ["[PAD]", "[UNK]", "[START]", "[END]"]
+
+eng_samples = [text_pair[0] for text_pair in train_pairs]
+eng_vocab = train_word_piece(eng_samples, ENG_VOCAB_SIZE, reserved_tokens)
+
+spa_samples = [text_pair[1] for text_pair in train_pairs]
+spa_vocab = train_word_piece(spa_samples, SPA_VOCAB_SIZE, reserved_tokens)
+
+"""
+Let's see some tokens!
+"""
+
+print("English Tokens: ", eng_vocab[100:110])
+print("Spanish Tokens: ", spa_vocab[100:110])
+
+"""
+Now, let's define the tokenizers. We will configure the tokenizers with the
+the vocabularies trained above.
+"""
+
+eng_tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
+ vocabulary=eng_vocab, lowercase=False
+)
+spa_tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
+ vocabulary=spa_vocab, lowercase=False
+)
+
+"""
+Let's try and tokenize a sample from our dataset! To verify whether the text has
+been tokenized correctly, we can also detokenize the list of tokens back to the
+original text.
+"""
+
+eng_input_ex = text_pairs[0][0]
+eng_tokens_ex = eng_tokenizer.tokenize(eng_input_ex)
+print("English sentence: ", eng_input_ex)
+print("Tokens: ", eng_tokens_ex)
+print(
+ "Recovered text after detokenizing: ",
+ eng_tokenizer.detokenize(eng_tokens_ex),
+)
+
+print()
+
+spa_input_ex = text_pairs[0][1]
+spa_tokens_ex = spa_tokenizer.tokenize(spa_input_ex)
+print("Spanish sentence: ", spa_input_ex)
+print("Tokens: ", spa_tokens_ex)
+print(
+ "Recovered text after detokenizing: ",
+ spa_tokenizer.detokenize(spa_tokens_ex),
+)
+
+"""
+## Format datasets
+
+Next, we'll format our datasets.
+
+At each training step, the model will seek to predict target words N+1 (and beyond)
+using the source sentence and the target words 0 to N.
+
+As such, the training dataset will yield a tuple `(inputs, targets)`, where:
+
+- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
+`encoder_inputs` is the tokenized source sentence and `decoder_inputs` is the target
+sentence "so far",
+that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target
+sentence.
+- `target` is the target sentence offset by one step:
+it provides the next words in the target sentence -- what the model will try to predict.
+
+We will add special tokens, `"[START]"` and `"[END]"`, to the input Spanish
+sentence after tokenizing the text. We will also pad the input to a fixed length.
+This can be easily done using `keras_hub.layers.StartEndPacker`.
+"""
+
+
+def preprocess_batch(eng, spa):
+ batch_size = ops.shape(spa)[0]
+
+ eng = eng_tokenizer(eng)
+ spa = spa_tokenizer(spa)
+
+ # Pad `eng` to `MAX_SEQUENCE_LENGTH`.
+ eng_start_end_packer = keras_hub.layers.StartEndPacker(
+ sequence_length=MAX_SEQUENCE_LENGTH,
+ pad_value=eng_tokenizer.token_to_id("[PAD]"),
+ )
+ eng = eng_start_end_packer(eng)
+
+ # Add special tokens (`"[START]"` and `"[END]"`) to `spa` and pad it as well.
+ spa_start_end_packer = keras_hub.layers.StartEndPacker(
+ sequence_length=MAX_SEQUENCE_LENGTH + 1,
+ start_value=spa_tokenizer.token_to_id("[START]"),
+ end_value=spa_tokenizer.token_to_id("[END]"),
+ pad_value=spa_tokenizer.token_to_id("[PAD]"),
+ )
+ spa = spa_start_end_packer(spa)
+
+ return (
+ {
+ "encoder_inputs": eng,
+ "decoder_inputs": spa[:, :-1],
+ },
+ spa[:, 1:],
+ )
+
+
+def make_dataset(pairs):
+ eng_texts, spa_texts = zip(*pairs)
+ eng_texts = list(eng_texts)
+ spa_texts = list(spa_texts)
+ dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))
+ dataset = dataset.batch(BATCH_SIZE)
+ dataset = dataset.map(preprocess_batch, num_parallel_calls=tf_data.AUTOTUNE)
+ return dataset.shuffle(2048).prefetch(16).cache()
+
+
+train_ds = make_dataset(train_pairs)
+val_ds = make_dataset(val_pairs)
+
+"""
+Let's take a quick look at the sequence shapes
+(we have batches of 64 pairs, and all sequences are 40 steps long):
+"""
+
+for inputs, targets in train_ds.take(1):
+ print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
+ print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
+ print(f"targets.shape: {targets.shape}")
+
+
+"""
+## Building the model
+
+Now, let's move on to the exciting part - defining our model!
+We first need an embedding layer, i.e., a vector for every token in our input sequence.
+This embedding layer can be initialised randomly. We also need a positional
+embedding layer which encodes the word order in the sequence. The convention is
+to add these two embeddings. KerasHub has a `keras_hub.layers.TokenAndPositionEmbedding `
+layer which does all of the above steps for us.
+
+Our sequence-to-sequence Transformer consists of a `keras_hub.layers.TransformerEncoder`
+layer and a `keras_hub.layers.TransformerDecoder` layer chained together.
+
+The source sequence will be passed to `keras_hub.layers.TransformerEncoder`, which
+will produce a new representation of it. This new representation will then be passed
+to the `keras_hub.layers.TransformerDecoder`, together with the target sequence
+so far (target words 0 to N). The `keras_hub.layers.TransformerDecoder` will
+then seek to predict the next words in the target sequence (N+1 and beyond).
+
+A key detail that makes this possible is causal masking.
+The `keras_hub.layers.TransformerDecoder` sees the entire sequence at once, and
+thus we must make sure that it only uses information from target tokens 0 to N
+when predicting token N+1 (otherwise, it could use information from the future,
+which would result in a model that cannot be used at inference time). Causal masking
+is enabled by default in `keras_hub.layers.TransformerDecoder`.
+
+We also need to mask the padding tokens (`"[PAD]"`). For this, we can set the
+`mask_zero` argument of the `keras_hub.layers.TokenAndPositionEmbedding` layer
+to True. This will then be propagated to all subsequent layers.
+"""
+
+# Encoder
+encoder_inputs = keras.Input(shape=(None,), name="encoder_inputs")
+
+x = keras_hub.layers.TokenAndPositionEmbedding(
+ vocabulary_size=ENG_VOCAB_SIZE,
+ sequence_length=MAX_SEQUENCE_LENGTH,
+ embedding_dim=EMBED_DIM,
+)(encoder_inputs)
+
+encoder_outputs = keras_hub.layers.TransformerEncoder(
+ intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
+)(inputs=x)
+encoder = keras.Model(encoder_inputs, encoder_outputs)
+
+
+# Decoder
+decoder_inputs = keras.Input(shape=(None,), name="decoder_inputs")
+encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM), name="decoder_state_inputs")
+
+x = keras_hub.layers.TokenAndPositionEmbedding(
+ vocabulary_size=SPA_VOCAB_SIZE,
+ sequence_length=MAX_SEQUENCE_LENGTH,
+ embedding_dim=EMBED_DIM,
+)(decoder_inputs)
+
+x = keras_hub.layers.TransformerDecoder(
+ intermediate_dim=INTERMEDIATE_DIM, num_heads=NUM_HEADS
+)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
+x = keras.layers.Dropout(0.5)(x)
+decoder_outputs = keras.layers.Dense(SPA_VOCAB_SIZE, activation="softmax")(x)
+decoder = keras.Model(
+ [
+ decoder_inputs,
+ encoded_seq_inputs,
+ ],
+ decoder_outputs,
+)
+decoder_outputs = decoder([decoder_inputs, encoder_outputs])
+
+transformer = keras.Model(
+ [encoder_inputs, decoder_inputs],
+ decoder_outputs,
+ name="transformer",
+)
+
+"""
+## Training our model
+
+We'll use accuracy as a quick way to monitor training progress on the validation data.
+Note that machine translation typically uses BLEU scores as well as other metrics,
+rather than accuracy. However, in order to use metrics like ROUGE, BLEU, etc. we
+will have decode the probabilities and generate the text. Text generation is
+computationally expensive, and performing this during training is not recommended.
+
+Here we only train for 1 epoch, but to get the model to actually converge
+you should train for at least 10 epochs.
+"""
+
+transformer.summary()
+transformer.compile(
+ "rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
+)
+transformer.fit(train_ds, epochs=EPOCHS, validation_data=val_ds)
+
+"""
+## Decoding test sentences (qualitative analysis)
+
+Finally, let's demonstrate how to translate brand new English sentences.
+We simply feed into the model the tokenized English sentence
+as well as the target token `"[START]"`. The model outputs probabilities of the
+next token. We then we repeatedly generated the next token conditioned on the
+tokens generated so far, until we hit the token `"[END]"`.
+
+For decoding, we will use the `keras_hub.samplers` module from
+KerasHub. Greedy Decoding is a text decoding method which outputs the most
+likely next token at each time step, i.e., the token with the highest probability.
+"""
+
+
+def decode_sequences(input_sentences):
+ batch_size = 1
+
+ # Tokenize the encoder input.
+ encoder_input_tokens = ops.convert_to_tensor(eng_tokenizer(input_sentences))
+ if len(encoder_input_tokens[0]) < MAX_SEQUENCE_LENGTH:
+ pads = ops.full((1, MAX_SEQUENCE_LENGTH - len(encoder_input_tokens[0])), 0)
+ encoder_input_tokens = ops.concatenate(
+ [encoder_input_tokens.to_tensor(), pads], 1
+ )
+
+ # Define a function that outputs the next token's probability given the
+ # input sequence.
+ def next(prompt, cache, index):
+ logits = transformer([encoder_input_tokens, prompt])[:, index - 1, :]
+ # Ignore hidden states for now; only needed for contrastive search.
+ hidden_states = None
+ return logits, hidden_states, cache
+
+ # Build a prompt of length 40 with a start token and padding tokens.
+ length = 40
+ start = ops.full((batch_size, 1), spa_tokenizer.token_to_id("[START]"))
+ pad = ops.full((batch_size, length - 1), spa_tokenizer.token_to_id("[PAD]"))
+ prompt = ops.concatenate((start, pad), axis=-1)
+
+ generated_tokens = keras_hub.samplers.GreedySampler()(
+ next,
+ prompt,
+ stop_token_ids=[spa_tokenizer.token_to_id("[END]")],
+ index=1, # Start sampling after start token.
+ )
+ generated_sentences = spa_tokenizer.detokenize(generated_tokens)
+ return generated_sentences
+
+
+test_eng_texts = [pair[0] for pair in test_pairs]
+for i in range(2):
+ input_sentence = random.choice(test_eng_texts)
+ translated = decode_sequences([input_sentence])
+ translated = translated.numpy()[0].decode("utf-8")
+ translated = (
+ translated.replace("[PAD]", "")
+ .replace("[START]", "")
+ .replace("[END]", "")
+ .strip()
+ )
+ print(f"** Example {i} **")
+ print(input_sentence)
+ print(translated)
+ print()
+
+"""
+## Evaluating our model (quantitative analysis)
+
+There are many metrics which are used for text generation tasks. Here, to
+evaluate translations generated by our model, let's compute the ROUGE-1 and
+ROUGE-2 scores. Essentially, ROUGE-N is a score based on the number of common
+n-grams between the reference text and the generated text. ROUGE-1 and ROUGE-2
+use the number of common unigrams and bigrams, respectively.
+
+We will calculate the score over 30 test samples (since decoding is an
+expensive process).
+"""
+
+rouge_1 = keras_hub.metrics.RougeN(order=1)
+rouge_2 = keras_hub.metrics.RougeN(order=2)
+
+for test_pair in test_pairs[:30]:
+ input_sentence = test_pair[0]
+ reference_sentence = test_pair[1]
+
+ translated_sentence = decode_sequences([input_sentence])
+ translated_sentence = translated_sentence.numpy()[0].decode("utf-8")
+ translated_sentence = (
+ translated_sentence.replace("[PAD]", "")
+ .replace("[START]", "")
+ .replace("[END]", "")
+ .strip()
+ )
+
+ rouge_1(reference_sentence, translated_sentence)
+ rouge_2(reference_sentence, translated_sentence)
+
+print("ROUGE-1 Score: ", rouge_1.result())
+print("ROUGE-2 Score: ", rouge_2.result())
+
+"""
+After 10 epochs, the scores are as follows:
+
+| | **ROUGE-1** | **ROUGE-2** |
+|:-------------:|:-----------:|:-----------:|
+| **Precision** | 0.568 | 0.374 |
+| **Recall** | 0.615 | 0.394 |
+| **F1 Score** | 0.579 | 0.381 |
+"""
diff --git a/knowledge_base/nlp/neural_machine_translation_with_transformer.py b/knowledge_base/nlp/neural_machine_translation_with_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..54558533a61aa4f287e83ce5047f58b3e5a36b4c
--- /dev/null
+++ b/knowledge_base/nlp/neural_machine_translation_with_transformer.py
@@ -0,0 +1,507 @@
+"""
+Title: English-to-Spanish translation with a sequence-to-sequence Transformer
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2021/05/26
+Last modified: 2024/11/18
+Description: Implementing a sequence-to-sequence Transformer and training it on a machine translation task.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we'll build a sequence-to-sequence Transformer model, which
+we'll train on an English-to-Spanish machine translation task.
+
+You'll learn how to:
+
+- Vectorize text using the Keras `TextVectorization` layer.
+- Implement a `TransformerEncoder` layer, a `TransformerDecoder` layer,
+and a `PositionalEmbedding` layer.
+- Prepare data for training a sequence-to-sequence model.
+- Use the trained model to generate translations of never-seen-before
+input sentences (sequence-to-sequence inference).
+
+The code featured here is adapted from the book
+[Deep Learning with Python, Second Edition](https://www.manning.com/books/deep-learning-with-python-second-edition)
+(chapter 11: Deep learning for text).
+The present example is fairly barebones, so for detailed explanations of
+how each building block works, as well as the theory behind Transformers,
+I recommend reading the book.
+"""
+"""
+## Setup
+"""
+
+# We set the backend to TensorFlow. The code works with
+# both `tensorflow` and `torch`. It does not work with JAX
+# due to the behavior of `jax.numpy.tile` in a jit scope
+# (used in `TransformerDecoder.get_causal_attention_mask()`:
+# `tile` in JAX does not support a dynamic `reps` argument.
+# You can make the code work in JAX by wrapping the
+# inside of the `get_causal_attention_mask` method in
+# a decorator to prevent jit compilation:
+# `with jax.ensure_compile_time_eval():`.
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import pathlib
+import random
+import string
+import re
+import numpy as np
+
+import tensorflow.data as tf_data
+import tensorflow.strings as tf_strings
+
+import keras
+from keras import layers
+from keras import ops
+from keras.layers import TextVectorization
+
+"""
+## Downloading the data
+
+We'll be working with an English-to-Spanish translation dataset
+provided by [Anki](https://www.manythings.org/anki/). Let's download it:
+"""
+
+text_file = keras.utils.get_file(
+ fname="spa-eng.zip",
+ origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
+ extract=True,
+)
+text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"
+
+"""
+## Parsing the data
+
+Each line contains an English sentence and its corresponding Spanish sentence.
+The English sentence is the *source sequence* and Spanish one is the *target sequence*.
+We prepend the token `"[start]"` and we append the token `"[end]"` to the Spanish sentence.
+"""
+
+with open(text_file) as f:
+ lines = f.read().split("\n")[:-1]
+text_pairs = []
+for line in lines:
+ eng, spa = line.split("\t")
+ spa = "[start] " + spa + " [end]"
+ text_pairs.append((eng, spa))
+
+"""
+Here's what our sentence pairs look like:
+"""
+
+for _ in range(5):
+ print(random.choice(text_pairs))
+
+"""
+Now, let's split the sentence pairs into a training set, a validation set,
+and a test set.
+"""
+
+random.shuffle(text_pairs)
+num_val_samples = int(0.15 * len(text_pairs))
+num_train_samples = len(text_pairs) - 2 * num_val_samples
+train_pairs = text_pairs[:num_train_samples]
+val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
+test_pairs = text_pairs[num_train_samples + num_val_samples :]
+
+print(f"{len(text_pairs)} total pairs")
+print(f"{len(train_pairs)} training pairs")
+print(f"{len(val_pairs)} validation pairs")
+print(f"{len(test_pairs)} test pairs")
+
+"""
+## Vectorizing the text data
+
+We'll use two instances of the `TextVectorization` layer to vectorize the text
+data (one for English and one for Spanish),
+that is to say, to turn the original strings into integer sequences
+where each integer represents the index of a word in a vocabulary.
+
+The English layer will use the default string standardization (strip punctuation characters)
+and splitting scheme (split on whitespace), while
+the Spanish layer will use a custom standardization, where we add the character
+`"ยฟ"` to the set of punctuation characters to be stripped.
+
+Note: in a production-grade machine translation model, I would not recommend
+stripping the punctuation characters in either language. Instead, I would recommend turning
+each punctuation character into its own token,
+which you could achieve by providing a custom `split` function to the `TextVectorization` layer.
+"""
+
+strip_chars = string.punctuation + "ยฟ"
+strip_chars = strip_chars.replace("[", "")
+strip_chars = strip_chars.replace("]", "")
+
+vocab_size = 15000
+sequence_length = 20
+batch_size = 64
+
+
+def custom_standardization(input_string):
+ lowercase = tf_strings.lower(input_string)
+ return tf_strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
+
+
+eng_vectorization = TextVectorization(
+ max_tokens=vocab_size,
+ output_mode="int",
+ output_sequence_length=sequence_length,
+)
+spa_vectorization = TextVectorization(
+ max_tokens=vocab_size,
+ output_mode="int",
+ output_sequence_length=sequence_length + 1,
+ standardize=custom_standardization,
+)
+train_eng_texts = [pair[0] for pair in train_pairs]
+train_spa_texts = [pair[1] for pair in train_pairs]
+eng_vectorization.adapt(train_eng_texts)
+spa_vectorization.adapt(train_spa_texts)
+
+"""
+Next, we'll format our datasets.
+
+At each training step, the model will seek to predict target words N+1 (and beyond)
+using the source sentence and the target words 0 to N.
+
+As such, the training dataset will yield a tuple `(inputs, targets)`, where:
+
+- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
+`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",
+that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
+- `target` is the target sentence offset by one step:
+it provides the next words in the target sentence -- what the model will try to predict.
+"""
+
+
+def format_dataset(eng, spa):
+ eng = eng_vectorization(eng)
+ spa = spa_vectorization(spa)
+ return (
+ {
+ "encoder_inputs": eng,
+ "decoder_inputs": spa[:, :-1],
+ },
+ spa[:, 1:],
+ )
+
+
+def make_dataset(pairs):
+ eng_texts, spa_texts = zip(*pairs)
+ eng_texts = list(eng_texts)
+ spa_texts = list(spa_texts)
+ dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))
+ dataset = dataset.batch(batch_size)
+ dataset = dataset.map(format_dataset)
+ return dataset.cache().shuffle(2048).prefetch(16)
+
+
+train_ds = make_dataset(train_pairs)
+val_ds = make_dataset(val_pairs)
+
+"""
+Let's take a quick look at the sequence shapes
+(we have batches of 64 pairs, and all sequences are 20 steps long):
+"""
+
+for inputs, targets in train_ds.take(1):
+ print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
+ print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
+ print(f"targets.shape: {targets.shape}")
+
+"""
+## Building the model
+
+Our sequence-to-sequence Transformer consists of a `TransformerEncoder`
+and a `TransformerDecoder` chained together. To make the model aware of word order,
+we also use a `PositionalEmbedding` layer.
+
+The source sequence will be pass to the `TransformerEncoder`,
+which will produce a new representation of it.
+This new representation will then be passed
+to the `TransformerDecoder`, together with the target sequence so far (target words 0 to N).
+The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).
+
+A key detail that makes this possible is causal masking
+(see method `get_causal_attention_mask()` on the `TransformerDecoder`).
+The `TransformerDecoder` sees the entire sequences at once, and thus we must make
+sure that it only uses information from target tokens 0 to N when predicting token N+1
+(otherwise, it could use information from the future, which would
+result in a model that cannot be used at inference time).
+"""
+import keras.ops as ops
+
+
+class TransformerEncoder(layers.Layer):
+ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.dense_dim = dense_dim
+ self.num_heads = num_heads
+ self.attention = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.dense_proj = keras.Sequential(
+ [
+ layers.Dense(dense_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ if mask is not None:
+ padding_mask = ops.cast(mask[:, None, :], dtype="int32")
+ else:
+ padding_mask = None
+
+ attention_output = self.attention(
+ query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
+ )
+ proj_input = self.layernorm_1(inputs + attention_output)
+ proj_output = self.dense_proj(proj_input)
+ return self.layernorm_2(proj_input + proj_output)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "embed_dim": self.embed_dim,
+ "dense_dim": self.dense_dim,
+ "num_heads": self.num_heads,
+ }
+ )
+ return config
+
+
+class PositionalEmbedding(layers.Layer):
+ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.token_embeddings = layers.Embedding(
+ input_dim=vocab_size, output_dim=embed_dim
+ )
+ self.position_embeddings = layers.Embedding(
+ input_dim=sequence_length, output_dim=embed_dim
+ )
+ self.sequence_length = sequence_length
+ self.vocab_size = vocab_size
+ self.embed_dim = embed_dim
+
+ def call(self, inputs):
+ length = ops.shape(inputs)[-1]
+ positions = ops.arange(0, length, 1)
+ embedded_tokens = self.token_embeddings(inputs)
+ embedded_positions = self.position_embeddings(positions)
+ return embedded_tokens + embedded_positions
+
+ def compute_mask(self, inputs, mask=None):
+ return ops.not_equal(inputs, 0)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "sequence_length": self.sequence_length,
+ "vocab_size": self.vocab_size,
+ "embed_dim": self.embed_dim,
+ }
+ )
+ return config
+
+
+class TransformerDecoder(layers.Layer):
+ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.latent_dim = latent_dim
+ self.num_heads = num_heads
+ self.attention_1 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.attention_2 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim
+ )
+ self.dense_proj = keras.Sequential(
+ [
+ layers.Dense(latent_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+ self.layernorm_3 = layers.LayerNormalization()
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ inputs, encoder_outputs = inputs
+ causal_mask = self.get_causal_attention_mask(inputs)
+
+ if mask is None:
+ inputs_padding_mask, encoder_outputs_padding_mask = None, None
+ else:
+ inputs_padding_mask, encoder_outputs_padding_mask = mask
+
+ attention_output_1 = self.attention_1(
+ query=inputs,
+ value=inputs,
+ key=inputs,
+ attention_mask=causal_mask,
+ query_mask=inputs_padding_mask,
+ )
+ out_1 = self.layernorm_1(inputs + attention_output_1)
+
+ attention_output_2 = self.attention_2(
+ query=out_1,
+ value=encoder_outputs,
+ key=encoder_outputs,
+ query_mask=inputs_padding_mask,
+ key_mask=encoder_outputs_padding_mask,
+ )
+ out_2 = self.layernorm_2(out_1 + attention_output_2)
+
+ proj_output = self.dense_proj(out_2)
+ return self.layernorm_3(out_2 + proj_output)
+
+ def get_causal_attention_mask(self, inputs):
+ input_shape = ops.shape(inputs)
+ batch_size, sequence_length = input_shape[0], input_shape[1]
+ i = ops.arange(sequence_length)[:, None]
+ j = ops.arange(sequence_length)
+ mask = ops.cast(i >= j, dtype="int32")
+ mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
+ mult = ops.concatenate(
+ [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
+ axis=0,
+ )
+ return ops.tile(mask, mult)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "embed_dim": self.embed_dim,
+ "latent_dim": self.latent_dim,
+ "num_heads": self.num_heads,
+ }
+ )
+ return config
+
+
+"""
+Next, we assemble the end-to-end model.
+"""
+
+embed_dim = 256
+latent_dim = 2048
+num_heads = 8
+
+encoder_inputs = keras.Input(shape=(None,), dtype="int64", name="encoder_inputs")
+x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
+encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
+encoder = keras.Model(encoder_inputs, encoder_outputs)
+
+decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
+encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
+x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
+x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])
+x = layers.Dropout(0.5)(x)
+decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
+decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)
+
+transformer = keras.Model(
+ {"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
+ decoder_outputs,
+ name="transformer",
+)
+
+"""
+## Training our model
+
+We'll use accuracy as a quick way to monitor training progress on the validation data.
+Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.
+
+Here we only train for 1 epoch, but to get the model to actually converge
+you should train for at least 30 epochs.
+"""
+
+epochs = 1 # This should be at least 30 for convergence
+
+transformer.summary()
+transformer.compile(
+ "rmsprop",
+ loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
+ metrics=["accuracy"],
+)
+transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)
+
+"""
+## Decoding test sentences
+
+Finally, let's demonstrate how to translate brand new English sentences.
+We simply feed into the model the vectorized English sentence
+as well as the target token `"[start]"`, then we repeatedly generated the next token, until
+we hit the token `"[end]"`.
+"""
+
+spa_vocab = spa_vectorization.get_vocabulary()
+spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
+max_decoded_sentence_length = 20
+
+
+def decode_sequence(input_sentence):
+ tokenized_input_sentence = eng_vectorization([input_sentence])
+ decoded_sentence = "[start]"
+ for i in range(max_decoded_sentence_length):
+ tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
+ predictions = transformer(
+ {
+ "encoder_inputs": tokenized_input_sentence,
+ "decoder_inputs": tokenized_target_sentence,
+ }
+ )
+
+ # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
+ sampled_token_index = ops.convert_to_numpy(
+ ops.argmax(predictions[0, i, :])
+ ).item(0)
+ sampled_token = spa_index_lookup[sampled_token_index]
+ decoded_sentence += " " + sampled_token
+
+ if sampled_token == "[end]":
+ break
+ return decoded_sentence
+
+
+test_eng_texts = [pair[0] for pair in test_pairs]
+for _ in range(30):
+ input_sentence = random.choice(test_eng_texts)
+ translated = decode_sequence(input_sentence)
+
+"""
+After 30 epochs, we get results such as:
+
+> She handed him the money.
+> [start] ella le pasรณ el dinero [end]
+
+> Tom has never heard Mary sing.
+> [start] tom nunca ha oรญdo cantar a mary [end]
+
+> Perhaps she will come tomorrow.
+> [start] tal vez ella vendrรก maรฑana [end]
+
+> I love to write.
+> [start] me encanta escribir [end]
+
+> His French is improving little by little.
+> [start] su francรฉs va a [UNK] sรณlo un poco [end]
+
+> My hotel told me to call you.
+> [start] mi hotel me dijo que te [UNK] [end]
+"""
diff --git a/knowledge_base/nlp/parameter_efficient_finetuning_of_gpt2_with_lora.py b/knowledge_base/nlp/parameter_efficient_finetuning_of_gpt2_with_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..830cdc09c65b7b3fc37db883387f409dc3fdad86
--- /dev/null
+++ b/knowledge_base/nlp/parameter_efficient_finetuning_of_gpt2_with_lora.py
@@ -0,0 +1,607 @@
+"""
+Title: Parameter-efficient fine-tuning of GPT-2 with LoRA
+Author: [Abheesht Sharma](https://github.com/abheesht17/), [Matthew Watson](https://github.com/mattdangerw/)
+Date created: 2023/05/27
+Last modified: 2023/05/27
+Description: Use KerasHub to fine-tune a GPT-2 LLM with LoRA.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Large Language Models (LLMs) have been shown to be effective at a variety of NLP
+tasks. An LLM is first pre-trained on a large corpus of text in a
+self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge,
+such as statistical relationships between words. An LLM can then be fine-tuned
+on a downstream task of interest (such as sentiment analysis).
+
+However, LLMs are extremely large in size, and we don't need to train all the
+parameters in the model while fine-tuning, especially because datasets on which
+the model is fine-tuned are relatively small. Another way of saying this is
+that LLMs are over-parametrized for fine-tuning. This is where
+[Low-Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685) comes in; it
+significantly reduces the number of trainable parameters. This results in a
+decrease in training time and GPU memory usage, while maintaining the quality
+of the outputs.
+
+In this example, we will explain LoRA in technical terms, show how the technical
+explanation translates to code, hack KerasHub's
+[GPT-2 model](https://keras.io/api/keras_hub/models/gpt2/) and fine-tune
+it on the next token prediction task using LoRA. We will compare LoRA GPT-2
+with a fully fine-tuned GPT-2 in terms of the quality of the generated text,
+training time and GPU memory usage.
+
+Note: This example runs on the TensorFlow backend purely for the
+`tf.config.experimental.get_memory_info` API to easily plot memory usage.
+Outside of the memory usage callback, this example will run on `jax` and `torch`
+backends.
+"""
+
+"""
+## Setup
+
+Before we start implementing the pipeline, let's install and import all the
+libraries we need. We'll be using the KerasHub library.
+
+Secondly, let's enable mixed precision training. This will help us reduce the
+training time.
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras_hub
+import keras
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import tensorflow_datasets as tfds
+import time
+
+keras.mixed_precision.set_global_policy("mixed_float16")
+
+"""
+Let's also define our hyperparameters.
+"""
+
+# General hyperparameters
+BATCH_SIZE = 32
+NUM_BATCHES = 500
+EPOCHS = 1 # Can be set to a higher value for better results
+MAX_SEQUENCE_LENGTH = 128
+MAX_GENERATION_LENGTH = 200
+
+GPT2_PRESET = "gpt2_base_en"
+
+# LoRA-specific hyperparameters
+RANK = 4
+ALPHA = 32.0
+
+
+"""
+## Dataset
+
+Let's load a Reddit dataset. We will fine-tune both the GPT-2 model and the
+LoRA GPT-2 model on a subset of this dataset. The aim is to produce text similar
+in style to Reddit posts.
+"""
+
+reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)
+
+"""
+The dataset has two fields: `document` and `title`.
+"""
+
+for document, title in reddit_ds:
+ print(document.numpy())
+ print(title.numpy())
+ break
+
+"""
+We'll now batch the dataset and retain only the `document` field because we are
+fine-tuning the model on the next word prediction task. Take a subset
+of the dataset for the purpose of this example.
+"""
+
+train_ds = (
+ reddit_ds.map(lambda document, _: document)
+ .batch(BATCH_SIZE)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+train_ds = train_ds.take(NUM_BATCHES)
+
+"""
+## Helper functions
+
+Before we begin fine-tuning the models, let's define a few helper functions and
+classes.
+"""
+
+"""
+### Callback for tracking GPU memory usage
+
+We'll define a custom callback function which tracks GPU memory usage. The
+callback function uses TensorFlow's `tf.config.experimental.get_memory_info`
+API.
+
+Here, we assume that we are using a single GPU, `GPU:0`.
+"""
+
+
+class GPUMemoryCallback(keras.callbacks.Callback):
+ def __init__(
+ self,
+ target_batches,
+ print_stats=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.target_batches = target_batches
+ self.print_stats = print_stats
+
+ self.memory_usage = []
+ self.labels = []
+
+ def _compute_memory_usage(self):
+ memory_stats = tf.config.experimental.get_memory_info("GPU:0")
+ # Convert bytes to GB and store in list.
+ peak_usage = round(memory_stats["peak"] / (2**30), 3)
+ self.memory_usage.append(peak_usage)
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self._compute_memory_usage()
+ self.labels.append(f"epoch {epoch} start")
+
+ def on_train_batch_begin(self, batch, logs=None):
+ if batch in self.target_batches:
+ self._compute_memory_usage()
+ self.labels.append(f"batch {batch}")
+
+ def on_epoch_end(self, epoch, logs=None):
+ self._compute_memory_usage()
+ self.labels.append(f"epoch {epoch} end")
+
+
+"""
+### Function for text generation
+
+Here is a helper function to generate text.
+"""
+
+
+def generate_text(model, input_text, max_length=200):
+ start = time.time()
+
+ output = model.generate(input_text, max_length=max_length)
+ print("\nOutput:")
+ print(output)
+
+ end = time.time()
+ print(f"Total Time Elapsed: {end - start:.2f}s")
+
+
+"""
+### Define optimizer and loss
+
+We will use AdamW optimizer and cross-entropy loss for training both models.
+"""
+
+
+def get_optimizer_and_loss():
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=5e-5,
+ weight_decay=0.01,
+ epsilon=1e-6,
+ global_clipnorm=1.0, # Gradient clipping.
+ )
+ # Exclude layernorm and bias terms from weight decay.
+ optimizer.exclude_from_weight_decay(var_names=["bias"])
+ optimizer.exclude_from_weight_decay(var_names=["gamma"])
+ optimizer.exclude_from_weight_decay(var_names=["beta"])
+
+ loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ return optimizer, loss
+
+
+"""
+## Fine-tune GPT-2
+
+Let's load the model and preprocessor first. We use a sequence length of 128
+instead of 1024 (which is the default sequence length). This will limit our
+ability to predict long sequences, but will allow us to run this example quickly
+on Colab.
+"""
+
+preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
+ "gpt2_base_en",
+ sequence_length=MAX_SEQUENCE_LENGTH,
+)
+gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
+ "gpt2_base_en", preprocessor=preprocessor
+)
+
+gpt2_lm.summary()
+
+"""
+Initialize the GPU memory tracker callback object, and compile the model. We
+use the Adam optimizer with a linearly decaying learning rate.
+"""
+
+gpu_memory_callback = GPUMemoryCallback(
+ target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
+ print_stats=True,
+)
+
+optimizer, loss = get_optimizer_and_loss()
+
+gpt2_lm.compile(
+ optimizer=optimizer,
+ loss=loss,
+ weighted_metrics=["accuracy"],
+)
+
+"""
+We are all set to train the model!
+"""
+
+gpt2_lm.fit(train_ds, epochs=EPOCHS, callbacks=[gpu_memory_callback])
+gpt2_lm_memory_usage = gpu_memory_callback.memory_usage
+
+"""
+As a final step, let's generate some text. We will harness the power of XLA. The
+first call to `generate()` will be slow because of XLA compilation, but
+subsequent calls will be super-fast. :)
+"""
+
+generate_text(gpt2_lm, "I like basketball", max_length=MAX_GENERATION_LENGTH)
+generate_text(gpt2_lm, "That Italian restaurant is", max_length=MAX_GENERATION_LENGTH)
+
+"""
+## LoRA GPT-2
+
+In this section, we discuss the technical details of LoRA, build a LoRA GPT-2
+model, fine-tune it and generate text.
+
+### What exactly is LoRA?
+
+LoRA is a parameter-efficient fine-tuning technique for LLMs. It freezes the
+weights of the LLM, and injects trainable rank-decomposition matrices. Let's
+understand this more clearly.
+
+Assume we have an `n x n` pre-trained dense layer (or weight matrix), `W0`. We
+initialize two dense layers, `A` and `B`, of shapes `n x rank`, and `rank x n`,
+respectively. `rank` is much smaller than `n`. In the paper, values between 1
+and 4 are shown to work well.
+
+
+#### LoRA equation
+
+The original equation is `output = W0x + b0`, where `x` is the input, `W0` and
+`b0` are the weight matrix and bias terms of the original dense layer (frozen).
+The LoRA equation is: `output = W0x + b0 + BAx`, where `A` and `B` are the
+rank-decomposition matrices.
+
+LoRA is based on the idea that updates to the weights of the pre-trained
+language model have a low "intrinsic rank" since pre-trained language models are
+over-parametrized. Predictive performance of full fine-tuning can be replicated
+even by constraining `W0`'s updates to low-rank decomposition matrices.
+
+
+
+
+
+
+#### Number of trainable parameters
+
+Let's do some quick math. Suppose `n` is 768, and `rank` is 4. `W0` has
+`768 x 768 = 589,824` parameters, whereas the LoRA layers, `A` and `B` together
+have `768 x 4 + 4 x 768 = 6,144` parameters. So, for the dense layer, we go from
+`589,824` trainable parameters to `6,144` trainable parameters!
+
+#### Why does LoRA reduce memory footprint?
+
+Even though the total number of parameters increase (since we are adding LoRA
+layers), the memory footprint reduces, because the number of trainable
+parameters reduces. Let's dive deeper into this.
+
+The memory usage of a model can be split into four parts:
+
+- Model memory: This is the memory required to store the model weights. This
+will be slightly higher for LoRA than GPT-2.
+- Forward pass memory: This mostly depends on batch size, sequence length, etc.
+We keep this constant for both models for a fair comparison.
+- Backward pass memory: This is the memory required to store the gradients.
+Note that the gradients are computed only for the trainable parameters.
+- Optimizer memory: This is the memory required to store the optimizer state.
+For example, the Adam optimizer stores the "1st moment vectors" and
+"2nd moment vectors" for the trainable parameters.
+
+Since, with LoRA, there is a huge reduction in the number of trainable
+parameters, the optimizer memory and the memory required to store the gradients
+for LoRA is much less than GPT-2. This is where most of the memory savings
+happen.
+
+#### Why is LoRA so popular?
+
+- Reduces GPU memory usage;
+- Faster training; and
+- No additional inference latency.
+
+### Create LoRA layer
+
+According to the technical description above, let's create a LoRA layer. In
+a transformer model, the LoRA layer is created and injected for the query and
+value projection matrices. In `keras.layers.MultiHeadAttention`, the query/value
+projection layers are `keras.layers.EinsumDense` layers.
+"""
+
+import math
+
+
+class LoraLayer(keras.layers.Layer):
+ def __init__(
+ self,
+ original_layer,
+ rank=8,
+ alpha=32,
+ trainable=False,
+ **kwargs,
+ ):
+ # We want to keep the name of this layer the same as the original
+ # dense layer.
+ original_layer_config = original_layer.get_config()
+ name = original_layer_config["name"]
+
+ kwargs.pop("name", None)
+
+ super().__init__(name=name, trainable=trainable, **kwargs)
+
+ self.rank = rank
+ self.alpha = alpha
+
+ self._scale = alpha / rank
+
+ self._num_heads = original_layer_config["output_shape"][-2]
+ self._hidden_dim = self._num_heads * original_layer_config["output_shape"][-1]
+
+ # Layers.
+
+ # Original dense layer.
+ self.original_layer = original_layer
+ # No matter whether we are training the model or are in inference mode,
+ # this layer should be frozen.
+ self.original_layer.trainable = False
+
+ # LoRA dense layers.
+ self.A = keras.layers.Dense(
+ units=rank,
+ use_bias=False,
+ # Note: the original paper mentions that normal distribution was
+ # used for initialization. However, the official LoRA implementation
+ # uses "Kaiming/He Initialization".
+ kernel_initializer=keras.initializers.VarianceScaling(
+ scale=math.sqrt(5), mode="fan_in", distribution="uniform"
+ ),
+ trainable=trainable,
+ name=f"lora_A",
+ )
+ # B has the same `equation` and `output_shape` as the original layer.
+ # `equation = abc,cde->abde`, where `a`: batch size, `b`: sequence
+ # length, `c`: `hidden_dim`, `d`: `num_heads`,
+ # `e`: `hidden_dim//num_heads`. The only difference is that in layer `B`,
+ # `c` represents `rank`.
+ self.B = keras.layers.EinsumDense(
+ equation=original_layer_config["equation"],
+ output_shape=original_layer_config["output_shape"],
+ kernel_initializer="zeros",
+ trainable=trainable,
+ name=f"lora_B",
+ )
+
+ def call(self, inputs):
+ original_output = self.original_layer(inputs)
+ if self.trainable:
+ # If we are fine-tuning the model, we will add LoRA layers' output
+ # to the original layer's output.
+ lora_output = self.B(self.A(inputs)) * self._scale
+ return original_output + lora_output
+
+ # If we are in inference mode, we "merge" the LoRA layers' weights into
+ # the original layer's weights - more on this in the text generation
+ # section!
+ return original_output
+
+
+"""
+### Inject LoRA layer into the model
+
+We will now hack the original GPT-2 model and inject LoRA layers into it. Let's
+do a couple of things before doing that:
+
+- Delete previous model;
+- Reset "peak" GPU memory usage using `tf.config.experimental.reset_memory_stats`;
+- Load a new GPT-2 model.
+"""
+
+del gpt2_lm
+del optimizer
+del loss
+
+# This resets "peak" memory usage to "current" memory usage.
+tf.config.experimental.reset_memory_stats("GPU:0")
+
+# Load the original model.
+preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
+ "gpt2_base_en",
+ sequence_length=128,
+)
+lora_model = keras_hub.models.GPT2CausalLM.from_preset(
+ "gpt2_base_en",
+ preprocessor=preprocessor,
+)
+
+"""
+We will now override the original query/value projection matrices with our
+new LoRA layers.
+"""
+
+for layer_idx in range(lora_model.backbone.num_layers):
+ # Change query dense layer.
+ decoder_layer = lora_model.backbone.get_layer(f"transformer_layer_{layer_idx}")
+ self_attention_layer = decoder_layer._self_attention_layer
+ # Allow mutation to Keras layer state.
+ self_attention_layer._tracker.locked = False
+
+ # Change query dense layer.
+ self_attention_layer._query_dense = LoraLayer(
+ self_attention_layer._query_dense,
+ rank=RANK,
+ alpha=ALPHA,
+ trainable=True,
+ )
+
+ # Change value dense layer.
+ self_attention_layer._value_dense = LoraLayer(
+ self_attention_layer._value_dense,
+ rank=RANK,
+ alpha=ALPHA,
+ trainable=True,
+ )
+
+"""
+Let's now do a forward pass to make sure we still have a valid chain of
+computation.
+"""
+
+lora_model(preprocessor(["LoRA is very useful for quick LLM finetuning"])[0])
+pass
+
+"""
+Freeze the entire LLM, only the LoRA layers should be trainable.
+"""
+
+for layer in lora_model._flatten_layers():
+ lst_of_sublayers = list(layer._flatten_layers())
+
+ if len(lst_of_sublayers) == 1: # "leaves of the model"
+ if layer.name in ["lora_A", "lora_B"]:
+ layer.trainable = True
+ else:
+ layer.trainable = False
+
+"""
+Print the model's summary and see if the number of non-trainable parameters and
+total parameters are correct.
+
+In a previous section, we had calculated the number of parameters associated with
+the LoRA layers to be 6,144. The total trainable parameters in the model should
+be `num_layers * (query, value) * 6,144 = 12 * 2 * 6,144 = 147,456`. The
+number of non-trainable parameters should be the same as the total number of
+parameters in the original GPT-2 model, which is `124,439,808`.
+"""
+
+lora_model.summary()
+
+"""
+### Fine-tune LoRA GPT-2
+
+Now that we have hacked and verified the LoRA GPT-2 model, let's train it!
+"""
+
+gpu_memory_callback = GPUMemoryCallback(
+ target_batches=[5, 10, 25, 50, 100, 150, 200, 300, 400, 500],
+ print_stats=True,
+)
+
+optimizer, loss = get_optimizer_and_loss()
+
+lora_model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ weighted_metrics=["accuracy"],
+)
+
+lora_model.fit(
+ train_ds,
+ epochs=EPOCHS,
+ callbacks=[gpu_memory_callback],
+)
+lora_model_memory_usage = gpu_memory_callback.memory_usage
+
+"""
+And we are done fine-tuning the model! Before we generate text, let's compare
+the training time and memory usage of the two models. The training time of GPT-2
+on a 16 GB Tesla T4 (Colab) is 7 minutes, and for LoRA, it is 5 minutes, a 30%
+decrease. The memory usage of LoRA GPT-2 is roughly 35% times less than GPT-2.
+"""
+
+plt.bar(
+ ["GPT-2", "LoRA GPT-2"],
+ [max(gpt2_lm_memory_usage), max(lora_model_memory_usage)],
+ color=["red", "blue"],
+)
+
+plt.xlabel("Time")
+plt.ylabel("GPU Memory Usage (in GB)")
+
+plt.title("GPU Memory Usage Comparison")
+plt.legend()
+plt.show()
+
+"""
+### Merge weights and generate text!
+
+One of the biggest advantages of LoRA over other adapter methods is that it
+does not incur any additional inference latency. Let's understand why.
+
+Recall our LoRA equation: `output = W0x + b0 + BAx`. We can rewrite this as:
+`output = = Wx + b0 = (W0 + BA)x + b0`, where `W = W0 + BA`. This means that if
+we merge the weights of the original model and the adapter, we will be essentially
+doing the same computation as the original model!
+"""
+
+for layer_idx in range(lora_model.backbone.num_layers):
+ self_attention_layer = lora_model.backbone.get_layer(
+ f"transformer_layer_{layer_idx}"
+ )._self_attention_layer
+
+ # Merge query dense layer.
+ query_lora_layer = self_attention_layer._query_dense
+
+ A_weights = query_lora_layer.A.kernel # (768, 1) (a, b)
+ B_weights = query_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
+ increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
+ query_lora_layer.original_layer.kernel.assign_add(increment_weights)
+
+ # Merge value dense layer.
+ value_lora_layer = self_attention_layer._value_dense
+
+ A_weights = value_lora_layer.A.kernel # (768, 1) (a, b)
+ B_weights = value_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
+ increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
+ value_lora_layer.original_layer.kernel.assign_add(increment_weights)
+
+ # Put back in place the original layers with updated weights
+ self_attention_layer._query_dense = query_lora_layer.original_layer
+ self_attention_layer._value_dense = value_lora_layer.original_layer
+
+"""
+We are now all set to generate text with our LoRA model :).
+"""
+
+# Freezing weights not necessary during generation since no weights are updated.
+generate_text(lora_model, "I like basketball", max_length=MAX_GENERATION_LENGTH)
+generate_text(
+ lora_model, "That Italian restaurant is", max_length=MAX_GENERATION_LENGTH
+)
+
+"""
+And we're all done!
+"""
diff --git a/knowledge_base/nlp/pretrained_word_embeddings.py b/knowledge_base/nlp/pretrained_word_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a717cd8af9fb495dc18982bd1b6de21e7bd023a
--- /dev/null
+++ b/knowledge_base/nlp/pretrained_word_embeddings.py
@@ -0,0 +1,305 @@
+"""
+Title: Using pre-trained word embeddings
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/05/05
+Last modified: 2020/05/05
+Description: Text classification on the Newsgroup20 dataset using pre-trained GloVe word embeddings.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+# Only the TensorFlow backend supports string inputs.
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import pathlib
+import numpy as np
+import tensorflow.data as tf_data
+import keras
+from keras import layers
+
+"""
+## Introduction
+
+In this example, we show how to train a text classification model that uses pre-trained
+word embeddings.
+
+We'll work with the Newsgroup20 dataset, a set of 20,000 message board messages
+belonging to 20 different topic categories.
+
+For the pre-trained word embeddings, we'll use
+[GloVe embeddings](http://nlp.stanford.edu/projects/glove/).
+"""
+
+"""
+## Download the Newsgroup20 data
+"""
+
+data_path = keras.utils.get_file(
+ "news20.tar.gz",
+ "http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/news20.tar.gz",
+ untar=True,
+)
+
+"""
+## Let's take a look at the data
+"""
+
+data_dir = pathlib.Path(data_path).parent / "20_newsgroup"
+dirnames = os.listdir(data_dir)
+print("Number of directories:", len(dirnames))
+print("Directory names:", dirnames)
+
+fnames = os.listdir(data_dir / "comp.graphics")
+print("Number of files in comp.graphics:", len(fnames))
+print("Some example filenames:", fnames[:5])
+
+"""
+Here's a example of what one file contains:
+"""
+
+print(open(data_dir / "comp.graphics" / "38987").read())
+
+"""
+As you can see, there are header lines that are leaking the file's category, either
+explicitly (the first line is literally the category name), or implicitly, e.g. via the
+`Organization` filed. Let's get rid of the headers:
+"""
+
+samples = []
+labels = []
+class_names = []
+class_index = 0
+for dirname in sorted(os.listdir(data_dir)):
+ class_names.append(dirname)
+ dirpath = data_dir / dirname
+ fnames = os.listdir(dirpath)
+ print("Processing %s, %d files found" % (dirname, len(fnames)))
+ for fname in fnames:
+ fpath = dirpath / fname
+ f = open(fpath, encoding="latin-1")
+ content = f.read()
+ lines = content.split("\n")
+ lines = lines[10:]
+ content = "\n".join(lines)
+ samples.append(content)
+ labels.append(class_index)
+ class_index += 1
+
+print("Classes:", class_names)
+print("Number of samples:", len(samples))
+
+"""
+There's actually one category that doesn't have the expected number of files, but the
+difference is small enough that the problem remains a balanced classification problem.
+"""
+
+"""
+## Shuffle and split the data into training & validation sets
+"""
+
+# Shuffle the data
+seed = 1337
+rng = np.random.RandomState(seed)
+rng.shuffle(samples)
+rng = np.random.RandomState(seed)
+rng.shuffle(labels)
+
+# Extract a training & validation split
+validation_split = 0.2
+num_validation_samples = int(validation_split * len(samples))
+train_samples = samples[:-num_validation_samples]
+val_samples = samples[-num_validation_samples:]
+train_labels = labels[:-num_validation_samples]
+val_labels = labels[-num_validation_samples:]
+
+"""
+## Create a vocabulary index
+
+Let's use the `TextVectorization` to index the vocabulary found in the dataset.
+Later, we'll use the same layer instance to vectorize the samples.
+
+Our layer will only consider the top 20,000 words, and will truncate or pad sequences to
+be actually 200 tokens long.
+"""
+
+vectorizer = layers.TextVectorization(max_tokens=20000, output_sequence_length=200)
+text_ds = tf_data.Dataset.from_tensor_slices(train_samples).batch(128)
+vectorizer.adapt(text_ds)
+
+"""
+You can retrieve the computed vocabulary used via `vectorizer.get_vocabulary()`. Let's
+print the top 5 words:
+"""
+
+vectorizer.get_vocabulary()[:5]
+
+"""
+Let's vectorize a test sentence:
+"""
+
+output = vectorizer([["the cat sat on the mat"]])
+output.numpy()[0, :6]
+
+"""
+As you can see, "the" gets represented as "2". Why not 0, given that "the" was the first
+word in the vocabulary? That's because index 0 is reserved for padding and index 1 is
+reserved for "out of vocabulary" tokens.
+
+Here's a dict mapping words to their indices:
+"""
+
+voc = vectorizer.get_vocabulary()
+word_index = dict(zip(voc, range(len(voc))))
+
+"""
+As you can see, we obtain the same encoding as above for our test sentence:
+"""
+
+test = ["the", "cat", "sat", "on", "the", "mat"]
+[word_index[w] for w in test]
+
+"""
+## Load pre-trained word embeddings
+"""
+
+"""
+Let's download pre-trained GloVe embeddings (a 822M zip file).
+
+You'll need to run the following commands:
+"""
+
+"""shell
+wget https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
+unzip -q glove.6B.zip
+"""
+
+"""
+The archive contains text-encoded vectors of various sizes: 50-dimensional,
+100-dimensional, 200-dimensional, 300-dimensional. We'll use the 100D ones.
+
+Let's make a dict mapping words (strings) to their NumPy vector representation:
+"""
+
+path_to_glove_file = "glove.6B.100d.txt"
+
+embeddings_index = {}
+with open(path_to_glove_file) as f:
+ for line in f:
+ word, coefs = line.split(maxsplit=1)
+ coefs = np.fromstring(coefs, "f", sep=" ")
+ embeddings_index[word] = coefs
+
+print("Found %s word vectors." % len(embeddings_index))
+
+"""
+Now, let's prepare a corresponding embedding matrix that we can use in a Keras
+`Embedding` layer. It's a simple NumPy matrix where entry at index `i` is the pre-trained
+vector for the word of index `i` in our `vectorizer`'s vocabulary.
+"""
+
+num_tokens = len(voc) + 2
+embedding_dim = 100
+hits = 0
+misses = 0
+
+# Prepare embedding matrix
+embedding_matrix = np.zeros((num_tokens, embedding_dim))
+for word, i in word_index.items():
+ embedding_vector = embeddings_index.get(word)
+ if embedding_vector is not None:
+ # Words not found in embedding index will be all-zeros.
+ # This includes the representation for "padding" and "OOV"
+ embedding_matrix[i] = embedding_vector
+ hits += 1
+ else:
+ misses += 1
+print("Converted %d words (%d misses)" % (hits, misses))
+
+
+"""
+Next, we load the pre-trained word embeddings matrix into an `Embedding` layer.
+
+Note that we set `trainable=False` so as to keep the embeddings fixed (we don't want to
+update them during training).
+"""
+
+from keras.layers import Embedding
+
+embedding_layer = Embedding(
+ num_tokens,
+ embedding_dim,
+ trainable=False,
+)
+embedding_layer.build((1,))
+embedding_layer.set_weights([embedding_matrix])
+
+"""
+## Build the model
+
+A simple 1D convnet with global max pooling and a classifier at the end.
+"""
+
+int_sequences_input = keras.Input(shape=(None,), dtype="int32")
+embedded_sequences = embedding_layer(int_sequences_input)
+x = layers.Conv1D(128, 5, activation="relu")(embedded_sequences)
+x = layers.MaxPooling1D(5)(x)
+x = layers.Conv1D(128, 5, activation="relu")(x)
+x = layers.MaxPooling1D(5)(x)
+x = layers.Conv1D(128, 5, activation="relu")(x)
+x = layers.GlobalMaxPooling1D()(x)
+x = layers.Dense(128, activation="relu")(x)
+x = layers.Dropout(0.5)(x)
+preds = layers.Dense(len(class_names), activation="softmax")(x)
+model = keras.Model(int_sequences_input, preds)
+model.summary()
+
+"""
+## Train the model
+
+First, convert our list-of-strings data to NumPy arrays of integer indices. The arrays
+are right-padded.
+"""
+
+x_train = vectorizer(np.array([[s] for s in train_samples])).numpy()
+x_val = vectorizer(np.array([[s] for s in val_samples])).numpy()
+
+y_train = np.array(train_labels)
+y_val = np.array(val_labels)
+
+"""
+We use categorical crossentropy as our loss since we're doing softmax classification.
+Moreover, we use `sparse_categorical_crossentropy` since our labels are integers.
+"""
+
+model.compile(
+ loss="sparse_categorical_crossentropy", optimizer="rmsprop", metrics=["acc"]
+)
+model.fit(x_train, y_train, batch_size=128, epochs=20, validation_data=(x_val, y_val))
+
+"""
+## Export an end-to-end model
+
+Now, we may want to export a `Model` object that takes as input a string of arbitrary
+length, rather than a sequence of indices. It would make the model much more portable,
+since you wouldn't have to worry about the input preprocessing pipeline.
+
+Our `vectorizer` is actually a Keras layer, so it's simple:
+"""
+
+string_input = keras.Input(shape=(1,), dtype="string")
+x = vectorizer(string_input)
+preds = model(x)
+end_to_end_model = keras.Model(string_input, preds)
+
+probabilities = end_to_end_model(
+ keras.ops.convert_to_tensor(
+ [["this message is about computer graphics and 3D modeling"]]
+ )
+)
+
+print(class_names[np.argmax(probabilities[0])])
diff --git a/knowledge_base/nlp/semantic_similarity_with_bert.py b/knowledge_base/nlp/semantic_similarity_with_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ef00e66d22d9be941831cd53ccaeb3a7e81503
--- /dev/null
+++ b/knowledge_base/nlp/semantic_similarity_with_bert.py
@@ -0,0 +1,397 @@
+"""
+Title: Semantic Similarity with BERT
+Author: [Mohamad Merchant](https://twitter.com/mohmadmerchant1)
+Date created: 2020/08/15
+Last modified: 2020/08/29
+Description: Natural Language Inference by fine-tuning BERT model on SNLI Corpus.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Semantic Similarity is the task of determining how similar
+two sentences are, in terms of what they mean.
+This example demonstrates the use of SNLI (Stanford Natural Language Inference) Corpus
+to predict sentence semantic similarity with Transformers.
+We will fine-tune a BERT model that takes two sentences as inputs
+and that outputs a similarity score for these two sentences.
+
+### References
+
+* [BERT](https://arxiv.org/pdf/1810.04805.pdf)
+* [SNLI](https://nlp.stanford.edu/projects/snli/)
+"""
+
+"""
+## Setup
+
+Note: install HuggingFace `transformers` via `pip install transformers` (version >= 2.11.0).
+"""
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import transformers
+
+"""
+## Configuration
+"""
+
+max_length = 128 # Maximum length of input sentence to the model.
+batch_size = 32
+epochs = 2
+
+# Labels in our dataset.
+labels = ["contradiction", "entailment", "neutral"]
+
+"""
+## Load the Data
+"""
+
+"""shell
+curl -LO https://raw.githubusercontent.com/MohamadMerchant/SNLI/master/data.tar.gz
+tar -xvzf data.tar.gz
+"""
+# There are more than 550k samples in total; we will use 100k for this example.
+train_df = pd.read_csv("SNLI_Corpus/snli_1.0_train.csv", nrows=100000)
+valid_df = pd.read_csv("SNLI_Corpus/snli_1.0_dev.csv")
+test_df = pd.read_csv("SNLI_Corpus/snli_1.0_test.csv")
+
+# Shape of the data
+print(f"Total train samples : {train_df.shape[0]}")
+print(f"Total validation samples: {valid_df.shape[0]}")
+print(f"Total test samples: {valid_df.shape[0]}")
+
+"""
+Dataset Overview:
+
+- sentence1: The premise caption that was supplied to the author of the pair.
+- sentence2: The hypothesis caption that was written by the author of the pair.
+- similarity: This is the label chosen by the majority of annotators.
+Where no majority exists, the label "-" is used (we will skip such samples here).
+
+Here are the "similarity" label values in our dataset:
+
+- Contradiction: The sentences share no similarity.
+- Entailment: The sentences have similar meaning.
+- Neutral: The sentences are neutral.
+"""
+
+"""
+Let's look at one sample from the dataset:
+"""
+print(f"Sentence1: {train_df.loc[1, 'sentence1']}")
+print(f"Sentence2: {train_df.loc[1, 'sentence2']}")
+print(f"Similarity: {train_df.loc[1, 'similarity']}")
+
+"""
+## Preprocessing
+"""
+
+# We have some NaN entries in our train data, we will simply drop them.
+print("Number of missing values")
+print(train_df.isnull().sum())
+train_df.dropna(axis=0, inplace=True)
+
+"""
+Distribution of our training targets.
+"""
+print("Train Target Distribution")
+print(train_df.similarity.value_counts())
+
+"""
+Distribution of our validation targets.
+"""
+print("Validation Target Distribution")
+print(valid_df.similarity.value_counts())
+
+"""
+The value "-" appears as part of our training and validation targets.
+We will skip these samples.
+"""
+train_df = (
+ train_df[train_df.similarity != "-"]
+ .sample(frac=1.0, random_state=42)
+ .reset_index(drop=True)
+)
+valid_df = (
+ valid_df[valid_df.similarity != "-"]
+ .sample(frac=1.0, random_state=42)
+ .reset_index(drop=True)
+)
+
+"""
+One-hot encode training, validation, and test labels.
+"""
+train_df["label"] = train_df["similarity"].apply(
+ lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
+)
+y_train = tf.keras.utils.to_categorical(train_df.label, num_classes=3)
+
+valid_df["label"] = valid_df["similarity"].apply(
+ lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
+)
+y_val = tf.keras.utils.to_categorical(valid_df.label, num_classes=3)
+
+test_df["label"] = test_df["similarity"].apply(
+ lambda x: 0 if x == "contradiction" else 1 if x == "entailment" else 2
+)
+y_test = tf.keras.utils.to_categorical(test_df.label, num_classes=3)
+
+"""
+## Create a custom data generator
+"""
+
+
+class BertSemanticDataGenerator(tf.keras.utils.Sequence):
+ """Generates batches of data.
+
+ Args:
+ sentence_pairs: Array of premise and hypothesis input sentences.
+ labels: Array of labels.
+ batch_size: Integer batch size.
+ shuffle: boolean, whether to shuffle the data.
+ include_targets: boolean, whether to include the labels.
+
+ Returns:
+ Tuples `([input_ids, attention_mask, `token_type_ids], labels)`
+ (or just `[input_ids, attention_mask, `token_type_ids]`
+ if `include_targets=False`)
+ """
+
+ def __init__(
+ self,
+ sentence_pairs,
+ labels,
+ batch_size=batch_size,
+ shuffle=True,
+ include_targets=True,
+ ):
+ self.sentence_pairs = sentence_pairs
+ self.labels = labels
+ self.shuffle = shuffle
+ self.batch_size = batch_size
+ self.include_targets = include_targets
+ # Load our BERT Tokenizer to encode the text.
+ # We will use base-base-uncased pretrained model.
+ self.tokenizer = transformers.BertTokenizer.from_pretrained(
+ "bert-base-uncased", do_lower_case=True
+ )
+ self.indexes = np.arange(len(self.sentence_pairs))
+ self.on_epoch_end()
+
+ def __len__(self):
+ # Denotes the number of batches per epoch.
+ return len(self.sentence_pairs) // self.batch_size
+
+ def __getitem__(self, idx):
+ # Retrieves the batch of index.
+ indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]
+ sentence_pairs = self.sentence_pairs[indexes]
+
+ # With BERT tokenizer's batch_encode_plus batch of both the sentences are
+ # encoded together and separated by [SEP] token.
+ encoded = self.tokenizer.batch_encode_plus(
+ sentence_pairs.tolist(),
+ add_special_tokens=True,
+ max_length=max_length,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ pad_to_max_length=True,
+ return_tensors="tf",
+ )
+
+ # Convert batch of encoded features to numpy array.
+ input_ids = np.array(encoded["input_ids"], dtype="int32")
+ attention_masks = np.array(encoded["attention_mask"], dtype="int32")
+ token_type_ids = np.array(encoded["token_type_ids"], dtype="int32")
+
+ # Set to true if data generator is used for training/validation.
+ if self.include_targets:
+ labels = np.array(self.labels[indexes], dtype="int32")
+ return [input_ids, attention_masks, token_type_ids], labels
+ else:
+ return [input_ids, attention_masks, token_type_ids]
+
+ def on_epoch_end(self):
+ # Shuffle indexes after each epoch if shuffle is set to True.
+ if self.shuffle:
+ np.random.RandomState(42).shuffle(self.indexes)
+
+
+"""
+## Build the model
+"""
+# Create the model under a distribution strategy scope.
+strategy = tf.distribute.MirroredStrategy()
+
+with strategy.scope():
+ # Encoded token ids from BERT tokenizer.
+ input_ids = tf.keras.layers.Input(
+ shape=(max_length,), dtype=tf.int32, name="input_ids"
+ )
+ # Attention masks indicates to the model which tokens should be attended to.
+ attention_masks = tf.keras.layers.Input(
+ shape=(max_length,), dtype=tf.int32, name="attention_masks"
+ )
+ # Token type ids are binary masks identifying different sequences in the model.
+ token_type_ids = tf.keras.layers.Input(
+ shape=(max_length,), dtype=tf.int32, name="token_type_ids"
+ )
+ # Loading pretrained BERT model.
+ bert_model = transformers.TFBertModel.from_pretrained("bert-base-uncased")
+ # Freeze the BERT model to reuse the pretrained features without modifying them.
+ bert_model.trainable = False
+
+ bert_output = bert_model.bert(
+ input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids
+ )
+ sequence_output = bert_output.last_hidden_state
+ pooled_output = bert_output.pooler_output
+
+ # Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.
+ bi_lstm = tf.keras.layers.Bidirectional(
+ tf.keras.layers.LSTM(64, return_sequences=True)
+ )(sequence_output)
+ # Applying hybrid pooling approach to bi_lstm sequence output.
+ avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)
+ max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)
+ concat = tf.keras.layers.concatenate([avg_pool, max_pool])
+ dropout = tf.keras.layers.Dropout(0.3)(concat)
+ output = tf.keras.layers.Dense(3, activation="softmax")(dropout)
+ model = tf.keras.models.Model(
+ inputs=[input_ids, attention_masks, token_type_ids], outputs=output
+ )
+
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(),
+ loss="categorical_crossentropy",
+ metrics=["acc"],
+ )
+
+
+print(f"Strategy: {strategy}")
+model.summary()
+
+"""
+Create train and validation data generators
+"""
+train_data = BertSemanticDataGenerator(
+ train_df[["sentence1", "sentence2"]].values.astype("str"),
+ y_train,
+ batch_size=batch_size,
+ shuffle=True,
+)
+valid_data = BertSemanticDataGenerator(
+ valid_df[["sentence1", "sentence2"]].values.astype("str"),
+ y_val,
+ batch_size=batch_size,
+ shuffle=False,
+)
+
+"""
+## Train the Model
+
+Training is done only for the top layers to perform "feature extraction",
+which will allow the model to use the representations of the pretrained model.
+"""
+history = model.fit(
+ train_data,
+ validation_data=valid_data,
+ epochs=epochs,
+ use_multiprocessing=True,
+ workers=-1,
+)
+
+"""
+## Fine-tuning
+
+This step must only be performed after the feature extraction model has
+been trained to convergence on the new data.
+
+This is an optional last step where `bert_model` is unfreezed and retrained
+with a very low learning rate. This can deliver meaningful improvement by
+incrementally adapting the pretrained features to the new data.
+"""
+
+# Unfreeze the bert_model.
+bert_model.trainable = True
+# Recompile the model to make the change effective.
+model.compile(
+ optimizer=tf.keras.optimizers.Adam(1e-5),
+ loss="categorical_crossentropy",
+ metrics=["accuracy"],
+)
+model.summary()
+
+"""
+## Train the entire model end-to-end
+"""
+history = model.fit(
+ train_data,
+ validation_data=valid_data,
+ epochs=epochs,
+ use_multiprocessing=True,
+ workers=-1,
+)
+
+"""
+## Evaluate model on the test set
+"""
+test_data = BertSemanticDataGenerator(
+ test_df[["sentence1", "sentence2"]].values.astype("str"),
+ y_test,
+ batch_size=batch_size,
+ shuffle=False,
+)
+model.evaluate(test_data, verbose=1)
+
+"""
+## Inference on custom sentences
+"""
+
+
+def check_similarity(sentence1, sentence2):
+ sentence_pairs = np.array([[str(sentence1), str(sentence2)]])
+ test_data = BertSemanticDataGenerator(
+ sentence_pairs,
+ labels=None,
+ batch_size=1,
+ shuffle=False,
+ include_targets=False,
+ )
+
+ proba = model.predict(test_data[0])[0]
+ idx = np.argmax(proba)
+ proba = f"{proba[idx]: .2f}%"
+ pred = labels[idx]
+ return pred, proba
+
+
+"""
+Check results on some example sentence pairs.
+"""
+sentence1 = "Two women are observing something together."
+sentence2 = "Two women are standing with their eyes closed."
+check_similarity(sentence1, sentence2)
+"""
+Check results on some example sentence pairs.
+"""
+sentence1 = "A smiling costumed woman is holding an umbrella"
+sentence2 = "A happy woman in a fairy costume holds an umbrella"
+check_similarity(sentence1, sentence2)
+
+"""
+Check results on some example sentence pairs
+"""
+sentence1 = "A soccer game with multiple males playing"
+sentence2 = "Some men are playing a sport"
+check_similarity(sentence1, sentence2)
+
+"""
+Example available on HuggingFace
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/bert-semantic-similarity) | [](https://huggingface.co/spaces/keras-io/bert-semantic-similarity) |
+"""
diff --git a/knowledge_base/nlp/semantic_similarity_with_keras_hub.py b/knowledge_base/nlp/semantic_similarity_with_keras_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..409177406eb8b0f24adefcaf45b53be3b9531f9d
--- /dev/null
+++ b/knowledge_base/nlp/semantic_similarity_with_keras_hub.py
@@ -0,0 +1,323 @@
+"""
+Title: Semantic Similarity with KerasHub
+Author: [Anshuman Mishra](https://github.com/shivance/)
+Date created: 2023/02/25
+Last modified: 2023/02/25
+Description: Use pretrained models from KerasHub for the Semantic Similarity Task.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Semantic similarity refers to the task of determining the degree of similarity between two
+sentences in terms of their meaning. We already saw in [this](https://keras.io/examples/nlp/semantic_similarity_with_bert/)
+example how to use SNLI (Stanford Natural Language Inference) corpus to predict sentence
+semantic similarity with the HuggingFace Transformers library. In this tutorial we will
+learn how to use [KerasHub](https://keras.io/keras_hub/), an extension of the core Keras API,
+for the same task. Furthermore, we will discover how KerasHub effectively reduces boilerplate
+code and simplifies the process of building and utilizing models. For more information on KerasHub,
+please refer to [KerasHub's official documentation](https://keras.io/keras_hub/).
+
+This guide is broken down into the following parts:
+
+1. *Setup*, task definition, and establishing a baseline.
+2. *Establishing baseline* with BERT.
+3. *Saving and Reloading* the model.
+4. *Performing inference* with the model.
+5 *Improving accuracy* with RoBERTa
+
+## Setup
+
+The following guide uses [Keras Core](https://keras.io/keras_core/) to work in
+any of `tensorflow`, `jax` or `torch`. Support for Keras Core is baked into
+KerasHub, simply change the `KERAS_BACKEND` environment variable below to change
+the backend you would like to use. We select the `jax` backend below, which will
+give us a particularly fast train step below.
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+import numpy as np
+import tensorflow as tf
+import keras
+import keras_hub
+import tensorflow_datasets as tfds
+
+"""
+To load the SNLI dataset, we use the tensorflow-datasets library, which
+contains over 550,000 samples in total. However, to ensure that this example runs
+quickly, we use only 20% of the training samples.
+
+## Overview of SNLI Dataset
+
+Every sample in the dataset contains three components: `hypothesis`, `premise`,
+and `label`. epresents the original caption provided to the author of the pair,
+while the hypothesis refers to the hypothesis caption created by the author of
+the pair. The label is assigned by annotators to indicate the similarity between
+the two sentences.
+
+The dataset contains three possible similarity label values: Contradiction, Entailment,
+and Neutral. Contradiction represents completely dissimilar sentences, while Entailment
+denotes similar meaning sentences. Lastly, Neutral refers to sentences where no clear
+similarity or dissimilarity can be established between them.
+"""
+
+snli_train = tfds.load("snli", split="train[:20%]")
+snli_val = tfds.load("snli", split="validation")
+snli_test = tfds.load("snli", split="test")
+
+# Here's an example of how our training samples look like, where we randomly select
+# four samples:
+sample = snli_test.batch(4).take(1).get_single_element()
+sample
+
+"""
+### Preprocessing
+
+In our dataset, we have identified that some samples have missing or incorrectly labeled
+data, which is denoted by a value of -1. To ensure the accuracy and reliability of our model,
+we simply filter out these samples from our dataset.
+"""
+
+
+def filter_labels(sample):
+ return sample["label"] >= 0
+
+
+"""
+Here's a utility function that splits the example into an `(x, y)` tuple that is suitable
+for `model.fit()`. By default, `keras_hub.models.BertClassifier` will tokenize and pack
+together raw strings using a `"[SEP]"` token during training. Therefore, this label
+splitting is all the data preparation that we need to perform.
+"""
+
+
+def split_labels(sample):
+ x = (sample["hypothesis"], sample["premise"])
+ y = sample["label"]
+ return x, y
+
+
+train_ds = (
+ snli_train.filter(filter_labels)
+ .map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(16)
+)
+val_ds = (
+ snli_val.filter(filter_labels)
+ .map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(16)
+)
+test_ds = (
+ snli_test.filter(filter_labels)
+ .map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(16)
+)
+
+
+"""
+## Establishing baseline with BERT.
+
+We use the BERT model from KerasHub to establish a baseline for our semantic similarity
+task. The `keras_hub.models.BertClassifier` class attaches a classification head to the BERT
+Backbone, mapping the backbone outputs to a logit output suitable for a classification task.
+This significantly reduces the need for custom code.
+
+KerasHub models have built-in tokenization capabilities that handle tokenization by default
+based on the selected model. However, users can also use custom preprocessing techniques
+as per their specific needs. If we pass a tuple as input, the model will tokenize all the
+strings and concatenate them with a `"[SEP]"` separator.
+
+We use this model with pretrained weights, and we can use the `from_preset()` method
+to use our own preprocessor. For the SNLI dataset, we set `num_classes` to 3.
+"""
+
+bert_classifier = keras_hub.models.BertClassifier.from_preset(
+ "bert_tiny_en_uncased", num_classes=3
+)
+
+"""
+Please note that the BERT Tiny model has only 4,386,307 trainable parameters.
+
+KerasHub task models come with compilation defaults. We can now train the model we just
+instantiated by calling the `fit()` method.
+"""
+
+bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)
+
+"""
+Our BERT classifier achieved an accuracy of around 76% on the validation split. Now,
+let's evaluate its performance on the test split.
+
+### Evaluate the performance of the trained model on test data.
+"""
+
+bert_classifier.evaluate(test_ds)
+
+"""
+Our baseline BERT model achieved a similar accuracy of around 76% on the test split.
+Now, let's try to improve its performance by recompiling the model with a slightly
+higher learning rate.
+"""
+
+bert_classifier = keras_hub.models.BertClassifier.from_preset(
+ "bert_tiny_en_uncased", num_classes=3
+)
+bert_classifier.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ optimizer=keras.optimizers.Adam(5e-5),
+ metrics=["accuracy"],
+)
+
+bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)
+bert_classifier.evaluate(test_ds)
+
+"""
+Just tweaking the learning rate alone was not enough to boost performance, which
+stayed right around 76%. Let's try again, but this time with
+`keras.optimizers.AdamW`, and a learning rate schedule.
+"""
+
+
+class TriangularSchedule(keras.optimizers.schedules.LearningRateSchedule):
+ """Linear ramp up for `warmup` steps, then linear decay to zero at `total` steps."""
+
+ def __init__(self, rate, warmup, total):
+ self.rate = rate
+ self.warmup = warmup
+ self.total = total
+
+ def get_config(self):
+ config = {"rate": self.rate, "warmup": self.warmup, "total": self.total}
+ return config
+
+ def __call__(self, step):
+ step = keras.ops.cast(step, dtype="float32")
+ rate = keras.ops.cast(self.rate, dtype="float32")
+ warmup = keras.ops.cast(self.warmup, dtype="float32")
+ total = keras.ops.cast(self.total, dtype="float32")
+
+ warmup_rate = rate * step / self.warmup
+ cooldown_rate = rate * (total - step) / (total - warmup)
+ triangular_rate = keras.ops.minimum(warmup_rate, cooldown_rate)
+ return keras.ops.maximum(triangular_rate, 0.0)
+
+
+bert_classifier = keras_hub.models.BertClassifier.from_preset(
+ "bert_tiny_en_uncased", num_classes=3
+)
+
+# Get the total count of training batches.
+# This requires walking the dataset to filter all -1 labels.
+epochs = 3
+total_steps = sum(1 for _ in train_ds.as_numpy_iterator()) * epochs
+warmup_steps = int(total_steps * 0.2)
+
+bert_classifier.compile(
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ optimizer=keras.optimizers.AdamW(
+ TriangularSchedule(1e-4, warmup_steps, total_steps)
+ ),
+ metrics=["accuracy"],
+)
+
+bert_classifier.fit(train_ds, validation_data=val_ds, epochs=epochs)
+
+"""
+Success! With the learning rate scheduler and the `AdamW` optimizer, our validation
+accuracy improved to around 79%.
+
+Now, let's evaluate our final model on the test set and see how it performs.
+"""
+
+bert_classifier.evaluate(test_ds)
+
+"""
+Our Tiny BERT model achieved an accuracy of approximately 79% on the test set
+with the use of a learning rate scheduler. This is a significant improvement over
+our previous results. Fine-tuning a pretrained BERT
+model can be a powerful tool in natural language processing tasks, and even a
+small model like Tiny BERT can achieve impressive results.
+
+Let's save our model for now
+and move on to learning how to perform inference with it.
+
+## Save and Reload the model
+"""
+bert_classifier.save("bert_classifier.keras")
+restored_model = keras.models.load_model("bert_classifier.keras")
+restored_model.evaluate(test_ds)
+
+"""
+## Performing inference with the model.
+
+Let's see how to perform inference with KerasHub models
+"""
+
+# Convert to Hypothesis-Premise pair, for forward pass through model
+sample = (sample["hypothesis"], sample["premise"])
+sample
+
+"""
+The default preprocessor in KerasHub models handles input tokenization automatically,
+so we don't need to perform tokenization explicitly.
+"""
+predictions = bert_classifier.predict(sample)
+
+
+def softmax(x):
+ return np.exp(x) / np.exp(x).sum(axis=0)
+
+
+# Get the class predictions with maximum probabilities
+predictions = softmax(predictions)
+
+"""
+## Improving accuracy with RoBERTa
+
+Now that we have established a baseline, we can attempt to improve our results
+by experimenting with different models. Thanks to KerasHub, fine-tuning a RoBERTa
+checkpoint on the same dataset is easy with just a few lines of code.
+"""
+
+# Inittializing a RoBERTa from preset
+roberta_classifier = keras_hub.models.RobertaClassifier.from_preset(
+ "roberta_base_en", num_classes=3
+)
+
+roberta_classifier.fit(train_ds, validation_data=val_ds, epochs=1)
+
+roberta_classifier.evaluate(test_ds)
+
+"""
+The RoBERTa base model has significantly more trainable parameters than the BERT
+Tiny model, with almost 30 times as many at 124,645,635 parameters. As a result, it took
+approximately 1.5 hours to train on a P100 GPU. However, the performance
+improvement was substantial, with accuracy increasing to 88% on both the validation
+and test splits. With RoBERTa, we were able to fit a maximum batch size of 16 on
+our P100 GPU.
+
+Despite using a different model, the steps to perform inference with RoBERTa are
+the same as with BERT!
+"""
+
+predictions = roberta_classifier.predict(sample)
+print(tf.math.argmax(predictions, axis=1).numpy())
+
+"""
+We hope this tutorial has been helpful in demonstrating the ease and effectiveness
+of using KerasHub and BERT for semantic similarity tasks.
+
+Throughout this tutorial, we demonstrated how to use a pretrained BERT model to
+establish a baseline and improve performance by training a larger RoBERTa model
+using just a few lines of code.
+
+The KerasHub toolbox provides a range of modular building blocks for preprocessing
+text, including pretrained state-of-the-art models and low-level Transformer Encoder
+layers. We believe that this makes experimenting with natural language solutions
+more accessible and efficient.
+"""
diff --git a/knowledge_base/nlp/sentence_embeddings_with_sbert.py b/knowledge_base/nlp/sentence_embeddings_with_sbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..4867eb28bcb84638ddec4bda18bfb58d9102a537
--- /dev/null
+++ b/knowledge_base/nlp/sentence_embeddings_with_sbert.py
@@ -0,0 +1,466 @@
+"""
+Title: Sentence embeddings using Siamese RoBERTa-networks
+Author: [Mohammed Abu El-Nasr](https://github.com/abuelnasr0)
+Date created: 2023/07/14
+Last modified: 2023/07/14
+Description: Fine-tune a RoBERTa model to generate sentence embeddings using KerasHub.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+BERT and RoBERTa can be used for semantic textual similarity tasks, where two sentences
+are passed to the model and the network predicts whether they are similar or not. But
+what if we have a large collection of sentences and want to find the most similar pairs
+in that collection? That will take n*(n-1)/2 inference computations, where n is the
+number of sentences in the collection. For example, if n = 10000, the required time will
+be 65 hours on a V100 GPU.
+
+A common method to overcome the time overhead issue is to pass one sentence to the model,
+then average the output of the model, or take the first token (the [CLS] token) and use
+them as a [sentence embedding](https://en.wikipedia.org/wiki/Sentence_embedding), then
+use a vector similarity measure like cosine similarity or Manhatten / Euclidean distance
+to find close sentences (semantically similar sentences). That will reduce the time to
+find the most similar pairs in a collection of 10,000 sentences from 65 hours to 5
+seconds!
+
+If we use RoBERTa directly, that will yield rather bad sentence embeddings. But if we
+fine-tune RoBERTa using a Siamese network, that will generate semantically meaningful
+sentence embeddings. This will enable RoBERTa to be used for new tasks. These tasks
+include:
+
+- Large-scale semantic similarity comparison.
+- Clustering.
+- Information retrieval via semantic search.
+
+In this example, we will show how to fine-tune a RoBERTa model using a Siamese network
+such that it will be able to produce semantically meaningful sentence embeddings and use
+them in a semantic search and clustering example.
+This method of fine-tuning was introduced in
+[Sentence-BERT](https://arxiv.org/abs/1908.10084)
+"""
+
+"""
+## Setup
+
+Let's install and import the libraries we need. We'll be using the KerasHub library in
+this example.
+
+We will also enable [mixed precision](https://www.tensorflow.org/guide/mixed_precision)
+training. This will help us reduce the training time.
+"""
+
+"""shell
+pip install -q --upgrade keras-hub
+pip install -q --upgrade keras # Upgrade to Keras 3.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+import keras_hub
+import tensorflow as tf
+import tensorflow_datasets as tfds
+import sklearn.cluster as cluster
+
+keras.mixed_precision.set_global_policy("mixed_float16")
+
+"""
+## Fine-tune the model using siamese networks
+
+[Siamese network](https://en.wikipedia.org/wiki/Siamese_neural_network) is a neural
+network architecture that contains two or more subnetworks. The subnetworks share the
+same weights. It is used to generate feature vectors for each input and then compare them
+for similarity.
+
+For our example, the subnetwork will be a RoBERTa model that has a pooling layer on top
+of it to produce the embeddings of the input sentences. These embeddings will then be
+compared to each other to learn to produce semantically meaningful embeddings.
+
+The pooling strategies used are mean, max, and CLS pooling. Mean pooling produces the
+best results. We will use it in our examples.
+"""
+
+"""
+### Fine-tune using the regression objective function
+
+For building the siamese network with the regression objective function, the siamese
+network is asked to predict the cosine similarity between the embeddings of the two input
+sentences.
+
+Cosine similarity indicates the angle between the sentence embeddings. If the cosine
+similarity is high, that means there is a small angle between the embeddings; hence, they
+are semantically similar.
+"""
+
+"""
+#### Load the dataset
+
+We will use the STSB dataset to fine-tune the model for the regression objective. STSB
+consists of a collection of sentence pairs that are labelled in the range [0, 5]. 0
+indicates the least semantic similarity between the two sentences, and 5 indicates the
+most semantic similarity between the two sentences.
+
+The range of the cosine similarity is [-1, 1] and it's the output of the siamese network,
+but the range of the labels in the dataset is [0, 5]. We need to unify the range between
+the cosine similarity and the dataset labels, so while preparing the dataset, we will
+divide the labels by 2.5 and subtract 1.
+"""
+
+TRAIN_BATCH_SIZE = 6
+VALIDATION_BATCH_SIZE = 8
+
+TRAIN_NUM_BATCHES = 300
+VALIDATION_NUM_BATCHES = 40
+
+AUTOTUNE = tf.data.experimental.AUTOTUNE
+
+
+def change_range(x):
+ return (x / 2.5) - 1
+
+
+def prepare_dataset(dataset, num_batches, batch_size):
+ dataset = dataset.map(
+ lambda z: (
+ [z["sentence1"], z["sentence2"]],
+ [tf.cast(change_range(z["label"]), tf.float32)],
+ ),
+ num_parallel_calls=AUTOTUNE,
+ )
+ dataset = dataset.batch(batch_size)
+ dataset = dataset.take(num_batches)
+ dataset = dataset.prefetch(AUTOTUNE)
+ return dataset
+
+
+stsb_ds = tfds.load(
+ "glue/stsb",
+)
+stsb_train, stsb_valid = stsb_ds["train"], stsb_ds["validation"]
+
+stsb_train = prepare_dataset(stsb_train, TRAIN_NUM_BATCHES, TRAIN_BATCH_SIZE)
+stsb_valid = prepare_dataset(stsb_valid, VALIDATION_NUM_BATCHES, VALIDATION_BATCH_SIZE)
+
+"""
+Let's see examples from the dataset of two sentenses and their similarity.
+"""
+
+for x, y in stsb_train:
+ for i, example in enumerate(x):
+ print(f"sentence 1 : {example[0]} ")
+ print(f"sentence 2 : {example[1]} ")
+ print(f"similarity : {y[i]} \n")
+ break
+
+"""
+#### Build the encoder model.
+
+Now, we'll build the encoder model that will produce the sentence embeddings. It consists
+of:
+
+- A preprocessor layer to tokenize and generate padding masks for the sentences.
+- A backbone model that will generate the contextual representation of each token in the
+sentence.
+- A mean pooling layer to produce the embeddings. We will use `keras.layers.GlobalAveragePooling1D`
+to apply the mean pooling to the backbone outputs. We will pass the padding mask to the
+layer to exclude padded tokens from being averaged.
+- A normalization layer to normalize the embeddings as we are using the cosine similarity.
+"""
+
+preprocessor = keras_hub.models.RobertaPreprocessor.from_preset("roberta_base_en")
+backbone = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")
+inputs = keras.Input(shape=(1,), dtype="string", name="sentence")
+x = preprocessor(inputs)
+h = backbone(x)
+embedding = keras.layers.GlobalAveragePooling1D(name="pooling_layer")(
+ h, x["padding_mask"]
+)
+n_embedding = keras.layers.UnitNormalization(axis=1)(embedding)
+roberta_normal_encoder = keras.Model(inputs=inputs, outputs=n_embedding)
+
+roberta_normal_encoder.summary()
+
+"""
+#### Build the Siamese network with the regression objective function.
+
+It's described above that the Siamese network has two or more subnetworks, and for this
+Siamese model, we need two encoders. But we don't have two encoders; we have only one
+encoder, but we will pass the two sentences through it. That way, we can have two paths
+to get the embeddings and also shared weights between the two paths.
+
+After passing the two sentences to the model and getting the normalized embeddings, we
+will multiply the two normalized embeddings to get the cosine similarity between the two
+sentences.
+"""
+
+
+class RegressionSiamese(keras.Model):
+ def __init__(self, encoder, **kwargs):
+ inputs = keras.Input(shape=(2,), dtype="string", name="sentences")
+ sen1, sen2 = keras.ops.split(inputs, 2, axis=1)
+ u = encoder(sen1)
+ v = encoder(sen2)
+ cosine_similarity_scores = keras.ops.matmul(u, keras.ops.transpose(v))
+
+ super().__init__(
+ inputs=inputs,
+ outputs=cosine_similarity_scores,
+ **kwargs,
+ )
+
+ self.encoder = encoder
+
+ def get_encoder(self):
+ return self.encoder
+
+
+"""
+#### Fit the model
+
+Let's try this example before training and compare it to the output after training.
+"""
+
+sentences = [
+ "Today is a very sunny day.",
+ "I am hungry, I will get my meal.",
+ "The dog is eating his food.",
+]
+query = ["The dog is enjoying his meal."]
+
+encoder = roberta_normal_encoder
+
+sentence_embeddings = encoder(tf.constant(sentences))
+query_embedding = encoder(tf.constant(query))
+
+cosine_similarity_scores = tf.matmul(query_embedding, tf.transpose(sentence_embeddings))
+for i, sim in enumerate(cosine_similarity_scores[0]):
+ print(f"cosine similarity score between sentence {i+1} and the query = {sim} ")
+
+"""
+For the training we will use `MeanSquaredError()` as loss function, and `Adam()`
+optimizer with learning rate = 2e-5.
+"""
+
+roberta_regression_siamese = RegressionSiamese(roberta_normal_encoder)
+
+roberta_regression_siamese.compile(
+ loss=keras.losses.MeanSquaredError(),
+ optimizer=keras.optimizers.Adam(2e-5),
+ jit_compile=False,
+)
+
+roberta_regression_siamese.fit(stsb_train, validation_data=stsb_valid, epochs=1)
+
+"""
+Let's try the model after training, we will notice a huge difference in the output. That
+means that the model after fine-tuning is capable of producing semantically meaningful
+embeddings. where the semantically similar sentences have a small angle between them. and
+semantically dissimilar sentences have a large angle between them.
+"""
+
+sentences = [
+ "Today is a very sunny day.",
+ "I am hungry, I will get my meal.",
+ "The dog is eating his food.",
+]
+query = ["The dog is enjoying his food."]
+
+encoder = roberta_regression_siamese.get_encoder()
+
+sentence_embeddings = encoder(tf.constant(sentences))
+query_embedding = encoder(tf.constant(query))
+
+cosine_simalarities = tf.matmul(query_embedding, tf.transpose(sentence_embeddings))
+for i, sim in enumerate(cosine_simalarities[0]):
+ print(f"cosine similarity between sentence {i+1} and the query = {sim} ")
+
+"""
+### Fine-tune Using the triplet Objective Function
+
+For the Siamese network with the triplet objective function, three sentences are passed
+to the Siamese network *anchor*, *positive*, and *negative* sentences. *anchor* and
+*positive* sentences are semantically similar, and *anchor* and *negative* sentences are
+semantically dissimilar. The objective is to minimize the distance between the *anchor*
+sentence and the *positive* sentence, and to maximize the distance between the *anchor*
+sentence and the *negative* sentence.
+"""
+
+"""
+#### Load the dataset
+
+We will use the Wikipedia-sections-triplets dataset for fine-tuning. This data set
+consists of sentences derived from the Wikipedia website. It has a collection of 3
+sentences *anchor*, *positive*, *negative*. *anchor* and *positive* are derived from the
+same section. *anchor* and *negative* are derived from different sections.
+
+This dataset has 1.8 million training triplets and 220,000 test triplets. In this
+example, we will only use 1200 triplets for training and 300 for testing.
+"""
+
+"""shell
+wget https://sbert.net/datasets/wikipedia-sections-triplets.zip -q
+unzip wikipedia-sections-triplets.zip -d wikipedia-sections-triplets
+"""
+
+NUM_TRAIN_BATCHES = 200
+NUM_TEST_BATCHES = 75
+AUTOTUNE = tf.data.experimental.AUTOTUNE
+
+
+def prepare_wiki_data(dataset, num_batches):
+ dataset = dataset.map(
+ lambda z: ((z["Sentence1"], z["Sentence2"], z["Sentence3"]), 0)
+ )
+ dataset = dataset.batch(6)
+ dataset = dataset.take(num_batches)
+ dataset = dataset.prefetch(AUTOTUNE)
+ return dataset
+
+
+wiki_train = tf.data.experimental.make_csv_dataset(
+ "wikipedia-sections-triplets/train.csv",
+ batch_size=1,
+ num_epochs=1,
+)
+wiki_test = tf.data.experimental.make_csv_dataset(
+ "wikipedia-sections-triplets/test.csv",
+ batch_size=1,
+ num_epochs=1,
+)
+
+wiki_train = prepare_wiki_data(wiki_train, NUM_TRAIN_BATCHES)
+wiki_test = prepare_wiki_data(wiki_test, NUM_TEST_BATCHES)
+
+"""
+#### Build the encoder model
+
+For this encoder model, we will use RoBERTa with mean pooling and we will not normalize
+the output embeddings. The encoder model consists of:
+
+- A preprocessor layer to tokenize and generate padding masks for the sentences.
+- A backbone model that will generate the contextual representation of each token in the
+sentence.
+- A mean pooling layer to produce the embeddings.
+"""
+
+preprocessor = keras_hub.models.RobertaPreprocessor.from_preset("roberta_base_en")
+backbone = keras_hub.models.RobertaBackbone.from_preset("roberta_base_en")
+input = keras.Input(shape=(1,), dtype="string", name="sentence")
+
+x = preprocessor(input)
+h = backbone(x)
+embedding = keras.layers.GlobalAveragePooling1D(name="pooling_layer")(
+ h, x["padding_mask"]
+)
+
+roberta_encoder = keras.Model(inputs=input, outputs=embedding)
+
+
+roberta_encoder.summary()
+
+"""
+#### Build the Siamese network with the triplet objective function
+
+For the Siamese network with the triplet objective function, we will build the model with
+an encoder, and we will pass the three sentences through that encoder. We will get an
+embedding for each sentence, and we will calculate the `positive_dist` and
+`negative_dist` that will be passed to the loss function described below.
+"""
+
+
+class TripletSiamese(keras.Model):
+ def __init__(self, encoder, **kwargs):
+ anchor = keras.Input(shape=(1,), dtype="string")
+ positive = keras.Input(shape=(1,), dtype="string")
+ negative = keras.Input(shape=(1,), dtype="string")
+
+ ea = encoder(anchor)
+ ep = encoder(positive)
+ en = encoder(negative)
+
+ positive_dist = keras.ops.sum(keras.ops.square(ea - ep), axis=1)
+ negative_dist = keras.ops.sum(keras.ops.square(ea - en), axis=1)
+
+ positive_dist = keras.ops.sqrt(positive_dist)
+ negative_dist = keras.ops.sqrt(negative_dist)
+
+ output = keras.ops.stack([positive_dist, negative_dist], axis=0)
+
+ super().__init__(inputs=[anchor, positive, negative], outputs=output, **kwargs)
+
+ self.encoder = encoder
+
+ def get_encoder(self):
+ return self.encoder
+
+
+"""
+We will use a custom loss function for the triplet objective. The loss function will
+receive the distance between the *anchor* and the *positive* embeddings `positive_dist`,
+and the distance between the *anchor* and the *negative* embeddings `negative_dist`,
+where they are stacked together in `y_pred`.
+
+We will use `positive_dist` and `negative_dist` to compute the loss such that
+`negative_dist` is larger than `positive_dist` at least by a specific margin.
+Mathematically, we will minimize this loss function: `max( positive_dist - negative_dist
++ margin, 0)`.
+
+There is no `y_true` used in this loss function. Note that we set the labels in the
+dataset to zero, but they will not be used.
+"""
+
+
+class TripletLoss(keras.losses.Loss):
+ def __init__(self, margin=1, **kwargs):
+ super().__init__(**kwargs)
+ self.margin = margin
+
+ def call(self, y_true, y_pred):
+ positive_dist, negative_dist = tf.unstack(y_pred, axis=0)
+
+ losses = keras.ops.relu(positive_dist - negative_dist + self.margin)
+ return keras.ops.mean(losses, axis=0)
+
+
+"""
+#### Fit the model
+
+For the training, we will use the custom `TripletLoss()` loss function, and `Adam()`
+optimizer with a learning rate = 2e-5.
+"""
+
+roberta_triplet_siamese = TripletSiamese(roberta_encoder)
+
+roberta_triplet_siamese.compile(
+ loss=TripletLoss(),
+ optimizer=keras.optimizers.Adam(2e-5),
+ jit_compile=False,
+)
+
+roberta_triplet_siamese.fit(wiki_train, validation_data=wiki_test, epochs=1)
+
+"""
+Let's try this model in a clustering example. Here are 6 questions. first 3 questions
+about learning English, and the last 3 questions about working online. Let's see if the
+embeddings produced by our encoder will cluster them correctly.
+"""
+
+questions = [
+ "What should I do to improve my English writting?",
+ "How to be good at speaking English?",
+ "How can I improve my English?",
+ "How to earn money online?",
+ "How do I earn money online?",
+ "How to work and earn money through internet?",
+]
+
+encoder = roberta_triplet_siamese.get_encoder()
+embeddings = encoder(tf.constant(questions))
+kmeans = cluster.KMeans(n_clusters=2, random_state=0, n_init="auto").fit(embeddings)
+
+for i, label in enumerate(kmeans.labels_):
+ print(f"sentence ({questions[i]}) belongs to cluster {label}")
diff --git a/knowledge_base/nlp/text_classification_from_scratch.py b/knowledge_base/nlp/text_classification_from_scratch.py
new file mode 100644
index 0000000000000000000000000000000000000000..048e076b3a67bf1c7ef9f4d2702b57d48b201897
--- /dev/null
+++ b/knowledge_base/nlp/text_classification_from_scratch.py
@@ -0,0 +1,302 @@
+"""
+Title: Text classification from scratch
+Authors: Mark Omernick, Francois Chollet
+Date created: 2019/11/06
+Last modified: 2020/05/17
+Description: Text sentiment classification starting from raw text files.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example shows how to do text classification starting from raw text (as
+a set of text files on disk). We demonstrate the workflow on the IMDB sentiment
+classification dataset (unprocessed version). We use the `TextVectorization` layer for
+ word splitting & indexing.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+import tensorflow as tf
+import numpy as np
+from keras import layers
+
+"""
+## Load the data: IMDB movie review sentiment classification
+
+Let's download the data and inspect its structure.
+"""
+
+"""shell
+curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+tar -xf aclImdb_v1.tar.gz
+"""
+
+"""
+The `aclImdb` folder contains a `train` and `test` subfolder:
+"""
+
+"""shell
+ls aclImdb
+"""
+
+"""shell
+ls aclImdb/test
+"""
+
+"""shell
+ls aclImdb/train
+"""
+
+"""
+The `aclImdb/train/pos` and `aclImdb/train/neg` folders contain text files, each of
+ which represents one review (either positive or negative):
+"""
+
+"""shell
+cat aclImdb/train/pos/6248_7.txt
+"""
+
+"""
+We are only interested in the `pos` and `neg` subfolders, so let's delete the other subfolder that has text files in it:
+"""
+
+"""shell
+rm -r aclImdb/train/unsup
+"""
+
+"""
+You can use the utility `keras.utils.text_dataset_from_directory` to
+generate a labeled `tf.data.Dataset` object from a set of text files on disk filed
+ into class-specific folders.
+
+Let's use it to generate the training, validation, and test datasets. The validation
+and training datasets are generated from two subsets of the `train` directory, with 20%
+of samples going to the validation dataset and 80% going to the training dataset.
+
+Having a validation dataset in addition to the test dataset is useful for tuning
+hyperparameters, such as the model architecture, for which the test dataset should not
+be used.
+
+Before putting the model out into the real world however, it should be retrained using all
+available training data (without creating a validation dataset), so its performance is maximized.
+
+When using the `validation_split` & `subset` arguments, make sure to either specify a
+random seed, or to pass `shuffle=False`, so that the validation & training splits you
+get have no overlap.
+
+"""
+
+batch_size = 32
+raw_train_ds = keras.utils.text_dataset_from_directory(
+ "aclImdb/train",
+ batch_size=batch_size,
+ validation_split=0.2,
+ subset="training",
+ seed=1337,
+)
+raw_val_ds = keras.utils.text_dataset_from_directory(
+ "aclImdb/train",
+ batch_size=batch_size,
+ validation_split=0.2,
+ subset="validation",
+ seed=1337,
+)
+raw_test_ds = keras.utils.text_dataset_from_directory(
+ "aclImdb/test", batch_size=batch_size
+)
+
+print(f"Number of batches in raw_train_ds: {raw_train_ds.cardinality()}")
+print(f"Number of batches in raw_val_ds: {raw_val_ds.cardinality()}")
+print(f"Number of batches in raw_test_ds: {raw_test_ds.cardinality()}")
+
+"""
+Let's preview a few samples:
+"""
+
+# It's important to take a look at your raw data to ensure your normalization
+# and tokenization will work as expected. We can do that by taking a few
+# examples from the training set and looking at them.
+# This is one of the places where eager execution shines:
+# we can just evaluate these tensors using .numpy()
+# instead of needing to evaluate them in a Session/Graph context.
+for text_batch, label_batch in raw_train_ds.take(1):
+ for i in range(5):
+ print(text_batch.numpy()[i])
+ print(label_batch.numpy()[i])
+
+"""
+## Prepare the data
+
+In particular, we remove `
` tags.
+"""
+
+import string
+import re
+
+
+# Having looked at our data above, we see that the raw text contains HTML break
+# tags of the form '
'. These tags will not be removed by the default
+# standardizer (which doesn't strip HTML). Because of this, we will need to
+# create a custom standardization function.
+def custom_standardization(input_data):
+ lowercase = tf.strings.lower(input_data)
+ stripped_html = tf.strings.regex_replace(lowercase, "
", " ")
+ return tf.strings.regex_replace(
+ stripped_html, f"[{re.escape(string.punctuation)}]", ""
+ )
+
+
+# Model constants.
+max_features = 20000
+embedding_dim = 128
+sequence_length = 500
+
+# Now that we have our custom standardization, we can instantiate our text
+# vectorization layer. We are using this layer to normalize, split, and map
+# strings to integers, so we set our 'output_mode' to 'int'.
+# Note that we're using the default split function,
+# and the custom standardization defined above.
+# We also set an explicit maximum sequence length, since the CNNs later in our
+# model won't support ragged sequences.
+vectorize_layer = keras.layers.TextVectorization(
+ standardize=custom_standardization,
+ max_tokens=max_features,
+ output_mode="int",
+ output_sequence_length=sequence_length,
+)
+
+# Now that the vectorize_layer has been created, call `adapt` on a text-only
+# dataset to create the vocabulary. You don't have to batch, but for very large
+# datasets this means you're not keeping spare copies of the dataset in memory.
+
+# Let's make a text-only dataset (no labels):
+text_ds = raw_train_ds.map(lambda x, y: x)
+# Let's call `adapt`:
+vectorize_layer.adapt(text_ds)
+
+"""
+## Two options to vectorize the data
+
+There are 2 ways we can use our text vectorization layer:
+
+**Option 1: Make it part of the model**, so as to obtain a model that processes raw
+ strings, like this:
+"""
+
+"""
+
+```python
+text_input = keras.Input(shape=(1,), dtype=tf.string, name='text')
+x = vectorize_layer(text_input)
+x = layers.Embedding(max_features + 1, embedding_dim)(x)
+...
+```
+
+**Option 2: Apply it to the text dataset** to obtain a dataset of word indices, then
+ feed it into a model that expects integer sequences as inputs.
+
+An important difference between the two is that option 2 enables you to do
+**asynchronous CPU processing and buffering** of your data when training on GPU.
+So if you're training the model on GPU, you probably want to go with this option to get
+ the best performance. This is what we will do below.
+
+If we were to export our model to production, we'd ship a model that accepts raw
+strings as input, like in the code snippet for option 1 above. This can be done after
+ training. We do this in the last section.
+
+
+"""
+
+
+def vectorize_text(text, label):
+ text = tf.expand_dims(text, -1)
+ return vectorize_layer(text), label
+
+
+# Vectorize the data.
+train_ds = raw_train_ds.map(vectorize_text)
+val_ds = raw_val_ds.map(vectorize_text)
+test_ds = raw_test_ds.map(vectorize_text)
+
+# Do async prefetching / buffering of the data for best performance on GPU.
+train_ds = train_ds.cache().prefetch(buffer_size=10)
+val_ds = val_ds.cache().prefetch(buffer_size=10)
+test_ds = test_ds.cache().prefetch(buffer_size=10)
+
+"""
+## Build a model
+
+We choose a simple 1D convnet starting with an `Embedding` layer.
+"""
+
+# A integer input for vocab indices.
+inputs = keras.Input(shape=(None,), dtype="int64")
+
+# Next, we add a layer to map those vocab indices into a space of dimensionality
+# 'embedding_dim'.
+x = layers.Embedding(max_features, embedding_dim)(inputs)
+x = layers.Dropout(0.5)(x)
+
+# Conv1D + global max pooling
+x = layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3)(x)
+x = layers.Conv1D(128, 7, padding="valid", activation="relu", strides=3)(x)
+x = layers.GlobalMaxPooling1D()(x)
+
+# We add a vanilla hidden layer:
+x = layers.Dense(128, activation="relu")(x)
+x = layers.Dropout(0.5)(x)
+
+# We project onto a single unit output layer, and squash it with a sigmoid:
+predictions = layers.Dense(1, activation="sigmoid", name="predictions")(x)
+
+model = keras.Model(inputs, predictions)
+
+# Compile the model with binary crossentropy loss and an adam optimizer.
+model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
+
+"""
+## Train the model
+"""
+
+epochs = 3
+
+# Fit the model using the train and test datasets.
+model.fit(train_ds, validation_data=val_ds, epochs=epochs)
+
+"""
+## Evaluate the model on the test set
+"""
+
+model.evaluate(test_ds)
+
+"""
+## Make an end-to-end model
+
+If you want to obtain a model capable of processing raw strings, you can simply
+create a new model (using the weights we just trained):
+"""
+
+# A string input
+inputs = keras.Input(shape=(1,), dtype="string")
+# Turn strings into vocab indices
+indices = vectorize_layer(inputs)
+# Turn vocab indices into predictions
+outputs = model(indices)
+
+# Our end to end model
+end_to_end_model = keras.Model(inputs, outputs)
+end_to_end_model.compile(
+ loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]
+)
+
+# Test it with `raw_test_ds`, which yields raw strings
+end_to_end_model.evaluate(raw_test_ds)
diff --git a/knowledge_base/nlp/text_classification_with_switch_transformer.py b/knowledge_base/nlp/text_classification_with_switch_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f19a2e187d7cdf489fb8f671eeb3359f37fec30
--- /dev/null
+++ b/knowledge_base/nlp/text_classification_with_switch_transformer.py
@@ -0,0 +1,328 @@
+"""
+Title: Text classification with Switch Transformer
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2020/05/10
+Last modified: 2021/02/15
+Description: Implement a Switch Transformer for text classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates the implementation of the
+[Switch Transformer](https://arxiv.org/abs/2101.03961) model for text
+classification.
+
+The Switch Transformer replaces the feedforward network (FFN) layer in the standard
+Transformer with a Mixture of Expert (MoE) routing layer, where each expert operates
+independently on the tokens in the sequence. This allows increasing the model size without
+increasing the computation needed to process each example.
+
+Note that, for training the Switch Transformer efficiently, data and model parallelism
+need to be applied, so that expert modules can run simultaneously, each on its own accelerator.
+While the implementation described in the paper uses the
+[TensorFlow Mesh](https://github.com/tensorflow/mesh) framework for distributed training,
+this example presents a simple, non-distributed implementation of the Switch Transformer
+model for demonstration purposes.
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import ops
+from keras import layers
+
+"""
+## Download and prepare dataset
+"""
+
+vocab_size = 20000 # Only consider the top 20k words
+num_tokens_per_example = 200 # Only consider the first 200 words of each movie review
+(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
+print(len(x_train), "Training sequences")
+print(len(x_val), "Validation sequences")
+x_train = keras.utils.pad_sequences(x_train, maxlen=num_tokens_per_example)
+x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)
+
+"""
+## Define hyperparameters
+"""
+
+embed_dim = 32 # Embedding size for each token.
+num_heads = 2 # Number of attention heads
+ff_dim = 32 # Hidden layer size in feedforward network.
+num_experts = 10 # Number of experts used in the Switch Transformer.
+batch_size = 50 # Batch size.
+learning_rate = 0.001 # Learning rate.
+dropout_rate = 0.25 # Dropout rate.
+num_epochs = 3 # Number of epochs.
+num_tokens_per_batch = (
+ batch_size * num_tokens_per_example
+) # Total number of tokens per batch.
+print(f"Number of tokens per batch: {num_tokens_per_batch}")
+
+"""
+## Implement token & position embedding layer
+
+It consists of two separate embedding layers, one for tokens, one for token index (positions).
+"""
+
+
+class TokenAndPositionEmbedding(layers.Layer):
+ def __init__(self, maxlen, vocab_size, embed_dim):
+ super().__init__()
+ self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
+ self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
+
+ def call(self, x):
+ maxlen = ops.shape(x)[-1]
+ positions = ops.arange(start=0, stop=maxlen, step=1)
+ positions = self.pos_emb(positions)
+ x = self.token_emb(x)
+ return x + positions
+
+
+"""
+## Implement the feedforward network
+
+This is used as the Mixture of Experts in the Switch Transformer.
+"""
+
+
+def create_feedforward_network(ff_dim, embed_dim, name=None):
+ return keras.Sequential(
+ [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim)], name=name
+ )
+
+
+"""
+## Implement the load-balanced loss
+
+This is an auxiliary loss to encourage a balanced load across experts.
+"""
+
+
+def load_balanced_loss(router_probs, expert_mask):
+ # router_probs [tokens_per_batch, num_experts] is the probability assigned for
+ # each expert per token. expert_mask [tokens_per_batch, num_experts] contains
+ # the expert with the highest router probability in oneโhot format.
+
+ num_experts = ops.shape(expert_mask)[-1]
+ # Get the fraction of tokens routed to each expert.
+ # density is a vector of length num experts that sums to 1.
+ density = ops.mean(expert_mask, axis=0)
+ # Get fraction of probability mass assigned to each expert from the router
+ # across all tokens. density_proxy is a vector of length num experts that sums to 1.
+ density_proxy = ops.mean(router_probs, axis=0)
+ # Want both vectors to have uniform allocation (1/num experts) across all
+ # num_expert elements. The two vectors will be pushed towards uniform allocation
+ # when the dot product is minimized.
+ loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), "float32")
+ return loss
+
+
+"""
+### Implement the router as a layer
+"""
+
+
+class Router(layers.Layer):
+ def __init__(self, num_experts, expert_capacity):
+ self.num_experts = num_experts
+ self.route = layers.Dense(units=num_experts)
+ self.expert_capacity = expert_capacity
+ super().__init__()
+
+ def call(self, inputs, training=False):
+ # inputs shape: [tokens_per_batch, embed_dim]
+ # router_logits shape: [tokens_per_batch, num_experts]
+ router_logits = self.route(inputs)
+
+ if training:
+ # Add noise for exploration across experts.
+ router_logits += keras.random.uniform(
+ shape=router_logits.shape, minval=0.9, maxval=1.1
+ )
+ # Probabilities for each token of what expert it should be sent to.
+ router_probs = keras.activations.softmax(router_logits, axis=-1)
+ # Get the topโ1 expert for each token. expert_gate is the topโ1 probability
+ # from the router for each token. expert_index is what expert each token
+ # is going to be routed to.
+ expert_gate, expert_index = ops.top_k(router_probs, k=1)
+ # expert_mask shape: [tokens_per_batch, num_experts]
+ expert_mask = ops.one_hot(expert_index, self.num_experts)
+ # Compute load balancing loss.
+ aux_loss = load_balanced_loss(router_probs, expert_mask)
+ self.add_loss(aux_loss)
+ # Experts have a fixed capacity, ensure we do not exceed it. Construct
+ # the batch indices, to each expert, with position in expert make sure that
+ # not more that expert capacity examples can be routed to each expert.
+ position_in_expert = ops.cast(
+ ops.cumsum(expert_mask, axis=0) * expert_mask, "int32"
+ )
+ # Keep only tokens that fit within expert capacity.
+ expert_mask *= ops.cast(
+ ops.less(ops.cast(position_in_expert, "int32"), self.expert_capacity),
+ "float32",
+ )
+ expert_mask_flat = ops.sum(expert_mask, axis=-1)
+ # Mask out the experts that have overflowed the expert capacity.
+ expert_gate *= expert_mask_flat
+ # Combine expert outputs and scaling with router probability.
+ # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
+ combined_tensor = ops.expand_dims(
+ expert_gate
+ * expert_mask_flat
+ * ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
+ -1,
+ ) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)
+ # Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
+ # that is 1 if the token gets routed to the corresponding expert.
+ dispatch_tensor = ops.cast(combined_tensor, "float32")
+
+ return dispatch_tensor, combined_tensor
+
+
+"""
+### Implement a Switch layer
+"""
+
+
+class Switch(layers.Layer):
+ def __init__(
+ self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1
+ ):
+ self.num_experts = num_experts
+ self.embed_dim = embed_dim
+ self.experts = [
+ create_feedforward_network(ff_dim, embed_dim) for _ in range(num_experts)
+ ]
+
+ self.expert_capacity = num_tokens_per_batch // self.num_experts
+ self.router = Router(self.num_experts, self.expert_capacity)
+ super().__init__()
+
+ def call(self, inputs):
+ batch_size = ops.shape(inputs)[0]
+ num_tokens_per_example = ops.shape(inputs)[1]
+
+ # inputs shape: [num_tokens_per_batch, embed_dim]
+ inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
+ # dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
+ # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
+ dispatch_tensor, combine_tensor = self.router(inputs)
+ # expert_inputs shape: [num_experts, expert_capacity, embed_dim]
+ expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
+ expert_inputs = ops.reshape(
+ expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
+ )
+ # Dispatch to experts
+ expert_input_list = ops.unstack(expert_inputs, axis=0)
+ expert_output_list = [
+ self.experts[idx](expert_input)
+ for idx, expert_input in enumerate(expert_input_list)
+ ]
+ # expert_outputs shape: [expert_capacity, num_experts, embed_dim]
+ expert_outputs = ops.stack(expert_output_list, axis=1)
+ # expert_outputs_combined shape: [tokens_per_batch, embed_dim]
+ expert_outputs_combined = ops.einsum(
+ "abc,xba->xc", expert_outputs, combine_tensor
+ )
+ # output shape: [batch_size, num_tokens_per_example, embed_dim]
+ outputs = ops.reshape(
+ expert_outputs_combined,
+ [batch_size, num_tokens_per_example, self.embed_dim],
+ )
+ return outputs
+
+
+"""
+## Implement a Transformer block layer
+"""
+
+
+class TransformerBlock(layers.Layer):
+ def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):
+ super().__init__()
+ self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
+ # The ffn can be either a standard feedforward network or a switch
+ # layer with a Mixture of Experts.
+ self.ffn = ffn
+ self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
+ self.dropout1 = layers.Dropout(dropout_rate)
+ self.dropout2 = layers.Dropout(dropout_rate)
+
+ def call(self, inputs, training=False):
+ attn_output = self.att(inputs, inputs)
+ attn_output = self.dropout1(attn_output, training=training)
+ out1 = self.layernorm1(inputs + attn_output)
+ ffn_output = self.ffn(out1)
+ ffn_output = self.dropout2(ffn_output, training=training)
+ return self.layernorm2(out1 + ffn_output)
+
+
+"""
+## Implement the classifier
+
+The `TransformerBlock` layer outputs one vector for each time step of our input sequence.
+Here, we take the mean across all time steps and use a feedforward network on top
+of it to classify text.
+"""
+
+
+def create_classifier():
+ switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)
+ transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)
+
+ inputs = layers.Input(shape=(num_tokens_per_example,))
+ embedding_layer = TokenAndPositionEmbedding(
+ num_tokens_per_example, vocab_size, embed_dim
+ )
+ x = embedding_layer(inputs)
+ x = transformer_block(x)
+ x = layers.GlobalAveragePooling1D()(x)
+ x = layers.Dropout(dropout_rate)(x)
+ x = layers.Dense(ff_dim, activation="relu")(x)
+ x = layers.Dropout(dropout_rate)(x)
+ outputs = layers.Dense(2, activation="softmax")(x)
+
+ classifier = keras.Model(inputs=inputs, outputs=outputs)
+ return classifier
+
+
+"""
+## Train and evaluate the model
+"""
+
+
+def run_experiment(classifier):
+ classifier.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+ )
+ history = classifier.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_data=(x_val, y_val),
+ )
+ return history
+
+
+classifier = create_classifier()
+run_experiment(classifier)
+
+
+"""
+## Conclusion
+
+Compared to the standard Transformer architecture, the Switch Transformer can have a much
+larger number of parameters, leading to increased model
+capacity, while maintaining a reasonable computational cost.
+"""
diff --git a/knowledge_base/nlp/text_classification_with_transformer.py b/knowledge_base/nlp/text_classification_with_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a049bb9f8f6a72b1f9bba102fadcc6f0a700fd6
--- /dev/null
+++ b/knowledge_base/nlp/text_classification_with_transformer.py
@@ -0,0 +1,118 @@
+"""
+Title: Text classification with Transformer
+Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
+Date created: 2020/05/10
+Last modified: 2024/01/18
+Description: Implement a Transformer block as a Keras layer and use it for text classification.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import ops
+from keras import layers
+
+
+"""
+## Implement a Transformer block as a layer
+"""
+
+
+class TransformerBlock(layers.Layer):
+ def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
+ super().__init__()
+ self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(ff_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
+ self.dropout1 = layers.Dropout(rate)
+ self.dropout2 = layers.Dropout(rate)
+
+ def call(self, inputs):
+ attn_output = self.att(inputs, inputs)
+ attn_output = self.dropout1(attn_output)
+ out1 = self.layernorm1(inputs + attn_output)
+ ffn_output = self.ffn(out1)
+ ffn_output = self.dropout2(ffn_output)
+ return self.layernorm2(out1 + ffn_output)
+
+
+"""
+## Implement embedding layer
+
+Two separate embedding layers, one for tokens, one for token index (positions).
+"""
+
+
+class TokenAndPositionEmbedding(layers.Layer):
+ def __init__(self, maxlen, vocab_size, embed_dim):
+ super().__init__()
+ self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
+ self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
+
+ def call(self, x):
+ maxlen = ops.shape(x)[-1]
+ positions = ops.arange(start=0, stop=maxlen, step=1)
+ positions = self.pos_emb(positions)
+ x = self.token_emb(x)
+ return x + positions
+
+
+"""
+## Download and prepare dataset
+"""
+
+vocab_size = 20000 # Only consider the top 20k words
+maxlen = 200 # Only consider the first 200 words of each movie review
+(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)
+print(len(x_train), "Training sequences")
+print(len(x_val), "Validation sequences")
+x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
+x_val = keras.utils.pad_sequences(x_val, maxlen=maxlen)
+
+"""
+## Create classifier model using transformer layer
+
+Transformer layer outputs one vector for each time step of our input sequence.
+Here, we take the mean across all time steps and
+use a feed forward network on top of it to classify text.
+"""
+
+
+embed_dim = 32 # Embedding size for each token
+num_heads = 2 # Number of attention heads
+ff_dim = 32 # Hidden layer size in feed forward network inside transformer
+
+inputs = layers.Input(shape=(maxlen,))
+embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
+x = embedding_layer(inputs)
+transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
+x = transformer_block(x)
+x = layers.GlobalAveragePooling1D()(x)
+x = layers.Dropout(0.1)(x)
+x = layers.Dense(20, activation="relu")(x)
+x = layers.Dropout(0.1)(x)
+outputs = layers.Dense(2, activation="softmax")(x)
+
+model = keras.Model(inputs=inputs, outputs=outputs)
+
+
+"""
+## Train and Evaluate
+"""
+
+model.compile(
+ optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
+)
+history = model.fit(
+ x_train, y_train, batch_size=32, epochs=2, validation_data=(x_val, y_val)
+)
diff --git a/knowledge_base/nlp/text_extraction_with_bert.py b/knowledge_base/nlp/text_extraction_with_bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..7964c776bb8495609cb8d5cfaef359bf851bbacc
--- /dev/null
+++ b/knowledge_base/nlp/text_extraction_with_bert.py
@@ -0,0 +1,347 @@
+"""
+Title: Text Extraction with BERT
+Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
+Date created: 2020/05/23
+Last modified: 2020/05/23
+Description: Fine tune pretrained BERT from HuggingFace Transformers on SQuAD.
+Accelerator: TPU
+"""
+
+"""
+## Introduction
+
+This demonstration uses SQuAD (Stanford Question-Answering Dataset).
+In SQuAD, an input consists of a question, and a paragraph for context.
+The goal is to find the span of text in the paragraph that answers the question.
+We evaluate our performance on this data with the "Exact Match" metric,
+which measures the percentage of predictions that exactly match any one of the
+ground-truth answers.
+
+We fine-tune a BERT model to perform this task as follows:
+
+1. Feed the context and the question as inputs to BERT.
+2. Take two vectors S and T with dimensions equal to that of
+ hidden states in BERT.
+3. Compute the probability of each token being the start and end of
+ the answer span. The probability of a token being the start of
+ the answer is given by a dot product between S and the representation
+ of the token in the last layer of BERT, followed by a softmax over all tokens.
+ The probability of a token being the end of the answer is computed
+ similarly with the vector T.
+4. Fine-tune BERT and learn S and T along the way.
+
+**References:**
+
+- [BERT](https://arxiv.org/abs/1810.04805)
+- [SQuAD](https://arxiv.org/abs/1606.05250)
+"""
+"""
+## Setup
+"""
+import os
+import re
+import json
+import string
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tokenizers import BertWordPieceTokenizer
+from transformers import BertTokenizer, TFBertModel, BertConfig
+
+max_len = 384
+configuration = BertConfig() # default parameters and configuration for BERT
+
+"""
+## Set-up BERT tokenizer
+"""
+# Save the slow pretrained tokenizer
+slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+save_path = "bert_base_uncased/"
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+slow_tokenizer.save_pretrained(save_path)
+
+# Load the fast tokenizer from saved file
+tokenizer = BertWordPieceTokenizer("bert_base_uncased/vocab.txt", lowercase=True)
+
+"""
+## Load the data
+"""
+train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
+train_path = keras.utils.get_file("train.json", train_data_url)
+eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
+eval_path = keras.utils.get_file("eval.json", eval_data_url)
+
+"""
+## Preprocess the data
+
+1. Go through the JSON file and store every record as a `SquadExample` object.
+2. Go through each `SquadExample` and create `x_train, y_train, x_eval, y_eval`.
+"""
+
+
+class SquadExample:
+ def __init__(self, question, context, start_char_idx, answer_text, all_answers):
+ self.question = question
+ self.context = context
+ self.start_char_idx = start_char_idx
+ self.answer_text = answer_text
+ self.all_answers = all_answers
+ self.skip = False
+
+ def preprocess(self):
+ context = self.context
+ question = self.question
+ answer_text = self.answer_text
+ start_char_idx = self.start_char_idx
+
+ # Clean context, answer and question
+ context = " ".join(str(context).split())
+ question = " ".join(str(question).split())
+ answer = " ".join(str(answer_text).split())
+
+ # Find end character index of answer in context
+ end_char_idx = start_char_idx + len(answer)
+ if end_char_idx >= len(context):
+ self.skip = True
+ return
+
+ # Mark the character indexes in context that are in answer
+ is_char_in_ans = [0] * len(context)
+ for idx in range(start_char_idx, end_char_idx):
+ is_char_in_ans[idx] = 1
+
+ # Tokenize context
+ tokenized_context = tokenizer.encode(context)
+
+ # Find tokens that were created from answer characters
+ ans_token_idx = []
+ for idx, (start, end) in enumerate(tokenized_context.offsets):
+ if sum(is_char_in_ans[start:end]) > 0:
+ ans_token_idx.append(idx)
+
+ if len(ans_token_idx) == 0:
+ self.skip = True
+ return
+
+ # Find start and end token index for tokens from answer
+ start_token_idx = ans_token_idx[0]
+ end_token_idx = ans_token_idx[-1]
+
+ # Tokenize question
+ tokenized_question = tokenizer.encode(question)
+
+ # Create inputs
+ input_ids = tokenized_context.ids + tokenized_question.ids[1:]
+ token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(
+ tokenized_question.ids[1:]
+ )
+ attention_mask = [1] * len(input_ids)
+
+ # Pad and create attention masks.
+ # Skip if truncation is needed
+ padding_length = max_len - len(input_ids)
+ if padding_length > 0: # pad
+ input_ids = input_ids + ([0] * padding_length)
+ attention_mask = attention_mask + ([0] * padding_length)
+ token_type_ids = token_type_ids + ([0] * padding_length)
+ elif padding_length < 0: # skip
+ self.skip = True
+ return
+
+ self.input_ids = input_ids
+ self.token_type_ids = token_type_ids
+ self.attention_mask = attention_mask
+ self.start_token_idx = start_token_idx
+ self.end_token_idx = end_token_idx
+ self.context_token_to_char = tokenized_context.offsets
+
+
+with open(train_path) as f:
+ raw_train_data = json.load(f)
+
+with open(eval_path) as f:
+ raw_eval_data = json.load(f)
+
+
+def create_squad_examples(raw_data):
+ squad_examples = []
+ for item in raw_data["data"]:
+ for para in item["paragraphs"]:
+ context = para["context"]
+ for qa in para["qas"]:
+ question = qa["question"]
+ answer_text = qa["answers"][0]["text"]
+ all_answers = [_["text"] for _ in qa["answers"]]
+ start_char_idx = qa["answers"][0]["answer_start"]
+ squad_eg = SquadExample(
+ question, context, start_char_idx, answer_text, all_answers
+ )
+ squad_eg.preprocess()
+ squad_examples.append(squad_eg)
+ return squad_examples
+
+
+def create_inputs_targets(squad_examples):
+ dataset_dict = {
+ "input_ids": [],
+ "token_type_ids": [],
+ "attention_mask": [],
+ "start_token_idx": [],
+ "end_token_idx": [],
+ }
+ for item in squad_examples:
+ if item.skip == False:
+ for key in dataset_dict:
+ dataset_dict[key].append(getattr(item, key))
+ for key in dataset_dict:
+ dataset_dict[key] = np.array(dataset_dict[key])
+
+ x = [
+ dataset_dict["input_ids"],
+ dataset_dict["token_type_ids"],
+ dataset_dict["attention_mask"],
+ ]
+ y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
+ return x, y
+
+
+train_squad_examples = create_squad_examples(raw_train_data)
+x_train, y_train = create_inputs_targets(train_squad_examples)
+print(f"{len(train_squad_examples)} training points created.")
+
+eval_squad_examples = create_squad_examples(raw_eval_data)
+x_eval, y_eval = create_inputs_targets(eval_squad_examples)
+print(f"{len(eval_squad_examples)} evaluation points created.")
+
+"""
+Create the Question-Answering Model using BERT and Functional API
+"""
+
+
+def create_model():
+ ## BERT encoder
+ encoder = TFBertModel.from_pretrained("bert-base-uncased")
+
+ ## QA Model
+ input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
+ token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
+ attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)
+ embedding = encoder(
+ input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
+ )[0]
+
+ start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
+ start_logits = layers.Flatten()(start_logits)
+
+ end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
+ end_logits = layers.Flatten()(end_logits)
+
+ start_probs = layers.Activation(keras.activations.softmax)(start_logits)
+ end_probs = layers.Activation(keras.activations.softmax)(end_logits)
+
+ model = keras.Model(
+ inputs=[input_ids, token_type_ids, attention_mask],
+ outputs=[start_probs, end_probs],
+ )
+ loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
+ optimizer = keras.optimizers.Adam(lr=5e-5)
+ model.compile(optimizer=optimizer, loss=[loss, loss])
+ return model
+
+
+"""
+This code should preferably be run on Google Colab TPU runtime.
+With Colab TPUs, each epoch will take 5-6 minutes.
+"""
+use_tpu = True
+if use_tpu:
+ # Create distribution strategy
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
+ strategy = tf.distribute.TPUStrategy(tpu)
+
+ # Create model
+ with strategy.scope():
+ model = create_model()
+else:
+ model = create_model()
+
+model.summary()
+
+"""
+## Create evaluation Callback
+
+This callback will compute the exact match score using the validation data
+after every epoch.
+"""
+
+
+def normalize_text(text):
+ text = text.lower()
+
+ # Remove punctuations
+ exclude = set(string.punctuation)
+ text = "".join(ch for ch in text if ch not in exclude)
+
+ # Remove articles
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
+ text = re.sub(regex, " ", text)
+
+ # Remove extra white space
+ text = " ".join(text.split())
+ return text
+
+
+class ExactMatch(keras.callbacks.Callback):
+ """
+ Each `SquadExample` object contains the character level offsets for each token
+ in its input paragraph. We use them to get back the span of text corresponding
+ to the tokens between our predicted start and end tokens.
+ All the ground-truth answers are also present in each `SquadExample` object.
+ We calculate the percentage of data points where the span of text obtained
+ from model predictions matches one of the ground-truth answers.
+ """
+
+ def __init__(self, x_eval, y_eval):
+ self.x_eval = x_eval
+ self.y_eval = y_eval
+
+ def on_epoch_end(self, epoch, logs=None):
+ pred_start, pred_end = self.model.predict(self.x_eval)
+ count = 0
+ eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
+ for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
+ squad_eg = eval_examples_no_skip[idx]
+ offsets = squad_eg.context_token_to_char
+ start = np.argmax(start)
+ end = np.argmax(end)
+ if start >= len(offsets):
+ continue
+ pred_char_start = offsets[start][0]
+ if end < len(offsets):
+ pred_char_end = offsets[end][1]
+ pred_ans = squad_eg.context[pred_char_start:pred_char_end]
+ else:
+ pred_ans = squad_eg.context[pred_char_start:]
+
+ normalized_pred_ans = normalize_text(pred_ans)
+ normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
+ if normalized_pred_ans in normalized_true_ans:
+ count += 1
+ acc = count / len(self.y_eval[0])
+ print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")
+
+
+"""
+## Train and Evaluate
+"""
+exact_match_callback = ExactMatch(x_eval, y_eval)
+model.fit(
+ x_train,
+ y_train,
+ epochs=1, # For demonstration, 3 epochs are recommended
+ verbose=2,
+ batch_size=64,
+ callbacks=[exact_match_callback],
+)
diff --git a/knowledge_base/nlp/tweet-classification-using-tfdf.py b/knowledge_base/nlp/tweet-classification-using-tfdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..8320a0bc6eaf1a5cb72676085e58ef90242fbcf6
--- /dev/null
+++ b/knowledge_base/nlp/tweet-classification-using-tfdf.py
@@ -0,0 +1,297 @@
+"""
+Title: Text classification using Decision Forests and pretrained embeddings
+Author: Gitesh Chawda
+Date created: 09/05/2022
+Last modified: 09/05/2022
+Description: Using Tensorflow Decision Forests for text classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+[TensorFlow Decision Forests](https://www.tensorflow.org/decision_forests) (TF-DF)
+is a collection of state-of-the-art algorithms for Decision Forest models that are
+compatible with Keras APIs. The module includes Random Forests, Gradient Boosted Trees,
+and CART, and can be used for regression, classification, and ranking tasks.
+
+In this example we will use Gradient Boosted Trees with pretrained embeddings to
+classify disaster-related tweets.
+
+### See also:
+
+- [TF-DF beginner tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab)
+- [TF-DF intermediate tutorial](https://www.tensorflow.org/decision_forests/tutorials/intermediate_colab).
+"""
+
+"""
+Install Tensorflow Decision Forest using following command :
+`pip install tensorflow_decision_forests`
+"""
+
+
+"""
+## Imports
+"""
+
+import pandas as pd
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+import tensorflow_hub as hub
+from tensorflow.keras import layers
+import tensorflow_decision_forests as tfdf
+import matplotlib.pyplot as plt
+
+"""
+## Get the data
+
+The Dataset is available on [Kaggle](https://www.kaggle.com/c/nlp-getting-started)
+
+Dataset description:
+
+**Files:**
+
+- train.csv: the training set
+
+**Columns:**
+
+- id: a unique identifier for each tweet
+- text: the text of the tweet
+- location: the location the tweet was sent from (may be blank)
+- keyword: a particular keyword from the tweet (may be blank)
+- target: in train.csv only, this denotes whether a tweet is about a real disaster (1) or not (0)
+"""
+
+# Turn .csv files into pandas DataFrame's
+df = pd.read_csv(
+ "https://raw.githubusercontent.com/IMvision12/Tweets-Classification-NLP/main/train.csv"
+)
+print(df.head())
+
+"""
+The dataset includes 7613 samples with 5 columns:
+"""
+
+print(f"Training dataset shape: {df.shape}")
+
+"""
+Shuffling and dropping unnecessary columns:
+"""
+
+df_shuffled = df.sample(frac=1, random_state=42)
+# Dropping id, keyword and location columns as these columns consists of mostly nan values
+# we will be using only text and target columns
+df_shuffled.drop(["id", "keyword", "location"], axis=1, inplace=True)
+df_shuffled.reset_index(inplace=True, drop=True)
+print(df_shuffled.head())
+
+"""
+Printing information about the shuffled dataframe:
+"""
+
+print(df_shuffled.info())
+
+"""
+Total number of "disaster" and "non-disaster" tweets:
+"""
+
+print(
+ "Total Number of disaster and non-disaster tweets: "
+ f"{df_shuffled.target.value_counts()}"
+)
+
+"""
+Let's preview a few samples:
+"""
+
+for index, example in df_shuffled[:5].iterrows():
+ print(f"Example #{index}")
+ print(f"\tTarget : {example['target']}")
+ print(f"\tText : {example['text']}")
+
+"""
+Splitting dataset into training and test sets:
+"""
+test_df = df_shuffled.sample(frac=0.1, random_state=42)
+train_df = df_shuffled.drop(test_df.index)
+print(f"Using {len(train_df)} samples for training and {len(test_df)} for validation")
+
+"""
+Total number of "disaster" and "non-disaster" tweets in the training data:
+"""
+print(train_df["target"].value_counts())
+
+"""
+Total number of "disaster" and "non-disaster" tweets in the test data:
+"""
+
+print(test_df["target"].value_counts())
+
+"""
+## Convert data to a `tf.data.Dataset`
+"""
+
+
+def create_dataset(dataframe):
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (dataframe["text"].to_numpy(), dataframe["target"].to_numpy())
+ )
+ dataset = dataset.batch(100)
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
+ return dataset
+
+
+train_ds = create_dataset(train_df)
+test_ds = create_dataset(test_df)
+
+"""
+## Downloading pretrained embeddings
+
+The Universal Sentence Encoder embeddings encode text into high-dimensional vectors that can be
+used for text classification, semantic similarity, clustering and other natural language
+tasks. They're trained on a variety of data sources and a variety of tasks. Their input is
+variable-length English text and their output is a 512 dimensional vector.
+
+To learn more about these pretrained embeddings, see
+[Universal Sentence Encoder](https://tfhub.dev/google/universal-sentence-encoder/4).
+
+"""
+
+sentence_encoder_layer = hub.KerasLayer(
+ "https://tfhub.dev/google/universal-sentence-encoder/4"
+)
+
+"""
+## Creating our models
+
+We create two models. In the first model (model_1) raw text will be first encoded via
+pretrained embeddings and then passed to a Gradient Boosted Tree model for
+classification. In the second model (model_2) raw text will be directly passed to
+the Gradient Boosted Trees model.
+"""
+
+"""
+Building model_1
+"""
+
+inputs = layers.Input(shape=(), dtype=tf.string)
+outputs = sentence_encoder_layer(inputs)
+preprocessor = keras.Model(inputs=inputs, outputs=outputs)
+model_1 = tfdf.keras.GradientBoostedTreesModel(preprocessing=preprocessor)
+
+"""
+Building model_2
+"""
+
+model_2 = tfdf.keras.GradientBoostedTreesModel()
+
+"""
+## Train the models
+
+We compile our model by passing the metrics `Accuracy`, `Recall`, `Precision` and
+`AUC`. When it comes to the loss, TF-DF automatically detects the best loss for the task
+(Classification or regression). It is printed in the model summary.
+
+Also, because they're batch-training models rather than mini-batch gradient descent models,
+TF-DF models do not need a validation dataset to monitor overfitting, or to stop
+training early. Some algorithms do not use a validation dataset (e.g. Random Forest)
+while some others do (e.g. Gradient Boosted Trees). If a validation dataset is
+needed, it will be extracted automatically from the training dataset.
+"""
+
+# Compiling model_1
+model_1.compile(metrics=["Accuracy", "Recall", "Precision", "AUC"])
+# Here we do not specify epochs as, TF-DF trains exactly one epoch of the dataset
+model_1.fit(train_ds)
+
+# Compiling model_2
+model_2.compile(metrics=["Accuracy", "Recall", "Precision", "AUC"])
+# Here we do not specify epochs as, TF-DF trains exactly one epoch of the dataset
+model_2.fit(train_ds)
+
+"""
+Prints training logs of model_1
+"""
+
+logs_1 = model_1.make_inspector().training_logs()
+print(logs_1)
+
+"""
+Prints training logs of model_2
+"""
+
+logs_2 = model_2.make_inspector().training_logs()
+print(logs_2)
+
+"""
+The model.summary() method prints a variety of information about your decision tree model, including model type, task, input features, and feature importance.
+"""
+
+print("model_1 summary: ")
+print(model_1.summary())
+print()
+print("model_2 summary: ")
+print(model_2.summary())
+
+"""
+## Plotting training metrics
+"""
+
+
+def plot_curve(logs):
+ plt.figure(figsize=(12, 4))
+
+ plt.subplot(1, 2, 1)
+ plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])
+ plt.xlabel("Number of trees")
+ plt.ylabel("Accuracy")
+
+ plt.subplot(1, 2, 2)
+ plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])
+ plt.xlabel("Number of trees")
+ plt.ylabel("Loss")
+
+ plt.show()
+
+
+plot_curve(logs_1)
+plot_curve(logs_2)
+
+"""
+## Evaluating on test data
+"""
+
+results = model_1.evaluate(test_ds, return_dict=True, verbose=0)
+print("model_1 Evaluation: \n")
+for name, value in results.items():
+ print(f"{name}: {value:.4f}")
+
+results = model_2.evaluate(test_ds, return_dict=True, verbose=0)
+print("model_2 Evaluation: \n")
+for name, value in results.items():
+ print(f"{name}: {value:.4f}")
+
+"""
+## Predicting on validation data
+"""
+
+test_df.reset_index(inplace=True, drop=True)
+for index, row in test_df.iterrows():
+ text = tf.expand_dims(row["text"], axis=0)
+ preds = model_1.predict_step(text)
+ preds = tf.squeeze(tf.round(preds))
+ print(f"Text: {row['text']}")
+ print(f"Prediction: {int(preds)}")
+ print(f"Ground Truth : {row['target']}")
+ if index == 10:
+ break
+
+"""
+## Concluding remarks
+
+The TensorFlow Decision Forests package provides powerful models
+that work especially well with structured data. In our experiments,
+the Gradient Boosted Tree model with pretrained embeddings achieved 81.6%
+test accuracy while the plain Gradient Boosted Tree model had 54.4% accuracy.
+"""
diff --git a/knowledge_base/rl/actor_critic_cartpole.py b/knowledge_base/rl/actor_critic_cartpole.py
new file mode 100644
index 0000000000000000000000000000000000000000..c44552446c7d810465296fd886c53761886b9f21
--- /dev/null
+++ b/knowledge_base/rl/actor_critic_cartpole.py
@@ -0,0 +1,189 @@
+"""
+Title: Actor Critic Method
+Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
+Date created: 2020/05/13
+Last modified: 2024/02/22
+Description: Implement Actor Critic Method in CartPole environment.
+Accelerator: NONE
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+This script shows an implementation of Actor Critic method on CartPole-V0 environment.
+
+### Actor Critic Method
+
+As an agent takes actions and moves through an environment, it learns to map
+the observed state of the environment to two possible outputs:
+
+1. Recommended action: A probability value for each action in the action space.
+ The part of the agent responsible for this output is called the **actor**.
+2. Estimated rewards in the future: Sum of all rewards it expects to receive in the
+ future. The part of the agent responsible for this output is the **critic**.
+
+Agent and Critic learn to perform their tasks, such that the recommended actions
+from the actor maximize the rewards.
+
+### CartPole-V0
+
+A pole is attached to a cart placed on a frictionless track. The agent has to apply
+force to move the cart. It is rewarded for every time step the pole
+remains upright. The agent, therefore, must learn to keep the pole from falling over.
+
+### References
+
+- [Environment documentation](https://gymnasium.farama.org/environments/classic_control/cart_pole/)
+- [CartPole paper](http://www.derongliu.org/adp/adp-cdrom/Barto1983.pdf)
+- [Actor Critic Method](https://hal.inria.fr/hal-00840470/document)
+"""
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import gym
+import numpy as np
+import keras
+from keras import ops
+from keras import layers
+import tensorflow as tf
+
+# Configuration parameters for the whole setup
+seed = 42
+gamma = 0.99 # Discount factor for past rewards
+max_steps_per_episode = 10000
+# Adding `render_mode='human'` will show the attempts of the agent
+env = gym.make("CartPole-v0") # Create the environment
+env.reset(seed=seed)
+eps = np.finfo(np.float32).eps.item() # Smallest number such that 1.0 + eps != 1.0
+
+"""
+## Implement Actor Critic network
+
+This network learns two functions:
+
+1. Actor: This takes as input the state of our environment and returns a
+probability value for each action in its action space.
+2. Critic: This takes as input the state of our environment and returns
+an estimate of total rewards in the future.
+
+In our implementation, they share the initial layer.
+"""
+
+num_inputs = 4
+num_actions = 2
+num_hidden = 128
+
+inputs = layers.Input(shape=(num_inputs,))
+common = layers.Dense(num_hidden, activation="relu")(inputs)
+action = layers.Dense(num_actions, activation="softmax")(common)
+critic = layers.Dense(1)(common)
+
+model = keras.Model(inputs=inputs, outputs=[action, critic])
+
+"""
+## Train
+"""
+
+optimizer = keras.optimizers.Adam(learning_rate=0.01)
+huber_loss = keras.losses.Huber()
+action_probs_history = []
+critic_value_history = []
+rewards_history = []
+running_reward = 0
+episode_count = 0
+
+while True: # Run until solved
+ state = env.reset()[0]
+ episode_reward = 0
+ with tf.GradientTape() as tape:
+ for timestep in range(1, max_steps_per_episode):
+
+ state = ops.convert_to_tensor(state)
+ state = ops.expand_dims(state, 0)
+
+ # Predict action probabilities and estimated future rewards
+ # from environment state
+ action_probs, critic_value = model(state)
+ critic_value_history.append(critic_value[0, 0])
+
+ # Sample action from action probability distribution
+ action = np.random.choice(num_actions, p=np.squeeze(action_probs))
+ action_probs_history.append(ops.log(action_probs[0, action]))
+
+ # Apply the sampled action in our environment
+ state, reward, done, *_ = env.step(action)
+ rewards_history.append(reward)
+ episode_reward += reward
+
+ if done:
+ break
+
+ # Update running reward to check condition for solving
+ running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
+
+ # Calculate expected value from rewards
+ # - At each timestep what was the total reward received after that timestep
+ # - Rewards in the past are discounted by multiplying them with gamma
+ # - These are the labels for our critic
+ returns = []
+ discounted_sum = 0
+ for r in rewards_history[::-1]:
+ discounted_sum = r + gamma * discounted_sum
+ returns.insert(0, discounted_sum)
+
+ # Normalize
+ returns = np.array(returns)
+ returns = (returns - np.mean(returns)) / (np.std(returns) + eps)
+ returns = returns.tolist()
+
+ # Calculating loss values to update our network
+ history = zip(action_probs_history, critic_value_history, returns)
+ actor_losses = []
+ critic_losses = []
+ for log_prob, value, ret in history:
+ # At this point in history, the critic estimated that we would get a
+ # total reward = `value` in the future. We took an action with log probability
+ # of `log_prob` and ended up receiving a total reward = `ret`.
+ # The actor must be updated so that it predicts an action that leads to
+ # high rewards (compared to critic's estimate) with high probability.
+ diff = ret - value
+ actor_losses.append(-log_prob * diff) # actor loss
+
+ # The critic must be updated so that it predicts a better estimate of
+ # the future rewards.
+ critic_losses.append(
+ huber_loss(ops.expand_dims(value, 0), ops.expand_dims(ret, 0))
+ )
+
+ # Backpropagation
+ loss_value = sum(actor_losses) + sum(critic_losses)
+ grads = tape.gradient(loss_value, model.trainable_variables)
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
+
+ # Clear the loss and reward history
+ action_probs_history.clear()
+ critic_value_history.clear()
+ rewards_history.clear()
+
+ # Log details
+ episode_count += 1
+ if episode_count % 10 == 0:
+ template = "running reward: {:.2f} at episode {}"
+ print(template.format(running_reward, episode_count))
+
+ if running_reward > 195: # Condition to consider the task solved
+ print("Solved at episode {}!".format(episode_count))
+ break
+"""
+## Visualizations
+In early stages of training:
+
+
+In later stages of training:
+
+"""
diff --git a/knowledge_base/rl/ddpg_pendulum.py b/knowledge_base/rl/ddpg_pendulum.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d5b192fea12e321136a80bfc4a576549f4461f5
--- /dev/null
+++ b/knowledge_base/rl/ddpg_pendulum.py
@@ -0,0 +1,433 @@
+"""
+Title: Deep Deterministic Policy Gradient (DDPG)
+Author: [amifunny](https://github.com/amifunny)
+Date created: 2020/06/04
+Last modified: 2024/03/23
+Description: Implementing DDPG algorithm on the Inverted Pendulum Problem.
+Accelerator: None
+"""
+
+"""
+## Introduction
+
+**Deep Deterministic Policy Gradient (DDPG)** is a model-free off-policy algorithm for
+learning continuous actions.
+
+It combines ideas from DPG (Deterministic Policy Gradient) and DQN (Deep Q-Network).
+It uses Experience Replay and slow-learning target networks from DQN, and it is based on
+DPG, which can operate over continuous action spaces.
+
+This tutorial closely follow this paper -
+[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971)
+
+## Problem
+
+We are trying to solve the classic **Inverted Pendulum** control problem.
+In this setting, we can take only two actions: swing left or swing right.
+
+What make this problem challenging for Q-Learning Algorithms is that actions
+are **continuous** instead of being **discrete**. That is, instead of using two
+discrete actions like `-1` or `+1`, we have to select from infinite actions
+ranging from `-2` to `+2`.
+
+## Quick theory
+
+Just like the Actor-Critic method, we have two networks:
+
+1. Actor - It proposes an action given a state.
+2. Critic - It predicts if the action is good (positive value) or bad (negative value)
+given a state and an action.
+
+DDPG uses two more techniques not present in the original DQN:
+
+**First, it uses two Target networks.**
+
+**Why?** Because it add stability to training. In short, we are learning from estimated
+targets and Target networks are updated slowly, hence keeping our estimated targets
+stable.
+
+Conceptually, this is like saying, "I have an idea of how to play this well,
+I'm going to try it out for a bit until I find something better",
+as opposed to saying "I'm going to re-learn how to play this entire game after every
+move".
+See this [StackOverflow answer](https://stackoverflow.com/a/54238556/13475679).
+
+**Second, it uses Experience Replay.**
+
+We store list of tuples `(state, action, reward, next_state)`, and instead of
+learning only from recent experience, we learn from sampling all of our experience
+accumulated so far.
+
+Now, let's see how is it implemented.
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+
+import tensorflow as tf
+import gymnasium as gym
+import numpy as np
+import matplotlib.pyplot as plt
+
+"""
+We use [Gymnasium](https://gymnasium.farama.org/) to create the environment.
+We will use the `upper_bound` parameter to scale our actions later.
+"""
+
+# Specify the `render_mode` parameter to show the attempts of the agent in a pop up window.
+env = gym.make("Pendulum-v1") # , render_mode="human")
+
+num_states = env.observation_space.shape[0]
+print("Size of State Space -> {}".format(num_states))
+num_actions = env.action_space.shape[0]
+print("Size of Action Space -> {}".format(num_actions))
+
+upper_bound = env.action_space.high[0]
+lower_bound = env.action_space.low[0]
+
+print("Max Value of Action -> {}".format(upper_bound))
+print("Min Value of Action -> {}".format(lower_bound))
+
+"""
+To implement better exploration by the Actor network, we use noisy perturbations,
+specifically
+an **Ornstein-Uhlenbeck process** for generating noise, as described in the paper.
+It samples noise from a correlated normal distribution.
+"""
+
+
+class OUActionNoise:
+ def __init__(self, mean, std_deviation, theta=0.15, dt=1e-2, x_initial=None):
+ self.theta = theta
+ self.mean = mean
+ self.std_dev = std_deviation
+ self.dt = dt
+ self.x_initial = x_initial
+ self.reset()
+
+ def __call__(self):
+ # Formula taken from https://www.wikipedia.org/wiki/Ornstein-Uhlenbeck_process
+ x = (
+ self.x_prev
+ + self.theta * (self.mean - self.x_prev) * self.dt
+ + self.std_dev * np.sqrt(self.dt) * np.random.normal(size=self.mean.shape)
+ )
+ # Store x into x_prev
+ # Makes next noise dependent on current one
+ self.x_prev = x
+ return x
+
+ def reset(self):
+ if self.x_initial is not None:
+ self.x_prev = self.x_initial
+ else:
+ self.x_prev = np.zeros_like(self.mean)
+
+
+"""
+The `Buffer` class implements Experience Replay.
+
+---
+
+---
+
+
+**Critic loss** - Mean Squared Error of `y - Q(s, a)`
+where `y` is the expected return as seen by the Target network,
+and `Q(s, a)` is action value predicted by the Critic network. `y` is a moving target
+that the critic model tries to achieve; we make this target
+stable by updating the Target model slowly.
+
+**Actor loss** - This is computed using the mean of the value given by the Critic network
+for the actions taken by the Actor network. We seek to maximize this quantity.
+
+Hence we update the Actor network so that it produces actions that get
+the maximum predicted value as seen by the Critic, for a given state.
+"""
+
+
+class Buffer:
+ def __init__(self, buffer_capacity=100000, batch_size=64):
+ # Number of "experiences" to store at max
+ self.buffer_capacity = buffer_capacity
+ # Num of tuples to train on.
+ self.batch_size = batch_size
+
+ # Its tells us num of times record() was called.
+ self.buffer_counter = 0
+
+ # Instead of list of tuples as the exp.replay concept go
+ # We use different np.arrays for each tuple element
+ self.state_buffer = np.zeros((self.buffer_capacity, num_states))
+ self.action_buffer = np.zeros((self.buffer_capacity, num_actions))
+ self.reward_buffer = np.zeros((self.buffer_capacity, 1))
+ self.next_state_buffer = np.zeros((self.buffer_capacity, num_states))
+
+ # Takes (s,a,r,s') observation tuple as input
+ def record(self, obs_tuple):
+ # Set index to zero if buffer_capacity is exceeded,
+ # replacing old records
+ index = self.buffer_counter % self.buffer_capacity
+
+ self.state_buffer[index] = obs_tuple[0]
+ self.action_buffer[index] = obs_tuple[1]
+ self.reward_buffer[index] = obs_tuple[2]
+ self.next_state_buffer[index] = obs_tuple[3]
+
+ self.buffer_counter += 1
+
+ # Eager execution is turned on by default in TensorFlow 2. Decorating with tf.function allows
+ # TensorFlow to build a static graph out of the logic and computations in our function.
+ # This provides a large speed up for blocks of code that contain many small TensorFlow operations such as this one.
+ @tf.function
+ def update(
+ self,
+ state_batch,
+ action_batch,
+ reward_batch,
+ next_state_batch,
+ ):
+ # Training and updating Actor & Critic networks.
+ # See Pseudo Code.
+ with tf.GradientTape() as tape:
+ target_actions = target_actor(next_state_batch, training=True)
+ y = reward_batch + gamma * target_critic(
+ [next_state_batch, target_actions], training=True
+ )
+ critic_value = critic_model([state_batch, action_batch], training=True)
+ critic_loss = keras.ops.mean(keras.ops.square(y - critic_value))
+
+ critic_grad = tape.gradient(critic_loss, critic_model.trainable_variables)
+ critic_optimizer.apply_gradients(
+ zip(critic_grad, critic_model.trainable_variables)
+ )
+
+ with tf.GradientTape() as tape:
+ actions = actor_model(state_batch, training=True)
+ critic_value = critic_model([state_batch, actions], training=True)
+ # Used `-value` as we want to maximize the value given
+ # by the critic for our actions
+ actor_loss = -keras.ops.mean(critic_value)
+
+ actor_grad = tape.gradient(actor_loss, actor_model.trainable_variables)
+ actor_optimizer.apply_gradients(
+ zip(actor_grad, actor_model.trainable_variables)
+ )
+
+ # We compute the loss and update parameters
+ def learn(self):
+ # Get sampling range
+ record_range = min(self.buffer_counter, self.buffer_capacity)
+ # Randomly sample indices
+ batch_indices = np.random.choice(record_range, self.batch_size)
+
+ # Convert to tensors
+ state_batch = keras.ops.convert_to_tensor(self.state_buffer[batch_indices])
+ action_batch = keras.ops.convert_to_tensor(self.action_buffer[batch_indices])
+ reward_batch = keras.ops.convert_to_tensor(self.reward_buffer[batch_indices])
+ reward_batch = keras.ops.cast(reward_batch, dtype="float32")
+ next_state_batch = keras.ops.convert_to_tensor(
+ self.next_state_buffer[batch_indices]
+ )
+
+ self.update(state_batch, action_batch, reward_batch, next_state_batch)
+
+
+# This update target parameters slowly
+# Based on rate `tau`, which is much less than one.
+def update_target(target, original, tau):
+ target_weights = target.get_weights()
+ original_weights = original.get_weights()
+
+ for i in range(len(target_weights)):
+ target_weights[i] = original_weights[i] * tau + target_weights[i] * (1 - tau)
+
+ target.set_weights(target_weights)
+
+
+"""
+Here we define the Actor and Critic networks. These are basic Dense models
+with `ReLU` activation.
+
+Note: We need the initialization for last layer of the Actor to be between
+`-0.003` and `0.003` as this prevents us from getting `1` or `-1` output values in
+the initial stages, which would squash our gradients to zero,
+as we use the `tanh` activation.
+"""
+
+
+def get_actor():
+ # Initialize weights between -3e-3 and 3-e3
+ last_init = keras.initializers.RandomUniform(minval=-0.003, maxval=0.003)
+
+ inputs = layers.Input(shape=(num_states,))
+ out = layers.Dense(256, activation="relu")(inputs)
+ out = layers.Dense(256, activation="relu")(out)
+ outputs = layers.Dense(1, activation="tanh", kernel_initializer=last_init)(out)
+
+ # Our upper bound is 2.0 for Pendulum.
+ outputs = outputs * upper_bound
+ model = keras.Model(inputs, outputs)
+ return model
+
+
+def get_critic():
+ # State as input
+ state_input = layers.Input(shape=(num_states,))
+ state_out = layers.Dense(16, activation="relu")(state_input)
+ state_out = layers.Dense(32, activation="relu")(state_out)
+
+ # Action as input
+ action_input = layers.Input(shape=(num_actions,))
+ action_out = layers.Dense(32, activation="relu")(action_input)
+
+ # Both are passed through separate layer before concatenating
+ concat = layers.Concatenate()([state_out, action_out])
+
+ out = layers.Dense(256, activation="relu")(concat)
+ out = layers.Dense(256, activation="relu")(out)
+ outputs = layers.Dense(1)(out)
+
+ # Outputs single value for give state-action
+ model = keras.Model([state_input, action_input], outputs)
+
+ return model
+
+
+"""
+`policy()` returns an action sampled from our Actor network plus some noise for
+exploration.
+"""
+
+
+def policy(state, noise_object):
+ sampled_actions = keras.ops.squeeze(actor_model(state))
+ noise = noise_object()
+ # Adding noise to action
+ sampled_actions = sampled_actions.numpy() + noise
+
+ # We make sure action is within bounds
+ legal_action = np.clip(sampled_actions, lower_bound, upper_bound)
+
+ return [np.squeeze(legal_action)]
+
+
+"""
+## Training hyperparameters
+"""
+
+std_dev = 0.2
+ou_noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(std_dev) * np.ones(1))
+
+actor_model = get_actor()
+critic_model = get_critic()
+
+target_actor = get_actor()
+target_critic = get_critic()
+
+# Making the weights equal initially
+target_actor.set_weights(actor_model.get_weights())
+target_critic.set_weights(critic_model.get_weights())
+
+# Learning rate for actor-critic models
+critic_lr = 0.002
+actor_lr = 0.001
+
+critic_optimizer = keras.optimizers.Adam(critic_lr)
+actor_optimizer = keras.optimizers.Adam(actor_lr)
+
+total_episodes = 100
+# Discount factor for future rewards
+gamma = 0.99
+# Used to update target networks
+tau = 0.005
+
+buffer = Buffer(50000, 64)
+
+"""
+Now we implement our main training loop, and iterate over episodes.
+We sample actions using `policy()` and train with `learn()` at each time step,
+along with updating the Target networks at a rate `tau`.
+"""
+
+# To store reward history of each episode
+ep_reward_list = []
+# To store average reward history of last few episodes
+avg_reward_list = []
+
+# Takes about 4 min to train
+for ep in range(total_episodes):
+ prev_state, _ = env.reset()
+ episodic_reward = 0
+
+ while True:
+ tf_prev_state = keras.ops.expand_dims(
+ keras.ops.convert_to_tensor(prev_state), 0
+ )
+
+ action = policy(tf_prev_state, ou_noise)
+ # Receive state and reward from environment.
+ state, reward, done, truncated, _ = env.step(action)
+
+ buffer.record((prev_state, action, reward, state))
+ episodic_reward += reward
+
+ buffer.learn()
+
+ update_target(target_actor, actor_model, tau)
+ update_target(target_critic, critic_model, tau)
+
+ # End this episode when `done` or `truncated` is True
+ if done or truncated:
+ break
+
+ prev_state = state
+
+ ep_reward_list.append(episodic_reward)
+
+ # Mean of last 40 episodes
+ avg_reward = np.mean(ep_reward_list[-40:])
+ print("Episode * {} * Avg Reward is ==> {}".format(ep, avg_reward))
+ avg_reward_list.append(avg_reward)
+
+# Plotting graph
+# Episodes versus Avg. Rewards
+plt.plot(avg_reward_list)
+plt.xlabel("Episode")
+plt.ylabel("Avg. Episodic Reward")
+plt.show()
+
+"""
+If training proceeds correctly, the average episodic reward will increase with time.
+
+Feel free to try different learning rates, `tau` values, and architectures for the
+Actor and Critic networks.
+
+The Inverted Pendulum problem has low complexity, but DDPG work great on many other
+problems.
+
+Another great environment to try this on is `LunarLander-v2` continuous, but it will take
+more episodes to obtain good results.
+"""
+
+# Save the weights
+actor_model.save_weights("pendulum_actor.weights.h5")
+critic_model.save_weights("pendulum_critic.weights.h5")
+
+target_actor.save_weights("pendulum_target_actor.weights.h5")
+target_critic.save_weights("pendulum_target_critic.weights.h5")
+
+"""
+Before Training:
+
+
+"""
+
+"""
+After 100 episodes:
+
+
+"""
diff --git a/knowledge_base/rl/deep_q_network_breakout.py b/knowledge_base/rl/deep_q_network_breakout.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2d70afe4280af0887df868dd14ead9871d96af4
--- /dev/null
+++ b/knowledge_base/rl/deep_q_network_breakout.py
@@ -0,0 +1,285 @@
+"""
+Title: Deep Q-Learning for Atari Breakout
+Author: [Jacob Chapman](https://twitter.com/jacoblchapman) and [Mathias Lechner](https://twitter.com/MLech20)
+Date created: 2020/05/23
+Last modified: 2024/03/17
+Description: Play Atari Breakout with a Deep Q-Network.
+Accelerator: None
+"""
+
+"""
+## Introduction
+
+This script shows an implementation of Deep Q-Learning on the
+`BreakoutNoFrameskip-v4` environment.
+
+### Deep Q-Learning
+
+As an agent takes actions and moves through an environment, it learns to map
+the observed state of the environment to an action. An agent will choose an action
+in a given state based on a "Q-value", which is a weighted reward based on the
+expected highest long-term reward. A Q-Learning Agent learns to perform its
+task such that the recommended action maximizes the potential future rewards.
+This method is considered an "Off-Policy" method,
+meaning its Q values are updated assuming that the best action was chosen, even
+if the best action was not chosen.
+
+### Atari Breakout
+
+In this environment, a board moves along the bottom of the screen returning a ball that
+will destroy blocks at the top of the screen.
+The aim of the game is to remove all blocks and breakout of the
+level. The agent must learn to control the board by moving left and right, returning the
+ball and removing all the blocks without the ball passing the board.
+
+### Note
+
+The Deepmind paper trained for "a total of 50 million frames (that is, around 38 days of
+game experience in total)". However this script will give good results at around 10
+million frames which are processed in less than 24 hours on a modern machine.
+
+You can control the number of episodes by setting the `max_episodes` variable
+to a value greater than 0.
+
+### References
+
+- [Q-Learning](https://link.springer.com/content/pdf/10.1007/BF00992698.pdf)
+- [Deep Q-Learning](https://www.semanticscholar.org/paper/Human-level-control-through-deep-reinforcement-Mnih-Kavukcuoglu/340f48901f72278f6bf78a04ee5b01df208cc508)
+"""
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+
+import gymnasium as gym
+from gymnasium.wrappers import AtariPreprocessing, FrameStack
+import numpy as np
+import tensorflow as tf
+
+# Configuration parameters for the whole setup
+seed = 42
+gamma = 0.99 # Discount factor for past rewards
+epsilon = 1.0 # Epsilon greedy parameter
+epsilon_min = 0.1 # Minimum epsilon greedy parameter
+epsilon_max = 1.0 # Maximum epsilon greedy parameter
+epsilon_interval = (
+ epsilon_max - epsilon_min
+) # Rate at which to reduce chance of random action being taken
+batch_size = 32 # Size of batch taken from replay buffer
+max_steps_per_episode = 10000
+max_episodes = 10 # Limit training episodes, will run until solved if smaller than 1
+
+# Use the Atari environment
+# Specify the `render_mode` parameter to show the attempts of the agent in a pop up window.
+env = gym.make("BreakoutNoFrameskip-v4") # , render_mode="human")
+# Environment preprocessing
+env = AtariPreprocessing(env)
+# Stack four frames
+env = FrameStack(env, 4)
+env.seed(seed)
+"""
+## Implement the Deep Q-Network
+
+This network learns an approximation of the Q-table, which is a mapping between
+the states and actions that an agent will take. For every state we'll have four
+actions, that can be taken. The environment provides the state, and the action
+is chosen by selecting the larger of the four Q-values predicted in the output layer.
+
+"""
+
+num_actions = 4
+
+
+def create_q_model():
+ # Network defined by the Deepmind paper
+ return keras.Sequential(
+ [
+ layers.Lambda(
+ lambda tensor: keras.ops.transpose(tensor, [0, 2, 3, 1]),
+ output_shape=(84, 84, 4),
+ input_shape=(4, 84, 84),
+ ),
+ # Convolutions on the frames on the screen
+ layers.Conv2D(32, 8, strides=4, activation="relu", input_shape=(4, 84, 84)),
+ layers.Conv2D(64, 4, strides=2, activation="relu"),
+ layers.Conv2D(64, 3, strides=1, activation="relu"),
+ layers.Flatten(),
+ layers.Dense(512, activation="relu"),
+ layers.Dense(num_actions, activation="linear"),
+ ]
+ )
+
+
+# The first model makes the predictions for Q-values which are used to
+# make a action.
+model = create_q_model()
+# Build a target model for the prediction of future rewards.
+# The weights of a target model get updated every 10000 steps thus when the
+# loss between the Q-values is calculated the target Q-value is stable.
+model_target = create_q_model()
+
+
+"""
+## Train
+"""
+# In the Deepmind paper they use RMSProp however then Adam optimizer
+# improves training time
+optimizer = keras.optimizers.Adam(learning_rate=0.00025, clipnorm=1.0)
+
+# Experience replay buffers
+action_history = []
+state_history = []
+state_next_history = []
+rewards_history = []
+done_history = []
+episode_reward_history = []
+running_reward = 0
+episode_count = 0
+frame_count = 0
+# Number of frames to take random action and observe output
+epsilon_random_frames = 50000
+# Number of frames for exploration
+epsilon_greedy_frames = 1000000.0
+# Maximum replay length
+# Note: The Deepmind paper suggests 1000000 however this causes memory issues
+max_memory_length = 100000
+# Train the model after 4 actions
+update_after_actions = 4
+# How often to update the target network
+update_target_network = 10000
+# Using huber loss for stability
+loss_function = keras.losses.Huber()
+
+while True:
+ observation, _ = env.reset()
+ state = np.array(observation)
+ episode_reward = 0
+
+ for timestep in range(1, max_steps_per_episode):
+ frame_count += 1
+
+ # Use epsilon-greedy for exploration
+ if frame_count < epsilon_random_frames or epsilon > np.random.rand(1)[0]:
+ # Take random action
+ action = np.random.choice(num_actions)
+ else:
+ # Predict action Q-values
+ # From environment state
+ state_tensor = keras.ops.convert_to_tensor(state)
+ state_tensor = keras.ops.expand_dims(state_tensor, 0)
+ action_probs = model(state_tensor, training=False)
+ # Take best action
+ action = keras.ops.argmax(action_probs[0]).numpy()
+
+ # Decay probability of taking random action
+ epsilon -= epsilon_interval / epsilon_greedy_frames
+ epsilon = max(epsilon, epsilon_min)
+
+ # Apply the sampled action in our environment
+ state_next, reward, done, _, _ = env.step(action)
+ state_next = np.array(state_next)
+
+ episode_reward += reward
+
+ # Save actions and states in replay buffer
+ action_history.append(action)
+ state_history.append(state)
+ state_next_history.append(state_next)
+ done_history.append(done)
+ rewards_history.append(reward)
+ state = state_next
+
+ # Update every fourth frame and once batch size is over 32
+ if frame_count % update_after_actions == 0 and len(done_history) > batch_size:
+ # Get indices of samples for replay buffers
+ indices = np.random.choice(range(len(done_history)), size=batch_size)
+
+ # Using list comprehension to sample from replay buffer
+ state_sample = np.array([state_history[i] for i in indices])
+ state_next_sample = np.array([state_next_history[i] for i in indices])
+ rewards_sample = [rewards_history[i] for i in indices]
+ action_sample = [action_history[i] for i in indices]
+ done_sample = keras.ops.convert_to_tensor(
+ [float(done_history[i]) for i in indices]
+ )
+
+ # Build the updated Q-values for the sampled future states
+ # Use the target model for stability
+ future_rewards = model_target.predict(state_next_sample)
+ # Q value = reward + discount factor * expected future reward
+ updated_q_values = rewards_sample + gamma * keras.ops.amax(
+ future_rewards, axis=1
+ )
+
+ # If final frame set the last value to -1
+ updated_q_values = updated_q_values * (1 - done_sample) - done_sample
+
+ # Create a mask so we only calculate loss on the updated Q-values
+ masks = keras.ops.one_hot(action_sample, num_actions)
+
+ with tf.GradientTape() as tape:
+ # Train the model on the states and updated Q-values
+ q_values = model(state_sample)
+
+ # Apply the masks to the Q-values to get the Q-value for action taken
+ q_action = keras.ops.sum(keras.ops.multiply(q_values, masks), axis=1)
+ # Calculate loss between new Q-value and old Q-value
+ loss = loss_function(updated_q_values, q_action)
+
+ # Backpropagation
+ grads = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
+
+ if frame_count % update_target_network == 0:
+ # update the the target network with new weights
+ model_target.set_weights(model.get_weights())
+ # Log details
+ template = "running reward: {:.2f} at episode {}, frame count {}"
+ print(template.format(running_reward, episode_count, frame_count))
+
+ # Limit the state and reward history
+ if len(rewards_history) > max_memory_length:
+ del rewards_history[:1]
+ del state_history[:1]
+ del state_next_history[:1]
+ del action_history[:1]
+ del done_history[:1]
+
+ if done:
+ break
+
+ # Update running reward to check condition for solving
+ episode_reward_history.append(episode_reward)
+ if len(episode_reward_history) > 100:
+ del episode_reward_history[:1]
+ running_reward = np.mean(episode_reward_history)
+
+ episode_count += 1
+
+ if running_reward > 40: # Condition to consider the task solved
+ print("Solved at episode {}!".format(episode_count))
+ break
+
+ if (
+ max_episodes > 0 and episode_count >= max_episodes
+ ): # Maximum number of episodes reached
+ print("Stopped at episode {}!".format(episode_count))
+ break
+
+"""
+## Visualizations
+Before any training:
+
+
+In early stages of training:
+
+
+In later stages of training:
+
+"""
diff --git a/knowledge_base/rl/ppo_cartpole.py b/knowledge_base/rl/ppo_cartpole.py
new file mode 100644
index 0000000000000000000000000000000000000000..815127c2b917d3c4aa7d4f9aa54330e262f1c184
--- /dev/null
+++ b/knowledge_base/rl/ppo_cartpole.py
@@ -0,0 +1,340 @@
+"""
+Title: Proximal Policy Optimization
+Author: [Ilias Chrysovergis](https://twitter.com/iliachry)
+Date created: 2021/06/24
+Last modified: 2024/03/12
+Description: Implementation of a Proximal Policy Optimization agent for the CartPole-v1 environment.
+Accelerator: None
+"""
+
+"""
+## Introduction
+
+This code example solves the CartPole-v1 environment using a Proximal Policy Optimization (PPO) agent.
+
+### CartPole-v1
+
+A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
+The system is controlled by applying a force of +1 or -1 to the cart.
+The pendulum starts upright, and the goal is to prevent it from falling over.
+A reward of +1 is provided for every timestep that the pole remains upright.
+The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center.
+After 200 steps the episode ends. Thus, the highest return we can get is equal to 200.
+
+[CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/)
+
+### Proximal Policy Optimization
+
+PPO is a policy gradient method and can be used for environments with either discrete or continuous action spaces.
+It trains a stochastic policy in an on-policy way. Also, it utilizes the actor critic method. The actor maps the
+observation to an action and the critic gives an expectation of the rewards of the agent for the observation given.
+Firstly, it collects a set of trajectories for each epoch by sampling from the latest version of the stochastic policy.
+Then, the rewards-to-go and the advantage estimates are computed in order to update the policy and fit the value function.
+The policy is updated via a stochastic gradient ascent optimizer, while the value function is fitted via some gradient descent algorithm.
+This procedure is applied for many epochs until the environment is solved.
+
+
+
+- [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)
+- [OpenAI Spinning Up docs - PPO](https://spinningup.openai.com/en/latest/algorithms/ppo.html)
+
+### Note
+
+This code example uses Keras and Tensorflow v2. It is based on the PPO Original Paper,
+the OpenAI's Spinning Up docs for PPO, and the OpenAI's Spinning Up implementation of PPO using Tensorflow v1.
+
+[OpenAI Spinning Up Github - PPO](https://github.com/openai/spinningup/blob/master/spinup/algos/tf1/ppo/ppo.py)
+"""
+
+"""
+## Libraries
+
+For this example the following libraries are used:
+
+1. `numpy` for n-dimensional arrays
+2. `tensorflow` and `keras` for building the deep RL PPO agent
+3. `gymnasium` for getting everything we need about the environment
+4. `scipy.signal` for calculating the discounted cumulative sums of vectors
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+
+import numpy as np
+import tensorflow as tf
+import gymnasium as gym
+import scipy.signal
+
+"""
+## Functions and class
+"""
+
+
+def discounted_cumulative_sums(x, discount):
+ # Discounted cumulative sums of vectors for computing rewards-to-go and advantage estimates
+ return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
+
+
+class Buffer:
+ # Buffer for storing trajectories
+ def __init__(self, observation_dimensions, size, gamma=0.99, lam=0.95):
+ # Buffer initialization
+ self.observation_buffer = np.zeros(
+ (size, observation_dimensions), dtype=np.float32
+ )
+ self.action_buffer = np.zeros(size, dtype=np.int32)
+ self.advantage_buffer = np.zeros(size, dtype=np.float32)
+ self.reward_buffer = np.zeros(size, dtype=np.float32)
+ self.return_buffer = np.zeros(size, dtype=np.float32)
+ self.value_buffer = np.zeros(size, dtype=np.float32)
+ self.logprobability_buffer = np.zeros(size, dtype=np.float32)
+ self.gamma, self.lam = gamma, lam
+ self.pointer, self.trajectory_start_index = 0, 0
+
+ def store(self, observation, action, reward, value, logprobability):
+ # Append one step of agent-environment interaction
+ self.observation_buffer[self.pointer] = observation
+ self.action_buffer[self.pointer] = action
+ self.reward_buffer[self.pointer] = reward
+ self.value_buffer[self.pointer] = value
+ self.logprobability_buffer[self.pointer] = logprobability
+ self.pointer += 1
+
+ def finish_trajectory(self, last_value=0):
+ # Finish the trajectory by computing advantage estimates and rewards-to-go
+ path_slice = slice(self.trajectory_start_index, self.pointer)
+ rewards = np.append(self.reward_buffer[path_slice], last_value)
+ values = np.append(self.value_buffer[path_slice], last_value)
+
+ deltas = rewards[:-1] + self.gamma * values[1:] - values[:-1]
+
+ self.advantage_buffer[path_slice] = discounted_cumulative_sums(
+ deltas, self.gamma * self.lam
+ )
+ self.return_buffer[path_slice] = discounted_cumulative_sums(
+ rewards, self.gamma
+ )[:-1]
+
+ self.trajectory_start_index = self.pointer
+
+ def get(self):
+ # Get all data of the buffer and normalize the advantages
+ self.pointer, self.trajectory_start_index = 0, 0
+ advantage_mean, advantage_std = (
+ np.mean(self.advantage_buffer),
+ np.std(self.advantage_buffer),
+ )
+ self.advantage_buffer = (self.advantage_buffer - advantage_mean) / advantage_std
+ return (
+ self.observation_buffer,
+ self.action_buffer,
+ self.advantage_buffer,
+ self.return_buffer,
+ self.logprobability_buffer,
+ )
+
+
+def mlp(x, sizes, activation=keras.activations.tanh, output_activation=None):
+ # Build a feedforward neural network
+ for size in sizes[:-1]:
+ x = layers.Dense(units=size, activation=activation)(x)
+ return layers.Dense(units=sizes[-1], activation=output_activation)(x)
+
+
+def logprobabilities(logits, a):
+ # Compute the log-probabilities of taking actions a by using the logits (i.e. the output of the actor)
+ logprobabilities_all = keras.ops.log_softmax(logits)
+ logprobability = keras.ops.sum(
+ keras.ops.one_hot(a, num_actions) * logprobabilities_all, axis=1
+ )
+ return logprobability
+
+
+seed_generator = keras.random.SeedGenerator(1337)
+
+
+# Sample action from actor
+@tf.function
+def sample_action(observation):
+ logits = actor(observation)
+ action = keras.ops.squeeze(
+ keras.random.categorical(logits, 1, seed=seed_generator), axis=1
+ )
+ return logits, action
+
+
+# Train the policy by maxizing the PPO-Clip objective
+@tf.function
+def train_policy(
+ observation_buffer, action_buffer, logprobability_buffer, advantage_buffer
+):
+ with tf.GradientTape() as tape: # Record operations for automatic differentiation.
+ ratio = keras.ops.exp(
+ logprobabilities(actor(observation_buffer), action_buffer)
+ - logprobability_buffer
+ )
+ min_advantage = keras.ops.where(
+ advantage_buffer > 0,
+ (1 + clip_ratio) * advantage_buffer,
+ (1 - clip_ratio) * advantage_buffer,
+ )
+
+ policy_loss = -keras.ops.mean(
+ keras.ops.minimum(ratio * advantage_buffer, min_advantage)
+ )
+ policy_grads = tape.gradient(policy_loss, actor.trainable_variables)
+ policy_optimizer.apply_gradients(zip(policy_grads, actor.trainable_variables))
+
+ kl = keras.ops.mean(
+ logprobability_buffer
+ - logprobabilities(actor(observation_buffer), action_buffer)
+ )
+ kl = keras.ops.sum(kl)
+ return kl
+
+
+# Train the value function by regression on mean-squared error
+@tf.function
+def train_value_function(observation_buffer, return_buffer):
+ with tf.GradientTape() as tape: # Record operations for automatic differentiation.
+ value_loss = keras.ops.mean((return_buffer - critic(observation_buffer)) ** 2)
+ value_grads = tape.gradient(value_loss, critic.trainable_variables)
+ value_optimizer.apply_gradients(zip(value_grads, critic.trainable_variables))
+
+
+"""
+## Hyperparameters
+"""
+
+# Hyperparameters of the PPO algorithm
+steps_per_epoch = 4000
+epochs = 30
+gamma = 0.99
+clip_ratio = 0.2
+policy_learning_rate = 3e-4
+value_function_learning_rate = 1e-3
+train_policy_iterations = 80
+train_value_iterations = 80
+lam = 0.97
+target_kl = 0.01
+hidden_sizes = (64, 64)
+
+# True if you want to render the environment
+render = False
+
+"""
+## Initializations
+"""
+
+# Initialize the environment and get the dimensionality of the
+# observation space and the number of possible actions
+env = gym.make("CartPole-v1")
+observation_dimensions = env.observation_space.shape[0]
+num_actions = env.action_space.n
+
+# Initialize the buffer
+buffer = Buffer(observation_dimensions, steps_per_epoch)
+
+# Initialize the actor and the critic as keras models
+observation_input = keras.Input(shape=(observation_dimensions,), dtype="float32")
+logits = mlp(observation_input, list(hidden_sizes) + [num_actions])
+actor = keras.Model(inputs=observation_input, outputs=logits)
+value = keras.ops.squeeze(mlp(observation_input, list(hidden_sizes) + [1]), axis=1)
+critic = keras.Model(inputs=observation_input, outputs=value)
+
+# Initialize the policy and the value function optimizers
+policy_optimizer = keras.optimizers.Adam(learning_rate=policy_learning_rate)
+value_optimizer = keras.optimizers.Adam(learning_rate=value_function_learning_rate)
+
+# Initialize the observation, episode return and episode length
+observation, _ = env.reset()
+episode_return, episode_length = 0, 0
+
+"""
+## Train
+"""
+# Iterate over the number of epochs
+for epoch in range(epochs):
+ # Initialize the sum of the returns, lengths and number of episodes for each epoch
+ sum_return = 0
+ sum_length = 0
+ num_episodes = 0
+
+ # Iterate over the steps of each epoch
+ for t in range(steps_per_epoch):
+ if render:
+ env.render()
+
+ # Get the logits, action, and take one step in the environment
+ observation = observation.reshape(1, -1)
+ logits, action = sample_action(observation)
+ observation_new, reward, done, _, _ = env.step(action[0].numpy())
+ episode_return += reward
+ episode_length += 1
+
+ # Get the value and log-probability of the action
+ value_t = critic(observation)
+ logprobability_t = logprobabilities(logits, action)
+
+ # Store obs, act, rew, v_t, logp_pi_t
+ buffer.store(observation, action, reward, value_t, logprobability_t)
+
+ # Update the observation
+ observation = observation_new
+
+ # Finish trajectory if reached to a terminal state
+ terminal = done
+ if terminal or (t == steps_per_epoch - 1):
+ last_value = 0 if done else critic(observation.reshape(1, -1))
+ buffer.finish_trajectory(last_value)
+ sum_return += episode_return
+ sum_length += episode_length
+ num_episodes += 1
+ observation, _ = env.reset()
+ episode_return, episode_length = 0, 0
+
+ # Get values from the buffer
+ (
+ observation_buffer,
+ action_buffer,
+ advantage_buffer,
+ return_buffer,
+ logprobability_buffer,
+ ) = buffer.get()
+
+ # Update the policy and implement early stopping using KL divergence
+ for _ in range(train_policy_iterations):
+ kl = train_policy(
+ observation_buffer, action_buffer, logprobability_buffer, advantage_buffer
+ )
+ if kl > 1.5 * target_kl:
+ # Early Stopping
+ break
+
+ # Update the value function
+ for _ in range(train_value_iterations):
+ train_value_function(observation_buffer, return_buffer)
+
+ # Print mean return and length for each epoch
+ print(
+ f" Epoch: {epoch + 1}. Mean Return: {sum_return / num_episodes}. Mean Length: {sum_length / num_episodes}"
+ )
+
+"""
+## Visualizations
+
+Before training:
+
+
+
+After 8 epochs of training:
+
+
+
+After 20 epochs of training:
+
+
+"""
diff --git a/knowledge_base/structured_data/class_with_grn_and_vsn_with_hyperparameters_tuning.py b/knowledge_base/structured_data/class_with_grn_and_vsn_with_hyperparameters_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..961dd8b2ad467f8b8b804d8649b11aaa2d5bf724
--- /dev/null
+++ b/knowledge_base/structured_data/class_with_grn_and_vsn_with_hyperparameters_tuning.py
@@ -0,0 +1,708 @@
+"""
+Title: Classification with Gated Residual and Variable Selection Networks with HyperParameters tuning
+Author: [Humbulani Ndou](https://github.com/Humbulani1234)
+Date created: 2025/03/17
+Last modified: 2025/03/17
+Description: Gated Residual and Variable Selection Networks prediction with HyperParameters tuning.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The following example extends the script `structured_data/classification_with_grn_and_vsn.py` by incorporating hyperparameters tuning
+using [Autokeras](https://github.com/keras-team/autokeras) and [KerasTuner](https://github.com/keras-team/keras-tuner). Specifics regarding
+which APIs are used from the these two packages will be described in detail in the relevant code sections.
+
+This example demonstrates the use of Gated
+Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by
+Bryan Lim et al. in
+[Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363),
+for structured data classification. GRNs give the flexibility to the model to apply
+non-linear processing only where needed. VSNs allow the model to softly remove any
+unnecessary noisy inputs which could negatively impact performance.
+Together, those techniques help improving the learning capacity of deep neural
+network models.
+
+Note that this example implements only the GRN and VSN components described in
+in the paper, rather than the whole TFT model, as GRN and VSN can be useful on
+their own for structured data learning tasks.
+
+
+To run the code you need to use TensorFlow 2.3 or higher.
+"""
+
+"""
+## The dataset
+
+[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the
+Cleveland Clinic Foundation for Heart Disease.
+It's a CSV file with 303 rows. Each row contains information about a patient (a
+**sample**), and each column describes an attribute of the patient (a **feature**). We
+use the features to predict whether a patient has a heart disease (**binary
+classification**).
+
+Here's the description of each feature:
+
+Column| Description| Feature Type
+------------|--------------------|----------------------
+Age | Age in years | Numerical
+Sex | (1 = male; 0 = female) | Categorical
+CP | Chest pain type (0, 1, 2, 3, 4) | Categorical
+Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical
+Chol | Serum cholesterol in mg/dl | Numerical
+FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical
+RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical
+Thalach | Maximum heart rate achieved | Numerical
+Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical
+Oldpeak | ST depression induced by exercise relative to rest | Numerical
+Slope | Slope of the peak exercise ST segment | Numerical
+CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical
+Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical
+Target | Diagnosis of heart disease (1 = true; 0 = false) | Target
+"""
+
+"""
+## Setup
+"""
+
+
+import os
+import subprocess
+import tarfile
+import numpy as np
+import pandas as pd
+import tree
+from typing import Optional, Union
+
+os.environ["KERAS_BACKEND"] = "tensorflow" # or jax, or torch
+
+# Keras imports
+import keras
+from keras import layers
+
+# KerasTuner imports
+import keras_tuner
+from keras_tuner import HyperParameters
+
+# AutoKeras imports
+import autokeras as ak
+from autokeras.utils import utils, types
+
+
+"""
+## Preparing the data
+
+Let's download the data and load it into a Pandas dataframe:
+"""
+
+file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
+dataframe = pd.read_csv(file_url)
+
+"""
+The dataset includes 303 samples with 14 columns per sample (13 features, plus the target
+label):
+"""
+
+dataframe.shape
+
+"""
+Here's a preview of a few samples:
+"""
+
+dataframe.head()
+
+"""
+The last column, "target", indicates whether the patient has a heart disease (1) or not
+(0).
+
+Let's split the data into a training and validation set:
+"""
+
+val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
+train_dataframe = dataframe.drop(val_dataframe.index)
+
+print(
+ f"Using {len(train_dataframe)} samples for training "
+ f"and {len(val_dataframe)} for validation"
+)
+
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for reading and
+parsing the data into input features, and encoding the input features with respect
+to their types.
+"""
+
+COLUMN_NAMES = [
+ "age",
+ "sex",
+ "cp",
+ "trestbps",
+ "chol",
+ "fbs",
+ "restecg",
+ "thalach",
+ "exang",
+ "oldpeak",
+ "slope",
+ "ca",
+ "thal",
+ "target",
+]
+# Target feature name.
+TARGET_FEATURE_NAME = "target"
+# Numeric feature names.
+NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
+# Categorical features and their vocabulary lists.
+# Note that we add 'v=' as a prefix to all categorical feature values to make
+# sure that they are treated as strings.
+
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ feature_name: sorted(
+ [
+ # Integer categorcal must be int and string must be str
+ value if dataframe[feature_name].dtype == "int64" else str(value)
+ for value in list(dataframe[feature_name].unique())
+ ]
+ )
+ for feature_name in COLUMN_NAMES
+ if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
+}
+# All features names.
+FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
+ CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
+)
+
+
+"""
+## Feature preprocessing with Keras layers
+
+
+The following features are categorical features encoded as integers:
+
+- `sex`
+- `cp`
+- `fbs`
+- `restecg`
+- `exang`
+- `ca`
+
+We will encode these features using **one-hot encoding**. We have two options
+here:
+
+ - Use `CategoryEncoding()`, which requires knowing the range of input values
+ and will error on input outside the range.
+ - Use `IntegerLookup()` which will build a lookup table for inputs and reserve
+ an output index for unkown input values.
+
+For this example, we want a simple solution that will handle out of range inputs
+at inference, so we will use `IntegerLookup()`.
+
+We also have a categorical feature encoded as a string: `thal`. We will create an
+index of all possible features and encode output using the `StringLookup()` layer.
+
+Finally, the following feature are continuous numerical features:
+
+- `age`
+- `trestbps`
+- `chol`
+- `thalach`
+- `oldpeak`
+- `slope`
+
+For each of these features, we will use a `Normalization()` layer to make sure the mean
+of each feature is 0 and its standard deviation is 1.
+
+Below, we define a utility function to do the operations:
+
+- `process` to one-hot encode string or integer categorical features.
+"""
+
+# Tensorflow required for tf.data.Dataset
+import tensorflow as tf
+
+
+# We process our datasets elements here (categorical) and convert them to indices to avoid this step
+# during model training since only tensorflow support strings.
+def encode_categorical(features, target):
+ for f in features:
+ if f in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ cls = (
+ layers.StringLookup
+ if features[f].dtype == "string"
+ else layers.IntegerLookup
+ )
+ features[f] = cls(
+ vocabulary=CATEGORICAL_FEATURES_WITH_VOCABULARY[f],
+ mask_token=None,
+ num_oov_indices=0,
+ output_mode="binary",
+ )(features[f])
+
+ # Change features from OrderedDict to Dict to match Inputs as they are Dict.
+ return dict(features), target
+
+
+"""
+Let's generate `tf.data.Dataset` objects for each dataframe:
+"""
+
+
+def dataframe_to_dataset(dataframe):
+ dataframe = dataframe.copy()
+ labels = dataframe.pop("target")
+ ds = (
+ tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
+ .map(encode_categorical)
+ .shuffle(buffer_size=len(dataframe))
+ )
+ return ds
+
+
+train_ds = dataframe_to_dataset(train_dataframe)
+val_ds = dataframe_to_dataset(val_dataframe)
+
+
+"""
+Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
+and `target` is the value `0` or `1`:
+"""
+
+for x, y in train_ds.take(1):
+ print("Input:", x)
+ print("Target:", y)
+
+"""
+Let's batch the datasets:
+"""
+
+train_ds = train_ds.batch(32)
+val_ds = val_ds.batch(32)
+
+"""
+## Subclassing Autokeras Graph
+
+Here we subclass the Autokeras `Graph`
+
+- `build`: we override this method to be able to handle model `Inputs` passed
+as dictionaries. In structured data analysis Inputs are normally passed as
+dictionaries for each feature of interest
+
+"""
+
+
+class Graph(ak.graph.Graph):
+
+ def build(self, hp):
+ """Build the HyperModel into a Keras Model."""
+ keras_nodes = {}
+ keras_input_nodes = []
+ for node in self.inputs:
+ node_id = self._node_to_id[node]
+ input_node = node.build_node(hp)
+ output_node = node.build(hp, input_node)
+ keras_input_nodes.append(input_node)
+ keras_nodes[node_id] = output_node
+ for block in self.blocks:
+ temp_inputs = (
+ {
+ n.name: keras_nodes[self._node_to_id[n]]
+ for n in block.inputs
+ if isinstance(n, ak.Input)
+ }
+ if isinstance(block.inputs[0], ak.Input)
+ else [keras_nodes[self._node_to_id[n]] for n in block.inputs]
+ )
+ outputs = tree.flatten(block.build(hp, inputs=temp_inputs))
+ for n, o in zip(block.outputs, outputs):
+ keras_nodes[self._node_to_id[n]] = o
+ model = keras.models.Model(
+ keras_input_nodes,
+ [
+ keras_nodes[self._node_to_id[output_node]]
+ for output_node in self.outputs
+ ],
+ )
+ return self._compile_keras_model(hp, model)
+
+ def _compile_keras_model(self, hp, model):
+ # Specify hyperparameters from compile(...)
+ optimizer_name = hp.Choice(
+ "optimizer",
+ ["adam", "sgd"],
+ default="adam",
+ )
+ learning_rate = hp.Choice(
+ "learning_rate", [1e-1, 1e-2, 1e-3, 1e-4, 2e-5, 1e-5], default=1e-3
+ )
+ if optimizer_name == "adam":
+ optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
+ elif optimizer_name == "sgd":
+ optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
+ model.compile(
+ optimizer=optimizer,
+ metrics=self._get_metrics(),
+ loss=self._get_loss(),
+ )
+ return model
+
+
+"""
+
+## Subclassing Autokeras `Input`
+
+Here we subclass the Autokeras Input node object and override the dtype attribute
+from None to a user supplied value. We also override the `build_node` method to
+use user supplied name for Inputs layers.
+
+"""
+
+
+class Input(ak.Input):
+ def __init__(self, dtype, name=None, **kwargs):
+ super().__init__(name=name, **kwargs)
+ # Override dtype to a user dtype value
+ self.dtype = dtype
+ self.name = name
+
+ def build_node(self, hp):
+ return keras.Input(name=self.name, shape=self.shape, dtype=self.dtype)
+
+
+"""
+
+## Subclassing ClassificationHead
+
+Here we subclass Autokeras ClassificationHead and override the __init__ method, and
+we add the method `get_expected_shape` to infer the labels shape.
+We remove the preprocessing fuctionality as we prefer to conduct such manually.
+"""
+
+
+class ClassifierHead(ak.ClassificationHead):
+
+ def __init__(
+ self,
+ num_classes: Optional[int] = None,
+ multi_label: bool = False,
+ loss: Optional[types.LossType] = None,
+ metrics: Optional[types.MetricsType] = None,
+ dropout: Optional[float] = None,
+ **kwargs,
+ ):
+ self.num_classes = num_classes
+ self.multi_label = multi_label
+ self.dropout = dropout
+ if metrics is None:
+ metrics = ["accuracy"]
+ if loss is None:
+ loss = self.infer_loss()
+ ak.Head.__init__(self, loss=loss, metrics=metrics, **kwargs)
+ self.shape = self.get_expected_shape()
+
+ def get_expected_shape(self):
+ # Compute expected shape from num_classes.
+ if self.num_classes == 2 and not self.multi_label:
+ return [1]
+ return [self.num_classes]
+
+
+"""
+## GatedLinearUnit Layer
+
+This is a keras layer defined in the script `structured_data/classification_with_grn_vsn.py`
+More details about this layer maybe found in the relevant script
+
+"""
+
+
+class GatedLinearUnit(layers.Layer):
+ def __init__(self, num_units, activation, **kwargs):
+ super().__init__(**kwargs)
+ self.linear = layers.Dense(num_units)
+ self.sigmoid = layers.Dense(num_units, activation=activation)
+
+ def call(self, inputs):
+ return self.linear(inputs) * self.sigmoid(inputs)
+
+ def build(self):
+ self.built = True
+
+
+"""
+## GatedResidualNetwork Layer
+
+This is a keras layer defined in the script `structured_data/classification_with_grn_vsn.py`
+More details about this layer maybe found in the relevant script
+
+"""
+
+
+class GatedResidualNetwork(layers.Layer):
+
+ def __init__(
+ self, num_units, dropout_rate, activation, use_layernorm=None, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.num_units = num_units
+ self.use_layernorm = use_layernorm
+ self.elu_dense = layers.Dense(num_units, activation=activation)
+ self.linear_dense = layers.Dense(num_units)
+ self.dropout = layers.Dropout(dropout_rate)
+ self.gated_linear_unit = GatedLinearUnit(num_units, activation)
+ self.layer_norm = layers.LayerNormalization()
+ self.project = layers.Dense(num_units)
+
+ def call(self, inputs, hp):
+ x = self.elu_dense(inputs)
+ x = self.linear_dense(x)
+ x = self.dropout(x)
+ if inputs.shape[-1] != self.num_units:
+ inputs = self.project(inputs)
+ x = inputs + self.gated_linear_unit(x)
+ use_layernorm = self.use_layernorm
+ if use_layernorm is None:
+ use_layernorm = hp.Boolean("use_layernorm", default=True)
+ if use_layernorm:
+ x = self.layer_norm(x)
+ return x
+
+ def build(self):
+ self.built = True
+
+
+"""
+## Building the Autokeras `VariableSelection Block`
+
+We have converted the following keras layer to an Autokeras Block to include
+hyperapameters to tune. Refer to Autokeras blocks API for writing custom Blocks.
+
+"""
+
+
+class VariableSelection(ak.Block):
+ def __init__(
+ self,
+ num_units: Optional[Union[int, HyperParameters.Choice]] = None,
+ dropout_rate: Optional[Union[float, HyperParameters.Choice]] = None,
+ activation: Optional[Union[str, HyperParameters.Choice]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dropout = utils.get_hyperparameter(
+ dropout_rate,
+ HyperParameters().Choice("dropout", [0.0, 0.25, 0.5], default=0.0),
+ float,
+ )
+ self.num_units = utils.get_hyperparameter(
+ num_units,
+ HyperParameters().Choice(
+ "num_units", [16, 32, 64, 128, 256, 512, 1024], default=16
+ ),
+ int,
+ )
+ self.activation = utils.get_hyperparameter(
+ activation,
+ HyperParameters().Choice(
+ "vsn_activation", ["sigmoid", "elu"], default="sigmoid"
+ ),
+ str,
+ )
+
+ def build(self, hp, inputs):
+ num_units = utils.add_to_hp(self.num_units, hp, "num_units")
+ dropout_rate = utils.add_to_hp(self.dropout, hp, "dropout_rate")
+ activation = utils.add_to_hp(self.activation, hp, "activation")
+ concat_inputs = []
+ # Project the features to 'num_units' dimension
+ for input_ in inputs:
+ if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ concat_inputs.append(
+ keras.layers.Dense(units=num_units)(inputs[input_])
+ )
+ else:
+ # Create a Normalization layer for our feature
+ normalizer = layers.Normalization()
+ # Prepare a Dataset that only yields our feature
+ feature_ds = train_ds.map(lambda x, y: x[input_]).map(
+ lambda x: keras.ops.expand_dims(x, -1)
+ )
+ # Learn the statistics of the data
+ normalizer.adapt(feature_ds)
+ # Normalize the input feature
+ normal_feature = normalizer(inputs[input_])
+ concat_inputs.append(
+ keras.layers.Dense(units=num_units)(normal_feature)
+ )
+ v = layers.concatenate(concat_inputs)
+ v = GatedResidualNetwork(
+ num_units=num_units, dropout_rate=dropout_rate, activation=activation
+ )(v, hp=hp)
+ v = keras.ops.expand_dims(
+ layers.Dense(units=len(inputs), activation=activation)(v), axis=-1
+ )
+ x = []
+ x += [
+ GatedResidualNetwork(num_units, dropout_rate, activation)(i, hp=hp)
+ for i in concat_inputs
+ ]
+ x = keras.ops.stack(x, axis=1)
+ return keras.ops.squeeze(
+ keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
+ )
+
+
+"""
+
+# We create the HyperModel (from KerasTuner) Inputs which will be built into Keras Input objects
+
+"""
+
+
+# Categorical features have different shapes after the encoding, dependent on the
+# vocabulary or unique values of each feature. We create them accordinly to match the
+# input data elements generated by tf.data.Dataset after pre-processing them
+def create_model_inputs():
+ inputs = {
+ f: (
+ Input(
+ name=f,
+ shape=(len(CATEGORICAL_FEATURES_WITH_VOCABULARY[f]),),
+ dtype="int64",
+ )
+ if f in CATEGORICAL_FEATURES_WITH_VOCABULARY
+ else Input(name=f, shape=(1,), dtype="float32")
+ )
+ for f in FEATURE_NAMES
+ }
+ return inputs
+
+
+"""
+
+## KerasTuner `HyperModel`
+
+Here we use the Autokeras `Functional` API to construct a network of BlocksSSS which will
+be built into a KerasTuner HyperModel and finally to a Keras Model.
+
+"""
+
+
+class MyHyperModel(keras_tuner.HyperModel):
+
+ def build(self, hp):
+ inputs = create_model_inputs()
+ features = VariableSelection()(inputs)
+ outputs = ClassifierHead(num_classes=2, multi_label=False)(features)
+ model = Graph(inputs=inputs, outputs=outputs)
+ model = model.build(hp)
+ return model
+
+ def fit(self, hp, model, *args, **kwargs):
+ return model.fit(
+ *args,
+ # Tune whether to shuffle the data in each epoch.
+ shuffle=hp.Boolean("shuffle"),
+ **kwargs,
+ )
+
+
+"""
+
+## Using `RandomSearch` Tuner to find best HyperParameters
+
+We use the RandomSearch tuner to serach for hyparameters in the search space
+We also display the search space
+
+"""
+
+print("Start training and searching for the best model...")
+
+tuner = keras_tuner.RandomSearch(
+ MyHyperModel(),
+ objective="val_accuracy",
+ max_trials=3,
+ overwrite=True,
+ directory="my_dir",
+ project_name="tune_hypermodel",
+)
+
+# Show the search space summary
+print("Tuner search space summary:\n")
+tuner.search_space_summary()
+# Search for best model
+tuner.search(train_ds, epochs=2, validation_data=val_ds)
+
+"""
+## Extracting the best model
+"""
+
+# Get the top model.
+models = tuner.get_best_models(num_models=1)
+best_model = models[0]
+best_model.summary()
+
+
+"""
+## Inference on new data
+
+To get a prediction for a new sample, you can simply call `model.predict()`. There are
+just two things you need to do:
+
+1. wrap scalars into a list so as to have a batch dimension (models only process batches
+of data, not single samples)
+2. Call `convert_to_tensor` on each feature
+"""
+
+sample = {
+ "age": 60,
+ "sex": 1,
+ "cp": 1,
+ "trestbps": 145,
+ "chol": 233,
+ "fbs": 1,
+ "restecg": 2,
+ "thalach": 150,
+ "exang": 0,
+ "oldpeak": 2.3,
+ "slope": 3,
+ "ca": 0,
+ "thal": "fixed",
+}
+
+
+# Given the category (in the sample above - key) and the category value (in the sample above - value),
+# we return its one-hot encoding
+def get_cat_encoding(cat, cat_value):
+ # Create a list of zeros with the same length as categories
+ encoding = [0] * len(cat)
+ # Find the index of category_value in categories and set the corresponding position to 1
+ if cat_value in cat:
+ encoding[cat.index(cat_value)] = 1
+ return encoding
+
+
+for name, value in sample.items():
+ if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ sample.update(
+ {
+ name: get_cat_encoding(
+ CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
+ )
+ }
+ )
+# Convert inputs to tensors
+input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
+predictions = best_model.predict(input_dict)
+
+print(
+ f"This particular patient had a {100 * predictions[0][0]:.1f} "
+ "percent probability of having a heart disease, "
+ "as evaluated by our model."
+)
diff --git a/knowledge_base/structured_data/classification_with_grn_and_vsn.py b/knowledge_base/structured_data/classification_with_grn_and_vsn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1209cb137b1cef0b9f5dcf13707db6e483091d5d
--- /dev/null
+++ b/knowledge_base/structured_data/classification_with_grn_and_vsn.py
@@ -0,0 +1,529 @@
+"""
+Title: Classification with Gated Residual and Variable Selection Networks
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/02/10
+Last modified: 2025/01/08
+Description: Using Gated Residual and Variable Selection Networks for income level prediction.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
+"""
+
+"""
+## Introduction
+
+This example demonstrates the use of Gated
+Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by
+Bryan Lim et al. in
+[Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/abs/1912.09363),
+for structured data classification. GRNs give the flexibility to the model to apply
+non-linear processing only where needed. VSNs allow the model to softly remove any
+unnecessary noisy inputs which could negatively impact performance.
+Together, those techniques help improving the learning capacity of deep neural
+network models.
+
+Note that this example implements only the GRN and VSN components described in
+in the paper, rather than the whole TFT model, as GRN and VSN can be useful on
+their own for structured data learning tasks.
+
+
+To run the code you need to use TensorFlow 2.3 or higher.
+"""
+
+"""
+## The dataset
+
+This example uses the
+[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29)
+provided by the
+[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).
+The task is binary classification to determine whether a person makes over 50K a year.
+
+The dataset includes ~300K instances with 41 input features: 7 numerical features
+and 34 categorical features.
+"""
+
+"""
+## Setup
+"""
+
+import os
+import subprocess
+import tarfile
+
+os.environ["KERAS_BACKEND"] = "torch" # or jax, or tensorflow
+
+import numpy as np
+import pandas as pd
+import keras
+from keras import layers
+
+"""
+## Prepare the data
+
+First we load the data from the UCI Machine Learning Repository into a Pandas DataFrame.
+"""
+
+# Column names.
+CSV_HEADER = [
+ "age",
+ "class_of_worker",
+ "detailed_industry_recode",
+ "detailed_occupation_recode",
+ "education",
+ "wage_per_hour",
+ "enroll_in_edu_inst_last_wk",
+ "marital_stat",
+ "major_industry_code",
+ "major_occupation_code",
+ "race",
+ "hispanic_origin",
+ "sex",
+ "member_of_a_labor_union",
+ "reason_for_unemployment",
+ "full_or_part_time_employment_stat",
+ "capital_gains",
+ "capital_losses",
+ "dividends_from_stocks",
+ "tax_filer_stat",
+ "region_of_previous_residence",
+ "state_of_previous_residence",
+ "detailed_household_and_family_stat",
+ "detailed_household_summary_in_household",
+ "instance_weight",
+ "migration_code-change_in_msa",
+ "migration_code-change_in_reg",
+ "migration_code-move_within_reg",
+ "live_in_this_house_1_year_ago",
+ "migration_prev_res_in_sunbelt",
+ "num_persons_worked_for_employer",
+ "family_members_under_18",
+ "country_of_birth_father",
+ "country_of_birth_mother",
+ "country_of_birth_self",
+ "citizenship",
+ "own_business_or_self_employed",
+ "fill_inc_questionnaire_for_veterans_admin",
+ "veterans_benefits",
+ "weeks_worked_in_year",
+ "year",
+ "income_level",
+]
+
+data_url = "https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip"
+keras.utils.get_file(origin=data_url, extract=True)
+
+"""
+Determine the downloaded .tar.gz file path and
+extract the files from the downloaded .tar.gz file
+"""
+
+extracted_path = os.path.join(
+ os.path.expanduser("~"), ".keras", "datasets", "census+income+kdd.zip"
+)
+for root, dirs, files in os.walk(extracted_path):
+ for file in files:
+ if file.endswith(".tar.gz"):
+ tar_gz_path = os.path.join(root, file)
+ with tarfile.open(tar_gz_path, "r:gz") as tar:
+ tar.extractall(path=root)
+
+train_data_path = os.path.join(
+ os.path.expanduser("~"),
+ ".keras",
+ "datasets",
+ "census+income+kdd.zip",
+ "census-income.data",
+)
+test_data_path = os.path.join(
+ os.path.expanduser("~"),
+ ".keras",
+ "datasets",
+ "census+income+kdd.zip",
+ "census-income.test",
+)
+
+data = pd.read_csv(train_data_path, header=None, names=CSV_HEADER)
+test_data = pd.read_csv(test_data_path, header=None, names=CSV_HEADER)
+
+print(f"Data shape: {data.shape}")
+print(f"Test data shape: {test_data.shape}")
+
+
+"""
+We convert the target column from string to integer.
+"""
+
+data["income_level"] = data["income_level"].apply(
+ lambda x: 0 if x == " - 50000." else 1
+)
+test_data["income_level"] = test_data["income_level"].apply(
+ lambda x: 0 if x == " - 50000." else 1
+)
+
+
+"""
+Then, We split the dataset into train and validation sets.
+"""
+
+random_selection = np.random.rand(len(data.index)) <= 0.85
+train_data = data[random_selection]
+valid_data = data[~random_selection]
+
+
+"""
+Finally we store the train and test data splits locally to CSV files.
+"""
+
+train_data_file = "train_data.csv"
+valid_data_file = "valid_data.csv"
+test_data_file = "test_data.csv"
+
+train_data.to_csv(train_data_file, index=False, header=False)
+valid_data.to_csv(valid_data_file, index=False, header=False)
+test_data.to_csv(test_data_file, index=False, header=False)
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for reading and
+parsing the data into input features, and encoding the input features with respect
+to their types.
+"""
+
+# Target feature name.
+TARGET_FEATURE_NAME = "income_level"
+# Weight column name.
+WEIGHT_COLUMN_NAME = "instance_weight"
+# Numeric feature names.
+NUMERIC_FEATURE_NAMES = [
+ "age",
+ "wage_per_hour",
+ "capital_gains",
+ "capital_losses",
+ "dividends_from_stocks",
+ "num_persons_worked_for_employer",
+ "weeks_worked_in_year",
+]
+# Categorical features and their vocabulary lists.
+# Note that we add 'v=' as a prefix to all categorical feature values to make
+# sure that they are treated as strings.
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ feature_name: sorted([str(value) for value in list(data[feature_name].unique())])
+ for feature_name in CSV_HEADER
+ if feature_name
+ not in list(NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME, TARGET_FEATURE_NAME])
+}
+# All features names.
+FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
+ CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
+)
+# Feature default values.
+COLUMN_DEFAULTS = [
+ (
+ [0.0]
+ if feature_name
+ in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME]
+ else ["NA"]
+ )
+ for feature_name in CSV_HEADER
+]
+
+"""
+## Create a `tf.data.Dataset` for training and evaluation
+
+We create an input function to read and parse the file, and convert features and
+labels into a [`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets) for
+training and evaluation.
+"""
+
+# Tensorflow required for tf.data.Datasets
+import tensorflow as tf
+
+
+# We process our datasets elements here (categorical) and convert them to indices to avoid this step
+# during model training since only tensorflow support strings.
+def process(features, target):
+ for feature_name in features:
+ if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ # Cast categorical feature values to string.
+ features[feature_name] = tf.cast(features[feature_name], "string")
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ index = layers.StringLookup(
+ vocabulary=vocabulary,
+ mask_token=None,
+ num_oov_indices=0,
+ output_mode="int",
+ )
+ # Convert the string input values into integer indices.
+ value_index = index(features[feature_name])
+ features[feature_name] = value_index
+ else:
+ # Do nothing for numerical features
+ pass
+
+ # Get the instance weight.
+ weight = features.pop(WEIGHT_COLUMN_NAME)
+ # Change features from OrderedDict to Dict to match Inputs as they are Dict.
+ return dict(features), target, weight
+
+
+def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
+ dataset = tf.data.experimental.make_csv_dataset(
+ csv_file_path,
+ batch_size=batch_size,
+ column_names=CSV_HEADER,
+ column_defaults=COLUMN_DEFAULTS,
+ label_name=TARGET_FEATURE_NAME,
+ num_epochs=1,
+ header=False,
+ shuffle=shuffle,
+ ).map(process)
+
+ return dataset
+
+
+"""
+## Create model inputs
+"""
+
+
+def create_model_inputs():
+ inputs = {}
+ for feature_name in FEATURE_NAMES:
+ if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ # Make them int64, they are Categorical (whole units)
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="int64"
+ )
+ else:
+ # Make them float32, they are Real numbers
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="float32"
+ )
+ return inputs
+
+
+"""
+## Implement the Gated Linear Unit
+
+[Gated Linear Units (GLUs)](https://arxiv.org/abs/1612.08083) provide the
+flexibility to suppress input that are not relevant for a given task.
+"""
+
+
+class GatedLinearUnit(layers.Layer):
+ def __init__(self, units):
+ super().__init__()
+ self.linear = layers.Dense(units)
+ self.sigmoid = layers.Dense(units, activation="sigmoid")
+
+ def call(self, inputs):
+ return self.linear(inputs) * self.sigmoid(inputs)
+
+ # Remove build warnings
+ def build(self):
+ self.built = True
+
+
+"""
+## Implement the Gated Residual Network
+
+The Gated Residual Network (GRN) works as follows:
+
+1. Applies the nonlinear ELU transformation to the inputs.
+2. Applies linear transformation followed by dropout.
+4. Applies GLU and adds the original inputs to the output of the GLU to perform skip
+(residual) connection.
+6. Applies layer normalization and produces the output.
+"""
+
+
+class GatedResidualNetwork(layers.Layer):
+ def __init__(self, units, dropout_rate):
+ super().__init__()
+ self.units = units
+ self.elu_dense = layers.Dense(units, activation="elu")
+ self.linear_dense = layers.Dense(units)
+ self.dropout = layers.Dropout(dropout_rate)
+ self.gated_linear_unit = GatedLinearUnit(units)
+ self.layer_norm = layers.LayerNormalization()
+ self.project = layers.Dense(units)
+
+ def call(self, inputs):
+ x = self.elu_dense(inputs)
+ x = self.linear_dense(x)
+ x = self.dropout(x)
+ if inputs.shape[-1] != self.units:
+ inputs = self.project(inputs)
+ x = inputs + self.gated_linear_unit(x)
+ x = self.layer_norm(x)
+ return x
+
+ # Remove build warnings
+ def build(self):
+ self.built = True
+
+
+"""
+## Implement the Variable Selection Network
+
+The Variable Selection Network (VSN) works as follows:
+
+1. Applies a GRN to each feature individually.
+2. Applies a GRN on the concatenation of all the features, followed by a softmax to
+produce feature weights.
+3. Produces a weighted sum of the output of the individual GRN.
+
+Note that the output of the VSN is [batch_size, encoding_size], regardless of the
+number of the input features.
+
+For categorical features, we encode them using `layers.Embedding` using the
+`encoding_size` as the embedding dimensions. For the numerical features,
+we apply linear transformation using `layers.Dense` to project each feature into
+`encoding_size`-dimensional vector. Thus, all the encoded features will have the
+same dimensionality.
+
+"""
+
+
+class VariableSelection(layers.Layer):
+ def __init__(self, num_features, units, dropout_rate):
+ super().__init__()
+ self.units = units
+ # Create an embedding layers with the specified dimensions
+ self.embeddings = dict()
+ for input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[input_]
+ embedding_encoder = layers.Embedding(
+ input_dim=len(vocabulary), output_dim=self.units, name=input_
+ )
+ self.embeddings[input_] = embedding_encoder
+
+ # Projection layers for numeric features
+ self.proj_layer = dict()
+ for input_ in NUMERIC_FEATURE_NAMES:
+ proj_layer = layers.Dense(units=self.units)
+ self.proj_layer[input_] = proj_layer
+
+ self.grns = list()
+ # Create a GRN for each feature independently
+ for idx in range(num_features):
+ grn = GatedResidualNetwork(units, dropout_rate)
+ self.grns.append(grn)
+ # Create a GRN for the concatenation of all the features
+ self.grn_concat = GatedResidualNetwork(units, dropout_rate)
+ self.softmax = layers.Dense(units=num_features, activation="softmax")
+
+ def call(self, inputs):
+ concat_inputs = []
+ for input_ in inputs:
+ if input_ in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ max_index = self.embeddings[input_].input_dim - 1 # Clamp the indices
+ # torch had some index errors during embedding hence the clip function
+ embedded_feature = self.embeddings[input_](
+ keras.ops.clip(inputs[input_], 0, max_index)
+ )
+ concat_inputs.append(embedded_feature)
+ else:
+ # Project the numeric feature to encoding_size using linear transformation.
+ proj_feature = keras.ops.expand_dims(inputs[input_], -1)
+ proj_feature = self.proj_layer[input_](proj_feature)
+ concat_inputs.append(proj_feature)
+
+ v = layers.concatenate(concat_inputs)
+ v = self.grn_concat(v)
+ v = keras.ops.expand_dims(self.softmax(v), axis=-1)
+ x = []
+ for idx, input in enumerate(concat_inputs):
+ x.append(self.grns[idx](input))
+ x = keras.ops.stack(x, axis=1)
+ return keras.ops.squeeze(
+ keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
+ )
+
+ # to remove the build warnings
+ def build(self):
+ self.built = True
+
+
+"""
+## Create Gated Residual and Variable Selection Networks model
+"""
+
+
+def create_model(encoding_size):
+ inputs = create_model_inputs()
+ num_features = len(inputs)
+ features = VariableSelection(num_features, encoding_size, dropout_rate)(inputs)
+ outputs = layers.Dense(units=1, activation="sigmoid")(features)
+ # Functional model
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+"""
+## Compile, train, and evaluate the model
+"""
+
+learning_rate = 0.001
+dropout_rate = 0.15
+batch_size = 265
+num_epochs = 20 # may be adjusted to a desired value
+encoding_size = 16
+
+model = create_model(encoding_size)
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
+ loss=keras.losses.BinaryCrossentropy(),
+ metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
+)
+
+"""
+Let's visualize our connectivity graph:
+"""
+
+# `rankdir='LR'` is to make the graph horizontal.
+keras.utils.plot_model(model, show_shapes=True, show_layer_names=True, rankdir="LR")
+
+
+# Create an early stopping callback.
+early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_loss", patience=5, restore_best_weights=True
+)
+
+print("Start training the model...")
+train_dataset = get_dataset_from_csv(
+ train_data_file, shuffle=True, batch_size=batch_size
+)
+valid_dataset = get_dataset_from_csv(valid_data_file, batch_size=batch_size)
+model.fit(
+ train_dataset,
+ epochs=num_epochs,
+ validation_data=valid_dataset,
+ callbacks=[early_stopping],
+)
+print("Model training finished.")
+
+print("Evaluating model performance...")
+test_dataset = get_dataset_from_csv(test_data_file, batch_size=batch_size)
+_, accuracy = model.evaluate(test_dataset)
+print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+"""
+You should achieve more than 95% accuracy on the test set.
+
+To increase the learning capacity of the model, you can try increasing the
+`encoding_size` value, or stacking multiple GRN layers on top of the VSN layer.
+This may require to also increase the `dropout_rate` value to avoid overfitting.
+"""
+
+"""
+**Example available on HuggingFace**
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/structured-data-classification-grn-vsn) | [](https://huggingface.co/spaces/keras-io/structured-data-classification-grn-vsn) |
+"""
diff --git a/knowledge_base/structured_data/classification_with_tfdf.py b/knowledge_base/structured_data/classification_with_tfdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61285809f46ec4ad1a635bb8f9d643d4bfc2f36
--- /dev/null
+++ b/knowledge_base/structured_data/classification_with_tfdf.py
@@ -0,0 +1,670 @@
+"""
+Title: Classification with TensorFlow Decision Forests
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2022/01/25
+Last modified: 2022/01/25
+Description: Using TensorFlow Decision Forests for structured data classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+[TensorFlow Decision Forests](https://www.tensorflow.org/decision_forests)
+is a collection of state-of-the-art algorithms of Decision Forest models
+that are compatible with Keras APIs.
+The models include [Random Forests](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/RandomForestModel),
+[Gradient Boosted Trees](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel),
+and [CART](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/CartModel),
+and can be used for regression, classification, and ranking task.
+For a beginner's guide to TensorFlow Decision Forests,
+please refer to this [tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab).
+
+
+This example uses Gradient Boosted Trees model in binary classification of
+structured data, and covers the following scenarios:
+
+1. Build a decision forests model by specifying the input feature usage.
+2. Implement a custom *Binary Target encoder* as a [Keras Preprocessing layer](https://keras.io/api/layers/preprocessing_layers/)
+to encode the categorical features with respect to their target value co-occurrences,
+and then use the encoded features to build a decision forests model.
+3. Encode the categorical features as [embeddings](https://keras.io/api/layers/core_layers/embedding),
+train these embeddings in a simple NN model, and then use the
+trained embeddings as inputs to build decision forests model.
+
+This example uses TensorFlow 2.7 or higher,
+as well as [TensorFlow Decision Forests](https://www.tensorflow.org/decision_forests),
+which you can install using the following command:
+
+```python
+pip install -U tensorflow_decision_forests
+```
+"""
+
+"""
+## Setup
+"""
+
+import math
+import urllib
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import tensorflow_decision_forests as tfdf
+
+"""
+## Prepare the data
+
+This example uses the
+[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census-Income+%28KDD%29)
+provided by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).
+The task is binary classification to determine whether a person makes over 50K a year.
+
+The dataset includes ~300K instances with 41 input features: 7 numerical features
+and 34 categorical features.
+
+First we load the data from the UCI Machine Learning Repository into a Pandas DataFrame.
+"""
+
+BASE_PATH = "https://kdd.ics.uci.edu/databases/census-income/census-income"
+CSV_HEADER = [
+ l.decode("utf-8").split(":")[0].replace(" ", "_")
+ for l in urllib.request.urlopen(f"{BASE_PATH}.names")
+ if not l.startswith(b"|")
+][2:]
+CSV_HEADER.append("income_level")
+
+train_data = pd.read_csv(
+ f"{BASE_PATH}.data.gz",
+ header=None,
+ names=CSV_HEADER,
+)
+test_data = pd.read_csv(
+ f"{BASE_PATH}.test.gz",
+ header=None,
+ names=CSV_HEADER,
+)
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for encoding
+the input features with respect to their types.
+"""
+
+# Target column name.
+TARGET_COLUMN_NAME = "income_level"
+# The labels of the target columns.
+TARGET_LABELS = [" - 50000.", " 50000+."]
+# Weight column name.
+WEIGHT_COLUMN_NAME = "instance_weight"
+# Numeric feature names.
+NUMERIC_FEATURE_NAMES = [
+ "age",
+ "wage_per_hour",
+ "capital_gains",
+ "capital_losses",
+ "dividends_from_stocks",
+ "num_persons_worked_for_employer",
+ "weeks_worked_in_year",
+]
+# Categorical features and their vocabulary lists.
+CATEGORICAL_FEATURE_NAMES = [
+ "class_of_worker",
+ "detailed_industry_recode",
+ "detailed_occupation_recode",
+ "education",
+ "enroll_in_edu_inst_last_wk",
+ "marital_stat",
+ "major_industry_code",
+ "major_occupation_code",
+ "race",
+ "hispanic_origin",
+ "sex",
+ "member_of_a_labor_union",
+ "reason_for_unemployment",
+ "full_or_part_time_employment_stat",
+ "tax_filer_stat",
+ "region_of_previous_residence",
+ "state_of_previous_residence",
+ "detailed_household_and_family_stat",
+ "detailed_household_summary_in_household",
+ "migration_code-change_in_msa",
+ "migration_code-change_in_reg",
+ "migration_code-move_within_reg",
+ "live_in_this_house_1_year_ago",
+ "migration_prev_res_in_sunbelt",
+ "family_members_under_18",
+ "country_of_birth_father",
+ "country_of_birth_mother",
+ "country_of_birth_self",
+ "citizenship",
+ "own_business_or_self_employed",
+ "fill_inc_questionnaire_for_veteran's_admin",
+ "veterans_benefits",
+ "year",
+]
+
+
+"""
+Now we perform basic data preparation.
+"""
+
+
+def prepare_dataframe(dataframe):
+ # Convert the target labels from string to integer.
+ dataframe[TARGET_COLUMN_NAME] = dataframe[TARGET_COLUMN_NAME].map(
+ TARGET_LABELS.index
+ )
+ # Cast the categorical features to string.
+ for feature_name in CATEGORICAL_FEATURE_NAMES:
+ dataframe[feature_name] = dataframe[feature_name].astype(str)
+
+
+prepare_dataframe(train_data)
+prepare_dataframe(test_data)
+
+"""
+Now let's show the shapes of the training and test dataframes, and display some instances.
+"""
+
+print(f"Train data shape: {train_data.shape}")
+print(f"Test data shape: {test_data.shape}")
+print(train_data.head().T)
+
+"""
+## Configure hyperparameters
+
+You can find all the parameters of the Gradient Boosted Tree model in the
+[documentation](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel)
+"""
+
+# Maximum number of decision trees. The effective number of trained trees can be smaller if early stopping is enabled.
+NUM_TREES = 250
+# Minimum number of examples in a node.
+MIN_EXAMPLES = 6
+# Maximum depth of the tree. max_depth=1 means that all trees will be roots.
+MAX_DEPTH = 5
+# Ratio of the dataset (sampling without replacement) used to train individual trees for the random sampling method.
+SUBSAMPLE = 0.65
+# Control the sampling of the datasets used to train individual trees.
+SAMPLING_METHOD = "RANDOM"
+# Ratio of the training dataset used to monitor the training. Require to be >0 if early stopping is enabled.
+VALIDATION_RATIO = 0.1
+
+"""
+## Implement a training and evaluation procedure
+
+The `run_experiment()` method is responsible loading the train and test datasets,
+training a given model, and evaluating the trained model.
+
+Note that when training a Decision Forests model, only one epoch is needed to
+read the full dataset. Any extra steps will result in unnecessary slower training.
+Therefore, the default `num_epochs=1` is used in the `run_experiment()` method.
+"""
+
+
+def run_experiment(model, train_data, test_data, num_epochs=1, batch_size=None):
+ train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
+ train_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME
+ )
+ test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
+ test_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME
+ )
+
+ model.fit(train_dataset, epochs=num_epochs, batch_size=batch_size)
+ _, accuracy = model.evaluate(test_dataset, verbose=0)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+
+"""
+## Experiment 1: Decision Forests with raw features
+"""
+
+"""
+### Specify model input feature usages
+
+You can attach semantics to each feature to control how it is used by the model.
+If not specified, the semantics are inferred from the representation type.
+It is recommended to specify the [feature usages](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/FeatureUsage)
+explicitly to avoid incorrect inferred semantics is incorrect.
+For example, a categorical value identifier (integer) will be be inferred as numerical,
+while it is semantically categorical.
+
+For numerical features, you can set the `discretized` parameters to the number
+of buckets by which the numerical feature should be discretized.
+This makes the training faster but may lead to worse models.
+"""
+
+
+def specify_feature_usages():
+ feature_usages = []
+
+ for feature_name in NUMERIC_FEATURE_NAMES:
+ feature_usage = tfdf.keras.FeatureUsage(
+ name=feature_name, semantic=tfdf.keras.FeatureSemantic.NUMERICAL
+ )
+ feature_usages.append(feature_usage)
+
+ for feature_name in CATEGORICAL_FEATURE_NAMES:
+ feature_usage = tfdf.keras.FeatureUsage(
+ name=feature_name, semantic=tfdf.keras.FeatureSemantic.CATEGORICAL
+ )
+ feature_usages.append(feature_usage)
+
+ return feature_usages
+
+
+"""
+### Create a Gradient Boosted Trees model
+
+When compiling a decision forests model, you may only provide extra evaluation metrics.
+The loss is specified in the model construction,
+and the optimizer is irrelevant to decision forests models.
+"""
+
+
+def create_gbt_model():
+ # See all the model parameters in https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel
+ gbt_model = tfdf.keras.GradientBoostedTreesModel(
+ features=specify_feature_usages(),
+ exclude_non_specified_features=True,
+ num_trees=NUM_TREES,
+ max_depth=MAX_DEPTH,
+ min_examples=MIN_EXAMPLES,
+ subsample=SUBSAMPLE,
+ validation_ratio=VALIDATION_RATIO,
+ task=tfdf.keras.Task.CLASSIFICATION,
+ )
+
+ gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])
+ return gbt_model
+
+
+"""
+### Train and evaluate the model
+"""
+
+gbt_model = create_gbt_model()
+run_experiment(gbt_model, train_data, test_data)
+
+"""
+### Inspect the model
+
+The `model.summary()` method will display several types of information about
+your decision trees model, model type, task, input features, and feature importance.
+"""
+
+print(gbt_model.summary())
+
+"""
+## Experiment 2: Decision Forests with target encoding
+
+[Target encoding](https://dl.acm.org/doi/10.1145/507533.507538) is a common preprocessing
+technique for categorical features that convert them into numerical features.
+Using categorical features with high cardinality as-is may lead to overfitting.
+Target encoding aims to replace each categorical feature value with one or more
+numerical values that represent its co-occurrence with the target labels.
+
+More precisely, given a categorical feature, the binary target encoder in this example
+will produce three new numerical features:
+
+1. `positive_frequency`: How many times each feature value occurred with a positive target label.
+2. `negative_frequency`: How many times each feature value occurred with a negative target label.
+3. `positive_probability`: The probability that the target label is positive,
+given the feature value, which is computed as
+`positive_frequency / (positive_frequency + negative_frequency + correction)`.
+The `correction` term is added in to make the division more stable for rare categorical values.
+The default value for `correction` is 1.0.
+
+
+
+Note that target encoding is effective with models that cannot automatically
+learn dense representations to categorical features, such as decision forests
+or kernel methods. If neural network models are used, its recommended to
+encode categorical features as embeddings.
+"""
+
+"""
+### Implement Binary Target Encoder
+
+For simplicity, we assume that the inputs for the `adapt` and `call` methods
+are in the expected data types and shapes, so no validation logic is added.
+
+It is recommended to pass the `vocabulary_size` of the categorical feature to the
+`BinaryTargetEncoding` constructor. If not specified, it will be computed during
+the `adapt()` method execution.
+"""
+
+
+class BinaryTargetEncoding(layers.Layer):
+ def __init__(self, vocabulary_size=None, correction=1.0, **kwargs):
+ super().__init__(**kwargs)
+ self.vocabulary_size = vocabulary_size
+ self.correction = correction
+
+ def adapt(self, data):
+ # data is expected to be an integer numpy array to a Tensor shape [num_exmples, 2].
+ # This contains feature values for a given feature in the dataset, and target values.
+
+ # Convert the data to a tensor.
+ data = tf.convert_to_tensor(data)
+ # Separate the feature values and target values
+ feature_values = tf.cast(data[:, 0], tf.dtypes.int32)
+ target_values = tf.cast(data[:, 1], tf.dtypes.bool)
+
+ # Compute the vocabulary_size of not specified.
+ if self.vocabulary_size is None:
+ self.vocabulary_size = tf.unique(feature_values).y.shape[0]
+
+ # Filter the data where the target label is positive.
+ positive_indices = tf.where(condition=target_values)
+ positive_feature_values = tf.gather_nd(
+ params=feature_values, indices=positive_indices
+ )
+ # Compute how many times each feature value occurred with a positive target label.
+ positive_frequency = tf.math.unsorted_segment_sum(
+ data=tf.ones(
+ shape=(positive_feature_values.shape[0], 1), dtype=tf.dtypes.float64
+ ),
+ segment_ids=positive_feature_values,
+ num_segments=self.vocabulary_size,
+ )
+
+ # Filter the data where the target label is negative.
+ negative_indices = tf.where(condition=tf.math.logical_not(target_values))
+ negative_feature_values = tf.gather_nd(
+ params=feature_values, indices=negative_indices
+ )
+ # Compute how many times each feature value occurred with a negative target label.
+ negative_frequency = tf.math.unsorted_segment_sum(
+ data=tf.ones(
+ shape=(negative_feature_values.shape[0], 1), dtype=tf.dtypes.float64
+ ),
+ segment_ids=negative_feature_values,
+ num_segments=self.vocabulary_size,
+ )
+ # Compute positive probability for the input feature values.
+ positive_probability = positive_frequency / (
+ positive_frequency + negative_frequency + self.correction
+ )
+ # Concatenate the computed statistics for traget_encoding.
+ target_encoding_statistics = tf.cast(
+ tf.concat(
+ [positive_frequency, negative_frequency, positive_probability], axis=1
+ ),
+ dtype=tf.dtypes.float32,
+ )
+ self.target_encoding_statistics = tf.constant(target_encoding_statistics)
+
+ def call(self, inputs):
+ # inputs is expected to be an integer numpy array to a Tensor shape [num_exmples, 1].
+ # This includes the feature values for a given feature in the dataset.
+
+ # Raise an error if the target encoding statistics are not computed.
+ if self.target_encoding_statistics == None:
+ raise ValueError(
+ f"You need to call the adapt method to compute target encoding statistics."
+ )
+
+ # Convert the inputs to a tensor.
+ inputs = tf.convert_to_tensor(inputs)
+ # Cast the inputs int64 a tensor.
+ inputs = tf.cast(inputs, tf.dtypes.int64)
+ # Lookup target encoding statistics for the input feature values.
+ target_encoding_statistics = tf.cast(
+ tf.gather_nd(self.target_encoding_statistics, inputs),
+ dtype=tf.dtypes.float32,
+ )
+ return target_encoding_statistics
+
+
+"""
+Let's test the binary target encoder
+"""
+
+data = tf.constant(
+ [
+ [0, 1],
+ [2, 0],
+ [0, 1],
+ [1, 1],
+ [1, 1],
+ [2, 0],
+ [1, 0],
+ [0, 1],
+ [2, 1],
+ [1, 0],
+ [0, 1],
+ [2, 0],
+ [0, 1],
+ [1, 1],
+ [1, 1],
+ [2, 0],
+ [1, 0],
+ [0, 1],
+ [2, 0],
+ ]
+)
+
+binary_target_encoder = BinaryTargetEncoding()
+binary_target_encoder.adapt(data)
+print(binary_target_encoder([[0], [1], [2]]))
+
+"""
+### Create model inputs
+"""
+
+
+def create_model_inputs():
+ inputs = {}
+
+ for feature_name in NUMERIC_FEATURE_NAMES:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype=tf.float32
+ )
+
+ for feature_name in CATEGORICAL_FEATURE_NAMES:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype=tf.string
+ )
+
+ return inputs
+
+
+"""
+### Implement a feature encoding with target encoding
+"""
+
+
+def create_target_encoder():
+ inputs = create_model_inputs()
+ target_values = train_data[[TARGET_COLUMN_NAME]].to_numpy()
+ encoded_features = []
+ for feature_name in inputs:
+ if feature_name in CATEGORICAL_FEATURE_NAMES:
+ # Get the vocabulary of the categorical feature.
+ vocabulary = sorted(
+ [str(value) for value in list(train_data[feature_name].unique())]
+ )
+ # Create a lookup to convert string values to an integer indices.
+ # Since we are not using a mask token nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ lookup = layers.StringLookup(
+ vocabulary=vocabulary, mask_token=None, num_oov_indices=0
+ )
+ # Convert the string input values into integer indices.
+ value_indices = lookup(inputs[feature_name])
+ # Prepare the data to adapt the target encoding.
+ print("### Adapting target encoding for:", feature_name)
+ feature_values = train_data[[feature_name]].to_numpy().astype(str)
+ feature_value_indices = lookup(feature_values)
+ data = tf.concat([feature_value_indices, target_values], axis=1)
+ feature_encoder = BinaryTargetEncoding()
+ feature_encoder.adapt(data)
+ # Convert the feature value indices to target encoding representations.
+ encoded_feature = feature_encoder(tf.expand_dims(value_indices, -1))
+ else:
+ # Expand the dimensions of the numerical input feature and use it as-is.
+ encoded_feature = tf.expand_dims(inputs[feature_name], -1)
+ # Add the encoded feature to the list.
+ encoded_features.append(encoded_feature)
+ # Concatenate all the encoded features.
+ encoded_features = tf.concat(encoded_features, axis=1)
+ # Create and return a Keras model with encoded features as outputs.
+ return keras.Model(inputs=inputs, outputs=encoded_features)
+
+
+"""
+### Create a Gradient Boosted Trees model with a preprocessor
+
+In this scenario, we use the target encoding as a preprocessor for the Gradient Boosted Tree model,
+and let the model infer semantics of the input features.
+"""
+
+
+def create_gbt_with_preprocessor(preprocessor):
+ gbt_model = tfdf.keras.GradientBoostedTreesModel(
+ preprocessing=preprocessor,
+ num_trees=NUM_TREES,
+ max_depth=MAX_DEPTH,
+ min_examples=MIN_EXAMPLES,
+ subsample=SUBSAMPLE,
+ validation_ratio=VALIDATION_RATIO,
+ task=tfdf.keras.Task.CLASSIFICATION,
+ )
+
+ gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])
+
+ return gbt_model
+
+
+"""
+### Train and evaluate the model
+"""
+
+gbt_model = create_gbt_with_preprocessor(create_target_encoder())
+run_experiment(gbt_model, train_data, test_data)
+
+"""
+## Experiment 3: Decision Forests with trained embeddings
+
+In this scenario, we build an encoder model that codes the categorical
+features to embeddings, where the size of the embedding for a given categorical
+feature is the square root to the size of its vocabulary.
+
+We train these embeddings in a simple NN model through backpropagation.
+After the embedding encoder is trained, we used it as a preprocessor to the
+input features of a Gradient Boosted Tree model.
+
+Note that the embeddings and a decision forest model cannot be trained
+synergically in one phase, since decision forest models do not train with backpropagation.
+Rather, embeddings has to be trained in an initial phase,
+and then used as static inputs to the decision forest model.
+"""
+
+"""
+### Implement feature encoding with embeddings
+"""
+
+
+def create_embedding_encoder(size=None):
+ inputs = create_model_inputs()
+ encoded_features = []
+ for feature_name in inputs:
+ if feature_name in CATEGORICAL_FEATURE_NAMES:
+ # Get the vocabulary of the categorical feature.
+ vocabulary = sorted(
+ [str(value) for value in list(train_data[feature_name].unique())]
+ )
+ # Create a lookup to convert string values to an integer indices.
+ # Since we are not using a mask token nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ lookup = layers.StringLookup(
+ vocabulary=vocabulary, mask_token=None, num_oov_indices=0
+ )
+ # Convert the string input values into integer indices.
+ value_index = lookup(inputs[feature_name])
+ # Create an embedding layer with the specified dimensions
+ vocabulary_size = len(vocabulary)
+ embedding_size = int(math.sqrt(vocabulary_size))
+ feature_encoder = layers.Embedding(
+ input_dim=len(vocabulary), output_dim=embedding_size
+ )
+ # Convert the index values to embedding representations.
+ encoded_feature = feature_encoder(value_index)
+ else:
+ # Expand the dimensions of the numerical input feature and use it as-is.
+ encoded_feature = tf.expand_dims(inputs[feature_name], -1)
+ # Add the encoded feature to the list.
+ encoded_features.append(encoded_feature)
+ # Concatenate all the encoded features.
+ encoded_features = layers.concatenate(encoded_features, axis=1)
+ # Apply dropout.
+ encoded_features = layers.Dropout(rate=0.25)(encoded_features)
+ # Perform non-linearity projection.
+ encoded_features = layers.Dense(
+ units=size if size else encoded_features.shape[-1], activation="gelu"
+ )(encoded_features)
+ # Create and return a Keras model with encoded features as outputs.
+ return keras.Model(inputs=inputs, outputs=encoded_features)
+
+
+"""
+### Build an NN model to train the embeddings
+"""
+
+
+def create_nn_model(encoder):
+ inputs = create_model_inputs()
+ embeddings = encoder(inputs)
+ output = layers.Dense(units=1, activation="sigmoid")(embeddings)
+
+ nn_model = keras.Model(inputs=inputs, outputs=output)
+ nn_model.compile(
+ optimizer=keras.optimizers.Adam(),
+ loss=keras.losses.BinaryCrossentropy(),
+ metrics=[keras.metrics.BinaryAccuracy("accuracy")],
+ )
+ return nn_model
+
+
+embedding_encoder = create_embedding_encoder(size=64)
+run_experiment(
+ create_nn_model(embedding_encoder),
+ train_data,
+ test_data,
+ num_epochs=5,
+ batch_size=256,
+)
+
+"""
+### Train and evaluate a Gradient Boosted Tree model with embeddings
+"""
+
+gbt_model = create_gbt_with_preprocessor(embedding_encoder)
+run_experiment(gbt_model, train_data, test_data)
+
+"""
+## Concluding remarks
+
+TensorFlow Decision Forests provide powerful models, especially with structured data.
+In our experiments, the Gradient Boosted Tree model achieved 95.79% test accuracy.
+When using the target encoding with categorical feature, the same model achieved 95.81% test accuracy.
+When pretraining embeddings to be used as inputs to the Gradient Boosted Tree model,
+we achieved 95.82% test accuracy.
+
+Decision Forests can be used with Neural Networks, either by
+1) using Neural Networks to learn useful representation of the input data,
+and then using Decision Forests for the supervised learning task, or by
+2) creating an ensemble of both Decision Forests and Neural Network models.
+
+Note that TensorFlow Decision Forests does not (yet) support hardware accelerators.
+All training and inference is done on the CPU.
+Besides, Decision Forests require a finite dataset that fits in memory
+for their training procedures. However, there are diminishing returns
+for increasing the size of the dataset, and Decision Forests algorithms
+arguably need fewer examples for convergence than large Neural Network models.
+"""
diff --git a/knowledge_base/structured_data/collaborative_filtering_movielens.py b/knowledge_base/structured_data/collaborative_filtering_movielens.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5558ff9c3fa1286ae469551d803d286c3e5c88a
--- /dev/null
+++ b/knowledge_base/structured_data/collaborative_filtering_movielens.py
@@ -0,0 +1,230 @@
+"""
+Title: Collaborative Filtering for Movie Recommendations
+Author: [Siddhartha Banerjee](https://twitter.com/sidd2006)
+Date created: 2020/05/24
+Last modified: 2020/05/24
+Description: Recommending movies using a model trained on Movielens dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates
+[Collaborative filtering](https://en.wikipedia.org/wiki/Collaborative_filtering)
+using the [Movielens dataset](https://www.kaggle.com/c/movielens-100k)
+to recommend movies to users.
+The MovieLens ratings dataset lists the ratings given by a set of users to a set of movies.
+Our goal is to be able to predict ratings for movies a user has not yet watched.
+The movies with the highest predicted ratings can then be recommended to the user.
+
+The steps in the model are as follows:
+
+1. Map user ID to a "user vector" via an embedding matrix
+2. Map movie ID to a "movie vector" via an embedding matrix
+3. Compute the dot product between the user vector and movie vector, to obtain
+the a match score between the user and the movie (predicted rating).
+4. Train the embeddings via gradient descent using all known user-movie pairs.
+
+**References:**
+
+- [Collaborative Filtering](https://dl.acm.org/doi/pdf/10.1145/371920.372071)
+- [Neural Collaborative Filtering](https://dl.acm.org/doi/pdf/10.1145/3038912.3052569)
+"""
+
+import pandas as pd
+from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
+from zipfile import ZipFile
+
+import keras
+from keras import layers
+from keras import ops
+
+"""
+## First, load the data and apply preprocessing
+"""
+
+# Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"
+# Use the ratings.csv file
+movielens_data_file_url = (
+ "http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"
+)
+movielens_zipped_file = keras.utils.get_file(
+ "ml-latest-small.zip", movielens_data_file_url, extract=False
+)
+keras_datasets_path = Path(movielens_zipped_file).parents[0]
+movielens_dir = keras_datasets_path / "ml-latest-small"
+
+# Only extract the data the first time the script is run.
+if not movielens_dir.exists():
+ with ZipFile(movielens_zipped_file, "r") as zip:
+ # Extract files
+ print("Extracting all the files now...")
+ zip.extractall(path=keras_datasets_path)
+ print("Done!")
+
+ratings_file = movielens_dir / "ratings.csv"
+df = pd.read_csv(ratings_file)
+
+"""
+First, need to perform some preprocessing to encode users and movies as integer indices.
+"""
+user_ids = df["userId"].unique().tolist()
+user2user_encoded = {x: i for i, x in enumerate(user_ids)}
+userencoded2user = {i: x for i, x in enumerate(user_ids)}
+movie_ids = df["movieId"].unique().tolist()
+movie2movie_encoded = {x: i for i, x in enumerate(movie_ids)}
+movie_encoded2movie = {i: x for i, x in enumerate(movie_ids)}
+df["user"] = df["userId"].map(user2user_encoded)
+df["movie"] = df["movieId"].map(movie2movie_encoded)
+
+num_users = len(user2user_encoded)
+num_movies = len(movie_encoded2movie)
+df["rating"] = df["rating"].values.astype(np.float32)
+# min and max ratings will be used to normalize the ratings later
+min_rating = min(df["rating"])
+max_rating = max(df["rating"])
+
+print(
+ "Number of users: {}, Number of Movies: {}, Min rating: {}, Max rating: {}".format(
+ num_users, num_movies, min_rating, max_rating
+ )
+)
+
+"""
+## Prepare training and validation data
+"""
+df = df.sample(frac=1, random_state=42)
+x = df[["user", "movie"]].values
+# Normalize the targets between 0 and 1. Makes it easy to train.
+y = df["rating"].apply(lambda x: (x - min_rating) / (max_rating - min_rating)).values
+# Assuming training on 90% of the data and validating on 10%.
+train_indices = int(0.9 * df.shape[0])
+x_train, x_val, y_train, y_val = (
+ x[:train_indices],
+ x[train_indices:],
+ y[:train_indices],
+ y[train_indices:],
+)
+
+"""
+## Create the model
+
+We embed both users and movies in to 50-dimensional vectors.
+
+The model computes a match score between user and movie embeddings via a dot product,
+and adds a per-movie and per-user bias. The match score is scaled to the `[0, 1]`
+interval via a sigmoid (since our ratings are normalized to this range).
+"""
+EMBEDDING_SIZE = 50
+
+
+class RecommenderNet(keras.Model):
+ def __init__(self, num_users, num_movies, embedding_size, **kwargs):
+ super().__init__(**kwargs)
+ self.num_users = num_users
+ self.num_movies = num_movies
+ self.embedding_size = embedding_size
+ self.user_embedding = layers.Embedding(
+ num_users,
+ embedding_size,
+ embeddings_initializer="he_normal",
+ embeddings_regularizer=keras.regularizers.l2(1e-6),
+ )
+ self.user_bias = layers.Embedding(num_users, 1)
+ self.movie_embedding = layers.Embedding(
+ num_movies,
+ embedding_size,
+ embeddings_initializer="he_normal",
+ embeddings_regularizer=keras.regularizers.l2(1e-6),
+ )
+ self.movie_bias = layers.Embedding(num_movies, 1)
+
+ def call(self, inputs):
+ user_vector = self.user_embedding(inputs[:, 0])
+ user_bias = self.user_bias(inputs[:, 0])
+ movie_vector = self.movie_embedding(inputs[:, 1])
+ movie_bias = self.movie_bias(inputs[:, 1])
+ dot_user_movie = ops.tensordot(user_vector, movie_vector, 2)
+ # Add all the components (including bias)
+ x = dot_user_movie + user_bias + movie_bias
+ # The sigmoid activation forces the rating to between 0 and 1
+ return ops.nn.sigmoid(x)
+
+
+model = RecommenderNet(num_users, num_movies, EMBEDDING_SIZE)
+model.compile(
+ loss=keras.losses.BinaryCrossentropy(),
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+)
+
+"""
+## Train the model based on the data split
+"""
+history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=64,
+ epochs=5,
+ verbose=1,
+ validation_data=(x_val, y_val),
+)
+
+"""
+## Plot training and validation loss
+"""
+plt.plot(history.history["loss"])
+plt.plot(history.history["val_loss"])
+plt.title("model loss")
+plt.ylabel("loss")
+plt.xlabel("epoch")
+plt.legend(["train", "test"], loc="upper left")
+plt.show()
+
+"""
+## Show top 10 movie recommendations to a user
+"""
+
+movie_df = pd.read_csv(movielens_dir / "movies.csv")
+
+# Let us get a user and see the top recommendations.
+user_id = df.userId.sample(1).iloc[0]
+movies_watched_by_user = df[df.userId == user_id]
+movies_not_watched = movie_df[
+ ~movie_df["movieId"].isin(movies_watched_by_user.movieId.values)
+]["movieId"]
+movies_not_watched = list(
+ set(movies_not_watched).intersection(set(movie2movie_encoded.keys()))
+)
+movies_not_watched = [[movie2movie_encoded.get(x)] for x in movies_not_watched]
+user_encoder = user2user_encoded.get(user_id)
+user_movie_array = np.hstack(
+ ([[user_encoder]] * len(movies_not_watched), movies_not_watched)
+)
+ratings = model.predict(user_movie_array).flatten()
+top_ratings_indices = ratings.argsort()[-10:][::-1]
+recommended_movie_ids = [
+ movie_encoded2movie.get(movies_not_watched[x][0]) for x in top_ratings_indices
+]
+
+print("Showing recommendations for user: {}".format(user_id))
+print("====" * 9)
+print("Movies with high ratings from user")
+print("----" * 8)
+top_movies_user = (
+ movies_watched_by_user.sort_values(by="rating", ascending=False)
+ .head(5)
+ .movieId.values
+)
+movie_df_rows = movie_df[movie_df["movieId"].isin(top_movies_user)]
+for row in movie_df_rows.itertuples():
+ print(row.title, ":", row.genres)
+
+print("----" * 8)
+print("Top 10 movie recommendations")
+print("----" * 8)
+recommended_movies = movie_df[movie_df["movieId"].isin(recommended_movie_ids)]
+for row in recommended_movies.itertuples():
+ print(row.title, ":", row.genres)
diff --git a/knowledge_base/structured_data/customer_lifetime_value.py b/knowledge_base/structured_data/customer_lifetime_value.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca97b17c402a8d221e9882debb05b7be4af683c
--- /dev/null
+++ b/knowledge_base/structured_data/customer_lifetime_value.py
@@ -0,0 +1,561 @@
+"""
+Title: Deep Learning for Customer Lifetime Value
+Author: [Praveen Hosdrug](https://www.linkedin.com/in/praveenhosdrug/)
+Date created: 2024/11/23
+Last modified: 2024/11/27
+Description: A hybrid deep learning architecture for predicting customer purchase patterns and lifetime value.
+Accelerator: None
+"""
+
+"""
+## Introduction
+
+A hybrid deep learning architecture combining Transformer encoders and LSTM networks
+for predicting customer purchase patterns and lifetime value using transaction history.
+While many existing review articles focus on classic parametric models and traditional machine learning algorithms
+,this implementation leverages recent advancements in Transformer-based models for time series prediction.
+The approach handles multi-granularity prediction across different temporal scales.
+
+"""
+
+"""
+## Setting up Libraries for the Deep Learning Project
+"""
+import subprocess
+
+
+def install_packages(packages):
+ """
+ Install a list of packages using pip.
+
+ Args:
+ packages (list): A list of package names to install.
+ """
+ for package in packages:
+ subprocess.run(["pip", "install", package], check=True)
+
+
+"""
+## List of Packages to Install
+
+1. uciml: For the purpose of the tutorial; we will be using
+ the UK Retail [Dataset](https://archive.ics.uci.edu/dataset/352/online+retail)
+2. keras_hub: Access to the transformer encoder layer.
+"""
+
+packages_to_install = ["ucimlrepo", "keras_hub"]
+
+# Install the packages
+install_packages(packages_to_install)
+
+# Core data processing and numerical libraries
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+import keras
+import numpy as np
+import pandas as pd
+from typing import Dict
+
+
+# Visualization
+import matplotlib.pyplot as plt
+
+# Keras imports
+from keras import layers
+from keras import Model
+from keras import ops
+from keras_hub.layers import TransformerEncoder
+from keras import regularizers
+
+# UK Retail Dataset
+from ucimlrepo import fetch_ucirepo
+
+"""
+## Preprocessing the UK Retail dataset
+"""
+
+
+def prepare_time_series_data(data):
+ """
+ Preprocess retail transaction data for deep learning.
+
+ Args:
+ data: Raw transaction data containing InvoiceDate, UnitPrice, etc.
+ Returns:
+ Processed DataFrame with calculated features
+ """
+ processed_data = data.copy()
+
+ # Essential datetime handling for temporal ordering
+ processed_data["InvoiceDate"] = pd.to_datetime(processed_data["InvoiceDate"])
+
+ # Basic business constraints and calculations
+ processed_data = processed_data[processed_data["UnitPrice"] > 0]
+ processed_data["Amount"] = processed_data["UnitPrice"] * processed_data["Quantity"]
+ processed_data["CustomerID"] = processed_data["CustomerID"].fillna(99999.0)
+
+ # Handle outliers in Amount using statistical thresholds
+ q1 = processed_data["Amount"].quantile(0.25)
+ q3 = processed_data["Amount"].quantile(0.75)
+
+ # Define bounds - using 1.5 IQR rule
+ lower_bound = q1 - 1.5 * (q3 - q1)
+ upper_bound = q3 + 1.5 * (q3 - q1)
+
+ # Filter outliers
+ processed_data = processed_data[
+ (processed_data["Amount"] >= lower_bound)
+ & (processed_data["Amount"] <= upper_bound)
+ ]
+
+ return processed_data
+
+
+# Load Data
+
+online_retail = fetch_ucirepo(id=352)
+raw_data = online_retail.data.features
+transformed_data = prepare_time_series_data(raw_data)
+
+
+def prepare_data_for_modeling(
+ df: pd.DataFrame,
+ input_sequence_length: int = 6,
+ output_sequence_length: int = 6,
+) -> Dict:
+ """
+ Transform retail data into sequence-to-sequence format with separate
+ temporal and trend components.
+ """
+ df = df.copy()
+
+ # Daily aggregation
+ daily_purchases = (
+ df.groupby(["CustomerID", pd.Grouper(key="InvoiceDate", freq="D")])
+ .agg({"Amount": "sum", "Quantity": "sum", "Country": "first"})
+ .reset_index()
+ )
+
+ daily_purchases["frequency"] = np.where(daily_purchases["Amount"] > 0, 1, 0)
+
+ # Monthly resampling
+ monthly_purchases = (
+ daily_purchases.set_index("InvoiceDate")
+ .groupby("CustomerID")
+ .resample("M")
+ .agg(
+ {"Amount": "sum", "Quantity": "sum", "frequency": "sum", "Country": "first"}
+ )
+ .reset_index()
+ )
+
+ # Add cyclical temporal features
+ def prepare_temporal_features(input_window: pd.DataFrame) -> np.ndarray:
+
+ month = input_window["InvoiceDate"].dt.month
+ month_sin = np.sin(2 * np.pi * month / 12)
+ month_cos = np.cos(2 * np.pi * month / 12)
+ is_quarter_start = (month % 3 == 1).astype(int)
+
+ temporal_features = np.column_stack(
+ [
+ month,
+ input_window["InvoiceDate"].dt.year,
+ month_sin,
+ month_cos,
+ is_quarter_start,
+ ]
+ )
+ return temporal_features
+
+ # Prepare trend features with lagged values
+ def prepare_trend_features(input_window: pd.DataFrame, lag: int = 3) -> np.ndarray:
+
+ lagged_data = pd.DataFrame()
+ for i in range(1, lag + 1):
+ lagged_data[f"Amount_lag_{i}"] = input_window["Amount"].shift(i)
+ lagged_data[f"Quantity_lag_{i}"] = input_window["Quantity"].shift(i)
+ lagged_data[f"frequency_lag_{i}"] = input_window["frequency"].shift(i)
+
+ lagged_data = lagged_data.fillna(0)
+
+ trend_features = np.column_stack(
+ [
+ input_window["Amount"].values,
+ input_window["Quantity"].values,
+ input_window["frequency"].values,
+ lagged_data.values,
+ ]
+ )
+ return trend_features
+
+ sequence_containers = {
+ "temporal_sequences": [],
+ "trend_sequences": [],
+ "static_features": [],
+ "output_sequences": [],
+ }
+
+ # Process sequences for each customer
+ for customer_id, customer_data in monthly_purchases.groupby("CustomerID"):
+ customer_data = customer_data.sort_values("InvoiceDate")
+ sequence_ranges = (
+ len(customer_data) - input_sequence_length - output_sequence_length + 1
+ )
+
+ country = customer_data["Country"].iloc[0]
+
+ for i in range(sequence_ranges):
+ input_window = customer_data.iloc[i : i + input_sequence_length]
+ output_window = customer_data.iloc[
+ i
+ + input_sequence_length : i
+ + input_sequence_length
+ + output_sequence_length
+ ]
+
+ if (
+ len(input_window) == input_sequence_length
+ and len(output_window) == output_sequence_length
+ ):
+ temporal_features = prepare_temporal_features(input_window)
+ trend_features = prepare_trend_features(input_window)
+
+ sequence_containers["temporal_sequences"].append(temporal_features)
+ sequence_containers["trend_sequences"].append(trend_features)
+ sequence_containers["static_features"].append(country)
+ sequence_containers["output_sequences"].append(
+ output_window["Amount"].values
+ )
+
+ return {
+ "temporal_sequences": (
+ np.array(sequence_containers["temporal_sequences"], dtype=np.float32)
+ ),
+ "trend_sequences": (
+ np.array(sequence_containers["trend_sequences"], dtype=np.float32)
+ ),
+ "static_features": np.array(sequence_containers["static_features"]),
+ "output_sequences": (
+ np.array(sequence_containers["output_sequences"], dtype=np.float32)
+ ),
+ }
+
+
+# Transform data with input and output sequences into a Output dictionary
+output = prepare_data_for_modeling(
+ df=transformed_data, input_sequence_length=6, output_sequence_length=6
+)
+
+"""
+## Scaling and Splitting
+"""
+
+
+def robust_scale(data):
+ """
+ Min-Max scaling function since standard deviation is high.
+ """
+ data = np.array(data)
+ data_min = np.min(data)
+ data_max = np.max(data)
+ scaled = (data - data_min) / (data_max - data_min)
+ return scaled
+
+
+def create_temporal_splits_with_scaling(
+ prepared_data: Dict[str, np.ndarray],
+ test_ratio: float = 0.2,
+ val_ratio: float = 0.2,
+):
+ total_sequences = len(prepared_data["trend_sequences"])
+ # Calculate split points
+ test_size = int(total_sequences * test_ratio)
+ val_size = int(total_sequences * val_ratio)
+ train_size = total_sequences - (test_size + val_size)
+
+ # Scale trend sequences
+ trend_shape = prepared_data["trend_sequences"].shape
+ scaled_trends = np.zeros_like(prepared_data["trend_sequences"])
+
+ # Scale each feature independently
+ for i in range(trend_shape[-1]):
+ scaled_trends[..., i] = robust_scale(prepared_data["trend_sequences"][..., i])
+ # Scale output sequences
+ scaled_outputs = robust_scale(prepared_data["output_sequences"])
+
+ # Create splits
+ train_data = {
+ "trend_sequences": scaled_trends[:train_size],
+ "temporal_sequences": prepared_data["temporal_sequences"][:train_size],
+ "static_features": prepared_data["static_features"][:train_size],
+ "output_sequences": scaled_outputs[:train_size],
+ }
+
+ val_data = {
+ "trend_sequences": scaled_trends[train_size : train_size + val_size],
+ "temporal_sequences": prepared_data["temporal_sequences"][
+ train_size : train_size + val_size
+ ],
+ "static_features": prepared_data["static_features"][
+ train_size : train_size + val_size
+ ],
+ "output_sequences": scaled_outputs[train_size : train_size + val_size],
+ }
+
+ test_data = {
+ "trend_sequences": scaled_trends[train_size + val_size :],
+ "temporal_sequences": prepared_data["temporal_sequences"][
+ train_size + val_size :
+ ],
+ "static_features": prepared_data["static_features"][train_size + val_size :],
+ "output_sequences": scaled_outputs[train_size + val_size :],
+ }
+
+ return train_data, val_data, test_data
+
+
+# Usage
+train_data, val_data, test_data = create_temporal_splits_with_scaling(output)
+
+"""
+## Evaluation
+"""
+
+
+def calculate_metrics(y_true, y_pred):
+ """
+ Calculates RMSE, MAE and Rยฒ
+ """
+ # Convert inputs to "float32"
+ y_true = ops.cast(y_true, dtype="float32")
+ y_pred = ops.cast(y_pred, dtype="float32")
+
+ # RMSE
+ rmse = np.sqrt(np.mean(np.square(y_true - y_pred)))
+
+ # Rยฒ (coefficient of determination)
+ ss_res = np.sum(np.square(y_true - y_pred))
+ ss_tot = np.sum(np.square(y_true - np.mean(y_true)))
+ r2 = 1 - (ss_res / ss_tot)
+
+ return {"mae": np.mean(np.abs(y_true - y_pred)), "rmse": rmse, "r2": r2}
+
+
+def plot_lorenz_analysis(y_true, y_pred):
+ """
+ Plots Lorenz curves to show distribution of high and low value users
+ """
+ # Convert to numpy arrays and flatten
+ y_true = np.array(y_true).flatten()
+ y_pred = np.array(y_pred).flatten()
+
+ # Sort values in descending order (for high-value users analysis)
+ true_sorted = np.sort(-y_true)
+ pred_sorted = np.sort(-y_pred)
+
+ # Calculate cumulative sums
+ true_cumsum = np.cumsum(true_sorted)
+ pred_cumsum = np.cumsum(pred_sorted)
+
+ # Normalize to percentages
+ true_cumsum_pct = true_cumsum / true_cumsum[-1]
+ pred_cumsum_pct = pred_cumsum / pred_cumsum[-1]
+
+ # Generate percentiles for x-axis
+ percentiles = np.linspace(0, 1, len(y_true))
+
+ # Calculate Mutual Gini (area between curves)
+ mutual_gini = np.abs(
+ np.trapz(true_cumsum_pct, percentiles) - np.trapz(pred_cumsum_pct, percentiles)
+ )
+
+ # Create plot
+ plt.figure(figsize=(10, 6))
+ plt.plot(percentiles, true_cumsum_pct, "g-", label="True Values")
+ plt.plot(percentiles, pred_cumsum_pct, "r-", label="Predicted Values")
+ plt.xlabel("Cumulative % of Users (Descending Order)")
+ plt.ylabel("Cumulative % of LTV")
+ plt.title("Lorenz Curves: True vs Predicted Values")
+ plt.legend()
+ plt.grid(True)
+ print(f"\nMutual Gini: {mutual_gini:.4f} (lower is better)")
+ plt.show()
+
+ return mutual_gini
+
+
+"""
+## Hybrid Transformer / LSTM model architecture
+
+The hybrid nature of this model is particularly significant because it combines RNN's
+ability to handle sequential data with Transformer's attention mechanisms for capturing
+global patterns across countries and seasonality.
+"""
+
+
+def build_hybrid_model(
+ input_sequence_length: int,
+ output_sequence_length: int,
+ num_countries: int,
+ d_model: int = 8,
+ num_heads: int = 4,
+):
+
+ keras.utils.set_random_seed(seed=42)
+
+ # Inputs
+ temporal_inputs = layers.Input(
+ shape=(input_sequence_length, 5), name="temporal_inputs"
+ )
+ trend_inputs = layers.Input(shape=(input_sequence_length, 12), name="trend_inputs")
+ country_inputs = layers.Input(
+ shape=(num_countries,), dtype="int32", name="country_inputs"
+ )
+
+ # Process country features
+ country_embedding = layers.Embedding(
+ input_dim=num_countries,
+ output_dim=d_model,
+ mask_zero=False,
+ name="country_embedding",
+ )(
+ country_inputs
+ ) # Output shape: (batch_size, 1, d_model)
+
+ # Flatten the embedding output
+ country_embedding = layers.Flatten(name="flatten_country_embedding")(
+ country_embedding
+ )
+
+ # Repeat the country embedding across timesteps
+ country_embedding_repeated = layers.RepeatVector(
+ input_sequence_length, name="repeat_country_embedding"
+ )(country_embedding)
+
+ # Projection of temporal inputs to match Transformer dimensions
+ temporal_projection = layers.Dense(
+ d_model, activation="tanh", name="temporal_projection"
+ )(temporal_inputs)
+
+ # Combine all features
+ combined_features = layers.Concatenate()(
+ [temporal_projection, country_embedding_repeated]
+ )
+
+ transformer_output = combined_features
+ for _ in range(3):
+ transformer_output = TransformerEncoder(
+ intermediate_dim=16, num_heads=num_heads
+ )(transformer_output)
+
+ lstm_output = layers.LSTM(units=64, name="lstm_trend")(trend_inputs)
+
+ transformer_flattened = layers.GlobalAveragePooling1D(name="flatten_transformer")(
+ transformer_output
+ )
+ transformer_flattened = layers.Dense(1, activation="sigmoid")(transformer_flattened)
+ # Concatenate flattened Transformer output with LSTM output
+ merged_features = layers.Concatenate(name="concatenate_transformer_lstm")(
+ [transformer_flattened, lstm_output]
+ )
+ # Repeat the merged features to match the output sequence length
+ decoder_initial = layers.RepeatVector(
+ output_sequence_length, name="repeat_merged_features"
+ )(merged_features)
+
+ decoder_lstm = layers.LSTM(
+ units=64,
+ return_sequences=True,
+ recurrent_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4),
+ )(decoder_initial)
+
+ # Output Dense layer
+ output = layers.Dense(units=1, activation="linear", name="output_dense")(
+ decoder_lstm
+ )
+
+ model = Model(
+ inputs=[temporal_inputs, trend_inputs, country_inputs], outputs=output
+ )
+
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+ loss="mse",
+ metrics=["mse"],
+ )
+
+ return model
+
+
+# Create the hybrid model
+model = build_hybrid_model(
+ input_sequence_length=6,
+ output_sequence_length=6,
+ num_countries=len(np.unique(train_data["static_features"])) + 1,
+ d_model=8,
+ num_heads=4,
+)
+
+# Configure StringLookup
+label_encoder = layers.StringLookup(output_mode="one_hot", num_oov_indices=1)
+
+# Adapt and encode
+label_encoder.adapt(train_data["static_features"])
+
+train_static_encoded = label_encoder(train_data["static_features"])
+val_static_encoded = label_encoder(val_data["static_features"])
+test_static_encoded = label_encoder(test_data["static_features"])
+
+# Convert sequences with proper type casting
+x_train_seq = np.asarray(train_data["trend_sequences"]).astype(np.float32)
+x_val_seq = np.asarray(val_data["trend_sequences"]).astype(np.float32)
+x_train_temporal = np.asarray(train_data["temporal_sequences"]).astype(np.float32)
+x_val_temporal = np.asarray(val_data["temporal_sequences"]).astype(np.float32)
+train_outputs = np.asarray(train_data["output_sequences"]).astype(np.float32)
+val_outputs = np.asarray(val_data["output_sequences"]).astype(np.float32)
+test_output = np.asarray(test_data["output_sequences"]).astype(np.float32)
+# Training setup
+keras.utils.set_random_seed(seed=42)
+
+history = model.fit(
+ [x_train_temporal, x_train_seq, train_static_encoded],
+ train_outputs,
+ validation_data=(
+ [x_val_temporal, x_val_seq, val_static_encoded],
+ val_data["output_sequences"].astype(np.float32),
+ ),
+ epochs=20,
+ batch_size=30,
+)
+
+# Make predictions
+predictions = model.predict(
+ [
+ test_data["temporal_sequences"].astype(np.float32),
+ test_data["trend_sequences"].astype(np.float32),
+ test_static_encoded,
+ ]
+)
+
+# Calculate the predictions
+predictions = np.squeeze(predictions)
+
+# Calculate basic metrics
+hybrid_metrics = calculate_metrics(test_data["output_sequences"], predictions)
+
+# Plot Lorenz curves and get Mutual Gini
+hybrid_mutual_gini = plot_lorenz_analysis(test_data["output_sequences"], predictions)
+
+"""
+## Conclusion
+
+While LSTMs excel at sequence to sequence learning as demonstrated through the work of Sutskever, I., Vinyals,
+O., & Le, Q. V. (2014) Sequence to sequence learning with neural networks.
+The hybrid approach here enhances this foundation. The addition of attention mechanisms allows the model to adaptively
+focus on relevant temporal/geographical patterns while maintaining the LSTM's inherent strengths in sequence learning.
+This combination has proven especially effective for handling both periodic patterns and special events in time
+series forecasting from Zhou, H., Zhang, S., Peng, J., Zhang, S., Li, J., Xiong, H., & Zhang, W. (2021).
+Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting.
+"""
diff --git a/knowledge_base/structured_data/deep_neural_decision_forests.py b/knowledge_base/structured_data/deep_neural_decision_forests.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a3e38fc06e67b3d2f6cda0eef36cdc344e7149f
--- /dev/null
+++ b/knowledge_base/structured_data/deep_neural_decision_forests.py
@@ -0,0 +1,460 @@
+"""
+Title: Classification with Neural Decision Forests
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/01/15
+Last modified: 2021/01/15
+Description: How to train differentiable decision trees for end-to-end learning in deep neural networks.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example provides an implementation of the
+[Deep Neural Decision Forest](https://ieeexplore.ieee.org/document/7410529)
+model introduced by P. Kontschieder et al. for structured data classification.
+It demonstrates how to build a stochastic and differentiable decision tree model,
+train it end-to-end, and unify decision trees with deep representation learning.
+
+## The dataset
+
+This example uses the
+[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/census+income)
+provided by the
+[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).
+The task is binary classification
+to predict whether a person is likely to be making over USD 50,000 a year.
+
+The dataset includes 48,842 instances with 14 input features (such as age, work class, education, occupation, and so on): 5 numerical features
+and 9 categorical features.
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import layers
+from keras.layers import StringLookup
+from keras import ops
+
+
+from tensorflow import data as tf_data
+import numpy as np
+import pandas as pd
+
+import math
+
+
+"""
+## Prepare the data
+"""
+
+CSV_HEADER = [
+ "age",
+ "workclass",
+ "fnlwgt",
+ "education",
+ "education_num",
+ "marital_status",
+ "occupation",
+ "relationship",
+ "race",
+ "gender",
+ "capital_gain",
+ "capital_loss",
+ "hours_per_week",
+ "native_country",
+ "income_bracket",
+]
+
+train_data_url = (
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
+)
+train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)
+
+test_data_url = (
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
+)
+test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
+
+print(f"Train dataset shape: {train_data.shape}")
+print(f"Test dataset shape: {test_data.shape}")
+
+"""
+Remove the first record (because it is not a valid data example) and a trailing
+'dot' in the class labels.
+"""
+
+test_data = test_data[1:]
+test_data.income_bracket = test_data.income_bracket.apply(
+ lambda value: value.replace(".", "")
+)
+
+"""
+We store the training and test data splits locally as CSV files.
+"""
+
+train_data_file = "train_data.csv"
+test_data_file = "test_data.csv"
+
+train_data.to_csv(train_data_file, index=False, header=False)
+test_data.to_csv(test_data_file, index=False, header=False)
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for reading and parsing
+and encoding input features.
+"""
+
+# A list of the numerical feature names.
+NUMERIC_FEATURE_NAMES = [
+ "age",
+ "education_num",
+ "capital_gain",
+ "capital_loss",
+ "hours_per_week",
+]
+# A dictionary of the categorical features and their vocabulary.
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ "workclass": sorted(list(train_data["workclass"].unique())),
+ "education": sorted(list(train_data["education"].unique())),
+ "marital_status": sorted(list(train_data["marital_status"].unique())),
+ "occupation": sorted(list(train_data["occupation"].unique())),
+ "relationship": sorted(list(train_data["relationship"].unique())),
+ "race": sorted(list(train_data["race"].unique())),
+ "gender": sorted(list(train_data["gender"].unique())),
+ "native_country": sorted(list(train_data["native_country"].unique())),
+}
+# A list of the columns to ignore from the dataset.
+IGNORE_COLUMN_NAMES = ["fnlwgt"]
+# A list of the categorical feature names.
+CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())
+# A list of all the input features.
+FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES
+# A list of column default values for each feature.
+COLUMN_DEFAULTS = [
+ [0.0] if feature_name in NUMERIC_FEATURE_NAMES + IGNORE_COLUMN_NAMES else ["NA"]
+ for feature_name in CSV_HEADER
+]
+# The name of the target feature.
+TARGET_FEATURE_NAME = "income_bracket"
+# A list of the labels of the target features.
+TARGET_LABELS = [" <=50K", " >50K"]
+
+"""
+## Create `tf_data.Dataset` objects for training and validation
+
+We create an input function to read and parse the file, and convert features and labels
+into a [`tf_data.Dataset`](https://www.tensorflow.org/guide/datasets)
+for training and validation. We also preprocess the input by mapping the target label
+to an index.
+"""
+
+
+target_label_lookup = StringLookup(
+ vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0
+)
+
+
+lookup_dict = {}
+for feature_name in CATEGORICAL_FEATURE_NAMES:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token, nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ lookup = StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)
+ lookup_dict[feature_name] = lookup
+
+
+def encode_categorical(batch_x, batch_y):
+ for feature_name in CATEGORICAL_FEATURE_NAMES:
+ batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])
+
+ return batch_x, batch_y
+
+
+def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
+ dataset = (
+ tf_data.experimental.make_csv_dataset(
+ csv_file_path,
+ batch_size=batch_size,
+ column_names=CSV_HEADER,
+ column_defaults=COLUMN_DEFAULTS,
+ label_name=TARGET_FEATURE_NAME,
+ num_epochs=1,
+ header=False,
+ na_value="?",
+ shuffle=shuffle,
+ )
+ .map(lambda features, target: (features, target_label_lookup(target)))
+ .map(encode_categorical)
+ )
+
+ return dataset.cache()
+
+
+"""
+## Create model inputs
+"""
+
+
+def create_model_inputs():
+ inputs = {}
+ for feature_name in FEATURE_NAMES:
+ if feature_name in NUMERIC_FEATURE_NAMES:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="float32"
+ )
+ else:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="int32"
+ )
+ return inputs
+
+
+"""
+## Encode input features
+"""
+
+
+def encode_inputs(inputs):
+ encoded_features = []
+ for feature_name in inputs:
+ if feature_name in CATEGORICAL_FEATURE_NAMES:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token, nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ value_index = inputs[feature_name]
+ embedding_dims = int(math.sqrt(lookup.vocabulary_size()))
+ # Create an embedding layer with the specified dimensions.
+ embedding = layers.Embedding(
+ input_dim=lookup.vocabulary_size(), output_dim=embedding_dims
+ )
+ # Convert the index values to embedding representations.
+ encoded_feature = embedding(value_index)
+ else:
+ # Use the numerical features as-is.
+ encoded_feature = inputs[feature_name]
+ if inputs[feature_name].shape[-1] is None:
+ encoded_feature = keras.ops.expand_dims(encoded_feature, -1)
+
+ encoded_features.append(encoded_feature)
+
+ encoded_features = layers.concatenate(encoded_features)
+ return encoded_features
+
+
+"""
+## Deep Neural Decision Tree
+
+A neural decision tree model has two sets of weights to learn. The first set is `pi`,
+which represents the probability distribution of the classes in the tree leaves.
+The second set is the weights of the routing layer `decision_fn`, which represents the probability
+of going to each leave. The forward pass of the model works as follows:
+
+1. The model expects input `features` as a single vector encoding all the features of an instance
+in the batch. This vector can be generated from a Convolution Neural Network (CNN) applied to images
+or dense transformations applied to structured data features.
+2. The model first applies a `used_features_mask` to randomly select a subset of input features to use.
+3. Then, the model computes the probabilities (`mu`) for the input instances to reach the tree leaves
+by iteratively performing a *stochastic* routing throughout the tree levels.
+4. Finally, the probabilities of reaching the leaves are combined by the class probabilities at the
+leaves to produce the final `outputs`.
+"""
+
+
+class NeuralDecisionTree(keras.Model):
+ def __init__(self, depth, num_features, used_features_rate, num_classes):
+ super().__init__()
+ self.depth = depth
+ self.num_leaves = 2**depth
+ self.num_classes = num_classes
+
+ # Create a mask for the randomly selected features.
+ num_used_features = int(num_features * used_features_rate)
+ one_hot = np.eye(num_features)
+ sampled_feature_indices = np.random.choice(
+ np.arange(num_features), num_used_features, replace=False
+ )
+ self.used_features_mask = ops.convert_to_tensor(
+ one_hot[sampled_feature_indices], dtype="float32"
+ )
+
+ # Initialize the weights of the classes in leaves.
+ self.pi = self.add_weight(
+ initializer="random_normal",
+ shape=[self.num_leaves, self.num_classes],
+ dtype="float32",
+ trainable=True,
+ )
+
+ # Initialize the stochastic routing layer.
+ self.decision_fn = layers.Dense(
+ units=self.num_leaves, activation="sigmoid", name="decision"
+ )
+
+ def call(self, features):
+ batch_size = ops.shape(features)[0]
+
+ # Apply the feature mask to the input features.
+ features = ops.matmul(
+ features, ops.transpose(self.used_features_mask)
+ ) # [batch_size, num_used_features]
+ # Compute the routing probabilities.
+ decisions = ops.expand_dims(
+ self.decision_fn(features), axis=2
+ ) # [batch_size, num_leaves, 1]
+ # Concatenate the routing probabilities with their complements.
+ decisions = layers.concatenate(
+ [decisions, 1 - decisions], axis=2
+ ) # [batch_size, num_leaves, 2]
+
+ mu = ops.ones([batch_size, 1, 1])
+
+ begin_idx = 1
+ end_idx = 2
+ # Traverse the tree in breadth-first order.
+ for level in range(self.depth):
+ mu = ops.reshape(mu, [batch_size, -1, 1]) # [batch_size, 2 ** level, 1]
+ mu = ops.tile(mu, (1, 1, 2)) # [batch_size, 2 ** level, 2]
+ level_decisions = decisions[
+ :, begin_idx:end_idx, :
+ ] # [batch_size, 2 ** level, 2]
+ mu = mu * level_decisions # [batch_size, 2**level, 2]
+ begin_idx = end_idx
+ end_idx = begin_idx + 2 ** (level + 1)
+
+ mu = ops.reshape(mu, [batch_size, self.num_leaves]) # [batch_size, num_leaves]
+ probabilities = keras.activations.softmax(self.pi) # [num_leaves, num_classes]
+ outputs = ops.matmul(mu, probabilities) # [batch_size, num_classes]
+ return outputs
+
+
+"""
+## Deep Neural Decision Forest
+
+The neural decision forest model consists of a set of neural decision trees that are
+trained simultaneously. The output of the forest model is the average outputs of its trees.
+"""
+
+
+class NeuralDecisionForest(keras.Model):
+ def __init__(self, num_trees, depth, num_features, used_features_rate, num_classes):
+ super().__init__()
+ self.ensemble = []
+ # Initialize the ensemble by adding NeuralDecisionTree instances.
+ # Each tree will have its own randomly selected input features to use.
+ for _ in range(num_trees):
+ self.ensemble.append(
+ NeuralDecisionTree(depth, num_features, used_features_rate, num_classes)
+ )
+
+ def call(self, inputs):
+ # Initialize the outputs: a [batch_size, num_classes] matrix of zeros.
+ batch_size = ops.shape(inputs)[0]
+ outputs = ops.zeros([batch_size, num_classes])
+
+ # Aggregate the outputs of trees in the ensemble.
+ for tree in self.ensemble:
+ outputs += tree(inputs)
+ # Divide the outputs by the ensemble size to get the average.
+ outputs /= len(self.ensemble)
+ return outputs
+
+
+"""
+Finally, let's set up the code that will train and evaluate the model.
+"""
+
+learning_rate = 0.01
+batch_size = 265
+num_epochs = 10
+
+
+def run_experiment(model):
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
+ loss=keras.losses.SparseCategoricalCrossentropy(),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
+ )
+
+ print("Start training the model...")
+ train_dataset = get_dataset_from_csv(
+ train_data_file, shuffle=True, batch_size=batch_size
+ )
+
+ model.fit(train_dataset, epochs=num_epochs)
+ print("Model training finished")
+
+ print("Evaluating the model on the test data...")
+ test_dataset = get_dataset_from_csv(test_data_file, batch_size=batch_size)
+
+ _, accuracy = model.evaluate(test_dataset)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+
+"""
+## Experiment 1: train a decision tree model
+
+In this experiment, we train a single neural decision tree model
+where we use all input features.
+"""
+
+num_trees = 10
+depth = 10
+used_features_rate = 1.0
+num_classes = len(TARGET_LABELS)
+
+
+def create_tree_model():
+ inputs = create_model_inputs()
+ features = encode_inputs(inputs)
+ features = layers.BatchNormalization()(features)
+ num_features = features.shape[1]
+
+ tree = NeuralDecisionTree(depth, num_features, used_features_rate, num_classes)
+
+ outputs = tree(features)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+tree_model = create_tree_model()
+run_experiment(tree_model)
+
+
+"""
+## Experiment 2: train a forest model
+
+In this experiment, we train a neural decision forest with `num_trees` trees
+where each tree uses randomly selected 50% of the input features. You can control the number
+of features to be used in each tree by setting the `used_features_rate` variable.
+In addition, we set the depth to 5 instead of 10 compared to the previous experiment.
+"""
+
+num_trees = 25
+depth = 5
+used_features_rate = 0.5
+
+
+def create_forest_model():
+ inputs = create_model_inputs()
+ features = encode_inputs(inputs)
+ features = layers.BatchNormalization()(features)
+ num_features = features.shape[1]
+
+ forest_model = NeuralDecisionForest(
+ num_trees, depth, num_features, used_features_rate, num_classes
+ )
+
+ outputs = forest_model(features)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+forest_model = create_forest_model()
+
+run_experiment(forest_model)
diff --git a/knowledge_base/structured_data/feature_space_advanced.py b/knowledge_base/structured_data/feature_space_advanced.py
new file mode 100644
index 0000000000000000000000000000000000000000..12826728c5dddf4d8781672586d462b0af651e3e
--- /dev/null
+++ b/knowledge_base/structured_data/feature_space_advanced.py
@@ -0,0 +1,608 @@
+"""
+Title: FeatureSpace advanced use cases
+Author: [Dimitre Oliveira](https://www.linkedin.com/in/dimitre-oliveira-7a1a0113a/)
+Date created: 2023/07/01
+Last modified: 2025/01/03
+Description: How to use FeatureSpace for advanced preprocessing use cases.
+Accelerator: None
+"""
+
+"""
+## Introduction
+
+This example is an extension of the
+[Structured data classification with FeatureSpace](https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/)
+code example, and here we will extend it to cover more complex use
+cases of the [`keras.utils.FeatureSpace`](https://keras.io/api/utils/feature_space/)
+preprocessing utility, like feature hashing, feature crosses, handling missing values and
+integrating [Keras preprocessing layers](https://keras.io/api/layers/preprocessing_layers/)
+with FeatureSpace.
+
+The general task still is structured data classification (also known as tabular data
+classification) using a data that includes numerical features, integer categorical
+features, and string categorical features.
+"""
+
+"""
+### The dataset
+
+[Our dataset](https://archive.ics.uci.edu/dataset/222/bank+marketing) is provided by a
+Portuguese banking institution.
+It's a CSV file with 4119 rows. Each row contains information about marketing campaigns
+based on phone calls, and each column describes an attribute of the client. We use the
+features to predict whether the client subscribed ('yes') or not ('no') to the product
+(bank term deposit).
+
+Here's the description of each feature:
+
+Column| Description| Feature Type
+------|------------|-------------
+Age | Age of the client | Numerical
+Job | Type of job | Categorical
+Marital | Marital status | Categorical
+Education | Education level of the client | Categorical
+Default | Has credit in default? | Categorical
+Housing | Has housing loan? | Categorical
+Loan | Has personal loan? | Categorical
+Contact | Contact communication type | Categorical
+Month | Last contact month of year | Categorical
+Day_of_week | Last contact day of the week | Categorical
+Duration | Last contact duration, in seconds | Numerical
+Campaign | Number of contacts performed during this campaign and for this client | Numerical
+Pdays | Number of days that passed by after the client was last contacted from a previous campaign | Numerical
+Previous | Number of contacts performed before this campaign and for this client | Numerical
+Poutcome | Outcome of the previous marketing campaign | Categorical
+Emp.var.rate | Employment variation rate | Numerical
+Cons.price.idx | Consumer price index | Numerical
+Cons.conf.idx | Consumer confidence index | Numerical
+Euribor3m | Euribor 3 month rate | Numerical
+Nr.employed | Number of employees | Numerical
+Y | Has the client subscribed a term deposit? | Target
+
+**Important note regarding the feature `duration`**: this attribute highly affects the
+output target (e.g., if duration=0 then y='no'). Yet, the duration is not known before a
+call is performed. Also, after the end of the call y is obviously known. Thus, this input
+should only be included for benchmark purposes and should be discarded if the intention
+is to have a realistic predictive model. For this reason we will drop it.
+
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras.utils import FeatureSpace
+import pandas as pd
+import tensorflow as tf
+from pathlib import Path
+from zipfile import ZipFile
+
+"""
+## Load the data
+
+Let's download the data and load it into a Pandas dataframe:
+"""
+
+data_url = "https://archive.ics.uci.edu/static/public/222/bank+marketing.zip"
+data_zipped_path = keras.utils.get_file("bank_marketing.zip", data_url, extract=True)
+keras_datasets_path = Path(data_zipped_path)
+with ZipFile(f"{keras_datasets_path}/bank-additional.zip", "r") as zip:
+ # Extract files
+ zip.extractall(path=keras_datasets_path)
+
+dataframe = pd.read_csv(
+ f"{keras_datasets_path}/bank-additional/bank-additional.csv", sep=";"
+)
+
+"""
+We will create a new feature `previously_contacted` to be able to demonstrate some useful
+preprocessing techniques, this feature is based on `pdays`. According to the dataset
+information if `pdays = 999` it means that the client was not previously contacted, so
+let's create a feature to capture that.
+"""
+
+# Droping `duration` to avoid target leak
+dataframe.drop("duration", axis=1, inplace=True)
+# Creating the new feature `previously_contacted`
+dataframe["previously_contacted"] = dataframe["pdays"].map(
+ lambda x: 0 if x == 999 else 1
+)
+
+"""
+The dataset includes 4119 samples with 21 columns per sample (20 features, plus the
+target label), here's a preview of a few samples:
+"""
+
+print(f"Dataframe shape: {dataframe.shape}")
+print(dataframe.head())
+
+"""
+The column, "y", indicates whether the client has subscribed a term deposit or not.
+"""
+
+"""
+## Train/validation split
+
+Let's split the data into a training and validation set:
+"""
+
+valid_dataframe = dataframe.sample(frac=0.2, random_state=0)
+train_dataframe = dataframe.drop(valid_dataframe.index)
+
+print(
+ f"Using {len(train_dataframe)} samples for training and "
+ f"{len(valid_dataframe)} for validation"
+)
+
+"""
+## Generating TF datasets
+
+Let's generate
+[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) objects
+for each dataframe, since our target column `y` is a string we also need to encode it as
+an integer to be able to train our model with it. To achieve this we will create a
+`StringLookup` layer that will map the strings "no" and "yes" into "0" and "1"
+respectively.
+"""
+
+label_lookup = keras.layers.StringLookup(
+ # the order here is important since the first index will be encoded as 0
+ vocabulary=["no", "yes"],
+ num_oov_indices=0,
+)
+
+
+def encode_label(x, y):
+ encoded_y = label_lookup(y)
+ return x, encoded_y
+
+
+def dataframe_to_dataset(dataframe):
+ dataframe = dataframe.copy()
+ labels = dataframe.pop("y")
+ ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
+ ds = ds.map(encode_label, num_parallel_calls=tf.data.AUTOTUNE)
+ ds = ds.shuffle(buffer_size=len(dataframe))
+ return ds
+
+
+train_ds = dataframe_to_dataset(train_dataframe)
+valid_ds = dataframe_to_dataset(valid_dataframe)
+
+"""
+Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
+and `target` is the value `0` or `1`:
+"""
+
+for x, y in dataframe_to_dataset(train_dataframe).take(1):
+ print(f"Input: {x}")
+ print(f"Target: {y}")
+
+"""
+## Preprocessing
+
+Usually our data is not on the proper or best format for modeling, this is why most of
+the time we need to do some kind of preprocessing on the features to make them compatible
+with the model or to extract the most of them for the task. We need to do this
+preprocessing step for training but but at inference we also need to make sure that the
+data goes through the same process, this where a utility like `FeatureSpace` shines, we
+can define all the preprocessing once and re-use it at different stages of our system.
+
+Here we will see how to use `FeatureSpace` to perform more complex transformations and
+its flexibility, then combine everything together into a single component to preprocess
+data for our model.
+"""
+
+"""
+The `FeatureSpace` utility learns how to process the data by using the `adapt()` function
+to learn from it, this requires a dataset containing only feature, so let's create it
+together with a utility function to show the preprocessing example in practice:
+"""
+
+train_ds_with_no_labels = train_ds.map(lambda x, _: x)
+
+
+def example_feature_space(dataset, feature_space, feature_names):
+ feature_space.adapt(dataset)
+ for x in dataset.take(1):
+ inputs = {feature_name: x[feature_name] for feature_name in feature_names}
+ preprocessed_x = feature_space(inputs)
+ print(f"Input: {[{k:v.numpy()} for k, v in inputs.items()]}")
+ print(
+ f"Preprocessed output: {[{k:v.numpy()} for k, v in preprocessed_x.items()]}"
+ )
+
+
+"""
+### Feature hashing
+"""
+
+"""
+**Feature hashing** means hashing or encoding a set of values into a defined number of
+bins, in this case we have `campaign` (number of contacts performed during this campaign
+and for a client) which is a numerical feature that can assume a varying range of values
+and we will hash it into 4 bins, this means that any possible value of the original
+feature will be placed into one of those possible 4 bins. The output here can be a
+one-hot encoded vector or a single number.
+"""
+
+feature_space = FeatureSpace(
+ features={
+ "campaign": FeatureSpace.integer_hashed(num_bins=4, output_mode="one_hot")
+ },
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["campaign"])
+
+"""
+**Feature hashing** can also be used for string features.
+"""
+
+feature_space = FeatureSpace(
+ features={
+ "education": FeatureSpace.string_hashed(num_bins=3, output_mode="one_hot")
+ },
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["education"])
+
+"""
+For numerical features we can get a similar behavior by using the `float_discretized`
+option, the main difference between this and `integer_hashed` is that with the former we
+bin the values while keeping some numerical relationship (close values will likely be
+placed at the same bin) while the later (hashing) we cannot guarantee that those numbers
+will be hashed into the same bin, it depends on the hashing function.
+"""
+
+feature_space = FeatureSpace(
+ features={"age": FeatureSpace.float_discretized(num_bins=3, output_mode="one_hot")},
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["age"])
+
+"""
+### Feature indexing
+"""
+
+"""
+**Indexing** a string feature essentially means creating a discrete numerical
+representation for it, this is especially important for string features since most models
+only accept numerical features. This transformation will place the string values into
+different categories. The output here can be a one-hot encoded vector or a single number.
+
+Note that by specifying `num_oov_indices=1` we leave one spot at our output vector for
+OOV (out of vocabulary) values this is an important tool to handle missing or unseen
+values after the training (values that were not seen during the `adapt()` step)
+"""
+
+feature_space = FeatureSpace(
+ features={
+ "default": FeatureSpace.string_categorical(
+ num_oov_indices=1, output_mode="one_hot"
+ )
+ },
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["default"])
+
+"""
+We also can do **feature indexing** for integer features, this can be quite important for
+some datasets where categorical features are replaced by numbers, for instance features
+like `sex` or `gender` where values like (`1 and 0`) do not have a numerical relationship
+between them, they are just different categories, this behavior can be perfectly captured
+by this transformation.
+
+On this dataset we can use the feature that we created `previously_contacted`. For this
+case we want to explicitly set `num_oov_indices=0`, the reason is that we only expect two
+possible values for the feature, anything else would be either wrong input or an issue
+with the data creation, for this reason we would probably just want the code to throw an
+error so that we can be aware of the issue and fix it.
+"""
+
+feature_space = FeatureSpace(
+ features={
+ "previously_contacted": FeatureSpace.integer_categorical(
+ num_oov_indices=0, output_mode="one_hot"
+ )
+ },
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["previously_contacted"])
+
+"""
+### Feature crosses (mixing features of diverse types)
+
+With **crosses** we can do feature interactions between an arbitrary number of features
+of mixed types as long as they are categorical features, you can think of instead of
+having a feature {'age': 20} and another {'job': 'entrepreneur'} we can have
+{'age_X_job': 20_entrepreneur}, but with `FeatureSpace` and **crosses** we can apply
+specific preprocessing to each individual feature and to the feature cross itself. This
+option can be very powerful for specific use cases, here might be a good option since age
+combined with job can have different meanings for the banking domain.
+
+We will cross `age` and `job` and hash the combination output of them into a vector
+representation of size 8. The output here can be a one-hot encoded vector or a single
+number.
+
+Sometimes the combination of multiple features can result into on a super large feature
+space, think about crossing someone's ZIP code with its last name, the possibilities
+would be in the thousands, that is why the `crossing_dim` parameter is so important it
+limits the output dimension of the cross feature.
+
+Note that the combination of possible values of the 6 bins of `age` and the 12 values of
+`job` would be 72, so by choosing `crossing_dim = 8` we are choosing to constrain the
+output vector.
+"""
+
+feature_space = FeatureSpace(
+ features={
+ "age": FeatureSpace.integer_hashed(num_bins=6, output_mode="one_hot"),
+ "job": FeatureSpace.string_categorical(
+ num_oov_indices=0, output_mode="one_hot"
+ ),
+ },
+ crosses=[
+ FeatureSpace.cross(
+ feature_names=("age", "job"),
+ crossing_dim=8,
+ output_mode="one_hot",
+ )
+ ],
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["age", "job"])
+
+"""
+### FeatureSpace using a Keras preprocessing layer
+
+To be a really flexible and extensible feature we cannot only rely on those pre-defined
+transformation, we must be able to re-use other transformations from the Keras/TensorFlow
+ecosystem and customize our own, this is why `FeatureSpace` is also designed to work with
+[Keras preprocessing layers](https://keras.io/api/layers/preprocessing_layers/), this way we
+can use sophisticated data transformations provided by the framework, you can even create
+your own custom Keras preprocessing layers and use it in the same way.
+
+Here we are going to use the
+[`keras.layers.TextVectorization`](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization/#textvectorization-class)
+preprocessing layer to create a TF-IDF
+feature from our data. Note that this feature is not a really good use case for TF-IDF,
+this is just for demonstration purposes.
+"""
+
+custom_layer = keras.layers.TextVectorization(output_mode="tf_idf")
+
+feature_space = FeatureSpace(
+ features={
+ "education": FeatureSpace.feature(
+ preprocessor=custom_layer, dtype="string", output_mode="float"
+ )
+ },
+ output_mode="dict",
+)
+example_feature_space(train_ds_with_no_labels, feature_space, ["education"])
+
+"""
+## Configuring the final `FeatureSpace`
+
+Now that we know how to use `FeatureSpace` for more complex use cases let's pick the ones
+that looks more useful for this task and create the final `FeatureSpace` component.
+
+To configure how each feature should be preprocessed,
+we instantiate a `keras.utils.FeatureSpace`, and we
+pass to it a dictionary that maps the name of our features
+to the feature transformation function.
+
+"""
+
+feature_space = FeatureSpace(
+ features={
+ # Categorical features encoded as integers
+ "previously_contacted": FeatureSpace.integer_categorical(num_oov_indices=0),
+ # Categorical features encoded as string
+ "marital": FeatureSpace.string_categorical(num_oov_indices=0),
+ "education": FeatureSpace.string_categorical(num_oov_indices=0),
+ "default": FeatureSpace.string_categorical(num_oov_indices=0),
+ "housing": FeatureSpace.string_categorical(num_oov_indices=0),
+ "loan": FeatureSpace.string_categorical(num_oov_indices=0),
+ "contact": FeatureSpace.string_categorical(num_oov_indices=0),
+ "month": FeatureSpace.string_categorical(num_oov_indices=0),
+ "day_of_week": FeatureSpace.string_categorical(num_oov_indices=0),
+ "poutcome": FeatureSpace.string_categorical(num_oov_indices=0),
+ # Categorical features to hash and bin
+ "job": FeatureSpace.string_hashed(num_bins=3),
+ # Numerical features to hash and bin
+ "pdays": FeatureSpace.integer_hashed(num_bins=4),
+ # Numerical features to normalize and bin
+ "age": FeatureSpace.float_discretized(num_bins=4),
+ # Numerical features to normalize
+ "campaign": FeatureSpace.float_normalized(),
+ "previous": FeatureSpace.float_normalized(),
+ "emp.var.rate": FeatureSpace.float_normalized(),
+ "cons.price.idx": FeatureSpace.float_normalized(),
+ "cons.conf.idx": FeatureSpace.float_normalized(),
+ "euribor3m": FeatureSpace.float_normalized(),
+ "nr.employed": FeatureSpace.float_normalized(),
+ },
+ # Specify feature cross with a custom crossing dim.
+ crosses=[
+ FeatureSpace.cross(feature_names=("age", "job"), crossing_dim=8),
+ FeatureSpace.cross(feature_names=("housing", "loan"), crossing_dim=6),
+ FeatureSpace.cross(
+ feature_names=("poutcome", "previously_contacted"), crossing_dim=2
+ ),
+ ],
+ output_mode="concat",
+)
+
+"""
+## Adapt the `FeatureSpace` to the training data
+
+Before we start using the `FeatureSpace` to build a model, we have
+to adapt it to the training data. During `adapt()`, the `FeatureSpace` will:
+
+- Index the set of possible values for categorical features.
+- Compute the mean and variance for numerical features to normalize.
+- Compute the value boundaries for the different bins for numerical features to
+discretize.
+- Any other kind of preprocessing required by custom layers.
+
+Note that `adapt()` should be called on a `tf.data.Dataset` which yields dicts
+of feature values -- no labels.
+
+But first let's batch the datasets
+"""
+
+train_ds = train_ds.batch(32)
+valid_ds = valid_ds.batch(32)
+
+train_ds_with_no_labels = train_ds.map(lambda x, _: x)
+feature_space.adapt(train_ds_with_no_labels)
+
+"""
+At this point, the `FeatureSpace` can be called on a dict of raw feature values, and
+because we set `output_mode="concat"` it will return a single concatenate vector for each
+sample, combining encoded features and feature crosses.
+"""
+
+for x, _ in train_ds.take(1):
+ preprocessed_x = feature_space(x)
+ print(f"preprocessed_x shape: {preprocessed_x.shape}")
+ print(f"preprocessed_x sample: \n{preprocessed_x[0]}")
+
+"""
+## Saving the `FeatureSpace`
+
+At this point we can choose to save our `FeatureSpace` component, this have many
+advantages like re-using it on different experiments that use the same model, saving time
+if you need to re-run the preprocessing step, and mainly for model deployment, where by
+loading it you can be sure that you will be applying the same preprocessing steps don't
+matter the device or environment, this is a great way to reduce
+[training/servingskew](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew).
+"""
+
+feature_space.save("myfeaturespace.keras")
+
+"""
+## Preprocessing with `FeatureSpace` as part of the tf.data pipeline
+
+We will opt to use our component asynchronously by making it part of the tf.data
+pipeline, as noted at the
+[previous guide](https://keras.io/examples/structured_data/structured_data_classification_with_feature_space/)
+This enables asynchronous parallel preprocessing of the data on CPU before it
+hits the model. Usually, this is always the right thing to do during training.
+
+Let's create a training and validation dataset of preprocessed batches:
+"""
+
+preprocessed_train_ds = train_ds.map(
+ lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
+).prefetch(tf.data.AUTOTUNE)
+
+preprocessed_valid_ds = valid_ds.map(
+ lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
+).prefetch(tf.data.AUTOTUNE)
+
+"""
+## Model
+
+We will take advantage of our `FeatureSpace` component to build the model, as we want the
+model to be compatible with our preprocessing function, let's use the the `FeatureSpace`
+feature map as the input of our model.
+"""
+
+encoded_features = feature_space.get_encoded_features()
+print(encoded_features)
+
+"""
+This model is quite trivial only for demonstration purposes so don't pay too much
+attention to the architecture.
+"""
+
+x = keras.layers.Dense(64, activation="relu")(encoded_features)
+x = keras.layers.Dropout(0.5)(x)
+output = keras.layers.Dense(1, activation="sigmoid")(x)
+
+model = keras.Model(inputs=encoded_features, outputs=output)
+model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
+
+"""
+## Training
+
+Let's train our model for 20 epochs. Note that feature preprocessing is happening as part
+of the tf.data pipeline, not as part of the model.
+"""
+
+model.fit(
+ preprocessed_train_ds, validation_data=preprocessed_valid_ds, epochs=10, verbose=2
+)
+
+"""
+## Inference on new data with the end-to-end model
+
+Now, we can build our inference model (which includes the `FeatureSpace`) to make
+predictions based on dicts of raw features values, as follows:
+"""
+
+"""
+### Loading the `FeatureSpace`
+
+First let's load the `FeatureSpace` that we saved a few moment ago, this can be quite
+handy if you train a model but want to do inference at different time, possibly using a
+different device or environment.
+"""
+
+loaded_feature_space = keras.saving.load_model("myfeaturespace.keras")
+
+"""
+### Building the inference end-to-end model
+
+To build the inference model we need both the feature input map and the preprocessing
+encoded Keras tensors.
+"""
+
+dict_inputs = loaded_feature_space.get_inputs()
+encoded_features = loaded_feature_space.get_encoded_features()
+print(encoded_features)
+
+print(dict_inputs)
+
+outputs = model(encoded_features)
+inference_model = keras.Model(inputs=dict_inputs, outputs=outputs)
+
+sample = {
+ "age": 30,
+ "job": "blue-collar",
+ "marital": "married",
+ "education": "basic.9y",
+ "default": "no",
+ "housing": "yes",
+ "loan": "no",
+ "contact": "cellular",
+ "month": "may",
+ "day_of_week": "fri",
+ "campaign": 2,
+ "pdays": 999,
+ "previous": 0,
+ "poutcome": "nonexistent",
+ "emp.var.rate": -1.8,
+ "cons.price.idx": 92.893,
+ "cons.conf.idx": -46.2,
+ "euribor3m": 1.313,
+ "nr.employed": 5099.1,
+ "previously_contacted": 0,
+}
+
+input_dict = {
+ name: keras.ops.convert_to_tensor([value]) for name, value in sample.items()
+}
+predictions = inference_model.predict(input_dict)
+
+print(
+ f"This particular client has a {100 * predictions[0][0]:.2f}% probability "
+ "of subscribing a term deposit, as evaluated by our model."
+)
diff --git a/knowledge_base/structured_data/imbalanced_classification.py b/knowledge_base/structured_data/imbalanced_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..663e0024c324c2c676c5c3281b781427cca005ec
--- /dev/null
+++ b/knowledge_base/structured_data/imbalanced_classification.py
@@ -0,0 +1,149 @@
+"""
+Title: Imbalanced classification: credit card fraud detection
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2019/05/28
+Last modified: 2020/04/17
+Description: Demonstration of how to handle highly imbalanced classification problems.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example looks at the
+[Kaggle Credit Card Fraud Detection](https://www.kaggle.com/mlg-ulb/creditcardfraud/)
+dataset to demonstrate how
+to train a classification model on data with highly imbalanced classes.
+"""
+
+"""
+## First, vectorize the CSV data
+"""
+
+import csv
+import numpy as np
+
+# Get the real data from https://www.kaggle.com/mlg-ulb/creditcardfraud/
+fname = "/Users/fchollet/Downloads/creditcard.csv"
+
+all_features = []
+all_targets = []
+with open(fname) as f:
+ for i, line in enumerate(f):
+ if i == 0:
+ print("HEADER:", line.strip())
+ continue # Skip header
+ fields = line.strip().split(",")
+ all_features.append([float(v.replace('"', "")) for v in fields[:-1]])
+ all_targets.append([int(fields[-1].replace('"', ""))])
+ if i == 1:
+ print("EXAMPLE FEATURES:", all_features[-1])
+
+features = np.array(all_features, dtype="float32")
+targets = np.array(all_targets, dtype="uint8")
+print("features.shape:", features.shape)
+print("targets.shape:", targets.shape)
+
+"""
+## Prepare a validation set
+"""
+
+num_val_samples = int(len(features) * 0.2)
+train_features = features[:-num_val_samples]
+train_targets = targets[:-num_val_samples]
+val_features = features[-num_val_samples:]
+val_targets = targets[-num_val_samples:]
+
+print("Number of training samples:", len(train_features))
+print("Number of validation samples:", len(val_features))
+
+"""
+## Analyze class imbalance in the targets
+"""
+
+counts = np.bincount(train_targets[:, 0])
+print(
+ "Number of positive samples in training data: {} ({:.2f}% of total)".format(
+ counts[1], 100 * float(counts[1]) / len(train_targets)
+ )
+)
+
+weight_for_0 = 1.0 / counts[0]
+weight_for_1 = 1.0 / counts[1]
+
+"""
+## Normalize the data using training set statistics
+"""
+
+mean = np.mean(train_features, axis=0)
+train_features -= mean
+val_features -= mean
+std = np.std(train_features, axis=0)
+train_features /= std
+val_features /= std
+
+"""
+## Build a binary classification model
+"""
+
+import keras
+
+model = keras.Sequential(
+ [
+ keras.Input(shape=train_features.shape[1:]),
+ keras.layers.Dense(256, activation="relu"),
+ keras.layers.Dense(256, activation="relu"),
+ keras.layers.Dropout(0.3),
+ keras.layers.Dense(256, activation="relu"),
+ keras.layers.Dropout(0.3),
+ keras.layers.Dense(1, activation="sigmoid"),
+ ]
+)
+model.summary()
+
+"""
+## Train the model with `class_weight` argument
+"""
+
+metrics = [
+ keras.metrics.FalseNegatives(name="fn"),
+ keras.metrics.FalsePositives(name="fp"),
+ keras.metrics.TrueNegatives(name="tn"),
+ keras.metrics.TruePositives(name="tp"),
+ keras.metrics.Precision(name="precision"),
+ keras.metrics.Recall(name="recall"),
+]
+
+model.compile(
+ optimizer=keras.optimizers.Adam(1e-2), loss="binary_crossentropy", metrics=metrics
+)
+
+callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.keras")]
+class_weight = {0: weight_for_0, 1: weight_for_1}
+
+model.fit(
+ train_features,
+ train_targets,
+ batch_size=2048,
+ epochs=30,
+ verbose=2,
+ callbacks=callbacks,
+ validation_data=(val_features, val_targets),
+ class_weight=class_weight,
+)
+
+"""
+## Conclusions
+
+At the end of training, out of 56,961 validation transactions, we are:
+
+- Correctly identifying 66 of them as fraudulent
+- Missing 9 fraudulent transactions
+- At the cost of incorrectly flagging 441 legitimate transactions
+
+In the real world, one would put an even higher weight on class 1,
+so as to reflect that False Negatives are more costly than False Positives.
+
+Next time your credit card gets declined in an online purchase -- this is why.
+
+"""
diff --git a/knowledge_base/structured_data/movielens_recommendations_transformers.py b/knowledge_base/structured_data/movielens_recommendations_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e20bed0cbd51556c926e5717e132efc651e947
--- /dev/null
+++ b/knowledge_base/structured_data/movielens_recommendations_transformers.py
@@ -0,0 +1,578 @@
+"""
+Title: A Transformer-based recommendation system
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2020/12/30
+Last modified: 2025/01/27
+Description: Rating rate prediction using the Behavior Sequence Transformer (BST) model on the Movielens.
+Accelerator: GPU
+Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
+"""
+
+"""
+## Introduction
+
+This example demonstrates the [Behavior Sequence Transformer (BST)](https://arxiv.org/abs/1905.06874)
+model, by Qiwei Chen et al., using the [Movielens dataset](https://grouplens.org/datasets/movielens/).
+The BST model leverages the sequential behaviour of the users in watching and rating movies,
+as well as user profile and movie features, to predict the rating of the user to a target movie.
+
+More precisely, the BST model aims to predict the rating of a target movie by accepting
+the following inputs:
+
+1. A fixed-length *sequence* of `movie_ids` watched by a user.
+2. A fixed-length *sequence* of the `ratings` for the movies watched by a user.
+3. A *set* of user features, including `user_id`, `sex`, `occupation`, and `age_group`.
+4. A *set* of `genres` for each movie in the input sequence and the target movie.
+5. A `target_movie_id` for which to predict the rating.
+
+This example modifies the original BST model in the following ways:
+
+1. We incorporate the movie features (genres) into the processing of the embedding of each
+movie of the input sequence and the target movie, rather than treating them as "other features"
+outside the transformer layer.
+2. We utilize the ratings of movies in the input sequence, along with the their positions
+in the sequence, to update them before feeding them into the self-attention layer.
+
+
+Note that this example should be run with TensorFlow 2.4 or higher.
+"""
+
+"""
+## The dataset
+
+We use the [1M version of the Movielens dataset](https://grouplens.org/datasets/movielens/1m/).
+The dataset includes around 1 million ratings from 6000 users on 4000 movies,
+along with some user features, movie genres. In addition, the timestamp of each user-movie
+rating is provided, which allows creating sequences of movie ratings for each user,
+as expected by the BST model.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # or torch, or tensorflow
+
+import math
+from zipfile import ZipFile
+from urllib.request import urlretrieve
+import numpy as np
+import pandas as pd
+
+import keras
+from keras import layers, ops
+from keras.layers import StringLookup
+
+"""
+## Prepare the data
+
+### Download and prepare the DataFrames
+
+First, let's download the movielens data.
+
+The downloaded folder will contain three data files: `users.dat`, `movies.dat`,
+and `ratings.dat`.
+"""
+
+urlretrieve("http://files.grouplens.org/datasets/movielens/ml-1m.zip", "movielens.zip")
+ZipFile("movielens.zip", "r").extractall()
+
+"""
+Then, we load the data into pandas DataFrames with their proper column names.
+"""
+
+users = pd.read_csv(
+ "ml-1m/users.dat",
+ sep="::",
+ names=["user_id", "sex", "age_group", "occupation", "zip_code"],
+ encoding="ISO-8859-1",
+ engine="python",
+)
+
+ratings = pd.read_csv(
+ "ml-1m/ratings.dat",
+ sep="::",
+ names=["user_id", "movie_id", "rating", "unix_timestamp"],
+ encoding="ISO-8859-1",
+ engine="python",
+)
+
+movies = pd.read_csv(
+ "ml-1m/movies.dat",
+ sep="::",
+ names=["movie_id", "title", "genres"],
+ encoding="ISO-8859-1",
+ engine="python",
+)
+
+"""
+Here, we do some simple data processing to fix the data types of the columns.
+"""
+
+users["user_id"] = users["user_id"].apply(lambda x: f"user_{x}")
+users["age_group"] = users["age_group"].apply(lambda x: f"group_{x}")
+users["occupation"] = users["occupation"].apply(lambda x: f"occupation_{x}")
+
+movies["movie_id"] = movies["movie_id"].apply(lambda x: f"movie_{x}")
+
+ratings["movie_id"] = ratings["movie_id"].apply(lambda x: f"movie_{x}")
+ratings["user_id"] = ratings["user_id"].apply(lambda x: f"user_{x}")
+ratings["rating"] = ratings["rating"].apply(lambda x: float(x))
+
+"""
+Each movie has multiple genres. We split them into separate columns in the `movies`
+DataFrame.
+"""
+
+genres = ["Action", "Adventure", "Animation", "Children's", "Comedy", "Crime"]
+genres += ["Documentary", "Drama", "Fantasy", "Film-Noir", "Horror", "Musical"]
+genres += ["Mystery", "Romance", "Sci-Fi", "Thriller", "War", "Western"]
+
+for genre in genres:
+ movies[genre] = movies["genres"].apply(
+ lambda values: int(genre in values.split("|"))
+ )
+
+
+"""
+### Transform the movie ratings data into sequences
+
+First, let's sort the the ratings data using the `unix_timestamp`, and then group the
+`movie_id` values and the `rating` values by `user_id`.
+
+The output DataFrame will have a record for each `user_id`, with two ordered lists
+(sorted by rating datetime): the movies they have rated, and their ratings of these movies.
+"""
+
+ratings_group = ratings.sort_values(by=["unix_timestamp"]).groupby("user_id")
+
+ratings_data = pd.DataFrame(
+ data={
+ "user_id": list(ratings_group.groups.keys()),
+ "movie_ids": list(ratings_group.movie_id.apply(list)),
+ "ratings": list(ratings_group.rating.apply(list)),
+ "timestamps": list(ratings_group.unix_timestamp.apply(list)),
+ }
+)
+
+
+"""
+Now, let's split the `movie_ids` list into a set of sequences of a fixed length.
+We do the same for the `ratings`. Set the `sequence_length` variable to change the length
+of the input sequence to the model. You can also change the `step_size` to control the
+number of sequences to generate for each user.
+"""
+
+sequence_length = 4
+step_size = 2
+
+
+def create_sequences(values, window_size, step_size):
+ sequences = []
+ start_index = 0
+ while True:
+ end_index = start_index + window_size
+ seq = values[start_index:end_index]
+ if len(seq) < window_size:
+ seq = values[-window_size:]
+ if len(seq) == window_size:
+ sequences.append(seq)
+ break
+ sequences.append(seq)
+ start_index += step_size
+ return sequences
+
+
+ratings_data.movie_ids = ratings_data.movie_ids.apply(
+ lambda ids: create_sequences(ids, sequence_length, step_size)
+)
+
+ratings_data.ratings = ratings_data.ratings.apply(
+ lambda ids: create_sequences(ids, sequence_length, step_size)
+)
+
+del ratings_data["timestamps"]
+
+"""
+After that, we process the output to have each sequence in a separate records in
+the DataFrame. In addition, we join the user features with the ratings data.
+"""
+
+ratings_data_movies = ratings_data[["user_id", "movie_ids"]].explode(
+ "movie_ids", ignore_index=True
+)
+ratings_data_rating = ratings_data[["ratings"]].explode("ratings", ignore_index=True)
+ratings_data_transformed = pd.concat([ratings_data_movies, ratings_data_rating], axis=1)
+ratings_data_transformed = ratings_data_transformed.join(
+ users.set_index("user_id"), on="user_id"
+)
+ratings_data_transformed.movie_ids = ratings_data_transformed.movie_ids.apply(
+ lambda x: ",".join(x)
+)
+ratings_data_transformed.ratings = ratings_data_transformed.ratings.apply(
+ lambda x: ",".join([str(v) for v in x])
+)
+
+del ratings_data_transformed["zip_code"]
+
+ratings_data_transformed.rename(
+ columns={"movie_ids": "sequence_movie_ids", "ratings": "sequence_ratings"},
+ inplace=True,
+)
+
+"""
+With `sequence_length` of 4 and `step_size` of 2, we end up with 498,623 sequences.
+
+Finally, we split the data into training and testing splits, with 85% and 15% of
+the instances, respectively, and store them to CSV files.
+"""
+
+random_selection = np.random.rand(len(ratings_data_transformed.index)) <= 0.85
+train_data = ratings_data_transformed[random_selection]
+test_data = ratings_data_transformed[~random_selection]
+
+train_data.to_csv("train_data.csv", index=False, sep="|", header=False)
+test_data.to_csv("test_data.csv", index=False, sep="|", header=False)
+
+"""
+## Define metadata
+"""
+
+CSV_HEADER = list(ratings_data_transformed.columns)
+
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ "user_id": list(users.user_id.unique()),
+ "movie_id": list(movies.movie_id.unique()),
+ "sex": list(users.sex.unique()),
+ "age_group": list(users.age_group.unique()),
+ "occupation": list(users.occupation.unique()),
+}
+
+USER_FEATURES = ["sex", "age_group", "occupation"]
+
+MOVIE_FEATURES = ["genres"]
+
+
+"""
+## Encode input features
+
+The `encode_input_features` function works as follows:
+
+1. Each categorical user feature is encoded using `layers.Embedding`, with embedding
+dimension equals to the square root of the vocabulary size of the feature.
+The embeddings of these features are concatenated to form a single input tensor.
+
+2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,
+where the dimension size is the square root of the number of movies.
+
+3. A multi-hot genres vector for each movie is concatenated with its embedding vector,
+and processed using a non-linear `layers.Dense` to output a vector of the same movie
+embedding dimensions.
+
+4. A positional embedding is added to each movie embedding in the sequence, and then
+multiplied by its rating from the ratings sequence.
+
+5. The target movie embedding is concatenated to the sequence movie embeddings, producing
+a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected
+by the attention layer for the transformer architecture.
+
+6. The method returns a tuple of two elements: `encoded_transformer_features` and
+`encoded_other_features`.
+"""
+
+# Required for tf.data.Dataset
+import tensorflow as tf
+
+
+def get_dataset_from_csv(csv_file_path, batch_size, shuffle=True):
+
+ def process(features):
+ movie_ids_string = features["sequence_movie_ids"]
+ sequence_movie_ids = tf.strings.split(movie_ids_string, ",").to_tensor()
+ # The last movie id in the sequence is the target movie.
+ features["target_movie_id"] = sequence_movie_ids[:, -1]
+ features["sequence_movie_ids"] = sequence_movie_ids[:, :-1]
+ # Sequence ratings
+ ratings_string = features["sequence_ratings"]
+ sequence_ratings = tf.strings.to_number(
+ tf.strings.split(ratings_string, ","), tf.dtypes.float32
+ ).to_tensor()
+ # The last rating in the sequence is the target for the model to predict.
+ target = sequence_ratings[:, -1]
+ features["sequence_ratings"] = sequence_ratings[:, :-1]
+
+ def encoding_helper(feature_name):
+
+ # This are target_movie_id and sequence_movie_ids and they have the same
+ # vocabulary as movie_id.
+ if feature_name not in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]
+ index_lookup = StringLookup(
+ vocabulary=vocabulary, mask_token=None, num_oov_indices=0
+ )
+ # Convert the string input values into integer indices.
+ value_index = index_lookup(features[feature_name])
+ features[feature_name] = value_index
+ else:
+ # movie_id is not part of the features, hence not processed. It was mainly required
+ # for its vocabulary above.
+ if feature_name == "movie_id":
+ pass
+ else:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ index_lookup = StringLookup(
+ vocabulary=vocabulary, mask_token=None, num_oov_indices=0
+ )
+ # Convert the string input values into integer indices.
+ value_index = index_lookup(features[feature_name])
+ features[feature_name] = value_index
+
+ # Encode the user features
+ for feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ encoding_helper(feature_name)
+ # Encoding target_movie_id and returning it as the target variable
+ encoding_helper("target_movie_id")
+ # Encoding sequence movie_ids.
+ encoding_helper("sequence_movie_ids")
+ return dict(features), target
+
+ dataset = tf.data.experimental.make_csv_dataset(
+ csv_file_path,
+ batch_size=batch_size,
+ column_names=CSV_HEADER,
+ num_epochs=1,
+ header=False,
+ field_delim="|",
+ shuffle=shuffle,
+ ).map(process)
+ return dataset
+
+
+def encode_input_features(
+ inputs,
+ include_user_id,
+ include_user_features,
+ include_movie_features,
+):
+ encoded_transformer_features = []
+ encoded_other_features = []
+
+ other_feature_names = []
+ if include_user_id:
+ other_feature_names.append("user_id")
+ if include_user_features:
+ other_feature_names.extend(USER_FEATURES)
+
+ ## Encode user features
+ for feature_name in other_feature_names:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Compute embedding dimensions
+ embedding_dims = int(math.sqrt(len(vocabulary)))
+ # Create an embedding layer with the specified dimensions.
+ embedding_encoder = layers.Embedding(
+ input_dim=len(vocabulary),
+ output_dim=embedding_dims,
+ name=f"{feature_name}_embedding",
+ )
+ # Convert the index values to embedding representations.
+ encoded_other_features.append(embedding_encoder(inputs[feature_name]))
+
+ ## Create a single embedding vector for the user features
+ if len(encoded_other_features) > 1:
+ encoded_other_features = layers.concatenate(encoded_other_features)
+ elif len(encoded_other_features) == 1:
+ encoded_other_features = encoded_other_features[0]
+ else:
+ encoded_other_features = None
+
+ ## Create a movie embedding encoder
+ movie_vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY["movie_id"]
+ movie_embedding_dims = int(math.sqrt(len(movie_vocabulary)))
+ # Create an embedding layer with the specified dimensions.
+ movie_embedding_encoder = layers.Embedding(
+ input_dim=len(movie_vocabulary),
+ output_dim=movie_embedding_dims,
+ name=f"movie_embedding",
+ )
+ # Create a vector lookup for movie genres.
+ genre_vectors = movies[genres].to_numpy()
+ movie_genres_lookup = layers.Embedding(
+ input_dim=genre_vectors.shape[0],
+ output_dim=genre_vectors.shape[1],
+ embeddings_initializer=keras.initializers.Constant(genre_vectors),
+ trainable=False,
+ name="genres_vector",
+ )
+ # Create a processing layer for genres.
+ movie_embedding_processor = layers.Dense(
+ units=movie_embedding_dims,
+ activation="relu",
+ name="process_movie_embedding_with_genres",
+ )
+
+ ## Define a function to encode a given movie id.
+ def encode_movie(movie_id):
+ # Convert the string input values into integer indices.
+ movie_embedding = movie_embedding_encoder(movie_id)
+ encoded_movie = movie_embedding
+ if include_movie_features:
+ movie_genres_vector = movie_genres_lookup(movie_id)
+ encoded_movie = movie_embedding_processor(
+ layers.concatenate([movie_embedding, movie_genres_vector])
+ )
+ return encoded_movie
+
+ ## Encoding target_movie_id
+ target_movie_id = inputs["target_movie_id"]
+ encoded_target_movie = encode_movie(target_movie_id)
+
+ ## Encoding sequence movie_ids.
+ sequence_movies_ids = inputs["sequence_movie_ids"]
+ encoded_sequence_movies = encode_movie(sequence_movies_ids)
+ # Create positional embedding.
+ position_embedding_encoder = layers.Embedding(
+ input_dim=sequence_length,
+ output_dim=movie_embedding_dims,
+ name="position_embedding",
+ )
+ positions = ops.arange(start=0, stop=sequence_length - 1, step=1)
+ encodded_positions = position_embedding_encoder(positions)
+ # Retrieve sequence ratings to incorporate them into the encoding of the movie.
+ sequence_ratings = inputs["sequence_ratings"]
+ sequence_ratings = ops.expand_dims(sequence_ratings, -1)
+ # Add the positional encoding to the movie encodings and multiply them by rating.
+ encoded_sequence_movies_with_poistion_and_rating = layers.Multiply()(
+ [(encoded_sequence_movies + encodded_positions), sequence_ratings]
+ )
+
+ # Construct the transformer inputs.
+ for i in range(sequence_length - 1):
+ feature = encoded_sequence_movies_with_poistion_and_rating[:, i, ...]
+ feature = ops.expand_dims(feature, 1)
+ encoded_transformer_features.append(feature)
+ encoded_transformer_features.append(encoded_target_movie)
+ encoded_transformer_features = layers.concatenate(
+ encoded_transformer_features, axis=1
+ )
+ return encoded_transformer_features, encoded_other_features
+
+
+"""
+## Create model inputs
+"""
+
+
+def create_model_inputs():
+ return {
+ "user_id": keras.Input(name="user_id", shape=(1,), dtype="int32"),
+ "sequence_movie_ids": keras.Input(
+ name="sequence_movie_ids", shape=(sequence_length - 1,), dtype="int32"
+ ),
+ "target_movie_id": keras.Input(
+ name="target_movie_id", shape=(1,), dtype="int32"
+ ),
+ "sequence_ratings": keras.Input(
+ name="sequence_ratings", shape=(sequence_length - 1,), dtype="float32"
+ ),
+ "sex": keras.Input(name="sex", shape=(1,), dtype="int32"),
+ "age_group": keras.Input(name="age_group", shape=(1,), dtype="int32"),
+ "occupation": keras.Input(name="occupation", shape=(1,), dtype="int32"),
+ }
+
+
+"""
+## Create a BST model
+"""
+
+include_user_id = False
+include_user_features = False
+include_movie_features = False
+
+hidden_units = [256, 128]
+dropout_rate = 0.1
+num_heads = 3
+
+
+def create_model():
+
+ inputs = create_model_inputs()
+ transformer_features, other_features = encode_input_features(
+ inputs, include_user_id, include_user_features, include_movie_features
+ )
+ # Create a multi-headed attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=transformer_features.shape[2], dropout=dropout_rate
+ )(transformer_features, transformer_features)
+
+ # Transformer block.
+ attention_output = layers.Dropout(dropout_rate)(attention_output)
+ x1 = layers.Add()([transformer_features, attention_output])
+ x1 = layers.LayerNormalization()(x1)
+ x2 = layers.LeakyReLU()(x1)
+ x2 = layers.Dense(units=x2.shape[-1])(x2)
+ x2 = layers.Dropout(dropout_rate)(x2)
+ transformer_features = layers.Add()([x1, x2])
+ transformer_features = layers.LayerNormalization()(transformer_features)
+ features = layers.Flatten()(transformer_features)
+
+ # Included the other_features.
+ if other_features is not None:
+ features = layers.concatenate(
+ [features, layers.Reshape([other_features.shape[-1]])(other_features)]
+ )
+
+ # Fully-connected layers.
+ for num_units in hidden_units:
+ features = layers.Dense(num_units)(features)
+ features = layers.BatchNormalization()(features)
+ features = layers.LeakyReLU()(features)
+ features = layers.Dropout(dropout_rate)(features)
+ outputs = layers.Dense(units=1)(features)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+model = create_model()
+
+"""
+## Run training and evaluation experiment
+"""
+
+# Compile the model.
+model.compile(
+ optimizer=keras.optimizers.Adagrad(learning_rate=0.01),
+ loss=keras.losses.MeanSquaredError(),
+ metrics=[keras.metrics.MeanAbsoluteError()],
+)
+
+# Read the training data.
+
+train_dataset = get_dataset_from_csv("train_data.csv", batch_size=265, shuffle=True)
+
+# Fit the model with the training data.
+model.fit(train_dataset, epochs=2)
+
+# Read the test data.
+test_dataset = get_dataset_from_csv("test_data.csv", batch_size=265)
+
+# Evaluate the model on the test data.
+_, rmse = model.evaluate(test_dataset, verbose=0)
+print(f"Test MAE: {round(rmse, 3)}")
+
+"""
+You should achieve a Mean Absolute Error (MAE) at or around 0.7 on the test data.
+"""
+
+"""
+## Conclusion
+
+The BST model uses the Transformer layer in its architecture to capture the sequential signals underlying
+usersโ behavior sequences for recommendation.
+
+You can try training this model with different configurations, for example, by increasing
+the input sequence length and training the model for a larger number of epochs. In addition,
+you can try including other features like movie release year and customer
+zipcode, and including cross features like sex X genre.
+"""
diff --git a/knowledge_base/structured_data/structured_data_classification_from_scratch.py b/knowledge_base/structured_data/structured_data_classification_from_scratch.py
new file mode 100644
index 0000000000000000000000000000000000000000..818c1f7bec419d2182021ac70545d50f775e8a90
--- /dev/null
+++ b/knowledge_base/structured_data/structured_data_classification_from_scratch.py
@@ -0,0 +1,428 @@
+"""
+Title: Structured data classification from scratch
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/06/09
+Last modified: 2020/06/09
+Description: Binary classification of structured data including numerical and categorical features.
+Accelerator: GPU
+Made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to do structured data classification, starting from a raw
+CSV file. Our data includes both numerical and categorical features. We will use Keras
+preprocessing layers to normalize the numerical features and vectorize the categorical
+ones.
+
+Note that this example should be run with TensorFlow 2.5 or higher.
+
+### The dataset
+
+[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the
+Cleveland Clinic Foundation for Heart Disease.
+It's a CSV file with 303 rows. Each row contains information about a patient (a
+**sample**), and each column describes an attribute of the patient (a **feature**). We
+use the features to predict whether a patient has a heart disease (**binary
+classification**).
+
+Here's the description of each feature:
+
+Column| Description| Feature Type
+------------|--------------------|----------------------
+Age | Age in years | Numerical
+Sex | (1 = male; 0 = female) | Categorical
+CP | Chest pain type (0, 1, 2, 3, 4) | Categorical
+Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical
+Chol | Serum cholesterol in mg/dl | Numerical
+FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical
+RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical
+Thalach | Maximum heart rate achieved | Numerical
+Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical
+Oldpeak | ST depression induced by exercise relative to rest | Numerical
+Slope | Slope of the peak exercise ST segment | Numerical
+CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical
+Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical
+Target | Diagnosis of heart disease (1 = true; 0 = false) | Target
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "torch" # or torch, or tensorflow
+
+import pandas as pd
+import keras
+from keras import layers
+
+"""
+## Preparing the data
+
+Let's download the data and load it into a Pandas dataframe:
+"""
+
+file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
+dataframe = pd.read_csv(file_url)
+
+"""
+The dataset includes 303 samples with 14 columns per sample (13 features, plus the target
+label):
+"""
+
+dataframe.shape
+
+"""
+Here's a preview of a few samples:
+"""
+
+dataframe.head()
+
+"""
+The last column, "target", indicates whether the patient has a heart disease (1) or not
+(0).
+
+Let's split the data into a training and validation set:
+"""
+
+val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
+train_dataframe = dataframe.drop(val_dataframe.index)
+
+print(
+ f"Using {len(train_dataframe)} samples for training "
+ f"and {len(val_dataframe)} for validation"
+)
+
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for reading and
+parsing the data into input features, and encoding the input features with respect
+to their types.
+"""
+
+COLUMN_NAMES = [
+ "age",
+ "sex",
+ "cp",
+ "trestbps",
+ "chol",
+ "fbs",
+ "restecg",
+ "thalach",
+ "exang",
+ "oldpeak",
+ "slope",
+ "ca",
+ "thal",
+ "target",
+]
+# Target feature name.
+TARGET_FEATURE_NAME = "target"
+# Numeric feature names.
+NUMERIC_FEATURE_NAMES = ["age", "trestbps", "thalach", "oldpeak", "slope", "chol"]
+# Categorical features and their vocabulary lists.
+# Note that we add 'v=' as a prefix to all categorical feature values to make
+# sure that they are treated as strings.
+
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ feature_name: sorted(
+ [
+ # Integer categorcal must be int and string must be str
+ value if dataframe[feature_name].dtype == "int64" else str(value)
+ for value in list(dataframe[feature_name].unique())
+ ]
+ )
+ for feature_name in COLUMN_NAMES
+ if feature_name not in list(NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME])
+}
+# All features names.
+FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
+ CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
+)
+
+
+"""
+## Feature preprocessing with Keras layers
+
+
+The following features are categorical features encoded as integers:
+
+- `sex`
+- `cp`
+- `fbs`
+- `restecg`
+- `exang`
+- `ca`
+
+We will encode these features using **one-hot encoding**. We have two options
+here:
+
+ - Use `CategoryEncoding()`, which requires knowing the range of input values
+ and will error on input outside the range.
+ - Use `IntegerLookup()` which will build a lookup table for inputs and reserve
+ an output index for unkown input values.
+
+For this example, we want a simple solution that will handle out of range inputs
+at inference, so we will use `IntegerLookup()`.
+
+We also have a categorical feature encoded as a string: `thal`. We will create an
+index of all possible features and encode output using the `StringLookup()` layer.
+
+Finally, the following feature are continuous numerical features:
+
+- `age`
+- `trestbps`
+- `chol`
+- `thalach`
+- `oldpeak`
+- `slope`
+
+For each of these features, we will use a `Normalization()` layer to make sure the mean
+of each feature is 0 and its standard deviation is 1.
+
+Below, we define 2 utility functions to do the operations:
+
+- `encode_numerical_feature` to apply featurewise normalization to numerical features.
+- `process` to one-hot encode string or integer categorical features.
+"""
+
+# Tensorflow required for tf.data.Dataset
+import tensorflow as tf
+
+
+# We process our datasets elements here (categorical) and convert them to indices to avoid this step
+# during model training since only tensorflow support strings.
+def encode_categorical(features, target):
+ for feature_name in features:
+ if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ lookup_class = (
+ layers.StringLookup
+ if features[feature_name].dtype == "string"
+ else layers.IntegerLookup
+ )
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ index = lookup_class(
+ vocabulary=vocabulary,
+ mask_token=None,
+ num_oov_indices=0,
+ output_mode="binary",
+ )
+ # Convert the string input values into integer indices.
+ value_index = index(features[feature_name])
+ features[feature_name] = value_index
+
+ else:
+ pass
+
+ # Change features from OrderedDict to Dict to match Inputs as they are Dict.
+ return dict(features), target
+
+
+def encode_numerical_feature(feature, name, dataset):
+ # Create a Normalization layer for our feature
+ normalizer = layers.Normalization()
+ # Prepare a Dataset that only yields our feature
+ feature_ds = dataset.map(lambda x, y: x[name])
+ feature_ds = feature_ds.map(lambda x: tf.expand_dims(x, -1))
+ # Learn the statistics of the data
+ normalizer.adapt(feature_ds)
+ # Normalize the input feature
+ encoded_feature = normalizer(feature)
+ return encoded_feature
+
+
+"""
+Let's generate `tf.data.Dataset` objects for each dataframe:
+"""
+
+
+def dataframe_to_dataset(dataframe):
+ dataframe = dataframe.copy()
+ labels = dataframe.pop("target")
+ ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels)).map(
+ encode_categorical
+ )
+ ds = ds.shuffle(buffer_size=len(dataframe))
+ return ds
+
+
+train_ds = dataframe_to_dataset(train_dataframe)
+val_ds = dataframe_to_dataset(val_dataframe)
+
+"""
+Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
+and `target` is the value `0` or `1`:
+"""
+
+for x, y in train_ds.take(1):
+ print("Input:", x)
+ print("Target:", y)
+
+"""
+Let's batch the datasets:
+"""
+
+train_ds = train_ds.batch(32)
+val_ds = val_ds.batch(32)
+
+
+"""
+## Build a model
+
+With this done, we can create our end-to-end model:
+"""
+
+
+# Categorical features have different shapes after the encoding, dependent on the
+# vocabulary or unique values of each feature. We create them accordinly to match the
+# input data elements generated by tf.data.Dataset after pre-processing them
+def create_model_inputs():
+ inputs = {}
+
+ # This a helper function for creating categorical features
+ def create_input_helper(feature_name):
+ num_categories = len(CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name])
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(num_categories,), dtype="int64"
+ )
+ return inputs
+
+ for feature_name in FEATURE_NAMES:
+ if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ # Categorical features
+ create_input_helper(feature_name)
+ else:
+ # Make them float32, they are Real numbers
+ feature_input = layers.Input(name=feature_name, shape=(1,), dtype="float32")
+ # Process the Inputs here
+ inputs[feature_name] = encode_numerical_feature(
+ feature_input, feature_name, train_ds
+ )
+ return inputs
+
+
+# This Layer defines the logic of the Model to perform the classification
+class Classifier(keras.layers.Layer):
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dense_1 = layers.Dense(32, activation="relu")
+ self.dropout = layers.Dropout(0.5)
+ self.dense_2 = layers.Dense(1, activation="sigmoid")
+
+ def call(self, inputs):
+ all_features = layers.concatenate(list(inputs.values()))
+ x = self.dense_1(all_features)
+ x = self.dropout(x)
+ output = self.dense_2(x)
+ return output
+
+ # Surpress build warnings
+ def build(self, input_shape):
+ self.built = True
+
+
+# Create the Classifier model
+def create_model():
+ all_inputs = create_model_inputs()
+ output = Classifier()(all_inputs)
+ model = keras.Model(all_inputs, output)
+ return model
+
+
+model = create_model()
+model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
+
+"""
+Let's visualize our connectivity graph:
+"""
+
+# `rankdir='LR'` is to make the graph horizontal.
+keras.utils.plot_model(model, show_shapes=True, rankdir="LR")
+
+"""
+## Train the model
+"""
+
+model.fit(train_ds, epochs=50, validation_data=val_ds)
+
+
+"""
+We quickly get to 80% validation accuracy.
+"""
+
+"""
+## Inference on new data
+
+To get a prediction for a new sample, you can simply call `model.predict()`. There are
+just two things you need to do:
+
+1. wrap scalars into a list so as to have a batch dimension (models only process batches
+of data, not single samples)
+2. Call `convert_to_tensor` on each feature
+"""
+
+sample = {
+ "age": 60,
+ "sex": 1,
+ "cp": 1,
+ "trestbps": 145,
+ "chol": 233,
+ "fbs": 1,
+ "restecg": 2,
+ "thalach": 150,
+ "exang": 0,
+ "oldpeak": 2.3,
+ "slope": 3,
+ "ca": 0,
+ "thal": "fixed",
+}
+
+
+# Given the category (in the sample above - key) and the category value (in the sample above - value),
+# we return its one-hot encoding
+def get_cat_encoding(cat, cat_value):
+ # Create a list of zeros with the same length as categories
+ encoding = [0] * len(cat)
+ # Find the index of category_value in categories and set the corresponding position to 1
+ if cat_value in cat:
+ encoding[cat.index(cat_value)] = 1
+ return encoding
+
+
+for name, value in sample.items():
+ if name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
+ sample.update(
+ {
+ name: get_cat_encoding(
+ CATEGORICAL_FEATURES_WITH_VOCABULARY[name], sample[name]
+ )
+ }
+ )
+# Convert inputs to tensors
+input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
+predictions = model.predict(input_dict)
+
+print(
+ f"This particular patient had a {100 * predictions[0][0]:.1f} "
+ "percent probability of having a heart disease, "
+ "as evaluated by our model."
+)
+
+"""
+## Conclusions
+
+- The orignal model (the one that runs only on tensorflow) converges quickly to around 80% and remains
+there for extended periods and at times hits 85%
+- The updated model (the backed-agnostic) model may fluctuate between 78% and 83% and at times hitting 86%
+validation accuracy and converges around 80% also.
+
+"""
diff --git a/knowledge_base/structured_data/structured_data_classification_with_feature_space.py b/knowledge_base/structured_data/structured_data_classification_with_feature_space.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a521983891ee33b1df7cbb6452dcdb3599432fa
--- /dev/null
+++ b/knowledge_base/structured_data/structured_data_classification_with_feature_space.py
@@ -0,0 +1,383 @@
+"""
+Title: Structured data classification with FeatureSpace
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2022/11/09
+Last modified: 2022/11/09
+Description: Classify tabular data in a few lines of code.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to do structured data classification
+(also known as tabular data classification), starting from a raw
+CSV file. Our data includes numerical features,
+and integer categorical features, and string categorical features.
+We will use the utility `keras.utils.FeatureSpace` to index,
+preprocess, and encode our features.
+
+The code is adapted from the example
+[Structured data classification from scratch](https://keras.io/examples/structured_data/structured_data_classification_from_scratch/).
+While the previous example managed its own low-level feature preprocessing and
+encoding with Keras preprocessing layers, in this example we
+delegate everything to `FeatureSpace`, making the workflow
+extremely quick and easy.
+
+### The dataset
+
+[Our dataset](https://archive.ics.uci.edu/ml/datasets/heart+Disease) is provided by the
+Cleveland Clinic Foundation for Heart Disease.
+It's a CSV file with 303 rows. Each row contains information about a patient (a
+**sample**), and each column describes an attribute of the patient (a **feature**). We
+use the features to predict whether a patient has a heart disease
+(**binary classification**).
+
+Here's the description of each feature:
+
+Column| Description| Feature Type
+------------|--------------------|----------------------
+Age | Age in years | Numerical
+Sex | (1 = male; 0 = female) | Categorical
+CP | Chest pain type (0, 1, 2, 3, 4) | Categorical
+Trestbpd | Resting blood pressure (in mm Hg on admission) | Numerical
+Chol | Serum cholesterol in mg/dl | Numerical
+FBS | fasting blood sugar in 120 mg/dl (1 = true; 0 = false) | Categorical
+RestECG | Resting electrocardiogram results (0, 1, 2) | Categorical
+Thalach | Maximum heart rate achieved | Numerical
+Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical
+Oldpeak | ST depression induced by exercise relative to rest | Numerical
+Slope | Slope of the peak exercise ST segment | Numerical
+CA | Number of major vessels (0-3) colored by fluoroscopy | Both numerical & categorical
+Thal | 3 = normal; 6 = fixed defect; 7 = reversible defect | Categorical
+Target | Diagnosis of heart disease (1 = true; 0 = false) | Target
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import tensorflow as tf
+import pandas as pd
+import keras
+from keras.utils import FeatureSpace
+
+"""
+## Preparing the data
+
+Let's download the data and load it into a Pandas dataframe:
+"""
+
+file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
+dataframe = pd.read_csv(file_url)
+
+"""
+The dataset includes 303 samples with 14 columns per sample
+(13 features, plus the target label):
+"""
+
+print(dataframe.shape)
+
+"""
+Here's a preview of a few samples:
+"""
+
+dataframe.head()
+
+"""
+The last column, "target", indicates whether the patient
+has a heart disease (1) or not (0).
+
+Let's split the data into a training and validation set:
+"""
+
+val_dataframe = dataframe.sample(frac=0.2, random_state=1337)
+train_dataframe = dataframe.drop(val_dataframe.index)
+
+print(
+ "Using %d samples for training and %d for validation"
+ % (len(train_dataframe), len(val_dataframe))
+)
+
+"""
+Let's generate `tf.data.Dataset` objects for each dataframe:
+"""
+
+
+def dataframe_to_dataset(dataframe):
+ dataframe = dataframe.copy()
+ labels = dataframe.pop("target")
+ ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
+ ds = ds.shuffle(buffer_size=len(dataframe))
+ return ds
+
+
+train_ds = dataframe_to_dataset(train_dataframe)
+val_ds = dataframe_to_dataset(val_dataframe)
+
+"""
+Each `Dataset` yields a tuple `(input, target)` where `input` is a dictionary of features
+and `target` is the value `0` or `1`:
+"""
+
+for x, y in train_ds.take(1):
+ print("Input:", x)
+ print("Target:", y)
+
+"""
+Let's batch the datasets:
+"""
+
+train_ds = train_ds.batch(32)
+val_ds = val_ds.batch(32)
+
+"""
+## Configuring a `FeatureSpace`
+
+To configure how each feature should be preprocessed,
+we instantiate a `keras.utils.FeatureSpace`, and we
+pass to it a dictionary that maps the name of our features
+to a string that describes the feature type.
+
+We have a few "integer categorical" features such as `"FBS"`,
+one "string categorical" feature (`"thal"`),
+and a few numerical features, which we'd like to normalize
+-- except `"age"`, which we'd like to discretize into
+a number of bins.
+
+We also use the `crosses` argument
+to capture *feature interactions* for some categorical
+features, that is to say, create additional features
+that represent value co-occurrences for these categorical features.
+You can compute feature crosses like this for arbitrary sets of
+categorical features -- not just tuples of two features.
+Because the resulting co-occurences are hashed
+into a fixed-sized vector, you don't need to worry about whether
+the co-occurence space is too large.
+"""
+
+feature_space = FeatureSpace(
+ features={
+ # Categorical features encoded as integers
+ "sex": "integer_categorical",
+ "cp": "integer_categorical",
+ "fbs": "integer_categorical",
+ "restecg": "integer_categorical",
+ "exang": "integer_categorical",
+ "ca": "integer_categorical",
+ # Categorical feature encoded as string
+ "thal": "string_categorical",
+ # Numerical features to discretize
+ "age": "float_discretized",
+ # Numerical features to normalize
+ "trestbps": "float_normalized",
+ "chol": "float_normalized",
+ "thalach": "float_normalized",
+ "oldpeak": "float_normalized",
+ "slope": "float_normalized",
+ },
+ # We create additional features by hashing
+ # value co-occurrences for the
+ # following groups of categorical features.
+ crosses=[("sex", "age"), ("thal", "ca")],
+ # The hashing space for these co-occurrences
+ # wil be 32-dimensional.
+ crossing_dim=32,
+ # Our utility will one-hot encode all categorical
+ # features and concat all features into a single
+ # vector (one vector per sample).
+ output_mode="concat",
+)
+
+"""
+## Further customizing a `FeatureSpace`
+
+Specifying the feature type via a string name is quick and easy,
+but sometimes you may want to further configure the preprocessing
+of each feature. For instance, in our case, our categorical
+features don't have a large set of possible values -- it's only
+a handful of values per feature (e.g. `1` and `0` for the feature `"FBS"`),
+and all possible values are represented in the training set.
+As a result, we don't need to reserve an index to represent "out of vocabulary" values
+for these features -- which would have been the default behavior.
+Below, we just specify `num_oov_indices=0` in each of these features
+to tell the feature preprocessor to skip "out of vocabulary" indexing.
+
+Other customizations you have access to include specifying the number of
+bins for discretizing features of type `"float_discretized"`,
+or the dimensionality of the hashing space for feature crossing.
+"""
+
+feature_space = FeatureSpace(
+ features={
+ # Categorical features encoded as integers
+ "sex": FeatureSpace.integer_categorical(num_oov_indices=0),
+ "cp": FeatureSpace.integer_categorical(num_oov_indices=0),
+ "fbs": FeatureSpace.integer_categorical(num_oov_indices=0),
+ "restecg": FeatureSpace.integer_categorical(num_oov_indices=0),
+ "exang": FeatureSpace.integer_categorical(num_oov_indices=0),
+ "ca": FeatureSpace.integer_categorical(num_oov_indices=0),
+ # Categorical feature encoded as string
+ "thal": FeatureSpace.string_categorical(num_oov_indices=0),
+ # Numerical features to discretize
+ "age": FeatureSpace.float_discretized(num_bins=30),
+ # Numerical features to normalize
+ "trestbps": FeatureSpace.float_normalized(),
+ "chol": FeatureSpace.float_normalized(),
+ "thalach": FeatureSpace.float_normalized(),
+ "oldpeak": FeatureSpace.float_normalized(),
+ "slope": FeatureSpace.float_normalized(),
+ },
+ # Specify feature cross with a custom crossing dim.
+ crosses=[
+ FeatureSpace.cross(feature_names=("sex", "age"), crossing_dim=64),
+ FeatureSpace.cross(
+ feature_names=("thal", "ca"),
+ crossing_dim=16,
+ ),
+ ],
+ output_mode="concat",
+)
+
+"""
+## Adapt the `FeatureSpace` to the training data
+
+Before we start using the `FeatureSpace` to build a model, we have
+to adapt it to the training data. During `adapt()`, the `FeatureSpace` will:
+
+- Index the set of possible values for categorical features.
+- Compute the mean and variance for numerical features to normalize.
+- Compute the value boundaries for the different bins for numerical features to discretize.
+
+Note that `adapt()` should be called on a `tf.data.Dataset` which yields dicts
+of feature values -- no labels.
+"""
+
+train_ds_with_no_labels = train_ds.map(lambda x, _: x)
+feature_space.adapt(train_ds_with_no_labels)
+
+"""
+At this point, the `FeatureSpace` can be called on a dict of raw feature values, and will return a
+single concatenate vector for each sample, combining encoded features and feature crosses.
+"""
+
+for x, _ in train_ds.take(1):
+ preprocessed_x = feature_space(x)
+ print("preprocessed_x.shape:", preprocessed_x.shape)
+ print("preprocessed_x.dtype:", preprocessed_x.dtype)
+
+"""
+## Two ways to manage preprocessing: as part of the `tf.data` pipeline, or in the model itself
+
+There are two ways in which you can leverage your `FeatureSpace`:
+
+### Asynchronous preprocessing in `tf.data`
+
+You can make it part of your data pipeline, before the model. This enables asynchronous parallel
+preprocessing of the data on CPU before it hits the model. Do this if you're training on GPU or TPU,
+or if you want to speed up preprocessing. Usually, this is always the right thing to do during training.
+
+### Synchronous preprocessing in the model
+
+You can make it part of your model. This means that the model will expect dicts of raw feature
+values, and the preprocessing batch will be done synchronously (in a blocking manner) before the
+rest of the forward pass. Do this if you want to have an end-to-end model that can process
+raw feature values -- but keep in mind that your model will only be able to run on CPU,
+since most types of feature preprocessing (e.g. string preprocessing) are not GPU or TPU compatible.
+
+Do not do this on GPU / TPU or in performance-sensitive settings. In general, you want to do in-model
+preprocessing when you do inference on CPU.
+
+In our case, we will apply the `FeatureSpace` in the tf.data pipeline during training, but we will
+do inference with an end-to-end model that includes the `FeatureSpace`.
+"""
+
+"""
+Let's create a training and validation dataset of preprocessed batches:
+"""
+
+preprocessed_train_ds = train_ds.map(
+ lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
+)
+preprocessed_train_ds = preprocessed_train_ds.prefetch(tf.data.AUTOTUNE)
+
+preprocessed_val_ds = val_ds.map(
+ lambda x, y: (feature_space(x), y), num_parallel_calls=tf.data.AUTOTUNE
+)
+preprocessed_val_ds = preprocessed_val_ds.prefetch(tf.data.AUTOTUNE)
+
+"""
+## Build a model
+
+Time to build a model -- or rather two models:
+
+- A training model that expects preprocessed features (one sample = one vector)
+- An inference model that expects raw features (one sample = dict of raw feature values)
+"""
+
+dict_inputs = feature_space.get_inputs()
+encoded_features = feature_space.get_encoded_features()
+
+x = keras.layers.Dense(32, activation="relu")(encoded_features)
+x = keras.layers.Dropout(0.5)(x)
+predictions = keras.layers.Dense(1, activation="sigmoid")(x)
+
+training_model = keras.Model(inputs=encoded_features, outputs=predictions)
+training_model.compile(
+ optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
+)
+
+inference_model = keras.Model(inputs=dict_inputs, outputs=predictions)
+
+"""
+## Train the model
+
+Let's train our model for 50 epochs. Note that feature preprocessing is happening
+as part of the tf.data pipeline, not as part of the model.
+"""
+
+training_model.fit(
+ preprocessed_train_ds,
+ epochs=20,
+ validation_data=preprocessed_val_ds,
+ verbose=2,
+)
+
+"""
+We quickly get to 80% validation accuracy.
+"""
+
+"""
+## Inference on new data with the end-to-end model
+
+Now, we can use our inference model (which includes the `FeatureSpace`)
+to make predictions based on dicts of raw features values, as follows:
+"""
+
+sample = {
+ "age": 60,
+ "sex": 1,
+ "cp": 1,
+ "trestbps": 145,
+ "chol": 233,
+ "fbs": 1,
+ "restecg": 2,
+ "thalach": 150,
+ "exang": 0,
+ "oldpeak": 2.3,
+ "slope": 3,
+ "ca": 0,
+ "thal": "fixed",
+}
+
+input_dict = {name: tf.convert_to_tensor([value]) for name, value in sample.items()}
+predictions = inference_model.predict(input_dict)
+
+print(
+ f"This particular patient had a {100 * predictions[0][0]:.2f}% probability "
+ "of having a heart disease, as evaluated by our model."
+)
diff --git a/knowledge_base/structured_data/tabtransformer.py b/knowledge_base/structured_data/tabtransformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..304bc442591b3cbb359d1f105f15e42ecb9f76ff
--- /dev/null
+++ b/knowledge_base/structured_data/tabtransformer.py
@@ -0,0 +1,569 @@
+"""
+Title: Structured data learning with TabTransformer
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2022/01/18
+Last modified: 2022/01/18
+Description: Using contextual embeddings for structured data classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to do structured data classification using
+[TabTransformer](https://arxiv.org/abs/2012.06678), a deep tabular data modeling
+architecture for supervised and semi-supervised learning.
+The TabTransformer is built upon self-attention based Transformers.
+The Transformer layers transform the embeddings of categorical features
+into robust contextual embeddings to achieve higher predictive accuracy.
+
+
+
+## Setup
+"""
+import keras
+from keras import layers
+from keras import ops
+
+import math
+import numpy as np
+import pandas as pd
+from tensorflow import data as tf_data
+import matplotlib.pyplot as plt
+from functools import partial
+
+"""
+## Prepare the data
+
+This example uses the
+[United States Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/census+income)
+provided by the
+[UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php).
+The task is binary classification
+to predict whether a person is likely to be making over USD 50,000 a year.
+
+The dataset includes 48,842 instances with 14 input features: 5 numerical features and 9 categorical features.
+
+First, let's load the dataset from the UCI Machine Learning Repository into a Pandas
+DataFrame:
+"""
+
+CSV_HEADER = [
+ "age",
+ "workclass",
+ "fnlwgt",
+ "education",
+ "education_num",
+ "marital_status",
+ "occupation",
+ "relationship",
+ "race",
+ "gender",
+ "capital_gain",
+ "capital_loss",
+ "hours_per_week",
+ "native_country",
+ "income_bracket",
+]
+
+train_data_url = (
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
+)
+train_data = pd.read_csv(train_data_url, header=None, names=CSV_HEADER)
+
+test_data_url = (
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
+)
+test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
+
+print(f"Train dataset shape: {train_data.shape}")
+print(f"Test dataset shape: {test_data.shape}")
+
+"""
+Remove the first record (because it is not a valid data example) and a trailing 'dot' in the class labels.
+"""
+
+test_data = test_data[1:]
+test_data.income_bracket = test_data.income_bracket.apply(
+ lambda value: value.replace(".", "")
+)
+
+"""
+Now we store the training and test data in separate CSV files.
+"""
+
+train_data_file = "train_data.csv"
+test_data_file = "test_data.csv"
+
+train_data.to_csv(train_data_file, index=False, header=False)
+test_data.to_csv(test_data_file, index=False, header=False)
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for reading and parsing
+the data into input features, and encoding the input features with respect to their types.
+"""
+
+# A list of the numerical feature names.
+NUMERIC_FEATURE_NAMES = [
+ "age",
+ "education_num",
+ "capital_gain",
+ "capital_loss",
+ "hours_per_week",
+]
+# A dictionary of the categorical features and their vocabulary.
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ "workclass": sorted(list(train_data["workclass"].unique())),
+ "education": sorted(list(train_data["education"].unique())),
+ "marital_status": sorted(list(train_data["marital_status"].unique())),
+ "occupation": sorted(list(train_data["occupation"].unique())),
+ "relationship": sorted(list(train_data["relationship"].unique())),
+ "race": sorted(list(train_data["race"].unique())),
+ "gender": sorted(list(train_data["gender"].unique())),
+ "native_country": sorted(list(train_data["native_country"].unique())),
+}
+# Name of the column to be used as instances weight.
+WEIGHT_COLUMN_NAME = "fnlwgt"
+# A list of the categorical feature names.
+CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())
+# A list of all the input features.
+FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES
+# A list of column default values for each feature.
+COLUMN_DEFAULTS = [
+ [0.0] if feature_name in NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME] else ["NA"]
+ for feature_name in CSV_HEADER
+]
+# The name of the target feature.
+TARGET_FEATURE_NAME = "income_bracket"
+# A list of the labels of the target features.
+TARGET_LABELS = [" <=50K", " >50K"]
+
+"""
+## Configure the hyperparameters
+
+The hyperparameters includes model architecture and training configurations.
+"""
+
+LEARNING_RATE = 0.001
+WEIGHT_DECAY = 0.0001
+DROPOUT_RATE = 0.2
+BATCH_SIZE = 265
+NUM_EPOCHS = 15
+
+NUM_TRANSFORMER_BLOCKS = 3 # Number of transformer blocks.
+NUM_HEADS = 4 # Number of attention heads.
+EMBEDDING_DIMS = 16 # Embedding dimensions of the categorical features.
+MLP_HIDDEN_UNITS_FACTORS = [
+ 2,
+ 1,
+] # MLP hidden layer units, as factors of the number of inputs.
+NUM_MLP_BLOCKS = 2 # Number of MLP blocks in the baseline model.
+
+"""
+## Implement data reading pipeline
+
+We define an input function that reads and parses the file, then converts features
+and labels into a[`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets)
+for training or evaluation.
+"""
+
+target_label_lookup = layers.StringLookup(
+ vocabulary=TARGET_LABELS, mask_token=None, num_oov_indices=0
+)
+
+
+def prepare_example(features, target):
+ target_index = target_label_lookup(target)
+ weights = features.pop(WEIGHT_COLUMN_NAME)
+ return features, target_index, weights
+
+
+lookup_dict = {}
+for feature_name in CATEGORICAL_FEATURE_NAMES:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token, nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ lookup = layers.StringLookup(
+ vocabulary=vocabulary, mask_token=None, num_oov_indices=0
+ )
+ lookup_dict[feature_name] = lookup
+
+
+def encode_categorical(batch_x, batch_y, weights):
+ for feature_name in CATEGORICAL_FEATURE_NAMES:
+ batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])
+
+ return batch_x, batch_y, weights
+
+
+def get_dataset_from_csv(csv_file_path, batch_size=128, shuffle=False):
+ dataset = (
+ tf_data.experimental.make_csv_dataset(
+ csv_file_path,
+ batch_size=batch_size,
+ column_names=CSV_HEADER,
+ column_defaults=COLUMN_DEFAULTS,
+ label_name=TARGET_FEATURE_NAME,
+ num_epochs=1,
+ header=False,
+ na_value="?",
+ shuffle=shuffle,
+ )
+ .map(prepare_example, num_parallel_calls=tf_data.AUTOTUNE, deterministic=False)
+ .map(encode_categorical)
+ )
+ return dataset.cache()
+
+
+"""
+## Implement a training and evaluation procedure
+"""
+
+
+def run_experiment(
+ model,
+ train_data_file,
+ test_data_file,
+ num_epochs,
+ learning_rate,
+ weight_decay,
+ batch_size,
+):
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ )
+
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.BinaryCrossentropy(),
+ metrics=[keras.metrics.BinaryAccuracy(name="accuracy")],
+ )
+
+ train_dataset = get_dataset_from_csv(train_data_file, batch_size, shuffle=True)
+ validation_dataset = get_dataset_from_csv(test_data_file, batch_size)
+
+ print("Start training the model...")
+ history = model.fit(
+ train_dataset, epochs=num_epochs, validation_data=validation_dataset
+ )
+ print("Model training finished")
+
+ _, accuracy = model.evaluate(validation_dataset, verbose=0)
+
+ print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
+
+ return history
+
+
+"""
+## Create model inputs
+
+Now, define the inputs for the models as a dictionary, where the key is the feature name,
+and the value is a `keras.layers.Input` tensor with the corresponding feature shape
+and data type.
+"""
+
+
+def create_model_inputs():
+ inputs = {}
+ for feature_name in FEATURE_NAMES:
+ if feature_name in NUMERIC_FEATURE_NAMES:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="float32"
+ )
+ else:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="int32"
+ )
+ return inputs
+
+
+"""
+## Encode features
+
+The `encode_inputs` method returns `encoded_categorical_feature_list` and `numerical_feature_list`.
+We encode the categorical features as embeddings, using a fixed `embedding_dims` for all the features,
+regardless their vocabulary sizes. This is required for the Transformer model.
+"""
+
+
+def encode_inputs(inputs, embedding_dims):
+ encoded_categorical_feature_list = []
+ numerical_feature_list = []
+
+ for feature_name in inputs:
+ if feature_name in CATEGORICAL_FEATURE_NAMES:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert a string values to an integer indices.
+ # Since we are not using a mask token, nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+
+ # Convert the string input values into integer indices.
+
+ # Create an embedding layer with the specified dimensions.
+ embedding = layers.Embedding(
+ input_dim=len(vocabulary), output_dim=embedding_dims
+ )
+
+ # Convert the index values to embedding representations.
+ encoded_categorical_feature = embedding(inputs[feature_name])
+ encoded_categorical_feature_list.append(encoded_categorical_feature)
+
+ else:
+ # Use the numerical features as-is.
+ numerical_feature = ops.expand_dims(inputs[feature_name], -1)
+ numerical_feature_list.append(numerical_feature)
+
+ return encoded_categorical_feature_list, numerical_feature_list
+
+
+"""
+## Implement an MLP block
+"""
+
+
+def create_mlp(hidden_units, dropout_rate, activation, normalization_layer, name=None):
+ mlp_layers = []
+ for units in hidden_units:
+ mlp_layers.append(normalization_layer())
+ mlp_layers.append(layers.Dense(units, activation=activation))
+ mlp_layers.append(layers.Dropout(dropout_rate))
+
+ return keras.Sequential(mlp_layers, name=name)
+
+
+"""
+## Experiment 1: a baseline model
+
+In the first experiment, we create a simple multi-layer feed-forward network.
+"""
+
+
+def create_baseline_model(
+ embedding_dims, num_mlp_blocks, mlp_hidden_units_factors, dropout_rate
+):
+ # Create model inputs.
+ inputs = create_model_inputs()
+ # encode features.
+ encoded_categorical_feature_list, numerical_feature_list = encode_inputs(
+ inputs, embedding_dims
+ )
+ # Concatenate all features.
+ features = layers.concatenate(
+ encoded_categorical_feature_list + numerical_feature_list
+ )
+ # Compute Feedforward layer units.
+ feedforward_units = [features.shape[-1]]
+
+ # Create several feedforwad layers with skip connections.
+ for layer_idx in range(num_mlp_blocks):
+ features = create_mlp(
+ hidden_units=feedforward_units,
+ dropout_rate=dropout_rate,
+ activation=keras.activations.gelu,
+ normalization_layer=layers.LayerNormalization,
+ name=f"feedforward_{layer_idx}",
+ )(features)
+
+ # Compute MLP hidden_units.
+ mlp_hidden_units = [
+ factor * features.shape[-1] for factor in mlp_hidden_units_factors
+ ]
+ # Create final MLP.
+ features = create_mlp(
+ hidden_units=mlp_hidden_units,
+ dropout_rate=dropout_rate,
+ activation=keras.activations.selu,
+ normalization_layer=layers.BatchNormalization,
+ name="MLP",
+ )(features)
+
+ # Add a sigmoid as a binary classifer.
+ outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+baseline_model = create_baseline_model(
+ embedding_dims=EMBEDDING_DIMS,
+ num_mlp_blocks=NUM_MLP_BLOCKS,
+ mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,
+ dropout_rate=DROPOUT_RATE,
+)
+
+print("Total model weights:", baseline_model.count_params())
+keras.utils.plot_model(baseline_model, show_shapes=True, rankdir="LR")
+
+"""
+Let's train and evaluate the baseline model:
+"""
+
+history = run_experiment(
+ model=baseline_model,
+ train_data_file=train_data_file,
+ test_data_file=test_data_file,
+ num_epochs=NUM_EPOCHS,
+ learning_rate=LEARNING_RATE,
+ weight_decay=WEIGHT_DECAY,
+ batch_size=BATCH_SIZE,
+)
+
+"""
+The baseline linear model achieves ~81% validation accuracy.
+"""
+
+"""
+## Experiment 2: TabTransformer
+
+The TabTransformer architecture works as follows:
+
+1. All the categorical features are encoded as embeddings, using the same `embedding_dims`.
+This means that each value in each categorical feature will have its own embedding vector.
+2. A column embedding, one embedding vector for each categorical feature, is added (point-wise) to the categorical feature embedding.
+3. The embedded categorical features are fed into a stack of Transformer blocks.
+Each Transformer block consists of a multi-head self-attention layer followed by a feed-forward layer.
+3. The outputs of the final Transformer layer, which are the *contextual embeddings* of the categorical features,
+are concatenated with the input numerical features, and fed into a final MLP block.
+4. A `softmax` classifer is applied at the end of the model.
+
+The [paper](https://arxiv.org/abs/2012.06678) discusses both addition and concatenation of the column embedding in the
+*Appendix: Experiment and Model Details* section.
+The architecture of TabTransformer is shown below, as presented in the paper.
+
+
+"""
+
+
+def create_tabtransformer_classifier(
+ num_transformer_blocks,
+ num_heads,
+ embedding_dims,
+ mlp_hidden_units_factors,
+ dropout_rate,
+ use_column_embedding=False,
+):
+ # Create model inputs.
+ inputs = create_model_inputs()
+ # encode features.
+ encoded_categorical_feature_list, numerical_feature_list = encode_inputs(
+ inputs, embedding_dims
+ )
+ # Stack categorical feature embeddings for the Tansformer.
+ encoded_categorical_features = ops.stack(encoded_categorical_feature_list, axis=1)
+ # Concatenate numerical features.
+ numerical_features = layers.concatenate(numerical_feature_list)
+
+ # Add column embedding to categorical feature embeddings.
+ if use_column_embedding:
+ num_columns = encoded_categorical_features.shape[1]
+ column_embedding = layers.Embedding(
+ input_dim=num_columns, output_dim=embedding_dims
+ )
+ column_indices = ops.arange(start=0, stop=num_columns, step=1)
+ encoded_categorical_features = encoded_categorical_features + column_embedding(
+ column_indices
+ )
+
+ # Create multiple layers of the Transformer block.
+ for block_idx in range(num_transformer_blocks):
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads,
+ key_dim=embedding_dims,
+ dropout=dropout_rate,
+ name=f"multihead_attention_{block_idx}",
+ )(encoded_categorical_features, encoded_categorical_features)
+ # Skip connection 1.
+ x = layers.Add(name=f"skip_connection1_{block_idx}")(
+ [attention_output, encoded_categorical_features]
+ )
+ # Layer normalization 1.
+ x = layers.LayerNormalization(name=f"layer_norm1_{block_idx}", epsilon=1e-6)(x)
+ # Feedforward.
+ feedforward_output = create_mlp(
+ hidden_units=[embedding_dims],
+ dropout_rate=dropout_rate,
+ activation=keras.activations.gelu,
+ normalization_layer=partial(
+ layers.LayerNormalization, epsilon=1e-6
+ ), # using partial to provide keyword arguments before initialization
+ name=f"feedforward_{block_idx}",
+ )(x)
+ # Skip connection 2.
+ x = layers.Add(name=f"skip_connection2_{block_idx}")([feedforward_output, x])
+ # Layer normalization 2.
+ encoded_categorical_features = layers.LayerNormalization(
+ name=f"layer_norm2_{block_idx}", epsilon=1e-6
+ )(x)
+
+ # Flatten the "contextualized" embeddings of the categorical features.
+ categorical_features = layers.Flatten()(encoded_categorical_features)
+ # Apply layer normalization to the numerical features.
+ numerical_features = layers.LayerNormalization(epsilon=1e-6)(numerical_features)
+ # Prepare the input for the final MLP block.
+ features = layers.concatenate([categorical_features, numerical_features])
+
+ # Compute MLP hidden_units.
+ mlp_hidden_units = [
+ factor * features.shape[-1] for factor in mlp_hidden_units_factors
+ ]
+ # Create final MLP.
+ features = create_mlp(
+ hidden_units=mlp_hidden_units,
+ dropout_rate=dropout_rate,
+ activation=keras.activations.selu,
+ normalization_layer=layers.BatchNormalization,
+ name="MLP",
+ )(features)
+
+ # Add a sigmoid as a binary classifer.
+ outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+tabtransformer_model = create_tabtransformer_classifier(
+ num_transformer_blocks=NUM_TRANSFORMER_BLOCKS,
+ num_heads=NUM_HEADS,
+ embedding_dims=EMBEDDING_DIMS,
+ mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,
+ dropout_rate=DROPOUT_RATE,
+)
+
+print("Total model weights:", tabtransformer_model.count_params())
+keras.utils.plot_model(tabtransformer_model, show_shapes=True, rankdir="LR")
+
+"""
+Let's train and evaluate the TabTransformer model:
+"""
+
+history = run_experiment(
+ model=tabtransformer_model,
+ train_data_file=train_data_file,
+ test_data_file=test_data_file,
+ num_epochs=NUM_EPOCHS,
+ learning_rate=LEARNING_RATE,
+ weight_decay=WEIGHT_DECAY,
+ batch_size=BATCH_SIZE,
+)
+
+"""
+The TabTransformer model achieves ~85% validation accuracy.
+Note that, with the default parameter configurations, both the baseline and the TabTransformer
+have similar number of trainable weights: 109,895 and 87,745 respectively, and both use the same training hyperparameters.
+"""
+
+"""
+## Conclusion
+
+TabTransformer significantly outperforms MLP and recent
+deep networks for tabular data while matching the performance of tree-based ensemble models.
+TabTransformer can be learned in end-to-end supervised training using labeled examples.
+For a scenario where there are a few labeled examples and a large number of unlabeled
+examples, a pre-training procedure can be employed to train the Transformer layers using unlabeled data.
+This is followed by fine-tuning of the pre-trained Transformer layers along with
+the top MLP layer using the labeled data.
+"""
diff --git a/knowledge_base/structured_data/wide_deep_cross_networks.py b/knowledge_base/structured_data/wide_deep_cross_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..36ee4149eee0d153f584902f42b9885c6e60bda5
--- /dev/null
+++ b/knowledge_base/structured_data/wide_deep_cross_networks.py
@@ -0,0 +1,450 @@
+"""
+Title: Structured data learning with Wide, Deep, and Cross networks
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2020/12/31
+Last modified: 2025/01/03
+Description: Using Wide & Deep and Deep & Cross networks for structured data classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to do structured data classification using the two modeling
+techniques:
+
+1. [Wide & Deep](https://ai.googleblog.com/2016/06/wide-deep-learning-better-together-with.html) models
+2. [Deep & Cross](https://arxiv.org/abs/1708.05123) models
+
+Note that this example should be run with TensorFlow 2.5 or higher.
+"""
+
+"""
+## The dataset
+
+This example uses the [Covertype](https://archive.ics.uci.edu/ml/datasets/covertype) dataset from the UCI
+Machine Learning Repository. The task is to predict forest cover type from cartographic variables.
+The dataset includes 506,011 instances with 12 input features: 10 numerical features and 2
+categorical features. Each instance is categorized into 1 of 7 classes.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+# Only the TensorFlow backend supports string inputs.
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import math
+import numpy as np
+import pandas as pd
+from tensorflow import data as tf_data
+import keras
+from keras import layers
+
+"""
+## Prepare the data
+
+First, let's load the dataset from the UCI Machine Learning Repository into a Pandas
+DataFrame:
+"""
+
+data_url = (
+ "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"
+)
+raw_data = pd.read_csv(data_url, header=None)
+print(f"Dataset shape: {raw_data.shape}")
+raw_data.head()
+
+"""
+The two categorical features in the dataset are binary-encoded.
+We will convert this dataset representation to the typical representation, where each
+categorical feature is represented as a single integer value.
+"""
+
+soil_type_values = [f"soil_type_{idx+1}" for idx in range(40)]
+wilderness_area_values = [f"area_type_{idx+1}" for idx in range(4)]
+
+soil_type = raw_data.loc[:, 14:53].apply(
+ lambda x: soil_type_values[0::1][x.to_numpy().nonzero()[0][0]], axis=1
+)
+wilderness_area = raw_data.loc[:, 10:13].apply(
+ lambda x: wilderness_area_values[0::1][x.to_numpy().nonzero()[0][0]], axis=1
+)
+
+CSV_HEADER = [
+ "Elevation",
+ "Aspect",
+ "Slope",
+ "Horizontal_Distance_To_Hydrology",
+ "Vertical_Distance_To_Hydrology",
+ "Horizontal_Distance_To_Roadways",
+ "Hillshade_9am",
+ "Hillshade_Noon",
+ "Hillshade_3pm",
+ "Horizontal_Distance_To_Fire_Points",
+ "Wilderness_Area",
+ "Soil_Type",
+ "Cover_Type",
+]
+
+data = pd.concat(
+ [raw_data.loc[:, 0:9], wilderness_area, soil_type, raw_data.loc[:, 54]],
+ axis=1,
+ ignore_index=True,
+)
+data.columns = CSV_HEADER
+
+# Convert the target label indices into a range from 0 to 6 (there are 7 labels in total).
+data["Cover_Type"] = data["Cover_Type"] - 1
+
+print(f"Dataset shape: {data.shape}")
+data.head().T
+
+"""
+The shape of the DataFrame shows there are 13 columns per sample
+(12 for the features and 1 for the target label).
+
+Let's split the data into training (85%) and test (15%) sets.
+"""
+
+train_splits = []
+test_splits = []
+
+for _, group_data in data.groupby("Cover_Type"):
+ random_selection = np.random.rand(len(group_data.index)) <= 0.85
+ train_splits.append(group_data[random_selection])
+ test_splits.append(group_data[~random_selection])
+
+train_data = pd.concat(train_splits).sample(frac=1).reset_index(drop=True)
+test_data = pd.concat(test_splits).sample(frac=1).reset_index(drop=True)
+
+print(f"Train split size: {len(train_data.index)}")
+print(f"Test split size: {len(test_data.index)}")
+
+"""
+Next, store the training and test data in separate CSV files.
+"""
+
+train_data_file = "train_data.csv"
+test_data_file = "test_data.csv"
+
+train_data.to_csv(train_data_file, index=False)
+test_data.to_csv(test_data_file, index=False)
+
+"""
+## Define dataset metadata
+
+Here, we define the metadata of the dataset that will be useful for reading and parsing
+the data into input features, and encoding the input features with respect to their types.
+"""
+
+TARGET_FEATURE_NAME = "Cover_Type"
+
+TARGET_FEATURE_LABELS = ["0", "1", "2", "3", "4", "5", "6"]
+
+NUMERIC_FEATURE_NAMES = [
+ "Aspect",
+ "Elevation",
+ "Hillshade_3pm",
+ "Hillshade_9am",
+ "Hillshade_Noon",
+ "Horizontal_Distance_To_Fire_Points",
+ "Horizontal_Distance_To_Hydrology",
+ "Horizontal_Distance_To_Roadways",
+ "Slope",
+ "Vertical_Distance_To_Hydrology",
+]
+
+CATEGORICAL_FEATURES_WITH_VOCABULARY = {
+ "Soil_Type": list(data["Soil_Type"].unique()),
+ "Wilderness_Area": list(data["Wilderness_Area"].unique()),
+}
+
+CATEGORICAL_FEATURE_NAMES = list(CATEGORICAL_FEATURES_WITH_VOCABULARY.keys())
+
+FEATURE_NAMES = NUMERIC_FEATURE_NAMES + CATEGORICAL_FEATURE_NAMES
+
+COLUMN_DEFAULTS = [
+ [0] if feature_name in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME] else ["NA"]
+ for feature_name in CSV_HEADER
+]
+
+NUM_CLASSES = len(TARGET_FEATURE_LABELS)
+
+"""
+## Experiment setup
+
+Next, let's define an input function that reads and parses the file, then converts features
+and labels into a[`tf.data.Dataset`](https://www.tensorflow.org/guide/datasets)
+for training or evaluation.
+"""
+
+
+# To convert the datasets elements to from OrderedDict to Dictionary
+def process(features, target):
+ return dict(features), target
+
+
+def get_dataset_from_csv(csv_file_path, batch_size, shuffle=False):
+ dataset = tf_data.experimental.make_csv_dataset(
+ csv_file_path,
+ batch_size=batch_size,
+ column_names=CSV_HEADER,
+ column_defaults=COLUMN_DEFAULTS,
+ label_name=TARGET_FEATURE_NAME,
+ num_epochs=1,
+ header=True,
+ shuffle=shuffle,
+ ).map(process)
+ return dataset.cache()
+
+
+"""
+Here we configure the parameters and implement the procedure for running a training and
+evaluation experiment given a model.
+"""
+
+learning_rate = 0.001
+dropout_rate = 0.1
+batch_size = 265
+num_epochs = 1
+
+hidden_units = [32, 32]
+
+
+def run_experiment(model):
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
+ loss=keras.losses.SparseCategoricalCrossentropy(),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
+ )
+
+ train_dataset = get_dataset_from_csv(train_data_file, batch_size, shuffle=True)
+
+ test_dataset = get_dataset_from_csv(test_data_file, batch_size)
+
+ print("Start training the model...")
+ history = model.fit(train_dataset, epochs=num_epochs)
+ print("Model training finished")
+
+ _, accuracy = model.evaluate(test_dataset, verbose=0)
+
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+
+"""
+## Create model inputs
+
+Now, define the inputs for the models as a dictionary, where the key is the feature name,
+and the value is a `keras.layers.Input` tensor with the corresponding feature shape
+and data type.
+"""
+
+
+def create_model_inputs():
+ inputs = {}
+ for feature_name in FEATURE_NAMES:
+ if feature_name in NUMERIC_FEATURE_NAMES:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="float32"
+ )
+ else:
+ inputs[feature_name] = layers.Input(
+ name=feature_name, shape=(), dtype="string"
+ )
+ return inputs
+
+
+"""
+## Encode features
+
+We create two representations of our input features: sparse and dense:
+1. In the **sparse** representation, the categorical features are encoded with one-hot
+encoding using the `CategoryEncoding` layer. This representation can be useful for the
+model to *memorize* particular feature values to make certain predictions.
+2. In the **dense** representation, the categorical features are encoded with
+low-dimensional embeddings using the `Embedding` layer. This representation helps
+the model to *generalize* well to unseen feature combinations.
+"""
+
+
+def encode_inputs(inputs, use_embedding=False):
+ encoded_features = []
+ for feature_name in inputs:
+ if feature_name in CATEGORICAL_FEATURE_NAMES:
+ vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]
+ # Create a lookup to convert string values to an integer indices.
+ # Since we are not using a mask token nor expecting any out of vocabulary
+ # (oov) token, we set mask_token to None and num_oov_indices to 0.
+ lookup = layers.StringLookup(
+ vocabulary=vocabulary,
+ mask_token=None,
+ num_oov_indices=0,
+ output_mode="int" if use_embedding else "binary",
+ )
+ if use_embedding:
+ # Convert the string input values into integer indices.
+ encoded_feature = lookup(inputs[feature_name])
+ embedding_dims = int(math.sqrt(len(vocabulary)))
+ # Create an embedding layer with the specified dimensions.
+ embedding = layers.Embedding(
+ input_dim=len(vocabulary), output_dim=embedding_dims
+ )
+ # Convert the index values to embedding representations.
+ encoded_feature = embedding(encoded_feature)
+ else:
+ # Convert the string input values into a one hot encoding.
+ encoded_feature = lookup(
+ keras.ops.expand_dims(inputs[feature_name], -1)
+ )
+ else:
+ # Use the numerical features as-is.
+ encoded_feature = keras.ops.expand_dims(inputs[feature_name], -1)
+
+ encoded_features.append(encoded_feature)
+
+ all_features = layers.concatenate(encoded_features)
+ return all_features
+
+
+"""
+## Experiment 1: a baseline model
+
+In the first experiment, let's create a multi-layer feed-forward network,
+where the categorical features are one-hot encoded.
+"""
+
+
+def create_baseline_model():
+ inputs = create_model_inputs()
+ features = encode_inputs(inputs)
+
+ for units in hidden_units:
+ features = layers.Dense(units)(features)
+ features = layers.BatchNormalization()(features)
+ features = layers.ReLU()(features)
+ features = layers.Dropout(dropout_rate)(features)
+
+ outputs = layers.Dense(units=NUM_CLASSES, activation="softmax")(features)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+baseline_model = create_baseline_model()
+keras.utils.plot_model(baseline_model, show_shapes=True, rankdir="LR")
+
+"""
+Let's run it:
+"""
+
+run_experiment(baseline_model)
+
+"""
+The baseline linear model achieves ~76% test accuracy.
+"""
+
+"""
+## Experiment 2: Wide & Deep model
+
+In the second experiment, we create a Wide & Deep model. The wide part of the model
+a linear model, while the deep part of the model is a multi-layer feed-forward network.
+
+Use the sparse representation of the input features in the wide part of the model and the
+dense representation of the input features for the deep part of the model.
+
+Note that every input features contributes to both parts of the model with different
+representations.
+"""
+
+
+def create_wide_and_deep_model():
+ inputs = create_model_inputs()
+ wide = encode_inputs(inputs)
+ wide = layers.BatchNormalization()(wide)
+
+ deep = encode_inputs(inputs, use_embedding=True)
+ for units in hidden_units:
+ deep = layers.Dense(units)(deep)
+ deep = layers.BatchNormalization()(deep)
+ deep = layers.ReLU()(deep)
+ deep = layers.Dropout(dropout_rate)(deep)
+
+ merged = layers.concatenate([wide, deep])
+ outputs = layers.Dense(units=NUM_CLASSES, activation="softmax")(merged)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+wide_and_deep_model = create_wide_and_deep_model()
+keras.utils.plot_model(wide_and_deep_model, show_shapes=True, rankdir="LR")
+
+"""
+Let's run it:
+"""
+
+run_experiment(wide_and_deep_model)
+
+"""
+The wide and deep model achieves ~79% test accuracy.
+"""
+
+"""
+## Experiment 3: Deep & Cross model
+
+In the third experiment, we create a Deep & Cross model. The deep part of this model
+is the same as the deep part created in the previous experiment. The key idea of
+the cross part is to apply explicit feature crossing in an efficient way,
+where the degree of cross features grows with layer depth.
+"""
+
+
+def create_deep_and_cross_model():
+ inputs = create_model_inputs()
+ x0 = encode_inputs(inputs, use_embedding=True)
+
+ cross = x0
+ for _ in hidden_units:
+ units = cross.shape[-1]
+ x = layers.Dense(units)(cross)
+ cross = x0 * x + cross
+ cross = layers.BatchNormalization()(cross)
+
+ deep = x0
+ for units in hidden_units:
+ deep = layers.Dense(units)(deep)
+ deep = layers.BatchNormalization()(deep)
+ deep = layers.ReLU()(deep)
+ deep = layers.Dropout(dropout_rate)(deep)
+
+ merged = layers.concatenate([cross, deep])
+ outputs = layers.Dense(units=NUM_CLASSES, activation="softmax")(merged)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+deep_and_cross_model = create_deep_and_cross_model()
+keras.utils.plot_model(deep_and_cross_model, show_shapes=True, rankdir="LR")
+
+"""
+Let's run it:
+"""
+
+run_experiment(deep_and_cross_model)
+
+"""
+The deep and cross model achieves ~81% test accuracy.
+"""
+
+"""
+## Conclusion
+
+You can use Keras Preprocessing Layers to easily handle categorical features
+with different encoding mechanisms, including one-hot encoding and feature embedding.
+In addition, different model architectures โ like wide, deep, and cross networks
+โ have different advantages, with respect to different dataset properties.
+You can explore using them independently or combining them to achieve the best result
+for your dataset.
+"""
diff --git a/knowledge_base/timeseries/eeg_bci_ssvepformer.py b/knowledge_base/timeseries/eeg_bci_ssvepformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe4d0ec4e4c854410d63731d4be76df8970a59d0
--- /dev/null
+++ b/knowledge_base/timeseries/eeg_bci_ssvepformer.py
@@ -0,0 +1,655 @@
+"""
+Title: Electroencephalogram Signal Classification for Brain-Computer Interface
+Author: [Okba Bekhelifi](https://github.com/okbalefthanded)
+Date created: 2025/01/08
+Last modified: 2025/01/08
+Description: A Transformer based classification for EEG signal for BCI.
+Accelerator: GPU
+"""
+
+"""
+# Introduction
+
+This tutorial will explain how to build a Transformer based Neural Network to classify
+Brain-Computer Interface (BCI) Electroencephalograpy (EEG) data recorded in a
+Steady-State Visual Evoked Potentials (SSVEPs) experiment for the application of a
+brain-controlled speller.
+
+The tutorial reproduces an experiment from the SSVEPFormer study [1]
+( [arXiv preprint](https://arxiv.org/abs/2210.04172) /
+[Peer-Reviewed paper](https://www.sciencedirect.com/science/article/abs/pii/S0893608023002319) ).
+This model was the first Transformer based model to be introduced for SSVEP data classification,
+we will test it on the Nakanishi et al. [2] public dataset as dataset 1 from the paper.
+
+The process follows an inter-subject classification experiment. Given N subject data in
+the dataset, the training data partition contains data from N-1 subject and the remaining
+single subject data is used for testing. the training set does not contain any sample from
+the testing subject. This way we construct a true subject-independent model. We keep the
+same parameters and settings as the original paper in all processing operations from
+preprocessing to training.
+
+
+The tutorial begins with a quick BCI and dataset description then, we go through the
+technicalities following these sections:
+- Setup, and imports.
+- Dataset download and extraction.
+- Data preprocessing: EEG data filtering, segmentation and visualization of raw and
+filtered data, and frequency response for a well performing participant.
+- Layers and model creation.
+- Evaluation: a single participant data classification as an example then the total
+participants data classification.
+- Visulization: we show the results of training and inference times comparison among
+the Keras 3 available backends (JAX, Tensorflow, and PyTorch) on three different GPUs.
+- Conclusion: final discussion and remarks.
+
+"""
+
+"""
+# Dataset description
+
+## BCI and SSVEP:
+A BCI offers the ability to communicate using only brain activity, this can be achieved
+through exogenous stimuli that generate specific responses indicating the intent of the
+subject. the responses are elicited when the user focuses their attention on the target
+stimulus. We can use visual stimuli by presenting the subject with a set of options
+typically on a monitor as a grid to select one command at a time. Each stimulus will
+flicker following a fixed frequency and phase, the resulting EEG recorded at occipital
+and occipito-parietal areas of the cortex (visual cortex) will have higher power in the
+associated frequency with the stimulus where the subject was looking at. This type of
+BCI paradigm is called the Steady-State Visual Evoked Potentials (SSVEPs) and became
+widely used for multiple application due to its reliability and high perfromance in
+classification and rapidity as a 1-second of EEG is sufficient making a command. Other
+types of brain responses exists and do not require external stimulations, however they
+are less reliable.
+[Demo video](https://www.youtube.com/watch?v=VtA6jsEMIug)
+
+This tutorials uses the 12 commands (class) public SSVEP dataset [2] with the following
+interface emulating a phone dialing numbers.
+
+
+The dataset was recorded with 10 participants, each faced the above 12 SSVEP stimuli (A).
+The stimulation frequencies ranged from 9.25Hz to 14.75 Hz with 0.5Hz step, and phases
+ranged from 0 to 1.5 ฯ with 0.5 ฯ step for each row.(B). The EEG signal was acquired
+with 8 electrodes (channels) (PO7, PO3, POz,
+PO4, PO8, O1, Oz, O2) sampling frequency was 2048 Hz then the stored data were
+downsampled to 256 Hz. The subjects completed 15 blocks of recordings, each consisted
+of 12 random ordered stimulations (1 for each class) of 4 seconds each. In total,
+each subject conducted 180 trials.
+
+
+"""
+
+"""
+# Setup
+"""
+
+"""
+## Select JAX backend
+
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax"
+
+"""
+## Install dependencies
+
+"""
+
+"""shell
+pip install -q numpy
+pip install -q scipy
+pip install -q matplotlib
+"""
+
+"""
+# Imports
+
+
+"""
+
+# deep learning libraries
+from keras import backend as K
+from keras import layers
+import keras
+
+# visualization and signal processing imports
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import numpy as np
+from scipy.signal import butter, filtfilt
+from scipy.io import loadmat
+
+# setting the backend, seed and Keras channel format
+K.set_image_data_format("channels_first")
+keras.utils.set_random_seed(42)
+
+"""
+# Download and extract dataset
+
+
+"""
+
+"""
+## Nakanishi et. al 2015 [DataSet Repo](https://github.com/mnakanishi/12JFPM_SSVEP)
+"""
+
+"""shell
+curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip
+unzip cca_ssvep.zip
+"""
+
+"""
+# Pre-Processing
+
+The preprocessing steps followed are first to read the EEG data for each subject, then
+to filter the raw data in a frequency interval where most useful information lies,
+then we select a fixed duration of signal starting from the onset of the stimulation
+(due to latency delay caused by the visual system we start we add 135 milliseconds to
+the stimulation onset). Lastly, all subjects data are concatenated in a single Tensor
+of the shape: [subjects x samples x channels x trials]. The data labels are also
+concatenated following the order of the trials in the experiments and will be a
+matrix of the shape [subjects x trials]
+(here by channels we mean electrodes, we use this notation throughout the tutorial).
+"""
+
+
+def raw_signal(folder, fs=256, duration=1.0, onset=0.135):
+ """selecting a 1-second segment of the raw EEG signal for
+ subject 1.
+ """
+ onset = 38 + int(onset * fs)
+ end = int(duration * fs)
+ data = loadmat(f"{folder}/s1.mat")
+ # samples, channels, trials, targets
+ eeg = data["eeg"].transpose((2, 1, 3, 0))
+ # segment data
+ eeg = eeg[onset : onset + end, :, :, :]
+ return eeg
+
+
+def segment_eeg(
+ folder, elecs=None, fs=256, duration=1.0, band=[5.0, 45.0], order=4, onset=0.135
+):
+ """Filtering and segmenting EEG signals for all subjects."""
+ n_subejects = 10
+ onset = 38 + int(onset * fs)
+ end = int(duration * fs)
+ X, Y = [], [] # empty data and labels
+
+ for subj in range(1, n_subejects + 1):
+ data = loadmat(f"{data_folder}/s{subj}.mat")
+ # samples, channels, trials, targets
+ eeg = data["eeg"].transpose((2, 1, 3, 0))
+ # filter data
+ eeg = filter_eeg(eeg, fs=fs, band=band, order=order)
+ # segment data
+ eeg = eeg[onset : onset + end, :, :, :]
+ # reshape labels
+ samples, channels, blocks, targets = eeg.shape
+ y = np.tile(np.arange(1, targets + 1), (blocks, 1))
+ y = y.reshape((1, blocks * targets), order="F")
+
+ X.append(eeg.reshape((samples, channels, blocks * targets), order="F"))
+ Y.append(y)
+
+ X = np.array(X, dtype=np.float32, order="F")
+ Y = np.array(Y, dtype=np.float32).squeeze()
+
+ return X, Y
+
+
+def filter_eeg(data, fs=256, band=[5.0, 45.0], order=4):
+ """Filter EEG signal using a zero-phase IIR filter"""
+ B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")
+ return filtfilt(B, A, data, axis=0)
+
+
+"""
+## Segment data into epochs
+"""
+
+data_folder = os.path.abspath("./cca_ssvep")
+band = [8, 64] # low-frequency / high-frequency cutoffS
+order = 4 # filter order
+fs = 256 # sampling frequency
+duration = 1.0 # 1 second
+
+# raw signal
+X_raw = raw_signal(data_folder, fs=fs, duration=duration)
+print(
+ f"A single subject raw EEG (X_raw) shape: {X_raw.shape} [Samples x Channels x Blocks x Targets]"
+)
+
+# segmented signal
+X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)
+print(
+ f"Full training data (X) shape: {X.shape} [Subject x Samples x Channels x Trials]"
+)
+print(f"data labels (Y) shape: {Y.shape} [Subject x Trials]")
+
+samples = X.shape[1]
+time = np.linspace(0.0, samples / fs, samples) * 1000
+
+"""
+## Visualize EEG signal
+"""
+
+"""
+## EEG in time
+
+Raw EEG vs Filtered EEG
+The same 1-second recording for subject s1 at Oz (central electrode in the visual cortex,
+back of the head) is illustrated. left is the raw EEG as recorded and in the right is
+the filtered EEG on the [8, 64] Hz frequency band. we see less noise and
+normalized amplitude values in a natural EEG range.
+"""
+
+
+elec = 6 # Oz channel
+
+x_label = "Time (ms)"
+y_label = "Voltage (uV)"
+# Create subplots
+fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
+
+# Plot data on the first subplot
+ax1.plot(time, X_raw[:, elec, 0, 0], "r-")
+ax1.set_xlabel(x_label)
+ax1.set_ylabel(y_label)
+ax1.set_title("Raw EEG : 1 second at Oz ")
+
+# Plot data on the second subplot
+ax2.plot(time, X[0, :, elec, 0], "b-")
+ax2.set_xlabel(x_label)
+ax2.set_ylabel(y_label)
+ax2.set_title("Filtered EEG between 8-64 Hz: 1 second at Oz")
+
+# Adjust spacing between subplots
+plt.tight_layout()
+
+# Show the plot
+plt.show()
+
+"""
+## EEG frequency representation
+
+Using the welch method, we visualize the frequency power for a well performing subject
+for the entire 4 seconds EEG recording at Oz electrode for each stimuli. the red peaks
+indicate the stimuli fundamental frequency and the 2nd harmonics (double the fundamental
+frequency). we see clear peaks showing the high responses from that subject which means
+that this subject is a good candidate for SSVEP BCI control. In many cases the peaks
+are weak or absent, meaning that subject do not achieve the task correctly.
+
+
+"""
+
+
+"""
+# Create Layers and model
+
+Create Layers in a cross-framework custom component fashion.
+In the SSVEPFormer, the data is first transformed to the frequency domain through
+Fast-Fourier transform (FFT), to construct a complex spectrum presentation consisting of
+the concatenation of frequency and phase information in a fixed frequency band. To keep
+the model in an end-to-end format, we implement the complex spectrum transformation as
+non-trainable layer.
+
+
+The SSVEPFormer unlike the Transformer architecture does not contain positional encoding/embedding
+layers which replaced a channel combination block that has a layer of Conv1D layer of 1
+kernel size with double input channels (double the count of electrodes) number of filters,
+and LayerNorm, Gelu activation and dropout.
+Another difference with Transformers is the absence of multi-head attention layers with
+attention mechanism.
+The model encoder contains two identical and successive blocks. Each block has two
+sub-blocks of CNN module and MLP module. the CNN module consists of a LayerNorm, Conv1D
+with the same number of filters as channel combination, LayerNorm, Gelu, Dropout and an
+residual connection. The MLP module consists of a LayerNorm, Dense layer, Gelu, droput
+and residual connection. the Dense layer is applied on each channel separately.
+The last block of the model is MLP head with Flatten layer, Dropout, Dense, LayerNorm,
+Gelu, Dropout and Dense layer with softmax acitvation.
+All trainable weights are initialized by a normal distribution with 0 mean and 0.01
+standard deviation as state in the original paper.
+"""
+
+
+class ComplexSpectrum(keras.layers.Layer):
+ def __init__(self, nfft=512, fft_start=8, fft_end=64):
+ super().__init__()
+ self.nfft = nfft
+ self.fft_start = fft_start
+ self.fft_end = fft_end
+
+ def call(self, x):
+ samples = x.shape[-1]
+ x = keras.ops.rfft(x, fft_length=self.nfft)
+ real = x[0] / samples
+ imag = x[1] / samples
+ real = real[:, :, self.fft_start : self.fft_end]
+ imag = imag[:, :, self.fft_start : self.fft_end]
+ x = keras.ops.concatenate((real, imag), axis=-1)
+ return x
+
+
+class ChannelComb(keras.layers.Layer):
+ def __init__(self, n_channels, drop_rate=0.5):
+ super().__init__()
+ self.conv = layers.Conv1D(
+ 2 * n_channels,
+ 1,
+ padding="same",
+ kernel_initializer=keras.initializers.RandomNormal(
+ mean=0.0, stddev=0.01, seed=None
+ ),
+ )
+ self.normalization = layers.LayerNormalization()
+ self.activation = layers.Activation(activation="gelu")
+ self.drop = layers.Dropout(drop_rate)
+
+ def call(self, x):
+ x = self.conv(x)
+ x = self.normalization(x)
+ x = self.activation(x)
+ x = self.drop(x)
+ return x
+
+
+class ConvAttention(keras.layers.Layer):
+ def __init__(self, n_channels, drop_rate=0.5):
+ super().__init__()
+ self.norm = layers.LayerNormalization()
+ self.conv = layers.Conv1D(
+ 2 * n_channels,
+ 31,
+ padding="same",
+ kernel_initializer=keras.initializers.RandomNormal(
+ mean=0.0, stddev=0.01, seed=None
+ ),
+ )
+ self.activation = layers.Activation(activation="gelu")
+ self.drop = layers.Dropout(drop_rate)
+
+ def call(self, x):
+ input = x
+ x = self.norm(x)
+ x = self.conv(x)
+ x = self.activation(x)
+ x = self.drop(x)
+ x = x + input
+ return x
+
+
+class ChannelMLP(keras.layers.Layer):
+ def __init__(self, n_features, drop_rate=0.5):
+ super().__init__()
+ self.norm = layers.LayerNormalization()
+ self.mlp = layers.Dense(
+ 2 * n_features,
+ kernel_initializer=keras.initializers.RandomNormal(
+ mean=0.0, stddev=0.01, seed=None
+ ),
+ )
+ self.activation = layers.Activation(activation="gelu")
+ self.drop = layers.Dropout(drop_rate)
+ self.cat = layers.Concatenate(axis=1)
+
+ def call(self, x):
+ input = x
+ channels = x.shape[1] # x shape : NCF
+ x = self.norm(x)
+ output_channels = []
+ for i in range(channels):
+ c = self.mlp(x[:, :, i])
+ c = layers.Reshape([1, -1])(c)
+ output_channels.append(c)
+ x = self.cat(output_channels)
+ x = self.activation(x)
+ x = self.drop(x)
+ x = x + input
+ return x
+
+
+class Encoder(keras.layers.Layer):
+ def __init__(self, n_channels, n_features, drop_rate=0.5):
+ super().__init__()
+ self.attention1 = ConvAttention(n_channels, drop_rate=drop_rate)
+ self.mlp1 = ChannelMLP(n_features, drop_rate=drop_rate)
+ self.attention2 = ConvAttention(n_channels, drop_rate=drop_rate)
+ self.mlp2 = ChannelMLP(n_features, drop_rate=drop_rate)
+
+ def call(self, x):
+ x = self.attention1(x)
+ x = self.mlp1(x)
+ x = self.attention2(x)
+ x = self.mlp2(x)
+ return x
+
+
+class MlpHead(keras.layers.Layer):
+ def __init__(self, n_classes, drop_rate=0.5):
+ super().__init__()
+ self.flatten = layers.Flatten()
+ self.drop = layers.Dropout(drop_rate)
+ self.linear1 = layers.Dense(
+ 6 * n_classes,
+ kernel_initializer=keras.initializers.RandomNormal(
+ mean=0.0, stddev=0.01, seed=None
+ ),
+ )
+ self.norm = layers.LayerNormalization()
+ self.activation = layers.Activation(activation="gelu")
+ self.drop2 = layers.Dropout(drop_rate)
+ self.linear2 = layers.Dense(
+ n_classes,
+ kernel_initializer=keras.initializers.RandomNormal(
+ mean=0.0, stddev=0.01, seed=None
+ ),
+ )
+
+ def call(self, x):
+ x = self.flatten(x)
+ x = self.drop(x)
+ x = self.linear1(x)
+ x = self.norm(x)
+ x = self.activation(x)
+ x = self.drop2(x)
+ x = self.linear2(x)
+ return x
+
+
+"""
+### Create a sequential model with the layers above
+"""
+
+
+def create_ssvepformer(
+ input_shape, fs, resolution, fq_band, n_channels, n_classes, drop_rate
+):
+ nfft = round(fs / resolution)
+ fft_start = int(fq_band[0] / resolution)
+ fft_end = int(fq_band[1] / resolution) + 1
+ n_features = fft_end - fft_start
+
+ model = keras.Sequential(
+ [
+ keras.Input(shape=input_shape),
+ ComplexSpectrum(nfft, fft_start, fft_end),
+ ChannelComb(n_channels=n_channels, drop_rate=drop_rate),
+ Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
+ Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
+ MlpHead(n_classes=n_classes, drop_rate=drop_rate),
+ layers.Activation(activation="softmax"),
+ ]
+ )
+
+ return model
+
+
+"""
+# Evaluation
+"""
+
+# Training settings same as the original paper
+BATCH_SIZE = 128
+EPOCHS = 100
+LR = 0.001 # learning rate
+WD = 0.001 # weight decay
+MOMENTUM = 0.9
+DROP_RATE = 0.5
+
+resolution = 0.25
+
+"""
+From the entire dataset we select folds for each subject evaluation.
+construct a tf dataset object for train and testing data and create the model and launch
+the training using SGD optimizer.
+"""
+
+
+def concatenate_subjects(x, y, fold):
+ X = np.concatenate([x[idx] for idx in fold], axis=-1)
+ Y = np.concatenate([y[idx] for idx in fold], axis=-1)
+ X = X.transpose((2, 1, 0)) # trials x channels x samples
+ return X, Y - 1 # transform labels to values from 0...11
+
+
+def evaluate_subject(
+ x_train,
+ y_train,
+ x_val,
+ y_val,
+ input_shape,
+ fs=256,
+ resolution=0.25,
+ band=[8, 64],
+ channels=8,
+ n_classes=12,
+ drop_rate=DROP_RATE,
+):
+
+ train_dataset = (
+ tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ .batch(BATCH_SIZE)
+ .prefetch(tf.data.AUTOTUNE)
+ )
+
+ test_dataset = (
+ tf.data.Dataset.from_tensor_slices((x_val, y_val))
+ .batch(BATCH_SIZE)
+ .prefetch(tf.data.AUTOTUNE)
+ )
+
+ model = create_ssvepformer(
+ input_shape, fs, resolution, band, channels, n_classes, drop_rate
+ )
+ sgd = keras.optimizers.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WD)
+
+ model.compile(
+ loss="sparse_categorical_crossentropy",
+ optimizer=sgd,
+ metrics=["accuracy"],
+ jit_compile=True,
+ )
+
+ history = model.fit(
+ train_dataset,
+ batch_size=BATCH_SIZE,
+ epochs=EPOCHS,
+ validation_data=test_dataset,
+ verbose=0,
+ )
+ loss, acc = model.evaluate(test_dataset)
+ return acc * 100
+
+
+"""
+## Run evaluation
+"""
+
+channels = X.shape[2]
+samples = X.shape[1]
+input_shape = (channels, samples)
+n_classes = 12
+
+model = create_ssvepformer(
+ input_shape, fs, resolution, band, channels, n_classes, DROP_RATE
+)
+model.summary()
+
+"""
+## Evaluation on all subjects following a leave-one-subject out data repartition scheme
+"""
+
+accs = np.zeros(10)
+
+for subject in range(10):
+ print(f"Testing subject: {subject+ 1}")
+
+ # create train / test folds
+ folds = np.delete(np.arange(10), subject)
+ train_index = folds
+ test_index = [subject]
+
+ # create data split for each subject
+ x_train, y_train = concatenate_subjects(X, Y, train_index)
+ x_val, y_val = concatenate_subjects(X, Y, test_index)
+
+ # train and evaluate a fold and compute the time it takes
+ acc = evaluate_subject(x_train, y_train, x_val, y_val, input_shape)
+
+ accs[subject] = acc
+
+print(f"\nAccuracy Across Subjects: {accs.mean()} % std: {np.std(accs)}")
+
+"""
+and that's it! we see how some subjects with no data on the training set still can achieve
+almost a 100% correct commands and others show poor performance around 50%. In the original
+paper using PyTorch the average accuracy was 84.04% with 17.37 std. we reached the same
+values knowing the stochastic nature of deep learning.
+"""
+
+"""
+# Visualizations
+
+Training and inference times comparison between the different backends (Jax, Tensorflow
+and PyTorch) on the three GPUs available with Colab Free/Pro/Pro+: T4, L4, A100.
+
+
+"""
+
+"""
+## Training Time
+
+
+"""
+
+"""
+# Inference Time
+
+
+"""
+
+"""
+the Jax backend was the best on training and inference in all the GPUs, the PyTorch was
+exremely slow due to the jit compilation option being disable because of the complex
+data type calculated by FFT which is not supported by the PyTorch jit compiler.
+"""
+
+"""
+# Acknowledgment
+
+I thank Chris Perry [X](https://x.com/thechrisperry) @GoogleColab for supporting this
+work with GPU compute.
+"""
+
+"""
+# References
+[1] Chen, J. et al. (2023) โA transformer-based deep neural network model for SSVEP
+classificationโ, Neural Networks, 164, pp. 521โ534. Available at: https://doi.org/10.1016/j.neunet.2023.04.045.
+
+[2] Nakanishi, M. et al. (2015) โA Comparison Study of Canonical Correlation Analysis
+Based Methods for Detecting Steady-State Visual Evoked Potentialsโ, Plos One, 10(10), p.
+e0140703. Available at: https://doi.org/10.1371/journal.pone.0140703
+"""
diff --git a/knowledge_base/timeseries/eeg_signal_classification.py b/knowledge_base/timeseries/eeg_signal_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cc86437d0cbd0ee09a4893fb588a3455b64316c
--- /dev/null
+++ b/knowledge_base/timeseries/eeg_signal_classification.py
@@ -0,0 +1,560 @@
+"""
+Title: Electroencephalogram Signal Classification for action identification
+Author: [Suvaditya Mukherjee](https://github.com/suvadityamuk)
+Date created: 2022/11/03
+Last modified: 2022/11/05
+Description: Training a Convolutional model to classify EEG signals produced by exposure to certain stimuli.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The following example explores how we can make a Convolution-based Neural Network to
+perform classification on Electroencephalogram signals captured when subjects were
+exposed to different stimuli.
+We train a model from scratch since such signal-classification models are fairly scarce
+in pre-trained format.
+The data we use is sourced from the UC Berkeley-Biosense Lab where the data was collected
+from 15 subjects at the same time.
+Our process is as follows:
+
+- Load the [UC Berkeley-Biosense Synchronized Brainwave Dataset](https://www.kaggle.com/datasets/berkeley-biosense/synchronized-brainwave-dataset)
+- Visualize random samples from the data
+- Pre-process, collate and scale the data to finally make a `tf.data.Dataset`
+- Prepare class weights in order to tackle major imbalances
+- Create a Conv1D and Dense-based model to perform classification
+- Define callbacks and hyperparameters
+- Train the model
+- Plot metrics from History and perform evaluation
+
+This example needs the following external dependencies (Gdown, Scikit-learn, Pandas,
+Numpy, Matplotlib). You can install it via the following commands.
+
+Gdown is an external package used to download large files from Google Drive. To know
+more, you can refer to its [PyPi page here](https://pypi.org/project/gdown)
+"""
+
+
+"""
+## Setup and Data Downloads
+
+First, lets install our dependencies:
+"""
+
+"""shell
+pip install gdown -q
+pip install scikit-learn -q
+pip install pandas -q
+pip install numpy -q
+pip install matplotlib -q
+"""
+
+"""
+Next, lets download our dataset.
+The gdown package makes it easy to download the data from Google Drive:
+"""
+
+"""shell
+gdown 1V5B7Bt6aJm0UHbR7cRKBEK8jx7lYPVuX
+# gdown will download eeg-data.csv onto the local drive for use. Total size of
+# eeg-data.csv is 105.7 MB
+"""
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import json
+import numpy as np
+import keras
+from keras import layers
+import tensorflow as tf
+from sklearn import preprocessing, model_selection
+import random
+
+QUALITY_THRESHOLD = 128
+BATCH_SIZE = 64
+SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 2
+
+"""
+## Read data from `eeg-data.csv`
+
+We use the Pandas library to read the `eeg-data.csv` file and display the first 5 rows
+using the `.head()` command
+"""
+
+eeg = pd.read_csv("eeg-data.csv")
+
+"""
+We remove unlabeled samples from our dataset as they do not contribute to the model. We
+also perform a `.drop()` operation on the columns that are not required for training data
+preparation
+"""
+
+unlabeled_eeg = eeg[eeg["label"] == "unlabeled"]
+eeg = eeg.loc[eeg["label"] != "unlabeled"]
+eeg = eeg.loc[eeg["label"] != "everyone paired"]
+
+eeg.drop(
+ [
+ "indra_time",
+ "Unnamed: 0",
+ "browser_latency",
+ "reading_time",
+ "attention_esense",
+ "meditation_esense",
+ "updatedAt",
+ "createdAt",
+ ],
+ axis=1,
+ inplace=True,
+)
+
+eeg.reset_index(drop=True, inplace=True)
+eeg.head()
+
+"""
+In the data, the samples recorded are given a score from 0 to 128 based on how
+well-calibrated the sensor was (0 being best, 200 being worst). We filter the values
+based on an arbitrary cutoff limit of 128.
+"""
+
+
+def convert_string_data_to_values(value_string):
+ str_list = json.loads(value_string)
+ return str_list
+
+
+eeg["raw_values"] = eeg["raw_values"].apply(convert_string_data_to_values)
+
+eeg = eeg.loc[eeg["signal_quality"] < QUALITY_THRESHOLD]
+eeg.head()
+
+"""
+## Visualize one random sample from the data
+"""
+
+"""
+We visualize one sample from the data to understand how the stimulus-induced signal looks
+like
+"""
+
+
+def view_eeg_plot(idx):
+ data = eeg.loc[idx, "raw_values"]
+ plt.plot(data)
+ plt.title(f"Sample random plot")
+ plt.show()
+
+
+view_eeg_plot(7)
+
+"""
+## Pre-process and collate data
+"""
+
+"""
+There are a total of 67 different labels present in the data, where there are numbered
+sub-labels. We collate them under a single label as per their numbering and replace them
+in the data itself. Following this process, we perform simple Label encoding to get them
+in an integer format.
+"""
+
+print("Before replacing labels")
+print(eeg["label"].unique(), "\n")
+print(len(eeg["label"].unique()), "\n")
+
+
+eeg.replace(
+ {
+ "label": {
+ "blink1": "blink",
+ "blink2": "blink",
+ "blink3": "blink",
+ "blink4": "blink",
+ "blink5": "blink",
+ "math1": "math",
+ "math2": "math",
+ "math3": "math",
+ "math4": "math",
+ "math5": "math",
+ "math6": "math",
+ "math7": "math",
+ "math8": "math",
+ "math9": "math",
+ "math10": "math",
+ "math11": "math",
+ "math12": "math",
+ "thinkOfItems-ver1": "thinkOfItems",
+ "thinkOfItems-ver2": "thinkOfItems",
+ "video-ver1": "video",
+ "video-ver2": "video",
+ "thinkOfItemsInstruction-ver1": "thinkOfItemsInstruction",
+ "thinkOfItemsInstruction-ver2": "thinkOfItemsInstruction",
+ "colorRound1-1": "colorRound1",
+ "colorRound1-2": "colorRound1",
+ "colorRound1-3": "colorRound1",
+ "colorRound1-4": "colorRound1",
+ "colorRound1-5": "colorRound1",
+ "colorRound1-6": "colorRound1",
+ "colorRound2-1": "colorRound2",
+ "colorRound2-2": "colorRound2",
+ "colorRound2-3": "colorRound2",
+ "colorRound2-4": "colorRound2",
+ "colorRound2-5": "colorRound2",
+ "colorRound2-6": "colorRound2",
+ "colorRound3-1": "colorRound3",
+ "colorRound3-2": "colorRound3",
+ "colorRound3-3": "colorRound3",
+ "colorRound3-4": "colorRound3",
+ "colorRound3-5": "colorRound3",
+ "colorRound3-6": "colorRound3",
+ "colorRound4-1": "colorRound4",
+ "colorRound4-2": "colorRound4",
+ "colorRound4-3": "colorRound4",
+ "colorRound4-4": "colorRound4",
+ "colorRound4-5": "colorRound4",
+ "colorRound4-6": "colorRound4",
+ "colorRound5-1": "colorRound5",
+ "colorRound5-2": "colorRound5",
+ "colorRound5-3": "colorRound5",
+ "colorRound5-4": "colorRound5",
+ "colorRound5-5": "colorRound5",
+ "colorRound5-6": "colorRound5",
+ "colorInstruction1": "colorInstruction",
+ "colorInstruction2": "colorInstruction",
+ "readyRound1": "readyRound",
+ "readyRound2": "readyRound",
+ "readyRound3": "readyRound",
+ "readyRound4": "readyRound",
+ "readyRound5": "readyRound",
+ "colorRound1": "colorRound",
+ "colorRound2": "colorRound",
+ "colorRound3": "colorRound",
+ "colorRound4": "colorRound",
+ "colorRound5": "colorRound",
+ }
+ },
+ inplace=True,
+)
+
+print("After replacing labels")
+print(eeg["label"].unique())
+print(len(eeg["label"].unique()))
+
+le = preprocessing.LabelEncoder() # Generates a look-up table
+le.fit(eeg["label"])
+eeg["label"] = le.transform(eeg["label"])
+
+"""
+We extract the number of unique classes present in the data
+"""
+
+num_classes = len(eeg["label"].unique())
+print(num_classes)
+
+"""
+We now visualize the number of samples present in each class using a Bar plot.
+"""
+
+plt.bar(range(num_classes), eeg["label"].value_counts())
+plt.title("Number of samples per class")
+plt.show()
+
+"""
+## Scale and split data
+"""
+
+"""
+We perform a simple Min-Max scaling to bring the value-range between 0 and 1. We do not
+use Standard Scaling as the data does not follow a Gaussian distribution.
+"""
+
+scaler = preprocessing.MinMaxScaler()
+series_list = [
+ scaler.fit_transform(np.asarray(i).reshape(-1, 1)) for i in eeg["raw_values"]
+]
+
+labels_list = [i for i in eeg["label"]]
+
+"""
+We now create a Train-test split with a 15% holdout set. Following this, we reshape the
+data to create a sequence of length 512. We also convert the labels from their current
+label-encoded form to a one-hot encoding to enable use of several different
+`keras.metrics` functions.
+"""
+
+x_train, x_test, y_train, y_test = model_selection.train_test_split(
+ series_list, labels_list, test_size=0.15, random_state=42, shuffle=True
+)
+
+print(
+ f"Length of x_train : {len(x_train)}\nLength of x_test : {len(x_test)}\nLength of y_train : {len(y_train)}\nLength of y_test : {len(y_test)}"
+)
+
+x_train = np.asarray(x_train).astype(np.float32).reshape(-1, 512, 1)
+y_train = np.asarray(y_train).astype(np.float32).reshape(-1, 1)
+y_train = keras.utils.to_categorical(y_train)
+
+x_test = np.asarray(x_test).astype(np.float32).reshape(-1, 512, 1)
+y_test = np.asarray(y_test).astype(np.float32).reshape(-1, 1)
+y_test = keras.utils.to_categorical(y_test)
+
+"""
+## Prepare `tf.data.Dataset`
+"""
+
+"""
+We now create a `tf.data.Dataset` from this data to prepare it for training. We also
+shuffle and batch the data for use later.
+"""
+
+train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+
+train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
+test_dataset = test_dataset.batch(BATCH_SIZE)
+
+"""
+## Make Class Weights using Naive method
+"""
+
+"""
+As we can see from the plot of number of samples per class, the dataset is imbalanced.
+Hence, we **calculate weights for each class** to make sure that the model is trained in
+a fair manner without preference to any specific class due to greater number of samples.
+
+We use a naive method to calculate these weights, finding an **inverse proportion** of
+each class and using that as the weight.
+"""
+
+vals_dict = {}
+for i in eeg["label"]:
+ if i in vals_dict.keys():
+ vals_dict[i] += 1
+ else:
+ vals_dict[i] = 1
+total = sum(vals_dict.values())
+
+# Formula used - Naive method where
+# weight = 1 - (no. of samples present / total no. of samples)
+# So more the samples, lower the weight
+
+weight_dict = {k: (1 - (v / total)) for k, v in vals_dict.items()}
+print(weight_dict)
+
+"""
+## Define simple function to plot all the metrics present in a `keras.callbacks.History`
+object
+"""
+
+
+def plot_history_metrics(history: keras.callbacks.History):
+ total_plots = len(history.history)
+ cols = total_plots // 2
+
+ rows = total_plots // cols
+
+ if total_plots % cols != 0:
+ rows += 1
+
+ pos = range(1, total_plots + 1)
+ plt.figure(figsize=(15, 10))
+ for i, (key, value) in enumerate(history.history.items()):
+ plt.subplot(rows, cols, pos[i])
+ plt.plot(range(len(value)), value)
+ plt.title(str(key))
+ plt.show()
+
+
+"""
+## Define function to generate Convolutional model
+"""
+
+
+def create_model():
+ input_layer = keras.Input(shape=(512, 1))
+
+ x = layers.Conv1D(
+ filters=32, kernel_size=3, strides=2, activation="relu", padding="same"
+ )(input_layer)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv1D(
+ filters=64, kernel_size=3, strides=2, activation="relu", padding="same"
+ )(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv1D(
+ filters=128, kernel_size=5, strides=2, activation="relu", padding="same"
+ )(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv1D(
+ filters=256, kernel_size=5, strides=2, activation="relu", padding="same"
+ )(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv1D(
+ filters=512, kernel_size=7, strides=2, activation="relu", padding="same"
+ )(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv1D(
+ filters=1024,
+ kernel_size=7,
+ strides=2,
+ activation="relu",
+ padding="same",
+ )(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Dropout(0.2)(x)
+
+ x = layers.Flatten()(x)
+
+ x = layers.Dense(4096, activation="relu")(x)
+ x = layers.Dropout(0.2)(x)
+
+ x = layers.Dense(
+ 2048, activation="relu", kernel_regularizer=keras.regularizers.L2()
+ )(x)
+ x = layers.Dropout(0.2)(x)
+
+ x = layers.Dense(
+ 1024, activation="relu", kernel_regularizer=keras.regularizers.L2()
+ )(x)
+ x = layers.Dropout(0.2)(x)
+ x = layers.Dense(
+ 128, activation="relu", kernel_regularizer=keras.regularizers.L2()
+ )(x)
+ output_layer = layers.Dense(num_classes, activation="softmax")(x)
+
+ return keras.Model(inputs=input_layer, outputs=output_layer)
+
+
+"""
+## Get Model summary
+"""
+
+conv_model = create_model()
+conv_model.summary()
+
+"""
+## Define callbacks, optimizer, loss and metrics
+"""
+
+"""
+We set the number of epochs at 30 after performing extensive experimentation. It was seen
+that this was the optimal number, after performing Early-Stopping analysis as well.
+We define a Model Checkpoint callback to make sure that we only get the best model
+weights.
+We also define a ReduceLROnPlateau as there were several cases found during
+experimentation where the loss stagnated after a certain point. On the other hand, a
+direct LRScheduler was found to be too aggressive in its decay.
+"""
+
+epochs = 30
+
+callbacks = [
+ keras.callbacks.ModelCheckpoint(
+ "best_model.keras", save_best_only=True, monitor="loss"
+ ),
+ keras.callbacks.ReduceLROnPlateau(
+ monitor="val_top_k_categorical_accuracy",
+ factor=0.2,
+ patience=2,
+ min_lr=0.000001,
+ ),
+]
+
+optimizer = keras.optimizers.Adam(amsgrad=True, learning_rate=0.001)
+loss = keras.losses.CategoricalCrossentropy()
+
+"""
+## Compile model and call `model.fit()`
+"""
+
+"""
+We use the `Adam` optimizer since it is commonly considered the best choice for
+preliminary training, and was found to be the best optimizer.
+We use `CategoricalCrossentropy` as the loss as our labels are in a one-hot-encoded form.
+
+We define the `TopKCategoricalAccuracy(k=3)`, `AUC`, `Precision` and `Recall` metrics to
+further aid in understanding the model better.
+"""
+
+conv_model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=[
+ keras.metrics.TopKCategoricalAccuracy(k=3),
+ keras.metrics.AUC(),
+ keras.metrics.Precision(),
+ keras.metrics.Recall(),
+ ],
+)
+
+conv_model_history = conv_model.fit(
+ train_dataset,
+ epochs=epochs,
+ callbacks=callbacks,
+ validation_data=test_dataset,
+ class_weight=weight_dict,
+)
+
+"""
+## Visualize model metrics during training
+"""
+
+"""
+We use the function defined above to see model metrics during training.
+"""
+
+plot_history_metrics(conv_model_history)
+
+"""
+## Evaluate model on test data
+"""
+
+loss, accuracy, auc, precision, recall = conv_model.evaluate(test_dataset)
+print(f"Loss : {loss}")
+print(f"Top 3 Categorical Accuracy : {accuracy}")
+print(f"Area under the Curve (ROC) : {auc}")
+print(f"Precision : {precision}")
+print(f"Recall : {recall}")
+
+
+def view_evaluated_eeg_plots(model):
+ start_index = random.randint(10, len(eeg))
+ end_index = start_index + 11
+ data = eeg.loc[start_index:end_index, "raw_values"]
+ data_array = [scaler.fit_transform(np.asarray(i).reshape(-1, 1)) for i in data]
+ data_array = [np.asarray(data_array).astype(np.float32).reshape(-1, 512, 1)]
+ original_labels = eeg.loc[start_index:end_index, "label"]
+ predicted_labels = np.argmax(model.predict(data_array, verbose=0), axis=1)
+ original_labels = [
+ le.inverse_transform(np.array(label).reshape(-1))[0]
+ for label in original_labels
+ ]
+ predicted_labels = [
+ le.inverse_transform(np.array(label).reshape(-1))[0]
+ for label in predicted_labels
+ ]
+ total_plots = 12
+ cols = total_plots // 3
+ rows = total_plots // cols
+ if total_plots % cols != 0:
+ rows += 1
+ pos = range(1, total_plots + 1)
+ fig = plt.figure(figsize=(20, 10))
+ for i, (plot_data, og_label, pred_label) in enumerate(
+ zip(data, original_labels, predicted_labels)
+ ):
+ plt.subplot(rows, cols, pos[i])
+ plt.plot(plot_data)
+ plt.title(f"Actual Label : {og_label}\nPredicted Label : {pred_label}")
+ fig.subplots_adjust(hspace=0.5)
+ plt.show()
+
+
+view_evaluated_eeg_plots(conv_model)
diff --git a/knowledge_base/timeseries/event_classification_for_payment_card_fraud_detection.py b/knowledge_base/timeseries/event_classification_for_payment_card_fraud_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..65632d9a80c4b80f72c3bbab59343375d2804771
--- /dev/null
+++ b/knowledge_base/timeseries/event_classification_for_payment_card_fraud_detection.py
@@ -0,0 +1,541 @@
+"""
+Title: Event classification for payment card fraud detection
+Author: [achoum](https://github.com/achoum/)
+Date created: 2024/02/01
+Last modified: 2024/02/01
+Description: Detection of fraudulent payment card transactions using Temporian and a feed-forward neural network.
+Accelerator: GPU
+"""
+
+"""
+This notebook depends on Keras 3, Temporian, and a few other libraries. You can
+install them as follow:
+
+```shell
+pip install temporian keras pandas tf-nightly scikit-learn -U
+```
+"""
+
+import keras # To train the Machine Learning model
+import temporian as tp # To convert transactions into tabular data
+
+import numpy as np
+import os
+import pandas as pd
+import datetime
+import math
+import tensorflow as tf
+from sklearn.metrics import RocCurveDisplay
+
+"""
+## Introduction
+
+Payment fraud detection is critical for banks, businesses, and consumers. In
+Europe alone, fraudulent transactions were estimated at
+[โฌ1.89 billion in 2019](https://www.ecb.europa.eu/pub/pdf/cardfraud/ecb.cardfraudreport202110~cac4c418e8.en.pdf).
+Worldwide, approximately
+[3.6%](https://www.cybersource.com/content/dam/documents/campaign/fraud-report/global-fraud-report-2022.pdf)
+of commerce revenue is lost to fraud. In this notebook, we train and evaluate a
+model to detect fraudulent transactions using the synthetic dataset attached to
+the book
+[Reproducible Machine Learning for Credit Card Fraud Detection](https://fraud-detection-handbook.github.io/fraud-detection-handbook/Foreword.html)
+by Le Borgne et al.
+
+Fraudulent transactions often cannot be detected by looking at transactions in
+isolation. Instead, fraudulent transactions are detected by looking at patterns
+across multiple transactions from the same user, to the same merchant, or with
+other types of relationships. To express these relationships in a way that is
+understandable by a machine learning model, and to augment features with feature
+ engineering, we We use the
+ [Temporian](https://temporian.readthedocs.io/en/latest) preprocessing library.
+
+We preprocess a transaction dataset into a tabular dataset and use a
+feed-forward neural network to learn the patterns of fraud and make predictions.
+
+## Loading the dataset
+
+The dataset contains payment transactions sampled between April 1, 2018 and
+September 30, 2018. The transactions are stored in CSV files, one for each day.
+
+**Note:** Downloading the dataset takes ~1 minute.
+"""
+
+start_date = datetime.date(2018, 4, 1)
+end_date = datetime.date(2018, 9, 30)
+
+# Load the dataset as a Pandas dataframe.
+cache_path = "fraud_detection_cache.csv"
+if not os.path.exists(cache_path):
+ print("Download dataset")
+ dataframes = []
+ num_files = (end_date - start_date).days
+ counter = 0
+ while start_date <= end_date:
+ if counter % (num_files // 10) == 0:
+ print(f"[{100 * (counter+1) // num_files}%]", end="", flush=True)
+ print(".", end="", flush=True)
+ url = f"https://github.com/Fraud-Detection-Handbook/simulated-data-raw/raw/6e67dbd0a3bfe0d7ec33abc4bce5f37cd4ff0d6a/data/{start_date}.pkl"
+ dataframes.append(pd.read_pickle(url))
+ start_date += datetime.timedelta(days=1)
+ counter += 1
+ print("done", flush=True)
+ transactions_dataframe = pd.concat(dataframes)
+ transactions_dataframe.to_csv(cache_path, index=False)
+else:
+ print("Load dataset from cache")
+ transactions_dataframe = pd.read_csv(
+ cache_path, dtype={"CUSTOMER_ID": bytes, "TERMINAL_ID": bytes}
+ )
+
+print(f"Found {len(transactions_dataframe)} transactions")
+
+"""
+Each transaction is represented by a single row, with the following columns of
+interest:
+
+- **TX_DATETIME**: The date and time of the transaction.
+- **CUSTOMER_ID**: The unique identifier of the customer.
+- **TERMINAL_ID**: The identifier of the terminal where the transaction was
+ made.
+- **TX_AMOUNT**: The amount of the transaction.
+- **TX_FRAUD**: Whether the transaction is fraudulent (1) or not (0).
+"""
+
+transactions_dataframe = transactions_dataframe[
+ ["TX_DATETIME", "CUSTOMER_ID", "TERMINAL_ID", "TX_AMOUNT", "TX_FRAUD"]
+]
+
+transactions_dataframe.head(4)
+
+"""
+The dataset is highly imbalanced, with the majority of transactions being
+legitimate.
+"""
+
+fraudulent_rate = transactions_dataframe["TX_FRAUD"].mean()
+print("Rate of fraudulent transactions:", fraudulent_rate)
+
+"""
+The
+[pandas dataframe](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html)
+is converted into a
+[Temporian EventSet](https://temporian.readthedocs.io/en/latest/reference/temporian/EventSet/),
+which is better suited for the data exploration and feature preprocessing of the
+ next steps.
+"""
+
+transactions_evset = tp.from_pandas(transactions_dataframe, timestamps="TX_DATETIME")
+
+transactions_evset
+
+"""
+It is possible to plot the entire dataset, but the resulting plot will be
+difficult to read. Instead, we can group the transactions per client.
+"""
+
+transactions_evset.add_index("CUSTOMER_ID").plot(indexes="3774")
+
+"""
+Note the few fraudulent transactions for this client.
+
+## Preparing the training data
+
+Fraudulent transactions in isolation cannot be detected. Instead, we need to
+connect related transactions. For each transaction, we compute the sum and count
+of transactions for the same terminal in the last `n` days. Because we don't
+know the correct value for `n`, we use multiple values for `n` and compute a
+set of features for each of them.
+"""
+
+# Group the transactions per terminal
+transactions_per_terminal = transactions_evset.add_index("TERMINAL_ID")
+
+# Moving statistics per terminal
+tmp_features = []
+for n in [7, 14, 28]:
+ tmp_features.append(
+ transactions_per_terminal["TX_AMOUNT"]
+ .moving_sum(tp.duration.days(n))
+ .rename(f"sum_transactions_{n}_days")
+ )
+
+ tmp_features.append(
+ transactions_per_terminal.moving_count(tp.duration.days(n)).rename(
+ f"count_transactions_{n}_days"
+ )
+ )
+
+feature_set_1 = tp.glue(*tmp_features)
+
+feature_set_1
+
+"""
+Let's look at the features of terminal "3774".
+"""
+
+feature_set_1.plot(indexes="3774")
+
+"""
+A transaction's fraudulent status is not known at the time of the transaction
+(otherwise, there would be no problem). However, the banks knows if a
+transacation is fraudulent one week after it is made. We create a set of
+features that indicate the number and ratio of fraudulent transactions in the
+last N days.
+"""
+
+# Lag the transactions by one week.
+lagged_transactions = transactions_per_terminal.lag(tp.duration.weeks(1))
+
+# Moving statistics per customer
+tmp_features = []
+for n in [7, 14, 28]:
+ tmp_features.append(
+ lagged_transactions["TX_FRAUD"]
+ .moving_sum(tp.duration.days(n), sampling=transactions_per_terminal)
+ .rename(f"count_fraud_transactions_{n}_days")
+ )
+
+ tmp_features.append(
+ lagged_transactions["TX_FRAUD"]
+ .cast(tp.float32)
+ .simple_moving_average(tp.duration.days(n), sampling=transactions_per_terminal)
+ .rename(f"rate_fraud_transactions_{n}_days")
+ )
+
+feature_set_2 = tp.glue(*tmp_features)
+
+"""
+Transaction date and time can be correlated with fraud. While each transaction
+has a timestamp, a machine learning model might struggle to consume them
+directly. Instead, we extract various informative calendar features from the
+timestamps, such as hour, day of the week (e.g., Monday, Tuesday), and day of
+the month (1-31).
+"""
+
+feature_set_3 = tp.glue(
+ transactions_per_terminal.calendar_hour(),
+ transactions_per_terminal.calendar_day_of_week(),
+)
+
+"""
+Finally, we group together all the features and the label.
+"""
+
+all_data = tp.glue(
+ transactions_per_terminal, feature_set_1, feature_set_2, feature_set_3
+).drop_index()
+
+print("All the available features:")
+all_data.schema.feature_names()
+
+"""
+We extract the name of the input features.
+"""
+
+input_feature_names = [k for k in all_data.schema.feature_names() if k.islower()]
+
+print("The model's input features:")
+input_feature_names
+
+"""
+For neural networks to work correctly, numerical inputs must be normalized. A
+common approach is to apply z-normalization, which involves subtracting the mean
+and dividing by the standard deviation estimated from the training data to each
+value. In forecasting, such z-normalization is not recommended as it would lead
+to future leakage. Specifically, to classify a transaction at time t, we cannot
+rely on data after time t since, at serving time when making a prediction at
+time t, no subsequent data is available yet. In short, at time t, we are limited
+to using data that precedes or is concurrent with time t.
+
+The solution is therefore to apply z-normalization **over time**, which means
+that we normalize each transaction using the mean and standard deviation
+computed from the past data **for that transaction**.
+
+Future leakage is pernicious. Luckily, Temporian is here to help: the only
+operator that can cause future leakage is `EventSet.leak()`. If you are not
+using `EventSet.leak()`, your preprocessing is **guaranteed** not to create
+future leakage.
+
+**Note:** For advanced pipelines, you can also check programatically that a
+feature does not depends on an `EventSet.leak()` operation.
+"""
+
+# Cast all values (e.g. ints) to floats.
+values = all_data[input_feature_names].cast(tp.float32)
+
+# Apply z-normalization overtime.
+normalized_features = (
+ values - values.simple_moving_average(math.inf)
+) / values.moving_standard_deviation(math.inf)
+
+# Restore the original name of the features.
+normalized_features = normalized_features.rename(values.schema.feature_names())
+
+print(normalized_features)
+
+"""
+The first transactions will be normalized using poor estimates of the mean and
+standard deviation since there are only a few transactions before them. To
+mitigate this issue, we remove the first week of data from the training dataset.
+
+Notice that the first values contain NaN. In Temporian, NaN represents missing
+values, and all operators handle them accordingly. For instance, when
+calculating a moving average, NaN values are not included in the calculation
+and do not generate a NaN result.
+
+However, neural networks cannot natively handle NaN values. So, we replace them
+with zeros.
+"""
+
+normalized_features = normalized_features.fillna(0.0)
+
+"""
+Finally, we group together the features and the labels.
+"""
+
+normalized_all_data = tp.glue(normalized_features, all_data["TX_FRAUD"])
+
+"""
+## Split dataset into a train, validation and test set
+
+To evaluate the quality of our machine learning model, we need training,
+validation and test sets. Since the system is dynamic (new fraud patterns are
+being created all the time), it is important for the training set to come before
+the validation set, and the validation set come before the testing set:
+
+- **Training:** April 8, 2018 to July 31, 2018
+- **Validation:** August 1, 2018 to August 31, 2018
+- **Testing:** September 1, 2018 to September 30, 2018
+
+For the example to run faster, we will effectively reduce the size of the
+training set to:
+- **Training:** July 1, 2018 to July 31, 2018
+"""
+
+# begin_train = datetime.datetime(2018, 4, 8).timestamp() # Full training dataset
+begin_train = datetime.datetime(2018, 7, 1).timestamp() # Reduced training dataset
+begin_valid = datetime.datetime(2018, 8, 1).timestamp()
+begin_test = datetime.datetime(2018, 9, 1).timestamp()
+
+is_train = (normalized_all_data.timestamps() >= begin_train) & (
+ normalized_all_data.timestamps() < begin_valid
+)
+is_valid = (normalized_all_data.timestamps() >= begin_valid) & (
+ normalized_all_data.timestamps() < begin_test
+)
+is_test = normalized_all_data.timestamps() >= begin_test
+
+"""
+`is_train`, `is_valid` and `is_test` are boolean features overtime that indicate
+the limit of the tree folds. Let's plot them.
+"""
+
+tp.plot(
+ [
+ is_train.rename("is_train"),
+ is_valid.rename("is_valid"),
+ is_test.rename("is_test"),
+ ]
+)
+
+"""
+We filter the input features and label in each fold.
+"""
+
+train_ds_evset = normalized_all_data.filter(is_train)
+valid_ds_evset = normalized_all_data.filter(is_valid)
+test_ds_evset = normalized_all_data.filter(is_test)
+
+print(f"Training examples: {train_ds_evset.num_events()}")
+print(f"Validation examples: {valid_ds_evset.num_events()}")
+print(f"Testing examples: {test_ds_evset.num_events()}")
+
+"""
+It is important to split the dataset **after** the features have been computed
+because some of the features for the training dataset are computed from
+transactions during the training window.
+
+## Create TensorFlow datasets
+
+We convert the datasets from EventSets to TensorFlow Datasets as Keras consumes
+them natively.
+"""
+
+non_batched_train_ds = tp.to_tensorflow_dataset(train_ds_evset)
+non_batched_valid_ds = tp.to_tensorflow_dataset(valid_ds_evset)
+non_batched_test_ds = tp.to_tensorflow_dataset(test_ds_evset)
+
+"""
+The following processing steps are applied using TensorFlow datasets:
+
+1. The features and labels are separated using `extract_features_and_label` in
+ the format that Keras expects.
+1. The dataset is batched, which means that the examples are grouped into
+ mini-batches.
+1. The training examples are shuffled to improve the quality of mini-batch
+ training.
+
+As we noted before, the dataset is imbalanced in the direction of legitimate
+transactions. While we want to evaluate our model on this original distribution,
+neural networks often train poorly on strongly imbalanced datasets. Therefore,
+we resample the training dataset to a ratio of 80% legitimate / 20% fraudulent
+using `rejection_resample`.
+"""
+
+
+def extract_features_and_label(example):
+ features = {k: example[k] for k in input_feature_names}
+ labels = tf.cast(example["TX_FRAUD"], tf.int32)
+ return features, labels
+
+
+# Target ratio of fraudulent transactions in the training dataset.
+target_rate = 0.2
+
+# Number of examples in a mini-batch.
+batch_size = 32
+
+train_ds = (
+ non_batched_train_ds.shuffle(10000)
+ .rejection_resample(
+ class_func=lambda x: tf.cast(x["TX_FRAUD"], tf.int32),
+ target_dist=[1 - target_rate, target_rate],
+ initial_dist=[1 - fraudulent_rate, fraudulent_rate],
+ )
+ .map(lambda _, x: x) # Remove the label copy added by "rejection_resample".
+ .batch(batch_size)
+ .map(extract_features_and_label)
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+# The test and validation dataset does not need resampling or shuffling.
+valid_ds = (
+ non_batched_valid_ds.batch(batch_size)
+ .map(extract_features_and_label)
+ .prefetch(tf.data.AUTOTUNE)
+)
+test_ds = (
+ non_batched_test_ds.batch(batch_size)
+ .map(extract_features_and_label)
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+"""
+We print the first four examples of the training dataset. This is a simple way
+to identify some of the errors that could have been made above.
+"""
+
+for features, labels in train_ds.take(1):
+ print("features")
+ for feature_name, feature_value in features.items():
+ print(f"\t{feature_name}: {feature_value[:4]}")
+ print(f"labels: {labels[:4]}")
+
+"""
+## Train the model
+
+The original dataset is transactional, but the processed data is tabular and
+only contains normalized numerical values. Therefore, we train a feed-forward
+neural network.
+"""
+
+inputs = [keras.Input(shape=(1,), name=name) for name in input_feature_names]
+x = keras.layers.concatenate(inputs)
+x = keras.layers.Dense(32, activation="sigmoid")(x)
+x = keras.layers.Dense(16, activation="sigmoid")(x)
+x = keras.layers.Dense(1, activation="sigmoid")(x)
+model = keras.Model(inputs=inputs, outputs=x)
+
+"""
+Our goal is to differentiate between the fraudulent and legitimate transactions,
+so we use a binary classification objective. Because the dataset is imbalanced,
+accuracy is not an informative metric. Instead, we evaluate the model using the
+[area under the curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve)
+(AUC).
+"""
+
+model.compile(
+ optimizer=keras.optimizers.Adam(0.01),
+ loss=keras.losses.BinaryCrossentropy(),
+ metrics=[keras.metrics.Accuracy(), keras.metrics.AUC()],
+)
+model.fit(train_ds, validation_data=valid_ds)
+
+"""
+We evaluate the model on the test dataset.
+"""
+
+model.evaluate(test_ds)
+
+"""
+With and AUC of ~83%, our simple fraud detector is showing encouraging
+results.
+
+
+Plotting the ROC curve is a good solution to understand and select the operation
+point of the model i.e. the threshold applied on the model output to
+differentiate between fraudulent and legitimate transactions.
+
+Compute the test predictions:
+"""
+
+predictions = model.predict(test_ds)
+predictions = np.nan_to_num(predictions, nan=0)
+
+"""
+Extract the labels from the test set:
+"""
+
+labels = np.concatenate([label for _, label in test_ds])
+
+"""
+Finaly, we plot the ROC curve.
+"""
+
+_ = RocCurveDisplay.from_predictions(labels, predictions)
+
+
+"""
+The Keras model is ready to be used on transactions with an unknown fraud
+status, a.k.a. serving. We save the model on disk for future use.
+
+**Note:** The model does not include the data preparation and preprocessing steps
+done in Pandas and Temporian. They have to be applied manually to the data fed
+into the model. While not demonstrated here, Temporian preprocessing can also be
+saved to disk with
+[tp.save](https://temporian.readthedocs.io/en/latest/reference/temporian/serialization/save/).
+"""
+
+model.save("fraud_detection_model.keras")
+
+"""
+The model can be later reloaded with:
+"""
+
+loaded_model = keras.saving.load_model("fraud_detection_model.keras")
+
+# Generate predictions with the loaded model on 5 test examples.
+loaded_model.predict(test_ds.rebatch(5).take(1))
+
+"""
+## Conclusion
+
+We trained a feed-forward neural network to identify fraudulent transactions. To
+feed them into the model, the transactions were preprocessed and transformed
+into a tabular dataset using
+[Temporian](https://temporian.readthedocs.io/en/latest/). Now, a question to the
+reader: What could be done to further improve the model's performance?
+
+Here are some ideas:
+
+- Train the model on the entire dataset instead of a single month of data.
+- Train the model for more epochs and use early stopping to ensure that the
+ model is fully trained without overfitting.
+- Make the feed-forward network more powerful by increasing the number of layers
+ while ensuring that the model is regularized.
+- Compute additional preprocessing features. For example, in addition to
+ aggregating transactions by terminal, aggregate transactions by client.
+- Use the Keras Tuner to perform hyperparameter tuning on the model. Note that
+ the parameters of the preprocessing (e.g., the number of days of
+ aggregations) are also hyperparameters that can be tuned.
+"""
diff --git a/knowledge_base/timeseries/timeseries_anomaly_detection.py b/knowledge_base/timeseries/timeseries_anomaly_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..452e4d761af537d444a75133a60f8db47ccf3ecf
--- /dev/null
+++ b/knowledge_base/timeseries/timeseries_anomaly_detection.py
@@ -0,0 +1,306 @@
+"""
+Title: Timeseries anomaly detection using an Autoencoder
+Author: [pavithrasv](https://github.com/pavithrasv)
+Date created: 2020/05/31
+Last modified: 2020/05/31
+Description: Detect anomalies in a timeseries using an Autoencoder.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This script demonstrates how you can use a reconstruction convolutional
+autoencoder model to detect anomalies in timeseries data.
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import pandas as pd
+import keras
+from keras import layers
+from matplotlib import pyplot as plt
+
+"""
+## Load the data
+
+We will use the [Numenta Anomaly Benchmark(NAB)](
+https://www.kaggle.com/boltzmannbrain/nab) dataset. It provides artificial
+timeseries data containing labeled anomalous periods of behavior. Data are
+ordered, timestamped, single-valued metrics.
+
+We will use the `art_daily_small_noise.csv` file for training and the
+`art_daily_jumpsup.csv` file for testing. The simplicity of this dataset
+allows us to demonstrate anomaly detection effectively.
+"""
+
+master_url_root = "https://raw.githubusercontent.com/numenta/NAB/master/data/"
+
+df_small_noise_url_suffix = "artificialNoAnomaly/art_daily_small_noise.csv"
+df_small_noise_url = master_url_root + df_small_noise_url_suffix
+df_small_noise = pd.read_csv(
+ df_small_noise_url, parse_dates=True, index_col="timestamp"
+)
+
+df_daily_jumpsup_url_suffix = "artificialWithAnomaly/art_daily_jumpsup.csv"
+df_daily_jumpsup_url = master_url_root + df_daily_jumpsup_url_suffix
+df_daily_jumpsup = pd.read_csv(
+ df_daily_jumpsup_url, parse_dates=True, index_col="timestamp"
+)
+
+"""
+## Quick look at the data
+"""
+
+print(df_small_noise.head())
+
+print(df_daily_jumpsup.head())
+
+"""
+## Visualize the data
+### Timeseries data without anomalies
+
+We will use the following data for training.
+"""
+fig, ax = plt.subplots()
+df_small_noise.plot(legend=False, ax=ax)
+plt.show()
+
+"""
+### Timeseries data with anomalies
+
+We will use the following data for testing and see if the sudden jump up in the
+data is detected as an anomaly.
+"""
+fig, ax = plt.subplots()
+df_daily_jumpsup.plot(legend=False, ax=ax)
+plt.show()
+
+"""
+## Prepare training data
+
+Get data values from the training timeseries data file and normalize the
+`value` data. We have a `value` for every 5 mins for 14 days.
+
+- 24 * 60 / 5 = **288 timesteps per day**
+- 288 * 14 = **4032 data points** in total
+"""
+
+
+# Normalize and save the mean and std we get,
+# for normalizing test data.
+training_mean = df_small_noise.mean()
+training_std = df_small_noise.std()
+df_training_value = (df_small_noise - training_mean) / training_std
+print("Number of training samples:", len(df_training_value))
+
+"""
+### Create sequences
+Create sequences combining `TIME_STEPS` contiguous data values from the
+training data.
+"""
+
+TIME_STEPS = 288
+
+
+# Generated training sequences for use in the model.
+def create_sequences(values, time_steps=TIME_STEPS):
+ output = []
+ for i in range(len(values) - time_steps + 1):
+ output.append(values[i : (i + time_steps)])
+ return np.stack(output)
+
+
+x_train = create_sequences(df_training_value.values)
+print("Training input shape: ", x_train.shape)
+
+"""
+## Build a model
+
+We will build a convolutional reconstruction autoencoder model. The model will
+take input of shape `(batch_size, sequence_length, num_features)` and return
+output of the same shape. In this case, `sequence_length` is 288 and
+`num_features` is 1.
+"""
+
+model = keras.Sequential(
+ [
+ layers.Input(shape=(x_train.shape[1], x_train.shape[2])),
+ layers.Conv1D(
+ filters=32,
+ kernel_size=7,
+ padding="same",
+ strides=2,
+ activation="relu",
+ ),
+ layers.Dropout(rate=0.2),
+ layers.Conv1D(
+ filters=16,
+ kernel_size=7,
+ padding="same",
+ strides=2,
+ activation="relu",
+ ),
+ layers.Conv1DTranspose(
+ filters=16,
+ kernel_size=7,
+ padding="same",
+ strides=2,
+ activation="relu",
+ ),
+ layers.Dropout(rate=0.2),
+ layers.Conv1DTranspose(
+ filters=32,
+ kernel_size=7,
+ padding="same",
+ strides=2,
+ activation="relu",
+ ),
+ layers.Conv1DTranspose(filters=1, kernel_size=7, padding="same"),
+ ]
+)
+model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001), loss="mse")
+model.summary()
+
+"""
+## Train the model
+
+Please note that we are using `x_train` as both the input and the target
+since this is a reconstruction model.
+"""
+
+history = model.fit(
+ x_train,
+ x_train,
+ epochs=50,
+ batch_size=128,
+ validation_split=0.1,
+ callbacks=[
+ keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, mode="min")
+ ],
+)
+
+"""
+Let's plot training and validation loss to see how the training went.
+"""
+
+plt.plot(history.history["loss"], label="Training Loss")
+plt.plot(history.history["val_loss"], label="Validation Loss")
+plt.legend()
+plt.show()
+
+"""
+## Detecting anomalies
+
+We will detect anomalies by determining how well our model can reconstruct
+the input data.
+
+
+1. Find MAE loss on training samples.
+2. Find max MAE loss value. This is the worst our model has performed trying
+to reconstruct a sample. We will make this the `threshold` for anomaly
+detection.
+3. If the reconstruction loss for a sample is greater than this `threshold`
+value then we can infer that the model is seeing a pattern that it isn't
+familiar with. We will label this sample as an `anomaly`.
+
+
+"""
+
+# Get train MAE loss.
+x_train_pred = model.predict(x_train)
+train_mae_loss = np.mean(np.abs(x_train_pred - x_train), axis=1)
+
+plt.hist(train_mae_loss, bins=50)
+plt.xlabel("Train MAE loss")
+plt.ylabel("No of samples")
+plt.show()
+
+# Get reconstruction loss threshold.
+threshold = np.max(train_mae_loss)
+print("Reconstruction error threshold: ", threshold)
+
+"""
+### Compare recontruction
+
+Just for fun, let's see how our model has recontructed the first sample.
+This is the 288 timesteps from day 1 of our training dataset.
+"""
+
+# Checking how the first sequence is learnt
+plt.plot(x_train[0])
+plt.plot(x_train_pred[0])
+plt.show()
+
+"""
+### Prepare test data
+"""
+
+
+df_test_value = (df_daily_jumpsup - training_mean) / training_std
+fig, ax = plt.subplots()
+df_test_value.plot(legend=False, ax=ax)
+plt.show()
+
+# Create sequences from test values.
+x_test = create_sequences(df_test_value.values)
+print("Test input shape: ", x_test.shape)
+
+# Get test MAE loss.
+x_test_pred = model.predict(x_test)
+test_mae_loss = np.mean(np.abs(x_test_pred - x_test), axis=1)
+test_mae_loss = test_mae_loss.reshape((-1))
+
+plt.hist(test_mae_loss, bins=50)
+plt.xlabel("test MAE loss")
+plt.ylabel("No of samples")
+plt.show()
+
+# Detect all the samples which are anomalies.
+anomalies = test_mae_loss > threshold
+print("Number of anomaly samples: ", np.sum(anomalies))
+print("Indices of anomaly samples: ", np.where(anomalies))
+
+"""
+## Plot anomalies
+
+We now know the samples of the data which are anomalies. With this, we will
+find the corresponding `timestamps` from the original test data. We will be
+using the following method to do that:
+
+Let's say time_steps = 3 and we have 10 training values. Our `x_train` will
+look like this:
+
+- 0, 1, 2
+- 1, 2, 3
+- 2, 3, 4
+- 3, 4, 5
+- 4, 5, 6
+- 5, 6, 7
+- 6, 7, 8
+- 7, 8, 9
+
+All except the initial and the final time_steps-1 data values, will appear in
+`time_steps` number of samples. So, if we know that the samples
+[(3, 4, 5), (4, 5, 6), (5, 6, 7)] are anomalies, we can say that the data point
+5 is an anomaly.
+"""
+
+# data i is an anomaly if samples [(i - timesteps + 1) to (i)] are anomalies
+anomalous_data_indices = []
+for data_idx in range(TIME_STEPS - 1, len(df_test_value) - TIME_STEPS + 1):
+ if np.all(anomalies[data_idx - TIME_STEPS + 1 : data_idx]):
+ anomalous_data_indices.append(data_idx)
+
+"""
+Let's overlay the anomalies on the original test data plot.
+"""
+
+df_subset = df_daily_jumpsup.iloc[anomalous_data_indices]
+fig, ax = plt.subplots()
+df_daily_jumpsup.plot(legend=False, ax=ax)
+df_subset.plot(legend=False, ax=ax, color="r")
+plt.show()
diff --git a/knowledge_base/timeseries/timeseries_classification_from_scratch.py b/knowledge_base/timeseries/timeseries_classification_from_scratch.py
new file mode 100755
index 0000000000000000000000000000000000000000..b04ebe8de19c9168d82aa04729ebd3f1b7ba4e81
--- /dev/null
+++ b/knowledge_base/timeseries/timeseries_classification_from_scratch.py
@@ -0,0 +1,227 @@
+"""
+Title: Timeseries classification from scratch
+Author: [hfawaz](https://github.com/hfawaz/)
+Date created: 2020/07/21
+Last modified: 2023/11/10
+Description: Training a timeseries classifier from scratch on the FordA dataset from the UCR/UEA archive.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example shows how to do timeseries classification from scratch, starting from raw
+CSV timeseries files on disk. We demonstrate the workflow on the FordA dataset from the
+[UCR/UEA archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/).
+
+"""
+
+"""
+## Setup
+
+"""
+import keras
+import numpy as np
+import matplotlib.pyplot as plt
+
+"""
+## Load the data: the FordA dataset
+
+### Dataset description
+
+The dataset we are using here is called FordA.
+The data comes from the UCR archive.
+The dataset contains 3601 training instances and another 1320 testing instances.
+Each timeseries corresponds to a measurement of engine noise captured by a motor sensor.
+For this task, the goal is to automatically detect the presence of a specific issue with
+the engine. The problem is a balanced binary classification task. The full description of
+this dataset can be found [here](http://www.j-wichard.de/publications/FordPaper.pdf).
+
+### Read the TSV data
+
+We will use the `FordA_TRAIN` file for training and the
+`FordA_TEST` file for testing. The simplicity of this dataset
+allows us to demonstrate effectively how to use ConvNets for timeseries classification.
+In this file, the first column corresponds to the label.
+"""
+
+
+def readucr(filename):
+ data = np.loadtxt(filename, delimiter="\t")
+ y = data[:, 0]
+ x = data[:, 1:]
+ return x, y.astype(int)
+
+
+root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"
+
+x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")
+x_test, y_test = readucr(root_url + "FordA_TEST.tsv")
+
+"""
+## Visualize the data
+
+Here we visualize one timeseries example for each class in the dataset.
+
+"""
+
+classes = np.unique(np.concatenate((y_train, y_test), axis=0))
+
+plt.figure()
+for c in classes:
+ c_x_train = x_train[y_train == c]
+ plt.plot(c_x_train[0], label="class " + str(c))
+plt.legend(loc="best")
+plt.show()
+plt.close()
+
+"""
+## Standardize the data
+
+Our timeseries are already in a single length (500). However, their values are
+usually in various ranges. This is not ideal for a neural network;
+in general we should seek to make the input values normalized.
+For this specific dataset, the data is already z-normalized: each timeseries sample
+has a mean equal to zero and a standard deviation equal to one. This type of
+normalization is very common for timeseries classification problems, see
+[Bagnall et al. (2016)](https://link.springer.com/article/10.1007/s10618-016-0483-9).
+
+Note that the timeseries data used here are univariate, meaning we only have one channel
+per timeseries example.
+We will therefore transform the timeseries into a multivariate one with one channel
+using a simple reshaping via numpy.
+This will allow us to construct a model that is easily applicable to multivariate time
+series.
+"""
+
+x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))
+x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))
+
+"""
+Finally, in order to use `sparse_categorical_crossentropy`, we will have to count
+the number of classes beforehand.
+"""
+
+num_classes = len(np.unique(y_train))
+
+"""
+Now we shuffle the training set because we will be using the `validation_split` option
+later when training.
+"""
+
+idx = np.random.permutation(len(x_train))
+x_train = x_train[idx]
+y_train = y_train[idx]
+
+"""
+Standardize the labels to positive integers.
+The expected labels will then be 0 and 1.
+"""
+
+y_train[y_train == -1] = 0
+y_test[y_test == -1] = 0
+
+"""
+## Build a model
+
+We build a Fully Convolutional Neural Network originally proposed in
+[this paper](https://arxiv.org/abs/1611.06455).
+The implementation is based on the TF 2 version provided
+[here](https://github.com/hfawaz/dl-4-tsc/).
+The following hyperparameters (kernel_size, filters, the usage of BatchNorm) were found
+via random search using [KerasTuner](https://github.com/keras-team/keras-tuner).
+
+"""
+
+
+def make_model(input_shape):
+ input_layer = keras.layers.Input(input_shape)
+
+ conv1 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(input_layer)
+ conv1 = keras.layers.BatchNormalization()(conv1)
+ conv1 = keras.layers.ReLU()(conv1)
+
+ conv2 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv1)
+ conv2 = keras.layers.BatchNormalization()(conv2)
+ conv2 = keras.layers.ReLU()(conv2)
+
+ conv3 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv2)
+ conv3 = keras.layers.BatchNormalization()(conv3)
+ conv3 = keras.layers.ReLU()(conv3)
+
+ gap = keras.layers.GlobalAveragePooling1D()(conv3)
+
+ output_layer = keras.layers.Dense(num_classes, activation="softmax")(gap)
+
+ return keras.models.Model(inputs=input_layer, outputs=output_layer)
+
+
+model = make_model(input_shape=x_train.shape[1:])
+keras.utils.plot_model(model, show_shapes=True)
+
+"""
+## Train the model
+
+"""
+
+epochs = 500
+batch_size = 32
+
+callbacks = [
+ keras.callbacks.ModelCheckpoint(
+ "best_model.keras", save_best_only=True, monitor="val_loss"
+ ),
+ keras.callbacks.ReduceLROnPlateau(
+ monitor="val_loss", factor=0.5, patience=20, min_lr=0.0001
+ ),
+ keras.callbacks.EarlyStopping(monitor="val_loss", patience=50, verbose=1),
+]
+model.compile(
+ optimizer="adam",
+ loss="sparse_categorical_crossentropy",
+ metrics=["sparse_categorical_accuracy"],
+)
+history = model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs,
+ callbacks=callbacks,
+ validation_split=0.2,
+ verbose=1,
+)
+
+"""
+## Evaluate model on test data
+"""
+
+model = keras.models.load_model("best_model.keras")
+
+test_loss, test_acc = model.evaluate(x_test, y_test)
+
+print("Test accuracy", test_acc)
+print("Test loss", test_loss)
+
+"""
+## Plot the model's training and validation loss
+"""
+
+metric = "sparse_categorical_accuracy"
+plt.figure()
+plt.plot(history.history[metric])
+plt.plot(history.history["val_" + metric])
+plt.title("model " + metric)
+plt.ylabel(metric, fontsize="large")
+plt.xlabel("epoch", fontsize="large")
+plt.legend(["train", "val"], loc="best")
+plt.show()
+plt.close()
+
+"""
+We can see how the training accuracy reaches almost 0.95 after 100 epochs.
+However, by observing the validation accuracy we can see how the network still needs
+training until it reaches almost 0.97 for both the validation and the training accuracy
+after 200 epochs. Beyond the 200th epoch, if we continue on training, the validation
+accuracy will start decreasing while the training accuracy will continue on increasing:
+the model starts overfitting.
+"""
diff --git a/knowledge_base/timeseries/timeseries_classification_transformer.py b/knowledge_base/timeseries/timeseries_classification_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bfc11d1da4ae4535c3e7f1cd3028393a199d8f2
--- /dev/null
+++ b/knowledge_base/timeseries/timeseries_classification_transformer.py
@@ -0,0 +1,174 @@
+"""
+Title: Timeseries classification with a Transformer model
+Author: [Theodoros Ntakouris](https://github.com/ntakouris)
+Date created: 2021/06/25
+Last modified: 2021/08/05
+Description: This notebook demonstrates how to do timeseries classification using a Transformer model.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This is the Transformer architecture from
+[Attention Is All You Need](https://arxiv.org/abs/1706.03762),
+applied to timeseries instead of natural language.
+
+This example requires TensorFlow 2.4 or higher.
+
+## Load the dataset
+
+We are going to use the same dataset and preprocessing as the
+[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)
+example.
+"""
+
+import numpy as np
+import keras
+from keras import layers
+
+
+def readucr(filename):
+ data = np.loadtxt(filename, delimiter="\t")
+ y = data[:, 0]
+ x = data[:, 1:]
+ return x, y.astype(int)
+
+
+root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"
+
+x_train, y_train = readucr(root_url + "FordA_TRAIN.tsv")
+x_test, y_test = readucr(root_url + "FordA_TEST.tsv")
+
+x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))
+x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))
+
+n_classes = len(np.unique(y_train))
+
+idx = np.random.permutation(len(x_train))
+x_train = x_train[idx]
+y_train = y_train[idx]
+
+y_train[y_train == -1] = 0
+y_test[y_test == -1] = 0
+
+"""
+## Build the model
+
+Our model processes a tensor of shape `(batch size, sequence length, features)`,
+where `sequence length` is the number of time steps and `features` is each input
+timeseries.
+
+You can replace your classification RNN layers with this one: the
+inputs are fully compatible!
+
+We include residual connections, layer normalization, and dropout.
+The resulting layer can be stacked multiple times.
+
+The projection layers are implemented through `keras.layers.Conv1D`.
+"""
+
+# This implementation applies Layer Normalization before the residual connection
+# to improve training stability by producing better-behaved gradients and often
+# eliminating the need for learning rate warm-up.
+
+
+def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
+ # Attention and Normalization
+ x = layers.MultiHeadAttention(
+ key_dim=head_size, num_heads=num_heads, dropout=dropout
+ )(inputs, inputs)
+ x = layers.Dropout(dropout)(x)
+ x = layers.LayerNormalization(epsilon=1e-6)(x)
+ res = x + inputs
+
+ # Feed Forward Part
+ x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
+ x = layers.Dropout(dropout)(x)
+ x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
+ x = layers.LayerNormalization(epsilon=1e-6)(x)
+ return x + res
+
+
+"""
+The main part of our model is now complete. We can stack multiple of those
+`transformer_encoder` blocks and we can also proceed to add the final
+Multi-Layer Perceptron classification head. Apart from a stack of `Dense`
+layers, we need to reduce the output tensor of the `TransformerEncoder` part of
+our model down to a vector of features for each data point in the current
+batch. A common way to achieve this is to use a pooling layer. For
+this example, a `GlobalAveragePooling1D` layer is sufficient.
+"""
+
+
+def build_model(
+ input_shape,
+ head_size,
+ num_heads,
+ ff_dim,
+ num_transformer_blocks,
+ mlp_units,
+ dropout=0,
+ mlp_dropout=0,
+):
+ inputs = keras.Input(shape=input_shape)
+ x = inputs
+ for _ in range(num_transformer_blocks):
+ x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)
+
+ x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)
+ for dim in mlp_units:
+ x = layers.Dense(dim, activation="relu")(x)
+ x = layers.Dropout(mlp_dropout)(x)
+ outputs = layers.Dense(n_classes, activation="softmax")(x)
+ return keras.Model(inputs, outputs)
+
+
+"""
+## Train and evaluate
+"""
+
+input_shape = x_train.shape[1:]
+
+model = build_model(
+ input_shape,
+ head_size=256,
+ num_heads=4,
+ ff_dim=4,
+ num_transformer_blocks=4,
+ mlp_units=[128],
+ mlp_dropout=0.4,
+ dropout=0.25,
+)
+
+model.compile(
+ loss="sparse_categorical_crossentropy",
+ optimizer=keras.optimizers.Adam(learning_rate=1e-4),
+ metrics=["sparse_categorical_accuracy"],
+)
+model.summary()
+
+callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
+
+model.fit(
+ x_train,
+ y_train,
+ validation_split=0.2,
+ epochs=150,
+ batch_size=64,
+ callbacks=callbacks,
+)
+
+model.evaluate(x_test, y_test, verbose=1)
+
+"""
+## Conclusions
+
+In about 110-120 epochs (25s each on Colab), the model reaches a training
+accuracy of ~0.95, validation accuracy of ~84 and a testing
+accuracy of ~85, without hyperparameter tuning. And that is for a model
+with less than 100k parameters. Of course, parameter count and accuracy could be
+improved by a hyperparameter search and a more sophisticated learning rate
+schedule, or a different optimizer.
+
+"""
diff --git a/knowledge_base/timeseries/timeseries_traffic_forecasting.py b/knowledge_base/timeseries/timeseries_traffic_forecasting.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea66ee2131e54cbfbfe2a9fa6ab3e13e0a10dc82
--- /dev/null
+++ b/knowledge_base/timeseries/timeseries_traffic_forecasting.py
@@ -0,0 +1,660 @@
+"""
+Title: Traffic forecasting using graph neural networks and LSTM
+Author: [Arash Khodadadi](https://www.linkedin.com/in/arash-khodadadi-08a02490/)
+Date created: 2021/12/28
+Last modified: 2023/11/22
+Description: This example demonstrates how to do timeseries forecasting over graphs.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example shows how to forecast traffic condition using graph neural networks and LSTM.
+Specifically, we are interested in predicting the future values of the traffic speed given
+a history of the traffic speed for a collection of road segments.
+
+One popular method to
+solve this problem is to consider each road segment's traffic speed as a separate
+timeseries and predict the future values of each timeseries
+using the past values of the same timeseries.
+
+This method, however, ignores the dependency of the traffic speed of one road segment on
+the neighboring segments. To be able to take into account the complex interactions between
+the traffic speed on a collection of neighboring roads, we can define the traffic network
+as a graph and consider the traffic speed as a signal on this graph. In this example,
+we implement a neural network architecture which can process timeseries data over a graph.
+We first show how to process the data and create a
+[tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for
+forecasting over graphs. Then, we implement a model which uses graph convolution and
+LSTM layers to perform forecasting over a graph.
+
+The data processing and the model architecture are inspired by this paper:
+
+Yu, Bing, Haoteng Yin, and Zhanxing Zhu. "Spatio-temporal graph convolutional networks:
+a deep learning framework for traffic forecasting." Proceedings of the 27th International
+Joint Conference on Artificial Intelligence, 2018.
+([github](https://github.com/VeritasYin/STGCN_IJCAI-18))
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import pandas as pd
+import numpy as np
+import typing
+import matplotlib.pyplot as plt
+
+import tensorflow as tf
+import keras
+from keras import layers
+from keras import ops
+
+"""
+## Data preparation
+"""
+
+"""
+### Data description
+
+We use a real-world traffic speed dataset named `PeMSD7`. We use the version
+collected and prepared by [Yu et al., 2018](https://arxiv.org/abs/1709.04875)
+and available
+[here](https://github.com/VeritasYin/STGCN_IJCAI-18/tree/master/dataset).
+
+The data consists of two files:
+
+- `PeMSD7_W_228.csv` contains the distances between 228
+stations across the District 7 of California.
+- `PeMSD7_V_228.csv` contains traffic
+speed collected for those stations in the weekdays of May and June of 2012.
+
+The full description of the dataset can be found in
+[Yu et al., 2018](https://arxiv.org/abs/1709.04875).
+"""
+
+"""
+### Loading data
+"""
+
+url = "https://github.com/VeritasYin/STGCN_IJCAI-18/raw/master/dataset/PeMSD7_Full.zip"
+data_dir = keras.utils.get_file(origin=url, extract=True, archive_format="zip")
+data_dir = data_dir.rstrip("PeMSD7_Full.zip")
+
+route_distances = pd.read_csv(
+ os.path.join(data_dir, "PeMSD7_W_228.csv"), header=None
+).to_numpy()
+speeds_array = pd.read_csv(
+ os.path.join(data_dir, "PeMSD7_V_228.csv"), header=None
+).to_numpy()
+
+print(f"route_distances shape={route_distances.shape}")
+print(f"speeds_array shape={speeds_array.shape}")
+
+"""
+### sub-sampling roads
+
+To reduce the problem size and make the training faster, we will only
+work with a sample of 26 roads out of the 228 roads in the dataset.
+We have chosen the roads by starting from road 0, choosing the 5 closest
+roads to it, and continuing this process until we get 25 roads. You can choose
+any other subset of the roads. We chose the roads in this way to increase the likelihood
+of having roads with correlated speed timeseries.
+`sample_routes` contains the IDs of the selected roads.
+"""
+
+sample_routes = [
+ 0,
+ 1,
+ 4,
+ 7,
+ 8,
+ 11,
+ 15,
+ 108,
+ 109,
+ 114,
+ 115,
+ 118,
+ 120,
+ 123,
+ 124,
+ 126,
+ 127,
+ 129,
+ 130,
+ 132,
+ 133,
+ 136,
+ 139,
+ 144,
+ 147,
+ 216,
+]
+route_distances = route_distances[np.ix_(sample_routes, sample_routes)]
+speeds_array = speeds_array[:, sample_routes]
+
+print(f"route_distances shape={route_distances.shape}")
+print(f"speeds_array shape={speeds_array.shape}")
+
+"""
+### Data visualization
+
+Here are the timeseries of the traffic speed for two of the routes:
+"""
+
+plt.figure(figsize=(18, 6))
+plt.plot(speeds_array[:, [0, -1]])
+plt.legend(["route_0", "route_25"])
+
+"""
+We can also visualize the correlation between the timeseries in different routes.
+"""
+
+plt.figure(figsize=(8, 8))
+plt.matshow(np.corrcoef(speeds_array.T), 0)
+plt.xlabel("road number")
+plt.ylabel("road number")
+
+"""
+Using this correlation heatmap, we can see that for example the speed in
+routes 4, 5, 6 are highly correlated.
+"""
+
+"""
+### Splitting and normalizing data
+
+Next, we split the speed values array into train/validation/test sets,
+and normalize the resulting arrays:
+"""
+
+train_size, val_size = 0.5, 0.2
+
+
+def preprocess(data_array: np.ndarray, train_size: float, val_size: float):
+ """Splits data into train/val/test sets and normalizes the data.
+
+ Args:
+ data_array: ndarray of shape `(num_time_steps, num_routes)`
+ train_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset
+ to include in the train split.
+ val_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset
+ to include in the validation split.
+
+ Returns:
+ `train_array`, `val_array`, `test_array`
+ """
+
+ num_time_steps = data_array.shape[0]
+ num_train, num_val = (
+ int(num_time_steps * train_size),
+ int(num_time_steps * val_size),
+ )
+ train_array = data_array[:num_train]
+ mean, std = train_array.mean(axis=0), train_array.std(axis=0)
+
+ train_array = (train_array - mean) / std
+ val_array = (data_array[num_train : (num_train + num_val)] - mean) / std
+ test_array = (data_array[(num_train + num_val) :] - mean) / std
+
+ return train_array, val_array, test_array
+
+
+train_array, val_array, test_array = preprocess(speeds_array, train_size, val_size)
+
+print(f"train set size: {train_array.shape}")
+print(f"validation set size: {val_array.shape}")
+print(f"test set size: {test_array.shape}")
+
+"""
+### Creating TensorFlow Datasets
+
+Next, we create the datasets for our forecasting problem. The forecasting problem
+can be stated as follows: given a sequence of the
+road speed values at times `t+1, t+2, ..., t+T`, we want to predict the future values of
+the roads speed for times `t+T+1, ..., t+T+h`. So for each time `t` the inputs to our
+model are `T` vectors each of size `N` and the targets are `h` vectors each of size `N`,
+where `N` is the number of roads.
+"""
+
+"""
+We use the Keras built-in function
+`keras.utils.timeseries_dataset_from_array`.
+The function `create_tf_dataset()` below takes as input a `numpy.ndarray` and returns a
+`tf.data.Dataset`. In this function `input_sequence_length=T` and `forecast_horizon=h`.
+
+The argument `multi_horizon` needs more explanation. Assume `forecast_horizon=3`.
+If `multi_horizon=True` then the model will make a forecast for time steps
+`t+T+1, t+T+2, t+T+3`. So the target will have shape `(T,3)`. But if
+`multi_horizon=False`, the model will make a forecast only for time step `t+T+3` and
+so the target will have shape `(T, 1)`.
+
+You may notice that the input tensor in each batch has shape
+`(batch_size, input_sequence_length, num_routes, 1)`. The last dimension is added to
+make the model more general: at each time step, the input features for each raod may
+contain multiple timeseries. For instance, one might want to use temperature timeseries
+in addition to historical values of the speed as input features. In this example,
+however, the last dimension of the input is always 1.
+
+We use the last 12 values of the speed in each road to forecast the speed for 3 time
+steps ahead:
+"""
+
+batch_size = 64
+input_sequence_length = 12
+forecast_horizon = 3
+multi_horizon = False
+
+
+def create_tf_dataset(
+ data_array: np.ndarray,
+ input_sequence_length: int,
+ forecast_horizon: int,
+ batch_size: int = 128,
+ shuffle=True,
+ multi_horizon=True,
+):
+ """Creates tensorflow dataset from numpy array.
+
+ This function creates a dataset where each element is a tuple `(inputs, targets)`.
+ `inputs` is a Tensor
+ of shape `(batch_size, input_sequence_length, num_routes, 1)` containing
+ the `input_sequence_length` past values of the timeseries for each node.
+ `targets` is a Tensor of shape `(batch_size, forecast_horizon, num_routes)`
+ containing the `forecast_horizon`
+ future values of the timeseries for each node.
+
+ Args:
+ data_array: np.ndarray with shape `(num_time_steps, num_routes)`
+ input_sequence_length: Length of the input sequence (in number of timesteps).
+ forecast_horizon: If `multi_horizon=True`, the target will be the values of the timeseries for 1 to
+ `forecast_horizon` timesteps ahead. If `multi_horizon=False`, the target will be the value of the
+ timeseries `forecast_horizon` steps ahead (only one value).
+ batch_size: Number of timeseries samples in each batch.
+ shuffle: Whether to shuffle output samples, or instead draw them in chronological order.
+ multi_horizon: See `forecast_horizon`.
+
+ Returns:
+ A tf.data.Dataset instance.
+ """
+
+ inputs = keras.utils.timeseries_dataset_from_array(
+ np.expand_dims(data_array[:-forecast_horizon], axis=-1),
+ None,
+ sequence_length=input_sequence_length,
+ shuffle=False,
+ batch_size=batch_size,
+ )
+
+ target_offset = (
+ input_sequence_length
+ if multi_horizon
+ else input_sequence_length + forecast_horizon - 1
+ )
+ target_seq_length = forecast_horizon if multi_horizon else 1
+ targets = keras.utils.timeseries_dataset_from_array(
+ data_array[target_offset:],
+ None,
+ sequence_length=target_seq_length,
+ shuffle=False,
+ batch_size=batch_size,
+ )
+
+ dataset = tf.data.Dataset.zip((inputs, targets))
+ if shuffle:
+ dataset = dataset.shuffle(100)
+
+ return dataset.prefetch(16).cache()
+
+
+train_dataset, val_dataset = (
+ create_tf_dataset(data_array, input_sequence_length, forecast_horizon, batch_size)
+ for data_array in [train_array, val_array]
+)
+
+test_dataset = create_tf_dataset(
+ test_array,
+ input_sequence_length,
+ forecast_horizon,
+ batch_size=test_array.shape[0],
+ shuffle=False,
+ multi_horizon=multi_horizon,
+)
+
+
+"""
+### Roads Graph
+
+As mentioned before, we assume that the road segments form a graph.
+The `PeMSD7` dataset has the road segments distance. The next step
+is to create the graph adjacency matrix from these distances. Following
+[Yu et al., 2018](https://arxiv.org/abs/1709.04875) (equation 10) we assume there
+is an edge between two nodes in the graph if the distance between the corresponding roads
+is less than a threshold.
+"""
+
+
+def compute_adjacency_matrix(
+ route_distances: np.ndarray, sigma2: float, epsilon: float
+):
+ """Computes the adjacency matrix from distances matrix.
+
+ It uses the formula in https://github.com/VeritasYin/STGCN_IJCAI-18#data-preprocessing to
+ compute an adjacency matrix from the distance matrix.
+ The implementation follows that paper.
+
+ Args:
+ route_distances: np.ndarray of shape `(num_routes, num_routes)`. Entry `i,j` of this array is the
+ distance between roads `i,j`.
+ sigma2: Determines the width of the Gaussian kernel applied to the square distances matrix.
+ epsilon: A threshold specifying if there is an edge between two nodes. Specifically, `A[i,j]=1`
+ if `np.exp(-w2[i,j] / sigma2) >= epsilon` and `A[i,j]=0` otherwise, where `A` is the adjacency
+ matrix and `w2=route_distances * route_distances`
+
+ Returns:
+ A boolean graph adjacency matrix.
+ """
+ num_routes = route_distances.shape[0]
+ route_distances = route_distances / 10000.0
+ w2, w_mask = (
+ route_distances * route_distances,
+ np.ones([num_routes, num_routes]) - np.identity(num_routes),
+ )
+ return (np.exp(-w2 / sigma2) >= epsilon) * w_mask
+
+
+"""
+The function `compute_adjacency_matrix()` returns a boolean adjacency matrix
+where 1 means there is an edge between two nodes. We use the following class
+to store the information about the graph.
+"""
+
+
+class GraphInfo:
+ def __init__(self, edges: typing.Tuple[list, list], num_nodes: int):
+ self.edges = edges
+ self.num_nodes = num_nodes
+
+
+sigma2 = 0.1
+epsilon = 0.5
+adjacency_matrix = compute_adjacency_matrix(route_distances, sigma2, epsilon)
+node_indices, neighbor_indices = np.where(adjacency_matrix == 1)
+graph = GraphInfo(
+ edges=(node_indices.tolist(), neighbor_indices.tolist()),
+ num_nodes=adjacency_matrix.shape[0],
+)
+print(f"number of nodes: {graph.num_nodes}, number of edges: {len(graph.edges[0])}")
+
+"""
+## Network architecture
+
+Our model for forecasting over the graph consists of a graph convolution
+layer and a LSTM layer.
+"""
+
+"""
+### Graph convolution layer
+
+Our implementation of the graph convolution layer resembles the implementation
+in [this Keras example](https://keras.io/examples/graph/gnn_citations/). Note that
+in that example input to the layer is a 2D tensor of shape `(num_nodes,in_feat)`
+but in our example the input to the layer is a 4D tensor of shape
+`(num_nodes, batch_size, input_seq_length, in_feat)`. The graph convolution layer
+performs the following steps:
+
+- The nodes' representations are computed in `self.compute_nodes_representation()`
+by multiplying the input features by `self.weight`
+- The aggregated neighbors' messages are computed in `self.compute_aggregated_messages()`
+by first aggregating the neighbors' representations and then multiplying the results by
+`self.weight`
+- The final output of the layer is computed in `self.update()` by combining the nodes
+representations and the neighbors' aggregated messages
+"""
+
+
+class GraphConv(layers.Layer):
+ def __init__(
+ self,
+ in_feat,
+ out_feat,
+ graph_info: GraphInfo,
+ aggregation_type="mean",
+ combination_type="concat",
+ activation: typing.Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.in_feat = in_feat
+ self.out_feat = out_feat
+ self.graph_info = graph_info
+ self.aggregation_type = aggregation_type
+ self.combination_type = combination_type
+ self.weight = self.add_weight(
+ initializer=keras.initializers.GlorotUniform(),
+ shape=(in_feat, out_feat),
+ dtype="float32",
+ trainable=True,
+ )
+ self.activation = layers.Activation(activation)
+
+ def aggregate(self, neighbour_representations):
+ aggregation_func = {
+ "sum": tf.math.unsorted_segment_sum,
+ "mean": tf.math.unsorted_segment_mean,
+ "max": tf.math.unsorted_segment_max,
+ }.get(self.aggregation_type)
+
+ if aggregation_func:
+ return aggregation_func(
+ neighbour_representations,
+ self.graph_info.edges[0],
+ num_segments=self.graph_info.num_nodes,
+ )
+
+ raise ValueError(f"Invalid aggregation type: {self.aggregation_type}")
+
+ def compute_nodes_representation(self, features):
+ """Computes each node's representation.
+
+ The nodes' representations are obtained by multiplying the features tensor with
+ `self.weight`. Note that
+ `self.weight` has shape `(in_feat, out_feat)`.
+
+ Args:
+ features: Tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`
+
+ Returns:
+ A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`
+ """
+ return ops.matmul(features, self.weight)
+
+ def compute_aggregated_messages(self, features):
+ neighbour_representations = tf.gather(features, self.graph_info.edges[1])
+ aggregated_messages = self.aggregate(neighbour_representations)
+ return ops.matmul(aggregated_messages, self.weight)
+
+ def update(self, nodes_representation, aggregated_messages):
+ if self.combination_type == "concat":
+ h = ops.concatenate([nodes_representation, aggregated_messages], axis=-1)
+ elif self.combination_type == "add":
+ h = nodes_representation + aggregated_messages
+ else:
+ raise ValueError(f"Invalid combination type: {self.combination_type}.")
+ return self.activation(h)
+
+ def call(self, features):
+ """Forward pass.
+
+ Args:
+ features: tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`
+
+ Returns:
+ A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`
+ """
+ nodes_representation = self.compute_nodes_representation(features)
+ aggregated_messages = self.compute_aggregated_messages(features)
+ return self.update(nodes_representation, aggregated_messages)
+
+
+"""
+### LSTM plus graph convolution
+
+By applying the graph convolution layer to the input tensor, we get another tensor
+containing the nodes' representations over time (another 4D tensor). For each time
+step, a node's representation is informed by the information from its neighbors.
+
+To make good forecasts, however, we need not only information from the neighbors
+but also we need to process the information over time. To this end, we can pass each
+node's tensor through a recurrent layer. The `LSTMGC` layer below, first applies
+a graph convolution layer to the inputs and then passes the results through a
+`LSTM` layer.
+"""
+
+
+class LSTMGC(layers.Layer):
+ """Layer comprising a convolution layer followed by LSTM and dense layers."""
+
+ def __init__(
+ self,
+ in_feat,
+ out_feat,
+ lstm_units: int,
+ input_seq_len: int,
+ output_seq_len: int,
+ graph_info: GraphInfo,
+ graph_conv_params: typing.Optional[dict] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ # graph conv layer
+ if graph_conv_params is None:
+ graph_conv_params = {
+ "aggregation_type": "mean",
+ "combination_type": "concat",
+ "activation": None,
+ }
+ self.graph_conv = GraphConv(in_feat, out_feat, graph_info, **graph_conv_params)
+
+ self.lstm = layers.LSTM(lstm_units, activation="relu")
+ self.dense = layers.Dense(output_seq_len)
+
+ self.input_seq_len, self.output_seq_len = input_seq_len, output_seq_len
+
+ def call(self, inputs):
+ """Forward pass.
+
+ Args:
+ inputs: tensor of shape `(batch_size, input_seq_len, num_nodes, in_feat)`
+
+ Returns:
+ A tensor of shape `(batch_size, output_seq_len, num_nodes)`.
+ """
+
+ # convert shape to (num_nodes, batch_size, input_seq_len, in_feat)
+ inputs = ops.transpose(inputs, [2, 0, 1, 3])
+
+ gcn_out = self.graph_conv(
+ inputs
+ ) # gcn_out has shape: (num_nodes, batch_size, input_seq_len, out_feat)
+ shape = ops.shape(gcn_out)
+ num_nodes, batch_size, input_seq_len, out_feat = (
+ shape[0],
+ shape[1],
+ shape[2],
+ shape[3],
+ )
+
+ # LSTM takes only 3D tensors as input
+ gcn_out = ops.reshape(
+ gcn_out, (batch_size * num_nodes, input_seq_len, out_feat)
+ )
+ lstm_out = self.lstm(
+ gcn_out
+ ) # lstm_out has shape: (batch_size * num_nodes, lstm_units)
+
+ dense_output = self.dense(
+ lstm_out
+ ) # dense_output has shape: (batch_size * num_nodes, output_seq_len)
+ output = ops.reshape(dense_output, (num_nodes, batch_size, self.output_seq_len))
+ return ops.transpose(
+ output, [1, 2, 0]
+ ) # returns Tensor of shape (batch_size, output_seq_len, num_nodes)
+
+
+"""
+## Model training
+"""
+
+in_feat = 1
+batch_size = 64
+epochs = 20
+input_sequence_length = 12
+forecast_horizon = 3
+multi_horizon = False
+out_feat = 10
+lstm_units = 64
+graph_conv_params = {
+ "aggregation_type": "mean",
+ "combination_type": "concat",
+ "activation": None,
+}
+
+st_gcn = LSTMGC(
+ in_feat,
+ out_feat,
+ lstm_units,
+ input_sequence_length,
+ forecast_horizon,
+ graph,
+ graph_conv_params,
+)
+inputs = layers.Input((input_sequence_length, graph.num_nodes, in_feat))
+outputs = st_gcn(inputs)
+
+model = keras.models.Model(inputs, outputs)
+model.compile(
+ optimizer=keras.optimizers.RMSprop(learning_rate=0.0002),
+ loss=keras.losses.MeanSquaredError(),
+)
+model.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=epochs,
+ callbacks=[keras.callbacks.EarlyStopping(patience=10)],
+)
+
+"""
+## Making forecasts on test set
+
+Now we can use the trained model to make forecasts for the test set. Below, we
+compute the MAE of the model and compare it to the MAE of naive forecasts.
+The naive forecasts are the last value of the speed for each node.
+"""
+
+x_test, y = next(test_dataset.as_numpy_iterator())
+y_pred = model.predict(x_test)
+plt.figure(figsize=(18, 6))
+plt.plot(y[:, 0, 0])
+plt.plot(y_pred[:, 0, 0])
+plt.legend(["actual", "forecast"])
+
+naive_mse, model_mse = (
+ np.square(x_test[:, -1, :, 0] - y[:, 0, :]).mean(),
+ np.square(y_pred[:, 0, :] - y[:, 0, :]).mean(),
+)
+print(f"naive MAE: {naive_mse}, model MAE: {model_mse}")
+
+"""
+Of course, the goal here is to demonstrate the method,
+not to achieve the best performance. To improve the
+model's accuracy, all model hyperparameters should be tuned carefully. In addition,
+several of the `LSTMGC` blocks can be stacked to increase the representation power
+of the model.
+"""
diff --git a/knowledge_base/timeseries/timeseries_weather_forecasting.py b/knowledge_base/timeseries/timeseries_weather_forecasting.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b30771948734df88e7ccb7d1ffd620a66521692
--- /dev/null
+++ b/knowledge_base/timeseries/timeseries_weather_forecasting.py
@@ -0,0 +1,361 @@
+"""
+Title: Timeseries forecasting for weather prediction
+Authors: [Prabhanshu Attri](https://prabhanshu.com/github), [Yashika Sharma](https://github.com/yashika51), [Kristi Takach](https://github.com/ktakattack), [Falak Shah](https://github.com/falaktheoptimist)
+Date created: 2020/06/23
+Last modified: 2023/11/22
+Description: This notebook demonstrates how to do timeseries forecasting using a LSTM model.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import pandas as pd
+import matplotlib.pyplot as plt
+import keras
+
+"""
+## Climate Data Time-Series
+
+We will be using Jena Climate dataset recorded by the
+[Max Planck Institute for Biogeochemistry](https://www.bgc-jena.mpg.de/wetter/).
+The dataset consists of 14 features such as temperature, pressure, humidity etc, recorded once per
+10 minutes.
+
+**Location**: Weather Station, Max Planck Institute for Biogeochemistry
+in Jena, Germany
+
+**Time-frame Considered**: Jan 10, 2009 - December 31, 2016
+
+
+The table below shows the column names, their value formats, and their description.
+
+Index| Features |Format |Description
+-----|---------------|-------------------|-----------------------
+1 |Date Time |01.01.2009 00:10:00|Date-time reference
+2 |p (mbar) |996.52 |The pascal SI derived unit of pressure used to quantify internal pressure. Meteorological reports typically state atmospheric pressure in millibars.
+3 |T (degC) |-8.02 |Temperature in Celsius
+4 |Tpot (K) |265.4 |Temperature in Kelvin
+5 |Tdew (degC) |-8.9 |Temperature in Celsius relative to humidity. Dew Point is a measure of the absolute amount of water in the air, the DP is the temperature at which the air cannot hold all the moisture in it and water condenses.
+6 |rh (%) |93.3 |Relative Humidity is a measure of how saturated the air is with water vapor, the %RH determines the amount of water contained within collection objects.
+7 |VPmax (mbar) |3.33 |Saturation vapor pressure
+8 |VPact (mbar) |3.11 |Vapor pressure
+9 |VPdef (mbar) |0.22 |Vapor pressure deficit
+10 |sh (g/kg) |1.94 |Specific humidity
+11 |H2OC (mmol/mol)|3.12 |Water vapor concentration
+12 |rho (g/m ** 3) |1307.75 |Airtight
+13 |wv (m/s) |1.03 |Wind speed
+14 |max. wv (m/s) |1.75 |Maximum wind speed
+15 |wd (deg) |152.3 |Wind direction in degrees
+"""
+
+from zipfile import ZipFile
+
+uri = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"
+zip_path = keras.utils.get_file(origin=uri, fname="jena_climate_2009_2016.csv.zip")
+zip_file = ZipFile(zip_path)
+zip_file.extractall()
+csv_path = "jena_climate_2009_2016.csv"
+
+df = pd.read_csv(csv_path)
+
+"""
+## Raw Data Visualization
+
+To give us a sense of the data we are working with, each feature has been plotted below.
+This shows the distinct pattern of each feature over the time period from 2009 to 2016.
+It also shows where anomalies are present, which will be addressed during normalization.
+"""
+
+titles = [
+ "Pressure",
+ "Temperature",
+ "Temperature in Kelvin",
+ "Temperature (dew point)",
+ "Relative Humidity",
+ "Saturation vapor pressure",
+ "Vapor pressure",
+ "Vapor pressure deficit",
+ "Specific humidity",
+ "Water vapor concentration",
+ "Airtight",
+ "Wind speed",
+ "Maximum wind speed",
+ "Wind direction in degrees",
+]
+
+feature_keys = [
+ "p (mbar)",
+ "T (degC)",
+ "Tpot (K)",
+ "Tdew (degC)",
+ "rh (%)",
+ "VPmax (mbar)",
+ "VPact (mbar)",
+ "VPdef (mbar)",
+ "sh (g/kg)",
+ "H2OC (mmol/mol)",
+ "rho (g/m**3)",
+ "wv (m/s)",
+ "max. wv (m/s)",
+ "wd (deg)",
+]
+
+colors = [
+ "blue",
+ "orange",
+ "green",
+ "red",
+ "purple",
+ "brown",
+ "pink",
+ "gray",
+ "olive",
+ "cyan",
+]
+
+date_time_key = "Date Time"
+
+
+def show_raw_visualization(data):
+ time_data = data[date_time_key]
+ fig, axes = plt.subplots(
+ nrows=7, ncols=2, figsize=(15, 20), dpi=80, facecolor="w", edgecolor="k"
+ )
+ for i in range(len(feature_keys)):
+ key = feature_keys[i]
+ c = colors[i % (len(colors))]
+ t_data = data[key]
+ t_data.index = time_data
+ t_data.head()
+ ax = t_data.plot(
+ ax=axes[i // 2, i % 2],
+ color=c,
+ title="{} - {}".format(titles[i], key),
+ rot=25,
+ )
+ ax.legend([titles[i]])
+ plt.tight_layout()
+
+
+show_raw_visualization(df)
+
+
+"""
+## Data Preprocessing
+
+Here we are picking ~300,000 data points for training. Observation is recorded every
+10 mins, that means 6 times per hour. We will resample one point per hour since no
+drastic change is expected within 60 minutes. We do this via the `sampling_rate`
+argument in `timeseries_dataset_from_array` utility.
+
+We are tracking data from past 720 timestamps (720/6=120 hours). This data will be
+used to predict the temperature after 72 timestamps (72/6=12 hours).
+
+Since every feature has values with
+varying ranges, we do normalization to confine feature values to a range of `[0, 1]` before
+training a neural network.
+We do this by subtracting the mean and dividing by the standard deviation of each feature.
+
+71.5 % of the data will be used to train the model, i.e. 300,693 rows. `split_fraction` can
+be changed to alter this percentage.
+
+The model is shown data for first 5 days i.e. 720 observations, that are sampled every
+hour. The temperature after 72 (12 hours * 6 observation per hour) observation will be
+used as a label.
+"""
+
+split_fraction = 0.715
+train_split = int(split_fraction * int(df.shape[0]))
+step = 6
+
+past = 720
+future = 72
+learning_rate = 0.001
+batch_size = 256
+epochs = 10
+
+
+def normalize(data, train_split):
+ data_mean = data[:train_split].mean(axis=0)
+ data_std = data[:train_split].std(axis=0)
+ return (data - data_mean) / data_std
+
+
+"""
+We can see from the correlation heatmap, few parameters like Relative Humidity and
+Specific Humidity are redundant. Hence we will be using select features, not all.
+"""
+
+print(
+ "The selected parameters are:",
+ ", ".join([titles[i] for i in [0, 1, 5, 7, 8, 10, 11]]),
+)
+selected_features = [feature_keys[i] for i in [0, 1, 5, 7, 8, 10, 11]]
+features = df[selected_features]
+features.index = df[date_time_key]
+features.head()
+
+features = normalize(features.values, train_split)
+features = pd.DataFrame(features)
+features.head()
+
+train_data = features.loc[0 : train_split - 1]
+val_data = features.loc[train_split:]
+
+"""
+# Training dataset
+
+The training dataset labels starts from the 792nd observation (720 + 72).
+"""
+
+start = past + future
+end = start + train_split
+
+x_train = train_data[[i for i in range(7)]].values
+y_train = features.iloc[start:end][[1]]
+
+sequence_length = int(past / step)
+
+"""
+The `timeseries_dataset_from_array` function takes in a sequence of data-points gathered at
+equal intervals, along with time series parameters such as length of the
+sequences/windows, spacing between two sequence/windows, etc., to produce batches of
+sub-timeseries inputs and targets sampled from the main timeseries.
+"""
+
+dataset_train = keras.preprocessing.timeseries_dataset_from_array(
+ x_train,
+ y_train,
+ sequence_length=sequence_length,
+ sampling_rate=step,
+ batch_size=batch_size,
+)
+
+"""
+## Validation dataset
+
+The validation dataset must not contain the last 792 rows as we won't have label data for
+those records, hence 792 must be subtracted from the end of the data.
+
+The validation label dataset must start from 792 after train_split, hence we must add
+past + future (792) to label_start.
+"""
+
+x_end = len(val_data) - past - future
+
+label_start = train_split + past + future
+
+x_val = val_data.iloc[:x_end][[i for i in range(7)]].values
+y_val = features.iloc[label_start:][[1]]
+
+dataset_val = keras.preprocessing.timeseries_dataset_from_array(
+ x_val,
+ y_val,
+ sequence_length=sequence_length,
+ sampling_rate=step,
+ batch_size=batch_size,
+)
+
+
+for batch in dataset_train.take(1):
+ inputs, targets = batch
+
+print("Input shape:", inputs.numpy().shape)
+print("Target shape:", targets.numpy().shape)
+
+"""
+## Training
+"""
+
+inputs = keras.layers.Input(shape=(inputs.shape[1], inputs.shape[2]))
+lstm_out = keras.layers.LSTM(32)(inputs)
+outputs = keras.layers.Dense(1)(lstm_out)
+
+model = keras.Model(inputs=inputs, outputs=outputs)
+model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss="mse")
+model.summary()
+
+"""
+We'll use the `ModelCheckpoint` callback to regularly save checkpoints, and
+the `EarlyStopping` callback to interrupt training when the validation loss
+is not longer improving.
+"""
+
+path_checkpoint = "model_checkpoint.weights.h5"
+es_callback = keras.callbacks.EarlyStopping(monitor="val_loss", min_delta=0, patience=5)
+
+modelckpt_callback = keras.callbacks.ModelCheckpoint(
+ monitor="val_loss",
+ filepath=path_checkpoint,
+ verbose=1,
+ save_weights_only=True,
+ save_best_only=True,
+)
+
+history = model.fit(
+ dataset_train,
+ epochs=epochs,
+ validation_data=dataset_val,
+ callbacks=[es_callback, modelckpt_callback],
+)
+
+"""
+We can visualize the loss with the function below. After one point, the loss stops
+decreasing.
+"""
+
+
+def visualize_loss(history, title):
+ loss = history.history["loss"]
+ val_loss = history.history["val_loss"]
+ epochs = range(len(loss))
+ plt.figure()
+ plt.plot(epochs, loss, "b", label="Training loss")
+ plt.plot(epochs, val_loss, "r", label="Validation loss")
+ plt.title(title)
+ plt.xlabel("Epochs")
+ plt.ylabel("Loss")
+ plt.legend()
+ plt.show()
+
+
+visualize_loss(history, "Training and Validation Loss")
+
+"""
+## Prediction
+
+The trained model above is now able to make predictions for 5 sets of values from
+validation set.
+"""
+
+
+def show_plot(plot_data, delta, title):
+ labels = ["History", "True Future", "Model Prediction"]
+ marker = [".-", "rx", "go"]
+ time_steps = list(range(-(plot_data[0].shape[0]), 0))
+ if delta:
+ future = delta
+ else:
+ future = 0
+
+ plt.title(title)
+ for i, val in enumerate(plot_data):
+ if i:
+ plt.plot(future, plot_data[i], marker[i], markersize=10, label=labels[i])
+ else:
+ plt.plot(time_steps, plot_data[i].flatten(), marker[i], label=labels[i])
+ plt.legend()
+ plt.xlim([time_steps[0], (future + 5) * 2])
+ plt.xlabel("Time-Step")
+ plt.show()
+ return
+
+
+for x, y in dataset_val.take(5):
+ show_plot(
+ [x[0][:, 1].numpy(), y[0].numpy(), model.predict(x)[0]],
+ 12,
+ "Single Step Prediction",
+ )
diff --git a/knowledge_base/vision/3D_image_classification.py b/knowledge_base/vision/3D_image_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..9449a8efdb30f7c801743b8d0268fcdce9a0cea0
--- /dev/null
+++ b/knowledge_base/vision/3D_image_classification.py
@@ -0,0 +1,439 @@
+"""
+Title: 3D image classification from CT scans
+Author: [Hasib Zunair](https://twitter.com/hasibzunair)
+Date created: 2020/09/23
+Last modified: 2024/01/11
+Description: Train a 3D convolutional neural network to predict presence of pneumonia.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example will show the steps needed to build a 3D convolutional neural network (CNN)
+to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are
+commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D
+equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan),
+3D CNNs are a powerful model for learning representations for volumetric data.
+
+## References
+
+- [A survey on Deep Learning Advances on Different 3D DataRepresentations](https://arxiv.org/abs/1808.01462)
+- [VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition](https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)
+- [FusionNet: 3D Object Classification Using MultipleData Representations](https://arxiv.org/abs/1607.05695)
+- [Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction](https://arxiv.org/abs/2007.13224)
+"""
+"""
+## Setup
+"""
+
+import os
+import zipfile
+import numpy as np
+import tensorflow as tf # for data preprocessing
+
+import keras
+from keras import layers
+
+"""
+## Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings
+
+In this example, we use a subset of the
+[MosMedData: Chest CT Scans with COVID-19 Related Findings](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1).
+This dataset consists of lung CT scans with COVID-19 related findings, as well as without such findings.
+
+We will be using the associated radiological findings of the CT scans as labels to build
+a classifier to predict presence of viral pneumonia.
+Hence, the task is a binary classification problem.
+"""
+
+# Download url of normal CT scans.
+url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip"
+filename = os.path.join(os.getcwd(), "CT-0.zip")
+keras.utils.get_file(filename, url)
+
+# Download url of abnormal CT scans.
+url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"
+filename = os.path.join(os.getcwd(), "CT-23.zip")
+keras.utils.get_file(filename, url)
+
+# Make a directory to store the data.
+os.makedirs("MosMedData")
+
+# Unzip data in the newly created directory.
+with zipfile.ZipFile("CT-0.zip", "r") as z_fp:
+ z_fp.extractall("./MosMedData/")
+
+with zipfile.ZipFile("CT-23.zip", "r") as z_fp:
+ z_fp.extractall("./MosMedData/")
+
+"""
+## Loading data and preprocessing
+
+The files are provided in Nifti format with the extension .nii. To read the
+scans, we use the `nibabel` package.
+You can install the package via `pip install nibabel`. CT scans store raw voxel
+intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset.
+Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold
+between -1000 and 400 is commonly used to normalize CT scans.
+
+To process the data, we do the following:
+
+* We first rotate the volumes by 90 degrees, so the orientation is fixed
+* We scale the HU values to be between 0 and 1.
+* We resize width, height and depth.
+
+Here we define several helper functions to process the data. These functions
+will be used when building training and validation datasets.
+"""
+
+
+import nibabel as nib
+
+from scipy import ndimage
+
+
+def read_nifti_file(filepath):
+ """Read and load volume"""
+ # Read file
+ scan = nib.load(filepath)
+ # Get raw data
+ scan = scan.get_fdata()
+ return scan
+
+
+def normalize(volume):
+ """Normalize the volume"""
+ min = -1000
+ max = 400
+ volume[volume < min] = min
+ volume[volume > max] = max
+ volume = (volume - min) / (max - min)
+ volume = volume.astype("float32")
+ return volume
+
+
+def resize_volume(img):
+ """Resize across z-axis"""
+ # Set the desired depth
+ desired_depth = 64
+ desired_width = 128
+ desired_height = 128
+ # Get current depth
+ current_depth = img.shape[-1]
+ current_width = img.shape[0]
+ current_height = img.shape[1]
+ # Compute depth factor
+ depth = current_depth / desired_depth
+ width = current_width / desired_width
+ height = current_height / desired_height
+ depth_factor = 1 / depth
+ width_factor = 1 / width
+ height_factor = 1 / height
+ # Rotate
+ img = ndimage.rotate(img, 90, reshape=False)
+ # Resize across z-axis
+ img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
+ return img
+
+
+def process_scan(path):
+ """Read and resize volume"""
+ # Read scan
+ volume = read_nifti_file(path)
+ # Normalize
+ volume = normalize(volume)
+ # Resize width, height and depth
+ volume = resize_volume(volume)
+ return volume
+
+
+"""
+Let's read the paths of the CT scans from the class directories.
+"""
+
+# Folder "CT-0" consist of CT scans having normal lung tissue,
+# no CT-signs of viral pneumonia.
+normal_scan_paths = [
+ os.path.join(os.getcwd(), "MosMedData/CT-0", x)
+ for x in os.listdir("MosMedData/CT-0")
+]
+# Folder "CT-23" consist of CT scans having several ground-glass opacifications,
+# involvement of lung parenchyma.
+abnormal_scan_paths = [
+ os.path.join(os.getcwd(), "MosMedData/CT-23", x)
+ for x in os.listdir("MosMedData/CT-23")
+]
+
+print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
+print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))
+
+
+"""
+## Build train and validation datasets
+Read the scans from the class directories and assign labels. Downsample the scans to have
+shape of 128x128x64. Rescale the raw HU values to the range 0 to 1.
+Lastly, split the dataset into train and validation subsets.
+"""
+
+# Read and process the scans.
+# Each scan is resized across height, width, and depth and rescaled.
+abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
+normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
+
+# For the CT scans having presence of viral pneumonia
+# assign 1, for the normal ones assign 0.
+abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
+normal_labels = np.array([0 for _ in range(len(normal_scans))])
+
+# Split data in the ratio 70-30 for training and validation.
+x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
+y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
+x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
+y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
+print(
+ "Number of samples in train and validation are %d and %d."
+ % (x_train.shape[0], x_val.shape[0])
+)
+
+"""
+## Data augmentation
+
+The CT scans also augmented by rotating at random angles during training. Since
+the data is stored in rank-3 tensors of shape `(samples, height, width, depth)`,
+we add a dimension of size 1 at axis 4 to be able to perform 3D convolutions on
+the data. The new shape is thus `(samples, height, width, depth, 1)`. There are
+different kinds of preprocessing and augmentation techniques out there,
+this example shows a few simple ones to get started.
+"""
+
+import random
+
+from scipy import ndimage
+
+
+def rotate(volume):
+ """Rotate the volume by a few degrees"""
+
+ def scipy_rotate(volume):
+ # define some rotation angles
+ angles = [-20, -10, -5, 5, 10, 20]
+ # pick angles at random
+ angle = random.choice(angles)
+ # rotate volume
+ volume = ndimage.rotate(volume, angle, reshape=False)
+ volume[volume < 0] = 0
+ volume[volume > 1] = 1
+ return volume
+
+ augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
+ return augmented_volume
+
+
+def train_preprocessing(volume, label):
+ """Process training data by rotating and adding a channel."""
+ # Rotate volume
+ volume = rotate(volume)
+ volume = tf.expand_dims(volume, axis=3)
+ return volume, label
+
+
+def validation_preprocessing(volume, label):
+ """Process validation data by only adding a channel."""
+ volume = tf.expand_dims(volume, axis=3)
+ return volume, label
+
+
+"""
+While defining the train and validation data loader, the training data is passed through
+and augmentation function which randomly rotates volume at different angles. Note that both
+training and validation data are already rescaled to have values between 0 and 1.
+"""
+
+# Define data loaders.
+train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+
+batch_size = 2
+# Augment the on the fly during training.
+train_dataset = (
+ train_loader.shuffle(len(x_train))
+ .map(train_preprocessing)
+ .batch(batch_size)
+ .prefetch(2)
+)
+# Only rescale.
+validation_dataset = (
+ validation_loader.shuffle(len(x_val))
+ .map(validation_preprocessing)
+ .batch(batch_size)
+ .prefetch(2)
+)
+
+"""
+Visualize an augmented CT scan.
+"""
+
+import matplotlib.pyplot as plt
+
+data = train_dataset.take(1)
+images, labels = list(data)[0]
+images = images.numpy()
+image = images[0]
+print("Dimension of the CT scan is:", image.shape)
+plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
+
+
+"""
+Since a CT scan has many slices, let's visualize a montage of the slices.
+"""
+
+
+def plot_slices(num_rows, num_columns, width, height, data):
+ """Plot a montage of 20 CT slices"""
+ data = np.rot90(np.array(data))
+ data = np.transpose(data)
+ data = np.reshape(data, (num_rows, num_columns, width, height))
+ rows_data, columns_data = data.shape[0], data.shape[1]
+ heights = [slc[0].shape[0] for slc in data]
+ widths = [slc.shape[1] for slc in data[0]]
+ fig_width = 12.0
+ fig_height = fig_width * sum(heights) / sum(widths)
+ f, axarr = plt.subplots(
+ rows_data,
+ columns_data,
+ figsize=(fig_width, fig_height),
+ gridspec_kw={"height_ratios": heights},
+ )
+ for i in range(rows_data):
+ for j in range(columns_data):
+ axarr[i, j].imshow(data[i][j], cmap="gray")
+ axarr[i, j].axis("off")
+ plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
+ plt.show()
+
+
+# Visualize montage of slices.
+# 4 rows and 10 columns for 100 slices of the CT scan.
+plot_slices(4, 10, 128, 128, image[:, :, :40])
+
+"""
+## Define a 3D convolutional neural network
+
+To make the model easier to understand, we structure it into blocks.
+The architecture of the 3D CNN used in this example
+is based on [this paper](https://arxiv.org/abs/2007.13224).
+"""
+
+
+def get_model(width=128, height=128, depth=64):
+ """Build a 3D convolutional neural network model."""
+
+ inputs = keras.Input((width, height, depth, 1))
+
+ x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
+ x = layers.MaxPool3D(pool_size=2)(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
+ x = layers.MaxPool3D(pool_size=2)(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
+ x = layers.MaxPool3D(pool_size=2)(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
+ x = layers.MaxPool3D(pool_size=2)(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.GlobalAveragePooling3D()(x)
+ x = layers.Dense(units=512, activation="relu")(x)
+ x = layers.Dropout(0.3)(x)
+
+ outputs = layers.Dense(units=1, activation="sigmoid")(x)
+
+ # Define the model.
+ model = keras.Model(inputs, outputs, name="3dcnn")
+ return model
+
+
+# Build model.
+model = get_model(width=128, height=128, depth=64)
+model.summary()
+
+"""
+## Train model
+"""
+
+# Compile model.
+initial_learning_rate = 0.0001
+lr_schedule = keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
+)
+model.compile(
+ loss="binary_crossentropy",
+ optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
+ metrics=["acc"],
+ run_eagerly=True,
+)
+
+# Define callbacks.
+checkpoint_cb = keras.callbacks.ModelCheckpoint(
+ "3d_image_classification.keras", save_best_only=True
+)
+early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)
+
+# Train the model, doing validation at the end of each epoch
+epochs = 100
+model.fit(
+ train_dataset,
+ validation_data=validation_dataset,
+ epochs=epochs,
+ shuffle=True,
+ verbose=2,
+ callbacks=[checkpoint_cb, early_stopping_cb],
+)
+
+"""
+It is important to note that the number of samples is very small (only 200) and we don't
+specify a random seed. As such, you can expect significant variance in the results. The full dataset
+which consists of over 1000 CT scans can be found [here](https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1). Using the full
+dataset, an accuracy of 83% was achieved. A variability of 6-7% in the classification
+performance is observed in both cases.
+"""
+
+"""
+## Visualizing model performance
+
+Here the model accuracy and loss for the training and the validation sets are plotted.
+Since the validation set is class-balanced, accuracy provides an unbiased representation
+of the model's performance.
+"""
+
+fig, ax = plt.subplots(1, 2, figsize=(20, 3))
+ax = ax.ravel()
+
+for i, metric in enumerate(["acc", "loss"]):
+ ax[i].plot(model.history.history[metric])
+ ax[i].plot(model.history.history["val_" + metric])
+ ax[i].set_title("Model {}".format(metric))
+ ax[i].set_xlabel("epochs")
+ ax[i].set_ylabel(metric)
+ ax[i].legend(["train", "val"])
+
+"""
+## Make predictions on a single CT scan
+"""
+
+# Load best weights.
+model.load_weights("3d_image_classification.keras")
+prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
+scores = [1 - prediction[0], prediction[0]]
+
+class_names = ["normal", "abnormal"]
+for score, name in zip(scores, class_names):
+ print(
+ "This model is %.2f percent confident that CT scan is %s"
+ % ((100 * score), name)
+ )
diff --git a/knowledge_base/vision/adamatch.py b/knowledge_base/vision/adamatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..afd2365499594753432bcb71c7aca5d04ee1d422
--- /dev/null
+++ b/knowledge_base/vision/adamatch.py
@@ -0,0 +1,603 @@
+"""
+Title: Semi-supervision and domain adaptation with AdaMatch
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/06/19
+Last modified: 2021/06/19
+Description: Unifying semi-supervised learning and unsupervised domain adaptation with AdaMatch.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we will implement the AdaMatch algorithm, proposed in
+[AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation](https://arxiv.org/abs/2106.04732)
+by Berthelot et al. It sets a new state-of-the-art in unsupervised domain adaptation (as of
+June 2021). AdaMatch is particularly interesting because it
+unifies semi-supervised learning (SSL) and unsupervised domain adaptation
+(UDA) under one framework. It thereby provides a way to perform semi-supervised domain
+adaptation (SSDA).
+
+This example requires TensorFlow 2.5 or higher, as well as TensorFlow Models, which can
+be installed using the following command:
+"""
+
+"""shell
+pip install -q tf-models-official==2.9.2
+"""
+
+"""
+Before we proceed, let's review a few preliminary concepts underlying this example.
+"""
+
+"""
+## Preliminaries
+
+In **semi-supervised learning (SSL)**, we use a small amount of labeled data to
+train models on a bigger unlabeled dataset. Popular semi-supervised learning methods
+for computer vision include [FixMatch](https://arxiv.org/abs/2001.07685),
+[MixMatch](https://arxiv.org/abs/1905.02249),
+[Noisy Student Training](https://arxiv.org/abs/1911.04252), etc. You can refer to
+[this example](https://keras.io/examples/vision/consistency_training/) to get an idea
+of what a standard SSL workflow looks like.
+
+In **unsupervised domain adaptation**, we have access to a source labeled dataset and
+a target *unlabeled* dataset. Then the task is to learn a model that can generalize well
+to the target dataset. The source and the target datasets vary in terms of distribution.
+The following figure provides an illustration of this idea. In the present example, we use the
+[MNIST dataset](http://yann.lecun.com/exdb/mnist/) as the source dataset, while the target dataset is
+[SVHN](http://ufldl.stanford.edu/housenumbers/), which consists of images of house
+numbers. Both datasets have various varying factors in terms of texture, viewpoint,
+appearance, etc.: their domains, or distributions, are different from one
+another.
+
+
+
+Popular domain adaptation algorithms in deep learning include
+[Deep CORAL](https://arxiv.org/abs/1612.01939),
+[Moment Matching](https://arxiv.org/abs/1812.01754), etc.
+"""
+
+"""
+## Setup
+"""
+
+import tensorflow as tf
+
+tf.random.set_seed(42)
+
+import numpy as np
+
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras import regularizers
+from keras_cv.layers import RandAugment
+
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+
+"""
+## Prepare the data
+"""
+
+# MNIST
+(
+ (mnist_x_train, mnist_y_train),
+ (mnist_x_test, mnist_y_test),
+) = keras.datasets.mnist.load_data()
+
+# Add a channel dimension
+mnist_x_train = tf.expand_dims(mnist_x_train, -1)
+mnist_x_test = tf.expand_dims(mnist_x_test, -1)
+
+# Convert the labels to one-hot encoded vectors
+mnist_y_train = tf.one_hot(mnist_y_train, 10).numpy()
+
+# SVHN
+svhn_train, svhn_test = tfds.load(
+ "svhn_cropped", split=["train", "test"], as_supervised=True
+)
+
+"""
+## Define constants and hyperparameters
+"""
+
+RESIZE_TO = 32
+
+SOURCE_BATCH_SIZE = 64
+TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE # Reference: Section 3.2
+EPOCHS = 10
+STEPS_PER_EPOCH = len(mnist_x_train) // SOURCE_BATCH_SIZE
+TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH
+
+AUTO = tf.data.AUTOTUNE
+LEARNING_RATE = 0.03
+
+WEIGHT_DECAY = 0.0005
+INIT = "he_normal"
+DEPTH = 28
+WIDTH_MULT = 2
+
+"""
+## Data augmentation utilities
+
+A standard element of SSL algorithms is to feed weakly and strongly augmented versions of
+the same images to the learning model to make its predictions consistent. For strong
+augmentation, [RandAugment](https://arxiv.org/abs/1909.13719) is a standard choice. For
+weak augmentation, we will use horizontal flipping and random cropping.
+"""
+
+# Initialize `RandAugment` object with 2 layers of
+# augmentation transforms and strength of 5.
+augmenter = RandAugment(value_range=(0, 255), augmentations_per_image=2, magnitude=0.5)
+
+
+def weak_augment(image, source=True):
+ if image.dtype != tf.float32:
+ image = tf.cast(image, tf.float32)
+
+ # MNIST images are grayscale, this is why we first convert them to
+ # RGB images.
+ if source:
+ image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
+ image = tf.tile(image, [1, 1, 3])
+ image = tf.image.random_flip_left_right(image)
+ image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3))
+ return image
+
+
+def strong_augment(image, source=True):
+ if image.dtype != tf.float32:
+ image = tf.cast(image, tf.float32)
+
+ if source:
+ image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
+ image = tf.tile(image, [1, 1, 3])
+ image = augmenter(image)
+ return image
+
+
+"""
+## Data loading utilities
+"""
+
+
+def create_individual_ds(ds, aug_func, source=True):
+ if source:
+ batch_size = SOURCE_BATCH_SIZE
+ else:
+ # During training 3x more target unlabeled samples are shown
+ # to the model in AdaMatch (Section 3.2 of the paper).
+ batch_size = TARGET_BATCH_SIZE
+ ds = ds.shuffle(batch_size * 10, seed=42)
+
+ if source:
+ ds = ds.map(lambda x, y: (aug_func(x), y), num_parallel_calls=AUTO)
+ else:
+ ds = ds.map(lambda x, y: (aug_func(x, False), y), num_parallel_calls=AUTO)
+
+ ds = ds.batch(batch_size).prefetch(AUTO)
+ return ds
+
+
+"""
+`_w` and `_s` suffixes denote weak and strong respectively.
+"""
+
+source_ds = tf.data.Dataset.from_tensor_slices((mnist_x_train, mnist_y_train))
+source_ds_w = create_individual_ds(source_ds, weak_augment)
+source_ds_s = create_individual_ds(source_ds, strong_augment)
+final_source_ds = tf.data.Dataset.zip((source_ds_w, source_ds_s))
+
+target_ds_w = create_individual_ds(svhn_train, weak_augment, source=False)
+target_ds_s = create_individual_ds(svhn_train, strong_augment, source=False)
+final_target_ds = tf.data.Dataset.zip((target_ds_w, target_ds_s))
+
+"""
+Here's what a single image batch looks like:
+
+
+"""
+
+"""
+## Loss computation utilities
+"""
+
+
+def compute_loss_source(source_labels, logits_source_w, logits_source_s):
+ loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
+ # First compute the losses between original source labels and
+ # predictions made on the weakly and strongly augmented versions
+ # of the same images.
+ w_loss = loss_func(source_labels, logits_source_w)
+ s_loss = loss_func(source_labels, logits_source_s)
+ return w_loss + s_loss
+
+
+def compute_loss_target(target_pseudo_labels_w, logits_target_s, mask):
+ loss_func = keras.losses.CategoricalCrossentropy(from_logits=True, reduction="none")
+ target_pseudo_labels_w = tf.stop_gradient(target_pseudo_labels_w)
+ # For calculating loss for the target samples, we treat the pseudo labels
+ # as the ground-truth. These are not considered during backpropagation
+ # which is a standard SSL practice.
+ target_loss = loss_func(target_pseudo_labels_w, logits_target_s)
+
+ # More on `mask` later.
+ mask = tf.cast(mask, target_loss.dtype)
+ target_loss *= mask
+ return tf.reduce_mean(target_loss, 0)
+
+
+"""
+## Subclassed model for AdaMatch training
+
+The figure below presents the overall workflow of AdaMatch (taken from the
+[original paper](https://arxiv.org/abs/2106.04732)):
+
+
+
+Here's a brief step-by-step breakdown of the workflow:
+
+1. We first retrieve the weakly and strongly augmented pairs of images from the source and
+target datasets.
+2. We prepare two concatenated copies:
+ i. One where both pairs are concatenated.
+ ii. One where only the source data image pair is concatenated.
+3. We run two forward passes through the model:
+ i. The first forward pass uses the concatenated copy obtained from **2.i**. In
+this forward pass, the [Batch Normalization](https://arxiv.org/abs/1502.03167) statistics
+are updated.
+ ii. In the second forward pass, we only use the concatenated copy obtained from **2.ii**.
+ Batch Normalization layers are run in inference mode.
+4. The respective logits are computed for both the forward passes.
+5. The logits go through a series of transformations, introduced in the paper (which
+we will discuss shortly).
+6. We compute the loss and update the gradients of the underlying model.
+"""
+
+
+class AdaMatch(keras.Model):
+ def __init__(self, model, total_steps, tau=0.9):
+ super().__init__()
+ self.model = model
+ self.tau = tau # Denotes the confidence threshold
+ self.loss_tracker = tf.keras.metrics.Mean(name="loss")
+ self.total_steps = total_steps
+ self.current_step = tf.Variable(0, dtype="int64")
+
+ @property
+ def metrics(self):
+ return [self.loss_tracker]
+
+ # This is a warmup schedule to update the weight of the
+ # loss contributed by the target unlabeled samples. More
+ # on this in the text.
+ def compute_mu(self):
+ pi = tf.constant(np.pi, dtype="float32")
+ step = tf.cast(self.current_step, dtype="float32")
+ return 0.5 - tf.cos(tf.math.minimum(pi, (2 * pi * step) / self.total_steps)) / 2
+
+ def train_step(self, data):
+ ## Unpack and organize the data ##
+ source_ds, target_ds = data
+ (source_w, source_labels), (source_s, _) = source_ds
+ (
+ (target_w, _),
+ (target_s, _),
+ ) = target_ds # Notice that we are NOT using any labels here.
+
+ combined_images = tf.concat([source_w, source_s, target_w, target_s], 0)
+ combined_source = tf.concat([source_w, source_s], 0)
+
+ total_source = tf.shape(combined_source)[0]
+ total_target = tf.shape(tf.concat([target_w, target_s], 0))[0]
+
+ with tf.GradientTape() as tape:
+ ## Forward passes ##
+ combined_logits = self.model(combined_images, training=True)
+ z_d_prime_source = self.model(
+ combined_source, training=False
+ ) # No BatchNorm update.
+ z_prime_source = combined_logits[:total_source]
+
+ ## 1. Random logit interpolation for the source images ##
+ lambd = tf.random.uniform((total_source, 10), 0, 1)
+ final_source_logits = (lambd * z_prime_source) + (
+ (1 - lambd) * z_d_prime_source
+ )
+
+ ## 2. Distribution alignment (only consider weakly augmented images) ##
+ # Compute softmax for logits of the WEAKLY augmented SOURCE images.
+ y_hat_source_w = tf.nn.softmax(final_source_logits[: tf.shape(source_w)[0]])
+
+ # Extract logits for the WEAKLY augmented TARGET images and compute softmax.
+ logits_target = combined_logits[total_source:]
+ logits_target_w = logits_target[: tf.shape(target_w)[0]]
+ y_hat_target_w = tf.nn.softmax(logits_target_w)
+
+ # Align the target label distribution to that of the source.
+ expectation_ratio = tf.reduce_mean(y_hat_source_w) / tf.reduce_mean(
+ y_hat_target_w
+ )
+ y_tilde_target_w = tf.math.l2_normalize(
+ y_hat_target_w * expectation_ratio, 1
+ )
+
+ ## 3. Relative confidence thresholding ##
+ row_wise_max = tf.reduce_max(y_hat_source_w, axis=-1)
+ final_sum = tf.reduce_mean(row_wise_max, 0)
+ c_tau = self.tau * final_sum
+ mask = tf.reduce_max(y_tilde_target_w, axis=-1) >= c_tau
+
+ ## Compute losses (pay attention to the indexing) ##
+ source_loss = compute_loss_source(
+ source_labels,
+ final_source_logits[: tf.shape(source_w)[0]],
+ final_source_logits[tf.shape(source_w)[0] :],
+ )
+ target_loss = compute_loss_target(
+ y_tilde_target_w, logits_target[tf.shape(target_w)[0] :], mask
+ )
+
+ t = self.compute_mu() # Compute weight for the target loss
+ total_loss = source_loss + (t * target_loss)
+ self.current_step.assign_add(
+ 1
+ ) # Update current training step for the scheduler
+
+ gradients = tape.gradient(total_loss, self.model.trainable_variables)
+ self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
+
+ self.loss_tracker.update_state(total_loss)
+ return {"loss": self.loss_tracker.result()}
+
+
+"""
+The authors introduce three improvements in the paper:
+
+* In AdaMatch, we perform two forward passes, and only one of them is respsonsible for
+updating the Batch Normalization statistics. This is done to account for distribution
+shifts in the target dataset. In the other forward pass, we only use the source sample,
+and the Batch Normalization layers are run in inference mode. Logits for the source
+samples (weakly and strongly augmented versions) from these two passes are slightly
+different from one another because of how Batch Normalization layers are run. Final
+logits for the source samples are computed by linearly interpolating between these two
+different pairs of logits. This induces a form of consistency regularization. This step
+is referred to as **random logit interpolation**.
+* **Distribution alignment** is used to align the source and target label distributions.
+This further helps the underlying model learn *domain-invariant representations*. In case
+of unsupervised domain adaptation, we don't have access to any labels of the target
+dataset. This is why pseudo labels are generated from the underlying model.
+* The underlying model generates pseudo-labels for the target samples. It's likely that
+the model would make faulty predictions. Those can propagate back as we make progress in
+the training, and hurt the overall performance. To compensate for that, we filter the
+high-confidence predictions based on a threshold (hence the use of `mask` inside
+`compute_loss_target()`). In AdaMatch, this threshold is relatively adjusted which is why
+it is called **relative confidence thresholding**.
+
+For more details on these methods and to know how each of them contribute please refer to
+[the paper](https://arxiv.org/abs/2106.04732).
+
+**About `compute_mu()`**:
+
+Rather than using a fixed scalar quantity, a varying scalar is used in AdaMatch. It
+denotes the weight of the loss contibuted by the target samples. Visually, the weight
+scheduler look like so:
+
+
+
+This scheduler increases the weight of the target domain loss from 0 to 1 for the first
+half of the training. Then it keeps that weight at 1 for the second half of the training.
+"""
+
+"""
+## Instantiate a Wide-ResNet-28-2
+
+The authors use a [WideResNet-28-2](https://arxiv.org/abs/1605.07146) for the dataset
+pairs we are using in this example. Most of the following code has been referred from
+[this script](https://github.com/asmith26/wide_resnets_keras/blob/master/main.py). Note
+that the following model has a scaling layer inside it that scales the pixel values to
+[0, 1].
+"""
+
+
+def wide_basic(x, n_input_plane, n_output_plane, stride):
+ conv_params = [[3, 3, stride, "same"], [3, 3, (1, 1), "same"]]
+
+ n_bottleneck_plane = n_output_plane
+
+ # Residual block
+ for i, v in enumerate(conv_params):
+ if i == 0:
+ if n_input_plane != n_output_plane:
+ x = layers.BatchNormalization()(x)
+ x = layers.Activation("relu")(x)
+ convs = x
+ else:
+ convs = layers.BatchNormalization()(x)
+ convs = layers.Activation("relu")(convs)
+ convs = layers.Conv2D(
+ n_bottleneck_plane,
+ (v[0], v[1]),
+ strides=v[2],
+ padding=v[3],
+ kernel_initializer=INIT,
+ kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
+ use_bias=False,
+ )(convs)
+ else:
+ convs = layers.BatchNormalization()(convs)
+ convs = layers.Activation("relu")(convs)
+ convs = layers.Conv2D(
+ n_bottleneck_plane,
+ (v[0], v[1]),
+ strides=v[2],
+ padding=v[3],
+ kernel_initializer=INIT,
+ kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
+ use_bias=False,
+ )(convs)
+
+ # Shortcut connection: identity function or 1x1
+ # convolutional
+ # (depends on difference between input & output shape - this
+ # corresponds to whether we are using the first block in
+ # each
+ # group; see `block_series()`).
+ if n_input_plane != n_output_plane:
+ shortcut = layers.Conv2D(
+ n_output_plane,
+ (1, 1),
+ strides=stride,
+ padding="same",
+ kernel_initializer=INIT,
+ kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
+ use_bias=False,
+ )(x)
+ else:
+ shortcut = x
+
+ return layers.Add()([convs, shortcut])
+
+
+# Stacking residual units on the same stage
+def block_series(x, n_input_plane, n_output_plane, count, stride):
+ x = wide_basic(x, n_input_plane, n_output_plane, stride)
+ for i in range(2, int(count + 1)):
+ x = wide_basic(x, n_output_plane, n_output_plane, stride=1)
+ return x
+
+
+def get_network(image_size=32, num_classes=10):
+ n = (DEPTH - 4) / 6
+ n_stages = [16, 16 * WIDTH_MULT, 32 * WIDTH_MULT, 64 * WIDTH_MULT]
+
+ inputs = keras.Input(shape=(image_size, image_size, 3))
+ x = layers.Rescaling(scale=1.0 / 255)(inputs)
+
+ conv1 = layers.Conv2D(
+ n_stages[0],
+ (3, 3),
+ strides=1,
+ padding="same",
+ kernel_initializer=INIT,
+ kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
+ use_bias=False,
+ )(x)
+
+ ## Add wide residual blocks ##
+
+ conv2 = block_series(
+ conv1,
+ n_input_plane=n_stages[0],
+ n_output_plane=n_stages[1],
+ count=n,
+ stride=(1, 1),
+ ) # Stage 1
+
+ conv3 = block_series(
+ conv2,
+ n_input_plane=n_stages[1],
+ n_output_plane=n_stages[2],
+ count=n,
+ stride=(2, 2),
+ ) # Stage 2
+
+ conv4 = block_series(
+ conv3,
+ n_input_plane=n_stages[2],
+ n_output_plane=n_stages[3],
+ count=n,
+ stride=(2, 2),
+ ) # Stage 3
+
+ batch_norm = layers.BatchNormalization()(conv4)
+ relu = layers.Activation("relu")(batch_norm)
+
+ # Classifier
+ trunk_outputs = layers.GlobalAveragePooling2D()(relu)
+ outputs = layers.Dense(
+ num_classes, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
+ )(trunk_outputs)
+
+ return keras.Model(inputs, outputs)
+
+
+"""
+We can now instantiate a Wide ResNet model like so. Note that the purpose of using a
+Wide ResNet here is to keep the implementation as close to the original one
+as possible.
+"""
+
+wrn_model = get_network()
+print(f"Model has {wrn_model.count_params()/1e6} Million parameters.")
+
+"""
+## Instantiate AdaMatch model and compile it
+"""
+
+reduce_lr = keras.optimizers.schedules.CosineDecay(LEARNING_RATE, TOTAL_STEPS, 0.25)
+optimizer = keras.optimizers.Adam(reduce_lr)
+
+adamatch_trainer = AdaMatch(model=wrn_model, total_steps=TOTAL_STEPS)
+adamatch_trainer.compile(optimizer=optimizer)
+
+"""
+## Model training
+"""
+
+total_ds = tf.data.Dataset.zip((final_source_ds, final_target_ds))
+adamatch_trainer.fit(total_ds, epochs=EPOCHS)
+
+"""
+## Evaluation on the target and source test sets
+"""
+
+# Compile the AdaMatch model to yield accuracy.
+adamatch_trained_model = adamatch_trainer.model
+adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())
+
+# Score on the target test set.
+svhn_test = svhn_test.batch(TARGET_BATCH_SIZE).prefetch(AUTO)
+_, accuracy = adamatch_trained_model.evaluate(svhn_test)
+print(f"Accuracy on target test set: {accuracy * 100:.2f}%")
+
+"""
+With more training, this score improves. When this same network is trained with
+standard classification objective, it yields an accuracy of **7.20%** which is
+significantly lower than what we got with AdaMatch. You can check out
+[this notebook](https://colab.research.google.com/github/sayakpaul/AdaMatch-TF/blob/main/Vanilla_WideResNet.ipynb)
+to learn more about the hyperparameters and other experimental details.
+"""
+
+
+# Utility function for preprocessing the source test set.
+def prepare_test_ds_source(image, label):
+ image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
+ image = tf.tile(image, [1, 1, 3])
+ return image, label
+
+
+source_test_ds = tf.data.Dataset.from_tensor_slices((mnist_x_test, mnist_y_test))
+source_test_ds = (
+ source_test_ds.map(prepare_test_ds_source, num_parallel_calls=AUTO)
+ .batch(TARGET_BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# Evaluation on the source test set.
+_, accuracy = adamatch_trained_model.evaluate(source_test_ds)
+print(f"Accuracy on source test set: {accuracy * 100:.2f}%")
+
+"""
+You can reproduce the results by using these
+[model weights](https://github.com/sayakpaul/AdaMatch-TF/releases/tag/v1.0.0).
+"""
+
+"""
+**Example available on HuggingFace**
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/adamatch-domain-adaption) | [](https://huggingface.co/spaces/keras-io/adamatch-domain-adaption) |
+"""
diff --git a/knowledge_base/vision/attention_mil_classification.py b/knowledge_base/vision/attention_mil_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..a83dab57bd1e0d2f5a626b4772fc544b080b6608
--- /dev/null
+++ b/knowledge_base/vision/attention_mil_classification.py
@@ -0,0 +1,565 @@
+"""
+Title: Classification using Attention-based Deep Multiple Instance Learning (MIL).
+Author: [Mohamad Jaber](https://www.linkedin.com/in/mohamadjaber1/)
+Date created: 2021/08/16
+Last modified: 2021/11/25
+Description: MIL approach to classify bags of instances and get their individual instance score.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+### What is Multiple Instance Learning (MIL)?
+
+Usually, with supervised learning algorithms, the learner receives labels for a set of
+instances. In the case of MIL, the learner receives labels for a set of bags, each of which
+contains a set of instances. The bag is labeled positive if it contains at least
+one positive instance, and negative if it does not contain any.
+
+### Motivation
+
+It is often assumed in image classification tasks that each image clearly represents a
+class label. In medical imaging (e.g. computational pathology, etc.) an *entire image*
+is represented by a single class label (cancerous/non-cancerous) or a region of interest
+could be given. However, one will be interested in knowing which patterns in the image
+is actually causing it to belong to that class. In this context, the image(s) will be
+divided and the subimages will form the bag of instances.
+
+Therefore, the goals are to:
+
+1. Learn a model to predict a class label for a bag of instances.
+2. Find out which instances within the bag caused a position class label
+prediction.
+
+### Implementation
+
+The following steps describe how the model works:
+
+1. The feature extractor layers extract feature embeddings.
+2. The embeddings are fed into the MIL attention layer to get
+the attention scores. The layer is designed as permutation-invariant.
+3. Input features and their corresponding attention scores are multiplied together.
+4. The resulting output is passed to a softmax function for classification.
+
+### References
+
+- [Attention-based Deep Multiple Instance Learning](https://arxiv.org/abs/1802.04712).
+- Some of the attention operator code implementation was inspired from https://github.com/utayao/Atten_Deep_MIL.
+- Imbalanced data [tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data)
+by TensorFlow.
+
+"""
+"""
+## Setup
+"""
+
+import numpy as np
+import keras
+from keras import layers
+from keras import ops
+from tqdm import tqdm
+from matplotlib import pyplot as plt
+
+plt.style.use("ggplot")
+
+"""
+## Create dataset
+
+We will create a set of bags and assign their labels according to their contents.
+If at least one positive instance
+is available in a bag, the bag is considered as a positive bag. If it does not contain any
+positive instance, the bag will be considered as negative.
+
+### Configuration parameters
+
+- `POSITIVE_CLASS`: The desired class to be kept in the positive bag.
+- `BAG_COUNT`: The number of training bags.
+- `VAL_BAG_COUNT`: The number of validation bags.
+- `BAG_SIZE`: The number of instances in a bag.
+- `PLOT_SIZE`: The number of bags to plot.
+- `ENSEMBLE_AVG_COUNT`: The number of models to create and average together. (Optional:
+often results in better performance - set to 1 for single model)
+"""
+
+POSITIVE_CLASS = 1
+BAG_COUNT = 1000
+VAL_BAG_COUNT = 300
+BAG_SIZE = 3
+PLOT_SIZE = 3
+ENSEMBLE_AVG_COUNT = 1
+
+"""
+### Prepare bags
+
+Since the attention operator is a permutation-invariant operator, an instance with a
+positive class label is randomly placed among the instances in the positive bag.
+"""
+
+
+def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
+ # Set up bags.
+ bags = []
+ bag_labels = []
+
+ # Normalize input data.
+ input_data = np.divide(input_data, 255.0)
+
+ # Count positive samples.
+ count = 0
+
+ for _ in range(bag_count):
+ # Pick a fixed size random subset of samples.
+ index = np.random.choice(input_data.shape[0], instance_count, replace=False)
+ instances_data = input_data[index]
+ instances_labels = input_labels[index]
+
+ # By default, all bags are labeled as 0.
+ bag_label = 0
+
+ # Check if there is at least a positive class in the bag.
+ if positive_class in instances_labels:
+ # Positive bag will be labeled as 1.
+ bag_label = 1
+ count += 1
+
+ bags.append(instances_data)
+ bag_labels.append(np.array([bag_label]))
+
+ print(f"Positive bags: {count}")
+ print(f"Negative bags: {bag_count - count}")
+
+ return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))
+
+
+# Load the MNIST dataset.
+(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()
+
+# Create training data.
+train_data, train_labels = create_bags(
+ x_train, y_train, POSITIVE_CLASS, BAG_COUNT, BAG_SIZE
+)
+
+# Create validation data.
+val_data, val_labels = create_bags(
+ x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
+)
+
+"""
+## Create the model
+
+We will now build the attention layer, prepare some utilities, then build and train the
+entire model.
+
+### Attention operator implementation
+
+The output size of this layer is decided by the size of a single bag.
+
+The attention mechanism uses a weighted average of instances in a bag, in which the sum
+of the weights must equal to 1 (invariant of the bag size).
+
+The weight matrices (parameters) are **w** and **v**. To include positive and negative
+values, hyperbolic tangent element-wise non-linearity is utilized.
+
+A **Gated attention mechanism** can be used to deal with complex relations. Another weight
+matrix, **u**, is added to the computation.
+A sigmoid non-linearity is used to overcome approximately linear behavior for *x* โ [โ1, 1]
+by hyperbolic tangent non-linearity.
+"""
+
+
+class MILAttentionLayer(layers.Layer):
+ """Implementation of the attention-based Deep MIL layer.
+
+ Args:
+ weight_params_dim: Positive Integer. Dimension of the weight matrix.
+ kernel_initializer: Initializer for the `kernel` matrix.
+ kernel_regularizer: Regularizer function applied to the `kernel` matrix.
+ use_gated: Boolean, whether or not to use the gated mechanism.
+
+ Returns:
+ List of 2D tensors with BAG_SIZE length.
+ The tensors are the attention scores after softmax with shape `(batch_size, 1)`.
+ """
+
+ def __init__(
+ self,
+ weight_params_dim,
+ kernel_initializer="glorot_uniform",
+ kernel_regularizer=None,
+ use_gated=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.weight_params_dim = weight_params_dim
+ self.use_gated = use_gated
+
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
+ self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
+
+ self.v_init = self.kernel_initializer
+ self.w_init = self.kernel_initializer
+ self.u_init = self.kernel_initializer
+
+ self.v_regularizer = self.kernel_regularizer
+ self.w_regularizer = self.kernel_regularizer
+ self.u_regularizer = self.kernel_regularizer
+
+ def build(self, input_shape):
+ # Input shape.
+ # List of 2D tensors with shape: (batch_size, input_dim).
+ input_dim = input_shape[0][1]
+
+ self.v_weight_params = self.add_weight(
+ shape=(input_dim, self.weight_params_dim),
+ initializer=self.v_init,
+ name="v",
+ regularizer=self.v_regularizer,
+ trainable=True,
+ )
+
+ self.w_weight_params = self.add_weight(
+ shape=(self.weight_params_dim, 1),
+ initializer=self.w_init,
+ name="w",
+ regularizer=self.w_regularizer,
+ trainable=True,
+ )
+
+ if self.use_gated:
+ self.u_weight_params = self.add_weight(
+ shape=(input_dim, self.weight_params_dim),
+ initializer=self.u_init,
+ name="u",
+ regularizer=self.u_regularizer,
+ trainable=True,
+ )
+ else:
+ self.u_weight_params = None
+
+ self.input_built = True
+
+ def call(self, inputs):
+ # Assigning variables from the number of inputs.
+ instances = [self.compute_attention_scores(instance) for instance in inputs]
+
+ # Stack instances into a single tensor.
+ instances = ops.stack(instances)
+
+ # Apply softmax over instances such that the output summation is equal to 1.
+ alpha = ops.softmax(instances, axis=0)
+
+ # Split to recreate the same array of tensors we had as inputs.
+ return [alpha[i] for i in range(alpha.shape[0])]
+
+ def compute_attention_scores(self, instance):
+ # Reserve in-case "gated mechanism" used.
+ original_instance = instance
+
+ # tanh(v*h_k^T)
+ instance = ops.tanh(ops.tensordot(instance, self.v_weight_params, axes=1))
+
+ # for learning non-linear relations efficiently.
+ if self.use_gated:
+ instance = instance * ops.sigmoid(
+ ops.tensordot(original_instance, self.u_weight_params, axes=1)
+ )
+
+ # w^T*(tanh(v*h_k^T)) / w^T*(tanh(v*h_k^T)*sigmoid(u*h_k^T))
+ return ops.tensordot(instance, self.w_weight_params, axes=1)
+
+
+"""
+## Visualizer tool
+
+Plot the number of bags (given by `PLOT_SIZE`) with respect to the class.
+
+Moreover, if activated, the class label prediction with its associated instance score
+for each bag (after the model has been trained) can be seen.
+"""
+
+
+def plot(data, labels, bag_class, predictions=None, attention_weights=None):
+ """ "Utility for plotting bags and attention weights.
+
+ Args:
+ data: Input data that contains the bags of instances.
+ labels: The associated bag labels of the input data.
+ bag_class: String name of the desired bag class.
+ The options are: "positive" or "negative".
+ predictions: Class labels model predictions.
+ If you don't specify anything, ground truth labels will be used.
+ attention_weights: Attention weights for each instance within the input data.
+ If you don't specify anything, the values won't be displayed.
+ """
+ return ## TODO
+ labels = np.array(labels).reshape(-1)
+
+ if bag_class == "positive":
+ if predictions is not None:
+ labels = np.where(predictions.argmax(1) == 1)[0]
+ bags = np.array(data)[:, labels[0:PLOT_SIZE]]
+
+ else:
+ labels = np.where(labels == 1)[0]
+ bags = np.array(data)[:, labels[0:PLOT_SIZE]]
+
+ elif bag_class == "negative":
+ if predictions is not None:
+ labels = np.where(predictions.argmax(1) == 0)[0]
+ bags = np.array(data)[:, labels[0:PLOT_SIZE]]
+ else:
+ labels = np.where(labels == 0)[0]
+ bags = np.array(data)[:, labels[0:PLOT_SIZE]]
+
+ else:
+ print(f"There is no class {bag_class}")
+ return
+
+ print(f"The bag class label is {bag_class}")
+ for i in range(PLOT_SIZE):
+ figure = plt.figure(figsize=(8, 8))
+ print(f"Bag number: {labels[i]}")
+ for j in range(BAG_SIZE):
+ image = bags[j][i]
+ figure.add_subplot(1, BAG_SIZE, j + 1)
+ plt.grid(False)
+ if attention_weights is not None:
+ plt.title(np.around(attention_weights[labels[i]][j], 2))
+ plt.imshow(image)
+ plt.show()
+
+
+# Plot some of validation data bags per class.
+plot(val_data, val_labels, "positive")
+plot(val_data, val_labels, "negative")
+
+"""
+## Create model
+
+First we will create some embeddings per instance, invoke the attention operator and then
+use the softmax function to output the class probabilities.
+"""
+
+
+def create_model(instance_shape):
+ # Extract features from inputs.
+ inputs, embeddings = [], []
+ shared_dense_layer_1 = layers.Dense(128, activation="relu")
+ shared_dense_layer_2 = layers.Dense(64, activation="relu")
+ for _ in range(BAG_SIZE):
+ inp = layers.Input(instance_shape)
+ flatten = layers.Flatten()(inp)
+ dense_1 = shared_dense_layer_1(flatten)
+ dense_2 = shared_dense_layer_2(dense_1)
+ inputs.append(inp)
+ embeddings.append(dense_2)
+
+ # Invoke the attention layer.
+ alpha = MILAttentionLayer(
+ weight_params_dim=256,
+ kernel_regularizer=keras.regularizers.L2(0.01),
+ use_gated=True,
+ name="alpha",
+ )(embeddings)
+
+ # Multiply attention weights with the input layers.
+ multiply_layers = [
+ layers.multiply([alpha[i], embeddings[i]]) for i in range(len(alpha))
+ ]
+
+ # Concatenate layers.
+ concat = layers.concatenate(multiply_layers, axis=1)
+
+ # Classification output node.
+ output = layers.Dense(2, activation="softmax")(concat)
+
+ return keras.Model(inputs, output)
+
+
+"""
+## Class weights
+
+Since this kind of problem could simply turn into imbalanced data classification problem,
+class weighting should be considered.
+
+Let's say there are 1000 bags. There often could be cases were ~90 % of the bags do not
+contain any positive label and ~10 % do.
+Such data can be referred to as **Imbalanced data**.
+
+Using class weights, the model will tend to give a higher weight to the rare class.
+"""
+
+
+def compute_class_weights(labels):
+ # Count number of positive and negative bags.
+ negative_count = len(np.where(labels == 0)[0])
+ positive_count = len(np.where(labels == 1)[0])
+ total_count = negative_count + positive_count
+
+ # Build class weight dictionary.
+ return {
+ 0: (1 / negative_count) * (total_count / 2),
+ 1: (1 / positive_count) * (total_count / 2),
+ }
+
+
+"""
+## Build and train model
+
+The model is built and trained in this section.
+"""
+
+
+def train(train_data, train_labels, val_data, val_labels, model):
+ # Train model.
+ # Prepare callbacks.
+ # Path where to save best weights.
+
+ # Take the file name from the wrapper.
+ file_path = "/tmp/best_model.weights.h5"
+
+ # Initialize model checkpoint callback.
+ model_checkpoint = keras.callbacks.ModelCheckpoint(
+ file_path,
+ monitor="val_loss",
+ verbose=0,
+ mode="min",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ # Initialize early stopping callback.
+ # The model performance is monitored across the validation data and stops training
+ # when the generalization error cease to decrease.
+ early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_loss", patience=10, mode="min"
+ )
+
+ # Compile model.
+ model.compile(
+ optimizer="adam",
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+ )
+
+ # Fit model.
+ model.fit(
+ train_data,
+ train_labels,
+ validation_data=(val_data, val_labels),
+ epochs=20,
+ class_weight=compute_class_weights(train_labels),
+ batch_size=1,
+ callbacks=[early_stopping, model_checkpoint],
+ verbose=0,
+ )
+
+ # Load best weights.
+ model.load_weights(file_path)
+
+ return model
+
+
+# Building model(s).
+instance_shape = train_data[0][0].shape
+models = [create_model(instance_shape) for _ in range(ENSEMBLE_AVG_COUNT)]
+
+# Show single model architecture.
+print(models[0].summary())
+
+# Training model(s).
+trained_models = [
+ train(train_data, train_labels, val_data, val_labels, model)
+ for model in tqdm(models)
+]
+
+"""
+## Model evaluation
+
+The models are now ready for evaluation.
+With each model we also create an associated intermediate model to get the
+weights from the attention layer.
+
+We will compute a prediction for each of our `ENSEMBLE_AVG_COUNT` models, and
+average them together for our final prediction.
+"""
+
+
+def predict(data, labels, trained_models):
+ # Collect info per model.
+ models_predictions = []
+ models_attention_weights = []
+ models_losses = []
+ models_accuracies = []
+
+ for model in trained_models:
+ # Predict output classes on data.
+ predictions = model.predict(data)
+ models_predictions.append(predictions)
+
+ # Create intermediate model to get MIL attention layer weights.
+ intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)
+
+ # Predict MIL attention layer weights.
+ intermediate_predictions = intermediate_model.predict(data)
+
+ attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
+ models_attention_weights.append(attention_weights)
+
+ loss, accuracy = model.evaluate(data, labels, verbose=0)
+ models_losses.append(loss)
+ models_accuracies.append(accuracy)
+
+ print(
+ f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
+ f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
+ )
+
+ return (
+ np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
+ np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
+ )
+
+
+# Evaluate and predict classes and attention scores on validation data.
+class_predictions, attention_params = predict(val_data, val_labels, trained_models)
+
+# Plot some results from our validation data.
+plot(
+ val_data,
+ val_labels,
+ "positive",
+ predictions=class_predictions,
+ attention_weights=attention_params,
+)
+plot(
+ val_data,
+ val_labels,
+ "negative",
+ predictions=class_predictions,
+ attention_weights=attention_params,
+)
+
+"""
+## Conclusion
+
+From the above plot, you can notice that the weights always sum to 1. In a
+positively predict bag, the instance which resulted in the positive labeling will have
+a substantially higher attention score than the rest of the bag. However, in a negatively
+predicted bag, there are two cases:
+
+* All instances will have approximately similar scores.
+* An instance will have relatively higher score (but not as high as of a positive instance).
+This is because the feature space of this instance is close to that of the positive instance.
+
+## Remarks
+
+- If the model is overfit, the weights will be equally distributed for all bags. Hence,
+the regularization techniques are necessary.
+- In the paper, the bag sizes can differ from one bag to another. For simplicity, the
+bag sizes are fixed here.
+- In order not to rely on the random initial weights of a single model, averaging ensemble
+methods should be considered.
+"""
diff --git a/knowledge_base/vision/autoencoder.py b/knowledge_base/vision/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d36cfb23f69853934663e806ca4c30578bc55a
--- /dev/null
+++ b/knowledge_base/vision/autoencoder.py
@@ -0,0 +1,165 @@
+"""
+Title: Convolutional autoencoder for image denoising
+Author: [Santiago L. Valdarrama](https://twitter.com/svpino)
+Date created: 2021/03/01
+Last modified: 2021/03/01
+Description: How to train a deep convolutional autoencoder for image denoising.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to implement a deep convolutional autoencoder
+for image denoising, mapping noisy digits images from the MNIST dataset to
+clean digits images. This implementation is based on an original blog post
+titled [Building Autoencoders in Keras](https://blog.keras.io/building-autoencoders-in-keras.html)
+by [Franรงois Chollet](https://twitter.com/fchollet).
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from keras import layers
+from keras.datasets import mnist
+from keras.models import Model
+
+
+def preprocess(array):
+ """Normalizes the supplied array and reshapes it."""
+ array = array.astype("float32") / 255.0
+ array = np.reshape(array, (len(array), 28, 28, 1))
+ return array
+
+
+def noise(array):
+ """Adds random noise to each image in the supplied array."""
+ noise_factor = 0.4
+ noisy_array = array + noise_factor * np.random.normal(
+ loc=0.0, scale=1.0, size=array.shape
+ )
+
+ return np.clip(noisy_array, 0.0, 1.0)
+
+
+def display(array1, array2):
+ """Displays ten random images from each array."""
+ n = 10
+ indices = np.random.randint(len(array1), size=n)
+ images1 = array1[indices, :]
+ images2 = array2[indices, :]
+
+ plt.figure(figsize=(20, 4))
+ for i, (image1, image2) in enumerate(zip(images1, images2)):
+ ax = plt.subplot(2, n, i + 1)
+ plt.imshow(image1.reshape(28, 28))
+ plt.gray()
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+
+ ax = plt.subplot(2, n, i + 1 + n)
+ plt.imshow(image2.reshape(28, 28))
+ plt.gray()
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+
+ plt.show()
+
+
+"""
+## Prepare the data
+"""
+
+# Since we only need images from the dataset to encode and decode, we
+# won't use the labels.
+(train_data, _), (test_data, _) = mnist.load_data()
+
+# Normalize and reshape the data
+train_data = preprocess(train_data)
+test_data = preprocess(test_data)
+
+# Create a copy of the data with added noise
+noisy_train_data = noise(train_data)
+noisy_test_data = noise(test_data)
+
+# Display the train data and a version of it with added noise
+display(train_data, noisy_train_data)
+
+"""
+## Build the autoencoder
+
+We are going to use the Functional API to build our convolutional autoencoder.
+"""
+
+input = layers.Input(shape=(28, 28, 1))
+
+# Encoder
+x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(input)
+x = layers.MaxPooling2D((2, 2), padding="same")(x)
+x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(x)
+x = layers.MaxPooling2D((2, 2), padding="same")(x)
+
+# Decoder
+x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
+x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation="relu", padding="same")(x)
+x = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(x)
+
+# Autoencoder
+autoencoder = Model(input, x)
+autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
+autoencoder.summary()
+
+"""
+Now we can train our autoencoder using `train_data` as both our input data
+and target. Notice we are setting up the validation data using the same
+format.
+"""
+
+autoencoder.fit(
+ x=train_data,
+ y=train_data,
+ epochs=50,
+ batch_size=128,
+ shuffle=True,
+ validation_data=(test_data, test_data),
+)
+
+"""
+Let's predict on our test dataset and display the original image together with
+the prediction from our autoencoder.
+
+Notice how the predictions are pretty close to the original images, although
+not quite the same.
+"""
+
+predictions = autoencoder.predict(test_data)
+display(test_data, predictions)
+
+"""
+Now that we know that our autoencoder works, let's retrain it using the noisy
+data as our input and the clean data as our target. We want our autoencoder to
+learn how to denoise the images.
+"""
+
+autoencoder.fit(
+ x=noisy_train_data,
+ y=train_data,
+ epochs=100,
+ batch_size=128,
+ shuffle=True,
+ validation_data=(noisy_test_data, test_data),
+)
+
+"""
+Let's now predict on the noisy data and display the results of our autoencoder.
+
+Notice how the autoencoder does an amazing job at removing the noise from the
+input images.
+"""
+
+predictions = autoencoder.predict(noisy_test_data)
+display(noisy_test_data, predictions)
diff --git a/knowledge_base/vision/barlow_twins.py b/knowledge_base/vision/barlow_twins.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaea461b19802088f46890d2664693a3ea3b9c5d
--- /dev/null
+++ b/knowledge_base/vision/barlow_twins.py
@@ -0,0 +1,1050 @@
+"""
+Title: Barlow Twins for Contrastive SSL
+Author: [Abhiraam Eranti](https://github.com/dewball345)
+Date created: 11/4/21
+Last modified: 12/20/21
+Description: A keras implementation of Barlow Twins (constrastive SSL with redundancy reduction).
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+"""
+
+"""
+Self-supervised learning (SSL) is a relatively novel technique in which a model
+learns from unlabeled data, and is often used when the data is corrupted or
+if there is very little of it. A practical use for SSL is to create
+intermediate embeddings that are learned from the data. These embeddings are
+based on the dataset itself, with similar images having similar embeddings, and
+vice versa. They are then attached to the rest of the model, which uses those
+embeddings as information and effectively learns and makes predictions properly.
+These embeddings, ideally, should contain as much information and insight about
+the data as possible, so that the model can make better predictions. However,
+a common problem that arises is that the model creates embeddings that are
+redundant. For example, if two images are similar, the model will create
+embeddings that are just a string of 1's, or some other value that
+contains repeating bits of information. This is no better than a one-hot
+encoding or just having one bit as the modelโs representations; it defeats the
+purpose of the embeddings, as they do not learn as much about the dataset as
+possible. For other approaches, the solution to the problem was to carefully
+configure the model such that it tries not to be redundant.
+
+
+Barlow Twins is a new approach to this problem; while other solutions mainly
+tackle the first goal of invariance (similar images have similar embeddings),
+the Barlow Twins method also prioritizes the goal of reducing redundancy.
+
+It also has the advantage of being much simpler than other methods, and its
+model architecture is symmetric, meaning that both twins in the model do the
+same thing. It is also near state-of-the-art on imagenet, even exceeding methods
+like SimCLR.
+
+
+One disadvantage of Barlow Twins is that it is heavily dependent on
+augmentation, suffering major performance decreases in accuracy without them.
+
+TL, DR: Barlow twins creates representations that are:
+
+* Invariant.
+* Not redundant, and carry as much info about the dataset.
+
+Also, it is simpler than other methods.
+
+This notebook can train a Barlow Twins model and reach up to
+64% validation accuracy on the CIFAR-10 dataset.
+"""
+
+"""
+
+
+
+
+
+
+
+"""
+
+"""
+### High-Level Theory
+
+
+"""
+
+"""
+The model takes two versions of the same image(with different augmentations) as
+input. Then it takes a prediction of each of them, creating representations.
+They are then used to make a cross-correlation matrix.
+
+Cross-correlation matrix:
+```
+(pred_1.T @ pred_2) / batch_size
+```
+
+The cross-correlation matrix measures the correlation between the output
+neurons in the two representations made by the model predictions of the two
+augmented versions of data. Ideally, a cross-correlation matrix should look
+like an identity matrix if the two images are the same.
+
+When this happens, it means that the representations:
+
+1. Are invariant. The diagonal shows the correlation between each
+representation's neurons and its corresponding augmented one. Because the two
+versions come from the same image, the diagonal of the matrix should show that
+there is a strong correlation between them. If the images are different, there
+shouldn't be a diagonal.
+2. Do not show signs of redundancy. If the neurons show correlation with a
+non-diagonal neuron, it means that it is not correctly identifying similarities
+between the two augmented images. This means that it is redundant.
+
+Here is a good way of understanding in pseudocode(information from the original
+paper):
+
+```
+c[i][i] = 1
+c[i][j] = 0
+
+where:
+ c is the cross-correlation matrix
+ i is the index of one representation's neuron
+ j is the index of the second representation's neuron
+```
+"""
+
+"""
+Taken from the original paper: [Barlow Twins: Self-Supervised Learning via Redundancy
+Reduction](https://arxiv.org/abs/2103.03230)
+"""
+
+"""
+### References
+"""
+
+"""
+Paper:
+[Barlow Twins: Self-Supervised Learning via Redundancy
+Reduction](https://arxiv.org/abs/2103.03230)
+
+Original Implementation:
+ [facebookresearch/barlowtwins](https://github.com/facebookresearch/barlowtwins)
+
+
+"""
+
+"""
+## Setup
+"""
+
+"""shell
+pip install tensorflow-addons
+"""
+
+import os
+
+# slightly faster improvements, on the first epoch 30 second decrease and a 1-2 second
+# decrease in epoch time. Overall saves approx. 5 min of training time
+
+# Allocates two threads for a gpu private which allows more operations to be
+# done faster
+os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
+
+import tensorflow as tf # framework
+from tensorflow import keras # for tf.keras
+import tensorflow_addons as tfa # LAMB optimizer and gaussian_blur_2d function
+import numpy as np # np.random.random
+import matplotlib.pyplot as plt # graphs
+import datetime # tensorboard logs naming
+
+# XLA optimization for faster performance(up to 10-15 minutes total time saved)
+tf.config.optimizer.set_jit(True)
+
+"""
+## Load the CIFAR-10 dataset
+"""
+
+[
+ (train_features, train_labels),
+ (test_features, test_labels),
+] = keras.datasets.cifar10.load_data()
+
+train_features = train_features / 255.0
+test_features = test_features / 255.0
+
+"""
+## Necessary Hyperparameters
+"""
+
+# Batch size of dataset
+BATCH_SIZE = 512
+# Width and height of image
+IMAGE_SIZE = 32
+
+"""
+## Augmentation Utilities
+The Barlow twins algorithm is heavily reliant on
+Augmentation. One unique feature of the method is that sometimes, augmentations
+probabilistically occur.
+
+**Augmentations**
+
+* *RandomToGrayscale*: randomly applies grayscale to image 20% of the time
+* *RandomColorJitter*: randomly applies color jitter 80% of the time
+* *RandomFlip*: randomly flips image horizontally 50% of the time
+* *RandomResizedCrop*: randomly crops an image to a random size then resizes. This
+happens 100% of the time
+* *RandomSolarize*: randomly applies solarization to an image 20% of the time
+* *RandomBlur*: randomly blurs an image 20% of the time
+"""
+
+
+class Augmentation(keras.layers.Layer):
+ """Base augmentation class.
+
+ Base augmentation class. Contains the random_execute method.
+
+ Methods:
+ random_execute: method that returns true or false based
+ on a probability. Used to determine whether an augmentation
+ will be run.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ @tf.function
+ def random_execute(self, prob: float) -> bool:
+ """random_execute function.
+
+ Arguments:
+ prob: a float value from 0-1 that determines the
+ probability.
+
+ Returns:
+ returns true or false based on the probability.
+ """
+
+ return tf.random.uniform([], minval=0, maxval=1) < prob
+
+
+class RandomToGrayscale(Augmentation):
+ """RandomToGrayscale class.
+
+ RandomToGrayscale class. Randomly makes an image
+ grayscaled based on the random_execute method. There
+ is a 20% chance that an image will be grayscaled.
+
+ Methods:
+ call: method that grayscales an image 20% of
+ the time.
+ """
+
+ @tf.function
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """call function.
+
+ Arguments:
+ x: a tf.Tensor representing the image.
+
+ Returns:
+ returns a grayscaled version of the image 20% of the time
+ and the original image 80% of the time.
+ """
+
+ if self.random_execute(0.2):
+ x = tf.image.rgb_to_grayscale(x)
+ x = tf.tile(x, [1, 1, 3])
+ return x
+
+
+class RandomColorJitter(Augmentation):
+ """RandomColorJitter class.
+
+ RandomColorJitter class. Randomly adds color jitter to an image.
+ Color jitter means to add random brightness, contrast,
+ saturation, and hue to an image. There is a 80% chance that an
+ image will be randomly color-jittered.
+
+ Methods:
+ call: method that color-jitters an image 80% of
+ the time.
+ """
+
+ @tf.function
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """call function.
+
+ Adds color jitter to image, including:
+ Brightness change by a max-delta of 0.8
+ Contrast change by a max-delta of 0.8
+ Saturation change by a max-delta of 0.8
+ Hue change by a max-delta of 0.2
+ Originally, the same deltas of the original paper
+ were used, but a performance boost of almost 2% was found
+ when doubling them.
+
+ Arguments:
+ x: a tf.Tensor representing the image.
+
+ Returns:
+ returns a color-jittered version of the image 80% of the time
+ and the original image 20% of the time.
+ """
+
+ if self.random_execute(0.8):
+ x = tf.image.random_brightness(x, 0.8)
+ x = tf.image.random_contrast(x, 0.4, 1.6)
+ x = tf.image.random_saturation(x, 0.4, 1.6)
+ x = tf.image.random_hue(x, 0.2)
+ return x
+
+
+class RandomFlip(Augmentation):
+ """RandomFlip class.
+
+ RandomFlip class. Randomly flips image horizontally. There is a 50%
+ chance that an image will be randomly flipped.
+
+ Methods:
+ call: method that flips an image 50% of
+ the time.
+ """
+
+ @tf.function
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """call function.
+
+ Randomly flips the image.
+
+ Arguments:
+ x: a tf.Tensor representing the image.
+
+ Returns:
+ returns a flipped version of the image 50% of the time
+ and the original image 50% of the time.
+ """
+
+ if self.random_execute(0.5):
+ x = tf.image.random_flip_left_right(x)
+ return x
+
+
+class RandomResizedCrop(Augmentation):
+ """RandomResizedCrop class.
+
+ RandomResizedCrop class. Randomly crop an image to a random size,
+ then resize the image back to the original size.
+
+ Attributes:
+ image_size: The dimension of the image
+
+ Methods:
+ __call__: method that does random resize crop to the image.
+ """
+
+ def __init__(self, image_size):
+ super().__init__()
+ self.image_size = image_size
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """call function.
+
+ Does random resize crop by randomly cropping an image to a random
+ size 75% - 100% the size of the image. Then resizes it.
+
+ Arguments:
+ x: a tf.Tensor representing the image.
+
+ Returns:
+ returns a randomly cropped image.
+ """
+
+ rand_size = tf.random.uniform(
+ shape=[],
+ minval=int(0.75 * self.image_size),
+ maxval=1 * self.image_size,
+ dtype=tf.int32,
+ )
+
+ crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
+ crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
+ return crop_resize
+
+
+class RandomSolarize(Augmentation):
+ """RandomSolarize class.
+
+ RandomSolarize class. Randomly solarizes an image.
+ Solarization is when pixels accidentally flip to an inverted state.
+
+ Methods:
+ call: method that does random solarization 20% of the time.
+ """
+
+ @tf.function
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """call function.
+
+ Randomly solarizes the image.
+
+ Arguments:
+ x: a tf.Tensor representing the image.
+
+ Returns:
+ returns a solarized version of the image 20% of the time
+ and the original image 80% of the time.
+ """
+
+ if self.random_execute(0.2):
+ # flips abnormally low pixels to abnormally high pixels
+ x = tf.where(x < 10, x, 255 - x)
+ return x
+
+
+class RandomBlur(Augmentation):
+ """RandomBlur class.
+
+ RandomBlur class. Randomly blurs an image.
+
+ Methods:
+ call: method that does random blur 20% of the time.
+ """
+
+ @tf.function
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """call function.
+
+ Randomly solarizes the image.
+
+ Arguments:
+ x: a tf.Tensor representing the image.
+
+ Returns:
+ returns a blurred version of the image 20% of the time
+ and the original image 80% of the time.
+ """
+
+ if self.random_execute(0.2):
+ s = np.random.random()
+ return tfa.image.gaussian_filter2d(image=x, sigma=s)
+ return x
+
+
+class RandomAugmentor(keras.Model):
+ """RandomAugmentor class.
+
+ RandomAugmentor class. Chains all the augmentations into
+ one pipeline.
+
+ Attributes:
+ image_size: An integer represing the width and height
+ of the image. Designed to be used for square images.
+ random_resized_crop: Instance variable representing the
+ RandomResizedCrop layer.
+ random_flip: Instance variable representing the
+ RandomFlip layer.
+ random_color_jitter: Instance variable representing the
+ RandomColorJitter layer.
+ random_blur: Instance variable representing the
+ RandomBlur layer
+ random_to_grayscale: Instance variable representing the
+ RandomToGrayscale layer
+ random_solarize: Instance variable representing the
+ RandomSolarize layer
+
+ Methods:
+ call: chains layers in pipeline together
+ """
+
+ def __init__(self, image_size: int):
+ super().__init__()
+
+ self.image_size = image_size
+ self.random_resized_crop = RandomResizedCrop(image_size)
+ self.random_flip = RandomFlip()
+ self.random_color_jitter = RandomColorJitter()
+ self.random_blur = RandomBlur()
+ self.random_to_grayscale = RandomToGrayscale()
+ self.random_solarize = RandomSolarize()
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ x = self.random_resized_crop(x)
+ x = self.random_flip(x)
+ x = self.random_color_jitter(x)
+ x = self.random_blur(x)
+ x = self.random_to_grayscale(x)
+ x = self.random_solarize(x)
+
+ x = tf.clip_by_value(x, 0, 1)
+ return x
+
+
+bt_augmentor = RandomAugmentor(IMAGE_SIZE)
+
+"""
+## Data Loading
+
+A class that creates the barlow twins' dataset.
+
+The dataset consists of two copies of each image, with each copy receiving different
+augmentations.
+"""
+
+
+class BTDatasetCreator:
+ """Barlow twins dataset creator class.
+
+ BTDatasetCreator class. Responsible for creating the
+ barlow twins' dataset.
+
+ Attributes:
+ options: tf.data.Options needed to configure a setting
+ that may improve performance.
+ seed: random seed for shuffling. Used to synchronize two
+ augmented versions.
+ augmentor: augmentor used for augmentation.
+
+ Methods:
+ __call__: creates barlow dataset.
+ augmented_version: creates 1 half of the dataset.
+ """
+
+ def __init__(self, augmentor: RandomAugmentor, seed: int = 1024):
+ self.options = tf.data.Options()
+ self.options.threading.max_intra_op_parallelism = 1
+ self.seed = seed
+ self.augmentor = augmentor
+
+ def augmented_version(self, ds: list) -> tf.data.Dataset:
+ return (
+ tf.data.Dataset.from_tensor_slices(ds)
+ .shuffle(1000, seed=self.seed)
+ .map(self.augmentor, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(BATCH_SIZE, drop_remainder=True)
+ .prefetch(tf.data.AUTOTUNE)
+ .with_options(self.options)
+ )
+
+ def __call__(self, ds: list) -> tf.data.Dataset:
+ a1 = self.augmented_version(ds)
+ a2 = self.augmented_version(ds)
+
+ return tf.data.Dataset.zip((a1, a2)).with_options(self.options)
+
+
+augment_versions = BTDatasetCreator(bt_augmentor)(train_features)
+
+"""
+View examples of dataset.
+"""
+
+sample_augment_versions = iter(augment_versions)
+
+
+def plot_values(batch: tuple):
+ fig, axs = plt.subplots(3, 3)
+ fig1, axs1 = plt.subplots(3, 3)
+
+ fig.suptitle("Augmentation 1")
+ fig1.suptitle("Augmentation 2")
+
+ a1, a2 = batch
+
+ # plots images on both tables
+ for i in range(3):
+ for j in range(3):
+ # CHANGE(add / 255)
+ axs[i][j].imshow(a1[3 * i + j])
+ axs[i][j].axis("off")
+ axs1[i][j].imshow(a2[3 * i + j])
+ axs1[i][j].axis("off")
+
+ plt.show()
+
+
+plot_values(next(sample_augment_versions))
+
+"""
+## Pseudocode of loss and model
+The following sections follow the original author's pseudocode containing both model and
+loss functions(see diagram below). Also contains a reference of variables used.
+"""
+
+"""
+
+"""
+
+"""
+Reference:
+
+```
+y_a: first augmented version of original image.
+y_b: second augmented version of original image.
+z_a: model representation(embeddings) of y_a.
+z_b: model representation(embeddings) of y_b.
+z_a_norm: normalized z_a.
+z_b_norm: normalized z_b.
+c: cross correlation matrix.
+c_diff: diagonal portion of loss(invariance term).
+off_diag: off-diagonal portion of loss(redundancy reduction term).
+```
+"""
+
+"""
+## BarlowLoss: barlow twins model's loss function
+
+Barlow Twins uses the cross correlation matrix for its loss. There are two parts to the
+loss function:
+
+* ***The invariance term***(diagonal). This part is used to make the diagonals of the
+matrix into 1s. When this is the case, the matrix shows that the images are
+correlated(same).
+ * The loss function subtracts 1 from the diagonal and squares the values.
+* ***The redundancy reduction term***(off-diagonal). Here, the barlow twins loss
+function aims to make these values zero. As mentioned before, it is redundant if the
+representation neurons are correlated with values that are not on the diagonal.
+ * Off diagonals are squared.
+
+After this the two parts are summed together.
+
+
+
+
+"""
+
+
+class BarlowLoss(keras.losses.Loss):
+ """BarlowLoss class.
+
+ BarlowLoss class. Creates a loss function based on the cross-correlation
+ matrix.
+
+ Attributes:
+ batch_size: the batch size of the dataset
+ lambda_amt: the value for lambda(used in cross_corr_matrix_loss)
+
+ Methods:
+ __init__: gets instance variables
+ call: gets the loss based on the cross-correlation matrix
+ make_diag_zeros: Used in calculating off-diagonal section
+ of loss function; makes diagonals zeros.
+ cross_corr_matrix_loss: creates loss based on cross correlation
+ matrix.
+ """
+
+ def __init__(self, batch_size: int):
+ """__init__ method.
+
+ Gets the instance variables
+
+ Arguments:
+ batch_size: An integer value representing the batch size of the
+ dataset. Used for cross correlation matrix calculation.
+ """
+
+ super().__init__()
+ self.lambda_amt = 5e-3
+ self.batch_size = batch_size
+
+ def get_off_diag(self, c: tf.Tensor) -> tf.Tensor:
+ """get_off_diag method.
+
+ Makes the diagonals of the cross correlation matrix zeros.
+ This is used in the off-diagonal portion of the loss function,
+ where we take the squares of the off-diagonal values and sum them.
+
+ Arguments:
+ c: A tf.tensor that represents the cross correlation
+ matrix
+
+ Returns:
+ Returns a tf.tensor which represents the cross correlation
+ matrix with its diagonals as zeros.
+ """
+
+ zero_diag = tf.zeros(c.shape[-1])
+ return tf.linalg.set_diag(c, zero_diag)
+
+ def cross_corr_matrix_loss(self, c: tf.Tensor) -> tf.Tensor:
+ """cross_corr_matrix_loss method.
+
+ Gets the loss based on the cross correlation matrix.
+ We want the diagonals to be 1's and everything else to be
+ zeros to show that the two augmented images are similar.
+
+ Loss function procedure:
+ take the diagonal of the cross-correlation matrix, subtract by 1,
+ and square that value so no negatives.
+
+ Take the off-diagonal of the cc-matrix(see get_off_diag()),
+ square those values to get rid of negatives and increase the value,
+ and multiply it by a lambda to weight it such that it is of equal
+ value to the optimizer as the diagonal(there are more values off-diag
+ then on-diag)
+
+ Take the sum of the first and second parts and then sum them together.
+
+ Arguments:
+ c: A tf.tensor that represents the cross correlation
+ matrix
+
+ Returns:
+ Returns a tf.tensor which represents the cross correlation
+ matrix with its diagonals as zeros.
+ """
+
+ # subtracts diagonals by one and squares them(first part)
+ c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2)
+
+ # takes off diagonal, squares it, multiplies with lambda(second part)
+ off_diag = tf.pow(self.get_off_diag(c), 2) * self.lambda_amt
+
+ # sum first and second parts together
+ loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag)
+
+ return loss
+
+ def normalize(self, output: tf.Tensor) -> tf.Tensor:
+ """normalize method.
+
+ Normalizes the model prediction.
+
+ Arguments:
+ output: the model prediction.
+
+ Returns:
+ Returns a normalized version of the model prediction.
+ """
+
+ return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std(
+ output, axis=0
+ )
+
+ def cross_corr_matrix(self, z_a_norm: tf.Tensor, z_b_norm: tf.Tensor) -> tf.Tensor:
+ """cross_corr_matrix method.
+
+ Creates a cross correlation matrix from the predictions.
+ It transposes the first prediction and multiplies this with
+ the second, creating a matrix with shape (n_dense_units, n_dense_units).
+ See build_twin() for more info. Then it divides this with the
+ batch size.
+
+ Arguments:
+ z_a_norm: A normalized version of the first prediction.
+ z_b_norm: A normalized version of the second prediction.
+
+ Returns:
+ Returns a cross correlation matrix.
+ """
+ return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size
+
+ def call(self, z_a: tf.Tensor, z_b: tf.Tensor) -> tf.Tensor:
+ """call method.
+
+ Makes the cross-correlation loss. Uses the CreateCrossCorr
+ class to make the cross corr matrix, then finds the loss and
+ returns it(see cross_corr_matrix_loss()).
+
+ Arguments:
+ z_a: The prediction of the first set of augmented data.
+ z_b: the prediction of the second set of augmented data.
+
+ Returns:
+ Returns a (rank-0) tf.Tensor that represents the loss.
+ """
+
+ z_a_norm, z_b_norm = self.normalize(z_a), self.normalize(z_b)
+ c = self.cross_corr_matrix(z_a_norm, z_b_norm)
+ loss = self.cross_corr_matrix_loss(c)
+ return loss
+
+
+"""
+## Barlow Twins' Model Architecture
+The model has two parts:
+
+* The encoder network, which is a resnet-34.
+* The projector network, which creates the model embeddings.
+ * This consists of an MLP with 3 dense-batchnorm-relu layers.
+"""
+
+"""
+Resnet encoder network implementation:
+"""
+
+
+class ResNet34:
+ """Resnet34 class.
+
+ Responsible for the Resnet 34 architecture.
+ Modified from
+ https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
+ https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2.
+ View their website for more information.
+ """
+
+ def identity_block(self, x, filter):
+ # copy tensor to variable called x_skip
+ x_skip = x
+ # Layer 1
+ x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
+ x = tf.keras.layers.BatchNormalization(axis=3)(x)
+ x = tf.keras.layers.Activation("relu")(x)
+ # Layer 2
+ x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
+ x = tf.keras.layers.BatchNormalization(axis=3)(x)
+ # Add Residue
+ x = tf.keras.layers.Add()([x, x_skip])
+ x = tf.keras.layers.Activation("relu")(x)
+ return x
+
+ def convolutional_block(self, x, filter):
+ # copy tensor to variable called x_skip
+ x_skip = x
+ # Layer 1
+ x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x)
+ x = tf.keras.layers.BatchNormalization(axis=3)(x)
+ x = tf.keras.layers.Activation("relu")(x)
+ # Layer 2
+ x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
+ x = tf.keras.layers.BatchNormalization(axis=3)(x)
+ # Processing Residue with conv(1,1)
+ x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip)
+ # Add Residue
+ x = tf.keras.layers.Add()([x, x_skip])
+ x = tf.keras.layers.Activation("relu")(x)
+ return x
+
+ def __call__(self, shape=(32, 32, 3)):
+ # Step 1 (Setup Input Layer)
+ x_input = tf.keras.layers.Input(shape)
+ x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
+ # Step 2 (Initial Conv layer along with maxPool)
+ x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x)
+ x = tf.keras.layers.BatchNormalization()(x)
+ x = tf.keras.layers.Activation("relu")(x)
+ x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x)
+ # Define size of sub-blocks and initial filter size
+ block_layers = [3, 4, 6, 3]
+ filter_size = 64
+ # Step 3 Add the Resnet Blocks
+ for i in range(4):
+ if i == 0:
+ # For sub-block 1 Residual/Convolutional block not needed
+ for j in range(block_layers[i]):
+ x = self.identity_block(x, filter_size)
+ else:
+ # One Residual/Convolutional Block followed by Identity blocks
+ # The filter size will go on increasing by a factor of 2
+ filter_size = filter_size * 2
+ x = self.convolutional_block(x, filter_size)
+ for j in range(block_layers[i] - 1):
+ x = self.identity_block(x, filter_size)
+ # Step 4 End Dense Network
+ x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x)
+ x = tf.keras.layers.Flatten()(x)
+ model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34")
+ return model
+
+
+"""
+Projector network:
+"""
+
+
+def build_twin() -> keras.Model:
+ """build_twin method.
+
+ Builds a barlow twins model consisting of an encoder(resnet-34)
+ and a projector, which generates embeddings for the images
+
+ Returns:
+ returns a barlow twins model
+ """
+
+ # number of dense neurons in the projector
+ n_dense_neurons = 5000
+
+ # encoder network
+ resnet = ResNet34()()
+ last_layer = resnet.layers[-1].output
+
+ # intermediate layers of the projector network
+ n_layers = 2
+ for i in range(n_layers):
+ dense = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{i}")
+ if i == 0:
+ x = dense(last_layer)
+ else:
+ x = dense(x)
+ x = tf.keras.layers.BatchNormalization(name=f"projector_bn_{i}")(x)
+ x = tf.keras.layers.ReLU(name=f"projector_relu_{i}")(x)
+
+ x = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{n_layers}")(x)
+
+ model = keras.Model(resnet.input, x)
+ return model
+
+
+"""
+## Training Loop Model
+
+See pseudocode for reference.
+"""
+
+
+class BarlowModel(keras.Model):
+ """BarlowModel class.
+
+ BarlowModel class. Responsible for making predictions and handling
+ gradient descent with the optimizer.
+
+ Attributes:
+ model: the barlow model architecture.
+ loss_tracker: the loss metric.
+
+ Methods:
+ train_step: one train step; do model predictions, loss, and
+ optimizer step.
+ metrics: Returns metrics.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.model = build_twin()
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+
+ @property
+ def metrics(self):
+ return [self.loss_tracker]
+
+ def train_step(self, batch: tf.Tensor) -> tf.Tensor:
+ """train_step method.
+
+ Do one train step. Make model predictions, find loss, pass loss to
+ optimizer, and make optimizer apply gradients.
+
+ Arguments:
+ batch: one batch of data to be given to the loss function.
+
+ Returns:
+ Returns a dictionary with the loss metric.
+ """
+
+ # get the two augmentations from the batch
+ y_a, y_b = batch
+
+ with tf.GradientTape() as tape:
+ # get two versions of predictions
+ z_a, z_b = self.model(y_a, training=True), self.model(y_b, training=True)
+ loss = self.loss(z_a, z_b)
+
+ grads_model = tape.gradient(loss, self.model.trainable_variables)
+
+ self.optimizer.apply_gradients(zip(grads_model, self.model.trainable_variables))
+ self.loss_tracker.update_state(loss)
+
+ return {"loss": self.loss_tracker.result()}
+
+
+"""
+## Model Training
+
+* Used the LAMB optimizer, instead of ADAM or SGD.
+* Similar to the LARS optimizer used in the paper, and lets the model converge much
+faster than other methods.
+* Expected training time: 1 hour 30 min. Go and eat a snack or take a nap or something.
+"""
+
+# sets up model, optimizer, loss
+
+bm = BarlowModel()
+# chose the LAMB optimizer due to high batch sizes. Converged MUCH faster
+# than ADAM or SGD
+optimizer = tfa.optimizers.LAMB()
+loss = BarlowLoss(BATCH_SIZE)
+
+bm.compile(optimizer=optimizer, loss=loss)
+
+# Expected training time: 1 hours 30 min
+
+history = bm.fit(augment_versions, epochs=160)
+plt.plot(history.history["loss"])
+plt.show()
+
+"""
+## Evaluation
+
+**Linear evaluation:** to evaluate the model's performance, we add
+a linear dense layer at the end and freeze the main model's weights, only letting the
+dense layer to be tuned. If the model actually learned something, then the accuracy would
+be significantly higher than random chance.
+
+**Accuracy on CIFAR-10** : 64% for this notebook. This is much better than the 10% we get
+from random guessing.
+"""
+
+# Approx: 64% accuracy with this barlow twins model.
+
+xy_ds = (
+ tf.data.Dataset.from_tensor_slices((train_features, train_labels))
+ .shuffle(1000)
+ .batch(BATCH_SIZE, drop_remainder=True)
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+test_ds = (
+ tf.data.Dataset.from_tensor_slices((test_features, test_labels))
+ .shuffle(1000)
+ .batch(BATCH_SIZE, drop_remainder=True)
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+model = keras.models.Sequential(
+ [
+ bm.model,
+ keras.layers.Dense(
+ 10, activation="softmax", kernel_regularizer=keras.regularizers.l2(0.02)
+ ),
+ ]
+)
+
+model.layers[0].trainable = False
+
+linear_optimizer = tfa.optimizers.LAMB()
+model.compile(
+ optimizer=linear_optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+)
+
+model.fit(xy_ds, epochs=35, validation_data=test_ds)
+
+"""
+## Conclusion
+
+* Barlow Twins is a simple and concise method for contrastive and self-supervised
+learning.
+* With this resnet-34 model architecture, we were able to reach 62-64% validation
+accuracy.
+
+## Use-Cases of Barlow-Twins(and contrastive learning in General)
+
+* Semi-supervised learning: You can see that this model gave a 62-64% boost in accuracy
+when it wasn't even trained with the labels. It can be used when you have little labeled
+data but a lot of unlabeled data.
+* You do barlow twins training on the unlabeled data, and then you do secondary training
+with the labeled data.
+
+## Helpful links
+
+* [Paper](https://arxiv.org/abs/2103.03230)
+* [Original Pytorch Implementation](https://github.com/facebookresearch/barlowtwins)
+* [Sayak Paul's Implementation](https://colab.research.google.com/github/sayakpaul/Barlow-Twins-TF/blob/main/Barlow_Twins.ipynb#scrollTo=GlWepkM8_prl).
+* Thanks to Sayak Paul for his implementation. It helped me with debugging and
+comparisons of accuracy, loss.
+* [resnet34 implementation](https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2)
+ * Thanks to Yashowardhan Shinde for writing the article.
+
+
+
+"""
diff --git a/knowledge_base/vision/basnet_segmentation.py b/knowledge_base/vision/basnet_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad2ce373733c0a639a8249818d710403526edd3
--- /dev/null
+++ b/knowledge_base/vision/basnet_segmentation.py
@@ -0,0 +1,485 @@
+"""
+Title: Highly accurate boundaries segmentation using BASNet
+Author: [Hamid Ali](https://github.com/hamidriasat)
+Date created: 2023/05/30
+Last modified: 2024/10/02
+Description: Boundaries aware segmentation model trained on the DUTS dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Deep semantic segmentation algorithms have improved a lot recently, but still fails to correctly
+predict pixels around object boundaries. In this example we implement
+**Boundary-Aware Segmentation Network (BASNet)**, using two stage predict and refine
+architecture, and a hybrid loss it can predict highly accurate boundaries and fine structures
+for image segmentation.
+
+### References:
+
+- [Boundary-Aware Segmentation Network for Mobile and Web Applications](https://arxiv.org/abs/2101.04704)
+- [BASNet Keras Implementation](https://github.com/hamidriasat/BASNet/tree/basnet_keras)
+- [Learning to Detect Salient Objects with Image-level Supervision](https://openaccess.thecvf.com/content_cvpr_2017/html/Wang_Learning_to_Detect_CVPR_2017_paper.html)
+"""
+
+"""
+## Download the Data
+
+We will use the [DUTS-TE](http://saliencydetection.net/duts/) dataset for training. It has 5,019
+images but we will use 140 for training and validation to save notebook running time. DUTS is
+relatively large salient object segmentation dataset. which contain diversified textures and
+structures common to real-world images in both foreground and background.
+"""
+
+import os
+
+# Because of the use of tf.image.ssim in the loss,
+# this example requires TensorFlow. The rest of the code
+# is backend-agnostic.
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+from glob import glob
+import matplotlib.pyplot as plt
+
+import keras_hub
+import tensorflow as tf
+import keras
+from keras import layers, ops
+
+keras.config.disable_traceback_filtering()
+
+"""
+## Define Hyperparameters
+"""
+
+IMAGE_SIZE = 288
+BATCH_SIZE = 4
+OUT_CLASSES = 1
+TRAIN_SPLIT_RATIO = 0.90
+
+"""
+## Create `PyDataset`s
+
+We will use `load_paths()` to load and split 140 paths into train and validation set, and
+convert paths into `PyDataset` object.
+"""
+
+data_dir = keras.utils.get_file(
+ origin="http://saliencydetection.net/duts/download/DUTS-TE.zip",
+ extract=True,
+)
+data_dir = os.path.join(data_dir, "DUTS-TE")
+
+
+def load_paths(path, split_ratio):
+ images = sorted(glob(os.path.join(path, "DUTS-TE-Image/*")))[:140]
+ masks = sorted(glob(os.path.join(path, "DUTS-TE-Mask/*")))[:140]
+ len_ = int(len(images) * split_ratio)
+ return (images[:len_], masks[:len_]), (images[len_:], masks[len_:])
+
+
+class Dataset(keras.utils.PyDataset):
+ def __init__(
+ self,
+ image_paths,
+ mask_paths,
+ img_size,
+ out_classes,
+ batch,
+ shuffle=True,
+ **kwargs,
+ ):
+ if shuffle:
+ perm = np.random.permutation(len(image_paths))
+ image_paths = [image_paths[i] for i in perm]
+ mask_paths = [mask_paths[i] for i in perm]
+ self.image_paths = image_paths
+ self.mask_paths = mask_paths
+ self.img_size = img_size
+ self.out_classes = out_classes
+ self.batch_size = batch
+ super().__init__(*kwargs)
+
+ def __len__(self):
+ return len(self.image_paths) // self.batch_size
+
+ def __getitem__(self, idx):
+ batch_x, batch_y = [], []
+ for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
+ x, y = self.preprocess(
+ self.image_paths[i],
+ self.mask_paths[i],
+ self.img_size,
+ )
+ batch_x.append(x)
+ batch_y.append(y)
+ batch_x = np.stack(batch_x, axis=0)
+ batch_y = np.stack(batch_y, axis=0)
+ return batch_x, batch_y
+
+ def read_image(self, path, size, mode):
+ x = keras.utils.load_img(path, target_size=size, color_mode=mode)
+ x = keras.utils.img_to_array(x)
+ x = (x / 255.0).astype(np.float32)
+ return x
+
+ def preprocess(self, x_batch, y_batch, img_size):
+ images = self.read_image(x_batch, (img_size, img_size), mode="rgb") # image
+ masks = self.read_image(y_batch, (img_size, img_size), mode="grayscale") # mask
+ return images, masks
+
+
+train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)
+
+train_dataset = Dataset(
+ train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True
+)
+val_dataset = Dataset(
+ val_paths[0], val_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=False
+)
+
+"""
+## Visualize Data
+"""
+
+
+def display(display_list):
+ title = ["Input Image", "True Mask", "Predicted Mask"]
+
+ for i in range(len(display_list)):
+ plt.subplot(1, len(display_list), i + 1)
+ plt.title(title[i])
+ plt.imshow(keras.utils.array_to_img(display_list[i]), cmap="gray")
+ plt.axis("off")
+ plt.show()
+
+
+for image, mask in val_dataset:
+ display([image[0], mask[0]])
+ break
+
+"""
+## Analyze Mask
+
+Lets print unique values of above displayed mask. You can see despite belonging to one class, it's
+intensity is changing between low(0) to high(255). This variation in intensity makes it hard for
+network to generate good segmentation map for **salient or camouflaged object segmentation**.
+Because of its Residual Refined Module (RMs), BASNet is good in generating highly accurate
+boundaries and fine structures.
+"""
+
+print(f"Unique values count: {len(np.unique((mask[0] * 255)))}")
+print("Unique values:")
+print(np.unique((mask[0] * 255)).astype(int))
+
+"""
+## Building the BASNet Model
+
+BASNet comprises of a predict-refine architecture and a hybrid loss. The predict-refine
+architecture consists of a densely supervised encoder-decoder network and a residual refinement
+module, which are respectively used to predict and refine a segmentation probability map.
+
+
+"""
+
+
+def basic_block(x_input, filters, stride=1, down_sample=None, activation=None):
+ """Creates a residual(identity) block with two 3*3 convolutions."""
+ residual = x_input
+
+ x = layers.Conv2D(filters, (3, 3), strides=stride, padding="same", use_bias=False)(
+ x_input
+ )
+ x = layers.BatchNormalization()(x)
+ x = layers.Activation("relu")(x)
+
+ x = layers.Conv2D(filters, (3, 3), strides=(1, 1), padding="same", use_bias=False)(
+ x
+ )
+ x = layers.BatchNormalization()(x)
+
+ if down_sample is not None:
+ residual = down_sample
+
+ x = layers.Add()([x, residual])
+
+ if activation is not None:
+ x = layers.Activation(activation)(x)
+
+ return x
+
+
+def convolution_block(x_input, filters, dilation=1):
+ """Apply convolution + batch normalization + relu layer."""
+ x = layers.Conv2D(filters, (3, 3), padding="same", dilation_rate=dilation)(x_input)
+ x = layers.BatchNormalization()(x)
+ return layers.Activation("relu")(x)
+
+
+def segmentation_head(x_input, out_classes, final_size):
+ """Map each decoder stage output to model output classes."""
+ x = layers.Conv2D(out_classes, kernel_size=(3, 3), padding="same")(x_input)
+
+ if final_size is not None:
+ x = layers.Resizing(final_size[0], final_size[1])(x)
+
+ return x
+
+
+def get_resnet_block(resnet, block_num):
+ """Extract and return a ResNet-34 block."""
+ extractor_levels = ["P2", "P3", "P4", "P5"]
+ num_blocks = resnet.stackwise_num_blocks
+ if block_num == 0:
+ x = resnet.get_layer("pool1_pool").output
+ else:
+ x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]
+ y = resnet.get_layer(f"stack{block_num}_block{num_blocks[block_num]-1}_add").output
+ return keras.models.Model(
+ inputs=x,
+ outputs=y,
+ name=f"resnet_block{block_num + 1}",
+ )
+
+
+"""
+## Prediction Module
+
+Prediction module is a heavy encoder decoder structure like U-Net. The encoder includes an input
+convolutional layer and six stages. First four are adopted from ResNet-34 and rest are basic
+res-blocks. Since first convolution and pooling layer of ResNet-34 is skipped so we will use
+`get_resnet_block()` to extract first four blocks. Both bridge and decoder uses three
+convolutional layers with side outputs. The module produces seven segmentation probability
+maps during training, with the last one considered the final output.
+"""
+
+
+def basnet_predict(input_shape, out_classes):
+ """BASNet Prediction Module, it outputs coarse label map."""
+ filters = 64
+ num_stages = 6
+
+ x_input = layers.Input(input_shape)
+
+ # -------------Encoder--------------
+ x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
+
+ resnet = keras_hub.models.ResNetBackbone(
+ input_conv_filters=[64],
+ input_conv_kernel_sizes=[7],
+ stackwise_num_filters=[64, 128, 256, 512],
+ stackwise_num_blocks=[3, 4, 6, 3],
+ stackwise_num_strides=[1, 2, 2, 2],
+ block_type="basic_block",
+ )
+
+ encoder_blocks = []
+ for i in range(num_stages):
+ if i < 4: # First four stages are adopted from ResNet-34 blocks.
+ x = get_resnet_block(resnet, i)(x)
+ encoder_blocks.append(x)
+ x = layers.Activation("relu")(x)
+ else: # Last 2 stages consist of three basic resnet blocks.
+ x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
+ x = basic_block(x, filters=filters * 8, activation="relu")
+ x = basic_block(x, filters=filters * 8, activation="relu")
+ x = basic_block(x, filters=filters * 8, activation="relu")
+ encoder_blocks.append(x)
+
+ # -------------Bridge-------------
+ x = convolution_block(x, filters=filters * 8, dilation=2)
+ x = convolution_block(x, filters=filters * 8, dilation=2)
+ x = convolution_block(x, filters=filters * 8, dilation=2)
+ encoder_blocks.append(x)
+
+ # -------------Decoder-------------
+ decoder_blocks = []
+ for i in reversed(range(num_stages)):
+ if i != (num_stages - 1): # Except first, scale other decoder stages.
+ shape = x.shape
+ x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
+
+ x = layers.concatenate([encoder_blocks[i], x], axis=-1)
+ x = convolution_block(x, filters=filters * 8)
+ x = convolution_block(x, filters=filters * 8)
+ x = convolution_block(x, filters=filters * 8)
+ decoder_blocks.append(x)
+
+ decoder_blocks.reverse() # Change order from last to first decoder stage.
+ decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder.
+
+ # -------------Side Outputs--------------
+ decoder_blocks = [
+ segmentation_head(decoder_block, out_classes, input_shape[:2])
+ for decoder_block in decoder_blocks
+ ]
+
+ return keras.models.Model(inputs=x_input, outputs=decoder_blocks)
+
+
+"""
+## Residual Refinement Module
+
+Refinement Modules (RMs), designed as a residual block aim to refines the coarse(blurry and noisy
+boundaries) segmentation maps generated by prediction module. Similar to prediction module it's
+also an encode decoder structure but with light weight 4 stages, each containing one
+`convolutional block()` init. At the end it adds both coarse and residual output to generate
+refined output.
+"""
+
+
+def basnet_rrm(base_model, out_classes):
+ """BASNet Residual Refinement Module(RRM) module, output fine label map."""
+ num_stages = 4
+ filters = 64
+
+ x_input = base_model.output[0]
+
+ # -------------Encoder--------------
+ x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)
+
+ encoder_blocks = []
+ for _ in range(num_stages):
+ x = convolution_block(x, filters=filters)
+ encoder_blocks.append(x)
+ x = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(x)
+
+ # -------------Bridge--------------
+ x = convolution_block(x, filters=filters)
+
+ # -------------Decoder--------------
+ for i in reversed(range(num_stages)):
+ shape = x.shape
+ x = layers.Resizing(shape[1] * 2, shape[2] * 2)(x)
+ x = layers.concatenate([encoder_blocks[i], x], axis=-1)
+ x = convolution_block(x, filters=filters)
+
+ x = segmentation_head(x, out_classes, None) # Segmentation head.
+
+ # ------------- refined = coarse + residual
+ x = layers.Add()([x_input, x]) # Add prediction + refinement output
+
+ return keras.models.Model(inputs=[base_model.input], outputs=[x])
+
+
+"""
+## Combine Predict and Refinement Module
+"""
+
+
+class BASNet(keras.Model):
+ def __init__(self, input_shape, out_classes):
+ """BASNet, it's a combination of two modules
+ Prediction Module and Residual Refinement Module(RRM)."""
+
+ # Prediction model.
+ predict_model = basnet_predict(input_shape, out_classes)
+ # Refinement model.
+ refine_model = basnet_rrm(predict_model, out_classes)
+
+ output = refine_model.outputs # Combine outputs.
+ output.extend(predict_model.output)
+
+ # Activations.
+ output = [layers.Activation("sigmoid")(x) for x in output]
+ super().__init__(inputs=predict_model.input, outputs=output)
+
+ self.smooth = 1.0e-9
+ # Binary Cross Entropy loss.
+ self.cross_entropy_loss = keras.losses.BinaryCrossentropy()
+ # Structural Similarity Index value.
+ self.ssim_value = tf.image.ssim
+ # Jaccard / IoU loss.
+ self.iou_value = self.calculate_iou
+
+ def calculate_iou(
+ self,
+ y_true,
+ y_pred,
+ ):
+ """Calculate intersection over union (IoU) between images."""
+ intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])
+ union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])
+ union = union - intersection
+ return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)
+
+ def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
+ total = 0.0
+ for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output
+ cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)
+
+ ssim_value = self.ssim_value(y_true, y_pred, max_val=1)
+ ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)
+
+ iou_value = self.iou_value(y_true, y_pred)
+ iou_loss = 1 - iou_value
+
+ # Add all three losses.
+ total += cross_entropy_loss + ssim_loss + iou_loss
+ return total
+
+
+"""
+## Hybrid Loss
+
+Another important feature of BASNet is its hybrid loss function, which is a combination of
+binary cross entropy, structural similarity and intersection-over-union losses, which guide
+the network to learn three-level (i.e., pixel, patch and map level) hierarchy representations.
+"""
+
+
+basnet_model = BASNet(
+ input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES
+) # Create model.
+basnet_model.summary() # Show model summary.
+
+optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)
+# Compile model.
+basnet_model.compile(
+ optimizer=optimizer,
+ metrics=[keras.metrics.MeanAbsoluteError(name="mae") for _ in basnet_model.outputs],
+)
+
+"""
+### Train the Model
+"""
+
+basnet_model.fit(train_dataset, validation_data=val_dataset, epochs=1)
+
+"""
+### Visualize Predictions
+
+In paper BASNet was trained on DUTS-TR dataset, which has 10553 images. Model was trained for 400k
+iterations with a batch size of eight and without a validation dataset. After training model was
+evaluated on DUTS-TE dataset and achieved a mean absolute error of `0.042`.
+
+Since BASNet is a deep model and cannot be trained in a short amount of time which is a
+requirement for keras example notebook, so we will load pretrained weights from [here](https://github.com/hamidriasat/BASNet/tree/basnet_keras)
+to show model prediction. Due to computer power limitation this model was trained for 120k
+iterations but it still demonstrates its capabilities. For further details about
+trainings parameters please check given link.
+"""
+
+import gdown
+
+gdown.download(id="1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", output="basnet_weights.h5")
+
+
+def normalize_output(prediction):
+ max_value = np.max(prediction)
+ min_value = np.min(prediction)
+ return (prediction - min_value) / (max_value - min_value)
+
+
+# Load weights.
+basnet_model.load_weights("./basnet_weights.h5")
+
+"""
+### Make Predictions
+"""
+
+for (image, mask), _ in zip(val_dataset, range(1)):
+ pred_mask = basnet_model.predict(image)
+ display([image[0], mask[0], normalize_output(pred_mask[0][0])])
diff --git a/knowledge_base/vision/bit.py b/knowledge_base/vision/bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b7d5ffa17e1a6121c35433531c4ee864875c015
--- /dev/null
+++ b/knowledge_base/vision/bit.py
@@ -0,0 +1,311 @@
+"""
+Title: Image Classification using BigTransfer (BiT)
+Author: [Sayan Nath](https://twitter.com/sayannath2350)
+Date created: 2021/09/24
+Last modified: 2024/01/03
+Description: BigTransfer (BiT) State-of-the-art transfer learning for image classification.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+BigTransfer (also known as BiT) is a state-of-the-art transfer learning method for image
+classification. Transfer of pre-trained representations improves sample efficiency and
+simplifies hyperparameter tuning when training deep neural networks for vision. BiT
+revisit the paradigm of pre-training on large supervised datasets and fine-tuning the
+model on a target task. The importance of appropriately choosing normalization layers and
+scaling the architecture capacity as the amount of pre-training data increases.
+
+BigTransfer(BiT) is trained on public datasets, along with code in
+[TF2, Jax and Pytorch](https://github.com/google-research/big_transfer). This will help anyone to reach
+state of the art performance on their task of interest, even with just a handful of
+labeled images per class.
+
+You can find BiT models pre-trained on
+[ImageNet](https://image-net.org/challenges/LSVRC/2012/index) and ImageNet-21k in
+[TFHub](https://tfhub.dev/google/collections/bit/1) as TensorFlow2 SavedModels that you
+can use easily as Keras Layers. There are a variety of sizes ranging from a standard
+ResNet50 to a ResNet152x4 (152 layers deep, 4x wider than a typical ResNet50) for users
+with larger computational and memory budgets but higher accuracy requirements.
+
+
+Figure: The x-axis shows the number of images used per class, ranging from 1 to the full
+dataset. On the plots on the left, the curve in blue above is our BiT-L model, whereas
+the curve below is a ResNet-50 pre-trained on ImageNet (ILSVRC-2012).
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+
+import keras
+from keras import ops
+import tensorflow as tf
+import tensorflow_hub as hub
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+
+SEEDS = 42
+
+keras.utils.set_random_seed(SEEDS)
+
+"""
+## Gather Flower Dataset
+"""
+
+train_ds, validation_ds = tfds.load(
+ "tf_flowers",
+ split=["train[:85%]", "train[85%:]"],
+ as_supervised=True,
+)
+
+"""
+## Visualise the dataset
+"""
+
+plt.figure(figsize=(10, 10))
+for i, (image, label) in enumerate(train_ds.take(9)):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image)
+ plt.title(int(label))
+ plt.axis("off")
+
+"""
+## Define hyperparameters
+"""
+
+RESIZE_TO = 384
+CROP_TO = 224
+BATCH_SIZE = 64
+STEPS_PER_EPOCH = 10
+AUTO = tf.data.AUTOTUNE # optimise the pipeline performance
+NUM_CLASSES = 5 # number of classes
+SCHEDULE_LENGTH = (
+ 500 # we will train on lower resolution images and will still attain good results
+)
+SCHEDULE_BOUNDARIES = [
+ 200,
+ 300,
+ 400,
+] # more the dataset size the schedule length increase
+
+"""
+The hyperparamteres like `SCHEDULE_LENGTH` and `SCHEDULE_BOUNDARIES` are determined based
+on empirical results. The method has been explained in the [original
+paper](https://arxiv.org/abs/1912.11370) and in their [Google AI Blog
+Post](https://ai.googleblog.com/2020/05/open-sourcing-bit-exploring-large-scale.html).
+
+The `SCHEDULE_LENGTH` is aslo determined whether to use [MixUp
+Augmentation](https://arxiv.org/abs/1710.09412) or not. You can also find an easy MixUp
+Implementation in [Keras Coding Examples](https://keras.io/examples/vision/mixup/).
+
+
+"""
+
+"""
+## Define preprocessing helper functions
+"""
+
+SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE
+
+random_flip = keras.layers.RandomFlip("horizontal")
+random_crop = keras.layers.RandomCrop(CROP_TO, CROP_TO)
+
+
+def preprocess_train(image, label):
+ image = random_flip(image)
+ image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
+ image = random_crop(image)
+ image = image / 255.0
+ return (image, label)
+
+
+def preprocess_test(image, label):
+ image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
+ image = ops.cast(image, dtype="float32")
+ image = image / 255.0
+ return (image, label)
+
+
+DATASET_NUM_TRAIN_EXAMPLES = train_ds.cardinality().numpy()
+
+repeat_count = int(
+ SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH
+)
+repeat_count += 50 + 1 # To ensure at least there are 50 epochs of training
+
+"""
+## Define the data pipeline
+"""
+
+# Training pipeline
+pipeline_train = (
+ train_ds.shuffle(10000)
+ .repeat(repeat_count) # Repeat dataset_size / num_steps
+ .map(preprocess_train, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# Validation pipeline
+pipeline_validation = (
+ validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+"""
+## Visualise the training samples
+"""
+
+image_batch, label_batch = next(iter(pipeline_train))
+
+plt.figure(figsize=(10, 10))
+for n in range(25):
+ ax = plt.subplot(5, 5, n + 1)
+ plt.imshow(image_batch[n])
+ plt.title(label_batch[n].numpy())
+ plt.axis("off")
+
+"""
+## Load pretrained TF-Hub model into a `KerasLayer`
+"""
+
+bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
+bit_module = hub.load(bit_model_url)
+
+"""
+## Create BigTransfer (BiT) model
+
+To create the new model, we:
+
+1. Cut off the BiT modelโs original head. This leaves us with the โpre-logitsโ output.
+We do not have to do this if we use the โfeature extractorโ models (i.e. all those in
+subdirectories titled `feature_vectors`), since for those models the head has already
+been cut off.
+
+2. Add a new head with the number of outputs equal to the number of classes of our new
+task. Note that it is important that we initialise the head to all zeroes.
+"""
+
+
+class MyBiTModel(keras.Model):
+ def __init__(self, num_classes, module, **kwargs):
+ super().__init__(**kwargs)
+
+ self.num_classes = num_classes
+ self.head = keras.layers.Dense(num_classes, kernel_initializer="zeros")
+ self.bit_model = module
+
+ def call(self, images):
+ bit_embedding = self.bit_model(images)
+ return self.head(bit_embedding)
+
+
+model = MyBiTModel(num_classes=NUM_CLASSES, module=bit_module)
+
+"""
+## Define optimizer and loss
+"""
+
+learning_rate = 0.003 * BATCH_SIZE / 512
+
+# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
+lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
+ boundaries=SCHEDULE_BOUNDARIES,
+ values=[
+ learning_rate,
+ learning_rate * 0.1,
+ learning_rate * 0.01,
+ learning_rate * 0.001,
+ ],
+)
+optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
+
+loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+
+"""
+## Compile the model
+"""
+
+model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
+
+"""
+## Set up callbacks
+"""
+
+train_callbacks = [
+ keras.callbacks.EarlyStopping(
+ monitor="val_accuracy", patience=2, restore_best_weights=True
+ )
+]
+
+"""
+## Train the model
+"""
+
+history = model.fit(
+ pipeline_train,
+ batch_size=BATCH_SIZE,
+ epochs=int(SCHEDULE_LENGTH / STEPS_PER_EPOCH),
+ steps_per_epoch=STEPS_PER_EPOCH,
+ validation_data=pipeline_validation,
+ callbacks=train_callbacks,
+)
+
+"""
+## Plot the training and validation metrics
+"""
+
+
+def plot_hist(hist):
+ plt.plot(hist.history["accuracy"])
+ plt.plot(hist.history["val_accuracy"])
+ plt.plot(hist.history["loss"])
+ plt.plot(hist.history["val_loss"])
+ plt.title("Training Progress")
+ plt.ylabel("Accuracy/Loss")
+ plt.xlabel("Epochs")
+ plt.legend(["train_acc", "val_acc", "train_loss", "val_loss"], loc="upper left")
+ plt.show()
+
+
+plot_hist(history)
+
+"""
+## Evaluate the model
+"""
+
+accuracy = model.evaluate(pipeline_validation)[1] * 100
+print("Accuracy: {:.2f}%".format(accuracy))
+
+"""
+## Conclusion
+
+BiT performs well across a surprisingly wide range of data regimes
+-- from 1 example per class to 1M total examples. BiT achieves 87.5% top-1 accuracy on
+ILSVRC-2012, 99.4% on CIFAR-10, and 76.3% on the 19 task Visual Task Adaptation Benchmark
+(VTAB). On small datasets, BiT attains 76.8% on ILSVRC-2012 with 10 examples per class,
+and 97.0% on CIFAR-10 with 10 examples per class.
+
+
+
+You can experiment further with the BigTransfer Method by following the
+[original paper](https://arxiv.org/abs/1912.11370).
+
+
+**Example available on HuggingFace**
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/bit) | [](https://huggingface.co/spaces/keras-io/siamese-contrastive) |
+"""
diff --git a/knowledge_base/vision/cait.py b/knowledge_base/vision/cait.py
new file mode 100644
index 0000000000000000000000000000000000000000..39502d71beae009f2beb313f6cedb99d3d1a7539
--- /dev/null
+++ b/knowledge_base/vision/cait.py
@@ -0,0 +1,1031 @@
+"""
+Title: Class Attention Image Transformers with LayerScale
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2022/09/19
+Last modified: 2022/11/21
+Description: Implementing an image transformer equipped with Class Attention and LayerScale.
+Accelerator: None
+"""
+
+"""
+
+## Introduction
+
+In this tutorial, we implement the CaiT (Class-Attention in Image Transformers)
+proposed in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) by
+Touvron et al. Depth scaling, i.e. increasing the model depth for obtaining better
+performance and generalization has been quite successful for convolutional neural
+networks ([Tan et al.](https://arxiv.org/abs/1905.11946),
+[Dollรกr et al.](https://arxiv.org/abs/2103.06877), for example). But applying
+the same model scaling principles to
+Vision Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)) doesn't
+translate equally well -- their performance gets saturated quickly with depth scaling.
+Note that one assumption here is that the underlying pre-training dataset is
+always kept fixed when performing model scaling.
+
+In the CaiT paper, the authors investigate this phenomenon and propose modifications to
+the vanilla ViT (Vision Transformers) architecture to mitigate this problem.
+
+The tutorial is structured like so:
+
+* Implementation of the individual blocks of CaiT
+* Collating all the blocks to create the CaiT model
+* Loading a pre-trained CaiT model
+* Obtaining prediction results
+* Visualization of the different attention layers of CaiT
+
+The readers are assumed to be familiar with Vision Transformers already. Here is
+an implementation of Vision Transformers in Keras:
+[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
+"""
+
+"""
+## Imports
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import io
+import typing
+from urllib.request import urlopen
+
+import matplotlib.pyplot as plt
+import numpy as np
+import PIL
+import keras
+from keras import layers
+from keras import ops
+
+"""
+## The LayerScale layer
+
+We begin by implementing a **LayerScale** layer which is one of the two modifications
+proposed in the CaiT paper.
+
+When increasing the depth of the ViT models, they meet with optimization instability and
+eventually don't converge. The residual connections within each Transformer block
+introduce information bottleneck. When there is an increased amount of depth, this
+bottleneck can quickly explode and deviate the optimization pathway for the underlying
+model.
+
+The following equations denote where residual connections are added within a Transformer
+block:
+
+
+

+
+
+where, **SA** stands for self-attention, **FFN** stands for feed-forward network, and
+**eta** denotes the LayerNorm operator ([Ba et al.](https://arxiv.org/abs/1607.06450)).
+
+LayerScale is formally implemented like so:
+
+
+

+
+
+where, the lambdas are learnable parameters and are initialized with a very small value
+({0.1, 1e-5, 1e-6}). **diag** represents a diagonal matrix.
+
+Intuitively, LayerScale helps control the contribution of the residual branches. The
+learnable parameters of LayerScale are initialized to a small value to let the branches
+act like identity functions and then let them figure out the degrees of interactions
+during the training. The diagonal matrix additionally helps control the contributions
+of the individual dimensions of the residual inputs as it is applied on a per-channel
+basis.
+
+The practical implementation of LayerScale is simpler than it might sound.
+"""
+
+
+class LayerScale(layers.Layer):
+ """LayerScale as introduced in CaiT: https://arxiv.org/abs/2103.17239.
+
+ Args:
+ init_values (float): value to initialize the diagonal matrix of LayerScale.
+ projection_dim (int): projection dimension used in LayerScale.
+ """
+
+ def __init__(self, init_values: float, projection_dim: int, **kwargs):
+ super().__init__(**kwargs)
+ self.gamma = self.add_weight(
+ shape=(projection_dim,),
+ initializer=keras.initializers.Constant(init_values),
+ )
+
+ def call(self, x, training=False):
+ return x * self.gamma
+
+
+"""
+## Stochastic depth layer
+
+Since its introduction ([Huang et al.](https://arxiv.org/abs/1603.09382)), Stochastic
+Depth has become a favorite component in almost all modern neural network architectures.
+CaiT is no exception. Discussing Stochastic Depth is out of scope for this notebook. You
+can refer to [this resource](https://paperswithcode.com/method/stochastic-depth) in case
+you need a refresher.
+"""
+
+
+class StochasticDepth(layers.Layer):
+ """Stochastic Depth layer (https://arxiv.org/abs/1603.09382).
+
+ Reference:
+ https://github.com/rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_prob: float, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_prob = drop_prob
+ self.seed_generator = keras.random.SeedGenerator(1337)
+
+ def call(self, x, training=False):
+ if training:
+ keep_prob = 1 - self.drop_prob
+ shape = (ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
+ random_tensor = keep_prob + ops.random.uniform(
+ shape, minval=0, maxval=1, seed=self.seed_generator
+ )
+ random_tensor = ops.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+"""
+## Class attention
+
+The vanilla ViT uses self-attention (SA) layers for modelling how the image patches and
+the _learnable_ CLS token interact with each other. The CaiT authors propose to decouple
+the attention layers responsible for attending to the image patches and the CLS tokens.
+
+When using ViTs for any discriminative tasks (classification, for example), we usually
+take the representations belonging to the CLS token and then pass them to the
+task-specific heads. This is as opposed to using something like global average pooling as
+is typically done in convolutional neural networks.
+
+The interactions between the CLS token and other image patches are processed uniformly
+through self-attention layers. As the CaiT authors point out, this setup has got an
+entangled effect. On one hand, the self-attention layers are responsible for modelling
+the image patches. On the other hand, they're also responsible for summarizing the
+modelled information via the CLS token so that it's useful for the learning objective.
+
+To help disentangle these two things, the authors propose to:
+
+* Introduce the CLS token at a later stage in the network.
+* Model the interaction between the CLS token and the representations related to the
+image patches through a separate set of attention layers. The authors call this **Class
+Attention** (CA).
+
+The figure below (taken from the original paper) depicts this idea:
+
+
+

+
+
+This is achieved by treating the CLS token embeddings as the queries in the CA layers.
+CLS token embeddings and the image patch embeddings are fed as keys as well values.
+
+**Note** that "embeddings" and "representations" have been used interchangeably here.
+"""
+
+
+class ClassAttention(layers.Layer):
+ """Class attention as proposed in CaiT: https://arxiv.org/abs/2103.17239.
+
+ Args:
+ projection_dim (int): projection dimension for the query, key, and value
+ of attention.
+ num_heads (int): number of attention heads.
+ dropout_rate (float): dropout rate to be used for dropout in the attention
+ scores as well as the final projected outputs.
+ """
+
+ def __init__(
+ self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.num_heads = num_heads
+
+ head_dim = projection_dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.q = layers.Dense(projection_dim)
+ self.k = layers.Dense(projection_dim)
+ self.v = layers.Dense(projection_dim)
+ self.attn_drop = layers.Dropout(dropout_rate)
+ self.proj = layers.Dense(projection_dim)
+ self.proj_drop = layers.Dropout(dropout_rate)
+
+ def call(self, x, training=False):
+ batch_size, num_patches, num_channels = (
+ ops.shape(x)[0],
+ ops.shape(x)[1],
+ ops.shape(x)[2],
+ )
+
+ # Query projection. `cls_token` embeddings are queries.
+ q = ops.expand_dims(self.q(x[:, 0]), axis=1)
+ q = ops.reshape(
+ q, (batch_size, 1, self.num_heads, num_channels // self.num_heads)
+ ) # Shape: (batch_size, 1, num_heads, dimension_per_head)
+ q = ops.transpose(q, axes=[0, 2, 1, 3])
+ scale = ops.cast(self.scale, dtype=q.dtype)
+ q = q * scale
+
+ # Key projection. Patch embeddings as well the cls embedding are used as keys.
+ k = self.k(x)
+ k = ops.reshape(
+ k, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads)
+ ) # Shape: (batch_size, num_tokens, num_heads, dimension_per_head)
+ k = ops.transpose(k, axes=[0, 2, 3, 1])
+
+ # Value projection. Patch embeddings as well the cls embedding are used as values.
+ v = self.v(x)
+ v = ops.reshape(
+ v, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads)
+ )
+ v = ops.transpose(v, axes=[0, 2, 1, 3])
+
+ # Calculate attention scores between cls_token embedding and patch embeddings.
+ attn = ops.matmul(q, k)
+ attn = ops.nn.softmax(attn, axis=-1)
+ attn = self.attn_drop(attn, training=training)
+
+ x_cls = ops.matmul(attn, v)
+ x_cls = ops.transpose(x_cls, axes=[0, 2, 1, 3])
+ x_cls = ops.reshape(x_cls, (batch_size, 1, num_channels))
+ x_cls = self.proj(x_cls)
+ x_cls = self.proj_drop(x_cls, training=training)
+
+ return x_cls, attn
+
+
+"""
+## Talking Head Attention
+
+The CaiT authors use the Talking Head attention
+([Shazeer et al.](https://arxiv.org/abs/2003.02436))
+instead of the vanilla scaled dot-product multi-head attention used in
+the original Transformer paper
+([Vaswani et al.](https://papers.nips.cc/paper/7181-attention-is-all-you-need)).
+They introduce two linear projections before and after the softmax
+operations for obtaining better results.
+
+For a more rigorous treatment of the Talking Head attention and the vanilla attention
+mechanisms, please refer to their respective papers (linked above).
+"""
+
+
+class TalkingHeadAttention(layers.Layer):
+ """Talking-head attention as proposed in CaiT: https://arxiv.org/abs/2003.02436.
+
+ Args:
+ projection_dim (int): projection dimension for the query, key, and value
+ of attention.
+ num_heads (int): number of attention heads.
+ dropout_rate (float): dropout rate to be used for dropout in the attention
+ scores as well as the final projected outputs.
+ """
+
+ def __init__(
+ self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
+ ):
+ super().__init__(**kwargs)
+
+ self.num_heads = num_heads
+
+ head_dim = projection_dim // self.num_heads
+
+ self.scale = head_dim**-0.5
+
+ self.qkv = layers.Dense(projection_dim * 3)
+ self.attn_drop = layers.Dropout(dropout_rate)
+
+ self.proj = layers.Dense(projection_dim)
+
+ self.proj_l = layers.Dense(self.num_heads)
+ self.proj_w = layers.Dense(self.num_heads)
+
+ self.proj_drop = layers.Dropout(dropout_rate)
+
+ def call(self, x, training=False):
+ B, N, C = ops.shape(x)[0], ops.shape(x)[1], ops.shape(x)[2]
+
+ # Project the inputs all at once.
+ qkv = self.qkv(x)
+
+ # Reshape the projected output so that they're segregated in terms of
+ # query, key, and value projections.
+ qkv = ops.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads))
+
+ # Transpose so that the `num_heads` becomes the leading dimensions.
+ # Helps to better segregate the representation sub-spaces.
+ qkv = ops.transpose(qkv, axes=[2, 0, 3, 1, 4])
+ scale = ops.cast(self.scale, dtype=qkv.dtype)
+ q, k, v = qkv[0] * scale, qkv[1], qkv[2]
+
+ # Obtain the raw attention scores.
+ attn = ops.matmul(q, ops.transpose(k, axes=[0, 1, 3, 2]))
+
+ # Linear projection of the similarities between the query and key projections.
+ attn = self.proj_l(ops.transpose(attn, axes=[0, 2, 3, 1]))
+
+ # Normalize the attention scores.
+ attn = ops.transpose(attn, axes=[0, 3, 1, 2])
+ attn = ops.nn.softmax(attn, axis=-1)
+
+ # Linear projection on the softmaxed scores.
+ attn = self.proj_w(ops.transpose(attn, axes=[0, 2, 3, 1]))
+ attn = ops.transpose(attn, axes=[0, 3, 1, 2])
+ attn = self.attn_drop(attn, training=training)
+
+ # Final set of projections as done in the vanilla attention mechanism.
+ x = ops.matmul(attn, v)
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
+ x = ops.reshape(x, (B, N, C))
+
+ x = self.proj(x)
+ x = self.proj_drop(x, training=training)
+
+ return x, attn
+
+
+"""
+## Feed-forward Network
+
+Next, we implement the feed-forward network which is one of the components within a
+Transformer block.
+"""
+
+
+def mlp(x, dropout_rate: float, hidden_units: typing.List[int]):
+ """FFN for a Transformer block."""
+ for idx, units in enumerate(hidden_units):
+ x = layers.Dense(
+ units,
+ activation=ops.nn.gelu if idx == 0 else None,
+ bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
+ )(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+"""
+## Other blocks
+
+In the next two cells, we implement the remaining blocks as standalone functions:
+
+* `LayerScaleBlockClassAttention()` which returns a `keras.Model`. It is a Transformer block
+equipped with Class Attention, LayerScale, and Stochastic Depth. It operates on the CLS
+embeddings and the image patch embeddings.
+* `LayerScaleBlock()` which returns a `keras.model`. It is also a Transformer block that
+operates only on the embeddings of the image patches. It is equipped with LayerScale and
+Stochastic Depth.
+"""
+
+
+def LayerScaleBlockClassAttention(
+ projection_dim: int,
+ num_heads: int,
+ layer_norm_eps: float,
+ init_values: float,
+ mlp_units: typing.List[int],
+ dropout_rate: float,
+ sd_prob: float,
+ name: str,
+):
+ """Pre-norm transformer block meant to be applied to the embeddings of the
+ cls token and the embeddings of image patches.
+
+ Includes LayerScale and Stochastic Depth.
+
+ Args:
+ projection_dim (int): projection dimension to be used in the
+ Transformer blocks and patch projection layer.
+ num_heads (int): number of attention heads.
+ layer_norm_eps (float): epsilon to be used for Layer Normalization.
+ init_values (float): initial value for the diagonal matrix used in LayerScale.
+ mlp_units (List[int]): dimensions of the feed-forward network used in
+ the Transformer blocks.
+ dropout_rate (float): dropout rate to be used for dropout in the attention
+ scores as well as the final projected outputs.
+ sd_prob (float): stochastic depth rate.
+ name (str): a name identifier for the block.
+
+ Returns:
+ A keras.Model instance.
+ """
+ x = keras.Input((None, projection_dim))
+ x_cls = keras.Input((None, projection_dim))
+ inputs = keras.layers.Concatenate(axis=1)([x_cls, x])
+
+ # Class attention (CA).
+ x1 = layers.LayerNormalization(epsilon=layer_norm_eps)(inputs)
+ attn_output, attn_scores = ClassAttention(projection_dim, num_heads, dropout_rate)(
+ x1
+ )
+ attn_output = (
+ LayerScale(init_values, projection_dim)(attn_output)
+ if init_values
+ else attn_output
+ )
+ attn_output = StochasticDepth(sd_prob)(attn_output) if sd_prob else attn_output
+ x2 = keras.layers.Add()([x_cls, attn_output])
+
+ # FFN.
+ x3 = layers.LayerNormalization(epsilon=layer_norm_eps)(x2)
+ x4 = mlp(x3, hidden_units=mlp_units, dropout_rate=dropout_rate)
+ x4 = LayerScale(init_values, projection_dim)(x4) if init_values else x4
+ x4 = StochasticDepth(sd_prob)(x4) if sd_prob else x4
+ outputs = keras.layers.Add()([x2, x4])
+
+ return keras.Model([x, x_cls], [outputs, attn_scores], name=name)
+
+
+def LayerScaleBlock(
+ projection_dim: int,
+ num_heads: int,
+ layer_norm_eps: float,
+ init_values: float,
+ mlp_units: typing.List[int],
+ dropout_rate: float,
+ sd_prob: float,
+ name: str,
+):
+ """Pre-norm transformer block meant to be applied to the embeddings of the
+ image patches.
+
+ Includes LayerScale and Stochastic Depth.
+
+ Args:
+ projection_dim (int): projection dimension to be used in the
+ Transformer blocks and patch projection layer.
+ num_heads (int): number of attention heads.
+ layer_norm_eps (float): epsilon to be used for Layer Normalization.
+ init_values (float): initial value for the diagonal matrix used in LayerScale.
+ mlp_units (List[int]): dimensions of the feed-forward network used in
+ the Transformer blocks.
+ dropout_rate (float): dropout rate to be used for dropout in the attention
+ scores as well as the final projected outputs.
+ sd_prob (float): stochastic depth rate.
+ name (str): a name identifier for the block.
+
+ Returns:
+ A keras.Model instance.
+ """
+ encoded_patches = keras.Input((None, projection_dim))
+
+ # Self-attention.
+ x1 = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
+ attn_output, attn_scores = TalkingHeadAttention(
+ projection_dim, num_heads, dropout_rate
+ )(x1)
+ attn_output = (
+ LayerScale(init_values, projection_dim)(attn_output)
+ if init_values
+ else attn_output
+ )
+ attn_output = StochasticDepth(sd_prob)(attn_output) if sd_prob else attn_output
+ x2 = layers.Add()([encoded_patches, attn_output])
+
+ # FFN.
+ x3 = layers.LayerNormalization(epsilon=layer_norm_eps)(x2)
+ x4 = mlp(x3, hidden_units=mlp_units, dropout_rate=dropout_rate)
+ x4 = LayerScale(init_values, projection_dim)(x4) if init_values else x4
+ x4 = StochasticDepth(sd_prob)(x4) if sd_prob else x4
+ outputs = layers.Add()([x2, x4])
+
+ return keras.Model(encoded_patches, [outputs, attn_scores], name=name)
+
+
+"""
+Given all these blocks, we are now ready to collate them into the final CaiT model.
+"""
+
+"""
+## Putting the pieces together: The CaiT model
+"""
+
+
+class CaiT(keras.Model):
+ """CaiT model.
+
+ Args:
+ projection_dim (int): projection dimension to be used in the
+ Transformer blocks and patch projection layer.
+ patch_size (int): patch size of the input images.
+ num_patches (int): number of patches after extracting the image patches.
+ init_values (float): initial value for the diagonal matrix used in LayerScale.
+ mlp_units: (List[int]): dimensions of the feed-forward network used in
+ the Transformer blocks.
+ sa_ffn_layers (int): number of self-attention Transformer blocks.
+ ca_ffn_layers (int): number of class-attention Transformer blocks.
+ num_heads (int): number of attention heads.
+ layer_norm_eps (float): epsilon to be used for Layer Normalization.
+ dropout_rate (float): dropout rate to be used for dropout in the attention
+ scores as well as the final projected outputs.
+ sd_prob (float): stochastic depth rate.
+ global_pool (str): denotes how to pool the representations coming out of
+ the final Transformer block.
+ pre_logits (bool): if set to True then don't add a classification head.
+ num_classes (int): number of classes to construct the final classification
+ layer with.
+ """
+
+ def __init__(
+ self,
+ projection_dim: int,
+ patch_size: int,
+ num_patches: int,
+ init_values: float,
+ mlp_units: typing.List[int],
+ sa_ffn_layers: int,
+ ca_ffn_layers: int,
+ num_heads: int,
+ layer_norm_eps: float,
+ dropout_rate: float,
+ sd_prob: float,
+ global_pool: str,
+ pre_logits: bool,
+ num_classes: int,
+ **kwargs,
+ ):
+ if global_pool not in ["token", "avg"]:
+ raise ValueError(
+ 'Invalid value received for `global_pool`, should be either `"token"` or `"avg"`.'
+ )
+
+ super().__init__(**kwargs)
+
+ # Responsible for patchifying the input images and the linearly projecting them.
+ self.projection = keras.Sequential(
+ [
+ layers.Conv2D(
+ filters=projection_dim,
+ kernel_size=(patch_size, patch_size),
+ strides=(patch_size, patch_size),
+ padding="VALID",
+ name="conv_projection",
+ kernel_initializer="lecun_normal",
+ ),
+ layers.Reshape(
+ target_shape=(-1, projection_dim),
+ name="flatten_projection",
+ ),
+ ],
+ name="projection",
+ )
+
+ # CLS token and the positional embeddings.
+ self.cls_token = self.add_weight(
+ shape=(1, 1, projection_dim), initializer="zeros"
+ )
+ self.pos_embed = self.add_weight(
+ shape=(1, num_patches, projection_dim), initializer="zeros"
+ )
+
+ # Projection dropout.
+ self.pos_drop = layers.Dropout(dropout_rate, name="projection_dropout")
+
+ # Stochastic depth schedule.
+ dpr = [sd_prob for _ in range(sa_ffn_layers)]
+
+ # Self-attention (SA) Transformer blocks operating only on the image patch
+ # embeddings.
+ self.blocks = [
+ LayerScaleBlock(
+ projection_dim=projection_dim,
+ num_heads=num_heads,
+ layer_norm_eps=layer_norm_eps,
+ init_values=init_values,
+ mlp_units=mlp_units,
+ dropout_rate=dropout_rate,
+ sd_prob=dpr[i],
+ name=f"sa_ffn_block_{i}",
+ )
+ for i in range(sa_ffn_layers)
+ ]
+
+ # Class Attention (CA) Transformer blocks operating on the CLS token and image patch
+ # embeddings.
+ self.blocks_token_only = [
+ LayerScaleBlockClassAttention(
+ projection_dim=projection_dim,
+ num_heads=num_heads,
+ layer_norm_eps=layer_norm_eps,
+ init_values=init_values,
+ mlp_units=mlp_units,
+ dropout_rate=dropout_rate,
+ name=f"ca_ffn_block_{i}",
+ sd_prob=0.0, # No Stochastic Depth in the class attention layers.
+ )
+ for i in range(ca_ffn_layers)
+ ]
+
+ # Pre-classification layer normalization.
+ self.norm = layers.LayerNormalization(epsilon=layer_norm_eps, name="head_norm")
+
+ # Representation pooling for classification head.
+ self.global_pool = global_pool
+
+ # Classification head.
+ self.pre_logits = pre_logits
+ self.num_classes = num_classes
+ if not pre_logits:
+ self.head = layers.Dense(num_classes, name="classification_head")
+
+ def call(self, x, training=False):
+ # Notice how CLS token is not added here.
+ x = self.projection(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ # SA+FFN layers.
+ sa_ffn_attn = {}
+ for blk in self.blocks:
+ x, attn_scores = blk(x)
+ sa_ffn_attn[f"{blk.name}_att"] = attn_scores
+
+ # CA+FFN layers.
+ ca_ffn_attn = {}
+ cls_tokens = ops.tile(self.cls_token, (ops.shape(x)[0], 1, 1))
+ for blk in self.blocks_token_only:
+ cls_tokens, attn_scores = blk([x, cls_tokens])
+ ca_ffn_attn[f"{blk.name}_att"] = attn_scores
+
+ x = ops.concatenate([cls_tokens, x], axis=1)
+ x = self.norm(x)
+
+ # Always return the attention scores from the SA+FFN and CA+FFN layers
+ # for convenience.
+ if self.global_pool:
+ x = (
+ ops.reduce_mean(x[:, 1:], axis=1)
+ if self.global_pool == "avg"
+ else x[:, 0]
+ )
+ return (
+ (x, sa_ffn_attn, ca_ffn_attn)
+ if self.pre_logits
+ else (self.head(x), sa_ffn_attn, ca_ffn_attn)
+ )
+
+
+"""
+Having the SA and CA layers segregated this way helps the model to focus on underlying
+objectives more concretely:
+
+* model dependencies in between the image patches
+* summarize the information from the image patches in a CLS token that can be used for
+the task at hand
+
+Now that we have defined the CaiT model, it's time to test it. We will start by defining
+a model configuration that will be passed to our `CaiT` class for initialization.
+"""
+
+"""
+## Defining Model Configuration
+"""
+
+
+def get_config(
+ image_size: int = 224,
+ patch_size: int = 16,
+ projection_dim: int = 192,
+ sa_ffn_layers: int = 24,
+ ca_ffn_layers: int = 2,
+ num_heads: int = 4,
+ mlp_ratio: int = 4,
+ layer_norm_eps=1e-6,
+ init_values: float = 1e-5,
+ dropout_rate: float = 0.0,
+ sd_prob: float = 0.0,
+ global_pool: str = "token",
+ pre_logits: bool = False,
+ num_classes: int = 1000,
+) -> typing.Dict:
+ """Default configuration for CaiT models (cait_xxs24_224).
+
+ Reference:
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cait.py
+ """
+ config = {}
+
+ # Patchification and projection.
+ config["patch_size"] = patch_size
+ config["num_patches"] = (image_size // patch_size) ** 2
+
+ # LayerScale.
+ config["init_values"] = init_values
+
+ # Dropout and Stochastic Depth.
+ config["dropout_rate"] = dropout_rate
+ config["sd_prob"] = sd_prob
+
+ # Shared across different blocks and layers.
+ config["layer_norm_eps"] = layer_norm_eps
+ config["projection_dim"] = projection_dim
+ config["mlp_units"] = [
+ projection_dim * mlp_ratio,
+ projection_dim,
+ ]
+
+ # Attention layers.
+ config["num_heads"] = num_heads
+ config["sa_ffn_layers"] = sa_ffn_layers
+ config["ca_ffn_layers"] = ca_ffn_layers
+
+ # Representation pooling and task specific parameters.
+ config["global_pool"] = global_pool
+ config["pre_logits"] = pre_logits
+ config["num_classes"] = num_classes
+
+ return config
+
+
+"""
+Most of the configuration variables should sound familiar to you if you already know the
+ViT architecture. Point of focus is given to `sa_ffn_layers` and `ca_ffn_layers` that
+control the number of SA-Transformer blocks and CA-Transformer blocks. You can easily
+amend this `get_config()` method to instantiate a CaiT model for your own dataset.
+"""
+
+"""
+## Model Instantiation
+"""
+
+image_size = 224
+num_channels = 3
+batch_size = 2
+
+config = get_config()
+cait_xxs24_224 = CaiT(**config)
+
+dummy_inputs = ops.ones((batch_size, image_size, image_size, num_channels))
+_ = cait_xxs24_224(dummy_inputs)
+
+"""
+We can successfully perform inference with the model. But what about implementation
+correctness? There are many ways to verify it:
+
+* Obtain the performance of the model (given it's been populated with the pre-trained
+parameters) on the ImageNet-1k validation set (as the pretraining dataset was
+ImageNet-1k).
+* Fine-tune the model on a different dataset.
+
+In order to verify that, we will load another instance of the same model that has been
+already populated with the pre-trained parameters. Please refer to
+[this repository](https://github.com/sayakpaul/cait-tf)
+(developed by the author of this notebook) for more details.
+Additionally, the repository provides code to verify model performance on the
+[ImageNet-1k validation set](https://github.com/sayakpaul/cait-tf/tree/main/i1k_eval)
+as well as
+[fine-tuning](https://github.com/sayakpaul/cait-tf/blob/main/notebooks/finetune.ipynb).
+"""
+
+"""
+## Load a pretrained model
+"""
+
+model_gcs_path = "gs://kaggle-tfhub-models-uncompressed/tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
+pretrained_model = keras.Sequential(
+ [keras.layers.TFSMLayer(model_gcs_path, call_endpoint="serving_default")]
+)
+
+"""
+## Inference utilities
+
+In the next couple of cells, we develop preprocessing utilities needed to run inference
+with the pretrained model.
+"""
+# The preprocessing transformations include center cropping, and normalizing
+# the pixel values with the ImageNet-1k training stats (mean and standard deviation).
+crop_layer = keras.layers.CenterCrop(image_size, image_size)
+norm_layer = keras.layers.Normalization(
+ mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
+ variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
+)
+
+
+def preprocess_image(image, size=image_size):
+ image = np.array(image)
+ image_resized = ops.expand_dims(image, 0)
+ resize_size = int((256 / image_size) * size)
+ image_resized = ops.image.resize(
+ image_resized, (resize_size, resize_size), interpolation="bicubic"
+ )
+ image_resized = crop_layer(image_resized)
+ return norm_layer(image_resized).numpy()
+
+
+def load_image_from_url(url):
+ image_bytes = io.BytesIO(urlopen(url).read())
+ image = PIL.Image.open(image_bytes)
+ preprocessed_image = preprocess_image(image)
+ return image, preprocessed_image
+
+
+"""
+Now, we retrieve the ImageNet-1k labels and load them as the model we're
+loading was pretrained on the ImageNet-1k dataset.
+"""
+
+# ImageNet-1k class labels.
+imagenet_labels = (
+ "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
+)
+label_path = keras.utils.get_file(origin=imagenet_labels)
+
+with open(label_path, "r") as f:
+ lines = f.readlines()
+imagenet_labels = [line.rstrip() for line in lines]
+
+"""
+## Load an Image
+"""
+
+img_url = "https://i.imgur.com/ErgfLTn.jpg"
+image, preprocessed_image = load_image_from_url(img_url)
+
+# https://unsplash.com/photos/Ho93gVTRWW8
+plt.imshow(image)
+plt.axis("off")
+plt.show()
+
+"""
+## Obtain Predictions
+"""
+
+outputs = pretrained_model.predict(preprocessed_image)
+logits = outputs["output_1"]
+ca_ffn_block_0_att = outputs["output_3_ca_ffn_block_0_att"]
+ca_ffn_block_1_att = outputs["output_3_ca_ffn_block_1_att"]
+
+predicted_label = imagenet_labels[int(np.argmax(logits))]
+print(predicted_label)
+
+"""
+Now that we have obtained the predictions (which appear to be as expected), we can
+further extend our investigation. Following the CaiT authors, we can investigate the
+attention scores from the attention layers. This helps us to get deeper insights into the
+modifications introduced in the CaiT paper.
+"""
+
+"""
+## Visualizing the Attention Layers
+
+We start by inspecting the shape of the attention weights returned by a Class Attention
+layer.
+"""
+
+# (batch_size, nb_attention_heads, num_cls_token, seq_length)
+print("Shape of the attention scores from a class attention block:")
+print(ca_ffn_block_0_att.shape)
+
+"""
+The shape denotes we have got attention weights for each of the individual attention
+heads. They quantify the information about how the CLS token is related to itself and the
+rest of the image patches.
+
+Next, we write a utility to:
+
+* Visualize what the individual attention heads in the Class Attention layers are
+focusing on. This helps us to get an idea of how the _spatial-class relationship_ is
+induced in the CaiT model.
+* Obtain a saliency map from the first Class Attention layer that helps to understand how
+CA layer aggregates information from the region(s) of interest in the images.
+
+This utility is referred from Figures 6 and 7 of the original
+[CaiT paper](https://arxiv.org/abs/2103.17239). This is also a part of
+[this notebook](https://github.com/sayakpaul/cait-tf/blob/main/notebooks/classification.ipynb)
+(developed by the author of this tutorial).
+"""
+
+# Reference:
+# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
+
+patch_size = 16
+
+
+def get_cls_attention_map(
+ attention_scores,
+ return_saliency=False,
+) -> np.ndarray:
+ """
+ Returns attention scores from a particular attention block.
+
+ Args:
+ attention_scores: the attention scores from the attention block to
+ visualize.
+ return_saliency: a boolean flag if set to True also returns the salient
+ representations of the attention block.
+ """
+ w_featmap = preprocessed_image.shape[2] // patch_size
+ h_featmap = preprocessed_image.shape[1] // patch_size
+
+ nh = attention_scores.shape[1] # Number of attention heads.
+
+ # Taking the representations from CLS token.
+ attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
+
+ # Reshape the attention scores to resemble mini patches.
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
+
+ if not return_saliency:
+ attentions = attentions.transpose((1, 2, 0))
+
+ else:
+ attentions = np.mean(attentions, axis=0)
+ attentions = (attentions - attentions.min()) / (
+ attentions.max() - attentions.min()
+ )
+ attentions = np.expand_dims(attentions, -1)
+
+ # Resize the attention patches to 224x224 (224: 14x16)
+ attentions = ops.image.resize(
+ attentions,
+ size=(h_featmap * patch_size, w_featmap * patch_size),
+ interpolation="bicubic",
+ )
+
+ return attentions
+
+
+"""
+In the first CA layer, we notice that the model is focusing solely on the region of
+interest.
+"""
+
+attentions_ca_block_0 = get_cls_attention_map(ca_ffn_block_0_att)
+
+
+fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
+img_count = 0
+
+for i in range(attentions_ca_block_0.shape[-1]):
+ if img_count < attentions_ca_block_0.shape[-1]:
+ axes[i].imshow(attentions_ca_block_0[:, :, img_count])
+ axes[i].title.set_text(f"Attention head: {img_count}")
+ axes[i].axis("off")
+ img_count += 1
+
+fig.tight_layout()
+plt.show()
+
+"""
+Whereas in the second CA layer, the model is trying to focus more on the context that
+contains discriminative signals.
+"""
+
+attentions_ca_block_1 = get_cls_attention_map(ca_ffn_block_1_att)
+
+
+fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
+img_count = 0
+
+for i in range(attentions_ca_block_1.shape[-1]):
+ if img_count < attentions_ca_block_1.shape[-1]:
+ axes[i].imshow(attentions_ca_block_1[:, :, img_count])
+ axes[i].title.set_text(f"Attention head: {img_count}")
+ axes[i].axis("off")
+ img_count += 1
+
+fig.tight_layout()
+plt.show()
+
+"""
+Finally, we obtain the saliency map for the given image.
+"""
+
+saliency_attention = get_cls_attention_map(ca_ffn_block_0_att, return_saliency=True)
+
+image = np.array(image)
+image_resized = ops.expand_dims(image, 0)
+resize_size = int((256 / 224) * image_size)
+image_resized = ops.image.resize(
+ image_resized, (resize_size, resize_size), interpolation="bicubic"
+)
+image_resized = crop_layer(image_resized)
+
+plt.imshow(image_resized.numpy().squeeze().astype("int32"))
+plt.imshow(saliency_attention.numpy().squeeze(), cmap="cividis", alpha=0.9)
+plt.axis("off")
+
+plt.show()
+
+"""
+## Conclusion
+
+In this notebook, we implemented the CaiT model. It shows how to mitigate the issues in
+ViTs when trying scale their depth while keeping the pretraining dataset fixed. I hope
+the additional visualizations provided in the notebook spark excitement in the community
+and people develop interesting methods to probe what models like ViT learn.
+
+## Acknowledgement
+
+Thanks to the ML Developer Programs team at Google providing Google Cloud Platform
+support.
+"""
diff --git a/knowledge_base/vision/captcha_ocr.py b/knowledge_base/vision/captcha_ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..574a6a1b240df917d8905663b623583580a70f8f
--- /dev/null
+++ b/knowledge_base/vision/captcha_ocr.py
@@ -0,0 +1,428 @@
+"""
+Title: OCR model for reading Captchas
+Author: [A_K_Nain](https://twitter.com/A_K_Nain)
+Date created: 2020/06/14
+Last modified: 2024/03/13
+Description: How to implement an OCR model using CNNs, RNNs and CTC loss.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+This example demonstrates a simple OCR model built with the Functional API. Apart from
+combining CNN and RNN, it also illustrates how you can instantiate a new layer
+and use it as an "Endpoint layer" for implementing CTC loss. For a detailed
+guide to layer subclassing, please check out
+[this page](https://keras.io/guides/making_new_layers_and_models_via_subclassing/)
+in the developer guides.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from pathlib import Path
+
+import tensorflow as tf
+import keras
+from keras import ops
+from keras import layers
+
+"""
+## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)
+Let's download the data.
+"""
+
+
+"""shell
+curl -LO https://github.com/AakashKumarNain/CaptchaCracker/raw/master/captcha_images_v2.zip
+unzip -qq captcha_images_v2.zip
+"""
+
+
+"""
+The dataset contains 1040 captcha files as `png` images. The label for each sample is a string,
+the name of the file (minus the file extension).
+We will map each character in the string to an integer for training the model. Similary,
+we will need to map the predictions of the model back to strings. For this purpose
+we will maintain two dictionaries, mapping characters to integers, and integers to characters,
+respectively.
+"""
+
+
+# Path to the data directory
+data_dir = Path("./captcha_images_v2/")
+
+# Get list of all the images
+images = sorted(list(map(str, list(data_dir.glob("*.png")))))
+labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images]
+characters = set(char for label in labels for char in label)
+characters = sorted(list(characters))
+
+print("Number of images found: ", len(images))
+print("Number of labels found: ", len(labels))
+print("Number of unique characters: ", len(characters))
+print("Characters present: ", characters)
+
+# Batch size for training and validation
+batch_size = 16
+
+# Desired image dimensions
+img_width = 200
+img_height = 50
+
+# Factor by which the image is going to be downsampled
+# by the convolutional blocks. We will be using two
+# convolution blocks and each block will have
+# a pooling layer which downsample the features by a factor of 2.
+# Hence total downsampling factor would be 4.
+downsample_factor = 4
+
+# Maximum length of any captcha in the dataset
+max_length = max([len(label) for label in labels])
+
+
+"""
+## Preprocessing
+"""
+
+
+# Mapping characters to integers
+char_to_num = layers.StringLookup(vocabulary=list(characters), mask_token=None)
+
+# Mapping integers back to original characters
+num_to_char = layers.StringLookup(
+ vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
+)
+
+
+def split_data(images, labels, train_size=0.9, shuffle=True):
+ # 1. Get the total size of the dataset
+ size = len(images)
+ # 2. Make an indices array and shuffle it, if required
+ indices = ops.arange(size)
+ if shuffle:
+ indices = keras.random.shuffle(indices)
+ # 3. Get the size of training samples
+ train_samples = int(size * train_size)
+ # 4. Split data into training and validation sets
+ x_train, y_train = images[indices[:train_samples]], labels[indices[:train_samples]]
+ x_valid, y_valid = images[indices[train_samples:]], labels[indices[train_samples:]]
+ return x_train, x_valid, y_train, y_valid
+
+
+# Splitting data into training and validation sets
+x_train, x_valid, y_train, y_valid = split_data(np.array(images), np.array(labels))
+
+
+def encode_single_sample(img_path, label):
+ # 1. Read image
+ img = tf.io.read_file(img_path)
+ # 2. Decode and convert to grayscale
+ img = tf.io.decode_png(img, channels=1)
+ # 3. Convert to float32 in [0, 1] range
+ img = tf.image.convert_image_dtype(img, tf.float32)
+ # 4. Resize to the desired size
+ img = ops.image.resize(img, [img_height, img_width])
+ # 5. Transpose the image because we want the time
+ # dimension to correspond to the width of the image.
+ img = ops.transpose(img, axes=[1, 0, 2])
+ # 6. Map the characters in label to numbers
+ label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
+ # 7. Return a dict as our model is expecting two inputs
+ return {"image": img, "label": label}
+
+
+"""
+## Create `Dataset` objects
+"""
+
+
+train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_dataset = (
+ train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(batch_size)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+)
+
+validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
+validation_dataset = (
+ validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(batch_size)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+)
+
+"""
+## Visualize the data
+"""
+
+
+_, ax = plt.subplots(4, 4, figsize=(10, 5))
+for batch in train_dataset.take(1):
+ images = batch["image"]
+ labels = batch["label"]
+ for i in range(16):
+ img = (images[i] * 255).numpy().astype("uint8")
+ label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8")
+ ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray")
+ ax[i // 4, i % 4].set_title(label)
+ ax[i // 4, i % 4].axis("off")
+plt.show()
+
+"""
+## Model
+"""
+
+
+def ctc_batch_cost(y_true, y_pred, input_length, label_length):
+ label_length = ops.cast(ops.squeeze(label_length, axis=-1), dtype="int32")
+ input_length = ops.cast(ops.squeeze(input_length, axis=-1), dtype="int32")
+ sparse_labels = ops.cast(
+ ctc_label_dense_to_sparse(y_true, label_length), dtype="int32"
+ )
+
+ y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
+
+ return ops.expand_dims(
+ tf.compat.v1.nn.ctc_loss(
+ inputs=y_pred, labels=sparse_labels, sequence_length=input_length
+ ),
+ 1,
+ )
+
+
+def ctc_label_dense_to_sparse(labels, label_lengths):
+ label_shape = ops.shape(labels)
+ num_batches_tns = ops.stack([label_shape[0]])
+ max_num_labels_tns = ops.stack([label_shape[1]])
+
+ def range_less_than(old_input, current_input):
+ return ops.expand_dims(ops.arange(ops.shape(old_input)[1]), 0) < tf.fill(
+ max_num_labels_tns, current_input
+ )
+
+ init = ops.cast(tf.fill([1, label_shape[1]], 0), dtype="bool")
+ dense_mask = tf.compat.v1.scan(
+ range_less_than, label_lengths, initializer=init, parallel_iterations=1
+ )
+ dense_mask = dense_mask[:, 0, :]
+
+ label_array = ops.reshape(
+ ops.tile(ops.arange(0, label_shape[1]), num_batches_tns), label_shape
+ )
+ label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
+
+ batch_array = ops.transpose(
+ ops.reshape(
+ ops.tile(ops.arange(0, label_shape[0]), max_num_labels_tns),
+ tf.reverse(label_shape, [0]),
+ )
+ )
+ batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
+ indices = ops.transpose(
+ ops.reshape(ops.concatenate([batch_ind, label_ind], axis=0), [2, -1])
+ )
+
+ vals_sparse = tf.compat.v1.gather_nd(labels, indices)
+
+ return tf.SparseTensor(
+ ops.cast(indices, dtype="int64"),
+ vals_sparse,
+ ops.cast(label_shape, dtype="int64"),
+ )
+
+
+class CTCLayer(layers.Layer):
+ def __init__(self, name=None):
+ super().__init__(name=name)
+ self.loss_fn = ctc_batch_cost
+
+ def call(self, y_true, y_pred):
+ # Compute the training-time loss value and add it
+ # to the layer using `self.add_loss()`.
+ batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
+ input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
+ label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
+
+ input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
+ label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
+
+ loss = self.loss_fn(y_true, y_pred, input_length, label_length)
+ self.add_loss(loss)
+
+ # At test time, just return the computed predictions
+ return y_pred
+
+
+def build_model():
+ # Inputs to the model
+ input_img = layers.Input(
+ shape=(img_width, img_height, 1), name="image", dtype="float32"
+ )
+ labels = layers.Input(name="label", shape=(None,), dtype="float32")
+
+ # First conv block
+ x = layers.Conv2D(
+ 32,
+ (3, 3),
+ activation="relu",
+ kernel_initializer="he_normal",
+ padding="same",
+ name="Conv1",
+ )(input_img)
+ x = layers.MaxPooling2D((2, 2), name="pool1")(x)
+
+ # Second conv block
+ x = layers.Conv2D(
+ 64,
+ (3, 3),
+ activation="relu",
+ kernel_initializer="he_normal",
+ padding="same",
+ name="Conv2",
+ )(x)
+ x = layers.MaxPooling2D((2, 2), name="pool2")(x)
+
+ # We have used two max pool with pool size and strides 2.
+ # Hence, downsampled feature maps are 4x smaller. The number of
+ # filters in the last layer is 64. Reshape accordingly before
+ # passing the output to the RNN part of the model
+ new_shape = ((img_width // 4), (img_height // 4) * 64)
+ x = layers.Reshape(target_shape=new_shape, name="reshape")(x)
+ x = layers.Dense(64, activation="relu", name="dense1")(x)
+ x = layers.Dropout(0.2)(x)
+
+ # RNNs
+ x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x)
+ x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)
+
+ # Output layer
+ x = layers.Dense(
+ len(char_to_num.get_vocabulary()) + 1, activation="softmax", name="dense2"
+ )(x)
+
+ # Add CTC layer for calculating CTC loss at each step
+ output = CTCLayer(name="ctc_loss")(labels, x)
+
+ # Define the model
+ model = keras.models.Model(
+ inputs=[input_img, labels], outputs=output, name="ocr_model_v1"
+ )
+ # Optimizer
+ opt = keras.optimizers.Adam()
+ # Compile the model and return
+ model.compile(optimizer=opt)
+ return model
+
+
+# Get the model
+model = build_model()
+model.summary()
+
+"""
+## Training
+"""
+
+
+# TODO restore epoch count.
+epochs = 100
+early_stopping_patience = 10
+# Add early stopping
+early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True
+)
+
+# Train the model
+history = model.fit(
+ train_dataset,
+ validation_data=validation_dataset,
+ epochs=epochs,
+ callbacks=[early_stopping],
+)
+
+
+"""
+## Inference
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/ocr-for-captcha)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/ocr-for-captcha).
+"""
+
+
+def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
+ input_shape = ops.shape(y_pred)
+ num_samples, num_steps = input_shape[0], input_shape[1]
+ y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
+ input_length = ops.cast(input_length, dtype="int32")
+
+ if greedy:
+ (decoded, log_prob) = tf.nn.ctc_greedy_decoder(
+ inputs=y_pred, sequence_length=input_length
+ )
+ else:
+ (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
+ inputs=y_pred,
+ sequence_length=input_length,
+ beam_width=beam_width,
+ top_paths=top_paths,
+ )
+ decoded_dense = []
+ for st in decoded:
+ st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
+ decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
+ return (decoded_dense, log_prob)
+
+
+# Get the prediction model by extracting layers till the output layer
+prediction_model = keras.models.Model(
+ model.input[0], model.get_layer(name="dense2").output
+)
+prediction_model.summary()
+
+
+# A utility function to decode the output of the network
+def decode_batch_predictions(pred):
+ input_len = np.ones(pred.shape[0]) * pred.shape[1]
+ # Use greedy search. For complex tasks, you can use beam search
+ results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
+ :, :max_length
+ ]
+ # Iterate over the results and get back the text
+ output_text = []
+ for res in results:
+ res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
+ output_text.append(res)
+ return output_text
+
+
+# Let's check results on some validation samples
+for batch in validation_dataset.take(1):
+ batch_images = batch["image"]
+ batch_labels = batch["label"]
+
+ preds = prediction_model.predict(batch_images)
+ pred_texts = decode_batch_predictions(preds)
+
+ orig_texts = []
+ for label in batch_labels:
+ label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
+ orig_texts.append(label)
+
+ _, ax = plt.subplots(4, 4, figsize=(15, 5))
+ for i in range(len(pred_texts)):
+ img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)
+ img = img.T
+ title = f"Prediction: {pred_texts[i]}"
+ ax[i // 4, i % 4].imshow(img, cmap="gray")
+ ax[i // 4, i % 4].set_title(title)
+ ax[i // 4, i % 4].axis("off")
+plt.show()
diff --git a/knowledge_base/vision/cct.py b/knowledge_base/vision/cct.py
new file mode 100644
index 0000000000000000000000000000000000000000..999856597337efefa9126d9d27f438103e53a1b6
--- /dev/null
+++ b/knowledge_base/vision/cct.py
@@ -0,0 +1,441 @@
+"""
+Title: Compact Convolutional Transformers
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/06/30
+Last modified: 2023/08/07
+Description: Compact Convolutional Transformers for efficient image classification.
+Accelerator: GPU
+Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com), [Guillaume Baquiast](https://www.linkedin.com/in/guillaume-baquiast-478965ba/)
+"""
+
+"""
+As discussed in the [Vision Transformers (ViT)](https://arxiv.org/abs/2010.11929) paper,
+a Transformer-based architecture for vision typically requires a larger dataset than
+usual, as well as a longer pre-training schedule. [ImageNet-1k](http://imagenet.org/)
+(which has about a million images) is considered to fall under the medium-sized data regime with
+respect to ViTs. This is primarily because, unlike CNNs, ViTs (or a typical
+Transformer-based architecture) do not have well-informed inductive biases (such as
+convolutions for processing images). This begs the question: can't we combine the
+benefits of convolution and the benefits of Transformers
+in a single network architecture? These benefits include parameter-efficiency, and
+self-attention to process long-range and global dependencies (interactions between
+different regions in an image).
+
+In [Escaping the Big Data Paradigm with Compact Transformers](https://arxiv.org/abs/2104.05704),
+Hassani et al. present an approach for doing exactly this. They proposed the
+**Compact Convolutional Transformer** (CCT) architecture. In this example, we will work on an
+implementation of CCT and we will see how well it performs on the CIFAR-10 dataset.
+
+If you are unfamiliar with the concept of self-attention or Transformers, you can read
+[this chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11/r-3/312)
+from Franรงois Chollet's book *Deep Learning with Python*. This example uses
+code snippets from another example,
+[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
+"""
+
+"""
+## Imports
+"""
+
+from keras import layers
+import keras
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+"""
+## Hyperparameters and constants
+"""
+
+positional_emb = True
+conv_layers = 2
+projection_dim = 128
+
+num_heads = 2
+transformer_units = [
+ projection_dim,
+ projection_dim,
+]
+transformer_layers = 2
+stochastic_depth_rate = 0.1
+
+learning_rate = 0.001
+weight_decay = 0.0001
+batch_size = 128
+num_epochs = 30
+image_size = 32
+
+"""
+## Load CIFAR-10 dataset
+"""
+
+num_classes = 10
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+
+y_train = keras.utils.to_categorical(y_train, num_classes)
+y_test = keras.utils.to_categorical(y_test, num_classes)
+
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+"""
+## The CCT tokenizer
+
+The first recipe introduced by the CCT authors is the tokenizer for processing the
+images. In a standard ViT, images are organized into uniform *non-overlapping* patches.
+This eliminates the boundary-level information present in between different patches. This
+is important for a neural network to effectively exploit the locality information. The
+figure below presents an illustration of how images are organized into patches.
+
+
+
+We already know that convolutions are quite good at exploiting locality information. So,
+based on this, the authors introduce an all-convolution mini-network to produce image
+patches.
+"""
+
+
+class CCTTokenizer(layers.Layer):
+ def __init__(
+ self,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ pooling_kernel_size=3,
+ pooling_stride=2,
+ num_conv_layers=conv_layers,
+ num_output_channels=[64, 128],
+ positional_emb=positional_emb,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ # This is our tokenizer.
+ self.conv_model = keras.Sequential()
+ for i in range(num_conv_layers):
+ self.conv_model.add(
+ layers.Conv2D(
+ num_output_channels[i],
+ kernel_size,
+ stride,
+ padding="valid",
+ use_bias=False,
+ activation="relu",
+ kernel_initializer="he_normal",
+ )
+ )
+ self.conv_model.add(layers.ZeroPadding2D(padding))
+ self.conv_model.add(
+ layers.MaxPooling2D(pooling_kernel_size, pooling_stride, "same")
+ )
+
+ self.positional_emb = positional_emb
+
+ def call(self, images):
+ outputs = self.conv_model(images)
+ # After passing the images through our mini-network the spatial dimensions
+ # are flattened to form sequences.
+ reshaped = keras.ops.reshape(
+ outputs,
+ (
+ -1,
+ keras.ops.shape(outputs)[1] * keras.ops.shape(outputs)[2],
+ keras.ops.shape(outputs)[-1],
+ ),
+ )
+ return reshaped
+
+
+"""
+Positional embeddings are optional in CCT. If we want to use them, we can use
+the Layer defined below.
+"""
+
+
+class PositionEmbedding(keras.layers.Layer):
+ def __init__(
+ self,
+ sequence_length,
+ initializer="glorot_uniform",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if sequence_length is None:
+ raise ValueError("`sequence_length` must be an Integer, received `None`.")
+ self.sequence_length = int(sequence_length)
+ self.initializer = keras.initializers.get(initializer)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "sequence_length": self.sequence_length,
+ "initializer": keras.initializers.serialize(self.initializer),
+ }
+ )
+ return config
+
+ def build(self, input_shape):
+ feature_size = input_shape[-1]
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.sequence_length, feature_size],
+ initializer=self.initializer,
+ trainable=True,
+ )
+
+ super().build(input_shape)
+
+ def call(self, inputs, start_index=0):
+ shape = keras.ops.shape(inputs)
+ feature_length = shape[-1]
+ sequence_length = shape[-2]
+ # trim to match the length of the input sequence, which might be less
+ # than the sequence_length of the layer.
+ position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
+ position_embeddings = keras.ops.slice(
+ position_embeddings,
+ (start_index, 0),
+ (sequence_length, feature_length),
+ )
+ return keras.ops.broadcast_to(position_embeddings, shape)
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+"""
+## Sequence Pooling
+Another recipe introduced in CCT is attention pooling or sequence pooling. In ViT, only
+the feature map corresponding to the class token is pooled and is then used for the
+subsequent classification task (or any other downstream task).
+"""
+
+
+class SequencePooling(layers.Layer):
+ def __init__(self):
+ super().__init__()
+ self.attention = layers.Dense(1)
+
+ def call(self, x):
+ attention_weights = keras.ops.softmax(self.attention(x), axis=1)
+ attention_weights = keras.ops.transpose(attention_weights, axes=(0, 2, 1))
+ weighted_representation = keras.ops.matmul(attention_weights, x)
+ return keras.ops.squeeze(weighted_representation, -2)
+
+
+"""
+## Stochastic depth for regularization
+
+[Stochastic depth](https://arxiv.org/abs/1603.09382) is a regularization technique that
+randomly drops a set of layers. During inference, the layers are kept as they are. It is
+very much similar to [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) but only
+that it operates on a block of layers rather than individual nodes present inside a
+layer. In CCT, stochastic depth is used just before the residual blocks of a Transformers
+encoder.
+"""
+
+
+# Referred from: github.com:rwightman/pytorch-image-models.
+class StochasticDepth(layers.Layer):
+ def __init__(self, drop_prop, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_prob = drop_prop
+ self.seed_generator = keras.random.SeedGenerator(1337)
+
+ def call(self, x, training=None):
+ if training:
+ keep_prob = 1 - self.drop_prob
+ shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
+ random_tensor = keep_prob + keras.random.uniform(
+ shape, 0, 1, seed=self.seed_generator
+ )
+ random_tensor = keras.ops.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+"""
+## MLP for the Transformers encoder
+"""
+
+
+def mlp(x, hidden_units, dropout_rate):
+ for units in hidden_units:
+ x = layers.Dense(units, activation=keras.ops.gelu)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+"""
+## Data augmentation
+
+In the [original paper](https://arxiv.org/abs/2104.05704), the authors use
+[AutoAugment](https://arxiv.org/abs/1805.09501) to induce stronger regularization. For
+this example, we will be using the standard geometric augmentations like random cropping
+and flipping.
+"""
+
+# Note the rescaling layer. These layers have pre-defined inference behavior.
+data_augmentation = keras.Sequential(
+ [
+ layers.Rescaling(scale=1.0 / 255),
+ layers.RandomCrop(image_size, image_size),
+ layers.RandomFlip("horizontal"),
+ ],
+ name="data_augmentation",
+)
+
+"""
+## The final CCT model
+
+In CCT, outputs from the Transformers encoder are weighted and then passed on to the final task-specific layer (in
+this example, we do classification).
+"""
+
+
+def create_cct_model(
+ image_size=image_size,
+ input_shape=input_shape,
+ num_heads=num_heads,
+ projection_dim=projection_dim,
+ transformer_units=transformer_units,
+):
+ inputs = layers.Input(input_shape)
+
+ # Augment data.
+ augmented = data_augmentation(inputs)
+
+ # Encode patches.
+ cct_tokenizer = CCTTokenizer()
+ encoded_patches = cct_tokenizer(augmented)
+
+ # Apply positional embedding.
+ if positional_emb:
+ sequence_length = encoded_patches.shape[1]
+ encoded_patches += PositionEmbedding(sequence_length=sequence_length)(
+ encoded_patches
+ )
+
+ # Calculate Stochastic Depth probabilities.
+ dpr = [x for x in np.linspace(0, stochastic_depth_rate, transformer_layers)]
+
+ # Create multiple layers of the Transformer block.
+ for i in range(transformer_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
+
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=projection_dim, dropout=0.1
+ )(x1, x1)
+
+ # Skip connection 1.
+ attention_output = StochasticDepth(dpr[i])(attention_output)
+ x2 = layers.Add()([attention_output, encoded_patches])
+
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-5)(x2)
+
+ # MLP.
+ x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
+
+ # Skip connection 2.
+ x3 = StochasticDepth(dpr[i])(x3)
+ encoded_patches = layers.Add()([x3, x2])
+
+ # Apply sequence pooling.
+ representation = layers.LayerNormalization(epsilon=1e-5)(encoded_patches)
+ weighted_representation = SequencePooling()(representation)
+
+ # Classify outputs.
+ logits = layers.Dense(num_classes)(weighted_representation)
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=logits)
+ return model
+
+
+"""
+## Model training and evaluation
+"""
+
+
+def run_experiment(model):
+ optimizer = keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001)
+
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.CategoricalCrossentropy(
+ from_logits=True, label_smoothing=0.1
+ ),
+ metrics=[
+ keras.metrics.CategoricalAccuracy(name="accuracy"),
+ keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+ )
+
+ checkpoint_filepath = "/tmp/checkpoint.weights.h5"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=0.1,
+ callbacks=[checkpoint_callback],
+ )
+
+ model.load_weights(checkpoint_filepath)
+ _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+ return history
+
+
+cct_model = create_cct_model()
+history = run_experiment(cct_model)
+
+"""
+Let's now visualize the training progress of the model.
+"""
+
+plt.plot(history.history["loss"], label="train_loss")
+plt.plot(history.history["val_loss"], label="val_loss")
+plt.xlabel("Epochs")
+plt.ylabel("Loss")
+plt.title("Train and Validation Losses Over Epochs", fontsize=14)
+plt.legend()
+plt.grid()
+plt.show()
+
+"""
+The CCT model we just trained has just **0.4 million** parameters, and it gets us to
+~79% top-1 accuracy within 30 epochs. The plot above shows no signs of overfitting as
+well. This means we can train this network for longer (perhaps with a bit more
+regularization) and may obtain even better performance. This performance can further be
+improved by additional recipes like cosine decay learning rate schedule, other data augmentation
+techniques like [AutoAugment](https://arxiv.org/abs/1805.09501),
+[MixUp](https://arxiv.org/abs/1710.09412) or
+[Cutmix](https://arxiv.org/abs/1905.04899). With these modifications, the authors present
+95.1% top-1 accuracy on the CIFAR-10 dataset. The authors also present a number of
+experiments to study how the number of convolution blocks, Transformers layers, etc.
+affect the final performance of CCTs.
+
+For a comparison, a ViT model takes about **4.7 million** parameters and **100
+epochs** of training to reach a top-1 accuracy of 78.22% on the CIFAR-10 dataset. You can
+refer to
+[this notebook](https://colab.research.google.com/gist/sayakpaul/1a80d9f582b044354a1a26c5cb3d69e5/image_classification_with_vision_transformer.ipynb)
+to know about the experimental setup.
+
+The authors also demonstrate the performance of Compact Convolutional Transformers on
+NLP tasks and they report competitive results there.
+"""
diff --git a/knowledge_base/vision/consistency_training.py b/knowledge_base/vision/consistency_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..39ef2af899a61c3b160d3db1f81a83cadecfdc72
--- /dev/null
+++ b/knowledge_base/vision/consistency_training.py
@@ -0,0 +1,396 @@
+"""
+Title: Consistency training with supervision
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/04/13
+Last modified: 2021/04/19
+Description: Training with consistency regularization for robustness against data distribution shifts.
+Accelerator: GPU
+"""
+
+"""
+Deep learning models excel in many image recognition tasks when the data is independent
+and identically distributed (i.i.d.). However, they can suffer from performance
+degradation caused by subtle distribution shifts in the input data (such as random
+noise, contrast change, and blurring). So, naturally, there arises a question of
+why. As discussed in [A Fourier Perspective on Model Robustness in Computer Vision](https://arxiv.org/pdf/1906.08988.pdf)),
+there's no reason for deep learning models to be robust against such shifts. Standard
+model training procedures (such as standard image classification training workflows)
+*don't* enable a model to learn beyond what's fed to it in the form of training data.
+
+In this example, we will be training an image classification model enforcing a sense of
+*consistency* inside it by doing the following:
+
+* Train a standard image classification model.
+* Train an _equal or larger_ model on a noisy version of the dataset (augmented using
+[RandAugment](https://arxiv.org/abs/1909.13719)).
+* To do this, we will first obtain predictions of the previous model on the clean images
+of the dataset.
+* We will then use these predictions and train the second model to match these
+predictions on the noisy variant of the same images. This is identical to the workflow of
+[*Knowledge Distillation*](https://keras.io/examples/vision/knowledge_distillation/) but
+since the student model is equal or larger in size this process is also referred to as
+***Self-Training***.
+
+This overall training workflow finds its roots in works like
+[FixMatch](https://arxiv.org/abs/2001.07685), [Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/abs/1904.12848),
+and [Noisy Student Training](https://arxiv.org/abs/1911.04252). Since this training
+process encourages a model yield consistent predictions for clean as well as noisy
+images, it's often referred to as *consistency training* or *training with consistency
+regularization*. Although the example focuses on using consistency training to enhance
+the robustness of models to common corruptions this example can also serve a template
+for performing _weakly supervised learning_.
+
+This example requires TensorFlow 2.4 or higher, as well as TensorFlow Hub and TensorFlow
+Models, which can be installed using the following command:
+
+"""
+
+"""shell
+pip install -q tf-models-official tensorflow-addons
+"""
+
+"""
+## Imports and setup
+"""
+
+from official.vision.image_classification.augment import RandAugment
+from tensorflow.keras import layers
+
+import tensorflow as tf
+import tensorflow_addons as tfa
+import matplotlib.pyplot as plt
+
+tf.random.set_seed(42)
+
+"""
+## Define hyperparameters
+"""
+
+AUTO = tf.data.AUTOTUNE
+BATCH_SIZE = 128
+EPOCHS = 5
+
+CROP_TO = 72
+RESIZE_TO = 96
+
+"""
+## Load the CIFAR-10 dataset
+"""
+
+(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
+
+val_samples = 49500
+new_train_x, new_y_train = x_train[: val_samples + 1], y_train[: val_samples + 1]
+val_x, val_y = x_train[val_samples:], y_train[val_samples:]
+
+"""
+## Create TensorFlow `Dataset` objects
+"""
+
+# Initialize `RandAugment` object with 2 layers of
+# augmentation transforms and strength of 9.
+augmenter = RandAugment(num_layers=2, magnitude=9)
+
+"""
+For training the teacher model, we will only be using two geometric augmentation
+transforms: random horizontal flip and random crop.
+"""
+
+
+def preprocess_train(image, label, noisy=True):
+ image = tf.image.random_flip_left_right(image)
+ # We first resize the original image to a larger dimension
+ # and then we take random crops from it.
+ image = tf.image.resize(image, [RESIZE_TO, RESIZE_TO])
+ image = tf.image.random_crop(image, [CROP_TO, CROP_TO, 3])
+ if noisy:
+ image = augmenter.distort(image)
+ return image, label
+
+
+def preprocess_test(image, label):
+ image = tf.image.resize(image, [CROP_TO, CROP_TO])
+ return image, label
+
+
+train_ds = tf.data.Dataset.from_tensor_slices((new_train_x, new_y_train))
+validation_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y))
+test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+
+"""
+We make sure `train_clean_ds` and `train_noisy_ds` are shuffled using the *same* seed to
+ensure their orders are exactly the same. This will be helpful during training the
+student model.
+"""
+
+# This dataset will be used to train the first model.
+train_clean_ds = (
+ train_ds.shuffle(BATCH_SIZE * 10, seed=42)
+ .map(lambda x, y: (preprocess_train(x, y, noisy=False)), num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# This prepares the `Dataset` object to use RandAugment.
+train_noisy_ds = (
+ train_ds.shuffle(BATCH_SIZE * 10, seed=42)
+ .map(preprocess_train, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+validation_ds = (
+ validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+test_ds = (
+ test_ds.map(preprocess_test, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# This dataset will be used to train the second model.
+consistency_training_ds = tf.data.Dataset.zip((train_clean_ds, train_noisy_ds))
+
+"""
+## Visualize the datasets
+"""
+
+sample_images, sample_labels = next(iter(train_clean_ds))
+plt.figure(figsize=(10, 10))
+for i, image in enumerate(sample_images[:9]):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image.numpy().astype("int"))
+ plt.axis("off")
+
+sample_images, sample_labels = next(iter(train_noisy_ds))
+plt.figure(figsize=(10, 10))
+for i, image in enumerate(sample_images[:9]):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image.numpy().astype("int"))
+ plt.axis("off")
+
+"""
+## Define a model building utility function
+
+We now define our model building utility. Our model is based on the [ResNet50V2 architecture](https://arxiv.org/abs/1603.05027).
+"""
+
+
+def get_training_model(num_classes=10):
+ resnet50_v2 = tf.keras.applications.ResNet50V2(
+ weights=None,
+ include_top=False,
+ input_shape=(CROP_TO, CROP_TO, 3),
+ )
+ model = tf.keras.Sequential(
+ [
+ layers.Input((CROP_TO, CROP_TO, 3)),
+ layers.Rescaling(scale=1.0 / 127.5, offset=-1),
+ resnet50_v2,
+ layers.GlobalAveragePooling2D(),
+ layers.Dense(num_classes),
+ ]
+ )
+ return model
+
+
+"""
+In the interest of reproducibility, we serialize the initial random weights of the
+teacher network.
+"""
+
+initial_teacher_model = get_training_model()
+initial_teacher_model.save_weights("initial_teacher_model.h5")
+
+"""
+## Train the teacher model
+
+As noted in Noisy Student Training, if the teacher model is trained with *geometric
+ensembling* and when the student model is forced to mimic that, it leads to better
+performance. The original work uses [Stochastic Depth](https://arxiv.org/abs/1603.09382)
+and [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) to bring in the ensembling
+part but for this example, we will use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407)
+(SWA) which also resembles geometric ensembling.
+"""
+
+# Define the callbacks.
+reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=3)
+early_stopping = tf.keras.callbacks.EarlyStopping(
+ patience=10, restore_best_weights=True
+)
+
+# Initialize SWA from tf-hub.
+SWA = tfa.optimizers.SWA
+
+# Compile and train the teacher model.
+teacher_model = get_training_model()
+teacher_model.load_weights("initial_teacher_model.h5")
+teacher_model.compile(
+ # Notice that we are wrapping our optimizer within SWA
+ optimizer=SWA(tf.keras.optimizers.Adam()),
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=["accuracy"],
+)
+history = teacher_model.fit(
+ train_clean_ds,
+ epochs=EPOCHS,
+ validation_data=validation_ds,
+ callbacks=[reduce_lr, early_stopping],
+)
+
+# Evaluate the teacher model on the test set.
+_, acc = teacher_model.evaluate(test_ds, verbose=0)
+print(f"Test accuracy: {acc*100}%")
+
+"""
+## Define a self-training utility
+
+For this part, we will borrow the `Distiller` class from [this Keras Example](https://keras.io/examples/vision/knowledge_distillation/).
+"""
+
+
+# Majority of the code is taken from:
+# https://keras.io/examples/vision/knowledge_distillation/
+class SelfTrainer(tf.keras.Model):
+ def __init__(self, student, teacher):
+ super().__init__()
+ self.student = student
+ self.teacher = teacher
+
+ def compile(
+ self,
+ optimizer,
+ metrics,
+ student_loss_fn,
+ distillation_loss_fn,
+ temperature=3,
+ ):
+ super().compile(optimizer=optimizer, metrics=metrics)
+ self.student_loss_fn = student_loss_fn
+ self.distillation_loss_fn = distillation_loss_fn
+ self.temperature = temperature
+
+ def train_step(self, data):
+ # Since our dataset is a zip of two independent datasets,
+ # after initially parsing them, we segregate the
+ # respective images and labels next.
+ clean_ds, noisy_ds = data
+ clean_images, _ = clean_ds
+ noisy_images, y = noisy_ds
+
+ # Forward pass of teacher
+ teacher_predictions = self.teacher(clean_images, training=False)
+
+ with tf.GradientTape() as tape:
+ # Forward pass of student
+ student_predictions = self.student(noisy_images, training=True)
+
+ # Compute losses
+ student_loss = self.student_loss_fn(y, student_predictions)
+ distillation_loss = self.distillation_loss_fn(
+ tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
+ tf.nn.softmax(student_predictions / self.temperature, axis=1),
+ )
+ total_loss = (student_loss + distillation_loss) / 2
+
+ # Compute gradients
+ trainable_vars = self.student.trainable_variables
+ gradients = tape.gradient(total_loss, trainable_vars)
+
+ # Update weights
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+
+ # Update the metrics configured in `compile()`
+ self.compiled_metrics.update_state(
+ y, tf.nn.softmax(student_predictions, axis=1)
+ )
+
+ # Return a dict of performance
+ results = {m.name: m.result() for m in self.metrics}
+ results.update({"total_loss": total_loss})
+ return results
+
+ def test_step(self, data):
+ # During inference, we only pass a dataset consisting images and labels.
+ x, y = data
+
+ # Compute predictions
+ y_prediction = self.student(x, training=False)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(y, tf.nn.softmax(y_prediction, axis=1))
+
+ # Return a dict of performance
+ results = {m.name: m.result() for m in self.metrics}
+ return results
+
+
+"""
+The only difference in this implementation is the way loss is being calculated. **Instead
+of weighted the distillation loss and student loss differently we are taking their
+average following Noisy Student Training**.
+"""
+
+"""
+## Train the student model
+"""
+
+# Define the callbacks.
+# We are using a larger decay factor to stabilize the training.
+reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
+ patience=3, factor=0.5, monitor="val_accuracy"
+)
+early_stopping = tf.keras.callbacks.EarlyStopping(
+ patience=10, restore_best_weights=True, monitor="val_accuracy"
+)
+
+# Compile and train the student model.
+self_trainer = SelfTrainer(student=get_training_model(), teacher=teacher_model)
+self_trainer.compile(
+ # Notice we are *not* using SWA here.
+ optimizer="adam",
+ metrics=["accuracy"],
+ student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ distillation_loss_fn=tf.keras.losses.KLDivergence(),
+ temperature=10,
+)
+history = self_trainer.fit(
+ consistency_training_ds,
+ epochs=EPOCHS,
+ validation_data=validation_ds,
+ callbacks=[reduce_lr, early_stopping],
+)
+
+# Evaluate the student model.
+acc = self_trainer.evaluate(test_ds, verbose=0)
+print(f"Test accuracy from student model: {acc*100}%")
+
+"""
+## Assess the robustness of the models
+
+A standard benchmark of assessing the robustness of vision models is to record their
+performance on corrupted datasets like ImageNet-C and CIFAR-10-C both of which were
+proposed in [Benchmarking Neural Network Robustness to Common Corruptions and
+Perturbations](https://arxiv.org/abs/1903.12261). For this example, we will be using the
+CIFAR-10-C dataset which has 19 different corruptions on 5 different severity levels. To
+assess the robustness of the models on this dataset, we will do the following:
+
+* Run the pre-trained models on the highest level of severities and obtain the top-1
+accuracies.
+* Compute the mean top-1 accuracy.
+
+For the purpose of this example, we won't be going through these steps. This is why we
+trained the models for only 5 epochs. You can check out [this
+repository](https://github.com/sayakpaul/Consistency-Training-with-Supervision) that
+demonstrates the full-scale training experiments and also the aforementioned assessment.
+The figure below presents an executive summary of that assessment:
+
+
+
+**Mean Top-1** results stand for the CIFAR-10-C dataset and **Test Top-1** results stand
+for the CIFAR-10 test set. It's clear that consistency training has an advantage on not
+only enhancing the model robustness but also on improving the standard test performance.
+"""
diff --git a/knowledge_base/vision/conv_lstm.py b/knowledge_base/vision/conv_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b17725dea97a36be1ffe04bcfd55c4de4d2d15bc
--- /dev/null
+++ b/knowledge_base/vision/conv_lstm.py
@@ -0,0 +1,294 @@
+"""
+Title: Next-Frame Video Prediction with Convolutional LSTMs
+Author: [Amogh Joshi](https://github.com/amogh7joshi)
+Date created: 2021/06/02
+Last modified: 2023/11/10
+Description: How to build and train a convolutional LSTM model for next-frame video prediction.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The
+[Convolutional LSTM](https://papers.nips.cc/paper/2015/file/07563a3fe3bbe7e3ba84431ad9d055af-Paper.pdf)
+architectures bring together time series processing and computer vision by
+introducing a convolutional recurrent cell in a LSTM layer. In this example, we will explore the
+Convolutional LSTM model in an application to next-frame prediction, the process
+of predicting what video frames come next given a series of past frames.
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import keras
+from keras import layers
+
+import io
+import imageio
+from IPython.display import Image, display
+from ipywidgets import widgets, Layout, HBox
+
+"""
+## Dataset Construction
+
+For this example, we will be using the
+[Moving MNIST](http://www.cs.toronto.edu/~nitish/unsupervised_video/)
+dataset.
+
+We will download the dataset and then construct and
+preprocess training and validation sets.
+
+For next-frame prediction, our model will be using a previous frame,
+which we'll call `f_n`, to predict a new frame, called `f_(n + 1)`.
+To allow the model to create these predictions, we'll need to process
+the data such that we have "shifted" inputs and outputs, where the
+input data is frame `x_n`, being used to predict frame `y_(n + 1)`.
+"""
+
+# Download and load the dataset.
+fpath = keras.utils.get_file(
+ "moving_mnist.npy",
+ "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy",
+)
+dataset = np.load(fpath)
+
+# Swap the axes representing the number of frames and number of data samples.
+dataset = np.swapaxes(dataset, 0, 1)
+# We'll pick out 1000 of the 10000 total examples and use those.
+dataset = dataset[:1000, ...]
+# Add a channel dimension since the images are grayscale.
+dataset = np.expand_dims(dataset, axis=-1)
+
+# Split into train and validation sets using indexing to optimize memory.
+indexes = np.arange(dataset.shape[0])
+np.random.shuffle(indexes)
+train_index = indexes[: int(0.9 * dataset.shape[0])]
+val_index = indexes[int(0.9 * dataset.shape[0]) :]
+train_dataset = dataset[train_index]
+val_dataset = dataset[val_index]
+
+# Normalize the data to the 0-1 range.
+train_dataset = train_dataset / 255
+val_dataset = val_dataset / 255
+
+
+# We'll define a helper function to shift the frames, where
+# `x` is frames 0 to n - 1, and `y` is frames 1 to n.
+def create_shifted_frames(data):
+ x = data[:, 0 : data.shape[1] - 1, :, :]
+ y = data[:, 1 : data.shape[1], :, :]
+ return x, y
+
+
+# Apply the processing function to the datasets.
+x_train, y_train = create_shifted_frames(train_dataset)
+x_val, y_val = create_shifted_frames(val_dataset)
+
+# Inspect the dataset.
+print("Training Dataset Shapes: " + str(x_train.shape) + ", " + str(y_train.shape))
+print("Validation Dataset Shapes: " + str(x_val.shape) + ", " + str(y_val.shape))
+
+"""
+## Data Visualization
+
+Our data consists of sequences of frames, each of which
+are used to predict the upcoming frame. Let's take a look
+at some of these sequential frames.
+"""
+
+# Construct a figure on which we will visualize the images.
+fig, axes = plt.subplots(4, 5, figsize=(10, 8))
+
+# Plot each of the sequential images for one random data example.
+data_choice = np.random.choice(range(len(train_dataset)), size=1)[0]
+for idx, ax in enumerate(axes.flat):
+ ax.imshow(np.squeeze(train_dataset[data_choice][idx]), cmap="gray")
+ ax.set_title(f"Frame {idx + 1}")
+ ax.axis("off")
+
+# Print information and display the figure.
+print(f"Displaying frames for example {data_choice}.")
+plt.show()
+
+"""
+## Model Construction
+
+To build a Convolutional LSTM model, we will use the
+`ConvLSTM2D` layer, which will accept inputs of shape
+`(batch_size, num_frames, width, height, channels)`, and return
+a prediction movie of the same shape.
+"""
+
+# Construct the input layer with no definite frame size.
+inp = layers.Input(shape=(None, *x_train.shape[2:]))
+
+# We will construct 3 `ConvLSTM2D` layers with batch normalization,
+# followed by a `Conv3D` layer for the spatiotemporal outputs.
+x = layers.ConvLSTM2D(
+ filters=64,
+ kernel_size=(5, 5),
+ padding="same",
+ return_sequences=True,
+ activation="relu",
+)(inp)
+x = layers.BatchNormalization()(x)
+x = layers.ConvLSTM2D(
+ filters=64,
+ kernel_size=(3, 3),
+ padding="same",
+ return_sequences=True,
+ activation="relu",
+)(x)
+x = layers.BatchNormalization()(x)
+x = layers.ConvLSTM2D(
+ filters=64,
+ kernel_size=(1, 1),
+ padding="same",
+ return_sequences=True,
+ activation="relu",
+)(x)
+x = layers.Conv3D(
+ filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
+)(x)
+
+# Next, we will build the complete model and compile it.
+model = keras.models.Model(inp, x)
+model.compile(
+ loss=keras.losses.binary_crossentropy,
+ optimizer=keras.optimizers.Adam(),
+)
+
+"""
+## Model Training
+
+With our model and data constructed, we can now train the model.
+"""
+
+# Define some callbacks to improve training.
+early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
+reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)
+
+# Define modifiable training hyperparameters.
+epochs = 20
+batch_size = 5
+
+# Fit the model to the training data.
+model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=epochs,
+ validation_data=(x_val, y_val),
+ callbacks=[early_stopping, reduce_lr],
+)
+
+"""
+## Frame Prediction Visualizations
+
+With our model now constructed and trained, we can generate
+some example frame predictions based on a new video.
+
+We'll pick a random example from the validation set and
+then choose the first ten frames from them. From there, we can
+allow the model to predict 10 new frames, which we can compare
+to the ground truth frame predictions.
+"""
+
+# Select a random example from the validation dataset.
+example = val_dataset[np.random.choice(range(len(val_dataset)), size=1)[0]]
+
+# Pick the first/last ten frames from the example.
+frames = example[:10, ...]
+original_frames = example[10:, ...]
+
+# Predict a new set of 10 frames.
+for _ in range(10):
+ # Extract the model's prediction and post-process it.
+ new_prediction = model.predict(np.expand_dims(frames, axis=0))
+ new_prediction = np.squeeze(new_prediction, axis=0)
+ predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
+
+ # Extend the set of prediction frames.
+ frames = np.concatenate((frames, predicted_frame), axis=0)
+
+# Construct a figure for the original and new frames.
+fig, axes = plt.subplots(2, 10, figsize=(20, 4))
+
+# Plot the original frames.
+for idx, ax in enumerate(axes[0]):
+ ax.imshow(np.squeeze(original_frames[idx]), cmap="gray")
+ ax.set_title(f"Frame {idx + 11}")
+ ax.axis("off")
+
+# Plot the new frames.
+new_frames = frames[10:, ...]
+for idx, ax in enumerate(axes[1]):
+ ax.imshow(np.squeeze(new_frames[idx]), cmap="gray")
+ ax.set_title(f"Frame {idx + 11}")
+ ax.axis("off")
+
+# Display the figure.
+plt.show()
+
+"""
+## Predicted Videos
+
+Finally, we'll pick a few examples from the validation set
+and construct some GIFs with them to see the model's
+predicted videos.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conv-lstm)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conv-lstm).
+"""
+
+# Select a few random examples from the dataset.
+examples = val_dataset[np.random.choice(range(len(val_dataset)), size=5)]
+
+# Iterate over the examples and predict the frames.
+predicted_videos = []
+for example in examples:
+ # Pick the first/last ten frames from the example.
+ frames = example[:10, ...]
+ original_frames = example[10:, ...]
+ new_predictions = np.zeros(shape=(10, *frames[0].shape))
+
+ # Predict a new set of 10 frames.
+ for i in range(10):
+ # Extract the model's prediction and post-process it.
+ frames = example[: 10 + i + 1, ...]
+ new_prediction = model.predict(np.expand_dims(frames, axis=0))
+ new_prediction = np.squeeze(new_prediction, axis=0)
+ predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
+
+ # Extend the set of prediction frames.
+ new_predictions[i] = predicted_frame
+
+ # Create and save GIFs for each of the ground truth/prediction images.
+ for frame_set in [original_frames, new_predictions]:
+ # Construct a GIF from the selected video frames.
+ current_frames = np.squeeze(frame_set)
+ current_frames = current_frames[..., np.newaxis] * np.ones(3)
+ current_frames = (current_frames * 255).astype(np.uint8)
+ current_frames = list(current_frames)
+
+ # Construct a GIF from the frames.
+ with io.BytesIO() as gif:
+ imageio.mimsave(gif, current_frames, "GIF", duration=200)
+ predicted_videos.append(gif.getvalue())
+
+# Display the videos.
+print(" Truth\tPrediction")
+for i in range(0, len(predicted_videos), 2):
+ # Construct and display an `HBox` with the ground truth and prediction.
+ box = HBox(
+ [
+ widgets.Image(value=predicted_videos[i]),
+ widgets.Image(value=predicted_videos[i + 1]),
+ ]
+ )
+ display(box)
diff --git a/knowledge_base/vision/convmixer.py b/knowledge_base/vision/convmixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78f701a7f33bda80a6fbf075d49c6c0861a1734
--- /dev/null
+++ b/knowledge_base/vision/convmixer.py
@@ -0,0 +1,318 @@
+"""
+Title: Image classification with ConvMixer
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/10/12
+Last modified: 2021/10/12
+Description: An all-convolutional network applied to patches of images.
+Accelerator: GPU
+Converted to Keras 3 by: [Md Awsafur Rahman](https://awsaf49.github.io)
+"""
+
+"""
+## Introduction
+
+Vision Transformers (ViT; [Dosovitskiy et al.](https://arxiv.org/abs/1612.00593)) extract
+small patches from the input images, linearly project them, and then apply the
+Transformer ([Vaswani et al.](https://arxiv.org/abs/1706.03762)) blocks. The application
+of ViTs to image recognition tasks is quickly becoming a promising area of research,
+because ViTs eliminate the need to have strong inductive biases (such as convolutions) for
+modeling locality. This presents them as a general computation primititive capable of
+learning just from the training data with as minimal inductive priors as possible. ViTs
+yield great downstream performance when trained with proper regularization, data
+augmentation, and relatively large datasets.
+
+In the [Patches Are All You Need](https://openreview.net/pdf?id=TVHS5Y4dNvM) paper (note:
+at
+the time of writing, it is a submission to the ICLR 2022 conference), the authors extend
+the idea of using patches to train an all-convolutional network and demonstrate
+competitive results. Their architecture namely **ConvMixer** uses recipes from the recent
+isotrophic architectures like ViT, MLP-Mixer
+([Tolstikhin et al.](https://arxiv.org/abs/2105.01601)), such as using the same
+depth and resolution across different layers in the network, residual connections,
+and so on.
+
+In this example, we will implement the ConvMixer model and demonstrate its performance on
+the CIFAR-10 dataset.
+"""
+
+"""
+## Imports
+"""
+
+import keras
+from keras import layers
+
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import numpy as np
+
+"""
+## Hyperparameters
+
+To keep run time short, we will train the model for only 10 epochs. To focus on
+the core ideas of ConvMixer, we will not use other training-specific elements like
+RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)). If you are interested in
+learning more about those details, please refer to the
+[original paper](https://openreview.net/pdf?id=TVHS5Y4dNvM).
+"""
+
+learning_rate = 0.001
+weight_decay = 0.0001
+batch_size = 128
+num_epochs = 10
+
+"""
+## Load the CIFAR-10 dataset
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+val_split = 0.1
+
+val_indices = int(len(x_train) * val_split)
+new_x_train, new_y_train = x_train[val_indices:], y_train[val_indices:]
+x_val, y_val = x_train[:val_indices], y_train[:val_indices]
+
+print(f"Training data samples: {len(new_x_train)}")
+print(f"Validation data samples: {len(x_val)}")
+print(f"Test data samples: {len(x_test)}")
+
+"""
+## Prepare `tf.data.Dataset` objects
+
+Our data augmentation pipeline is different from what the authors used for the CIFAR-10
+dataset, which is fine for the purpose of the example.
+Note that, it's ok to use **TF APIs for data I/O and preprocessing** with other backends
+(jax, torch) as it is feature-complete framework when it comes to data preprocessing.
+"""
+
+image_size = 32
+auto = tf.data.AUTOTUNE
+
+augmentation_layers = [
+ keras.layers.RandomCrop(image_size, image_size),
+ keras.layers.RandomFlip("horizontal"),
+]
+
+
+def augment_images(images):
+ for layer in augmentation_layers:
+ images = layer(images, training=True)
+ return images
+
+
+def make_datasets(images, labels, is_train=False):
+ dataset = tf.data.Dataset.from_tensor_slices((images, labels))
+ if is_train:
+ dataset = dataset.shuffle(batch_size * 10)
+ dataset = dataset.batch(batch_size)
+ if is_train:
+ dataset = dataset.map(
+ lambda x, y: (augment_images(x), y), num_parallel_calls=auto
+ )
+ return dataset.prefetch(auto)
+
+
+train_dataset = make_datasets(new_x_train, new_y_train, is_train=True)
+val_dataset = make_datasets(x_val, y_val)
+test_dataset = make_datasets(x_test, y_test)
+
+"""
+## ConvMixer utilities
+
+The following figure (taken from the original paper) depicts the ConvMixer model:
+
+
+
+ConvMixer is very similar to the MLP-Mixer, model with the following key
+differences:
+
+* Instead of using fully-connected layers, it uses standard convolution layers.
+* Instead of LayerNorm (which is typical for ViTs and MLP-Mixers), it uses BatchNorm.
+
+Two types of convolution layers are used in ConvMixer. **(1)**: Depthwise convolutions,
+for mixing spatial locations of the images, **(2)**: Pointwise convolutions (which follow
+the depthwise convolutions), for mixing channel-wise information across the patches.
+Another keypoint is the use of *larger kernel sizes* to allow a larger receptive field.
+"""
+
+
+def activation_block(x):
+ x = layers.Activation("gelu")(x)
+ return layers.BatchNormalization()(x)
+
+
+def conv_stem(x, filters: int, patch_size: int):
+ x = layers.Conv2D(filters, kernel_size=patch_size, strides=patch_size)(x)
+ return activation_block(x)
+
+
+def conv_mixer_block(x, filters: int, kernel_size: int):
+ # Depthwise convolution.
+ x0 = x
+ x = layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x)
+ x = layers.Add()([activation_block(x), x0]) # Residual.
+
+ # Pointwise convolution.
+ x = layers.Conv2D(filters, kernel_size=1)(x)
+ x = activation_block(x)
+
+ return x
+
+
+def get_conv_mixer_256_8(
+ image_size=32, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=10
+):
+ """ConvMixer-256/8: https://openreview.net/pdf?id=TVHS5Y4dNvM.
+ The hyperparameter values are taken from the paper.
+ """
+ inputs = keras.Input((image_size, image_size, 3))
+ x = layers.Rescaling(scale=1.0 / 255)(inputs)
+
+ # Extract patch embeddings.
+ x = conv_stem(x, filters, patch_size)
+
+ # ConvMixer blocks.
+ for _ in range(depth):
+ x = conv_mixer_block(x, filters, kernel_size)
+
+ # Classification block.
+ x = layers.GlobalAvgPool2D()(x)
+ outputs = layers.Dense(num_classes, activation="softmax")(x)
+
+ return keras.Model(inputs, outputs)
+
+
+"""
+The model used in this experiment is termed as **ConvMixer-256/8** where 256 denotes the
+number of channels and 8 denotes the depth. The resulting model only has 0.8 million
+parameters.
+"""
+
+"""
+## Model training and evaluation utility
+"""
+
+# Code reference:
+# https://keras.io/examples/vision/image_classification_with_vision_transformer/.
+
+
+def run_experiment(model):
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ )
+
+ model.compile(
+ optimizer=optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+ )
+
+ checkpoint_filepath = "/tmp/checkpoint.keras"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=False,
+ )
+
+ history = model.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=num_epochs,
+ callbacks=[checkpoint_callback],
+ )
+
+ model.load_weights(checkpoint_filepath)
+ _, accuracy = model.evaluate(test_dataset)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+ return history, model
+
+
+"""
+## Train and evaluate model
+"""
+
+conv_mixer_model = get_conv_mixer_256_8()
+history, conv_mixer_model = run_experiment(conv_mixer_model)
+
+"""
+The gap in training and validation performance can be mitigated by using additional
+regularization techniques. Nevertheless, being able to get to ~83% accuracy within 10
+epochs with 0.8 million parameters is a strong result.
+"""
+
+"""
+## Visualizing the internals of ConvMixer
+
+We can visualize the patch embeddings and the learned convolution filters. Recall
+that each patch embedding and intermediate feature map have the same number of channels
+(256 in this case). This will make our visualization utility easier to implement.
+"""
+
+# Code reference: https://bit.ly/3awIRbP.
+
+
+def visualization_plot(weights, idx=1):
+ # First, apply min-max normalization to the
+ # given weights to avoid isotrophic scaling.
+ p_min, p_max = weights.min(), weights.max()
+ weights = (weights - p_min) / (p_max - p_min)
+
+ # Visualize all the filters.
+ num_filters = 256
+ plt.figure(figsize=(8, 8))
+
+ for i in range(num_filters):
+ current_weight = weights[:, :, :, i]
+ if current_weight.shape[-1] == 1:
+ current_weight = current_weight.squeeze()
+ ax = plt.subplot(16, 16, idx)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ plt.imshow(current_weight)
+ idx += 1
+
+
+# We first visualize the learned patch embeddings.
+patch_embeddings = conv_mixer_model.layers[2].get_weights()[0]
+visualization_plot(patch_embeddings)
+
+"""
+Even though we did not train the network to convergence, we can notice that different
+patches show different patterns. Some share similarity with others while some are very
+different. These visualizations are more salient with larger image sizes.
+
+Similarly, we can visualize the raw convolution kernels. This can help us understand
+the patterns to which a given kernel is receptive.
+"""
+
+# First, print the indices of the convolution layers that are not
+# pointwise convolutions.
+for i, layer in enumerate(conv_mixer_model.layers):
+ if isinstance(layer, layers.DepthwiseConv2D):
+ if layer.get_config()["kernel_size"] == (5, 5):
+ print(i, layer)
+
+idx = 26 # Taking a kernel from the middle of the network.
+
+kernel = conv_mixer_model.layers[idx].get_weights()[0]
+kernel = np.expand_dims(kernel.squeeze(), axis=2)
+visualization_plot(kernel)
+
+"""
+We see that different filters in the kernel have different locality spans, and this
+pattern
+is likely to evolve with more training.
+"""
+
+"""
+## Final notes
+
+There's been a recent trend on fusing convolutions with other data-agnostic operations
+like self-attention. Following works are along this line of research:
+
+* ConViT ([d'Ascoli et al.](https://arxiv.org/abs/2103.10697))
+* CCT ([Hassani et al.](https://arxiv.org/abs/2104.05704))
+* CoAtNet ([Dai et al.](https://arxiv.org/abs/2106.04803))
+"""
diff --git a/knowledge_base/vision/cutmix.py b/knowledge_base/vision/cutmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a27c0c7a997ff8c2a986c8bd5b4209a691b5080
--- /dev/null
+++ b/knowledge_base/vision/cutmix.py
@@ -0,0 +1,387 @@
+"""
+Title: CutMix data augmentation for image classification
+Author: [Sayan Nath](https://twitter.com/sayannath2350)
+Date created: 2021/06/08
+Last modified: 2023/11/14
+Description: Data augmentation with CutMix for image classification on CIFAR-10.
+Accelerator: GPU
+Converted to Keras 3 By: [Piyush Thakur](https://github.com/cosmo3769)
+"""
+
+"""
+## Introduction
+"""
+
+"""
+_CutMix_ is a data augmentation technique that addresses the issue of information loss
+and inefficiency present in regional dropout strategies.
+Instead of removing pixels and filling them with black or grey pixels or Gaussian noise,
+you replace the removed regions with a patch from another image,
+while the ground truth labels are mixed proportionally to the number of pixels of combined images.
+CutMix was proposed in
+[CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features](https://arxiv.org/abs/1905.04899)
+(Yun et al., 2019)
+
+It's implemented via the following formulas:
+
+
+
+where `M` is the binary mask which indicates the cutout and the fill-in
+regions from the two randomly drawn images and `ฮป` (in `[0, 1]`) is drawn from a
+[`Beta(ฮฑ, ฮฑ)` distribution](https://en.wikipedia.org/wiki/Beta_distribution)
+
+The coordinates of bounding boxes are:
+
+
+
+which indicates the cutout and fill-in regions in case of the images.
+The bounding box sampling is represented by:
+
+
+
+where `rx, ry` are randomly drawn from a uniform distribution with upper bound.
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import keras
+import matplotlib.pyplot as plt
+
+from keras import layers
+
+# TF imports related to tf.data preprocessing
+from tensorflow import clip_by_value
+from tensorflow import data as tf_data
+from tensorflow import image as tf_image
+from tensorflow import random as tf_random
+
+keras.utils.set_random_seed(42)
+
+"""
+## Load the CIFAR-10 dataset
+
+In this example, we will use the
+[CIFAR-10 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+y_train = keras.utils.to_categorical(y_train, num_classes=10)
+y_test = keras.utils.to_categorical(y_test, num_classes=10)
+
+print(x_train.shape)
+print(y_train.shape)
+print(x_test.shape)
+print(y_test.shape)
+
+class_names = [
+ "Airplane",
+ "Automobile",
+ "Bird",
+ "Cat",
+ "Deer",
+ "Dog",
+ "Frog",
+ "Horse",
+ "Ship",
+ "Truck",
+]
+
+"""
+## Define hyperparameters
+"""
+
+AUTO = tf_data.AUTOTUNE
+BATCH_SIZE = 32
+IMG_SIZE = 32
+
+"""
+## Define the image preprocessing function
+"""
+
+
+def preprocess_image(image, label):
+ image = tf_image.resize(image, (IMG_SIZE, IMG_SIZE))
+ image = tf_image.convert_image_dtype(image, "float32") / 255.0
+ label = keras.ops.cast(label, dtype="float32")
+ return image, label
+
+
+"""
+## Convert the data into TensorFlow `Dataset` objects
+"""
+
+train_ds_one = (
+ tf_data.Dataset.from_tensor_slices((x_train, y_train))
+ .shuffle(1024)
+ .map(preprocess_image, num_parallel_calls=AUTO)
+)
+train_ds_two = (
+ tf_data.Dataset.from_tensor_slices((x_train, y_train))
+ .shuffle(1024)
+ .map(preprocess_image, num_parallel_calls=AUTO)
+)
+
+train_ds_simple = tf_data.Dataset.from_tensor_slices((x_train, y_train))
+
+test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
+
+train_ds_simple = (
+ train_ds_simple.map(preprocess_image, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# Combine two shuffled datasets from the same training data.
+train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))
+
+test_ds = (
+ test_ds.map(preprocess_image, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+"""
+## Define the CutMix data augmentation function
+
+The CutMix function takes two `image` and `label` pairs to perform the augmentation.
+It samples `ฮป(l)` from the [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution)
+and returns a bounding box from `get_box` function. We then crop the second image (`image2`)
+and pad this image in the final padded image at the same location.
+"""
+
+
+def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
+ gamma_1_sample = tf_random.gamma(shape=[size], alpha=concentration_1)
+ gamma_2_sample = tf_random.gamma(shape=[size], alpha=concentration_0)
+ return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
+
+
+def get_box(lambda_value):
+ cut_rat = keras.ops.sqrt(1.0 - lambda_value)
+
+ cut_w = IMG_SIZE * cut_rat # rw
+ cut_w = keras.ops.cast(cut_w, "int32")
+
+ cut_h = IMG_SIZE * cut_rat # rh
+ cut_h = keras.ops.cast(cut_h, "int32")
+
+ cut_x = keras.random.uniform((1,), minval=0, maxval=IMG_SIZE) # rx
+ cut_x = keras.ops.cast(cut_x, "int32")
+ cut_y = keras.random.uniform((1,), minval=0, maxval=IMG_SIZE) # ry
+ cut_y = keras.ops.cast(cut_y, "int32")
+
+ boundaryx1 = clip_by_value(cut_x[0] - cut_w // 2, 0, IMG_SIZE)
+ boundaryy1 = clip_by_value(cut_y[0] - cut_h // 2, 0, IMG_SIZE)
+ bbx2 = clip_by_value(cut_x[0] + cut_w // 2, 0, IMG_SIZE)
+ bby2 = clip_by_value(cut_y[0] + cut_h // 2, 0, IMG_SIZE)
+
+ target_h = bby2 - boundaryy1
+ if target_h == 0:
+ target_h += 1
+
+ target_w = bbx2 - boundaryx1
+ if target_w == 0:
+ target_w += 1
+
+ return boundaryx1, boundaryy1, target_h, target_w
+
+
+def cutmix(train_ds_one, train_ds_two):
+ (image1, label1), (image2, label2) = train_ds_one, train_ds_two
+
+ alpha = [0.25]
+ beta = [0.25]
+
+ # Get a sample from the Beta distribution
+ lambda_value = sample_beta_distribution(1, alpha, beta)
+
+ # Define Lambda
+ lambda_value = lambda_value[0][0]
+
+ # Get the bounding box offsets, heights and widths
+ boundaryx1, boundaryy1, target_h, target_w = get_box(lambda_value)
+
+ # Get a patch from the second image (`image2`)
+ crop2 = tf_image.crop_to_bounding_box(
+ image2, boundaryy1, boundaryx1, target_h, target_w
+ )
+ # Pad the `image2` patch (`crop2`) with the same offset
+ image2 = tf_image.pad_to_bounding_box(
+ crop2, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
+ )
+ # Get a patch from the first image (`image1`)
+ crop1 = tf_image.crop_to_bounding_box(
+ image1, boundaryy1, boundaryx1, target_h, target_w
+ )
+ # Pad the `image1` patch (`crop1`) with the same offset
+ img1 = tf_image.pad_to_bounding_box(
+ crop1, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
+ )
+
+ # Modify the first image by subtracting the patch from `image1`
+ # (before applying the `image2` patch)
+ image1 = image1 - img1
+ # Add the modified `image1` and `image2` together to get the CutMix image
+ image = image1 + image2
+
+ # Adjust Lambda in accordance to the pixel ration
+ lambda_value = 1 - (target_w * target_h) / (IMG_SIZE * IMG_SIZE)
+ lambda_value = keras.ops.cast(lambda_value, "float32")
+
+ # Combine the labels of both images
+ label = lambda_value * label1 + (1 - lambda_value) * label2
+ return image, label
+
+
+"""
+**Note**: we are combining two images to create a single one.
+
+## Visualize the new dataset after applying the CutMix augmentation
+"""
+
+# Create the new dataset using our `cutmix` utility
+train_ds_cmu = (
+ train_ds.shuffle(1024)
+ .map(cutmix, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# Let's preview 9 samples from the dataset
+image_batch, label_batch = next(iter(train_ds_cmu))
+plt.figure(figsize=(10, 10))
+for i in range(9):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.title(class_names[np.argmax(label_batch[i])])
+ plt.imshow(image_batch[i])
+ plt.axis("off")
+
+"""
+## Define a ResNet-20 model
+"""
+
+
+def resnet_layer(
+ inputs,
+ num_filters=16,
+ kernel_size=3,
+ strides=1,
+ activation="relu",
+ batch_normalization=True,
+ conv_first=True,
+):
+ conv = layers.Conv2D(
+ num_filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding="same",
+ kernel_initializer="he_normal",
+ kernel_regularizer=keras.regularizers.L2(1e-4),
+ )
+ x = inputs
+ if conv_first:
+ x = conv(x)
+ if batch_normalization:
+ x = layers.BatchNormalization()(x)
+ if activation is not None:
+ x = layers.Activation(activation)(x)
+ else:
+ if batch_normalization:
+ x = layers.BatchNormalization()(x)
+ if activation is not None:
+ x = layers.Activation(activation)(x)
+ x = conv(x)
+ return x
+
+
+def resnet_v20(input_shape, depth, num_classes=10):
+ if (depth - 2) % 6 != 0:
+ raise ValueError("depth should be 6n+2 (eg 20, 32, 44 in [a])")
+ # Start model definition.
+ num_filters = 16
+ num_res_blocks = int((depth - 2) / 6)
+
+ inputs = layers.Input(shape=input_shape)
+ x = resnet_layer(inputs=inputs)
+ # Instantiate the stack of residual units
+ for stack in range(3):
+ for res_block in range(num_res_blocks):
+ strides = 1
+ if stack > 0 and res_block == 0: # first layer but not first stack
+ strides = 2 # downsample
+ y = resnet_layer(inputs=x, num_filters=num_filters, strides=strides)
+ y = resnet_layer(inputs=y, num_filters=num_filters, activation=None)
+ if stack > 0 and res_block == 0: # first layer but not first stack
+ # linear projection residual shortcut connection to match
+ # changed dims
+ x = resnet_layer(
+ inputs=x,
+ num_filters=num_filters,
+ kernel_size=1,
+ strides=strides,
+ activation=None,
+ batch_normalization=False,
+ )
+ x = layers.add([x, y])
+ x = layers.Activation("relu")(x)
+ num_filters *= 2
+
+ # Add classifier on top.
+ # v1 does not use BN after last shortcut connection-ReLU
+ x = layers.AveragePooling2D(pool_size=8)(x)
+ y = layers.Flatten()(x)
+ outputs = layers.Dense(
+ num_classes, activation="softmax", kernel_initializer="he_normal"
+ )(y)
+
+ # Instantiate model.
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+def training_model():
+ return resnet_v20((32, 32, 3), 20)
+
+
+initial_model = training_model()
+initial_model.save_weights("initial_weights.weights.h5")
+
+"""
+## Train the model with the dataset augmented by CutMix
+"""
+
+model = training_model()
+model.load_weights("initial_weights.weights.h5")
+
+model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+model.fit(train_ds_cmu, validation_data=test_ds, epochs=15)
+
+test_loss, test_accuracy = model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_accuracy * 100))
+
+"""
+## Train the model using the original non-augmented dataset
+"""
+
+model = training_model()
+model.load_weights("initial_weights.weights.h5")
+model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+model.fit(train_ds_simple, validation_data=test_ds, epochs=15)
+
+test_loss, test_accuracy = model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_accuracy * 100))
+
+"""
+## Notes
+
+In this example, we trained our model for 15 epochs.
+In our experiment, the model with CutMix achieves a better accuracy on the CIFAR-10 dataset
+(77.34% in our experiment) compared to the model that doesn't use the augmentation (66.90%).
+You may notice it takes less time to train the model with the CutMix augmentation.
+
+You can experiment further with the CutMix technique by following the
+[original paper](https://arxiv.org/abs/1905.04899).
+"""
diff --git a/knowledge_base/vision/deeplabv3_plus.py b/knowledge_base/vision/deeplabv3_plus.py
new file mode 100644
index 0000000000000000000000000000000000000000..894f2ee22a70d2945f3db59046a490953ca3aaa7
--- /dev/null
+++ b/knowledge_base/vision/deeplabv3_plus.py
@@ -0,0 +1,331 @@
+"""
+Title: Multiclass semantic segmentation using DeepLabV3+
+Author: [Soumik Rakshit](http://github.com/soumik12345)
+Date created: 2021/08/31
+Last modified: 2024/01/05
+Description: Implement DeepLabV3+ architecture for Multi-class Semantic Segmentation.
+Accelerator: GPU
+Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com)
+"""
+
+"""
+## Introduction
+
+Semantic segmentation, with the goal to assign semantic labels to every pixel in an image,
+is an essential computer vision task. In this example, we implement
+the **DeepLabV3+** model for multi-class semantic segmentation, a fully-convolutional
+architecture that performs well on semantic segmentation benchmarks.
+
+### References:
+
+- [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
+- [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
+- [DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs](https://arxiv.org/abs/1606.00915)
+"""
+
+"""
+## Downloading the data
+
+We will use the [Crowd Instance-level Human Parsing Dataset](https://arxiv.org/abs/1811.12596)
+for training our model. The Crowd Instance-level Human Parsing (CIHP) dataset has 38,280 diverse human images.
+Each image in CIHP is labeled with pixel-wise annotations for 20 categories, as well as instance-level identification.
+This dataset can be used for the "human part segmentation" task.
+"""
+
+
+import keras
+from keras import layers
+from keras import ops
+
+import os
+import numpy as np
+from glob import glob
+import cv2
+from scipy.io import loadmat
+import matplotlib.pyplot as plt
+
+# For data preprocessing
+from tensorflow import image as tf_image
+from tensorflow import data as tf_data
+from tensorflow import io as tf_io
+
+"""shell
+gdown "1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz&confirm=t"
+unzip -q instance-level-human-parsing.zip
+"""
+
+"""
+## Creating a TensorFlow Dataset
+
+Training on the entire CIHP dataset with 38,280 images takes a lot of time, hence we will be using
+a smaller subset of 200 images for training our model in this example.
+"""
+
+IMAGE_SIZE = 512
+BATCH_SIZE = 4
+NUM_CLASSES = 20
+DATA_DIR = "./instance-level_human_parsing/instance-level_human_parsing/Training"
+NUM_TRAIN_IMAGES = 1000
+NUM_VAL_IMAGES = 50
+
+train_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[:NUM_TRAIN_IMAGES]
+train_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[:NUM_TRAIN_IMAGES]
+val_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[
+ NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES
+]
+val_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[
+ NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES
+]
+
+
+def read_image(image_path, mask=False):
+ image = tf_io.read_file(image_path)
+ if mask:
+ image = tf_image.decode_png(image, channels=1)
+ image.set_shape([None, None, 1])
+ image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
+ else:
+ image = tf_image.decode_png(image, channels=3)
+ image.set_shape([None, None, 3])
+ image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
+ return image
+
+
+def load_data(image_list, mask_list):
+ image = read_image(image_list)
+ mask = read_image(mask_list, mask=True)
+ return image, mask
+
+
+def data_generator(image_list, mask_list):
+ dataset = tf_data.Dataset.from_tensor_slices((image_list, mask_list))
+ dataset = dataset.map(load_data, num_parallel_calls=tf_data.AUTOTUNE)
+ dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
+ return dataset
+
+
+train_dataset = data_generator(train_images, train_masks)
+val_dataset = data_generator(val_images, val_masks)
+
+print("Train Dataset:", train_dataset)
+print("Val Dataset:", val_dataset)
+
+"""
+## Building the DeepLabV3+ model
+
+DeepLabv3+ extends DeepLabv3 by adding an encoder-decoder structure. The encoder module
+processes multiscale contextual information by applying dilated convolution at multiple
+scales, while the decoder module refines the segmentation results along object boundaries.
+
+
+
+**Dilated convolution:** With dilated convolution, as we go deeper in the network, we can keep the
+stride constant but with larger field-of-view without increasing the number of parameters
+or the amount of computation. Besides, it enables larger output feature maps, which is
+useful for semantic segmentation.
+
+The reason for using **Dilated Spatial Pyramid Pooling** is that it was shown that as the
+sampling rate becomes larger, the number of valid filter weights (i.e., weights that
+are applied to the valid feature region, instead of padded zeros) becomes smaller.
+"""
+
+
+def convolution_block(
+ block_input,
+ num_filters=256,
+ kernel_size=3,
+ dilation_rate=1,
+ use_bias=False,
+):
+ x = layers.Conv2D(
+ num_filters,
+ kernel_size=kernel_size,
+ dilation_rate=dilation_rate,
+ padding="same",
+ use_bias=use_bias,
+ kernel_initializer=keras.initializers.HeNormal(),
+ )(block_input)
+ x = layers.BatchNormalization()(x)
+ return ops.nn.relu(x)
+
+
+def DilatedSpatialPyramidPooling(dspp_input):
+ dims = dspp_input.shape
+ x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
+ x = convolution_block(x, kernel_size=1, use_bias=True)
+ out_pool = layers.UpSampling2D(
+ size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
+ interpolation="bilinear",
+ )(x)
+
+ out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
+ out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
+ out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
+ out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
+
+ x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
+ output = convolution_block(x, kernel_size=1)
+ return output
+
+
+"""
+The encoder features are first bilinearly upsampled by a factor 4, and then
+concatenated with the corresponding low-level features from the network backbone that
+have the same spatial resolution. For this example, we
+use a ResNet50 pretrained on ImageNet as the backbone model, and we use
+the low-level features from the `conv4_block6_2_relu` block of the backbone.
+"""
+
+
+def DeeplabV3Plus(image_size, num_classes):
+ model_input = keras.Input(shape=(image_size, image_size, 3))
+ preprocessed = keras.applications.resnet50.preprocess_input(model_input)
+ resnet50 = keras.applications.ResNet50(
+ weights="imagenet", include_top=False, input_tensor=preprocessed
+ )
+ x = resnet50.get_layer("conv4_block6_2_relu").output
+ x = DilatedSpatialPyramidPooling(x)
+
+ input_a = layers.UpSampling2D(
+ size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
+ interpolation="bilinear",
+ )(x)
+ input_b = resnet50.get_layer("conv2_block3_2_relu").output
+ input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
+
+ x = layers.Concatenate(axis=-1)([input_a, input_b])
+ x = convolution_block(x)
+ x = convolution_block(x)
+ x = layers.UpSampling2D(
+ size=(image_size // x.shape[1], image_size // x.shape[2]),
+ interpolation="bilinear",
+ )(x)
+ model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
+ return keras.Model(inputs=model_input, outputs=model_output)
+
+
+model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
+model.summary()
+
+"""
+## Training
+
+We train the model using sparse categorical crossentropy as the loss function, and
+Adam as the optimizer.
+"""
+
+loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+ loss=loss,
+ metrics=["accuracy"],
+)
+
+history = model.fit(train_dataset, validation_data=val_dataset, epochs=25)
+
+plt.plot(history.history["loss"])
+plt.title("Training Loss")
+plt.ylabel("loss")
+plt.xlabel("epoch")
+plt.show()
+
+plt.plot(history.history["accuracy"])
+plt.title("Training Accuracy")
+plt.ylabel("accuracy")
+plt.xlabel("epoch")
+plt.show()
+
+plt.plot(history.history["val_loss"])
+plt.title("Validation Loss")
+plt.ylabel("val_loss")
+plt.xlabel("epoch")
+plt.show()
+
+plt.plot(history.history["val_accuracy"])
+plt.title("Validation Accuracy")
+plt.ylabel("val_accuracy")
+plt.xlabel("epoch")
+plt.show()
+
+"""
+## Inference using Colormap Overlay
+
+The raw predictions from the model represent a one-hot encoded tensor of shape `(N, 512, 512, 20)`
+where each one of the 20 channels is a binary mask corresponding to a predicted label.
+In order to visualize the results, we plot them as RGB segmentation masks where each pixel
+is represented by a unique color corresponding to the particular label predicted. We can easily
+find the color corresponding to each label from the `human_colormap.mat` file provided as part
+of the dataset. We would also plot an overlay of the RGB segmentation mask on the input image as
+this further helps us to identify the different categories present in the image more intuitively.
+"""
+
+# Loading the Colormap
+colormap = loadmat(
+ "./instance-level_human_parsing/instance-level_human_parsing/human_colormap.mat"
+)["colormap"]
+colormap = colormap * 100
+colormap = colormap.astype(np.uint8)
+
+
+def infer(model, image_tensor):
+ predictions = model.predict(np.expand_dims((image_tensor), axis=0))
+ predictions = np.squeeze(predictions)
+ predictions = np.argmax(predictions, axis=2)
+ return predictions
+
+
+def decode_segmentation_masks(mask, colormap, n_classes):
+ r = np.zeros_like(mask).astype(np.uint8)
+ g = np.zeros_like(mask).astype(np.uint8)
+ b = np.zeros_like(mask).astype(np.uint8)
+ for l in range(0, n_classes):
+ idx = mask == l
+ r[idx] = colormap[l, 0]
+ g[idx] = colormap[l, 1]
+ b[idx] = colormap[l, 2]
+ rgb = np.stack([r, g, b], axis=2)
+ return rgb
+
+
+def get_overlay(image, colored_mask):
+ image = keras.utils.array_to_img(image)
+ image = np.array(image).astype(np.uint8)
+ overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0)
+ return overlay
+
+
+def plot_samples_matplotlib(display_list, figsize=(5, 3)):
+ _, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)
+ for i in range(len(display_list)):
+ if display_list[i].shape[-1] == 3:
+ axes[i].imshow(keras.utils.array_to_img(display_list[i]))
+ else:
+ axes[i].imshow(display_list[i])
+ plt.show()
+
+
+def plot_predictions(images_list, colormap, model):
+ for image_file in images_list:
+ image_tensor = read_image(image_file)
+ prediction_mask = infer(image_tensor=image_tensor, model=model)
+ prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
+ overlay = get_overlay(image_tensor, prediction_colormap)
+ plot_samples_matplotlib(
+ [image_tensor, overlay, prediction_colormap], figsize=(18, 14)
+ )
+
+
+"""
+### Inference on Train Images
+"""
+
+plot_predictions(train_images[:4], colormap, model=model)
+
+"""
+### Inference on Validation Images
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/deeplabv3p-resnet50)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Human-Part-Segmentation).
+"""
+
+plot_predictions(val_images[:4], colormap, model=model)
diff --git a/knowledge_base/vision/deit.py b/knowledge_base/vision/deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..47529a38e6b5a2b229eecef8478723c20e8a0e76
--- /dev/null
+++ b/knowledge_base/vision/deit.py
@@ -0,0 +1,618 @@
+"""
+Title: Distilling Vision Transformers
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2022/04/05
+Last modified: 2022/04/08
+Description: Distillation of Vision Transformers through attention.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In the original *Vision Transformers* (ViT) paper
+([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)),
+the authors concluded that to perform on par with Convolutional Neural Networks (CNNs),
+ViTs need to be pre-trained on larger datasets. The larger the better. This is mainly
+due to the lack of inductive biases in the ViT architecture -- unlike CNNs,
+they don't have layers that exploit locality. In a follow-up paper
+([Steiner et al.](https://arxiv.org/abs/2106.10270)),
+the authors show that it is possible to substantially improve the performance of ViTs
+with stronger regularization and longer training.
+
+Many groups have proposed different ways to deal with the problem
+of data-intensiveness of ViT training.
+One such way was shown in the *Data-efficient image Transformers*,
+(DeiT) paper ([Touvron et al.](https://arxiv.org/abs/2012.12877)). The
+authors introduced a distillation technique that is specific to transformer-based vision
+models. DeiT is among the first works to show that it's possible to train ViTs well
+without using larger datasets.
+
+In this example, we implement the distillation recipe proposed in DeiT. This
+requires us to slightly tweak the original ViT architecture and write a custom training
+loop to implement the distillation recipe.
+
+To run the example, you'll need TensorFlow Addons, which you can install with the
+following command:
+
+```
+pip install tensorflow-addons
+```
+
+To comfortably navigate through this example, you'll be expected to know how a ViT and
+knowledge distillation work. The following are good resources in case you needed a
+refresher:
+
+* [ViT on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer)
+* [Knowledge distillation on keras.io](https://keras.io/examples/vision/knowledge_distillation/)
+"""
+
+"""
+## Imports
+"""
+
+from typing import List
+
+import tensorflow as tf
+import tensorflow_addons as tfa
+import tensorflow_datasets as tfds
+import tensorflow_hub as hub
+from tensorflow import keras
+from tensorflow.keras import layers
+
+tfds.disable_progress_bar()
+tf.keras.utils.set_random_seed(42)
+
+"""
+## Constants
+"""
+
+# Model
+MODEL_TYPE = "deit_distilled_tiny_patch16_224"
+RESOLUTION = 224
+PATCH_SIZE = 16
+NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
+LAYER_NORM_EPS = 1e-6
+PROJECTION_DIM = 192
+NUM_HEADS = 3
+NUM_LAYERS = 12
+MLP_UNITS = [
+ PROJECTION_DIM * 4,
+ PROJECTION_DIM,
+]
+DROPOUT_RATE = 0.0
+DROP_PATH_RATE = 0.1
+
+# Training
+NUM_EPOCHS = 20
+BASE_LR = 0.0005
+WEIGHT_DECAY = 0.0001
+
+# Data
+BATCH_SIZE = 256
+AUTO = tf.data.AUTOTUNE
+NUM_CLASSES = 5
+
+"""
+You probably noticed that `DROPOUT_RATE` has been set 0.0. Dropout has been used
+in the implementation to keep it complete. For smaller models (like the one used in
+this example), you don't need it, but for bigger models, using dropout helps.
+"""
+
+"""
+## Load the `tf_flowers` dataset and prepare preprocessing utilities
+
+The authors use an array of different augmentation techniques, including MixUp
+([Zhang et al.](https://arxiv.org/abs/1710.09412)),
+RandAugment ([Cubuk et al.](https://arxiv.org/abs/1909.13719)),
+and so on. However, to keep the example simple to work through, we'll discard them.
+"""
+
+
+def preprocess_dataset(is_training=True):
+ def fn(image, label):
+ if is_training:
+ # Resize to a bigger spatial resolution and take the random
+ # crops.
+ image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
+ image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
+ image = tf.image.random_flip_left_right(image)
+ else:
+ image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
+ label = tf.one_hot(label, depth=NUM_CLASSES)
+ return image, label
+
+ return fn
+
+
+def prepare_dataset(dataset, is_training=True):
+ if is_training:
+ dataset = dataset.shuffle(BATCH_SIZE * 10)
+ dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
+ return dataset.batch(BATCH_SIZE).prefetch(AUTO)
+
+
+train_dataset, val_dataset = tfds.load(
+ "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
+)
+num_train = train_dataset.cardinality()
+num_val = val_dataset.cardinality()
+print(f"Number of training examples: {num_train}")
+print(f"Number of validation examples: {num_val}")
+
+train_dataset = prepare_dataset(train_dataset, is_training=True)
+val_dataset = prepare_dataset(val_dataset, is_training=False)
+
+"""
+## Implementing the DeiT variants of ViT
+
+Since DeiT is an extension of ViT it'd make sense to first implement ViT and then extend
+it to support DeiT's components.
+
+First, we'll implement a layer for Stochastic Depth
+([Huang et al.](https://arxiv.org/abs/1603.09382))
+which is used in DeiT for regularization.
+"""
+
+
+# Referred from: github.com:rwightman/pytorch-image-models.
+class StochasticDepth(layers.Layer):
+ def __init__(self, drop_prop, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_prob = drop_prop
+
+ def call(self, x, training=True):
+ if training:
+ keep_prob = 1 - self.drop_prob
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+"""
+Now, we'll implement the MLP and Transformer blocks.
+"""
+
+
+def mlp(x, dropout_rate: float, hidden_units: List):
+ """FFN for a Transformer block."""
+ # Iterate over the hidden units and
+ # add Dense => Dropout.
+ for idx, units in enumerate(hidden_units):
+ x = layers.Dense(
+ units,
+ activation=tf.nn.gelu if idx == 0 else None,
+ )(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+def transformer(drop_prob: float, name: str) -> keras.Model:
+ """Transformer block with pre-norm."""
+ num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
+ encoded_patches = layers.Input((num_patches, PROJECTION_DIM))
+
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
+
+ # Multi Head Self Attention layer 1.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=NUM_HEADS,
+ key_dim=PROJECTION_DIM,
+ dropout=DROPOUT_RATE,
+ )(x1, x1)
+ attention_output = (
+ StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
+ )
+
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, encoded_patches])
+
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
+
+ # MLP layer 1.
+ x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
+ x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4
+
+ # Skip connection 2.
+ outputs = layers.Add()([x2, x4])
+
+ return keras.Model(encoded_patches, outputs, name=name)
+
+
+"""
+We'll now implement a `ViTClassifier` class building on top of the components we just
+developed. Here we'll be following the original pooling strategy used in the ViT paper --
+use a class token and use the feature representations corresponding to it for
+classification.
+"""
+
+
+class ViTClassifier(keras.Model):
+ """Vision Transformer base class."""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ # Patchify + linear projection + reshaping.
+ self.projection = keras.Sequential(
+ [
+ layers.Conv2D(
+ filters=PROJECTION_DIM,
+ kernel_size=(PATCH_SIZE, PATCH_SIZE),
+ strides=(PATCH_SIZE, PATCH_SIZE),
+ padding="VALID",
+ name="conv_projection",
+ ),
+ layers.Reshape(
+ target_shape=(NUM_PATCHES, PROJECTION_DIM),
+ name="flatten_projection",
+ ),
+ ],
+ name="projection",
+ )
+
+ # Positional embedding.
+ init_shape = (
+ 1,
+ NUM_PATCHES + 1,
+ PROJECTION_DIM,
+ )
+ self.positional_embedding = tf.Variable(
+ tf.zeros(init_shape), name="pos_embedding"
+ )
+
+ # Transformer blocks.
+ dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
+ self.transformer_blocks = [
+ transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
+ for i in range(NUM_LAYERS)
+ ]
+
+ # CLS token.
+ initial_value = tf.zeros((1, 1, PROJECTION_DIM))
+ self.cls_token = tf.Variable(
+ initial_value=initial_value, trainable=True, name="cls"
+ )
+
+ # Other layers.
+ self.dropout = layers.Dropout(DROPOUT_RATE)
+ self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
+ self.head = layers.Dense(
+ NUM_CLASSES,
+ name="classification_head",
+ )
+
+ def call(self, inputs, training=True):
+ n = tf.shape(inputs)[0]
+
+ # Create patches and project the patches.
+ projected_patches = self.projection(inputs)
+
+ # Append class token if needed.
+ cls_token = tf.tile(self.cls_token, (n, 1, 1))
+ cls_token = tf.cast(cls_token, projected_patches.dtype)
+ projected_patches = tf.concat([cls_token, projected_patches], axis=1)
+
+ # Add positional embeddings to the projected patches.
+ encoded_patches = (
+ self.positional_embedding + projected_patches
+ ) # (B, number_patches, projection_dim)
+ encoded_patches = self.dropout(encoded_patches)
+
+ # Iterate over the number of layers and stack up blocks of
+ # Transformer.
+ for transformer_module in self.transformer_blocks:
+ # Add a Transformer block.
+ encoded_patches = transformer_module(encoded_patches)
+
+ # Final layer normalization.
+ representation = self.layer_norm(encoded_patches)
+
+ # Pool representation.
+ encoded_patches = representation[:, 0]
+
+ # Classification head.
+ output = self.head(encoded_patches)
+ return output
+
+
+"""
+This class can be used standalone as ViT and is end-to-end trainable. Just remove the
+`distilled` phrase in `MODEL_TYPE` and it should work with `vit_tiny = ViTClassifier()`.
+Let's now extend it to DeiT. The following figure presents the schematic of DeiT (taken
+from the DeiT paper):
+
+
+
+Apart from the class token, DeiT has another token for distillation. During distillation,
+the logits corresponding to the class token are compared to the true labels, and the
+logits corresponding to the distillation token are compared to the teacher's predictions.
+"""
+
+
+class ViTDistilled(ViTClassifier):
+ def __init__(self, regular_training=False, **kwargs):
+ super().__init__(**kwargs)
+ self.num_tokens = 2
+ self.regular_training = regular_training
+
+ # CLS and distillation tokens, positional embedding.
+ init_value = tf.zeros((1, 1, PROJECTION_DIM))
+ self.dist_token = tf.Variable(init_value, name="dist_token")
+ self.positional_embedding = tf.Variable(
+ tf.zeros(
+ (
+ 1,
+ NUM_PATCHES + self.num_tokens,
+ PROJECTION_DIM,
+ )
+ ),
+ name="pos_embedding",
+ )
+
+ # Head layers.
+ self.head = layers.Dense(
+ NUM_CLASSES,
+ name="classification_head",
+ )
+ self.head_dist = layers.Dense(
+ NUM_CLASSES,
+ name="distillation_head",
+ )
+
+ def call(self, inputs, training=True):
+ n = tf.shape(inputs)[0]
+
+ # Create patches and project the patches.
+ projected_patches = self.projection(inputs)
+
+ # Append the tokens.
+ cls_token = tf.tile(self.cls_token, (n, 1, 1))
+ dist_token = tf.tile(self.dist_token, (n, 1, 1))
+ cls_token = tf.cast(cls_token, projected_patches.dtype)
+ dist_token = tf.cast(dist_token, projected_patches.dtype)
+ projected_patches = tf.concat(
+ [cls_token, dist_token, projected_patches], axis=1
+ )
+
+ # Add positional embeddings to the projected patches.
+ encoded_patches = (
+ self.positional_embedding + projected_patches
+ ) # (B, number_patches, projection_dim)
+ encoded_patches = self.dropout(encoded_patches)
+
+ # Iterate over the number of layers and stack up blocks of
+ # Transformer.
+ for transformer_module in self.transformer_blocks:
+ # Add a Transformer block.
+ encoded_patches = transformer_module(encoded_patches)
+
+ # Final layer normalization.
+ representation = self.layer_norm(encoded_patches)
+
+ # Classification heads.
+ x, x_dist = (
+ self.head(representation[:, 0]),
+ self.head_dist(representation[:, 1]),
+ )
+
+ if not training or self.regular_training:
+ # During standard train / finetune, inference average the classifier
+ # predictions.
+ return (x + x_dist) / 2
+
+ elif training:
+ # Only return separate classification predictions when training in distilled
+ # mode.
+ return x, x_dist
+
+
+"""
+Let's verify if the `ViTDistilled` class can be initialized and called as expected.
+"""
+
+deit_tiny_distilled = ViTDistilled()
+
+dummy_inputs = tf.ones((2, 224, 224, 3))
+outputs = deit_tiny_distilled(dummy_inputs, training=False)
+print(outputs.shape)
+
+"""
+## Implementing the trainer
+
+Unlike what happens in standard knowledge distillation
+([Hinton et al.](https://arxiv.org/abs/1503.02531)),
+where a temperature-scaled softmax is used as well as KL divergence,
+DeiT authors use the following loss function:
+
+
+
+
+Here,
+
+* CE is cross-entropy
+* `psi` is the softmax function
+* Z_s denotes student predictions
+* y denotes true labels
+* y_t denotes teacher predictions
+"""
+
+
+class DeiT(keras.Model):
+ # Reference:
+ # https://keras.io/examples/vision/knowledge_distillation/
+ def __init__(self, student, teacher, **kwargs):
+ super().__init__(**kwargs)
+ self.student = student
+ self.teacher = teacher
+
+ self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
+ self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")
+
+ @property
+ def metrics(self):
+ metrics = super().metrics
+ metrics.append(self.student_loss_tracker)
+ metrics.append(self.dist_loss_tracker)
+ return metrics
+
+ def compile(
+ self,
+ optimizer,
+ metrics,
+ student_loss_fn,
+ distillation_loss_fn,
+ ):
+ super().compile(optimizer=optimizer, metrics=metrics)
+ self.student_loss_fn = student_loss_fn
+ self.distillation_loss_fn = distillation_loss_fn
+
+ def train_step(self, data):
+ # Unpack data.
+ x, y = data
+
+ # Forward pass of teacher
+ teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
+ teacher_predictions = tf.argmax(teacher_predictions, -1)
+
+ with tf.GradientTape() as tape:
+ # Forward pass of student.
+ cls_predictions, dist_predictions = self.student(x / 255.0, training=True)
+
+ # Compute losses.
+ student_loss = self.student_loss_fn(y, cls_predictions)
+ distillation_loss = self.distillation_loss_fn(
+ teacher_predictions, dist_predictions
+ )
+ loss = (student_loss + distillation_loss) / 2
+
+ # Compute gradients.
+ trainable_vars = self.student.trainable_variables
+ gradients = tape.gradient(loss, trainable_vars)
+
+ # Update weights.
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+
+ # Update the metrics configured in `compile()`.
+ student_predictions = (cls_predictions + dist_predictions) / 2
+ self.compiled_metrics.update_state(y, student_predictions)
+ self.dist_loss_tracker.update_state(distillation_loss)
+ self.student_loss_tracker.update_state(student_loss)
+
+ # Return a dict of performance.
+ results = {m.name: m.result() for m in self.metrics}
+ return results
+
+ def test_step(self, data):
+ # Unpack the data.
+ x, y = data
+
+ # Compute predictions.
+ y_prediction = self.student(x / 255.0, training=False)
+
+ # Calculate the loss.
+ student_loss = self.student_loss_fn(y, y_prediction)
+
+ # Update the metrics.
+ self.compiled_metrics.update_state(y, y_prediction)
+ self.student_loss_tracker.update_state(student_loss)
+
+ # Return a dict of performance.
+ results = {m.name: m.result() for m in self.metrics}
+ return results
+
+ def call(self, inputs):
+ return self.student(inputs / 255.0, training=False)
+
+
+"""
+## Load the teacher model
+
+This model is based on the BiT family of ResNets
+([Kolesnikov et al.](https://arxiv.org/abs/1912.11370))
+fine-tuned on the `tf_flowers` dataset. You can refer to
+[this notebook](https://github.com/sayakpaul/deit-tf/blob/main/notebooks/bit-teacher.ipynb)
+to know how the training was performed. The teacher model has about 212 Million parameters
+which is about **40x more** than the student.
+"""
+
+"""shell
+wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
+unzip -q bit_teacher_flowers.zip
+"""
+
+bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")
+
+"""
+## Training through distillation
+"""
+
+deit_tiny = ViTDistilled()
+deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
+
+lr_scaled = (BASE_LR / 512) * BATCH_SIZE
+deit_distiller.compile(
+ optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
+ metrics=["accuracy"],
+ student_loss_fn=keras.losses.CategoricalCrossentropy(
+ from_logits=True, label_smoothing=0.1
+ ),
+ distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+)
+_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
+
+"""
+If we had trained the same model (the `ViTClassifier`) from scratch with the exact same
+hyperparameters, the model would have scored about 59% accuracy. You can adapt the following code
+to reproduce this result:
+
+```
+vit_tiny = ViTClassifier()
+
+inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
+x = keras.layers.Rescaling(scale=1./255)(inputs)
+outputs = deit_tiny(x)
+model = keras.Model(inputs, outputs)
+
+model.compile(...)
+model.fit(...)
+```
+"""
+
+"""
+## Notes
+
+* Through the use of distillation, we're effectively transferring the inductive biases of
+a CNN-based teacher model.
+* Interestingly enough, this distillation strategy works better with a CNN as the teacher
+model rather than a Transformer as shown in the paper.
+* The use of regularization to train DeiT models is very important.
+* ViT models are initialized with a combination of different initializers including
+truncated normal, random normal, Glorot uniform, etc. If you're looking for
+end-to-end reproduction of the original results, don't forget to initialize the ViTs well.
+* If you want to explore the pre-trained DeiT models in TensorFlow and Keras with code
+for fine-tuning, [check out these models on TF-Hub](https://tfhub.dev/sayakpaul/collections/deit/1).
+
+## Acknowledgements
+
+* Ross Wightman for keeping
+[`timm`](https://github.com/rwightman/pytorch-image-models)
+updated with readable implementations. I referred to the implementations of ViT and DeiT
+a lot during implementing them in TensorFlow.
+* [Aritra Roy Gosthipaty](https://github.com/ariG23498)
+who implemented some portions of the `ViTClassifier` in another project.
+* [Google Developers Experts](https://developers.google.com/programs/experts/)
+program for supporting me with GCP credits which were used to run experiments for this
+example.
+
+Example available on HuggingFace:
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/deit) | [](https://huggingface.co/spaces/keras-io/deit/) |
+
+"""
diff --git a/knowledge_base/vision/depth_estimation.py b/knowledge_base/vision/depth_estimation.py
new file mode 100644
index 0000000000000000000000000000000000000000..22acc36f09854c0a730471fcfde15ad7baadb263
--- /dev/null
+++ b/knowledge_base/vision/depth_estimation.py
@@ -0,0 +1,533 @@
+"""
+Title: Monocular depth estimation
+Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
+Date created: 2021/08/30
+Last modified: 2024/08/13
+Description: Implement a depth estimation model with a convnet.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+_Depth estimation_ is a crucial step towards inferring scene geometry from 2D images.
+The goal in _monocular depth estimation_ is to predict the depth value of each pixel or
+inferring depth information, given only a single RGB image as input.
+This example will show an approach to build a depth estimation model with a convnet
+and simple loss functions.
+
+
+
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import sys
+
+import tensorflow as tf
+import keras
+from keras import layers
+from keras import ops
+import pandas as pd
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
+keras.utils.set_random_seed(123)
+
+"""
+## Downloading the dataset
+
+We will be using the dataset **DIODE: A Dense Indoor and Outdoor Depth Dataset** for this
+tutorial. However, we use the validation set generating training and evaluation subsets
+for our model. The reason we use the validation set rather than the training set of the original dataset is because
+the training set consists of 81GB of data, which is challenging to download compared
+to the validation set which is only 2.6GB.
+Other datasets that you could use are
+**[NYU-v2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)**
+and **[KITTI](http://www.cvlibs.net/datasets/kitti/)**.
+"""
+
+annotation_folder = "/dataset/"
+if not os.path.exists(os.path.abspath(".") + annotation_folder):
+ annotation_zip = keras.utils.get_file(
+ "val.tar.gz",
+ cache_subdir=os.path.abspath("."),
+ origin="http://diode-dataset.s3.amazonaws.com/val.tar.gz",
+ extract=True,
+ )
+
+"""
+## Preparing the dataset
+
+We only use the indoor images to train our depth estimation model.
+"""
+
+path = "val/indoors"
+
+filelist = []
+
+for root, dirs, files in os.walk(path):
+ for file in files:
+ filelist.append(os.path.join(root, file))
+
+filelist.sort()
+data = {
+ "image": [x for x in filelist if x.endswith(".png")],
+ "depth": [x for x in filelist if x.endswith("_depth.npy")],
+ "mask": [x for x in filelist if x.endswith("_depth_mask.npy")],
+}
+df = pd.DataFrame(data)
+
+df = df.sample(frac=1, random_state=42)
+
+"""
+## Preparing hyperparameters
+"""
+
+HEIGHT = 256
+WIDTH = 256
+LR = 0.00001
+EPOCHS = 30
+BATCH_SIZE = 32
+
+"""
+## Building a data pipeline
+
+1. The pipeline takes a dataframe containing the path for the RGB images,
+as well as the depth and depth mask files.
+2. It reads and resize the RGB images.
+3. It reads the depth and depth mask files, process them to generate the depth map image and
+resize it.
+4. It returns the RGB images and the depth map images for a batch.
+"""
+
+
+class DataGenerator(keras.utils.PyDataset):
+ def __init__(self, data, batch_size=6, dim=(768, 1024), n_channels=3, shuffle=True):
+ super().__init__()
+ """
+ Initialization
+ """
+ self.data = data
+ self.indices = self.data.index.tolist()
+ self.dim = dim
+ self.n_channels = n_channels
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.min_depth = 0.1
+ self.on_epoch_end()
+
+ def __len__(self):
+ return int(np.ceil(len(self.data) / self.batch_size))
+
+ def __getitem__(self, index):
+ if (index + 1) * self.batch_size > len(self.indices):
+ self.batch_size = len(self.indices) - index * self.batch_size
+ # Generate one batch of data
+ # Generate indices of the batch
+ index = self.indices[index * self.batch_size : (index + 1) * self.batch_size]
+ # Find list of IDs
+ batch = [self.indices[k] for k in index]
+ x, y = self.data_generation(batch)
+
+ return x, y
+
+ def on_epoch_end(self):
+ """
+ Updates indexes after each epoch
+ """
+ self.index = np.arange(len(self.indices))
+ if self.shuffle == True:
+ np.random.shuffle(self.index)
+
+ def load(self, image_path, depth_map, mask):
+ """Load input and target image."""
+
+ image_ = cv2.imread(image_path)
+ image_ = cv2.cvtColor(image_, cv2.COLOR_BGR2RGB)
+ image_ = cv2.resize(image_, self.dim)
+ image_ = tf.image.convert_image_dtype(image_, tf.float32)
+
+ depth_map = np.load(depth_map).squeeze()
+
+ mask = np.load(mask)
+ mask = mask > 0
+
+ max_depth = min(300, np.percentile(depth_map, 99))
+ depth_map = np.clip(depth_map, self.min_depth, max_depth)
+ depth_map = np.log(depth_map, where=mask)
+
+ depth_map = np.ma.masked_where(~mask, depth_map)
+
+ depth_map = np.clip(depth_map, 0.1, np.log(max_depth))
+ depth_map = cv2.resize(depth_map, self.dim)
+ depth_map = np.expand_dims(depth_map, axis=2)
+ depth_map = tf.image.convert_image_dtype(depth_map, tf.float32)
+
+ return image_, depth_map
+
+ def data_generation(self, batch):
+ x = np.empty((self.batch_size, *self.dim, self.n_channels))
+ y = np.empty((self.batch_size, *self.dim, 1))
+
+ for i, batch_id in enumerate(batch):
+ x[i,], y[i,] = self.load(
+ self.data["image"][batch_id],
+ self.data["depth"][batch_id],
+ self.data["mask"][batch_id],
+ )
+ x, y = x.astype("float32"), y.astype("float32")
+ return x, y
+
+
+"""
+## Visualizing samples
+"""
+
+
+def visualize_depth_map(samples, test=False, model=None):
+ input, target = samples
+ cmap = plt.cm.jet
+ cmap.set_bad(color="black")
+
+ if test:
+ pred = model.predict(input)
+ fig, ax = plt.subplots(6, 3, figsize=(50, 50))
+ for i in range(6):
+ ax[i, 0].imshow((input[i].squeeze()))
+ ax[i, 1].imshow((target[i].squeeze()), cmap=cmap)
+ ax[i, 2].imshow((pred[i].squeeze()), cmap=cmap)
+
+ else:
+ fig, ax = plt.subplots(6, 2, figsize=(50, 50))
+ for i in range(6):
+ ax[i, 0].imshow((input[i].squeeze()))
+ ax[i, 1].imshow((target[i].squeeze()), cmap=cmap)
+
+
+visualize_samples = next(
+ iter(DataGenerator(data=df, batch_size=6, dim=(HEIGHT, WIDTH)))
+)
+visualize_depth_map(visualize_samples)
+
+"""
+## 3D point cloud visualization
+"""
+
+depth_vis = np.flipud(visualize_samples[1][1].squeeze()) # target
+img_vis = np.flipud(visualize_samples[0][1].squeeze()) # input
+
+fig = plt.figure(figsize=(15, 10))
+ax = plt.axes(projection="3d")
+
+STEP = 3
+for x in range(0, img_vis.shape[0], STEP):
+ for y in range(0, img_vis.shape[1], STEP):
+ ax.scatter(
+ [depth_vis[x, y]] * 3,
+ [y] * 3,
+ [x] * 3,
+ c=tuple(img_vis[x, y, :3] / 255),
+ s=3,
+ )
+ ax.view_init(45, 135)
+
+"""
+## Building the model
+
+1. The basic model is from U-Net.
+2. Addditive skip-connections are implemented in the downscaling block.
+"""
+
+
+class DownscaleBlock(layers.Layer):
+ def __init__(
+ self, filters, kernel_size=(3, 3), padding="same", strides=1, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
+ self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
+ self.reluA = layers.LeakyReLU(negative_slope=0.2)
+ self.reluB = layers.LeakyReLU(negative_slope=0.2)
+ self.bn2a = layers.BatchNormalization()
+ self.bn2b = layers.BatchNormalization()
+
+ self.pool = layers.MaxPool2D((2, 2), (2, 2))
+
+ def call(self, input_tensor):
+ d = self.convA(input_tensor)
+ x = self.bn2a(d)
+ x = self.reluA(x)
+
+ x = self.convB(x)
+ x = self.bn2b(x)
+ x = self.reluB(x)
+
+ x += d
+ p = self.pool(x)
+ return x, p
+
+
+class UpscaleBlock(layers.Layer):
+ def __init__(
+ self, filters, kernel_size=(3, 3), padding="same", strides=1, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.us = layers.UpSampling2D((2, 2))
+ self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
+ self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
+ self.reluA = layers.LeakyReLU(negative_slope=0.2)
+ self.reluB = layers.LeakyReLU(negative_slope=0.2)
+ self.bn2a = layers.BatchNormalization()
+ self.bn2b = layers.BatchNormalization()
+ self.conc = layers.Concatenate()
+
+ def call(self, x, skip):
+ x = self.us(x)
+ concat = self.conc([x, skip])
+ x = self.convA(concat)
+ x = self.bn2a(x)
+ x = self.reluA(x)
+
+ x = self.convB(x)
+ x = self.bn2b(x)
+ x = self.reluB(x)
+
+ return x
+
+
+class BottleNeckBlock(layers.Layer):
+ def __init__(
+ self, filters, kernel_size=(3, 3), padding="same", strides=1, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
+ self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
+ self.reluA = layers.LeakyReLU(negative_slope=0.2)
+ self.reluB = layers.LeakyReLU(negative_slope=0.2)
+
+ def call(self, x):
+ x = self.convA(x)
+ x = self.reluA(x)
+ x = self.convB(x)
+ x = self.reluB(x)
+ return x
+
+
+"""
+## Defining the loss
+
+We will optimize 3 losses in our mode.
+1. Structural similarity index(SSIM).
+2. L1-loss, or Point-wise depth in our case.
+3. Depth smoothness loss.
+
+Out of the three loss functions, SSIM contributes the most to improving model performance.
+"""
+
+
+def image_gradients(image):
+ if len(ops.shape(image)) != 4:
+ raise ValueError(
+ "image_gradients expects a 4D tensor "
+ "[batch_size, h, w, d], not {}.".format(ops.shape(image))
+ )
+
+ image_shape = ops.shape(image)
+ batch_size, height, width, depth = ops.unstack(image_shape)
+
+ dy = image[:, 1:, :, :] - image[:, :-1, :, :]
+ dx = image[:, :, 1:, :] - image[:, :, :-1, :]
+
+ # Return tensors with same size as original image by concatenating
+ # zeros. Place the gradient [I(x+1,y) - I(x,y)] on the base pixel (x, y).
+ shape = ops.stack([batch_size, 1, width, depth])
+ dy = ops.concatenate([dy, ops.zeros(shape, dtype=image.dtype)], axis=1)
+ dy = ops.reshape(dy, image_shape)
+
+ shape = ops.stack([batch_size, height, 1, depth])
+ dx = ops.concatenate([dx, ops.zeros(shape, dtype=image.dtype)], axis=2)
+ dx = ops.reshape(dx, image_shape)
+
+ return dy, dx
+
+
+class DepthEstimationModel(keras.Model):
+ def __init__(self):
+ super().__init__()
+ self.ssim_loss_weight = 0.85
+ self.l1_loss_weight = 0.1
+ self.edge_loss_weight = 0.9
+ self.loss_metric = keras.metrics.Mean(name="loss")
+ f = [16, 32, 64, 128, 256]
+ self.downscale_blocks = [
+ DownscaleBlock(f[0]),
+ DownscaleBlock(f[1]),
+ DownscaleBlock(f[2]),
+ DownscaleBlock(f[3]),
+ ]
+ self.bottle_neck_block = BottleNeckBlock(f[4])
+ self.upscale_blocks = [
+ UpscaleBlock(f[3]),
+ UpscaleBlock(f[2]),
+ UpscaleBlock(f[1]),
+ UpscaleBlock(f[0]),
+ ]
+ self.conv_layer = layers.Conv2D(1, (1, 1), padding="same", activation="tanh")
+
+ def calculate_loss(self, target, pred):
+ # Edges
+ dy_true, dx_true = image_gradients(target)
+ dy_pred, dx_pred = image_gradients(pred)
+ weights_x = ops.cast(ops.exp(ops.mean(ops.abs(dx_true))), "float32")
+ weights_y = ops.cast(ops.exp(ops.mean(ops.abs(dy_true))), "float32")
+
+ # Depth smoothness
+ smoothness_x = dx_pred * weights_x
+ smoothness_y = dy_pred * weights_y
+
+ depth_smoothness_loss = ops.mean(abs(smoothness_x)) + ops.mean(
+ abs(smoothness_y)
+ )
+
+ # Structural similarity (SSIM) index
+ ssim_loss = ops.mean(
+ 1
+ - tf.image.ssim(
+ target, pred, max_val=WIDTH, filter_size=7, k1=0.01**2, k2=0.03**2
+ )
+ )
+ # Point-wise depth
+ l1_loss = ops.mean(ops.abs(target - pred))
+
+ loss = (
+ (self.ssim_loss_weight * ssim_loss)
+ + (self.l1_loss_weight * l1_loss)
+ + (self.edge_loss_weight * depth_smoothness_loss)
+ )
+
+ return loss
+
+ @property
+ def metrics(self):
+ return [self.loss_metric]
+
+ def train_step(self, batch_data):
+ input, target = batch_data
+ with tf.GradientTape() as tape:
+ pred = self(input, training=True)
+ loss = self.calculate_loss(target, pred)
+
+ gradients = tape.gradient(loss, self.trainable_variables)
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
+ self.loss_metric.update_state(loss)
+ return {
+ "loss": self.loss_metric.result(),
+ }
+
+ def test_step(self, batch_data):
+ input, target = batch_data
+
+ pred = self(input, training=False)
+ loss = self.calculate_loss(target, pred)
+
+ self.loss_metric.update_state(loss)
+ return {
+ "loss": self.loss_metric.result(),
+ }
+
+ def call(self, x):
+ c1, p1 = self.downscale_blocks[0](x)
+ c2, p2 = self.downscale_blocks[1](p1)
+ c3, p3 = self.downscale_blocks[2](p2)
+ c4, p4 = self.downscale_blocks[3](p3)
+
+ bn = self.bottle_neck_block(p4)
+
+ u1 = self.upscale_blocks[0](bn, c4)
+ u2 = self.upscale_blocks[1](u1, c3)
+ u3 = self.upscale_blocks[2](u2, c2)
+ u4 = self.upscale_blocks[3](u3, c1)
+
+ return self.conv_layer(u4)
+
+
+"""
+## Model training
+"""
+
+optimizer = keras.optimizers.SGD(
+ learning_rate=LR,
+ nesterov=False,
+)
+model = DepthEstimationModel()
+# Compile the model
+model.compile(optimizer)
+
+train_loader = DataGenerator(
+ data=df[:260].reset_index(drop="true"), batch_size=BATCH_SIZE, dim=(HEIGHT, WIDTH)
+)
+validation_loader = DataGenerator(
+ data=df[260:].reset_index(drop="true"), batch_size=BATCH_SIZE, dim=(HEIGHT, WIDTH)
+)
+model.fit(
+ train_loader,
+ epochs=EPOCHS,
+ validation_data=validation_loader,
+)
+
+"""
+## Visualizing model output
+
+We visualize the model output over the validation set.
+The first image is the RGB image, the second image is the ground truth depth map image
+and the third one is the predicted depth map image.
+"""
+
+test_loader = next(
+ iter(
+ DataGenerator(
+ data=df[265:].reset_index(drop="true"), batch_size=6, dim=(HEIGHT, WIDTH)
+ )
+ )
+)
+visualize_depth_map(test_loader, test=True, model=model)
+
+test_loader = next(
+ iter(
+ DataGenerator(
+ data=df[300:].reset_index(drop="true"), batch_size=6, dim=(HEIGHT, WIDTH)
+ )
+ )
+)
+visualize_depth_map(test_loader, test=True, model=model)
+
+"""
+## Possible improvements
+
+1. You can improve this model by replacing the encoding part of the U-Net with a
+pretrained DenseNet or ResNet.
+2. Loss functions play an important role in solving this problem.
+Tuning the loss functions may yield significant improvement.
+"""
+
+"""
+## References
+
+The following papers go deeper into possible approaches for depth estimation.
+1. [Depth Prediction Without the Sensors: Leveraging Structure for Unsupervised Learning from Monocular Videos](https://arxiv.org/abs/1811.06152v1)
+2. [Digging Into Self-Supervised Monocular Depth Estimation](https://openaccess.thecvf.com/content_ICCV_2019/papers/Godard_Digging_Into_Self-Supervised_Monocular_Depth_Estimation_ICCV_2019_paper.pdf)
+3. [Deeper Depth Prediction with Fully Convolutional Residual Networks](https://arxiv.org/abs/1606.00373v2)
+
+You can also find helpful implementations in the papers with code depth estimation task.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/spaces/keras-io/Monocular-Depth-Estimation)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/keras-io/monocular-depth-estimation).
+"""
diff --git a/knowledge_base/vision/eanet.py b/knowledge_base/vision/eanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..45de6b317421d1204ed372f7aac3db4316325a2c
--- /dev/null
+++ b/knowledge_base/vision/eanet.py
@@ -0,0 +1,322 @@
+"""
+Title: Image classification with EANet (External Attention Transformer)
+Author: [ZhiYong Chang](https://github.com/czy00000)
+Date created: 2021/10/19
+Last modified: 2023/07/18
+Description: Image classification with a Transformer that leverages external attention.
+Accelerator: GPU
+Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com)
+"""
+
+"""
+## Introduction
+
+This example implements the [EANet](https://arxiv.org/abs/2105.02358)
+model for image classification, and demonstrates it on the CIFAR-100 dataset.
+EANet introduces a novel attention mechanism
+named ***external attention***, based on two external, small, learnable, and
+shared memories, which can be implemented easily by simply using two cascaded
+linear layers and two normalization layers. It conveniently replaces self-attention
+as used in existing architectures. External attention has linear complexity, as it only
+implicitly considers the correlations between all samples.
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import layers
+from keras import ops
+
+import matplotlib.pyplot as plt
+
+
+"""
+## Prepare the data
+"""
+
+num_classes = 100
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+y_train = keras.utils.to_categorical(y_train, num_classes)
+y_test = keras.utils.to_categorical(y_test, num_classes)
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+"""
+## Configure the hyperparameters
+"""
+
+weight_decay = 0.0001
+learning_rate = 0.001
+label_smoothing = 0.1
+validation_split = 0.2
+batch_size = 128
+num_epochs = 50
+patch_size = 2 # Size of the patches to be extracted from the input images.
+num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
+embedding_dim = 64 # Number of hidden units.
+mlp_dim = 64
+dim_coefficient = 4
+num_heads = 4
+attention_dropout = 0.2
+projection_dropout = 0.2
+num_transformer_blocks = 8 # Number of repetitions of the transformer layer
+
+print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
+print(f"Patches per image: {num_patches}")
+
+
+"""
+## Use data augmentation
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(factor=0.1),
+ layers.RandomContrast(factor=0.1),
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
+ ],
+ name="data_augmentation",
+)
+# Compute the mean and the variance of the training data for normalization.
+data_augmentation.layers[0].adapt(x_train)
+
+"""
+## Implement the patch extraction and encoding layer
+"""
+
+
+class PatchExtract(layers.Layer):
+ def __init__(self, patch_size, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+
+ def call(self, x):
+ B, C = ops.shape(x)[0], ops.shape(x)[-1]
+ x = ops.image.extract_patches(x, self.patch_size)
+ x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
+ return x
+
+
+class PatchEmbedding(layers.Layer):
+ def __init__(self, num_patch, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.num_patch = num_patch
+ self.proj = layers.Dense(embed_dim)
+ self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
+
+ def call(self, patch):
+ pos = ops.arange(start=0, stop=self.num_patch, step=1)
+ return self.proj(patch) + self.pos_embed(pos)
+
+
+"""
+## Implement the external attention block
+"""
+
+
+def external_attention(
+ x,
+ dim,
+ num_heads,
+ dim_coefficient=4,
+ attention_dropout=0,
+ projection_dropout=0,
+):
+ _, num_patch, channel = x.shape
+ assert dim % num_heads == 0
+ num_heads = num_heads * dim_coefficient
+
+ x = layers.Dense(dim * dim_coefficient)(x)
+ # create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]
+ x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
+ # a linear layer M_k
+ attn = layers.Dense(dim // dim_coefficient)(x)
+ # normalize attention map
+ attn = layers.Softmax(axis=2)(attn)
+ # dobule-normalization
+ attn = layers.Lambda(
+ lambda attn: ops.divide(
+ attn,
+ ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
+ )
+ )(attn)
+ attn = layers.Dropout(attention_dropout)(attn)
+ # a linear layer M_v
+ x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
+ x = ops.transpose(x, axes=[0, 2, 1, 3])
+ x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
+ # a linear layer to project original dim
+ x = layers.Dense(dim)(x)
+ x = layers.Dropout(projection_dropout)(x)
+ return x
+
+
+"""
+## Implement the MLP block
+"""
+
+
+def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
+ x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
+ x = layers.Dropout(drop_rate)(x)
+ x = layers.Dense(embedding_dim)(x)
+ x = layers.Dropout(drop_rate)(x)
+ return x
+
+
+"""
+## Implement the Transformer block
+"""
+
+
+def transformer_encoder(
+ x,
+ embedding_dim,
+ mlp_dim,
+ num_heads,
+ dim_coefficient,
+ attention_dropout,
+ projection_dropout,
+ attention_type="external_attention",
+):
+ residual_1 = x
+ x = layers.LayerNormalization(epsilon=1e-5)(x)
+ if attention_type == "external_attention":
+ x = external_attention(
+ x,
+ embedding_dim,
+ num_heads,
+ dim_coefficient,
+ attention_dropout,
+ projection_dropout,
+ )
+ elif attention_type == "self_attention":
+ x = layers.MultiHeadAttention(
+ num_heads=num_heads,
+ key_dim=embedding_dim,
+ dropout=attention_dropout,
+ )(x, x)
+ x = layers.add([x, residual_1])
+ residual_2 = x
+ x = layers.LayerNormalization(epsilon=1e-5)(x)
+ x = mlp(x, embedding_dim, mlp_dim)
+ x = layers.add([x, residual_2])
+ return x
+
+
+"""
+## Implement the EANet model
+"""
+
+"""
+The EANet model leverages external attention.
+The computational complexity of traditional self attention is `O(d * N ** 2)`,
+where `d` is the embedding size, and `N` is the number of patch.
+the authors find that most pixels are closely related to just a few other
+pixels, and an `N`-to-`N` attention matrix may be redundant.
+So, they propose as an alternative an external
+attention module where the computational complexity of external attention is `O(d * S * N)`.
+As `d` and `S` are hyper-parameters,
+the proposed algorithm is linear in the number of pixels. In fact, this is equivalent
+to a drop patch operation, because a lot of information contained in a patch
+in an image is redundant and unimportant.
+"""
+
+
+def get_model(attention_type="external_attention"):
+ inputs = layers.Input(shape=input_shape)
+ # Image augment
+ x = data_augmentation(inputs)
+ # Extract patches.
+ x = PatchExtract(patch_size)(x)
+ # Create patch embedding.
+ x = PatchEmbedding(num_patches, embedding_dim)(x)
+ # Create Transformer block.
+ for _ in range(num_transformer_blocks):
+ x = transformer_encoder(
+ x,
+ embedding_dim,
+ mlp_dim,
+ num_heads,
+ dim_coefficient,
+ attention_dropout,
+ projection_dropout,
+ attention_type,
+ )
+
+ x = layers.GlobalAveragePooling1D()(x)
+ outputs = layers.Dense(num_classes, activation="softmax")(x)
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+"""
+## Train on CIFAR-100
+
+"""
+
+
+model = get_model(attention_type="external_attention")
+
+model.compile(
+ loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
+ optimizer=keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ ),
+ metrics=[
+ keras.metrics.CategoricalAccuracy(name="accuracy"),
+ keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+)
+
+history = model.fit(
+ x_train,
+ y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=validation_split,
+)
+
+"""
+### Let's visualize the training progress of the model.
+
+"""
+
+plt.plot(history.history["loss"], label="train_loss")
+plt.plot(history.history["val_loss"], label="val_loss")
+plt.xlabel("Epochs")
+plt.ylabel("Loss")
+plt.title("Train and Validation Losses Over Epochs", fontsize=14)
+plt.legend()
+plt.grid()
+plt.show()
+
+"""
+### Let's display the final results of the test on CIFAR-100.
+
+"""
+
+loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
+print(f"Test loss: {round(loss, 2)}")
+print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+"""
+EANet just replaces self attention in Vit with external attention.
+The traditional Vit achieved a ~73% test top-5 accuracy and ~41 top-1 accuracy after
+training 50 epochs, but with 0.6M parameters. Under the same experimental environment
+and the same hyperparameters, The EANet model we just trained has just 0.3M parameters,
+and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the
+effectiveness of external attention.
+
+We only show the training
+process of EANet, you can train Vit under the same experimental conditions and observe
+the test results.
+"""
diff --git a/knowledge_base/vision/edsr.py b/knowledge_base/vision/edsr.py
new file mode 100644
index 0000000000000000000000000000000000000000..87cb1ccd8d1a8f76e89b72b439c240c107a3201a
--- /dev/null
+++ b/knowledge_base/vision/edsr.py
@@ -0,0 +1,357 @@
+"""
+Title: Enhanced Deep Residual Networks for single-image super-resolution
+Author: Gitesh Chawda
+Date created: 2022/04/07
+Last modified: 2024/08/27
+Description: Training an EDSR model on the DIV2K Dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we implement
+[Enhanced Deep Residual Networks for Single Image Super-Resolution (EDSR)](https://arxiv.org/abs/1707.02921)
+by Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee.
+
+The EDSR architecture is based on the SRResNet architecture and consists of multiple
+residual blocks. It uses constant scaling layers instead of batch normalization layers to
+produce consistent results (input and output have similar distributions, thus
+normalizing intermediate features may not be desirable). Instead of using a L2 loss (mean squared error),
+the authors employed an L1 loss (mean absolute error), which performs better empirically.
+
+Our implementation only includes 16 residual blocks with 64 channels.
+
+Alternatively, as shown in the Keras example
+[Image Super-Resolution using an Efficient Sub-Pixel CNN](https://keras.io/examples/vision/super_resolution_sub_pixel/#image-superresolution-using-an-efficient-subpixel-cnn),
+you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five
+best-performing super-resolution methods based on PSNR scores. However, it has more
+parameters and requires more computational power than other approaches.
+It has a PSNR value (โ34db) that is slightly higher than ESPCN (โ32db).
+As per the survey paper, EDSR performs better than ESPCN.
+
+Paper:
+[A comprehensive review of deep learning based single image super-resolution](https://arxiv.org/abs/2102.09351)
+
+Comparison Graph:
+
+"""
+
+"""
+## Imports
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+import tensorflow as tf
+import tensorflow_datasets as tfds
+import matplotlib.pyplot as plt
+
+import keras
+from keras import layers
+from keras import ops
+
+AUTOTUNE = tf.data.AUTOTUNE
+
+"""
+## Download the training dataset
+
+We use the DIV2K Dataset, a prominent single-image super-resolution dataset with 1,000
+images of scenes with various sorts of degradations,
+divided into 800 images for training, 100 images for validation, and 100
+images for testing. We use 4x bicubic downsampled images as our "low quality" reference.
+"""
+
+# Download DIV2K from TF Datasets
+# Using bicubic 4x degradation type
+div2k_data = tfds.image.Div2k(config="bicubic_x4")
+div2k_data.download_and_prepare()
+
+# Taking train data from div2k_data object
+train = div2k_data.as_dataset(split="train", as_supervised=True)
+train_cache = train.cache()
+# Validation data
+val = div2k_data.as_dataset(split="validation", as_supervised=True)
+val_cache = val.cache()
+
+"""
+## Flip, crop and resize images
+"""
+
+
+def flip_left_right(lowres_img, highres_img):
+ """Flips Images to left and right."""
+
+ # Outputs random values from a uniform distribution in between 0 to 1
+ rn = keras.random.uniform(shape=(), maxval=1)
+ # If rn is less than 0.5 it returns original lowres_img and highres_img
+ # If rn is greater than 0.5 it returns flipped image
+ return ops.cond(
+ rn < 0.5,
+ lambda: (lowres_img, highres_img),
+ lambda: (
+ ops.flip(lowres_img),
+ ops.flip(highres_img),
+ ),
+ )
+
+
+def random_rotate(lowres_img, highres_img):
+ """Rotates Images by 90 degrees."""
+
+ # Outputs random values from uniform distribution in between 0 to 4
+ rn = ops.cast(
+ keras.random.uniform(shape=(), maxval=4, dtype="float32"), dtype="int32"
+ )
+ # Here rn signifies number of times the image(s) are rotated by 90 degrees
+ return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)
+
+
+def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):
+ """Crop images.
+
+ low resolution images: 24x24
+ high resolution images: 96x96
+ """
+ lowres_crop_size = hr_crop_size // scale # 96//4=24
+ lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)
+
+ lowres_width = ops.cast(
+ keras.random.uniform(
+ shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype="float32"
+ ),
+ dtype="int32",
+ )
+ lowres_height = ops.cast(
+ keras.random.uniform(
+ shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype="float32"
+ ),
+ dtype="int32",
+ )
+
+ highres_width = lowres_width * scale
+ highres_height = lowres_height * scale
+
+ lowres_img_cropped = lowres_img[
+ lowres_height : lowres_height + lowres_crop_size,
+ lowres_width : lowres_width + lowres_crop_size,
+ ] # 24x24
+ highres_img_cropped = highres_img[
+ highres_height : highres_height + hr_crop_size,
+ highres_width : highres_width + hr_crop_size,
+ ] # 96x96
+
+ return lowres_img_cropped, highres_img_cropped
+
+
+"""
+## Prepare a `tf.data.Dataset` object
+
+We augment the training data with random horizontal flips and 90 rotations.
+
+As low resolution images, we use 24x24 RGB input patches.
+"""
+
+
+def dataset_object(dataset_cache, training=True):
+ ds = dataset_cache
+ ds = ds.map(
+ lambda lowres, highres: random_crop(lowres, highres, scale=4),
+ num_parallel_calls=AUTOTUNE,
+ )
+
+ if training:
+ ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)
+ ds = ds.map(flip_left_right, num_parallel_calls=AUTOTUNE)
+ # Batching Data
+ ds = ds.batch(16)
+
+ if training:
+ # Repeating Data, so that cardinality if dataset becomes infinte
+ ds = ds.repeat()
+ # prefetching allows later images to be prepared while the current image is being processed
+ ds = ds.prefetch(buffer_size=AUTOTUNE)
+ return ds
+
+
+train_ds = dataset_object(train_cache, training=True)
+val_ds = dataset_object(val_cache, training=False)
+
+"""
+## Visualize the data
+
+Let's visualize a few sample images:
+"""
+
+lowres, highres = next(iter(train_ds))
+
+# High Resolution Images
+plt.figure(figsize=(10, 10))
+for i in range(9):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(highres[i].numpy().astype("uint8"))
+ plt.title(highres[i].shape)
+ plt.axis("off")
+
+# Low Resolution Images
+plt.figure(figsize=(10, 10))
+for i in range(9):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(lowres[i].numpy().astype("uint8"))
+ plt.title(lowres[i].shape)
+ plt.axis("off")
+
+
+def PSNR(super_resolution, high_resolution):
+ """Compute the peak signal-to-noise ratio, measures quality of image."""
+ # Max value of pixel is 255
+ psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]
+ return psnr_value
+
+
+"""
+## Build the model
+
+In the paper, the authors train three models: EDSR, MDSR, and a baseline model. In this code example,
+we only train the baseline model.
+
+### Comparison with model with three residual blocks
+
+The residual block design of EDSR differs from that of ResNet. Batch normalization
+layers have been removed (together with the final ReLU activation): since batch normalization
+layers normalize the features, they hurt output value range flexibility.
+It is thus better to remove them. Further, it also helps reduce the
+amount of GPU RAM required by the model, since the batch normalization layers consume the same amount of
+memory as the preceding convolutional layers.
+
+
+"""
+
+
+class EDSRModel(keras.Model):
+ def train_step(self, data):
+ # Unpack the data. Its structure depends on your model and
+ # on what you pass to `fit()`.
+ x, y = data
+
+ with tf.GradientTape() as tape:
+ y_pred = self(x, training=True) # Forward pass
+ # Compute the loss value
+ # (the loss function is configured in `compile()`)
+ loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
+
+ # Compute gradients
+ trainable_vars = self.trainable_variables
+ gradients = tape.gradient(loss, trainable_vars)
+ # Update weights
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+ # Update metrics (includes the metric that tracks the loss)
+ self.compiled_metrics.update_state(y, y_pred)
+ # Return a dict mapping metric names to current value
+ return {m.name: m.result() for m in self.metrics}
+
+ def predict_step(self, x):
+ # Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast
+ x = ops.cast(tf.expand_dims(x, axis=0), dtype="float32")
+ # Passing low resolution image to model
+ super_resolution_img = self(x, training=False)
+ # Clips the tensor from min(0) to max(255)
+ super_resolution_img = ops.clip(super_resolution_img, 0, 255)
+ # Rounds the values of a tensor to the nearest integer
+ super_resolution_img = ops.round(super_resolution_img)
+ # Removes dimensions of size 1 from the shape of a tensor and converting to uint8
+ super_resolution_img = ops.squeeze(
+ ops.cast(super_resolution_img, dtype="uint8"), axis=0
+ )
+ return super_resolution_img
+
+
+# Residual Block
+def ResBlock(inputs):
+ x = layers.Conv2D(64, 3, padding="same", activation="relu")(inputs)
+ x = layers.Conv2D(64, 3, padding="same")(x)
+ x = layers.Add()([inputs, x])
+ return x
+
+
+# Upsampling Block
+def Upsampling(inputs, factor=2, **kwargs):
+ x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(inputs)
+ x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
+ x = layers.Conv2D(64 * (factor**2), 3, padding="same", **kwargs)(x)
+ x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)
+ return x
+
+
+def make_model(num_filters, num_of_residual_blocks):
+ # Flexible Inputs to input_layer
+ input_layer = layers.Input(shape=(None, None, 3))
+ # Scaling Pixel Values
+ x = layers.Rescaling(scale=1.0 / 255)(input_layer)
+ x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)
+
+ # 16 residual blocks
+ for _ in range(num_of_residual_blocks):
+ x_new = ResBlock(x_new)
+
+ x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
+ x = layers.Add()([x, x_new])
+
+ x = Upsampling(x)
+ x = layers.Conv2D(3, 3, padding="same")(x)
+
+ output_layer = layers.Rescaling(scale=255)(x)
+ return EDSRModel(input_layer, output_layer)
+
+
+model = make_model(num_filters=64, num_of_residual_blocks=16)
+
+"""
+## Train the model
+"""
+
+# Using adam optimizer with initial learning rate as 1e-4, changing learning rate after 5000 steps to 5e-5
+optim_edsr = keras.optimizers.Adam(
+ learning_rate=keras.optimizers.schedules.PiecewiseConstantDecay(
+ boundaries=[5000], values=[1e-4, 5e-5]
+ )
+)
+# Compiling model with loss as mean absolute error(L1 Loss) and metric as psnr
+model.compile(optimizer=optim_edsr, loss="mae", metrics=[PSNR])
+# Training for more epochs will improve results
+model.fit(train_ds, epochs=100, steps_per_epoch=200, validation_data=val_ds)
+
+"""
+## Run inference on new images and plot the results
+"""
+
+
+def plot_results(lowres, preds):
+ """
+ Displays low resolution image and super resolution image
+ """
+ plt.figure(figsize=(24, 14))
+ plt.subplot(132), plt.imshow(lowres), plt.title("Low resolution")
+ plt.subplot(133), plt.imshow(preds), plt.title("Prediction")
+ plt.show()
+
+
+for lowres, highres in val.take(10):
+ lowres = tf.image.random_crop(lowres, (150, 150, 3))
+ preds = model.predict_step(lowres)
+ plot_results(lowres, preds)
+
+"""
+## Final remarks
+
+In this example, we implemented the EDSR model (Enhanced Deep Residual Networks for Single Image
+Super-Resolution). You could improve the model accuracy by training the model for more epochs, as well as
+training the model with a wider variety of inputs with mixed downgrading factors, so as to
+be able to handle a greater range of real-world images.
+
+You could also improve on the given baseline EDSR model by implementing EDSR+,
+or MDSR( Multi-Scale super-resolution) and MDSR+,
+which were proposed in the same paper.
+"""
diff --git a/knowledge_base/vision/fixres.py b/knowledge_base/vision/fixres.py
new file mode 100644
index 0000000000000000000000000000000000000000..dafb2f76e1dc46b897d25211ff273b623d69c864
--- /dev/null
+++ b/knowledge_base/vision/fixres.py
@@ -0,0 +1,338 @@
+"""
+Title: FixRes: Fixing train-test resolution discrepancy
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/10/08
+Last modified: 2021/10/10
+Description: Mitigating resolution discrepancy between training and test sets.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+It is a common practice to use the same input image resolution while training and testing
+vision models. However, as investigated in
+[Fixing the train-test resolution discrepancy](https://arxiv.org/abs/1906.06423)
+(Touvron et al.), this practice leads to suboptimal performance. Data augmentation
+is an indispensable part of the training process of deep neural networks. For vision models, we
+typically use random resized crops during training and center crops during inference.
+This introduces a discrepancy in the object sizes seen during training and inference.
+As shown by Touvron et al., if we can fix this discrepancy, we can significantly
+boost model performance.
+
+In this example, we implement the **FixRes** techniques introduced by Touvron et al.
+to fix this discrepancy.
+"""
+
+"""
+## Imports
+"""
+
+import keras
+from keras import layers
+import tensorflow as tf # just for image processing and pipeline
+
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+
+import matplotlib.pyplot as plt
+
+"""
+## Load the `tf_flowers` dataset
+"""
+
+train_dataset, val_dataset = tfds.load(
+ "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
+)
+
+num_train = train_dataset.cardinality()
+num_val = val_dataset.cardinality()
+print(f"Number of training examples: {num_train}")
+print(f"Number of validation examples: {num_val}")
+
+"""
+## Data preprocessing utilities
+"""
+
+"""
+We create three datasets:
+
+1. A dataset with a smaller resolution - 128x128.
+2. Two datasets with a larger resolution - 224x224.
+
+We will apply different augmentation transforms to the larger-resolution datasets.
+
+The idea of FixRes is to first train a model on a smaller resolution dataset and then fine-tune
+it on a larger resolution dataset. This simple yet effective recipe leads to non-trivial performance
+improvements. Please refer to the [original paper](https://arxiv.org/abs/1906.06423) for
+results.
+"""
+
+# Reference: https://github.com/facebookresearch/FixRes/blob/main/transforms_v2.py.
+
+batch_size = 32
+auto = tf.data.AUTOTUNE
+smaller_size = 128
+bigger_size = 224
+
+size_for_resizing = int((bigger_size / smaller_size) * bigger_size)
+central_crop_layer = layers.CenterCrop(bigger_size, bigger_size)
+
+
+def preprocess_initial(train, image_size):
+ """Initial preprocessing function for training on smaller resolution.
+
+ For training, do random_horizontal_flip -> random_crop.
+ For validation, just resize.
+ No color-jittering has been used.
+ """
+
+ def _pp(image, label, train):
+ if train:
+ channels = image.shape[-1]
+ begin, size, _ = tf.image.sample_distorted_bounding_box(
+ tf.shape(image),
+ tf.zeros([0, 0, 4], tf.float32),
+ area_range=(0.05, 1.0),
+ min_object_covered=0,
+ use_image_if_no_bounding_boxes=True,
+ )
+ image = tf.slice(image, begin, size)
+
+ image.set_shape([None, None, channels])
+ image = tf.image.resize(image, [image_size, image_size])
+ image = tf.image.random_flip_left_right(image)
+ else:
+ image = tf.image.resize(image, [image_size, image_size])
+
+ return image, label
+
+ return _pp
+
+
+def preprocess_finetune(image, label, train):
+ """Preprocessing function for fine-tuning on a higher resolution.
+
+ For training, resize to a bigger resolution to maintain the ratio ->
+ random_horizontal_flip -> center_crop.
+ For validation, do the same without any horizontal flipping.
+ No color-jittering has been used.
+ """
+ image = tf.image.resize(image, [size_for_resizing, size_for_resizing])
+ if train:
+ image = tf.image.random_flip_left_right(image)
+ image = central_crop_layer(image[None, ...])[0]
+
+ return image, label
+
+
+def make_dataset(
+ dataset: tf.data.Dataset,
+ train: bool,
+ image_size: int = smaller_size,
+ fixres: bool = True,
+ num_parallel_calls=auto,
+):
+ if image_size not in [smaller_size, bigger_size]:
+ raise ValueError(f"{image_size} resolution is not supported.")
+
+ # Determine which preprocessing function we are using.
+ if image_size == smaller_size:
+ preprocess_func = preprocess_initial(train, image_size)
+ elif not fixres and image_size == bigger_size:
+ preprocess_func = preprocess_initial(train, image_size)
+ else:
+ preprocess_func = preprocess_finetune
+
+ dataset = dataset.map(
+ lambda x, y: preprocess_func(x, y, train),
+ num_parallel_calls=num_parallel_calls,
+ )
+ dataset = dataset.batch(batch_size)
+
+ if train:
+ dataset = dataset.shuffle(batch_size * 10)
+
+ return dataset.prefetch(num_parallel_calls)
+
+
+"""
+Notice how the augmentation transforms vary for the kind of dataset we are preparing.
+"""
+
+"""
+## Prepare datasets
+"""
+
+initial_train_dataset = make_dataset(train_dataset, train=True, image_size=smaller_size)
+initial_val_dataset = make_dataset(val_dataset, train=False, image_size=smaller_size)
+
+finetune_train_dataset = make_dataset(train_dataset, train=True, image_size=bigger_size)
+finetune_val_dataset = make_dataset(val_dataset, train=False, image_size=bigger_size)
+
+vanilla_train_dataset = make_dataset(
+ train_dataset, train=True, image_size=bigger_size, fixres=False
+)
+vanilla_val_dataset = make_dataset(
+ val_dataset, train=False, image_size=bigger_size, fixres=False
+)
+
+"""
+## Visualize the datasets
+"""
+
+
+def visualize_dataset(batch_images):
+ plt.figure(figsize=(10, 10))
+ for n in range(25):
+ ax = plt.subplot(5, 5, n + 1)
+ plt.imshow(batch_images[n].numpy().astype("int"))
+ plt.axis("off")
+ plt.show()
+
+ print(f"Batch shape: {batch_images.shape}.")
+
+
+# Smaller resolution.
+initial_sample_images, _ = next(iter(initial_train_dataset))
+visualize_dataset(initial_sample_images)
+
+# Bigger resolution, only for fine-tuning.
+finetune_sample_images, _ = next(iter(finetune_train_dataset))
+visualize_dataset(finetune_sample_images)
+
+# Bigger resolution, with the same augmentation transforms as
+# the smaller resolution dataset.
+vanilla_sample_images, _ = next(iter(vanilla_train_dataset))
+visualize_dataset(vanilla_sample_images)
+
+"""
+## Model training utilities
+
+We train multiple variants of ResNet50V2
+([He et al.](https://arxiv.org/abs/1603.05027)):
+
+1. On the smaller resolution dataset (128x128). It will be trained from scratch.
+2. Then fine-tune the model from 1 on the larger resolution (224x224) dataset.
+3. Train another ResNet50V2 from scratch on the larger resolution dataset.
+
+As a reminder, the larger resolution datasets differ in terms of their augmentation
+transforms.
+"""
+
+
+def get_training_model(num_classes=5):
+ inputs = layers.Input((None, None, 3))
+ resnet_base = keras.applications.ResNet50V2(
+ include_top=False, weights=None, pooling="avg"
+ )
+ resnet_base.trainable = True
+
+ x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(inputs)
+ x = resnet_base(x)
+ outputs = layers.Dense(num_classes, activation="softmax")(x)
+ return keras.Model(inputs, outputs)
+
+
+def train_and_evaluate(
+ model,
+ train_ds,
+ val_ds,
+ epochs,
+ learning_rate=1e-3,
+ use_early_stopping=False,
+):
+ optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
+ model.compile(
+ optimizer=optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+ )
+
+ if use_early_stopping:
+ es_callback = keras.callbacks.EarlyStopping(patience=5)
+ callbacks = [es_callback]
+ else:
+ callbacks = None
+
+ model.fit(
+ train_ds,
+ validation_data=val_ds,
+ epochs=epochs,
+ callbacks=callbacks,
+ )
+
+ _, accuracy = model.evaluate(val_ds)
+ print(f"Top-1 accuracy on the validation set: {accuracy*100:.2f}%.")
+ return model
+
+
+"""
+## Experiment 1: Train on 128x128 and then fine-tune on 224x224
+"""
+
+epochs = 30
+
+smaller_res_model = get_training_model()
+smaller_res_model = train_and_evaluate(
+ smaller_res_model, initial_train_dataset, initial_val_dataset, epochs
+)
+
+"""
+### Freeze all the layers except for the final Batch Normalization layer
+
+For fine-tuning, we train only two layers:
+
+* The final Batch Normalization ([Ioffe et al.](https://arxiv.org/abs/1502.03167)) layer.
+* The classification layer.
+
+We are unfreezing the final Batch Normalization layer to compensate for the change in
+activation statistics before the global average pooling layer. As shown in
+[the paper](https://arxiv.org/abs/1906.06423), unfreezing the final Batch
+Normalization layer is enough.
+
+For a comprehensive guide on fine-tuning models in Keras, refer to
+[this tutorial](https://keras.io/guides/transfer_learning/).
+"""
+
+for layer in smaller_res_model.layers[2].layers:
+ layer.trainable = False
+
+smaller_res_model.layers[2].get_layer("post_bn").trainable = True
+
+epochs = 10
+
+# Use a lower learning rate during fine-tuning.
+bigger_res_model = train_and_evaluate(
+ smaller_res_model,
+ finetune_train_dataset,
+ finetune_val_dataset,
+ epochs,
+ learning_rate=1e-4,
+)
+
+"""
+## Experiment 2: Train a model on 224x224 resolution from scratch
+
+Now, we train another model from scratch on the larger resolution dataset. Recall that
+the augmentation transforms used in this dataset are different from before.
+"""
+
+epochs = 30
+
+vanilla_bigger_res_model = get_training_model()
+vanilla_bigger_res_model = train_and_evaluate(
+ vanilla_bigger_res_model, vanilla_train_dataset, vanilla_val_dataset, epochs
+)
+
+"""
+As we can notice from the above cells, FixRes leads to a better performance. Another
+advantage of FixRes is the improved total training time and reduction in GPU memory usage.
+FixRes is model-agnostic, you can use it on any image classification model
+to potentially boost performance.
+
+You can find more results
+[here](https://tensorboard.dev/experiment/BQOg28w0TlmvuJYeqsVntw)
+that were gathered by running the same code with different random seeds.
+"""
diff --git a/knowledge_base/vision/focal_modulation_network.py b/knowledge_base/vision/focal_modulation_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..d50fa5fd36cf5b0a369f743d86b02c004176f3d5
--- /dev/null
+++ b/knowledge_base/vision/focal_modulation_network.py
@@ -0,0 +1,1085 @@
+"""
+Title: Focal Modulation: A replacement for Self-Attention
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+Date created: 2023/01/25
+Last modified: 2023/02/15
+Description: Image classification with Focal Modulation Networks.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This tutorial aims to provide a comprehensive guide to the implementation of
+Focal Modulation Networks, as presented in
+[Yang et al.](https://arxiv.org/abs/2203.11926).
+
+This tutorial will provide a formal, minimalistic approach to implementing Focal
+Modulation Networks and explore its potential applications in the field of Deep Learning.
+
+**Problem statement**
+
+The Transformer architecture ([Vaswani et al.](https://arxiv.org/abs/1706.03762)),
+which has become the de facto standard in most Natural Language Processing tasks, has
+also been applied to the field of computer vision, e.g. Vision
+Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929v2)).
+
+> In Transformers, the self-attention (SA) is arguably the key to its success which
+enables input-dependent global interactions, in contrast to convolution operation which
+constraints interactions in a local region with a shared kernel.
+
+The **Attention** module is mathematically written as shown in **Equation 1**.
+
+|  |
+| :--: |
+| Equation 1: The mathematical equation of attention (Source: Aritra and Ritwik) |
+
+Where:
+
+- `Q` is the query
+- `K` is the key
+- `V` is the value
+- `d_k` is the dimension of the key
+
+With **self-attention**, the query, key, and value are all sourced from the input
+sequence. Let us rewrite the attention equation for self-attention as shown in **Equation
+2**.
+
+|  |
+| :--: |
+| Equation 2: The mathematical equation of self-attention (Source: Aritra and Ritwik) |
+
+Upon looking at the equation of self-attention, we see that it is a quadratic equation.
+Therefore, as the number of tokens increase, so does the computation time (cost too). To
+mitigate this problem and make Transformers more interpretable, Yang et al.
+have tried to replace the Self-Attention module with better components.
+
+**The Solution**
+
+Yang et al. introduce the Focal Modulation layer to serve as a
+seamless replacement for the Self-Attention Layer. The layer boasts high
+interpretability, making it a valuable tool for Deep Learning practitioners.
+
+In this tutorial, we will delve into the practical application of this layer by training
+the entire model on the CIFAR-10 dataset and visually interpreting the layer's
+performance.
+
+Note: We try to align our implementation with the
+[official implementation](https://github.com/microsoft/FocalNet).
+"""
+
+"""
+## Setup and Imports
+
+We use tensorflow version `2.11.0` for this tutorial.
+"""
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+from tensorflow.keras.optimizers.experimental import AdamW
+from typing import Optional, Tuple, List
+from matplotlib import pyplot as plt
+from random import randint
+
+# Set seed for reproducibility.
+tf.keras.utils.set_random_seed(42)
+
+"""
+## Global Configuration
+
+We do not have any strong rationale behind choosing these hyperparameters. Please feel
+free to change the configuration and train the model.
+"""
+
+# DATA
+TRAIN_SLICE = 40000
+BUFFER_SIZE = 2048
+BATCH_SIZE = 1024
+AUTO = tf.data.AUTOTUNE
+INPUT_SHAPE = (32, 32, 3)
+IMAGE_SIZE = 48
+NUM_CLASSES = 10
+
+# OPTIMIZER
+LEARNING_RATE = 1e-4
+WEIGHT_DECAY = 1e-4
+
+# TRAINING
+EPOCHS = 25
+
+"""
+## Load and process the CIFAR-10 dataset
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),
+ (x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),
+)
+
+"""
+### Build the augmentations
+
+We use the `keras.Sequential` API to compose all the individual augmentation steps
+into one API.
+"""
+
+# Build the `train` augmentation pipeline.
+train_aug = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
+ layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
+ layers.RandomFlip("horizontal"),
+ ],
+ name="train_data_augmentation",
+)
+
+# Build the `val` and `test` data pipeline.
+test_aug = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
+ ],
+ name="test_data_augmentation",
+)
+
+"""
+### Build `tf.data` pipeline
+"""
+
+train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = (
+ train_ds.map(
+ lambda image, label: (train_aug(image), label), num_parallel_calls=AUTO
+ )
+ .shuffle(BUFFER_SIZE)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = (
+ val_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = (
+ test_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+"""
+## Architecture
+
+We pause here to take a quick look at the Architecture of the Focal Modulation Network.
+**Figure 1** shows how every individual layer is compiled into a single model. This gives
+us a bird's eye view of the entire architecture.
+
+|  |
+| :--: |
+| Figure 1: A diagram of the Focal Modulation model (Source: Aritra and Ritwik) |
+
+We dive deep into each of these layers in the following sections. This is the order we
+will follow:
+
+
+- Patch Embedding Layer
+- Focal Modulation Block
+ - Multi-Layer Perceptron
+ - Focal Modulation Layer
+ - Hierarchical Contextualization
+ - Gated Aggregation
+ - Building Focal Modulation Block
+- Building the Basic Layer
+
+To better understand the architecture in a format we are well versed in, let us see how
+the Focal Modulation Network would look when drawn like a Transformer architecture.
+
+**Figure 2** shows the encoder layer of a traditional Transformer architecture where Self
+Attention is replaced with the Focal Modulation layer.
+
+The blue blocks represent the Focal Modulation block. A stack
+of these blocks builds a single Basic Layer. The green blocks
+represent the Focal Modulation layer.
+
+|  |
+| :--: |
+| Figure 2: The Entire Architecture (Source: Aritra and Ritwik) |
+"""
+
+"""
+## Patch Embedding Layer
+
+The patch embedding layer is used to patchify the input images and project them into a
+latent space. This layer is also used as the down-sampling layer in the architecture.
+"""
+
+
+class PatchEmbed(layers.Layer):
+ """Image patch embedding layer, also acts as the down-sampling layer.
+
+ Args:
+ image_size (Tuple[int]): Input image resolution.
+ patch_size (Tuple[int]): Patch spatial resolution.
+ embed_dim (int): Embedding dimension.
+ """
+
+ def __init__(
+ self,
+ image_size: Tuple[int] = (224, 224),
+ patch_size: Tuple[int] = (4, 4),
+ embed_dim: int = 96,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ patch_resolution = [
+ image_size[0] // patch_size[0],
+ image_size[1] // patch_size[1],
+ ]
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.patch_resolution = patch_resolution
+ self.num_patches = patch_resolution[0] * patch_resolution[1]
+ self.proj = layers.Conv2D(
+ filters=embed_dim, kernel_size=patch_size, strides=patch_size
+ )
+ self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
+ self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
+
+ def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, int, int, int]:
+ """Patchifies the image and converts into tokens.
+
+ Args:
+ x: Tensor of shape (B, H, W, C)
+
+ Returns:
+ A tuple of the processed tensor, height of the projected
+ feature map, width of the projected feature map, number
+ of channels of the projected feature map.
+ """
+ # Project the inputs.
+ x = self.proj(x)
+
+ # Obtain the shape from the projected tensor.
+ height = tf.shape(x)[1]
+ width = tf.shape(x)[2]
+ channels = tf.shape(x)[3]
+
+ # B, H, W, C -> B, H*W, C
+ x = self.norm(self.flatten(x))
+
+ return x, height, width, channels
+
+
+"""
+## Focal Modulation block
+
+A Focal Modulation block can be considered as a single Transformer Block with the Self
+Attention (SA) module being replaced with Focal Modulation module, as we saw in **Figure
+2**.
+
+Let us recall how a focal modulation block is supposed to look like with the aid of the
+**Figure 3**.
+
+
+|  |
+| :--: |
+| Figure 3: The isolated view of the Focal Modulation Block (Source: Aritra and Ritwik) |
+
+The Focal Modulation Block consists of:
+- Multilayer Perceptron
+- Focal Modulation layer
+"""
+
+"""
+### Multilayer Perceptron
+"""
+
+
+def MLP(
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ mlp_drop_rate: float = 0.0,
+):
+ hidden_features = hidden_features or in_features
+ out_features = out_features or in_features
+
+ return keras.Sequential(
+ [
+ layers.Dense(units=hidden_features, activation=keras.activations.gelu),
+ layers.Dense(units=out_features),
+ layers.Dropout(rate=mlp_drop_rate),
+ ]
+ )
+
+
+"""
+### Focal Modulation layer
+
+In a typical Transformer architecture, for each visual token (**query**) `x_i in R^C` in
+an input feature map `X in R^{HxWxC}` a **generic encoding process** produces a feature
+representation `y_i in R^C`.
+
+The encoding process consists of **interaction** (with its surroundings for e.g. a dot
+product), and **aggregation** (over the contexts for e.g weighted mean).
+
+We will talk about two types of encoding here:
+- Interaction and then Aggregation in **Self-Attention**
+- Aggregation and then Interaction in **Focal Modulation**
+
+**Self-Attention**
+
+|  |
+| :--: |
+| **Figure 4**: Self-Attention module. (Source: Aritra and Ritwik) |
+
+|  |
+| :--: |
+| **Equation 3:** Aggregation and Interaction in Self-Attention(Surce: Aritra and Ritwik)|
+
+As shown in **Figure 4** the query and the key interact (in the interaction step) with
+each other to output the attention scores. The weighted aggregation of the value comes
+next, known as the aggregation step.
+
+**Focal Modulation**
+
+|  |
+| :--: |
+| **Figure 5**: Focal Modulation module. (Source: Aritra and Ritwik) |
+
+|  |
+| :--: |
+| **Equation 4:** Aggregation and Interaction in Focal Modulation (Source: Aritra and Ritwik) |
+
+**Figure 5** depicts the Focal Modulation layer. `q()` is the query projection
+function. It is a **linear layer** that projects the query into a latent space. `m ()` is
+the context aggregation function. Unlike self-attention, the
+aggregation step takes place in focal modulation before the interaction step.
+"""
+
+"""
+While `q()` is pretty straightforward to understand, the context aggregation function
+`m()` is more complex. Therefore, this section will focus on `m()`.
+
+| |
+| :--: |
+| **Figure 6**: Context Aggregation function `m()`. (Source: Aritra and Ritwik) |
+
+The context aggregation function `m()` consists of two parts as shown in **Figure 6**:
+- Hierarchical Contextualization
+- Gated Aggregation
+"""
+
+"""
+#### Hierarchical Contextualization
+
+| |
+| :--: |
+| **Figure 7**: Hierarchical Contextualization (Source: Aritra and Ritwik) |
+
+In **Figure 7**, we see that the input is first projected linearly. This linear projection
+produces `Z^0`. Where `Z^0` can be expressed as follows:
+
+|  |
+| :--: |
+| Equation 5: Linear projection of `Z^0` (Source: Aritra and Ritwik) |
+
+`Z^0` is then passed on to a series of Depth-Wise (DWConv) Conv and
+[GeLU](https://www.tensorflow.org/api_docs/python/tf/keras/activations/gelu) layers. The
+authors term each block of DWConv and GeLU as levels denoted by `l`. In **Figure 6** we
+have two levels. Mathematically this is represented as:
+
+|  |
+| :--: |
+| Equation 6: Levels of the modulation layer (Source: Aritra and Ritwik) |
+
+where `l in {1, ... , L}`
+
+The final feature map goes through a Global Average Pooling Layer. This can be expressed
+as follows:
+
+|  |
+| :--: |
+| Equation 7: Average Pooling of the final feature (Source: Aritra and Ritwik)|
+"""
+
+"""
+#### Gated Aggregation
+
+| |
+| :--: |
+| **Figure 8**: Gated Aggregation (Source: Aritra and Ritwik) |
+
+Now that we have `L+1` intermediate feature maps by virtue of the Hierarchical
+Contextualization step, we need a gating mechanism that lets some features pass and
+prohibits others. This can be implemented with the attention module.
+Later in the tutorial, we will visualize these gates to better understand their
+usefulness.
+
+First, we build the weights for aggregation. Here we apply a **linear layer** on the input
+feature map that projects it into `L+1` dimensions.
+
+|  |
+| :--: |
+| Eqation 8: Gates (Source: Aritra and Ritwik) |
+
+Next we perform the weighted aggregation over the contexts.
+
+|  |
+| :--: |
+| Eqation 9: Final feature map (Source: Aritra and Ritwik) |
+
+To enable communication across different channels, we use another linear layer `h()`
+to obtain the modulator
+
+|  |
+| :--: |
+| Eqation 10: Modulator (Source: Aritra and Ritwik) |
+
+To sum up the Focal Modulation layer we have:
+
+|  |
+| :--: |
+| Eqation 11: Focal Modulation Layer (Source: Aritra and Ritwik) |
+"""
+
+
+class FocalModulationLayer(layers.Layer):
+ """The Focal Modulation layer includes query projection & context aggregation.
+
+ Args:
+ dim (int): Projection dimension.
+ focal_window (int): Window size for focal modulation.
+ focal_level (int): The current focal level.
+ focal_factor (int): Factor of focal modulation.
+ proj_drop_rate (float): Rate of dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ focal_window: int,
+ focal_level: int,
+ focal_factor: int = 2,
+ proj_drop_rate: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.focal_window = focal_window
+ self.focal_level = focal_level
+ self.focal_factor = focal_factor
+ self.proj_drop_rate = proj_drop_rate
+
+ # Project the input feature into a new feature space using a
+ # linear layer. Note the `units` used. We will be projecting the input
+ # feature all at once and split the projection into query, context,
+ # and gates.
+ self.initial_proj = layers.Dense(
+ units=(2 * self.dim) + (self.focal_level + 1),
+ use_bias=True,
+ )
+ self.focal_layers = list()
+ self.kernel_sizes = list()
+ for idx in range(self.focal_level):
+ kernel_size = (self.focal_factor * idx) + self.focal_window
+ depth_gelu_block = keras.Sequential(
+ [
+ layers.ZeroPadding2D(padding=(kernel_size // 2, kernel_size // 2)),
+ layers.Conv2D(
+ filters=self.dim,
+ kernel_size=kernel_size,
+ activation=keras.activations.gelu,
+ groups=self.dim,
+ use_bias=False,
+ ),
+ ]
+ )
+ self.focal_layers.append(depth_gelu_block)
+ self.kernel_sizes.append(kernel_size)
+ self.activation = keras.activations.gelu
+ self.gap = layers.GlobalAveragePooling2D(keepdims=True)
+ self.modulator_proj = layers.Conv2D(
+ filters=self.dim,
+ kernel_size=(1, 1),
+ use_bias=True,
+ )
+ self.proj = layers.Dense(units=self.dim)
+ self.proj_drop = layers.Dropout(self.proj_drop_rate)
+
+ def call(self, x: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
+ """Forward pass of the layer.
+
+ Args:
+ x: Tensor of shape (B, H, W, C)
+ """
+ # Apply the linear projecion to the input feature map
+ x_proj = self.initial_proj(x)
+
+ # Split the projected x into query, context and gates
+ query, context, self.gates = tf.split(
+ value=x_proj,
+ num_or_size_splits=[self.dim, self.dim, self.focal_level + 1],
+ axis=-1,
+ )
+
+ # Context aggregation
+ context = self.focal_layers[0](context)
+ context_all = context * self.gates[..., 0:1]
+ for idx in range(1, self.focal_level):
+ context = self.focal_layers[idx](context)
+ context_all += context * self.gates[..., idx : idx + 1]
+
+ # Build the global context
+ context_global = self.activation(self.gap(context))
+ context_all += context_global * self.gates[..., self.focal_level :]
+
+ # Focal Modulation
+ self.modulator = self.modulator_proj(context_all)
+ x_output = query * self.modulator
+
+ # Project the output and apply dropout
+ x_output = self.proj(x_output)
+ x_output = self.proj_drop(x_output)
+
+ return x_output
+
+
+"""
+### The Focal Modulation block
+
+Finally, we have all the components we need to build the Focal Modulation block. Here we
+take the MLP and Focal Modulation layer together and build the Focal Modulation block.
+"""
+
+
+class FocalModulationBlock(layers.Layer):
+ """Combine FFN and Focal Modulation Layer.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (Tuple[int]): Input resulotion.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float): Dropout rate.
+ drop_path (float): Stochastic depth rate.
+ focal_level (int): Number of focal levels.
+ focal_window (int): Focal window size at first focal level
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ input_resolution: Tuple[int],
+ mlp_ratio: float = 4.0,
+ drop: float = 0.0,
+ drop_path: float = 0.0,
+ focal_level: int = 1,
+ focal_window: int = 3,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.mlp_ratio = mlp_ratio
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+ self.norm = layers.LayerNormalization(epsilon=1e-5)
+ self.modulation = FocalModulationLayer(
+ dim=self.dim,
+ focal_window=self.focal_window,
+ focal_level=self.focal_level,
+ proj_drop_rate=drop,
+ )
+ mlp_hidden_dim = int(self.dim * self.mlp_ratio)
+ self.mlp = MLP(
+ in_features=self.dim,
+ hidden_features=mlp_hidden_dim,
+ mlp_drop_rate=drop,
+ )
+
+ def call(self, x: tf.Tensor, height: int, width: int, channels: int) -> tf.Tensor:
+ """Processes the input tensor through the focal modulation block.
+
+ Args:
+ x (tf.Tensor): Inputs of the shape (B, L, C)
+ height (int): The height of the feature map
+ width (int): The width of the feature map
+ channels (int): The number of channels of the feature map
+
+ Returns:
+ The processed tensor.
+ """
+ shortcut = x
+
+ # Focal Modulation
+ x = tf.reshape(x, shape=(-1, height, width, channels))
+ x = self.modulation(x)
+ x = tf.reshape(x, shape=(-1, height * width, channels))
+
+ # FFN
+ x = shortcut + x
+ x = x + self.mlp(self.norm(x))
+ return x
+
+
+"""
+## The Basic Layer
+
+The basic layer consists of a collection of Focal Modulation blocks. This is
+illustrated in **Figure 9**.
+
+|  |
+| :--: |
+| **Figure 9**: Basic Layer, a collection of focal modulation blocks. (Source: Aritra and Ritwik) |
+
+Notice how in **Fig. 9** there are more than one focal modulation blocks denoted by `Nx`.
+This shows how the Basic Layer is a collection of Focal Modulation blocks.
+"""
+
+
+class BasicLayer(layers.Layer):
+ """Collection of Focal Modulation Blocks.
+
+ Args:
+ dim (int): Dimensions of the model.
+ out_dim (int): Dimension used by the Patch Embedding Layer.
+ input_resolution (Tuple[int]): Input image resolution.
+ depth (int): The number of Focal Modulation Blocks.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float): Dropout rate.
+ downsample (tf.keras.layers.Layer): Downsampling layer at the end of the layer.
+ focal_level (int): The current focal level.
+ focal_window (int): Focal window used.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ out_dim: int,
+ input_resolution: Tuple[int],
+ depth: int,
+ mlp_ratio: float = 4.0,
+ drop: float = 0.0,
+ downsample=None,
+ focal_level: int = 1,
+ focal_window: int = 1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.blocks = [
+ FocalModulationBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ focal_level=focal_level,
+ focal_window=focal_window,
+ )
+ for i in range(self.depth)
+ ]
+
+ # Downsample layer at the end of the layer
+ if downsample is not None:
+ self.downsample = downsample(
+ image_size=input_resolution,
+ patch_size=(2, 2),
+ embed_dim=out_dim,
+ )
+ else:
+ self.downsample = None
+
+ def call(
+ self, x: tf.Tensor, height: int, width: int, channels: int
+ ) -> Tuple[tf.Tensor, int, int, int]:
+ """Forward pass of the layer.
+
+ Args:
+ x (tf.Tensor): Tensor of shape (B, L, C)
+ height (int): Height of feature map
+ width (int): Width of feature map
+ channels (int): Embed Dim of feature map
+
+ Returns:
+ A tuple of the processed tensor, changed height, width, and
+ dim of the tensor.
+ """
+ # Apply Focal Modulation Blocks
+ for block in self.blocks:
+ x = block(x, height, width, channels)
+
+ # Except the last Basic Layer, all the layers have
+ # downsample at the end of it.
+ if self.downsample is not None:
+ x = tf.reshape(x, shape=(-1, height, width, channels))
+ x, height_o, width_o, channels_o = self.downsample(x)
+ else:
+ height_o, width_o, channels_o = height, width, channels
+
+ return x, height_o, width_o, channels_o
+
+
+"""
+## The Focal Modulation Network model
+
+This is the model that ties everything together.
+It consists of a collection of Basic Layers with a classification head.
+For a recap of how this is structured refer to **Figure 1**.
+"""
+
+
+class FocalModulationNetwork(keras.Model):
+ """The Focal Modulation Network.
+
+ Parameters:
+ image_size (Tuple[int]): Spatial size of images used.
+ patch_size (Tuple[int]): Patch size of each patch.
+ num_classes (int): Number of classes used for classification.
+ embed_dim (int): Patch embedding dimension.
+ depths (List[int]): Depth of each Focal Transformer block.
+ mlp_ratio (float): Ratio of expansion for the intermediate layer of MLP.
+ drop_rate (float): The dropout rate for FM and MLP layers.
+ focal_levels (list): How many focal levels at all stages.
+ Note that this excludes the finest-grain level.
+ focal_windows (list): The focal window size at all stages.
+ """
+
+ def __init__(
+ self,
+ image_size: Tuple[int] = (48, 48),
+ patch_size: Tuple[int] = (4, 4),
+ num_classes: int = 10,
+ embed_dim: int = 256,
+ depths: List[int] = [2, 3, 2],
+ mlp_ratio: float = 4.0,
+ drop_rate: float = 0.1,
+ focal_levels=[2, 2, 2],
+ focal_windows=[3, 3, 3],
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.num_layers = len(depths)
+ embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
+ self.num_classes = num_classes
+ self.embed_dim = embed_dim
+ self.num_features = embed_dim[-1]
+ self.mlp_ratio = mlp_ratio
+ self.patch_embed = PatchEmbed(
+ image_size=image_size,
+ patch_size=patch_size,
+ embed_dim=embed_dim[0],
+ )
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patch_resolution
+ self.patches_resolution = patches_resolution
+ self.pos_drop = layers.Dropout(drop_rate)
+ self.basic_layers = list()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=embed_dim[i_layer],
+ out_dim=(
+ embed_dim[i_layer + 1] if (i_layer < self.num_layers - 1) else None
+ ),
+ input_resolution=(
+ patches_resolution[0] // (2**i_layer),
+ patches_resolution[1] // (2**i_layer),
+ ),
+ depth=depths[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ drop=drop_rate,
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
+ focal_level=focal_levels[i_layer],
+ focal_window=focal_windows[i_layer],
+ )
+ self.basic_layers.append(layer)
+ self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
+ self.avgpool = layers.GlobalAveragePooling1D()
+ self.flatten = layers.Flatten()
+ self.head = layers.Dense(self.num_classes, activation="softmax")
+
+ def call(self, x: tf.Tensor) -> tf.Tensor:
+ """Forward pass of the layer.
+
+ Args:
+ x: Tensor of shape (B, H, W, C)
+
+ Returns:
+ The logits.
+ """
+ # Patch Embed the input images.
+ x, height, width, channels = self.patch_embed(x)
+ x = self.pos_drop(x)
+
+ for idx, layer in enumerate(self.basic_layers):
+ x, height, width, channels = layer(x, height, width, channels)
+
+ x = self.norm(x)
+ x = self.avgpool(x)
+ x = self.flatten(x)
+ x = self.head(x)
+ return x
+
+
+"""
+## Train the model
+
+Now with all the components in place and the architecture actually built, we are ready to
+put it to good use.
+
+In this section, we train our Focal Modulation model on the CIFAR-10 dataset.
+"""
+
+"""
+### Visualization Callback
+
+A key feature of the Focal Modulation Network is explicit input-dependency. This means
+the modulator is calculated by looking at the local features around the target location,
+so it depends on the input. In very simple terms, this makes interpretation easy. We can
+simply lay down the gating values and the original image, next to each other to see how
+the gating mechanism works.
+
+The authors of the paper visualize the gates and the modulator in order to focus on the
+interpretability of the Focal Modulation layer. Below is a visualization
+callback that shows the gates and modulator of a specific layer in the model while the
+model trains.
+
+We will notice later that as the model trains, the visualizations get better.
+
+The gates appear to selectively permit certain aspects of the input image to pass
+through, while gently disregarding others, ultimately leading to improved classification
+accuracy.
+"""
+
+
+def display_grid(
+ test_images: tf.Tensor,
+ gates: tf.Tensor,
+ modulator: tf.Tensor,
+):
+ """Displays the image with the gates and modulator overlayed.
+
+ Args:
+ test_images (tf.Tensor): A batch of test images.
+ gates (tf.Tensor): The gates of the Focal Modualtion Layer.
+ modulator (tf.Tensor): The modulator of the Focal Modulation Layer.
+ """
+ fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(25, 5))
+
+ # Radomly sample an image from the batch.
+ index = randint(0, BATCH_SIZE - 1)
+ orig_image = test_images[index]
+ gate_image = gates[index]
+ modulator_image = modulator[index]
+
+ # Original Image
+ ax[0].imshow(orig_image)
+ ax[0].set_title("Original:")
+ ax[0].axis("off")
+
+ for index in range(1, 5):
+ img = ax[index].imshow(orig_image)
+ if index != 4:
+ overlay_image = gate_image[..., index - 1]
+ title = f"G {index}:"
+ else:
+ overlay_image = tf.norm(modulator_image, ord=2, axis=-1)
+ title = f"MOD:"
+
+ ax[index].imshow(
+ overlay_image, cmap="inferno", alpha=0.6, extent=img.get_extent()
+ )
+ ax[index].set_title(title)
+ ax[index].axis("off")
+
+ plt.axis("off")
+ plt.show()
+ plt.close()
+
+
+"""
+### TrainMonitor
+"""
+
+# Taking a batch of test inputs to measure the model's progress.
+test_images, test_labels = next(iter(test_ds))
+upsampler = tf.keras.layers.UpSampling2D(
+ size=(4, 4),
+ interpolation="bilinear",
+)
+
+
+class TrainMonitor(keras.callbacks.Callback):
+ def __init__(self, epoch_interval=None):
+ self.epoch_interval = epoch_interval
+
+ def on_epoch_end(self, epoch, logs=None):
+ if self.epoch_interval and epoch % self.epoch_interval == 0:
+ _ = self.model(test_images)
+
+ # Take the mid layer for visualization
+ gates = self.model.basic_layers[1].blocks[-1].modulation.gates
+ gates = upsampler(gates)
+ modulator = self.model.basic_layers[1].blocks[-1].modulation.modulator
+ modulator = upsampler(modulator)
+
+ # Display the grid of gates and modulator.
+ display_grid(test_images=test_images, gates=gates, modulator=modulator)
+
+
+"""
+### Learning Rate scheduler
+"""
+
+
+# Some code is taken from:
+# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ def __init__(
+ self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
+ ):
+ super().__init__()
+ self.learning_rate_base = learning_rate_base
+ self.total_steps = total_steps
+ self.warmup_learning_rate = warmup_learning_rate
+ self.warmup_steps = warmup_steps
+ self.pi = tf.constant(np.pi)
+
+ def __call__(self, step):
+ if self.total_steps < self.warmup_steps:
+ raise ValueError("Total_steps must be larger or equal to warmup_steps.")
+ cos_annealed_lr = tf.cos(
+ self.pi
+ * (tf.cast(step, tf.float32) - self.warmup_steps)
+ / float(self.total_steps - self.warmup_steps)
+ )
+ learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
+ if self.warmup_steps > 0:
+ if self.learning_rate_base < self.warmup_learning_rate:
+ raise ValueError(
+ "Learning_rate_base must be larger or equal to "
+ "warmup_learning_rate."
+ )
+ slope = (
+ self.learning_rate_base - self.warmup_learning_rate
+ ) / self.warmup_steps
+ warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
+ learning_rate = tf.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+ return tf.where(
+ step > self.total_steps, 0.0, learning_rate, name="learning_rate"
+ )
+
+
+total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
+warmup_epoch_percentage = 0.15
+warmup_steps = int(total_steps * warmup_epoch_percentage)
+scheduled_lrs = WarmUpCosine(
+ learning_rate_base=LEARNING_RATE,
+ total_steps=total_steps,
+ warmup_learning_rate=0.0,
+ warmup_steps=warmup_steps,
+)
+
+"""
+### Initialize, compile and train the model
+"""
+
+focal_mod_net = FocalModulationNetwork()
+optimizer = AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)
+
+# Compile and train the model.
+focal_mod_net.compile(
+ optimizer=optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+)
+history = focal_mod_net.fit(
+ train_ds,
+ epochs=EPOCHS,
+ validation_data=val_ds,
+ callbacks=[TrainMonitor(epoch_interval=10)],
+)
+
+"""
+## Plot loss and accuracy
+"""
+
+plt.plot(history.history["loss"], label="loss")
+plt.plot(history.history["val_loss"], label="val_loss")
+plt.legend()
+plt.show()
+
+plt.plot(history.history["accuracy"], label="accuracy")
+plt.plot(history.history["val_accuracy"], label="val_accuracy")
+plt.legend()
+plt.show()
+
+"""
+## Test visualizations
+
+Let's test our model on some test images and see how the gates look like.
+"""
+
+test_images, test_labels = next(iter(test_ds))
+_ = focal_mod_net(test_images)
+
+# Take the mid layer for visualization
+gates = focal_mod_net.basic_layers[1].blocks[-1].modulation.gates
+gates = upsampler(gates)
+modulator = focal_mod_net.basic_layers[1].blocks[-1].modulation.modulator
+modulator = upsampler(modulator)
+
+# Plot the test images with the gates and modulator overlayed.
+for row in range(5):
+ display_grid(
+ test_images=test_images,
+ gates=gates,
+ modulator=modulator,
+ )
+
+"""
+## Conclusion
+
+The proposed architecture, the Focal Modulation Network
+architecture is a mechanism that allows different
+parts of an image to interact with each other in a way that depends on the image itself.
+It works by first gathering different levels of context information around each part of
+the image (the "query token"), then using a gate to decide which context information is
+most relevant, and finally combining the chosen information in a simple but effective
+way.
+
+This is meant as a replacement of Self-Attention mechanism from the Transformer
+architecture. The key feature that makes this research notable is not the conception of
+attention-less networks, but rather the introduction of a equally powerful architecture
+that is interpretable.
+
+The authors also mention that they created a series of Focal Modulation Networks
+(FocalNets) that significantly outperform Self-Attention counterparts and with a fraction
+of parameters and pretraining data.
+
+The FocalNets architecture has the potential to deliver impressive results and offers a
+simple implementation. Its promising performance and ease of use make it an attractive
+alternative to Self-Attention for researchers to explore in their own projects. It could
+potentially become widely adopted by the Deep Learning community in the near future.
+
+## Acknowledgement
+
+We would like to thank [PyImageSearch](https://pyimagesearch.com/) for providing with a
+Colab Pro account, [JarvisLabs.ai](https://cloud.jarvislabs.ai/) for GPU credits,
+and also Microsoft Research for providing an
+[official implementation](https://github.com/microsoft/FocalNet) of their paper.
+We would also like to extend our gratitude to the first author of the
+paper [Jianwei Yang](https://twitter.com/jw2yang4ai) who reviewed this tutorial
+extensively.
+"""
diff --git a/knowledge_base/vision/forwardforward.py b/knowledge_base/vision/forwardforward.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0d8573c322974faab63b71ee4df927629255f1
--- /dev/null
+++ b/knowledge_base/vision/forwardforward.py
@@ -0,0 +1,446 @@
+"""
+Title: Using the Forward-Forward Algorithm for Image Classification
+Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
+Date created: 2023/01/08
+Last modified: 2024/09/17
+Description: Training a Dense-layer model using the Forward-Forward algorithm.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The following example explores how to use the Forward-Forward algorithm to perform
+training instead of the traditionally-used method of backpropagation, as proposed by
+Hinton in
+[The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf)
+(2022).
+
+The concept was inspired by the understanding behind
+[Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation
+involves calculating the difference between actual and predicted output via a cost
+function to adjust network weights. On the other hand, the FF Algorithm suggests the
+analogy of neurons which get "excited" based on looking at a certain recognized
+combination of an image and its correct corresponding label.
+
+This method takes certain inspiration from the biological learning process that occurs in
+the cortex. A significant advantage that this method brings is the fact that
+backpropagation through the network does not need to be performed anymore, and that
+weight updates are local to the layer itself.
+
+As this is yet still an experimental method, it does not yield state-of-the-art results.
+But with proper tuning, it is supposed to come close to the same.
+Through this example, we will examine a process that allows us to implement the
+Forward-Forward algorithm within the layers themselves, instead of the traditional method
+of relying on the global loss functions and optimizers.
+
+The tutorial is structured as follows:
+
+- Perform necessary imports
+- Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)
+- Visualize Random samples from the MNIST dataset
+- Define a `FFDense` Layer to override `call` and implement a custom `forwardforward`
+method which performs weight updates.
+- Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom
+functions for per-sample prediction and overlaying labels
+- Convert MNIST from `NumPy` arrays to `tf.data.Dataset`
+- Fit the network
+- Visualize results
+- Perform inference on test samples
+
+As this example requires the customization of certain core functions with
+`keras.layers.Layer` and `keras.models.Model`, refer to the following resources for
+a primer on how to do so:
+
+- [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)
+- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
+"""
+
+"""
+## Setup imports
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import tensorflow as tf
+import keras
+from keras import ops
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score
+import random
+from tensorflow.compiler.tf2xla.python import xla
+
+"""
+## Load the dataset and visualize the data
+
+We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset
+in the form of `NumPy` arrays. We then arrange it in the form of the train and test
+splits.
+
+Following loading the dataset, we select 4 random samples from within the training set
+and visualize them using `matplotlib.pyplot`.
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
+
+print("4 Random Training samples and labels")
+idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4)
+
+img1 = (x_train[idx1], y_train[idx1])
+img2 = (x_train[idx2], y_train[idx2])
+img3 = (x_train[idx3], y_train[idx3])
+img4 = (x_train[idx4], y_train[idx4])
+
+imgs = [img1, img2, img3, img4]
+
+plt.figure(figsize=(10, 10))
+
+for idx, item in enumerate(imgs):
+ image, label = item[0], item[1]
+ plt.subplot(2, 2, idx + 1)
+ plt.imshow(image, cmap="gray")
+ plt.title(f"Label : {label}")
+plt.show()
+
+"""
+## Define `FFDense` custom layer
+
+In this custom layer, we have a base `keras.layers.Dense` object which acts as the
+base `Dense` layer within. Since weight updates will happen within the layer itself, we
+add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we
+use `Adam` as our optimizer with a rather higher learning rate of `0.03`.
+
+Following the algorithm's specifics, we must set a `threshold` parameter that will be
+used to make the positive-negative decision in each prediction. This is set to a default
+of 2.0.
+As the epochs are localized to the layer itself, we also set a `num_epochs` parameter
+(defaults to 50).
+
+We override the `call` method in order to perform a normalization over the complete
+input space followed by running it through the base `Dense` layer as would happen in a
+normal `Dense` layer call.
+
+We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each
+representing the positive and negative samples respectively. We write a custom training
+loop here with the use of `tf.GradientTape()`, within which we calculate a loss per
+sample by taking the distance of the prediction from the threshold to understand the
+error and taking its mean to get a `mean_loss` metric.
+
+With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable
+base `Dense` layer and apply them using the layer's local optimizer.
+
+Finally, we return the `call` result as the `Dense` results of the positive and negative
+samples while also returning the last `mean_loss` metric and all the loss values over a
+certain all-epoch run.
+"""
+
+
+class FFDense(keras.layers.Layer):
+ """
+ A custom ForwardForward-enabled Dense layer. It has an implementation of the
+ Forward-Forward network internally for use.
+ This layer must be used in conjunction with the `FFNetwork` model.
+ """
+
+ def __init__(
+ self,
+ units,
+ init_optimizer,
+ loss_metric,
+ num_epochs=50,
+ use_bias=True,
+ kernel_initializer="glorot_uniform",
+ bias_initializer="zeros",
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dense = keras.layers.Dense(
+ units=units,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ )
+ self.relu = keras.layers.ReLU()
+ self.optimizer = init_optimizer()
+ self.loss_metric = loss_metric
+ self.threshold = 1.5
+ self.num_epochs = num_epochs
+
+ # We perform a normalization step before we run the input through the Dense
+ # layer.
+
+ def call(self, x):
+ x_norm = ops.norm(x, ord=2, axis=1, keepdims=True)
+ x_norm = x_norm + 1e-4
+ x_dir = x / x_norm
+ res = self.dense(x_dir)
+ return self.relu(res)
+
+ # The Forward-Forward algorithm is below. We first perform the Dense-layer
+ # operation and then get a Mean Square value for all positive and negative
+ # samples respectively.
+ # The custom loss function finds the distance between the Mean-squared
+ # result and the threshold value we set (a hyperparameter) that will define
+ # whether the prediction is positive or negative in nature. Once the loss is
+ # calculated, we get a mean across the entire batch combined and perform a
+ # gradient calculation and optimization step. This does not technically
+ # qualify as backpropagation since there is no gradient being
+ # sent to any previous layer and is completely local in nature.
+
+ def forward_forward(self, x_pos, x_neg):
+ for i in range(self.num_epochs):
+ with tf.GradientTape() as tape:
+ g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1)
+ g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1)
+
+ loss = ops.log(
+ 1
+ + ops.exp(
+ ops.concatenate(
+ [-g_pos + self.threshold, g_neg - self.threshold], 0
+ )
+ )
+ )
+ mean_loss = ops.cast(ops.mean(loss), dtype="float32")
+ self.loss_metric.update_state([mean_loss])
+ gradients = tape.gradient(mean_loss, self.dense.trainable_weights)
+ self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))
+ return (
+ ops.stop_gradient(self.call(x_pos)),
+ ops.stop_gradient(self.call(x_neg)),
+ self.loss_metric.result(),
+ )
+
+
+"""
+## Define the `FFNetwork` Custom Model
+
+With our custom layer defined, we also need to override the `train_step` method and
+define a custom `keras.models.Model` that works with our `FFDense` layer.
+
+For this algorithm, we must 'embed' the labels onto the original image. To do so, we
+exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We
+use that as a label space in order to visually one-hot-encode the labels within the image
+itself. This action is performed by the `overlay_y_on_x` function.
+
+We break down the prediction function with a per-sample prediction function which is then
+called over the entire test set by the overriden `predict()` function. The prediction is
+performed here with the help of measuring the `excitation` of the neurons per layer for
+each image. This is then summed over all layers to calculate a network-wide 'goodness
+score'. The label with the highest 'goodness score' is then chosen as the sample
+prediction.
+
+The `train_step` function is overriden to act as the main controlling loop for running
+training on each layer as per the number of epochs per layer.
+"""
+
+
+class FFNetwork(keras.Model):
+ """
+ A `keras.Model` that supports a `FFDense` network creation. This model
+ can work for any kind of classification task. It has an internal
+ implementation with some details specific to the MNIST dataset which can be
+ changed as per the use-case.
+ """
+
+ # Since each layer runs gradient-calculation and optimization locally, each
+ # layer has its own optimizer that we pass. As a standard choice, we pass
+ # the `Adam` optimizer with a default learning rate of 0.03 as that was
+ # found to be the best rate after experimentation.
+ # Loss is tracked using `loss_var` and `loss_count` variables.
+
+ def __init__(
+ self,
+ dims,
+ init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.init_layer_optimizer = init_layer_optimizer
+ self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32")
+ self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32")
+ self.layer_list = [keras.Input(shape=(dims[0],))]
+ self.metrics_built = False
+ for d in range(len(dims) - 1):
+ self.layer_list += [
+ FFDense(
+ dims[d + 1],
+ init_optimizer=self.init_layer_optimizer,
+ loss_metric=keras.metrics.Mean(),
+ )
+ ]
+
+ # This function makes a dynamic change to the image wherein the labels are
+ # put on top of the original image (for this example, as MNIST has 10
+ # unique labels, we take the top-left corner's first 10 pixels). This
+ # function returns the original data tensor with the first 10 pixels being
+ # a pixel-based one-hot representation of the labels.
+
+ @tf.function(reduce_retracing=True)
+ def overlay_y_on_x(self, data):
+ X_sample, y_sample = data
+ max_sample = ops.amax(X_sample, axis=0, keepdims=True)
+ max_sample = ops.cast(max_sample, dtype="float64")
+ X_zeros = ops.zeros([10], dtype="float64")
+ X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])
+ X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])
+ return X_sample, y_sample
+
+ # A custom `predict_one_sample` performs predictions by passing the images
+ # through the network, measures the results produced by each layer (i.e.
+ # how high/low the output values are with respect to the set threshold for
+ # each label) and then simply finding the label with the highest values.
+ # In such a case, the images are tested for their 'goodness' with all
+ # labels.
+
+ @tf.function(reduce_retracing=True)
+ def predict_one_sample(self, x):
+ goodness_per_label = []
+ x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]])
+ for label in range(10):
+ h, label = self.overlay_y_on_x(data=(x, label))
+ h = ops.reshape(h, [-1, ops.shape(h)[0]])
+ goodness = []
+ for layer_idx in range(1, len(self.layer_list)):
+ layer = self.layer_list[layer_idx]
+ h = layer(h)
+ goodness += [ops.mean(ops.power(h, 2), 1)]
+ goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)]
+ goodness_per_label = tf.concat(goodness_per_label, 1)
+ return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64")
+
+ def predict(self, data):
+ x = data
+ preds = list()
+ preds = ops.vectorized_map(self.predict_one_sample, x)
+ return np.asarray(preds, dtype=int)
+
+ # This custom `train_step` function overrides the internal `train_step`
+ # implementation. We take all the input image tensors, flatten them and
+ # subsequently produce positive and negative samples on the images.
+ # A positive sample is an image that has the right label encoded on it with
+ # the `overlay_y_on_x` function. A negative sample is an image that has an
+ # erroneous label present on it.
+ # With the samples ready, we pass them through each `FFLayer` and perform
+ # the Forward-Forward computation on it. The returned loss is the final
+ # loss value over all the layers.
+
+ @tf.function(jit_compile=False)
+ def train_step(self, data):
+ x, y = data
+
+ if not self.metrics_built:
+ # build metrics to ensure they can be queried without erroring out.
+ # We can't update the metrics' state, as we would usually do, since
+ # we do not perform predictions within the train step
+ for metric in self.metrics:
+ if hasattr(metric, "build"):
+ metric.build(y, y)
+ self.metrics_built = True
+
+ # Flatten op
+ x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]])
+
+ x_pos, y = ops.vectorized_map(self.overlay_y_on_x, (x, y))
+
+ random_y = tf.random.shuffle(y)
+ x_neg, y = tf.map_fn(self.overlay_y_on_x, (x, random_y))
+
+ h_pos, h_neg = x_pos, x_neg
+
+ for idx, layer in enumerate(self.layers):
+ if isinstance(layer, FFDense):
+ print(f"Training layer {idx+1} now : ")
+ h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg)
+ self.loss_var.assign_add(loss)
+ self.loss_count.assign_add(1.0)
+ else:
+ print(f"Passing layer {idx+1} now : ")
+ x = layer(x)
+ mean_res = ops.divide(self.loss_var, self.loss_count)
+ return {"FinalLoss": mean_res}
+
+
+"""
+## Convert MNIST `NumPy` arrays to `tf.data.Dataset`
+
+We now perform some preliminary processing on the `NumPy` arrays and then convert them
+into the `tf.data.Dataset` format which allows for optimized loading.
+"""
+
+x_train = x_train.astype(float) / 255
+x_test = x_test.astype(float) / 255
+y_train = y_train.astype(int)
+y_test = y_test.astype(int)
+
+train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+
+train_dataset = train_dataset.batch(60000)
+test_dataset = test_dataset.batch(10000)
+
+"""
+## Fit the network and visualize results
+
+Having performed all previous set-up, we are now going to run `model.fit()` and run 250
+model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss
+curve as each layer is trained.
+"""
+
+model = FFNetwork(dims=[784, 500, 500])
+
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=0.03),
+ loss="mse",
+ jit_compile=False,
+ metrics=[],
+)
+
+epochs = 250
+history = model.fit(train_dataset, epochs=epochs)
+
+"""
+## Perform inference and testing
+
+Having trained the model to a large extent, we now see how it performs on the
+test set. We calculate the Accuracy Score to understand the results closely.
+"""
+
+preds = model.predict(ops.convert_to_tensor(x_test))
+
+preds = preds.reshape((preds.shape[0], preds.shape[1]))
+
+results = accuracy_score(preds, y_test)
+
+print(f"Test Accuracy score : {results*100}%")
+
+plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"])
+plt.title("Loss over training")
+plt.show()
+
+"""
+## Conclusion
+
+This example has hereby demonstrated how the Forward-Forward algorithm works using
+the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton
+in their paper are currently still limited to smaller models and datasets like MNIST and
+Fashion-MNIST, subsequent results on larger models like LLMs are expected in future
+papers.
+
+Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a
+2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning
+that backpropagation takes only 20 epochs to achieve similar performance). Another run of
+doubling the learning rate and training for 40 epochs yields a slightly worse error rate
+of 1.46%
+
+The current example does not yield state-of-the-art results. But with proper tuning of
+the Learning Rate, model architecture (number of units in `Dense` layers, kernel
+activations, initializations, regularization etc.), the results can be improved
+to match the claims of the paper.
+"""
diff --git a/knowledge_base/vision/fully_convolutional_network.py b/knowledge_base/vision/fully_convolutional_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa792f17185c40cea8e49e7183a165e68678a210
--- /dev/null
+++ b/knowledge_base/vision/fully_convolutional_network.py
@@ -0,0 +1,634 @@
+"""
+Title: Image Segmentation using Composable Fully-Convolutional Networks
+Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
+Date created: 2023/06/16
+Last modified: 2023/12/25
+Description: Using the Fully-Convolutional Network for Image Segmentation.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The following example walks through the steps to implement Fully-Convolutional Networks
+for Image Segmentation on the Oxford-IIIT Pets dataset.
+The model was proposed in the paper,
+[Fully Convolutional Networks for Semantic Segmentation by Long et. al.(2014)](https://arxiv.org/abs/1411.4038).
+Image segmentation is one of the most common and introductory tasks when it comes to
+Computer Vision, where we extend the problem of Image Classification from
+one-label-per-image to a pixel-wise classification problem.
+In this example, we will assemble the aforementioned Fully-Convolutional Segmentation architecture,
+capable of performing Image Segmentation.
+The network extends the pooling layer outputs from the VGG in order to perform upsampling
+and get a final result. The intermediate outputs coming from the 3rd, 4th and 5th Max-Pooling layers from VGG19 are
+extracted out and upsampled at different levels and factors to get a final output with the same shape as that
+of the output, but with the class of each pixel present at each location, instead of pixel intensity values.
+Different intermediate pool layers are extracted and processed upon for different versions of the network.
+The FCN architecture has 3 versions of differing quality.
+
+- FCN-32S
+- FCN-16S
+- FCN-8S
+
+All versions of the model derive their outputs through an iterative processing of
+successive intermediate pool layers of the main backbone used.
+A better idea can be gained from the figure below.
+
+|  |
+| :--: |
+| **Diagram 1**: Combined Architecture Versions (Source: Paper) |
+
+To get a better idea on Image Segmentation or find more pre-trained models, feel free to
+navigate to the [Hugging Face Image Segmentation Models](https://huggingface.co/models?pipeline_tag=image-segmentation) page,
+or a [PyImageSearch Blog on Semantic Segmentation](https://pyimagesearch.com/2018/09/03/semantic-segmentation-with-opencv-and-deep-learning/)
+
+"""
+
+"""
+## Setup Imports
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+from keras import ops
+import tensorflow as tf
+import matplotlib.pyplot as plt
+import tensorflow_datasets as tfds
+import numpy as np
+
+AUTOTUNE = tf.data.AUTOTUNE
+
+"""
+## Set configurations for notebook variables
+
+We set the required parameters for the experiment.
+The chosen dataset has a total of 4 classes per image, with regards to the segmentation mask.
+We also set our hyperparameters in this cell.
+
+Mixed Precision as an option is also available in systems which support it, to reduce
+load.
+This would make most tensors use `16-bit float` values instead of `32-bit float`
+values, in places where it will not adversely affect computation.
+This means, during computation, TensorFlow will use `16-bit float` Tensors to increase speed at the cost of precision,
+while storing the values in their original default `32-bit float` form.
+"""
+
+NUM_CLASSES = 4
+INPUT_HEIGHT = 224
+INPUT_WIDTH = 224
+LEARNING_RATE = 1e-3
+WEIGHT_DECAY = 1e-4
+EPOCHS = 20
+BATCH_SIZE = 32
+MIXED_PRECISION = True
+SHUFFLE = True
+
+# Mixed-precision setting
+if MIXED_PRECISION:
+ policy = keras.mixed_precision.Policy("mixed_float16")
+ keras.mixed_precision.set_global_policy(policy)
+
+"""
+## Load dataset
+
+We make use of the [Oxford-IIIT Pets dataset](http://www.robots.ox.ac.uk/~vgg/data/pets/)
+which contains a total of 7,349 samples and their segmentation masks.
+We have 37 classes, with roughly 200 samples per class.
+Our training and validation dataset has 3,128 and 552 samples respectively.
+Aside from this, our test split has a total of 3,669 samples.
+
+We set a `batch_size` parameter that will batch our samples together, use a `shuffle`
+parameter to mix our samples together.
+"""
+
+(train_ds, valid_ds, test_ds) = tfds.load(
+ "oxford_iiit_pet",
+ split=["train[:85%]", "train[85%:]", "test"],
+ batch_size=BATCH_SIZE,
+ shuffle_files=SHUFFLE,
+)
+
+"""
+## Unpack and preprocess dataset
+
+We define a simple function that includes performs Resizing over our
+training, validation and test datasets.
+We do the same process on the masks as well, to make sure both are aligned in terms of shape and size.
+"""
+
+
+# Image and Mask Pre-processing
+def unpack_resize_data(section):
+ image = section["image"]
+ segmentation_mask = section["segmentation_mask"]
+
+ resize_layer = keras.layers.Resizing(INPUT_HEIGHT, INPUT_WIDTH)
+
+ image = resize_layer(image)
+ segmentation_mask = resize_layer(segmentation_mask)
+
+ return image, segmentation_mask
+
+
+train_ds = train_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
+valid_ds = valid_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
+test_ds = test_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
+"""
+## Visualize one random sample from the pre-processed dataset
+
+We visualize what a random sample in our test split of the dataset looks like, and plot
+the segmentation mask on top to see the effective mask areas.
+Note that we have performed pre-processing on this dataset too,
+which makes the image and mask size same.
+"""
+
+# Select random image and mask. Cast to NumPy array
+# for Matplotlib visualization.
+
+images, masks = next(iter(test_ds))
+random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)
+
+test_image = images[int(random_idx)].numpy().astype("float")
+test_mask = masks[int(random_idx)].numpy().astype("float")
+
+# Overlay segmentation mask on top of image.
+fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
+
+ax[0].set_title("Image")
+ax[0].imshow(test_image / 255.0)
+
+ax[1].set_title("Image with segmentation mask overlay")
+ax[1].imshow(test_image / 255.0)
+ax[1].imshow(
+ test_mask,
+ cmap="inferno",
+ alpha=0.6,
+)
+plt.show()
+
+"""
+## Perform VGG-specific pre-processing
+
+`keras.applications.VGG19` requires the use of a `preprocess_input` function that will
+pro-actively perform Image-net style Standard Deviation Normalization scheme.
+"""
+
+
+def preprocess_data(image, segmentation_mask):
+ image = keras.applications.vgg19.preprocess_input(image)
+
+ return image, segmentation_mask
+
+
+train_ds = (
+ train_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
+ .shuffle(buffer_size=1024)
+ .prefetch(buffer_size=1024)
+)
+valid_ds = (
+ valid_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
+ .shuffle(buffer_size=1024)
+ .prefetch(buffer_size=1024)
+)
+test_ds = (
+ test_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE)
+ .shuffle(buffer_size=1024)
+ .prefetch(buffer_size=1024)
+)
+"""
+## Model Definition
+
+The Fully-Convolutional Network boasts a simple architecture composed of only
+`keras.layers.Conv2D` Layers, `keras.layers.Dense` layers and `keras.layers.Dropout`
+layers.
+
+|  |
+| :--: |
+| **Diagram 2**: Generic FCN Forward Pass (Source: Paper)|
+
+Pixel-wise prediction is performed by having a Softmax Convolutional layer with the same
+size of the image, such that we can perform direct comparison
+We can find several important metrics such as Accuracy and Mean-Intersection-over-Union on the network.
+"""
+
+"""
+### Backbone (VGG-19)
+
+We use the [VGG-19 network](https://keras.io/api/applications/vgg/) as the backbone, as
+the paper suggests it to be one of the most effective backbones for this network.
+We extract different outputs from the network by making use of `keras.models.Model`.
+Following this, we add layers on top to make a network perfectly simulating that of
+Diagram 1.
+The backbone's `keras.layers.Dense` layers will be converted to `keras.layers.Conv2D`
+layers based on the [original Caffe code present here.](https://github.com/linxi159/FCN-caffe/blob/master/pascalcontext-fcn16s/net.py)
+All 3 networks will share the same backbone weights, but will have differing results
+based on their extensions.
+We make the backbone non-trainable to improve training time requirements.
+It is also noted in the paper that making the network trainable does not yield major benefits.
+"""
+
+input_layer = keras.Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, 3))
+
+# VGG Model backbone with pre-trained ImageNet weights.
+vgg_model = keras.applications.vgg19.VGG19(include_top=True, weights="imagenet")
+
+# Extracting different outputs from same model
+fcn_backbone = keras.models.Model(
+ inputs=vgg_model.layers[1].input,
+ outputs=[
+ vgg_model.get_layer(block_name).output
+ for block_name in ["block3_pool", "block4_pool", "block5_pool"]
+ ],
+)
+
+# Setting backbone to be non-trainable
+fcn_backbone.trainable = False
+
+x = fcn_backbone(input_layer)
+
+# Converting Dense layers to Conv2D layers
+units = [4096, 4096]
+dense_convs = []
+
+for filter_idx in range(len(units)):
+ dense_conv = keras.layers.Conv2D(
+ filters=units[filter_idx],
+ kernel_size=(7, 7) if filter_idx == 0 else (1, 1),
+ strides=(1, 1),
+ activation="relu",
+ padding="same",
+ use_bias=False,
+ kernel_initializer=keras.initializers.Constant(1.0),
+ )
+ dense_convs.append(dense_conv)
+ dropout_layer = keras.layers.Dropout(0.5)
+ dense_convs.append(dropout_layer)
+
+dense_convs = keras.Sequential(dense_convs)
+dense_convs.trainable = False
+
+x[-1] = dense_convs(x[-1])
+
+pool3_output, pool4_output, pool5_output = x
+
+"""
+### FCN-32S
+
+We extend the last output, perform a `1x1 Convolution` and perform 2D Bilinear Upsampling
+by a factor of 32 to get an image of the same size as that of our input.
+We use a simple `keras.layers.UpSampling2D` layer over a `keras.layers.Conv2DTranspose`
+since it yields performance benefits from being a deterministic mathematical operation
+over a Convolutional operation
+It is also noted in the paper that making the Up-sampling parameters trainable does not yield benefits.
+Original experiments of the paper used Upsampling as well.
+"""
+
+# 1x1 convolution to set channels = number of classes
+pool5 = keras.layers.Conv2D(
+ filters=NUM_CLASSES,
+ kernel_size=(1, 1),
+ padding="same",
+ strides=(1, 1),
+ activation="relu",
+)
+
+# Get Softmax outputs for all classes
+fcn32s_conv_layer = keras.layers.Conv2D(
+ filters=NUM_CLASSES,
+ kernel_size=(1, 1),
+ activation="softmax",
+ padding="same",
+ strides=(1, 1),
+)
+
+# Up-sample to original image size
+fcn32s_upsampling = keras.layers.UpSampling2D(
+ size=(32, 32),
+ data_format=keras.backend.image_data_format(),
+ interpolation="bilinear",
+)
+
+final_fcn32s_pool = pool5(pool5_output)
+final_fcn32s_output = fcn32s_conv_layer(final_fcn32s_pool)
+final_fcn32s_output = fcn32s_upsampling(final_fcn32s_output)
+
+fcn32s_model = keras.Model(inputs=input_layer, outputs=final_fcn32s_output)
+
+"""
+### FCN-16S
+
+The pooling output from the FCN-32S is extended and added to the 4th-level Pooling output
+of our backbone.
+Following this, we upsample by a factor of 16 to get image of the same
+size as that of our input.
+"""
+
+# 1x1 convolution to set channels = number of classes
+# Followed from the original Caffe implementation
+pool4 = keras.layers.Conv2D(
+ filters=NUM_CLASSES,
+ kernel_size=(1, 1),
+ padding="same",
+ strides=(1, 1),
+ activation="linear",
+ kernel_initializer=keras.initializers.Zeros(),
+)(pool4_output)
+
+# Intermediate up-sample
+pool5 = keras.layers.UpSampling2D(
+ size=(2, 2),
+ data_format=keras.backend.image_data_format(),
+ interpolation="bilinear",
+)(final_fcn32s_pool)
+
+# Get Softmax outputs for all classes
+fcn16s_conv_layer = keras.layers.Conv2D(
+ filters=NUM_CLASSES,
+ kernel_size=(1, 1),
+ activation="softmax",
+ padding="same",
+ strides=(1, 1),
+)
+
+# Up-sample to original image size
+fcn16s_upsample_layer = keras.layers.UpSampling2D(
+ size=(16, 16),
+ data_format=keras.backend.image_data_format(),
+ interpolation="bilinear",
+)
+
+# Add intermediate outputs
+final_fcn16s_pool = keras.layers.Add()([pool4, pool5])
+final_fcn16s_output = fcn16s_conv_layer(final_fcn16s_pool)
+final_fcn16s_output = fcn16s_upsample_layer(final_fcn16s_output)
+
+fcn16s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn16s_output)
+
+"""
+### FCN-8S
+
+The pooling output from the FCN-16S is extended once more, and added from the 3rd-level
+Pooling output of our backbone.
+This result is upsampled by a factor of 8 to get an image of the same size as that of our input.
+"""
+
+# 1x1 convolution to set channels = number of classes
+# Followed from the original Caffe implementation
+pool3 = keras.layers.Conv2D(
+ filters=NUM_CLASSES,
+ kernel_size=(1, 1),
+ padding="same",
+ strides=(1, 1),
+ activation="linear",
+ kernel_initializer=keras.initializers.Zeros(),
+)(pool3_output)
+
+# Intermediate up-sample
+intermediate_pool_output = keras.layers.UpSampling2D(
+ size=(2, 2),
+ data_format=keras.backend.image_data_format(),
+ interpolation="bilinear",
+)(final_fcn16s_pool)
+
+# Get Softmax outputs for all classes
+fcn8s_conv_layer = keras.layers.Conv2D(
+ filters=NUM_CLASSES,
+ kernel_size=(1, 1),
+ activation="softmax",
+ padding="same",
+ strides=(1, 1),
+)
+
+# Up-sample to original image size
+fcn8s_upsample_layer = keras.layers.UpSampling2D(
+ size=(8, 8),
+ data_format=keras.backend.image_data_format(),
+ interpolation="bilinear",
+)
+
+# Add intermediate outputs
+final_fcn8s_pool = keras.layers.Add()([pool3, intermediate_pool_output])
+final_fcn8s_output = fcn8s_conv_layer(final_fcn8s_pool)
+final_fcn8s_output = fcn8s_upsample_layer(final_fcn8s_output)
+
+fcn8s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn8s_output)
+
+"""
+### Load weights into backbone
+
+It was noted in the paper, as well as through experimentation that extracting the weights
+of the last 2 Fully-connected Dense layers from the backbone, reshaping the weights to
+fit that of the `keras.layers.Dense` layers we had previously converted into
+`keras.layers.Conv2D`, and setting them to it yields far better results and a significant
+increase in mIOU performance.
+"""
+
+# VGG's last 2 layers
+weights1 = vgg_model.get_layer("fc1").get_weights()[0]
+weights2 = vgg_model.get_layer("fc2").get_weights()[0]
+
+weights1 = weights1.reshape(7, 7, 512, 4096)
+weights2 = weights2.reshape(1, 1, 4096, 4096)
+
+dense_convs.layers[0].set_weights([weights1])
+dense_convs.layers[2].set_weights([weights2])
+
+"""
+## Training
+
+The original paper talks about making use of [SGD with Momentum](https://keras.io/api/optimizers/sgd/) as the optimizer of choice.
+But it was noticed during experimentation that
+[AdamW](https://keras.io/api/optimizers/adamw/)
+yielded better results in terms of mIOU and Pixel-wise Accuracy.
+"""
+
+"""
+### FCN-32S
+"""
+
+fcn32s_optimizer = keras.optimizers.AdamW(
+ learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
+)
+
+fcn32s_loss = keras.losses.SparseCategoricalCrossentropy()
+
+# Maintain mIOU and Pixel-wise Accuracy as metrics
+fcn32s_model.compile(
+ optimizer=fcn32s_optimizer,
+ loss=fcn32s_loss,
+ metrics=[
+ keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
+ keras.metrics.SparseCategoricalAccuracy(),
+ ],
+)
+
+fcn32s_history = fcn32s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
+
+"""
+### FCN-16S
+"""
+
+fcn16s_optimizer = keras.optimizers.AdamW(
+ learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
+)
+
+fcn16s_loss = keras.losses.SparseCategoricalCrossentropy()
+
+# Maintain mIOU and Pixel-wise Accuracy as metrics
+fcn16s_model.compile(
+ optimizer=fcn16s_optimizer,
+ loss=fcn16s_loss,
+ metrics=[
+ keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
+ keras.metrics.SparseCategoricalAccuracy(),
+ ],
+)
+
+fcn16s_history = fcn16s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
+
+"""
+### FCN-8S
+"""
+
+fcn8s_optimizer = keras.optimizers.AdamW(
+ learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
+)
+
+fcn8s_loss = keras.losses.SparseCategoricalCrossentropy()
+
+# Maintain mIOU and Pixel-wise Accuracy as metrics
+fcn8s_model.compile(
+ optimizer=fcn8s_optimizer,
+ loss=fcn8s_loss,
+ metrics=[
+ keras.metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),
+ keras.metrics.SparseCategoricalAccuracy(),
+ ],
+)
+
+fcn8s_history = fcn8s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
+"""
+## Visualizations
+"""
+
+"""
+### Plotting metrics for training run
+
+We perform a comparative study between all 3 versions of the model by tracking training
+and validation metrics of Accuracy, Loss and Mean IoU.
+"""
+
+total_plots = len(fcn32s_history.history)
+cols = total_plots // 2
+
+rows = total_plots // cols
+
+if total_plots % cols != 0:
+ rows += 1
+
+# Set all history dictionary objects
+fcn32s_dict = fcn32s_history.history
+fcn16s_dict = fcn16s_history.history
+fcn8s_dict = fcn8s_history.history
+
+pos = range(1, total_plots + 1)
+plt.figure(figsize=(15, 10))
+
+for i, ((key_32s, value_32s), (key_16s, value_16s), (key_8s, value_8s)) in enumerate(
+ zip(fcn32s_dict.items(), fcn16s_dict.items(), fcn8s_dict.items())
+):
+ plt.subplot(rows, cols, pos[i])
+ plt.plot(range(len(value_32s)), value_32s)
+ plt.plot(range(len(value_16s)), value_16s)
+ plt.plot(range(len(value_8s)), value_8s)
+ plt.title(str(key_32s) + " (combined)")
+ plt.legend(["FCN-32S", "FCN-16S", "FCN-8S"])
+
+plt.show()
+
+"""
+### Visualizing predicted segmentation masks
+
+To understand the results and see them better, we pick a random image from the test
+dataset and perform inference on it to see the masks generated by each model.
+Note: For better results, the model must be trained for a higher number of epochs.
+"""
+
+images, masks = next(iter(test_ds))
+random_idx = keras.random.uniform([], minval=0, maxval=BATCH_SIZE, seed=10)
+
+# Get random test image and mask
+test_image = images[int(random_idx)].numpy().astype("float")
+test_mask = masks[int(random_idx)].numpy().astype("float")
+
+pred_image = ops.expand_dims(test_image, axis=0)
+pred_image = keras.applications.vgg19.preprocess_input(pred_image)
+
+# Perform inference on FCN-32S
+pred_mask_32s = fcn32s_model.predict(pred_image, verbose=0).astype("float")
+pred_mask_32s = np.argmax(pred_mask_32s, axis=-1)
+pred_mask_32s = pred_mask_32s[0, ...]
+
+# Perform inference on FCN-16S
+pred_mask_16s = fcn16s_model.predict(pred_image, verbose=0).astype("float")
+pred_mask_16s = np.argmax(pred_mask_16s, axis=-1)
+pred_mask_16s = pred_mask_16s[0, ...]
+
+# Perform inference on FCN-8S
+pred_mask_8s = fcn8s_model.predict(pred_image, verbose=0).astype("float")
+pred_mask_8s = np.argmax(pred_mask_8s, axis=-1)
+pred_mask_8s = pred_mask_8s[0, ...]
+
+# Plot all results
+fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 8))
+
+fig.delaxes(ax[0, 2])
+
+ax[0, 0].set_title("Image")
+ax[0, 0].imshow(test_image / 255.0)
+
+ax[0, 1].set_title("Image with ground truth overlay")
+ax[0, 1].imshow(test_image / 255.0)
+ax[0, 1].imshow(
+ test_mask,
+ cmap="inferno",
+ alpha=0.6,
+)
+
+ax[1, 0].set_title("Image with FCN-32S mask overlay")
+ax[1, 0].imshow(test_image / 255.0)
+ax[1, 0].imshow(pred_mask_32s, cmap="inferno", alpha=0.6)
+
+ax[1, 1].set_title("Image with FCN-16S mask overlay")
+ax[1, 1].imshow(test_image / 255.0)
+ax[1, 1].imshow(pred_mask_16s, cmap="inferno", alpha=0.6)
+
+ax[1, 2].set_title("Image with FCN-8S mask overlay")
+ax[1, 2].imshow(test_image / 255.0)
+ax[1, 2].imshow(pred_mask_8s, cmap="inferno", alpha=0.6)
+
+plt.show()
+
+"""
+## Conclusion
+
+The Fully-Convolutional Network is an exceptionally simple network that has yielded
+strong results in Image Segmentation tasks across different benchmarks.
+With the advent of better mechanisms like [Attention](https://arxiv.org/abs/1706.03762) as used in
+[SegFormer](https://arxiv.org/abs/2105.15203) and
+[DeTR](https://arxiv.org/abs/2005.12872), this model serves as a quick way to iterate and
+find baselines for this task on unknown data.
+"""
+
+"""
+## Acknowledgements
+
+I thank [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush
+Thakur](https://twitter.com/ayushthakur0) and [Ritwik
+Raha](https://twitter.com/ritwik_raha) for giving a preliminary review of the example.
+I also thank the [Google Developer
+Experts](https://developers.google.com/community/experts) program.
+
+"""
diff --git a/knowledge_base/vision/grad_cam.py b/knowledge_base/vision/grad_cam.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6b6531c216ca330a1dbc35ab511e7077840a0b7
--- /dev/null
+++ b/knowledge_base/vision/grad_cam.py
@@ -0,0 +1,203 @@
+"""
+Title: Grad-CAM class activation visualization
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/04/26
+Last modified: 2021/03/07
+Description: How to obtain a class activation heatmap for an image classification model.
+Accelerator: GPU
+"""
+
+"""
+Adapted from Deep Learning with Python (2017).
+
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+import tensorflow as tf
+import keras
+
+# Display
+from IPython.display import Image, display
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+
+
+"""
+## Configurable parameters
+
+You can change these to another model.
+
+To get the values for `last_conv_layer_name` use `model.summary()`
+to see the names of all layers in the model.
+"""
+
+model_builder = keras.applications.xception.Xception
+img_size = (299, 299)
+preprocess_input = keras.applications.xception.preprocess_input
+decode_predictions = keras.applications.xception.decode_predictions
+
+last_conv_layer_name = "block14_sepconv2_act"
+
+# The local path to our target image
+img_path = keras.utils.get_file(
+ "african_elephant.jpg", "https://i.imgur.com/Bvro0YD.png"
+)
+
+display(Image(img_path))
+
+
+"""
+## The Grad-CAM algorithm
+"""
+
+
+def get_img_array(img_path, size):
+ # `img` is a PIL image of size 299x299
+ img = keras.utils.load_img(img_path, target_size=size)
+ # `array` is a float32 Numpy array of shape (299, 299, 3)
+ array = keras.utils.img_to_array(img)
+ # We add a dimension to transform our array into a "batch"
+ # of size (1, 299, 299, 3)
+ array = np.expand_dims(array, axis=0)
+ return array
+
+
+def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
+ # First, we create a model that maps the input image to the activations
+ # of the last conv layer as well as the output predictions
+ grad_model = keras.models.Model(
+ model.inputs, [model.get_layer(last_conv_layer_name).output, model.output]
+ )
+
+ # Then, we compute the gradient of the top predicted class for our input image
+ # with respect to the activations of the last conv layer
+ with tf.GradientTape() as tape:
+ last_conv_layer_output, preds = grad_model(img_array)
+ if pred_index is None:
+ pred_index = tf.argmax(preds[0])
+ class_channel = preds[:, pred_index]
+
+ # This is the gradient of the output neuron (top predicted or chosen)
+ # with regard to the output feature map of the last conv layer
+ grads = tape.gradient(class_channel, last_conv_layer_output)
+
+ # This is a vector where each entry is the mean intensity of the gradient
+ # over a specific feature map channel
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
+
+ # We multiply each channel in the feature map array
+ # by "how important this channel is" with regard to the top predicted class
+ # then sum all the channels to obtain the heatmap class activation
+ last_conv_layer_output = last_conv_layer_output[0]
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
+ heatmap = tf.squeeze(heatmap)
+
+ # For visualization purpose, we will also normalize the heatmap between 0 & 1
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
+ return heatmap.numpy()
+
+
+"""
+## Let's test-drive it
+"""
+
+# Prepare image
+img_array = preprocess_input(get_img_array(img_path, size=img_size))
+
+# Make model
+model = model_builder(weights="imagenet")
+
+# Remove last layer's softmax
+model.layers[-1].activation = None
+
+# Print what the top predicted class is
+preds = model.predict(img_array)
+print("Predicted:", decode_predictions(preds, top=1)[0])
+
+# Generate class activation heatmap
+heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)
+
+# Display heatmap
+plt.matshow(heatmap)
+plt.show()
+
+
+"""
+## Create a superimposed visualization
+"""
+
+
+def save_and_display_gradcam(img_path, heatmap, cam_path="cam.jpg", alpha=0.4):
+ # Load the original image
+ img = keras.utils.load_img(img_path)
+ img = keras.utils.img_to_array(img)
+
+ # Rescale heatmap to a range 0-255
+ heatmap = np.uint8(255 * heatmap)
+
+ # Use jet colormap to colorize heatmap
+ jet = mpl.colormaps["jet"]
+
+ # Use RGB values of the colormap
+ jet_colors = jet(np.arange(256))[:, :3]
+ jet_heatmap = jet_colors[heatmap]
+
+ # Create an image with RGB colorized heatmap
+ jet_heatmap = keras.utils.array_to_img(jet_heatmap)
+ jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
+ jet_heatmap = keras.utils.img_to_array(jet_heatmap)
+
+ # Superimpose the heatmap on original image
+ superimposed_img = jet_heatmap * alpha + img
+ superimposed_img = keras.utils.array_to_img(superimposed_img)
+
+ # Save the superimposed image
+ superimposed_img.save(cam_path)
+
+ # Display Grad CAM
+ display(Image(cam_path))
+
+
+save_and_display_gradcam(img_path, heatmap)
+
+"""
+## Let's try another image
+
+We will see how the grad cam explains the model's outputs for a multi-label image. Let's
+try an image with a cat and a dog together, and see how the grad cam behaves.
+"""
+
+img_path = keras.utils.get_file(
+ "cat_and_dog.jpg",
+ "https://storage.googleapis.com/petbacker/images/blog/2017/dog-and-cat-cover.jpg",
+)
+
+display(Image(img_path))
+
+# Prepare image
+img_array = preprocess_input(get_img_array(img_path, size=img_size))
+
+# Print what the two top predicted classes are
+preds = model.predict(img_array)
+print("Predicted:", decode_predictions(preds, top=2)[0])
+
+"""
+We generate class activation heatmap for "chow," the class index is 260
+"""
+
+heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=260)
+
+save_and_display_gradcam(img_path, heatmap)
+
+"""
+We generate class activation heatmap for "egyptian cat," the class index is 285
+"""
+
+heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=285)
+
+save_and_display_gradcam(img_path, heatmap)
diff --git a/knowledge_base/vision/gradient_centralization.py b/knowledge_base/vision/gradient_centralization.py
new file mode 100644
index 0000000000000000000000000000000000000000..aefc5e3e9ff1e00fe353c29b467926bacc4cd7aa
--- /dev/null
+++ b/knowledge_base/vision/gradient_centralization.py
@@ -0,0 +1,273 @@
+"""
+Title: Gradient Centralization for Better Training Performance
+Author: [Rishit Dagli](https://github.com/Rishit-dagli)
+Date created: 06/18/21
+Last modified: 07/25/23
+Description: Implement Gradient Centralization to improve training performance of DNNs.
+Accelerator: GPU
+Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)
+"""
+
+"""
+## Introduction
+
+This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a
+new optimization technique for Deep Neural Networks by Yong et al., and demonstrates it
+on Laurence Moroney's [Horses or Humans
+Dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans). Gradient
+Centralization can both speedup training process and improve the final generalization
+performance of DNNs. It operates directly on gradients by centralizing the gradient
+vectors to have zero mean. Gradient Centralization morever improves the Lipschitzness of
+the loss function and its gradient so that the training process becomes more efficient
+and stable.
+
+This example requires `tensorflow_datasets` which can be installed with this command:
+
+```
+pip install tensorflow-datasets
+```
+"""
+
+"""
+## Setup
+"""
+
+from time import time
+
+import keras
+from keras import layers
+from keras.optimizers import RMSprop
+from keras import ops
+
+from tensorflow import data as tf_data
+import tensorflow_datasets as tfds
+
+
+"""
+## Prepare the data
+
+For this example, we will be using the [Horses or Humans
+dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans).
+"""
+
+num_classes = 2
+input_shape = (300, 300, 3)
+dataset_name = "horses_or_humans"
+batch_size = 128
+AUTOTUNE = tf_data.AUTOTUNE
+
+(train_ds, test_ds), metadata = tfds.load(
+ name=dataset_name,
+ split=[tfds.Split.TRAIN, tfds.Split.TEST],
+ with_info=True,
+ as_supervised=True,
+)
+
+print(f"Image shape: {metadata.features['image'].shape}")
+print(f"Training images: {metadata.splits['train'].num_examples}")
+print(f"Test images: {metadata.splits['test'].num_examples}")
+
+"""
+## Use Data Augmentation
+
+We will rescale the data to `[0, 1]` and perform simple augmentations to our data.
+"""
+
+rescale = layers.Rescaling(1.0 / 255)
+
+data_augmentation = [
+ layers.RandomFlip("horizontal_and_vertical"),
+ layers.RandomRotation(0.3),
+ layers.RandomZoom(0.2),
+]
+
+
+# Helper to apply augmentation
+def apply_aug(x):
+ for aug in data_augmentation:
+ x = aug(x)
+ return x
+
+
+def prepare(ds, shuffle=False, augment=False):
+ # Rescale dataset
+ ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)
+
+ if shuffle:
+ ds = ds.shuffle(1024)
+
+ # Batch dataset
+ ds = ds.batch(batch_size)
+
+ # Use data augmentation only on the training set
+ if augment:
+ ds = ds.map(
+ lambda x, y: (apply_aug(x), y),
+ num_parallel_calls=AUTOTUNE,
+ )
+
+ # Use buffered prefecting
+ return ds.prefetch(buffer_size=AUTOTUNE)
+
+
+"""
+Rescale and augment the data
+"""
+
+train_ds = prepare(train_ds, shuffle=True, augment=True)
+test_ds = prepare(test_ds)
+"""
+## Define a model
+
+In this section we will define a Convolutional neural network.
+"""
+
+model = keras.Sequential(
+ [
+ layers.Input(shape=input_shape),
+ layers.Conv2D(16, (3, 3), activation="relu"),
+ layers.MaxPooling2D(2, 2),
+ layers.Conv2D(32, (3, 3), activation="relu"),
+ layers.Dropout(0.5),
+ layers.MaxPooling2D(2, 2),
+ layers.Conv2D(64, (3, 3), activation="relu"),
+ layers.Dropout(0.5),
+ layers.MaxPooling2D(2, 2),
+ layers.Conv2D(64, (3, 3), activation="relu"),
+ layers.MaxPooling2D(2, 2),
+ layers.Conv2D(64, (3, 3), activation="relu"),
+ layers.MaxPooling2D(2, 2),
+ layers.Flatten(),
+ layers.Dropout(0.5),
+ layers.Dense(512, activation="relu"),
+ layers.Dense(1, activation="sigmoid"),
+ ]
+)
+
+"""
+## Implement Gradient Centralization
+
+We will now
+subclass the `RMSProp` optimizer class modifying the
+`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
+Centralization. On a high level the idea is that let us say we obtain our gradients
+through back propagation for a Dense or Convolution layer we then compute the mean of the
+column vectors of the weight matrix, and then remove the mean from each column vector.
+
+The experiments in [this paper](https://arxiv.org/abs/2004.01461) on various
+applications, including general image classification, fine-grained image classification,
+detection and segmentation and Person ReID demonstrate that GC can consistently improve
+the performance of DNN learning.
+
+Also, for simplicity at the moment we are not implementing gradient cliiping functionality,
+however this quite easy to implement.
+
+At the moment we are just creating a subclass for the `RMSProp` optimizer
+however you could easily reproduce this for any other optimizer or on a custom
+optimizer in the same way. We will be using this class in the later section when
+we train a model with Gradient Centralization.
+"""
+
+
+class GCRMSprop(RMSprop):
+ def get_gradients(self, loss, params):
+ # We here just provide a modified get_gradients() function since we are
+ # trying to just compute the centralized gradients.
+
+ grads = []
+ gradients = super().get_gradients()
+ for grad in gradients:
+ grad_len = len(grad.shape)
+ if grad_len > 1:
+ axis = list(range(grad_len - 1))
+ grad -= ops.mean(grad, axis=axis, keep_dims=True)
+ grads.append(grad)
+
+ return grads
+
+
+optimizer = GCRMSprop(learning_rate=1e-4)
+
+"""
+## Training utilities
+
+We will also create a callback which allows us to easily measure the total training time
+and the time taken for each epoch since we are interested in comparing the effect of
+Gradient Centralization on the model we built above.
+"""
+
+
+class TimeHistory(keras.callbacks.Callback):
+ def on_train_begin(self, logs={}):
+ self.times = []
+
+ def on_epoch_begin(self, batch, logs={}):
+ self.epoch_time_start = time()
+
+ def on_epoch_end(self, batch, logs={}):
+ self.times.append(time() - self.epoch_time_start)
+
+
+"""
+## Train the model without GC
+
+We now train the model we built earlier without Gradient Centralization which we can
+compare to the training performance of the model trained with Gradient Centralization.
+"""
+
+time_callback_no_gc = TimeHistory()
+model.compile(
+ loss="binary_crossentropy",
+ optimizer=RMSprop(learning_rate=1e-4),
+ metrics=["accuracy"],
+)
+
+model.summary()
+
+"""
+We also save the history since we later want to compare our model trained with and not
+trained with Gradient Centralization
+"""
+
+history_no_gc = model.fit(
+ train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
+)
+
+"""
+## Train the model with GC
+
+We will now train the same model, this time using Gradient Centralization,
+notice our optimizer is the one using Gradient Centralization this time.
+"""
+
+time_callback_gc = TimeHistory()
+model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
+
+model.summary()
+
+history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])
+
+"""
+## Comparing performance
+"""
+
+print("Not using Gradient Centralization")
+print(f"Loss: {history_no_gc.history['loss'][-1]}")
+print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
+print(f"Training Time: {sum(time_callback_no_gc.times)}")
+
+print("Using Gradient Centralization")
+print(f"Loss: {history_gc.history['loss'][-1]}")
+print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
+print(f"Training Time: {sum(time_callback_gc.times)}")
+
+"""
+Readers are encouraged to try out Gradient Centralization on different datasets from
+different domains and experiment with it's effect. You are strongly advised to check out
+the [original paper](https://arxiv.org/abs/2004.01461) as well - the authors present
+several studies on Gradient Centralization showing how it can improve general
+performance, generalization, training time as well as more efficient.
+
+Many thanks to [Ali Mustufa Shaikh](https://github.com/ialimustufa) for reviewing this
+implementation.
+"""
diff --git a/knowledge_base/vision/handwriting_recognition.py b/knowledge_base/vision/handwriting_recognition.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0fc453170df7da2e379db8f7cfbd664fa9f410d
--- /dev/null
+++ b/knowledge_base/vision/handwriting_recognition.py
@@ -0,0 +1,580 @@
+"""
+Title: Handwriting recognition
+Authors: [A_K_Nain](https://twitter.com/A_K_Nain), [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/08/16
+Last modified: 2025/09/29
+Description: Training a handwriting recognition model with variable-length sequences.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example shows how the [Captcha OCR](https://keras.io/examples/vision/captcha_ocr/)
+example can be extended to the
+[IAM Dataset](https://fki.tic.heia-fr.ch/databases/iam-handwriting-database),
+which has variable length ground-truth targets. Each sample in the dataset is an image of some
+handwritten text, and its corresponding target is the string present in the image.
+The IAM Dataset is widely used across many OCR benchmarks, so we hope this example can serve as a
+good starting point for building OCR systems.
+"""
+
+"""
+## Data collection
+"""
+
+"""shell
+wget -q https://github.com/sayakpaul/Handwriting-Recognizer-in-Keras/releases/download/v1.0.0/IAM_Words.zip
+unzip -qq IAM_Words.zip
+
+mkdir data
+mkdir data/words
+tar -xf IAM_Words/words.tgz -C data/words
+mv IAM_Words/words.txt data
+"""
+
+"""
+Preview how the dataset is organized. Lines prepended by "#" are just metadata information.
+"""
+
+"""shell
+head -20 data/words.txt
+"""
+
+"""
+## Imports
+"""
+
+import keras
+from keras.layers import StringLookup
+from keras import ops
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import numpy as np
+import os
+
+np.random.seed(42)
+keras.utils.set_random_seed(42)
+
+"""
+## Dataset splitting
+"""
+
+base_path = "data"
+words_list = []
+
+words = open(f"{base_path}/words.txt", "r").readlines()
+for line in words:
+ if line[0] == "#":
+ continue
+ if line.split(" ")[1] != "err": # We don't need to deal with errored entries.
+ words_list.append(line)
+
+len(words_list)
+
+np.random.shuffle(words_list)
+
+"""
+We will split the dataset into three subsets with a 90:5:5 ratio (train:validation:test).
+"""
+
+split_idx = int(0.9 * len(words_list))
+train_samples = words_list[:split_idx]
+test_samples = words_list[split_idx:]
+
+val_split_idx = int(0.5 * len(test_samples))
+validation_samples = test_samples[:val_split_idx]
+test_samples = test_samples[val_split_idx:]
+
+assert len(words_list) == len(train_samples) + len(validation_samples) + len(
+ test_samples
+)
+
+print(f"Total training samples: {len(train_samples)}")
+print(f"Total validation samples: {len(validation_samples)}")
+print(f"Total test samples: {len(test_samples)}")
+
+"""
+## Data input pipeline
+
+We start building our data input pipeline by first preparing the image paths.
+"""
+
+base_image_path = os.path.join(base_path, "words")
+
+
+def get_image_paths_and_labels(samples):
+ paths = []
+ corrected_samples = []
+ for i, file_line in enumerate(samples):
+ line_split = file_line.strip()
+ line_split = line_split.split(" ")
+
+ # Each line split will have this format for the corresponding image:
+ # part1/part1-part2/part1-part2-part3.png
+ image_name = line_split[0]
+ partI = image_name.split("-")[0]
+ partII = image_name.split("-")[1]
+ img_path = os.path.join(
+ base_image_path, partI, partI + "-" + partII, image_name + ".png"
+ )
+ if os.path.getsize(img_path):
+ paths.append(img_path)
+ corrected_samples.append(file_line.split("\n")[0])
+
+ return paths, corrected_samples
+
+
+train_img_paths, train_labels = get_image_paths_and_labels(train_samples)
+validation_img_paths, validation_labels = get_image_paths_and_labels(validation_samples)
+test_img_paths, test_labels = get_image_paths_and_labels(test_samples)
+
+"""
+Then we prepare the ground-truth labels.
+"""
+
+# Find maximum length and the size of the vocabulary in the training data.
+train_labels_cleaned = []
+characters = set()
+max_len = 0
+
+for label in train_labels:
+ label = label.split(" ")[-1].strip()
+ for char in label:
+ characters.add(char)
+
+ max_len = max(max_len, len(label))
+ train_labels_cleaned.append(label)
+
+characters = sorted(list(characters))
+
+print("Maximum length: ", max_len)
+print("Vocab size: ", len(characters))
+
+# Check some label samples.
+train_labels_cleaned[:10]
+
+"""
+Now we clean the validation and the test labels as well.
+"""
+
+
+def clean_labels(labels):
+ cleaned_labels = []
+ for label in labels:
+ label = label.split(" ")[-1].strip()
+ cleaned_labels.append(label)
+ return cleaned_labels
+
+
+validation_labels_cleaned = clean_labels(validation_labels)
+test_labels_cleaned = clean_labels(test_labels)
+
+"""
+### Building the character vocabulary
+
+Keras provides different preprocessing layers to deal with different modalities of data.
+[This guide](https://keras.io/api/layers/preprocessing_layers/) provides a comprehensive introduction.
+Our example involves preprocessing labels at the character
+level. This means that if there are two labels, e.g. "cat" and "dog", then our character
+vocabulary should be {a, c, d, g, o, t} (without any special tokens). We use the
+[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup/)
+layer for this purpose.
+"""
+
+
+AUTOTUNE = tf.data.AUTOTUNE
+
+# Mapping characters to integers.
+char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
+
+# Mapping integers back to original characters.
+num_to_char = StringLookup(
+ vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
+)
+
+"""
+### Resizing images without distortion
+
+Instead of square images, many OCR models work with rectangular images. This will become
+clearer in a moment when we will visualize a few samples from the dataset. While
+aspect-unaware resizing square images does not introduce a significant amount of
+distortion this is not the case for rectangular images. But resizing images to a uniform
+size is a requirement for mini-batching. So we need to perform our resizing such that
+the following criteria are met:
+
+* Aspect ratio is preserved.
+* Content of the images is not affected.
+"""
+
+
+def distortion_free_resize(image, img_size):
+ w, h = img_size
+ image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
+
+ # Check tha amount of padding needed to be done.
+ pad_height = h - ops.shape(image)[0]
+ pad_width = w - ops.shape(image)[1]
+
+ # Only necessary if you want to do same amount of padding on both sides.
+ if pad_height % 2 != 0:
+ height = pad_height // 2
+ pad_height_top = height + 1
+ pad_height_bottom = height
+ else:
+ pad_height_top = pad_height_bottom = pad_height // 2
+
+ if pad_width % 2 != 0:
+ width = pad_width // 2
+ pad_width_left = width + 1
+ pad_width_right = width
+ else:
+ pad_width_left = pad_width_right = pad_width // 2
+
+ image = tf.pad(
+ image,
+ paddings=[
+ [pad_height_top, pad_height_bottom],
+ [pad_width_left, pad_width_right],
+ [0, 0],
+ ],
+ )
+
+ image = ops.transpose(image, (1, 0, 2))
+ image = tf.image.flip_left_right(image)
+ return image
+
+
+"""
+If we just go with the plain resizing then the images would look like so:
+
+
+
+Notice how this resizing would have introduced unnecessary stretching.
+"""
+
+"""
+### Putting the utilities together
+"""
+
+batch_size = 64
+padding_token = 99
+image_width = 128
+image_height = 32
+
+
+def preprocess_image(image_path, img_size=(image_width, image_height)):
+ image = tf.io.read_file(image_path)
+ image = tf.image.decode_png(image, 1)
+ image = distortion_free_resize(image, img_size)
+ image = ops.cast(image, tf.float32) / 255.0
+ return image
+
+
+def vectorize_label(label):
+ label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
+ length = ops.shape(label)[0]
+ pad_amount = max_len - length
+ label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
+ return label
+
+
+def process_images_labels(image_path, label):
+ image = preprocess_image(image_path)
+ label = vectorize_label(label)
+ return {"image": image, "label": label}
+
+
+def prepare_dataset(image_paths, labels):
+ dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(
+ process_images_labels, num_parallel_calls=AUTOTUNE
+ )
+ return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)
+
+
+"""
+## Prepare `tf.data.Dataset` objects
+"""
+
+train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)
+validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)
+test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)
+
+"""
+## Visualize a few samples
+"""
+
+for data in train_ds.take(1):
+ images, labels = data["image"], data["label"]
+
+ _, ax = plt.subplots(4, 4, figsize=(15, 8))
+
+ for i in range(16):
+ img = images[i]
+ img = tf.image.flip_left_right(img)
+ img = ops.transpose(img, (1, 0, 2))
+ img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
+ img = img[:, :, 0]
+
+ # Gather indices where label!= padding_token.
+ label = labels[i]
+ indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
+ # Convert to string.
+ label = tf.strings.reduce_join(num_to_char(indices))
+ label = label.numpy().decode("utf-8")
+
+ ax[i // 4, i % 4].imshow(img, cmap="gray")
+ ax[i // 4, i % 4].set_title(label)
+ ax[i // 4, i % 4].axis("off")
+
+
+plt.show()
+
+"""
+You will notice that the content of original image is kept as faithful as possible and has
+been padded accordingly.
+"""
+
+"""
+## Model
+
+Our model will use the CTC loss as an endpoint layer. For a detailed understanding of the
+CTC loss, refer to [this post](https://distill.pub/2017/ctc/).
+"""
+
+
+class CTCLayer(keras.layers.Layer):
+ def __init__(self, name=None):
+ super().__init__(name=name)
+ self.loss_fn = tf.keras.backend.ctc_batch_cost
+
+ def call(self, y_true, y_pred):
+ batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
+ input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
+ label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
+
+ input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
+ label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
+ loss = self.loss_fn(y_true, y_pred, input_length, label_length)
+ self.add_loss(loss)
+
+ # At test time, just return the computed predictions.
+ return y_pred
+
+
+def build_model():
+ # Inputs to the model
+ input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
+ labels = keras.layers.Input(name="label", shape=(None,))
+
+ # First conv block.
+ x = keras.layers.Conv2D(
+ 32,
+ (3, 3),
+ activation="relu",
+ kernel_initializer="he_normal",
+ padding="same",
+ name="Conv1",
+ )(input_img)
+ x = keras.layers.MaxPooling2D((2, 2), name="pool1")(x)
+
+ # Second conv block.
+ x = keras.layers.Conv2D(
+ 64,
+ (3, 3),
+ activation="relu",
+ kernel_initializer="he_normal",
+ padding="same",
+ name="Conv2",
+ )(x)
+ x = keras.layers.MaxPooling2D((2, 2), name="pool2")(x)
+
+ # We have used two max pool with pool size and strides 2.
+ # Hence, downsampled feature maps are 4x smaller. The number of
+ # filters in the last layer is 64. Reshape accordingly before
+ # passing the output to the RNN part of the model.
+ new_shape = ((image_width // 4), (image_height // 4) * 64)
+ x = keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
+ x = keras.layers.Dense(64, activation="relu", name="dense1")(x)
+ x = keras.layers.Dropout(0.2)(x)
+
+ # RNNs.
+ x = keras.layers.Bidirectional(
+ keras.layers.LSTM(128, return_sequences=True, dropout=0.25)
+ )(x)
+ x = keras.layers.Bidirectional(
+ keras.layers.LSTM(64, return_sequences=True, dropout=0.25)
+ )(x)
+
+ # +2 is to account for the two special tokens introduced by the CTC loss.
+ # The recommendation comes here: https://git.io/J0eXP.
+ x = keras.layers.Dense(
+ len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="dense2"
+ )(x)
+
+ # Add CTC layer for calculating CTC loss at each step.
+ output = CTCLayer(name="ctc_loss")(labels, x)
+
+ # Define the model.
+ model = keras.models.Model(
+ inputs=[input_img, labels], outputs=output, name="handwriting_recognizer"
+ )
+ # Optimizer.
+ opt = keras.optimizers.Adam()
+ # Compile the model and return.
+ model.compile(optimizer=opt)
+ return model
+
+
+# Get the model.
+model = build_model()
+model.summary()
+
+"""
+## Evaluation metric
+
+[Edit Distance](https://en.wikipedia.org/wiki/Edit_distance)
+is the most widely used metric for evaluating OCR models. In this section, we will
+implement it and use it as a callback to monitor our model.
+"""
+
+"""
+We first segregate the validation images and their labels for convenience.
+"""
+validation_images = []
+validation_labels = []
+
+for batch in validation_ds:
+ validation_images.append(batch["image"])
+ validation_labels.append(batch["label"])
+
+"""
+Now, we create a callback to monitor the edit distances.
+"""
+
+
+def calculate_edit_distance(labels, predictions):
+ # Get a single batch and convert its labels to sparse tensors.
+ saprse_labels = ops.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
+
+ # Make predictions and convert them to sparse tensors.
+ input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
+ predictions_decoded = keras.ops.nn.ctc_decode(
+ predictions, sequence_lengths=input_len
+ )[0][0][:, :max_len]
+ sparse_predictions = ops.cast(
+ tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
+ )
+
+ # Compute individual edit distances and average them out.
+ edit_distances = tf.edit_distance(
+ sparse_predictions, saprse_labels, normalize=False
+ )
+ return tf.reduce_mean(edit_distances)
+
+
+class EditDistanceCallback(keras.callbacks.Callback):
+ def __init__(self, pred_model):
+ super().__init__()
+ self.prediction_model = pred_model
+
+ def on_epoch_end(self, epoch, logs=None):
+ edit_distances = []
+
+ for i in range(len(validation_images)):
+ labels = validation_labels[i]
+ predictions = self.prediction_model.predict(validation_images[i])
+ edit_distances.append(calculate_edit_distance(labels, predictions).numpy())
+
+ print(
+ f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"
+ )
+
+
+"""
+## Training
+
+Now we are ready to kick off model training.
+"""
+
+epochs = 10 # To get good results this should be at least 50.
+
+model = build_model()
+prediction_model = keras.models.Model(
+ model.get_layer(name="image").output, model.get_layer(name="dense2").output
+)
+edit_distance_callback = EditDistanceCallback(prediction_model)
+
+# Train the model.
+history = model.fit(
+ train_ds,
+ validation_data=validation_ds,
+ epochs=epochs,
+ callbacks=[edit_distance_callback],
+)
+
+
+"""
+## Inference
+"""
+
+
+# A utility function to decode the output of the network.
+def decode_batch_predictions(pred):
+ input_len = np.ones(pred.shape[0]) * pred.shape[1]
+ # Use greedy search. For complex tasks, you can use beam search.
+ results = keras.ops.nn.ctc_decode(pred, sequence_lengths=input_len)[0][0][
+ :, :max_len
+ ]
+ # Iterate over the results and get back the text.
+ output_text = []
+ for res in results:
+ res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
+ res = (
+ tf.strings.reduce_join(num_to_char(res))
+ .numpy()
+ .decode("utf-8")
+ .replace("[UNK]", "")
+ )
+ output_text.append(res)
+ return output_text
+
+
+# Let's check results on some test samples.
+for batch in test_ds.take(1):
+ batch_images = batch["image"]
+ _, ax = plt.subplots(4, 4, figsize=(15, 8))
+
+ preds = prediction_model.predict(batch_images)
+ pred_texts = decode_batch_predictions(preds)
+
+ for i in range(16):
+ img = batch_images[i]
+ img = tf.image.flip_left_right(img)
+ img = ops.transpose(img, (1, 0, 2))
+ img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
+ img = img[:, :, 0]
+
+ title = f"Prediction: {pred_texts[i]}"
+ ax[i // 4, i % 4].imshow(img, cmap="gray")
+ ax[i // 4, i % 4].set_title(title)
+ ax[i // 4, i % 4].axis("off")
+
+plt.show()
+
+"""
+To get better results the model should be trained for at least 50 epochs.
+"""
+
+"""
+## Final remarks
+
+* The `prediction_model` is fully compatible with TensorFlow Lite. If you are interested,
+you can use it inside a mobile application. You may find
+[this notebook](https://github.com/tulasiram58827/ocr_tflite/blob/main/colabs/captcha_ocr_tflite.ipynb)
+to be useful in this regard.
+* Not all the training examples are perfectly aligned as observed in this example. This
+can hurt model performance for complex sequences. To this end, we can leverage
+Spatial Transformer Networks ([Jaderberg et al.](https://arxiv.org/abs/1506.02025))
+that can help the model learn affine transformations that maximize its performance.
+"""
diff --git a/knowledge_base/vision/image_captioning.py b/knowledge_base/vision/image_captioning.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b82821b23780ce579841bb2c5faa3ba600b0cc8
--- /dev/null
+++ b/knowledge_base/vision/image_captioning.py
@@ -0,0 +1,654 @@
+"""
+Title: Image Captioning
+Author: [A_K_Nain](https://twitter.com/A_K_Nain)
+Date created: 2021/05/29
+Last modified: 2021/10/31
+Description: Implement an image captioning model using a CNN and a Transformer.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import re
+import numpy as np
+import matplotlib.pyplot as plt
+
+import tensorflow as tf
+import keras
+from keras import layers
+from keras.applications import efficientnet
+from keras.layers import TextVectorization
+
+keras.utils.set_random_seed(111)
+
+"""
+## Download the dataset
+
+We will be using the Flickr8K dataset for this tutorial. This dataset comprises over
+8,000 images, that are each paired with five different captions.
+"""
+
+
+"""shell
+wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
+wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
+unzip -qq Flickr8k_Dataset.zip
+unzip -qq Flickr8k_text.zip
+rm Flickr8k_Dataset.zip Flickr8k_text.zip
+"""
+
+
+# Path to the images
+IMAGES_PATH = "Flicker8k_Dataset"
+
+# Desired image dimensions
+IMAGE_SIZE = (299, 299)
+
+# Vocabulary size
+VOCAB_SIZE = 10000
+
+# Fixed length allowed for any sequence
+SEQ_LENGTH = 25
+
+# Dimension for the image embeddings and token embeddings
+EMBED_DIM = 512
+
+# Per-layer units in the feed-forward network
+FF_DIM = 512
+
+# Other training parameters
+BATCH_SIZE = 64
+EPOCHS = 30
+AUTOTUNE = tf.data.AUTOTUNE
+
+"""
+## Preparing the dataset
+"""
+
+
+def load_captions_data(filename):
+ """Loads captions (text) data and maps them to corresponding images.
+
+ Args:
+ filename: Path to the text file containing caption data.
+
+ Returns:
+ caption_mapping: Dictionary mapping image names and the corresponding captions
+ text_data: List containing all the available captions
+ """
+
+ with open(filename) as caption_file:
+ caption_data = caption_file.readlines()
+ caption_mapping = {}
+ text_data = []
+ images_to_skip = set()
+
+ for line in caption_data:
+ line = line.rstrip("\n")
+ # Image name and captions are separated using a tab
+ img_name, caption = line.split("\t")
+
+ # Each image is repeated five times for the five different captions.
+ # Each image name has a suffix `#(caption_number)`
+ img_name = img_name.split("#")[0]
+ img_name = os.path.join(IMAGES_PATH, img_name.strip())
+
+ # We will remove caption that are either too short to too long
+ tokens = caption.strip().split()
+
+ if len(tokens) < 5 or len(tokens) > SEQ_LENGTH:
+ images_to_skip.add(img_name)
+ continue
+
+ if img_name.endswith("jpg") and img_name not in images_to_skip:
+ # We will add a start and an end token to each caption
+ caption = " " + caption.strip() + " "
+ text_data.append(caption)
+
+ if img_name in caption_mapping:
+ caption_mapping[img_name].append(caption)
+ else:
+ caption_mapping[img_name] = [caption]
+
+ for img_name in images_to_skip:
+ if img_name in caption_mapping:
+ del caption_mapping[img_name]
+
+ return caption_mapping, text_data
+
+
+def train_val_split(caption_data, train_size=0.8, shuffle=True):
+ """Split the captioning dataset into train and validation sets.
+
+ Args:
+ caption_data (dict): Dictionary containing the mapped caption data
+ train_size (float): Fraction of all the full dataset to use as training data
+ shuffle (bool): Whether to shuffle the dataset before splitting
+
+ Returns:
+ Traning and validation datasets as two separated dicts
+ """
+
+ # 1. Get the list of all image names
+ all_images = list(caption_data.keys())
+
+ # 2. Shuffle if necessary
+ if shuffle:
+ np.random.shuffle(all_images)
+
+ # 3. Split into training and validation sets
+ train_size = int(len(caption_data) * train_size)
+
+ training_data = {
+ img_name: caption_data[img_name] for img_name in all_images[:train_size]
+ }
+ validation_data = {
+ img_name: caption_data[img_name] for img_name in all_images[train_size:]
+ }
+
+ # 4. Return the splits
+ return training_data, validation_data
+
+
+# Load the dataset
+captions_mapping, text_data = load_captions_data("Flickr8k.token.txt")
+
+# Split the dataset into training and validation sets
+train_data, valid_data = train_val_split(captions_mapping)
+print("Number of training samples: ", len(train_data))
+print("Number of validation samples: ", len(valid_data))
+
+"""
+## Vectorizing the text data
+
+We'll use the `TextVectorization` layer to vectorize the text data,
+that is to say, to turn the
+original strings into integer sequences where each integer represents the index of
+a word in a vocabulary. We will use a custom string standardization scheme
+(strip punctuation characters except `<` and `>`) and the default
+splitting scheme (split on whitespace).
+"""
+
+
+def custom_standardization(input_string):
+ lowercase = tf.strings.lower(input_string)
+ return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")
+
+
+strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
+strip_chars = strip_chars.replace("<", "")
+strip_chars = strip_chars.replace(">", "")
+
+vectorization = TextVectorization(
+ max_tokens=VOCAB_SIZE,
+ output_mode="int",
+ output_sequence_length=SEQ_LENGTH,
+ standardize=custom_standardization,
+)
+vectorization.adapt(text_data)
+
+# Data augmentation for image data
+image_augmentation = keras.Sequential(
+ [
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(0.2),
+ layers.RandomContrast(0.3),
+ ]
+)
+
+
+"""
+## Building a `tf.data.Dataset` pipeline for training
+
+We will generate pairs of images and corresponding captions using a `tf.data.Dataset` object.
+The pipeline consists of two steps:
+
+1. Read the image from the disk
+2. Tokenize all the five captions corresponding to the image
+"""
+
+
+def decode_and_resize(img_path):
+ img = tf.io.read_file(img_path)
+ img = tf.image.decode_jpeg(img, channels=3)
+ img = tf.image.resize(img, IMAGE_SIZE)
+ img = tf.image.convert_image_dtype(img, tf.float32)
+ return img
+
+
+def process_input(img_path, captions):
+ return decode_and_resize(img_path), vectorization(captions)
+
+
+def make_dataset(images, captions):
+ dataset = tf.data.Dataset.from_tensor_slices((images, captions))
+ dataset = dataset.shuffle(BATCH_SIZE * 8)
+ dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)
+ dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
+
+ return dataset
+
+
+# Pass the list of images and the list of corresponding captions
+train_dataset = make_dataset(list(train_data.keys()), list(train_data.values()))
+
+valid_dataset = make_dataset(list(valid_data.keys()), list(valid_data.values()))
+
+
+"""
+## Building the model
+
+Our image captioning architecture consists of three models:
+
+1. A CNN: used to extract the image features
+2. A TransformerEncoder: The extracted image features are then passed to a Transformer
+ based encoder that generates a new representation of the inputs
+3. A TransformerDecoder: This model takes the encoder output and the text data
+ (sequences) as inputs and tries to learn to generate the caption.
+"""
+
+
+def get_cnn_model():
+ base_model = efficientnet.EfficientNetB0(
+ input_shape=(*IMAGE_SIZE, 3),
+ include_top=False,
+ weights="imagenet",
+ )
+ # We freeze our feature extractor
+ base_model.trainable = False
+ base_model_out = base_model.output
+ base_model_out = layers.Reshape((-1, base_model_out.shape[-1]))(base_model_out)
+ cnn_model = keras.models.Model(base_model.input, base_model_out)
+ return cnn_model
+
+
+class TransformerEncoderBlock(layers.Layer):
+ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.dense_dim = dense_dim
+ self.num_heads = num_heads
+ self.attention_1 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.0
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+ self.dense_1 = layers.Dense(embed_dim, activation="relu")
+
+ def call(self, inputs, training, mask=None):
+ inputs = self.layernorm_1(inputs)
+ inputs = self.dense_1(inputs)
+
+ attention_output_1 = self.attention_1(
+ query=inputs,
+ value=inputs,
+ key=inputs,
+ attention_mask=None,
+ training=training,
+ )
+ out_1 = self.layernorm_2(inputs + attention_output_1)
+ return out_1
+
+
+class PositionalEmbedding(layers.Layer):
+ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.token_embeddings = layers.Embedding(
+ input_dim=vocab_size, output_dim=embed_dim
+ )
+ self.position_embeddings = layers.Embedding(
+ input_dim=sequence_length, output_dim=embed_dim
+ )
+ self.sequence_length = sequence_length
+ self.vocab_size = vocab_size
+ self.embed_dim = embed_dim
+ self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))
+
+ def call(self, inputs):
+ length = tf.shape(inputs)[-1]
+ positions = tf.range(start=0, limit=length, delta=1)
+ embedded_tokens = self.token_embeddings(inputs)
+ embedded_tokens = embedded_tokens * self.embed_scale
+ embedded_positions = self.position_embeddings(positions)
+ return embedded_tokens + embedded_positions
+
+ def compute_mask(self, inputs, mask=None):
+ return tf.math.not_equal(inputs, 0)
+
+
+class TransformerDecoderBlock(layers.Layer):
+ def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.ff_dim = ff_dim
+ self.num_heads = num_heads
+ self.attention_1 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
+ )
+ self.attention_2 = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
+ )
+ self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
+ self.ffn_layer_2 = layers.Dense(embed_dim)
+
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+ self.layernorm_3 = layers.LayerNormalization()
+
+ self.embedding = PositionalEmbedding(
+ embed_dim=EMBED_DIM,
+ sequence_length=SEQ_LENGTH,
+ vocab_size=VOCAB_SIZE,
+ )
+ self.out = layers.Dense(VOCAB_SIZE, activation="softmax")
+
+ self.dropout_1 = layers.Dropout(0.3)
+ self.dropout_2 = layers.Dropout(0.5)
+ self.supports_masking = True
+
+ def call(self, inputs, encoder_outputs, training, mask=None):
+ inputs = self.embedding(inputs)
+ causal_mask = self.get_causal_attention_mask(inputs)
+
+ if mask is not None:
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
+ combined_mask = tf.minimum(combined_mask, causal_mask)
+
+ attention_output_1 = self.attention_1(
+ query=inputs,
+ value=inputs,
+ key=inputs,
+ attention_mask=combined_mask,
+ training=training,
+ )
+ out_1 = self.layernorm_1(inputs + attention_output_1)
+
+ attention_output_2 = self.attention_2(
+ query=out_1,
+ value=encoder_outputs,
+ key=encoder_outputs,
+ attention_mask=padding_mask,
+ training=training,
+ )
+ out_2 = self.layernorm_2(out_1 + attention_output_2)
+
+ ffn_out = self.ffn_layer_1(out_2)
+ ffn_out = self.dropout_1(ffn_out, training=training)
+ ffn_out = self.ffn_layer_2(ffn_out)
+
+ ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
+ ffn_out = self.dropout_2(ffn_out, training=training)
+ preds = self.out(ffn_out)
+ return preds
+
+ def get_causal_attention_mask(self, inputs):
+ input_shape = tf.shape(inputs)
+ batch_size, sequence_length = input_shape[0], input_shape[1]
+ i = tf.range(sequence_length)[:, tf.newaxis]
+ j = tf.range(sequence_length)
+ mask = tf.cast(i >= j, dtype="int32")
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
+ mult = tf.concat(
+ [
+ tf.expand_dims(batch_size, -1),
+ tf.constant([1, 1], dtype=tf.int32),
+ ],
+ axis=0,
+ )
+ return tf.tile(mask, mult)
+
+
+class ImageCaptioningModel(keras.Model):
+ def __init__(
+ self,
+ cnn_model,
+ encoder,
+ decoder,
+ num_captions_per_image=5,
+ image_aug=None,
+ ):
+ super().__init__()
+ self.cnn_model = cnn_model
+ self.encoder = encoder
+ self.decoder = decoder
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+ self.acc_tracker = keras.metrics.Mean(name="accuracy")
+ self.num_captions_per_image = num_captions_per_image
+ self.image_aug = image_aug
+
+ def calculate_loss(self, y_true, y_pred, mask):
+ loss = self.loss(y_true, y_pred)
+ mask = tf.cast(mask, dtype=loss.dtype)
+ loss *= mask
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
+
+ def calculate_accuracy(self, y_true, y_pred, mask):
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
+ accuracy = tf.math.logical_and(mask, accuracy)
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
+ mask = tf.cast(mask, dtype=tf.float32)
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
+
+ def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
+ encoder_out = self.encoder(img_embed, training=training)
+ batch_seq_inp = batch_seq[:, :-1]
+ batch_seq_true = batch_seq[:, 1:]
+ mask = tf.math.not_equal(batch_seq_true, 0)
+ batch_seq_pred = self.decoder(
+ batch_seq_inp, encoder_out, training=training, mask=mask
+ )
+ loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
+ acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
+ return loss, acc
+
+ def train_step(self, batch_data):
+ batch_img, batch_seq = batch_data
+ batch_loss = 0
+ batch_acc = 0
+
+ if self.image_aug:
+ batch_img = self.image_aug(batch_img)
+
+ # 1. Get image embeddings
+ img_embed = self.cnn_model(batch_img)
+
+ # 2. Pass each of the five captions one by one to the decoder
+ # along with the encoder outputs and compute the loss as well as accuracy
+ # for each caption.
+ for i in range(self.num_captions_per_image):
+ with tf.GradientTape() as tape:
+ loss, acc = self._compute_caption_loss_and_acc(
+ img_embed, batch_seq[:, i, :], training=True
+ )
+
+ # 3. Update loss and accuracy
+ batch_loss += loss
+ batch_acc += acc
+
+ # 4. Get the list of all the trainable weights
+ train_vars = (
+ self.encoder.trainable_variables + self.decoder.trainable_variables
+ )
+
+ # 5. Get the gradients
+ grads = tape.gradient(loss, train_vars)
+
+ # 6. Update the trainable weights
+ self.optimizer.apply_gradients(zip(grads, train_vars))
+
+ # 7. Update the trackers
+ batch_acc /= float(self.num_captions_per_image)
+ self.loss_tracker.update_state(batch_loss)
+ self.acc_tracker.update_state(batch_acc)
+
+ # 8. Return the loss and accuracy values
+ return {
+ "loss": self.loss_tracker.result(),
+ "acc": self.acc_tracker.result(),
+ }
+
+ def test_step(self, batch_data):
+ batch_img, batch_seq = batch_data
+ batch_loss = 0
+ batch_acc = 0
+
+ # 1. Get image embeddings
+ img_embed = self.cnn_model(batch_img)
+
+ # 2. Pass each of the five captions one by one to the decoder
+ # along with the encoder outputs and compute the loss as well as accuracy
+ # for each caption.
+ for i in range(self.num_captions_per_image):
+ loss, acc = self._compute_caption_loss_and_acc(
+ img_embed, batch_seq[:, i, :], training=False
+ )
+
+ # 3. Update batch loss and batch accuracy
+ batch_loss += loss
+ batch_acc += acc
+
+ batch_acc /= float(self.num_captions_per_image)
+
+ # 4. Update the trackers
+ self.loss_tracker.update_state(batch_loss)
+ self.acc_tracker.update_state(batch_acc)
+
+ # 5. Return the loss and accuracy values
+ return {
+ "loss": self.loss_tracker.result(),
+ "acc": self.acc_tracker.result(),
+ }
+
+ @property
+ def metrics(self):
+ # We need to list our metrics here so the `reset_states()` can be
+ # called automatically.
+ return [self.loss_tracker, self.acc_tracker]
+
+
+cnn_model = get_cnn_model()
+encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)
+decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)
+caption_model = ImageCaptioningModel(
+ cnn_model=cnn_model,
+ encoder=encoder,
+ decoder=decoder,
+ image_aug=image_augmentation,
+)
+
+"""
+## Model training
+"""
+
+
+# Define the loss function
+cross_entropy = keras.losses.SparseCategoricalCrossentropy(
+ from_logits=False,
+ reduction=None,
+)
+
+# EarlyStopping criteria
+early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
+
+
+# Learning Rate Scheduler for the optimizer
+class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):
+ def __init__(self, post_warmup_learning_rate, warmup_steps):
+ super().__init__()
+ self.post_warmup_learning_rate = post_warmup_learning_rate
+ self.warmup_steps = warmup_steps
+
+ def __call__(self, step):
+ global_step = tf.cast(step, tf.float32)
+ warmup_steps = tf.cast(self.warmup_steps, tf.float32)
+ warmup_progress = global_step / warmup_steps
+ warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress
+ return tf.cond(
+ global_step < warmup_steps,
+ lambda: warmup_learning_rate,
+ lambda: self.post_warmup_learning_rate,
+ )
+
+
+# Create a learning rate schedule
+num_train_steps = len(train_dataset) * EPOCHS
+num_warmup_steps = num_train_steps // 15
+lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps)
+
+# Compile the model
+caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy)
+
+# Fit the model
+caption_model.fit(
+ train_dataset,
+ epochs=EPOCHS,
+ validation_data=valid_dataset,
+ callbacks=[early_stopping],
+)
+
+"""
+## Check sample predictions
+"""
+
+vocab = vectorization.get_vocabulary()
+index_lookup = dict(zip(range(len(vocab)), vocab))
+max_decoded_sentence_length = SEQ_LENGTH - 1
+valid_images = list(valid_data.keys())
+
+
+def generate_caption():
+ # Select a random image from the validation dataset
+ sample_img = np.random.choice(valid_images)
+
+ # Read the image from the disk
+ sample_img = decode_and_resize(sample_img)
+ img = sample_img.numpy().clip(0, 255).astype(np.uint8)
+ plt.imshow(img)
+ plt.show()
+
+ # Pass the image to the CNN
+ img = tf.expand_dims(sample_img, 0)
+ img = caption_model.cnn_model(img)
+
+ # Pass the image features to the Transformer encoder
+ encoded_img = caption_model.encoder(img, training=False)
+
+ # Generate the caption using the Transformer decoder
+ decoded_caption = " "
+ for i in range(max_decoded_sentence_length):
+ tokenized_caption = vectorization([decoded_caption])[:, :-1]
+ mask = tf.math.not_equal(tokenized_caption, 0)
+ predictions = caption_model.decoder(
+ tokenized_caption, encoded_img, training=False, mask=mask
+ )
+ sampled_token_index = np.argmax(predictions[0, i, :])
+ sampled_token = index_lookup[sampled_token_index]
+ if sampled_token == "":
+ break
+ decoded_caption += " " + sampled_token
+
+ decoded_caption = decoded_caption.replace(" ", "")
+ decoded_caption = decoded_caption.replace(" ", "").strip()
+ print("Predicted Caption: ", decoded_caption)
+
+
+# Check predictions for a few samples
+generate_caption()
+generate_caption()
+generate_caption()
+
+"""
+## End Notes
+
+We saw that the model starts to generate reasonable captions after a few epochs. To keep
+this example easily runnable, we have trained it with a few constraints, like a minimal
+number of attention heads. To improve the predictions, you can try changing these training
+settings and find a good model for your use case.
+"""
diff --git a/knowledge_base/vision/image_classification_efficientnet_fine_tuning.py b/knowledge_base/vision/image_classification_efficientnet_fine_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..2073af2210a62ce82988cb305fd9a62d6473eeaf
--- /dev/null
+++ b/knowledge_base/vision/image_classification_efficientnet_fine_tuning.py
@@ -0,0 +1,442 @@
+"""
+Title: Image classification via fine-tuning with EfficientNet
+Author: [Yixing Fu](https://github.com/yixingfu)
+Date created: 2020/06/30
+Last modified: 2023/07/10
+Description: Use EfficientNet with weights pre-trained on imagenet for Stanford Dogs classification.
+Accelerator: GPU
+"""
+
+"""
+
+## Introduction: what is EfficientNet
+
+EfficientNet, first introduced in [Tan and Le, 2019](https://arxiv.org/abs/1905.11946)
+is among the most efficient models (i.e. requiring least FLOPS for inference)
+that reaches State-of-the-Art accuracy on both
+imagenet and common image classification transfer learning tasks.
+
+The smallest base model is similar to [MnasNet](https://arxiv.org/abs/1807.11626), which
+reached near-SOTA with a significantly smaller model. By introducing a heuristic way to
+scale the model, EfficientNet provides a family of models (B0 to B7) that represents a
+good combination of efficiency and accuracy on a variety of scales. Such a scaling
+heuristics (compound-scaling, details see
+[Tan and Le, 2019](https://arxiv.org/abs/1905.11946)) allows the
+efficiency-oriented base model (B0) to surpass models at every scale, while avoiding
+extensive grid-search of hyperparameters.
+
+A summary of the latest updates on the model is available at
+[here](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet), where various
+augmentation schemes and semi-supervised learning approaches are applied to further
+improve the imagenet performance of the models. These extensions of the model can be used
+by updating weights without changing model architecture.
+
+## B0 to B7 variants of EfficientNet
+
+*(This section provides some details on "compound scaling", and can be skipped
+if you're only interested in using the models)*
+
+Based on the [original paper](https://arxiv.org/abs/1905.11946) people may have the
+impression that EfficientNet is a continuous family of models created by arbitrarily
+choosing scaling factor in as Eq.(3) of the paper. However, choice of resolution,
+depth and width are also restricted by many factors:
+
+- Resolution: Resolutions not divisible by 8, 16, etc. cause zero-padding near boundaries
+of some layers which wastes computational resources. This especially applies to smaller
+variants of the model, hence the input resolution for B0 and B1 are chosen as 224 and
+240.
+
+- Depth and width: The building blocks of EfficientNet demands channel size to be
+multiples of 8.
+
+- Resource limit: Memory limitation may bottleneck resolution when depth
+and width can still increase. In such a situation, increasing depth and/or
+width but keep resolution can still improve performance.
+
+As a result, the depth, width and resolution of each variant of the EfficientNet models
+are hand-picked and proven to produce good results, though they may be significantly
+off from the compound scaling formula.
+Therefore, the keras implementation (detailed below) only provide these 8 models, B0 to B7,
+instead of allowing arbitray choice of width / depth / resolution parameters.
+
+## Keras implementation of EfficientNet
+
+An implementation of EfficientNet B0 to B7 has been shipped with Keras since v2.3. To
+use EfficientNetB0 for classifying 1000 classes of images from ImageNet, run:
+
+```python
+from tensorflow.keras.applications import EfficientNetB0
+model = EfficientNetB0(weights='imagenet')
+```
+
+This model takes input images of shape `(224, 224, 3)`, and the input data should be in the
+range `[0, 255]`. Normalization is included as part of the model.
+
+Because training EfficientNet on ImageNet takes a tremendous amount of resources and
+several techniques that are not a part of the model architecture itself. Hence the Keras
+implementation by default loads pre-trained weights obtained via training with
+[AutoAugment](https://arxiv.org/abs/1805.09501).
+
+For B0 to B7 base models, the input shapes are different. Here is a list of input shape
+expected for each model:
+
+| Base model | resolution|
+|----------------|-----|
+| EfficientNetB0 | 224 |
+| EfficientNetB1 | 240 |
+| EfficientNetB2 | 260 |
+| EfficientNetB3 | 300 |
+| EfficientNetB4 | 380 |
+| EfficientNetB5 | 456 |
+| EfficientNetB6 | 528 |
+| EfficientNetB7 | 600 |
+
+When the model is intended for transfer learning, the Keras implementation
+provides a option to remove the top layers:
+```
+model = EfficientNetB0(include_top=False, weights='imagenet')
+```
+This option excludes the final `Dense` layer that turns 1280 features on the penultimate
+layer into prediction of the 1000 ImageNet classes. Replacing the top layer with custom
+layers allows using EfficientNet as a feature extractor in a transfer learning workflow.
+
+Another argument in the model constructor worth noticing is `drop_connect_rate` which controls
+the dropout rate responsible for [stochastic depth](https://arxiv.org/abs/1603.09382).
+This parameter serves as a toggle for extra regularization in finetuning, but does not
+affect loaded weights. For example, when stronger regularization is desired, try:
+
+```python
+model = EfficientNetB0(weights='imagenet', drop_connect_rate=0.4)
+```
+The default value is 0.2.
+
+## Example: EfficientNetB0 for Stanford Dogs.
+
+EfficientNet is capable of a wide range of image classification tasks.
+This makes it a good model for transfer learning.
+As an end-to-end example, we will show using pre-trained EfficientNetB0 on
+[Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/main.html) dataset.
+
+"""
+
+"""
+## Setup and data loading
+"""
+
+import numpy as np
+import tensorflow_datasets as tfds
+import tensorflow as tf # For tf.data
+import matplotlib.pyplot as plt
+import keras
+from keras import layers
+from keras.applications import EfficientNetB0
+
+# IMG_SIZE is determined by EfficientNet model choice
+IMG_SIZE = 224
+BATCH_SIZE = 64
+
+
+"""
+### Loading data
+
+Here we load data from [tensorflow_datasets](https://www.tensorflow.org/datasets)
+(hereafter TFDS).
+Stanford Dogs dataset is provided in
+TFDS as [stanford_dogs](https://www.tensorflow.org/datasets/catalog/stanford_dogs).
+It features 20,580 images that belong to 120 classes of dog breeds
+(12,000 for training and 8,580 for testing).
+
+By simply changing `dataset_name` below, you may also try this notebook for
+other datasets in TFDS such as
+[cifar10](https://www.tensorflow.org/datasets/catalog/cifar10),
+[cifar100](https://www.tensorflow.org/datasets/catalog/cifar100),
+[food101](https://www.tensorflow.org/datasets/catalog/food101),
+etc. When the images are much smaller than the size of EfficientNet input,
+we can simply upsample the input images. It has been shown in
+[Tan and Le, 2019](https://arxiv.org/abs/1905.11946) that transfer learning
+result is better for increased resolution even if input images remain small.
+"""
+
+dataset_name = "stanford_dogs"
+(ds_train, ds_test), ds_info = tfds.load(
+ dataset_name, split=["train", "test"], with_info=True, as_supervised=True
+)
+NUM_CLASSES = ds_info.features["label"].num_classes
+
+
+"""
+When the dataset include images with various size, we need to resize them into a
+shared size. The Stanford Dogs dataset includes only images at least 200x200
+pixels in size. Here we resize the images to the input size needed for EfficientNet.
+"""
+
+size = (IMG_SIZE, IMG_SIZE)
+ds_train = ds_train.map(lambda image, label: (tf.image.resize(image, size), label))
+ds_test = ds_test.map(lambda image, label: (tf.image.resize(image, size), label))
+
+"""
+### Visualizing the data
+
+The following code shows the first 9 images with their labels.
+"""
+
+
+def format_label(label):
+ string_label = label_info.int2str(label)
+ return string_label.split("-")[1]
+
+
+label_info = ds_info.features["label"]
+for i, (image, label) in enumerate(ds_train.take(9)):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image.numpy().astype("uint8"))
+ plt.title("{}".format(format_label(label)))
+ plt.axis("off")
+
+
+"""
+### Data augmentation
+
+We can use the preprocessing layers APIs for image augmentation.
+"""
+
+img_augmentation_layers = [
+ layers.RandomRotation(factor=0.15),
+ layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
+ layers.RandomFlip(),
+ layers.RandomContrast(factor=0.1),
+]
+
+
+def img_augmentation(images):
+ for layer in img_augmentation_layers:
+ images = layer(images)
+ return images
+
+
+"""
+This `Sequential` model object can be used both as a part of
+the model we later build, and as a function to preprocess
+data before feeding into the model. Using them as function makes
+it easy to visualize the augmented images. Here we plot 9 examples
+of augmentation result of a given figure.
+"""
+
+for image, label in ds_train.take(1):
+ for i in range(9):
+ ax = plt.subplot(3, 3, i + 1)
+ aug_img = img_augmentation(np.expand_dims(image.numpy(), axis=0))
+ aug_img = np.array(aug_img)
+ plt.imshow(aug_img[0].astype("uint8"))
+ plt.title("{}".format(format_label(label)))
+ plt.axis("off")
+
+
+"""
+### Prepare inputs
+
+Once we verify the input data and augmentation are working correctly,
+we prepare dataset for training. The input data are resized to uniform
+`IMG_SIZE`. The labels are put into one-hot
+(a.k.a. categorical) encoding. The dataset is batched.
+
+Note: `prefetch` and `AUTOTUNE` may in some situation improve
+performance, but depends on environment and the specific dataset used.
+See this [guide](https://www.tensorflow.org/guide/data_performance)
+for more information on data pipeline performance.
+"""
+
+
+# One-hot / categorical encoding
+def input_preprocess_train(image, label):
+ image = img_augmentation(image)
+ label = tf.one_hot(label, NUM_CLASSES)
+ return image, label
+
+
+def input_preprocess_test(image, label):
+ label = tf.one_hot(label, NUM_CLASSES)
+ return image, label
+
+
+ds_train = ds_train.map(input_preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
+ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)
+ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
+
+ds_test = ds_test.map(input_preprocess_test, num_parallel_calls=tf.data.AUTOTUNE)
+ds_test = ds_test.batch(batch_size=BATCH_SIZE, drop_remainder=True)
+
+
+"""
+## Training a model from scratch
+
+We build an EfficientNetB0 with 120 output classes, that is initialized from scratch:
+
+Note: the accuracy will increase very slowly and may overfit.
+"""
+
+model = EfficientNetB0(
+ include_top=True,
+ weights=None,
+ classes=NUM_CLASSES,
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
+)
+model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
+
+model.summary()
+
+epochs = 40 # @param {type: "slider", min:10, max:100}
+hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
+
+
+"""
+Training the model is relatively fast. This might make it sounds easy to simply train EfficientNet on any
+dataset wanted from scratch. However, training EfficientNet on smaller datasets,
+especially those with lower resolution like CIFAR-100, faces the significant challenge of
+overfitting.
+
+Hence training from scratch requires very careful choice of hyperparameters and is
+difficult to find suitable regularization. It would also be much more demanding in resources.
+Plotting the training and validation accuracy
+makes it clear that validation accuracy stagnates at a low value.
+"""
+
+import matplotlib.pyplot as plt
+
+
+def plot_hist(hist):
+ plt.plot(hist.history["accuracy"])
+ plt.plot(hist.history["val_accuracy"])
+ plt.title("model accuracy")
+ plt.ylabel("accuracy")
+ plt.xlabel("epoch")
+ plt.legend(["train", "validation"], loc="upper left")
+ plt.show()
+
+
+plot_hist(hist)
+
+"""
+## Transfer learning from pre-trained weights
+
+Here we initialize the model with pre-trained ImageNet weights,
+and we fine-tune it on our own dataset.
+"""
+
+
+def build_model(num_classes):
+ inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
+ model = EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")
+
+ # Freeze the pretrained weights
+ model.trainable = False
+
+ # Rebuild top
+ x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
+ x = layers.BatchNormalization()(x)
+
+ top_dropout_rate = 0.2
+ x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)
+ outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)
+
+ # Compile
+ model = keras.Model(inputs, outputs, name="EfficientNet")
+ optimizer = keras.optimizers.Adam(learning_rate=1e-2)
+ model.compile(
+ optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
+ )
+ return model
+
+
+"""
+The first step to transfer learning is to freeze all layers and train only the top
+layers. For this step, a relatively large learning rate (1e-2) can be used.
+Note that validation accuracy and loss will usually be better than training
+accuracy and loss. This is because the regularization is strong, which only
+suppresses training-time metrics.
+
+Note that the convergence may take up to 50 epochs depending on choice of learning rate.
+If image augmentation layers were not
+applied, the validation accuracy may only reach ~60%.
+"""
+
+model = build_model(num_classes=NUM_CLASSES)
+
+epochs = 25 # @param {type: "slider", min:8, max:80}
+hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
+plot_hist(hist)
+
+"""
+The second step is to unfreeze a number of layers and fit the model using smaller
+learning rate. In this example we show unfreezing all layers, but depending on
+specific dataset it may be desireble to only unfreeze a fraction of all layers.
+
+When the feature extraction with
+pretrained model works good enough, this step would give a very limited gain on
+validation accuracy. In our case we only see a small improvement,
+as ImageNet pretraining already exposed the model to a good amount of dogs.
+
+On the other hand, when we use pretrained weights on a dataset that is more different
+from ImageNet, this fine-tuning step can be crucial as the feature extractor also
+needs to be adjusted by a considerable amount. Such a situation can be demonstrated
+if choosing CIFAR-100 dataset instead, where fine-tuning boosts validation accuracy
+by about 10% to pass 80% on `EfficientNetB0`.
+
+A side note on freezing/unfreezing models: setting `trainable` of a `Model` will
+simultaneously set all layers belonging to the `Model` to the same `trainable`
+attribute. Each layer is trainable only if both the layer itself and the model
+containing it are trainable. Hence when we need to partially freeze/unfreeze
+a model, we need to make sure the `trainable` attribute of the model is set
+to `True`.
+"""
+
+
+def unfreeze_model(model):
+ # We unfreeze the top 20 layers while leaving BatchNorm layers frozen
+ for layer in model.layers[-20:]:
+ if not isinstance(layer, layers.BatchNormalization):
+ layer.trainable = True
+
+ optimizer = keras.optimizers.Adam(learning_rate=1e-5)
+ model.compile(
+ optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
+ )
+
+
+unfreeze_model(model)
+
+epochs = 4 # @param {type: "slider", min:4, max:10}
+hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
+plot_hist(hist)
+
+"""
+### Tips for fine tuning EfficientNet
+
+On unfreezing layers:
+
+- The `BatchNormalization` layers need to be kept frozen
+([more details](https://keras.io/guides/transfer_learning/)).
+If they are also turned to trainable, the
+first epoch after unfreezing will significantly reduce accuracy.
+- In some cases it may be beneficial to open up only a portion of layers instead of
+unfreezing all. This will make fine tuning much faster when going to larger models like
+B7.
+- Each block needs to be all turned on or off. This is because the architecture includes
+a shortcut from the first layer to the last layer for each block. Not respecting blocks
+also significantly harms the final performance.
+
+Some other tips for utilizing EfficientNet:
+
+- Larger variants of EfficientNet do not guarantee improved performance, especially for
+tasks with less data or fewer classes. In such a case, the larger variant of EfficientNet
+chosen, the harder it is to tune hyperparameters.
+- EMA (Exponential Moving Average) is very helpful in training EfficientNet from scratch,
+but not so much for transfer learning.
+- Do not use the RMSprop setup as in the original paper for transfer learning. The
+momentum and learning rate are too high for transfer learning. It will easily corrupt the
+pretrained weight and blow up the loss. A quick check is to see if loss (as categorical
+cross entropy) is getting significantly larger than log(NUM_CLASSES) after the same
+epoch. If so, the initial learning rate/momentum is too high.
+- Smaller batch size benefit validation accuracy, possibly due to effectively providing
+regularization.
+"""
diff --git a/knowledge_base/vision/image_classification_from_scratch.py b/knowledge_base/vision/image_classification_from_scratch.py
new file mode 100644
index 0000000000000000000000000000000000000000..12266b54596ff5a703b0f21099ee03bf69d99a73
--- /dev/null
+++ b/knowledge_base/vision/image_classification_from_scratch.py
@@ -0,0 +1,325 @@
+"""
+Title: Image classification from scratch
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/04/27
+Last modified: 2023/11/09
+Description: Training an image classifier from scratch on the Kaggle Cats vs Dogs dataset.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example shows how to do image classification from scratch, starting from JPEG
+image files on disk, without leveraging pre-trained weights or a pre-made Keras
+Application model. We demonstrate the workflow on the Kaggle Cats vs Dogs binary
+classification dataset.
+
+We use the `image_dataset_from_directory` utility to generate the datasets, and
+we use Keras image preprocessing layers for image standardization and data augmentation.
+"""
+
+"""
+## Setup
+"""
+
+import os
+import numpy as np
+import keras
+from keras import layers
+from tensorflow import data as tf_data
+import matplotlib.pyplot as plt
+
+"""
+## Load the data: the Cats vs Dogs dataset
+
+### Raw data download
+
+First, let's download the 786M ZIP archive of the raw data:
+"""
+
+"""shell
+curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
+"""
+
+"""shell
+unzip -q kagglecatsanddogs_5340.zip
+ls
+"""
+
+"""
+Now we have a `PetImages` folder which contain two subfolders, `Cat` and `Dog`. Each
+subfolder contains image files for each category.
+"""
+
+"""shell
+ls PetImages
+"""
+
+"""
+### Filter out corrupted images
+
+When working with lots of real-world image data, corrupted images are a common
+occurence. Let's filter out badly-encoded images that do not feature the string "JFIF"
+in their header.
+"""
+
+num_skipped = 0
+for folder_name in ("Cat", "Dog"):
+ folder_path = os.path.join("PetImages", folder_name)
+ for fname in os.listdir(folder_path):
+ fpath = os.path.join(folder_path, fname)
+ try:
+ fobj = open(fpath, "rb")
+ is_jfif = b"JFIF" in fobj.peek(10)
+ finally:
+ fobj.close()
+
+ if not is_jfif:
+ num_skipped += 1
+ # Delete corrupted image
+ os.remove(fpath)
+
+print(f"Deleted {num_skipped} images.")
+
+"""
+## Generate a `Dataset`
+"""
+
+image_size = (180, 180)
+batch_size = 128
+
+train_ds, val_ds = keras.utils.image_dataset_from_directory(
+ "PetImages",
+ validation_split=0.2,
+ subset="both",
+ seed=1337,
+ image_size=image_size,
+ batch_size=batch_size,
+)
+
+"""
+## Visualize the data
+
+Here are the first 9 images in the training dataset.
+"""
+
+
+plt.figure(figsize=(10, 10))
+for images, labels in train_ds.take(1):
+ for i in range(9):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(np.array(images[i]).astype("uint8"))
+ plt.title(int(labels[i]))
+ plt.axis("off")
+
+"""
+## Using image data augmentation
+
+When you don't have a large image dataset, it's a good practice to artificially
+introduce sample diversity by applying random yet realistic transformations to the
+training images, such as random horizontal flipping or small random rotations. This
+helps expose the model to different aspects of the training data while slowing down
+overfitting.
+"""
+
+data_augmentation_layers = [
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(0.1),
+]
+
+
+def data_augmentation(images):
+ for layer in data_augmentation_layers:
+ images = layer(images)
+ return images
+
+
+"""
+Let's visualize what the augmented samples look like, by applying `data_augmentation`
+repeatedly to the first few images in the dataset:
+"""
+
+plt.figure(figsize=(10, 10))
+for images, _ in train_ds.take(1):
+ for i in range(9):
+ augmented_images = data_augmentation(images)
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(np.array(augmented_images[0]).astype("uint8"))
+ plt.axis("off")
+
+
+"""
+## Standardizing the data
+
+Our image are already in a standard size (180x180), as they are being yielded as
+contiguous `float32` batches by our dataset. However, their RGB channel values are in
+the `[0, 255]` range. This is not ideal for a neural network;
+in general you should seek to make your input values small. Here, we will
+standardize values to be in the `[0, 1]` by using a `Rescaling` layer at the start of
+our model.
+"""
+
+"""
+## Two options to preprocess the data
+
+There are two ways you could be using the `data_augmentation` preprocessor:
+
+**Option 1: Make it part of the model**, like this:
+
+```python
+inputs = keras.Input(shape=input_shape)
+x = data_augmentation(inputs)
+x = layers.Rescaling(1./255)(x)
+... # Rest of the model
+```
+
+With this option, your data augmentation will happen *on device*, synchronously
+with the rest of the model execution, meaning that it will benefit from GPU
+acceleration.
+
+Note that data augmentation is inactive at test time, so the input samples will only be
+augmented during `fit()`, not when calling `evaluate()` or `predict()`.
+
+If you're training on GPU, this may be a good option.
+
+**Option 2: apply it to the dataset**, so as to obtain a dataset that yields batches of
+augmented images, like this:
+
+```python
+augmented_train_ds = train_ds.map(
+ lambda x, y: (data_augmentation(x, training=True), y))
+```
+
+With this option, your data augmentation will happen **on CPU**, asynchronously, and will
+be buffered before going into the model.
+
+If you're training on CPU, this is the better option, since it makes data augmentation
+asynchronous and non-blocking.
+
+In our case, we'll go with the second option. If you're not sure
+which one to pick, this second option (asynchronous preprocessing) is always a solid choice.
+"""
+
+"""
+## Configure the dataset for performance
+
+Let's apply data augmentation to our training dataset,
+and let's make sure to use buffered prefetching so we can yield data from disk without
+having I/O becoming blocking:
+"""
+
+# Apply `data_augmentation` to the training images.
+train_ds = train_ds.map(
+ lambda img, label: (data_augmentation(img), label),
+ num_parallel_calls=tf_data.AUTOTUNE,
+)
+# Prefetching samples in GPU memory helps maximize GPU utilization.
+train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
+val_ds = val_ds.prefetch(tf_data.AUTOTUNE)
+
+"""
+## Build a model
+
+We'll build a small version of the Xception network. We haven't particularly tried to
+optimize the architecture; if you want to do a systematic search for the best model
+configuration, consider using
+[KerasTuner](https://github.com/keras-team/keras-tuner).
+
+Note that:
+
+- We start the model with the `data_augmentation` preprocessor, followed by a
+ `Rescaling` layer.
+- We include a `Dropout` layer before the final classification layer.
+"""
+
+
+def make_model(input_shape, num_classes):
+ inputs = keras.Input(shape=input_shape)
+
+ # Entry block
+ x = layers.Rescaling(1.0 / 255)(inputs)
+ x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+ x = layers.Activation("relu")(x)
+
+ previous_block_activation = x # Set aside residual
+
+ for size in [256, 512, 728]:
+ x = layers.Activation("relu")(x)
+ x = layers.SeparableConv2D(size, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Activation("relu")(x)
+ x = layers.SeparableConv2D(size, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
+
+ # Project residual
+ residual = layers.Conv2D(size, 1, strides=2, padding="same")(
+ previous_block_activation
+ )
+ x = layers.add([x, residual]) # Add back residual
+ previous_block_activation = x # Set aside next residual
+
+ x = layers.SeparableConv2D(1024, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+ x = layers.Activation("relu")(x)
+
+ x = layers.GlobalAveragePooling2D()(x)
+ if num_classes == 2:
+ units = 1
+ else:
+ units = num_classes
+
+ x = layers.Dropout(0.25)(x)
+ # We specify activation=None so as to return logits
+ outputs = layers.Dense(units, activation=None)(x)
+ return keras.Model(inputs, outputs)
+
+
+model = make_model(input_shape=image_size + (3,), num_classes=2)
+keras.utils.plot_model(model, show_shapes=True)
+
+"""
+## Train the model
+"""
+
+epochs = 25
+
+callbacks = [
+ keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
+]
+model.compile(
+ optimizer=keras.optimizers.Adam(3e-4),
+ loss=keras.losses.BinaryCrossentropy(from_logits=True),
+ metrics=[keras.metrics.BinaryAccuracy(name="acc")],
+)
+model.fit(
+ train_ds,
+ epochs=epochs,
+ callbacks=callbacks,
+ validation_data=val_ds,
+)
+
+"""
+We get to >90% validation accuracy after training for 25 epochs on the full dataset
+(in practice, you can train for 50+ epochs before validation performance starts degrading).
+"""
+
+"""
+## Run inference on new data
+
+Note that data augmentation and dropout are inactive at inference time.
+"""
+
+img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)
+plt.imshow(img)
+
+img_array = keras.utils.img_to_array(img)
+img_array = keras.ops.expand_dims(img_array, 0) # Create batch axis
+
+predictions = model.predict(img_array)
+score = float(keras.ops.sigmoid(predictions[0][0]))
+print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")
diff --git a/knowledge_base/vision/image_classification_using_global_context_vision_transformer.py b/knowledge_base/vision/image_classification_using_global_context_vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d0a3b78af2968aa38df2c612cd76578ac8b093f
--- /dev/null
+++ b/knowledge_base/vision/image_classification_using_global_context_vision_transformer.py
@@ -0,0 +1,1187 @@
+"""
+Title: Image Classification using Global Context Vision Transformer
+Author: Md Awsafur Rahman
+Date created: 2023/10/30
+Last modified: 2023/10/30
+Description: Implementation and fine-tuning of Global Context Vision Transformer for image classification.
+Accelerator: GPU
+"""
+
+"""
+# Setup
+"""
+
+"""shell
+pip install --upgrade keras_cv tensorflow
+pip install --upgrade keras
+"""
+
+import keras
+from keras_cv.layers import DropPath
+from keras import ops
+from keras import layers
+
+import tensorflow as tf # only for dataloader
+import tensorflow_datasets as tfds # for flower dataset
+
+from skimage.data import chelsea
+import matplotlib.pyplot as plt
+import numpy as np
+
+"""
+## Introduction
+
+In this notebook, we will utilize multi-backend Keras 3.0 to implement the
+[**GCViT: Global Context Vision Transformer**](https://arxiv.org/abs/2206.09959) paper,
+presented at ICML 2023 by A Hatamizadeh et al. The, we will fine-tune the model on the
+Flower dataset for image classification task, leveraging the official ImageNet pre-trained
+weights. A highlight of this notebook is its compatibility with multiple backends:
+TensorFlow, PyTorch, and JAX, showcasing the true potential of multi-backend Keras.
+"""
+
+"""
+## Motivation
+
+> **Note:** In this section we'll learn about the backstory of GCViT and try to
+understand why it is proposed.
+
+* During recent years, **Transformers** have achieved dominance in **Natural Language
+Processing (NLP)** tasks and with the **self-attention** mechanism which allows for
+capturing both long and short-range information.
+* Following this trend, **Vision Transformer (ViT)** proposed to utilize image patches as
+tokens in a gigantic architecture similar to encoder of the original Transformer.
+* Despite the historic dominance of **Convolutional Neural Network (CNN)** in computer
+vision, **ViT-based** models have shown **SOTA or competitive performance** in various
+computer vision tasks.
+
+
+* However, **quadratic [`O(n^2)`] computational complexity** of self-attention and **lack
+of multi-scale information** makes it difficult for **ViT** to be considered as
+general-purpose architecture for Compute Vision tasks like **segmentation and object
+detection** where it requires **dense prediction at the pixel level**.
+* Swin Transformer has attempted to address the issues of **ViT** by proposing
+**multi-resolution/hierarchical** architectures in which the self-attention is computed
+in **local windows** and cross-window connections such as **window shifting** are used
+for modeling the interactions across different regions. But the **limited receptive field
+of local windows** can not capture long-range information, and cross-window-connection
+schemes such as **window-shifting only cover a small neighborhood** in the vicinity of
+each window. Also, it lacks **inductive-bias** that encourages certain translation
+invariance is still preferable for general-purpose visual modeling, particularly for the
+dense prediction tasks of object detection and semantic segmentation.
+
+
+
+
+* To address above limitations, **Global Context (GC) ViT** network is proposed.
+"""
+
+"""
+## Architecture
+
+Let's have a quick **overview** of our key components,
+1. `Stem/PatchEmbed:` A stem/patchify layer processes images at the networkโs beginning.
+For this network, it creates **patches/tokens** and converts them into **embeddings**.
+2. `Level:` It is the repetitive building block that extracts features using different
+blocks.
+3. `Global Token Gen./FeatureExtraction:` It generates **global tokens/patches** with
+**Depthwise-CNN**, **SqueezeAndExcitation (Squeeze-Excitation)**, **CNN** and
+**MaxPooling**. So basically
+it's a Feature Extractor.
+4. `Block:` It is the repetitive module that applies attention to the features and
+projects them to a certain dimension.
+ 1. `Local-MSA:` Local Multi head Self Attention.
+ 2. `Global-MSA:` Global Multi head Self Attention.
+ 3. `MLP:` Linear layer that projects a vector to another dimension.
+5. `Downsample/ReduceSize:` It is very similar to **Global Token Gen.** module except it
+uses **CNN** instead of **MaxPooling** to downsample with additional **Layer
+Normalization** modules.
+6. `Head:` It is the module responsible for the classification task.
+ 1. `Pooling:` It converts `N x 2D` features to `N x 1D` features.
+ 2. `Classifier:` It processes `N x 1D` features to make a decision about class.
+
+I've annotated the architecture figure to make it easier to digest,
+
+"""
+
+"""
+### Unit Blocks
+
+> **Note:** This blocks are used to build other modules throughout the paper. Most of the
+blocks are either borrowed from other work or modified version old work.
+
+1. `SqueezeAndExcitation`: **Squeeze-Excitation (SE)** aka **Bottleneck** module acts sd
+kind of **channel
+attention**. It consits of **AvgPooling**, **Dense/FullyConnected (FC)/Linear** ,
+**GELU** and **Sigmoid** module.
+
+
+2. `Fused-MBConv:` This is similar to the one used in **EfficientNetV2**. It uses
+**Depthwise-Conv**, **GELU**, **SqueezeAndExcitation**, **Conv**, to extract feature with
+a resiudal
+connection. Note that, no new module is declared for this one, we simply applied
+corresponding modules directly.
+
+
+3. `ReduceSize`: It is a **CNN** based **downsample** module which abvobe mentioned
+`Fused-MBConv` module to extract feature, **Strided Conv** to simultaneously reduce
+spatial dimension and increse channelwise dimention of the features and finally
+**LayerNormalization** module to normalize features. In the paper/figure this module is
+referred as **downsample** module. I think it is mention worthy that **SwniTransformer**
+used `PatchMerging` module instead of `ReduceSize` to reduce the spatial dimention and
+increase channelwise dimension which uses **fully-connected/dense/linear** module.
+According to the **GCViT** paper, one of the purposes of using `ReduceSize` is to add
+inductive bias through **CNN** module.
+
+
+4. `MLP:` This is our very own **Multi Layer Perceptron** module. This a
+feed-forward/fully-connected/linear module which simply projects input to an arbitary
+dimension.
+"""
+
+
+class SqueezeAndExcitation(layers.Layer):
+ """Squeeze and excitation block.
+
+ Args:
+ output_dim: output features dimension, if `None` use same dim as input.
+ expansion: expansion ratio.
+ """
+
+ def __init__(self, output_dim=None, expansion=0.25, **kwargs):
+ super().__init__(**kwargs)
+ self.expansion = expansion
+ self.output_dim = output_dim
+
+ def build(self, input_shape):
+ inp = input_shape[-1]
+ self.output_dim = self.output_dim or inp
+ self.avg_pool = layers.GlobalAvgPool2D(keepdims=True, name="avg_pool")
+ self.fc = [
+ layers.Dense(int(inp * self.expansion), use_bias=False, name="fc_0"),
+ layers.Activation("gelu", name="fc_1"),
+ layers.Dense(self.output_dim, use_bias=False, name="fc_2"),
+ layers.Activation("sigmoid", name="fc_3"),
+ ]
+ super().build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ x = self.avg_pool(inputs)
+ for layer in self.fc:
+ x = layer(x)
+ return x * inputs
+
+
+class ReduceSize(layers.Layer):
+ """Down-sampling block.
+
+ Args:
+ keepdims: if False spatial dim is reduced and channel dim is increased
+ """
+
+ def __init__(self, keepdims=False, **kwargs):
+ super().__init__(**kwargs)
+ self.keepdims = keepdims
+
+ def build(self, input_shape):
+ embed_dim = input_shape[-1]
+ dim_out = embed_dim if self.keepdims else 2 * embed_dim
+ self.pad1 = layers.ZeroPadding2D(1, name="pad1")
+ self.pad2 = layers.ZeroPadding2D(1, name="pad2")
+ self.conv = [
+ layers.DepthwiseConv2D(
+ kernel_size=3, strides=1, padding="valid", use_bias=False, name="conv_0"
+ ),
+ layers.Activation("gelu", name="conv_1"),
+ SqueezeAndExcitation(name="conv_2"),
+ layers.Conv2D(
+ embed_dim,
+ kernel_size=1,
+ strides=1,
+ padding="valid",
+ use_bias=False,
+ name="conv_3",
+ ),
+ ]
+ self.reduction = layers.Conv2D(
+ dim_out,
+ kernel_size=3,
+ strides=2,
+ padding="valid",
+ use_bias=False,
+ name="reduction",
+ )
+ self.norm1 = layers.LayerNormalization(
+ -1, 1e-05, name="norm1"
+ ) # eps like PyTorch
+ self.norm2 = layers.LayerNormalization(-1, 1e-05, name="norm2")
+
+ def call(self, inputs, **kwargs):
+ x = self.norm1(inputs)
+ xr = self.pad1(x)
+ for layer in self.conv:
+ xr = layer(xr)
+ x = x + xr
+ x = self.pad2(x)
+ x = self.reduction(x)
+ x = self.norm2(x)
+ return x
+
+
+class MLP(layers.Layer):
+ """Multi-Layer Perceptron (MLP) block.
+
+ Args:
+ hidden_features: hidden features dimension.
+ out_features: output features dimension.
+ activation: activation function.
+ dropout: dropout rate.
+ """
+
+ def __init__(
+ self,
+ hidden_features=None,
+ out_features=None,
+ activation="gelu",
+ dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_features = hidden_features
+ self.out_features = out_features
+ self.activation = activation
+ self.dropout = dropout
+
+ def build(self, input_shape):
+ self.in_features = input_shape[-1]
+ self.hidden_features = self.hidden_features or self.in_features
+ self.out_features = self.out_features or self.in_features
+ self.fc1 = layers.Dense(self.hidden_features, name="fc1")
+ self.act = layers.Activation(self.activation, name="act")
+ self.fc2 = layers.Dense(self.out_features, name="fc2")
+ self.drop1 = layers.Dropout(self.dropout, name="drop1")
+ self.drop2 = layers.Dropout(self.dropout, name="drop2")
+
+ def call(self, inputs, **kwargs):
+ x = self.fc1(inputs)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+"""
+### Stem
+
+> **Notes**: In the code, this module is referred to as **PatchEmbed** but on paper, it
+is referred to as **Stem**.
+
+In the model, we have first used `patch_embed` module. Let's try to understand this
+module. As we can see from the `call` method,
+1. This module first **pads** input
+2. Then uses **convolutions** to extract patches with embeddings.
+3. Finally, uses `ReduceSize` module to first extract features with **convolution** but
+neither reduces spatial dimension nor increases spatial dimension.
+4. One important point to notice, unlike **ViT** or **SwinTransformer**, **GCViT**
+creates **overlapping patches**. We can notice that from the code,
+`Conv2D(self.embed_dim, kernel_size=3, strides=2, name='proj')`. If we wanted
+**non-overlapping** patches then we would've used the same `kernel_size` and `stride`.
+5. This module reduces the spatial dimension of input by `4x`.
+> Summary: image โ padding โ convolution โ
+(feature_extract + downsample)
+"""
+
+
+class PatchEmbed(layers.Layer):
+ """Patch embedding block.
+
+ Args:
+ embed_dim: feature size dimension.
+ """
+
+ def __init__(self, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ def build(self, input_shape):
+ self.pad = layers.ZeroPadding2D(1, name="pad")
+ self.proj = layers.Conv2D(self.embed_dim, 3, 2, name="proj")
+ self.conv_down = ReduceSize(keepdims=True, name="conv_down")
+
+ def call(self, inputs, **kwargs):
+ x = self.pad(inputs)
+ x = self.proj(x)
+ x = self.conv_down(x)
+ return x
+
+
+"""
+### Global Token Gen.
+
+> **Notes:** It is one of the two **CNN** modules that is used to imppose inductive bias.
+
+As we can see from above cell, in the `level` we have first used `to_q_global/Global
+Token Gen./FeatureExtraction`. Let's try to understand how it works,
+
+* This module is series of `FeatureExtract` module, according to paper we need to
+repeat this module `K` times, where `K = log2(H/h)`, `H = feature_map_height`,
+`W = feature_map_width`.
+* `FeatureExtraction:` This layer is very similar to `ReduceSize` module except it uses
+**MaxPooling** module to reduce the dimension, it doesn't increse feature dimension
+(channelsie) and it doesn't uses **LayerNormalizaton**. This module is used to in
+`Generate Token Gen.` module repeatedly to generte **global tokens** for
+**global-context-attention**.
+* One important point to notice from the figure is that, **global tokens** is shared
+across the whole image which means we use only **one global window** for **all local
+tokens** in a image. This makes the computation very efficient.
+* For input feature map with shape `(B, H, W, C)`, we'll get output shape `(B, h, w, C)`.
+If we copy these global tokens for total `M` local windows in an image where,
+`M = (H x W)/(h x w) = num_window`, then output shape: `(B * M, h, w, C)`."
+
+> Summary: This module is used to `resize` the image to fit window.
+
+
+"""
+
+
+class FeatureExtraction(layers.Layer):
+ """Feature extraction block.
+
+ Args:
+ keepdims: bool argument for maintaining the resolution.
+ """
+
+ def __init__(self, keepdims=False, **kwargs):
+ super().__init__(**kwargs)
+ self.keepdims = keepdims
+
+ def build(self, input_shape):
+ embed_dim = input_shape[-1]
+ self.pad1 = layers.ZeroPadding2D(1, name="pad1")
+ self.pad2 = layers.ZeroPadding2D(1, name="pad2")
+ self.conv = [
+ layers.DepthwiseConv2D(3, 1, use_bias=False, name="conv_0"),
+ layers.Activation("gelu", name="conv_1"),
+ SqueezeAndExcitation(name="conv_2"),
+ layers.Conv2D(embed_dim, 1, 1, use_bias=False, name="conv_3"),
+ ]
+ if not self.keepdims:
+ self.pool = layers.MaxPool2D(3, 2, name="pool")
+ super().build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ x = inputs
+ xr = self.pad1(x)
+ for layer in self.conv:
+ xr = layer(xr)
+ x = x + xr
+ if not self.keepdims:
+ x = self.pool(self.pad2(x))
+ return x
+
+
+class GlobalQueryGenerator(layers.Layer):
+ """Global query generator.
+
+ Args:
+ keepdims: to keep the dimension of FeatureExtraction layer.
+ For instance, repeating log(56/7) = 3 blocks, with input
+ window dimension 56 and output window dimension 7 at down-sampling
+ ratio 2. Please check Fig.5 of GC ViT paper for details.
+ """
+
+ def __init__(self, keepdims=False, **kwargs):
+ super().__init__(**kwargs)
+ self.keepdims = keepdims
+
+ def build(self, input_shape):
+ self.to_q_global = [
+ FeatureExtraction(keepdims, name=f"to_q_global_{i}")
+ for i, keepdims in enumerate(self.keepdims)
+ ]
+ super().build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ x = inputs
+ for layer in self.to_q_global:
+ x = layer(x)
+ return x
+
+
+"""
+### Attention
+
+> **Notes:** This is the core contribution of the paper.
+
+As we can see from the `call` method,
+1. `WindowAttention` module applies both **local** and **global** window attention
+depending on `global_query` parameter.
+
+2. First it converts input features into `query, key, value` for local attention and
+`key, value` for global attention. For global attention, it takes global query from
+`Global Token Gen.`. One thing to notice from the code is that we divide the **features
+or embed_dim** among all the **heads of Transformer** to reduce the computation.
+`qkv = tf.reshape(qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads])`
+3. Before sending query, key and value for attention, **global token** goes through an
+important process. Same global tokens or one global window gets copied for all the local
+windows to increase efficiency.
+`q_global = tf.repeat(q_global, repeats=B_//B, axis=0)`, here `B_//B` means `num_windows`
+in a image.
+4. Then simply applies `local-window-self-attention` or `global-window-attention`
+depending on `global_query` parameter. One thing to notice from the code is that we are
+adding **relative-positional-embedding** with the **attention mask** instead of the
+**patch embedding**.
+`attn = attn + relative_position_bias[tf.newaxis,]`
+
+5. Now, let's think for a bit and try to understand what is happening here. Let's focus
+on the figure below. We can see from the left, that in the **local-attention** the
+**query is local** and it's **limited to the local window** (red square border) hence we
+don't have access to long-range information. But on the right that due to **global
+query** we're now **not limited to local-windows** (blue square border) and we have
+access to long-range information.
+
+6. In **ViT** we compare (attention) image-tokens with image-tokens, in
+**SwinTransformer** we compare window-tokens with window-tokens but in **GCViT** we
+compare image-tokens with window-tokens. But now you may ask, how can compare(attention)
+image-tokens with window-tokens even after image-tokens have larger dimensions than
+window-tokens? (from above figure image-tokens have shape `(1, 8, 8, 3)` and
+window-tokens have shape `(1, 4, 4, 3)`). Yes, you are right we can't directly compare
+them hence we resize image-tokens to fit window-tokens with `Global Token
+Gen./FeatureExtraction` **CNN** module. The following table should give you a clear
+comparison,
+
+| Model | Query Tokens | Key-Value Tokens | Attention Type | Attention Coverage |
+|------------------|-----------------|-------------------|---------------------------|--------------------|
+| ViT | image | image | self-attention | global |
+| SwinTransformer | window | window | self-attention | local |
+| **GCViT** | **resized-image** | **window** | **image-window attention** | **global** |
+
+"""
+
+
+class WindowAttention(layers.Layer):
+ """Local window attention.
+
+ This implementation was proposed by
+ [Liu et al., 2021](https://arxiv.org/abs/2103.14030) in SwinTransformer.
+
+ Args:
+ window_size: window size.
+ num_heads: number of attention head.
+ global_query: if the input contains global_query
+ qkv_bias: bool argument for query, key, value learnable bias.
+ qk_scale: bool argument to scaling query, key.
+ attention_dropout: attention dropout rate.
+ projection_dropout: output dropout rate.
+ """
+
+ def __init__(
+ self,
+ window_size,
+ num_heads,
+ global_query,
+ qkv_bias=True,
+ qk_scale=None,
+ attention_dropout=0.0,
+ projection_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ window_size = (window_size, window_size)
+ self.window_size = window_size
+ self.num_heads = num_heads
+ self.global_query = global_query
+ self.qkv_bias = qkv_bias
+ self.qk_scale = qk_scale
+ self.attention_dropout = attention_dropout
+ self.projection_dropout = projection_dropout
+
+ def build(self, input_shape):
+ embed_dim = input_shape[0][-1]
+ head_dim = embed_dim // self.num_heads
+ self.scale = self.qk_scale or head_dim**-0.5
+ self.qkv_size = 3 - int(self.global_query)
+ self.qkv = layers.Dense(
+ embed_dim * self.qkv_size, use_bias=self.qkv_bias, name="qkv"
+ )
+ self.relative_position_bias_table = self.add_weight(
+ name="relative_position_bias_table",
+ shape=[
+ (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
+ self.num_heads,
+ ],
+ initializer=keras.initializers.TruncatedNormal(stddev=0.02),
+ trainable=True,
+ dtype=self.dtype,
+ )
+ self.attn_drop = layers.Dropout(self.attention_dropout, name="attn_drop")
+ self.proj = layers.Dense(embed_dim, name="proj")
+ self.proj_drop = layers.Dropout(self.projection_dropout, name="proj_drop")
+ self.softmax = layers.Activation("softmax", name="softmax")
+ super().build(input_shape)
+
+ def get_relative_position_index(self):
+ coords_h = ops.arange(self.window_size[0])
+ coords_w = ops.arange(self.window_size[1])
+ coords = ops.stack(ops.meshgrid(coords_h, coords_w, indexing="ij"), axis=0)
+ coords_flatten = ops.reshape(coords, [2, -1])
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = ops.transpose(relative_coords, axes=[1, 2, 0])
+ relative_coords_xx = relative_coords[:, :, 0] + self.window_size[0] - 1
+ relative_coords_yy = relative_coords[:, :, 1] + self.window_size[1] - 1
+ relative_coords_xx = relative_coords_xx * (2 * self.window_size[1] - 1)
+ relative_position_index = relative_coords_xx + relative_coords_yy
+ return relative_position_index
+
+ def call(self, inputs, **kwargs):
+ if self.global_query:
+ inputs, q_global = inputs
+ B = ops.shape(q_global)[0] # B, N, C
+ else:
+ inputs = inputs[0]
+ B_, N, C = ops.shape(inputs) # B*num_window, num_tokens, channels
+ qkv = self.qkv(inputs)
+ qkv = ops.reshape(
+ qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads]
+ )
+ qkv = ops.transpose(qkv, [2, 0, 3, 1, 4])
+ if self.global_query:
+ k, v = ops.split(
+ qkv, indices_or_sections=2, axis=0
+ ) # for unknown shame num=None will throw error
+ q_global = ops.repeat(
+ q_global, repeats=B_ // B, axis=0
+ ) # num_windows = B_//B => q_global same for all windows in a img
+ q = ops.reshape(q_global, [B_, N, self.num_heads, C // self.num_heads])
+ q = ops.transpose(q, axes=[0, 2, 1, 3])
+ else:
+ q, k, v = ops.split(qkv, indices_or_sections=3, axis=0)
+ q = ops.squeeze(q, axis=0)
+
+ k = ops.squeeze(k, axis=0)
+ v = ops.squeeze(v, axis=0)
+
+ q = q * self.scale
+ attn = q @ ops.transpose(k, axes=[0, 1, 3, 2])
+ relative_position_bias = ops.take(
+ self.relative_position_bias_table,
+ ops.reshape(self.get_relative_position_index(), [-1]),
+ )
+ relative_position_bias = ops.reshape(
+ relative_position_bias,
+ [
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1,
+ ],
+ )
+ relative_position_bias = ops.transpose(relative_position_bias, axes=[2, 0, 1])
+ attn = attn + relative_position_bias[None,]
+ attn = self.softmax(attn)
+ attn = self.attn_drop(attn)
+
+ x = ops.transpose((attn @ v), axes=[0, 2, 1, 3])
+ x = ops.reshape(x, [B_, N, C])
+ x = self.proj_drop(self.proj(x))
+ return x
+
+
+"""
+### Block
+
+> **Notes:** This module doesn't have any Convolutional module.
+
+In the `level` second module that we have used is `block`. Let's try to understand how it
+works. As we can see from the `call` method,
+1. `Block` module takes either only feature_maps for local attention or additional global
+query for global attention.
+2. Before sending feature maps for attention, this module converts **batch feature maps**
+to **batch windows** as we'll be applying **Window Attention**.
+3. Then we send batch **batch windows** for attention.
+4. After attention has been applied we revert **batch windows** to **batch feature maps**.
+5. Before sending the attention to applied features for output, this module applies
+**Stochastic Depth** regularization in the residual connection. Also, before applying
+**Stochastic Depth** it rescales the input with trainable parameters. Note that, this
+**Stochastic Depth** block hasn't been shown in the figure of the paper.
+
+
+
+
+### Window
+In the `block` module, we have created **windows** before and after applying attention.
+Let's try to understand how we're creating windows,
+* Following module converts feature maps `(B, H, W, C)` to stacked windows
+`(B x H/h x W/w, h, w, C)` โ `(num_windows_batch, window_size, window_size, channel)`
+* This module uses `reshape` & `transpose` to create these windows out of image instead
+of iterating over them.
+"""
+
+
+class Block(layers.Layer):
+ """GCViT block.
+
+ Args:
+ window_size: window size.
+ num_heads: number of attention head.
+ global_query: apply global window attention
+ mlp_ratio: MLP ratio.
+ qkv_bias: bool argument for query, key, value learnable bias.
+ qk_scale: bool argument to scaling query, key.
+ drop: dropout rate.
+ attention_dropout: attention dropout rate.
+ path_drop: drop path rate.
+ activation: activation function.
+ layer_scale: layer scaling coefficient.
+ """
+
+ def __init__(
+ self,
+ window_size,
+ num_heads,
+ global_query,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ dropout=0.0,
+ attention_dropout=0.0,
+ path_drop=0.0,
+ activation="gelu",
+ layer_scale=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.window_size = window_size
+ self.num_heads = num_heads
+ self.global_query = global_query
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.qk_scale = qk_scale
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.path_drop = path_drop
+ self.activation = activation
+ self.layer_scale = layer_scale
+
+ def build(self, input_shape):
+ B, H, W, C = input_shape[0]
+ self.norm1 = layers.LayerNormalization(-1, 1e-05, name="norm1")
+ self.attn = WindowAttention(
+ window_size=self.window_size,
+ num_heads=self.num_heads,
+ global_query=self.global_query,
+ qkv_bias=self.qkv_bias,
+ qk_scale=self.qk_scale,
+ attention_dropout=self.attention_dropout,
+ projection_dropout=self.dropout,
+ name="attn",
+ )
+ self.drop_path1 = DropPath(self.path_drop)
+ self.drop_path2 = DropPath(self.path_drop)
+ self.norm2 = layers.LayerNormalization(-1, 1e-05, name="norm2")
+ self.mlp = MLP(
+ hidden_features=int(C * self.mlp_ratio),
+ dropout=self.dropout,
+ activation=self.activation,
+ name="mlp",
+ )
+ if self.layer_scale is not None:
+ self.gamma1 = self.add_weight(
+ name="gamma1",
+ shape=[C],
+ initializer=keras.initializers.Constant(self.layer_scale),
+ trainable=True,
+ dtype=self.dtype,
+ )
+ self.gamma2 = self.add_weight(
+ name="gamma2",
+ shape=[C],
+ initializer=keras.initializers.Constant(self.layer_scale),
+ trainable=True,
+ dtype=self.dtype,
+ )
+ else:
+ self.gamma1 = 1.0
+ self.gamma2 = 1.0
+ self.num_windows = int(H // self.window_size) * int(W // self.window_size)
+ super().build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ if self.global_query:
+ inputs, q_global = inputs
+ else:
+ inputs = inputs[0]
+ B, H, W, C = ops.shape(inputs)
+ x = self.norm1(inputs)
+ # create windows and concat them in batch axis
+ x = self.window_partition(x, self.window_size) # (B_, win_h, win_w, C)
+ # flatten patch
+ x = ops.reshape(x, [-1, self.window_size * self.window_size, C])
+ # attention
+ if self.global_query:
+ x = self.attn([x, q_global])
+ else:
+ x = self.attn([x])
+ # reverse window partition
+ x = self.window_reverse(x, self.window_size, H, W, C)
+ # FFN
+ x = inputs + self.drop_path1(x * self.gamma1)
+ x = x + self.drop_path2(self.gamma2 * self.mlp(self.norm2(x)))
+ return x
+
+ def window_partition(self, x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size: window size
+ Returns:
+ local window features (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = ops.shape(x)
+ x = ops.reshape(
+ x,
+ [
+ -1,
+ H // window_size,
+ window_size,
+ W // window_size,
+ window_size,
+ C,
+ ],
+ )
+ x = ops.transpose(x, axes=[0, 1, 3, 2, 4, 5])
+ windows = ops.reshape(x, [-1, window_size, window_size, C])
+ return windows
+
+ def window_reverse(self, windows, window_size, H, W, C):
+ """
+ Args:
+ windows: local window features (num_windows*B, window_size, window_size, C)
+ window_size: Window size
+ H: Height of image
+ W: Width of image
+ C: Channel of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ x = ops.reshape(
+ windows,
+ [
+ -1,
+ H // window_size,
+ W // window_size,
+ window_size,
+ window_size,
+ C,
+ ],
+ )
+ x = ops.transpose(x, axes=[0, 1, 3, 2, 4, 5])
+ x = ops.reshape(x, [-1, H, W, C])
+ return x
+
+
+"""
+### Level
+
+> **Note:** This module has both Transformer and CNN modules.
+
+In the model, the second module that we have used is `level`. Let's try to understand
+this module. As we can see from the `call` method,
+1. First it creates **global_token** with a series of `FeatureExtraction` modules. As
+we'll see
+later that `FeatureExtraction` is nothing but a simple **CNN** based module.
+2. Then it uses series of`Block` modules to apply **local or global window attention**
+depending on depth level.
+3. Finally, it uses `ReduceSize` to reduce the dimension of **contextualized features**.
+
+> Summary: feature_map โ global_token โ local/global window
+attention โ dowsample
+
+
+"""
+
+
+class Level(layers.Layer):
+ """GCViT level.
+
+ Args:
+ depth: number of layers in each stage.
+ num_heads: number of heads in each stage.
+ window_size: window size in each stage.
+ keepdims: dims to keep in FeatureExtraction.
+ downsample: bool argument for down-sampling.
+ mlp_ratio: MLP ratio.
+ qkv_bias: bool argument for query, key, value learnable bias.
+ qk_scale: bool argument to scaling query, key.
+ drop: dropout rate.
+ attention_dropout: attention dropout rate.
+ path_drop: drop path rate.
+ layer_scale: layer scaling coefficient.
+ """
+
+ def __init__(
+ self,
+ depth,
+ num_heads,
+ window_size,
+ keepdims,
+ downsample=True,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ dropout=0.0,
+ attention_dropout=0.0,
+ path_drop=0.0,
+ layer_scale=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.depth = depth
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.keepdims = keepdims
+ self.downsample = downsample
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.qk_scale = qk_scale
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.path_drop = path_drop
+ self.layer_scale = layer_scale
+
+ def build(self, input_shape):
+ path_drop = (
+ [self.path_drop] * self.depth
+ if not isinstance(self.path_drop, list)
+ else self.path_drop
+ )
+ self.blocks = [
+ Block(
+ window_size=self.window_size,
+ num_heads=self.num_heads,
+ global_query=bool(i % 2),
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ qk_scale=self.qk_scale,
+ dropout=self.dropout,
+ attention_dropout=self.attention_dropout,
+ path_drop=path_drop[i],
+ layer_scale=self.layer_scale,
+ name=f"blocks_{i}",
+ )
+ for i in range(self.depth)
+ ]
+ self.down = ReduceSize(keepdims=False, name="downsample")
+ self.q_global_gen = GlobalQueryGenerator(self.keepdims, name="q_global_gen")
+ super().build(input_shape)
+
+ def call(self, inputs, **kwargs):
+ x = inputs
+ q_global = self.q_global_gen(x) # shape: (B, win_size, win_size, C)
+ for i, blk in enumerate(self.blocks):
+ if i % 2:
+ x = blk([x, q_global]) # shape: (B, H, W, C)
+ else:
+ x = blk([x]) # shape: (B, H, W, C)
+ if self.downsample:
+ x = self.down(x) # shape: (B, H//2, W//2, 2*C)
+ return x
+
+
+"""
+### Model
+
+Let's directly jump to the model. As we can see from the `call` method,
+1. It creates patch embeddings from an image. This layer doesn't flattens these
+embeddings which means output of this module will be
+`(batch, height/window_size, width/window_size, embed_dim)` instead of
+`(batch, height x width/window_size^2, embed_dim)`.
+2. Then it applies `Dropout` module which randomly sets input units to 0.
+3. It passes these embeddings to series of `Level` modules which we are calling `level`
+where,
+ 1. Global token is generated
+ 1. Both local & global attention is applied
+ 1. Finally downsample is applied.
+4. So, output after `n` number of **levels**, shape: `(batch, width/window_size x 2^{n-1},
+width/window_size x 2^{n-1}, embed_dim x 2^{n-1})`. In the last layer,
+paper doesn't use **downsample** and increase **channels**.
+5. Output of above layer is normalized using `LayerNormalization` module.
+6. In the head, 2D features are converted to 1D features with `Pooling` module. Output
+shape after this module is `(batch, embed_dim x 2^{n-1})`
+7. Finally, pooled features are sent to `Dense/Linear` module for classification.
+
+> Sumamry: image โ (patchs + embedding) โ dropout
+โ (attention + feature extraction) โ normalizaion โ
+pooling โ classify
+"""
+
+
+class GCViT(keras.Model):
+ """GCViT model.
+
+ Args:
+ window_size: window size in each stage.
+ embed_dim: feature size dimension.
+ depths: number of layers in each stage.
+ num_heads: number of heads in each stage.
+ drop_rate: dropout rate.
+ mlp_ratio: MLP ratio.
+ qkv_bias: bool argument for query, key, value learnable bias.
+ qk_scale: bool argument to scaling query, key.
+ attention_dropout: attention dropout rate.
+ path_drop: drop path rate.
+ layer_scale: layer scaling coefficient.
+ num_classes: number of classes.
+ head_activation: activation function for head.
+ """
+
+ def __init__(
+ self,
+ window_size,
+ embed_dim,
+ depths,
+ num_heads,
+ drop_rate=0.0,
+ mlp_ratio=3.0,
+ qkv_bias=True,
+ qk_scale=None,
+ attention_dropout=0.0,
+ path_drop=0.1,
+ layer_scale=None,
+ num_classes=1000,
+ head_activation="softmax",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.window_size = window_size
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_heads = num_heads
+ self.drop_rate = drop_rate
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.qk_scale = qk_scale
+ self.attention_dropout = attention_dropout
+ self.path_drop = path_drop
+ self.layer_scale = layer_scale
+ self.num_classes = num_classes
+ self.head_activation = head_activation
+
+ self.patch_embed = PatchEmbed(embed_dim=embed_dim, name="patch_embed")
+ self.pos_drop = layers.Dropout(drop_rate, name="pos_drop")
+ path_drops = np.linspace(0.0, path_drop, sum(depths))
+ keepdims = [(0, 0, 0), (0, 0), (1,), (1,)]
+ self.levels = []
+ for i in range(len(depths)):
+ path_drop = path_drops[sum(depths[:i]) : sum(depths[: i + 1])].tolist()
+ level = Level(
+ depth=depths[i],
+ num_heads=num_heads[i],
+ window_size=window_size[i],
+ keepdims=keepdims[i],
+ downsample=(i < len(depths) - 1),
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ dropout=drop_rate,
+ attention_dropout=attention_dropout,
+ path_drop=path_drop,
+ layer_scale=layer_scale,
+ name=f"levels_{i}",
+ )
+ self.levels.append(level)
+ self.norm = layers.LayerNormalization(axis=-1, epsilon=1e-05, name="norm")
+ self.pool = layers.GlobalAvgPool2D(name="pool")
+ self.head = layers.Dense(num_classes, name="head", activation=head_activation)
+
+ def build(self, input_shape):
+ super().build(input_shape)
+ self.built = True
+
+ def call(self, inputs, **kwargs):
+ x = self.patch_embed(inputs) # shape: (B, H, W, C)
+ x = self.pos_drop(x)
+ for level in self.levels:
+ x = level(x) # shape: (B, H_, W_, C_)
+ x = self.norm(x)
+ x = self.pool(x) # shape: (B, C__)
+ x = self.head(x)
+ return x
+
+ def build_graph(self, input_shape=(224, 224, 3)):
+ """
+ ref: https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam
+ """
+ x = keras.Input(shape=input_shape)
+ return keras.Model(inputs=[x], outputs=self.call(x), name=self.name)
+
+ def summary(self, input_shape=(224, 224, 3)):
+ return self.build_graph(input_shape).summary()
+
+
+"""
+## Build Model
+
+* Let's build a complete model with all the modules that we've explained above. We'll
+build **GCViT-XXTiny** model with the configuration mentioned in the paper.
+* Also we'll load the ported official **pre-trained** weights and try for some
+predictions.
+"""
+
+# Model Configs
+config = {
+ "window_size": (7, 7, 14, 7),
+ "embed_dim": 64,
+ "depths": (2, 2, 6, 2),
+ "num_heads": (2, 4, 8, 16),
+ "mlp_ratio": 3.0,
+ "path_drop": 0.2,
+}
+ckpt_link = (
+ "https://github.com/awsaf49/gcvit-tf/releases/download/v1.1.6/gcvitxxtiny.keras"
+)
+
+# Build Model
+model = GCViT(**config)
+inp = ops.array(np.random.uniform(size=(1, 224, 224, 3)))
+out = model(inp)
+
+# Load Weights
+ckpt_path = keras.utils.get_file(ckpt_link.split("/")[-1], ckpt_link)
+model.load_weights(ckpt_path)
+
+# Summary
+model.summary((224, 224, 3))
+
+"""
+## Sanity check for Pre-Trained Weights
+"""
+
+img = keras.applications.imagenet_utils.preprocess_input(
+ chelsea(), mode="torch"
+) # Chelsea the cat
+img = ops.image.resize(img, (224, 224))[None,] # resize & create batch
+pred = model(img)
+pred_dec = keras.applications.imagenet_utils.decode_predictions(pred)[0]
+
+print("\n# Image:")
+plt.figure(figsize=(6, 6))
+plt.imshow(chelsea())
+plt.show()
+print()
+
+print("# Prediction (Top 5):")
+for i in range(5):
+ print("{:<12} : {:0.2f}".format(pred_dec[i][1], pred_dec[i][2]))
+
+"""
+# Fine-tune **GCViT** Model
+
+In the following cells, we will fine-tune **GCViT** model on Flower Dataset which
+consists `104` classes.
+"""
+
+"""
+### Configs
+"""
+
+# Model
+IMAGE_SIZE = (224, 224)
+
+# Hyper Params
+BATCH_SIZE = 32
+EPOCHS = 5
+
+# Dataset
+CLASSES = [
+ "dandelion",
+ "daisy",
+ "tulips",
+ "sunflowers",
+ "roses",
+] # don't change the order
+
+# Other constants
+MEAN = 255 * np.array([0.485, 0.456, 0.406], dtype="float32") # imagenet mean
+STD = 255 * np.array([0.229, 0.224, 0.225], dtype="float32") # imagenet std
+AUTO = tf.data.AUTOTUNE
+
+"""
+## Data Loader
+"""
+
+
+def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = IMAGE_SIZE):
+ def preprocess(image, label):
+ # for training, do augmentation
+ if train:
+ if tf.random.uniform(shape=[]) > 0.5:
+ image = tf.image.flip_left_right(image)
+ image = tf.image.resize(image, size=image_size, method="bicubic")
+ image = (image - MEAN) / STD # normalization
+ return image, label
+
+ if train:
+ dataset = dataset.shuffle(BATCH_SIZE * 10)
+
+ return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)
+
+
+"""
+### Flower Dataset
+"""
+
+train_dataset, val_dataset = tfds.load(
+ "tf_flowers",
+ split=["train[:90%]", "train[90%:]"],
+ as_supervised=True,
+ try_gcs=False, # gcs_path is necessary for tpu,
+)
+
+train_dataset = make_dataset(train_dataset, True)
+val_dataset = make_dataset(val_dataset, False)
+
+"""
+### Re-Build Model for Flower Dataset
+"""
+
+# Re-Build Model
+model = GCViT(**config, num_classes=104)
+inp = ops.array(np.random.uniform(size=(1, 224, 224, 3)))
+out = model(inp)
+
+# Load Weights
+ckpt_path = keras.utils.get_file(ckpt_link.split("/")[-1], ckpt_link)
+model.load_weights(ckpt_path, skip_mismatch=True)
+
+model.compile(
+ loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
+)
+
+"""
+### Training
+"""
+
+history = model.fit(
+ train_dataset, validation_data=val_dataset, epochs=EPOCHS, verbose=1
+)
+
+"""
+## Reference
+
+* [gcvit-tf - A Python library for GCViT with TF2.0](https://github.com/awsaf49/gcvit-tf)
+* [gcvit - Official codebase for GCViT](https://github.com/NVlabs/GCVit)
+"""
diff --git a/knowledge_base/vision/image_classification_with_vision_transformer.py b/knowledge_base/vision/image_classification_with_vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25edd9ce86a33de40d9a0e4b77b1162f9b68263
--- /dev/null
+++ b/knowledge_base/vision/image_classification_with_vision_transformer.py
@@ -0,0 +1,332 @@
+"""
+Title: Image classification with Vision Transformer
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/01/18
+Last modified: 2021/01/18
+Description: Implementing the Vision Transformer (ViT) model for image classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example implements the [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
+model by Alexey Dosovitskiy et al. for image classification,
+and demonstrates it on the CIFAR-100 dataset.
+The ViT model applies the Transformer architecture with self-attention to sequences of
+image patches, without using convolution layers.
+
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
+
+import keras
+from keras import layers
+from keras import ops
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+"""
+## Prepare the data
+"""
+
+num_classes = 100
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+
+"""
+## Configure the hyperparameters
+"""
+
+learning_rate = 0.001
+weight_decay = 0.0001
+batch_size = 256
+num_epochs = 10 # For real training, use num_epochs=100. 10 is a test value
+image_size = 72 # We'll resize input images to this size
+patch_size = 6 # Size of the patches to be extract from the input images
+num_patches = (image_size // patch_size) ** 2
+projection_dim = 64
+num_heads = 4
+transformer_units = [
+ projection_dim * 2,
+ projection_dim,
+] # Size of the transformer layers
+transformer_layers = 8
+mlp_head_units = [
+ 2048,
+ 1024,
+] # Size of the dense layers of the final classifier
+
+
+"""
+## Use data augmentation
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.Resizing(image_size, image_size),
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(factor=0.02),
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
+ ],
+ name="data_augmentation",
+)
+# Compute the mean and the variance of the training data for normalization.
+data_augmentation.layers[0].adapt(x_train)
+
+
+"""
+## Implement multilayer perceptron (MLP)
+"""
+
+
+def mlp(x, hidden_units, dropout_rate):
+ for units in hidden_units:
+ x = layers.Dense(units, activation=keras.activations.gelu)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+"""
+## Implement patch creation as a layer
+"""
+
+
+class Patches(layers.Layer):
+ def __init__(self, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+
+ def call(self, images):
+ input_shape = ops.shape(images)
+ batch_size = input_shape[0]
+ height = input_shape[1]
+ width = input_shape[2]
+ channels = input_shape[3]
+ num_patches_h = height // self.patch_size
+ num_patches_w = width // self.patch_size
+ patches = keras.ops.image.extract_patches(images, size=self.patch_size)
+ patches = ops.reshape(
+ patches,
+ (
+ batch_size,
+ num_patches_h * num_patches_w,
+ self.patch_size * self.patch_size * channels,
+ ),
+ )
+ return patches
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"patch_size": self.patch_size})
+ return config
+
+
+"""
+Let's display patches for a sample image
+"""
+
+plt.figure(figsize=(4, 4))
+image = x_train[np.random.choice(range(x_train.shape[0]))]
+plt.imshow(image.astype("uint8"))
+plt.axis("off")
+
+resized_image = ops.image.resize(
+ ops.convert_to_tensor([image]), size=(image_size, image_size)
+)
+patches = Patches(patch_size)(resized_image)
+print(f"Image size: {image_size} X {image_size}")
+print(f"Patch size: {patch_size} X {patch_size}")
+print(f"Patches per image: {patches.shape[1]}")
+print(f"Elements per patch: {patches.shape[-1]}")
+
+n = int(np.sqrt(patches.shape[1]))
+plt.figure(figsize=(4, 4))
+for i, patch in enumerate(patches[0]):
+ ax = plt.subplot(n, n, i + 1)
+ patch_img = ops.reshape(patch, (patch_size, patch_size, 3))
+ plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
+ plt.axis("off")
+
+"""
+## Implement the patch encoding layer
+
+The `PatchEncoder` layer will linearly transform a patch by projecting it into a
+vector of size `projection_dim`. In addition, it adds a learnable position
+embedding to the projected vector.
+"""
+
+
+class PatchEncoder(layers.Layer):
+ def __init__(self, num_patches, projection_dim):
+ super().__init__()
+ self.num_patches = num_patches
+ self.projection = layers.Dense(units=projection_dim)
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+
+ def call(self, patch):
+ positions = ops.expand_dims(
+ ops.arange(start=0, stop=self.num_patches, step=1), axis=0
+ )
+ projected_patches = self.projection(patch)
+ encoded = projected_patches + self.position_embedding(positions)
+ return encoded
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"num_patches": self.num_patches})
+ return config
+
+
+"""
+## Build the ViT model
+
+The ViT model consists of multiple Transformer blocks,
+which use the `layers.MultiHeadAttention` layer as a self-attention mechanism
+applied to the sequence of patches. The Transformer blocks produce a
+`[batch_size, num_patches, projection_dim]` tensor, which is processed via an
+classifier head with softmax to produce the final class probabilities output.
+
+Unlike the technique described in the [paper](https://arxiv.org/abs/2010.11929),
+which prepends a learnable embedding to the sequence of encoded patches to serve
+as the image representation, all the outputs of the final Transformer block are
+reshaped with `layers.Flatten()` and used as the image
+representation input to the classifier head.
+Note that the `layers.GlobalAveragePooling1D` layer
+could also be used instead to aggregate the outputs of the Transformer block,
+especially when the number of patches and the projection dimensions are large.
+"""
+
+
+def create_vit_classifier():
+ inputs = keras.Input(shape=input_shape)
+ # Augment data.
+ augmented = data_augmentation(inputs)
+ # Create patches.
+ patches = Patches(patch_size)(augmented)
+ # Encode patches.
+ encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
+
+ # Create multiple layers of the Transformer block.
+ for _ in range(transformer_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=projection_dim, dropout=0.1
+ )(x1, x1)
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, encoded_patches])
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ # MLP.
+ x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
+ # Skip connection 2.
+ encoded_patches = layers.Add()([x3, x2])
+
+ # Create a [batch_size, projection_dim] tensor.
+ representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ representation = layers.Flatten()(representation)
+ representation = layers.Dropout(0.5)(representation)
+ # Add MLP.
+ features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
+ # Classify outputs.
+ logits = layers.Dense(num_classes)(features)
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=logits)
+ return model
+
+
+"""
+## Compile, train, and evaluate the mode
+"""
+
+
+def run_experiment(model):
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ )
+
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+ )
+
+ checkpoint_filepath = "/tmp/checkpoint.weights.h5"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=0.1,
+ callbacks=[checkpoint_callback],
+ )
+
+ model.load_weights(checkpoint_filepath)
+ _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+ return history
+
+
+vit_classifier = create_vit_classifier()
+history = run_experiment(vit_classifier)
+
+
+def plot_history(item):
+ plt.plot(history.history[item], label=item)
+ plt.plot(history.history["val_" + item], label="val_" + item)
+ plt.xlabel("Epochs")
+ plt.ylabel(item)
+ plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_history("loss")
+plot_history("top-5-accuracy")
+
+
+"""
+After 100 epochs, the ViT model achieves around 55% accuracy and
+82% top-5 accuracy on the test data. These are not competitive results on the CIFAR-100 dataset,
+as a ResNet50V2 trained from scratch on the same data can achieve 67% accuracy.
+
+Note that the state of the art results reported in the
+[paper](https://arxiv.org/abs/2010.11929) are achieved by pre-training the ViT model using
+the JFT-300M dataset, then fine-tuning it on the target dataset. To improve the model quality
+without pre-training, you can try to train the model for more epochs, use a larger number of
+Transformer layers, resize the input images, change the patch size, or increase the projection dimensions.
+Besides, as mentioned in the paper, the quality of the model is affected not only by architecture choices,
+but also by parameters such as the learning rate schedule, optimizer, weight decay, etc.
+In practice, it's recommended to fine-tune a ViT model
+that was pre-trained using a large, high-resolution dataset.
+"""
diff --git a/knowledge_base/vision/integrated_gradients.py b/knowledge_base/vision/integrated_gradients.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7d012a8b462a5a9d30afa74acd43e0cc41ebc87
--- /dev/null
+++ b/knowledge_base/vision/integrated_gradients.py
@@ -0,0 +1,498 @@
+"""
+Title: Model interpretability with Integrated Gradients
+Author: [A_K_Nain](https://twitter.com/A_K_Nain)
+Date created: 2020/06/02
+Last modified: 2020/06/02
+Description: How to obtain integrated gradients for a classification model.
+Accelerator: None
+"""
+
+"""
+## Integrated Gradients
+
+[Integrated Gradients](https://arxiv.org/abs/1703.01365) is a technique for
+attributing a classification model's prediction to its input features. It is
+a model interpretability technique: you can use it to visualize the relationship
+between input features and model predictions.
+
+Integrated Gradients is a variation on computing
+the gradient of the prediction output with regard to features of the input.
+To compute integrated gradients, we need to perform the following steps:
+
+1. Identify the input and the output. In our case, the input is an image and the
+output is the last layer of our model (dense layer with softmax activation).
+
+2. Compute which features are important to a neural network
+when making a prediction on a particular data point. To identify these features, we
+need to choose a baseline input. A baseline input can be a black image (all pixel
+values set to zero) or random noise. The shape of the baseline input needs to be
+the same as our input image, e.g. (299, 299, 3).
+
+3. Interpolate the baseline for a given number of steps. The number of steps represents
+the steps we need in the gradient approximation for a given input image. The number of
+steps is a hyperparameter. The authors recommend using anywhere between
+20 and 1000 steps.
+
+4. Preprocess these interpolated images and do a forward pass.
+5. Get the gradients for these interpolated images.
+6. Approximate the gradients integral using the trapezoidal rule.
+
+To read in-depth about integrated gradients and why this method works,
+consider reading this excellent
+[article](https://distill.pub/2020/attribution-baselines/).
+
+**References:**
+
+- Integrated Gradients original [paper](https://arxiv.org/abs/1703.01365)
+- [Original implementation](https://github.com/ankurtaly/Integrated-Gradients)
+"""
+
+"""
+## Setup
+"""
+
+
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy import ndimage
+from IPython.display import Image, display
+
+import tensorflow as tf
+import keras
+from keras import layers
+from keras.applications import xception
+
+
+# Size of the input image
+img_size = (299, 299, 3)
+
+# Load Xception model with imagenet weights
+model = xception.Xception(weights="imagenet")
+
+# The local path to our target image
+img_path = keras.utils.get_file("elephant.jpg", "https://i.imgur.com/Bvro0YD.png")
+display(Image(img_path))
+
+"""
+## Integrated Gradients algorithm
+"""
+
+
+def get_img_array(img_path, size=(299, 299)):
+ # `img` is a PIL image of size 299x299
+ img = keras.utils.load_img(img_path, target_size=size)
+ # `array` is a float32 Numpy array of shape (299, 299, 3)
+ array = keras.utils.img_to_array(img)
+ # We add a dimension to transform our array into a "batch"
+ # of size (1, 299, 299, 3)
+ array = np.expand_dims(array, axis=0)
+ return array
+
+
+def get_gradients(img_input, top_pred_idx):
+ """Computes the gradients of outputs w.r.t input image.
+
+ Args:
+ img_input: 4D image tensor
+ top_pred_idx: Predicted label for the input image
+
+ Returns:
+ Gradients of the predictions w.r.t img_input
+ """
+ images = tf.cast(img_input, tf.float32)
+
+ with tf.GradientTape() as tape:
+ tape.watch(images)
+ preds = model(images)
+ top_class = preds[:, top_pred_idx]
+
+ grads = tape.gradient(top_class, images)
+ return grads
+
+
+def get_integrated_gradients(img_input, top_pred_idx, baseline=None, num_steps=50):
+ """Computes Integrated Gradients for a predicted label.
+
+ Args:
+ img_input (ndarray): Original image
+ top_pred_idx: Predicted label for the input image
+ baseline (ndarray): The baseline image to start with for interpolation
+ num_steps: Number of interpolation steps between the baseline
+ and the input used in the computation of integrated gradients. These
+ steps along determine the integral approximation error. By default,
+ num_steps is set to 50.
+
+ Returns:
+ Integrated gradients w.r.t input image
+ """
+ # If baseline is not provided, start with a black image
+ # having same size as the input image.
+ if baseline is None:
+ baseline = np.zeros(img_size).astype(np.float32)
+ else:
+ baseline = baseline.astype(np.float32)
+
+ # 1. Do interpolation.
+ img_input = img_input.astype(np.float32)
+ interpolated_image = [
+ baseline + (step / num_steps) * (img_input - baseline)
+ for step in range(num_steps + 1)
+ ]
+ interpolated_image = np.array(interpolated_image).astype(np.float32)
+
+ # 2. Preprocess the interpolated images
+ interpolated_image = xception.preprocess_input(interpolated_image)
+
+ # 3. Get the gradients
+ grads = []
+ for i, img in enumerate(interpolated_image):
+ img = tf.expand_dims(img, axis=0)
+ grad = get_gradients(img, top_pred_idx=top_pred_idx)
+ grads.append(grad[0])
+ grads = tf.convert_to_tensor(grads, dtype=tf.float32)
+
+ # 4. Approximate the integral using the trapezoidal rule
+ grads = (grads[:-1] + grads[1:]) / 2.0
+ avg_grads = tf.reduce_mean(grads, axis=0)
+
+ # 5. Calculate integrated gradients and return
+ integrated_grads = (img_input - baseline) * avg_grads
+ return integrated_grads
+
+
+def random_baseline_integrated_gradients(
+ img_input, top_pred_idx, num_steps=50, num_runs=2
+):
+ """Generates a number of random baseline images.
+
+ Args:
+ img_input (ndarray): 3D image
+ top_pred_idx: Predicted label for the input image
+ num_steps: Number of interpolation steps between the baseline
+ and the input used in the computation of integrated gradients. These
+ steps along determine the integral approximation error. By default,
+ num_steps is set to 50.
+ num_runs: number of baseline images to generate
+
+ Returns:
+ Averaged integrated gradients for `num_runs` baseline images
+ """
+ # 1. List to keep track of Integrated Gradients (IG) for all the images
+ integrated_grads = []
+
+ # 2. Get the integrated gradients for all the baselines
+ for run in range(num_runs):
+ baseline = np.random.random(img_size) * 255
+ igrads = get_integrated_gradients(
+ img_input=img_input,
+ top_pred_idx=top_pred_idx,
+ baseline=baseline,
+ num_steps=num_steps,
+ )
+ integrated_grads.append(igrads)
+
+ # 3. Return the average integrated gradients for the image
+ integrated_grads = tf.convert_to_tensor(integrated_grads)
+ return tf.reduce_mean(integrated_grads, axis=0)
+
+
+"""
+## Helper class for visualizing gradients and integrated gradients
+"""
+
+
+class GradVisualizer:
+ """Plot gradients of the outputs w.r.t an input image."""
+
+ def __init__(self, positive_channel=None, negative_channel=None):
+ if positive_channel is None:
+ self.positive_channel = [0, 255, 0]
+ else:
+ self.positive_channel = positive_channel
+
+ if negative_channel is None:
+ self.negative_channel = [255, 0, 0]
+ else:
+ self.negative_channel = negative_channel
+
+ def apply_polarity(self, attributions, polarity):
+ if polarity == "positive":
+ return np.clip(attributions, 0, 1)
+ else:
+ return np.clip(attributions, -1, 0)
+
+ def apply_linear_transformation(
+ self,
+ attributions,
+ clip_above_percentile=99.9,
+ clip_below_percentile=70.0,
+ lower_end=0.2,
+ ):
+ # 1. Get the thresholds
+ m = self.get_thresholded_attributions(
+ attributions, percentage=100 - clip_above_percentile
+ )
+ e = self.get_thresholded_attributions(
+ attributions, percentage=100 - clip_below_percentile
+ )
+
+ # 2. Transform the attributions by a linear function f(x) = a*x + b such that
+ # f(m) = 1.0 and f(e) = lower_end
+ transformed_attributions = (1 - lower_end) * (np.abs(attributions) - e) / (
+ m - e
+ ) + lower_end
+
+ # 3. Make sure that the sign of transformed attributions is the same as original attributions
+ transformed_attributions *= np.sign(attributions)
+
+ # 4. Only keep values that are bigger than the lower_end
+ transformed_attributions *= transformed_attributions >= lower_end
+
+ # 5. Clip values and return
+ transformed_attributions = np.clip(transformed_attributions, 0.0, 1.0)
+ return transformed_attributions
+
+ def get_thresholded_attributions(self, attributions, percentage):
+ if percentage == 100.0:
+ return np.min(attributions)
+
+ # 1. Flatten the attributions
+ flatten_attr = attributions.flatten()
+
+ # 2. Get the sum of the attributions
+ total = np.sum(flatten_attr)
+
+ # 3. Sort the attributions from largest to smallest.
+ sorted_attributions = np.sort(np.abs(flatten_attr))[::-1]
+
+ # 4. Calculate the percentage of the total sum that each attribution
+ # and the values about it contribute.
+ cum_sum = 100.0 * np.cumsum(sorted_attributions) / total
+
+ # 5. Threshold the attributions by the percentage
+ indices_to_consider = np.where(cum_sum >= percentage)[0][0]
+
+ # 6. Select the desired attributions and return
+ attributions = sorted_attributions[indices_to_consider]
+ return attributions
+
+ def binarize(self, attributions, threshold=0.001):
+ return attributions > threshold
+
+ def morphological_cleanup_fn(self, attributions, structure=np.ones((4, 4))):
+ closed = ndimage.grey_closing(attributions, structure=structure)
+ opened = ndimage.grey_opening(closed, structure=structure)
+ return opened
+
+ def draw_outlines(
+ self,
+ attributions,
+ percentage=90,
+ connected_component_structure=np.ones((3, 3)),
+ ):
+ # 1. Binarize the attributions.
+ attributions = self.binarize(attributions)
+
+ # 2. Fill the gaps
+ attributions = ndimage.binary_fill_holes(attributions)
+
+ # 3. Compute connected components
+ connected_components, num_comp = ndimage.label(
+ attributions, structure=connected_component_structure
+ )
+
+ # 4. Sum up the attributions for each component
+ total = np.sum(attributions[connected_components > 0])
+ component_sums = []
+ for comp in range(1, num_comp + 1):
+ mask = connected_components == comp
+ component_sum = np.sum(attributions[mask])
+ component_sums.append((component_sum, mask))
+
+ # 5. Compute the percentage of top components to keep
+ sorted_sums_and_masks = sorted(component_sums, key=lambda x: x[0], reverse=True)
+ sorted_sums = list(zip(*sorted_sums_and_masks))[0]
+ cumulative_sorted_sums = np.cumsum(sorted_sums)
+ cutoff_threshold = percentage * total / 100
+ cutoff_idx = np.where(cumulative_sorted_sums >= cutoff_threshold)[0][0]
+ if cutoff_idx > 2:
+ cutoff_idx = 2
+
+ # 6. Set the values for the kept components
+ border_mask = np.zeros_like(attributions)
+ for i in range(cutoff_idx + 1):
+ border_mask[sorted_sums_and_masks[i][1]] = 1
+
+ # 7. Make the mask hollow and show only the border
+ eroded_mask = ndimage.binary_erosion(border_mask, iterations=1)
+ border_mask[eroded_mask] = 0
+
+ # 8. Return the outlined mask
+ return border_mask
+
+ def process_grads(
+ self,
+ image,
+ attributions,
+ polarity="positive",
+ clip_above_percentile=99.9,
+ clip_below_percentile=0,
+ morphological_cleanup=False,
+ structure=np.ones((3, 3)),
+ outlines=False,
+ outlines_component_percentage=90,
+ overlay=True,
+ ):
+ if polarity not in ["positive", "negative"]:
+ raise ValueError(
+ f""" Allowed polarity values: 'positive' or 'negative'
+ but provided {polarity}"""
+ )
+ if clip_above_percentile < 0 or clip_above_percentile > 100:
+ raise ValueError("clip_above_percentile must be in [0, 100]")
+
+ if clip_below_percentile < 0 or clip_below_percentile > 100:
+ raise ValueError("clip_below_percentile must be in [0, 100]")
+
+ # 1. Apply polarity
+ if polarity == "positive":
+ attributions = self.apply_polarity(attributions, polarity=polarity)
+ channel = self.positive_channel
+ else:
+ attributions = self.apply_polarity(attributions, polarity=polarity)
+ attributions = np.abs(attributions)
+ channel = self.negative_channel
+
+ # 2. Take average over the channels
+ attributions = np.average(attributions, axis=2)
+
+ # 3. Apply linear transformation to the attributions
+ attributions = self.apply_linear_transformation(
+ attributions,
+ clip_above_percentile=clip_above_percentile,
+ clip_below_percentile=clip_below_percentile,
+ lower_end=0.0,
+ )
+
+ # 4. Cleanup
+ if morphological_cleanup:
+ attributions = self.morphological_cleanup_fn(
+ attributions, structure=structure
+ )
+ # 5. Draw the outlines
+ if outlines:
+ attributions = self.draw_outlines(
+ attributions, percentage=outlines_component_percentage
+ )
+
+ # 6. Expand the channel axis and convert to RGB
+ attributions = np.expand_dims(attributions, 2) * channel
+
+ # 7.Superimpose on the original image
+ if overlay:
+ attributions = np.clip((attributions * 0.8 + image), 0, 255)
+ return attributions
+
+ def visualize(
+ self,
+ image,
+ gradients,
+ integrated_gradients,
+ polarity="positive",
+ clip_above_percentile=99.9,
+ clip_below_percentile=0,
+ morphological_cleanup=False,
+ structure=np.ones((3, 3)),
+ outlines=False,
+ outlines_component_percentage=90,
+ overlay=True,
+ figsize=(15, 8),
+ ):
+ # 1. Make two copies of the original image
+ img1 = np.copy(image)
+ img2 = np.copy(image)
+
+ # 2. Process the normal gradients
+ grads_attr = self.process_grads(
+ image=img1,
+ attributions=gradients,
+ polarity=polarity,
+ clip_above_percentile=clip_above_percentile,
+ clip_below_percentile=clip_below_percentile,
+ morphological_cleanup=morphological_cleanup,
+ structure=structure,
+ outlines=outlines,
+ outlines_component_percentage=outlines_component_percentage,
+ overlay=overlay,
+ )
+
+ # 3. Process the integrated gradients
+ igrads_attr = self.process_grads(
+ image=img2,
+ attributions=integrated_gradients,
+ polarity=polarity,
+ clip_above_percentile=clip_above_percentile,
+ clip_below_percentile=clip_below_percentile,
+ morphological_cleanup=morphological_cleanup,
+ structure=structure,
+ outlines=outlines,
+ outlines_component_percentage=outlines_component_percentage,
+ overlay=overlay,
+ )
+
+ _, ax = plt.subplots(1, 3, figsize=figsize)
+ ax[0].imshow(image)
+ ax[1].imshow(grads_attr.astype(np.uint8))
+ ax[2].imshow(igrads_attr.astype(np.uint8))
+
+ ax[0].set_title("Input")
+ ax[1].set_title("Normal gradients")
+ ax[2].set_title("Integrated gradients")
+ plt.show()
+
+
+"""
+## Let's test-drive it
+"""
+
+# 1. Convert the image to numpy array
+img = get_img_array(img_path)
+
+# 2. Keep a copy of the original image
+orig_img = np.copy(img[0]).astype(np.uint8)
+
+# 3. Preprocess the image
+img_processed = tf.cast(xception.preprocess_input(img), dtype=tf.float32)
+
+# 4. Get model predictions
+preds = model.predict(img_processed)
+top_pred_idx = tf.argmax(preds[0])
+print("Predicted:", top_pred_idx, xception.decode_predictions(preds, top=1)[0])
+
+# 5. Get the gradients of the last layer for the predicted label
+grads = get_gradients(img_processed, top_pred_idx=top_pred_idx)
+
+# 6. Get the integrated gradients
+igrads = random_baseline_integrated_gradients(
+ np.copy(orig_img), top_pred_idx=top_pred_idx, num_steps=50, num_runs=2
+)
+
+# 7. Process the gradients and plot
+vis = GradVisualizer()
+vis.visualize(
+ image=orig_img,
+ gradients=grads[0].numpy(),
+ integrated_gradients=igrads.numpy(),
+ clip_above_percentile=99,
+ clip_below_percentile=0,
+)
+
+vis.visualize(
+ image=orig_img,
+ gradients=grads[0].numpy(),
+ integrated_gradients=igrads.numpy(),
+ clip_above_percentile=95,
+ clip_below_percentile=28,
+ morphological_cleanup=True,
+ outlines=True,
+)
diff --git a/knowledge_base/vision/involution.py b/knowledge_base/vision/involution.py
new file mode 100644
index 0000000000000000000000000000000000000000..571bf00b7ead6ab5f32c253bc7927c47dc86a9f2
--- /dev/null
+++ b/knowledge_base/vision/involution.py
@@ -0,0 +1,483 @@
+"""
+Title: Involutional neural networks
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
+Date created: 2021/07/25
+Last modified: 2021/07/25
+Description: Deep dive into location-specific and channel-agnostic "involution" kernels.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Convolution has been the basis of most modern neural
+networks for computer vision. A convolution kernel is
+spatial-agnostic and channel-specific. Because of this, it isn't able
+to adapt to different visual patterns with respect to
+different spatial locations. Along with location-related problems, the
+receptive field of convolution creates challenges with regard to capturing
+long-range spatial interactions.
+
+To address the above issues, Li et. al. rethink the properties
+of convolution in
+[Involution: Inverting the Inherence of Convolution for VisualRecognition](https://arxiv.org/abs/2103.06255).
+The authors propose the "involution kernel", that is location-specific and
+channel-agnostic. Due to the location-specific nature of the operation,
+the authors say that self-attention falls under the design paradigm of
+involution.
+
+This example describes the involution kernel, compares two image
+classification models, one with convolution and the other with
+involution, and also tries drawing a parallel with the self-attention
+layer.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import tensorflow as tf
+import keras
+import matplotlib.pyplot as plt
+
+# Set seed for reproducibility.
+tf.random.set_seed(42)
+
+"""
+## Convolution
+
+Convolution remains the mainstay of deep neural networks for computer vision.
+To understand Involution, it is necessary to talk about the
+convolution operation.
+
+
+
+Consider an input tensor **X** with dimensions **H**, **W** and
+**C_in**. We take a collection of **C_out** convolution kernels each of
+shape **K**, **K**, **C_in**. With the multiply-add operation between
+the input tensor and the kernels we obtain an output tensor **Y** with
+dimensions **H**, **W**, **C_out**.
+
+In the diagram above `C_out=3`. This makes the output tensor of shape H,
+W and 3. One can notice that the convoltuion kernel does not depend on
+the spatial position of the input tensor which makes it
+**location-agnostic**. On the other hand, each channel in the output
+tensor is based on a specific convolution filter which makes is
+**channel-specific**.
+"""
+
+"""
+## Involution
+
+The idea is to have an operation that is both **location-specific**
+and **channel-agnostic**. Trying to implement these specific properties poses
+a challenge. With a fixed number of involution kernels (for each
+spatial position) we will **not** be able to process variable-resolution
+input tensors.
+
+To solve this problem, the authors have considered *generating* each
+kernel conditioned on specific spatial positions. With this method, we
+should be able to process variable-resolution input tensors with ease.
+The diagram below provides an intuition on this kernel generation
+method.
+
+
+"""
+
+
+class Involution(keras.layers.Layer):
+ def __init__(
+ self, channel, group_number, kernel_size, stride, reduction_ratio, name
+ ):
+ super().__init__(name=name)
+
+ # Initialize the parameters.
+ self.channel = channel
+ self.group_number = group_number
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.reduction_ratio = reduction_ratio
+
+ def build(self, input_shape):
+ # Get the shape of the input.
+ (_, height, width, num_channels) = input_shape
+
+ # Scale the height and width with respect to the strides.
+ height = height // self.stride
+ width = width // self.stride
+
+ # Define a layer that average pools the input tensor
+ # if stride is more than 1.
+ self.stride_layer = (
+ keras.layers.AveragePooling2D(
+ pool_size=self.stride, strides=self.stride, padding="same"
+ )
+ if self.stride > 1
+ else tf.identity
+ )
+ # Define the kernel generation layer.
+ self.kernel_gen = keras.Sequential(
+ [
+ keras.layers.Conv2D(
+ filters=self.channel // self.reduction_ratio, kernel_size=1
+ ),
+ keras.layers.BatchNormalization(),
+ keras.layers.ReLU(),
+ keras.layers.Conv2D(
+ filters=self.kernel_size * self.kernel_size * self.group_number,
+ kernel_size=1,
+ ),
+ ]
+ )
+ # Define reshape layers
+ self.kernel_reshape = keras.layers.Reshape(
+ target_shape=(
+ height,
+ width,
+ self.kernel_size * self.kernel_size,
+ 1,
+ self.group_number,
+ )
+ )
+ self.input_patches_reshape = keras.layers.Reshape(
+ target_shape=(
+ height,
+ width,
+ self.kernel_size * self.kernel_size,
+ num_channels // self.group_number,
+ self.group_number,
+ )
+ )
+ self.output_reshape = keras.layers.Reshape(
+ target_shape=(height, width, num_channels)
+ )
+
+ def call(self, x):
+ # Generate the kernel with respect to the input tensor.
+ # B, H, W, K*K*G
+ kernel_input = self.stride_layer(x)
+ kernel = self.kernel_gen(kernel_input)
+
+ # reshape the kerenl
+ # B, H, W, K*K, 1, G
+ kernel = self.kernel_reshape(kernel)
+
+ # Extract input patches.
+ # B, H, W, K*K*C
+ input_patches = tf.image.extract_patches(
+ images=x,
+ sizes=[1, self.kernel_size, self.kernel_size, 1],
+ strides=[1, self.stride, self.stride, 1],
+ rates=[1, 1, 1, 1],
+ padding="SAME",
+ )
+
+ # Reshape the input patches to align with later operations.
+ # B, H, W, K*K, C//G, G
+ input_patches = self.input_patches_reshape(input_patches)
+
+ # Compute the multiply-add operation of kernels and patches.
+ # B, H, W, K*K, C//G, G
+ output = tf.multiply(kernel, input_patches)
+ # B, H, W, C//G, G
+ output = tf.reduce_sum(output, axis=3)
+
+ # Reshape the output kernel.
+ # B, H, W, C
+ output = self.output_reshape(output)
+
+ # Return the output tensor and the kernel.
+ return output, kernel
+
+
+"""
+## Testing the Involution layer
+"""
+
+# Define the input tensor.
+input_tensor = tf.random.normal((32, 256, 256, 3))
+
+# Compute involution with stride 1.
+output_tensor, _ = Involution(
+ channel=3, group_number=1, kernel_size=5, stride=1, reduction_ratio=1, name="inv_1"
+)(input_tensor)
+print(f"with stride 1 ouput shape: {output_tensor.shape}")
+
+# Compute involution with stride 2.
+output_tensor, _ = Involution(
+ channel=3, group_number=1, kernel_size=5, stride=2, reduction_ratio=1, name="inv_2"
+)(input_tensor)
+print(f"with stride 2 ouput shape: {output_tensor.shape}")
+
+# Compute involution with stride 1, channel 16 and reduction ratio 2.
+output_tensor, _ = Involution(
+ channel=16, group_number=1, kernel_size=5, stride=1, reduction_ratio=2, name="inv_3"
+)(input_tensor)
+print(
+ "with channel 16 and reduction ratio 2 ouput shape: {}".format(output_tensor.shape)
+)
+
+"""
+## Image Classification
+
+In this section, we will build an image-classifier model. There will
+be two models one with convolutions and the other with involutions.
+
+The image-classification model is heavily inspired by this
+[Convolutional Neural Network (CNN)](https://www.tensorflow.org/tutorials/images/cnn)
+tutorial from Google.
+"""
+
+"""
+## Get the CIFAR10 Dataset
+"""
+
+# Load the CIFAR10 dataset.
+print("loading the CIFAR10 dataset...")
+(
+ (train_images, train_labels),
+ (
+ test_images,
+ test_labels,
+ ),
+) = keras.datasets.cifar10.load_data()
+
+# Normalize pixel values to be between 0 and 1.
+(train_images, test_images) = (train_images / 255.0, test_images / 255.0)
+
+# Shuffle and batch the dataset.
+train_ds = (
+ tf.data.Dataset.from_tensor_slices((train_images, train_labels))
+ .shuffle(256)
+ .batch(256)
+)
+test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(256)
+
+"""
+## Visualise the data
+"""
+
+class_names = [
+ "airplane",
+ "automobile",
+ "bird",
+ "cat",
+ "deer",
+ "dog",
+ "frog",
+ "horse",
+ "ship",
+ "truck",
+]
+
+plt.figure(figsize=(10, 10))
+for i in range(25):
+ plt.subplot(5, 5, i + 1)
+ plt.xticks([])
+ plt.yticks([])
+ plt.grid(False)
+ plt.imshow(train_images[i])
+ plt.xlabel(class_names[train_labels[i][0]])
+plt.show()
+
+"""
+## Convolutional Neural Network
+"""
+
+# Build the conv model.
+print("building the convolution model...")
+conv_model = keras.Sequential(
+ [
+ keras.layers.Conv2D(32, (3, 3), input_shape=(32, 32, 3), padding="same"),
+ keras.layers.ReLU(name="relu1"),
+ keras.layers.MaxPooling2D((2, 2)),
+ keras.layers.Conv2D(64, (3, 3), padding="same"),
+ keras.layers.ReLU(name="relu2"),
+ keras.layers.MaxPooling2D((2, 2)),
+ keras.layers.Conv2D(64, (3, 3), padding="same"),
+ keras.layers.ReLU(name="relu3"),
+ keras.layers.Flatten(),
+ keras.layers.Dense(64, activation="relu"),
+ keras.layers.Dense(10),
+ ]
+)
+
+# Compile the mode with the necessary loss function and optimizer.
+print("compiling the convolution model...")
+conv_model.compile(
+ optimizer="adam",
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=["accuracy"],
+)
+
+# Train the model.
+print("conv model training...")
+conv_hist = conv_model.fit(train_ds, epochs=20, validation_data=test_ds)
+
+"""
+## Involutional Neural Network
+"""
+
+# Build the involution model.
+print("building the involution model...")
+
+inputs = keras.Input(shape=(32, 32, 3))
+x, _ = Involution(
+ channel=3, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="inv_1"
+)(inputs)
+x = keras.layers.ReLU()(x)
+x = keras.layers.MaxPooling2D((2, 2))(x)
+x, _ = Involution(
+ channel=3, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="inv_2"
+)(x)
+x = keras.layers.ReLU()(x)
+x = keras.layers.MaxPooling2D((2, 2))(x)
+x, _ = Involution(
+ channel=3, group_number=1, kernel_size=3, stride=1, reduction_ratio=2, name="inv_3"
+)(x)
+x = keras.layers.ReLU()(x)
+x = keras.layers.Flatten()(x)
+x = keras.layers.Dense(64, activation="relu")(x)
+outputs = keras.layers.Dense(10)(x)
+
+inv_model = keras.Model(inputs=[inputs], outputs=[outputs], name="inv_model")
+
+# Compile the mode with the necessary loss function and optimizer.
+print("compiling the involution model...")
+inv_model.compile(
+ optimizer="adam",
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=["accuracy"],
+)
+
+# train the model
+print("inv model training...")
+inv_hist = inv_model.fit(train_ds, epochs=20, validation_data=test_ds)
+
+"""
+## Comparisons
+
+In this section, we will be looking at both the models and compare a
+few pointers.
+"""
+
+"""
+### Parameters
+
+One can see that with a similar architecture the parameters in a CNN
+is much larger than that of an INN (Involutional Neural Network).
+"""
+
+conv_model.summary()
+
+inv_model.summary()
+
+"""
+### Loss and Accuracy Plots
+
+Here, the loss and the accuracy plots demonstrate that INNs are slow
+learners (with lower parameters).
+"""
+
+plt.figure(figsize=(20, 5))
+
+plt.subplot(1, 2, 1)
+plt.title("Convolution Loss")
+plt.plot(conv_hist.history["loss"], label="loss")
+plt.plot(conv_hist.history["val_loss"], label="val_loss")
+plt.legend()
+
+plt.subplot(1, 2, 2)
+plt.title("Involution Loss")
+plt.plot(inv_hist.history["loss"], label="loss")
+plt.plot(inv_hist.history["val_loss"], label="val_loss")
+plt.legend()
+
+plt.show()
+
+plt.figure(figsize=(20, 5))
+
+plt.subplot(1, 2, 1)
+plt.title("Convolution Accuracy")
+plt.plot(conv_hist.history["accuracy"], label="accuracy")
+plt.plot(conv_hist.history["val_accuracy"], label="val_accuracy")
+plt.legend()
+
+plt.subplot(1, 2, 2)
+plt.title("Involution Accuracy")
+plt.plot(inv_hist.history["accuracy"], label="accuracy")
+plt.plot(inv_hist.history["val_accuracy"], label="val_accuracy")
+plt.legend()
+
+plt.show()
+
+"""
+## Visualizing Involution Kernels
+
+To visualize the kernels, we take the sum of **KรK** values from each
+involution kernel. **All the representatives at different spatial
+locations frame the corresponding heat map.**
+
+The authors mention:
+
+"Our proposed involution is reminiscent of self-attention and
+essentially could become a generalized version of it."
+
+With the visualization of the kernel we can indeed obtain an attention
+map of the image. The learned involution kernels provides attention to
+individual spatial positions of the input tensor. The
+**location-specific** property makes involution a generic space of models
+in which self-attention belongs.
+"""
+
+layer_names = ["inv_1", "inv_2", "inv_3"]
+outputs = [inv_model.get_layer(name).output[1] for name in layer_names]
+vis_model = keras.Model(inv_model.input, outputs)
+
+fig, axes = plt.subplots(nrows=10, ncols=4, figsize=(10, 30))
+
+for ax, test_image in zip(axes, test_images[:10]):
+ (inv1_kernel, inv2_kernel, inv3_kernel) = vis_model.predict(test_image[None, ...])
+ inv1_kernel = tf.reduce_sum(inv1_kernel, axis=[-1, -2, -3])
+ inv2_kernel = tf.reduce_sum(inv2_kernel, axis=[-1, -2, -3])
+ inv3_kernel = tf.reduce_sum(inv3_kernel, axis=[-1, -2, -3])
+
+ ax[0].imshow(keras.utils.array_to_img(test_image))
+ ax[0].set_title("Input Image")
+
+ ax[1].imshow(keras.utils.array_to_img(inv1_kernel[0, ..., None]))
+ ax[1].set_title("Involution Kernel 1")
+
+ ax[2].imshow(keras.utils.array_to_img(inv2_kernel[0, ..., None]))
+ ax[2].set_title("Involution Kernel 2")
+
+ ax[3].imshow(keras.utils.array_to_img(inv3_kernel[0, ..., None]))
+ ax[3].set_title("Involution Kernel 3")
+
+"""
+## Conclusions
+
+In this example, the main focus was to build an `Involution` layer which
+can be easily reused. While our comparisons were based on a specific
+task, feel free to use the layer for different tasks and report your
+results.
+
+According to me, the key take-away of involution is its
+relationship with self-attention. The intuition behind location-specific
+and channel-spefic processing makes sense in a lot of tasks.
+
+Moving forward one can:
+
+- Look at [Yannick's video](https://youtu.be/pH2jZun8MoY) on
+ involution for a better understanding.
+- Experiment with the various hyperparameters of the involution layer.
+- Build different models with the involution layer.
+- Try building a different kernel generation method altogether.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/involution)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/involution).
+"""
diff --git a/knowledge_base/vision/keypoint_detection.py b/knowledge_base/vision/keypoint_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..13a7a69ac4801b3f1d4d75cf1f08c9f6ac6365f2
--- /dev/null
+++ b/knowledge_base/vision/keypoint_detection.py
@@ -0,0 +1,459 @@
+"""
+Title: Keypoint Detection with Transfer Learning
+Author: [Sayak Paul](https://twitter.com/RisingSayak), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
+Date created: 2021/05/02
+Last modified: 2023/07/19
+Description: Training a keypoint detector with data augmentation and transfer learning.
+Accelerator: GPU
+"""
+
+"""
+Keypoint detection consists of locating key object parts. For example, the key parts
+of our faces include nose tips, eyebrows, eye corners, and so on. These parts help to
+represent the underlying object in a feature-rich manner. Keypoint detection has
+applications that include pose estimation, face detection, etc.
+
+In this example, we will build a keypoint detector using the
+[StanfordExtra dataset](https://github.com/benjiebob/StanfordExtra),
+using transfer learning. This example requires TensorFlow 2.4 or higher,
+as well as [`imgaug`](https://imgaug.readthedocs.io/) library,
+which can be installed using the following command:
+"""
+
+"""shell
+pip install -q -U imgaug
+"""
+
+"""
+## Data collection
+"""
+
+"""
+The StanfordExtra dataset contains 12,000 images of dogs together with keypoints and
+segmentation maps. It is developed from the [Stanford dogs dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/).
+It can be downloaded with the command below:
+"""
+
+"""shell
+wget -q http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar
+"""
+
+"""
+Annotations are provided as a single JSON file in the StanfordExtra dataset and one needs
+to fill [this form](https://forms.gle/sRtbicgxsWvRtRmUA) to get access to it. The
+authors explicitly instruct users not to share the JSON file, and this example respects this wish:
+you should obtain the JSON file yourself.
+
+The JSON file is expected to be locally available as `stanfordextra_v12.zip`.
+
+After the files are downloaded, we can extract the archives.
+"""
+
+"""shell
+tar xf images.tar
+unzip -qq ~/stanfordextra_v12.zip
+"""
+
+"""
+## Imports
+"""
+from keras import layers
+import keras
+
+from imgaug.augmentables.kps import KeypointsOnImage
+from imgaug.augmentables.kps import Keypoint
+import imgaug.augmenters as iaa
+
+from PIL import Image
+from sklearn.model_selection import train_test_split
+from matplotlib import pyplot as plt
+import pandas as pd
+import numpy as np
+import json
+import os
+
+"""
+## Define hyperparameters
+"""
+
+IMG_SIZE = 224
+BATCH_SIZE = 64
+EPOCHS = 5
+NUM_KEYPOINTS = 24 * 2 # 24 pairs each having x and y coordinates
+
+"""
+## Load data
+
+The authors also provide a metadata file that specifies additional information about the
+keypoints, like color information, animal pose name, etc. We will load this file in a `pandas`
+dataframe to extract information for visualization purposes.
+"""
+
+IMG_DIR = "Images"
+JSON = "StanfordExtra_V12/StanfordExtra_v12.json"
+KEYPOINT_DEF = (
+ "https://github.com/benjiebob/StanfordExtra/raw/master/keypoint_definitions.csv"
+)
+
+# Load the ground-truth annotations.
+with open(JSON) as infile:
+ json_data = json.load(infile)
+
+# Set up a dictionary, mapping all the ground-truth information
+# with respect to the path of the image.
+json_dict = {i["img_path"]: i for i in json_data}
+
+"""
+A single entry of `json_dict` looks like the following:
+
+```
+'n02085782-Japanese_spaniel/n02085782_2886.jpg':
+{'img_bbox': [205, 20, 116, 201],
+ 'img_height': 272,
+ 'img_path': 'n02085782-Japanese_spaniel/n02085782_2886.jpg',
+ 'img_width': 350,
+ 'is_multiple_dogs': False,
+ 'joints': [[108.66666666666667, 252.0, 1],
+ [147.66666666666666, 229.0, 1],
+ [163.5, 208.5, 1],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [54.0, 244.0, 1],
+ [77.33333333333333, 225.33333333333334, 1],
+ [79.0, 196.5, 1],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [150.66666666666666, 86.66666666666667, 1],
+ [88.66666666666667, 73.0, 1],
+ [116.0, 106.33333333333333, 1],
+ [109.0, 123.33333333333333, 1],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]],
+ 'seg': ...}
+```
+"""
+
+"""
+In this example, the keys we are interested in are:
+
+* `img_path`
+* `joints`
+
+There are a total of 24 entries present inside `joints`. Each entry has 3 values:
+
+* x-coordinate
+* y-coordinate
+* visibility flag of the keypoints (1 indicates visibility and 0 indicates non-visibility)
+
+As we can see `joints` contain multiple `[0, 0, 0]` entries which denote that those
+keypoints were not labeled. In this example, we will consider both non-visible as well as
+unlabeled keypoints in order to allow mini-batch learning.
+"""
+
+# Load the metdata definition file and preview it.
+keypoint_def = pd.read_csv(KEYPOINT_DEF)
+keypoint_def.head()
+
+# Extract the colours and labels.
+colours = keypoint_def["Hex colour"].values.tolist()
+colours = ["#" + colour for colour in colours]
+labels = keypoint_def["Name"].values.tolist()
+
+
+# Utility for reading an image and for getting its annotations.
+def get_dog(name):
+ data = json_dict[name]
+ img_data = plt.imread(os.path.join(IMG_DIR, data["img_path"]))
+ # If the image is RGBA convert it to RGB.
+ if img_data.shape[-1] == 4:
+ img_data = img_data.astype(np.uint8)
+ img_data = Image.fromarray(img_data)
+ img_data = np.array(img_data.convert("RGB"))
+ data["img_data"] = img_data
+
+ return data
+
+
+"""
+## Visualize data
+
+Now, we write a utility function to visualize the images and their keypoints.
+"""
+
+
+# Parts of this code come from here:
+# https://github.com/benjiebob/StanfordExtra/blob/master/demo.ipynb
+def visualize_keypoints(images, keypoints):
+ fig, axes = plt.subplots(nrows=len(images), ncols=2, figsize=(16, 12))
+ [ax.axis("off") for ax in np.ravel(axes)]
+
+ for (ax_orig, ax_all), image, current_keypoint in zip(axes, images, keypoints):
+ ax_orig.imshow(image)
+ ax_all.imshow(image)
+
+ # If the keypoints were formed by `imgaug` then the coordinates need
+ # to be iterated differently.
+ if isinstance(current_keypoint, KeypointsOnImage):
+ for idx, kp in enumerate(current_keypoint.keypoints):
+ ax_all.scatter(
+ [kp.x],
+ [kp.y],
+ c=colours[idx],
+ marker="x",
+ s=50,
+ linewidths=5,
+ )
+ else:
+ current_keypoint = np.array(current_keypoint)
+ # Since the last entry is the visibility flag, we discard it.
+ current_keypoint = current_keypoint[:, :2]
+ for idx, (x, y) in enumerate(current_keypoint):
+ ax_all.scatter([x], [y], c=colours[idx], marker="x", s=50, linewidths=5)
+
+ plt.tight_layout(pad=2.0)
+ plt.show()
+
+
+# Select four samples randomly for visualization.
+samples = list(json_dict.keys())
+num_samples = 4
+selected_samples = np.random.choice(samples, num_samples, replace=False)
+
+images, keypoints = [], []
+
+for sample in selected_samples:
+ data = get_dog(sample)
+ image = data["img_data"]
+ keypoint = data["joints"]
+
+ images.append(image)
+ keypoints.append(keypoint)
+
+visualize_keypoints(images, keypoints)
+
+"""
+The plots show that we have images of non-uniform sizes, which is expected in most
+real-world scenarios. However, if we resize these images to have a uniform shape (for
+instance (224 x 224)) their ground-truth annotations will also be affected. The same
+applies if we apply any geometric transformation (horizontal flip, for e.g.) to an image.
+Fortunately, `imgaug` provides utilities that can handle this issue.
+In the next section, we will write a data generator inheriting the
+[`keras.utils.Sequence`](https://keras.io/api/utils/python_utils/#sequence-class) class
+that applies data augmentation on batches of data using `imgaug`.
+"""
+
+"""
+## Prepare data generator
+"""
+
+
+class KeyPointsDataset(keras.utils.PyDataset):
+ def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True, **kwargs):
+ super().__init__(**kwargs)
+ self.image_keys = image_keys
+ self.aug = aug
+ self.batch_size = batch_size
+ self.train = train
+ self.on_epoch_end()
+
+ def __len__(self):
+ return len(self.image_keys) // self.batch_size
+
+ def on_epoch_end(self):
+ self.indexes = np.arange(len(self.image_keys))
+ if self.train:
+ np.random.shuffle(self.indexes)
+
+ def __getitem__(self, index):
+ indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
+ image_keys_temp = [self.image_keys[k] for k in indexes]
+ (images, keypoints) = self.__data_generation(image_keys_temp)
+
+ return (images, keypoints)
+
+ def __data_generation(self, image_keys_temp):
+ batch_images = np.empty((self.batch_size, IMG_SIZE, IMG_SIZE, 3), dtype="int")
+ batch_keypoints = np.empty(
+ (self.batch_size, 1, 1, NUM_KEYPOINTS), dtype="float32"
+ )
+
+ for i, key in enumerate(image_keys_temp):
+ data = get_dog(key)
+ current_keypoint = np.array(data["joints"])[:, :2]
+ kps = []
+
+ # To apply our data augmentation pipeline, we first need to
+ # form Keypoint objects with the original coordinates.
+ for j in range(0, len(current_keypoint)):
+ kps.append(Keypoint(x=current_keypoint[j][0], y=current_keypoint[j][1]))
+
+ # We then project the original image and its keypoint coordinates.
+ current_image = data["img_data"]
+ kps_obj = KeypointsOnImage(kps, shape=current_image.shape)
+
+ # Apply the augmentation pipeline.
+ (new_image, new_kps_obj) = self.aug(image=current_image, keypoints=kps_obj)
+ batch_images[i,] = new_image
+
+ # Parse the coordinates from the new keypoint object.
+ kp_temp = []
+ for keypoint in new_kps_obj:
+ kp_temp.append(np.nan_to_num(keypoint.x))
+ kp_temp.append(np.nan_to_num(keypoint.y))
+
+ # More on why this reshaping later.
+ batch_keypoints[i,] = np.array(kp_temp).reshape(1, 1, 24 * 2)
+
+ # Scale the coordinates to [0, 1] range.
+ batch_keypoints = batch_keypoints / IMG_SIZE
+
+ return (batch_images, batch_keypoints)
+
+
+"""
+To know more about how to operate with keypoints in `imgaug` check out
+[this document](https://imgaug.readthedocs.io/en/latest/source/examples_keypoints.html).
+"""
+
+"""
+## Define augmentation transforms
+"""
+
+train_aug = iaa.Sequential(
+ [
+ iaa.Resize(IMG_SIZE, interpolation="linear"),
+ iaa.Fliplr(0.3),
+ # `Sometimes()` applies a function randomly to the inputs with
+ # a given probability (0.3, in this case).
+ iaa.Sometimes(0.3, iaa.Affine(rotate=10, scale=(0.5, 0.7))),
+ ]
+)
+
+test_aug = iaa.Sequential([iaa.Resize(IMG_SIZE, interpolation="linear")])
+
+"""
+## Create training and validation splits
+"""
+
+np.random.shuffle(samples)
+train_keys, validation_keys = (
+ samples[int(len(samples) * 0.15) :],
+ samples[: int(len(samples) * 0.15)],
+)
+
+
+"""
+## Data generator investigation
+"""
+
+train_dataset = KeyPointsDataset(
+ train_keys, train_aug, workers=2, use_multiprocessing=True
+)
+validation_dataset = KeyPointsDataset(
+ validation_keys, test_aug, train=False, workers=2, use_multiprocessing=True
+)
+
+print(f"Total batches in training set: {len(train_dataset)}")
+print(f"Total batches in validation set: {len(validation_dataset)}")
+
+sample_images, sample_keypoints = next(iter(train_dataset))
+assert sample_keypoints.max() == 1.0
+assert sample_keypoints.min() == 0.0
+
+sample_keypoints = sample_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE
+visualize_keypoints(sample_images[:4], sample_keypoints)
+
+"""
+## Model building
+
+The [Stanford dogs dataset](http://vision.stanford.edu/aditya86/ImageNetDogs/) (on which
+the StanfordExtra dataset is based) was built using the [ImageNet-1k dataset](http://image-net.org/).
+So, it is likely that the models pretrained on the ImageNet-1k dataset would be useful
+for this task. We will use a MobileNetV2 pre-trained on this dataset as a backbone to
+extract meaningful features from the images and then pass those to a custom regression
+head for predicting coordinates.
+"""
+
+
+def get_model():
+ # Load the pre-trained weights of MobileNetV2 and freeze the weights
+ backbone = keras.applications.MobileNetV2(
+ weights="imagenet",
+ include_top=False,
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
+ )
+ backbone.trainable = False
+
+ inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))
+ x = keras.applications.mobilenet_v2.preprocess_input(inputs)
+ x = backbone(x)
+ x = layers.Dropout(0.3)(x)
+ x = layers.SeparableConv2D(
+ NUM_KEYPOINTS, kernel_size=5, strides=1, activation="relu"
+ )(x)
+ outputs = layers.SeparableConv2D(
+ NUM_KEYPOINTS, kernel_size=3, strides=1, activation="sigmoid"
+ )(x)
+
+ return keras.Model(inputs, outputs, name="keypoint_detector")
+
+
+"""
+Our custom network is fully-convolutional which makes it more parameter-friendly than the
+same version of the network having fully-connected dense layers.
+"""
+
+get_model().summary()
+
+"""
+Notice the output shape of the network: `(None, 1, 1, 48)`. This is why we have reshaped
+the coordinates as: `batch_keypoints[i, :] = np.array(kp_temp).reshape(1, 1, 24 * 2)`.
+"""
+
+"""
+## Model compilation and training
+
+For this example, we will train the network only for five epochs.
+"""
+
+model = get_model()
+model.compile(loss="mse", optimizer=keras.optimizers.Adam(1e-4))
+model.fit(train_dataset, validation_data=validation_dataset, epochs=EPOCHS)
+
+"""
+## Make predictions and visualize them
+"""
+
+sample_val_images, sample_val_keypoints = next(iter(validation_dataset))
+sample_val_images = sample_val_images[:4]
+sample_val_keypoints = sample_val_keypoints[:4].reshape(-1, 24, 2) * IMG_SIZE
+predictions = model.predict(sample_val_images).reshape(-1, 24, 2) * IMG_SIZE
+
+# Ground-truth
+visualize_keypoints(sample_val_images, sample_val_keypoints)
+
+# Predictions
+visualize_keypoints(sample_val_images, predictions)
+
+"""
+Predictions will likely improve with more training.
+"""
+
+"""
+## Going further
+
+* Try using other augmentation transforms from `imgaug` to investigate how that changes
+the results.
+* Here, we transferred the features from the pre-trained network linearly that is we did
+not [fine-tune](https://keras.io/guides/transfer_learning/) it. You are encouraged to fine-tune it on this task and see if that
+improves the performance. You can also try different architectures and see how they
+affect the final performance.
+"""
diff --git a/knowledge_base/vision/knowledge_distillation.py b/knowledge_base/vision/knowledge_distillation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4af348658131d42db6a79d81868784662dbb28d1
--- /dev/null
+++ b/knowledge_base/vision/knowledge_distillation.py
@@ -0,0 +1,242 @@
+"""
+Title: Knowledge Distillation
+Author: [Kenneth Borup](https://twitter.com/Kennethborup)
+Date created: 2020/09/01
+Last modified: 2020/09/01
+Description: Implementation of classical Knowledge Distillation.
+Accelerator: GPU
+Converted to Keras 3 by: [Md Awsafur Rahman](https://awsaf49.github.io)
+"""
+
+"""
+## Introduction to Knowledge Distillation
+
+Knowledge Distillation is a procedure for model
+compression, in which a small (student) model is trained to match a large pre-trained
+(teacher) model. Knowledge is transferred from the teacher model to the student
+by minimizing a loss function, aimed at matching softened teacher logits as well as
+ground-truth labels.
+
+The logits are softened by applying a "temperature" scaling function in the softmax,
+effectively smoothing out the probability distribution and revealing
+inter-class relationships learned by the teacher.
+
+**Reference:**
+
+- [Hinton et al. (2015)](https://arxiv.org/abs/1503.02531)
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+import keras
+from keras import layers
+from keras import ops
+import numpy as np
+
+"""
+## Construct `Distiller()` class
+
+The custom `Distiller()` class, overrides the `Model` methods `compile`, `compute_loss`,
+and `call`. In order to use the distiller, we need:
+
+- A trained teacher model
+- A student model to train
+- A student loss function on the difference between student predictions and ground-truth
+- A distillation loss function, along with a `temperature`, on the difference between the
+soft student predictions and the soft teacher labels
+- An `alpha` factor to weight the student and distillation loss
+- An optimizer for the student and (optional) metrics to evaluate performance
+
+In the `compute_loss` method, we perform a forward pass of both the teacher and student,
+calculate the loss with weighting of the `student_loss` and `distillation_loss` by `alpha`
+and `1 - alpha`, respectively. Note: only the student weights are updated.
+"""
+
+
+class Distiller(keras.Model):
+ def __init__(self, student, teacher):
+ super().__init__()
+ self.teacher = teacher
+ self.student = student
+
+ def compile(
+ self,
+ optimizer,
+ metrics,
+ student_loss_fn,
+ distillation_loss_fn,
+ alpha=0.1,
+ temperature=3,
+ ):
+ """Configure the distiller.
+
+ Args:
+ optimizer: Keras optimizer for the student weights
+ metrics: Keras metrics for evaluation
+ student_loss_fn: Loss function of difference between student
+ predictions and ground-truth
+ distillation_loss_fn: Loss function of difference between soft
+ student predictions and soft teacher predictions
+ alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
+ temperature: Temperature for softening probability distributions.
+ Larger temperature gives softer distributions.
+ """
+ super().compile(optimizer=optimizer, metrics=metrics)
+ self.student_loss_fn = student_loss_fn
+ self.distillation_loss_fn = distillation_loss_fn
+ self.alpha = alpha
+ self.temperature = temperature
+
+ def compute_loss(
+ self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
+ ):
+ teacher_pred = self.teacher(x, training=False)
+ student_loss = self.student_loss_fn(y, y_pred)
+
+ distillation_loss = self.distillation_loss_fn(
+ ops.softmax(teacher_pred / self.temperature, axis=1),
+ ops.softmax(y_pred / self.temperature, axis=1),
+ ) * (self.temperature**2)
+
+ loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
+ return loss
+
+ def call(self, x):
+ return self.student(x)
+
+
+"""
+## Create student and teacher models
+
+Initialy, we create a teacher model and a smaller student model. Both models are
+convolutional neural networks and created using `Sequential()`,
+but could be any Keras model.
+"""
+
+# Create the teacher
+teacher = keras.Sequential(
+ [
+ keras.Input(shape=(28, 28, 1)),
+ layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
+ layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
+ layers.Flatten(),
+ layers.Dense(10),
+ ],
+ name="teacher",
+)
+
+# Create the student
+student = keras.Sequential(
+ [
+ keras.Input(shape=(28, 28, 1)),
+ layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
+ layers.LeakyReLU(negative_slope=0.2),
+ layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
+ layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
+ layers.Flatten(),
+ layers.Dense(10),
+ ],
+ name="student",
+)
+
+# Clone student for later comparison
+student_scratch = keras.models.clone_model(student)
+
+"""
+## Prepare the dataset
+
+The dataset used for training the teacher and distilling the teacher is
+[MNIST](https://keras.io/api/datasets/mnist/), and the procedure would be equivalent for
+any other
+dataset, e.g. [CIFAR-10](https://keras.io/api/datasets/cifar10/), with a suitable choice
+of models. Both the student and teacher are trained on the training set and evaluated on
+the test set.
+"""
+
+# Prepare the train and test dataset.
+batch_size = 64
+(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
+
+# Normalize data
+x_train = x_train.astype("float32") / 255.0
+x_train = np.reshape(x_train, (-1, 28, 28, 1))
+
+x_test = x_test.astype("float32") / 255.0
+x_test = np.reshape(x_test, (-1, 28, 28, 1))
+
+
+"""
+## Train the teacher
+
+In knowledge distillation we assume that the teacher is trained and fixed. Thus, we start
+by training the teacher model on the training set in the usual way.
+"""
+
+# Train teacher as usual
+teacher.compile(
+ optimizer=keras.optimizers.Adam(),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
+)
+
+# Train and evaluate teacher on data.
+teacher.fit(x_train, y_train, epochs=5)
+teacher.evaluate(x_test, y_test)
+
+"""
+## Distill teacher to student
+
+We have already trained the teacher model, and we only need to initialize a
+`Distiller(student, teacher)` instance, `compile()` it with the desired losses,
+hyperparameters and optimizer, and distill the teacher to the student.
+"""
+
+# Initialize and compile distiller
+distiller = Distiller(student=student, teacher=teacher)
+distiller.compile(
+ optimizer=keras.optimizers.Adam(),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
+ student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ distillation_loss_fn=keras.losses.KLDivergence(),
+ alpha=0.1,
+ temperature=10,
+)
+
+# Distill teacher to student
+distiller.fit(x_train, y_train, epochs=3)
+
+# Evaluate student on test dataset
+distiller.evaluate(x_test, y_test)
+
+"""
+## Train student from scratch for comparison
+
+We can also train an equivalent student model from scratch without the teacher, in order
+to evaluate the performance gain obtained by knowledge distillation.
+"""
+
+# Train student as doen usually
+student_scratch.compile(
+ optimizer=keras.optimizers.Adam(),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
+)
+
+# Train and evaluate student trained from scratch.
+student_scratch.fit(x_train, y_train, epochs=3)
+student_scratch.evaluate(x_test, y_test)
+
+"""
+If the teacher is trained for 5 full epochs and the student is distilled on this teacher
+for 3 full epochs, you should in this example experience a performance boost compared to
+training the same student model from scratch, and even compared to the teacher itself.
+You should expect the teacher to have accuracy around 97.6%, the student trained from
+scratch should be around 97.6%, and the distilled student should be around 98.1%. Remove
+or try out different seeds to use different weight initializations.
+"""
diff --git a/knowledge_base/vision/learnable_resizer.py b/knowledge_base/vision/learnable_resizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f13ff73eb5743e20e16243b53190e9a4c1030206
--- /dev/null
+++ b/knowledge_base/vision/learnable_resizer.py
@@ -0,0 +1,321 @@
+"""
+Title: Learning to Resize in Computer Vision
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/04/30
+Last modified: 2023/12/18
+Description: How to optimally learn representations of images for a given resolution.
+Accelerator: GPU
+"""
+
+"""
+It is a common belief that if we constrain vision models to perceive things as humans do,
+their performance can be improved. For example, in [this work](https://arxiv.org/abs/1811.12231),
+Geirhos et al. showed that the vision models pre-trained on the ImageNet-1k dataset are
+biased towards texture, whereas human beings mostly use the shape descriptor to develop a
+common perception. But does this belief always apply, especially when it comes to improving
+the performance of vision models?
+
+It turns out it may not always be the case. When training vision models, it is common to
+resize images to a lower dimension ((224 x 224), (299 x 299), etc.) to allow mini-batch
+learning and also to keep up the compute limitations. We generally make use of image
+resizing methods like **bilinear interpolation** for this step and the resized images do
+not lose much of their perceptual character to the human eyes. In
+[Learning to Resize Images for Computer Vision Tasks](https://arxiv.org/abs/2103.09950v1), Talebi et al. show
+that if we try to optimize the perceptual quality of the images for the vision models
+rather than the human eyes, their performance can further be improved. They investigate
+the following question:
+
+**For a given image resolution and a model, how to best resize the given images?**
+
+As shown in the paper, this idea helps to consistently improve the performance of the
+common vision models (pre-trained on ImageNet-1k) like DenseNet-121, ResNet-50,
+MobileNetV2, and EfficientNets. In this example, we will implement the learnable image
+resizing module as proposed in the paper and demonstrate that on the
+[Cats and Dogs dataset](https://www.microsoft.com/en-us/download/details.aspx?id=54765)
+using the [DenseNet-121](https://arxiv.org/abs/1608.06993) architecture.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+from keras import ops
+from keras import layers
+import tensorflow as tf
+
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+"""
+## Define hyperparameters
+"""
+
+"""
+In order to facilitate mini-batch learning, we need to have a fixed shape for the images
+inside a given batch. This is why an initial resizing is required. We first resize all
+the images to (300 x 300) shape and then learn their optimal representation for the
+(150 x 150) resolution.
+"""
+
+INP_SIZE = (300, 300)
+TARGET_SIZE = (150, 150)
+INTERPOLATION = "bilinear"
+
+AUTO = tf.data.AUTOTUNE
+BATCH_SIZE = 64
+EPOCHS = 5
+
+"""
+In this example, we will use the bilinear interpolation but the learnable image resizer
+module is not dependent on any specific interpolation method. We can also use others,
+such as bicubic.
+"""
+
+"""
+## Load and prepare the dataset
+
+For this example, we will only use 40% of the total training dataset.
+"""
+
+train_ds, validation_ds = tfds.load(
+ "cats_vs_dogs",
+ # Reserve 10% for validation
+ split=["train[:40%]", "train[40%:50%]"],
+ as_supervised=True,
+)
+
+
+def preprocess_dataset(image, label):
+ image = ops.image.resize(image, (INP_SIZE[0], INP_SIZE[1]))
+ label = ops.one_hot(label, num_classes=2)
+ return (image, label)
+
+
+train_ds = (
+ train_ds.shuffle(BATCH_SIZE * 100)
+ .map(preprocess_dataset, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+validation_ds = (
+ validation_ds.map(preprocess_dataset, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+"""
+## Define the learnable resizer utilities
+
+The figure below (courtesy: [Learning to Resize Images for Computer Vision Tasks](https://arxiv.org/abs/2103.09950v1))
+presents the structure of the learnable resizing module:
+
+
+"""
+
+
+def conv_block(x, filters, kernel_size, strides, activation=layers.LeakyReLU(0.2)):
+ x = layers.Conv2D(filters, kernel_size, strides, padding="same", use_bias=False)(x)
+ x = layers.BatchNormalization()(x)
+ if activation:
+ x = activation(x)
+ return x
+
+
+def res_block(x):
+ inputs = x
+ x = conv_block(x, 16, 3, 1)
+ x = conv_block(x, 16, 3, 1, activation=None)
+ return layers.Add()([inputs, x])
+
+ # Note: user can change num_res_blocks to >1 also if needed
+
+
+def get_learnable_resizer(filters=16, num_res_blocks=1, interpolation=INTERPOLATION):
+ inputs = layers.Input(shape=[None, None, 3])
+
+ # First, perform naive resizing.
+ naive_resize = layers.Resizing(*TARGET_SIZE, interpolation=interpolation)(inputs)
+
+ # First convolution block without batch normalization.
+ x = layers.Conv2D(filters=filters, kernel_size=7, strides=1, padding="same")(inputs)
+ x = layers.LeakyReLU(0.2)(x)
+
+ # Second convolution block with batch normalization.
+ x = layers.Conv2D(filters=filters, kernel_size=1, strides=1, padding="same")(x)
+ x = layers.LeakyReLU(0.2)(x)
+ x = layers.BatchNormalization()(x)
+
+ # Intermediate resizing as a bottleneck.
+ bottleneck = layers.Resizing(*TARGET_SIZE, interpolation=interpolation)(x)
+
+ # Residual passes.
+ # First res_block will get bottleneck output as input
+ x = res_block(bottleneck)
+ # Remaining res_blocks will get previous res_block output as input
+ for _ in range(num_res_blocks - 1):
+ x = res_block(x)
+
+ # Projection.
+ x = layers.Conv2D(
+ filters=filters, kernel_size=3, strides=1, padding="same", use_bias=False
+ )(x)
+ x = layers.BatchNormalization()(x)
+
+ # Skip connection.
+ x = layers.Add()([bottleneck, x])
+
+ # Final resized image.
+ x = layers.Conv2D(filters=3, kernel_size=7, strides=1, padding="same")(x)
+ final_resize = layers.Add()([naive_resize, x])
+
+ return keras.Model(inputs, final_resize, name="learnable_resizer")
+
+
+learnable_resizer = get_learnable_resizer()
+
+"""
+## Visualize the outputs of the learnable resizing module
+
+Here, we visualize how the resized images would look like after being passed through the
+random weights of the resizer.
+"""
+
+sample_images, _ = next(iter(train_ds))
+
+
+plt.figure(figsize=(16, 10))
+for i, image in enumerate(sample_images[:6]):
+ image = image / 255
+
+ ax = plt.subplot(3, 4, 2 * i + 1)
+ plt.title("Input Image")
+ plt.imshow(image.numpy().squeeze())
+ plt.axis("off")
+
+ ax = plt.subplot(3, 4, 2 * i + 2)
+ resized_image = learnable_resizer(image[None, ...])
+ plt.title("Resized Image")
+ plt.imshow(resized_image.numpy().squeeze())
+ plt.axis("off")
+
+"""
+## Model building utility
+"""
+
+
+def get_model():
+ backbone = keras.applications.DenseNet121(
+ weights=None,
+ include_top=True,
+ classes=2,
+ input_shape=((TARGET_SIZE[0], TARGET_SIZE[1], 3)),
+ )
+ backbone.trainable = True
+
+ inputs = layers.Input((INP_SIZE[0], INP_SIZE[1], 3))
+ x = layers.Rescaling(scale=1.0 / 255)(inputs)
+ x = learnable_resizer(x)
+ outputs = backbone(x)
+
+ return keras.Model(inputs, outputs)
+
+
+"""
+The structure of the learnable image resizer module allows for flexible integrations with
+different vision models.
+"""
+
+"""
+## Compile and train our model with learnable resizer
+"""
+
+model = get_model()
+model.compile(
+ loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
+ optimizer="sgd",
+ metrics=["accuracy"],
+)
+model.fit(train_ds, validation_data=validation_ds, epochs=EPOCHS)
+
+"""
+## Visualize the outputs of the trained visualizer
+"""
+
+plt.figure(figsize=(16, 10))
+for i, image in enumerate(sample_images[:6]):
+ image = image / 255
+
+ ax = plt.subplot(3, 4, 2 * i + 1)
+ plt.title("Input Image")
+ plt.imshow(image.numpy().squeeze())
+ plt.axis("off")
+
+ ax = plt.subplot(3, 4, 2 * i + 2)
+ resized_image = learnable_resizer(image[None, ...])
+ plt.title("Resized Image")
+ plt.imshow(resized_image.numpy().squeeze() / 10)
+ plt.axis("off")
+
+"""
+The plot shows that the visuals of the images have improved with training. The following
+table shows the benefits of using the resizing module in comparison to using the bilinear
+interpolation:
+
+| Model | Number of parameters (Million) | Top-1 accuracy |
+|:-------------------------: |:-------------------------------: |:--------------: |
+| With the learnable resizer | 7.051717 | 67.67% |
+| Without the learnable resizer | 7.039554 | 60.19% |
+
+For more details, you can check out [this repository](https://github.com/sayakpaul/Learnable-Image-Resizing).
+Note the above-reported models were trained for 10 epochs on 90% of the training set of
+Cats and Dogs unlike this example. Also, note that the increase in the number of
+parameters due to the resizing module is very negligible. To ensure that the improvement
+in the performance is not due to stochasticity, the models were trained using the same
+initial random weights.
+
+Now, a question worth asking here is - _isn't the improved accuracy simply a consequence
+of adding more layers (the resizer is a mini network after all) to the model, compared to
+the baseline?_
+
+To show that it is not the case, the authors conduct the following experiment:
+
+* Take a pre-trained model trained some size, say (224 x 224).
+
+* Now, first, use it to infer predictions on images resized to a lower resolution. Record
+the performance.
+
+* For the second experiment, plug in the resizer module at the top of the pre-trained
+model and warm-start the training. Record the performance.
+
+Now, the authors argue that using the second option is better because it helps the model
+learn how to adjust the representations better with respect to the given resolution.
+Since the results purely are empirical, a few more experiments such as analyzing the
+cross-channel interaction would have been even better. It is worth noting that elements
+like [Squeeze and Excitation (SE) blocks](https://arxiv.org/abs/1709.01507), [Global Context (GC) blocks](https://arxiv.org/abs/1904.11492) also add a few
+parameters to an existing network but they are known to help a network process
+information in systematic ways to improve the overall performance.
+"""
+
+"""
+## Notes
+
+* To impose shape bias inside the vision models, Geirhos et al. trained them with a
+combination of natural and stylized images. It might be interesting to investigate if
+this learnable resizing module could achieve something similar as the outputs seem to
+discard the texture information.
+
+* The resizer module can handle arbitrary resolutions and aspect ratios which is very
+important for tasks like object detection and segmentation.
+
+* There is another closely related topic on ***adaptive image resizing*** that attempts
+to resize images/feature maps adaptively during training. [EfficientV2](https://arxiv.org/abs/2104.00298)
+uses this idea.
+"""
diff --git a/knowledge_base/vision/masked_image_modeling.py b/knowledge_base/vision/masked_image_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd471a3a0369c5631a54412b508cd279562b5d02
--- /dev/null
+++ b/knowledge_base/vision/masked_image_modeling.py
@@ -0,0 +1,942 @@
+"""
+Title: Masked image modeling with Autoencoders
+Author: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/12/20
+Last modified: 2021/12/21
+Description: Implementing Masked Autoencoders for self-supervised pretraining.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In deep learning, models with growing **capacity** and **capability** can easily overfit
+on large datasets (ImageNet-1K). In the field of natural language processing, the
+appetite for data has been **successfully addressed** by self-supervised pretraining.
+
+In the academic paper
+[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
+by He et. al. the authors propose a simple yet effective method to pretrain large
+vision models (here [ViT Huge](https://arxiv.org/abs/2010.11929)). Inspired from
+the pretraining algorithm of BERT ([Devlin et al.](https://arxiv.org/abs/1810.04805)),
+they mask patches of an image and, through an autoencoder predict the masked patches.
+In the spirit of "masked language modeling", this pretraining task could be referred
+to as "masked image modeling".
+
+In this example, we implement
+[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
+with the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. After
+pretraining a scaled down version of ViT, we also implement the linear evaluation
+pipeline on CIFAR-10.
+
+
+This implementation covers (MAE refers to Masked Autoencoder):
+
+- The masking algorithm
+- MAE encoder
+- MAE decoder
+- Evaluation with linear probing
+
+As a reference, we reuse some of the code presented in
+[this example](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
+
+"""
+
+"""
+## Imports
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import tensorflow as tf
+import keras
+from keras import layers
+
+import matplotlib.pyplot as plt
+import numpy as np
+import random
+
+# Setting seeds for reproducibility.
+SEED = 42
+keras.utils.set_random_seed(SEED)
+
+"""
+## Hyperparameters for pretraining
+
+Please feel free to change the hyperparameters and check your results. The best way to
+get an intuition about the architecture is to experiment with it. Our hyperparameters are
+heavily inspired by the design guidelines laid out by the authors in
+[the original paper](https://arxiv.org/abs/2111.06377).
+"""
+
+# DATA
+BUFFER_SIZE = 1024
+BATCH_SIZE = 256
+AUTO = tf.data.AUTOTUNE
+INPUT_SHAPE = (32, 32, 3)
+NUM_CLASSES = 10
+
+# OPTIMIZER
+LEARNING_RATE = 5e-3
+WEIGHT_DECAY = 1e-4
+
+# PRETRAINING
+EPOCHS = 100
+
+# AUGMENTATION
+IMAGE_SIZE = 48 # We will resize input images to this size.
+PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.
+NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
+MASK_PROPORTION = 0.75 # We have found 75% masking to give us the best results.
+
+# ENCODER and DECODER
+LAYER_NORM_EPS = 1e-6
+ENC_PROJECTION_DIM = 128
+DEC_PROJECTION_DIM = 64
+ENC_NUM_HEADS = 4
+ENC_LAYERS = 6
+DEC_NUM_HEADS = 4
+DEC_LAYERS = (
+ 2 # The decoder is lightweight but should be reasonably deep for reconstruction.
+)
+ENC_TRANSFORMER_UNITS = [
+ ENC_PROJECTION_DIM * 2,
+ ENC_PROJECTION_DIM,
+] # Size of the transformer layers.
+DEC_TRANSFORMER_UNITS = [
+ DEC_PROJECTION_DIM * 2,
+ DEC_PROJECTION_DIM,
+]
+
+"""
+## Load and prepare the CIFAR-10 dataset
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:40000], y_train[:40000]),
+ (x_train[40000:], y_train[40000:]),
+)
+print(f"Training samples: {len(x_train)}")
+print(f"Validation samples: {len(x_val)}")
+print(f"Testing samples: {len(x_test)}")
+
+train_ds = tf.data.Dataset.from_tensor_slices(x_train)
+train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)
+
+val_ds = tf.data.Dataset.from_tensor_slices(x_val)
+val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+test_ds = tf.data.Dataset.from_tensor_slices(x_test)
+test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+"""
+## Data augmentation
+
+In previous self-supervised pretraining methodologies
+([SimCLR](https://arxiv.org/abs/2002.05709) alike), we have noticed that the data
+augmentation pipeline plays an important role. On the other hand the authors of this
+paper point out that Masked Autoencoders **do not** rely on augmentations. They propose a
+simple augmentation pipeline of:
+
+
+- Resizing
+- Random cropping (fixed-sized or random sized)
+- Random horizontal flipping
+"""
+
+
+def get_train_augmentation_model():
+ model = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
+ layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
+ layers.RandomFlip("horizontal"),
+ ],
+ name="train_data_augmentation",
+ )
+ return model
+
+
+def get_test_augmentation_model():
+ model = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
+ ],
+ name="test_data_augmentation",
+ )
+ return model
+
+
+"""
+## A layer for extracting patches from images
+
+This layer takes images as input and divides them into patches. The layer also includes
+two utility method:
+
+- `show_patched_image` -- Takes a batch of images and its corresponding patches to plot a
+random pair of image and patches.
+- `reconstruct_from_patch` -- Takes a single instance of patches and stitches them
+together into the original image.
+"""
+
+
+class Patches(layers.Layer):
+ def __init__(self, patch_size=PATCH_SIZE, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+
+ # Assuming the image has three channels each patch would be
+ # of size (patch_size, patch_size, 3).
+ self.resize = layers.Reshape((-1, patch_size * patch_size * 3))
+
+ def call(self, images):
+ # Create patches from the input images
+ patches = tf.image.extract_patches(
+ images=images,
+ sizes=[1, self.patch_size, self.patch_size, 1],
+ strides=[1, self.patch_size, self.patch_size, 1],
+ rates=[1, 1, 1, 1],
+ padding="VALID",
+ )
+
+ # Reshape the patches to (batch, num_patches, patch_area) and return it.
+ patches = self.resize(patches)
+ return patches
+
+ def show_patched_image(self, images, patches):
+ # This is a utility function which accepts a batch of images and its
+ # corresponding patches and help visualize one image and its patches
+ # side by side.
+ idx = np.random.choice(patches.shape[0])
+ print(f"Index selected: {idx}.")
+
+ plt.figure(figsize=(4, 4))
+ plt.imshow(keras.utils.array_to_img(images[idx]))
+ plt.axis("off")
+ plt.show()
+
+ n = int(np.sqrt(patches.shape[1]))
+ plt.figure(figsize=(4, 4))
+ for i, patch in enumerate(patches[idx]):
+ ax = plt.subplot(n, n, i + 1)
+ patch_img = tf.reshape(patch, (self.patch_size, self.patch_size, 3))
+ plt.imshow(keras.utils.img_to_array(patch_img))
+ plt.axis("off")
+ plt.show()
+
+ # Return the index chosen to validate it outside the method.
+ return idx
+
+ # taken from https://stackoverflow.com/a/58082878/10319735
+ def reconstruct_from_patch(self, patch):
+ # This utility function takes patches from a *single* image and
+ # reconstructs it back into the image. This is useful for the train
+ # monitor callback.
+ num_patches = patch.shape[0]
+ n = int(np.sqrt(num_patches))
+ patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))
+ rows = tf.split(patch, n, axis=0)
+ rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
+ reconstructed = tf.concat(rows, axis=0)
+ return reconstructed
+
+
+"""
+Let's visualize the image patches.
+"""
+
+# Get a batch of images.
+image_batch = next(iter(train_ds))
+
+# Augment the images.
+augmentation_model = get_train_augmentation_model()
+augmented_images = augmentation_model(image_batch)
+
+# Define the patch layer.
+patch_layer = Patches()
+
+# Get the patches from the batched images.
+patches = patch_layer(images=augmented_images)
+
+# Now pass the images and the corresponding patches
+# to the `show_patched_image` method.
+random_index = patch_layer.show_patched_image(images=augmented_images, patches=patches)
+
+# Chose the same chose image and try reconstructing the patches
+# into the original image.
+image = patch_layer.reconstruct_from_patch(patches[random_index])
+plt.imshow(image)
+plt.axis("off")
+plt.show()
+
+"""
+## Patch encoding with masking
+
+Quoting the paper
+
+> Following ViT, we divide an image into regular non-overlapping patches. Then we sample
+a subset of patches and mask (i.e., remove) the remaining ones. Our sampling strategy is
+straightforward: we sample random patches without replacement, following a uniform
+distribution. We simply refer to this as โrandom samplingโ.
+
+This layer includes masking and encoding the patches.
+
+The utility methods of the layer are:
+
+- `get_random_indices` -- Provides the mask and unmask indices.
+- `generate_masked_image` -- Takes patches and unmask indices, results in a random masked
+image. This is an essential utility method for our training monitor callback (defined
+later).
+"""
+
+
+class PatchEncoder(layers.Layer):
+ def __init__(
+ self,
+ patch_size=PATCH_SIZE,
+ projection_dim=ENC_PROJECTION_DIM,
+ mask_proportion=MASK_PROPORTION,
+ downstream=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+ self.projection_dim = projection_dim
+ self.mask_proportion = mask_proportion
+ self.downstream = downstream
+
+ # This is a trainable mask token initialized randomly from a normal
+ # distribution.
+ self.mask_token = tf.Variable(
+ tf.random.normal([1, patch_size * patch_size * 3]), trainable=True
+ )
+
+ def build(self, input_shape):
+ (_, self.num_patches, self.patch_area) = input_shape
+
+ # Create the projection layer for the patches.
+ self.projection = layers.Dense(units=self.projection_dim)
+
+ # Create the positional embedding layer.
+ self.position_embedding = layers.Embedding(
+ input_dim=self.num_patches, output_dim=self.projection_dim
+ )
+
+ # Number of patches that will be masked.
+ self.num_mask = int(self.mask_proportion * self.num_patches)
+
+ def call(self, patches):
+ # Get the positional embeddings.
+ batch_size = tf.shape(patches)[0]
+ positions = tf.range(start=0, limit=self.num_patches, delta=1)
+ pos_embeddings = self.position_embedding(positions[tf.newaxis, ...])
+ pos_embeddings = tf.tile(
+ pos_embeddings, [batch_size, 1, 1]
+ ) # (B, num_patches, projection_dim)
+
+ # Embed the patches.
+ patch_embeddings = (
+ self.projection(patches) + pos_embeddings
+ ) # (B, num_patches, projection_dim)
+
+ if self.downstream:
+ return patch_embeddings
+ else:
+ mask_indices, unmask_indices = self.get_random_indices(batch_size)
+ # The encoder input is the unmasked patch embeddings. Here we gather
+ # all the patches that should be unmasked.
+ unmasked_embeddings = tf.gather(
+ patch_embeddings, unmask_indices, axis=1, batch_dims=1
+ ) # (B, unmask_numbers, projection_dim)
+
+ # Get the unmasked and masked position embeddings. We will need them
+ # for the decoder.
+ unmasked_positions = tf.gather(
+ pos_embeddings, unmask_indices, axis=1, batch_dims=1
+ ) # (B, unmask_numbers, projection_dim)
+ masked_positions = tf.gather(
+ pos_embeddings, mask_indices, axis=1, batch_dims=1
+ ) # (B, mask_numbers, projection_dim)
+
+ # Repeat the mask token number of mask times.
+ # Mask tokens replace the masks of the image.
+ mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)
+ mask_tokens = tf.repeat(
+ mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0
+ )
+
+ # Get the masked embeddings for the tokens.
+ masked_embeddings = self.projection(mask_tokens) + masked_positions
+ return (
+ unmasked_embeddings, # Input to the encoder.
+ masked_embeddings, # First part of input to the decoder.
+ unmasked_positions, # Added to the encoder outputs.
+ mask_indices, # The indices that were masked.
+ unmask_indices, # The indices that were unmaksed.
+ )
+
+ def get_random_indices(self, batch_size):
+ # Create random indices from a uniform distribution and then split
+ # it into mask and unmask indices.
+ rand_indices = tf.argsort(
+ tf.random.uniform(shape=(batch_size, self.num_patches)), axis=-1
+ )
+ mask_indices = rand_indices[:, : self.num_mask]
+ unmask_indices = rand_indices[:, self.num_mask :]
+ return mask_indices, unmask_indices
+
+ def generate_masked_image(self, patches, unmask_indices):
+ # Choose a random patch and it corresponding unmask index.
+ idx = np.random.choice(patches.shape[0])
+ patch = patches[idx]
+ unmask_index = unmask_indices[idx]
+
+ # Build a numpy array of same shape as patch.
+ new_patch = np.zeros_like(patch)
+
+ # Iterate of the new_patch and plug the unmasked patches.
+ count = 0
+ for i in range(unmask_index.shape[0]):
+ new_patch[unmask_index[i]] = patch[unmask_index[i]]
+ return new_patch, idx
+
+
+"""
+Let's see the masking process in action on a sample image.
+"""
+
+# Create the patch encoder layer.
+patch_encoder = PatchEncoder()
+
+# Get the embeddings and positions.
+(
+ unmasked_embeddings,
+ masked_embeddings,
+ unmasked_positions,
+ mask_indices,
+ unmask_indices,
+) = patch_encoder(patches=patches)
+
+
+# Show a maksed patch image.
+new_patch, random_index = patch_encoder.generate_masked_image(patches, unmask_indices)
+
+plt.figure(figsize=(10, 10))
+plt.subplot(1, 2, 1)
+img = patch_layer.reconstruct_from_patch(new_patch)
+plt.imshow(keras.utils.array_to_img(img))
+plt.axis("off")
+plt.title("Masked")
+plt.subplot(1, 2, 2)
+img = augmented_images[random_index]
+plt.imshow(keras.utils.array_to_img(img))
+plt.axis("off")
+plt.title("Original")
+plt.show()
+
+"""
+## MLP
+
+This serves as the fully connected feed forward network of the transformer architecture.
+"""
+
+
+def mlp(x, dropout_rate, hidden_units):
+ for units in hidden_units:
+ x = layers.Dense(units, activation=tf.nn.gelu)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+"""
+## MAE encoder
+
+The MAE encoder is ViT. The only point to note here is that the encoder outputs a layer
+normalized output.
+"""
+
+
+def create_encoder(num_heads=ENC_NUM_HEADS, num_layers=ENC_LAYERS):
+ inputs = layers.Input((None, ENC_PROJECTION_DIM))
+ x = inputs
+
+ for _ in range(num_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
+
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=ENC_PROJECTION_DIM, dropout=0.1
+ )(x1, x1)
+
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, x])
+
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
+
+ # MLP.
+ x3 = mlp(x3, hidden_units=ENC_TRANSFORMER_UNITS, dropout_rate=0.1)
+
+ # Skip connection 2.
+ x = layers.Add()([x3, x2])
+
+ outputs = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
+ return keras.Model(inputs, outputs, name="mae_encoder")
+
+
+"""
+## MAE decoder
+
+The authors point out that they use an **asymmetric** autoencoder model. They use a
+lightweight decoder that takes "<10% computation per token vs. the encoder". We are not
+specific with the "<10% computation" in our implementation but have used a smaller
+decoder (both in terms of depth and projection dimensions).
+"""
+
+
+def create_decoder(
+ num_layers=DEC_LAYERS, num_heads=DEC_NUM_HEADS, image_size=IMAGE_SIZE
+):
+ inputs = layers.Input((NUM_PATCHES, ENC_PROJECTION_DIM))
+ x = layers.Dense(DEC_PROJECTION_DIM)(inputs)
+
+ for _ in range(num_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
+
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=DEC_PROJECTION_DIM, dropout=0.1
+ )(x1, x1)
+
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, x])
+
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
+
+ # MLP.
+ x3 = mlp(x3, hidden_units=DEC_TRANSFORMER_UNITS, dropout_rate=0.1)
+
+ # Skip connection 2.
+ x = layers.Add()([x3, x2])
+
+ x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
+ x = layers.Flatten()(x)
+ pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
+ outputs = layers.Reshape((image_size, image_size, 3))(pre_final)
+
+ return keras.Model(inputs, outputs, name="mae_decoder")
+
+
+"""
+## MAE trainer
+
+This is the trainer module. We wrap the encoder and decoder inside of a `tf.keras.Model`
+subclass. This allows us to customize what happens in the `model.fit()` loop.
+"""
+
+
+class MaskedAutoencoder(keras.Model):
+ def __init__(
+ self,
+ train_augmentation_model,
+ test_augmentation_model,
+ patch_layer,
+ patch_encoder,
+ encoder,
+ decoder,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.train_augmentation_model = train_augmentation_model
+ self.test_augmentation_model = test_augmentation_model
+ self.patch_layer = patch_layer
+ self.patch_encoder = patch_encoder
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def calculate_loss(self, images, test=False):
+ # Augment the input images.
+ if test:
+ augmented_images = self.test_augmentation_model(images)
+ else:
+ augmented_images = self.train_augmentation_model(images)
+
+ # Patch the augmented images.
+ patches = self.patch_layer(augmented_images)
+
+ # Encode the patches.
+ (
+ unmasked_embeddings,
+ masked_embeddings,
+ unmasked_positions,
+ mask_indices,
+ unmask_indices,
+ ) = self.patch_encoder(patches)
+
+ # Pass the unmaksed patche to the encoder.
+ encoder_outputs = self.encoder(unmasked_embeddings)
+
+ # Create the decoder inputs.
+ encoder_outputs = encoder_outputs + unmasked_positions
+ decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
+
+ # Decode the inputs.
+ decoder_outputs = self.decoder(decoder_inputs)
+ decoder_patches = self.patch_layer(decoder_outputs)
+
+ loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)
+ loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)
+
+ # Compute the total loss.
+ total_loss = self.compute_loss(y=loss_patch, y_pred=loss_output)
+
+ return total_loss, loss_patch, loss_output
+
+ def train_step(self, images):
+ with tf.GradientTape() as tape:
+ total_loss, loss_patch, loss_output = self.calculate_loss(images)
+
+ # Apply gradients.
+ train_vars = [
+ self.train_augmentation_model.trainable_variables,
+ self.patch_layer.trainable_variables,
+ self.patch_encoder.trainable_variables,
+ self.encoder.trainable_variables,
+ self.decoder.trainable_variables,
+ ]
+ grads = tape.gradient(total_loss, train_vars)
+ tv_list = []
+ for grad, var in zip(grads, train_vars):
+ for g, v in zip(grad, var):
+ tv_list.append((g, v))
+ self.optimizer.apply_gradients(tv_list)
+
+ # Report progress.
+ results = {}
+ for metric in self.metrics:
+ metric.update_state(loss_patch, loss_output)
+ results[metric.name] = metric.result()
+ return results
+
+ def test_step(self, images):
+ total_loss, loss_patch, loss_output = self.calculate_loss(images, test=True)
+
+ # Update the trackers.
+ results = {}
+ for metric in self.metrics:
+ metric.update_state(loss_patch, loss_output)
+ results[metric.name] = metric.result()
+ return results
+
+
+"""
+## Model initialization
+"""
+
+train_augmentation_model = get_train_augmentation_model()
+test_augmentation_model = get_test_augmentation_model()
+patch_layer = Patches()
+patch_encoder = PatchEncoder()
+encoder = create_encoder()
+decoder = create_decoder()
+
+mae_model = MaskedAutoencoder(
+ train_augmentation_model=train_augmentation_model,
+ test_augmentation_model=test_augmentation_model,
+ patch_layer=patch_layer,
+ patch_encoder=patch_encoder,
+ encoder=encoder,
+ decoder=decoder,
+)
+
+"""
+## Training callbacks
+"""
+
+"""
+### Visualization callback
+"""
+
+# Taking a batch of test inputs to measure model's progress.
+test_images = next(iter(test_ds))
+
+
+class TrainMonitor(keras.callbacks.Callback):
+ def __init__(self, epoch_interval=None):
+ self.epoch_interval = epoch_interval
+
+ def on_epoch_end(self, epoch, logs=None):
+ if self.epoch_interval and epoch % self.epoch_interval == 0:
+ test_augmented_images = self.model.test_augmentation_model(test_images)
+ test_patches = self.model.patch_layer(test_augmented_images)
+ (
+ test_unmasked_embeddings,
+ test_masked_embeddings,
+ test_unmasked_positions,
+ test_mask_indices,
+ test_unmask_indices,
+ ) = self.model.patch_encoder(test_patches)
+ test_encoder_outputs = self.model.encoder(test_unmasked_embeddings)
+ test_encoder_outputs = test_encoder_outputs + test_unmasked_positions
+ test_decoder_inputs = tf.concat(
+ [test_encoder_outputs, test_masked_embeddings], axis=1
+ )
+ test_decoder_outputs = self.model.decoder(test_decoder_inputs)
+
+ # Show a maksed patch image.
+ test_masked_patch, idx = self.model.patch_encoder.generate_masked_image(
+ test_patches, test_unmask_indices
+ )
+ print(f"\nIdx chosen: {idx}")
+ original_image = test_augmented_images[idx]
+ masked_image = self.model.patch_layer.reconstruct_from_patch(
+ test_masked_patch
+ )
+ reconstructed_image = test_decoder_outputs[idx]
+
+ fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
+ ax[0].imshow(original_image)
+ ax[0].set_title(f"Original: {epoch:03d}")
+
+ ax[1].imshow(masked_image)
+ ax[1].set_title(f"Masked: {epoch:03d}")
+
+ ax[2].imshow(reconstructed_image)
+ ax[2].set_title(f"Resonstructed: {epoch:03d}")
+
+ plt.show()
+ plt.close()
+
+
+"""
+### Learning rate scheduler
+"""
+
+# Some code is taken from:
+# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
+
+
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ def __init__(
+ self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
+ ):
+ super().__init__()
+
+ self.learning_rate_base = learning_rate_base
+ self.total_steps = total_steps
+ self.warmup_learning_rate = warmup_learning_rate
+ self.warmup_steps = warmup_steps
+ self.pi = tf.constant(np.pi)
+
+ def __call__(self, step):
+ if self.total_steps < self.warmup_steps:
+ raise ValueError("Total_steps must be larger or equal to warmup_steps.")
+
+ cos_annealed_lr = tf.cos(
+ self.pi
+ * (tf.cast(step, tf.float32) - self.warmup_steps)
+ / float(self.total_steps - self.warmup_steps)
+ )
+ learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
+
+ if self.warmup_steps > 0:
+ if self.learning_rate_base < self.warmup_learning_rate:
+ raise ValueError(
+ "Learning_rate_base must be larger or equal to "
+ "warmup_learning_rate."
+ )
+ slope = (
+ self.learning_rate_base - self.warmup_learning_rate
+ ) / self.warmup_steps
+ warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
+ learning_rate = tf.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+ return tf.where(
+ step > self.total_steps, 0.0, learning_rate, name="learning_rate"
+ )
+
+
+total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
+warmup_epoch_percentage = 0.15
+warmup_steps = int(total_steps * warmup_epoch_percentage)
+scheduled_lrs = WarmUpCosine(
+ learning_rate_base=LEARNING_RATE,
+ total_steps=total_steps,
+ warmup_learning_rate=0.0,
+ warmup_steps=warmup_steps,
+)
+
+lrs = [scheduled_lrs(step) for step in range(total_steps)]
+plt.plot(lrs)
+plt.xlabel("Step", fontsize=14)
+plt.ylabel("LR", fontsize=14)
+plt.show()
+
+# Assemble the callbacks.
+train_callbacks = [TrainMonitor(epoch_interval=5)]
+
+"""
+## Model compilation and training
+"""
+
+optimizer = keras.optimizers.AdamW(
+ learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY
+)
+
+# Compile and pretrain the model.
+mae_model.compile(
+ optimizer=optimizer, loss=keras.losses.MeanSquaredError(), metrics=["mae"]
+)
+history = mae_model.fit(
+ train_ds,
+ epochs=EPOCHS,
+ validation_data=val_ds,
+ callbacks=train_callbacks,
+)
+
+# Measure its performance.
+loss, mae = mae_model.evaluate(test_ds)
+print(f"Loss: {loss:.2f}")
+print(f"MAE: {mae:.2f}")
+
+"""
+## Evaluation with linear probing
+"""
+
+"""
+### Extract the encoder model along with other layers
+"""
+
+# Extract the augmentation layers.
+train_augmentation_model = mae_model.train_augmentation_model
+test_augmentation_model = mae_model.test_augmentation_model
+
+# Extract the patchers.
+patch_layer = mae_model.patch_layer
+patch_encoder = mae_model.patch_encoder
+patch_encoder.downstream = True # Swtich the downstream flag to True.
+
+# Extract the encoder.
+encoder = mae_model.encoder
+
+# Pack as a model.
+downstream_model = keras.Sequential(
+ [
+ layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
+ patch_layer,
+ patch_encoder,
+ encoder,
+ layers.BatchNormalization(), # Refer to A.1 (Linear probing).
+ layers.GlobalAveragePooling1D(),
+ layers.Dense(NUM_CLASSES, activation="softmax"),
+ ],
+ name="linear_probe_model",
+)
+
+# Only the final classification layer of the `downstream_model` should be trainable.
+for layer in downstream_model.layers[:-1]:
+ layer.trainable = False
+
+downstream_model.summary()
+
+"""
+We are using average pooling to extract learned representations from the MAE encoder.
+Another approach would be to use a learnable dummy token inside the encoder during
+pretraining (resembling the [CLS] token). Then we can extract representations from that
+token during the downstream tasks.
+"""
+
+"""
+### Prepare datasets for linear probing
+"""
+
+
+def prepare_data(images, labels, is_train=True):
+ if is_train:
+ augmentation_model = train_augmentation_model
+ else:
+ augmentation_model = test_augmentation_model
+
+ dataset = tf.data.Dataset.from_tensor_slices((images, labels))
+ if is_train:
+ dataset = dataset.shuffle(BUFFER_SIZE)
+
+ dataset = dataset.batch(BATCH_SIZE).map(
+ lambda x, y: (augmentation_model(x), y), num_parallel_calls=AUTO
+ )
+ return dataset.prefetch(AUTO)
+
+
+train_ds = prepare_data(x_train, y_train)
+val_ds = prepare_data(x_train, y_train, is_train=False)
+test_ds = prepare_data(x_test, y_test, is_train=False)
+
+"""
+### Perform linear probing
+"""
+
+linear_probe_epochs = 50
+linear_prob_lr = 0.1
+warm_epoch_percentage = 0.1
+steps = int((len(x_train) // BATCH_SIZE) * linear_probe_epochs)
+
+warmup_steps = int(steps * warm_epoch_percentage)
+scheduled_lrs = WarmUpCosine(
+ learning_rate_base=linear_prob_lr,
+ total_steps=steps,
+ warmup_learning_rate=0.0,
+ warmup_steps=warmup_steps,
+)
+
+optimizer = keras.optimizers.SGD(learning_rate=scheduled_lrs, momentum=0.9)
+downstream_model.compile(
+ optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"]
+)
+downstream_model.fit(train_ds, validation_data=val_ds, epochs=linear_probe_epochs)
+
+loss, accuracy = downstream_model.evaluate(test_ds)
+accuracy = round(accuracy * 100, 2)
+print(f"Accuracy on the test set: {accuracy}%.")
+
+"""
+We believe that with a more sophisticated hyperparameter tuning process and a longer
+pretraining it is possible to improve this performance further. For comparison, we took
+the encoder architecture and
+[trained it from scratch](https://github.com/ariG23498/mae-scalable-vision-learners/blob/master/regular-classification.ipynb)
+in a fully supervised manner. This gave us ~76% test top-1 accuracy. The authors of
+MAE demonstrates strong performance on the ImageNet-1k dataset as well as
+other downstream tasks like object detection and semantic segmentation.
+"""
+
+"""
+## Final notes
+
+We refer the interested readers to other examples on self-supervised learning present on
+keras.io:
+
+* [SimCLR](https://keras.io/examples/vision/semisupervised_simclr/)
+* [NNCLR](https://keras.io/examples/vision/nnclr)
+* [SimSiam](https://keras.io/examples/vision/simsiam)
+
+This idea of using BERT flavored pretraining in computer vision was also explored in
+[Selfie](https://arxiv.org/abs/1906.02940), but it could not demonstrate strong results.
+Another concurrent work that explores the idea of masked image modeling is
+[SimMIM](https://arxiv.org/abs/2111.09886). Finally, as a fun fact, we, the authors of
+this example also explored the idea of ["reconstruction as a pretext task"](https://i.ibb.co/k5CpwDX/image.png)
+in 2020 but we could not prevent the network from representation collapse, and
+hence we did not get strong downstream performance.
+
+We would like to thank [Xinlei Chen](http://xinleic.xyz/)
+(one of the authors of MAE) for helpful discussions. We are grateful to
+[JarvisLabs](https://jarvislabs.ai/) and
+[Google Developers Experts](https://developers.google.com/programs/experts/)
+program for helping with GPU credits.
+"""
diff --git a/knowledge_base/vision/metric_learning.py b/knowledge_base/vision/metric_learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..3376afd63c396d657e8c98c81d1a335b18f39018
--- /dev/null
+++ b/knowledge_base/vision/metric_learning.py
@@ -0,0 +1,325 @@
+"""
+Title: Metric learning for image similarity search
+Author: [Mat Kelcey](https://twitter.com/mat_kelcey)
+Date created: 2020/06/05
+Last modified: 2020/06/09
+Description: Example of using similarity metric learning on CIFAR-10 images.
+Accelerator: GPU
+"""
+
+"""
+## Overview
+
+Metric learning aims to train models that can embed inputs into a high-dimensional space
+such that "similar" inputs, as defined by the training scheme, are located close to each
+other. These models once trained can produce embeddings for downstream systems where such
+similarity is useful; examples include as a ranking signal for search or as a form of
+pretrained embedding model for another supervised problem.
+
+For a more detailed overview of metric learning see:
+
+* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)
+* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)
+"""
+
+"""
+## Setup
+
+Set Keras backend to tensorflow.
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import random
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+from collections import defaultdict
+from PIL import Image
+from sklearn.metrics import ConfusionMatrixDisplay
+import keras
+from keras import layers
+
+"""
+## Dataset
+
+For this example we will be using the
+[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
+"""
+
+from keras.datasets import cifar10
+
+
+(x_train, y_train), (x_test, y_test) = cifar10.load_data()
+
+x_train = x_train.astype("float32") / 255.0
+y_train = np.squeeze(y_train)
+x_test = x_test.astype("float32") / 255.0
+y_test = np.squeeze(y_test)
+
+"""
+To get a sense of the dataset we can visualise a grid of 25 random examples.
+
+
+"""
+
+height_width = 32
+
+
+def show_collage(examples):
+ box_size = height_width + 2
+ num_rows, num_cols = examples.shape[:2]
+
+ collage = Image.new(
+ mode="RGB",
+ size=(num_cols * box_size, num_rows * box_size),
+ color=(250, 250, 250),
+ )
+ for row_idx in range(num_rows):
+ for col_idx in range(num_cols):
+ array = (np.array(examples[row_idx, col_idx]) * 255).astype(np.uint8)
+ collage.paste(
+ Image.fromarray(array), (col_idx * box_size, row_idx * box_size)
+ )
+
+ # Double size for visualisation.
+ collage = collage.resize((2 * num_cols * box_size, 2 * num_rows * box_size))
+ return collage
+
+
+# Show a collage of 5x5 random images.
+sample_idxs = np.random.randint(0, 50000, size=(5, 5))
+examples = x_train[sample_idxs]
+show_collage(examples)
+
+"""
+Metric learning provides training data not as explicit `(X, y)` pairs but instead uses
+multiple instances that are related in the way we want to express similarity. In our
+example we will use instances of the same class to represent similarity; a single
+training instance will not be one image, but a pair of images of the same class. When
+referring to the images in this pair we'll use the common metric learning names of the
+`anchor` (a randomly chosen image) and the `positive` (another randomly chosen image of
+the same class).
+
+To facilitate this we need to build a form of lookup that maps from classes to the
+instances of that class. When generating data for training we will sample from this
+lookup.
+"""
+
+class_idx_to_train_idxs = defaultdict(list)
+for y_train_idx, y in enumerate(y_train):
+ class_idx_to_train_idxs[y].append(y_train_idx)
+
+class_idx_to_test_idxs = defaultdict(list)
+for y_test_idx, y in enumerate(y_test):
+ class_idx_to_test_idxs[y].append(y_test_idx)
+
+"""
+For this example we are using the simplest approach to training; a batch will consist of
+`(anchor, positive)` pairs spread across the classes. The goal of learning will be to
+move the anchor and positive pairs closer together and further away from other instances
+in the batch. In this case the batch size will be dictated by the number of classes; for
+CIFAR-10 this is 10.
+"""
+
+num_classes = 10
+
+
+class AnchorPositivePairs(keras.utils.Sequence):
+ def __init__(self, num_batches):
+ super().__init__()
+ self.num_batches = num_batches
+
+ def __len__(self):
+ return self.num_batches
+
+ def __getitem__(self, _idx):
+ x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
+ for class_idx in range(num_classes):
+ examples_for_class = class_idx_to_train_idxs[class_idx]
+ anchor_idx = random.choice(examples_for_class)
+ positive_idx = random.choice(examples_for_class)
+ while positive_idx == anchor_idx:
+ positive_idx = random.choice(examples_for_class)
+ x[0, class_idx] = x_train[anchor_idx]
+ x[1, class_idx] = x_train[positive_idx]
+ return x
+
+
+"""
+We can visualise a batch in another collage. The top row shows randomly chosen anchors
+from the 10 classes, the bottom row shows the corresponding 10 positives.
+"""
+
+examples = next(iter(AnchorPositivePairs(num_batches=1)))
+
+show_collage(examples)
+
+"""
+## Embedding model
+
+We define a custom model with a `train_step` that first embeds both anchors and positives
+and then uses their pairwise dot products as logits for a softmax.
+"""
+
+
+class EmbeddingModel(keras.Model):
+ def train_step(self, data):
+ # Note: Workaround for open issue, to be removed.
+ if isinstance(data, tuple):
+ data = data[0]
+ anchors, positives = data[0], data[1]
+
+ with tf.GradientTape() as tape:
+ # Run both anchors and positives through model.
+ anchor_embeddings = self(anchors, training=True)
+ positive_embeddings = self(positives, training=True)
+
+ # Calculate cosine similarity between anchors and positives. As they have
+ # been normalised this is just the pair wise dot products.
+ similarities = keras.ops.einsum(
+ "ae,pe->ap", anchor_embeddings, positive_embeddings
+ )
+
+ # Since we intend to use these as logits we scale them by a temperature.
+ # This value would normally be chosen as a hyper parameter.
+ temperature = 0.2
+ similarities /= temperature
+
+ # We use these similarities as logits for a softmax. The labels for
+ # this call are just the sequence [0, 1, 2, ..., num_classes] since we
+ # want the main diagonal values, which correspond to the anchor/positive
+ # pairs, to be high. This loss will move embeddings for the
+ # anchor/positive pairs together and move all other pairs apart.
+ sparse_labels = keras.ops.arange(num_classes)
+ loss = self.compute_loss(y=sparse_labels, y_pred=similarities)
+
+ # Calculate gradients and apply via optimizer.
+ gradients = tape.gradient(loss, self.trainable_variables)
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
+
+ # Update and return metrics (specifically the one for the loss value).
+ for metric in self.metrics:
+ # Calling `self.compile` will by default add a `keras.metrics.Mean` loss
+ if metric.name == "loss":
+ metric.update_state(loss)
+ else:
+ metric.update_state(sparse_labels, similarities)
+
+ return {m.name: m.result() for m in self.metrics}
+
+
+"""
+Next we describe the architecture that maps from an image to an embedding. This model
+simply consists of a sequence of 2d convolutions followed by global pooling with a final
+linear projection to an embedding space. As is common in metric learning we normalise the
+embeddings so that we can use simple dot products to measure similarity. For simplicity
+this model is intentionally small.
+"""
+
+inputs = layers.Input(shape=(height_width, height_width, 3))
+x = layers.Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
+x = layers.Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
+x = layers.Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
+x = layers.GlobalAveragePooling2D()(x)
+embeddings = layers.Dense(units=8, activation=None)(x)
+embeddings = layers.UnitNormalization()(embeddings)
+
+model = EmbeddingModel(inputs, embeddings)
+
+"""
+Finally we run the training. On a Google Colab GPU instance this takes about a minute.
+"""
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=1e-3),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+)
+
+history = model.fit(AnchorPositivePairs(num_batches=1000), epochs=20)
+
+plt.plot(history.history["loss"])
+plt.show()
+
+"""
+## Testing
+
+We can review the quality of this model by applying it to the test set and considering
+near neighbours in the embedding space.
+
+First we embed the test set and calculate all near neighbours. Recall that since the
+embeddings are unit length we can calculate cosine similarity via dot products.
+"""
+
+near_neighbours_per_example = 10
+
+embeddings = model.predict(x_test)
+gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
+near_neighbours = np.argsort(gram_matrix.T)[:, -(near_neighbours_per_example + 1) :]
+
+"""
+As a visual check of these embeddings we can build a collage of the near neighbours for 5
+random examples. The first column of the image below is a randomly selected image, the
+following 10 columns show the nearest neighbours in order of similarity.
+"""
+
+num_collage_examples = 5
+
+examples = np.empty(
+ (
+ num_collage_examples,
+ near_neighbours_per_example + 1,
+ height_width,
+ height_width,
+ 3,
+ ),
+ dtype=np.float32,
+)
+for row_idx in range(num_collage_examples):
+ examples[row_idx, 0] = x_test[row_idx]
+ anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])
+ for col_idx, nn_idx in enumerate(anchor_near_neighbours):
+ examples[row_idx, col_idx + 1] = x_test[nn_idx]
+
+show_collage(examples)
+
+"""
+We can also get a quantified view of the performance by considering the correctness of
+near neighbours in terms of a confusion matrix.
+
+Let us sample 10 examples from each of the 10 classes and consider their near neighbours
+as a form of prediction; that is, does the example and its near neighbours share the same
+class?
+
+We observe that each animal class does generally well, and is confused the most with the
+other animal classes. The vehicle classes follow the same pattern.
+"""
+
+confusion_matrix = np.zeros((num_classes, num_classes))
+
+# For each class.
+for class_idx in range(num_classes):
+ # Consider 10 examples.
+ example_idxs = class_idx_to_test_idxs[class_idx][:10]
+ for y_test_idx in example_idxs:
+ # And count the classes of its near neighbours.
+ for nn_idx in near_neighbours[y_test_idx][:-1]:
+ nn_class_idx = y_test[nn_idx]
+ confusion_matrix[class_idx, nn_class_idx] += 1
+
+# Display a confusion matrix.
+labels = [
+ "Airplane",
+ "Automobile",
+ "Bird",
+ "Cat",
+ "Deer",
+ "Dog",
+ "Frog",
+ "Horse",
+ "Ship",
+ "Truck",
+]
+disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
+disp.plot(include_values=True, cmap="viridis", ax=None, xticks_rotation="vertical")
+plt.show()
diff --git a/knowledge_base/vision/metric_learning_tf_similarity.py b/knowledge_base/vision/metric_learning_tf_similarity.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffcfc1c4dc0c154ffbfd58d4f3ad7bcb374d4c7b
--- /dev/null
+++ b/knowledge_base/vision/metric_learning_tf_similarity.py
@@ -0,0 +1,428 @@
+"""
+Title: Metric learning for image similarity search using TensorFlow Similarity
+Author: [Owen Vallis](https://twitter.com/owenvallis)
+Date created: 2021/09/30
+Last modified: 2022/02/29
+Description: Example of using similarity metric learning on CIFAR-10 images.
+Accelerator: GPU
+"""
+
+"""
+## Overview
+
+This example is based on the
+["Metric learning for image similarity search" example](https://keras.io/examples/vision/metric_learning/).
+We aim to use the same data set but implement the model using
+[TensorFlow Similarity](https://github.com/tensorflow/similarity).
+
+Metric learning aims to train models that can embed inputs into a
+high-dimensional space such that "similar" inputs are pulled closer to each
+other and "dissimilar" inputs are pushed farther apart. Once trained, these
+models can produce embeddings for downstream systems where such similarity is
+useful, for instance as a ranking signal for search or as a form of pretrained
+embedding model for another supervised problem.
+
+For a more detailed overview of metric learning, see:
+
+* [What is metric learning?](http://contrib.scikit-learn.org/metric-learn/introduction.html)
+* ["Using crossentropy for metric learning" tutorial](https://www.youtube.com/watch?v=Jb4Ewl5RzkI)
+"""
+
+"""
+## Setup
+
+This tutorial will use the [TensorFlow Similarity](https://github.com/tensorflow/similarity) library
+to learn and evaluate the similarity embedding.
+TensorFlow Similarity provides components that:
+
+* Make training contrastive models simple and fast.
+* Make it easier to ensure that batches contain pairs of examples.
+* Enable the evaluation of the quality of the embedding.
+
+TensorFlow Similarity can be installed easily via pip, as follows:
+
+```
+pip -q install tensorflow_similarity
+```
+
+"""
+
+import random
+
+from matplotlib import pyplot as plt
+from mpl_toolkits import axes_grid1
+import numpy as np
+
+import tensorflow as tf
+from tensorflow import keras
+
+import tensorflow_similarity as tfsim
+
+
+tfsim.utils.tf_cap_memory()
+
+print("TensorFlow:", tf.__version__)
+print("TensorFlow Similarity:", tfsim.__version__)
+
+"""
+## Dataset samplers
+
+We will be using the
+[CIFAR-10](https://www.tensorflow.org/datasets/catalog/cifar10)
+dataset for this tutorial.
+
+For a similarity model to learn efficiently, each batch must contain at least 2
+examples of each class.
+
+To make this easy, tf_similarity offers `Sampler` objects that enable you to set both
+the number of classes and the minimum number of examples of each class per
+batch.
+
+The training and validation datasets will be created using the
+`TFDatasetMultiShotMemorySampler` object. This creates a sampler that loads datasets
+from [TensorFlow Datasets](https://www.tensorflow.org/datasets) and yields
+batches containing a target number of classes and a target number of examples
+per class. Additionally, we can restrict the sampler to only yield the subset of
+classes defined in `class_list`, enabling us to train on a subset of the classes
+and then test how the embedding generalizes to the unseen classes. This can be
+useful when working on few-shot learning problems.
+
+The following cell creates a train_ds sample that:
+
+* Loads the CIFAR-10 dataset from TFDS and then takes the `examples_per_class_per_batch`.
+* Ensures the sampler restricts the classes to those defined in `class_list`.
+* Ensures each batch contains 10 different classes with 8 examples each.
+
+We also create a validation dataset in the same way, but we limit the total number of
+examples per class to 100 and the examples per class per batch is set to the
+default of 2.
+"""
+# This determines the number of classes used during training.
+# Here we are using all the classes.
+num_known_classes = 10
+class_list = random.sample(population=range(10), k=num_known_classes)
+
+classes_per_batch = 10
+# Passing multiple examples per class per batch ensures that each example has
+# multiple positive pairs. This can be useful when performing triplet mining or
+# when using losses like `MultiSimilarityLoss` or `CircleLoss` as these can
+# take a weighted mix of all the positive pairs. In general, more examples per
+# class will lead to more information for the positive pairs, while more classes
+# per batch will provide more varied information in the negative pairs. However,
+# the losses compute the pairwise distance between the examples in a batch so
+# the upper limit of the batch size is restricted by the memory.
+examples_per_class_per_batch = 8
+
+print(
+ "Batch size is: "
+ f"{min(classes_per_batch, num_known_classes) * examples_per_class_per_batch}"
+)
+
+print(" Create Training Data ".center(34, "#"))
+train_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
+ "cifar10",
+ classes_per_batch=min(classes_per_batch, num_known_classes),
+ splits="train",
+ steps_per_epoch=4000,
+ examples_per_class_per_batch=examples_per_class_per_batch,
+ class_list=class_list,
+)
+
+print("\n" + " Create Validation Data ".center(34, "#"))
+val_ds = tfsim.samplers.TFDatasetMultiShotMemorySampler(
+ "cifar10",
+ classes_per_batch=classes_per_batch,
+ splits="test",
+ total_examples_per_class=100,
+)
+
+"""
+## Visualize the dataset
+
+The samplers will shuffle the dataset, so we can get a sense of the dataset by
+plotting the first 25 images.
+
+The samplers provide a `get_slice(begin, size)` method that allows us to easily
+select a block of samples.
+
+Alternatively, we can use the `generate_batch()` method to yield a batch. This
+can allow us to check that a batch contains the expected number of classes and
+examples per class.
+"""
+
+num_cols = num_rows = 5
+# Get the first 25 examples.
+x_slice, y_slice = train_ds.get_slice(begin=0, size=num_cols * num_rows)
+
+fig = plt.figure(figsize=(6.0, 6.0))
+grid = axes_grid1.ImageGrid(fig, 111, nrows_ncols=(num_cols, num_rows), axes_pad=0.1)
+
+for ax, im, label in zip(grid, x_slice, y_slice):
+ ax.imshow(im)
+ ax.axis("off")
+
+"""
+## Embedding model
+
+Next we define a `SimilarityModel` using the Keras Functional API. The model
+is a standard convnet with the addition of a `MetricEmbedding` layer that
+applies L2 normalization. The metric embedding layer is helpful when using
+`Cosine` distance as we only care about the angle between the vectors.
+
+Additionally, the `SimilarityModel` provides a number of helper methods for:
+
+* Indexing embedded examples
+* Performing example lookups
+* Evaluating the classification
+* Evaluating the quality of the embedding space
+
+See the [TensorFlow Similarity documentation](https://github.com/tensorflow/similarity)
+for more details.
+"""
+
+embedding_size = 256
+
+inputs = keras.layers.Input((32, 32, 3))
+x = keras.layers.Rescaling(scale=1.0 / 255)(inputs)
+x = keras.layers.Conv2D(64, 3, activation="relu")(x)
+x = keras.layers.BatchNormalization()(x)
+x = keras.layers.Conv2D(128, 3, activation="relu")(x)
+x = keras.layers.BatchNormalization()(x)
+x = keras.layers.MaxPool2D((4, 4))(x)
+x = keras.layers.Conv2D(256, 3, activation="relu")(x)
+x = keras.layers.BatchNormalization()(x)
+x = keras.layers.Conv2D(256, 3, activation="relu")(x)
+x = keras.layers.GlobalMaxPool2D()(x)
+outputs = tfsim.layers.MetricEmbedding(embedding_size)(x)
+
+# building model
+model = tfsim.models.SimilarityModel(inputs, outputs)
+model.summary()
+
+"""
+## Similarity loss
+
+The similarity loss expects batches containing at least 2 examples of each
+class, from which it computes the loss over the pairwise positive and negative
+distances. Here we are using `MultiSimilarityLoss()`
+([paper](ihttps://arxiv.org/abs/1904.06627)), one of several losses in
+[TensorFlow Similarity](https://github.com/tensorflow/similarity). This loss
+attempts to use all informative pairs in the batch, taking into account the
+self-similarity, positive-similarity, and the negative-similarity.
+"""
+
+epochs = 3
+learning_rate = 0.002
+val_steps = 50
+
+# init similarity loss
+loss = tfsim.losses.MultiSimilarityLoss()
+
+# compiling and training
+model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=loss,
+ steps_per_execution=10,
+)
+history = model.fit(
+ train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps
+)
+
+"""
+## Indexing
+
+Now that we have trained our model, we can create an index of examples. Here we
+batch index the first 200 validation examples by passing the x and y to the index
+along with storing the image in the data parameter. The `x_index` values are
+embedded and then added to the index to make them searchable. The `y_index` and
+data parameters are optional but allow the user to associate metadata with the
+embedded example.
+"""
+
+x_index, y_index = val_ds.get_slice(begin=0, size=200)
+model.reset_index()
+model.index(x_index, y_index, data=x_index)
+
+"""
+## Calibration
+
+Once the index is built, we can calibrate a distance threshold using a matching
+strategy and a calibration metric.
+
+Here we are searching for the optimal F1 score while using K=1 as our
+classifier. All matches at or below the calibrated threshold distance will be
+labeled as a Positive match between the query example and the label associated
+with the match result, while all matches above the threshold distance will be
+labeled as a Negative match.
+
+Additionally, we pass in extra metrics to compute as well. All values in the
+output are computed at the calibrated threshold.
+
+Finally, `model.calibrate()` returns a `CalibrationResults` object containing:
+
+* `"cutpoints"`: A Python dict mapping the cutpoint name to a dict containing the
+`ClassificationMetric` values associated with a particular distance threshold,
+e.g., `"optimal" : {"acc": 0.90, "f1": 0.92}`.
+* `"thresholds"`: A Python dict mapping `ClassificationMetric` names to a list
+containing the metric's value computed at each of the distance thresholds, e.g.,
+`{"f1": [0.99, 0.80], "distance": [0.0, 1.0]}`.
+"""
+
+x_train, y_train = train_ds.get_slice(begin=0, size=1000)
+calibration = model.calibrate(
+ x_train,
+ y_train,
+ calibration_metric="f1",
+ matcher="match_nearest",
+ extra_metrics=["precision", "recall", "binary_accuracy"],
+ verbose=1,
+)
+
+"""
+## Visualization
+
+It may be difficult to get a sense of the model quality from the metrics alone.
+A complementary approach is to manually inspect a set of query results to get a
+feel for the match quality.
+
+Here we take 10 validation examples and plot them with their 5 nearest
+neighbors and the distances to the query example. Looking at the results, we see
+that while they are imperfect they still represent meaningfully similar images,
+and that the model is able to find similar images irrespective of their pose or
+image illumination.
+
+We can also see that the model is very confident with certain images, resulting
+in very small distances between the query and the neighbors. Conversely, we see
+more mistakes in the class labels as the distances become larger. This is one of
+the reasons why calibration is critical for matching applications.
+"""
+
+num_neighbors = 5
+labels = [
+ "Airplane",
+ "Automobile",
+ "Bird",
+ "Cat",
+ "Deer",
+ "Dog",
+ "Frog",
+ "Horse",
+ "Ship",
+ "Truck",
+ "Unknown",
+]
+class_mapping = {c_id: c_lbl for c_id, c_lbl in zip(range(11), labels)}
+
+x_display, y_display = val_ds.get_slice(begin=200, size=10)
+# lookup nearest neighbors in the index
+nns = model.lookup(x_display, k=num_neighbors)
+
+# display
+for idx in np.argsort(y_display):
+ tfsim.visualization.viz_neigbors_imgs(
+ x_display[idx],
+ y_display[idx],
+ nns[idx],
+ class_mapping=class_mapping,
+ fig_size=(16, 2),
+ )
+
+"""
+## Metrics
+
+We can also plot the extra metrics contained in the `CalibrationResults` to get
+a sense of the matching performance as the distance threshold increases.
+
+The following plots show the Precision, Recall, and F1 Score. We can see that
+the matching precision degrades as the distance increases, but that the
+percentage of the queries that we accept as positive matches (recall) grows
+faster up to the calibrated distance threshold.
+"""
+
+fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
+x = calibration.thresholds["distance"]
+
+ax1.plot(x, calibration.thresholds["precision"], label="precision")
+ax1.plot(x, calibration.thresholds["recall"], label="recall")
+ax1.plot(x, calibration.thresholds["f1"], label="f1 score")
+ax1.legend()
+ax1.set_title("Metric evolution as distance increase")
+ax1.set_xlabel("Distance")
+ax1.set_ylim((-0.05, 1.05))
+
+ax2.plot(calibration.thresholds["recall"], calibration.thresholds["precision"])
+ax2.set_title("Precision recall curve")
+ax2.set_xlabel("Recall")
+ax2.set_ylabel("Precision")
+ax2.set_ylim((-0.05, 1.05))
+plt.show()
+
+"""
+We can also take 100 examples for each class and plot the confusion matrix for
+each example and their nearest match. We also add an "extra" 10th class to
+represent the matches above the calibrated distance threshold.
+
+We can see that most of the errors are between the animal classes with an
+interesting number of confusions between Airplane and Bird. Additionally, we see
+that only a few of the 100 examples for each class returned matches outside of
+the calibrated distance threshold.
+"""
+
+cutpoint = "optimal"
+
+# This yields 100 examples for each class.
+# We defined this when we created the val_ds sampler.
+x_confusion, y_confusion = val_ds.get_slice(0, -1)
+
+matches = model.match(x_confusion, cutpoint=cutpoint, no_match_label=10)
+cm = tfsim.visualization.confusion_matrix(
+ matches,
+ y_confusion,
+ labels=labels,
+ title="Confusion matrix for cutpoint:%s" % cutpoint,
+ normalize=False,
+)
+
+"""
+## No Match
+
+We can plot the examples outside of the calibrated threshold to see which images
+are not matching any indexed examples.
+
+This may provide insight into what other examples may need to be indexed or
+surface anomalous examples within the class.
+"""
+
+idx_no_match = np.where(np.array(matches) == 10)
+no_match_queries = x_confusion[idx_no_match]
+if len(no_match_queries):
+ plt.imshow(no_match_queries[0])
+else:
+ print("All queries have a match below the distance threshold.")
+
+"""
+## Visualize clusters
+
+One of the best ways to quickly get a sense of the quality of how the model is
+doing and understand it's short comings is to project the embedding into a 2D
+space.
+
+This allows us to inspect clusters of images and understand which classes are
+entangled.
+"""
+
+# Each class in val_ds was restricted to 100 examples.
+num_examples_to_clusters = 1000
+thumb_size = 96
+plot_size = 800
+vx, vy = val_ds.get_slice(0, num_examples_to_clusters)
+
+# Uncomment to run the interactive projector.
+# tfsim.visualization.projector(
+# model.predict(vx),
+# labels=vy,
+# images=vx,
+# class_mapping=class_mapping,
+# image_size=thumb_size,
+# plot_size=plot_size,
+# )
diff --git a/knowledge_base/vision/mirnet.py b/knowledge_base/vision/mirnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1099cd042cb8c87527fabc52e5b0d9794a1ef1
--- /dev/null
+++ b/knowledge_base/vision/mirnet.py
@@ -0,0 +1,507 @@
+"""
+Title: Low-light image enhancement using MIRNet
+Author: [Soumik Rakshit](http://github.com/soumik12345)
+Date created: 2021/09/11
+Last modified: 2023/07/15
+Description: Implementing the MIRNet architecture for low-light image enhancement.
+Accelerator: GPU
+Converted to Keras 3 by: [Soumik Rakshit](http://github.com/soumik12345)
+"""
+
+"""
+## Introduction
+
+With the goal of recovering high-quality image content from its degraded version, image
+restoration enjoys numerous applications, such as in
+photography, security, medical imaging, and remote sensing. In this example, we implement the
+**MIRNet** model for low-light image enhancement, a fully-convolutional architecture that
+learns an enriched set of
+features that combines contextual information from multiple scales, while
+simultaneously preserving the high-resolution spatial details.
+
+### References:
+
+- [Learning Enriched Features for Real Image Restoration and Enhancement](https://arxiv.org/abs/2003.06792)
+- [The Retinex Theory of Color Vision](http://www.cnbc.cmu.edu/~tai/cp_papers/E.Land_Retinex_Theory_ScientifcAmerican.pdf)
+- [Two deterministic half-quadratic regularization algorithms for computed imaging](https://ieeexplore.ieee.org/document/413553)
+"""
+
+"""
+## Downloading LOLDataset
+
+The **LoL Dataset** has been created for low-light image enhancement.
+It provides 485 images for training and 15 for testing. Each image pair in the dataset
+consists of a low-light input image and its corresponding well-exposed reference image.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import random
+import numpy as np
+from glob import glob
+from PIL import Image, ImageOps
+import matplotlib.pyplot as plt
+
+import keras
+from keras import layers
+
+import tensorflow as tf
+
+"""shell
+wget https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
+unzip -q lol_dataset.zip && rm lol_dataset.zip
+"""
+
+"""
+## Creating a TensorFlow Dataset
+
+We use 300 image pairs from the LoL Dataset's training set for training,
+and we use the remaining 185 image pairs for validation.
+We generate random crops of size `128 x 128` from the image pairs to be
+used for both training and validation.
+"""
+
+random.seed(10)
+
+IMAGE_SIZE = 128
+BATCH_SIZE = 4
+MAX_TRAIN_IMAGES = 300
+
+
+def read_image(image_path):
+ image = tf.io.read_file(image_path)
+ image = tf.image.decode_png(image, channels=3)
+ image.set_shape([None, None, 3])
+ image = tf.cast(image, dtype=tf.float32) / 255.0
+ return image
+
+
+def random_crop(low_image, enhanced_image):
+ low_image_shape = tf.shape(low_image)[:2]
+ low_w = tf.random.uniform(
+ shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32
+ )
+ low_h = tf.random.uniform(
+ shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32
+ )
+ low_image_cropped = low_image[
+ low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
+ ]
+ enhanced_image_cropped = enhanced_image[
+ low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE
+ ]
+ # in order to avoid `NONE` during shape inference
+ low_image_cropped.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])
+ enhanced_image_cropped.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])
+ return low_image_cropped, enhanced_image_cropped
+
+
+def load_data(low_light_image_path, enhanced_image_path):
+ low_light_image = read_image(low_light_image_path)
+ enhanced_image = read_image(enhanced_image_path)
+ low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)
+ return low_light_image, enhanced_image
+
+
+def get_dataset(low_light_images, enhanced_images):
+ dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
+ dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
+ dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
+ return dataset
+
+
+train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
+train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]
+
+val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
+val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]
+
+test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))
+test_enhanced_images = sorted(glob("./lol_dataset/eval15/high/*"))
+
+
+train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
+val_dataset = get_dataset(val_low_light_images, val_enhanced_images)
+
+
+print("Train Dataset:", train_dataset.element_spec)
+print("Val Dataset:", val_dataset.element_spec)
+
+"""
+## MIRNet Model
+
+Here are the main features of the MIRNet model:
+
+- A feature extraction model that computes a complementary set of features across multiple
+spatial scales, while maintaining the original high-resolution features to preserve
+precise spatial details.
+- A regularly repeated mechanism for information exchange, where the features across
+multi-resolution branches are progressively fused together for improved representation
+learning.
+- A new approach to fuse multi-scale features using a selective kernel network
+that dynamically combines variable receptive fields and faithfully preserves
+the original feature information at each spatial resolution.
+- A recursive residual design that progressively breaks down the input signal
+in order to simplify the overall learning process, and allows the construction
+of very deep networks.
+
+
+
+"""
+
+"""
+### Selective Kernel Feature Fusion
+
+The Selective Kernel Feature Fusion or SKFF module performs dynamic adjustment of
+receptive fields via two operations: **Fuse** and **Select**. The Fuse operator generates
+global feature descriptors by combining the information from multi-resolution streams.
+The Select operator uses these descriptors to recalibrate the feature maps (of different
+streams) followed by their aggregation.
+
+**Fuse**: The SKFF receives inputs from three parallel convolution streams carrying
+different scales of information. We first combine these multi-scale features using an
+element-wise sum, on which we apply Global Average Pooling (GAP) across the spatial
+dimension. Next, we apply a channel- downscaling convolution layer to generate a compact
+feature representation which passes through three parallel channel-upscaling convolution
+layers (one for each resolution stream) and provides us with three feature descriptors.
+
+**Select**: This operator applies the softmax function to the feature descriptors to
+obtain the corresponding activations that are used to adaptively recalibrate multi-scale
+feature maps. The aggregated features are defined as the sum of product of the corresponding
+multi-scale feature and the feature descriptor.
+
+
+"""
+
+
+def selective_kernel_feature_fusion(
+ multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
+):
+ channels = list(multi_scale_feature_1.shape)[-1]
+ combined_feature = layers.Add()(
+ [multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]
+ )
+ gap = layers.GlobalAveragePooling2D()(combined_feature)
+ channel_wise_statistics = layers.Reshape((1, 1, channels))(gap)
+ compact_feature_representation = layers.Conv2D(
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
+ )(channel_wise_statistics)
+ feature_descriptor_1 = layers.Conv2D(
+ channels, kernel_size=(1, 1), activation="softmax"
+ )(compact_feature_representation)
+ feature_descriptor_2 = layers.Conv2D(
+ channels, kernel_size=(1, 1), activation="softmax"
+ )(compact_feature_representation)
+ feature_descriptor_3 = layers.Conv2D(
+ channels, kernel_size=(1, 1), activation="softmax"
+ )(compact_feature_representation)
+ feature_1 = multi_scale_feature_1 * feature_descriptor_1
+ feature_2 = multi_scale_feature_2 * feature_descriptor_2
+ feature_3 = multi_scale_feature_3 * feature_descriptor_3
+ aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])
+ return aggregated_feature
+
+
+"""
+### Dual Attention Unit
+
+The Dual Attention Unit or DAU is used to extract features in the convolutional streams.
+While the SKFF block fuses information across multi-resolution branches, we also need a
+mechanism to share information within a feature tensor, both along the spatial and the
+channel dimensions which is done by the DAU block. The DAU suppresses less useful
+features and only allows more informative ones to pass further. This feature
+recalibration is achieved by using **Channel Attention** and **Spatial Attention**
+mechanisms.
+
+The **Channel Attention** branch exploits the inter-channel relationships of the
+convolutional feature maps by applying squeeze and excitation operations. Given a feature
+map, the squeeze operation applies Global Average Pooling across spatial dimensions to
+encode global context, thus yielding a feature descriptor. The excitation operator passes
+this feature descriptor through two convolutional layers followed by the sigmoid gating
+and generates activations. Finally, the output of Channel Attention branch is obtained by
+rescaling the input feature map with the output activations.
+
+The **Spatial Attention** branch is designed to exploit the inter-spatial dependencies of
+convolutional features. The goal of Spatial Attention is to generate a spatial attention
+map and use it to recalibrate the incoming features. To generate the spatial attention
+map, the Spatial Attention branch first independently applies Global Average Pooling and
+Max Pooling operations on input features along the channel dimensions and concatenates
+the outputs to form a resultant feature map which is then passed through a convolution
+and sigmoid activation to obtain the spatial attention map. This spatial attention map is
+then used to rescale the input feature map.
+
+
+"""
+
+
+class ChannelPooling(layers.Layer):
+ def __init__(self, axis=-1, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.axis = axis
+ self.concat = layers.Concatenate(axis=self.axis)
+
+ def call(self, inputs):
+ average_pooling = tf.expand_dims(tf.reduce_mean(inputs, axis=-1), axis=-1)
+ max_pooling = tf.expand_dims(tf.reduce_max(inputs, axis=-1), axis=-1)
+ return self.concat([average_pooling, max_pooling])
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"axis": self.axis})
+
+
+def spatial_attention_block(input_tensor):
+ compressed_feature_map = ChannelPooling(axis=-1)(input_tensor)
+ feature_map = layers.Conv2D(1, kernel_size=(1, 1))(compressed_feature_map)
+ feature_map = keras.activations.sigmoid(feature_map)
+ return input_tensor * feature_map
+
+
+def channel_attention_block(input_tensor):
+ channels = list(input_tensor.shape)[-1]
+ average_pooling = layers.GlobalAveragePooling2D()(input_tensor)
+ feature_descriptor = layers.Reshape((1, 1, channels))(average_pooling)
+ feature_activations = layers.Conv2D(
+ filters=channels // 8, kernel_size=(1, 1), activation="relu"
+ )(feature_descriptor)
+ feature_activations = layers.Conv2D(
+ filters=channels, kernel_size=(1, 1), activation="sigmoid"
+ )(feature_activations)
+ return input_tensor * feature_activations
+
+
+def dual_attention_unit_block(input_tensor):
+ channels = list(input_tensor.shape)[-1]
+ feature_map = layers.Conv2D(
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
+ )(input_tensor)
+ feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
+ feature_map
+ )
+ channel_attention = channel_attention_block(feature_map)
+ spatial_attention = spatial_attention_block(feature_map)
+ concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])
+ concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
+ return layers.Add()([input_tensor, concatenation])
+
+
+"""
+### Multi-Scale Residual Block
+
+The Multi-Scale Residual Block is capable of generating a spatially-precise output by
+maintaining high-resolution representations, while receiving rich contextual information
+from low-resolutions. The MRB consists of multiple (three in this paper)
+fully-convolutional streams connected in parallel. It allows information exchange across
+parallel streams in order to consolidate the high-resolution features with the help of
+low-resolution features, and vice versa. The MIRNet employs a recursive residual design
+(with skip connections) to ease the flow of information during the learning process. In
+order to maintain the residual nature of our architecture, residual resizing modules are
+used to perform downsampling and upsampling operations that are used in the Multi-scale
+Residual Block.
+
+
+"""
+
+# Recursive Residual Modules
+
+
+def down_sampling_module(input_tensor):
+ channels = list(input_tensor.shape)[-1]
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
+ input_tensor
+ )
+ main_branch = layers.Conv2D(
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
+ )(main_branch)
+ main_branch = layers.MaxPooling2D()(main_branch)
+ main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
+ skip_branch = layers.MaxPooling2D()(input_tensor)
+ skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
+ return layers.Add()([skip_branch, main_branch])
+
+
+def up_sampling_module(input_tensor):
+ channels = list(input_tensor.shape)[-1]
+ main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(
+ input_tensor
+ )
+ main_branch = layers.Conv2D(
+ channels, kernel_size=(3, 3), padding="same", activation="relu"
+ )(main_branch)
+ main_branch = layers.UpSampling2D()(main_branch)
+ main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
+ skip_branch = layers.UpSampling2D()(input_tensor)
+ skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
+ return layers.Add()([skip_branch, main_branch])
+
+
+# MRB Block
+def multi_scale_residual_block(input_tensor, channels):
+ # features
+ level1 = input_tensor
+ level2 = down_sampling_module(input_tensor)
+ level3 = down_sampling_module(level2)
+ # DAU
+ level1_dau = dual_attention_unit_block(level1)
+ level2_dau = dual_attention_unit_block(level2)
+ level3_dau = dual_attention_unit_block(level3)
+ # SKFF
+ level1_skff = selective_kernel_feature_fusion(
+ level1_dau,
+ up_sampling_module(level2_dau),
+ up_sampling_module(up_sampling_module(level3_dau)),
+ )
+ level2_skff = selective_kernel_feature_fusion(
+ down_sampling_module(level1_dau),
+ level2_dau,
+ up_sampling_module(level3_dau),
+ )
+ level3_skff = selective_kernel_feature_fusion(
+ down_sampling_module(down_sampling_module(level1_dau)),
+ down_sampling_module(level2_dau),
+ level3_dau,
+ )
+ # DAU 2
+ level1_dau_2 = dual_attention_unit_block(level1_skff)
+ level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
+ level3_dau_2 = up_sampling_module(
+ up_sampling_module(dual_attention_unit_block(level3_skff))
+ )
+ # SKFF 2
+ skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2)
+ conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)
+ return layers.Add()([input_tensor, conv])
+
+
+"""
+### MIRNet Model
+"""
+
+
+def recursive_residual_group(input_tensor, num_mrb, channels):
+ conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
+ for _ in range(num_mrb):
+ conv1 = multi_scale_residual_block(conv1, channels)
+ conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
+ return layers.Add()([conv2, input_tensor])
+
+
+def mirnet_model(num_rrg, num_mrb, channels):
+ input_tensor = keras.Input(shape=[None, None, 3])
+ x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
+ for _ in range(num_rrg):
+ x1 = recursive_residual_group(x1, num_mrb, channels)
+ conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
+ output_tensor = layers.Add()([input_tensor, conv])
+ return keras.Model(input_tensor, output_tensor)
+
+
+model = mirnet_model(num_rrg=3, num_mrb=2, channels=64)
+
+"""
+## Training
+
+- We train MIRNet using **Charbonnier Loss** as the loss function and **Adam
+Optimizer** with a learning rate of `1e-4`.
+- We use **Peak Signal Noise Ratio** or PSNR as a metric which is an expression for the
+ratio between the maximum possible value (power) of a signal and the power of distorting
+noise that affects the quality of its representation.
+"""
+
+
+def charbonnier_loss(y_true, y_pred):
+ return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))
+
+
+def peak_signal_noise_ratio(y_true, y_pred):
+ return tf.image.psnr(y_pred, y_true, max_val=255.0)
+
+
+optimizer = keras.optimizers.Adam(learning_rate=1e-4)
+model.compile(
+ optimizer=optimizer,
+ loss=charbonnier_loss,
+ metrics=[peak_signal_noise_ratio],
+)
+
+history = model.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=50,
+ callbacks=[
+ keras.callbacks.ReduceLROnPlateau(
+ monitor="val_peak_signal_noise_ratio",
+ factor=0.5,
+ patience=5,
+ verbose=1,
+ min_delta=1e-7,
+ mode="max",
+ )
+ ],
+)
+
+
+def plot_history(value, name):
+ plt.plot(history.history[value], label=f"train_{name.lower()}")
+ plt.plot(history.history[f"val_{value}"], label=f"val_{name.lower()}")
+ plt.xlabel("Epochs")
+ plt.ylabel(name)
+ plt.title(f"Train and Validation {name} Over Epochs", fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_history("loss", "Loss")
+plot_history("peak_signal_noise_ratio", "PSNR")
+
+"""
+## Inference
+"""
+
+
+def plot_results(images, titles, figure_size=(12, 12)):
+ fig = plt.figure(figsize=figure_size)
+ for i in range(len(images)):
+ fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
+ _ = plt.imshow(images[i])
+ plt.axis("off")
+ plt.show()
+
+
+def infer(original_image):
+ image = keras.utils.img_to_array(original_image)
+ image = image.astype("float32") / 255.0
+ image = np.expand_dims(image, axis=0)
+ output = model.predict(image, verbose=0)
+ output_image = output[0] * 255.0
+ output_image = output_image.clip(0, 255)
+ output_image = output_image.reshape(
+ (np.shape(output_image)[0], np.shape(output_image)[1], 3)
+ )
+ output_image = Image.fromarray(np.uint8(output_image))
+ original_image = Image.fromarray(np.uint8(original_image))
+ return output_image
+
+
+"""
+### Inference on Test Images
+
+We compare the test images from LOLDataset enhanced by MIRNet with images
+enhanced via the `PIL.ImageOps.autocontrast()` function.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/lowlight-enhance-mirnet)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Enhance_Low_Light_Image).
+"""
+
+
+for low_light_image in random.sample(test_low_light_images, 6):
+ original_image = Image.open(low_light_image)
+ enhanced_image = infer(original_image)
+ plot_results(
+ [original_image, ImageOps.autocontrast(original_image), enhanced_image],
+ ["Original", "PIL Autocontrast", "MIRNet Enhanced"],
+ (20, 12),
+ )
diff --git a/knowledge_base/vision/mixup.py b/knowledge_base/vision/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c9c049ee1a613be109f4db60a6053b9a876117
--- /dev/null
+++ b/knowledge_base/vision/mixup.py
@@ -0,0 +1,242 @@
+"""
+Title: MixUp augmentation for image classification
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/03/06
+Last modified: 2023/07/24
+Description: Data augmentation using the mixup technique for image classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+"""
+
+"""
+_mixup_ is a *domain-agnostic* data augmentation technique proposed in [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
+by Zhang et al. It's implemented with the following formulas:
+
+
+
+(Note that the lambda values are values with the [0, 1] range and are sampled from the
+[Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution).)
+
+The technique is quite systematically named. We are literally mixing up the features and
+their corresponding labels. Implementation-wise it's simple. Neural networks are prone
+to [memorizing corrupt labels](https://arxiv.org/abs/1611.03530). mixup relaxes this by
+combining different features with one another (same happens for the labels too) so that
+a network does not get overconfident about the relationship between the features and
+their labels.
+
+mixup is specifically useful when we are not sure about selecting a set of augmentation
+transforms for a given dataset, medical imaging datasets, for example. mixup can be
+extended to a variety of data modalities such as computer vision, naturallanguage
+processing, speech, and so on.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import numpy as np
+import keras
+import matplotlib.pyplot as plt
+
+from keras import layers
+
+# TF imports related to tf.data preprocessing
+from tensorflow import data as tf_data
+from tensorflow import image as tf_image
+from tensorflow.random import gamma as tf_random_gamma
+
+
+"""
+## Prepare the dataset
+
+In this example, we will be using the [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. But this same recipe can
+be used for other classification datasets as well.
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
+
+x_train = x_train.astype("float32") / 255.0
+x_train = np.reshape(x_train, (-1, 28, 28, 1))
+y_train = keras.ops.one_hot(y_train, 10)
+
+x_test = x_test.astype("float32") / 255.0
+x_test = np.reshape(x_test, (-1, 28, 28, 1))
+y_test = keras.ops.one_hot(y_test, 10)
+
+"""
+## Define hyperparameters
+"""
+
+AUTO = tf_data.AUTOTUNE
+BATCH_SIZE = 64
+EPOCHS = 10
+
+"""
+## Convert the data into TensorFlow `Dataset` objects
+"""
+
+# Put aside a few samples to create our validation set
+val_samples = 2000
+x_val, y_val = x_train[:val_samples], y_train[:val_samples]
+new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]
+
+train_ds_one = (
+ tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
+ .shuffle(BATCH_SIZE * 100)
+ .batch(BATCH_SIZE)
+)
+train_ds_two = (
+ tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
+ .shuffle(BATCH_SIZE * 100)
+ .batch(BATCH_SIZE)
+)
+# Because we will be mixing up the images and their corresponding labels, we will be
+# combining two shuffled datasets from the same training data.
+train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))
+
+val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)
+
+test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
+
+"""
+## Define the mixup technique function
+
+To perform the mixup routine, we create new virtual datasets using the training data from
+the same dataset, and apply a lambda value within the [0, 1] range sampled from a [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution)
+โ such that, for example, `new_x = lambda * x1 + (1 - lambda) * x2` (where
+`x1` and `x2` are images) and the same equation is applied to the labels as well.
+"""
+
+
+def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
+ gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
+ gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
+ return gamma_1_sample / (gamma_1_sample + gamma_2_sample)
+
+
+def mix_up(ds_one, ds_two, alpha=0.2):
+ # Unpack two datasets
+ images_one, labels_one = ds_one
+ images_two, labels_two = ds_two
+ batch_size = keras.ops.shape(images_one)[0]
+
+ # Sample lambda and reshape it to do the mixup
+ l = sample_beta_distribution(batch_size, alpha, alpha)
+ x_l = keras.ops.reshape(l, (batch_size, 1, 1, 1))
+ y_l = keras.ops.reshape(l, (batch_size, 1))
+
+ # Perform mixup on both images and labels by combining a pair of images/labels
+ # (one from each dataset) into one image/label
+ images = images_one * x_l + images_two * (1 - x_l)
+ labels = labels_one * y_l + labels_two * (1 - y_l)
+ return (images, labels)
+
+
+"""
+**Note** that here , we are combining two images to create a single one. Theoretically,
+we can combine as many we want but that comes at an increased computation cost. In
+certain cases, it may not help improve the performance as well.
+"""
+
+"""
+## Visualize the new augmented dataset
+"""
+
+# First create the new dataset using our `mix_up` utility
+train_ds_mu = train_ds.map(
+ lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2),
+ num_parallel_calls=AUTO,
+)
+
+# Let's preview 9 samples from the dataset
+sample_images, sample_labels = next(iter(train_ds_mu))
+plt.figure(figsize=(10, 10))
+for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image.numpy().squeeze())
+ print(label.numpy().tolist())
+ plt.axis("off")
+
+"""
+## Model building
+"""
+
+
+def get_training_model():
+ model = keras.Sequential(
+ [
+ layers.Input(shape=(28, 28, 1)),
+ layers.Conv2D(16, (5, 5), activation="relu"),
+ layers.MaxPooling2D(pool_size=(2, 2)),
+ layers.Conv2D(32, (5, 5), activation="relu"),
+ layers.MaxPooling2D(pool_size=(2, 2)),
+ layers.Dropout(0.2),
+ layers.GlobalAveragePooling2D(),
+ layers.Dense(128, activation="relu"),
+ layers.Dense(10, activation="softmax"),
+ ]
+ )
+ return model
+
+
+"""
+For the sake of reproducibility, we serialize the initial random weights of our shallow
+network.
+"""
+
+initial_model = get_training_model()
+initial_model.save_weights("initial_weights.weights.h5")
+
+"""
+## 1. Train the model with the mixed up dataset
+"""
+
+model = get_training_model()
+model.load_weights("initial_weights.weights.h5")
+model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS)
+_, test_acc = model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_acc * 100))
+
+"""
+## 2. Train the model *without* the mixed up dataset
+"""
+
+model = get_training_model()
+model.load_weights("initial_weights.weights.h5")
+model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+# Notice that we are NOT using the mixed up dataset here
+model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS)
+_, test_acc = model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_acc * 100))
+
+"""
+Readers are encouraged to try out mixup on different datasets from different domains and
+experiment with the lambda parameter. You are strongly advised to check out the
+[original paper](https://arxiv.org/abs/1710.09412) as well - the authors present several ablation studies on mixup
+showing how it can improve generalization, as well as show their results of combining
+more than two images to create a single one.
+"""
+
+"""
+## Notes
+
+* With mixup, you can create synthetic examples โ especially when you lack a large
+dataset - without incurring high computational costs.
+* [Label smoothing](https://www.pyimagesearch.com/2019/12/30/label-smoothing-with-keras-tensorflow-and-deep-learning/) and mixup usually do not work well together because label smoothing
+already modifies the hard labels by some factor.
+* mixup does not work well when you are using [Supervised Contrastive
+Learning](https://arxiv.org/abs/2004.11362) (SCL) since SCL expects the true labels
+during its pre-training phase.
+* A few other benefits of mixup include (as described in the [paper](https://arxiv.org/abs/1710.09412)) robustness to
+adversarial examples and stabilized GAN (Generative Adversarial Networks) training.
+* There are a number of data augmentation techniques that extend mixup such as
+[CutMix](https://arxiv.org/abs/1905.04899) and [AugMix](https://arxiv.org/abs/1912.02781).
+"""
diff --git a/knowledge_base/vision/mlp_image_classification.py b/knowledge_base/vision/mlp_image_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..81d513cc7422dcb1fa451b658b4890e77a2e65cd
--- /dev/null
+++ b/knowledge_base/vision/mlp_image_classification.py
@@ -0,0 +1,476 @@
+"""
+Title: Image classification with modern MLP models
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/05/30
+Last modified: 2023/08/03
+Description: Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image
+classification, demonstrated on the CIFAR-100 dataset:
+
+1. The [MLP-Mixer](https://arxiv.org/abs/2105.01601) model, by Ilya Tolstikhin et al., based on two types of MLPs.
+3. The [FNet](https://arxiv.org/abs/2105.03824) model, by James Lee-Thorp et al., based on unparameterized
+Fourier Transform.
+2. The [gMLP](https://arxiv.org/abs/2105.08050) model, by Hanxiao Liu et al., based on MLP with gating.
+
+The purpose of the example is not to compare between these models, as they might perform differently on
+different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their
+main building blocks.
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import keras
+from keras import layers
+
+"""
+## Prepare the data
+"""
+
+num_classes = 100
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+"""
+## Configure the hyperparameters
+"""
+
+weight_decay = 0.0001
+batch_size = 128
+num_epochs = 1 # Recommended num_epochs = 50
+dropout_rate = 0.2
+image_size = 64 # We'll resize input images to this size.
+patch_size = 8 # Size of the patches to be extracted from the input images.
+num_patches = (image_size // patch_size) ** 2 # Size of the data array.
+embedding_dim = 256 # Number of hidden units.
+num_blocks = 4 # Number of blocks.
+
+print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
+print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
+print(f"Patches per image: {num_patches}")
+print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
+
+"""
+## Build a classification model
+
+We implement a method that builds a classifier given the processing blocks.
+"""
+
+
+def build_classifier(blocks, positional_encoding=False):
+ inputs = layers.Input(shape=input_shape)
+ # Augment data.
+ augmented = data_augmentation(inputs)
+ # Create patches.
+ patches = Patches(patch_size)(augmented)
+ # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
+ x = layers.Dense(units=embedding_dim)(patches)
+ if positional_encoding:
+ x = x + PositionEmbedding(sequence_length=num_patches)(x)
+ # Process x using the module blocks.
+ x = blocks(x)
+ # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
+ representation = layers.GlobalAveragePooling1D()(x)
+ # Apply dropout.
+ representation = layers.Dropout(rate=dropout_rate)(representation)
+ # Compute logits outputs.
+ logits = layers.Dense(num_classes)(representation)
+ # Create the Keras model.
+ return keras.Model(inputs=inputs, outputs=logits)
+
+
+"""
+## Define an experiment
+
+We implement a utility function to compile, train, and evaluate a given model.
+"""
+
+
+def run_experiment(model):
+ # Create Adam optimizer with weight decay.
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=learning_rate,
+ weight_decay=weight_decay,
+ )
+ # Compile the model.
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="acc"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
+ ],
+ )
+ # Create a learning rate scheduler callback.
+ reduce_lr = keras.callbacks.ReduceLROnPlateau(
+ monitor="val_loss", factor=0.5, patience=5
+ )
+ # Create an early stopping callback.
+ early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_loss", patience=10, restore_best_weights=True
+ )
+ # Fit the model.
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=0.1,
+ callbacks=[early_stopping, reduce_lr],
+ )
+
+ _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+ # Return history to plot learning curves.
+ return history
+
+
+"""
+## Use data augmentation
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.Resizing(image_size, image_size),
+ layers.RandomFlip("horizontal"),
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
+ ],
+ name="data_augmentation",
+)
+# Compute the mean and the variance of the training data for normalization.
+data_augmentation.layers[0].adapt(x_train)
+
+
+"""
+## Implement patch extraction as a layer
+"""
+
+
+class Patches(layers.Layer):
+ def __init__(self, patch_size, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_size = patch_size
+
+ def call(self, x):
+ patches = keras.ops.image.extract_patches(x, self.patch_size)
+ batch_size = keras.ops.shape(patches)[0]
+ num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]
+ patch_dim = keras.ops.shape(patches)[3]
+ out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))
+ return out
+
+
+"""
+## Implement position embedding as a layer
+"""
+
+
+class PositionEmbedding(keras.layers.Layer):
+ def __init__(
+ self,
+ sequence_length,
+ initializer="glorot_uniform",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if sequence_length is None:
+ raise ValueError("`sequence_length` must be an Integer, received `None`.")
+ self.sequence_length = int(sequence_length)
+ self.initializer = keras.initializers.get(initializer)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "sequence_length": self.sequence_length,
+ "initializer": keras.initializers.serialize(self.initializer),
+ }
+ )
+ return config
+
+ def build(self, input_shape):
+ feature_size = input_shape[-1]
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.sequence_length, feature_size],
+ initializer=self.initializer,
+ trainable=True,
+ )
+
+ super().build(input_shape)
+
+ def call(self, inputs, start_index=0):
+ shape = keras.ops.shape(inputs)
+ feature_length = shape[-1]
+ sequence_length = shape[-2]
+ # trim to match the length of the input sequence, which might be less
+ # than the sequence_length of the layer.
+ position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
+ position_embeddings = keras.ops.slice(
+ position_embeddings,
+ (start_index, 0),
+ (sequence_length, feature_length),
+ )
+ return keras.ops.broadcast_to(position_embeddings, shape)
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+"""
+## The MLP-Mixer model
+
+The MLP-Mixer is an architecture based exclusively on
+multi-layer perceptrons (MLPs), that contains two types of MLP layers:
+
+1. One applied independently to image patches, which mixes the per-location features.
+2. The other applied across patches (along channels), which mixes spatial information.
+
+This is similar to a [depthwise separable convolution based model](https://arxiv.org/abs/1610.02357)
+such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization
+instead of batch normalization.
+"""
+
+"""
+### Implement the MLP-Mixer module
+"""
+
+
+class MLPMixerLayer(layers.Layer):
+ def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.mlp1 = keras.Sequential(
+ [
+ layers.Dense(units=num_patches, activation="gelu"),
+ layers.Dense(units=num_patches),
+ layers.Dropout(rate=dropout_rate),
+ ]
+ )
+ self.mlp2 = keras.Sequential(
+ [
+ layers.Dense(units=num_patches, activation="gelu"),
+ layers.Dense(units=hidden_units),
+ layers.Dropout(rate=dropout_rate),
+ ]
+ )
+ self.normalize = layers.LayerNormalization(epsilon=1e-6)
+
+ def build(self, input_shape):
+ return super().build(input_shape)
+
+ def call(self, inputs):
+ # Apply layer normalization.
+ x = self.normalize(inputs)
+ # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
+ x_channels = keras.ops.transpose(x, axes=(0, 2, 1))
+ # Apply mlp1 on each channel independently.
+ mlp1_outputs = self.mlp1(x_channels)
+ # Transpose mlp1_outputs from [num_batches, hidden_units, num_patches] to [num_batches, num_patches, hidden_units].
+ mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))
+ # Add skip connection.
+ x = mlp1_outputs + inputs
+ # Apply layer normalization.
+ x_patches = self.normalize(x)
+ # Apply mlp2 on each patch independtenly.
+ mlp2_outputs = self.mlp2(x_patches)
+ # Add skip connection.
+ x = x + mlp2_outputs
+ return x
+
+
+"""
+### Build, train, and evaluate the MLP-Mixer model
+
+Note that training the model with the current settings on a V100 GPUs
+takes around 8 seconds per epoch.
+"""
+
+mlpmixer_blocks = keras.Sequential(
+ [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
+)
+learning_rate = 0.005
+mlpmixer_classifier = build_classifier(mlpmixer_blocks)
+history = run_experiment(mlpmixer_classifier)
+
+"""
+The MLP-Mixer model tends to have much less number of parameters compared
+to convolutional and transformer-based models, which leads to less training and
+serving computational cost.
+
+As mentioned in the [MLP-Mixer](https://arxiv.org/abs/2105.01601) paper,
+when pre-trained on large datasets, or with modern regularization schemes,
+the MLP-Mixer attains competitive scores to state-of-the-art models.
+You can obtain better results by increasing the embedding dimensions,
+increasing the number of mixer blocks, and training the model for longer.
+You may also try to increase the size of the input images and use different patch sizes.
+"""
+
+"""
+## The FNet model
+
+The FNet uses a similar block to the Transformer block. However, FNet replaces the self-attention layer
+in the Transformer block with a parameter-free 2D Fourier transformation layer:
+
+1. One 1D Fourier Transform is applied along the patches.
+2. One 1D Fourier Transform is applied along the channels.
+"""
+
+"""
+### Implement the FNet module
+"""
+
+
+class FNetLayer(layers.Layer):
+ def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(units=embedding_dim, activation="gelu"),
+ layers.Dropout(rate=dropout_rate),
+ layers.Dense(units=embedding_dim),
+ ]
+ )
+
+ self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
+ self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
+
+ def call(self, inputs):
+ # Apply fourier transformations.
+ real_part = inputs
+ im_part = keras.ops.zeros_like(inputs)
+ x = keras.ops.fft2((real_part, im_part))[0]
+ # Add skip connection.
+ x = x + inputs
+ # Apply layer normalization.
+ x = self.normalize1(x)
+ # Apply Feedfowrad network.
+ x_ffn = self.ffn(x)
+ # Add skip connection.
+ x = x + x_ffn
+ # Apply layer normalization.
+ return self.normalize2(x)
+
+
+"""
+### Build, train, and evaluate the FNet model
+
+Note that training the model with the current settings on a V100 GPUs
+takes around 8 seconds per epoch.
+"""
+
+fnet_blocks = keras.Sequential(
+ [FNetLayer(embedding_dim, dropout_rate) for _ in range(num_blocks)]
+)
+learning_rate = 0.001
+fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)
+history = run_experiment(fnet_classifier)
+
+"""
+As shown in the [FNet](https://arxiv.org/abs/2105.03824) paper,
+better results can be achieved by increasing the embedding dimensions,
+increasing the number of FNet blocks, and training the model for longer.
+You may also try to increase the size of the input images and use different patch sizes.
+The FNet scales very efficiently to long inputs, runs much faster than attention-based
+Transformer models, and produces competitive accuracy results.
+"""
+
+"""
+## The gMLP model
+
+The gMLP is a MLP architecture that features a Spatial Gating Unit (SGU).
+The SGU enables cross-patch interactions across the spatial (channel) dimension, by:
+
+1. Transforming the input spatially by applying linear projection across patches (along channels).
+2. Applying element-wise multiplication of the input and its spatial transformation.
+"""
+
+"""
+### Implement the gMLP module
+"""
+
+
+class gMLPLayer(layers.Layer):
+ def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.channel_projection1 = keras.Sequential(
+ [
+ layers.Dense(units=embedding_dim * 2, activation="gelu"),
+ layers.Dropout(rate=dropout_rate),
+ ]
+ )
+
+ self.channel_projection2 = layers.Dense(units=embedding_dim)
+
+ self.spatial_projection = layers.Dense(
+ units=num_patches, bias_initializer="Ones"
+ )
+
+ self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
+ self.normalize2 = layers.LayerNormalization(epsilon=1e-6)
+
+ def spatial_gating_unit(self, x):
+ # Split x along the channel dimensions.
+ # Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].
+ u, v = keras.ops.split(x, indices_or_sections=2, axis=2)
+ # Apply layer normalization.
+ v = self.normalize2(v)
+ # Apply spatial projection.
+ v_channels = keras.ops.transpose(v, axes=(0, 2, 1))
+ v_projected = self.spatial_projection(v_channels)
+ v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))
+ # Apply element-wise multiplication.
+ return u * v_projected
+
+ def call(self, inputs):
+ # Apply layer normalization.
+ x = self.normalize1(inputs)
+ # Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
+ x_projected = self.channel_projection1(x)
+ # Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
+ x_spatial = self.spatial_gating_unit(x_projected)
+ # Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
+ x_projected = self.channel_projection2(x_spatial)
+ # Add skip connection.
+ return x + x_projected
+
+
+"""
+### Build, train, and evaluate the gMLP model
+
+Note that training the model with the current settings on a V100 GPUs
+takes around 9 seconds per epoch.
+"""
+
+gmlp_blocks = keras.Sequential(
+ [gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
+)
+learning_rate = 0.003
+gmlp_classifier = build_classifier(gmlp_blocks)
+history = run_experiment(gmlp_classifier)
+
+"""
+As shown in the [gMLP](https://arxiv.org/abs/2105.08050) paper,
+better results can be achieved by increasing the embedding dimensions,
+increasing the number of gMLP blocks, and training the model for longer.
+You may also try to increase the size of the input images and use different patch sizes.
+Note that, the paper used advanced regularization strategies, such as MixUp and CutMix,
+as well as AutoAugment.
+"""
diff --git a/knowledge_base/vision/mnist_convnet.py b/knowledge_base/vision/mnist_convnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..de0f355e652c4fb98f8c80c5e69f3d8aa5f928ea
--- /dev/null
+++ b/knowledge_base/vision/mnist_convnet.py
@@ -0,0 +1,80 @@
+"""
+Title: Simple MNIST convnet
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2015/06/19
+Last modified: 2020/04/21
+Description: A simple convnet that achieves ~99% test accuracy on MNIST.
+Accelerator: GPU
+"""
+
+"""
+## Setup
+"""
+
+import numpy as np
+import keras
+from keras import layers
+
+"""
+## Prepare the data
+"""
+
+# Model / data parameters
+num_classes = 10
+input_shape = (28, 28, 1)
+
+# Load the data and split it between train and test sets
+(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
+
+# Scale images to the [0, 1] range
+x_train = x_train.astype("float32") / 255
+x_test = x_test.astype("float32") / 255
+# Make sure images have shape (28, 28, 1)
+x_train = np.expand_dims(x_train, -1)
+x_test = np.expand_dims(x_test, -1)
+print("x_train shape:", x_train.shape)
+print(x_train.shape[0], "train samples")
+print(x_test.shape[0], "test samples")
+
+
+# convert class vectors to binary class matrices
+y_train = keras.utils.to_categorical(y_train, num_classes)
+y_test = keras.utils.to_categorical(y_test, num_classes)
+
+"""
+## Build the model
+"""
+
+model = keras.Sequential(
+ [
+ keras.Input(shape=input_shape),
+ layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
+ layers.MaxPooling2D(pool_size=(2, 2)),
+ layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
+ layers.MaxPooling2D(pool_size=(2, 2)),
+ layers.Flatten(),
+ layers.Dropout(0.5),
+ layers.Dense(num_classes, activation="softmax"),
+ ]
+)
+
+model.summary()
+
+"""
+## Train the model
+"""
+
+batch_size = 128
+epochs = 15
+
+model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
+
+model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
+
+"""
+## Evaluate the trained model
+"""
+
+score = model.evaluate(x_test, y_test, verbose=0)
+print("Test loss:", score[0])
+print("Test accuracy:", score[1])
diff --git a/knowledge_base/vision/mobilevit.py b/knowledge_base/vision/mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..630de253db19b2c5f1a5e30b7eb9e6397d99ecd4
--- /dev/null
+++ b/knowledge_base/vision/mobilevit.py
@@ -0,0 +1,404 @@
+"""
+Title: MobileViT: A mobile-friendly Transformer-based model for image classification
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/10/20
+Last modified: 2025/09/30
+Description: MobileViT for image classification with combined benefits of convolutions and Transformers.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we implement the MobileViT architecture
+([Mehta et al.](https://arxiv.org/abs/2110.02178)),
+which combines the benefits of Transformers
+([Vaswani et al.](https://arxiv.org/abs/1706.03762))
+and convolutions. With Transformers, we can capture long-range dependencies that result
+in global representations. With convolutions, we can capture spatial relationships that
+model locality.
+
+Besides combining the properties of Transformers and convolutions, the authors introduce
+MobileViT as a general-purpose mobile-friendly backbone for different image recognition
+tasks. Their findings suggest that, performance-wise, MobileViT is better than other
+models with the same or higher complexity ([MobileNetV3](https://arxiv.org/abs/1905.02244),
+for example), while being efficient on mobile devices.
+
+Note: This example should be run with Tensorflow 2.13 and higher.
+"""
+
+"""
+## Imports
+"""
+
+import os
+import tensorflow as tf
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+from keras import backend
+
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+
+"""
+## Hyperparameters
+"""
+
+# Values are from table 4.
+patch_size = 4 # 2x2, for the Transformer blocks.
+image_size = 256
+expansion_factor = 2 # expansion factor for the MobileNetV2 blocks.
+
+"""
+## MobileViT utilities
+
+The MobileViT architecture is comprised of the following blocks:
+
+* Strided 3x3 convolutions that process the input image.
+* [MobileNetV2](https://arxiv.org/abs/1801.04381)-style inverted residual blocks for
+downsampling the resolution of the intermediate feature maps.
+* MobileViT blocks that combine the benefits of Transformers and convolutions. It is
+presented in the figure below (taken from the
+[original paper](https://arxiv.org/abs/2110.02178)):
+
+
+
+"""
+
+
+def conv_block(x, filters=16, kernel_size=3, strides=2):
+ conv_layer = layers.Conv2D(
+ filters,
+ kernel_size,
+ strides=strides,
+ activation=keras.activations.swish,
+ padding="same",
+ )
+ return conv_layer(x)
+
+
+# Reference: https://github.com/keras-team/keras/blob/e3858739d178fe16a0c77ce7fab88b0be6dbbdc7/keras/applications/imagenet_utils.py#L413C17-L435
+
+
+def correct_pad(inputs, kernel_size):
+ img_dim = 2 if backend.image_data_format() == "channels_first" else 1
+ input_size = inputs.shape[img_dim : (img_dim + 2)]
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ if input_size[0] is None:
+ adjust = (1, 1)
+ else:
+ adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
+ correct = (kernel_size[0] // 2, kernel_size[1] // 2)
+ return (
+ (correct[0] - adjust[0], correct[0]),
+ (correct[1] - adjust[1], correct[1]),
+ )
+
+
+# Reference: https://git.io/JKgtC
+
+
+def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
+ m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
+ m = layers.BatchNormalization()(m)
+ m = keras.activations.swish(m)
+
+ if strides == 2:
+ m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
+ m = layers.DepthwiseConv2D(
+ 3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
+ )(m)
+ m = layers.BatchNormalization()(m)
+ m = keras.activations.swish(m)
+
+ m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
+ m = layers.BatchNormalization()(m)
+
+ if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
+ return layers.Add()([m, x])
+ return m
+
+
+# Reference:
+# https://keras.io/examples/vision/image_classification_with_vision_transformer/
+
+
+def mlp(x, hidden_units, dropout_rate):
+ for units in hidden_units:
+ x = layers.Dense(units, activation=keras.activations.swish)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
+ for _ in range(transformer_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-6)(x)
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=projection_dim, dropout=0.1
+ )(x1, x1)
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, x])
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ # MLP.
+ x3 = mlp(
+ x3,
+ hidden_units=[x.shape[-1] * 2, x.shape[-1]],
+ dropout_rate=0.1,
+ )
+ # Skip connection 2.
+ x = layers.Add()([x3, x2])
+
+ return x
+
+
+def mobilevit_block(x, num_blocks, projection_dim, strides=1):
+ # Local projection with convolutions.
+ local_features = conv_block(x, filters=projection_dim, strides=strides)
+ local_features = conv_block(
+ local_features, filters=projection_dim, kernel_size=1, strides=strides
+ )
+
+ # Unfold into patches and then pass through Transformers.
+ num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
+ non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
+ local_features
+ )
+ global_features = transformer_block(
+ non_overlapping_patches, num_blocks, projection_dim
+ )
+
+ # Fold into conv-like feature-maps.
+ folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
+ global_features
+ )
+
+ # Apply point-wise conv -> concatenate with the input features.
+ folded_feature_map = conv_block(
+ folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
+ )
+ local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
+
+ # Fuse the local and global features using a convoluion layer.
+ local_global_features = conv_block(
+ local_global_features, filters=projection_dim, strides=strides
+ )
+
+ return local_global_features
+
+
+"""
+**More on the MobileViT block**:
+
+* First, the feature representations (A) go through convolution blocks that capture local
+relationships. The expected shape of a single entry here would be `(h, w, num_channels)`.
+* Then they get unfolded into another vector with shape `(p, n, num_channels)`,
+where `p` is the area of a small patch, and `n` is `(h * w) / p`. So, we end up with `n`
+non-overlapping patches.
+* This unfolded vector is then passed through a Tranformer block that captures global
+relationships between the patches.
+* The output vector (B) is again folded into a vector of shape `(h, w, num_channels)`
+resembling a feature map coming out of convolutions.
+
+Vectors A and B are then passed through two more convolutional layers to fuse the local
+and global representations. Notice how the spatial resolution of the final vector remains
+unchanged at this point. The authors also present an explanation of how the MobileViT
+block resembles a convolution block of a CNN. For more details, please refer to the
+original paper.
+"""
+
+"""
+Next, we combine these blocks together and implement the MobileViT architecture (XXS
+variant). The following figure (taken from the original paper) presents a schematic
+representation of the architecture:
+
+
+"""
+
+
+def create_mobilevit(num_classes=5):
+ inputs = keras.Input((image_size, image_size, 3))
+ x = layers.Rescaling(scale=1.0 / 255)(inputs)
+
+ # Initial conv-stem -> MV2 block.
+ x = conv_block(x, filters=16)
+ x = inverted_residual_block(
+ x, expanded_channels=16 * expansion_factor, output_channels=16
+ )
+
+ # Downsampling with MV2 block.
+ x = inverted_residual_block(
+ x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
+ )
+ x = inverted_residual_block(
+ x, expanded_channels=24 * expansion_factor, output_channels=24
+ )
+ x = inverted_residual_block(
+ x, expanded_channels=24 * expansion_factor, output_channels=24
+ )
+
+ # First MV2 -> MobileViT block.
+ x = inverted_residual_block(
+ x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
+ )
+ x = mobilevit_block(x, num_blocks=2, projection_dim=64)
+
+ # Second MV2 -> MobileViT block.
+ x = inverted_residual_block(
+ x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
+ )
+ x = mobilevit_block(x, num_blocks=4, projection_dim=80)
+
+ # Third MV2 -> MobileViT block.
+ x = inverted_residual_block(
+ x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
+ )
+ x = mobilevit_block(x, num_blocks=3, projection_dim=96)
+ x = conv_block(x, filters=320, kernel_size=1, strides=1)
+
+ # Classification head.
+ x = layers.GlobalAvgPool2D()(x)
+ outputs = layers.Dense(num_classes, activation="softmax")(x)
+
+ return keras.Model(inputs, outputs)
+
+
+mobilevit_xxs = create_mobilevit()
+mobilevit_xxs.summary()
+
+"""
+## Dataset preparation
+
+We will be using the
+[`tf_flowers`](https://www.tensorflow.org/datasets/catalog/tf_flowers)
+dataset to demonstrate the model. Unlike other Transformer-based architectures,
+MobileViT uses a simple augmentation pipeline primarily because it has the properties
+of a CNN.
+"""
+
+batch_size = 64
+auto = tf.data.AUTOTUNE
+resize_bigger = 280
+num_classes = 5
+
+
+def preprocess_dataset(is_training=True):
+ def _pp(image, label):
+ if is_training:
+ # Resize to a bigger spatial resolution and take the random
+ # crops.
+ image = tf.image.resize(image, (resize_bigger, resize_bigger))
+ image = tf.image.random_crop(image, (image_size, image_size, 3))
+ image = tf.image.random_flip_left_right(image)
+ else:
+ image = tf.image.resize(image, (image_size, image_size))
+ label = tf.one_hot(label, depth=num_classes)
+ return image, label
+
+ return _pp
+
+
+def prepare_dataset(dataset, is_training=True):
+ if is_training:
+ dataset = dataset.shuffle(batch_size * 10)
+ dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
+ return dataset.batch(batch_size).prefetch(auto)
+
+
+"""
+The authors use a multi-scale data sampler to help the model learn representations of
+varied scales. In this example, we discard this part.
+"""
+
+"""
+## Load and prepare the dataset
+"""
+
+train_dataset, val_dataset = tfds.load(
+ "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
+)
+
+num_train = train_dataset.cardinality()
+num_val = val_dataset.cardinality()
+print(f"Number of training examples: {num_train}")
+print(f"Number of validation examples: {num_val}")
+
+train_dataset = prepare_dataset(train_dataset, is_training=True)
+val_dataset = prepare_dataset(val_dataset, is_training=False)
+
+"""
+## Train a MobileViT (XXS) model
+"""
+
+learning_rate = 0.002
+label_smoothing_factor = 0.1
+epochs = 30
+
+optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
+loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)
+
+
+def run_experiment(epochs=epochs):
+ mobilevit_xxs = create_mobilevit(num_classes=num_classes)
+ mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
+
+ # When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`
+ checkpoint_filepath = "/tmp/checkpoint.weights.h5"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ mobilevit_xxs.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=epochs,
+ callbacks=[checkpoint_callback],
+ )
+ mobilevit_xxs.load_weights(checkpoint_filepath)
+ _, accuracy = mobilevit_xxs.evaluate(val_dataset)
+ print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
+ return mobilevit_xxs
+
+
+mobilevit_xxs = run_experiment()
+
+"""
+## Results and TFLite conversion
+
+With about one million parameters, getting to ~85% top-1 accuracy on 256x256 resolution is
+a strong result. This MobileViT mobile is fully compatible with TensorFlow Lite (TFLite)
+and can be converted with the following code:
+"""
+
+# Serialize the model as a SavedModel.
+tf.saved_model.save(mobilevit_xxs, "mobilevit_xxs")
+
+# Convert to TFLite. This form of quantization is called
+# post-training dynamic-range quantization in TFLite.
+converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [
+ tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops.
+ tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops.
+]
+tflite_model = converter.convert()
+open("mobilevit_xxs.tflite", "wb").write(tflite_model)
+
+"""
+To learn more about different quantization recipes available in TFLite and running
+inference with TFLite models, check out
+[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization).
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/mobile-vit-xxs)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Flowers-Classification-MobileViT).
+"""
diff --git a/knowledge_base/vision/near_dup_search.py b/knowledge_base/vision/near_dup_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..0030edacbb1e700996e035fa4a0c28e33bd2306e
--- /dev/null
+++ b/knowledge_base/vision/near_dup_search.py
@@ -0,0 +1,562 @@
+"""
+Title: Near-duplicate image search
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/09/10
+Last modified: 2023/08/30
+Description: Building a near-duplicate image search utility using deep learning and locality-sensitive hashing.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Fetching similar images in (near) real time is an important use case of information
+retrieval systems. Some popular products utilizing it include Pinterest, Google Image
+Search, etc. In this example, we will build a similar image search utility using
+[Locality Sensitive Hashing](https://towardsdatascience.com/understanding-locality-sensitive-hashing-49f6d1f6134)
+(LSH) and [random projection](https://en.wikipedia.org/wiki/Random_projection) on top
+of the image representations computed by a pretrained image classifier.
+This kind of search engine is also known
+as a _near-duplicate (or near-dup) image detector_.
+We will also look into optimizing the inference performance of
+our search utility on GPU using [TensorRT](https://developer.nvidia.com/tensorrt).
+
+There are other examples under [keras.io/examples/vision](https://keras.io/examples/vision)
+that are worth checking out in this regard:
+
+* [Metric learning for image similarity search](https://keras.io/examples/vision/metric_learning)
+* [Image similarity estimation using a Siamese Network with a triplet loss](https://keras.io/examples/vision/siamese_network)
+
+Finally, this example uses the following resource as a reference and as such reuses some
+of its code:
+[Locality Sensitive Hashing for Similar Item Search](https://towardsdatascience.com/locality-sensitive-hashing-for-music-search-f2f1940ace23).
+
+_Note that in order to optimize the performance of our parser,
+you should have a GPU runtime available._
+"""
+
+"""
+## Setup
+"""
+
+"""shell
+pip install tensorrt
+"""
+
+"""
+## Imports
+"""
+
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import tensorrt
+import numpy as np
+import time
+
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+
+"""
+## Load the dataset and create a training set of 1,000 images
+
+To keep the run time of the example short, we will be using a subset of 1,000 images from
+the `tf_flowers` dataset (available through
+[TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/tf_flowers))
+to build our vocabulary.
+"""
+
+train_ds, validation_ds = tfds.load(
+ "tf_flowers", split=["train[:85%]", "train[85%:]"], as_supervised=True
+)
+
+IMAGE_SIZE = 224
+NUM_IMAGES = 1000
+
+images = []
+labels = []
+
+for image, label in train_ds.take(NUM_IMAGES):
+ image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
+ images.append(image.numpy())
+ labels.append(label.numpy())
+
+images = np.array(images)
+labels = np.array(labels)
+
+"""
+## Load a pre-trained model
+"""
+
+"""
+In this section, we load an image classification model that was trained on the
+`tf_flowers` dataset. 85% of the total images were used to build the training set. For
+more details on the training, refer to
+[this notebook](https://github.com/sayakpaul/near-dup-parser/blob/main/bit-supervised-training.ipynb).
+
+The underlying model is a BiT-ResNet (proposed in
+[Big Transfer (BiT): General Visual Representation Learning](https://arxiv.org/abs/1912.11370)).
+The BiT-ResNet family of models is known to provide excellent transfer performance across
+a wide variety of different downstream tasks.
+"""
+
+"""shell
+wget -q https://github.com/sayakpaul/near-dup-parser/releases/download/v0.1.0/flower_model_bit_0.96875.zip
+unzip -qq flower_model_bit_0.96875.zip
+"""
+
+bit_model = tf.keras.models.load_model("flower_model_bit_0.96875")
+bit_model.count_params()
+
+"""
+## Create an embedding model
+
+To retrieve similar images given a query image, we need to first generate vector
+representations of all the images involved. We do this via an
+embedding model that extracts output features from our pretrained classifier and
+normalizes the resulting feature vectors.
+"""
+
+embedding_model = tf.keras.Sequential(
+ [
+ tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
+ tf.keras.layers.Rescaling(scale=1.0 / 255),
+ bit_model.layers[1],
+ tf.keras.layers.Normalization(mean=0, variance=1),
+ ],
+ name="embedding_model",
+)
+
+embedding_model.summary()
+
+"""
+Take note of the normalization layer inside the model. It is used to project the
+representation vectors to the space of unit-spheres.
+"""
+
+"""
+## Hashing utilities
+"""
+
+
+def hash_func(embedding, random_vectors):
+ embedding = np.array(embedding)
+
+ # Random projection.
+ bools = np.dot(embedding, random_vectors) > 0
+ return [bool2int(bool_vec) for bool_vec in bools]
+
+
+def bool2int(x):
+ y = 0
+ for i, j in enumerate(x):
+ if j:
+ y += 1 << i
+ return y
+
+
+"""
+The shape of the vectors coming out of `embedding_model` is `(2048,)`, and considering practical
+aspects (storage, retrieval performance, etc.) it is quite large. So, there arises a need
+to reduce the dimensionality of the embedding vectors without reducing their information
+content. This is where *random projection* comes into the picture.
+It is based on the principle that if the
+distance between a group of points on a given plane is _approximately_ preserved, the
+dimensionality of that plane can further be reduced.
+
+Inside `hash_func()`, we first reduce the dimensionality of the embedding vectors. Then
+we compute the bitwise hash values of the images to determine their hash buckets. Images
+having same hash values are likely to go into the same hash bucket. From a deployment
+perspective, bitwise hash values are cheaper to store and operate on.
+"""
+
+"""
+## Query utilities
+
+The `Table` class is responsible for building a single hash table. Each entry in the hash
+table is a mapping between the reduced embedding of an image from our dataset and a
+unique identifier. Because our dimensionality reduction technique involves randomness, it
+can so happen that similar images are not mapped to the same hash bucket everytime the
+process run. To reduce this effect, we will take results from multiple tables into
+consideration -- the number of tables and the reduction dimensionality are the key
+hyperparameters here.
+
+Crucially, you wouldn't reimplement locality-sensitive hashing yourself when working with
+real world applications. Instead, you'd likely use one of the following popular libraries:
+
+* [ScaNN](https://github.com/google-research/google-research/tree/master/scann)
+* [Annoy](https://github.com/spotify/annoy)
+* [Vald](https://github.com/vdaas/vald)
+"""
+
+
+class Table:
+ def __init__(self, hash_size, dim):
+ self.table = {}
+ self.hash_size = hash_size
+ self.random_vectors = np.random.randn(hash_size, dim).T
+
+ def add(self, id, vectors, label):
+ # Create a unique indentifier.
+ entry = {"id_label": str(id) + "_" + str(label)}
+
+ # Compute the hash values.
+ hashes = hash_func(vectors, self.random_vectors)
+
+ # Add the hash values to the current table.
+ for h in hashes:
+ if h in self.table:
+ self.table[h].append(entry)
+ else:
+ self.table[h] = [entry]
+
+ def query(self, vectors):
+ # Compute hash value for the query vector.
+ hashes = hash_func(vectors, self.random_vectors)
+ results = []
+
+ # Loop over the query hashes and determine if they exist in
+ # the current table.
+ for h in hashes:
+ if h in self.table:
+ results.extend(self.table[h])
+ return results
+
+
+"""
+In the following `LSH` class we will pack the utilities to have multiple hash tables.
+"""
+
+
+class LSH:
+ def __init__(self, hash_size, dim, num_tables):
+ self.num_tables = num_tables
+ self.tables = []
+ for i in range(self.num_tables):
+ self.tables.append(Table(hash_size, dim))
+
+ def add(self, id, vectors, label):
+ for table in self.tables:
+ table.add(id, vectors, label)
+
+ def query(self, vectors):
+ results = []
+ for table in self.tables:
+ results.extend(table.query(vectors))
+ return results
+
+
+"""
+Now we can encapsulate the logic for building and operating with the master LSH table (a
+collection of many tables) inside a class. It has two methods:
+
+* `train()`: Responsible for building the final LSH table.
+* `query()`: Computes the number of matches given a query image and also quantifies the
+similarity score.
+"""
+
+
+class BuildLSHTable:
+ def __init__(
+ self,
+ prediction_model,
+ concrete_function=False,
+ hash_size=8,
+ dim=2048,
+ num_tables=10,
+ ):
+ self.hash_size = hash_size
+ self.dim = dim
+ self.num_tables = num_tables
+ self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
+
+ self.prediction_model = prediction_model
+ self.concrete_function = concrete_function
+
+ def train(self, training_files):
+ for id, training_file in enumerate(training_files):
+ # Unpack the data.
+ image, label = training_file
+ if len(image.shape) < 4:
+ image = image[None, ...]
+
+ # Compute embeddings and update the LSH tables.
+ # More on `self.concrete_function()` later.
+ if self.concrete_function:
+ features = self.prediction_model(tf.constant(image))[
+ "normalization"
+ ].numpy()
+ else:
+ features = self.prediction_model.predict(image)
+ self.lsh.add(id, features, label)
+
+ def query(self, image, verbose=True):
+ # Compute the embeddings of the query image and fetch the results.
+ if len(image.shape) < 4:
+ image = image[None, ...]
+
+ if self.concrete_function:
+ features = self.prediction_model(tf.constant(image))[
+ "normalization"
+ ].numpy()
+ else:
+ features = self.prediction_model.predict(image)
+
+ results = self.lsh.query(features)
+ if verbose:
+ print("Matches:", len(results))
+
+ # Calculate Jaccard index to quantify the similarity.
+ counts = {}
+ for r in results:
+ if r["id_label"] in counts:
+ counts[r["id_label"]] += 1
+ else:
+ counts[r["id_label"]] = 1
+ for k in counts:
+ counts[k] = float(counts[k]) / self.dim
+ return counts
+
+
+"""
+## Create LSH tables
+
+With our helper utilities and classes implemented, we can now build our LSH table. Since
+we will be benchmarking performance between optimized and unoptimized embedding models, we
+will also warm up our GPU to avoid any unfair comparison.
+"""
+
+
+# Utility to warm up the GPU.
+def warmup():
+ dummy_sample = tf.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3))
+ for _ in range(100):
+ _ = embedding_model.predict(dummy_sample)
+
+
+"""
+Now we can first do the GPU wam-up and proceed to build the master LSH table with
+`embedding_model`.
+"""
+
+warmup()
+
+training_files = zip(images, labels)
+lsh_builder = BuildLSHTable(embedding_model)
+lsh_builder.train(training_files)
+
+
+"""
+At the time of writing, the wall time was 54.1 seconds on a Tesla T4 GPU. This timing may
+vary based on the GPU you are using.
+"""
+
+"""
+## Optimize the model with TensorRT
+
+For NVIDIA-based GPUs, the
+[TensorRT framework](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html)
+can be used to dramatically enhance the inference latency by using various model
+optimization techniques like pruning, constant folding, layer fusion, and so on. Here we
+will use the `tf.experimental.tensorrt` module to optimize our embedding model.
+"""
+
+# First serialize the embedding model as a SavedModel.
+embedding_model.save("embedding_model")
+
+# Initialize the conversion parameters.
+params = tf.experimental.tensorrt.ConversionParams(
+ precision_mode="FP16", maximum_cached_engines=16
+)
+
+# Run the conversion.
+converter = tf.experimental.tensorrt.Converter(
+ input_saved_model_dir="embedding_model", conversion_params=params
+)
+converter.convert()
+converter.save("tensorrt_embedding_model")
+
+"""
+**Notes on the parameters inside of `tf.experimental.tensorrt.ConversionParams()`**:
+
+* `precision_mode` defines the numerical precision of the operations in the
+to-be-converted model.
+* `maximum_cached_engines` specifies the maximum number of TRT engines that will be
+cached to handle dynamic operations (operations with unknown shapes).
+
+To learn more about the other options, refer to the
+[official documentation](https://www.tensorflow.org/api_docs/python/tf/experimental/tensorrt/ConversionParams).
+You can also explore the different quantization options provided by the
+`tf.experimental.tensorrt` module.
+"""
+
+# Load the converted model.
+root = tf.saved_model.load("tensorrt_embedding_model")
+trt_model_function = root.signatures["serving_default"]
+
+"""
+## Build LSH tables with optimized model
+"""
+
+warmup()
+
+training_files = zip(images, labels)
+lsh_builder_trt = BuildLSHTable(trt_model_function, concrete_function=True)
+lsh_builder_trt.train(training_files)
+
+"""
+Notice the difference in the wall time which is **13.1 seconds**. Earlier, with the
+unoptimized model it was **54.1 seconds**.
+
+We can take a closer look into one of the hash tables and get an idea of how they are
+represented.
+"""
+
+idx = 0
+for hash, entry in lsh_builder_trt.lsh.tables[0].table.items():
+ if idx == 5:
+ break
+ if len(entry) < 5:
+ print(hash, entry)
+ idx += 1
+
+"""
+## Visualize results on validation images
+
+In this section we will first writing a couple of utility functions to visualize the
+similar image parsing process. Then we will benchmark the query performance of the models
+with and without optimization.
+"""
+
+"""
+First, we take 100 images from the validation set for testing purposes.
+"""
+
+validation_images = []
+validation_labels = []
+
+for image, label in validation_ds.take(100):
+ image = tf.image.resize(image, (224, 224))
+ validation_images.append(image.numpy())
+ validation_labels.append(label.numpy())
+
+validation_images = np.array(validation_images)
+validation_labels = np.array(validation_labels)
+validation_images.shape, validation_labels.shape
+
+
+"""
+Now we write our visualization utilities.
+"""
+
+
+def plot_images(images, labels):
+ plt.figure(figsize=(20, 10))
+ columns = 5
+ for i, image in enumerate(images):
+ ax = plt.subplot(len(images) // columns + 1, columns, i + 1)
+ if i == 0:
+ ax.set_title("Query Image\n" + "Label: {}".format(labels[i]))
+ else:
+ ax.set_title("Similar Image # " + str(i) + "\nLabel: {}".format(labels[i]))
+ plt.imshow(image.astype("int"))
+ plt.axis("off")
+
+
+def visualize_lsh(lsh_class):
+ idx = np.random.choice(len(validation_images))
+ image = validation_images[idx]
+ label = validation_labels[idx]
+ results = lsh_class.query(image)
+
+ candidates = []
+ labels = []
+ overlaps = []
+
+ for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
+ if idx == 4:
+ break
+ image_id, label = r.split("_")[0], r.split("_")[1]
+ candidates.append(images[int(image_id)])
+ labels.append(label)
+ overlaps.append(results[r])
+
+ candidates.insert(0, image)
+ labels.insert(0, label)
+
+ plot_images(candidates, labels)
+
+
+"""
+### Non-TRT model
+"""
+
+for _ in range(5):
+ visualize_lsh(lsh_builder)
+
+visualize_lsh(lsh_builder)
+
+"""
+### TRT model
+"""
+
+for _ in range(5):
+ visualize_lsh(lsh_builder_trt)
+
+"""
+As you may have noticed, there are a couple of incorrect results. This can be mitigated in
+a few ways:
+
+* Better models for generating the initial embeddings especially for noisy samples. We can
+use techniques like [ArcFace](https://arxiv.org/abs/1801.07698),
+[Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362), etc.
+that implicitly encourage better learning of representations for retrieval purposes.
+* The trade-off between the number of tables and the reduction dimensionality is crucial
+and helps set the right recall required for your application.
+"""
+
+"""
+## Benchmarking query performance
+"""
+
+
+def benchmark(lsh_class):
+ warmup()
+
+ start_time = time.time()
+ for _ in range(1000):
+ image = np.ones((1, 224, 224, 3)).astype("float32")
+ _ = lsh_class.query(image, verbose=False)
+ end_time = time.time() - start_time
+ print(f"Time taken: {end_time:.3f}")
+
+
+benchmark(lsh_builder)
+
+benchmark(lsh_builder_trt)
+
+"""
+We can immediately notice a stark difference between the query performance of the two
+models.
+"""
+
+"""
+## Final remarks
+
+In this example, we explored the TensorRT framework from NVIDIA for optimizing our model.
+It's best suited for GPU-based inference servers. There are other choices for such
+frameworks that cater to different hardware platforms:
+
+* [TensorFlow Lite](https://www.tensorflow.org/lite) for mobile and edge devices.
+* [ONNX](hhttps://onnx.ai/) for commodity CPU-based servers.
+* [Apache TVM](https://tvm.apache.org/), compiler for machine learning models covering
+various platforms.
+
+Here are a few resources you might want to check out to learn more
+about applications based on vector similary search in general:
+
+* [ANN Benchmarks](http://ann-benchmarks.com/)
+* [Accelerating Large-Scale Inference with Anisotropic Vector Quantization(ScaNN)](https://arxiv.org/abs/1908.10396)
+* [Spreading vectors for similarity search](https://arxiv.org/abs/1806.03198)
+* [Building a real-time embeddings similarity matching system](https://cloud.google.com/architecture/building-real-time-embeddings-similarity-matching-system)
+"""
diff --git a/knowledge_base/vision/nerf.py b/knowledge_base/vision/nerf.py
new file mode 100644
index 0000000000000000000000000000000000000000..580eaffe9ad210ebb803ca9cdb31f1ad0352a0c8
--- /dev/null
+++ b/knowledge_base/vision/nerf.py
@@ -0,0 +1,777 @@
+"""
+Title: 3D volumetric rendering with NeRF
+Authors: [Aritra Roy Gosthipaty](https://twitter.com/arig23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+Date created: 2021/08/09
+Last modified: 2023/11/13
+Description: Minimal implementation of volumetric rendering as shown in NeRF.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we present a minimal implementation of the research paper
+[**NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis**](https://arxiv.org/abs/2003.08934)
+by Ben Mildenhall et. al. The authors have proposed an ingenious way
+to *synthesize novel views of a scene* by modelling the *volumetric
+scene function* through a neural network.
+
+To help you understand this intuitively, let's start with the following question:
+*would it be possible to give to a neural
+network the position of a pixel in an image, and ask the network
+to predict the color at that position?*
+
+|  |
+| :---: |
+| **Figure 1**: A neural network being given coordinates of an image
+as input and asked to predict the color at the coordinates. |
+
+The neural network would hypothetically *memorize* (overfit on) the
+image. This means that our neural network would have encoded the entire image
+in its weights. We could query the neural network with each position,
+and it would eventually reconstruct the entire image.
+
+|  |
+| :---: |
+| **Figure 2**: The trained neural network recreates the image from scratch. |
+
+A question now arises, how do we extend this idea to learn a 3D
+volumetric scene? Implementing a similar process as above would
+require the knowledge of every voxel (volume pixel). Turns out, this
+is quite a challenging task to do.
+
+The authors of the paper propose a minimal and elegant way to learn a
+3D scene using a few images of the scene. They discard the use of
+voxels for training. The network learns to model the volumetric scene,
+thus generating novel views (images) of the 3D scene that the model
+was not shown at training time.
+
+There are a few prerequisites one needs to understand to fully
+appreciate the process. We structure the example in such a way that
+you will have all the required knowledge before starting the
+implementation.
+"""
+
+"""
+## Setup
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+# Setting random seed to obtain reproducible results.
+import tensorflow as tf
+
+tf.random.set_seed(42)
+
+import keras
+from keras import layers
+
+import os
+import glob
+import imageio.v2 as imageio
+import numpy as np
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+
+# Initialize global variables.
+AUTO = tf.data.AUTOTUNE
+BATCH_SIZE = 5
+NUM_SAMPLES = 32
+POS_ENCODE_DIMS = 16
+EPOCHS = 20
+
+"""
+## Download and load the data
+
+The `npz` data file contains images, camera poses, and a focal length.
+The images are taken from multiple camera angles as shown in
+**Figure 3**.
+
+|  |
+| :---: |
+| **Figure 3**: Multiple camera angles
+[Source: NeRF](https://arxiv.org/abs/2003.08934) |
+
+
+To understand camera poses in this context we have to first allow
+ourselves to think that a *camera is a mapping between the real-world
+and the 2-D image*.
+
+|  |
+| :---: |
+| **Figure 4**: 3-D world to 2-D image mapping through a camera
+[Source: Mathworks](https://www.mathworks.com/help/vision/ug/camera-calibration.html) |
+
+Consider the following equation:
+
+
+
+Where **x** is the 2-D image point, **X** is the 3-D world point and
+**P** is the camera-matrix. **P** is a 3 x 4 matrix that plays the
+crucial role of mapping the real world object onto an image plane.
+
+
+
+The camera-matrix is an *affine transform matrix* that is
+concatenated with a 3 x 1 column `[image height, image width, focal length]`
+to produce the *pose matrix*. This matrix is of
+dimensions 3 x 5 where the first 3 x 3 block is in the cameraโs point
+of view. The axes are `[down, right, backwards]` or `[-y, x, z]`
+where the camera is facing forwards `-z`.
+
+|  |
+| :---: |
+| **Figure 5**: The affine transformation. |
+
+The COLMAP frame is `[right, down, forwards]` or `[x, -y, -z]`. Read
+more about COLMAP [here](https://colmap.github.io/).
+"""
+
+# Download the data if it does not already exist.
+url = (
+ "http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz"
+)
+data = keras.utils.get_file(origin=url)
+
+data = np.load(data)
+images = data["images"]
+im_shape = images.shape
+(num_images, H, W, _) = images.shape
+(poses, focal) = (data["poses"], data["focal"])
+
+# Plot a random image from the dataset for visualization.
+plt.imshow(images[np.random.randint(low=0, high=num_images)])
+plt.show()
+
+"""
+## Data pipeline
+
+Now that you've understood the notion of camera matrix
+and the mapping from a 3D scene to 2D images,
+let's talk about the inverse mapping, i.e. from 2D image to the 3D scene.
+
+We'll need to talk about volumetric rendering with ray casting and tracing,
+which are common computer graphics techniques.
+This section will help you get to speed with these techniques.
+
+Consider an image with `N` pixels. We shoot a ray through each pixel
+and sample some points on the ray. A ray is commonly parameterized by
+the equation `r(t) = o + td` where `t` is the parameter, `o` is the
+origin and `d` is the unit directional vector as shown in **Figure 6**.
+
+|  |
+| :---: |
+| **Figure 6**: `r(t) = o + td` where t is 3 |
+
+In **Figure 7**, we consider a ray, and we sample some random points on
+the ray. These sample points each have a unique location `(x, y, z)`
+and the ray has a viewing angle `(theta, phi)`. The viewing angle is
+particularly interesting as we can shoot a ray through a single pixel
+in a lot of different ways, each with a unique viewing angle. Another
+interesting thing to notice here is the noise that is added to the
+sampling process. We add a uniform noise to each sample so that the
+samples correspond to a continuous distribution. In **Figure 7** the
+blue points are the evenly distributed samples and the white points
+`(t1, t2, t3)` are randomly placed between the samples.
+
+|  |
+| :---: |
+| **Figure 7**: Sampling the points from a ray. |
+
+**Figure 8** showcases the entire sampling process in 3D, where you
+can see the rays coming out of the white image. This means that each
+pixel will have its corresponding rays and each ray will be sampled at
+distinct points.
+
+|  |
+| :---: |
+| **Figure 8**: Shooting rays from all the pixels of an image in 3-D |
+
+These sampled points act as the input to the NeRF model. The model is
+then asked to predict the RGB color and the volume density at that
+point.
+
+|  |
+| :---: |
+| **Figure 9**: Data pipeline
+[Source: NeRF](https://arxiv.org/abs/2003.08934) |
+
+"""
+
+
+def encode_position(x):
+ """Encodes the position into its corresponding Fourier feature.
+
+ Args:
+ x: The input coordinate.
+
+ Returns:
+ Fourier features tensors of the position.
+ """
+ positions = [x]
+ for i in range(POS_ENCODE_DIMS):
+ for fn in [tf.sin, tf.cos]:
+ positions.append(fn(2.0**i * x))
+ return tf.concat(positions, axis=-1)
+
+
+def get_rays(height, width, focal, pose):
+ """Computes origin point and direction vector of rays.
+
+ Args:
+ height: Height of the image.
+ width: Width of the image.
+ focal: The focal length between the images and the camera.
+ pose: The pose matrix of the camera.
+
+ Returns:
+ Tuple of origin point and direction vector for rays.
+ """
+ # Build a meshgrid for the rays.
+ i, j = tf.meshgrid(
+ tf.range(width, dtype=tf.float32),
+ tf.range(height, dtype=tf.float32),
+ indexing="xy",
+ )
+
+ # Normalize the x axis coordinates.
+ transformed_i = (i - width * 0.5) / focal
+
+ # Normalize the y axis coordinates.
+ transformed_j = (j - height * 0.5) / focal
+
+ # Create the direction unit vectors.
+ directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
+
+ # Get the camera matrix.
+ camera_matrix = pose[:3, :3]
+ height_width_focal = pose[:3, -1]
+
+ # Get origins and directions for the rays.
+ transformed_dirs = directions[..., None, :]
+ camera_dirs = transformed_dirs * camera_matrix
+ ray_directions = tf.reduce_sum(camera_dirs, axis=-1)
+ ray_origins = tf.broadcast_to(height_width_focal, tf.shape(ray_directions))
+
+ # Return the origins and directions.
+ return (ray_origins, ray_directions)
+
+
+def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
+ """Renders the rays and flattens it.
+
+ Args:
+ ray_origins: The origin points for rays.
+ ray_directions: The direction unit vectors for the rays.
+ near: The near bound of the volumetric scene.
+ far: The far bound of the volumetric scene.
+ num_samples: Number of sample points in a ray.
+ rand: Choice for randomising the sampling strategy.
+
+ Returns:
+ Tuple of flattened rays and sample points on each rays.
+ """
+ # Compute 3D query points.
+ # Equation: r(t) = o+td -> Building the "t" here.
+ t_vals = tf.linspace(near, far, num_samples)
+ if rand:
+ # Inject uniform noise into sample space to make the sampling
+ # continuous.
+ shape = list(ray_origins.shape[:-1]) + [num_samples]
+ noise = tf.random.uniform(shape=shape) * (far - near) / num_samples
+ t_vals = t_vals + noise
+
+ # Equation: r(t) = o + td -> Building the "r" here.
+ rays = ray_origins[..., None, :] + (
+ ray_directions[..., None, :] * t_vals[..., None]
+ )
+ rays_flat = tf.reshape(rays, [-1, 3])
+ rays_flat = encode_position(rays_flat)
+ return (rays_flat, t_vals)
+
+
+def map_fn(pose):
+ """Maps individual pose to flattened rays and sample points.
+
+ Args:
+ pose: The pose matrix of the camera.
+
+ Returns:
+ Tuple of flattened rays and sample points corresponding to the
+ camera pose.
+ """
+ (ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
+ (rays_flat, t_vals) = render_flat_rays(
+ ray_origins=ray_origins,
+ ray_directions=ray_directions,
+ near=2.0,
+ far=6.0,
+ num_samples=NUM_SAMPLES,
+ rand=True,
+ )
+ return (rays_flat, t_vals)
+
+
+# Create the training split.
+split_index = int(num_images * 0.8)
+
+# Split the images into training and validation.
+train_images = images[:split_index]
+val_images = images[split_index:]
+
+# Split the poses into training and validation.
+train_poses = poses[:split_index]
+val_poses = poses[split_index:]
+
+# Make the training pipeline.
+train_img_ds = tf.data.Dataset.from_tensor_slices(train_images)
+train_pose_ds = tf.data.Dataset.from_tensor_slices(train_poses)
+train_ray_ds = train_pose_ds.map(map_fn, num_parallel_calls=AUTO)
+training_ds = tf.data.Dataset.zip((train_img_ds, train_ray_ds))
+train_ds = (
+ training_ds.shuffle(BATCH_SIZE)
+ .batch(BATCH_SIZE, drop_remainder=True, num_parallel_calls=AUTO)
+ .prefetch(AUTO)
+)
+
+# Make the validation pipeline.
+val_img_ds = tf.data.Dataset.from_tensor_slices(val_images)
+val_pose_ds = tf.data.Dataset.from_tensor_slices(val_poses)
+val_ray_ds = val_pose_ds.map(map_fn, num_parallel_calls=AUTO)
+validation_ds = tf.data.Dataset.zip((val_img_ds, val_ray_ds))
+val_ds = (
+ validation_ds.shuffle(BATCH_SIZE)
+ .batch(BATCH_SIZE, drop_remainder=True, num_parallel_calls=AUTO)
+ .prefetch(AUTO)
+)
+
+"""
+## NeRF model
+
+The model is a multi-layer perceptron (MLP), with ReLU as its non-linearity.
+
+An excerpt from the paper:
+
+*"We encourage the representation to be multiview-consistent by
+restricting the network to predict the volume density sigma as a
+function of only the location `x`, while allowing the RGB color `c` to be
+predicted as a function of both location and viewing direction. To
+accomplish this, the MLP first processes the input 3D coordinate `x`
+with 8 fully-connected layers (using ReLU activations and 256 channels
+per layer), and outputs sigma and a 256-dimensional feature vector.
+This feature vector is then concatenated with the camera ray's viewing
+direction and passed to one additional fully-connected layer (using a
+ReLU activation and 128 channels) that output the view-dependent RGB
+color."*
+
+Here we have gone for a minimal implementation and have used 64
+Dense units instead of 256 as mentioned in the paper.
+"""
+
+
+def get_nerf_model(num_layers, num_pos):
+ """Generates the NeRF neural network.
+
+ Args:
+ num_layers: The number of MLP layers.
+ num_pos: The number of dimensions of positional encoding.
+
+ Returns:
+ The `keras` model.
+ """
+ inputs = keras.Input(shape=(num_pos, 2 * 3 * POS_ENCODE_DIMS + 3))
+ x = inputs
+ for i in range(num_layers):
+ x = layers.Dense(units=64, activation="relu")(x)
+ if i % 4 == 0 and i > 0:
+ # Inject residual connection.
+ x = layers.concatenate([x, inputs], axis=-1)
+ outputs = layers.Dense(units=4)(x)
+ return keras.Model(inputs=inputs, outputs=outputs)
+
+
+def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
+ """Generates the RGB image and depth map from model prediction.
+
+ Args:
+ model: The MLP model that is trained to predict the rgb and
+ volume density of the volumetric scene.
+ rays_flat: The flattened rays that serve as the input to
+ the NeRF model.
+ t_vals: The sample points for the rays.
+ rand: Choice to randomise the sampling strategy.
+ train: Whether the model is in the training or testing phase.
+
+ Returns:
+ Tuple of rgb image and depth map.
+ """
+ # Get the predictions from the nerf model and reshape it.
+ if train:
+ predictions = model(rays_flat)
+ else:
+ predictions = model.predict(rays_flat)
+ predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))
+
+ # Slice the predictions into rgb and sigma.
+ rgb = tf.sigmoid(predictions[..., :-1])
+ sigma_a = tf.nn.relu(predictions[..., -1])
+
+ # Get the distance of adjacent intervals.
+ delta = t_vals[..., 1:] - t_vals[..., :-1]
+ # delta shape = (num_samples)
+ if rand:
+ delta = tf.concat(
+ [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
+ )
+ alpha = 1.0 - tf.exp(-sigma_a * delta)
+ else:
+ delta = tf.concat(
+ [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
+ )
+ alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])
+
+ # Get transmittance.
+ exp_term = 1.0 - alpha
+ epsilon = 1e-10
+ transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
+ weights = alpha * transmittance
+ rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
+
+ if rand:
+ depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
+ else:
+ depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
+ return (rgb, depth_map)
+
+
+"""
+## Training
+
+The training step is implemented as part of a custom `keras.Model` subclass
+so that we can make use of the `model.fit` functionality.
+"""
+
+
+class NeRF(keras.Model):
+ def __init__(self, nerf_model):
+ super().__init__()
+ self.nerf_model = nerf_model
+
+ def compile(self, optimizer, loss_fn):
+ super().compile()
+ self.optimizer = optimizer
+ self.loss_fn = loss_fn
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+ self.psnr_metric = keras.metrics.Mean(name="psnr")
+
+ def train_step(self, inputs):
+ # Get the images and the rays.
+ (images, rays) = inputs
+ (rays_flat, t_vals) = rays
+
+ with tf.GradientTape() as tape:
+ # Get the predictions from the model.
+ rgb, _ = render_rgb_depth(
+ model=self.nerf_model, rays_flat=rays_flat, t_vals=t_vals, rand=True
+ )
+ loss = self.loss_fn(images, rgb)
+
+ # Get the trainable variables.
+ trainable_variables = self.nerf_model.trainable_variables
+
+ # Get the gradeints of the trainiable variables with respect to the loss.
+ gradients = tape.gradient(loss, trainable_variables)
+
+ # Apply the grads and optimize the model.
+ self.optimizer.apply_gradients(zip(gradients, trainable_variables))
+
+ # Get the PSNR of the reconstructed images and the source images.
+ psnr = tf.image.psnr(images, rgb, max_val=1.0)
+
+ # Compute our own metrics
+ self.loss_tracker.update_state(loss)
+ self.psnr_metric.update_state(psnr)
+ return {"loss": self.loss_tracker.result(), "psnr": self.psnr_metric.result()}
+
+ def test_step(self, inputs):
+ # Get the images and the rays.
+ (images, rays) = inputs
+ (rays_flat, t_vals) = rays
+
+ # Get the predictions from the model.
+ rgb, _ = render_rgb_depth(
+ model=self.nerf_model, rays_flat=rays_flat, t_vals=t_vals, rand=True
+ )
+ loss = self.loss_fn(images, rgb)
+
+ # Get the PSNR of the reconstructed images and the source images.
+ psnr = tf.image.psnr(images, rgb, max_val=1.0)
+
+ # Compute our own metrics
+ self.loss_tracker.update_state(loss)
+ self.psnr_metric.update_state(psnr)
+ return {"loss": self.loss_tracker.result(), "psnr": self.psnr_metric.result()}
+
+ @property
+ def metrics(self):
+ return [self.loss_tracker, self.psnr_metric]
+
+
+test_imgs, test_rays = next(iter(train_ds))
+test_rays_flat, test_t_vals = test_rays
+
+loss_list = []
+
+
+class TrainMonitor(keras.callbacks.Callback):
+ def on_epoch_end(self, epoch, logs=None):
+ loss = logs["loss"]
+ loss_list.append(loss)
+ test_recons_images, depth_maps = render_rgb_depth(
+ model=self.model.nerf_model,
+ rays_flat=test_rays_flat,
+ t_vals=test_t_vals,
+ rand=True,
+ train=False,
+ )
+
+ # Plot the rgb, depth and the loss plot.
+ fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 5))
+ ax[0].imshow(keras.utils.array_to_img(test_recons_images[0]))
+ ax[0].set_title(f"Predicted Image: {epoch:03d}")
+
+ ax[1].imshow(keras.utils.array_to_img(depth_maps[0, ..., None]))
+ ax[1].set_title(f"Depth Map: {epoch:03d}")
+
+ ax[2].plot(loss_list)
+ ax[2].set_xticks(np.arange(0, EPOCHS + 1, 5.0))
+ ax[2].set_title(f"Loss Plot: {epoch:03d}")
+
+ fig.savefig(f"images/{epoch:03d}.png")
+ plt.show()
+ plt.close()
+
+
+num_pos = H * W * NUM_SAMPLES
+nerf_model = get_nerf_model(num_layers=8, num_pos=num_pos)
+
+model = NeRF(nerf_model)
+model.compile(
+ optimizer=keras.optimizers.Adam(), loss_fn=keras.losses.MeanSquaredError()
+)
+
+# Create a directory to save the images during training.
+if not os.path.exists("images"):
+ os.makedirs("images")
+
+model.fit(
+ train_ds,
+ validation_data=val_ds,
+ batch_size=BATCH_SIZE,
+ epochs=EPOCHS,
+ callbacks=[TrainMonitor()],
+)
+
+
+def create_gif(path_to_images, name_gif):
+ filenames = glob.glob(path_to_images)
+ filenames = sorted(filenames)
+ images = []
+ for filename in tqdm(filenames):
+ images.append(imageio.imread(filename))
+ kargs = {"duration": 0.25}
+ imageio.mimsave(name_gif, images, "GIF", **kargs)
+
+
+create_gif("images/*.png", "training.gif")
+
+"""
+## Visualize the training step
+
+Here we see the training step. With the decreasing loss, the rendered
+image and the depth maps are getting better. In your local system, you
+will see the `training.gif` file generated.
+
+
+"""
+
+"""
+## Inference
+
+In this section, we ask the model to build novel views of the scene.
+The model was given `106` views of the scene in the training step. The
+collections of training images cannot contain each and every angle of
+the scene. A trained model can represent the entire 3-D scene with a
+sparse set of training images.
+
+Here we provide different poses to the model and ask for it to give us
+the 2-D image corresponding to that camera view. If we infer the model
+for all the 360-degree views, it should provide an overview of the
+entire scenery from all around.
+"""
+
+# Get the trained NeRF model and infer.
+nerf_model = model.nerf_model
+test_recons_images, depth_maps = render_rgb_depth(
+ model=nerf_model,
+ rays_flat=test_rays_flat,
+ t_vals=test_t_vals,
+ rand=True,
+ train=False,
+)
+
+# Create subplots.
+fig, axes = plt.subplots(nrows=5, ncols=3, figsize=(10, 20))
+
+for ax, ori_img, recons_img, depth_map in zip(
+ axes, test_imgs, test_recons_images, depth_maps
+):
+ ax[0].imshow(keras.utils.array_to_img(ori_img))
+ ax[0].set_title("Original")
+
+ ax[1].imshow(keras.utils.array_to_img(recons_img))
+ ax[1].set_title("Reconstructed")
+
+ ax[2].imshow(keras.utils.array_to_img(depth_map[..., None]), cmap="inferno")
+ ax[2].set_title("Depth Map")
+
+"""
+## Render 3D Scene
+
+Here we will synthesize novel 3D views and stitch all of them together
+to render a video encompassing the 360-degree view.
+"""
+
+
+def get_translation_t(t):
+ """Get the translation matrix for movement in t."""
+ matrix = [
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, t],
+ [0, 0, 0, 1],
+ ]
+ return tf.convert_to_tensor(matrix, dtype=tf.float32)
+
+
+def get_rotation_phi(phi):
+ """Get the rotation matrix for movement in phi."""
+ matrix = [
+ [1, 0, 0, 0],
+ [0, tf.cos(phi), -tf.sin(phi), 0],
+ [0, tf.sin(phi), tf.cos(phi), 0],
+ [0, 0, 0, 1],
+ ]
+ return tf.convert_to_tensor(matrix, dtype=tf.float32)
+
+
+def get_rotation_theta(theta):
+ """Get the rotation matrix for movement in theta."""
+ matrix = [
+ [tf.cos(theta), 0, -tf.sin(theta), 0],
+ [0, 1, 0, 0],
+ [tf.sin(theta), 0, tf.cos(theta), 0],
+ [0, 0, 0, 1],
+ ]
+ return tf.convert_to_tensor(matrix, dtype=tf.float32)
+
+
+def pose_spherical(theta, phi, t):
+ """
+ Get the camera to world matrix for the corresponding theta, phi
+ and t.
+ """
+ c2w = get_translation_t(t)
+ c2w = get_rotation_phi(phi / 180.0 * np.pi) @ c2w
+ c2w = get_rotation_theta(theta / 180.0 * np.pi) @ c2w
+ c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
+ return c2w
+
+
+rgb_frames = []
+batch_flat = []
+batch_t = []
+
+# Iterate over different theta value and generate scenes.
+for index, theta in tqdm(enumerate(np.linspace(0.0, 360.0, 120, endpoint=False))):
+ # Get the camera to world matrix.
+ c2w = pose_spherical(theta, -30.0, 4.0)
+
+ #
+ ray_oris, ray_dirs = get_rays(H, W, focal, c2w)
+ rays_flat, t_vals = render_flat_rays(
+ ray_oris, ray_dirs, near=2.0, far=6.0, num_samples=NUM_SAMPLES, rand=False
+ )
+
+ if index % BATCH_SIZE == 0 and index > 0:
+ batched_flat = tf.stack(batch_flat, axis=0)
+ batch_flat = [rays_flat]
+
+ batched_t = tf.stack(batch_t, axis=0)
+ batch_t = [t_vals]
+
+ rgb, _ = render_rgb_depth(
+ nerf_model, batched_flat, batched_t, rand=False, train=False
+ )
+
+ temp_rgb = [np.clip(255 * img, 0.0, 255.0).astype(np.uint8) for img in rgb]
+
+ rgb_frames = rgb_frames + temp_rgb
+ else:
+ batch_flat.append(rays_flat)
+ batch_t.append(t_vals)
+
+rgb_video = "rgb_video.mp4"
+imageio.mimwrite(rgb_video, rgb_frames, fps=30, quality=7, macro_block_size=None)
+
+"""
+### Visualize the video
+
+Here we can see the rendered 360 degree view of the scene. The model
+has successfully learned the entire volumetric space through the
+sparse set of images in **only 20 epochs**. You can view the
+rendered video saved locally, named `rgb_video.mp4`.
+
+
+"""
+
+"""
+## Conclusion
+
+We have produced a minimal implementation of NeRF to provide an intuition of its
+core ideas and methodology. This method has been used in various
+other works in the computer graphics space.
+
+We would like to encourage our readers to use this code as an example
+and play with the hyperparameters and visualize the outputs. Below we
+have also provided the outputs of the model trained for more epochs.
+
+| Epochs | GIF of the training step |
+| :--- | :---: |
+| **100** |  |
+| **200** |  |
+
+## Way forward
+
+If anyone is interested to go deeper into NeRF, we have built a 3-part blog
+series at [PyImageSearch](https://pyimagesearch.com/).
+
+- [Prerequisites of NeRF](https://www.pyimagesearch.com/2021/11/10/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-1/)
+- [Concepts of NeRF](https://www.pyimagesearch.com/2021/11/17/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-2/)
+- [Implementing NeRF](https://www.pyimagesearch.com/2021/11/24/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-3/)
+
+## Reference
+
+- [NeRF repository](https://github.com/bmild/nerf): The official
+ repository for NeRF.
+- [NeRF paper](https://arxiv.org/abs/2003.08934): The paper on NeRF.
+- [Manim Repository](https://github.com/3b1b/manim): We have used
+ manim to build all the animations.
+- [Mathworks](https://www.mathworks.com/help/vision/ug/camera-calibration.html):
+ Mathworks for the camera calibration article.
+- [Mathew's video](https://www.youtube.com/watch?v=dPWLybp4LL0): A
+ great video on NeRF.
+
+You can try the model on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/NeRF).
+"""
diff --git a/knowledge_base/vision/nl_image_search.py b/knowledge_base/vision/nl_image_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..942d808656141c926855de3806000dcbec6b8565
--- /dev/null
+++ b/knowledge_base/vision/nl_image_search.py
@@ -0,0 +1,594 @@
+"""
+Title: Natural language image search with a Dual Encoder
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/01/30
+Last modified: 2021/01/30
+Description: Implementation of a dual encoder model for retrieving images that match natural language queries.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The example demonstrates how to build a dual encoder (also known as two-tower) neural network
+model to search for images using natural language. The model is inspired by
+the [CLIP](https://openai.com/blog/clip/)
+approach, introduced by Alec Radford et al. The idea is to train a vision encoder and a text
+encoder jointly to project the representation of images and their captions into the same embedding
+space, such that the caption embeddings are located near the embeddings of the images they describe.
+
+This example requires TensorFlow 2.4 or higher.
+In addition, [TensorFlow Hub](https://www.tensorflow.org/hub)
+and [TensorFlow Text](https://www.tensorflow.org/tutorials/tensorflow_text/intro)
+are required for the BERT model, and [TensorFlow Addons](https://www.tensorflow.org/addons)
+is required for the AdamW optimizer. These libraries can be installed using the
+following command:
+
+```python
+pip install -q -U tensorflow-hub tensorflow-text tensorflow-addons
+```
+"""
+
+"""
+## Setup
+"""
+
+import os
+import collections
+import json
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import tensorflow_hub as hub
+import tensorflow_text as text
+import tensorflow_addons as tfa
+import matplotlib.pyplot as plt
+import matplotlib.image as mpimg
+from tqdm import tqdm
+
+# Suppressing tf.hub warnings
+tf.get_logger().setLevel("ERROR")
+
+"""
+## Prepare the data
+
+We will use the [MS-COCO](https://cocodataset.org/#home) dataset to train our
+dual encoder model. MS-COCO contains over 82,000 images, each of which has at least
+5 different caption annotations. The dataset is usually used for
+[image captioning](https://www.tensorflow.org/tutorials/text/image_captioning)
+tasks, but we can repurpose the image-caption pairs to train our dual encoder
+model for image search.
+
+###
+Download and extract the data
+
+First, let's download the dataset, which consists of two compressed folders:
+one with images, and the otherโwith associated image captions.
+Note that the compressed images folder is 13GB in size.
+"""
+
+root_dir = "datasets"
+annotations_dir = os.path.join(root_dir, "annotations")
+images_dir = os.path.join(root_dir, "train2014")
+tfrecords_dir = os.path.join(root_dir, "tfrecords")
+annotation_file = os.path.join(annotations_dir, "captions_train2014.json")
+
+# Download caption annotation files
+if not os.path.exists(annotations_dir):
+ annotation_zip = tf.keras.utils.get_file(
+ "captions.zip",
+ cache_dir=os.path.abspath("."),
+ origin="http://images.cocodataset.org/annotations/annotations_trainval2014.zip",
+ extract=True,
+ )
+ os.remove(annotation_zip)
+
+# Download image files
+if not os.path.exists(images_dir):
+ image_zip = tf.keras.utils.get_file(
+ "train2014.zip",
+ cache_dir=os.path.abspath("."),
+ origin="http://images.cocodataset.org/zips/train2014.zip",
+ extract=True,
+ )
+ os.remove(image_zip)
+
+print("Dataset is downloaded and extracted successfully.")
+
+with open(annotation_file, "r") as f:
+ annotations = json.load(f)["annotations"]
+
+image_path_to_caption = collections.defaultdict(list)
+for element in annotations:
+ caption = f"{element['caption'].lower().rstrip('.')}"
+ image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
+ image_path_to_caption[image_path].append(caption)
+
+image_paths = list(image_path_to_caption.keys())
+print(f"Number of images: {len(image_paths)}")
+
+"""
+### Process and save the data to TFRecord files
+
+You can change the `sample_size` parameter to control many image-caption pairs
+will be used for training the dual encoder model.
+In this example we set `train_size` to 30,000 images,
+which is about 35% of the dataset. We use 2 captions for each
+image, thus producing 60,000 image-caption pairs. The size of the training set
+affects the quality of the produced encoders, but more examples would lead to
+longer training time.
+"""
+
+train_size = 30000
+valid_size = 5000
+captions_per_image = 2
+images_per_file = 2000
+
+train_image_paths = image_paths[:train_size]
+num_train_files = int(np.ceil(train_size / images_per_file))
+train_files_prefix = os.path.join(tfrecords_dir, "train")
+
+valid_image_paths = image_paths[-valid_size:]
+num_valid_files = int(np.ceil(valid_size / images_per_file))
+valid_files_prefix = os.path.join(tfrecords_dir, "valid")
+
+tf.io.gfile.makedirs(tfrecords_dir)
+
+
+def bytes_feature(value):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def create_example(image_path, caption):
+ feature = {
+ "caption": bytes_feature(caption.encode()),
+ "raw_image": bytes_feature(tf.io.read_file(image_path).numpy()),
+ }
+ return tf.train.Example(features=tf.train.Features(feature=feature))
+
+
+def write_tfrecords(file_name, image_paths):
+ caption_list = []
+ image_path_list = []
+ for image_path in image_paths:
+ captions = image_path_to_caption[image_path][:captions_per_image]
+ caption_list.extend(captions)
+ image_path_list.extend([image_path] * len(captions))
+
+ with tf.io.TFRecordWriter(file_name) as writer:
+ for example_idx in range(len(image_path_list)):
+ example = create_example(
+ image_path_list[example_idx], caption_list[example_idx]
+ )
+ writer.write(example.SerializeToString())
+ return example_idx + 1
+
+
+def write_data(image_paths, num_files, files_prefix):
+ example_counter = 0
+ for file_idx in tqdm(range(num_files)):
+ file_name = files_prefix + "-%02d.tfrecord" % (file_idx)
+ start_idx = images_per_file * file_idx
+ end_idx = start_idx + images_per_file
+ example_counter += write_tfrecords(file_name, image_paths[start_idx:end_idx])
+ return example_counter
+
+
+train_example_count = write_data(train_image_paths, num_train_files, train_files_prefix)
+print(f"{train_example_count} training examples were written to tfrecord files.")
+
+valid_example_count = write_data(valid_image_paths, num_valid_files, valid_files_prefix)
+print(f"{valid_example_count} evaluation examples were written to tfrecord files.")
+
+"""
+### Create `tf.data.Dataset` for training and evaluation
+"""
+
+
+feature_description = {
+ "caption": tf.io.FixedLenFeature([], tf.string),
+ "raw_image": tf.io.FixedLenFeature([], tf.string),
+}
+
+
+def read_example(example):
+ features = tf.io.parse_single_example(example, feature_description)
+ raw_image = features.pop("raw_image")
+ features["image"] = tf.image.resize(
+ tf.image.decode_jpeg(raw_image, channels=3), size=(299, 299)
+ )
+ return features
+
+
+def get_dataset(file_pattern, batch_size):
+ return (
+ tf.data.TFRecordDataset(tf.data.Dataset.list_files(file_pattern))
+ .map(
+ read_example,
+ num_parallel_calls=tf.data.AUTOTUNE,
+ deterministic=False,
+ )
+ .shuffle(batch_size * 10)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+ .batch(batch_size)
+ )
+
+
+"""
+## Implement the projection head
+
+The projection head is used to transform the image and the text embeddings to
+the same embedding space with the same dimensionality.
+"""
+
+
+def project_embeddings(
+ embeddings, num_projection_layers, projection_dims, dropout_rate
+):
+ projected_embeddings = layers.Dense(units=projection_dims)(embeddings)
+ for _ in range(num_projection_layers):
+ x = tf.nn.gelu(projected_embeddings)
+ x = layers.Dense(projection_dims)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ x = layers.Add()([projected_embeddings, x])
+ projected_embeddings = layers.LayerNormalization()(x)
+ return projected_embeddings
+
+
+"""
+## Implement the vision encoder
+
+In this example, we use [Xception](https://keras.io/api/applications/xception/)
+from [Keras Applications](https://keras.io/api/applications/) as the base for the
+vision encoder.
+"""
+
+
+def create_vision_encoder(
+ num_projection_layers, projection_dims, dropout_rate, trainable=False
+):
+ # Load the pre-trained Xception model to be used as the base encoder.
+ xception = keras.applications.Xception(
+ include_top=False, weights="imagenet", pooling="avg"
+ )
+ # Set the trainability of the base encoder.
+ for layer in xception.layers:
+ layer.trainable = trainable
+ # Receive the images as inputs.
+ inputs = layers.Input(shape=(299, 299, 3), name="image_input")
+ # Preprocess the input image.
+ xception_input = tf.keras.applications.xception.preprocess_input(inputs)
+ # Generate the embeddings for the images using the xception model.
+ embeddings = xception(xception_input)
+ # Project the embeddings produced by the model.
+ outputs = project_embeddings(
+ embeddings, num_projection_layers, projection_dims, dropout_rate
+ )
+ # Create the vision encoder model.
+ return keras.Model(inputs, outputs, name="vision_encoder")
+
+
+"""
+## Implement the text encoder
+
+We use [BERT](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1)
+from [TensorFlow Hub](https://tfhub.dev) as the text encoder
+"""
+
+
+def create_text_encoder(
+ num_projection_layers, projection_dims, dropout_rate, trainable=False
+):
+ # Load the BERT preprocessing module.
+ preprocess = hub.KerasLayer(
+ "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2",
+ name="text_preprocessing",
+ )
+ # Load the pre-trained BERT model to be used as the base encoder.
+ bert = hub.KerasLayer(
+ "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1",
+ "bert",
+ )
+ # Set the trainability of the base encoder.
+ bert.trainable = trainable
+ # Receive the text as inputs.
+ inputs = layers.Input(shape=(), dtype=tf.string, name="text_input")
+ # Preprocess the text.
+ bert_inputs = preprocess(inputs)
+ # Generate embeddings for the preprocessed text using the BERT model.
+ embeddings = bert(bert_inputs)["pooled_output"]
+ # Project the embeddings produced by the model.
+ outputs = project_embeddings(
+ embeddings, num_projection_layers, projection_dims, dropout_rate
+ )
+ # Create the text encoder model.
+ return keras.Model(inputs, outputs, name="text_encoder")
+
+
+"""
+## Implement the dual encoder
+
+To calculate the loss, we compute the pairwise dot-product similarity between
+each `caption_i` and `images_j` in the batch as the predictions.
+The target similarity between `caption_i` and `image_j` is computed as
+the average of the (dot-product similarity between `caption_i` and `caption_j`)
+and (the dot-product similarity between `image_i` and `image_j`).
+Then, we use crossentropy to compute the loss between the targets and the predictions.
+"""
+
+
+class DualEncoder(keras.Model):
+ def __init__(self, text_encoder, image_encoder, temperature=1.0, **kwargs):
+ super().__init__(**kwargs)
+ self.text_encoder = text_encoder
+ self.image_encoder = image_encoder
+ self.temperature = temperature
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+
+ @property
+ def metrics(self):
+ return [self.loss_tracker]
+
+ def call(self, features, training=False):
+ # Place each encoder on a separate GPU (if available).
+ # TF will fallback on available devices if there are fewer than 2 GPUs.
+ with tf.device("/gpu:0"):
+ # Get the embeddings for the captions.
+ caption_embeddings = text_encoder(features["caption"], training=training)
+ with tf.device("/gpu:1"):
+ # Get the embeddings for the images.
+ image_embeddings = vision_encoder(features["image"], training=training)
+ return caption_embeddings, image_embeddings
+
+ def compute_loss(self, caption_embeddings, image_embeddings):
+ # logits[i][j] is the dot_similarity(caption_i, image_j).
+ logits = (
+ tf.matmul(caption_embeddings, image_embeddings, transpose_b=True)
+ / self.temperature
+ )
+ # images_similarity[i][j] is the dot_similarity(image_i, image_j).
+ images_similarity = tf.matmul(
+ image_embeddings, image_embeddings, transpose_b=True
+ )
+ # captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).
+ captions_similarity = tf.matmul(
+ caption_embeddings, caption_embeddings, transpose_b=True
+ )
+ # targets[i][j] = avarage dot_similarity(caption_i, caption_j) and dot_similarity(image_i, image_j).
+ targets = keras.activations.softmax(
+ (captions_similarity + images_similarity) / (2 * self.temperature)
+ )
+ # Compute the loss for the captions using crossentropy
+ captions_loss = keras.losses.categorical_crossentropy(
+ y_true=targets, y_pred=logits, from_logits=True
+ )
+ # Compute the loss for the images using crossentropy
+ images_loss = keras.losses.categorical_crossentropy(
+ y_true=tf.transpose(targets), y_pred=tf.transpose(logits), from_logits=True
+ )
+ # Return the mean of the loss over the batch.
+ return (captions_loss + images_loss) / 2
+
+ def train_step(self, features):
+ with tf.GradientTape() as tape:
+ # Forward pass
+ caption_embeddings, image_embeddings = self(features, training=True)
+ loss = self.compute_loss(caption_embeddings, image_embeddings)
+ # Backward pass
+ gradients = tape.gradient(loss, self.trainable_variables)
+ self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
+ # Monitor loss
+ self.loss_tracker.update_state(loss)
+ return {"loss": self.loss_tracker.result()}
+
+ def test_step(self, features):
+ caption_embeddings, image_embeddings = self(features, training=False)
+ loss = self.compute_loss(caption_embeddings, image_embeddings)
+ self.loss_tracker.update_state(loss)
+ return {"loss": self.loss_tracker.result()}
+
+
+"""
+## Train the dual encoder model
+
+In this experiment, we freeze the base encoders for text and images, and make only
+the projection head trainable.
+"""
+
+num_epochs = 5 # In practice, train for at least 30 epochs
+batch_size = 256
+
+vision_encoder = create_vision_encoder(
+ num_projection_layers=1, projection_dims=256, dropout_rate=0.1
+)
+text_encoder = create_text_encoder(
+ num_projection_layers=1, projection_dims=256, dropout_rate=0.1
+)
+dual_encoder = DualEncoder(text_encoder, vision_encoder, temperature=0.05)
+dual_encoder.compile(
+ optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001)
+)
+
+"""
+Note that training the model with 60,000 image-caption pairs, with a batch size of 256,
+takes around 12 minutes per epoch using a V100 GPU accelerator. If 2 GPUs are available,
+the epoch takes around 8 minutes.
+"""
+
+print(f"Number of GPUs: {len(tf.config.list_physical_devices('GPU'))}")
+print(f"Number of examples (caption-image pairs): {train_example_count}")
+print(f"Batch size: {batch_size}")
+print(f"Steps per epoch: {int(np.ceil(train_example_count / batch_size))}")
+train_dataset = get_dataset(os.path.join(tfrecords_dir, "train-*.tfrecord"), batch_size)
+valid_dataset = get_dataset(os.path.join(tfrecords_dir, "valid-*.tfrecord"), batch_size)
+# Create a learning rate scheduler callback.
+reduce_lr = keras.callbacks.ReduceLROnPlateau(
+ monitor="val_loss", factor=0.2, patience=3
+)
+# Create an early stopping callback.
+early_stopping = tf.keras.callbacks.EarlyStopping(
+ monitor="val_loss", patience=5, restore_best_weights=True
+)
+history = dual_encoder.fit(
+ train_dataset,
+ epochs=num_epochs,
+ validation_data=valid_dataset,
+ callbacks=[reduce_lr, early_stopping],
+)
+print("Training completed. Saving vision and text encoders...")
+vision_encoder.save("vision_encoder")
+text_encoder.save("text_encoder")
+print("Models are saved.")
+
+"""
+Plotting the training loss:
+"""
+
+plt.plot(history.history["loss"])
+plt.plot(history.history["val_loss"])
+plt.ylabel("Loss")
+plt.xlabel("Epoch")
+plt.legend(["train", "valid"], loc="upper right")
+plt.show()
+
+"""
+## Search for images using natural language queries
+
+We can then retrieve images corresponding to natural language queries via
+the following steps:
+
+1. Generate embeddings for the images by feeding them into the `vision_encoder`.
+2. Feed the natural language query to the `text_encoder` to generate a query embedding.
+3. Compute the similarity between the query embedding and the image embeddings
+in the index to retrieve the indices of the top matches.
+4. Look up the paths of the top matching images to display them.
+
+Note that, after training the `dual encoder`, only the fine-tuned `vision_encoder`
+and `text_encoder` models will be used, while the `dual_encoder` model will be discarded.
+"""
+
+"""
+### Generate embeddings for the images
+
+We load the images and feed them into the `vision_encoder` to generate their embeddings.
+In large scale systems, this step is performed using a parallel data processing framework,
+such as [Apache Spark](https://spark.apache.org) or [Apache Beam](https://beam.apache.org).
+Generating the image embeddings may take several minutes.
+"""
+print("Loading vision and text encoders...")
+vision_encoder = keras.models.load_model("vision_encoder")
+text_encoder = keras.models.load_model("text_encoder")
+print("Models are loaded.")
+
+
+def read_image(image_path):
+ image_array = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)
+ return tf.image.resize(image_array, (299, 299))
+
+
+print(f"Generating embeddings for {len(image_paths)} images...")
+image_embeddings = vision_encoder.predict(
+ tf.data.Dataset.from_tensor_slices(image_paths).map(read_image).batch(batch_size),
+ verbose=1,
+)
+print(f"Image embeddings shape: {image_embeddings.shape}.")
+
+"""
+### Retrieve relevant images
+
+In this example, we use exact matching by computing the dot product similarity
+between the input query embedding and the image embeddings, and retrieve the top k
+matches. However, *approximate* similarity matching, using frameworks like
+[ScaNN](https://github.com/google-research/google-research/tree/master/scann),
+[Annoy](https://github.com/spotify/annoy), or [Faiss](https://github.com/facebookresearch/faiss)
+is preferred in real-time use cases to scale with a large number of images.
+"""
+
+
+def find_matches(image_embeddings, queries, k=9, normalize=True):
+ # Get the embedding for the query.
+ query_embedding = text_encoder(tf.convert_to_tensor(queries))
+ # Normalize the query and the image embeddings.
+ if normalize:
+ image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)
+ query_embedding = tf.math.l2_normalize(query_embedding, axis=1)
+ # Compute the dot product between the query and the image embeddings.
+ dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)
+ # Retrieve top k indices.
+ results = tf.math.top_k(dot_similarity, k).indices.numpy()
+ # Return matching image paths.
+ return [[image_paths[idx] for idx in indices] for indices in results]
+
+
+"""
+Set the `query` variable to the type of images you want to search for.
+Try things like: 'a plate of healthy food',
+'a woman wearing a hat is walking down a sidewalk',
+'a bird sits near to the water', or 'wild animals are standing in a field'.
+"""
+
+query = "a family standing next to the ocean on a sandy beach with a surf board"
+matches = find_matches(image_embeddings, [query], normalize=True)[0]
+
+plt.figure(figsize=(20, 20))
+for i in range(9):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(mpimg.imread(matches[i]))
+ plt.axis("off")
+
+
+"""
+## Evaluate the retrieval quality
+
+To evaluate the dual encoder model, we use the captions as queries.
+We use the out-of-training-sample images and captions to evaluate the retrieval quality,
+using top k accuracy. A true prediction is counted if, for a given caption, its associated image
+is retrieved within the top k matches.
+"""
+
+
+def compute_top_k_accuracy(image_paths, k=100):
+ hits = 0
+ num_batches = int(np.ceil(len(image_paths) / batch_size))
+ for idx in tqdm(range(num_batches)):
+ start_idx = idx * batch_size
+ end_idx = start_idx + batch_size
+ current_image_paths = image_paths[start_idx:end_idx]
+ queries = [
+ image_path_to_caption[image_path][0] for image_path in current_image_paths
+ ]
+ result = find_matches(image_embeddings, queries, k)
+ hits += sum(
+ [
+ image_path in matches
+ for (image_path, matches) in list(zip(current_image_paths, result))
+ ]
+ )
+
+ return hits / len(image_paths)
+
+
+print("Scoring training data...")
+train_accuracy = compute_top_k_accuracy(train_image_paths)
+print(f"Train accuracy: {round(train_accuracy * 100, 3)}%")
+
+print("Scoring evaluation data...")
+eval_accuracy = compute_top_k_accuracy(image_paths[train_size:])
+print(f"Eval accuracy: {round(eval_accuracy * 100, 3)}%")
+
+
+"""
+## Final remarks
+
+You can obtain better results by increasing the size of the training sample,
+train for more epochs, explore other base encoders for images and text,
+set the base encoders to be trainable, and tune the hyperparameters,
+especially the `temperature` for the softmax in the loss computation.
+
+Example available on HuggingFace
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/dual-encoder-image-search) | [](https://huggingface.co/spaces/keras-io/dual-encoder-image-search) |
+"""
diff --git a/knowledge_base/vision/nnclr.py b/knowledge_base/vision/nnclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..68c1443928aa331101692777f9f797f718d9856c
--- /dev/null
+++ b/knowledge_base/vision/nnclr.py
@@ -0,0 +1,527 @@
+"""
+Title: Self-supervised contrastive learning with NNCLR
+Author: [Rishit Dagli](https://twitter.com/rishit_dagli)
+Date created: 2021/09/13
+Last modified: 2024/01/22
+Description: Implementation of NNCLR, a self-supervised learning method for computer vision.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+### Self-supervised learning
+
+Self-supervised representation learning aims to obtain robust representations of samples
+from raw data without expensive labels or annotations. Early methods in this field
+focused on defining pretraining tasks which involved a surrogate task on a domain with ample
+weak supervision labels. Encoders trained to solve such tasks are expected to
+learn general features that might be useful for other downstream tasks requiring
+expensive annotations like image classification.
+
+### Contrastive Learning
+
+A broad category of self-supervised learning techniques are those that use *contrastive
+losses*, which have been used in a wide range of computer vision applications like
+[image similarity](https://www.jmlr.org/papers/v11/chechik10a.html),
+[dimensionality reduction (DrLIM)](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
+and [face verification/identification](https://openaccess.thecvf.com/content_cvpr_2015/html/Schroff_FaceNet_A_Unified_2015_CVPR_paper.html).
+These methods learn a latent space that clusters positive samples together while
+pushing apart negative samples.
+
+### NNCLR
+
+In this example, we implement NNCLR as proposed in the paper
+[With a Little Help from My Friends: Nearest-Neighbor Contrastive Learning of Visual Representations](https://arxiv.org/abs/2104.14548),
+by Google Research and DeepMind.
+
+NNCLR learns self-supervised representations that go beyond single-instance positives, which
+allows for learning better features that are invariant to different viewpoints, deformations,
+and even intra-class variations.
+Clustering based methods offer a great approach to go beyond single instance positives,
+but assuming the entire cluster to be positives could hurt performance due to early
+over-generalization. Instead, NNCLR uses nearest neighbors in the learned representation
+space as positives.
+In addition, NNCLR increases the performance of existing contrastive learning methods like
+[SimCLR](https://arxiv.org/abs/2002.05709)([Keras Example](https://keras.io/examples/vision/semisupervised_simclr))
+and reduces the reliance of self-supervised methods on data augmentation strategies.
+
+Here is a great visualization by the paper authors showing how NNCLR builds on ideas from
+SimCLR:
+
+
+
+We can see that SimCLR uses two views of the same image as the positive pair. These two
+views, which are produced using random data augmentations, are fed through an encoder to
+obtain the positive embedding pair, we end up using two augmentations. NNCLR instead
+keeps a _support set_ of embeddings representing the full data distribution, and forms
+the positive pairs using nearest-neighbours. A support set is used as memory during
+training, similar to a queue (i.e. first-in-first-out) as in
+[MoCo](https://arxiv.org/abs/1911.05722).
+
+This example requires `tensorflow_datasets`, which can
+be installed with this command:
+"""
+
+"""shell
+pip install tensorflow-datasets
+"""
+
+"""
+## Setup
+"""
+
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import tensorflow_datasets as tfds
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+import keras_cv
+from keras import ops
+from keras import layers
+
+"""
+## Hyperparameters
+
+A greater `queue_size` most likely means better performance as shown in the original
+paper, but introduces significant computational overhead. The authors show that the best
+results of NNCLR are achieved with a queue size of 98,304 (the largest `queue_size` they
+experimented on). We here use 10,000 to show a working example.
+"""
+
+AUTOTUNE = tf.data.AUTOTUNE
+shuffle_buffer = 5000
+# The below two values are taken from https://www.tensorflow.org/datasets/catalog/stl10
+labelled_train_images = 5000
+unlabelled_images = 100000
+
+temperature = 0.1
+queue_size = 10000
+contrastive_augmenter = {
+ "brightness": 0.5,
+ "name": "contrastive_augmenter",
+ "scale": (0.2, 1.0),
+}
+classification_augmenter = {
+ "brightness": 0.2,
+ "name": "classification_augmenter",
+ "scale": (0.5, 1.0),
+}
+input_shape = (96, 96, 3)
+width = 128
+num_epochs = 5 # Use 25 for better results
+steps_per_epoch = 50 # Use 200 for better results
+
+"""
+## Load the Dataset
+
+We load the [STL-10](http://ai.stanford.edu/~acoates/stl10/) dataset from
+TensorFlow Datasets, an image recognition dataset for developing unsupervised
+feature learning, deep learning, self-taught learning algorithms. It is inspired by the
+CIFAR-10 dataset, with some modifications.
+"""
+
+dataset_name = "stl10"
+
+
+def prepare_dataset():
+ unlabeled_batch_size = unlabelled_images // steps_per_epoch
+ labeled_batch_size = labelled_train_images // steps_per_epoch
+ batch_size = unlabeled_batch_size + labeled_batch_size
+
+ unlabeled_train_dataset = (
+ tfds.load(
+ dataset_name, split="unlabelled", as_supervised=True, shuffle_files=True
+ )
+ .shuffle(buffer_size=shuffle_buffer)
+ .batch(unlabeled_batch_size, drop_remainder=True)
+ )
+ labeled_train_dataset = (
+ tfds.load(dataset_name, split="train", as_supervised=True, shuffle_files=True)
+ .shuffle(buffer_size=shuffle_buffer)
+ .batch(labeled_batch_size, drop_remainder=True)
+ )
+ test_dataset = (
+ tfds.load(dataset_name, split="test", as_supervised=True)
+ .batch(batch_size)
+ .prefetch(buffer_size=AUTOTUNE)
+ )
+ train_dataset = tf.data.Dataset.zip(
+ (unlabeled_train_dataset, labeled_train_dataset)
+ ).prefetch(buffer_size=AUTOTUNE)
+
+ return batch_size, train_dataset, labeled_train_dataset, test_dataset
+
+
+batch_size, train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()
+
+"""
+## Augmentations
+
+Other self-supervised techniques like [SimCLR](https://arxiv.org/abs/2002.05709),
+[BYOL](https://arxiv.org/abs/2006.07733), [SwAV](https://arxiv.org/abs/2006.09882) etc.
+rely heavily on a well-designed data augmentation pipeline to get the best performance.
+However, NNCLR is _less_ dependent on complex augmentations as nearest-neighbors already
+provide richness in sample variations. A few common techniques often included
+augmentation pipelines are:
+
+- Random resized crops
+- Multiple color distortions
+- Gaussian blur
+
+Since NNCLR is less dependent on complex augmentations, we will only use random
+crops and random brightness for augmenting the input images.
+"""
+
+
+"""
+### Prepare augmentation module
+"""
+
+
+def augmenter(brightness, name, scale):
+ return keras.Sequential(
+ [
+ layers.Input(shape=input_shape),
+ layers.Rescaling(1 / 255),
+ layers.RandomFlip("horizontal"),
+ keras_cv.layers.RandomCropAndResize(
+ target_size=(input_shape[0], input_shape[1]),
+ crop_area_factor=scale,
+ aspect_ratio_factor=(3 / 4, 4 / 3),
+ ),
+ keras_cv.layers.RandomBrightness(factor=brightness, value_range=(0.0, 1.0)),
+ ],
+ name=name,
+ )
+
+
+"""
+### Encoder architecture
+
+Using a ResNet-50 as the encoder architecture
+is standard in the literature. In the original paper, the authors use ResNet-50 as
+the encoder architecture and spatially average the outputs of ResNet-50. However, keep in
+mind that more powerful models will not only increase training time but will also
+require more memory and will limit the maximal batch size you can use. For the purpose of
+this example, we just use four convolutional layers.
+"""
+
+
+def encoder():
+ return keras.Sequential(
+ [
+ layers.Input(shape=input_shape),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Flatten(),
+ layers.Dense(width, activation="relu"),
+ ],
+ name="encoder",
+ )
+
+
+"""
+## The NNCLR model for contrastive pre-training
+
+We train an encoder on unlabeled images with a contrastive loss. A nonlinear projection
+head is attached to the top of the encoder, as it improves the quality of representations
+of the encoder.
+"""
+
+
+class NNCLR(keras.Model):
+ def __init__(
+ self,
+ temperature,
+ queue_size,
+ ):
+ super().__init__()
+ self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
+ self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy()
+ self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy()
+ self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+
+ self.contrastive_augmenter = augmenter(**contrastive_augmenter)
+ self.classification_augmenter = augmenter(**classification_augmenter)
+ self.encoder = encoder()
+ self.projection_head = keras.Sequential(
+ [
+ layers.Input(shape=(width,)),
+ layers.Dense(width, activation="relu"),
+ layers.Dense(width),
+ ],
+ name="projection_head",
+ )
+ self.linear_probe = keras.Sequential(
+ [layers.Input(shape=(width,)), layers.Dense(10)], name="linear_probe"
+ )
+ self.temperature = temperature
+
+ feature_dimensions = self.encoder.output_shape[1]
+ self.feature_queue = keras.Variable(
+ keras.utils.normalize(
+ keras.random.normal(shape=(queue_size, feature_dimensions)),
+ axis=1,
+ order=2,
+ ),
+ trainable=False,
+ )
+
+ def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
+ super().compile(**kwargs)
+ self.contrastive_optimizer = contrastive_optimizer
+ self.probe_optimizer = probe_optimizer
+
+ def nearest_neighbour(self, projections):
+ support_similarities = ops.matmul(
+ projections, ops.transpose(self.feature_queue)
+ )
+ nn_projections = ops.take(
+ self.feature_queue, ops.argmax(support_similarities, axis=1), axis=0
+ )
+ return projections + ops.stop_gradient(nn_projections - projections)
+
+ def update_contrastive_accuracy(self, features_1, features_2):
+ features_1 = keras.utils.normalize(features_1, axis=1, order=2)
+ features_2 = keras.utils.normalize(features_2, axis=1, order=2)
+ similarities = ops.matmul(features_1, ops.transpose(features_2))
+ batch_size = ops.shape(features_1)[0]
+ contrastive_labels = ops.arange(batch_size)
+ self.contrastive_accuracy.update_state(
+ ops.concatenate([contrastive_labels, contrastive_labels], axis=0),
+ ops.concatenate([similarities, ops.transpose(similarities)], axis=0),
+ )
+
+ def update_correlation_accuracy(self, features_1, features_2):
+ features_1 = (features_1 - ops.mean(features_1, axis=0)) / ops.std(
+ features_1, axis=0
+ )
+ features_2 = (features_2 - ops.mean(features_2, axis=0)) / ops.std(
+ features_2, axis=0
+ )
+
+ batch_size = ops.shape(features_1)[0]
+ cross_correlation = (
+ ops.matmul(ops.transpose(features_1), features_2) / batch_size
+ )
+
+ feature_dim = ops.shape(features_1)[1]
+ correlation_labels = ops.arange(feature_dim)
+ self.correlation_accuracy.update_state(
+ ops.concatenate([correlation_labels, correlation_labels], axis=0),
+ ops.concatenate(
+ [cross_correlation, ops.transpose(cross_correlation)], axis=0
+ ),
+ )
+
+ def contrastive_loss(self, projections_1, projections_2):
+ projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
+ projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)
+
+ similarities_1_2_1 = (
+ ops.matmul(
+ self.nearest_neighbour(projections_1), ops.transpose(projections_2)
+ )
+ / self.temperature
+ )
+ similarities_1_2_2 = (
+ ops.matmul(
+ projections_2, ops.transpose(self.nearest_neighbour(projections_1))
+ )
+ / self.temperature
+ )
+
+ similarities_2_1_1 = ( #
+ ops.matmul(
+ self.nearest_neighbour(projections_2), ops.transpose(projections_1)
+ )
+ / self.temperature
+ )
+ similarities_2_1_2 = (
+ ops.matmul(
+ projections_1, ops.transpose(self.nearest_neighbour(projections_2))
+ )
+ / self.temperature
+ )
+
+ batch_size = ops.shape(projections_1)[0]
+ contrastive_labels = ops.arange(batch_size)
+ loss = keras.losses.sparse_categorical_crossentropy(
+ ops.concatenate(
+ [
+ contrastive_labels,
+ contrastive_labels,
+ contrastive_labels,
+ contrastive_labels,
+ ],
+ axis=0,
+ ),
+ ops.concatenate(
+ [
+ similarities_1_2_1,
+ similarities_1_2_2,
+ similarities_2_1_1,
+ similarities_2_1_2,
+ ],
+ axis=0,
+ ),
+ from_logits=True,
+ )
+
+ self.feature_queue.assign(
+ ops.concatenate([projections_1, self.feature_queue[:-batch_size]], axis=0)
+ )
+ return loss
+
+ def train_step(self, data):
+ (unlabeled_images, _), (labeled_images, labels) = data
+ images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
+ augmented_images_1 = self.contrastive_augmenter(images)
+ augmented_images_2 = self.contrastive_augmenter(images)
+
+ with tf.GradientTape() as tape:
+ features_1 = self.encoder(augmented_images_1)
+ features_2 = self.encoder(augmented_images_2)
+ projections_1 = self.projection_head(features_1)
+ projections_2 = self.projection_head(features_2)
+ contrastive_loss = self.contrastive_loss(projections_1, projections_2)
+ gradients = tape.gradient(
+ contrastive_loss,
+ self.encoder.trainable_weights + self.projection_head.trainable_weights,
+ )
+ self.contrastive_optimizer.apply_gradients(
+ zip(
+ gradients,
+ self.encoder.trainable_weights + self.projection_head.trainable_weights,
+ )
+ )
+ self.update_contrastive_accuracy(features_1, features_2)
+ self.update_correlation_accuracy(features_1, features_2)
+ preprocessed_images = self.classification_augmenter(labeled_images)
+
+ with tf.GradientTape() as tape:
+ features = self.encoder(preprocessed_images)
+ class_logits = self.linear_probe(features)
+ probe_loss = self.probe_loss(labels, class_logits)
+ gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
+ self.probe_optimizer.apply_gradients(
+ zip(gradients, self.linear_probe.trainable_weights)
+ )
+ self.probe_accuracy.update_state(labels, class_logits)
+
+ return {
+ "c_loss": contrastive_loss,
+ "c_acc": self.contrastive_accuracy.result(),
+ "r_acc": self.correlation_accuracy.result(),
+ "p_loss": probe_loss,
+ "p_acc": self.probe_accuracy.result(),
+ }
+
+ def test_step(self, data):
+ labeled_images, labels = data
+
+ preprocessed_images = self.classification_augmenter(
+ labeled_images, training=False
+ )
+ features = self.encoder(preprocessed_images, training=False)
+ class_logits = self.linear_probe(features, training=False)
+ probe_loss = self.probe_loss(labels, class_logits)
+
+ self.probe_accuracy.update_state(labels, class_logits)
+ return {"p_loss": probe_loss, "p_acc": self.probe_accuracy.result()}
+
+
+"""
+## Pre-train NNCLR
+
+We train the network using a `temperature` of 0.1 as suggested in the paper and
+a `queue_size` of 10,000 as explained earlier. We use Adam as our contrastive and probe
+optimizer. For this example we train the model for only 30 epochs but it should be
+trained for more epochs for better performance.
+
+The following two metrics can be used for monitoring the pretraining performance
+which we also log (taken from
+[this Keras example](https://keras.io/examples/vision/semisupervised_simclr/#selfsupervised-model-for-contrastive-pretraining)):
+
+- Contrastive accuracy: self-supervised metric, the ratio of cases in which the
+representation of an image is more similar to its differently augmented version's one,
+than to the representation of any other image in the current batch. Self-supervised
+metrics can be used for hyperparameter tuning even in the case when there are no labeled
+examples.
+- Linear probing accuracy: linear probing is a popular metric to evaluate self-supervised
+classifiers. It is computed as the accuracy of a logistic regression classifier trained
+on top of the encoder's features. In our case, this is done by training a single dense
+layer on top of the frozen encoder. Note that contrary to traditional approach where the
+classifier is trained after the pretraining phase, in this example we train it during
+pretraining. This might slightly decrease its accuracy, but that way we can monitor its
+value during training, which helps with experimentation and debugging.
+"""
+
+model = NNCLR(temperature=temperature, queue_size=queue_size)
+model.compile(
+ contrastive_optimizer=keras.optimizers.Adam(),
+ probe_optimizer=keras.optimizers.Adam(),
+ jit_compile=False,
+)
+pretrain_history = model.fit(
+ train_dataset, epochs=num_epochs, validation_data=test_dataset
+)
+
+"""
+## Evaluate our model
+
+A popular way to evaluate a SSL method in computer vision or for that fact any other
+pre-training method as such is to learn a linear classifier on the frozen features of the
+trained backbone model and evaluate the classifier on unseen images. Other methods often
+include fine-tuning on the source dataset or even a target dataset with 5% or 10% labels
+present. You can use the backbone we just trained for any downstream task such as image
+classification (like we do here) or segmentation or detection, where the backbone models
+are usually pre-trained with supervised learning.
+"""
+
+finetuning_model = keras.Sequential(
+ [
+ layers.Input(shape=input_shape),
+ augmenter(**classification_augmenter),
+ model.encoder,
+ layers.Dense(10),
+ ],
+ name="finetuning_model",
+)
+finetuning_model.compile(
+ optimizer=keras.optimizers.Adam(),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
+ jit_compile=False,
+)
+
+finetuning_history = finetuning_model.fit(
+ labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
+)
+
+"""
+Self supervised learning is particularly helpful when you do only have access to very
+limited labeled training data but you can manage to build a large corpus of unlabeled
+data as shown by previous methods like [SEER](https://arxiv.org/abs/2103.01988),
+[SimCLR](https://arxiv.org/abs/2002.05709), [SwAV](https://arxiv.org/abs/2006.09882) and
+more.
+
+You should also take a look at the blog posts for these papers which neatly show that it is
+possible to achieve good results with few class labels by first pretraining on a large
+unlabeled dataset and then fine-tuning on a smaller labeled dataset:
+
+- [Advancing Self-Supervised and Semi-Supervised Learning with SimCLR](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html)
+- [High-performance self-supervised image classification with contrastive clustering](https://ai.facebook.com/blog/high-performance-self-supervised-image-classification-with-contrastive-clustering/)
+- [Self-supervised learning: The dark matter of intelligence](https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/)
+
+You are also advised to check out the [original paper](https://arxiv.org/abs/2104.14548).
+
+*Many thanks to [Debidatta Dwibedi](https://twitter.com/debidatta) (Google Research),
+primary author of the NNCLR paper for his super-insightful reviews for this example.
+This example also takes inspiration from the [SimCLR Keras Example](https://keras.io/examples/vision/semisupervised_simclr/).*
+"""
diff --git a/knowledge_base/vision/object_detection_using_vision_transformer.py b/knowledge_base/vision/object_detection_using_vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb93f23a4673f394f87d7e4cc4310584caf78214
--- /dev/null
+++ b/knowledge_base/vision/object_detection_using_vision_transformer.py
@@ -0,0 +1,521 @@
+"""
+Title: Object detection with Vision Transformers
+Author: [Karan V. Dave](https://www.linkedin.com/in/karan-dave-811413164/)
+Date created: 2022/03/27
+Last modified: 2023/11/20
+Description: A simple Keras implementation of object detection using Vision Transformers.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+The article
+[Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
+architecture by Alexey Dosovitskiy et al.
+demonstrates that a pure transformer applied directly to sequences of image
+patches can perform well on object detection tasks.
+
+In this Keras example, we implement an object detection ViT
+and we train it on the
+[Caltech 101 dataset](http://www.vision.caltech.edu/datasets/)
+to detect an airplane in the given image.
+"""
+
+"""
+## Imports and setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "jax" # @param ["tensorflow", "jax", "torch"]
+
+
+import numpy as np
+import keras
+from keras import layers
+from keras import ops
+import matplotlib.pyplot as plt
+import numpy as np
+import cv2
+import os
+import scipy.io
+import shutil
+
+"""
+## Prepare dataset
+
+We use the [Caltech 101 Dataset](https://data.caltech.edu/records/mzrjq-6wc02).
+"""
+
+# Path to images and annotations
+path_images = "./101_ObjectCategories/airplanes/"
+path_annot = "./Annotations/Airplanes_Side_2/"
+
+path_to_downloaded_file = keras.utils.get_file(
+ fname="caltech_101_zipped",
+ origin="https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip",
+ extract=True,
+ archive_format="zip", # downloaded file format
+ cache_dir="/", # cache and extract in current directory
+)
+download_base_dir = os.path.dirname(path_to_downloaded_file)
+
+# Extracting tar files found inside main zip file
+shutil.unpack_archive(
+ os.path.join(download_base_dir, "caltech-101", "101_ObjectCategories.tar.gz"), "."
+)
+shutil.unpack_archive(
+ os.path.join(download_base_dir, "caltech-101", "Annotations.tar"), "."
+)
+
+# list of paths to images and annotations
+image_paths = [
+ f for f in os.listdir(path_images) if os.path.isfile(os.path.join(path_images, f))
+]
+annot_paths = [
+ f for f in os.listdir(path_annot) if os.path.isfile(os.path.join(path_annot, f))
+]
+
+image_paths.sort()
+annot_paths.sort()
+
+image_size = 224 # resize input images to this size
+
+images, targets = [], []
+
+# loop over the annotations and images, preprocess them and store in lists
+for i in range(0, len(annot_paths)):
+ # Access bounding box coordinates
+ annot = scipy.io.loadmat(path_annot + annot_paths[i])["box_coord"][0]
+
+ top_left_x, top_left_y = annot[2], annot[0]
+ bottom_right_x, bottom_right_y = annot[3], annot[1]
+
+ image = keras.utils.load_img(
+ path_images + image_paths[i],
+ )
+ (w, h) = image.size[:2]
+
+ # resize images
+ image = image.resize((image_size, image_size))
+
+ # convert image to array and append to list
+ images.append(keras.utils.img_to_array(image))
+
+ # apply relative scaling to bounding boxes as per given image and append to list
+ targets.append(
+ (
+ float(top_left_x) / w,
+ float(top_left_y) / h,
+ float(bottom_right_x) / w,
+ float(bottom_right_y) / h,
+ )
+ )
+
+# Convert the list to numpy array, split to train and test dataset
+(x_train), (y_train) = (
+ np.asarray(images[: int(len(images) * 0.8)]),
+ np.asarray(targets[: int(len(targets) * 0.8)]),
+)
+(x_test), (y_test) = (
+ np.asarray(images[int(len(images) * 0.8) :]),
+ np.asarray(targets[int(len(targets) * 0.8) :]),
+)
+
+"""
+## Implement multilayer-perceptron (MLP)
+
+We use the code from the Keras example
+[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/)
+as a reference.
+"""
+
+
+def mlp(x, hidden_units, dropout_rate):
+ for units in hidden_units:
+ x = layers.Dense(units, activation=keras.activations.gelu)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+"""
+## Implement the patch creation layer
+"""
+
+
+class Patches(layers.Layer):
+ def __init__(self, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+
+ def call(self, images):
+ input_shape = ops.shape(images)
+ batch_size = input_shape[0]
+ height = input_shape[1]
+ width = input_shape[2]
+ channels = input_shape[3]
+ num_patches_h = height // self.patch_size
+ num_patches_w = width // self.patch_size
+ patches = keras.ops.image.extract_patches(images, size=self.patch_size)
+ patches = ops.reshape(
+ patches,
+ (
+ batch_size,
+ num_patches_h * num_patches_w,
+ self.patch_size * self.patch_size * channels,
+ ),
+ )
+ return patches
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"patch_size": self.patch_size})
+ return config
+
+
+"""
+## Display patches for an input image
+"""
+
+patch_size = 32 # Size of the patches to be extracted from the input images
+
+plt.figure(figsize=(4, 4))
+plt.imshow(x_train[0].astype("uint8"))
+plt.axis("off")
+
+patches = Patches(patch_size)(np.expand_dims(x_train[0], axis=0))
+print(f"Image size: {image_size} X {image_size}")
+print(f"Patch size: {patch_size} X {patch_size}")
+print(f"{patches.shape[1]} patches per image \n{patches.shape[-1]} elements per patch")
+
+
+n = int(np.sqrt(patches.shape[1]))
+plt.figure(figsize=(4, 4))
+for i, patch in enumerate(patches[0]):
+ ax = plt.subplot(n, n, i + 1)
+ patch_img = ops.reshape(patch, (patch_size, patch_size, 3))
+ plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
+ plt.axis("off")
+
+"""
+## Implement the patch encoding layer
+
+The `PatchEncoder` layer linearly transforms a patch by projecting it into a
+vector of size `projection_dim`. It also adds a learnable position
+embedding to the projected vector.
+"""
+
+
+class PatchEncoder(layers.Layer):
+ def __init__(self, num_patches, projection_dim):
+ super().__init__()
+ self.num_patches = num_patches
+ self.projection = layers.Dense(units=projection_dim)
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+
+ # Override function to avoid error while saving model
+ def get_config(self):
+ config = super().get_config().copy()
+ config.update(
+ {
+ "input_shape": input_shape,
+ "patch_size": patch_size,
+ "num_patches": num_patches,
+ "projection_dim": projection_dim,
+ "num_heads": num_heads,
+ "transformer_units": transformer_units,
+ "transformer_layers": transformer_layers,
+ "mlp_head_units": mlp_head_units,
+ }
+ )
+ return config
+
+ def call(self, patch):
+ positions = ops.expand_dims(
+ ops.arange(start=0, stop=self.num_patches, step=1), axis=0
+ )
+ projected_patches = self.projection(patch)
+ encoded = projected_patches + self.position_embedding(positions)
+ return encoded
+
+
+"""
+## Build the ViT model
+
+The ViT model has multiple Transformer blocks.
+The `MultiHeadAttention` layer is used for self-attention,
+applied to the sequence of image patches. The encoded patches (skip connection)
+and self-attention layer outputs are normalized and fed into a
+multilayer perceptron (MLP).
+The model outputs four dimensions representing
+the bounding box coordinates of an object.
+"""
+
+
+def create_vit_object_detector(
+ input_shape,
+ patch_size,
+ num_patches,
+ projection_dim,
+ num_heads,
+ transformer_units,
+ transformer_layers,
+ mlp_head_units,
+):
+ inputs = keras.Input(shape=input_shape)
+ # Create patches
+ patches = Patches(patch_size)(inputs)
+ # Encode patches
+ encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
+
+ # Create multiple layers of the Transformer block.
+ for _ in range(transformer_layers):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ # Create a multi-head attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=projection_dim, dropout=0.1
+ )(x1, x1)
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, encoded_patches])
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ # MLP
+ x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
+ # Skip connection 2.
+ encoded_patches = layers.Add()([x3, x2])
+
+ # Create a [batch_size, projection_dim] tensor.
+ representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ representation = layers.Flatten()(representation)
+ representation = layers.Dropout(0.3)(representation)
+ # Add MLP.
+ features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.3)
+
+ bounding_box = layers.Dense(4)(
+ features
+ ) # Final four neurons that output bounding box
+
+ # return Keras model.
+ return keras.Model(inputs=inputs, outputs=bounding_box)
+
+
+"""
+## Run the experiment
+"""
+
+
+def run_experiment(model, learning_rate, weight_decay, batch_size, num_epochs):
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ )
+
+ # Compile model.
+ model.compile(optimizer=optimizer, loss=keras.losses.MeanSquaredError())
+
+ checkpoint_filepath = "vit_object_detector.weights.h5"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_loss",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=0.1,
+ callbacks=[
+ checkpoint_callback,
+ keras.callbacks.EarlyStopping(monitor="val_loss", patience=10),
+ ],
+ )
+
+ return history
+
+
+input_shape = (image_size, image_size, 3) # input image shape
+learning_rate = 0.001
+weight_decay = 0.0001
+batch_size = 32
+num_epochs = 100
+num_patches = (image_size // patch_size) ** 2
+projection_dim = 64
+num_heads = 4
+# Size of the transformer layers
+transformer_units = [
+ projection_dim * 2,
+ projection_dim,
+]
+transformer_layers = 4
+mlp_head_units = [2048, 1024, 512, 64, 32] # Size of the dense layers
+
+
+history = []
+num_patches = (image_size // patch_size) ** 2
+
+vit_object_detector = create_vit_object_detector(
+ input_shape,
+ patch_size,
+ num_patches,
+ projection_dim,
+ num_heads,
+ transformer_units,
+ transformer_layers,
+ mlp_head_units,
+)
+
+# Train model
+history = run_experiment(
+ vit_object_detector, learning_rate, weight_decay, batch_size, num_epochs
+)
+
+
+def plot_history(item):
+ plt.plot(history.history[item], label=item)
+ plt.plot(history.history["val_" + item], label="val_" + item)
+ plt.xlabel("Epochs")
+ plt.ylabel(item)
+ plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_history("loss")
+
+
+"""
+## Evaluate the model
+"""
+
+import matplotlib.patches as patches
+
+# Saves the model in current path
+vit_object_detector.save("vit_object_detector.keras")
+
+
+# To calculate IoU (intersection over union, given two bounding boxes)
+def bounding_box_intersection_over_union(box_predicted, box_truth):
+ # get (x, y) coordinates of intersection of bounding boxes
+ top_x_intersect = max(box_predicted[0], box_truth[0])
+ top_y_intersect = max(box_predicted[1], box_truth[1])
+ bottom_x_intersect = min(box_predicted[2], box_truth[2])
+ bottom_y_intersect = min(box_predicted[3], box_truth[3])
+
+ # calculate area of the intersection bb (bounding box)
+ intersection_area = max(0, bottom_x_intersect - top_x_intersect + 1) * max(
+ 0, bottom_y_intersect - top_y_intersect + 1
+ )
+
+ # calculate area of the prediction bb and ground-truth bb
+ box_predicted_area = (box_predicted[2] - box_predicted[0] + 1) * (
+ box_predicted[3] - box_predicted[1] + 1
+ )
+ box_truth_area = (box_truth[2] - box_truth[0] + 1) * (
+ box_truth[3] - box_truth[1] + 1
+ )
+
+ # calculate intersection over union by taking intersection
+ # area and dividing it by the sum of predicted bb and ground truth
+ # bb areas subtracted by the interesection area
+
+ # return ioU
+ return intersection_area / float(
+ box_predicted_area + box_truth_area - intersection_area
+ )
+
+
+i, mean_iou = 0, 0
+
+# Compare results for 10 images in the test set
+for input_image in x_test[:10]:
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 15))
+ im = input_image
+
+ # Display the image
+ ax1.imshow(im.astype("uint8"))
+ ax2.imshow(im.astype("uint8"))
+
+ input_image = cv2.resize(
+ input_image, (image_size, image_size), interpolation=cv2.INTER_AREA
+ )
+ input_image = np.expand_dims(input_image, axis=0)
+ preds = vit_object_detector.predict(input_image)[0]
+
+ (h, w) = (im).shape[0:2]
+
+ top_left_x, top_left_y = int(preds[0] * w), int(preds[1] * h)
+
+ bottom_right_x, bottom_right_y = int(preds[2] * w), int(preds[3] * h)
+
+ box_predicted = [top_left_x, top_left_y, bottom_right_x, bottom_right_y]
+ # Create the bounding box
+ rect = patches.Rectangle(
+ (top_left_x, top_left_y),
+ bottom_right_x - top_left_x,
+ bottom_right_y - top_left_y,
+ facecolor="none",
+ edgecolor="red",
+ linewidth=1,
+ )
+ # Add the bounding box to the image
+ ax1.add_patch(rect)
+ ax1.set_xlabel(
+ "Predicted: "
+ + str(top_left_x)
+ + ", "
+ + str(top_left_y)
+ + ", "
+ + str(bottom_right_x)
+ + ", "
+ + str(bottom_right_y)
+ )
+
+ top_left_x, top_left_y = int(y_test[i][0] * w), int(y_test[i][1] * h)
+
+ bottom_right_x, bottom_right_y = int(y_test[i][2] * w), int(y_test[i][3] * h)
+
+ box_truth = top_left_x, top_left_y, bottom_right_x, bottom_right_y
+
+ mean_iou += bounding_box_intersection_over_union(box_predicted, box_truth)
+ # Create the bounding box
+ rect = patches.Rectangle(
+ (top_left_x, top_left_y),
+ bottom_right_x - top_left_x,
+ bottom_right_y - top_left_y,
+ facecolor="none",
+ edgecolor="red",
+ linewidth=1,
+ )
+ # Add the bounding box to the image
+ ax2.add_patch(rect)
+ ax2.set_xlabel(
+ "Target: "
+ + str(top_left_x)
+ + ", "
+ + str(top_left_y)
+ + ", "
+ + str(bottom_right_x)
+ + ", "
+ + str(bottom_right_y)
+ + "\n"
+ + "IoU"
+ + str(bounding_box_intersection_over_union(box_predicted, box_truth))
+ )
+ i = i + 1
+
+print("mean_iou: " + str(mean_iou / len(x_test[:10])))
+plt.show()
+
+"""
+This example demonstrates that a pure Transformer can be trained
+to predict the bounding boxes of an object in a given image,
+thus extending the use of Transformers to object detection tasks.
+The model can be improved further by tuning hyper-parameters and pre-training.
+"""
diff --git a/knowledge_base/vision/oxford_pets_image_segmentation.py b/knowledge_base/vision/oxford_pets_image_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f624a18d86754b524bab14fa348719750aaf097b
--- /dev/null
+++ b/knowledge_base/vision/oxford_pets_image_segmentation.py
@@ -0,0 +1,270 @@
+"""
+Title: Image segmentation with a U-Net-like architecture
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2019/03/20
+Last modified: 2020/04/20
+Description: Image segmentation model trained from scratch on the Oxford Pets dataset.
+Accelerator: GPU
+"""
+
+"""
+## Download the data
+"""
+
+"""shell
+!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
+!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
+
+curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz
+curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz
+
+tar -xf images.tar.gz
+tar -xf annotations.tar.gz
+"""
+
+"""
+## Prepare paths of input images and target segmentation masks
+"""
+
+import os
+
+input_dir = "images/"
+target_dir = "annotations/trimaps/"
+img_size = (160, 160)
+num_classes = 3
+batch_size = 32
+
+input_img_paths = sorted(
+ [
+ os.path.join(input_dir, fname)
+ for fname in os.listdir(input_dir)
+ if fname.endswith(".jpg")
+ ]
+)
+target_img_paths = sorted(
+ [
+ os.path.join(target_dir, fname)
+ for fname in os.listdir(target_dir)
+ if fname.endswith(".png") and not fname.startswith(".")
+ ]
+)
+
+print("Number of samples:", len(input_img_paths))
+
+for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
+ print(input_path, "|", target_path)
+
+"""
+## What does one input image and corresponding segmentation mask look like?
+"""
+
+from IPython.display import Image, display
+from keras.utils import load_img
+from PIL import ImageOps
+
+# Display input image #7
+display(Image(filename=input_img_paths[9]))
+
+# Display auto-contrast version of corresponding target (per-pixel categories)
+img = ImageOps.autocontrast(load_img(target_img_paths[9]))
+display(img)
+
+"""
+## Prepare dataset to load & vectorize batches of data
+"""
+
+import keras
+import numpy as np
+from tensorflow import data as tf_data
+from tensorflow import image as tf_image
+from tensorflow import io as tf_io
+
+
+def get_dataset(
+ batch_size,
+ img_size,
+ input_img_paths,
+ target_img_paths,
+ max_dataset_len=None,
+):
+ """Returns a TF Dataset."""
+
+ def load_img_masks(input_img_path, target_img_path):
+ input_img = tf_io.read_file(input_img_path)
+ input_img = tf_io.decode_png(input_img, channels=3)
+ input_img = tf_image.resize(input_img, img_size)
+ input_img = tf_image.convert_image_dtype(input_img, "float32")
+
+ target_img = tf_io.read_file(target_img_path)
+ target_img = tf_io.decode_png(target_img, channels=1)
+ target_img = tf_image.resize(target_img, img_size, method="nearest")
+ target_img = tf_image.convert_image_dtype(target_img, "uint8")
+
+ # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
+ target_img -= 1
+ return input_img, target_img
+
+ # For faster debugging, limit the size of data
+ if max_dataset_len:
+ input_img_paths = input_img_paths[:max_dataset_len]
+ target_img_paths = target_img_paths[:max_dataset_len]
+ dataset = tf_data.Dataset.from_tensor_slices((input_img_paths, target_img_paths))
+ dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE)
+ return dataset.batch(batch_size)
+
+
+"""
+## Prepare U-Net Xception-style model
+"""
+
+from keras import layers
+
+
+def get_model(img_size, num_classes):
+ inputs = keras.Input(shape=img_size + (3,))
+
+ ### [First half of the network: downsampling inputs] ###
+
+ # Entry block
+ x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
+ x = layers.BatchNormalization()(x)
+ x = layers.Activation("relu")(x)
+
+ previous_block_activation = x # Set aside residual
+
+ # Blocks 1, 2, 3 are identical apart from the feature depth.
+ for filters in [64, 128, 256]:
+ x = layers.Activation("relu")(x)
+ x = layers.SeparableConv2D(filters, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Activation("relu")(x)
+ x = layers.SeparableConv2D(filters, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
+
+ # Project residual
+ residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
+ previous_block_activation
+ )
+ x = layers.add([x, residual]) # Add back residual
+ previous_block_activation = x # Set aside next residual
+
+ ### [Second half of the network: upsampling inputs] ###
+
+ for filters in [256, 128, 64, 32]:
+ x = layers.Activation("relu")(x)
+ x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.Activation("relu")(x)
+ x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+
+ x = layers.UpSampling2D(2)(x)
+
+ # Project residual
+ residual = layers.UpSampling2D(2)(previous_block_activation)
+ residual = layers.Conv2D(filters, 1, padding="same")(residual)
+ x = layers.add([x, residual]) # Add back residual
+ previous_block_activation = x # Set aside next residual
+
+ # Add a per-pixel classification layer
+ outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)
+
+ # Define the model
+ model = keras.Model(inputs, outputs)
+ return model
+
+
+# Build model
+model = get_model(img_size, num_classes)
+model.summary()
+
+"""
+## Set aside a validation split
+"""
+
+import random
+
+# Split our img paths into a training and a validation set
+val_samples = 1000
+random.Random(1337).shuffle(input_img_paths)
+random.Random(1337).shuffle(target_img_paths)
+train_input_img_paths = input_img_paths[:-val_samples]
+train_target_img_paths = target_img_paths[:-val_samples]
+val_input_img_paths = input_img_paths[-val_samples:]
+val_target_img_paths = target_img_paths[-val_samples:]
+
+# Instantiate dataset for each split
+# Limit input files in `max_dataset_len` for faster epoch training time.
+# Remove the `max_dataset_len` arg when running with full dataset.
+train_dataset = get_dataset(
+ batch_size,
+ img_size,
+ train_input_img_paths,
+ train_target_img_paths,
+ max_dataset_len=1000,
+)
+valid_dataset = get_dataset(
+ batch_size, img_size, val_input_img_paths, val_target_img_paths
+)
+
+"""
+## Train the model
+"""
+
+# Configure the model for training.
+# We use the "sparse" version of categorical_crossentropy
+# because our target data is integers.
+model.compile(
+ optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy"
+)
+
+callbacks = [
+ keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True)
+]
+
+# Train the model, doing validation at the end of each epoch.
+epochs = 50
+model.fit(
+ train_dataset,
+ epochs=epochs,
+ validation_data=valid_dataset,
+ callbacks=callbacks,
+ verbose=2,
+)
+
+"""
+## Visualize predictions
+"""
+
+# Generate predictions for all images in the validation set
+
+val_dataset = get_dataset(
+ batch_size, img_size, val_input_img_paths, val_target_img_paths
+)
+val_preds = model.predict(val_dataset)
+
+
+def display_mask(i):
+ """Quick utility to display a model's prediction."""
+ mask = np.argmax(val_preds[i], axis=-1)
+ mask = np.expand_dims(mask, axis=-1)
+ img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
+ display(img)
+
+
+# Display results for validation image #10
+i = 10
+
+# Display input image
+display(Image(filename=val_input_img_paths[i]))
+
+# Display ground-truth target mask
+img = ImageOps.autocontrast(load_img(val_target_img_paths[i]))
+display(img)
+
+# Display mask predicted by our model
+display_mask(i) # Note that the model only sees inputs at 150x150.
diff --git a/knowledge_base/vision/patch_convnet.py b/knowledge_base/vision/patch_convnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e364bc4ce1f969980bb318772b63661637619e
--- /dev/null
+++ b/knowledge_base/vision/patch_convnet.py
@@ -0,0 +1,690 @@
+"""
+Title: Augmenting convnets with aggregated attention
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
+Date created: 2022/01/22
+Last modified: 2022/01/22
+Description: Building a patch-convnet architecture and visualizing its attention maps.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Vision transformers ([Dosovitskiy et. al](https://arxiv.org/abs/2010.11929))
+have emerged as a powerful alternative to Convolutional Neural Networks.
+ViTs process the images in a patch-based manner. The image information
+is then aggregated into a `CLASS` token. This token correlates to the
+most important patches of the image for a particular classification decision.
+
+The interaction between the `CLASS` token and the patches can be visualized
+to help explain a classification decision. In the academic paper
+[Augmenting convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692)
+by Touvron et. al, the authors propose to set up an equivalent visualization for
+convnets. They propose to substitute the global average pooling layer
+of a convnet with a Transformer layer. The self-attention layer of the
+Transformer would produce attention maps that correspond to the
+most attended patches of the image for the classification decision.
+
+In this example, we minimally implement the ideas of
+[Augmenting Convolutional networks with attention-based aggregation](https://arxiv.org/abs/2112.13692).
+The main goal of this example is to cover the following ideas, with
+minor modifications (to adjust the implementation with CIFAR10):
+
+- The simple design for the attention-based pooling layer, such that
+ it explicitly provides the weights (importance) of the different
+ patches.
+- The novel architecture of convnet is called the **PatchConvNet** which
+ deviates from the age old pyramidal architecture.
+"""
+
+"""
+## Setup and Imports
+
+This example requires TensorFlow Addons, which can be installed using
+the following command:
+
+```shell
+pip install -U tensorflow-addons
+```
+"""
+
+import math
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+import matplotlib.pyplot as plt
+import keras
+from keras import layers
+from keras import ops
+from tensorflow import data as tf_data
+
+# Set seed for reproducibiltiy
+SEED = 42
+keras.utils.set_random_seed(SEED)
+
+"""
+## Hyperparameters
+"""
+
+# DATA
+BATCH_SIZE = 128
+BUFFER_SIZE = BATCH_SIZE * 2
+AUTO = tf_data.AUTOTUNE
+INPUT_SHAPE = (32, 32, 3)
+NUM_CLASSES = 10 # for CIFAR 10
+
+# AUGMENTATION
+IMAGE_SIZE = 48 # We will resize input images to this size.
+
+# ARCHITECTURE
+DIMENSIONS = 256
+SE_RATIO = 8
+TRUNK_DEPTH = 2
+
+# OPTIMIZER
+LEARNING_RATE = 1e-3
+WEIGHT_DECAY = 1e-4
+
+# PRETRAINING
+EPOCHS = 50
+
+"""
+## Load the CIFAR10 dataset
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:40000], y_train[:40000]),
+ (x_train[40000:], y_train[40000:]),
+)
+print(f"Training samples: {len(x_train)}")
+print(f"Validation samples: {len(x_val)}")
+print(f"Testing samples: {len(x_test)}")
+
+train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)
+
+val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+"""
+## Augmentation layers
+"""
+
+
+def get_preprocessing():
+ model = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
+ ],
+ name="preprocessing",
+ )
+ return model
+
+
+def get_train_augmentation_model():
+ model = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
+ layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
+ layers.RandomFlip("horizontal"),
+ ],
+ name="train_data_augmentation",
+ )
+ return model
+
+
+"""
+## Convolutional stem
+
+The stem of the model is a lightweight preprocessing module that
+maps images pixels to a set of vectors (patches).
+"""
+
+
+def build_convolutional_stem(dimensions):
+ """Build the convolutional stem.
+
+ Args:
+ dimensions: The embedding dimension of the patches (d in paper).
+
+ Returs:
+ The convolutional stem as a keras seqeuntial
+ model.
+ """
+ config = {
+ "kernel_size": (3, 3),
+ "strides": (2, 2),
+ "activation": ops.gelu,
+ "padding": "same",
+ }
+
+ convolutional_stem = keras.Sequential(
+ [
+ layers.Conv2D(filters=dimensions // 2, **config),
+ layers.Conv2D(filters=dimensions, **config),
+ ],
+ name="convolutional_stem",
+ )
+
+ return convolutional_stem
+
+
+"""
+## Convolutional trunk
+
+The trunk of the model is the most compute-intesive part. It consists
+of `N` stacked residual convolutional blocks.
+"""
+
+
+class SqueezeExcite(layers.Layer):
+ """Applies squeeze and excitation to input feature maps as seen in
+ https://arxiv.org/abs/1709.01507.
+
+ Args:
+ ratio: The ratio with which the feature map needs to be reduced in
+ the reduction phase.
+
+ Inputs:
+ Convolutional features.
+
+ Outputs:
+ Attention modified feature maps.
+ """
+
+ def __init__(self, ratio, **kwargs):
+ super().__init__(**kwargs)
+ self.ratio = ratio
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"ratio": self.ratio})
+ return config
+
+ def build(self, input_shape):
+ filters = input_shape[-1]
+ self.squeeze = layers.GlobalAveragePooling2D(keepdims=True)
+ self.reduction = layers.Dense(
+ units=filters // self.ratio,
+ activation="relu",
+ use_bias=False,
+ )
+ self.excite = layers.Dense(units=filters, activation="sigmoid", use_bias=False)
+ self.multiply = layers.Multiply()
+
+ def call(self, x):
+ shortcut = x
+ x = self.squeeze(x)
+ x = self.reduction(x)
+ x = self.excite(x)
+ x = self.multiply([shortcut, x])
+ return x
+
+
+class Trunk(layers.Layer):
+ """Convolutional residual trunk as in the https://arxiv.org/abs/2112.13692
+
+ Args:
+ depth: Number of trunk residual blocks
+ dimensions: Dimnesion of the model (denoted by d in the paper)
+ ratio: The Squeeze-Excitation ratio
+
+ Inputs:
+ Convolutional features extracted from the conv stem.
+
+ Outputs:
+ Flattened patches.
+ """
+
+ def __init__(self, depth, dimensions, ratio, **kwargs):
+ super().__init__(**kwargs)
+ self.ratio = ratio
+ self.dimensions = dimensions
+ self.depth = depth
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "ratio": self.ratio,
+ "dimensions": self.dimensions,
+ "depth": self.depth,
+ }
+ )
+ return config
+
+ def build(self, input_shape):
+ config = {
+ "filters": self.dimensions,
+ "activation": ops.gelu,
+ "padding": "same",
+ }
+
+ trunk_block = [
+ layers.LayerNormalization(epsilon=1e-6),
+ layers.Conv2D(kernel_size=(1, 1), **config),
+ layers.Conv2D(kernel_size=(3, 3), **config),
+ SqueezeExcite(ratio=self.ratio),
+ layers.Conv2D(kernel_size=(1, 1), filters=self.dimensions, padding="same"),
+ ]
+
+ self.trunk_blocks = [keras.Sequential(trunk_block) for _ in range(self.depth)]
+ self.add = layers.Add()
+ self.flatten_spatial = layers.Reshape((-1, self.dimensions))
+
+ def call(self, x):
+ # Remember the input.
+ shortcut = x
+ for trunk_block in self.trunk_blocks:
+ output = trunk_block(x)
+ shortcut = self.add([output, shortcut])
+ x = shortcut
+ # Flatten the patches.
+ x = self.flatten_spatial(x)
+ return x
+
+
+"""
+## Attention Pooling
+
+The output of the convolutional trunk is attended with a trainable
+_query_ class token. The resulting attention map is the weight of
+every patch of the image for a classification decision.
+"""
+
+
+class AttentionPooling(layers.Layer):
+ """Applies attention to the patches extracted form the
+ trunk with the CLS token.
+
+ Args:
+ dimensions: The dimension of the whole architecture.
+ num_classes: The number of classes in the dataset.
+
+ Inputs:
+ Flattened patches from the trunk.
+
+ Outputs:
+ The modifies CLS token.
+ """
+
+ def __init__(self, dimensions, num_classes, **kwargs):
+ super().__init__(**kwargs)
+ self.dimensions = dimensions
+ self.num_classes = num_classes
+ self.cls = keras.Variable(ops.zeros((1, 1, dimensions)))
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "dimensions": self.dimensions,
+ "num_classes": self.num_classes,
+ "cls": self.cls.numpy(),
+ }
+ )
+ return config
+
+ def build(self, input_shape):
+ self.attention = layers.MultiHeadAttention(
+ num_heads=1,
+ key_dim=self.dimensions,
+ dropout=0.2,
+ )
+ self.layer_norm1 = layers.LayerNormalization(epsilon=1e-6)
+ self.layer_norm2 = layers.LayerNormalization(epsilon=1e-6)
+ self.layer_norm3 = layers.LayerNormalization(epsilon=1e-6)
+ self.mlp = keras.Sequential(
+ [
+ layers.Dense(units=self.dimensions, activation=ops.gelu),
+ layers.Dropout(0.2),
+ layers.Dense(units=self.dimensions, activation=ops.gelu),
+ ]
+ )
+ self.dense = layers.Dense(units=self.num_classes)
+ self.flatten = layers.Flatten()
+
+ def call(self, x):
+ batch_size = ops.shape(x)[0]
+ # Expand the class token batch number of times.
+ class_token = ops.repeat(self.cls, repeats=batch_size, axis=0)
+ # Concat the input with the trainable class token.
+ x = ops.concatenate([class_token, x], axis=1)
+ # Apply attention to x.
+ x = self.layer_norm1(x)
+ x, viz_weights = self.attention(
+ query=x[:, 0:1], key=x, value=x, return_attention_scores=True
+ )
+ class_token = class_token + x
+ class_token = self.layer_norm2(class_token)
+ class_token = self.flatten(class_token)
+ class_token = self.layer_norm3(class_token)
+ class_token = class_token + self.mlp(class_token)
+ # Build the logits
+ logits = self.dense(class_token)
+ return logits, ops.squeeze(viz_weights)[..., 1:]
+
+
+"""
+## Patch convnet
+
+The patch-convnet is shown in the figure below.
+
+|  |
+| :--: |
+| [Source](https://arxiv.org/abs/2112.13692) |
+
+All the modules in the architecture are built in the earlier seciton.
+In this section, we stack all of the different modules together.
+"""
+
+
+class PatchConvNet(keras.Model):
+ def __init__(
+ self,
+ stem,
+ trunk,
+ attention_pooling,
+ preprocessing_model,
+ train_augmentation_model,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.stem = stem
+ self.trunk = trunk
+ self.attention_pooling = attention_pooling
+ self.train_augmentation_model = train_augmentation_model
+ self.preprocessing_model = preprocessing_model
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "stem": self.stem,
+ "trunk": self.trunk,
+ "attention_pooling": self.attention_pooling,
+ "train_augmentation_model": self.train_augmentation_model,
+ "preprocessing_model": self.preprocessing_model,
+ }
+ )
+ return config
+
+ def _calculate_loss(self, inputs, test=False):
+ images, labels = inputs
+ # Augment the input images.
+ if test:
+ augmented_images = self.preprocessing_model(images)
+ else:
+ augmented_images = self.train_augmentation_model(images)
+ # Pass through the stem.
+ x = self.stem(augmented_images)
+ # Pass through the trunk.
+ x = self.trunk(x)
+ # Pass through the attention pooling block.
+ logits, _ = self.attention_pooling(x)
+ # Compute the total loss.
+ total_loss = self.compiled_loss(labels, logits)
+ return total_loss, logits
+
+ def train_step(self, inputs):
+ with tf.GradientTape() as tape:
+ total_loss, logits = self._calculate_loss(inputs)
+ # Apply gradients.
+ train_vars = [
+ self.stem.trainable_variables,
+ self.trunk.trainable_variables,
+ self.attention_pooling.trainable_variables,
+ ]
+ grads = tape.gradient(total_loss, train_vars)
+ trainable_variable_list = []
+ for grad, var in zip(grads, train_vars):
+ for g, v in zip(grad, var):
+ trainable_variable_list.append((g, v))
+ self.optimizer.apply_gradients(trainable_variable_list)
+ # Report progress.
+ _, labels = inputs
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+ def test_step(self, inputs):
+ total_loss, logits = self._calculate_loss(inputs, test=True)
+ # Report progress.
+ _, labels = inputs
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+ def call(self, images):
+ # Augment the input images.
+ augmented_images = self.preprocessing_model(images)
+ # Pass through the stem.
+ x = self.stem(augmented_images)
+ # Pass through the trunk.
+ x = self.trunk(x)
+ # Pass through the attention pooling block.
+ logits, viz_weights = self.attention_pooling(x)
+ return logits, viz_weights
+
+
+"""
+## Callbacks
+
+This callback will plot the image and the attention map overlayed on
+the image.
+"""
+
+# Taking a batch of test inputs to measure model's progress.
+test_images, test_labels = next(iter(test_ds))
+
+
+class TrainMonitor(keras.callbacks.Callback):
+ def __init__(self, epoch_interval=None):
+ self.epoch_interval = epoch_interval
+
+ def on_epoch_end(self, epoch, logs=None):
+ if self.epoch_interval and epoch % self.epoch_interval == 4:
+ test_augmented_images = self.model.preprocessing_model(test_images)
+ # Pass through the stem.
+ test_x = self.model.stem(test_augmented_images)
+ # Pass through the trunk.
+ test_x = self.model.trunk(test_x)
+ # Pass through the attention pooling block.
+ _, test_viz_weights = self.model.attention_pooling(test_x)
+ # Reshape the vizualization weights
+ num_patches = ops.shape(test_viz_weights)[-1]
+ height = width = int(math.sqrt(num_patches))
+ test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
+ # Take a random image and its attention weights.
+ index = np.random.randint(low=0, high=ops.shape(test_augmented_images)[0])
+ selected_image = test_augmented_images[index]
+ selected_weight = test_viz_weights[index]
+ # Plot the images and the overlayed attention map.
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
+ ax[0].imshow(selected_image)
+ ax[0].set_title(f"Original: {epoch:03d}")
+ ax[0].axis("off")
+ img = ax[1].imshow(selected_image)
+ ax[1].imshow(
+ selected_weight, cmap="inferno", alpha=0.6, extent=img.get_extent()
+ )
+ ax[1].set_title(f"Attended: {epoch:03d}")
+ ax[1].axis("off")
+ plt.axis("off")
+ plt.show()
+ plt.close()
+
+
+"""
+## Learning rate schedule
+"""
+
+
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ def __init__(
+ self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
+ ):
+ super().__init__()
+ self.learning_rate_base = learning_rate_base
+ self.total_steps = total_steps
+ self.warmup_learning_rate = warmup_learning_rate
+ self.warmup_steps = warmup_steps
+ self.pi = np.pi
+
+ def __call__(self, step):
+ if self.total_steps < self.warmup_steps:
+ raise ValueError("Total_steps must be larger or equal to warmup_steps.")
+ cos_annealed_lr = ops.cos(
+ self.pi
+ * (ops.cast(step, "float32") - self.warmup_steps)
+ / float(self.total_steps - self.warmup_steps)
+ )
+ learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
+ if self.warmup_steps > 0:
+ if self.learning_rate_base < self.warmup_learning_rate:
+ raise ValueError(
+ "Learning_rate_base must be larger or equal to "
+ "warmup_learning_rate."
+ )
+ slope = (
+ self.learning_rate_base - self.warmup_learning_rate
+ ) / self.warmup_steps
+ warmup_rate = slope * ops.cast(step, "float32") + self.warmup_learning_rate
+ learning_rate = ops.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+ return ops.where(
+ step > self.total_steps,
+ 0.0,
+ learning_rate,
+ )
+
+
+total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
+warmup_epoch_percentage = 0.15
+warmup_steps = int(total_steps * warmup_epoch_percentage)
+scheduled_lrs = WarmUpCosine(
+ learning_rate_base=LEARNING_RATE,
+ total_steps=total_steps,
+ warmup_learning_rate=0.0,
+ warmup_steps=warmup_steps,
+)
+
+"""
+## Training
+
+We build the model, compile it, and train it.
+"""
+
+train_augmentation_model = get_train_augmentation_model()
+preprocessing_model = get_preprocessing()
+conv_stem = build_convolutional_stem(dimensions=DIMENSIONS)
+conv_trunk = Trunk(depth=TRUNK_DEPTH, dimensions=DIMENSIONS, ratio=SE_RATIO)
+attention_pooling = AttentionPooling(dimensions=DIMENSIONS, num_classes=NUM_CLASSES)
+
+patch_conv_net = PatchConvNet(
+ stem=conv_stem,
+ trunk=conv_trunk,
+ attention_pooling=attention_pooling,
+ train_augmentation_model=train_augmentation_model,
+ preprocessing_model=preprocessing_model,
+)
+
+# Assemble the callbacks.
+train_callbacks = [TrainMonitor(epoch_interval=5)]
+# Get the optimizer.
+optimizer = keras.optimizers.AdamW(
+ learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY
+)
+# Compile and pretrain the model.
+patch_conv_net.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+)
+history = patch_conv_net.fit(
+ train_ds,
+ epochs=EPOCHS,
+ validation_data=val_ds,
+ callbacks=train_callbacks,
+)
+
+# Evaluate the model with the test dataset.
+loss, acc_top1, acc_top5 = patch_conv_net.evaluate(test_ds)
+print(f"Loss: {loss:0.2f}")
+print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
+print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
+
+"""
+## Inference
+
+Here, we use the trained model to plot the attention map.
+"""
+
+
+def plot_attention(image):
+ """Plots the attention map on top of the image.
+
+ Args:
+ image: A numpy image of arbitrary size.
+ """
+ # Resize the image to a (32, 32) dim.
+ image = ops.image.resize(image, (32, 32))
+ image = image[np.newaxis, ...]
+ test_augmented_images = patch_conv_net.preprocessing_model(image)
+ # Pass through the stem.
+ test_x = patch_conv_net.stem(test_augmented_images)
+ # Pass through the trunk.
+ test_x = patch_conv_net.trunk(test_x)
+ # Pass through the attention pooling block.
+ _, test_viz_weights = patch_conv_net.attention_pooling(test_x)
+ test_viz_weights = test_viz_weights[np.newaxis, ...]
+ # Reshape the vizualization weights.
+ num_patches = ops.shape(test_viz_weights)[-1]
+ height = width = int(math.sqrt(num_patches))
+ test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
+ selected_image = test_augmented_images[0]
+ selected_weight = test_viz_weights[0]
+ # Plot the images.
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
+ ax[0].imshow(selected_image)
+ ax[0].set_title(f"Original")
+ ax[0].axis("off")
+ img = ax[1].imshow(selected_image)
+ ax[1].imshow(selected_weight, cmap="inferno", alpha=0.6, extent=img.get_extent())
+ ax[1].set_title(f"Attended")
+ ax[1].axis("off")
+ plt.axis("off")
+ plt.show()
+ plt.close()
+
+
+url = "http://farm9.staticflickr.com/8017/7140384795_385b1f48df_z.jpg"
+image_name = keras.utils.get_file(fname="image.jpg", origin=url)
+image = keras.utils.load_img(image_name)
+image = keras.utils.img_to_array(image)
+plot_attention(image)
+
+"""
+## Conclusions
+
+The attention map corresponding to the trainable `CLASS`
+token and the patches of the image helps explain the classificaiton
+decision. One should also note that the attention maps gradually get
+better. In the initial training regime, the attention is scattered all
+around while at a later stage, it focuses more on the objects of the
+image.
+
+The non-pyramidal convnet achieves an accuracy of ~84-85% top-1 test
+accuracy.
+
+I would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for
+providing GPU credits for this project.
+"""
diff --git a/knowledge_base/vision/perceiver_image_classification.py b/knowledge_base/vision/perceiver_image_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..95428396d961de8ea3afcaa427c15f58f7d703b3
--- /dev/null
+++ b/knowledge_base/vision/perceiver_image_classification.py
@@ -0,0 +1,469 @@
+"""
+Title: Image classification with Perceiver
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/04/30
+Last modified: 2023/12/30
+Description: Implementing the Perceiver model for image classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example implements the
+[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
+model by Andrew Jaegle et al. for image classification,
+and demonstrates it on the CIFAR-100 dataset.
+
+The Perceiver model leverages an asymmetric attention mechanism to iteratively
+distill inputs into a tight latent bottleneck,
+allowing it to scale to handle very large inputs.
+
+In other words: let's assume that your input data array (e.g. image) has `M` elements (i.e. patches), where `M` is large.
+In a standard Transformer model, a self-attention operation is performed for the `M` elements.
+The complexity of this operation is `O(M^2)`.
+However, the Perceiver model creates a latent array of size `N` elements, where `N << M`,
+and performs two operations iteratively:
+
+1. Cross-attention Transformer between the latent array and the data array - The complexity of this operation is `O(M.N)`.
+2. Self-attention Transformer on the latent array - The complexity of this operation is `O(N^2)`.
+
+This example requires Keras 3.0 or higher.
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import layers, activations, ops
+
+"""
+## Prepare the data
+"""
+
+num_classes = 100
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+"""
+## Configure the hyperparameters
+"""
+
+learning_rate = 0.001
+weight_decay = 0.0001
+batch_size = 64
+num_epochs = 2 # It is recommended to run 50 epochs to observe improvements in accuracy
+dropout_rate = 0.2
+image_size = 64 # We'll resize input images to this size.
+patch_size = 2 # Size of the patches to be extract from the input images.
+num_patches = (image_size // patch_size) ** 2 # Size of the data array.
+latent_dim = 256 # Size of the latent array.
+projection_dim = 256 # Embedding size of each element in the data and latent arrays.
+num_heads = 8 # Number of Transformer heads.
+ffn_units = [
+ projection_dim,
+ projection_dim,
+] # Size of the Transformer Feedforward network.
+num_transformer_blocks = 4
+num_iterations = 2 # Repetitions of the cross-attention and Transformer modules.
+classifier_units = [
+ projection_dim,
+ num_classes,
+] # Size of the Feedforward network of the final classifier.
+
+print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
+print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
+print(f"Patches per image: {num_patches}")
+print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")
+print(f"Latent array shape: {latent_dim} X {projection_dim}")
+print(f"Data array shape: {num_patches} X {projection_dim}")
+
+"""
+Note that, in order to use each pixel as an individual input in the data array,
+set `patch_size` to 1.
+"""
+
+"""
+## Use data augmentation
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.Resizing(image_size, image_size),
+ layers.RandomFlip("horizontal"),
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
+ ],
+ name="data_augmentation",
+)
+# Compute the mean and the variance of the training data for normalization.
+data_augmentation.layers[0].adapt(x_train)
+
+"""
+## Implement Feedforward network (FFN)
+"""
+
+
+def create_ffn(hidden_units, dropout_rate):
+ ffn_layers = []
+ for units in hidden_units[:-1]:
+ ffn_layers.append(layers.Dense(units, activation=activations.gelu))
+
+ ffn_layers.append(layers.Dense(units=hidden_units[-1]))
+ ffn_layers.append(layers.Dropout(dropout_rate))
+
+ ffn = keras.Sequential(ffn_layers)
+ return ffn
+
+
+"""
+## Implement patch creation as a layer
+"""
+
+
+class Patches(layers.Layer):
+ def __init__(self, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+
+ def call(self, images):
+ batch_size = ops.shape(images)[0]
+ patches = ops.image.extract_patches(
+ images=images,
+ size=(self.patch_size, self.patch_size),
+ strides=(self.patch_size, self.patch_size),
+ dilation_rate=1,
+ padding="valid",
+ )
+ patch_dims = patches.shape[-1]
+ patches = ops.reshape(patches, [batch_size, -1, patch_dims])
+ return patches
+
+
+"""
+## Implement the patch encoding layer
+
+The `PatchEncoder` layer will linearly transform a patch by projecting it into
+a vector of size `latent_dim`. In addition, it adds a learnable position embedding
+to the projected vector.
+
+Note that the orginal Perceiver paper uses the Fourier feature positional encodings.
+"""
+
+
+class PatchEncoder(layers.Layer):
+ def __init__(self, num_patches, projection_dim):
+ super().__init__()
+ self.num_patches = num_patches
+ self.projection = layers.Dense(units=projection_dim)
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+
+ def call(self, patches):
+ positions = ops.arange(start=0, stop=self.num_patches, step=1)
+ encoded = self.projection(patches) + self.position_embedding(positions)
+ return encoded
+
+
+"""
+## Build the Perceiver model
+
+The Perceiver consists of two modules: a cross-attention
+module and a standard Transformer with self-attention.
+"""
+
+"""
+### Cross-attention module
+
+The cross-attention expects a `(latent_dim, projection_dim)` latent array,
+and the `(data_dim, projection_dim)` data array as inputs,
+to produce a `(latent_dim, projection_dim)` latent array as an output.
+To apply cross-attention, the `query` vectors are generated from the latent array,
+while the `key` and `value` vectors are generated from the encoded image.
+
+Note that the data array in this example is the image,
+where the `data_dim` is set to the `num_patches`.
+"""
+
+
+def create_cross_attention_module(
+ latent_dim, data_dim, projection_dim, ffn_units, dropout_rate
+):
+ inputs = {
+ # Recieve the latent array as an input of shape [1, latent_dim, projection_dim].
+ "latent_array": layers.Input(
+ shape=(latent_dim, projection_dim), name="latent_array"
+ ),
+ # Recieve the data_array (encoded image) as an input of shape [batch_size, data_dim, projection_dim].
+ "data_array": layers.Input(shape=(data_dim, projection_dim), name="data_array"),
+ }
+
+ # Apply layer norm to the inputs
+ latent_array = layers.LayerNormalization(epsilon=1e-6)(inputs["latent_array"])
+ data_array = layers.LayerNormalization(epsilon=1e-6)(inputs["data_array"])
+
+ # Create query tensor: [1, latent_dim, projection_dim].
+ query = layers.Dense(units=projection_dim)(latent_array)
+ # Create key tensor: [batch_size, data_dim, projection_dim].
+ key = layers.Dense(units=projection_dim)(data_array)
+ # Create value tensor: [batch_size, data_dim, projection_dim].
+ value = layers.Dense(units=projection_dim)(data_array)
+
+ # Generate cross-attention outputs: [batch_size, latent_dim, projection_dim].
+ attention_output = layers.Attention(use_scale=True, dropout=0.1)(
+ [query, key, value], return_attention_scores=False
+ )
+ # Skip connection 1.
+ attention_output = layers.Add()([attention_output, latent_array])
+
+ # Apply layer norm.
+ attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)
+ # Apply Feedforward network.
+ ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
+ outputs = ffn(attention_output)
+ # Skip connection 2.
+ outputs = layers.Add()([outputs, attention_output])
+
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+"""
+### Transformer module
+
+The Transformer expects the output latent vector from the cross-attention module
+as an input, applies multi-head self-attention to its `latent_dim` elements,
+followed by feedforward network, to produce another `(latent_dim, projection_dim)` latent array.
+"""
+
+
+def create_transformer_module(
+ latent_dim,
+ projection_dim,
+ num_heads,
+ num_transformer_blocks,
+ ffn_units,
+ dropout_rate,
+):
+ # input_shape: [1, latent_dim, projection_dim]
+ inputs = layers.Input(shape=(latent_dim, projection_dim))
+
+ x0 = inputs
+ # Create multiple layers of the Transformer block.
+ for _ in range(num_transformer_blocks):
+ # Apply layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-6)(x0)
+ # Create a multi-head self-attention layer.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=projection_dim, dropout=0.1
+ )(x1, x1)
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, x0])
+ # Apply layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ # Apply Feedforward network.
+ ffn = create_ffn(hidden_units=ffn_units, dropout_rate=dropout_rate)
+ x3 = ffn(x3)
+ # Skip connection 2.
+ x0 = layers.Add()([x3, x2])
+
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=x0)
+ return model
+
+
+"""
+### Perceiver model
+
+The Perceiver model repeats the cross-attention and Transformer modules
+`num_iterations` timesโwith shared weights and skip connectionsโto allow
+the latent array to iteratively extract information from the input image as it is needed.
+"""
+
+
+class Perceiver(keras.Model):
+ def __init__(
+ self,
+ patch_size,
+ data_dim,
+ latent_dim,
+ projection_dim,
+ num_heads,
+ num_transformer_blocks,
+ ffn_units,
+ dropout_rate,
+ num_iterations,
+ classifier_units,
+ ):
+ super().__init__()
+
+ self.latent_dim = latent_dim
+ self.data_dim = data_dim
+ self.patch_size = patch_size
+ self.projection_dim = projection_dim
+ self.num_heads = num_heads
+ self.num_transformer_blocks = num_transformer_blocks
+ self.ffn_units = ffn_units
+ self.dropout_rate = dropout_rate
+ self.num_iterations = num_iterations
+ self.classifier_units = classifier_units
+
+ def build(self, input_shape):
+ # Create latent array.
+ self.latent_array = self.add_weight(
+ shape=(self.latent_dim, self.projection_dim),
+ initializer="random_normal",
+ trainable=True,
+ )
+
+ # Create patching module.
+ self.patcher = Patches(self.patch_size)
+
+ # Create patch encoder.
+ self.patch_encoder = PatchEncoder(self.data_dim, self.projection_dim)
+
+ # Create cross-attenion module.
+ self.cross_attention = create_cross_attention_module(
+ self.latent_dim,
+ self.data_dim,
+ self.projection_dim,
+ self.ffn_units,
+ self.dropout_rate,
+ )
+
+ # Create Transformer module.
+ self.transformer = create_transformer_module(
+ self.latent_dim,
+ self.projection_dim,
+ self.num_heads,
+ self.num_transformer_blocks,
+ self.ffn_units,
+ self.dropout_rate,
+ )
+
+ # Create global average pooling layer.
+ self.global_average_pooling = layers.GlobalAveragePooling1D()
+
+ # Create a classification head.
+ self.classification_head = create_ffn(
+ hidden_units=self.classifier_units, dropout_rate=self.dropout_rate
+ )
+
+ super().build(input_shape)
+
+ def call(self, inputs):
+ # Augment data.
+ augmented = data_augmentation(inputs)
+ # Create patches.
+ patches = self.patcher(augmented)
+ # Encode patches.
+ encoded_patches = self.patch_encoder(patches)
+ # Prepare cross-attention inputs.
+ cross_attention_inputs = {
+ "latent_array": ops.expand_dims(self.latent_array, 0),
+ "data_array": encoded_patches,
+ }
+ # Apply the cross-attention and the Transformer modules iteratively.
+ for _ in range(self.num_iterations):
+ # Apply cross-attention from the latent array to the data array.
+ latent_array = self.cross_attention(cross_attention_inputs)
+ # Apply self-attention Transformer to the latent array.
+ latent_array = self.transformer(latent_array)
+ # Set the latent array of the next iteration.
+ cross_attention_inputs["latent_array"] = latent_array
+
+ # Apply global average pooling to generate a [batch_size, projection_dim] repesentation tensor.
+ representation = self.global_average_pooling(latent_array)
+ # Generate logits.
+ logits = self.classification_head(representation)
+ return logits
+
+
+"""
+## Compile, train, and evaluate the mode
+"""
+
+
+def run_experiment(model):
+ # Create ADAM instead of LAMB optimizer with weight decay. (LAMB isn't supported yet)
+ optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
+
+ # Compile the model.
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="acc"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
+ ],
+ )
+
+ # Create a learning rate scheduler callback.
+ reduce_lr = keras.callbacks.ReduceLROnPlateau(
+ monitor="val_loss", factor=0.2, patience=3
+ )
+
+ # Create an early stopping callback.
+ early_stopping = keras.callbacks.EarlyStopping(
+ monitor="val_loss", patience=15, restore_best_weights=True
+ )
+
+ # Fit the model.
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_split=0.1,
+ callbacks=[early_stopping, reduce_lr],
+ )
+
+ _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+ # Return history to plot learning curves.
+ return history
+
+
+"""
+Note that training the perceiver model with the current settings on a V100 GPUs takes
+around 200 seconds.
+"""
+
+perceiver_classifier = Perceiver(
+ patch_size,
+ num_patches,
+ latent_dim,
+ projection_dim,
+ num_heads,
+ num_transformer_blocks,
+ ffn_units,
+ dropout_rate,
+ num_iterations,
+ classifier_units,
+)
+
+
+history = run_experiment(perceiver_classifier)
+
+"""
+After 40 epochs, the Perceiver model achieves around 53% accuracy and 81% top-5 accuracy on the test data.
+
+As mentioned in the ablations of the [Perceiver](https://arxiv.org/abs/2103.03206) paper,
+you can obtain better results by increasing the latent array size,
+increasing the (projection) dimensions of the latent array and data array elements,
+increasing the number of blocks in the Transformer module, and increasing the number of iterations of applying
+the cross-attention and the latent Transformer modules. You may also try to increase the size the input images
+and use different patch sizes.
+
+The Perceiver benefits from inceasing the model size. However, larger models needs bigger accelerators
+to fit in and train efficiently. This is why in the Perceiver paper they used 32 TPU cores to run the experiments.
+"""
diff --git a/knowledge_base/vision/pointnet.py b/knowledge_base/vision/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..af637f41c32640916a9dd47dc2d3496a7ded8e65
--- /dev/null
+++ b/knowledge_base/vision/pointnet.py
@@ -0,0 +1,300 @@
+"""
+Title: Point cloud classification with PointNet
+Author: [David Griffiths](https://dgriffiths3.github.io)
+Date created: 2020/05/25
+Last modified: 2024/01/09
+Description: Implementation of PointNet for ModelNet10 classification.
+Accelerator: GPU
+"""
+
+"""
+# Point cloud classification
+"""
+
+"""
+## Introduction
+
+Classification, detection and segmentation of unordered 3D point sets i.e. point clouds
+is a core problem in computer vision. This example implements the seminal point cloud
+deep learning paper [PointNet (Qi et al., 2017)](https://arxiv.org/abs/1612.00593). For a
+detailed intoduction on PointNet see [this blog
+post](https://medium.com/@luis_gonzales/an-in-depth-look-at-pointnet-111d7efdaa1a).
+"""
+
+"""
+## Setup
+
+If using colab first install trimesh with `!pip install trimesh`.
+"""
+
+
+import os
+import glob
+import trimesh
+import numpy as np
+from tensorflow import data as tf_data
+from keras import ops
+import keras
+from keras import layers
+from matplotlib import pyplot as plt
+
+keras.utils.set_random_seed(seed=42)
+
+"""
+## Load dataset
+
+We use the ModelNet10 model dataset, the smaller 10 class version of the ModelNet40
+dataset. First download the data:
+"""
+
+DATA_DIR = keras.utils.get_file(
+ "modelnet.zip",
+ "http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip",
+ extract=True,
+)
+DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), "ModelNet10")
+
+"""
+We can use the `trimesh` package to read and visualize the `.off` mesh files.
+"""
+
+mesh = trimesh.load(os.path.join(DATA_DIR, "chair/train/chair_0001.off"))
+mesh.show()
+
+"""
+To convert a mesh file to a point cloud we first need to sample points on the mesh
+surface. `.sample()` performs a uniform random sampling. Here we sample at 2048 locations
+and visualize in `matplotlib`.
+"""
+
+points = mesh.sample(2048)
+
+fig = plt.figure(figsize=(5, 5))
+ax = fig.add_subplot(111, projection="3d")
+ax.scatter(points[:, 0], points[:, 1], points[:, 2])
+ax.set_axis_off()
+plt.show()
+
+"""
+To generate a `tf.data.Dataset()` we need to first parse through the ModelNet data
+folders. Each mesh is loaded and sampled into a point cloud before being added to a
+standard python list and converted to a `numpy` array. We also store the current
+enumerate index value as the object label and use a dictionary to recall this later.
+"""
+
+
+def parse_dataset(num_points=2048):
+ train_points = []
+ train_labels = []
+ test_points = []
+ test_labels = []
+ class_map = {}
+ folders = glob.glob(os.path.join(DATA_DIR, "[!README]*"))
+
+ for i, folder in enumerate(folders):
+ print("processing class: {}".format(os.path.basename(folder)))
+ # store folder name with ID so we can retrieve later
+ class_map[i] = folder.split("/")[-1]
+ # gather all files
+ train_files = glob.glob(os.path.join(folder, "train/*"))
+ test_files = glob.glob(os.path.join(folder, "test/*"))
+
+ for f in train_files:
+ train_points.append(trimesh.load(f).sample(num_points))
+ train_labels.append(i)
+
+ for f in test_files:
+ test_points.append(trimesh.load(f).sample(num_points))
+ test_labels.append(i)
+
+ return (
+ np.array(train_points),
+ np.array(test_points),
+ np.array(train_labels),
+ np.array(test_labels),
+ class_map,
+ )
+
+
+"""
+Set the number of points to sample and batch size and parse the dataset. This can take
+~5minutes to complete.
+"""
+
+NUM_POINTS = 2048
+NUM_CLASSES = 10
+BATCH_SIZE = 32
+
+train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(
+ NUM_POINTS
+)
+
+"""
+Our data can now be read into a `tf.data.Dataset()` object. We set the shuffle buffer
+size to the entire size of the dataset as prior to this the data is ordered by class.
+Data augmentation is important when working with point cloud data. We create a
+augmentation function to jitter and shuffle the train dataset.
+"""
+
+
+def augment(points, label):
+ # jitter points
+ points += keras.random.uniform(points.shape, -0.005, 0.005, dtype="float64")
+ # shuffle points
+ points = keras.random.shuffle(points)
+ return points, label
+
+
+train_size = 0.8
+dataset = tf_data.Dataset.from_tensor_slices((train_points, train_labels))
+test_dataset = tf_data.Dataset.from_tensor_slices((test_points, test_labels))
+train_dataset_size = int(len(dataset) * train_size)
+
+dataset = dataset.shuffle(len(train_points)).map(augment)
+test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)
+
+train_dataset = dataset.take(train_dataset_size).batch(BATCH_SIZE)
+validation_dataset = dataset.skip(train_dataset_size).batch(BATCH_SIZE)
+
+"""
+### Build a model
+
+Each convolution and fully-connected layer (with exception for end layers) consists of
+Convolution / Dense -> Batch Normalization -> ReLU Activation.
+"""
+
+
+def conv_bn(x, filters):
+ x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x)
+ x = layers.BatchNormalization(momentum=0.0)(x)
+ return layers.Activation("relu")(x)
+
+
+def dense_bn(x, filters):
+ x = layers.Dense(filters)(x)
+ x = layers.BatchNormalization(momentum=0.0)(x)
+ return layers.Activation("relu")(x)
+
+
+"""
+PointNet consists of two core components. The primary MLP network, and the transformer
+net (T-net). The T-net aims to learn an affine transformation matrix by its own mini
+network. The T-net is used twice. The first time to transform the input features (n, 3)
+into a canonical representation. The second is an affine transformation for alignment in
+feature space (n, 3). As per the original paper we constrain the transformation to be
+close to an orthogonal matrix (i.e. ||X*X^T - I|| = 0).
+"""
+
+
+class OrthogonalRegularizer(keras.regularizers.Regularizer):
+ def __init__(self, num_features, l2reg=0.001):
+ self.num_features = num_features
+ self.l2reg = l2reg
+ self.eye = ops.eye(num_features)
+
+ def __call__(self, x):
+ x = ops.reshape(x, (-1, self.num_features, self.num_features))
+ xxt = ops.tensordot(x, x, axes=(2, 2))
+ xxt = ops.reshape(xxt, (-1, self.num_features, self.num_features))
+ return ops.sum(self.l2reg * ops.square(xxt - self.eye))
+
+
+"""
+ We can then define a general function to build T-net layers.
+"""
+
+
+def tnet(inputs, num_features):
+ # Initialise bias as the identity matrix
+ bias = keras.initializers.Constant(np.eye(num_features).flatten())
+ reg = OrthogonalRegularizer(num_features)
+
+ x = conv_bn(inputs, 32)
+ x = conv_bn(x, 64)
+ x = conv_bn(x, 512)
+ x = layers.GlobalMaxPooling1D()(x)
+ x = dense_bn(x, 256)
+ x = dense_bn(x, 128)
+ x = layers.Dense(
+ num_features * num_features,
+ kernel_initializer="zeros",
+ bias_initializer=bias,
+ activity_regularizer=reg,
+ )(x)
+ feat_T = layers.Reshape((num_features, num_features))(x)
+ # Apply affine transformation to input features
+ return layers.Dot(axes=(2, 1))([inputs, feat_T])
+
+
+"""
+The main network can be then implemented in the same manner where the t-net mini models
+can be dropped in a layers in the graph. Here we replicate the network architecture
+published in the original paper but with half the number of weights at each layer as we
+are using the smaller 10 class ModelNet dataset.
+"""
+
+inputs = keras.Input(shape=(NUM_POINTS, 3))
+
+x = tnet(inputs, 3)
+x = conv_bn(x, 32)
+x = conv_bn(x, 32)
+x = tnet(x, 32)
+x = conv_bn(x, 32)
+x = conv_bn(x, 64)
+x = conv_bn(x, 512)
+x = layers.GlobalMaxPooling1D()(x)
+x = dense_bn(x, 256)
+x = layers.Dropout(0.3)(x)
+x = dense_bn(x, 128)
+x = layers.Dropout(0.3)(x)
+
+outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
+
+model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
+model.summary()
+
+"""
+### Train model
+
+Once the model is defined it can be trained like any other standard classification model
+using `.compile()` and `.fit()`.
+"""
+
+model.compile(
+ loss="sparse_categorical_crossentropy",
+ optimizer=keras.optimizers.Adam(learning_rate=0.001),
+ metrics=["sparse_categorical_accuracy"],
+)
+
+model.fit(train_dataset, epochs=20, validation_data=validation_dataset)
+
+"""
+## Visualize predictions
+
+We can use matplotlib to visualize our trained model performance.
+"""
+
+data = test_dataset.take(1)
+
+points, labels = list(data)[0]
+points = points[:8, ...]
+labels = labels[:8, ...]
+
+# run test data through model
+preds = model.predict(points)
+preds = ops.argmax(preds, -1)
+
+points = points.numpy()
+
+# plot points with predicted class and label
+fig = plt.figure(figsize=(15, 10))
+for i in range(8):
+ ax = fig.add_subplot(2, 4, i + 1, projection="3d")
+ ax.scatter(points[i, :, 0], points[i, :, 1], points[i, :, 2])
+ ax.set_title(
+ "pred: {:}, label: {:}".format(
+ CLASS_MAP[preds[i].numpy()], CLASS_MAP[labels.numpy()[i]]
+ )
+ )
+ ax.set_axis_off()
+plt.show()
diff --git a/knowledge_base/vision/pointnet_segmentation.py b/knowledge_base/vision/pointnet_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a66f42c57707b9ecf0257dbe0a0cdce4e8f1297
--- /dev/null
+++ b/knowledge_base/vision/pointnet_segmentation.py
@@ -0,0 +1,609 @@
+"""
+Title: Point cloud segmentation with PointNet
+Author: [Soumik Rakshit](https://github.com/soumik12345), [Sayak Paul](https://github.com/sayakpaul)
+Date created: 2020/10/23
+Last modified: 2020/10/24
+Description: Implementation of a PointNet-based model for segmenting point clouds.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+A "point cloud" is an important type of data structure for storing geometric shape data.
+Due to its irregular format, it's often transformed into
+regular 3D voxel grids or collections of images before being used in deep learning applications,
+a step which makes the data unnecessarily large.
+The PointNet family of models solves this problem by directly consuming point clouds, respecting
+the permutation-invariance property of the point data. The PointNet family of
+models provides a simple, unified architecture
+for applications ranging from **object classification**, **part segmentation**, to
+**scene semantic parsing**.
+
+In this example, we demonstrate the implementation of the PointNet architecture
+for shape segmentation.
+
+### References
+
+- [PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation](https://arxiv.org/abs/1612.00593)
+- [Point cloud classification with PointNet](https://keras.io/examples/vision/pointnet/)
+- [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025)
+"""
+
+"""
+## Imports
+"""
+
+import os
+import json
+import random
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from glob import glob
+
+import tensorflow as tf # For tf.data
+import keras
+from keras import layers
+
+import matplotlib.pyplot as plt
+
+"""
+## Downloading Dataset
+
+The [ShapeNet dataset](https://shapenet.org/) is an ongoing effort to establish a richly-annotated,
+large-scale dataset of 3D shapes. **ShapeNetCore** is a subset of the full ShapeNet
+dataset with clean single 3D models and manually verified category and alignment
+annotations. It covers 55 common object categories, with about 51,300 unique 3D models.
+
+For this example, we use one of the 12 object categories of
+[PASCAL 3D+](http://cvgl.stanford.edu/projects/pascal3d.html),
+included as part of the ShapenetCore dataset.
+"""
+
+dataset_url = "https://git.io/JiY4i"
+
+dataset_path = keras.utils.get_file(
+ fname="shapenet.zip",
+ origin=dataset_url,
+ cache_subdir="datasets",
+ hash_algorithm="auto",
+ extract=True,
+ archive_format="auto",
+ cache_dir="datasets",
+)
+
+"""
+## Loading the dataset
+
+We parse the dataset metadata in order to easily map model categories to their
+respective directories and segmentation classes to colors for the purpose of
+visualization.
+"""
+
+with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
+ metadata = json.load(json_file)
+
+print(metadata)
+
+"""
+In this example, we train PointNet to segment the parts of an `Airplane` model.
+"""
+
+points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(
+ metadata["Airplane"]["directory"]
+)
+labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(
+ metadata["Airplane"]["directory"]
+)
+LABELS = metadata["Airplane"]["lables"]
+COLORS = metadata["Airplane"]["colors"]
+
+VAL_SPLIT = 0.2
+NUM_SAMPLE_POINTS = 1024
+BATCH_SIZE = 32
+EPOCHS = 60
+INITIAL_LR = 1e-3
+
+"""
+## Structuring the dataset
+
+We generate the following in-memory data structures from the Airplane point clouds and
+their labels:
+
+- `point_clouds` is a list of `np.array` objects that represent the point cloud data in
+the form of x, y and z coordinates. Axis 0 represents the number of points in the
+point cloud, while axis 1 represents the coordinates. `all_labels` is the list
+that represents the label of each coordinate as a string (needed mainly for
+visualization purposes).
+- `test_point_clouds` is in the same format as `point_clouds`, but doesn't have
+corresponding the labels of the point clouds.
+- `all_labels` is a list of `np.array` objects that represent the point cloud labels
+for each coordinate, corresponding to the `point_clouds` list.
+- `point_cloud_labels` is a list of `np.array` objects that represent the point cloud
+labels for each coordinate in one-hot encoded form, corresponding to the `point_clouds`
+list.
+"""
+
+point_clouds, test_point_clouds = [], []
+point_cloud_labels, all_labels = [], []
+
+points_files = glob(os.path.join(points_dir, "*.pts"))
+for point_file in tqdm(points_files):
+ point_cloud = np.loadtxt(point_file)
+ if point_cloud.shape[0] < NUM_SAMPLE_POINTS:
+ continue
+
+ # Get the file-id of the current point cloud for parsing its
+ # labels.
+ file_id = point_file.split("/")[-1].split(".")[0]
+ label_data, num_labels = {}, 0
+ for label in LABELS:
+ label_file = os.path.join(labels_dir, label, file_id + ".seg")
+ if os.path.exists(label_file):
+ label_data[label] = np.loadtxt(label_file).astype("float32")
+ num_labels = len(label_data[label])
+
+ # Point clouds having labels will be our training samples.
+ try:
+ label_map = ["none"] * num_labels
+ for label in LABELS:
+ for i, data in enumerate(label_data[label]):
+ label_map[i] = label if data == 1 else label_map[i]
+ label_data = [
+ LABELS.index(label) if label != "none" else len(LABELS)
+ for label in label_map
+ ]
+ # Apply one-hot encoding to the dense label representation.
+ label_data = keras.utils.to_categorical(label_data, num_classes=len(LABELS) + 1)
+
+ point_clouds.append(point_cloud)
+ point_cloud_labels.append(label_data)
+ all_labels.append(label_map)
+ except KeyError:
+ test_point_clouds.append(point_cloud)
+
+"""
+Next, we take a look at some samples from the in-memory arrays we just generated:
+"""
+
+for _ in range(5):
+ i = random.randint(0, len(point_clouds) - 1)
+ print(f"point_clouds[{i}].shape:", point_clouds[0].shape)
+ print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)
+ for j in range(5):
+ print(
+ f"all_labels[{i}][{j}]:",
+ all_labels[i][j],
+ f"\tpoint_cloud_labels[{i}][{j}]:",
+ point_cloud_labels[i][j],
+ "\n",
+ )
+
+"""
+Now, let's visualize some of the point clouds along with their labels.
+"""
+
+
+def visualize_data(point_cloud, labels):
+ df = pd.DataFrame(
+ data={
+ "x": point_cloud[:, 0],
+ "y": point_cloud[:, 1],
+ "z": point_cloud[:, 2],
+ "label": labels,
+ }
+ )
+ fig = plt.figure(figsize=(15, 10))
+ ax = plt.axes(projection="3d")
+ for index, label in enumerate(LABELS):
+ c_df = df[df["label"] == label]
+ try:
+ ax.scatter(
+ c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
+ )
+ except IndexError:
+ pass
+ ax.legend()
+ plt.show()
+
+
+visualize_data(point_clouds[0], all_labels[0])
+visualize_data(point_clouds[300], all_labels[300])
+
+"""
+### Preprocessing
+
+Note that all the point clouds that we have loaded consist of a variable number of points,
+which makes it difficult for us to batch them together. In order to overcome this problem, we
+randomly sample a fixed number of points from each point cloud. We also normalize the
+point clouds in order to make the data scale-invariant.
+"""
+
+for index in tqdm(range(len(point_clouds))):
+ current_point_cloud = point_clouds[index]
+ current_label_cloud = point_cloud_labels[index]
+ current_labels = all_labels[index]
+ num_points = len(current_point_cloud)
+ # Randomly sampling respective indices.
+ sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)
+ # Sampling points corresponding to sampled indices.
+ sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])
+ # Sampling corresponding one-hot encoded labels.
+ sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])
+ # Sampling corresponding labels for visualization.
+ sampled_labels = np.array([current_labels[i] for i in sampled_indices])
+ # Normalizing sampled point cloud.
+ norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
+ norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))
+ point_clouds[index] = norm_point_cloud
+ point_cloud_labels[index] = sampled_label_cloud
+ all_labels[index] = sampled_labels
+
+"""
+Let's visualize the sampled and normalized point clouds along with their corresponding
+labels.
+"""
+
+visualize_data(point_clouds[0], all_labels[0])
+visualize_data(point_clouds[300], all_labels[300])
+
+"""
+### Creating TensorFlow datasets
+
+We create `tf.data.Dataset` objects for the training and validation data.
+We also augment the training point clouds by applying random jitter to them.
+"""
+
+
+def load_data(point_cloud_batch, label_cloud_batch):
+ point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])
+ label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])
+ return point_cloud_batch, label_cloud_batch
+
+
+def augment(point_cloud_batch, label_cloud_batch):
+ noise = tf.random.uniform(
+ tf.shape(label_cloud_batch), -0.001, 0.001, dtype=tf.float64
+ )
+ point_cloud_batch += noise[:, :, :3]
+ return point_cloud_batch, label_cloud_batch
+
+
+def generate_dataset(point_clouds, label_clouds, is_training=True):
+ dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))
+ dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else dataset
+ dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
+ dataset = dataset.batch(batch_size=BATCH_SIZE)
+ dataset = (
+ dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
+ if is_training
+ else dataset
+ )
+ return dataset
+
+
+split_index = int(len(point_clouds) * (1 - VAL_SPLIT))
+train_point_clouds = point_clouds[:split_index]
+train_label_cloud = point_cloud_labels[:split_index]
+total_training_examples = len(train_point_clouds)
+
+val_point_clouds = point_clouds[split_index:]
+val_label_cloud = point_cloud_labels[split_index:]
+
+print("Num train point clouds:", len(train_point_clouds))
+print("Num train point cloud labels:", len(train_label_cloud))
+print("Num val point clouds:", len(val_point_clouds))
+print("Num val point cloud labels:", len(val_label_cloud))
+
+train_dataset = generate_dataset(train_point_clouds, train_label_cloud)
+val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)
+
+print("Train Dataset:", train_dataset)
+print("Validation Dataset:", val_dataset)
+
+"""
+## PointNet model
+
+The figure below depicts the internals of the PointNet model family:
+
+
+
+Given that PointNet is meant to consume an ***unordered set*** of coordinates as its input data,
+its architecture needs to match the following characteristic properties
+of point cloud data:
+
+### Permutation invariance
+
+Given the unstructured nature of point cloud data, a scan made up of `n` points has `n!`
+permutations. The subsequent data processing must be invariant to the different
+representations. In order to make PointNet invariant to input permutations, we use a
+symmetric function (such as max-pooling) once the `n` input points are mapped to
+higher-dimensional space. The result is a **global feature vector** that aims to capture
+an aggregate signature of the `n` input points. The global feature vector is used alongside
+local point features for segmentation.
+
+
+
+### Transformation invariance
+
+Segmentation outputs should be unchanged if the object undergoes certain transformations,
+such as translation or scaling. For a given input point cloud, we apply an appropriate
+rigid or affine transformation to achieve pose normalization. Because each of the `n` input
+points are represented as a vector and are mapped to the embedding spaces independently,
+applying a geometric transformation simply amounts to matrix multiplying each point with
+a transformation matrix. This is motivated by the concept of
+[Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
+
+The operations comprising the T-Net are motivated by the higher-level architecture of
+PointNet. MLPs (or fully-connected layers) are used to map the input points independently
+and identically to a higher-dimensional space; max-pooling is used to encode a global
+feature vector whose dimensionality is then reduced with fully-connected layers. The
+input-dependent features at the final fully-connected layer are then combined with
+globally trainable weights and biases, resulting in a 3-by-3 transformation matrix.
+
+
+
+### Point interactions
+
+The interaction between neighboring points often carries useful information (i.e., a
+single point should not be treated in isolation). Whereas classification need only make
+use of global features, segmentation must be able to leverage local point features along
+with global point features.
+
+
+**Note**: The figures presented in this section have been taken from the
+[original paper](https://arxiv.org/abs/1612.00593).
+"""
+
+"""
+Now that we know the pieces that compose the PointNet model, we can implement the model.
+We start by implementing the basic blocks i.e., the convolutional block and the multi-layer
+perceptron block.
+"""
+
+
+def conv_block(x, filters, name):
+ x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)
+ x = layers.BatchNormalization(name=f"{name}_batch_norm")(x)
+ return layers.Activation("relu", name=f"{name}_relu")(x)
+
+
+def mlp_block(x, filters, name):
+ x = layers.Dense(filters, name=f"{name}_dense")(x)
+ x = layers.BatchNormalization(name=f"{name}_batch_norm")(x)
+ return layers.Activation("relu", name=f"{name}_relu")(x)
+
+
+"""
+We implement a regularizer (taken from
+[this example](https://keras.io/examples/vision/pointnet/#build-a-model))
+to enforce orthogonality in the feature space. This is needed to ensure
+that the magnitudes of the transformed features do not vary too much.
+"""
+
+
+class OrthogonalRegularizer(keras.regularizers.Regularizer):
+ """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""
+
+ def __init__(self, num_features, l2reg=0.001):
+ self.num_features = num_features
+ self.l2reg = l2reg
+ self.identity = keras.ops.eye(num_features)
+
+ def __call__(self, x):
+ x = keras.ops.reshape(x, (-1, self.num_features, self.num_features))
+ xxt = keras.ops.tensordot(x, x, axes=(2, 2))
+ xxt = keras.ops.reshape(xxt, (-1, self.num_features, self.num_features))
+ return keras.ops.sum(self.l2reg * keras.ops.square(xxt - self.identity))
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})
+ return config
+
+
+"""
+The next piece is the transformation network which we explained earlier.
+"""
+
+
+def transformation_net(inputs, num_features, name):
+ """
+ Reference: https://keras.io/examples/vision/pointnet/#build-a-model.
+
+ The `filters` values come from the original paper:
+ https://arxiv.org/abs/1612.00593.
+ """
+ x = conv_block(inputs, filters=64, name=f"{name}_1")
+ x = conv_block(x, filters=128, name=f"{name}_2")
+ x = conv_block(x, filters=1024, name=f"{name}_3")
+ x = layers.GlobalMaxPooling1D()(x)
+ x = mlp_block(x, filters=512, name=f"{name}_1_1")
+ x = mlp_block(x, filters=256, name=f"{name}_2_1")
+ return layers.Dense(
+ num_features * num_features,
+ kernel_initializer="zeros",
+ bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),
+ activity_regularizer=OrthogonalRegularizer(num_features),
+ name=f"{name}_final",
+ )(x)
+
+
+def transformation_block(inputs, num_features, name):
+ transformed_features = transformation_net(inputs, num_features, name=name)
+ transformed_features = layers.Reshape((num_features, num_features))(
+ transformed_features
+ )
+ return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])
+
+
+"""
+Finally, we piece the above blocks together and implement the segmentation model.
+"""
+
+
+def get_shape_segmentation_model(num_points, num_classes):
+ input_points = keras.Input(shape=(None, 3))
+
+ # PointNet Classification Network.
+ transformed_inputs = transformation_block(
+ input_points, num_features=3, name="input_transformation_block"
+ )
+ features_64 = conv_block(transformed_inputs, filters=64, name="features_64")
+ features_128_1 = conv_block(features_64, filters=128, name="features_128_1")
+ features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")
+ transformed_features = transformation_block(
+ features_128_2, num_features=128, name="transformed_features"
+ )
+ features_512 = conv_block(transformed_features, filters=512, name="features_512")
+ features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")
+ global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(
+ features_2048
+ )
+ global_features = keras.ops.tile(global_features, [1, num_points, 1])
+
+ # Segmentation head.
+ segmentation_input = layers.Concatenate(name="segmentation_input")(
+ [
+ features_64,
+ features_128_1,
+ features_128_2,
+ transformed_features,
+ features_512,
+ global_features,
+ ]
+ )
+ segmentation_features = conv_block(
+ segmentation_input, filters=128, name="segmentation_features"
+ )
+ outputs = layers.Conv1D(
+ num_classes, kernel_size=1, activation="softmax", name="segmentation_head"
+ )(segmentation_features)
+ return keras.Model(input_points, outputs)
+
+
+"""
+## Instantiate the model
+"""
+
+x, y = next(iter(train_dataset))
+
+num_points = x.shape[1]
+num_classes = y.shape[-1]
+
+segmentation_model = get_shape_segmentation_model(num_points, num_classes)
+segmentation_model.summary()
+
+"""
+## Training
+
+For the training the authors recommend using a learning rate schedule that decays the
+initial learning rate by half every 20 epochs. In this example, we use 5 epochs.
+"""
+
+steps_per_epoch = total_training_examples // BATCH_SIZE
+total_training_steps = steps_per_epoch * EPOCHS
+print(f"Steps per epoch: {steps_per_epoch}.")
+print(f"Total training steps: {total_training_steps}.")
+
+lr_schedule = keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate=0.003,
+ decay_steps=steps_per_epoch * 5,
+ decay_rate=0.5,
+ staircase=True,
+)
+
+steps = range(total_training_steps)
+lrs = [lr_schedule(step) for step in steps]
+
+plt.plot(lrs)
+plt.xlabel("Steps")
+plt.ylabel("Learning Rate")
+plt.show()
+
+"""
+Finally, we implement a utility for running our experiments and launch model training.
+"""
+
+
+def run_experiment(epochs):
+ segmentation_model = get_shape_segmentation_model(num_points, num_classes)
+ segmentation_model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
+ loss=keras.losses.CategoricalCrossentropy(),
+ metrics=["accuracy"],
+ )
+
+ checkpoint_filepath = "checkpoint.weights.h5"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_loss",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ history = segmentation_model.fit(
+ train_dataset,
+ validation_data=val_dataset,
+ epochs=epochs,
+ callbacks=[checkpoint_callback],
+ )
+
+ segmentation_model.load_weights(checkpoint_filepath)
+ return segmentation_model, history
+
+
+segmentation_model, history = run_experiment(epochs=EPOCHS)
+
+"""
+## Visualize the training landscape
+"""
+
+
+def plot_result(item):
+ plt.plot(history.history[item], label=item)
+ plt.plot(history.history["val_" + item], label="val_" + item)
+ plt.xlabel("Epochs")
+ plt.ylabel(item)
+ plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_result("loss")
+plot_result("accuracy")
+
+"""
+## Inference
+"""
+
+validation_batch = next(iter(val_dataset))
+val_predictions = segmentation_model.predict(validation_batch[0])
+print(f"Validation prediction shape: {val_predictions.shape}")
+
+
+def visualize_single_point_cloud(point_clouds, label_clouds, idx):
+ label_map = LABELS + ["none"]
+ point_cloud = point_clouds[idx]
+ label_cloud = label_clouds[idx]
+ visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])
+
+
+idx = np.random.choice(len(validation_batch[0]))
+print(f"Index selected: {idx}")
+
+# Plotting with ground-truth.
+visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)
+
+# Plotting with predicted labels.
+visualize_single_point_cloud(validation_batch[0], val_predictions, idx)
+
+"""
+## Final notes
+
+If you are interested in learning more about this topic, you may find
+[this repository](https://github.com/soumik12345/point-cloud-segmentation)
+useful.
+"""
diff --git a/knowledge_base/vision/probing_vits.py b/knowledge_base/vision/probing_vits.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f7b1a8815c50226a9178fab3e21f0d1bef9cb9
--- /dev/null
+++ b/knowledge_base/vision/probing_vits.py
@@ -0,0 +1,706 @@
+"""
+Title: Investigating Vision Transformer representations
+Authors: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)
+Date created: 2022/04/12
+Last modified: 2023/11/20
+Description: Looking into the representations learned by different Vision Transformers variants.
+Accelerator: None
+"""
+
+"""
+## Introduction
+
+In this example, we look into the representations learned by different Vision
+Transformer (ViT) models. Our main goal with this example is to provide insights into
+what empowers ViTs to learn from image data. In particular, the example discusses
+implementations of a few different ViT analysis tools.
+
+**Note:** when we say "Vision Transformer", we refer to a computer vision architecture that
+involves Transformer blocks ([Vaswani et al.](https://arxiv.org/abs/1706.03762)) and not
+necessarily the original Vision Transformer model
+([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)).
+
+## Models considered
+
+Since the inception of the original Vision Transformer, the computer vision community has
+seen a number of different ViT variants improving upon the original in various ways:
+training improvements, architecture improvements, and so on.
+In this example, we consider the following ViT model families:
+
+* ViTs trained using supervised pretraining with the ImageNet-1k and ImageNet-21k
+datasets ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929))
+* ViTs trained using supervised pretraining but only with the ImageNet-1k dataset with
+more regularization and distillation ([Touvron et al.](https://arxiv.org/abs/2012.12877))
+(DeiT).
+* ViTs trained using self-supervised pretraining ([Caron et al.](https://arxiv.org/abs/2104.14294))
+(DINO).
+
+Since the pretrained models are not implemented in Keras, we first implemented them as
+faithfully as possible. We then populated them with the official pretrained parameters.
+Finally, we evaluated our implementations on the ImageNet-1k validation set to ensure the
+evaluation numbers were matching with the original implementations. The details of our implementations
+are available in [this repository](https://github.com/sayakpaul/probing-vits).
+
+To keep the example concise, we won't exhaustively pair each model with the analysis
+methods. We'll provide notes in the respective sections so that you can pick up the
+pieces.
+
+To run this example on Google Colab, we need to update the `gdown` library like so:
+
+```shell
+pip install -U gdown -q
+```
+"""
+
+"""
+## Imports
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import zipfile
+from io import BytesIO
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import requests
+
+from PIL import Image
+from sklearn.preprocessing import MinMaxScaler
+import keras
+from keras import ops
+
+"""
+## Constants
+"""
+
+RESOLUTION = 224
+PATCH_SIZE = 16
+GITHUB_RELEASE = "https://github.com/sayakpaul/probing-vits/releases/download/v1.0.0/probing_vits.zip"
+FNAME = "probing_vits.zip"
+MODELS_ZIP = {
+ "vit_dino_base16": "Probing_ViTs/vit_dino_base16.zip",
+ "vit_b16_patch16_224": "Probing_ViTs/vit_b16_patch16_224.zip",
+ "vit_b16_patch16_224-i1k_pretrained": "Probing_ViTs/vit_b16_patch16_224-i1k_pretrained.zip",
+}
+
+"""
+## Data utilities
+
+For the original ViT models, the input images need to be scaled to the range `[-1, 1]`. For
+the other model families mentioned at the beginning, we need to normalize the images with
+channel-wise mean and standard deviation of the ImageNet-1k training set.
+
+
+"""
+
+crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
+norm_layer = keras.layers.Normalization(
+ mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
+ variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
+)
+rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)
+
+
+def preprocess_image(image, model_type, size=RESOLUTION):
+ # Turn the image into a numpy array and add batch dim.
+ image = np.array(image)
+ image = ops.expand_dims(image, 0)
+
+ # If model type is vit rescale the image to [-1, 1].
+ if model_type == "original_vit":
+ image = rescale_layer(image)
+
+ # Resize the image using bicubic interpolation.
+ resize_size = int((256 / 224) * size)
+ image = ops.image.resize(image, (resize_size, resize_size), interpolation="bicubic")
+
+ # Crop the image.
+ image = crop_layer(image)
+
+ # If model type is DeiT or DINO normalize the image.
+ if model_type != "original_vit":
+ image = norm_layer(image)
+
+ return ops.convert_to_numpy(image)
+
+
+def load_image_from_url(url, model_type):
+ # Credit: Willi Gierke
+ response = requests.get(url)
+ image = Image.open(BytesIO(response.content))
+ preprocessed_image = preprocess_image(image, model_type)
+ return image, preprocessed_image
+
+
+"""
+## Load a test image and display it
+"""
+
+# ImageNet-1k label mapping file and load it.
+
+mapping_file = keras.utils.get_file(
+ origin="https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
+)
+
+with open(mapping_file, "r") as f:
+ lines = f.readlines()
+imagenet_int_to_str = [line.rstrip() for line in lines]
+
+img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
+image, preprocessed_image = load_image_from_url(img_url, model_type="original_vit")
+
+plt.imshow(image)
+plt.axis("off")
+plt.show()
+
+"""
+## Load a model
+"""
+
+
+zip_path = keras.utils.get_file(
+ fname=FNAME,
+ origin=GITHUB_RELEASE,
+)
+
+with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ zip_ref.extractall("./")
+
+os.rename("Probing ViTs", "Probing_ViTs")
+
+
+def load_model(model_path: str) -> keras.Model:
+ with zipfile.ZipFile(model_path, "r") as zip_ref:
+ zip_ref.extractall("Probing_ViTs/")
+ model_name = model_path.split(".")[0]
+
+ inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
+ model = keras.layers.TFSMLayer(model_name, call_endpoint="serving_default")
+ outputs = model(inputs, training=False)
+
+ return keras.Model(inputs, outputs=outputs)
+
+
+vit_base_i21k_patch16_224 = load_model(MODELS_ZIP["vit_b16_patch16_224-i1k_pretrained"])
+print("Model loaded.")
+
+"""
+**More about the model**:
+
+This model was pretrained on the ImageNet-21k dataset and was then fine-tuned on the
+ImageNet-1k dataset. To learn more about how we developed this model in TensorFlow
+(with pretrained weights from
+[this source](https://github.com/google-research/vision_transformer/)) refer to
+[this notebook](https://github.com/sayakpaul/probing-vits/blob/main/notebooks/load-jax-weights-vitb16.ipynb).
+"""
+
+"""
+## Running regular inference with the model
+
+We now run inference with the loaded model on our test image.
+"""
+
+
+def split_prediction_and_attention_scores(outputs):
+ predictions = outputs["output_1"]
+ attention_score_dict = {}
+ for key, value in outputs.items():
+ if key.startswith("output_2_"):
+ attention_score_dict[key[len("output_2_") :]] = value
+ return predictions, attention_score_dict
+
+
+predictions, attention_score_dict = split_prediction_and_attention_scores(
+ vit_base_i21k_patch16_224.predict(preprocessed_image)
+)
+predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
+print(predicted_label)
+
+"""
+`attention_score_dict` contains the attention scores (softmaxed outputs) from each
+attention head of each Transformer block.
+"""
+
+"""
+## Method I: Mean attention distance
+
+[Dosovitskiy et al.](https://arxiv.org/abs/2010.11929) and
+[Raghu et al.](https://arxiv.org/abs/2108.08810) use a measure called
+"mean attention distance" from each attention head of different
+Transformer blocks to understand how local and global information flows
+into Vision Transformers.
+
+Mean attention distance is defined as the distance between query tokens and the other
+tokens times attention weights. So, for a single image
+
+* we take individual patches (tokens) extracted from the image,
+* calculate their geometric distance, and
+* multiply that with the attention scores.
+
+Attention scores are calculated here after forward passing the image in inference mode
+through the network. The following figure may help you understand the process a
+little bit better.
+
+
+
+This animation is created by [Ritwik Raha](https://twitter.com/ritwik_raha).
+"""
+
+
+def compute_distance_matrix(patch_size, num_patches, length):
+ distance_matrix = np.zeros((num_patches, num_patches))
+ for i in range(num_patches):
+ for j in range(num_patches):
+ if i == j: # zero distance
+ continue
+
+ xi, yi = (int(i / length)), (i % length)
+ xj, yj = (int(j / length)), (j % length)
+ distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])
+
+ return distance_matrix
+
+
+def compute_mean_attention_dist(patch_size, attention_weights, model_type):
+ num_cls_tokens = 2 if "distilled" in model_type else 1
+
+ # The attention_weights shape = (batch, num_heads, num_patches, num_patches)
+ attention_weights = attention_weights[
+ ..., num_cls_tokens:, num_cls_tokens:
+ ] # Removing the CLS token
+ num_patches = attention_weights.shape[-1]
+ length = int(np.sqrt(num_patches))
+ assert length**2 == num_patches, "Num patches is not perfect square"
+
+ distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
+ h, w = distance_matrix.shape
+
+ distance_matrix = distance_matrix.reshape((1, 1, h, w))
+ # The attention_weights along the last axis adds to 1
+ # this is due to the fact that they are softmax of the raw logits
+ # summation of the (attention_weights * distance_matrix)
+ # should result in an average distance per token.
+ mean_distances = attention_weights * distance_matrix
+ mean_distances = np.sum(
+ mean_distances, axis=-1
+ ) # Sum along last axis to get average distance per token
+ mean_distances = np.mean(
+ mean_distances, axis=-1
+ ) # Now average across all the tokens
+
+ return mean_distances
+
+
+"""
+Thanks to [Simon Kornblith](https://scholar.google.com/citations?user=1O3RPmsAAAAJ&hl=en)
+from Google who helped us with this code snippet. It can be found
+[here](https://gist.github.com/simonster/155894d48aef2bd36bd2dd8267e62391). Let's now use
+these utilities to generate a plot of attention distances with our loaded model and test
+image.
+"""
+
+# Build the mean distances for every Transformer block.
+mean_distances = {
+ f"{name}_mean_dist": compute_mean_attention_dist(
+ patch_size=PATCH_SIZE,
+ attention_weights=attention_weight,
+ model_type="original_vit",
+ )
+ for name, attention_weight in attention_score_dict.items()
+}
+
+# Get the number of heads from the mean distance output.
+num_heads = mean_distances["transformer_block_0_att_mean_dist"].shape[-1]
+
+# Print the shapes
+print(f"Num Heads: {num_heads}.")
+
+plt.figure(figsize=(9, 9))
+
+for idx in range(len(mean_distances)):
+ mean_distance = mean_distances[f"transformer_block_{idx}_att_mean_dist"]
+ x = [idx] * num_heads
+ y = mean_distance[0, :]
+ plt.scatter(x=x, y=y, label=f"transformer_block_{idx}")
+
+plt.legend(loc="lower right")
+plt.xlabel("Attention Head", fontsize=14)
+plt.ylabel("Attention Distance", fontsize=14)
+plt.title("vit_base_i21k_patch16_224", fontsize=14)
+plt.grid()
+plt.show()
+
+"""
+### Inspecting the plots
+
+**How does self-attention span across the input space? Do they attend
+input regions locally or globally?**
+
+The promise of self-attention is to enable the learning of contextual dependencies
+so that a model can attend to the regions of inputs which are the most salient w.r.t
+the objective. From the above plots we can notice that different attention heads yield
+different attention distances suggesting they use both local and global information
+from an image. But as we go deeper in the Transformer blocks the heads tend to
+focus more on global aggregate information.
+
+Inspired by [Raghu et al.](https://arxiv.org/abs/2108.08810) we computed mean attention
+distances over 1000 images randomly taken from the ImageNet-1k validation set and we
+repeated the process for all the models mentioned at the beginning. Intrestingly, we
+notice the following:
+
+* Pretraining with a larger dataset helps with more global attention spans:
+
+
+| Pretrained on ImageNet-21k
Fine-tuned on ImageNet-1k | Pretrained on ImageNet-1k |
+| :--: | :--: |
+|  |  |
+
+
+* When distilled from a CNN ViTs tend to have less global attention spans:
+
+
+| No distillation (ViT B-16 from DeiT) | Distilled ViT B-16 from DeiT |
+| :--: | :--: |
+|  |  |
+
+To reproduce these plots, please refer to
+[this notebook](https://github.com/sayakpaul/probing-vits/blob/main/notebooks/mean-attention-distance-1k.ipynb).
+
+"""
+
+"""
+## Method II: Attention Rollout
+
+[Abnar et al.](https://arxiv.org/abs/2005.00928) introduce "Attention rollout" for
+quantifying how information flow through self-attention layers of Transformer blocks.
+Original ViT authors use this method to investigate the learned representations, stating:
+
+> Briefly, we averaged attention weights of ViTL/16 across all heads and then recursively
+multiplied the weight matrices of all layers. This accounts for the mixing of attention
+across tokens through all layers.
+
+We used
+[this notebook](https://colab.research.google.com/github/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb)
+and modified the attention rollout code from it for compatibility with our models.
+"""
+
+
+def attention_rollout_map(image, attention_score_dict, model_type):
+ num_cls_tokens = 2 if "distilled" in model_type else 1
+
+ # Stack the individual attention matrices from individual Transformer blocks.
+ attn_mat = ops.stack([attention_score_dict[k] for k in attention_score_dict.keys()])
+ attn_mat = ops.squeeze(attn_mat, axis=1)
+
+ # Average the attention weights across all heads.
+ attn_mat = ops.mean(attn_mat, axis=1)
+
+ # To account for residual connections, we add an identity matrix to the
+ # attention matrix and re-normalize the weights.
+ residual_attn = ops.eye(attn_mat.shape[1])
+ aug_attn_mat = attn_mat + residual_attn
+ aug_attn_mat = aug_attn_mat / ops.sum(aug_attn_mat, axis=-1)[..., None]
+ aug_attn_mat = ops.convert_to_numpy(aug_attn_mat)
+
+ # Recursively multiply the weight matrices.
+ joint_attentions = np.zeros(aug_attn_mat.shape)
+ joint_attentions[0] = aug_attn_mat[0]
+
+ for n in range(1, aug_attn_mat.shape[0]):
+ joint_attentions[n] = np.matmul(aug_attn_mat[n], joint_attentions[n - 1])
+
+ # Attention from the output token to the input space.
+ v = joint_attentions[-1]
+ grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
+ mask = v[0, num_cls_tokens:].reshape(grid_size, grid_size)
+ mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
+ result = (mask * image).astype("uint8")
+ return result
+
+
+"""
+Let's now use these utilities to generate an attention plot based on our previous results
+from the "Running regular inference with the model" section. Following are the links to
+download each individual model:
+
+* [Original ViT model (pretrained on ImageNet-21k)](https://drive.google.com/file/d/1mbtnliT3jRb3yJUHhbItWw8unfYZw8KJ/view?usp=sharing)
+* [Original ViT model (pretrained on ImageNet-1k)](https://drive.google.com/file/d/1ApOdYe4NXxhPhJABefgZ3KVvqsQzhCL7/view?usp=sharing)
+* [DINO model (pretrained on ImageNet-1k)](https://drive.google.com/file/d/16_1oDm0PeCGJ_KGBG5UKVN7TsAtiRNrN/view?usp=sharing)
+* [DeiT models (pretrained on ImageNet-1k including distilled and non-distilled ones)](https://tfhub.dev/sayakpaul/collections/deit/1)
+"""
+
+attn_rollout_result = attention_rollout_map(
+ image, attention_score_dict, model_type="original_vit"
+)
+
+fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))
+fig.suptitle(f"Predicted label: {predicted_label}.", fontsize=20)
+
+_ = ax1.imshow(image)
+_ = ax2.imshow(attn_rollout_result)
+ax1.set_title("Input Image", fontsize=16)
+ax2.set_title("Attention Map", fontsize=16)
+ax1.axis("off")
+ax2.axis("off")
+
+fig.tight_layout()
+fig.subplots_adjust(top=1.35)
+fig.show()
+
+"""
+### Inspecting the plots
+
+**How can we quanitfy the information flow that propagates through the
+attention layers?**
+
+We notice that the model is able to focus its attention on the
+salient parts of the input image. We encourage you to apply this
+method to the other models we mentioned and compare the results. The
+attention rollout plots will differ according to the tasks and
+augmentation the model was trained with. We observe that DeiT has the
+best rollout plot, likely due to its augmentation regime.
+"""
+
+"""
+## Method III: Attention heatmaps
+
+A simple yet useful way to probe into the representation of a Vision Transformer is to
+visualise the attention maps overlayed on the input images. This helps form an intuition
+about what the model attends to. We use the DINO model for this purpose, because it
+yields better attention heatmaps.
+"""
+
+# Load the model.
+vit_dino_base16 = load_model(MODELS_ZIP["vit_dino_base16"])
+print("Model loaded.")
+
+# Preprocess the same image but with normlization.
+img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
+image, preprocessed_image = load_image_from_url(img_url, model_type="dino")
+
+# Grab the predictions.
+predictions, attention_score_dict = split_prediction_and_attention_scores(
+ vit_dino_base16.predict(preprocessed_image)
+)
+
+"""
+A Transformer block consists of multiple heads. Each head in a Transformer block projects
+the input data to different sub-spaces. This helps each individual head to attend to
+different parts of the image. Therefore, it makes sense to visualize each attention head
+map seperately, to make sense of what each heads looks at.
+
+**Notes**:
+
+* The following code has been copy-modified from the
+[original DINO codebase](https://github.com/facebookresearch/dino/blob/main/visualize_attention.py).
+* Here we grab the attention maps of the last Transformer block.
+* [DINO](https://arxiv.org/abs/2104.14294) was pretrained using a self-supervised
+objective.
+"""
+
+
+def attention_heatmap(attention_score_dict, image, model_type="dino"):
+ num_tokens = 2 if "distilled" in model_type else 1
+
+ # Sort the Transformer blocks in order of their depth.
+ attention_score_list = list(attention_score_dict.keys())
+ attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)
+
+ # Process the attention maps for overlay.
+ w_featmap = image.shape[2] // PATCH_SIZE
+ h_featmap = image.shape[1] // PATCH_SIZE
+ attention_scores = attention_score_dict[attention_score_list[0]]
+
+ # Taking the representations from CLS token.
+ attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)
+
+ # Reshape the attention scores to resemble mini patches.
+ attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
+ attentions = attentions.transpose((1, 2, 0))
+
+ # Resize the attention patches to 224x224 (224: 14x16).
+ attentions = ops.image.resize(
+ attentions, size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE)
+ )
+ return attentions
+
+
+"""
+We can use the same image we used for inference with DINO and the `attention_score_dict`
+we extracted from the results.
+"""
+
+# De-normalize the image for visual clarity.
+in1k_mean = np.array([0.485 * 255, 0.456 * 255, 0.406 * 255])
+in1k_std = np.array([0.229 * 255, 0.224 * 255, 0.225 * 255])
+preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
+preprocessed_img_orig = preprocessed_img_orig / 255.0
+preprocessed_img_orig = ops.convert_to_numpy(ops.clip(preprocessed_img_orig, 0.0, 1.0))
+
+# Generate the attention heatmaps.
+attentions = attention_heatmap(attention_score_dict, preprocessed_img_orig)
+
+# Plot the maps.
+fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
+img_count = 0
+
+for i in range(3):
+ for j in range(4):
+ if img_count < len(attentions):
+ axes[i, j].imshow(preprocessed_img_orig[0])
+ axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
+ axes[i, j].title.set_text(f"Attention head: {img_count}")
+ axes[i, j].axis("off")
+ img_count += 1
+
+"""
+### Inspecting the plots
+
+**How can we qualitatively evaluate the attention weights?**
+
+The attention weights of a Transformer block are computed between the
+key and the query. The weights quantifies how important is the key to the query.
+In the ViTs the key and the query comes from the same image, hence
+the weights determine which part of the image is important.
+
+Plotting the attention weigths overlayed on the image gives us a great
+intuition about the parts of the image that are important to the Transformer.
+This plot qualitatively evaluates the purpose of the attention weights.
+"""
+
+"""
+## Method IV: Visualizing the learned projection filters
+
+After extracting non-overlapping patches, ViTs flatten those patches across their
+saptial dimensions, and then linearly project them. One might wonder, how do these
+projections look like? Below, we take the ViT B-16 model and visualize its
+learned projections.
+"""
+
+
+def extract_weights(model, name):
+ for variable in model.weights:
+ if variable.name.startswith(name):
+ return variable.numpy()
+
+
+# Extract the projections.
+projections = extract_weights(vit_base_i21k_patch16_224, "conv_projection/kernel")
+projection_dim = projections.shape[-1]
+patch_h, patch_w, patch_channels = projections.shape[:-1]
+
+# Scale the projections.
+scaled_projections = MinMaxScaler().fit_transform(
+ projections.reshape(-1, projection_dim)
+)
+
+# Reshape the scaled projections so that the leading
+# three dimensions resemble an image.
+scaled_projections = scaled_projections.reshape(patch_h, patch_w, patch_channels, -1)
+
+# Visualize the first 128 filters of the learned
+# projections.
+fig, axes = plt.subplots(nrows=8, ncols=16, figsize=(13, 8))
+img_count = 0
+limit = 128
+
+for i in range(8):
+ for j in range(16):
+ if img_count < limit:
+ axes[i, j].imshow(scaled_projections[..., img_count])
+ axes[i, j].axis("off")
+ img_count += 1
+
+fig.tight_layout()
+
+"""
+### Inspecting the plots
+
+**What do the projection filters learn?**
+
+[When visualized](https://distill.pub/2017/feature-visualization/),
+the kernels of a convolutional neural network show
+the pattern that they look for in an image. This could be circles,
+sometimes lines -- when combined together (in later stage of a ConvNet), the filters
+transform into more complex shapes. We have found a stark similarity between such
+ConvNet kernels and the projection filters of a ViT.
+"""
+
+"""
+## Method V: Visualizing the positional emebddings
+
+Transformers are permutation-invariant. This means that do not take into account
+the spatial position of the input tokens. To overcome this
+limitation, we add positional information to the input tokens.
+
+The positional information can be in the form of leaned positional
+embeddings or handcrafted constant embeddings. In our case, all the
+three variants of ViTs feature learned positional embeddings.
+
+In this section, we visualize the similarities between the
+learned positional embeddings with itself. Below, we take the ViT B-16
+model and visualize the similarity of the positional embeddings by
+taking their dot-product.
+"""
+
+position_embeddings = extract_weights(vit_base_i21k_patch16_224, "pos_embedding")
+
+# Discard the batch dimension and the position embeddings of the
+# cls token.
+position_embeddings = position_embeddings.squeeze()[1:, ...]
+
+similarity = position_embeddings @ position_embeddings.T
+plt.imshow(similarity, cmap="inferno")
+plt.show()
+
+"""
+### Inspecting the plots
+
+**What do the positional embeddings tell us?**
+
+The plot has a distinctive diagonal pattern. The main diagonal is the brightest
+signifying that a position is the most similar to itself. An interesting
+pattern to look out for is the repeating diagonals. The repeating pattern
+portrays a sinusoidal function which is close in essence to what was proposed by
+[Vaswani et. al.](https://arxiv.org/abs/1706.03762) as a hand-crafted feature.
+"""
+
+"""
+## Notes
+
+* DINO extended the attention heatmap generation process to videos. We also
+[applied](https://github.com/sayakpaul/probing-vits/blob/main/notebooks/dino-attention-map
+s-video.ipynb) our DINO implementation on a series of videos and obtained similar
+results. Here's one such video of attention heatmaps:
+
+ 
+
+* [Raghu et al.](https://arxiv.org/abs/2108.08810) use an array of techniques to
+investigate the representations learned by ViTs and make comparisons with that of
+ResNets. We strongly recommend reading their work.
+
+* To author this example, we developed
+[this repository](https://github.com/sayakpaul/probing-vits) to guide our readers so that they
+can easily reproduce the experiments and extend them.
+
+* Another repository that you may find interesting in this regard is
+[`vit-explain`](https://github.com/jacobgil/vit-explain).
+
+* One can also plot the attention rollout and attention heat maps with
+custom images using our Hugging Face spaces.
+
+| Attention Heat Maps | Attention Rollout |
+| :--: | :--: |
+| [](https://huggingface.co/spaces/probing-vits/attention-heat-maps) | [](https://huggingface.co/spaces/probing-vits/attention-rollout) |
+"""
+
+"""
+## Acknowledgements
+
+- [PyImageSearch](https://pyimagesearch.com)
+- [Jarvislabs.ai](https://jarvislabs.ai/)
+- [GDE Program](https://developers.google.com/programs/experts/)
+"""
diff --git a/knowledge_base/vision/randaugment.py b/knowledge_base/vision/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..1196c2872b308b6ffbee0bf60d62fb2774674e28
--- /dev/null
+++ b/knowledge_base/vision/randaugment.py
@@ -0,0 +1,296 @@
+"""
+Title: RandAugment for Image Classification for Improved Robustness
+Authors: [Sayak Paul](https://twitter.com/RisingSayak)[Sachin Prasad](https://github.com/sachinprasadhs)
+Date created: 2021/03/13
+Last modified: 2023/12/12
+Description: RandAugment for training an image classification model with improved robustness.
+Accelerator: GPU
+"""
+
+"""
+Data augmentation is a very useful technique that can help to improve the translational
+invariance of convolutional neural networks (CNN). RandAugment is a stochastic data
+augmentation routine for vision data and was proposed in
+[RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719).
+It is composed of strong augmentation transforms like color jitters, Gaussian blurs,
+saturations, etc. along with more traditional augmentation transforms such as
+random crops.
+
+These parameters are tuned for a given dataset and a network architecture. The authors of
+RandAugment also provide pseudocode of RandAugment in the original paper (Figure 2).
+
+Recently, it has been a key component of works like
+[Noisy Student Training](https://arxiv.org/abs/1911.04252) and
+[Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/abs/1904.12848).
+It has been also central to the
+success of [EfficientNets](https://arxiv.org/abs/1905.11946).
+
+
+```python
+pip install keras-cv
+```
+"""
+
+"""
+## Imports & setup
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+import keras_cv
+from keras import ops
+from keras import layers
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+import tensorflow_datasets as tfds
+
+tfds.disable_progress_bar()
+keras.utils.set_random_seed(42)
+
+"""
+## Load the CIFAR10 dataset
+
+For this example, we will be using the
+[CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+print(f"Total training examples: {len(x_train)}")
+print(f"Total test examples: {len(x_test)}")
+
+"""
+## Define hyperparameters
+"""
+
+AUTO = tf.data.AUTOTUNE
+BATCH_SIZE = 128
+EPOCHS = 1
+IMAGE_SIZE = 72
+
+"""
+## Initialize `RandAugment` object
+
+Now, we will initialize a `RandAugment` object from the `imgaug.augmenters` module with
+the parameters suggested by the RandAugment authors.
+"""
+
+rand_augment = keras_cv.layers.RandAugment(
+ value_range=(0, 255), augmentations_per_image=3, magnitude=0.8
+)
+
+
+"""
+## Create TensorFlow `Dataset` objects
+"""
+
+train_ds_rand = (
+ tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ .shuffle(BATCH_SIZE * 100)
+ .batch(BATCH_SIZE)
+ .map(
+ lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
+ num_parallel_calls=AUTO,
+ )
+ .map(
+ lambda x, y: (rand_augment(tf.cast(x, tf.uint8)), y),
+ num_parallel_calls=AUTO,
+ )
+ .prefetch(AUTO)
+)
+
+test_ds = (
+ tf.data.Dataset.from_tensor_slices((x_test, y_test))
+ .batch(BATCH_SIZE)
+ .map(
+ lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
+ num_parallel_calls=AUTO,
+ )
+ .prefetch(AUTO)
+)
+
+"""
+For comparison purposes, let's also define a simple augmentation pipeline consisting of
+random flips, random rotations, and random zoomings.
+"""
+
+simple_aug = keras.Sequential(
+ [
+ layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(factor=0.02),
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
+ ]
+)
+
+# Now, map the augmentation pipeline to our training dataset
+train_ds_simple = (
+ tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ .shuffle(BATCH_SIZE * 100)
+ .batch(BATCH_SIZE)
+ .map(lambda x, y: (simple_aug(x), y), num_parallel_calls=AUTO)
+ .prefetch(AUTO)
+)
+
+"""
+## Visualize the dataset augmented with RandAugment
+"""
+
+sample_images, _ = next(iter(train_ds_rand))
+plt.figure(figsize=(10, 10))
+for i, image in enumerate(sample_images[:9]):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image.numpy().astype("int"))
+ plt.axis("off")
+
+"""
+You are encouraged to run the above code block a couple of times to see different
+variations.
+"""
+
+"""
+## Visualize the dataset augmented with `simple_aug`
+"""
+
+sample_images, _ = next(iter(train_ds_simple))
+plt.figure(figsize=(10, 10))
+for i, image in enumerate(sample_images[:9]):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(image.numpy().astype("int"))
+ plt.axis("off")
+
+"""
+## Define a model building utility function
+
+Now, we define a CNN model that is based on the
+[ResNet50V2 architecture](https://arxiv.org/abs/1603.05027). Also,
+notice that the network already has a rescaling layer inside it. This eliminates the need
+to do any separate preprocessing on our dataset and is specifically very useful for
+deployment purposes.
+"""
+
+
+def get_training_model():
+ resnet50_v2 = keras.applications.ResNet50V2(
+ weights=None,
+ include_top=True,
+ input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
+ classes=10,
+ )
+ model = keras.Sequential(
+ [
+ layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
+ layers.Rescaling(scale=1.0 / 127.5, offset=-1),
+ resnet50_v2,
+ ]
+ )
+ return model
+
+
+get_training_model().summary()
+
+
+"""
+We will train this network on two different versions of our dataset:
+
+* One augmented with RandAugment.
+* Another one augmented with `simple_aug`.
+
+Since RandAugment is known to enhance the robustness of models to common perturbations
+and corruptions, we will also evaluate our models on the CIFAR-10-C dataset, proposed in
+[Benchmarking Neural Network Robustness to Common Corruptions and Perturbations](https://arxiv.org/abs/1903.12261)
+by Hendrycks et al. The CIFAR-10-C dataset
+consists of 19 different image corruptions and perturbations (for example speckle noise,
+fog, Gaussian blur, etc.) that too at varying severity levels. For this example we will
+be using the following configuration:
+[`cifar10_corrupted/saturate_5`](https://www.tensorflow.org/datasets/catalog/cifar10_corrupted#cifar10_corruptedsaturate_5).
+The images from this configuration look like so:
+
+
+
+In the interest of reproducibility, we serialize the initial random weights of our shallow
+network.
+"""
+
+initial_model = get_training_model()
+initial_model.save_weights("initial.weights.h5")
+
+"""
+## Train model with RandAugment
+"""
+
+rand_aug_model = get_training_model()
+rand_aug_model.load_weights("initial.weights.h5")
+rand_aug_model.compile(
+ loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
+)
+rand_aug_model.fit(train_ds_rand, validation_data=test_ds, epochs=EPOCHS)
+_, test_acc = rand_aug_model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_acc * 100))
+
+"""
+## Train model with `simple_aug`
+"""
+
+simple_aug_model = get_training_model()
+simple_aug_model.load_weights("initial.weights.h5")
+simple_aug_model.compile(
+ loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
+)
+simple_aug_model.fit(train_ds_simple, validation_data=test_ds, epochs=EPOCHS)
+_, test_acc = simple_aug_model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_acc * 100))
+
+"""
+## Load the CIFAR-10-C dataset and evaluate performance
+"""
+
+# Load and prepare the CIFAR-10-C dataset
+# (If it's not already downloaded, it takes ~10 minutes of time to download)
+cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test", as_supervised=True)
+cifar_10_c = cifar_10_c.batch(BATCH_SIZE).map(
+ lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
+ num_parallel_calls=AUTO,
+)
+
+# Evaluate `rand_aug_model`
+_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
+print(
+ "Accuracy with RandAugment on CIFAR-10-C (saturate_5): {:.2f}%".format(
+ test_acc * 100
+ )
+)
+
+# Evaluate `simple_aug_model`
+_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
+print(
+ "Accuracy with simple_aug on CIFAR-10-C (saturate_5): {:.2f}%".format(
+ test_acc * 100
+ )
+)
+
+"""
+For the purpose of this example, we trained the models for only a single epoch. On the
+CIFAR-10-C dataset, the model with RandAugment can perform better with a higher accuracy
+(for example, 76.64% in one experiment) compared with the model trained with `simple_aug`
+(e.g., 64.80%). RandAugment can also help stabilize the training.
+
+In the notebook, you may notice that, at the expense of increased training time with RandAugment,
+we are able to carve out far better performance on the CIFAR-10-C dataset. You can
+experiment on the other corruption and perturbation settings that come with the
+run the same CIFAR-10-C dataset and see if RandAugment helps.
+
+You can also experiment with the different values of `n` and `m` in the `RandAugment`
+object. In the [original paper](https://arxiv.org/abs/1909.13719), the authors show
+the impact of the individual augmentation transforms for a particular task and a range of
+ablation studies. You are welcome to check them out.
+
+RandAugment has shown great progress in improving the robustness of deep models for
+computer vision as shown in works like [Noisy Student Training](https://arxiv.org/abs/1911.04252) and
+[FixMatch](https://arxiv.org/abs/2001.07685). This makes RandAugment quite a useful
+recipe for training different vision models.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/randaugment)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/randaugment).
+"""
diff --git a/knowledge_base/vision/reptile.py b/knowledge_base/vision/reptile.py
new file mode 100644
index 0000000000000000000000000000000000000000..184e4b81771df47fbe1d8b8632c7f8f030c8a019
--- /dev/null
+++ b/knowledge_base/vision/reptile.py
@@ -0,0 +1,296 @@
+"""
+Title: Few-Shot learning with Reptile
+Author: [ADMoreau](https://github.com/ADMoreau)
+Date created: 2020/05/21
+Last modified: 2023/07/20
+Description: Few-shot classification on the Omniglot dataset using Reptile.
+Accelerator: GPU
+Converted to Keras 3 By: [Muhammad Anas Raza](https://anasrz.com)
+"""
+
+"""
+## Introduction
+
+The [Reptile](https://arxiv.org/abs/1803.02999) algorithm was developed by OpenAI to
+perform model-agnostic meta-learning. Specifically, this algorithm was designed to
+quickly learn to perform new tasks with minimal training (few-shot learning).
+The algorithm works by performing Stochastic Gradient Descent using the
+difference between weights trained on a mini-batch of never-seen-before data and the
+model weights prior to training over a fixed number of meta-iterations.
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+
+import matplotlib.pyplot as plt
+import numpy as np
+import random
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+"""
+## Define the Hyperparameters
+"""
+
+learning_rate = 0.003
+meta_step_size = 0.25
+
+inner_batch_size = 25
+eval_batch_size = 25
+
+meta_iters = 2000
+eval_iters = 5
+inner_iters = 4
+
+eval_interval = 1
+train_shots = 20
+shots = 5
+classes = 5
+
+"""
+## Prepare the data
+
+The [Omniglot dataset](https://github.com/brendenlake/omniglot/) is a dataset of 1,623
+characters taken from 50 different alphabets, with 20 examples for each character.
+The 20 samples for each character were drawn online via Amazon's Mechanical Turk. For the
+few-shot learning task, `k` samples (or "shots") are drawn randomly from `n` randomly-chosen
+classes. These `n` numerical values are used to create a new set of temporary labels to use
+to test the model's ability to learn a new task given few examples. In other words, if you
+are training on 5 classes, your new class labels will be either 0, 1, 2, 3, or 4.
+Omniglot is a great dataset for this task since there are many different classes to draw
+from, with a reasonable number of samples for each class.
+"""
+
+
+class Dataset:
+ # This class will facilitate the creation of a few-shot dataset
+ # from the Omniglot dataset that can be sampled from quickly while also
+ # allowing to create new labels at the same time.
+ def __init__(self, training):
+ # Download the tfrecord files containing the omniglot data and convert to a
+ # dataset.
+ split = "train" if training else "test"
+ ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)
+ # Iterate over the dataset to get each individual image and its class,
+ # and put that data into a dictionary.
+ self.data = {}
+
+ def extraction(image, label):
+ # This function will shrink the Omniglot images to the desired size,
+ # scale pixel values and convert the RGB image to grayscale
+ image = tf.image.convert_image_dtype(image, tf.float32)
+ image = tf.image.rgb_to_grayscale(image)
+ image = tf.image.resize(image, [28, 28])
+ return image, label
+
+ for image, label in ds.map(extraction):
+ image = image.numpy()
+ label = str(label.numpy())
+ if label not in self.data:
+ self.data[label] = []
+ self.data[label].append(image)
+ self.labels = list(self.data.keys())
+
+ def get_mini_dataset(
+ self, batch_size, repetitions, shots, num_classes, split=False
+ ):
+ temp_labels = np.zeros(shape=(num_classes * shots))
+ temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))
+ if split:
+ test_labels = np.zeros(shape=(num_classes))
+ test_images = np.zeros(shape=(num_classes, 28, 28, 1))
+
+ # Get a random subset of labels from the entire label set.
+ label_subset = random.choices(self.labels, k=num_classes)
+ for class_idx, class_obj in enumerate(label_subset):
+ # Use enumerated index value as a temporary label for mini-batch in
+ # few shot learning.
+ temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
+ # If creating a split dataset for testing, select an extra sample from each
+ # label to create the test dataset.
+ if split:
+ test_labels[class_idx] = class_idx
+ images_to_split = random.choices(
+ self.data[label_subset[class_idx]], k=shots + 1
+ )
+ test_images[class_idx] = images_to_split[-1]
+ temp_images[class_idx * shots : (class_idx + 1) * shots] = (
+ images_to_split[:-1]
+ )
+ else:
+ # For each index in the randomly selected label_subset, sample the
+ # necessary number of images.
+ temp_images[class_idx * shots : (class_idx + 1) * shots] = (
+ random.choices(self.data[label_subset[class_idx]], k=shots)
+ )
+
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (temp_images.astype(np.float32), temp_labels.astype(np.int32))
+ )
+ dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)
+ if split:
+ return dataset, test_images, test_labels
+ return dataset
+
+
+import urllib3
+
+urllib3.disable_warnings() # Disable SSL warnings that may happen during download.
+train_dataset = Dataset(training=True)
+test_dataset = Dataset(training=False)
+
+"""
+## Visualize some examples from the dataset
+"""
+
+_, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))
+
+sample_keys = list(train_dataset.data.keys())
+
+for a in range(5):
+ for b in range(5):
+ temp_image = train_dataset.data[sample_keys[a]][b]
+ temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)
+ temp_image *= 255
+ temp_image = np.clip(temp_image, 0, 255).astype("uint8")
+ if b == 2:
+ axarr[a, b].set_title("Class : " + sample_keys[a])
+ axarr[a, b].imshow(temp_image, cmap="gray")
+ axarr[a, b].xaxis.set_visible(False)
+ axarr[a, b].yaxis.set_visible(False)
+plt.show()
+
+"""
+## Build the model
+"""
+
+
+def conv_bn(x):
+ x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding="same")(x)
+ x = layers.BatchNormalization()(x)
+ return layers.ReLU()(x)
+
+
+inputs = layers.Input(shape=(28, 28, 1))
+x = conv_bn(inputs)
+x = conv_bn(x)
+x = conv_bn(x)
+x = conv_bn(x)
+x = layers.Flatten()(x)
+outputs = layers.Dense(classes, activation="softmax")(x)
+model = keras.Model(inputs=inputs, outputs=outputs)
+model.compile()
+optimizer = keras.optimizers.SGD(learning_rate=learning_rate)
+
+"""
+## Train the model
+"""
+
+training = []
+testing = []
+for meta_iter in range(meta_iters):
+ frac_done = meta_iter / meta_iters
+ cur_meta_step_size = (1 - frac_done) * meta_step_size
+ # Temporarily save the weights from the model.
+ old_vars = model.get_weights()
+ # Get a sample from the full dataset.
+ mini_dataset = train_dataset.get_mini_dataset(
+ inner_batch_size, inner_iters, train_shots, classes
+ )
+ for images, labels in mini_dataset:
+ with tf.GradientTape() as tape:
+ preds = model(images)
+ loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
+ grads = tape.gradient(loss, model.trainable_weights)
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
+ new_vars = model.get_weights()
+ # Perform SGD for the meta step.
+ for var in range(len(new_vars)):
+ new_vars[var] = old_vars[var] + (
+ (new_vars[var] - old_vars[var]) * cur_meta_step_size
+ )
+ # After the meta-learning step, reload the newly-trained weights into the model.
+ model.set_weights(new_vars)
+ # Evaluation loop
+ if meta_iter % eval_interval == 0:
+ accuracies = []
+ for dataset in (train_dataset, test_dataset):
+ # Sample a mini dataset from the full dataset.
+ train_set, test_images, test_labels = dataset.get_mini_dataset(
+ eval_batch_size, eval_iters, shots, classes, split=True
+ )
+ old_vars = model.get_weights()
+ # Train on the samples and get the resulting accuracies.
+ for images, labels in train_set:
+ with tf.GradientTape() as tape:
+ preds = model(images)
+ loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
+ grads = tape.gradient(loss, model.trainable_weights)
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
+ test_preds = model.predict(test_images, verbose=0)
+ test_preds = tf.argmax(test_preds).numpy()
+ num_correct = (test_preds == test_labels).sum()
+ # Reset the weights after getting the evaluation accuracies.
+ model.set_weights(old_vars)
+ accuracies.append(num_correct / classes)
+ training.append(accuracies[0])
+ testing.append(accuracies[1])
+ if meta_iter % 100 == 0:
+ print(
+ "batch %d: train=%f test=%f" % (meta_iter, accuracies[0], accuracies[1])
+ )
+
+"""
+## Visualize Results
+"""
+
+# First, some preprocessing to smooth the training and testing arrays for display.
+window_length = 100
+train_s = np.r_[
+ training[window_length - 1 : 0 : -1],
+ training,
+ training[-1:-window_length:-1],
+]
+test_s = np.r_[
+ testing[window_length - 1 : 0 : -1], testing, testing[-1:-window_length:-1]
+]
+w = np.hamming(window_length)
+train_y = np.convolve(w / w.sum(), train_s, mode="valid")
+test_y = np.convolve(w / w.sum(), test_s, mode="valid")
+
+# Display the training accuracies.
+x = np.arange(0, len(test_y), 1)
+plt.plot(x, test_y, x, train_y)
+plt.legend(["test", "train"])
+plt.grid()
+
+train_set, test_images, test_labels = dataset.get_mini_dataset(
+ eval_batch_size, eval_iters, shots, classes, split=True
+)
+for images, labels in train_set:
+ with tf.GradientTape() as tape:
+ preds = model(images)
+ loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
+ grads = tape.gradient(loss, model.trainable_weights)
+ optimizer.apply_gradients(zip(grads, model.trainable_weights))
+test_preds = model.predict(test_images)
+test_preds = tf.argmax(test_preds).numpy()
+
+_, axarr = plt.subplots(nrows=1, ncols=5, figsize=(20, 20))
+
+sample_keys = list(train_dataset.data.keys())
+
+for i, ax in zip(range(5), axarr):
+ temp_image = np.stack((test_images[i, :, :, 0],) * 3, axis=2)
+ temp_image *= 255
+ temp_image = np.clip(temp_image, 0, 255).astype("uint8")
+ ax.set_title(
+ "Label : {}, Prediction : {}".format(int(test_labels[i]), test_preds[i])
+ )
+ ax.imshow(temp_image, cmap="gray")
+ ax.xaxis.set_visible(False)
+ ax.yaxis.set_visible(False)
+plt.show()
diff --git a/knowledge_base/vision/retinanet.py b/knowledge_base/vision/retinanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c27685616f19f070940f81f79b3feb4077eb3b97
--- /dev/null
+++ b/knowledge_base/vision/retinanet.py
@@ -0,0 +1,984 @@
+"""
+Title: Object Detection with RetinaNet
+Author: [Srihari Humbarwadi](https://twitter.com/srihari_rh)
+Date created: 2020/05/17
+Last modified: 2023/07/10
+Description: Implementing RetinaNet: Focal Loss for Dense Object Detection.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Object detection a very important problem in computer
+vision. Here the model is tasked with localizing the objects present in an
+image, and at the same time, classifying them into different categories.
+Object detection models can be broadly classified into "single-stage" and
+"two-stage" detectors. Two-stage detectors are often more accurate but at the
+cost of being slower. Here in this example, we will implement RetinaNet,
+a popular single-stage detector, which is accurate and runs fast.
+RetinaNet uses a feature pyramid network to efficiently detect objects at
+multiple scales and introduces a new loss, the Focal loss function, to alleviate
+the problem of the extreme foreground-background class imbalance.
+
+**References:**
+
+- [RetinaNet Paper](https://arxiv.org/abs/1708.02002)
+- [Feature Pyramid Network Paper](https://arxiv.org/abs/1612.03144)
+"""
+
+
+import os
+import re
+import zipfile
+
+import numpy as np
+import tensorflow as tf
+from tensorflow import keras
+
+import matplotlib.pyplot as plt
+import tensorflow_datasets as tfds
+
+
+"""
+## Downloading the COCO2017 dataset
+
+Training on the entire COCO2017 dataset which has around 118k images takes a
+lot of time, hence we will be using a smaller subset of ~500 images for
+training in this example.
+"""
+
+url = "https://github.com/srihari-humbarwadi/datasets/releases/download/v0.1.0/data.zip"
+filename = os.path.join(os.getcwd(), "data.zip")
+keras.utils.get_file(filename, url)
+
+
+with zipfile.ZipFile("data.zip", "r") as z_fp:
+ z_fp.extractall("./")
+
+
+"""
+## Implementing utility functions
+
+Bounding boxes can be represented in multiple ways, the most common formats are:
+
+- Storing the coordinates of the corners `[xmin, ymin, xmax, ymax]`
+- Storing the coordinates of the center and the box dimensions
+`[x, y, width, height]`
+
+Since we require both formats, we will be implementing functions for converting
+between the formats.
+"""
+
+
+def swap_xy(boxes):
+ """Swaps order the of x and y coordinates of the boxes.
+
+ Arguments:
+ boxes: A tensor with shape `(num_boxes, 4)` representing bounding boxes.
+
+ Returns:
+ swapped boxes with shape same as that of boxes.
+ """
+ return tf.stack([boxes[:, 1], boxes[:, 0], boxes[:, 3], boxes[:, 2]], axis=-1)
+
+
+def convert_to_xywh(boxes):
+ """Changes the box format to center, width and height.
+
+ Arguments:
+ boxes: A tensor of rank 2 or higher with a shape of `(..., num_boxes, 4)`
+ representing bounding boxes where each box is of the format
+ `[xmin, ymin, xmax, ymax]`.
+
+ Returns:
+ converted boxes with shape same as that of boxes.
+ """
+ return tf.concat(
+ [(boxes[..., :2] + boxes[..., 2:]) / 2.0, boxes[..., 2:] - boxes[..., :2]],
+ axis=-1,
+ )
+
+
+def convert_to_corners(boxes):
+ """Changes the box format to corner coordinates
+
+ Arguments:
+ boxes: A tensor of rank 2 or higher with a shape of `(..., num_boxes, 4)`
+ representing bounding boxes where each box is of the format
+ `[x, y, width, height]`.
+
+ Returns:
+ converted boxes with shape same as that of boxes.
+ """
+ return tf.concat(
+ [boxes[..., :2] - boxes[..., 2:] / 2.0, boxes[..., :2] + boxes[..., 2:] / 2.0],
+ axis=-1,
+ )
+
+
+"""
+## Computing pairwise Intersection Over Union (IOU)
+
+As we will see later in the example, we would be assigning ground truth boxes
+to anchor boxes based on the extent of overlapping. This will require us to
+calculate the Intersection Over Union (IOU) between all the anchor
+boxes and ground truth boxes pairs.
+"""
+
+
+def compute_iou(boxes1, boxes2):
+ """Computes pairwise IOU matrix for given two sets of boxes
+
+ Arguments:
+ boxes1: A tensor with shape `(N, 4)` representing bounding boxes
+ where each box is of the format `[x, y, width, height]`.
+ boxes2: A tensor with shape `(M, 4)` representing bounding boxes
+ where each box is of the format `[x, y, width, height]`.
+
+ Returns:
+ pairwise IOU matrix with shape `(N, M)`, where the value at ith row
+ jth column holds the IOU between ith box and jth box from
+ boxes1 and boxes2 respectively.
+ """
+ boxes1_corners = convert_to_corners(boxes1)
+ boxes2_corners = convert_to_corners(boxes2)
+ lu = tf.maximum(boxes1_corners[:, None, :2], boxes2_corners[:, :2])
+ rd = tf.minimum(boxes1_corners[:, None, 2:], boxes2_corners[:, 2:])
+ intersection = tf.maximum(0.0, rd - lu)
+ intersection_area = intersection[:, :, 0] * intersection[:, :, 1]
+ boxes1_area = boxes1[:, 2] * boxes1[:, 3]
+ boxes2_area = boxes2[:, 2] * boxes2[:, 3]
+ union_area = tf.maximum(
+ boxes1_area[:, None] + boxes2_area - intersection_area, 1e-8
+ )
+ return tf.clip_by_value(intersection_area / union_area, 0.0, 1.0)
+
+
+def visualize_detections(
+ image, boxes, classes, scores, figsize=(7, 7), linewidth=1, color=[0, 0, 1]
+):
+ """Visualize Detections"""
+ image = np.array(image, dtype=np.uint8)
+ plt.figure(figsize=figsize)
+ plt.axis("off")
+ plt.imshow(image)
+ ax = plt.gca()
+ for box, _cls, score in zip(boxes, classes, scores):
+ text = "{}: {:.2f}".format(_cls, score)
+ x1, y1, x2, y2 = box
+ w, h = x2 - x1, y2 - y1
+ patch = plt.Rectangle(
+ [x1, y1], w, h, fill=False, edgecolor=color, linewidth=linewidth
+ )
+ ax.add_patch(patch)
+ ax.text(
+ x1,
+ y1,
+ text,
+ bbox={"facecolor": color, "alpha": 0.4},
+ clip_box=ax.clipbox,
+ clip_on=True,
+ )
+ plt.show()
+ return ax
+
+
+"""
+## Implementing Anchor generator
+
+Anchor boxes are fixed sized boxes that the model uses to predict the bounding
+box for an object. It does this by regressing the offset between the location
+of the object's center and the center of an anchor box, and then uses the width
+and height of the anchor box to predict a relative scale of the object. In the
+case of RetinaNet, each location on a given feature map has nine anchor boxes
+(at three scales and three ratios).
+"""
+
+
+class AnchorBox:
+ """Generates anchor boxes.
+
+ This class has operations to generate anchor boxes for feature maps at
+ strides `[8, 16, 32, 64, 128]`. Where each anchor each box is of the
+ format `[x, y, width, height]`.
+
+ Attributes:
+ aspect_ratios: A list of float values representing the aspect ratios of
+ the anchor boxes at each location on the feature map
+ scales: A list of float values representing the scale of the anchor boxes
+ at each location on the feature map.
+ num_anchors: The number of anchor boxes at each location on feature map
+ areas: A list of float values representing the areas of the anchor
+ boxes for each feature map in the feature pyramid.
+ strides: A list of float value representing the strides for each feature
+ map in the feature pyramid.
+ """
+
+ def __init__(self):
+ self.aspect_ratios = [0.5, 1.0, 2.0]
+ self.scales = [2**x for x in [0, 1 / 3, 2 / 3]]
+
+ self._num_anchors = len(self.aspect_ratios) * len(self.scales)
+ self._strides = [2**i for i in range(3, 8)]
+ self._areas = [x**2 for x in [32.0, 64.0, 128.0, 256.0, 512.0]]
+ self._anchor_dims = self._compute_dims()
+
+ def _compute_dims(self):
+ """Computes anchor box dimensions for all ratios and scales at all levels
+ of the feature pyramid.
+ """
+ anchor_dims_all = []
+ for area in self._areas:
+ anchor_dims = []
+ for ratio in self.aspect_ratios:
+ anchor_height = tf.math.sqrt(area / ratio)
+ anchor_width = area / anchor_height
+ dims = tf.reshape(
+ tf.stack([anchor_width, anchor_height], axis=-1), [1, 1, 2]
+ )
+ for scale in self.scales:
+ anchor_dims.append(scale * dims)
+ anchor_dims_all.append(tf.stack(anchor_dims, axis=-2))
+ return anchor_dims_all
+
+ def _get_anchors(self, feature_height, feature_width, level):
+ """Generates anchor boxes for a given feature map size and level
+
+ Arguments:
+ feature_height: An integer representing the height of the feature map.
+ feature_width: An integer representing the width of the feature map.
+ level: An integer representing the level of the feature map in the
+ feature pyramid.
+
+ Returns:
+ anchor boxes with the shape
+ `(feature_height * feature_width * num_anchors, 4)`
+ """
+ rx = tf.range(feature_width, dtype=tf.float32) + 0.5
+ ry = tf.range(feature_height, dtype=tf.float32) + 0.5
+ centers = tf.stack(tf.meshgrid(rx, ry), axis=-1) * self._strides[level - 3]
+ centers = tf.expand_dims(centers, axis=-2)
+ centers = tf.tile(centers, [1, 1, self._num_anchors, 1])
+ dims = tf.tile(
+ self._anchor_dims[level - 3], [feature_height, feature_width, 1, 1]
+ )
+ anchors = tf.concat([centers, dims], axis=-1)
+ return tf.reshape(
+ anchors, [feature_height * feature_width * self._num_anchors, 4]
+ )
+
+ def get_anchors(self, image_height, image_width):
+ """Generates anchor boxes for all the feature maps of the feature pyramid.
+
+ Arguments:
+ image_height: Height of the input image.
+ image_width: Width of the input image.
+
+ Returns:
+ anchor boxes for all the feature maps, stacked as a single tensor
+ with shape `(total_anchors, 4)`
+ """
+ anchors = [
+ self._get_anchors(
+ tf.math.ceil(image_height / 2**i),
+ tf.math.ceil(image_width / 2**i),
+ i,
+ )
+ for i in range(3, 8)
+ ]
+ return tf.concat(anchors, axis=0)
+
+
+"""
+## Preprocessing data
+
+Preprocessing the images involves two steps:
+
+- Resizing the image: Images are resized such that the shortest size is equal
+to 800 px, after resizing if the longest side of the image exceeds 1333 px,
+the image is resized such that the longest size is now capped at 1333 px.
+- Applying augmentation: Random scale jittering and random horizontal flipping
+are the only augmentations applied to the images.
+
+Along with the images, bounding boxes are rescaled and flipped if required.
+"""
+
+
+def random_flip_horizontal(image, boxes):
+ """Flips image and boxes horizontally with 50% chance
+
+ Arguments:
+ image: A 3-D tensor of shape `(height, width, channels)` representing an
+ image.
+ boxes: A tensor with shape `(num_boxes, 4)` representing bounding boxes,
+ having normalized coordinates.
+
+ Returns:
+ Randomly flipped image and boxes
+ """
+ if tf.random.uniform(()) > 0.5:
+ image = tf.image.flip_left_right(image)
+ boxes = tf.stack(
+ [1 - boxes[:, 2], boxes[:, 1], 1 - boxes[:, 0], boxes[:, 3]], axis=-1
+ )
+ return image, boxes
+
+
+def resize_and_pad_image(
+ image, min_side=800.0, max_side=1333.0, jitter=[640, 1024], stride=128.0
+):
+ """Resizes and pads image while preserving aspect ratio.
+
+ 1. Resizes images so that the shorter side is equal to `min_side`
+ 2. If the longer side is greater than `max_side`, then resize the image
+ with longer side equal to `max_side`
+ 3. Pad with zeros on right and bottom to make the image shape divisible by
+ `stride`
+
+ Arguments:
+ image: A 3-D tensor of shape `(height, width, channels)` representing an
+ image.
+ min_side: The shorter side of the image is resized to this value, if
+ `jitter` is set to None.
+ max_side: If the longer side of the image exceeds this value after
+ resizing, the image is resized such that the longer side now equals to
+ this value.
+ jitter: A list of floats containing minimum and maximum size for scale
+ jittering. If available, the shorter side of the image will be
+ resized to a random value in this range.
+ stride: The stride of the smallest feature map in the feature pyramid.
+ Can be calculated using `image_size / feature_map_size`.
+
+ Returns:
+ image: Resized and padded image.
+ image_shape: Shape of the image before padding.
+ ratio: The scaling factor used to resize the image
+ """
+ image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32)
+ if jitter is not None:
+ min_side = tf.random.uniform((), jitter[0], jitter[1], dtype=tf.float32)
+ ratio = min_side / tf.reduce_min(image_shape)
+ if ratio * tf.reduce_max(image_shape) > max_side:
+ ratio = max_side / tf.reduce_max(image_shape)
+ image_shape = ratio * image_shape
+ image = tf.image.resize(image, tf.cast(image_shape, dtype=tf.int32))
+ padded_image_shape = tf.cast(
+ tf.math.ceil(image_shape / stride) * stride, dtype=tf.int32
+ )
+ image = tf.image.pad_to_bounding_box(
+ image, 0, 0, padded_image_shape[0], padded_image_shape[1]
+ )
+ return image, image_shape, ratio
+
+
+def preprocess_data(sample):
+ """Applies preprocessing step to a single sample
+
+ Arguments:
+ sample: A dict representing a single training sample.
+
+ Returns:
+ image: Resized and padded image with random horizontal flipping applied.
+ bbox: Bounding boxes with the shape `(num_objects, 4)` where each box is
+ of the format `[x, y, width, height]`.
+ class_id: An tensor representing the class id of the objects, having
+ shape `(num_objects,)`.
+ """
+ image = sample["image"]
+ bbox = swap_xy(sample["objects"]["bbox"])
+ class_id = tf.cast(sample["objects"]["label"], dtype=tf.int32)
+
+ image, bbox = random_flip_horizontal(image, bbox)
+ image, image_shape, _ = resize_and_pad_image(image)
+
+ bbox = tf.stack(
+ [
+ bbox[:, 0] * image_shape[1],
+ bbox[:, 1] * image_shape[0],
+ bbox[:, 2] * image_shape[1],
+ bbox[:, 3] * image_shape[0],
+ ],
+ axis=-1,
+ )
+ bbox = convert_to_xywh(bbox)
+ return image, bbox, class_id
+
+
+"""
+## Encoding labels
+
+The raw labels, consisting of bounding boxes and class ids need to be
+transformed into targets for training. This transformation consists of
+the following steps:
+
+- Generating anchor boxes for the given image dimensions
+- Assigning ground truth boxes to the anchor boxes
+- The anchor boxes that are not assigned any objects, are either assigned the
+background class or ignored depending on the IOU
+- Generating the classification and regression targets using anchor boxes
+"""
+
+
+class LabelEncoder:
+ """Transforms the raw labels into targets for training.
+
+ This class has operations to generate targets for a batch of samples which
+ is made up of the input images, bounding boxes for the objects present and
+ their class ids.
+
+ Attributes:
+ anchor_box: Anchor box generator to encode the bounding boxes.
+ box_variance: The scaling factors used to scale the bounding box targets.
+ """
+
+ def __init__(self):
+ self._anchor_box = AnchorBox()
+ self._box_variance = tf.convert_to_tensor(
+ [0.1, 0.1, 0.2, 0.2], dtype=tf.float32
+ )
+
+ def _match_anchor_boxes(
+ self, anchor_boxes, gt_boxes, match_iou=0.5, ignore_iou=0.4
+ ):
+ """Matches ground truth boxes to anchor boxes based on IOU.
+
+ 1. Calculates the pairwise IOU for the M `anchor_boxes` and N `gt_boxes`
+ to get a `(M, N)` shaped matrix.
+ 2. The ground truth box with the maximum IOU in each row is assigned to
+ the anchor box provided the IOU is greater than `match_iou`.
+ 3. If the maximum IOU in a row is less than `ignore_iou`, the anchor
+ box is assigned with the background class.
+ 4. The remaining anchor boxes that do not have any class assigned are
+ ignored during training.
+
+ Arguments:
+ anchor_boxes: A float tensor with the shape `(total_anchors, 4)`
+ representing all the anchor boxes for a given input image shape,
+ where each anchor box is of the format `[x, y, width, height]`.
+ gt_boxes: A float tensor with shape `(num_objects, 4)` representing
+ the ground truth boxes, where each box is of the format
+ `[x, y, width, height]`.
+ match_iou: A float value representing the minimum IOU threshold for
+ determining if a ground truth box can be assigned to an anchor box.
+ ignore_iou: A float value representing the IOU threshold under which
+ an anchor box is assigned to the background class.
+
+ Returns:
+ matched_gt_idx: Index of the matched object
+ positive_mask: A mask for anchor boxes that have been assigned ground
+ truth boxes.
+ ignore_mask: A mask for anchor boxes that need to by ignored during
+ training
+ """
+ iou_matrix = compute_iou(anchor_boxes, gt_boxes)
+ max_iou = tf.reduce_max(iou_matrix, axis=1)
+ matched_gt_idx = tf.argmax(iou_matrix, axis=1)
+ positive_mask = tf.greater_equal(max_iou, match_iou)
+ negative_mask = tf.less(max_iou, ignore_iou)
+ ignore_mask = tf.logical_not(tf.logical_or(positive_mask, negative_mask))
+ return (
+ matched_gt_idx,
+ tf.cast(positive_mask, dtype=tf.float32),
+ tf.cast(ignore_mask, dtype=tf.float32),
+ )
+
+ def _compute_box_target(self, anchor_boxes, matched_gt_boxes):
+ """Transforms the ground truth boxes into targets for training"""
+ box_target = tf.concat(
+ [
+ (matched_gt_boxes[:, :2] - anchor_boxes[:, :2]) / anchor_boxes[:, 2:],
+ tf.math.log(matched_gt_boxes[:, 2:] / anchor_boxes[:, 2:]),
+ ],
+ axis=-1,
+ )
+ box_target = box_target / self._box_variance
+ return box_target
+
+ def _encode_sample(self, image_shape, gt_boxes, cls_ids):
+ """Creates box and classification targets for a single sample"""
+ anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2])
+ cls_ids = tf.cast(cls_ids, dtype=tf.float32)
+ matched_gt_idx, positive_mask, ignore_mask = self._match_anchor_boxes(
+ anchor_boxes, gt_boxes
+ )
+ matched_gt_boxes = tf.gather(gt_boxes, matched_gt_idx)
+ box_target = self._compute_box_target(anchor_boxes, matched_gt_boxes)
+ matched_gt_cls_ids = tf.gather(cls_ids, matched_gt_idx)
+ cls_target = tf.where(
+ tf.not_equal(positive_mask, 1.0), -1.0, matched_gt_cls_ids
+ )
+ cls_target = tf.where(tf.equal(ignore_mask, 1.0), -2.0, cls_target)
+ cls_target = tf.expand_dims(cls_target, axis=-1)
+ label = tf.concat([box_target, cls_target], axis=-1)
+ return label
+
+ def encode_batch(self, batch_images, gt_boxes, cls_ids):
+ """Creates box and classification targets for a batch"""
+ images_shape = tf.shape(batch_images)
+ batch_size = images_shape[0]
+
+ labels = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=True)
+ for i in range(batch_size):
+ label = self._encode_sample(images_shape, gt_boxes[i], cls_ids[i])
+ labels = labels.write(i, label)
+ batch_images = tf.keras.applications.resnet.preprocess_input(batch_images)
+ return batch_images, labels.stack()
+
+
+"""
+## Building the ResNet50 backbone
+
+RetinaNet uses a ResNet based backbone, using which a feature pyramid network
+is constructed. In the example we use ResNet50 as the backbone, and return the
+feature maps at strides 8, 16 and 32.
+"""
+
+
+def get_backbone():
+ """Builds ResNet50 with pre-trained imagenet weights"""
+ backbone = keras.applications.ResNet50(
+ include_top=False, input_shape=[None, None, 3]
+ )
+ c3_output, c4_output, c5_output = [
+ backbone.get_layer(layer_name).output
+ for layer_name in ["conv3_block4_out", "conv4_block6_out", "conv5_block3_out"]
+ ]
+ return keras.Model(
+ inputs=[backbone.inputs], outputs=[c3_output, c4_output, c5_output]
+ )
+
+
+"""
+## Building Feature Pyramid Network as a custom layer
+"""
+
+
+class FeaturePyramid(keras.layers.Layer):
+ """Builds the Feature Pyramid with the feature maps from the backbone.
+
+ Attributes:
+ num_classes: Number of classes in the dataset.
+ backbone: The backbone to build the feature pyramid from.
+ Currently supports ResNet50 only.
+ """
+
+ def __init__(self, backbone=None, **kwargs):
+ super().__init__(name="FeaturePyramid", **kwargs)
+ self.backbone = backbone if backbone else get_backbone()
+ self.conv_c3_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
+ self.conv_c4_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
+ self.conv_c5_1x1 = keras.layers.Conv2D(256, 1, 1, "same")
+ self.conv_c3_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
+ self.conv_c4_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
+ self.conv_c5_3x3 = keras.layers.Conv2D(256, 3, 1, "same")
+ self.conv_c6_3x3 = keras.layers.Conv2D(256, 3, 2, "same")
+ self.conv_c7_3x3 = keras.layers.Conv2D(256, 3, 2, "same")
+ self.upsample_2x = keras.layers.UpSampling2D(2)
+
+ def call(self, images, training=False):
+ c3_output, c4_output, c5_output = self.backbone(images, training=training)
+ p3_output = self.conv_c3_1x1(c3_output)
+ p4_output = self.conv_c4_1x1(c4_output)
+ p5_output = self.conv_c5_1x1(c5_output)
+ p4_output = p4_output + self.upsample_2x(p5_output)
+ p3_output = p3_output + self.upsample_2x(p4_output)
+ p3_output = self.conv_c3_3x3(p3_output)
+ p4_output = self.conv_c4_3x3(p4_output)
+ p5_output = self.conv_c5_3x3(p5_output)
+ p6_output = self.conv_c6_3x3(c5_output)
+ p7_output = self.conv_c7_3x3(tf.nn.relu(p6_output))
+ return p3_output, p4_output, p5_output, p6_output, p7_output
+
+
+"""
+## Building the classification and box regression heads.
+The RetinaNet model has separate heads for bounding box regression and
+for predicting class probabilities for the objects. These heads are shared
+between all the feature maps of the feature pyramid.
+"""
+
+
+def build_head(output_filters, bias_init):
+ """Builds the class/box predictions head.
+
+ Arguments:
+ output_filters: Number of convolution filters in the final layer.
+ bias_init: Bias Initializer for the final convolution layer.
+
+ Returns:
+ A keras sequential model representing either the classification
+ or the box regression head depending on `output_filters`.
+ """
+ head = keras.Sequential([keras.Input(shape=[None, None, 256])])
+ kernel_init = tf.initializers.RandomNormal(0.0, 0.01)
+ for _ in range(4):
+ head.add(
+ keras.layers.Conv2D(256, 3, padding="same", kernel_initializer=kernel_init)
+ )
+ head.add(keras.layers.ReLU())
+ head.add(
+ keras.layers.Conv2D(
+ output_filters,
+ 3,
+ 1,
+ padding="same",
+ kernel_initializer=kernel_init,
+ bias_initializer=bias_init,
+ )
+ )
+ return head
+
+
+"""
+## Building RetinaNet using a subclassed model
+"""
+
+
+class RetinaNet(keras.Model):
+ """A subclassed Keras model implementing the RetinaNet architecture.
+
+ Attributes:
+ num_classes: Number of classes in the dataset.
+ backbone: The backbone to build the feature pyramid from.
+ Currently supports ResNet50 only.
+ """
+
+ def __init__(self, num_classes, backbone=None, **kwargs):
+ super().__init__(name="RetinaNet", **kwargs)
+ self.fpn = FeaturePyramid(backbone)
+ self.num_classes = num_classes
+
+ prior_probability = tf.constant_initializer(-np.log((1 - 0.01) / 0.01))
+ self.cls_head = build_head(9 * num_classes, prior_probability)
+ self.box_head = build_head(9 * 4, "zeros")
+
+ def call(self, image, training=False):
+ features = self.fpn(image, training=training)
+ N = tf.shape(image)[0]
+ cls_outputs = []
+ box_outputs = []
+ for feature in features:
+ box_outputs.append(tf.reshape(self.box_head(feature), [N, -1, 4]))
+ cls_outputs.append(
+ tf.reshape(self.cls_head(feature), [N, -1, self.num_classes])
+ )
+ cls_outputs = tf.concat(cls_outputs, axis=1)
+ box_outputs = tf.concat(box_outputs, axis=1)
+ return tf.concat([box_outputs, cls_outputs], axis=-1)
+
+
+"""
+## Implementing a custom layer to decode predictions
+"""
+
+
+class DecodePredictions(tf.keras.layers.Layer):
+ """A Keras layer that decodes predictions of the RetinaNet model.
+
+ Attributes:
+ num_classes: Number of classes in the dataset
+ confidence_threshold: Minimum class probability, below which detections
+ are pruned.
+ nms_iou_threshold: IOU threshold for the NMS operation
+ max_detections_per_class: Maximum number of detections to retain per
+ class.
+ max_detections: Maximum number of detections to retain across all
+ classes.
+ box_variance: The scaling factors used to scale the bounding box
+ predictions.
+ """
+
+ def __init__(
+ self,
+ num_classes=80,
+ confidence_threshold=0.05,
+ nms_iou_threshold=0.5,
+ max_detections_per_class=100,
+ max_detections=100,
+ box_variance=[0.1, 0.1, 0.2, 0.2],
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.num_classes = num_classes
+ self.confidence_threshold = confidence_threshold
+ self.nms_iou_threshold = nms_iou_threshold
+ self.max_detections_per_class = max_detections_per_class
+ self.max_detections = max_detections
+
+ self._anchor_box = AnchorBox()
+ self._box_variance = tf.convert_to_tensor(
+ [0.1, 0.1, 0.2, 0.2], dtype=tf.float32
+ )
+
+ def _decode_box_predictions(self, anchor_boxes, box_predictions):
+ boxes = box_predictions * self._box_variance
+ boxes = tf.concat(
+ [
+ boxes[:, :, :2] * anchor_boxes[:, :, 2:] + anchor_boxes[:, :, :2],
+ tf.math.exp(boxes[:, :, 2:]) * anchor_boxes[:, :, 2:],
+ ],
+ axis=-1,
+ )
+ boxes_transformed = convert_to_corners(boxes)
+ return boxes_transformed
+
+ def call(self, images, predictions):
+ image_shape = tf.cast(tf.shape(images), dtype=tf.float32)
+ anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2])
+ box_predictions = predictions[:, :, :4]
+ cls_predictions = tf.nn.sigmoid(predictions[:, :, 4:])
+ boxes = self._decode_box_predictions(anchor_boxes[None, ...], box_predictions)
+
+ return tf.image.combined_non_max_suppression(
+ tf.expand_dims(boxes, axis=2),
+ cls_predictions,
+ self.max_detections_per_class,
+ self.max_detections,
+ self.nms_iou_threshold,
+ self.confidence_threshold,
+ clip_boxes=False,
+ )
+
+
+"""
+## Implementing Smooth L1 loss and Focal Loss as keras custom losses
+"""
+
+
+class RetinaNetBoxLoss(tf.losses.Loss):
+ """Implements Smooth L1 loss"""
+
+ def __init__(self, delta):
+ super().__init__(reduction="none", name="RetinaNetBoxLoss")
+ self._delta = delta
+
+ def call(self, y_true, y_pred):
+ difference = y_true - y_pred
+ absolute_difference = tf.abs(difference)
+ squared_difference = difference**2
+ loss = tf.where(
+ tf.less(absolute_difference, self._delta),
+ 0.5 * squared_difference,
+ absolute_difference - 0.5,
+ )
+ return tf.reduce_sum(loss, axis=-1)
+
+
+class RetinaNetClassificationLoss(tf.losses.Loss):
+ """Implements Focal loss"""
+
+ def __init__(self, alpha, gamma):
+ super().__init__(reduction="none", name="RetinaNetClassificationLoss")
+ self._alpha = alpha
+ self._gamma = gamma
+
+ def call(self, y_true, y_pred):
+ cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=y_true, logits=y_pred
+ )
+ probs = tf.nn.sigmoid(y_pred)
+ alpha = tf.where(tf.equal(y_true, 1.0), self._alpha, (1.0 - self._alpha))
+ pt = tf.where(tf.equal(y_true, 1.0), probs, 1 - probs)
+ loss = alpha * tf.pow(1.0 - pt, self._gamma) * cross_entropy
+ return tf.reduce_sum(loss, axis=-1)
+
+
+class RetinaNetLoss(tf.losses.Loss):
+ """Wrapper to combine both the losses"""
+
+ def __init__(self, num_classes=80, alpha=0.25, gamma=2.0, delta=1.0):
+ super().__init__(reduction="auto", name="RetinaNetLoss")
+ self._clf_loss = RetinaNetClassificationLoss(alpha, gamma)
+ self._box_loss = RetinaNetBoxLoss(delta)
+ self._num_classes = num_classes
+
+ def call(self, y_true, y_pred):
+ y_pred = tf.cast(y_pred, dtype=tf.float32)
+ box_labels = y_true[:, :, :4]
+ box_predictions = y_pred[:, :, :4]
+ cls_labels = tf.one_hot(
+ tf.cast(y_true[:, :, 4], dtype=tf.int32),
+ depth=self._num_classes,
+ dtype=tf.float32,
+ )
+ cls_predictions = y_pred[:, :, 4:]
+ positive_mask = tf.cast(tf.greater(y_true[:, :, 4], -1.0), dtype=tf.float32)
+ ignore_mask = tf.cast(tf.equal(y_true[:, :, 4], -2.0), dtype=tf.float32)
+ clf_loss = self._clf_loss(cls_labels, cls_predictions)
+ box_loss = self._box_loss(box_labels, box_predictions)
+ clf_loss = tf.where(tf.equal(ignore_mask, 1.0), 0.0, clf_loss)
+ box_loss = tf.where(tf.equal(positive_mask, 1.0), box_loss, 0.0)
+ normalizer = tf.reduce_sum(positive_mask, axis=-1)
+ clf_loss = tf.math.divide_no_nan(tf.reduce_sum(clf_loss, axis=-1), normalizer)
+ box_loss = tf.math.divide_no_nan(tf.reduce_sum(box_loss, axis=-1), normalizer)
+ loss = clf_loss + box_loss
+ return loss
+
+
+"""
+## Setting up training parameters
+"""
+
+model_dir = "retinanet/"
+label_encoder = LabelEncoder()
+
+num_classes = 80
+batch_size = 2
+
+learning_rates = [2.5e-06, 0.000625, 0.00125, 0.0025, 0.00025, 2.5e-05]
+learning_rate_boundaries = [125, 250, 500, 240000, 360000]
+learning_rate_fn = tf.optimizers.schedules.PiecewiseConstantDecay(
+ boundaries=learning_rate_boundaries, values=learning_rates
+)
+
+"""
+## Initializing and compiling model
+"""
+
+resnet50_backbone = get_backbone()
+loss_fn = RetinaNetLoss(num_classes)
+model = RetinaNet(num_classes, resnet50_backbone)
+
+optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=learning_rate_fn, momentum=0.9)
+model.compile(loss=loss_fn, optimizer=optimizer)
+
+"""
+## Setting up callbacks
+"""
+
+callbacks_list = [
+ tf.keras.callbacks.ModelCheckpoint(
+ filepath=os.path.join(model_dir, "weights" + "_epoch_{epoch}"),
+ monitor="loss",
+ save_best_only=False,
+ save_weights_only=True,
+ verbose=1,
+ )
+]
+
+"""
+## Load the COCO2017 dataset using TensorFlow Datasets
+"""
+
+# set `data_dir=None` to load the complete dataset
+
+(train_dataset, val_dataset), dataset_info = tfds.load(
+ "coco/2017", split=["train", "validation"], with_info=True, data_dir="data"
+)
+
+"""
+## Setting up a `tf.data` pipeline
+
+To ensure that the model is fed with data efficiently we will be using
+`tf.data` API to create our input pipeline. The input pipeline
+consists for the following major processing steps:
+
+- Apply the preprocessing function to the samples
+- Create batches with fixed batch size. Since images in the batch can
+have different dimensions, and can also have different number of
+objects, we use `padded_batch` to the add the necessary padding to create
+rectangular tensors
+- Create targets for each sample in the batch using `LabelEncoder`
+"""
+
+autotune = tf.data.AUTOTUNE
+train_dataset = train_dataset.map(preprocess_data, num_parallel_calls=autotune)
+train_dataset = train_dataset.shuffle(8 * batch_size)
+train_dataset = train_dataset.padded_batch(
+ batch_size=batch_size, padding_values=(0.0, 1e-8, -1), drop_remainder=True
+)
+train_dataset = train_dataset.map(
+ label_encoder.encode_batch, num_parallel_calls=autotune
+)
+train_dataset = train_dataset.apply(tf.data.experimental.ignore_errors())
+train_dataset = train_dataset.prefetch(autotune)
+
+val_dataset = val_dataset.map(preprocess_data, num_parallel_calls=autotune)
+val_dataset = val_dataset.padded_batch(
+ batch_size=1, padding_values=(0.0, 1e-8, -1), drop_remainder=True
+)
+val_dataset = val_dataset.map(label_encoder.encode_batch, num_parallel_calls=autotune)
+val_dataset = val_dataset.apply(tf.data.experimental.ignore_errors())
+val_dataset = val_dataset.prefetch(autotune)
+
+"""
+## Training the model
+"""
+
+# Uncomment the following lines, when training on full dataset
+# train_steps_per_epoch = dataset_info.splits["train"].num_examples // batch_size
+# val_steps_per_epoch = \
+# dataset_info.splits["validation"].num_examples // batch_size
+
+# train_steps = 4 * 100000
+# epochs = train_steps // train_steps_per_epoch
+
+epochs = 1
+
+# Running 100 training and 50 validation steps,
+# remove `.take` when training on the full dataset
+
+model.fit(
+ train_dataset.take(100),
+ validation_data=val_dataset.take(50),
+ epochs=epochs,
+ callbacks=callbacks_list,
+ verbose=1,
+)
+
+"""
+## Loading weights
+"""
+
+# Change this to `model_dir` when not using the downloaded weights
+weights_dir = "data"
+
+latest_checkpoint = tf.train.latest_checkpoint(weights_dir)
+model.load_weights(latest_checkpoint)
+
+"""
+## Building inference model
+"""
+
+image = tf.keras.Input(shape=[None, None, 3], name="image")
+predictions = model(image, training=False)
+detections = DecodePredictions(confidence_threshold=0.5)(image, predictions)
+inference_model = tf.keras.Model(inputs=image, outputs=detections)
+
+"""
+## Generating detections
+"""
+
+
+def prepare_image(image):
+ image, _, ratio = resize_and_pad_image(image, jitter=None)
+ image = tf.keras.applications.resnet.preprocess_input(image)
+ return tf.expand_dims(image, axis=0), ratio
+
+
+val_dataset = tfds.load("coco/2017", split="validation", data_dir="data")
+int2str = dataset_info.features["objects"]["label"].int2str
+
+for sample in val_dataset.take(2):
+ image = tf.cast(sample["image"], dtype=tf.float32)
+ input_image, ratio = prepare_image(image)
+ detections = inference_model.predict(input_image)
+ num_detections = detections.valid_detections[0]
+ class_names = [
+ int2str(int(x)) for x in detections.nmsed_classes[0][:num_detections]
+ ]
+ visualize_detections(
+ image,
+ detections.nmsed_boxes[0][:num_detections] / ratio,
+ class_names,
+ detections.nmsed_scores[0][:num_detections],
+ )
+
+"""
+Example available on HuggingFace.
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/Object-Detection-RetinaNet) | [](https://huggingface.co/spaces/keras-io/Object-Detection-Using-RetinaNet) |
+"""
diff --git a/knowledge_base/vision/semantic_image_clustering.py b/knowledge_base/vision/semantic_image_clustering.py
new file mode 100644
index 0000000000000000000000000000000000000000..d70dfb4d5a04ce7f3f19c9cdfc440b082538a66f
--- /dev/null
+++ b/knowledge_base/vision/semantic_image_clustering.py
@@ -0,0 +1,600 @@
+"""
+Title: Semantic Image Clustering
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2021/02/28
+Last modified: 2021/02/28
+Description: Semantic Clustering by Adopting Nearest neighbors (SCAN) algorithm.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This example demonstrates how to apply the [Semantic Clustering by Adopting Nearest neighbors
+(SCAN)](https://arxiv.org/abs/2005.12320) algorithm (Van Gansbeke et al., 2020) on the
+[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. The algorithm consists of
+two phases:
+
+1. Self-supervised visual representation learning of images, in which we use the
+[simCLR](https://arxiv.org/abs/2002.05709) technique.
+2. Clustering of the learned visual representation vectors to maximize the agreement
+between the cluster assignments of neighboring vectors.
+"""
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+from collections import defaultdict
+import numpy as np
+import tensorflow as tf
+import keras
+from keras import layers
+import matplotlib.pyplot as plt
+from tqdm import tqdm
+
+"""
+## Prepare the data
+"""
+
+num_classes = 10
+input_shape = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+x_data = np.concatenate([x_train, x_test])
+y_data = np.concatenate([y_train, y_test])
+
+print("x_data shape:", x_data.shape, "- y_data shape:", y_data.shape)
+
+classes = [
+ "airplane",
+ "automobile",
+ "bird",
+ "cat",
+ "deer",
+ "dog",
+ "frog",
+ "horse",
+ "ship",
+ "truck",
+]
+
+"""
+## Define hyperparameters
+"""
+
+target_size = 32 # Resize the input images.
+representation_dim = 512 # The dimensions of the features vector.
+projection_units = 128 # The projection head of the representation learner.
+num_clusters = 20 # Number of clusters.
+k_neighbours = 5 # Number of neighbours to consider during cluster learning.
+tune_encoder_during_clustering = False # Freeze the encoder in the cluster learning.
+
+"""
+## Implement data preprocessing
+
+The data preprocessing step resizes the input images to the desired `target_size` and applies
+feature-wise normalization. Note that, when using `keras.applications.ResNet50V2` as the
+visual encoder, resizing the images into 255 x 255 inputs would lead to more accurate results
+but require a longer time to train.
+"""
+
+data_preprocessing = keras.Sequential(
+ [
+ layers.Resizing(target_size, target_size),
+ layers.Normalization(),
+ ]
+)
+# Compute the mean and the variance from the data for normalization.
+data_preprocessing.layers[-1].adapt(x_data)
+
+"""
+## Data augmentation
+
+Unlike simCLR, which randomly picks a single data augmentation function to apply to an input
+image, we apply a set of data augmentation functions randomly to the input image.
+(You can experiment with other image augmentation techniques by following
+the [data augmentation tutorial](https://www.tensorflow.org/tutorials/images/data_augmentation).)
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.RandomTranslation(
+ height_factor=(-0.2, 0.2), width_factor=(-0.2, 0.2), fill_mode="nearest"
+ ),
+ layers.RandomFlip(mode="horizontal"),
+ layers.RandomRotation(factor=0.15, fill_mode="nearest"),
+ layers.RandomZoom(
+ height_factor=(-0.3, 0.1), width_factor=(-0.3, 0.1), fill_mode="nearest"
+ ),
+ ]
+)
+
+"""
+Display a random image
+"""
+
+image_idx = np.random.choice(range(x_data.shape[0]))
+image = x_data[image_idx]
+image_class = classes[y_data[image_idx][0]]
+plt.figure(figsize=(3, 3))
+plt.imshow(x_data[image_idx].astype("uint8"))
+plt.title(image_class)
+_ = plt.axis("off")
+
+"""
+Display a sample of augmented versions of the image
+"""
+
+plt.figure(figsize=(10, 10))
+for i in range(9):
+ augmented_images = data_augmentation(np.array([image]))
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(augmented_images[0].numpy().astype("uint8"))
+ plt.axis("off")
+
+"""
+## Self-supervised representation learning
+"""
+
+"""
+### Implement the vision encoder
+"""
+
+
+def create_encoder(representation_dim):
+ encoder = keras.Sequential(
+ [
+ keras.applications.ResNet50V2(
+ include_top=False, weights=None, pooling="avg"
+ ),
+ layers.Dense(representation_dim),
+ ]
+ )
+ return encoder
+
+
+"""
+### Implement the unsupervised contrastive loss
+"""
+
+
+class RepresentationLearner(keras.Model):
+ def __init__(
+ self,
+ encoder,
+ projection_units,
+ num_augmentations,
+ temperature=1.0,
+ dropout_rate=0.1,
+ l2_normalize=False,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.encoder = encoder
+ # Create projection head.
+ self.projector = keras.Sequential(
+ [
+ layers.Dropout(dropout_rate),
+ layers.Dense(units=projection_units, use_bias=False),
+ layers.BatchNormalization(),
+ layers.ReLU(),
+ ]
+ )
+ self.num_augmentations = num_augmentations
+ self.temperature = temperature
+ self.l2_normalize = l2_normalize
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+
+ @property
+ def metrics(self):
+ return [self.loss_tracker]
+
+ def compute_contrastive_loss(self, feature_vectors, batch_size):
+ num_augmentations = keras.ops.shape(feature_vectors)[0] // batch_size
+ if self.l2_normalize:
+ feature_vectors = keras.utils.normalize(feature_vectors)
+ # The logits shape is [num_augmentations * batch_size, num_augmentations * batch_size].
+ logits = (
+ tf.linalg.matmul(feature_vectors, feature_vectors, transpose_b=True)
+ / self.temperature
+ )
+ # Apply log-max trick for numerical stability.
+ logits_max = keras.ops.max(logits, axis=1)
+ logits = logits - logits_max
+ # The shape of targets is [num_augmentations * batch_size, num_augmentations * batch_size].
+ # targets is a matrix consits of num_augmentations submatrices of shape [batch_size * batch_size].
+ # Each [batch_size * batch_size] submatrix is an identity matrix (diagonal entries are ones).
+ targets = keras.ops.tile(
+ tf.eye(batch_size), [num_augmentations, num_augmentations]
+ )
+ # Compute cross entropy loss
+ return keras.losses.categorical_crossentropy(
+ y_true=targets, y_pred=logits, from_logits=True
+ )
+
+ def call(self, inputs):
+ # Preprocess the input images.
+ preprocessed = data_preprocessing(inputs)
+ # Create augmented versions of the images.
+ augmented = []
+ for _ in range(self.num_augmentations):
+ augmented.append(data_augmentation(preprocessed))
+ augmented = layers.Concatenate(axis=0)(augmented)
+ # Generate embedding representations of the images.
+ features = self.encoder(augmented)
+ # Apply projection head.
+ return self.projector(features)
+
+ def train_step(self, inputs):
+ batch_size = keras.ops.shape(inputs)[0]
+ # Run the forward pass and compute the contrastive loss
+ with tf.GradientTape() as tape:
+ feature_vectors = self(inputs, training=True)
+ loss = self.compute_contrastive_loss(feature_vectors, batch_size)
+ # Compute gradients
+ trainable_vars = self.trainable_variables
+ gradients = tape.gradient(loss, trainable_vars)
+ # Update weights
+ self.optimizer.apply_gradients(zip(gradients, trainable_vars))
+ # Update loss tracker metric
+ self.loss_tracker.update_state(loss)
+ # Return a dict mapping metric names to current value
+ return {m.name: m.result() for m in self.metrics}
+
+ def test_step(self, inputs):
+ batch_size = keras.ops.shape(inputs)[0]
+ feature_vectors = self(inputs, training=False)
+ loss = self.compute_contrastive_loss(feature_vectors, batch_size)
+ self.loss_tracker.update_state(loss)
+ return {"loss": self.loss_tracker.result()}
+
+
+"""
+### Train the model
+"""
+# Create vision encoder.
+encoder = create_encoder(representation_dim)
+# Create representation learner.
+representation_learner = RepresentationLearner(
+ encoder, projection_units, num_augmentations=2, temperature=0.1
+)
+# Create a a Cosine decay learning rate scheduler.
+lr_scheduler = keras.optimizers.schedules.CosineDecay(
+ initial_learning_rate=0.001, decay_steps=500, alpha=0.1
+)
+# Compile the model.
+representation_learner.compile(
+ optimizer=keras.optimizers.AdamW(learning_rate=lr_scheduler, weight_decay=0.0001),
+ jit_compile=False,
+)
+# Fit the model.
+history = representation_learner.fit(
+ x=x_data,
+ batch_size=512,
+ epochs=50, # for better results, increase the number of epochs to 500.
+)
+
+
+"""
+Plot training loss
+"""
+
+plt.plot(history.history["loss"])
+plt.ylabel("loss")
+plt.xlabel("epoch")
+plt.show()
+
+"""
+## Compute the nearest neighbors
+"""
+
+"""
+### Generate the embeddings for the images
+"""
+
+batch_size = 500
+# Get the feature vector representations of the images.
+feature_vectors = encoder.predict(x_data, batch_size=batch_size, verbose=1)
+# Normalize the feature vectores.
+feature_vectors = keras.utils.normalize(feature_vectors)
+
+"""
+### Find the *k* nearest neighbours for each embedding
+"""
+
+neighbours = []
+num_batches = feature_vectors.shape[0] // batch_size
+for batch_idx in tqdm(range(num_batches)):
+ start_idx = batch_idx * batch_size
+ end_idx = start_idx + batch_size
+ current_batch = feature_vectors[start_idx:end_idx]
+ # Compute the dot similarity.
+ similarities = tf.linalg.matmul(current_batch, feature_vectors, transpose_b=True)
+ # Get the indices of most similar vectors.
+ _, indices = keras.ops.top_k(similarities, k=k_neighbours + 1, sorted=True)
+ # Add the indices to the neighbours.
+ neighbours.append(indices[..., 1:])
+
+neighbours = np.reshape(np.array(neighbours), (-1, k_neighbours))
+
+"""
+Let's display some neighbors on each row
+"""
+
+nrows = 4
+ncols = k_neighbours + 1
+
+plt.figure(figsize=(12, 12))
+position = 1
+for _ in range(nrows):
+ anchor_idx = np.random.choice(range(x_data.shape[0]))
+ neighbour_indicies = neighbours[anchor_idx]
+ indices = [anchor_idx] + neighbour_indicies.tolist()
+ for j in range(ncols):
+ plt.subplot(nrows, ncols, position)
+ plt.imshow(x_data[indices[j]].astype("uint8"))
+ plt.title(classes[y_data[indices[j]][0]])
+ plt.axis("off")
+ position += 1
+
+"""
+You notice that images on each row are visually similar, and belong to similar classes.
+"""
+
+"""
+## Semantic clustering with nearest neighbours
+"""
+
+"""
+### Implement clustering consistency loss
+
+This loss tries to make sure that neighbours have the same clustering assignments.
+"""
+
+
+class ClustersConsistencyLoss(keras.losses.Loss):
+ def __init__(self):
+ super().__init__()
+
+ def __call__(self, target, similarity, sample_weight=None):
+ # Set targets to be ones.
+ target = keras.ops.ones_like(similarity)
+ # Compute cross entropy loss.
+ loss = keras.losses.binary_crossentropy(
+ y_true=target, y_pred=similarity, from_logits=True
+ )
+ return keras.ops.mean(loss)
+
+
+"""
+### Implement the clusters entropy loss
+
+This loss tries to make sure that cluster distribution is roughly uniformed, to avoid
+assigning most of the instances to one cluster.
+"""
+
+
+class ClustersEntropyLoss(keras.losses.Loss):
+ def __init__(self, entropy_loss_weight=1.0):
+ super().__init__()
+ self.entropy_loss_weight = entropy_loss_weight
+
+ def __call__(self, target, cluster_probabilities, sample_weight=None):
+ # Ideal entropy = log(num_clusters).
+ num_clusters = keras.ops.cast(
+ keras.ops.shape(cluster_probabilities)[-1], "float32"
+ )
+ target = keras.ops.log(num_clusters)
+ # Compute the overall clusters distribution.
+ cluster_probabilities = keras.ops.mean(cluster_probabilities, axis=0)
+ # Replacing zero probabilities - if any - with a very small value.
+ cluster_probabilities = keras.ops.clip(cluster_probabilities, 1e-8, 1.0)
+ # Compute the entropy over the clusters.
+ entropy = -keras.ops.sum(
+ cluster_probabilities * keras.ops.log(cluster_probabilities)
+ )
+ # Compute the difference between the target and the actual.
+ loss = target - entropy
+ return loss
+
+
+"""
+### Implement clustering model
+
+This model takes a raw image as an input, generated its feature vector using the trained
+encoder, and produces a probability distribution of the clusters given the feature vector
+as the cluster assignments.
+"""
+
+
+def create_clustering_model(encoder, num_clusters, name=None):
+ inputs = keras.Input(shape=input_shape)
+ # Preprocess the input images.
+ preprocessed = data_preprocessing(inputs)
+ # Apply data augmentation to the images.
+ augmented = data_augmentation(preprocessed)
+ # Generate embedding representations of the images.
+ features = encoder(augmented)
+ # Assign the images to clusters.
+ outputs = layers.Dense(units=num_clusters, activation="softmax")(features)
+ # Create the model.
+ model = keras.Model(inputs=inputs, outputs=outputs, name=name)
+ return model
+
+
+"""
+### Implement clustering learner
+
+This model receives the input `anchor` image and its `neighbours`, produces the clusters
+assignments for them using the `clustering_model`, and produces two outputs:
+1. `similarity`: the similarity between the cluster assignments of the `anchor` image and
+its `neighbours`. This output is fed to the `ClustersConsistencyLoss`.
+2. `anchor_clustering`: cluster assignments of the `anchor` images. This is fed to the `ClustersEntropyLoss`.
+"""
+
+
+def create_clustering_learner(clustering_model):
+ anchor = keras.Input(shape=input_shape, name="anchors")
+ neighbours = keras.Input(
+ shape=tuple([k_neighbours]) + input_shape, name="neighbours"
+ )
+ # Changes neighbours shape to [batch_size * k_neighbours, width, height, channels]
+ neighbours_reshaped = keras.ops.reshape(neighbours, tuple([-1]) + input_shape)
+ # anchor_clustering shape: [batch_size, num_clusters]
+ anchor_clustering = clustering_model(anchor)
+ # neighbours_clustering shape: [batch_size * k_neighbours, num_clusters]
+ neighbours_clustering = clustering_model(neighbours_reshaped)
+ # Convert neighbours_clustering shape to [batch_size, k_neighbours, num_clusters]
+ neighbours_clustering = keras.ops.reshape(
+ neighbours_clustering,
+ (-1, k_neighbours, keras.ops.shape(neighbours_clustering)[-1]),
+ )
+ # similarity shape: [batch_size, 1, k_neighbours]
+ similarity = keras.ops.einsum(
+ "bij,bkj->bik",
+ keras.ops.expand_dims(anchor_clustering, axis=1),
+ neighbours_clustering,
+ )
+ # similarity shape: [batch_size, k_neighbours]
+ similarity = layers.Lambda(
+ lambda x: keras.ops.squeeze(x, axis=1), name="similarity"
+ )(similarity)
+ # Create the model.
+ model = keras.Model(
+ inputs=[anchor, neighbours],
+ outputs=[similarity, anchor_clustering],
+ name="clustering_learner",
+ )
+ return model
+
+
+"""
+### Train model
+"""
+
+# If tune_encoder_during_clustering is set to False,
+# then freeze the encoder weights.
+for layer in encoder.layers:
+ layer.trainable = tune_encoder_during_clustering
+# Create the clustering model and learner.
+clustering_model = create_clustering_model(encoder, num_clusters, name="clustering")
+clustering_learner = create_clustering_learner(clustering_model)
+# Instantiate the model losses.
+losses = [ClustersConsistencyLoss(), ClustersEntropyLoss(entropy_loss_weight=5)]
+# Create the model inputs and labels.
+inputs = {"anchors": x_data, "neighbours": tf.gather(x_data, neighbours)}
+labels = np.ones(shape=(x_data.shape[0]))
+# Compile the model.
+clustering_learner.compile(
+ optimizer=keras.optimizers.AdamW(learning_rate=0.0005, weight_decay=0.0001),
+ loss=losses,
+ jit_compile=False,
+)
+
+# Begin training the model.
+clustering_learner.fit(x=inputs, y=labels, batch_size=512, epochs=50)
+
+"""
+Plot training loss
+"""
+
+plt.plot(history.history["loss"])
+plt.ylabel("loss")
+plt.xlabel("epoch")
+plt.show()
+
+"""
+## Cluster analysis
+"""
+
+"""
+### Assign images to clusters
+"""
+
+# Get the cluster probability distribution of the input images.
+clustering_probs = clustering_model.predict(x_data, batch_size=batch_size, verbose=1)
+# Get the cluster of the highest probability.
+cluster_assignments = keras.ops.argmax(clustering_probs, axis=-1).numpy()
+# Store the clustering confidence.
+# Images with the highest clustering confidence are considered the 'prototypes'
+# of the clusters.
+cluster_confidence = keras.ops.max(clustering_probs, axis=-1).numpy()
+
+"""
+Let's compute the cluster sizes
+"""
+
+clusters = defaultdict(list)
+for idx, c in enumerate(cluster_assignments):
+ clusters[c].append((idx, cluster_confidence[idx]))
+
+non_empty_clusters = defaultdict(list)
+for c in clusters.keys():
+ if clusters[c]:
+ non_empty_clusters[c] = clusters[c]
+
+for c in range(num_clusters):
+ print("cluster", c, ":", len(clusters[c]))
+
+"""
+### Visualize cluster images
+
+Display the *prototypes*โinstances with the highest clustering confidenceโof each cluster:
+"""
+
+num_images = 8
+plt.figure(figsize=(15, 15))
+position = 1
+for c in non_empty_clusters.keys():
+ cluster_instances = sorted(
+ non_empty_clusters[c], key=lambda kv: kv[1], reverse=True
+ )
+
+ for j in range(num_images):
+ image_idx = cluster_instances[j][0]
+ plt.subplot(len(non_empty_clusters), num_images, position)
+ plt.imshow(x_data[image_idx].astype("uint8"))
+ plt.title(classes[y_data[image_idx][0]])
+ plt.axis("off")
+ position += 1
+
+"""
+### Compute clustering accuracy
+
+First, we assign a label for each cluster based on the majority label of its images.
+Then, we compute the accuracy of each cluster by dividing the number of image with the
+majority label by the size of the cluster.
+"""
+
+cluster_label_counts = dict()
+
+for c in range(num_clusters):
+ cluster_label_counts[c] = [0] * num_classes
+ instances = clusters[c]
+ for i, _ in instances:
+ cluster_label_counts[c][y_data[i][0]] += 1
+
+ cluster_label_idx = np.argmax(cluster_label_counts[c])
+ correct_count = np.max(cluster_label_counts[c])
+ cluster_size = len(clusters[c])
+ accuracy = (
+ np.round((correct_count / cluster_size) * 100, 2) if cluster_size > 0 else 0
+ )
+ cluster_label = classes[cluster_label_idx]
+ print("cluster", c, "label is:", cluster_label, " - accuracy:", accuracy, "%")
+
+"""
+## Conclusion
+
+To improve the accuracy results, you can: 1) increase the number
+of epochs in the representation learning and the clustering phases; 2)
+allow the encoder weights to be tuned during the clustering phase; and 3) perform a final
+fine-tuning step through self-labeling, as described in the [original SCAN paper](https://arxiv.org/abs/2005.12320).
+Note that unsupervised image clustering techniques are not expected to outperform the accuracy
+of supervised image classification techniques, rather showing that they can learn the semantics
+of the images and group them into clusters that are similar to their original classes.
+"""
diff --git a/knowledge_base/vision/semisupervised_simclr.py b/knowledge_base/vision/semisupervised_simclr.py
new file mode 100644
index 0000000000000000000000000000000000000000..c398210962cfac708e763c6d88e2a0670a63fbbc
--- /dev/null
+++ b/knowledge_base/vision/semisupervised_simclr.py
@@ -0,0 +1,691 @@
+"""
+Title: Semi-supervised image classification using contrastive pretraining with SimCLR
+Author: [Andrรกs Bรฉres](https://www.linkedin.com/in/andras-beres-789190210)
+Date created: 2021/04/24
+Last modified: 2024/03/04
+Description: Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+### Semi-supervised learning
+
+Semi-supervised learning is a machine learning paradigm that deals with
+**partially labeled datasets**. When applying deep learning in the real world,
+one usually has to gather a large dataset to make it work well. However, while
+the cost of labeling scales linearly with the dataset size (labeling each
+example takes a constant time), model performance only scales
+[sublinearly](https://arxiv.org/abs/2001.08361) with it. This means that
+labeling more and more samples becomes less and less cost-efficient, while
+gathering unlabeled data is generally cheap, as it is usually readily available
+in large quantities.
+
+Semi-supervised learning offers to solve this problem by only requiring a
+partially labeled dataset, and by being label-efficient by utilizing the
+unlabeled examples for learning as well.
+
+In this example, we will pretrain an encoder with contrastive learning on the
+[STL-10](https://ai.stanford.edu/~acoates/stl10/) semi-supervised dataset using
+no labels at all, and then fine-tune it using only its labeled subset.
+
+### Contrastive learning
+
+On the highest level, the main idea behind contrastive learning is to **learn
+representations that are invariant to image augmentations** in a self-supervised
+manner. One problem with this objective is that it has a trivial degenerate
+solution: the case where the representations are constant, and do not depend at all on the
+input images.
+
+Contrastive learning avoids this trap by modifying the objective in the
+following way: it pulls representations of augmented versions/views of the same
+image closer to each other (contracting positives), while simultaneously pushing
+different images away from each other (contrasting negatives) in representation
+space.
+
+One such contrastive approach is [SimCLR](https://arxiv.org/abs/2002.05709),
+which essentially identifies the core components needed to optimize this
+objective, and can achieve high performance by scaling this simple approach.
+
+Another approach is [SimSiam](https://arxiv.org/abs/2011.10566)
+([Keras example](https://keras.io/examples/vision/simsiam/)),
+whose main difference from
+SimCLR is that the former does not use any negatives in its loss. Therefore, it does not
+explicitly prevent the trivial solution, and, instead, avoids it implicitly by
+architecture design (asymmetric encoding paths using a predictor network and
+batch normalization (BatchNorm) are applied in the final layers).
+
+For further reading about SimCLR, check out
+[the official Google AI blog post](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html),
+and for an overview of self-supervised learning across both vision and language
+check out
+[this blog post](https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/).
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+
+# Make sure we are able to handle large datasets
+import resource
+
+low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
+
+import math
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import tensorflow_datasets as tfds
+
+import keras
+from keras import ops
+from keras import layers
+
+"""
+## Hyperparameter setup
+"""
+# Dataset hyperparameters
+unlabeled_dataset_size = 100000
+labeled_dataset_size = 5000
+image_channels = 3
+
+# Algorithm hyperparameters
+num_epochs = 20
+batch_size = 525 # Corresponds to 200 steps per epoch
+width = 128
+temperature = 0.1
+# Stronger augmentations for contrastive, weaker ones for supervised training
+contrastive_augmentation = {"min_area": 0.25, "brightness": 0.6, "jitter": 0.2}
+classification_augmentation = {
+ "min_area": 0.75,
+ "brightness": 0.3,
+ "jitter": 0.1,
+}
+
+"""
+## Dataset
+
+During training we will simultaneously load a large batch of unlabeled images along with a
+smaller batch of labeled images.
+"""
+
+
+def prepare_dataset():
+ # Labeled and unlabeled samples are loaded synchronously
+ # with batch sizes selected accordingly
+ steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
+ unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
+ labeled_batch_size = labeled_dataset_size // steps_per_epoch
+ print(
+ f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
+ )
+
+ # Turning off shuffle to lower resource usage
+ unlabeled_train_dataset = (
+ tfds.load("stl10", split="unlabelled", as_supervised=True, shuffle_files=False)
+ .shuffle(buffer_size=10 * unlabeled_batch_size)
+ .batch(unlabeled_batch_size)
+ )
+ labeled_train_dataset = (
+ tfds.load("stl10", split="train", as_supervised=True, shuffle_files=False)
+ .shuffle(buffer_size=10 * labeled_batch_size)
+ .batch(labeled_batch_size)
+ )
+ test_dataset = (
+ tfds.load("stl10", split="test", as_supervised=True)
+ .batch(batch_size)
+ .prefetch(buffer_size=tf.data.AUTOTUNE)
+ )
+
+ # Labeled and unlabeled datasets are zipped together
+ train_dataset = tf.data.Dataset.zip(
+ (unlabeled_train_dataset, labeled_train_dataset)
+ ).prefetch(buffer_size=tf.data.AUTOTUNE)
+
+ return train_dataset, labeled_train_dataset, test_dataset
+
+
+# Load STL10 dataset
+train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()
+
+"""
+## Image augmentations
+
+The two most important image augmentations for contrastive learning are the
+following:
+
+- Cropping: forces the model to encode different parts of the same image
+similarly, we implement it with the
+[RandomTranslation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_translation/)
+and
+[RandomZoom](https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_zoom/)
+layers
+- Color jitter: prevents a trivial color histogram-based solution to the task by
+distorting color histograms. A principled way to implement that is by affine
+transformations in color space.
+
+In this example we use random horizontal flips as well. Stronger augmentations
+are applied for contrastive learning, along with weaker ones for supervised
+classification to avoid overfitting on the few labeled examples.
+
+We implement random color jitter as a custom preprocessing layer. Using
+preprocessing layers for data augmentation has the following two advantages:
+
+- The data augmentation will run on GPU in batches, so the training will not be
+bottlenecked by the data pipeline in environments with constrained CPU
+resources (such as a Colab Notebook, or a personal machine)
+- Deployment is easier as the data preprocessing pipeline is encapsulated in the
+model, and does not have to be reimplemented when deploying it
+"""
+
+
+# Distorts the color distibutions of images
+class RandomColorAffine(layers.Layer):
+ def __init__(self, brightness=0, jitter=0, **kwargs):
+ super().__init__(**kwargs)
+
+ self.seed_generator = keras.random.SeedGenerator(1337)
+ self.brightness = brightness
+ self.jitter = jitter
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"brightness": self.brightness, "jitter": self.jitter})
+ return config
+
+ def call(self, images, training=True):
+ if training:
+ batch_size = ops.shape(images)[0]
+
+ # Same for all colors
+ brightness_scales = 1 + keras.random.uniform(
+ (batch_size, 1, 1, 1),
+ minval=-self.brightness,
+ maxval=self.brightness,
+ seed=self.seed_generator,
+ )
+ # Different for all colors
+ jitter_matrices = keras.random.uniform(
+ (batch_size, 1, 3, 3),
+ minval=-self.jitter,
+ maxval=self.jitter,
+ seed=self.seed_generator,
+ )
+
+ color_transforms = (
+ ops.tile(ops.expand_dims(ops.eye(3), axis=0), (batch_size, 1, 1, 1))
+ * brightness_scales
+ + jitter_matrices
+ )
+ images = ops.clip(ops.matmul(images, color_transforms), 0, 1)
+ return images
+
+
+# Image augmentation module
+def get_augmenter(min_area, brightness, jitter):
+ zoom_factor = 1.0 - math.sqrt(min_area)
+ return keras.Sequential(
+ [
+ layers.Rescaling(1 / 255),
+ layers.RandomFlip("horizontal"),
+ layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
+ layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
+ RandomColorAffine(brightness, jitter),
+ ]
+ )
+
+
+def visualize_augmentations(num_images):
+ # Sample a batch from a dataset
+ images = next(iter(train_dataset))[0][0][:num_images]
+
+ # Apply augmentations
+ augmented_images = zip(
+ images,
+ get_augmenter(**classification_augmentation)(images),
+ get_augmenter(**contrastive_augmentation)(images),
+ get_augmenter(**contrastive_augmentation)(images),
+ )
+ row_titles = [
+ "Original:",
+ "Weakly augmented:",
+ "Strongly augmented:",
+ "Strongly augmented:",
+ ]
+ plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
+ for column, image_row in enumerate(augmented_images):
+ for row, image in enumerate(image_row):
+ plt.subplot(4, num_images, row * num_images + column + 1)
+ plt.imshow(image)
+ if column == 0:
+ plt.title(row_titles[row], loc="left")
+ plt.axis("off")
+ plt.tight_layout()
+
+
+visualize_augmentations(num_images=8)
+
+"""
+## Encoder architecture
+"""
+
+
+# Define the encoder architecture
+def get_encoder():
+ return keras.Sequential(
+ [
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
+ layers.Flatten(),
+ layers.Dense(width, activation="relu"),
+ ],
+ name="encoder",
+ )
+
+
+"""
+## Supervised baseline model
+
+A baseline supervised model is trained using random initialization.
+"""
+
+# Baseline supervised training with random initialization
+baseline_model = keras.Sequential(
+ [
+ get_augmenter(**classification_augmentation),
+ get_encoder(),
+ layers.Dense(10),
+ ],
+ name="baseline_model",
+)
+baseline_model.compile(
+ optimizer=keras.optimizers.Adam(),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
+)
+
+baseline_history = baseline_model.fit(
+ labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
+)
+
+print(
+ "Maximal validation accuracy: {:.2f}%".format(
+ max(baseline_history.history["val_acc"]) * 100
+ )
+)
+
+"""
+## Self-supervised model for contrastive pretraining
+
+We pretrain an encoder on unlabeled images with a contrastive loss.
+A nonlinear projection head is attached to the top of the encoder, as it
+improves the quality of representations of the encoder.
+
+We use the InfoNCE/NT-Xent/N-pairs loss, which can be interpreted in the
+following way:
+
+1. We treat each image in the batch as if it had its own class.
+2. Then, we have two examples (a pair of augmented views) for each "class".
+3. Each view's representation is compared to every possible pair's one (for both
+ augmented versions).
+4. We use the temperature-scaled cosine similarity of compared representations as
+ logits.
+5. Finally, we use categorical cross-entropy as the "classification" loss
+
+The following two metrics are used for monitoring the pretraining performance:
+
+- [Contrastive accuracy (SimCLR Table 5)](https://arxiv.org/abs/2002.05709):
+Self-supervised metric, the ratio of cases in which the representation of an
+image is more similar to its differently augmented version's one, than to the
+representation of any other image in the current batch. Self-supervised
+metrics can be used for hyperparameter tuning even in the case when there are
+no labeled examples.
+- [Linear probing accuracy](https://arxiv.org/abs/1603.08511): Linear probing is
+a popular metric to evaluate self-supervised classifiers. It is computed as
+the accuracy of a logistic regression classifier trained on top of the
+encoder's features. In our case, this is done by training a single dense layer
+on top of the frozen encoder. Note that contrary to traditional approach where
+the classifier is trained after the pretraining phase, in this example we
+train it during pretraining. This might slightly decrease its accuracy, but
+that way we can monitor its value during training, which helps with
+experimentation and debugging.
+
+Another widely used supervised metric is the
+[KNN accuracy](https://arxiv.org/abs/1805.01978), which is the accuracy of a KNN
+classifier trained on top of the encoder's features, which is not implemented in
+this example.
+"""
+
+
+# Define the contrastive model with model-subclassing
+class ContrastiveModel(keras.Model):
+ def __init__(self):
+ super().__init__()
+
+ self.temperature = temperature
+ self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
+ self.classification_augmenter = get_augmenter(**classification_augmentation)
+ self.encoder = get_encoder()
+ # Non-linear MLP as projection head
+ self.projection_head = keras.Sequential(
+ [
+ keras.Input(shape=(width,)),
+ layers.Dense(width, activation="relu"),
+ layers.Dense(width),
+ ],
+ name="projection_head",
+ )
+ # Single dense layer for linear probing
+ self.linear_probe = keras.Sequential(
+ [layers.Input(shape=(width,)), layers.Dense(10)],
+ name="linear_probe",
+ )
+
+ self.encoder.summary()
+ self.projection_head.summary()
+ self.linear_probe.summary()
+
+ def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
+ super().compile(**kwargs)
+
+ self.contrastive_optimizer = contrastive_optimizer
+ self.probe_optimizer = probe_optimizer
+
+ # self.contrastive_loss will be defined as a method
+ self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+
+ self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
+ self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
+ name="c_acc"
+ )
+ self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
+ self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")
+
+ @property
+ def metrics(self):
+ return [
+ self.contrastive_loss_tracker,
+ self.contrastive_accuracy,
+ self.probe_loss_tracker,
+ self.probe_accuracy,
+ ]
+
+ def contrastive_loss(self, projections_1, projections_2):
+ # InfoNCE loss (information noise-contrastive estimation)
+ # NT-Xent loss (normalized temperature-scaled cross entropy)
+
+ # Cosine similarity: the dot product of the l2-normalized feature vectors
+ projections_1 = ops.normalize(projections_1, axis=1)
+ projections_2 = ops.normalize(projections_2, axis=1)
+ similarities = (
+ ops.matmul(projections_1, ops.transpose(projections_2)) / self.temperature
+ )
+
+ # The similarity between the representations of two augmented views of the
+ # same image should be higher than their similarity with other views
+ batch_size = ops.shape(projections_1)[0]
+ contrastive_labels = ops.arange(batch_size)
+ self.contrastive_accuracy.update_state(contrastive_labels, similarities)
+ self.contrastive_accuracy.update_state(
+ contrastive_labels, ops.transpose(similarities)
+ )
+
+ # The temperature-scaled similarities are used as logits for cross-entropy
+ # a symmetrized version of the loss is used here
+ loss_1_2 = keras.losses.sparse_categorical_crossentropy(
+ contrastive_labels, similarities, from_logits=True
+ )
+ loss_2_1 = keras.losses.sparse_categorical_crossentropy(
+ contrastive_labels, ops.transpose(similarities), from_logits=True
+ )
+ return (loss_1_2 + loss_2_1) / 2
+
+ def train_step(self, data):
+ (unlabeled_images, _), (labeled_images, labels) = data
+
+ # Both labeled and unlabeled images are used, without labels
+ images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
+ # Each image is augmented twice, differently
+ augmented_images_1 = self.contrastive_augmenter(images, training=True)
+ augmented_images_2 = self.contrastive_augmenter(images, training=True)
+ with tf.GradientTape() as tape:
+ features_1 = self.encoder(augmented_images_1, training=True)
+ features_2 = self.encoder(augmented_images_2, training=True)
+ # The representations are passed through a projection mlp
+ projections_1 = self.projection_head(features_1, training=True)
+ projections_2 = self.projection_head(features_2, training=True)
+ contrastive_loss = self.contrastive_loss(projections_1, projections_2)
+ gradients = tape.gradient(
+ contrastive_loss,
+ self.encoder.trainable_weights + self.projection_head.trainable_weights,
+ )
+ self.contrastive_optimizer.apply_gradients(
+ zip(
+ gradients,
+ self.encoder.trainable_weights + self.projection_head.trainable_weights,
+ )
+ )
+ self.contrastive_loss_tracker.update_state(contrastive_loss)
+
+ # Labels are only used in evalutation for an on-the-fly logistic regression
+ preprocessed_images = self.classification_augmenter(
+ labeled_images, training=True
+ )
+ with tf.GradientTape() as tape:
+ # the encoder is used in inference mode here to avoid regularization
+ # and updating the batch normalization paramers if they are used
+ features = self.encoder(preprocessed_images, training=False)
+ class_logits = self.linear_probe(features, training=True)
+ probe_loss = self.probe_loss(labels, class_logits)
+ gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
+ self.probe_optimizer.apply_gradients(
+ zip(gradients, self.linear_probe.trainable_weights)
+ )
+ self.probe_loss_tracker.update_state(probe_loss)
+ self.probe_accuracy.update_state(labels, class_logits)
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def test_step(self, data):
+ labeled_images, labels = data
+
+ # For testing the components are used with a training=False flag
+ preprocessed_images = self.classification_augmenter(
+ labeled_images, training=False
+ )
+ features = self.encoder(preprocessed_images, training=False)
+ class_logits = self.linear_probe(features, training=False)
+ probe_loss = self.probe_loss(labels, class_logits)
+ self.probe_loss_tracker.update_state(probe_loss)
+ self.probe_accuracy.update_state(labels, class_logits)
+
+ # Only the probe metrics are logged at test time
+ return {m.name: m.result() for m in self.metrics[2:]}
+
+
+# Contrastive pretraining
+pretraining_model = ContrastiveModel()
+pretraining_model.compile(
+ contrastive_optimizer=keras.optimizers.Adam(),
+ probe_optimizer=keras.optimizers.Adam(),
+)
+
+pretraining_history = pretraining_model.fit(
+ train_dataset, epochs=num_epochs, validation_data=test_dataset
+)
+print(
+ "Maximal validation accuracy: {:.2f}%".format(
+ max(pretraining_history.history["val_p_acc"]) * 100
+ )
+)
+
+"""
+## Supervised finetuning of the pretrained encoder
+
+We then finetune the encoder on the labeled examples, by attaching
+a single randomly initalized fully connected classification layer on its top.
+"""
+
+# Supervised finetuning of the pretrained encoder
+finetuning_model = keras.Sequential(
+ [
+ get_augmenter(**classification_augmentation),
+ pretraining_model.encoder,
+ layers.Dense(10),
+ ],
+ name="finetuning_model",
+)
+finetuning_model.compile(
+ optimizer=keras.optimizers.Adam(),
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
+)
+
+finetuning_history = finetuning_model.fit(
+ labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
+)
+print(
+ "Maximal validation accuracy: {:.2f}%".format(
+ max(finetuning_history.history["val_acc"]) * 100
+ )
+)
+
+"""
+## Comparison against the baseline
+"""
+
+
+# The classification accuracies of the baseline and the pretraining + finetuning process:
+def plot_training_curves(pretraining_history, finetuning_history, baseline_history):
+ for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
+ plt.figure(figsize=(8, 5), dpi=100)
+ plt.plot(
+ baseline_history.history[f"val_{metric_key}"],
+ label="supervised baseline",
+ )
+ plt.plot(
+ pretraining_history.history[f"val_p_{metric_key}"],
+ label="self-supervised pretraining",
+ )
+ plt.plot(
+ finetuning_history.history[f"val_{metric_key}"],
+ label="supervised finetuning",
+ )
+ plt.legend()
+ plt.title(f"Classification {metric_name} during training")
+ plt.xlabel("epochs")
+ plt.ylabel(f"validation {metric_name}")
+
+
+plot_training_curves(pretraining_history, finetuning_history, baseline_history)
+
+"""
+By comparing the training curves, we can see that when using contrastive
+pretraining, a higher validation accuracy can be reached, paired with a lower
+validation loss, which means that the pretrained network was able to generalize
+better when seeing only a small amount of labeled examples.
+"""
+
+"""
+## Improving further
+
+### Architecture
+
+The experiment in the original paper demonstrated that increasing the width and depth of the
+models improves performance at a higher rate than for supervised learning. Also,
+using a [ResNet-50](https://keras.io/api/applications/resnet/#resnet50-function)
+encoder is quite standard in the literature. However keep in mind, that more
+powerful models will not only increase training time but will also require more
+memory and will limit the maximal batch size you can use.
+
+It has [been](https://arxiv.org/abs/1905.09272)
+[reported](https://arxiv.org/abs/1911.05722) that the usage of BatchNorm layers
+could sometimes degrade performance, as it introduces an intra-batch dependency
+between samples, which is why I did not have used them in this example. In my
+experiments however, using BatchNorm, especially in the projection head,
+improves performance.
+
+### Hyperparameters
+
+The hyperparameters used in this example have been tuned manually for this task and
+architecture. Therefore, without changing them, only marginal gains can be expected
+from further hyperparameter tuning.
+
+However for a different task or model architecture these would need tuning, so
+here are my notes on the most important ones:
+
+- **Batch size**: since the objective can be interpreted as a classification
+over a batch of images (loosely speaking), the batch size is actually a more
+important hyperparameter than usual. The higher, the better.
+- **Temperature**: the temperature defines the "softness" of the softmax
+distribution that is used in the cross-entropy loss, and is an important
+hyperparameter. Lower values generally lead to a higher contrastive accuracy.
+A recent trick (in [ALIGN](https://arxiv.org/abs/2102.05918)) is to learn
+the temperature's value as well (which can be done by defining it as a
+tf.Variable, and applying gradients on it). Even though this provides a good baseline
+value, in my experiments the learned temperature was somewhat lower
+than optimal, as it is optimized with respect to the contrastive loss, which is not a
+perfect proxy for representation quality.
+- **Image augmentation strength**: during pretraining stronger augmentations
+increase the difficulty of the task, however after a point too strong
+augmentations will degrade performance. During finetuning stronger
+augmentations reduce overfitting while in my experience too strong
+augmentations decrease the performance gains from pretraining. The whole data
+augmentation pipeline can be seen as an important hyperparameter of the
+algorithm, implementations of other custom image augmentation layers in Keras
+can be found in
+[this repository](https://github.com/beresandras/image-augmentation-layers-keras).
+- **Learning rate schedule**: a constant schedule is used here, but it is
+quite common in the literature to use a
+[cosine decay schedule](https://www.tensorflow.org/api_docs/python/tf/keras/experimental/CosineDecay),
+which can further improve performance.
+- **Optimizer**: Adam is used in this example, as it provides good performance
+with default parameters. SGD with momentum requires more tuning, however it
+could slightly increase performance.
+"""
+
+"""
+## Related works
+
+Other instance-level (image-level) contrastive learning methods:
+
+- [MoCo](https://arxiv.org/abs/1911.05722)
+([v2](https://arxiv.org/abs/2003.04297),
+[v3](https://arxiv.org/abs/2104.02057)): uses a momentum-encoder as well,
+whose weights are an exponential moving average of the target encoder
+- [SwAV](https://arxiv.org/abs/2006.09882): uses clustering instead of pairwise
+comparison
+- [BarlowTwins](https://arxiv.org/abs/2103.03230): uses a cross
+correlation-based objective instead of pairwise comparison
+
+Keras implementations of **MoCo** and **BarlowTwins** can be found in
+[this repository](https://github.com/beresandras/contrastive-classification-keras),
+which includes a Colab notebook.
+
+There is also a new line of works, which optimize a similar objective, but
+without the use of any negatives:
+
+- [BYOL](https://arxiv.org/abs/2006.07733): momentum-encoder + no negatives
+- [SimSiam](https://arxiv.org/abs/2011.10566)
+([Keras example](https://keras.io/examples/vision/simsiam/)):
+no momentum-encoder + no negatives
+
+In my experience, these methods are more brittle (they can collapse to a constant
+representation, I could not get them to work using this encoder architecture).
+Even though they are generally more dependent on the
+[model](https://generallyintelligent.ai/understanding-self-supervised-contrastive-learning.html)
+[architecture](https://arxiv.org/abs/2010.10241), they can improve
+performance at smaller batch sizes.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/semi-supervised-classification-simclr)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/semi-supervised-classification).
+"""
diff --git a/knowledge_base/vision/shiftvit.py b/knowledge_base/vision/shiftvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c47ddc0785b47d085e3d7252ae089822abf0cf2d
--- /dev/null
+++ b/knowledge_base/vision/shiftvit.py
@@ -0,0 +1,1053 @@
+"""
+Title: A Vision Transformer without Attention
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
+Date created: 2022/02/24
+Last modified: 2024/12/06
+Description: A minimal implementation of ShiftViT.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+[Vision Transformers](https://arxiv.org/abs/2010.11929) (ViTs) have sparked a wave of
+research at the intersection of Transformers and Computer Vision (CV).
+
+ViTs can simultaneously model long- and short-range dependencies, thanks to
+the Multi-Head Self-Attention mechanism in the Transformer block. Many researchers believe
+that the success of ViTs are purely due to the attention layer, and they seldom
+think about other parts of the ViT model.
+
+In the academic paper
+[When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism](https://arxiv.org/abs/2201.10801)
+the authors propose to demystify the success of ViTs with the introduction of a **NO
+PARAMETER** operation in place of the attention operation. They swap the attention
+operation with a shifting operation.
+
+In this example, we minimally implement the paper with close alignement to the author's
+[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).
+
+This example requires TensorFlow 2.9 or higher.
+"""
+
+"""
+## Setup and imports
+"""
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import keras
+from keras import ops
+from keras import layers
+import tensorflow as tf
+
+import pathlib
+import glob
+
+# Setting seed for reproducibiltiy
+SEED = 42
+keras.utils.set_random_seed(SEED)
+
+"""
+## Hyperparameters
+
+These are the hyperparameters that we have chosen for the experiment.
+Please feel free to tune them.
+"""
+
+
+class Config(object):
+ # DATA
+ batch_size = 256
+ buffer_size = batch_size * 2
+ input_shape = (32, 32, 3)
+ num_classes = 10
+
+ # AUGMENTATION
+ image_size = 48
+
+ # ARCHITECTURE
+ patch_size = 4
+ projected_dim = 96
+ num_shift_blocks_per_stages = [2, 4, 8, 2]
+ epsilon = 1e-5
+ stochastic_depth_rate = 0.2
+ mlp_dropout_rate = 0.2
+ num_div = 12
+ shift_pixel = 1
+ mlp_expand_ratio = 2
+
+ # OPTIMIZER
+ lr_start = 1e-5
+ lr_max = 1e-3
+ weight_decay = 1e-4
+
+ # TRAINING
+ epochs = 100
+
+ # INFERENCE
+ label_map = {
+ 0: "airplane",
+ 1: "automobile",
+ 2: "bird",
+ 3: "cat",
+ 4: "deer",
+ 5: "dog",
+ 6: "frog",
+ 7: "horse",
+ 8: "ship",
+ 9: "truck",
+ }
+ tf_ds_batch_size = 20
+
+
+config = Config()
+
+"""
+## Load the CIFAR-10 dataset
+
+We use the CIFAR-10 dataset for our experiments.
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:40000], y_train[:40000]),
+ (x_train[40000:], y_train[40000:]),
+)
+print(f"Training samples: {len(x_train)}")
+print(f"Validation samples: {len(x_val)}")
+print(f"Testing samples: {len(x_test)}")
+
+AUTO = tf.data.AUTOTUNE
+train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
+
+val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)
+
+test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
+
+"""
+## Data Augmentation
+
+The augmentation pipeline consists of:
+
+- Rescaling
+- Resizing
+- Random cropping
+- Random horizontal flipping
+
+_Note_: The image data augmentation layers do not apply
+data transformations at inference time. This means that
+when these layers are called with `training=False` they
+behave differently. Refer to the
+[documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/)
+for more details.
+"""
+
+
+def get_augmentation_model():
+ """Build the data augmentation model."""
+ data_augmentation = keras.Sequential(
+ [
+ layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),
+ layers.RandomCrop(config.image_size, config.image_size),
+ layers.RandomFlip("horizontal"),
+ layers.Rescaling(1 / 255.0),
+ ]
+ )
+ return data_augmentation
+
+
+"""
+## The ShiftViT architecture
+
+In this section, we build the architecture proposed in
+[the ShiftViT paper](https://arxiv.org/abs/2201.10801).
+
+|  |
+| :--: |
+| Figure 1: The entire architecutre of ShiftViT.
+[Source](https://arxiv.org/abs/2201.10801) |
+
+The architecture as shown in Fig. 1, is inspired by
+[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030).
+Here the authors propose a modular architecture with 4 stages. Each stage works on its
+own spatial size, creating a hierarchical architecture.
+
+An input image of size `HxWx3` is split into non-overlapping patches of size `4x4`.
+This is done via the patchify layer which results in individual tokens of feature size `48`
+(`4x4x3`). Each stage comprises two parts:
+
+1. Embedding Generation
+2. Stacked Shift Blocks
+
+We discuss the stages and the modules in detail in what follows.
+
+_Note_: Compared to the [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py)
+we restructure some key components to better fit the Keras API.
+"""
+
+"""
+### The ShiftViT Block
+
+|  |
+| :--: |
+| Figure 2: From the Model to a Shift Block. |
+
+Each stage in the ShiftViT architecture comprises of a Shift Block as shown in Fig 2.
+
+|  |
+| :--: |
+| Figure 3: The Shift ViT Block. [Source](https://arxiv.org/abs/2201.10801) |
+
+The Shift Block as shown in Fig. 3, comprises of the following:
+
+1. Shift Operation
+2. Linear Normalization
+3. MLP Layer
+"""
+
+"""
+#### The MLP block
+
+The MLP block is intended to be a stack of densely-connected layers
+"""
+
+
+class MLP(layers.Layer):
+ """Get the MLP layer for each shift block.
+
+ Args:
+ mlp_expand_ratio (int): The ratio with which the first feature map is expanded.
+ mlp_dropout_rate (float): The rate for dropout.
+ """
+
+ def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):
+ super().__init__(**kwargs)
+ self.mlp_expand_ratio = mlp_expand_ratio
+ self.mlp_dropout_rate = mlp_dropout_rate
+
+ def build(self, input_shape):
+ input_channels = input_shape[-1]
+ initial_filters = int(self.mlp_expand_ratio * input_channels)
+
+ self.mlp = keras.Sequential(
+ [
+ layers.Dense(
+ units=initial_filters,
+ activation="gelu",
+ ),
+ layers.Dropout(rate=self.mlp_dropout_rate),
+ layers.Dense(units=input_channels),
+ layers.Dropout(rate=self.mlp_dropout_rate),
+ ]
+ )
+
+ def call(self, x):
+ x = self.mlp(x)
+ return x
+
+
+"""
+#### The DropPath layer
+
+Stochastic depth is a regularization technique that randomly drops a set of
+layers. During inference, the layers are kept as they are. It is very
+similar to Dropout, but it operates on a block of layers rather
+than on individual nodes present inside a layer.
+"""
+
+
+class DropPath(layers.Layer):
+ """Drop Path also known as the Stochastic Depth layer.
+
+ Refernece:
+ - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization
+ - github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path_prob, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path_prob = drop_path_prob
+ self.seed_generator = keras.random.SeedGenerator(1337)
+
+ def call(self, x, training=False):
+ if training:
+ keep_prob = 1 - self.drop_path_prob
+ shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)
+ random_tensor = keep_prob + keras.random.uniform(
+ shape, 0, 1, seed=self.seed_generator
+ )
+ random_tensor = ops.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+"""
+#### Block
+
+The most important operation in this paper is the **shift operation**. In this section,
+we describe the shift operation and compare it with its original implementation provided
+by the authors.
+
+A generic feature map is assumed to have the shape `[N, H, W, C]`. Here we choose a
+`num_div` parameter that decides the division size of the channels. The first 4 divisions
+are shifted (1 pixel) in the left, right, up, and down direction. The remaining splits
+are kept as is. After partial shifting the shifted channels are padded and the overflown
+pixels are chopped off. This completes the partial shifting operation.
+
+In the original implementation, the code is approximately:
+
+```python
+out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left
+out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right
+out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up
+out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down
+
+out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift
+```
+
+In TensorFlow it would be infeasible for us to assign shifted channels to a tensor in the
+middle of the training process. This is why we have resorted to the following procedure:
+
+1. Split the channels with the `num_div` parameter.
+2. Select each of the first four spilts and shift and pad them in the respective
+directions.
+3. After shifting and padding, we concatenate the channel back.
+
+|  |
+| :--: |
+| Figure 4: The TensorFlow style shifting |
+
+The entire procedure is explained in the Fig. 4.
+"""
+
+
+class ShiftViTBlock(layers.Layer):
+ """A unit ShiftViT Block
+
+ Args:
+ shift_pixel (int): The number of pixels to shift. Default to 1.
+ mlp_expand_ratio (int): The ratio with which MLP features are
+ expanded. Default to 2.
+ mlp_dropout_rate (float): The dropout rate used in MLP.
+ num_div (int): The number of divisions of the feature map's channel.
+ Totally, 4/num_div of channels will be shifted. Defaults to 12.
+ epsilon (float): Epsilon constant.
+ drop_path_prob (float): The drop probability for drop path.
+ """
+
+ def __init__(
+ self,
+ epsilon,
+ drop_path_prob,
+ mlp_dropout_rate,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.shift_pixel = shift_pixel
+ self.mlp_expand_ratio = mlp_expand_ratio
+ self.mlp_dropout_rate = mlp_dropout_rate
+ self.num_div = num_div
+ self.epsilon = epsilon
+ self.drop_path_prob = drop_path_prob
+
+ def build(self, input_shape):
+ self.H = input_shape[1]
+ self.W = input_shape[2]
+ self.C = input_shape[3]
+ self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
+ self.drop_path = (
+ DropPath(drop_path_prob=self.drop_path_prob)
+ if self.drop_path_prob > 0.0
+ else layers.Activation("linear")
+ )
+ self.mlp = MLP(
+ mlp_expand_ratio=self.mlp_expand_ratio,
+ mlp_dropout_rate=self.mlp_dropout_rate,
+ )
+
+ def get_shift_pad(self, x, mode):
+ """Shifts the channels according to the mode chosen."""
+ if mode == "left":
+ offset_height = 0
+ offset_width = 0
+ target_height = 0
+ target_width = self.shift_pixel
+ elif mode == "right":
+ offset_height = 0
+ offset_width = self.shift_pixel
+ target_height = 0
+ target_width = self.shift_pixel
+ elif mode == "up":
+ offset_height = 0
+ offset_width = 0
+ target_height = self.shift_pixel
+ target_width = 0
+ else:
+ offset_height = self.shift_pixel
+ offset_width = 0
+ target_height = self.shift_pixel
+ target_width = 0
+ crop = ops.image.crop_images(
+ x,
+ top_cropping=offset_height,
+ left_cropping=offset_width,
+ target_height=self.H - target_height,
+ target_width=self.W - target_width,
+ )
+ shift_pad = ops.image.pad_images(
+ crop,
+ top_padding=offset_height,
+ left_padding=offset_width,
+ target_height=self.H,
+ target_width=self.W,
+ )
+ return shift_pad
+
+ def call(self, x, training=False):
+ # Split the feature maps
+ x_splits = ops.split(x, indices_or_sections=self.C // self.num_div, axis=-1)
+
+ # Shift the feature maps
+ x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
+ x_splits[1] = self.get_shift_pad(x_splits[1], mode="right")
+ x_splits[2] = self.get_shift_pad(x_splits[2], mode="up")
+ x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")
+
+ # Concatenate the shifted and unshifted feature maps
+ x = ops.concatenate(x_splits, axis=-1)
+
+ # Add the residual connection
+ shortcut = x
+ x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)
+ return x
+
+
+"""
+### The ShiftViT blocks
+
+|  |
+| :--: |
+| Figure 5: Shift Blocks in the architecture. [Source](https://arxiv.org/abs/2201.10801) |
+
+Each stage of the architecture has shift blocks as shown in Fig.5. Each of these blocks
+contain a variable number of stacked ShiftViT block (as built in the earlier section).
+
+Shift blocks are followed by a PatchMerging layer that scales down feature inputs. The
+PatchMerging layer helps in the pyramidal structure of the model.
+"""
+
+"""
+#### The PatchMerging layer
+
+This layer merges the two adjacent tokens. This layer helps in scaling the features down
+spatially and increasing the features up channel wise. We use a Conv2D layer to merge the
+patches.
+"""
+
+
+class PatchMerging(layers.Layer):
+ """The Patch Merging layer.
+
+ Args:
+ epsilon (float): The epsilon constant.
+ """
+
+ def __init__(self, epsilon, **kwargs):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+
+ def build(self, input_shape):
+ filters = 2 * input_shape[-1]
+ self.reduction = layers.Conv2D(
+ filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False
+ )
+ self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
+
+ def call(self, x):
+ # Apply the patch merging algorithm on the feature maps
+ x = self.layer_norm(x)
+ x = self.reduction(x)
+ return x
+
+
+"""
+#### Stacked Shift Blocks
+
+Each stage will have a variable number of stacked ShiftViT Blocks, as suggested in
+the paper. This is a generic layer that will contain the stacked shift vit blocks
+with the patch merging layer as well. Combining the two operations (shift ViT
+block and patch merging) is a design choice we picked for better code reusability.
+"""
+
+
+# Note: This layer will have a different depth of stacking
+# for different stages on the model.
+class StackedShiftBlocks(layers.Layer):
+ """The layer containing stacked ShiftViTBlocks.
+
+ Args:
+ epsilon (float): The epsilon constant.
+ mlp_dropout_rate (float): The dropout rate used in the MLP block.
+ num_shift_blocks (int): The number of shift vit blocks for this stage.
+ stochastic_depth_rate (float): The maximum drop path rate chosen.
+ is_merge (boolean): A flag that determines the use of the Patch Merge
+ layer after the shift vit blocks.
+ num_div (int): The division of channels of the feature map. Defaults to 12.
+ shift_pixel (int): The number of pixels to shift. Defaults to 1.
+ mlp_expand_ratio (int): The ratio with which the initial dense layer of
+ the MLP is expanded Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ epsilon,
+ mlp_dropout_rate,
+ num_shift_blocks,
+ stochastic_depth_rate,
+ is_merge,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.epsilon = epsilon
+ self.mlp_dropout_rate = mlp_dropout_rate
+ self.num_shift_blocks = num_shift_blocks
+ self.stochastic_depth_rate = stochastic_depth_rate
+ self.is_merge = is_merge
+ self.num_div = num_div
+ self.shift_pixel = shift_pixel
+ self.mlp_expand_ratio = mlp_expand_ratio
+
+ def build(self, input_shapes):
+ # Calculate stochastic depth probabilities.
+ # Reference: https://keras.io/examples/vision/cct/#the-final-cct-model
+ dpr = [
+ x
+ for x in np.linspace(
+ start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks
+ )
+ ]
+
+ # Build the shift blocks as a list of ShiftViT Blocks
+ self.shift_blocks = list()
+ for num in range(self.num_shift_blocks):
+ self.shift_blocks.append(
+ ShiftViTBlock(
+ num_div=self.num_div,
+ epsilon=self.epsilon,
+ drop_path_prob=dpr[num],
+ mlp_dropout_rate=self.mlp_dropout_rate,
+ shift_pixel=self.shift_pixel,
+ mlp_expand_ratio=self.mlp_expand_ratio,
+ )
+ )
+ if self.is_merge:
+ self.patch_merge = PatchMerging(epsilon=self.epsilon)
+
+ def call(self, x, training=False):
+ for shift_block in self.shift_blocks:
+ x = shift_block(x, training=training)
+ if self.is_merge:
+ x = self.patch_merge(x)
+ return x
+
+ # Since this is a custom layer, we need to overwrite get_config()
+ # so that model can be easily saved & loaded after training
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "epsilon": self.epsilon,
+ "mlp_dropout_rate": self.mlp_dropout_rate,
+ "num_shift_blocks": self.num_shift_blocks,
+ "stochastic_depth_rate": self.stochastic_depth_rate,
+ "is_merge": self.is_merge,
+ "num_div": self.num_div,
+ "shift_pixel": self.shift_pixel,
+ "mlp_expand_ratio": self.mlp_expand_ratio,
+ }
+ )
+ return config
+
+
+"""
+## The ShiftViT model
+
+Build the ShiftViT custom model.
+"""
+
+
+class ShiftViTModel(keras.Model):
+ """The ShiftViT Model.
+
+ Args:
+ data_augmentation (keras.Model): A data augmentation model.
+ projected_dim (int): The dimension to which the patches of the image are
+ projected.
+ patch_size (int): The patch size of the images.
+ num_shift_blocks_per_stages (list[int]): A list of all the number of shit
+ blocks per stage.
+ epsilon (float): The epsilon constant.
+ mlp_dropout_rate (float): The dropout rate used in the MLP block.
+ stochastic_depth_rate (float): The maximum drop rate probability.
+ num_div (int): The number of divisions of the channesl of the feature
+ map. Defaults to 12.
+ shift_pixel (int): The number of pixel to shift. Default to 1.
+ mlp_expand_ratio (int): The ratio with which the initial mlp dense layer
+ is expanded to. Defaults to 2.
+ """
+
+ def __init__(
+ self,
+ data_augmentation,
+ projected_dim,
+ patch_size,
+ num_shift_blocks_per_stages,
+ epsilon,
+ mlp_dropout_rate,
+ stochastic_depth_rate,
+ num_div=12,
+ shift_pixel=1,
+ mlp_expand_ratio=2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.data_augmentation = data_augmentation
+ self.patch_projection = layers.Conv2D(
+ filters=projected_dim,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="same",
+ )
+ self.stages = list()
+ for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
+ if index == len(num_shift_blocks_per_stages) - 1:
+ # This is the last stage, do not use the patch merge here.
+ is_merge = False
+ else:
+ is_merge = True
+ # Build the stages.
+ self.stages.append(
+ StackedShiftBlocks(
+ epsilon=epsilon,
+ mlp_dropout_rate=mlp_dropout_rate,
+ num_shift_blocks=num_shift_blocks,
+ stochastic_depth_rate=stochastic_depth_rate,
+ is_merge=is_merge,
+ num_div=num_div,
+ shift_pixel=shift_pixel,
+ mlp_expand_ratio=mlp_expand_ratio,
+ )
+ )
+ self.global_avg_pool = layers.GlobalAveragePooling2D()
+
+ self.classifier = layers.Dense(config.num_classes)
+
+ def get_config(self):
+ config = super().get_config()
+ config.update(
+ {
+ "data_augmentation": self.data_augmentation,
+ "patch_projection": self.patch_projection,
+ "stages": self.stages,
+ "global_avg_pool": self.global_avg_pool,
+ "classifier": self.classifier,
+ }
+ )
+ return config
+
+ def _calculate_loss(self, data, training=False):
+ (images, labels) = data
+
+ # Augment the images
+ augmented_images = self.data_augmentation(images, training=training)
+
+ # Create patches and project the pathces.
+ projected_patches = self.patch_projection(augmented_images)
+
+ # Pass through the stages
+ x = projected_patches
+ for stage in self.stages:
+ x = stage(x, training=training)
+
+ # Get the logits.
+ x = self.global_avg_pool(x)
+ logits = self.classifier(x)
+
+ # Calculate the loss and return it.
+ total_loss = self.compiled_loss(labels, logits)
+ return total_loss, labels, logits
+
+ def train_step(self, inputs):
+ with tf.GradientTape() as tape:
+ total_loss, labels, logits = self._calculate_loss(
+ data=inputs, training=True
+ )
+
+ # Apply gradients.
+ train_vars = [
+ self.data_augmentation.trainable_variables,
+ self.patch_projection.trainable_variables,
+ self.global_avg_pool.trainable_variables,
+ self.classifier.trainable_variables,
+ ]
+ train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
+
+ # Optimize the gradients.
+ grads = tape.gradient(total_loss, train_vars)
+ trainable_variable_list = []
+ for grad, var in zip(grads, train_vars):
+ for g, v in zip(grad, var):
+ trainable_variable_list.append((g, v))
+ self.optimizer.apply_gradients(trainable_variable_list)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+ def test_step(self, data):
+ _, labels, logits = self._calculate_loss(data=data, training=False)
+
+ # Update the metrics
+ self.compiled_metrics.update_state(labels, logits)
+ return {m.name: m.result() for m in self.metrics}
+
+ def call(self, images):
+ augmented_images = self.data_augmentation(images)
+ x = self.patch_projection(augmented_images)
+ for stage in self.stages:
+ x = stage(x, training=False)
+ x = self.global_avg_pool(x)
+ logits = self.classifier(x)
+ return logits
+
+
+"""
+## Instantiate the model
+"""
+
+model = ShiftViTModel(
+ data_augmentation=get_augmentation_model(),
+ projected_dim=config.projected_dim,
+ patch_size=config.patch_size,
+ num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
+ epsilon=config.epsilon,
+ mlp_dropout_rate=config.mlp_dropout_rate,
+ stochastic_depth_rate=config.stochastic_depth_rate,
+ num_div=config.num_div,
+ shift_pixel=config.shift_pixel,
+ mlp_expand_ratio=config.mlp_expand_ratio,
+)
+
+"""
+## Learning rate schedule
+
+In many experiments, we want to warm up the model with a slowly increasing learning rate
+and then cool down the model with a slowly decaying learning rate. In the warmup cosine
+decay, the learning rate linearly increases for the warmup steps and then decays with a
+cosine decay.
+"""
+
+
+# Some code is taken from:
+# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ """A LearningRateSchedule that uses a warmup cosine decay schedule."""
+
+ def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
+ """
+ Args:
+ lr_start: The initial learning rate
+ lr_max: The maximum learning rate to which lr should increase to in
+ the warmup steps
+ warmup_steps: The number of steps for which the model warms up
+ total_steps: The total number of steps for the model training
+ """
+ super().__init__()
+ self.lr_start = lr_start
+ self.lr_max = lr_max
+ self.warmup_steps = warmup_steps
+ self.total_steps = total_steps
+ self.pi = ops.array(np.pi)
+
+ def __call__(self, step):
+ # Check whether the total number of steps is larger than the warmup
+ # steps. If not, then throw a value error.
+ if self.total_steps < self.warmup_steps:
+ raise ValueError(
+ f"Total number of steps {self.total_steps} must be"
+ + f"larger or equal to warmup steps {self.warmup_steps}."
+ )
+
+ # `cos_annealed_lr` is a graph that increases to 1 from the initial
+ # step to the warmup step. After that this graph decays to -1 at the
+ # final step mark.
+ cos_annealed_lr = ops.cos(
+ self.pi
+ * (ops.cast(step, dtype="float32") - self.warmup_steps)
+ / ops.cast(self.total_steps - self.warmup_steps, dtype="float32")
+ )
+
+ # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
+ # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
+ # to 1. With the normalized graph we scale it with `lr_max` such that
+ # it goes from 0 to `lr_max`
+ learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
+
+ # Check whether warmup_steps is more than 0.
+ if self.warmup_steps > 0:
+ # Check whether lr_max is larger that lr_start. If not, throw a value
+ # error.
+ if self.lr_max < self.lr_start:
+ raise ValueError(
+ f"lr_start {self.lr_start} must be smaller or"
+ + f"equal to lr_max {self.lr_max}."
+ )
+
+ # Calculate the slope with which the learning rate should increase
+ # in the warumup schedule. The formula for slope is m = ((b-a)/steps)
+ slope = (self.lr_max - self.lr_start) / self.warmup_steps
+
+ # With the formula for a straight line (y = mx+c) build the warmup
+ # schedule
+ warmup_rate = slope * ops.cast(step, dtype="float32") + self.lr_start
+
+ # When the current step is lesser that warmup steps, get the line
+ # graph. When the current step is greater than the warmup steps, get
+ # the scaled cos graph.
+ learning_rate = ops.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+
+ # When the current step is more that the total steps, return 0 else return
+ # the calculated graph.
+ return ops.where(step > self.total_steps, 0.0, learning_rate)
+
+ def get_config(self):
+ config = {
+ "lr_start": self.lr_start,
+ "lr_max": self.lr_max,
+ "total_steps": self.total_steps,
+ "warmup_steps": self.warmup_steps,
+ }
+ return config
+
+
+"""
+## Compile and train the model
+"""
+
+# pass sample data to the model so that input shape is available at the time of
+# saving the model
+sample_ds, _ = next(iter(train_ds))
+model(sample_ds, training=False)
+
+# Get the total number of steps for training.
+total_steps = int((len(x_train) / config.batch_size) * config.epochs)
+
+# Calculate the number of steps for warmup.
+warmup_epoch_percentage = 0.15
+warmup_steps = int(total_steps * warmup_epoch_percentage)
+
+# Initialize the warmupcosine schedule.
+scheduled_lrs = WarmUpCosine(
+ lr_start=1e-5,
+ lr_max=1e-3,
+ warmup_steps=warmup_steps,
+ total_steps=total_steps,
+)
+
+# Get the optimizer.
+optimizer = keras.optimizers.AdamW(
+ learning_rate=scheduled_lrs, weight_decay=config.weight_decay
+)
+
+# Compile and pretrain the model.
+model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+)
+
+# Train the model
+history = model.fit(
+ train_ds,
+ epochs=config.epochs,
+ validation_data=val_ds,
+ callbacks=[
+ keras.callbacks.EarlyStopping(
+ monitor="val_accuracy",
+ patience=5,
+ mode="auto",
+ )
+ ],
+)
+
+# Evaluate the model with the test dataset.
+print("TESTING")
+loss, acc_top1, acc_top5 = model.evaluate(test_ds)
+print(f"Loss: {loss:0.2f}")
+print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
+print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
+
+"""
+## Save trained model
+
+Since we created the model by Subclassing, we can't save the model in HDF5 format.
+
+It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well.
+"""
+model.export("ShiftViT")
+
+"""
+## Model inference
+"""
+
+"""
+**Download sample data for inference**
+"""
+
+"""shell
+wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip
+unzip -q inference_set.zip
+"""
+
+
+"""
+**Load saved model**
+"""
+# Using TFSMLayer to reload the TF SavedModel as a Keras layer.
+# This is not limited to SavedModels that originate from Keras โ it will work with any SavedModel, e.g. TF-Hub models.
+saved_model = keras.layers.TFSMLayer("ShiftViT", call_endpoint="serving_default")
+
+"""
+**Utility functions for inference**
+"""
+
+
+def process_image(img_path):
+ # read image file from string path
+ img = tf.io.read_file(img_path)
+
+ # decode jpeg to uint8 tensor
+ img = tf.io.decode_jpeg(img, channels=3)
+
+ # resize image to match input size accepted by model
+ # use `interpolation` as `nearest` to preserve dtype of input passed to `resize()`
+ img = ops.image.resize(
+ img, [config.input_shape[0], config.input_shape[1]], interpolation="nearest"
+ )
+ return img
+
+
+def create_tf_dataset(image_dir):
+ data_dir = pathlib.Path(image_dir)
+
+ # create tf.data dataset using directory of images
+ predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False)
+
+ # use map to convert string paths to uint8 image tensors
+ # setting `num_parallel_calls' helps in processing multiple images parallely
+ predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)
+
+ # create a Prefetch Dataset for better latency & throughput
+ predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)
+ return predict_ds
+
+
+def predict(predict_ds):
+ # ShiftViT model returns logits (non-normalized predictions)
+ model = keras.Sequential([saved_model])
+ output_dict = model.predict(predict_ds)
+ logits = list(output_dict.values())[0]
+
+ # normalize predictions by calling softmax()
+ probabilities = ops.softmax(logits)
+ return probabilities
+
+
+def get_predicted_class(probabilities):
+ pred_label = np.argmax(probabilities)
+ predicted_class = config.label_map[pred_label]
+ return predicted_class
+
+
+def get_confidence_scores(probabilities):
+ # get the indices of the probability scores sorted in descending order
+ labels = np.argsort(probabilities)[::-1]
+ confidences = {
+ config.label_map[label]: np.round((probabilities[label]) * 100, 2)
+ for label in labels
+ }
+ return confidences
+
+
+"""
+**Get predictions**
+"""
+
+img_dir = "inference_set"
+predict_ds = create_tf_dataset(img_dir)
+probabilities = predict(predict_ds)
+print(f"probabilities: {probabilities[0]}")
+confidences = get_confidence_scores(probabilities[0])
+print(confidences)
+
+"""
+**View predictions**
+"""
+
+plt.figure(figsize=(10, 10))
+for images in predict_ds:
+ for i in range(min(6, probabilities.shape[0])):
+ ax = plt.subplot(3, 3, i + 1)
+ plt.imshow(images[i].numpy().astype("uint8"))
+ predicted_class = get_predicted_class(probabilities[i])
+ plt.title(predicted_class)
+ plt.axis("off")
+
+"""
+## Conclusion
+
+The most impactful contribution of the paper is not the novel architecture, but
+the idea that hierarchical ViTs trained with no attention can perform quite well. This
+opens up the question of how essential attention is to the performance of ViTs.
+
+For curious minds, we would suggest reading the
+[ConvNexT](https://arxiv.org/abs/2201.03545) paper which attends more to the training
+paradigms and architectural details of ViTs rather than providing a novel architecture
+based on attention.
+
+Acknowledgements:
+
+- We would like to thank [PyImageSearch](https://pyimagesearch.com) for providing us with
+resources that helped in the completion of this project.
+- We would like to thank [JarvisLabs.ai](https://jarvislabs.ai/) for providing with the
+GPU credits.
+- We would like to thank [Manim Community](https://www.manim.community/) for the manim
+library.
+- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for
+helping us with the Learning Rate Schedule.
+"""
+
+"""
+**Example available on HuggingFace**
+
+| Trained Model | Demo |
+| :--: | :--: |
+| [](https://huggingface.co/keras-io/shiftvit) | [](https://huggingface.co/spaces/keras-io/shiftvit) |
+"""
diff --git a/knowledge_base/vision/siamese_contrastive.py b/knowledge_base/vision/siamese_contrastive.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ec2c8d2c3334e5f8b65d490605bd13a229158b
--- /dev/null
+++ b/knowledge_base/vision/siamese_contrastive.py
@@ -0,0 +1,406 @@
+"""
+Title: Image similarity estimation using a Siamese Network with a contrastive loss
+Author: Mehdi
+Date created: 2021/05/06
+Last modified: 2022/09/10
+Description: Similarity learning using a siamese network trained with a contrastive loss.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+[Siamese Networks](https://en.wikipedia.org/wiki/Siamese_neural_network)
+are neural networks which share weights between two or more sister networks,
+each producing embedding vectors of its respective inputs.
+
+In supervised similarity learning, the networks are then trained to maximize the
+contrast (distance) between embeddings of inputs of different classes, while minimizing the distance between
+embeddings of similar classes, resulting in embedding spaces that reflect
+the class segmentation of the training inputs.
+"""
+
+"""
+## Setup
+"""
+
+import random
+import numpy as np
+import keras
+from keras import ops
+import matplotlib.pyplot as plt
+
+"""
+## Hyperparameters
+"""
+
+epochs = 10
+batch_size = 16
+margin = 1 # Margin for contrastive loss.
+
+"""
+## Load the MNIST dataset
+"""
+(x_train_val, y_train_val), (x_test, y_test) = keras.datasets.mnist.load_data()
+
+# Change the data type to a floating point format
+x_train_val = x_train_val.astype("float32")
+x_test = x_test.astype("float32")
+
+
+"""
+## Define training and validation sets
+"""
+
+# Keep 50% of train_val in validation set
+x_train, x_val = x_train_val[:30000], x_train_val[30000:]
+y_train, y_val = y_train_val[:30000], y_train_val[30000:]
+del x_train_val, y_train_val
+
+
+"""
+## Create pairs of images
+
+We will train the model to differentiate between digits of different classes. For
+example, digit `0` needs to be differentiated from the rest of the
+digits (`1` through `9`), digit `1` - from `0` and `2` through `9`, and so on.
+To carry this out, we will select N random images from class A (for example,
+for digit `0`) and pair them with N random images from another class B
+(for example, for digit `1`). Then, we can repeat this process for all classes
+of digits (until digit `9`). Once we have paired digit `0` with other digits,
+we can repeat this process for the remaining classes for the rest of the digits
+(from `1` until `9`).
+"""
+
+
+def make_pairs(x, y):
+ """Creates a tuple containing image pairs with corresponding label.
+
+ Arguments:
+ x: List containing images, each index in this list corresponds to one image.
+ y: List containing labels, each label with datatype of `int`.
+
+ Returns:
+ Tuple containing two numpy arrays as (pairs_of_samples, labels),
+ where pairs_of_samples' shape is (2len(x), 2,n_features_dims) and
+ labels are a binary array of shape (2len(x)).
+ """
+
+ num_classes = max(y) + 1
+ digit_indices = [np.where(y == i)[0] for i in range(num_classes)]
+
+ pairs = []
+ labels = []
+
+ for idx1 in range(len(x)):
+ # add a matching example
+ x1 = x[idx1]
+ label1 = y[idx1]
+ idx2 = random.choice(digit_indices[label1])
+ x2 = x[idx2]
+
+ pairs += [[x1, x2]]
+ labels += [0]
+
+ # add a non-matching example
+ label2 = random.randint(0, num_classes - 1)
+ while label2 == label1:
+ label2 = random.randint(0, num_classes - 1)
+
+ idx2 = random.choice(digit_indices[label2])
+ x2 = x[idx2]
+
+ pairs += [[x1, x2]]
+ labels += [1]
+
+ return np.array(pairs), np.array(labels).astype("float32")
+
+
+# make train pairs
+pairs_train, labels_train = make_pairs(x_train, y_train)
+
+# make validation pairs
+pairs_val, labels_val = make_pairs(x_val, y_val)
+
+# make test pairs
+pairs_test, labels_test = make_pairs(x_test, y_test)
+
+"""
+We get:
+
+**pairs_train.shape = (60000, 2, 28, 28)**
+
+- We have 60,000 pairs
+- Each pair contains 2 images
+- Each image has shape `(28, 28)`
+"""
+
+"""
+Split the training pairs
+"""
+
+x_train_1 = pairs_train[:, 0] # x_train_1.shape is (60000, 28, 28)
+x_train_2 = pairs_train[:, 1]
+
+"""
+Split the validation pairs
+"""
+
+x_val_1 = pairs_val[:, 0] # x_val_1.shape = (60000, 28, 28)
+x_val_2 = pairs_val[:, 1]
+
+"""
+Split the test pairs
+"""
+
+x_test_1 = pairs_test[:, 0] # x_test_1.shape = (20000, 28, 28)
+x_test_2 = pairs_test[:, 1]
+
+
+"""
+## Visualize pairs and their labels
+"""
+
+
+def visualize(pairs, labels, to_show=6, num_col=3, predictions=None, test=False):
+ """Creates a plot of pairs and labels, and prediction if it's test dataset.
+
+ Arguments:
+ pairs: Numpy Array, of pairs to visualize, having shape
+ (Number of pairs, 2, 28, 28).
+ to_show: Int, number of examples to visualize (default is 6)
+ `to_show` must be an integral multiple of `num_col`.
+ Otherwise it will be trimmed if it is greater than num_col,
+ and incremented if if it is less then num_col.
+ num_col: Int, number of images in one row - (default is 3)
+ For test and train respectively, it should not exceed 3 and 7.
+ predictions: Numpy Array of predictions with shape (to_show, 1) -
+ (default is None)
+ Must be passed when test=True.
+ test: Boolean telling whether the dataset being visualized is
+ train dataset or test dataset - (default False).
+
+ Returns:
+ None.
+ """
+
+ # Define num_row
+ # If to_show % num_col != 0
+ # trim to_show,
+ # to trim to_show limit num_row to the point where
+ # to_show % num_col == 0
+ #
+ # If to_show//num_col == 0
+ # then it means num_col is greater then to_show
+ # increment to_show
+ # to increment to_show set num_row to 1
+ num_row = to_show // num_col if to_show // num_col != 0 else 1
+
+ # `to_show` must be an integral multiple of `num_col`
+ # we found num_row and we have num_col
+ # to increment or decrement to_show
+ # to make it integral multiple of `num_col`
+ # simply set it equal to num_row * num_col
+ to_show = num_row * num_col
+
+ # Plot the images
+ fig, axes = plt.subplots(num_row, num_col, figsize=(5, 5))
+ for i in range(to_show):
+ # If the number of rows is 1, the axes array is one-dimensional
+ if num_row == 1:
+ ax = axes[i % num_col]
+ else:
+ ax = axes[i // num_col, i % num_col]
+
+ ax.imshow(ops.concatenate([pairs[i][0], pairs[i][1]], axis=1), cmap="gray")
+ ax.set_axis_off()
+ if test:
+ ax.set_title("True: {} | Pred: {:.5f}".format(labels[i], predictions[i][0]))
+ else:
+ ax.set_title("Label: {}".format(labels[i]))
+ if test:
+ plt.tight_layout(rect=(0, 0, 1.9, 1.9), w_pad=0.0)
+ else:
+ plt.tight_layout(rect=(0, 0, 1.5, 1.5))
+ plt.show()
+
+
+"""
+Inspect training pairs
+"""
+
+visualize(pairs_train[:-1], labels_train[:-1], to_show=4, num_col=4)
+
+"""
+Inspect validation pairs
+"""
+
+visualize(pairs_val[:-1], labels_val[:-1], to_show=4, num_col=4)
+
+"""
+Inspect test pairs
+"""
+
+visualize(pairs_test[:-1], labels_test[:-1], to_show=4, num_col=4)
+
+"""
+## Define the model
+
+There are two input layers, each leading to its own network, which
+produces embeddings. A `Lambda` layer then merges them using an
+[Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) and the
+merged output is fed to the final network.
+"""
+
+
+# Provided two tensors t1 and t2
+# Euclidean distance = sqrt(sum(square(t1-t2)))
+def euclidean_distance(vects):
+ """Find the Euclidean distance between two vectors.
+
+ Arguments:
+ vects: List containing two tensors of same length.
+
+ Returns:
+ Tensor containing euclidean distance
+ (as floating point value) between vectors.
+ """
+
+ x, y = vects
+ sum_square = ops.sum(ops.square(x - y), axis=1, keepdims=True)
+ return ops.sqrt(ops.maximum(sum_square, keras.backend.epsilon()))
+
+
+input = keras.layers.Input((28, 28, 1))
+x = keras.layers.BatchNormalization()(input)
+x = keras.layers.Conv2D(4, (5, 5), activation="tanh")(x)
+x = keras.layers.AveragePooling2D(pool_size=(2, 2))(x)
+x = keras.layers.Conv2D(16, (5, 5), activation="tanh")(x)
+x = keras.layers.AveragePooling2D(pool_size=(2, 2))(x)
+x = keras.layers.Flatten()(x)
+
+x = keras.layers.BatchNormalization()(x)
+x = keras.layers.Dense(10, activation="tanh")(x)
+embedding_network = keras.Model(input, x)
+
+
+input_1 = keras.layers.Input((28, 28, 1))
+input_2 = keras.layers.Input((28, 28, 1))
+
+# As mentioned above, Siamese Network share weights between
+# tower networks (sister networks). To allow this, we will use
+# same embedding network for both tower networks.
+tower_1 = embedding_network(input_1)
+tower_2 = embedding_network(input_2)
+
+merge_layer = keras.layers.Lambda(euclidean_distance, output_shape=(1,))(
+ [tower_1, tower_2]
+)
+normal_layer = keras.layers.BatchNormalization()(merge_layer)
+output_layer = keras.layers.Dense(1, activation="sigmoid")(normal_layer)
+siamese = keras.Model(inputs=[input_1, input_2], outputs=output_layer)
+
+
+"""
+## Define the contrastive Loss
+"""
+
+
+def loss(margin=1):
+ """Provides 'contrastive_loss' an enclosing scope with variable 'margin'.
+
+ Arguments:
+ margin: Integer, defines the baseline for distance for which pairs
+ should be classified as dissimilar. - (default is 1).
+
+ Returns:
+ 'contrastive_loss' function with data ('margin') attached.
+ """
+
+ # Contrastive loss = mean( (1-true_value) * square(prediction) +
+ # true_value * square( max(margin-prediction, 0) ))
+ def contrastive_loss(y_true, y_pred):
+ """Calculates the contrastive loss.
+
+ Arguments:
+ y_true: List of labels, each label is of type float32.
+ y_pred: List of predictions of same length as of y_true,
+ each label is of type float32.
+
+ Returns:
+ A tensor containing contrastive loss as floating point value.
+ """
+
+ square_pred = ops.square(y_pred)
+ margin_square = ops.square(ops.maximum(margin - (y_pred), 0))
+ return ops.mean((1 - y_true) * square_pred + (y_true) * margin_square)
+
+ return contrastive_loss
+
+
+"""
+## Compile the model with the contrastive loss
+"""
+
+siamese.compile(loss=loss(margin=margin), optimizer="RMSprop", metrics=["accuracy"])
+siamese.summary()
+
+
+"""
+## Train the model
+"""
+
+history = siamese.fit(
+ [x_train_1, x_train_2],
+ labels_train,
+ validation_data=([x_val_1, x_val_2], labels_val),
+ batch_size=batch_size,
+ epochs=epochs,
+)
+
+"""
+## Visualize results
+"""
+
+
+def plt_metric(history, metric, title, has_valid=True):
+ """Plots the given 'metric' from 'history'.
+
+ Arguments:
+ history: history attribute of History object returned from Model.fit.
+ metric: Metric to plot, a string value present as key in 'history'.
+ title: A string to be used as title of plot.
+ has_valid: Boolean, true if valid data was passed to Model.fit else false.
+
+ Returns:
+ None.
+ """
+ plt.plot(history[metric])
+ if has_valid:
+ plt.plot(history["val_" + metric])
+ plt.legend(["train", "validation"], loc="upper left")
+ plt.title(title)
+ plt.ylabel(metric)
+ plt.xlabel("epoch")
+ plt.show()
+
+
+# Plot the accuracy
+plt_metric(history=history.history, metric="accuracy", title="Model accuracy")
+
+# Plot the contrastive loss
+plt_metric(history=history.history, metric="loss", title="Contrastive Loss")
+
+"""
+## Evaluate the model
+"""
+
+results = siamese.evaluate([x_test_1, x_test_2], labels_test)
+print("test loss, test acc:", results)
+
+"""
+## Visualize the predictions
+"""
+
+predictions = siamese.predict([x_test_1, x_test_2])
+visualize(pairs_test, labels_test, to_show=3, predictions=predictions, test=True)
diff --git a/knowledge_base/vision/siamese_network.py b/knowledge_base/vision/siamese_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcf776a3e724d8648a90ddd4191baa2ecf89943
--- /dev/null
+++ b/knowledge_base/vision/siamese_network.py
@@ -0,0 +1,417 @@
+"""
+Title: Image similarity estimation using a Siamese Network with a triplet loss
+Authors: [Hazem Essam](https://twitter.com/hazemessamm) and [Santiago L. Valdarrama](https://twitter.com/svpino)
+Date created: 2021/03/25
+Last modified: 2021/03/25
+Description: Training a Siamese Network to compare the similarity of images using a triplet loss function.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+A [Siamese Network](https://en.wikipedia.org/wiki/Siamese_neural_network) is a type of network architecture that
+contains two or more identical subnetworks used to generate feature vectors for each input and compare them.
+
+Siamese Networks can be applied to different use cases, like detecting duplicates, finding anomalies, and face recognition.
+
+This example uses a Siamese Network with three identical subnetworks. We will provide three images to the model, where
+two of them will be similar (_anchor_ and _positive_ samples), and the third will be unrelated (a _negative_ example.)
+Our goal is for the model to learn to estimate the similarity between images.
+
+For the network to learn, we use a triplet loss function. You can find an introduction to triplet loss in the
+[FaceNet paper](https://arxiv.org/abs/1503.03832) by Schroff et al,. 2015. In this example, we define the triplet
+loss function as follows:
+
+`L(A, P, N) = max(โf(A) - f(P)โยฒ - โf(A) - f(N)โยฒ + margin, 0)`
+
+This example uses the [Totally Looks Like dataset](https://sites.google.com/view/totally-looks-like-dataset)
+by [Rosenfeld et al., 2018](https://arxiv.org/abs/1803.01485v3).
+"""
+
+"""
+## Setup
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import random
+import tensorflow as tf
+from pathlib import Path
+from keras import applications
+from keras import layers
+from keras import losses
+from keras import ops
+from keras import optimizers
+from keras import metrics
+from keras import Model
+from keras.applications import resnet
+
+
+target_shape = (200, 200)
+
+
+"""
+## Load the dataset
+
+We are going to load the *Totally Looks Like* dataset and unzip it inside the `~/.keras` directory
+in the local environment.
+
+The dataset consists of two separate files:
+
+* `left.zip` contains the images that we will use as the anchor.
+* `right.zip` contains the images that we will use as the positive sample (an image that looks like the anchor).
+"""
+
+cache_dir = Path(Path.home()) / ".keras"
+anchor_images_path = cache_dir / "left"
+positive_images_path = cache_dir / "right"
+
+"""shell
+gdown --id 1jvkbTr_giSP3Ru8OwGNCg6B4PvVbcO34
+gdown --id 1EzBZUb_mh_Dp_FKD0P4XiYYSd0QBH5zW
+unzip -oq left.zip -d $cache_dir
+unzip -oq right.zip -d $cache_dir
+"""
+
+"""
+## Preparing the data
+
+We are going to use a `tf.data` pipeline to load the data and generate the triplets that we
+need to train the Siamese network.
+
+We'll set up the pipeline using a zipped list with anchor, positive, and negative filenames as
+the source. The pipeline will load and preprocess the corresponding images.
+"""
+
+
+def preprocess_image(filename):
+ """
+ Load the specified file as a JPEG image, preprocess it and
+ resize it to the target shape.
+ """
+
+ image_string = tf.io.read_file(filename)
+ image = tf.image.decode_jpeg(image_string, channels=3)
+ image = tf.image.convert_image_dtype(image, tf.float32)
+ image = tf.image.resize(image, target_shape)
+ return image
+
+
+def preprocess_triplets(anchor, positive, negative):
+ """
+ Given the filenames corresponding to the three images, load and
+ preprocess them.
+ """
+
+ return (
+ preprocess_image(anchor),
+ preprocess_image(positive),
+ preprocess_image(negative),
+ )
+
+
+"""
+Let's setup our data pipeline using a zipped list with an anchor, positive,
+and negative image filename as the source. The output of the pipeline
+contains the same triplet with every image loaded and preprocessed.
+"""
+
+# We need to make sure both the anchor and positive images are loaded in
+# sorted order so we can match them together.
+anchor_images = sorted(
+ [str(anchor_images_path / f) for f in os.listdir(anchor_images_path)]
+)
+
+positive_images = sorted(
+ [str(positive_images_path / f) for f in os.listdir(positive_images_path)]
+)
+
+image_count = len(anchor_images)
+
+anchor_dataset = tf.data.Dataset.from_tensor_slices(anchor_images)
+positive_dataset = tf.data.Dataset.from_tensor_slices(positive_images)
+
+# To generate the list of negative images, let's randomize the list of
+# available images and concatenate them together.
+rng = np.random.RandomState(seed=42)
+rng.shuffle(anchor_images)
+rng.shuffle(positive_images)
+
+negative_images = anchor_images + positive_images
+np.random.RandomState(seed=32).shuffle(negative_images)
+
+negative_dataset = tf.data.Dataset.from_tensor_slices(negative_images)
+negative_dataset = negative_dataset.shuffle(buffer_size=4096)
+
+dataset = tf.data.Dataset.zip((anchor_dataset, positive_dataset, negative_dataset))
+dataset = dataset.shuffle(buffer_size=1024)
+dataset = dataset.map(preprocess_triplets)
+
+# Let's now split our dataset in train and validation.
+train_dataset = dataset.take(round(image_count * 0.8))
+val_dataset = dataset.skip(round(image_count * 0.8))
+
+train_dataset = train_dataset.batch(32, drop_remainder=False)
+train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
+
+val_dataset = val_dataset.batch(32, drop_remainder=False)
+val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
+
+
+"""
+Let's take a look at a few examples of triplets. Notice how the first two images
+look alike while the third one is always different.
+"""
+
+
+def visualize(anchor, positive, negative):
+ """Visualize a few triplets from the supplied batches."""
+
+ def show(ax, image):
+ ax.imshow(image)
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+
+ fig = plt.figure(figsize=(9, 9))
+
+ axs = fig.subplots(3, 3)
+ for i in range(3):
+ show(axs[i, 0], anchor[i])
+ show(axs[i, 1], positive[i])
+ show(axs[i, 2], negative[i])
+
+
+visualize(*list(train_dataset.take(1).as_numpy_iterator())[0])
+
+"""
+## Setting up the embedding generator model
+
+Our Siamese Network will generate embeddings for each of the images of the
+triplet. To do this, we will use a ResNet50 model pretrained on ImageNet and
+connect a few `Dense` layers to it so we can learn to separate these
+embeddings.
+
+We will freeze the weights of all the layers of the model up until the layer `conv5_block1_out`.
+This is important to avoid affecting the weights that the model has already learned.
+We are going to leave the bottom few layers trainable, so that we can fine-tune their weights
+during training.
+"""
+
+base_cnn = resnet.ResNet50(
+ weights="imagenet", input_shape=target_shape + (3,), include_top=False
+)
+
+flatten = layers.Flatten()(base_cnn.output)
+dense1 = layers.Dense(512, activation="relu")(flatten)
+dense1 = layers.BatchNormalization()(dense1)
+dense2 = layers.Dense(256, activation="relu")(dense1)
+dense2 = layers.BatchNormalization()(dense2)
+output = layers.Dense(256)(dense2)
+
+embedding = Model(base_cnn.input, output, name="Embedding")
+
+trainable = False
+for layer in base_cnn.layers:
+ if layer.name == "conv5_block1_out":
+ trainable = True
+ layer.trainable = trainable
+
+"""
+## Setting up the Siamese Network model
+
+The Siamese network will receive each of the triplet images as an input,
+generate the embeddings, and output the distance between the anchor and the
+positive embedding, as well as the distance between the anchor and the negative
+embedding.
+
+To compute the distance, we can use a custom layer `DistanceLayer` that
+returns both values as a tuple.
+"""
+
+
+class DistanceLayer(layers.Layer):
+ """
+ This layer is responsible for computing the distance between the anchor
+ embedding and the positive embedding, and the anchor embedding and the
+ negative embedding.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def call(self, anchor, positive, negative):
+ ap_distance = ops.sum(tf.square(anchor - positive), -1)
+ an_distance = ops.sum(tf.square(anchor - negative), -1)
+ return (ap_distance, an_distance)
+
+
+anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
+positive_input = layers.Input(name="positive", shape=target_shape + (3,))
+negative_input = layers.Input(name="negative", shape=target_shape + (3,))
+
+distances = DistanceLayer()(
+ embedding(resnet.preprocess_input(anchor_input)),
+ embedding(resnet.preprocess_input(positive_input)),
+ embedding(resnet.preprocess_input(negative_input)),
+)
+
+siamese_network = Model(
+ inputs=[anchor_input, positive_input, negative_input], outputs=distances
+)
+
+"""
+## Putting everything together
+
+We now need to implement a model with custom training loop so we can compute
+the triplet loss using the three embeddings produced by the Siamese network.
+
+Let's create a `Mean` metric instance to track the loss of the training process.
+"""
+
+
+class SiameseModel(Model):
+ """The Siamese Network model with a custom training and testing loops.
+
+ Computes the triplet loss using the three embeddings produced by the
+ Siamese Network.
+
+ The triplet loss is defined as:
+ L(A, P, N) = max(โf(A) - f(P)โยฒ - โf(A) - f(N)โยฒ + margin, 0)
+ """
+
+ def __init__(self, siamese_network, margin=0.5):
+ super().__init__()
+ self.siamese_network = siamese_network
+ self.margin = margin
+ self.loss_tracker = metrics.Mean(name="loss")
+
+ def call(self, inputs):
+ return self.siamese_network(inputs)
+
+ def train_step(self, data):
+ # GradientTape is a context manager that records every operation that
+ # you do inside. We are using it here to compute the loss so we can get
+ # the gradients and apply them using the optimizer specified in
+ # `compile()`.
+ with tf.GradientTape() as tape:
+ loss = self._compute_loss(data)
+
+ # Storing the gradients of the loss function with respect to the
+ # weights/parameters.
+ gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
+
+ # Applying the gradients on the model using the specified optimizer
+ self.optimizer.apply_gradients(
+ zip(gradients, self.siamese_network.trainable_weights)
+ )
+
+ # Let's update and return the training loss metric.
+ self.loss_tracker.update_state(loss)
+ return {"loss": self.loss_tracker.result()}
+
+ def test_step(self, data):
+ loss = self._compute_loss(data)
+
+ # Let's update and return the loss metric.
+ self.loss_tracker.update_state(loss)
+ return {"loss": self.loss_tracker.result()}
+
+ def _compute_loss(self, data):
+ # The output of the network is a tuple containing the distances
+ # between the anchor and the positive example, and the anchor and
+ # the negative example.
+ ap_distance, an_distance = self.siamese_network(data)
+
+ # Computing the Triplet Loss by subtracting both distances and
+ # making sure we don't get a negative value.
+ loss = ap_distance - an_distance
+ loss = tf.maximum(loss + self.margin, 0.0)
+ return loss
+
+ @property
+ def metrics(self):
+ # We need to list our metrics here so the `reset_states()` can be
+ # called automatically.
+ return [self.loss_tracker]
+
+
+"""
+## Training
+
+We are now ready to train our model.
+"""
+
+siamese_model = SiameseModel(siamese_network)
+siamese_model.compile(optimizer=optimizers.Adam(0.0001))
+siamese_model.fit(train_dataset, epochs=10, validation_data=val_dataset)
+
+"""
+## Inspecting what the network has learned
+
+At this point, we can check how the network learned to separate the embeddings
+depending on whether they belong to similar images.
+
+We can use [cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) to measure the
+similarity between embeddings.
+
+Let's pick a sample from the dataset to check the similarity between the
+embeddings generated for each image.
+"""
+sample = next(iter(train_dataset))
+visualize(*sample)
+
+anchor, positive, negative = sample
+anchor_embedding, positive_embedding, negative_embedding = (
+ embedding(resnet.preprocess_input(anchor)),
+ embedding(resnet.preprocess_input(positive)),
+ embedding(resnet.preprocess_input(negative)),
+)
+
+"""
+Finally, we can compute the cosine similarity between the anchor and positive
+images and compare it with the similarity between the anchor and the negative
+images.
+
+We should expect the similarity between the anchor and positive images to be
+larger than the similarity between the anchor and the negative images.
+"""
+
+cosine_similarity = metrics.CosineSimilarity()
+
+positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
+print("Positive similarity:", positive_similarity.numpy())
+
+negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
+print("Negative similarity", negative_similarity.numpy())
+
+
+"""
+## Summary
+
+1. The `tf.data` API enables you to build efficient input pipelines for your model. It is
+particularly useful if you have a large dataset. You can learn more about `tf.data`
+pipelines in [tf.data: Build TensorFlow input pipelines](https://www.tensorflow.org/guide/data).
+
+2. In this example, we use a pre-trained ResNet50 as part of the subnetwork that generates
+the feature embeddings. By using [transfer learning](https://www.tensorflow.org/guide/keras/transfer_learning?hl=en),
+we can significantly reduce the training time and size of the dataset.
+
+3. Notice how we are [fine-tuning](https://www.tensorflow.org/guide/keras/transfer_learning?hl=en#fine-tuning)
+the weights of the final layers of the ResNet50 network but keeping the rest of the layers untouched.
+Using the name assigned to each layer, we can freeze the weights to a certain point and keep the last few layers open.
+
+4. We can create custom layers by creating a class that inherits from `tf.keras.layers.Layer`,
+as we did in the `DistanceLayer` class.
+
+5. We used a cosine similarity metric to measure how to 2 output embeddings are similar to each other.
+
+6. You can implement a custom training loop by overriding the `train_step()` method. `train_step()` uses
+[`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape),
+which records every operation that you perform inside it. In this example, we use it to access the
+gradients passed to the optimizer to update the model weights at every step. For more details, check out the
+[Intro to Keras for researchers](https://keras.io/getting_started/intro_to_keras_for_researchers/)
+and [Writing a training loop from scratch](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch?hl=en).
+
+"""
diff --git a/knowledge_base/vision/simsiam.py b/knowledge_base/vision/simsiam.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a671cd12f293f864e39acb4c82eacaf39916827
--- /dev/null
+++ b/knowledge_base/vision/simsiam.py
@@ -0,0 +1,450 @@
+"""
+Title: Self-supervised contrastive learning with SimSiam
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/03/19
+Last modified: 2023/12/29
+Description: Implementation of a self-supervised learning method for computer vision.
+Accelerator: GPU
+"""
+
+"""
+Self-supervised learning (SSL) is an interesting branch of study in the field of
+representation learning. SSL systems try to formulate a supervised signal from a corpus
+of unlabeled data points. An example is we train a deep neural network to predict the
+next word from a given set of words. In literature, these tasks are known as *pretext
+tasks* or *auxiliary tasks*. If we [train such a network](https://arxiv.org/abs/1801.06146) on a huge dataset (such as
+the [Wikipedia text corpus](https://www.corpusdata.org/wikipedia.asp)) it learns very effective
+representations that transfer well to downstream tasks. Language models like
+[BERT](https://arxiv.org/abs/1810.04805), [GPT-3](https://arxiv.org/abs/2005.14165),
+[ELMo](https://allennlp.org/elmo) all benefit from this.
+
+Much like the language models we can train computer vision models using similar
+approaches. To make things work in computer vision, we need to formulate the learning
+tasks such that the underlying model (a deep neural network) is able to make sense of the
+semantic information present in vision data. One such task is to a model to _contrast_
+between two different versions of the same image. The hope is that in this way the model
+will have learn representations where the similar images are grouped as together possible
+while the dissimilar images are further away.
+
+In this example, we will be implementing one such system called **SimSiam** proposed in
+[Exploring Simple Siamese Representation Learning](https://arxiv.org/abs/2011.10566). It
+is implemented as the following:
+
+1. We create two different versions of the same dataset with a stochastic data
+augmentation pipeline. Note that the random initialization seed needs to be the same
+during create these versions.
+2. We take a ResNet without any classification head (**backbone**) and we add a shallow
+fully-connected network (**projection head**) on top of it. Collectively, this is known
+as the **encoder**.
+3. We pass the output of the encoder through a **predictor** which is again a shallow
+fully-connected network having an
+[AutoEncoder](https://en.wikipedia.org/wiki/Autoencoder) like structure.
+4. We then train our encoder to maximize the cosine similarity between the two different
+versions of our dataset.
+
+"""
+
+"""
+## Setup
+"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+import keras_cv
+from keras import ops
+from keras import layers
+from keras import regularizers
+import tensorflow as tf
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+"""
+## Define hyperparameters
+"""
+
+AUTO = tf.data.AUTOTUNE
+BATCH_SIZE = 128
+EPOCHS = 5
+CROP_TO = 32
+SEED = 26
+
+PROJECT_DIM = 2048
+LATENT_DIM = 512
+WEIGHT_DECAY = 0.0005
+
+"""
+## Load the CIFAR-10 dataset
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+print(f"Total training examples: {len(x_train)}")
+print(f"Total test examples: {len(x_test)}")
+
+"""
+## Defining our data augmentation pipeline
+
+As studied in [SimCLR](https://arxiv.org/abs/2002.05709) having the right data
+augmentation pipeline is critical for SSL systems to work effectively in computer vision.
+Two particular augmentation transforms that seem to matter the most are: 1.) Random
+resized crops and 2.) Color distortions. Most of the other SSL systems for computer
+vision (such as [BYOL](https://arxiv.org/abs/2006.07733),
+[MoCoV2](https://arxiv.org/abs/2003.04297), [SwAV](https://arxiv.org/abs/2006.09882),
+etc.) include these in their training pipelines.
+"""
+
+
+strength = [0.4, 0.4, 0.3, 0.1]
+
+random_flip = layers.RandomFlip(mode="horizontal_and_vertical")
+random_crop = layers.RandomCrop(CROP_TO, CROP_TO)
+random_brightness = layers.RandomBrightness(0.8 * strength[0])
+random_contrast = layers.RandomContrast((1 - 0.8 * strength[1], 1 + 0.8 * strength[1]))
+random_saturation = keras_cv.layers.RandomSaturation(
+ (0.5 - 0.8 * strength[2], 0.5 + 0.8 * strength[2])
+)
+random_hue = keras_cv.layers.RandomHue(0.2 * strength[3], [0, 255])
+grayscale = keras_cv.layers.Grayscale()
+
+
+def flip_random_crop(image):
+ # With random crops we also apply horizontal flipping.
+ image = random_flip(image)
+ image = random_crop(image)
+ return image
+
+
+def color_jitter(x):
+ x = random_brightness(x)
+ x = random_contrast(x)
+ x = random_saturation(x)
+ x = random_hue(x)
+ # Affine transformations can disturb the natural range of
+ # RGB images, hence this is needed.
+ x = ops.clip(x, 0, 255)
+ return x
+
+
+def color_drop(x):
+ x = grayscale(x)
+ x = ops.tile(x, [1, 1, 3])
+ return x
+
+
+def random_apply(func, x, p):
+ if keras.random.uniform([], minval=0, maxval=1) < p:
+ return func(x)
+ else:
+ return x
+
+
+def custom_augment(image):
+ # As discussed in the SimCLR paper, the series of augmentation
+ # transformations (except for random crops) need to be applied
+ # randomly to impose translational invariance.
+ image = flip_random_crop(image)
+ image = random_apply(color_jitter, image, p=0.8)
+ image = random_apply(color_drop, image, p=0.2)
+ return image
+
+
+"""
+It should be noted that an augmentation pipeline is generally dependent on various
+properties of the dataset we are dealing with. For example, if images in the dataset are
+heavily object-centric then taking random crops with a very high probability may hurt the
+training performance.
+
+Let's now apply our augmentation pipeline to our dataset and visualize a few outputs.
+"""
+
+"""
+## Convert the data into TensorFlow `Dataset` objects
+
+Here we create two different versions of our dataset *without* any ground-truth labels.
+"""
+
+ssl_ds_one = tf.data.Dataset.from_tensor_slices(x_train)
+ssl_ds_one = (
+ ssl_ds_one.shuffle(1024, seed=SEED)
+ .map(custom_augment, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+ssl_ds_two = tf.data.Dataset.from_tensor_slices(x_train)
+ssl_ds_two = (
+ ssl_ds_two.shuffle(1024, seed=SEED)
+ .map(custom_augment, num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+
+# We then zip both of these datasets.
+ssl_ds = tf.data.Dataset.zip((ssl_ds_one, ssl_ds_two))
+
+# Visualize a few augmented images.
+sample_images_one = next(iter(ssl_ds_one))
+plt.figure(figsize=(10, 10))
+for n in range(25):
+ ax = plt.subplot(5, 5, n + 1)
+ plt.imshow(sample_images_one[n].numpy().astype("int"))
+ plt.axis("off")
+plt.show()
+
+# Ensure that the different versions of the dataset actually contain
+# identical images.
+sample_images_two = next(iter(ssl_ds_two))
+plt.figure(figsize=(10, 10))
+for n in range(25):
+ ax = plt.subplot(5, 5, n + 1)
+ plt.imshow(sample_images_two[n].numpy().astype("int"))
+ plt.axis("off")
+plt.show()
+
+"""
+Notice that the images in `samples_images_one` and `sample_images_two` are essentially
+the same but are augmented differently.
+"""
+
+"""
+## Defining the encoder and the predictor
+
+We use an implementation of ResNet20 that is specifically configured for the CIFAR10
+dataset. The code is taken from the
+[keras-idiomatic-programmer](https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/blob/master/zoo/resnet/resnet_cifar10_v2.py) repository. The hyperparameters of
+these architectures have been referred from Section 3 and Appendix A of [the original
+paper](https://arxiv.org/abs/2011.10566).
+"""
+
+"""shell
+wget -q https://git.io/JYx2x -O resnet_cifar10_v2.py
+"""
+
+import resnet_cifar10_v2
+
+N = 2
+DEPTH = N * 9 + 2
+NUM_BLOCKS = ((DEPTH - 2) // 9) - 1
+
+
+def get_encoder():
+ # Input and backbone.
+ inputs = layers.Input((CROP_TO, CROP_TO, 3))
+ x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(inputs)
+ x = resnet_cifar10_v2.stem(x)
+ x = resnet_cifar10_v2.learner(x, NUM_BLOCKS)
+ x = layers.GlobalAveragePooling2D(name="backbone_pool")(x)
+
+ # Projection head.
+ x = layers.Dense(
+ PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
+ )(x)
+ x = layers.BatchNormalization()(x)
+ x = layers.ReLU()(x)
+ x = layers.Dense(
+ PROJECT_DIM, use_bias=False, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
+ )(x)
+ outputs = layers.BatchNormalization()(x)
+ return keras.Model(inputs, outputs, name="encoder")
+
+
+def get_predictor():
+ model = keras.Sequential(
+ [
+ # Note the AutoEncoder-like structure.
+ layers.Input((PROJECT_DIM,)),
+ layers.Dense(
+ LATENT_DIM,
+ use_bias=False,
+ kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
+ ),
+ layers.ReLU(),
+ layers.BatchNormalization(),
+ layers.Dense(PROJECT_DIM),
+ ],
+ name="predictor",
+ )
+ return model
+
+
+"""
+## Defining the (pre-)training loop
+
+One of the main reasons behind training networks with these kinds of approaches is to
+utilize the learned representations for downstream tasks like classification. This is why
+this particular training phase is also referred to as _pre-training_.
+
+We start by defining the loss function.
+"""
+
+
+def compute_loss(p, z):
+ # The authors of SimSiam emphasize the impact of
+ # the `stop_gradient` operator in the paper as it
+ # has an important role in the overall optimization.
+ z = ops.stop_gradient(z)
+ p = keras.utils.normalize(p, axis=1, order=2)
+ z = keras.utils.normalize(z, axis=1, order=2)
+ # Negative cosine similarity (minimizing this is
+ # equivalent to maximizing the similarity).
+ return -ops.mean(ops.sum((p * z), axis=1))
+
+
+"""
+We then define our training loop by overriding the `train_step()` function of the
+`keras.Model` class.
+"""
+
+
+class SimSiam(keras.Model):
+ def __init__(self, encoder, predictor):
+ super().__init__()
+ self.encoder = encoder
+ self.predictor = predictor
+ self.loss_tracker = keras.metrics.Mean(name="loss")
+
+ @property
+ def metrics(self):
+ return [self.loss_tracker]
+
+ def train_step(self, data):
+ # Unpack the data.
+ ds_one, ds_two = data
+
+ # Forward pass through the encoder and predictor.
+ with tf.GradientTape() as tape:
+ z1, z2 = self.encoder(ds_one), self.encoder(ds_two)
+ p1, p2 = self.predictor(z1), self.predictor(z2)
+ # Note that here we are enforcing the network to match
+ # the representations of two differently augmented batches
+ # of data.
+ loss = compute_loss(p1, z2) / 2 + compute_loss(p2, z1) / 2
+
+ # Compute gradients and update the parameters.
+ learnable_params = (
+ self.encoder.trainable_variables + self.predictor.trainable_variables
+ )
+ gradients = tape.gradient(loss, learnable_params)
+ self.optimizer.apply_gradients(zip(gradients, learnable_params))
+
+ # Monitor loss.
+ self.loss_tracker.update_state(loss)
+ return {"loss": self.loss_tracker.result()}
+
+
+"""
+## Pre-training our networks
+
+In the interest of this example, we will train the model for only 5 epochs. In reality,
+this should at least be 100 epochs.
+"""
+
+# Create a cosine decay learning scheduler.
+num_training_samples = len(x_train)
+steps = EPOCHS * (num_training_samples // BATCH_SIZE)
+lr_decayed_fn = keras.optimizers.schedules.CosineDecay(
+ initial_learning_rate=0.03, decay_steps=steps
+)
+
+# Create an early stopping callback.
+early_stopping = keras.callbacks.EarlyStopping(
+ monitor="loss", patience=5, restore_best_weights=True
+)
+
+# Compile model and start training.
+simsiam = SimSiam(get_encoder(), get_predictor())
+simsiam.compile(optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.6))
+history = simsiam.fit(ssl_ds, epochs=EPOCHS, callbacks=[early_stopping])
+
+# Visualize the training progress of the model.
+plt.plot(history.history["loss"])
+plt.grid()
+plt.title("Negative Cosine Similairty")
+plt.show()
+
+"""
+If your solution gets very close to -1 (minimum value of our loss) very quickly with a
+different dataset and a different backbone architecture that is likely because of
+*representation collapse*. It is a phenomenon where the encoder yields similar output for
+all the images. In that case additional hyperparameter tuning is required especially in
+the following areas:
+
+* Strength of the color distortions and their probabilities.
+* Learning rate and its schedule.
+* Architecture of both the backbone and their projection head.
+
+"""
+
+"""
+## Evaluating our SSL method
+
+The most popularly used method to evaluate a SSL method in computer vision (or any other
+pre-training method as such) is to learn a linear classifier on the frozen features of
+the trained backbone model (in this case it is ResNet20) and evaluate the classifier on
+unseen images. Other methods include
+[fine-tuning](https://keras.io/guides/transfer_learning/) on the source dataset or even a
+target dataset with 5% or 10% labels present. Practically, we can use the backbone model
+for any downstream task such as semantic segmentation, object detection, and so on where
+the backbone models are usually pre-trained with *pure supervised learning*.
+"""
+
+# We first create labeled `Dataset` objects.
+train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+
+# Then we shuffle, batch, and prefetch this dataset for performance. We
+# also apply random resized crops as an augmentation but only to the
+# training set.
+train_ds = (
+ train_ds.shuffle(1024)
+ .map(lambda x, y: (flip_random_crop(x), y), num_parallel_calls=AUTO)
+ .batch(BATCH_SIZE)
+ .prefetch(AUTO)
+)
+test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+# Extract the backbone ResNet20.
+backbone = keras.Model(
+ simsiam.encoder.input, simsiam.encoder.get_layer("backbone_pool").output
+)
+
+# We then create our linear classifier and train it.
+backbone.trainable = False
+inputs = layers.Input((CROP_TO, CROP_TO, 3))
+x = backbone(inputs, training=False)
+outputs = layers.Dense(10, activation="softmax")(x)
+linear_model = keras.Model(inputs, outputs, name="linear_model")
+
+# Compile model and start training.
+linear_model.compile(
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+ optimizer=keras.optimizers.SGD(lr_decayed_fn, momentum=0.9),
+)
+history = linear_model.fit(
+ train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=[early_stopping]
+)
+_, test_acc = linear_model.evaluate(test_ds)
+print("Test accuracy: {:.2f}%".format(test_acc * 100))
+
+"""
+
+## Notes
+* More data and longer pre-training schedule benefit SSL in general.
+* SSL is particularly very helpful when you do not have access to very limited *labeled*
+training data but you can manage to build a large corpus of unlabeled data. Recently,
+using an SSL method called [SwAV](https://arxiv.org/abs/2006.09882), a group of
+researchers at Facebook trained a [RegNet](https://arxiv.org/abs/2006.09882) on 2 Billion
+images. They were able to achieve downstream performance very close to those achieved by
+pure supervised pre-training. For some downstream tasks, their method even outperformed
+the supervised counterparts. You can check out [their
+paper](https://arxiv.org/pdf/2103.01988.pdf) to know the details.
+* If you are interested to understand why contrastive SSL helps networks learn meaningful
+representations, you can check out the following resources:
+ * [Self-supervised learning: The dark matter of
+intelligence](https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/)
+ * [Understanding self-supervised learning using controlled datasets with known
+structure](https://sslneuips20.github.io/files/CameraReadys%203-77/64/CameraReady/Understanding_self_supervised_learning.pdf)
+
+"""
diff --git a/knowledge_base/vision/super_resolution_sub_pixel.py b/knowledge_base/vision/super_resolution_sub_pixel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3cccb776b0acadbc3b4b691c559b981291f09bf
--- /dev/null
+++ b/knowledge_base/vision/super_resolution_sub_pixel.py
@@ -0,0 +1,409 @@
+"""
+Title: Image Super-Resolution using an Efficient Sub-Pixel CNN
+Author: [Xingyu Long](https://github.com/xingyu-long)
+Date created: 2020/07/28
+Last modified: 2020/08/27
+Description: Implementing Super-Resolution using Efficient sub-pixel model on BSDS500.
+Accelerator: GPU
+Converted to Keras 3 by: [Md Awsfalur Rahman](https://awsaf49.github.io)
+"""
+
+"""
+## Introduction
+
+ESPCN (Efficient Sub-Pixel CNN), proposed by [Shi, 2016](https://arxiv.org/abs/1609.05158)
+is a model that reconstructs a high-resolution version of an image given a low-resolution
+version.
+It leverages efficient "sub-pixel convolution" layers, which learns an array of
+image upscaling filters.
+
+In this code example, we will implement the model from the paper and train it on a small
+dataset,
+[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).
+[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html).
+"""
+
+"""
+## Setup
+"""
+
+import keras
+from keras import layers
+from keras import ops
+from keras.utils import load_img
+from keras.utils import array_to_img
+from keras.utils import img_to_array
+from keras.preprocessing import image_dataset_from_directory
+import tensorflow as tf # only for data preprocessing
+
+import os
+import math
+import numpy as np
+
+from IPython.display import display
+
+"""
+## Load data: BSDS500 dataset
+
+### Download dataset
+
+We use the built-in `keras.utils.get_file` utility to retrieve the dataset.
+"""
+
+dataset_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
+data_dir = keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
+root_dir = os.path.join(data_dir, "BSDS500/data")
+
+"""
+We create training and validation datasets via `image_dataset_from_directory`.
+"""
+
+crop_size = 300
+upscale_factor = 3
+input_size = crop_size // upscale_factor
+batch_size = 8
+
+train_ds = image_dataset_from_directory(
+ root_dir,
+ batch_size=batch_size,
+ image_size=(crop_size, crop_size),
+ validation_split=0.2,
+ subset="training",
+ seed=1337,
+ label_mode=None,
+)
+
+valid_ds = image_dataset_from_directory(
+ root_dir,
+ batch_size=batch_size,
+ image_size=(crop_size, crop_size),
+ validation_split=0.2,
+ subset="validation",
+ seed=1337,
+ label_mode=None,
+)
+
+"""
+We rescale the images to take values in the range [0, 1].
+"""
+
+
+def scaling(input_image):
+ input_image = input_image / 255.0
+ return input_image
+
+
+# Scale from (0, 255) to (0, 1)
+train_ds = train_ds.map(scaling)
+valid_ds = valid_ds.map(scaling)
+
+"""
+Let's visualize a few sample images:
+"""
+
+for batch in train_ds.take(1):
+ for img in batch:
+ display(array_to_img(img))
+
+"""
+We prepare a dataset of test image paths that we will use for
+visual evaluation at the end of this example.
+"""
+
+dataset = os.path.join(root_dir, "images")
+test_path = os.path.join(dataset, "test")
+
+test_img_paths = sorted(
+ [
+ os.path.join(test_path, fname)
+ for fname in os.listdir(test_path)
+ if fname.endswith(".jpg")
+ ]
+)
+
+"""
+## Crop and resize images
+
+Let's process image data.
+First, we convert our images from the RGB color space to the
+[YUV colour space](https://en.wikipedia.org/wiki/YUV).
+
+For the input data (low-resolution images),
+we crop the image, retrieve the `y` channel (luninance),
+and resize it with the `area` method (use `BICUBIC` if you use PIL).
+We only consider the luminance channel
+in the YUV color space because humans are more sensitive to
+luminance change.
+
+For the target data (high-resolution images), we just crop the image
+and retrieve the `y` channel.
+"""
+
+
+# Use TF Ops to process.
+def process_input(input, input_size, upscale_factor):
+ input = tf.image.rgb_to_yuv(input)
+ last_dimension_axis = len(input.shape) - 1
+ y, u, v = tf.split(input, 3, axis=last_dimension_axis)
+ return tf.image.resize(y, [input_size, input_size], method="area")
+
+
+def process_target(input):
+ input = tf.image.rgb_to_yuv(input)
+ last_dimension_axis = len(input.shape) - 1
+ y, u, v = tf.split(input, 3, axis=last_dimension_axis)
+ return y
+
+
+train_ds = train_ds.map(
+ lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
+)
+train_ds = train_ds.prefetch(buffer_size=32)
+
+valid_ds = valid_ds.map(
+ lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
+)
+valid_ds = valid_ds.prefetch(buffer_size=32)
+
+"""
+Let's take a look at the input and target data.
+"""
+
+for batch in train_ds.take(1):
+ for img in batch[0]:
+ display(array_to_img(img))
+ for img in batch[1]:
+ display(array_to_img(img))
+
+"""
+## Build a model
+
+Compared to the paper, we add one more layer and we use the `relu` activation function
+instead of `tanh`.
+It achieves better performance even though we train the model for fewer epochs.
+"""
+
+
+class DepthToSpace(layers.Layer):
+ def __init__(self, block_size):
+ super().__init__()
+ self.block_size = block_size
+
+ def call(self, input):
+ batch, height, width, depth = ops.shape(input)
+ depth = depth // (self.block_size**2)
+
+ x = ops.reshape(
+ input, [batch, height, width, self.block_size, self.block_size, depth]
+ )
+ x = ops.transpose(x, [0, 1, 3, 2, 4, 5])
+ x = ops.reshape(
+ x, [batch, height * self.block_size, width * self.block_size, depth]
+ )
+ return x
+
+
+def get_model(upscale_factor=3, channels=1):
+ conv_args = {
+ "activation": "relu",
+ "kernel_initializer": "orthogonal",
+ "padding": "same",
+ }
+ inputs = keras.Input(shape=(None, None, channels))
+ x = layers.Conv2D(64, 5, **conv_args)(inputs)
+ x = layers.Conv2D(64, 3, **conv_args)(x)
+ x = layers.Conv2D(32, 3, **conv_args)(x)
+ x = layers.Conv2D(channels * (upscale_factor**2), 3, **conv_args)(x)
+ outputs = DepthToSpace(upscale_factor)(x)
+
+ return keras.Model(inputs, outputs)
+
+
+"""
+## Define utility functions
+
+We need to define several utility functions to monitor our results:
+
+- `plot_results` to plot an save an image.
+- `get_lowres_image` to convert an image to its low-resolution version.
+- `upscale_image` to turn a low-resolution image to
+a high-resolution version reconstructed by the model.
+In this function, we use the `y` channel from the YUV color space
+as input to the model and then combine the output with the
+other channels to obtain an RGB image.
+"""
+
+import matplotlib.pyplot as plt
+from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
+from mpl_toolkits.axes_grid1.inset_locator import mark_inset
+import PIL
+
+
+def plot_results(img, prefix, title):
+ """Plot the result with zoom-in area."""
+ img_array = img_to_array(img)
+ img_array = img_array.astype("float32") / 255.0
+
+ # Create a new figure with a default 111 subplot.
+ fig, ax = plt.subplots()
+ im = ax.imshow(img_array[::-1], origin="lower")
+
+ plt.title(title)
+ # zoom-factor: 2.0, location: upper-left
+ axins = zoomed_inset_axes(ax, 2, loc=2)
+ axins.imshow(img_array[::-1], origin="lower")
+
+ # Specify the limits.
+ x1, x2, y1, y2 = 200, 300, 100, 200
+ # Apply the x-limits.
+ axins.set_xlim(x1, x2)
+ # Apply the y-limits.
+ axins.set_ylim(y1, y2)
+
+ plt.yticks(visible=False)
+ plt.xticks(visible=False)
+
+ # Make the line.
+ mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
+ plt.savefig(str(prefix) + "-" + title + ".png")
+ plt.show()
+
+
+def get_lowres_image(img, upscale_factor):
+ """Return low-resolution image to use as model input."""
+ return img.resize(
+ (img.size[0] // upscale_factor, img.size[1] // upscale_factor),
+ PIL.Image.BICUBIC,
+ )
+
+
+def upscale_image(model, img):
+ """Predict the result based on input image and restore the image as RGB."""
+ ycbcr = img.convert("YCbCr")
+ y, cb, cr = ycbcr.split()
+ y = img_to_array(y)
+ y = y.astype("float32") / 255.0
+
+ input = np.expand_dims(y, axis=0)
+ out = model.predict(input)
+
+ out_img_y = out[0]
+ out_img_y *= 255.0
+
+ # Restore the image in RGB color space.
+ out_img_y = out_img_y.clip(0, 255)
+ out_img_y = out_img_y.reshape((np.shape(out_img_y)[0], np.shape(out_img_y)[1]))
+ out_img_y = PIL.Image.fromarray(np.uint8(out_img_y), mode="L")
+ out_img_cb = cb.resize(out_img_y.size, PIL.Image.BICUBIC)
+ out_img_cr = cr.resize(out_img_y.size, PIL.Image.BICUBIC)
+ out_img = PIL.Image.merge("YCbCr", (out_img_y, out_img_cb, out_img_cr)).convert(
+ "RGB"
+ )
+ return out_img
+
+
+"""
+## Define callbacks to monitor training
+
+The `ESPCNCallback` object will compute and display
+the [PSNR](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) metric.
+This is the main metric we use to evaluate super-resolution performance.
+"""
+
+
+class ESPCNCallback(keras.callbacks.Callback):
+ def __init__(self):
+ super().__init__()
+ self.test_img = get_lowres_image(load_img(test_img_paths[0]), upscale_factor)
+
+ # Store PSNR value in each epoch.
+ def on_epoch_begin(self, epoch, logs=None):
+ self.psnr = []
+
+ def on_epoch_end(self, epoch, logs=None):
+ print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))
+ if epoch % 20 == 0:
+ prediction = upscale_image(self.model, self.test_img)
+ plot_results(prediction, "epoch-" + str(epoch), "prediction")
+
+ def on_test_batch_end(self, batch, logs=None):
+ self.psnr.append(10 * math.log10(1 / logs["loss"]))
+
+
+"""
+Define `ModelCheckpoint` and `EarlyStopping` callbacks.
+"""
+
+early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)
+
+checkpoint_filepath = "/tmp/checkpoint.keras"
+
+model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ filepath=checkpoint_filepath,
+ save_weights_only=False,
+ monitor="loss",
+ mode="min",
+ save_best_only=True,
+)
+
+model = get_model(upscale_factor=upscale_factor, channels=1)
+model.summary()
+
+callbacks = [ESPCNCallback(), early_stopping_callback, model_checkpoint_callback]
+loss_fn = keras.losses.MeanSquaredError()
+optimizer = keras.optimizers.Adam(learning_rate=0.001)
+
+"""
+## Train the model
+"""
+
+epochs = 100
+
+model.compile(
+ optimizer=optimizer,
+ loss=loss_fn,
+)
+
+model.fit(
+ train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds, verbose=2
+)
+
+# The model weights (that are considered the best) are loaded into the model.
+model.load_weights(checkpoint_filepath)
+
+"""
+## Run model prediction and plot the results
+
+Let's compute the reconstructed version of a few images and save the results.
+"""
+
+total_bicubic_psnr = 0.0
+total_test_psnr = 0.0
+
+for index, test_img_path in enumerate(test_img_paths[50:60]):
+ img = load_img(test_img_path)
+ lowres_input = get_lowres_image(img, upscale_factor)
+ w = lowres_input.size[0] * upscale_factor
+ h = lowres_input.size[1] * upscale_factor
+ highres_img = img.resize((w, h))
+ prediction = upscale_image(model, lowres_input)
+ lowres_img = lowres_input.resize((w, h))
+ lowres_img_arr = img_to_array(lowres_img)
+ highres_img_arr = img_to_array(highres_img)
+ predict_img_arr = img_to_array(prediction)
+ bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)
+ test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)
+
+ total_bicubic_psnr += bicubic_psnr
+ total_test_psnr += test_psnr
+
+ print(
+ "PSNR of low resolution image and high resolution image is %.4f" % bicubic_psnr
+ )
+ print("PSNR of predict and high resolution is %.4f" % test_psnr)
+ plot_results(lowres_img, index, "lowres")
+ plot_results(highres_img, index, "highres")
+ plot_results(prediction, index, "prediction")
+
+print("Avg. PSNR of lowres images is %.4f" % (total_bicubic_psnr / 10))
+print("Avg. PSNR of reconstructions is %.4f" % (total_test_psnr / 10))
diff --git a/knowledge_base/vision/supervised-contrastive-learning.py b/knowledge_base/vision/supervised-contrastive-learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f6bfb8686e748c7b7f23786b0bc6fe5977bc178
--- /dev/null
+++ b/knowledge_base/vision/supervised-contrastive-learning.py
@@ -0,0 +1,238 @@
+"""
+Title: Supervised Contrastive Learning
+Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
+Date created: 2020/11/30
+Last modified: 2020/11/30
+Description: Using supervised contrastive learning for image classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+[Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362)
+(Prannay Khosla et al.) is a training methodology that outperforms
+supervised training with crossentropy on classification tasks.
+
+Essentially, training an image classification model with Supervised Contrastive
+Learning is performed in two phases:
+
+1. Training an encoder to learn to produce vector representations of input images such
+that representations of images in the same class will be more similar compared to
+representations of images in different classes.
+2. Training a classifier on top of the frozen encoder.
+
+Note that this example requires [TensorFlow Addons](https://www.tensorflow.org/addons),
+which you can install using the following command:
+
+```python
+pip install tensorflow-addons
+```
+
+## Setup
+"""
+
+import tensorflow as tf
+import tensorflow_addons as tfa
+import numpy as np
+from tensorflow import keras
+from tensorflow.keras import layers
+
+"""
+## Prepare the data
+"""
+
+num_classes = 10
+input_shape = (32, 32, 3)
+
+# Load the train and test data splits
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+
+# Display shapes of train and test datasets
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+
+"""
+## Using image data augmentation
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(0.02),
+ ]
+)
+
+# Setting the state of the normalization layer.
+data_augmentation.layers[0].adapt(x_train)
+
+"""
+## Build the encoder model
+
+The encoder model takes the image as input and turns it into a 2048-dimensional
+feature vector.
+"""
+
+
+def create_encoder():
+ resnet = keras.applications.ResNet50V2(
+ include_top=False, weights=None, input_shape=input_shape, pooling="avg"
+ )
+
+ inputs = keras.Input(shape=input_shape)
+ augmented = data_augmentation(inputs)
+ outputs = resnet(augmented)
+ model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-encoder")
+ return model
+
+
+encoder = create_encoder()
+encoder.summary()
+
+learning_rate = 0.001
+batch_size = 265
+hidden_units = 512
+projection_units = 128
+num_epochs = 50
+dropout_rate = 0.5
+temperature = 0.05
+
+"""
+## Build the classification model
+
+The classification model adds a fully-connected layer on top of the encoder,
+plus a softmax layer with the target classes.
+"""
+
+
+def create_classifier(encoder, trainable=True):
+ for layer in encoder.layers:
+ layer.trainable = trainable
+
+ inputs = keras.Input(shape=input_shape)
+ features = encoder(inputs)
+ features = layers.Dropout(dropout_rate)(features)
+ features = layers.Dense(hidden_units, activation="relu")(features)
+ features = layers.Dropout(dropout_rate)(features)
+ outputs = layers.Dense(num_classes, activation="softmax")(features)
+
+ model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-classifier")
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=keras.losses.SparseCategoricalCrossentropy(),
+ metrics=[keras.metrics.SparseCategoricalAccuracy()],
+ )
+ return model
+
+
+"""
+## Experiment 1: Train the baseline classification model
+
+In this experiment, a baseline classifier is trained as usual, i.e., the
+encoder and the classifier parts are trained together as a single model
+to minimize the crossentropy loss.
+"""
+
+encoder = create_encoder()
+classifier = create_classifier(encoder)
+classifier.summary()
+
+history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)
+
+accuracy = classifier.evaluate(x_test, y_test)[1]
+print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+
+"""
+## Experiment 2: Use supervised contrastive learning
+
+In this experiment, the model is trained in two phases. In the first phase,
+the encoder is pretrained to optimize the supervised contrastive loss,
+described in [Prannay Khosla et al.](https://arxiv.org/abs/2004.11362).
+
+In the second phase, the classifier is trained using the trained encoder with
+its weights freezed; only the weights of fully-connected layers with the
+softmax are optimized.
+
+### 1. Supervised contrastive learning loss function
+"""
+
+
+class SupervisedContrastiveLoss(keras.losses.Loss):
+ def __init__(self, temperature=1, name=None):
+ super().__init__(name=name)
+ self.temperature = temperature
+
+ def __call__(self, labels, feature_vectors, sample_weight=None):
+ # Normalize feature vectors
+ feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
+ # Compute logits
+ logits = tf.divide(
+ tf.matmul(
+ feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
+ ),
+ self.temperature,
+ )
+ return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
+
+
+def add_projection_head(encoder):
+ inputs = keras.Input(shape=input_shape)
+ features = encoder(inputs)
+ outputs = layers.Dense(projection_units, activation="relu")(features)
+ model = keras.Model(
+ inputs=inputs, outputs=outputs, name="cifar-encoder_with_projection-head"
+ )
+ return model
+
+
+"""
+### 2. Pretrain the encoder
+"""
+
+encoder = create_encoder()
+
+encoder_with_projection_head = add_projection_head(encoder)
+encoder_with_projection_head.compile(
+ optimizer=keras.optimizers.Adam(learning_rate),
+ loss=SupervisedContrastiveLoss(temperature),
+)
+
+encoder_with_projection_head.summary()
+
+history = encoder_with_projection_head.fit(
+ x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs
+)
+
+"""
+### 3. Train the classifier with the frozen encoder
+"""
+
+classifier = create_classifier(encoder, trainable=False)
+
+history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)
+
+accuracy = classifier.evaluate(x_test, y_test)[1]
+print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+"""
+We get to an improved test accuracy.
+"""
+
+"""
+## Conclusion
+
+As shown in the experiments, using the supervised contrastive learning technique
+outperformed the conventional technique in terms of the test accuracy. Note that
+the same training budget (i.e., number of epochs) was given to each technique.
+Supervised contrastive learning pays off when the encoder involves a complex
+architecture, like ResNet, and multi-class problems with many labels.
+In addition, large batch sizes and multi-layer projection heads
+improve its effectiveness. See the [Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362)
+paper for more details.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/supervised-contrastive-learning-cifar10)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/supervised-contrastive-learning).
+"""
diff --git a/knowledge_base/vision/swin_transformers.py b/knowledge_base/vision/swin_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..67172c7abf40bb3e633b2d5014f253da05ca6982
--- /dev/null
+++ b/knowledge_base/vision/swin_transformers.py
@@ -0,0 +1,601 @@
+"""
+Title: Image classification with Swin Transformers
+Author: [Rishit Dagli](https://twitter.com/rishit_dagli)
+Date created: 2021/09/08
+Last modified: 2021/09/08
+Description: Image classification using Swin Transformers, a general-purpose backbone for computer vision.
+Accelerator: GPU
+"""
+
+"""
+This example implements
+[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
+by Liu et al. for image classification, and demonstrates it on the
+[CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+
+Swin Transformer (**S**hifted **Win**dow Transformer) can serve as a
+general-purpose backbone for computer vision. Swin Transformer is a hierarchical
+Transformer whose representations are computed with _shifted windows_. The
+shifted window scheme brings greater efficiency by limiting self-attention
+computation to non-overlapping local windows while also allowing for
+cross-window connections. This architecture has the flexibility to model
+information at various scales and has a linear computational complexity with
+respect to image size.
+
+This example requires TensorFlow 2.5 or higher.
+"""
+
+"""
+## Setup
+"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf # For tf.data and preprocessing only.
+import keras
+from keras import layers
+from keras import ops
+
+"""
+## Configure the hyperparameters
+
+A key parameter to pick is the `patch_size`, the size of the input patches.
+In order to use each pixel as an individual input, you can set `patch_size` to
+`(1, 1)`. Below, we take inspiration from the original paper settings for
+training on ImageNet-1K, keeping most of the original settings for this example.
+"""
+
+num_classes = 100
+input_shape = (32, 32, 3)
+
+patch_size = (2, 2) # 2-by-2 sized patches
+dropout_rate = 0.03 # Dropout rate
+num_heads = 8 # Attention heads
+embed_dim = 64 # Embedding dimension
+num_mlp = 256 # MLP layer size
+# Convert embedded patches to query, key, and values with a learnable additive
+# value
+qkv_bias = True
+window_size = 2 # Size of attention window
+shift_size = 1 # Size of shifting window
+image_dimension = 32 # Initial image size
+
+num_patch_x = input_shape[0] // patch_size[0]
+num_patch_y = input_shape[1] // patch_size[1]
+
+learning_rate = 1e-3
+batch_size = 128
+num_epochs = 40
+validation_split = 0.1
+weight_decay = 0.0001
+label_smoothing = 0.1
+
+"""
+## Prepare the data
+
+We load the CIFAR-100 dataset through `keras.datasets`,
+normalize the images, and convert the integer labels to one-hot encoded vectors.
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+x_train, x_test = x_train / 255.0, x_test / 255.0
+y_train = keras.utils.to_categorical(y_train, num_classes)
+y_test = keras.utils.to_categorical(y_test, num_classes)
+num_train_samples = int(len(x_train) * (1 - validation_split))
+num_val_samples = len(x_train) - num_train_samples
+x_train, x_val = np.split(x_train, [num_train_samples])
+y_train, y_val = np.split(y_train, [num_train_samples])
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+plt.figure(figsize=(10, 10))
+for i in range(25):
+ plt.subplot(5, 5, i + 1)
+ plt.xticks([])
+ plt.yticks([])
+ plt.grid(False)
+ plt.imshow(x_train[i])
+plt.show()
+
+
+"""
+## Helper functions
+
+We create two helper functions to help us get a sequence of
+patches from the image, merge patches, and apply dropout.
+"""
+
+
+def window_partition(x, window_size):
+ _, height, width, channels = x.shape
+ patch_num_y = height // window_size
+ patch_num_x = width // window_size
+ x = ops.reshape(
+ x,
+ (
+ -1,
+ patch_num_y,
+ window_size,
+ patch_num_x,
+ window_size,
+ channels,
+ ),
+ )
+ x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
+ windows = ops.reshape(x, (-1, window_size, window_size, channels))
+ return windows
+
+
+def window_reverse(windows, window_size, height, width, channels):
+ patch_num_y = height // window_size
+ patch_num_x = width // window_size
+ x = ops.reshape(
+ windows,
+ (
+ -1,
+ patch_num_y,
+ patch_num_x,
+ window_size,
+ window_size,
+ channels,
+ ),
+ )
+ x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
+ x = ops.reshape(x, (-1, height, width, channels))
+ return x
+
+
+"""
+## Window based multi-head self-attention
+
+Usually Transformers perform global self-attention, where the relationships
+between a token and all other tokens are computed. The global computation leads
+to quadratic complexity with respect to the number of tokens. Here, as the
+[original paper](https://arxiv.org/abs/2103.14030) suggests, we compute
+self-attention within local windows, in a non-overlapping manner. Global
+self-attention leads to quadratic computational complexity in the number of
+patches, whereas window-based self-attention leads to linear complexity and is
+easily scalable.
+"""
+
+
+class WindowAttention(layers.Layer):
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ dropout_rate=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.dim = dim
+ self.window_size = window_size
+ self.num_heads = num_heads
+ self.scale = (dim // num_heads) ** -0.5
+ self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
+ self.dropout = layers.Dropout(dropout_rate)
+ self.proj = layers.Dense(dim)
+
+ num_window_elements = (2 * self.window_size[0] - 1) * (
+ 2 * self.window_size[1] - 1
+ )
+ self.relative_position_bias_table = self.add_weight(
+ shape=(num_window_elements, self.num_heads),
+ initializer=keras.initializers.Zeros(),
+ trainable=True,
+ )
+ coords_h = np.arange(self.window_size[0])
+ coords_w = np.arange(self.window_size[1])
+ coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
+ coords = np.stack(coords_matrix)
+ coords_flatten = coords.reshape(2, -1)
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.transpose([1, 2, 0])
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+
+ self.relative_position_index = keras.Variable(
+ initializer=relative_position_index,
+ shape=relative_position_index.shape,
+ dtype="int",
+ trainable=False,
+ )
+
+ def call(self, x, mask=None):
+ _, size, channels = x.shape
+ head_dim = channels // self.num_heads
+ x_qkv = self.qkv(x)
+ x_qkv = ops.reshape(x_qkv, (-1, size, 3, self.num_heads, head_dim))
+ x_qkv = ops.transpose(x_qkv, (2, 0, 3, 1, 4))
+ q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
+ q = q * self.scale
+ k = ops.transpose(k, (0, 1, 3, 2))
+ attn = q @ k
+
+ num_window_elements = self.window_size[0] * self.window_size[1]
+ relative_position_index_flat = ops.reshape(self.relative_position_index, (-1,))
+ relative_position_bias = ops.take(
+ self.relative_position_bias_table,
+ relative_position_index_flat,
+ axis=0,
+ )
+ relative_position_bias = ops.reshape(
+ relative_position_bias,
+ (num_window_elements, num_window_elements, -1),
+ )
+ relative_position_bias = ops.transpose(relative_position_bias, (2, 0, 1))
+ attn = attn + ops.expand_dims(relative_position_bias, axis=0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ mask_float = ops.cast(
+ ops.expand_dims(ops.expand_dims(mask, axis=1), axis=0),
+ "float32",
+ )
+ attn = ops.reshape(attn, (-1, nW, self.num_heads, size, size)) + mask_float
+ attn = ops.reshape(attn, (-1, self.num_heads, size, size))
+ attn = keras.activations.softmax(attn, axis=-1)
+ else:
+ attn = keras.activations.softmax(attn, axis=-1)
+ attn = self.dropout(attn)
+
+ x_qkv = attn @ v
+ x_qkv = ops.transpose(x_qkv, (0, 2, 1, 3))
+ x_qkv = ops.reshape(x_qkv, (-1, size, channels))
+ x_qkv = self.proj(x_qkv)
+ x_qkv = self.dropout(x_qkv)
+ return x_qkv
+
+
+"""
+## The complete Swin Transformer model
+
+Finally, we put together the complete Swin Transformer by replacing the standard
+multi-head attention (MHA) with shifted windows attention. As suggested in the
+original paper, we create a model comprising of a shifted window-based MHA
+layer, followed by a 2-layer MLP with GELU nonlinearity in between, applying
+`LayerNormalization` before each MSA layer and each MLP, and a residual
+connection after each of these layers.
+
+Notice that we only create a simple MLP with 2 Dense and
+2 Dropout layers. Often you will see models using ResNet-50 as the MLP which is
+quite standard in the literature. However in this paper the authors use a
+2-layer MLP with GELU nonlinearity in between.
+"""
+
+
+class SwinTransformer(layers.Layer):
+ def __init__(
+ self,
+ dim,
+ num_patch,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ num_mlp=1024,
+ qkv_bias=True,
+ dropout_rate=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.dim = dim # number of input dimensions
+ self.num_patch = num_patch # number of embedded patches
+ self.num_heads = num_heads # number of attention heads
+ self.window_size = window_size # size of window
+ self.shift_size = shift_size # size of window shift
+ self.num_mlp = num_mlp # number of MLP nodes
+
+ self.norm1 = layers.LayerNormalization(epsilon=1e-5)
+ self.attn = WindowAttention(
+ dim,
+ window_size=(self.window_size, self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate,
+ )
+ self.drop_path = layers.Dropout(dropout_rate)
+ self.norm2 = layers.LayerNormalization(epsilon=1e-5)
+
+ self.mlp = keras.Sequential(
+ [
+ layers.Dense(num_mlp),
+ layers.Activation(keras.activations.gelu),
+ layers.Dropout(dropout_rate),
+ layers.Dense(dim),
+ layers.Dropout(dropout_rate),
+ ]
+ )
+
+ if min(self.num_patch) < self.window_size:
+ self.shift_size = 0
+ self.window_size = min(self.num_patch)
+
+ def build(self, input_shape):
+ if self.shift_size == 0:
+ self.attn_mask = None
+ else:
+ height, width = self.num_patch
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ mask_array = np.zeros((1, height, width, 1))
+ count = 0
+ for h in h_slices:
+ for w in w_slices:
+ mask_array[:, h, w, :] = count
+ count += 1
+ mask_array = ops.convert_to_tensor(mask_array)
+
+ # mask array to windows
+ mask_windows = window_partition(mask_array, self.window_size)
+ mask_windows = ops.reshape(
+ mask_windows, [-1, self.window_size * self.window_size]
+ )
+ attn_mask = ops.expand_dims(mask_windows, axis=1) - ops.expand_dims(
+ mask_windows, axis=2
+ )
+ attn_mask = ops.where(attn_mask != 0, -100.0, attn_mask)
+ attn_mask = ops.where(attn_mask == 0, 0.0, attn_mask)
+ self.attn_mask = keras.Variable(
+ initializer=attn_mask,
+ shape=attn_mask.shape,
+ dtype=attn_mask.dtype,
+ trainable=False,
+ )
+
+ def call(self, x, training=False):
+ height, width = self.num_patch
+ _, num_patches_before, channels = x.shape
+ x_skip = x
+ x = self.norm1(x)
+ x = ops.reshape(x, (-1, height, width, channels))
+ if self.shift_size > 0:
+ shifted_x = ops.roll(
+ x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
+ )
+ else:
+ shifted_x = x
+
+ x_windows = window_partition(shifted_x, self.window_size)
+ x_windows = ops.reshape(
+ x_windows, (-1, self.window_size * self.window_size, channels)
+ )
+ attn_windows = self.attn(x_windows, mask=self.attn_mask)
+
+ attn_windows = ops.reshape(
+ attn_windows,
+ (-1, self.window_size, self.window_size, channels),
+ )
+ shifted_x = window_reverse(
+ attn_windows, self.window_size, height, width, channels
+ )
+ if self.shift_size > 0:
+ x = ops.roll(
+ shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
+ )
+ else:
+ x = shifted_x
+
+ x = ops.reshape(x, (-1, height * width, channels))
+ x = self.drop_path(x, training=training)
+ x = x_skip + x
+ x_skip = x
+ x = self.norm2(x)
+ x = self.mlp(x)
+ x = self.drop_path(x)
+ x = x_skip + x
+ return x
+
+
+"""
+## Model training and evaluation
+
+### Extract and embed patches
+
+We first create 3 layers to help us extract, embed and merge patches from the
+images on top of which we will later use the Swin Transformer class we built.
+"""
+
+
+# Using tf ops since it is only used in tf.data.
+def patch_extract(images):
+ batch_size = tf.shape(images)[0]
+ patches = tf.image.extract_patches(
+ images=images,
+ sizes=(1, patch_size[0], patch_size[1], 1),
+ strides=(1, patch_size[0], patch_size[1], 1),
+ rates=(1, 1, 1, 1),
+ padding="VALID",
+ )
+ patch_dim = patches.shape[-1]
+ patch_num = patches.shape[1]
+ return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))
+
+
+class PatchEmbedding(layers.Layer):
+ def __init__(self, num_patch, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.num_patch = num_patch
+ self.proj = layers.Dense(embed_dim)
+ self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)
+
+ def call(self, patch):
+ pos = ops.arange(start=0, stop=self.num_patch)
+ return self.proj(patch) + self.pos_embed(pos)
+
+
+class PatchMerging(keras.layers.Layer):
+ def __init__(self, num_patch, embed_dim):
+ super().__init__()
+ self.num_patch = num_patch
+ self.embed_dim = embed_dim
+ self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)
+
+ def call(self, x):
+ height, width = self.num_patch
+ _, _, C = x.shape
+ x = ops.reshape(x, (-1, height, width, C))
+ x0 = x[:, 0::2, 0::2, :]
+ x1 = x[:, 1::2, 0::2, :]
+ x2 = x[:, 0::2, 1::2, :]
+ x3 = x[:, 1::2, 1::2, :]
+ x = ops.concatenate((x0, x1, x2, x3), axis=-1)
+ x = ops.reshape(x, (-1, (height // 2) * (width // 2), 4 * C))
+ return self.linear_trans(x)
+
+
+"""
+### Prepare the tf.data.Dataset
+
+We do all the steps, which do not have trainable weights with tf.data.
+Prepare the training, validation and testing sets.
+
+"""
+
+
+def augment(x):
+ x = tf.image.random_crop(x, size=(image_dimension, image_dimension, 3))
+ x = tf.image.random_flip_left_right(x)
+ return x
+
+
+dataset = (
+ tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ .map(lambda x, y: (augment(x), y))
+ .batch(batch_size=batch_size)
+ .map(lambda x, y: (patch_extract(x), y))
+ .prefetch(tf.data.experimental.AUTOTUNE)
+)
+
+dataset_val = (
+ tf.data.Dataset.from_tensor_slices((x_val, y_val))
+ .batch(batch_size=batch_size)
+ .map(lambda x, y: (patch_extract(x), y))
+ .prefetch(tf.data.experimental.AUTOTUNE)
+)
+
+dataset_test = (
+ tf.data.Dataset.from_tensor_slices((x_test, y_test))
+ .batch(batch_size=batch_size)
+ .map(lambda x, y: (patch_extract(x), y))
+ .prefetch(tf.data.experimental.AUTOTUNE)
+)
+
+"""
+### Build the model
+
+We put together the Swin Transformer model.
+"""
+
+input = layers.Input(shape=(256, 12))
+x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(input)
+x = SwinTransformer(
+ dim=embed_dim,
+ num_patch=(num_patch_x, num_patch_y),
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0,
+ num_mlp=num_mlp,
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate,
+)(x)
+x = SwinTransformer(
+ dim=embed_dim,
+ num_patch=(num_patch_x, num_patch_y),
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=shift_size,
+ num_mlp=num_mlp,
+ qkv_bias=qkv_bias,
+ dropout_rate=dropout_rate,
+)(x)
+x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
+x = layers.GlobalAveragePooling1D()(x)
+output = layers.Dense(num_classes, activation="softmax")(x)
+
+"""
+### Train on CIFAR-100
+
+We train the model on CIFAR-100. Here, we only train the model
+for 40 epochs to keep the training time short in this example.
+In practice, you should train for 150 epochs to reach convergence.
+"""
+
+model = keras.Model(input, output)
+model.compile(
+ loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
+ optimizer=keras.optimizers.AdamW(
+ learning_rate=learning_rate, weight_decay=weight_decay
+ ),
+ metrics=[
+ keras.metrics.CategoricalAccuracy(name="accuracy"),
+ keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+)
+
+history = model.fit(
+ dataset,
+ batch_size=batch_size,
+ epochs=num_epochs,
+ validation_data=dataset_val,
+)
+
+"""
+Let's visualize the training progress of the model.
+"""
+
+plt.plot(history.history["loss"], label="train_loss")
+plt.plot(history.history["val_loss"], label="val_loss")
+plt.xlabel("Epochs")
+plt.ylabel("Loss")
+plt.title("Train and Validation Losses Over Epochs", fontsize=14)
+plt.legend()
+plt.grid()
+plt.show()
+
+"""
+Let's display the final results of the training on CIFAR-100.
+"""
+
+loss, accuracy, top_5_accuracy = model.evaluate(dataset_test)
+print(f"Test loss: {round(loss, 2)}")
+print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+"""
+The Swin Transformer model we just trained has just 152K parameters, and it gets
+us to ~75% test top-5 accuracy within just 40 epochs without any signs of
+overfitting as well as seen in above graph. This means we can train this network
+for longer (perhaps with a bit more regularization) and obtain even better
+performance. This performance can further be improved by additional techniques
+like cosine decay learning rate schedule, other data augmentation techniques.
+While experimenting, I tried training the model for 150 epochs with a slightly
+higher dropout and greater embedding dimensions which pushes the performance to
+~72% test accuracy on CIFAR-100 as you can see in the screenshot.
+
+
+
+The authors present a top-1 accuracy of 87.3% on ImageNet. The authors also
+present a number of experiments to study how input sizes, optimizers etc. affect
+the final performance of this model. The authors further present using this
+model for object detection, semantic segmentation and instance segmentation as
+well and report competitive results for these. You are strongly advised to also
+check out the [original paper](https://arxiv.org/abs/2103.14030).
+
+This example takes inspiration from the official
+[PyTorch](https://github.com/microsoft/Swin-Transformer) and
+[TensorFlow](https://github.com/VcampSoldiers/Swin-Transformer-Tensorflow)
+implementations.
+"""
diff --git a/knowledge_base/vision/temporal_latent_bottleneck.py b/knowledge_base/vision/temporal_latent_bottleneck.py
new file mode 100644
index 0000000000000000000000000000000000000000..21279501c5bc353ca9653a8f322121d796db2e93
--- /dev/null
+++ b/knowledge_base/vision/temporal_latent_bottleneck.py
@@ -0,0 +1,909 @@
+"""
+Title: When Recurrence meets Transformers
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
+Date created: 2023/03/12
+Last modified: 2024/10/29
+Description: Image Classification with Temporal Latent Bottleneck Networks.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+A simple Recurrent Neural Network (RNN) displays a strong inductive bias towards learning
+**temporally compressed representations**. **Equation 1** shows the recurrence formula,
+where `h_t` is the compressed representation (a single vector) of the entire input
+sequence `x`.
+
+|  |
+| :--: |
+| **Equation 1**: The recurrence equation. (Source: Aritra and Suvaditya)|
+
+On the other hand, Transformers ([Vaswani et. al](https://arxiv.org/abs/1706.03762)) have
+little inductive bias towards learning temporally compressed representations.
+Transformer has achieved SoTA results in Natural Language Processing (NLP)
+and Vision tasks with its pairwise attention mechanism.
+
+While the Transformer has the ability to **attend** to different sections of the input
+sequence, the computation of attention is quadratic in nature.
+
+[Didolkar et. al](https://arxiv.org/abs/2205.14794) argue that having a more compressed
+representation of a sequence may be beneficial for *generalization*, as it can be easily
+**re-used** and **re-purposed** with fewer irrelevant details. While compression is good,
+they also notice that too much of it can harm expressiveness.
+
+The authors propose a solution that divides computation into **two streams**. A *slow
+stream* that is recurrent in nature and a *fast stream* that is parameterized as a
+Transformer. While this method has the novelty of introducing different processing
+streams in order to preserve and process latent states, it has parallels drawn in other
+works like the [Perceiver Mechanism (by Jaegle et. al.)](https://arxiv.org/abs/2103.03206)
+and [Grounded Language Learning Fast and Slow (by Hill et. al.)](https://arxiv.org/abs/2009.01719).
+
+The following example explores how we can make use of the new Temporal Latent Bottleneck
+mechanism to perform image classification on the CIFAR-10 dataset. We implement this
+model by making a custom `RNNCell` implementation in order to make a **performant** and
+**vectorized** design.
+"""
+
+"""
+## Setup imports
+"""
+import os
+
+import keras
+from keras import layers, ops, mixed_precision
+from keras.optimizers import AdamW
+import numpy as np
+import random
+from matplotlib import pyplot as plt
+
+# Set seed for reproducibility.
+keras.utils.set_random_seed(42)
+
+"""
+## Setting required configuration
+
+We set a few configuration parameters that are needed within the pipeline we have
+designed. The current parameters are for use with the
+[CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+
+
+The model also supports `mixed-precision` settings, which would quantize the model to use
+`16-bit` float numbers where it can, while keeping some parameters in `32-bit` as needed
+for numerical stability. This brings performance benefits as the footprint of the model
+decreases significantly while bringing speed boosts at inference-time.
+"""
+
+config = {
+ "mixed_precision": True,
+ "dataset": "cifar10",
+ "train_slice": 40_000,
+ "batch_size": 2048,
+ "buffer_size": 2048 * 2,
+ "input_shape": [32, 32, 3],
+ "image_size": 48,
+ "num_classes": 10,
+ "learning_rate": 1e-4,
+ "weight_decay": 1e-4,
+ "epochs": 30,
+ "patch_size": 4,
+ "embed_dim": 64,
+ "chunk_size": 8,
+ "r": 2,
+ "num_layers": 4,
+ "ffn_drop": 0.2,
+ "attn_drop": 0.2,
+ "num_heads": 1,
+}
+
+if config["mixed_precision"]:
+ policy = mixed_precision.Policy("mixed_float16")
+ mixed_precision.set_global_policy(policy)
+
+"""
+## Loading the CIFAR-10 dataset
+
+We are going to use the CIFAR10 dataset for running our experiments. This dataset
+contains a training set of `50,000` images for `10` classes with the standard image size
+of `(32, 32, 3)`.
+
+It also has a separate set of `10,000` images with similar characteristics. More
+information about the dataset may be found at the official site for the dataset as well
+as [`keras.datasets.cifar10`](https://keras.io/api/datasets/cifar10/) API reference
+"""
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
+ (x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
+)
+
+"""
+## Define data augmentation for the training and validation/test pipelines
+
+We define separate pipelines for performing image augmentation on our data. This step is
+important to make the model more robust to changes, helping it generalize better.
+The preprocessing and augmentation steps we perform are as follows:
+
+- `Rescaling` (training, test): This step is performed to normalize all image pixel
+values from the `[0,255]` range to `[0,1)`. This helps in maintaining numerical stability
+later ahead during training.
+
+- `Resizing` (training, test): We resize the image from it's original size of (32, 32) to
+(52, 52). This is done to account for the Random Crop, as well as comply with the
+specifications of the data given in the paper.
+
+- `RandomCrop` (training): This layer randomly selects a crop/sub-region of the image
+with size `(48, 48)`.
+
+- `RandomFlip` (training): This layer randomly flips all the images horizontally,
+keeping image sizes the same.
+"""
+
+# Build the `train` augmentation pipeline.
+train_augmentation = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0, dtype="float32"),
+ layers.Resizing(
+ config["input_shape"][0] + 20,
+ config["input_shape"][0] + 20,
+ dtype="float32",
+ ),
+ layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
+ layers.RandomFlip("horizontal", dtype="float32"),
+ ],
+ name="train_data_augmentation",
+)
+
+# Build the `val` and `test` data pipeline.
+test_augmentation = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0, dtype="float32"),
+ layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
+ ],
+ name="test_data_augmentation",
+)
+
+# We define functions in place of simple lambda functions to run through the
+# `keras.Sequential`in order to solve this warning:
+# (https://github.com/tensorflow/tensorflow/issues/56089)
+
+
+def train_map_fn(image, label):
+ return train_augmentation(image), label
+
+
+def test_map_fn(image, label):
+ return test_augmentation(image), label
+
+
+"""
+## Load dataset into `PyDataset` object
+
+- We take the `np.ndarray` instance of the datasets and wrap a class around it,
+wrapping a `keras.utils.PyDataset` and apply augmentations with keras
+preprocessing layers.
+"""
+
+
+class Dataset(keras.utils.PyDataset):
+ def __init__(
+ self, x_data, y_data, batch_size, preprocess_fn=None, shuffle=False, **kwargs
+ ):
+ if shuffle:
+ perm = np.random.permutation(len(x_data))
+ x_data = x_data[perm]
+ y_data = y_data[perm]
+ self.x_data = x_data
+ self.y_data = y_data
+ self.preprocess_fn = preprocess_fn
+ self.batch_size = batch_size
+ super().__init__(*kwargs)
+
+ def __len__(self):
+ return len(self.x_data) // self.batch_size
+
+ def __getitem__(self, idx):
+ batch_x, batch_y = [], []
+ for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
+ x, y = self.x_data[i], self.y_data[i]
+ if self.preprocess_fn:
+ x, y = self.preprocess_fn(x, y)
+ batch_x.append(x)
+ batch_y.append(y)
+ batch_x = ops.stack(batch_x, axis=0)
+ batch_y = ops.stack(batch_y, axis=0)
+ return batch_x, batch_y
+
+
+train_ds = Dataset(
+ x_train, y_train, config["batch_size"], preprocess_fn=train_map_fn, shuffle=True
+)
+val_ds = Dataset(x_val, y_val, config["batch_size"], preprocess_fn=test_map_fn)
+test_ds = Dataset(x_test, y_test, config["batch_size"], preprocess_fn=test_map_fn)
+
+"""
+## Temporal Latent Bottleneck
+
+An excerpt from the paper:
+
+> In the brain, short-term and long-term memory have developed in a specialized way.
+Short-term memory is allowed to change very quickly to react to immediate sensory inputs
+and perception. By contrast, long-term memory changes slowly, is highly selective and
+involves repeated consolidation.
+
+Inspired from the short-term and long-term memory the authors introduce the fast stream
+and slow stream computation. The fast stream has a short-term memory with a high capacity
+that reacts quickly to sensory input (Transformers). The slow stream has long-term memory
+which updates at a slower rate and summarizes the most relevant information (Recurrence).
+
+To implement this idea we need to:
+
+- Take a sequence of data.
+- Divide the sequence into fixed-size chunks.
+- Fast stream operates within each chunk. It provides fine-grained local information.
+- Slow stream consolidates and aggregates information across chunks. It provides
+coarse-grained distant information.
+
+The fast and slow stream induce what is called **information asymmetry**. The two streams
+interact with each other through a bottleneck of attention. **Figure 1** shows the
+architecture of the model.
+
+|  |
+| :--: |
+| Figure 1: Architecture of the model. (Source: https://arxiv.org/abs/2205.14794) |
+
+A PyTorch-style pseudocode is also proposed by the authors as shown in **Algorithm 1**.
+
+|  |
+| :--: |
+| Algorithm 1: PyTorch style pseudocode. (Source: https://arxiv.org/abs/2205.14794) |
+
+"""
+
+"""
+### `PatchEmbedding` layer
+
+This custom `keras.layers.Layer` is useful for generating patches from the image and
+transform them into a higher-dimensional embedding space using `keras.layers.Embedding`.
+The patching operation is done using a `keras.layers.Conv2D` instance.
+
+Once the patching of images is complete, we reshape the image patches in order to get a
+flattened representation where the number of dimensions is the embedding dimension. At
+this stage, we also inject positional information to the tokens.
+
+After we obtain the tokens we chunk them. The chunking operation involves taking
+fixed-size sequences from the embedding output to create 'chunks', which will then be
+used as the final input to the model.
+"""
+
+
+class PatchEmbedding(layers.Layer):
+ """Image to Patch Embedding.
+ Args:
+ image_size (`Tuple[int]`): Size of the input image.
+ patch_size (`Tuple[int]`): Size of the patch.
+ embed_dim (`int`): Dimension of the embedding.
+ chunk_size (`int`): Number of patches to be chunked.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ patch_size,
+ embed_dim,
+ chunk_size,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ # Compute the patch resolution.
+ patch_resolution = [
+ image_size[0] // patch_size[0],
+ image_size[1] // patch_size[1],
+ ]
+
+ # Store the parameters.
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.patch_resolution = patch_resolution
+ self.num_patches = patch_resolution[0] * patch_resolution[1]
+
+ # Define the positions of the patches.
+ self.positions = ops.arange(start=0, stop=self.num_patches, step=1)
+
+ # Create the layers.
+ self.projection = layers.Conv2D(
+ filters=embed_dim,
+ kernel_size=patch_size,
+ strides=patch_size,
+ name="projection",
+ )
+ self.flatten = layers.Reshape(
+ target_shape=(-1, embed_dim),
+ name="flatten",
+ )
+ self.position_embedding = layers.Embedding(
+ input_dim=self.num_patches,
+ output_dim=embed_dim,
+ name="position_embedding",
+ )
+ self.layernorm = keras.layers.LayerNormalization(
+ epsilon=1e-5,
+ name="layernorm",
+ )
+ self.chunking_layer = layers.Reshape(
+ target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
+ name="chunking_layer",
+ )
+
+ def call(self, inputs):
+ # Project the inputs to the embedding dimension.
+ x = self.projection(inputs)
+
+ # Flatten the pathces and add position embedding.
+ x = self.flatten(x)
+ x = x + self.position_embedding(self.positions)
+
+ # Normalize the embeddings.
+ x = self.layernorm(x)
+
+ # Chunk the tokens.
+ x = self.chunking_layer(x)
+
+ return x
+
+
+"""
+### `FeedForwardNetwork` Layer
+
+This custom `keras.layers.Layer` instance allows us to define a generic FFN along with a
+dropout.
+"""
+
+
+class FeedForwardNetwork(layers.Layer):
+ """Feed Forward Network.
+ Args:
+ dims (`int`): Number of units in FFN.
+ dropout (`float`): Dropout probability for FFN.
+ """
+
+ def __init__(self, dims, dropout, **kwargs):
+ super().__init__(**kwargs)
+
+ # Create the layers.
+ self.ffn = keras.Sequential(
+ [
+ layers.Dense(units=4 * dims, activation="gelu"),
+ layers.Dense(units=dims),
+ layers.Dropout(rate=dropout),
+ ],
+ name="ffn",
+ )
+ self.layernorm = layers.LayerNormalization(
+ epsilon=1e-5,
+ name="layernorm",
+ )
+
+ def call(self, inputs):
+ # Apply the FFN.
+ x = self.layernorm(inputs)
+ x = inputs + self.ffn(x)
+ return x
+
+
+"""
+### `BaseAttention` layer
+
+This custom `keras.layers.Layer` instance is a `super`/`base` class that wraps a
+`keras.layers.MultiHeadAttention` layer along with some other components. This gives us
+basic common denominator functionality for all the Attention layers/modules in our model.
+"""
+
+
+class BaseAttention(layers.Layer):
+ """Base Attention Module.
+ Args:
+ num_heads (`int`): Number of attention heads.
+ key_dim (`int`): Size of each attention head for key.
+ dropout (`float`): Dropout probability for attention module.
+ """
+
+ def __init__(self, num_heads, key_dim, dropout, **kwargs):
+ super().__init__(**kwargs)
+ self.multi_head_attention = layers.MultiHeadAttention(
+ num_heads=num_heads,
+ key_dim=key_dim,
+ dropout=dropout,
+ name="mha",
+ )
+ self.query_layernorm = layers.LayerNormalization(
+ epsilon=1e-5,
+ name="q_layernorm",
+ )
+ self.key_layernorm = layers.LayerNormalization(
+ epsilon=1e-5,
+ name="k_layernorm",
+ )
+ self.value_layernorm = layers.LayerNormalization(
+ epsilon=1e-5,
+ name="v_layernorm",
+ )
+
+ self.attention_scores = None
+
+ def call(self, input_query, key, value):
+ # Apply the attention module.
+ query = self.query_layernorm(input_query)
+ key = self.key_layernorm(key)
+ value = self.value_layernorm(value)
+ (attention_outputs, attention_scores) = self.multi_head_attention(
+ query=query,
+ key=key,
+ value=value,
+ return_attention_scores=True,
+ )
+
+ # Save the attention scores for later visualization.
+ self.attention_scores = attention_scores
+
+ # Add the input to the attention output.
+ x = input_query + attention_outputs
+ return x
+
+
+"""
+### `Attention` with `FeedForwardNetwork` layer
+
+This custom `keras.layers.Layer` implementation combines the `BaseAttention` and
+`FeedForwardNetwork` components to develop one block which will be used repeatedly within
+the model. This module is highly customizable and flexible, allowing for changes within
+the internal layers.
+"""
+
+
+class AttentionWithFFN(layers.Layer):
+ """Attention with Feed Forward Network.
+ Args:
+ ffn_dims (`int`): Number of units in FFN.
+ ffn_dropout (`float`): Dropout probability for FFN.
+ num_heads (`int`): Number of attention heads.
+ key_dim (`int`): Size of each attention head for key.
+ attn_dropout (`float`): Dropout probability for attention module.
+ """
+
+ def __init__(
+ self,
+ ffn_dims,
+ ffn_dropout,
+ num_heads,
+ key_dim,
+ attn_dropout,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ # Create the layers.
+ self.fast_stream_attention = BaseAttention(
+ num_heads=num_heads,
+ key_dim=key_dim,
+ dropout=attn_dropout,
+ name="base_attn",
+ )
+ self.slow_stream_attention = BaseAttention(
+ num_heads=num_heads,
+ key_dim=key_dim,
+ dropout=attn_dropout,
+ name="base_attn",
+ )
+ self.ffn = FeedForwardNetwork(
+ dims=ffn_dims,
+ dropout=ffn_dropout,
+ name="ffn",
+ )
+
+ self.attention_scores = None
+
+ def build(self, input_shape):
+ self.built = True
+
+ def call(self, query, key, value, stream="fast"):
+ # Apply the attention module.
+ attention_layer = {
+ "fast": self.fast_stream_attention,
+ "slow": self.slow_stream_attention,
+ }[stream]
+ if len(query.shape) == 2:
+ query = ops.expand_dims(query, -1)
+ if len(key.shape) == 2:
+ key = ops.expand_dims(key, -1)
+ if len(value.shape) == 2:
+ value = ops.expand_dims(value, -1)
+ x = attention_layer(query, key, value)
+
+ # Save the attention scores for later visualization.
+ self.attention_scores = attention_layer.attention_scores
+
+ # Apply the FFN.
+ x = self.ffn(x)
+ return x
+
+
+"""
+### Custom RNN Cell for **Temporal Latent Bottleneck** and **Perceptual Module**
+
+**Algorithm 1** (the pseudocode) depicts recurrence with the help of for loops. Looping
+does make the implementation simpler, harming the training time. In this section we wrap
+the custom recurrence logic inside of the `CustomRecurrentCell`. This custom cell will
+then be wrapped with the [Keras RNN API](https://keras.io/api/layers/recurrent_layers/rnn/)
+that makes the entire code vectorizable.
+
+This custom cell, implemented as a `keras.layers.Layer`, is the integral part of the
+logic for the model.
+The cell's functionality can be divided into 2 parts:
+- **Slow Stream (Temporal Latent Bottleneck):**
+
+- This module consists of a single `AttentionWithFFN` layer that parses the output of the
+previous Slow Stream, an intermediate hidden representation (which is the *latent* in
+Temporal Latent Bottleneck) as the Query, and the output of the latest Fast Stream as Key
+and Value. This layer can also be construed as a *CrossAttention* layer.
+
+- **Fast Stream (Perceptual Module):**
+
+- This module consists of intertwined `AttentionWithFFN` layers. This stream consists of
+*n* layers of `SelfAttention` and `CrossAttention` in a sequential manner.
+- Here, some layers take the chunked input as the Query, Key and Value (Also referred to
+as the *SelfAttention* layer).
+- The other layers take the intermediate state outputs from within the Temporal Latent
+Bottleneck module as the Query while using the output of the previous Self-Attention
+layers before it as the Key and Value.
+"""
+
+
+class CustomRecurrentCell(layers.Layer):
+ """Custom Recurrent Cell.
+ Args:
+ chunk_size (`int`): Number of tokens in a chunk.
+ r (`int`): One Cross Attention per **r** Self Attention.
+ num_layers (`int`): Number of layers.
+ ffn_dims (`int`): Number of units in FFN.
+ ffn_dropout (`float`): Dropout probability for FFN.
+ num_heads (`int`): Number of attention heads.
+ key_dim (`int`): Size of each attention head for key.
+ attn_dropout (`float`): Dropout probability for attention module.
+ """
+
+ def __init__(
+ self,
+ chunk_size,
+ r,
+ num_layers,
+ ffn_dims,
+ ffn_dropout,
+ num_heads,
+ key_dim,
+ attn_dropout,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ # Save the arguments.
+ self.chunk_size = chunk_size
+ self.r = r
+ self.num_layers = num_layers
+ self.ffn_dims = ffn_dims
+ self.ffn_droput = ffn_dropout
+ self.num_heads = num_heads
+ self.key_dim = key_dim
+ self.attn_dropout = attn_dropout
+
+ # Create state_size. This is important for
+ # custom recurrence logic.
+ self.state_size = chunk_size * ffn_dims
+
+ self.get_attention_scores = False
+ self.attention_scores = []
+
+ # Perceptual Module
+ perceptual_module = list()
+ for layer_idx in range(num_layers):
+ perceptual_module.append(
+ AttentionWithFFN(
+ ffn_dims=ffn_dims,
+ ffn_dropout=ffn_dropout,
+ num_heads=num_heads,
+ key_dim=key_dim,
+ attn_dropout=attn_dropout,
+ name=f"pm_self_attn_{layer_idx}",
+ )
+ )
+ if layer_idx % r == 0:
+ perceptual_module.append(
+ AttentionWithFFN(
+ ffn_dims=ffn_dims,
+ ffn_dropout=ffn_dropout,
+ num_heads=num_heads,
+ key_dim=key_dim,
+ attn_dropout=attn_dropout,
+ name=f"pm_cross_attn_ffn_{layer_idx}",
+ )
+ )
+ self.perceptual_module = perceptual_module
+
+ # Temporal Latent Bottleneck Module
+ self.tlb_module = AttentionWithFFN(
+ ffn_dims=ffn_dims,
+ ffn_dropout=ffn_dropout,
+ num_heads=num_heads,
+ key_dim=key_dim,
+ attn_dropout=attn_dropout,
+ name=f"tlb_cross_attn_ffn",
+ )
+
+ def build(self, input_shape):
+ self.built = True
+
+ def call(self, inputs, states):
+ # inputs => (batch, chunk_size, dims)
+ # states => [(batch, chunk_size, units)]
+ slow_stream = ops.reshape(states[0], (-1, self.chunk_size, self.ffn_dims))
+ fast_stream = inputs
+
+ for layer_idx, layer in enumerate(self.perceptual_module):
+ fast_stream = layer(
+ query=fast_stream, key=fast_stream, value=fast_stream, stream="fast"
+ )
+
+ if layer_idx % self.r == 0:
+ fast_stream = layer(
+ query=fast_stream, key=slow_stream, value=slow_stream, stream="slow"
+ )
+
+ slow_stream = self.tlb_module(
+ query=slow_stream, key=fast_stream, value=fast_stream
+ )
+
+ # Save the attention scores for later visualization.
+ if self.get_attention_scores:
+ self.attention_scores.append(self.tlb_module.attention_scores)
+
+ return fast_stream, [
+ ops.reshape(slow_stream, (-1, self.chunk_size * self.ffn_dims))
+ ]
+
+
+"""
+### `TemporalLatentBottleneckModel` to encapsulate full model
+
+Here, we just wrap the full model as to expose it for training.
+"""
+
+
+class TemporalLatentBottleneckModel(keras.Model):
+ """Model Trainer.
+ Args:
+ patch_layer (`keras.layers.Layer`): Patching layer.
+ custom_cell (`keras.layers.Layer`): Custom Recurrent Cell.
+ """
+
+ def __init__(self, patch_layer, custom_cell, unroll_loops=False, **kwargs):
+ super().__init__(**kwargs)
+ self.patch_layer = patch_layer
+ self.rnn = layers.RNN(custom_cell, unroll=unroll_loops, name="rnn")
+ self.gap = layers.GlobalAveragePooling1D(name="gap")
+ self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")
+
+ def call(self, inputs):
+ x = self.patch_layer(inputs)
+ x = self.rnn(x)
+ x = self.gap(x)
+ outputs = self.head(x)
+ return outputs
+
+
+"""
+## Build the model
+
+To begin training, we now define the components individually and pass them as arguments
+to our wrapper class, which will prepare the final model for training. We define a
+`PatchEmbed` layer, and the `CustomCell`-based RNN.
+"""
+
+# Build the model.
+patch_layer = PatchEmbedding(
+ image_size=(config["image_size"], config["image_size"]),
+ patch_size=(config["patch_size"], config["patch_size"]),
+ embed_dim=config["embed_dim"],
+ chunk_size=config["chunk_size"],
+)
+custom_rnn_cell = CustomRecurrentCell(
+ chunk_size=config["chunk_size"],
+ r=config["r"],
+ num_layers=config["num_layers"],
+ ffn_dims=config["embed_dim"],
+ ffn_dropout=config["ffn_drop"],
+ num_heads=config["num_heads"],
+ key_dim=config["embed_dim"],
+ attn_dropout=config["attn_drop"],
+)
+model = TemporalLatentBottleneckModel(
+ patch_layer=patch_layer,
+ custom_cell=custom_rnn_cell,
+)
+
+"""
+## Metrics and Callbacks
+
+We use the `AdamW` optimizer since it has been shown to perform very well on several benchmark
+tasks from an optimization perspective. It is a version of the `keras.optimizers.Adam`
+optimizer, along with Weight Decay in place.
+
+For a loss function, we make use of the `keras.losses.SparseCategoricalCrossentropy`
+function that makes use of simple Cross-entropy between prediction and actual logits. We
+also calculate accuracy on our data as a sanity-check.
+"""
+
+optimizer = AdamW(
+ learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
+)
+model.compile(
+ optimizer=optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+)
+
+"""
+## Train the model with `model.fit()`
+
+We pass the training dataset and run training.
+"""
+
+history = model.fit(
+ train_ds,
+ epochs=config["epochs"],
+ validation_data=val_ds,
+)
+
+"""
+## Visualize training metrics
+
+The `model.fit()` will return a `history` object, which stores the values of the metrics
+generated during the training run (but it is ephemeral and needs to be saved manually).
+
+We now display the Loss and Accuracy curves for the training and validation sets.
+"""
+
+plt.plot(history.history["loss"], label="loss")
+plt.plot(history.history["val_loss"], label="val_loss")
+plt.legend()
+plt.show()
+
+plt.plot(history.history["accuracy"], label="accuracy")
+plt.plot(history.history["val_accuracy"], label="val_accuracy")
+plt.legend()
+plt.show()
+
+"""
+## Visualize attention maps from the Temporal Latent Bottleneck
+
+Now that we have trained our model, it is time for some visualizations. The Fast Stream
+(Transformers) processes a chunk of tokens. The Slow Stream processes each chunk and
+attends to tokens that are useful for the task.
+
+In this section we visualize the attention map of the Slow Stream. This is done by
+extracting the attention scores from the TLB layer at each chunk's intersection and
+storing it within the RNN's state. This is followed by 'ballooning' it up and returning
+these values.
+"""
+
+
+def score_to_viz(chunk_score):
+ # get the most attended token
+ chunk_viz = ops.max(chunk_score, axis=-2)
+ # get the mean across heads
+ chunk_viz = ops.mean(chunk_viz, axis=1)
+ return chunk_viz
+
+
+# Get a batch of images and labels from the testing dataset
+images, labels = next(iter(test_ds))
+
+# Create a new model instance that is executed eagerly to allow saving
+# attention scores. This also requires unrolling loops
+eager_model = TemporalLatentBottleneckModel(
+ patch_layer=patch_layer, custom_cell=custom_rnn_cell, unroll_loops=True
+)
+eager_model.compile(run_eagerly=True, jit_compile=False)
+model.save("weights.keras")
+eager_model.load_weights("weights.keras")
+
+# Set the get_attn_scores flag to True
+eager_model.rnn.cell.get_attention_scores = True
+
+# Run the model with the testing images and grab the
+# attention scores.
+outputs = eager_model(images)
+list_chunk_scores = eager_model.rnn.cell.attention_scores
+
+# Process the attention scores in order to visualize them
+num_chunks = (config["image_size"] // config["patch_size"]) ** 2 // config["chunk_size"]
+list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores[-num_chunks:]]
+chunk_viz = ops.concatenate(list_chunk_viz, axis=-1)
+chunk_viz = ops.reshape(
+ chunk_viz,
+ (
+ config["batch_size"],
+ config["image_size"] // config["patch_size"],
+ config["image_size"] // config["patch_size"],
+ 1,
+ ),
+)
+upsampled_heat_map = layers.UpSampling2D(
+ size=(4, 4), interpolation="bilinear", dtype="float32"
+)(chunk_viz)
+
+"""
+Run the following code snippet to get different images and their attention maps.
+"""
+
+# Sample a random image
+index = random.randint(0, config["batch_size"])
+orig_image = images[index]
+overlay_image = upsampled_heat_map[index, ..., 0]
+
+if keras.backend.backend() == "torch":
+ # when using the torch backend, we are required to ensure that the
+ # image is copied from the GPU
+ orig_image = orig_image.cpu().detach().numpy()
+ overlay_image = overlay_image.cpu().detach().numpy()
+
+# Plot the visualization
+fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
+
+ax[0].imshow(orig_image)
+ax[0].set_title("Original:")
+ax[0].axis("off")
+
+image = ax[1].imshow(orig_image)
+ax[1].imshow(
+ overlay_image,
+ cmap="inferno",
+ alpha=0.6,
+ extent=image.get_extent(),
+)
+ax[1].set_title("TLB Attention:")
+
+plt.show()
+
+"""
+## Conclusion
+
+This example has demonstrated an implementation of the Temporal Latent Bottleneck
+mechanism. The example highlights the use of compression and storage of historical states
+in the form of a Temporal Latent Bottleneck with regular updates from a Perceptual Module
+as an effective method to do so.
+
+In the original paper, the authors have conducted highly extensive tests around different
+modalities ranging from Supervised Image Classification to applications in Reinforcement
+Learning.
+
+While we have only displayed a method to apply this mechanism to Image Classification, it
+can be extended to other modalities too with minimal changes.
+
+*Note*: While building this example we did not have the official code to refer to. This
+means that our implementation is inspired by the paper with no claims of being a
+complete reproduction. For more details on the training process one can head over to
+[our GitHub repository](https://github.com/suvadityamuk/Temporal-Latent-Bottleneck-TF).
+"""
+
+"""
+## Acknowledgement
+
+Thanks to [Aniket Didolkar](https://www.aniketdidolkar.in/) (the first author) and
+[Anirudh Goyal](https://anirudh9119.github.io/) (the third author)
+for revieweing our work.
+
+We would like to thank
+[PyImageSearch](https://pyimagesearch.com/) for a Colab Pro account and
+[JarvisLabs.ai](https://cloud.jarvislabs.ai/) for the GPU credits.
+"""
diff --git a/knowledge_base/vision/token_learner.py b/knowledge_base/vision/token_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f86a5983ec8e5acf8a7a8ef25ba5646e1b727d3c
--- /dev/null
+++ b/knowledge_base/vision/token_learner.py
@@ -0,0 +1,515 @@
+"""
+Title: Learning to tokenize in Vision Transformers
+Authors: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
+Date created: 2021/12/10
+Last modified: 2023/08/14
+Description: Adaptively generating a smaller number of tokens for Vision Transformers.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Vision Transformers ([Dosovitskiy et al.](https://arxiv.org/abs/2010.11929)) and many
+other Transformer-based architectures ([Liu et al.](https://arxiv.org/abs/2103.14030),
+[Yuan et al.](https://arxiv.org/abs/2101.11986), etc.) have shown strong results in
+image recognition. The following provides a brief overview of the components involved in the
+Vision Transformer architecture for image classification:
+
+* Extract small patches from input images.
+* Linearly project those patches.
+* Add positional embeddings to these linear projections.
+* Run these projections through a series of Transformer ([Vaswani et al.](https://arxiv.org/abs/1706.03762))
+blocks.
+* Finally, take the representation from the final Transformer block and add a
+classification head.
+
+If we take 224x224 images and extract 16x16 patches, we get a total of 196 patches (also
+called tokens) for each image. The number of patches increases as we increase the
+resolution, leading to higher memory footprint. Could we use a reduced
+number of patches without having to compromise performance?
+Ryoo et al. investigate this question in
+[TokenLearner: Adaptive Space-Time Tokenization for Videos](https://openreview.net/forum?id=z-l1kpDXs88).
+They introduce a novel module called **TokenLearner** that can help reduce the number
+of patches used by a Vision Transformer (ViT) in an adaptive manner. With TokenLearner
+incorporated in the standard ViT architecture, they are able to reduce the amount of
+compute (measured in FLOPS) used by the model.
+
+In this example, we implement the TokenLearner module and demonstrate its
+performance with a mini ViT and the CIFAR-10 dataset. We make use of the following
+references:
+
+* [Official TokenLearner code](https://github.com/google-research/scenic/blob/main/scenic/projects/token_learner/model.py)
+* [Image Classification with ViTs on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer/)
+* [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf)
+"""
+
+
+"""
+## Imports
+"""
+
+import keras
+from keras import layers
+from keras import ops
+from tensorflow import data as tf_data
+
+
+from datetime import datetime
+import matplotlib.pyplot as plt
+import numpy as np
+
+import math
+
+"""
+## Hyperparameters
+
+Please feel free to change the hyperparameters and check your results. The best way to
+develop intuition about the architecture is to experiment with it.
+"""
+
+# DATA
+BATCH_SIZE = 256
+AUTO = tf_data.AUTOTUNE
+INPUT_SHAPE = (32, 32, 3)
+NUM_CLASSES = 10
+
+# OPTIMIZER
+LEARNING_RATE = 1e-3
+WEIGHT_DECAY = 1e-4
+
+# TRAINING
+EPOCHS = 1
+
+# AUGMENTATION
+IMAGE_SIZE = 48 # We will resize input images to this size.
+PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.
+NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
+
+# ViT ARCHITECTURE
+LAYER_NORM_EPS = 1e-6
+PROJECTION_DIM = 128
+NUM_HEADS = 4
+NUM_LAYERS = 4
+MLP_UNITS = [
+ PROJECTION_DIM * 2,
+ PROJECTION_DIM,
+]
+
+# TOKENLEARNER
+NUM_TOKENS = 4
+
+"""
+## Load and prepare the CIFAR-10 dataset
+"""
+
+# Load the CIFAR-10 dataset.
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
+(x_train, y_train), (x_val, y_val) = (
+ (x_train[:40000], y_train[:40000]),
+ (x_train[40000:], y_train[40000:]),
+)
+print(f"Training samples: {len(x_train)}")
+print(f"Validation samples: {len(x_val)}")
+print(f"Testing samples: {len(x_test)}")
+
+# Convert to tf.data.Dataset objects.
+train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
+train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)
+
+val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
+val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
+test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
+
+"""
+## Data augmentation
+
+The augmentation pipeline consists of:
+
+- Rescaling
+- Resizing
+- Random cropping (fixed-sized or random sized)
+- Random horizontal flipping
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Rescaling(1 / 255.0),
+ layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
+ layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
+ layers.RandomFlip("horizontal"),
+ ],
+ name="data_augmentation",
+)
+
+"""
+Note that image data augmentation layers do not apply data transformations at inference time.
+This means that when these layers are called with `training=False` they behave differently. Refer
+[to the documentation](https://keras.io/api/layers/preprocessing_layers/image_augmentation/) for more
+details.
+"""
+
+"""
+## Positional embedding module
+
+A [Transformer](https://arxiv.org/abs/1706.03762) architecture consists of **multi-head
+self attention** layers and **fully-connected feed forward** networks (MLP) as the main
+components. Both these components are _permutation invariant_: they're not aware of
+feature order.
+
+To overcome this problem we inject tokens with positional information. The
+`position_embedding` function adds this positional information to the linearly projected
+tokens.
+"""
+
+
+class PatchEncoder(layers.Layer):
+ def __init__(self, num_patches, projection_dim):
+ super().__init__()
+ self.num_patches = num_patches
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+
+ def call(self, patch):
+ positions = ops.expand_dims(
+ ops.arange(start=0, stop=self.num_patches, step=1), axis=0
+ )
+ encoded = patch + self.position_embedding(positions)
+ return encoded
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({"num_patches": self.num_patches})
+ return config
+
+
+"""
+## MLP block for Transformer
+
+This serves as the Fully Connected Feed Forward block for our Transformer.
+"""
+
+
+def mlp(x, dropout_rate, hidden_units):
+ # Iterate over the hidden units and
+ # add Dense => Dropout.
+ for units in hidden_units:
+ x = layers.Dense(units, activation=ops.gelu)(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+"""
+## TokenLearner module
+
+The following figure presents a pictorial overview of the module
+([source](https://ai.googleblog.com/2021/12/improving-vision-transformer-efficiency.html)).
+
+
+
+The TokenLearner module takes as input an image-shaped tensor. It then passes it through
+multiple single-channel convolutional layers extracting different spatial attention maps
+focusing on different parts of the input. These attention maps are then element-wise
+multiplied to the input and result is aggregated with pooling. This pooled output can be
+trated as a summary of the input and has much lesser number of patches (8, for example)
+than the original one (196, for example).
+
+Using multiple convolution layers helps with expressivity. Imposing a form of spatial
+attention helps retain relevant information from the inputs. Both of these components are
+crucial to make TokenLearner work, especially when we are significantly reducing the number of patches.
+"""
+
+
+def token_learner(inputs, number_of_tokens=NUM_TOKENS):
+ # Layer normalize the inputs.
+ x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs) # (B, H, W, C)
+
+ # Applying Conv2D => Reshape => Permute
+ # The reshape and permute is done to help with the next steps of
+ # multiplication and Global Average Pooling.
+ attention_maps = keras.Sequential(
+ [
+ # 3 layers of conv with gelu activation as suggested
+ # in the paper.
+ layers.Conv2D(
+ filters=number_of_tokens,
+ kernel_size=(3, 3),
+ activation=ops.gelu,
+ padding="same",
+ use_bias=False,
+ ),
+ layers.Conv2D(
+ filters=number_of_tokens,
+ kernel_size=(3, 3),
+ activation=ops.gelu,
+ padding="same",
+ use_bias=False,
+ ),
+ layers.Conv2D(
+ filters=number_of_tokens,
+ kernel_size=(3, 3),
+ activation=ops.gelu,
+ padding="same",
+ use_bias=False,
+ ),
+ # This conv layer will generate the attention maps
+ layers.Conv2D(
+ filters=number_of_tokens,
+ kernel_size=(3, 3),
+ activation="sigmoid", # Note sigmoid for [0, 1] output
+ padding="same",
+ use_bias=False,
+ ),
+ # Reshape and Permute
+ layers.Reshape((-1, number_of_tokens)), # (B, H*W, num_of_tokens)
+ layers.Permute((2, 1)),
+ ]
+ )(
+ x
+ ) # (B, num_of_tokens, H*W)
+
+ # Reshape the input to align it with the output of the conv block.
+ num_filters = inputs.shape[-1]
+ inputs = layers.Reshape((1, -1, num_filters))(inputs) # inputs == (B, 1, H*W, C)
+
+ # Element-Wise multiplication of the attention maps and the inputs
+ attended_inputs = (
+ ops.expand_dims(attention_maps, axis=-1) * inputs
+ ) # (B, num_tokens, H*W, C)
+
+ # Global average pooling the element wise multiplication result.
+ outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)
+ return outputs
+
+
+"""
+## Transformer block
+"""
+
+
+def transformer(encoded_patches):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
+
+ # Multi Head Self Attention layer 1.
+ attention_output = layers.MultiHeadAttention(
+ num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
+ )(x1, x1)
+
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, encoded_patches])
+
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
+
+ # MLP layer 1.
+ x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)
+
+ # Skip connection 2.
+ encoded_patches = layers.Add()([x4, x2])
+ return encoded_patches
+
+
+"""
+## ViT model with the TokenLearner module
+"""
+
+
+def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
+ inputs = layers.Input(shape=INPUT_SHAPE) # (B, H, W, C)
+
+ # Augment data.
+ augmented = data_augmentation(inputs)
+
+ # Create patches and project the pathces.
+ projected_patches = layers.Conv2D(
+ filters=PROJECTION_DIM,
+ kernel_size=(PATCH_SIZE, PATCH_SIZE),
+ strides=(PATCH_SIZE, PATCH_SIZE),
+ padding="VALID",
+ )(augmented)
+ _, h, w, c = projected_patches.shape
+ projected_patches = layers.Reshape((h * w, c))(
+ projected_patches
+ ) # (B, number_patches, projection_dim)
+
+ # Add positional embeddings to the projected patches.
+ encoded_patches = PatchEncoder(
+ num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
+ )(
+ projected_patches
+ ) # (B, number_patches, projection_dim)
+ encoded_patches = layers.Dropout(0.1)(encoded_patches)
+
+ # Iterate over the number of layers and stack up blocks of
+ # Transformer.
+ for i in range(NUM_LAYERS):
+ # Add a Transformer block.
+ encoded_patches = transformer(encoded_patches)
+
+ # Add TokenLearner layer in the middle of the
+ # architecture. The paper suggests that anywhere
+ # between 1/2 or 3/4 will work well.
+ if use_token_learner and i == NUM_LAYERS // 2:
+ _, hh, c = encoded_patches.shape
+ h = int(math.sqrt(hh))
+ encoded_patches = layers.Reshape((h, h, c))(
+ encoded_patches
+ ) # (B, h, h, projection_dim)
+ encoded_patches = token_learner(
+ encoded_patches, token_learner_units
+ ) # (B, num_tokens, c)
+
+ # Layer normalization and Global average pooling.
+ representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
+ representation = layers.GlobalAvgPool1D()(representation)
+
+ # Classify outputs.
+ outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)
+
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+"""
+As shown in the [TokenLearner paper](https://openreview.net/forum?id=z-l1kpDXs88), it is
+almost always advantageous to include the TokenLearner module in the middle of the
+network.
+"""
+
+"""
+## Training utility
+"""
+
+
+def run_experiment(model):
+ # Initialize the AdamW optimizer.
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
+ )
+
+ # Compile the model with the optimizer, loss function
+ # and the metrics.
+ model.compile(
+ optimizer=optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+ )
+
+ # Define callbacks
+ checkpoint_filepath = "/tmp/checkpoint.weights.h5"
+ checkpoint_callback = keras.callbacks.ModelCheckpoint(
+ checkpoint_filepath,
+ monitor="val_accuracy",
+ save_best_only=True,
+ save_weights_only=True,
+ )
+
+ # Train the model.
+ _ = model.fit(
+ train_ds,
+ epochs=EPOCHS,
+ validation_data=val_ds,
+ callbacks=[checkpoint_callback],
+ )
+
+ model.load_weights(checkpoint_filepath)
+ _, accuracy, top_5_accuracy = model.evaluate(test_ds)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+
+"""
+## Train and evaluate a ViT with TokenLearner
+"""
+
+vit_token_learner = create_vit_classifier()
+run_experiment(vit_token_learner)
+
+"""
+## Results
+
+We experimented with and without the TokenLearner inside the mini ViT we implemented
+(with the same hyperparameters presented in this example). Here are our results:
+
+| **TokenLearner** | **# tokens in
TokenLearner** | **Top-1 Acc
(Averaged across 5 runs)** | **GFLOPs** | **TensorBoard** |
+|:---:|:---:|:---:|:---:|:---:|
+| N | - | 56.112% | 0.0184 | [Link](https://tensorboard.dev/experiment/vkCwM49dQZ2RiK0ZT4mj7w/) |
+| Y | 8 | **56.55%** | **0.0153** | [Link](https://tensorboard.dev/experiment/vkCwM49dQZ2RiK0ZT4mj7w/) |
+| N | - | 56.37% | 0.0184 | [Link](https://tensorboard.dev/experiment/hdyJ4wznQROwqZTgbtmztQ/) |
+| Y | 4 | **56.4980%** | **0.0147** | [Link](https://tensorboard.dev/experiment/hdyJ4wznQROwqZTgbtmztQ/) |
+| N | - (# Transformer layers: 8) | 55.36% | 0.0359 | [Link](https://tensorboard.dev/experiment/sepBK5zNSaOtdCeEG6SV9w/) |
+
+TokenLearner is able to consistently outperform our mini ViT without the module. It is
+also interesting to notice that it was also able to outperform a deeper version of our
+mini ViT (with 8 layers). The authors also report similar observations in the paper and
+they attribute this to the adaptiveness of TokenLearner.
+
+One should also note that the FLOPs count **decreases** considerably with the addition of
+the TokenLearner module. With less FLOPs count the TokenLearner module is able to
+deliver better results. This aligns very well with the authors' findings.
+
+Additionally, the authors [introduced](https://github.com/google-research/scenic/blob/main/scenic/projects/token_learner/model.py#L104)
+a newer version of the TokenLearner for smaller training data regimes. Quoting the authors:
+
+> Instead of using 4 conv. layers with small channels to implement spatial attention,
+ this version uses 2 grouped conv. layers with more channels. It also uses softmax
+ instead of sigmoid. We confirmed that this version works better when having limited
+ training data, such as training with ImageNet1K from scratch.
+
+We experimented with this module and in the following table we summarize the results:
+
+| **# Groups** | **# Tokens** | **Top-1 Acc** | **GFLOPs** | **TensorBoard** |
+|:---:|:---:|:---:|:---:|:---:|
+| 4 | 4 | 54.638% | 0.0149 | [Link](https://tensorboard.dev/experiment/KmfkGqAGQjikEw85phySmw/) |
+| 8 | 8 | 54.898% | 0.0146 | [Link](https://tensorboard.dev/experiment/0PpgYOq9RFWV9njX6NJQ2w/) |
+| 4 | 8 | 55.196% | 0.0149 | [Link](https://tensorboard.dev/experiment/WUkrHbZASdu3zrfmY4ETZg/) |
+
+Please note that we used the same hyperparameters presented in this example. Our
+implementation is available
+[in this notebook](https://github.com/ariG23498/TokenLearner/blob/master/TokenLearner-V1.1.ipynb).
+We acknowledge that the results with this new TokenLearner module are slightly off
+than expected and this might mitigate with hyperparameter tuning.
+
+*Note*: To compute the FLOPs of our models we used
+[this utility](https://github.com/AdityaKane2001/regnety/blob/main/regnety/utils/model_utils.py#L27)
+from [this repository](https://github.com/AdityaKane2001/regnety).
+"""
+
+"""
+## Number of parameters
+
+You may have noticed that adding the TokenLearner module increases the number of
+parameters of the base network. But that does not mean it is less efficient as shown by
+[Dehghani et al.](https://arxiv.org/abs/2110.12894). Similar findings were reported
+by [Bello et al.](https://arxiv.org/abs/2103.07579) as well. The TokenLearner module
+helps reducing the FLOPS in the overall network thereby helping to reduce the memory
+footprint.
+"""
+
+"""
+## Final notes
+
+* TokenFuser: The authors of the paper also propose another module named TokenFuser. This
+module helps in remapping the representation of the TokenLearner output back to its
+original spatial resolution. To reuse the TokenLearner in the ViT architecture, the
+TokenFuser is a must. We first learn the tokens from the TokenLearner, build a
+representation of the tokens from a Transformer layer and then remap the representation
+into the original spatial resolution, so that it can again be consumed by a TokenLearner.
+Note here that you can only use the TokenLearner module once in entire ViT model if not
+paired with the TokenFuser.
+* Use of these modules for video: The authors also suggest that TokenFuser goes really
+well with Vision Transformers for Videos ([Arnab et al.](https://arxiv.org/abs/2103.15691)).
+
+We are grateful to [JarvisLabs](https://jarvislabs.ai/) and
+[Google Developers Experts](https://developers.google.com/programs/experts/)
+program for helping with GPU credits. Also, we are thankful to Michael Ryoo (first
+author of TokenLearner) for fruitful discussions.
+"""
diff --git a/knowledge_base/vision/video_classification.py b/knowledge_base/vision/video_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d3805164a90ed3ec9af70269806aaf735836da2
--- /dev/null
+++ b/knowledge_base/vision/video_classification.py
@@ -0,0 +1,384 @@
+"""
+Title: Video Classification with a CNN-RNN Architecture
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/05/28
+Last modified: 2023/12/08
+Description: Training a video classifier with transfer learning and a recurrent model on the UCF101 dataset.
+Accelerator: GPU
+"""
+
+"""
+This example demonstrates video classification, an important use-case with
+applications in recommendations, security, and so on.
+We will be using the [UCF101 dataset](https://www.crcv.ucf.edu/data/UCF101.php)
+to build our video classifier. The dataset consists of videos categorized into different
+actions, like cricket shot, punching, biking, etc. This dataset is commonly used to
+build action recognizers, which are an application of video classification.
+
+A video consists of an ordered sequence of frames. Each frame contains *spatial*
+information, and the sequence of those frames contains *temporal* information. To model
+both of these aspects, we use a hybrid architecture that consists of convolutions
+(for spatial processing) as well as recurrent layers (for temporal processing).
+Specifically, we'll use a Convolutional Neural Network (CNN) and a Recurrent Neural
+Network (RNN) consisting of [GRU layers](https://keras.io/api/layers/recurrent_layers/gru/).
+This kind of hybrid architecture is popularly known as a **CNN-RNN**.
+
+This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be
+installed using the following command:
+"""
+
+"""shell
+pip install -q git+https://github.com/tensorflow/docs
+"""
+
+"""
+## Data collection
+
+In order to keep the runtime of this example relatively short, we will be using a
+subsampled version of the original UCF101 dataset. You can refer to
+[this notebook](https://colab.research.google.com/github/sayakpaul/Action-Recognition-in-TensorFlow/blob/main/Data_Preparation_UCF101.ipynb)
+to know how the subsampling was done.
+"""
+
+"""shell
+!wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz
+tar xf ucf101_top5.tar.gz
+"""
+
+"""
+## Setup
+"""
+import os
+
+import keras
+from imutils import paths
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import numpy as np
+import imageio
+import cv2
+from IPython.display import Image
+
+"""
+## Define hyperparameters
+"""
+
+IMG_SIZE = 224
+BATCH_SIZE = 64
+EPOCHS = 10
+
+MAX_SEQ_LENGTH = 20
+NUM_FEATURES = 2048
+
+"""
+## Data preparation
+"""
+
+train_df = pd.read_csv("train.csv")
+test_df = pd.read_csv("test.csv")
+
+print(f"Total videos for training: {len(train_df)}")
+print(f"Total videos for testing: {len(test_df)}")
+
+train_df.sample(10)
+
+"""
+One of the many challenges of training video classifiers is figuring out a way to feed
+the videos to a network. [This blog post](https://blog.coast.ai/five-video-classification-methods-implemented-in-keras-and-tensorflow-99cad29cc0b5)
+discusses five such methods. Since a video is an ordered sequence of frames, we could
+just extract the frames and put them in a 3D tensor. But the number of frames may differ
+from video to video which would prevent us from stacking them into batches
+(unless we use padding). As an alternative, we can **save video frames at a fixed
+interval until a maximum frame count is reached**. In this example we will do
+the following:
+
+1. Capture the frames of a video.
+2. Extract frames from the videos until a maximum frame count is reached.
+3. In the case, where a video's frame count is lesser than the maximum frame count we
+will pad the video with zeros.
+
+Note that this workflow is identical to [problems involving texts sequences](https://developers.google.com/machine-learning/guides/text-classification/). Videos of the UCF101 dataset is [known](https://www.crcv.ucf.edu/papers/UCF101_CRCV-TR-12-01.pdf)
+to not contain extreme variations in objects and actions across frames. Because of this,
+it may be okay to only consider a few frames for the learning task. But this approach may
+not generalize well to other video classification problems. We will be using
+[OpenCV's `VideoCapture()` method](https://docs.opencv.org/master/dd/d43/tutorial_py_video_display.html)
+to read frames from videos.
+"""
+
+# The following two methods are taken from this tutorial:
+# https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub
+
+
+def crop_center_square(frame):
+ y, x = frame.shape[0:2]
+ min_dim = min(y, x)
+ start_x = (x // 2) - (min_dim // 2)
+ start_y = (y // 2) - (min_dim // 2)
+ return frame[start_y : start_y + min_dim, start_x : start_x + min_dim]
+
+
+def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)):
+ cap = cv2.VideoCapture(path)
+ frames = []
+ try:
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ frame = crop_center_square(frame)
+ frame = cv2.resize(frame, resize)
+ frame = frame[:, :, [2, 1, 0]]
+ frames.append(frame)
+
+ if len(frames) == max_frames:
+ break
+ finally:
+ cap.release()
+ return np.array(frames)
+
+
+"""
+We can use a pre-trained network to extract meaningful features from the extracted
+frames. The [`Keras Applications`](https://keras.io/api/applications/) module provides
+a number of state-of-the-art models pre-trained on the [ImageNet-1k dataset](http://image-net.org/).
+We will be using the [InceptionV3 model](https://arxiv.org/abs/1512.00567) for this purpose.
+"""
+
+
+def build_feature_extractor():
+ feature_extractor = keras.applications.InceptionV3(
+ weights="imagenet",
+ include_top=False,
+ pooling="avg",
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
+ )
+ preprocess_input = keras.applications.inception_v3.preprocess_input
+
+ inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3))
+ preprocessed = preprocess_input(inputs)
+
+ outputs = feature_extractor(preprocessed)
+ return keras.Model(inputs, outputs, name="feature_extractor")
+
+
+feature_extractor = build_feature_extractor()
+
+"""
+The labels of the videos are strings. Neural networks do not understand string values,
+so they must be converted to some numerical form before they are fed to the model. Here
+we will use the [`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup)
+layer encode the class labels as integers.
+"""
+
+label_processor = keras.layers.StringLookup(
+ num_oov_indices=0, vocabulary=np.unique(train_df["tag"])
+)
+print(label_processor.get_vocabulary())
+
+"""
+Finally, we can put all the pieces together to create our data processing utility.
+"""
+
+
+def prepare_all_videos(df, root_dir):
+ num_samples = len(df)
+ video_paths = df["video_name"].values.tolist()
+ labels = df["tag"].values
+ labels = keras.ops.convert_to_numpy(label_processor(labels[..., None]))
+
+ # `frame_masks` and `frame_features` are what we will feed to our sequence model.
+ # `frame_masks` will contain a bunch of booleans denoting if a timestep is
+ # masked with padding or not.
+ frame_masks = np.zeros(shape=(num_samples, MAX_SEQ_LENGTH), dtype="bool")
+ frame_features = np.zeros(
+ shape=(num_samples, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32"
+ )
+
+ # For each video.
+ for idx, path in enumerate(video_paths):
+ # Gather all its frames and add a batch dimension.
+ frames = load_video(os.path.join(root_dir, path))
+ frames = frames[None, ...]
+
+ # Initialize placeholders to store the masks and features of the current video.
+ temp_frame_mask = np.zeros(
+ shape=(
+ 1,
+ MAX_SEQ_LENGTH,
+ ),
+ dtype="bool",
+ )
+ temp_frame_features = np.zeros(
+ shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32"
+ )
+
+ # Extract features from the frames of the current video.
+ for i, batch in enumerate(frames):
+ video_length = batch.shape[0]
+ length = min(MAX_SEQ_LENGTH, video_length)
+ for j in range(length):
+ temp_frame_features[i, j, :] = feature_extractor.predict(
+ batch[None, j, :],
+ verbose=0,
+ )
+ temp_frame_mask[i, :length] = 1 # 1 = not masked, 0 = masked
+
+ frame_features[idx,] = temp_frame_features.squeeze()
+ frame_masks[idx,] = temp_frame_mask.squeeze()
+
+ return (frame_features, frame_masks), labels
+
+
+train_data, train_labels = prepare_all_videos(train_df, "train")
+test_data, test_labels = prepare_all_videos(test_df, "test")
+
+print(f"Frame features in train set: {train_data[0].shape}")
+print(f"Frame masks in train set: {train_data[1].shape}")
+
+"""
+The above code block will take ~20 minutes to execute depending on the machine it's being
+executed.
+"""
+
+"""
+## The sequence model
+
+Now, we can feed this data to a sequence model consisting of recurrent layers like `GRU`.
+
+"""
+
+
+# Utility for our sequence model.
+def get_sequence_model():
+ class_vocab = label_processor.get_vocabulary()
+
+ frame_features_input = keras.Input((MAX_SEQ_LENGTH, NUM_FEATURES))
+ mask_input = keras.Input((MAX_SEQ_LENGTH,), dtype="bool")
+
+ # Refer to the following tutorial to understand the significance of using `mask`:
+ # https://keras.io/api/layers/recurrent_layers/gru/
+ x = keras.layers.GRU(16, return_sequences=True)(
+ frame_features_input, mask=mask_input
+ )
+ x = keras.layers.GRU(8)(x)
+ x = keras.layers.Dropout(0.4)(x)
+ x = keras.layers.Dense(8, activation="relu")(x)
+ output = keras.layers.Dense(len(class_vocab), activation="softmax")(x)
+
+ rnn_model = keras.Model([frame_features_input, mask_input], output)
+
+ rnn_model.compile(
+ loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
+ )
+ return rnn_model
+
+
+# Utility for running experiments.
+def run_experiment():
+ filepath = "/tmp/video_classifier/ckpt.weights.h5"
+ checkpoint = keras.callbacks.ModelCheckpoint(
+ filepath, save_weights_only=True, save_best_only=True, verbose=1
+ )
+
+ seq_model = get_sequence_model()
+ history = seq_model.fit(
+ [train_data[0], train_data[1]],
+ train_labels,
+ validation_split=0.3,
+ epochs=EPOCHS,
+ callbacks=[checkpoint],
+ )
+
+ seq_model.load_weights(filepath)
+ _, accuracy = seq_model.evaluate([test_data[0], test_data[1]], test_labels)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+ return history, seq_model
+
+
+_, sequence_model = run_experiment()
+
+"""
+**Note**: To keep the runtime of this example relatively short, we just used a few
+training examples. This number of training examples is low with respect to the sequence
+model being used that has 99,909 trainable parameters. You are encouraged to sample more
+data from the UCF101 dataset using [the notebook](https://colab.research.google.com/github/sayakpaul/Action-Recognition-in-TensorFlow/blob/main/Data_Preparation_UCF101.ipynb) mentioned above and train the same model.
+"""
+
+"""
+## Inference
+"""
+
+
+def prepare_single_video(frames):
+ frames = frames[None, ...]
+ frame_mask = np.zeros(
+ shape=(
+ 1,
+ MAX_SEQ_LENGTH,
+ ),
+ dtype="bool",
+ )
+ frame_features = np.zeros(shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32")
+
+ for i, batch in enumerate(frames):
+ video_length = batch.shape[0]
+ length = min(MAX_SEQ_LENGTH, video_length)
+ for j in range(length):
+ frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
+ frame_mask[i, :length] = 1 # 1 = not masked, 0 = masked
+
+ return frame_features, frame_mask
+
+
+def sequence_prediction(path):
+ class_vocab = label_processor.get_vocabulary()
+
+ frames = load_video(os.path.join("test", path))
+ frame_features, frame_mask = prepare_single_video(frames)
+ probabilities = sequence_model.predict([frame_features, frame_mask])[0]
+
+ for i in np.argsort(probabilities)[::-1]:
+ print(f" {class_vocab[i]}: {probabilities[i] * 100:5.2f}%")
+ return frames
+
+
+# This utility is for visualization.
+# Referenced from:
+# https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub
+def to_gif(images):
+ converted_images = images.astype(np.uint8)
+ imageio.mimsave("animation.gif", converted_images, duration=100)
+ return Image("animation.gif")
+
+
+test_video = np.random.choice(test_df["video_name"].values.tolist())
+print(f"Test video path: {test_video}")
+test_frames = sequence_prediction(test_video)
+to_gif(test_frames[:MAX_SEQ_LENGTH])
+
+"""
+## Next steps
+
+* In this example, we made use of transfer learning for extracting meaningful features
+from video frames. You could also fine-tune the pre-trained network to notice how that
+affects the end results.
+* For speed-accuracy trade-offs, you can try out other models present inside
+`keras.applications`.
+* Try different combinations of `MAX_SEQ_LENGTH` to observe how that affects the
+performance.
+* Train on a higher number of classes and see if you are able to get good performance.
+* Following [this tutorial](https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub), try a
+[pre-trained action recognition model](https://arxiv.org/abs/1705.07750) from DeepMind.
+* Rolling-averaging can be useful technique for video classification and it can be
+combined with a standard image classification model to infer on videos.
+[This tutorial](https://www.pyimagesearch.com/2019/07/15/video-classification-with-keras-and-deep-learning/)
+will help understand how to use rolling-averaging with an image classifier.
+* When there are variations in between the frames of a video not all the frames might be
+equally important to decide its category. In those situations, putting a
+[self-attention layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention) in the
+sequence model will likely yield better results.
+* Following [this book chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11),
+you can implement Transformers-based models for processing videos.
+"""
diff --git a/knowledge_base/vision/video_transformers.py b/knowledge_base/vision/video_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b21e1ae3a598588f3327a38c7115e3858448a4
--- /dev/null
+++ b/knowledge_base/vision/video_transformers.py
@@ -0,0 +1,414 @@
+"""
+Title: Video Classification with Transformers
+Author: [Sayak Paul](https://twitter.com/RisingSayak)
+Date created: 2021/06/08
+Last modified: 2023/22/07
+Description: Training a video classifier with hybrid transformers.
+Accelerator: GPU
+Converted to Keras 3 by: [Soumik Rakshit](http://github.com/soumik12345)
+"""
+
+"""
+This example is a follow-up to the
+[Video Classification with a CNN-RNN Architecture](https://keras.io/examples/vision/video_classification/)
+example. This time, we will be using a Transformer-based model
+([Vaswani et al.](https://arxiv.org/abs/1706.03762)) to classify videos. You can follow
+[this book chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11)
+in case you need an introduction to Transformers (with code). After reading this
+example, you will know how to develop hybrid Transformer-based models for video
+classification that operate on CNN feature maps.
+"""
+
+"""shell
+pip install -q git+https://github.com/tensorflow/docs
+"""
+
+"""
+## Data collection
+
+As done in the [predecessor](https://keras.io/examples/vision/video_classification/) to
+this example, we will be using a subsampled version of the
+[UCF101 dataset](https://www.crcv.ucf.edu/data/UCF101.php),
+a well-known benchmark dataset. In case you want to operate on a larger subsample or
+even the entire dataset, please refer to
+[this notebook](https://colab.research.google.com/github/sayakpaul/Action-Recognition-in-TensorFlow/blob/main/Data_Preparation_UCF101.ipynb).
+"""
+
+"""shell
+wget -q https://github.com/sayakpaul/Action-Recognition-in-TensorFlow/releases/download/v1.0.0/ucf101_top5.tar.gz
+tar -xf ucf101_top5.tar.gz
+"""
+
+"""
+## Setup
+"""
+
+import os
+import keras
+from keras import layers
+from keras.applications.densenet import DenseNet121
+
+from tensorflow_docs.vis import embed
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import numpy as np
+import imageio
+import cv2
+
+keras.utils.set_random_seed(1234)
+
+"""
+## Define hyperparameters
+"""
+
+MAX_SEQ_LENGTH = 20
+NUM_FEATURES = 1024
+IMG_SIZE = 128
+
+EPOCHS = 5
+
+"""
+## Data preparation
+
+We will mostly be following the same data preparation steps in this example, except for
+the following changes:
+
+* We reduce the image size to 128x128 instead of 224x224 to speed up computation.
+* Instead of using a pre-trained [InceptionV3](https://arxiv.org/abs/1512.00567) network,
+we use a pre-trained
+[DenseNet121](http://openaccess.thecvf.com/content_cvpr_2017/papers/Huang_Densely_Connected_Convolutional_CVPR_2017_paper.pdf)
+for feature extraction.
+* We directly pad shorter videos to length `MAX_SEQ_LENGTH`.
+
+First, let's load up the
+[DataFrames](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html).
+"""
+
+train_df = pd.read_csv("train.csv")
+test_df = pd.read_csv("test.csv")
+
+print(f"Total videos for training: {len(train_df)}")
+print(f"Total videos for testing: {len(test_df)}")
+
+center_crop_layer = layers.CenterCrop(IMG_SIZE, IMG_SIZE)
+
+
+def crop_center(frame):
+ cropped = center_crop_layer(frame[None, ...])
+ cropped = keras.ops.convert_to_numpy(cropped)
+ cropped = keras.ops.squeeze(cropped)
+ return cropped
+
+
+# Following method is modified from this tutorial:
+# https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub
+def load_video(path, max_frames=0, offload_to_cpu=False):
+ cap = cv2.VideoCapture(path)
+ frames = []
+ try:
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ frame = frame[:, :, [2, 1, 0]]
+ frame = crop_center(frame)
+ if offload_to_cpu and keras.backend.backend() == "torch":
+ frame = frame.to("cpu")
+ frames.append(frame)
+
+ if len(frames) == max_frames:
+ break
+ finally:
+ cap.release()
+ if offload_to_cpu and keras.backend.backend() == "torch":
+ return np.array([frame.to("cpu").numpy() for frame in frames])
+ return np.array(frames)
+
+
+def build_feature_extractor():
+ feature_extractor = DenseNet121(
+ weights="imagenet",
+ include_top=False,
+ pooling="avg",
+ input_shape=(IMG_SIZE, IMG_SIZE, 3),
+ )
+ preprocess_input = keras.applications.densenet.preprocess_input
+
+ inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3))
+ preprocessed = preprocess_input(inputs)
+
+ outputs = feature_extractor(preprocessed)
+ return keras.Model(inputs, outputs, name="feature_extractor")
+
+
+feature_extractor = build_feature_extractor()
+
+
+# Label preprocessing with StringLookup.
+label_processor = keras.layers.StringLookup(
+ num_oov_indices=0, vocabulary=np.unique(train_df["tag"]), mask_token=None
+)
+print(label_processor.get_vocabulary())
+
+
+def prepare_all_videos(df, root_dir):
+ num_samples = len(df)
+ video_paths = df["video_name"].values.tolist()
+ labels = df["tag"].values
+ labels = label_processor(labels[..., None]).numpy()
+
+ # `frame_features` are what we will feed to our sequence model.
+ frame_features = np.zeros(
+ shape=(num_samples, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32"
+ )
+
+ # For each video.
+ for idx, path in enumerate(video_paths):
+ # Gather all its frames and add a batch dimension.
+ frames = load_video(os.path.join(root_dir, path))
+
+ # Pad shorter videos.
+ if len(frames) < MAX_SEQ_LENGTH:
+ diff = MAX_SEQ_LENGTH - len(frames)
+ padding = np.zeros((diff, IMG_SIZE, IMG_SIZE, 3))
+ frames = np.concatenate(frames, padding)
+
+ frames = frames[None, ...]
+
+ # Initialize placeholder to store the features of the current video.
+ temp_frame_features = np.zeros(
+ shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32"
+ )
+
+ # Extract features from the frames of the current video.
+ for i, batch in enumerate(frames):
+ video_length = batch.shape[0]
+ length = min(MAX_SEQ_LENGTH, video_length)
+ for j in range(length):
+ if np.mean(batch[j, :]) > 0.0:
+ temp_frame_features[i, j, :] = feature_extractor.predict(
+ batch[None, j, :]
+ )
+
+ else:
+ temp_frame_features[i, j, :] = 0.0
+
+ frame_features[idx,] = temp_frame_features.squeeze()
+
+ return frame_features, labels
+
+
+"""
+Calling `prepare_all_videos()` on `train_df` and `test_df` takes ~20 minutes to
+complete. For this reason, to save time, here we download already preprocessed NumPy arrays:
+"""
+
+"""shell
+!wget -q https://git.io/JZmf4 -O top5_data_prepared.tar.gz
+!tar -xf top5_data_prepared.tar.gz
+"""
+
+train_data, train_labels = np.load("train_data.npy"), np.load("train_labels.npy")
+test_data, test_labels = np.load("test_data.npy"), np.load("test_labels.npy")
+
+print(f"Frame features in train set: {train_data.shape}")
+
+"""
+## Building the Transformer-based model
+
+We will be building on top of the code shared in
+[this book chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11) of
+[Deep Learning with Python (Second ed.)](https://www.manning.com/books/deep-learning-with-python)
+by Franรงois Chollet.
+
+First, self-attention layers that form the basic blocks of a Transformer are
+order-agnostic. Since videos are ordered sequences of frames, we need our
+Transformer model to take into account order information.
+We do this via **positional encoding**.
+We simply embed the positions of the frames present inside videos with an
+[`Embedding` layer](https://keras.io/api/layers/core_layers/embedding). We then
+add these positional embeddings to the precomputed CNN feature maps.
+"""
+
+
+class PositionalEmbedding(layers.Layer):
+ def __init__(self, sequence_length, output_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.position_embeddings = layers.Embedding(
+ input_dim=sequence_length, output_dim=output_dim
+ )
+ self.sequence_length = sequence_length
+ self.output_dim = output_dim
+
+ def build(self, input_shape):
+ self.position_embeddings.build(input_shape)
+
+ def call(self, inputs):
+ # The inputs are of shape: `(batch_size, frames, num_features)`
+ inputs = keras.ops.cast(inputs, self.compute_dtype)
+ length = keras.ops.shape(inputs)[1]
+ positions = keras.ops.arange(start=0, stop=length, step=1)
+ embedded_positions = self.position_embeddings(positions)
+ return inputs + embedded_positions
+
+
+"""
+Now, we can create a subclassed layer for the Transformer.
+"""
+
+
+class TransformerEncoder(layers.Layer):
+ def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+ self.dense_dim = dense_dim
+ self.num_heads = num_heads
+ self.attention = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.3
+ )
+ self.dense_proj = keras.Sequential(
+ [
+ layers.Dense(dense_dim, activation=keras.activations.gelu),
+ layers.Dense(embed_dim),
+ ]
+ )
+ self.layernorm_1 = layers.LayerNormalization()
+ self.layernorm_2 = layers.LayerNormalization()
+
+ def call(self, inputs, mask=None):
+ attention_output = self.attention(inputs, inputs, attention_mask=mask)
+ proj_input = self.layernorm_1(inputs + attention_output)
+ proj_output = self.dense_proj(proj_input)
+ return self.layernorm_2(proj_input + proj_output)
+
+
+"""
+## Utility functions for training
+"""
+
+
+def get_compiled_model(shape):
+ sequence_length = MAX_SEQ_LENGTH
+ embed_dim = NUM_FEATURES
+ dense_dim = 4
+ num_heads = 1
+ classes = len(label_processor.get_vocabulary())
+
+ inputs = keras.Input(shape=shape)
+ x = PositionalEmbedding(
+ sequence_length, embed_dim, name="frame_position_embedding"
+ )(inputs)
+ x = TransformerEncoder(embed_dim, dense_dim, num_heads, name="transformer_layer")(x)
+ x = layers.GlobalMaxPooling1D()(x)
+ x = layers.Dropout(0.5)(x)
+ outputs = layers.Dense(classes, activation="softmax")(x)
+ model = keras.Model(inputs, outputs)
+
+ model.compile(
+ optimizer="adam",
+ loss="sparse_categorical_crossentropy",
+ metrics=["accuracy"],
+ )
+ return model
+
+
+def run_experiment():
+ filepath = "/tmp/video_classifier.weights.h5"
+ checkpoint = keras.callbacks.ModelCheckpoint(
+ filepath, save_weights_only=True, save_best_only=True, verbose=1
+ )
+
+ model = get_compiled_model(train_data.shape[1:])
+ history = model.fit(
+ train_data,
+ train_labels,
+ validation_split=0.15,
+ epochs=EPOCHS,
+ callbacks=[checkpoint],
+ )
+
+ model.load_weights(filepath)
+ _, accuracy = model.evaluate(test_data, test_labels)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+
+ return model
+
+
+"""
+## Model training and inference
+"""
+
+trained_model = run_experiment()
+
+"""
+**Note**: This model has ~4.23 Million parameters, which is way more than the sequence
+model (99918 parameters) we used in the prequel of this example. This kind of
+Transformer model works best with a larger dataset and a longer pre-training schedule.
+"""
+
+
+def prepare_single_video(frames):
+ frame_features = np.zeros(shape=(1, MAX_SEQ_LENGTH, NUM_FEATURES), dtype="float32")
+
+ # Pad shorter videos.
+ if len(frames) < MAX_SEQ_LENGTH:
+ diff = MAX_SEQ_LENGTH - len(frames)
+ padding = np.zeros((diff, IMG_SIZE, IMG_SIZE, 3))
+ frames = np.concatenate(frames, padding)
+
+ frames = frames[None, ...]
+
+ # Extract features from the frames of the current video.
+ for i, batch in enumerate(frames):
+ video_length = batch.shape[0]
+ length = min(MAX_SEQ_LENGTH, video_length)
+ for j in range(length):
+ if np.mean(batch[j, :]) > 0.0:
+ frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :])
+ else:
+ frame_features[i, j, :] = 0.0
+
+ return frame_features
+
+
+def predict_action(path):
+ class_vocab = label_processor.get_vocabulary()
+
+ frames = load_video(os.path.join("test", path), offload_to_cpu=True)
+ frame_features = prepare_single_video(frames)
+ probabilities = trained_model.predict(frame_features)[0]
+
+ plot_x_axis, plot_y_axis = [], []
+
+ for i in np.argsort(probabilities)[::-1]:
+ plot_x_axis.append(class_vocab[i])
+ plot_y_axis.append(probabilities[i])
+ print(f" {class_vocab[i]}: {probabilities[i] * 100:5.2f}%")
+
+ plt.bar(plot_x_axis, plot_y_axis, label=plot_x_axis)
+ plt.xlabel("class_label")
+ plt.xlabel("Probability")
+ plt.show()
+
+ return frames
+
+
+# This utility is for visualization.
+# Referenced from:
+# https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub
+def to_gif(images):
+ converted_images = images.astype(np.uint8)
+ imageio.mimsave("animation.gif", converted_images, fps=10)
+ return embed.embed_file("animation.gif")
+
+
+test_video = np.random.choice(test_df["video_name"].values.tolist())
+print(f"Test video path: {test_video}")
+test_frames = predict_action(test_video)
+to_gif(test_frames[:MAX_SEQ_LENGTH])
+
+"""
+The performance of our model is far from optimal, because it was trained on a
+small dataset.
+"""
diff --git a/knowledge_base/vision/visualizing_what_convnets_learn.py b/knowledge_base/vision/visualizing_what_convnets_learn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d855ec84e8a00dd1b4e9d50147437227694a5f58
--- /dev/null
+++ b/knowledge_base/vision/visualizing_what_convnets_learn.py
@@ -0,0 +1,204 @@
+"""
+Title: Visualizing what convnets learn
+Author: [fchollet](https://twitter.com/fchollet)
+Date created: 2020/05/29
+Last modified: 2020/05/29
+Description: Displaying the visual patterns that convnet filters respond to.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+In this example, we look into what sort of visual patterns image classification models
+learn. We'll be using the `ResNet50V2` model, trained on the ImageNet dataset.
+
+Our process is simple: we will create input images that maximize the activation of
+specific filters in a target layer (picked somewhere in the middle of the model: layer
+`conv3_block4_out`). Such images represent a visualization of the
+pattern that the filter responds to.
+"""
+
+"""
+## Setup
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+import numpy as np
+import tensorflow as tf
+
+# The dimensions of our input image
+img_width = 180
+img_height = 180
+# Our target layer: we will visualize the filters from this layer.
+# See `model.summary()` for list of layer names, if you want to change this.
+layer_name = "conv3_block4_out"
+
+"""
+## Build a feature extraction model
+"""
+
+# Build a ResNet50V2 model loaded with pre-trained ImageNet weights
+model = keras.applications.ResNet50V2(weights="imagenet", include_top=False)
+
+# Set up a model that returns the activation values for our target layer
+layer = model.get_layer(name=layer_name)
+feature_extractor = keras.Model(inputs=model.inputs, outputs=layer.output)
+
+"""
+## Set up the gradient ascent process
+
+The "loss" we will maximize is simply the mean of the activation of a specific filter in
+our target layer. To avoid border effects, we exclude border pixels.
+"""
+
+
+def compute_loss(input_image, filter_index):
+ activation = feature_extractor(input_image)
+ # We avoid border artifacts by only involving non-border pixels in the loss.
+ filter_activation = activation[:, 2:-2, 2:-2, filter_index]
+ return tf.reduce_mean(filter_activation)
+
+
+"""
+Our gradient ascent function simply computes the gradients of the loss above
+with regard to the input image, and update the update image so as to move it
+towards a state that will activate the target filter more strongly.
+"""
+
+
+@tf.function
+def gradient_ascent_step(img, filter_index, learning_rate):
+ with tf.GradientTape() as tape:
+ tape.watch(img)
+ loss = compute_loss(img, filter_index)
+ # Compute gradients.
+ grads = tape.gradient(loss, img)
+ # Normalize gradients.
+ grads = tf.math.l2_normalize(grads)
+ img += learning_rate * grads
+ return loss, img
+
+
+"""
+## Set up the end-to-end filter visualization loop
+
+Our process is as follow:
+
+- Start from a random image that is close to "all gray" (i.e. visually netural)
+- Repeatedly apply the gradient ascent step function defined above
+- Convert the resulting input image back to a displayable form, by normalizing it,
+center-cropping it, and restricting it to the [0, 255] range.
+"""
+
+
+def initialize_image():
+ # We start from a gray image with some random noise
+ img = tf.random.uniform((1, img_width, img_height, 3))
+ # ResNet50V2 expects inputs in the range [-1, +1].
+ # Here we scale our random inputs to [-0.125, +0.125]
+ return (img - 0.5) * 0.25
+
+
+def visualize_filter(filter_index):
+ # We run gradient ascent for 20 steps
+ iterations = 30
+ learning_rate = 10.0
+ img = initialize_image()
+ for iteration in range(iterations):
+ loss, img = gradient_ascent_step(img, filter_index, learning_rate)
+
+ # Decode the resulting input image
+ img = deprocess_image(img[0].numpy())
+ return loss, img
+
+
+def deprocess_image(img):
+ # Normalize array: center on 0., ensure variance is 0.15
+ img -= img.mean()
+ img /= img.std() + 1e-5
+ img *= 0.15
+
+ # Center crop
+ img = img[25:-25, 25:-25, :]
+
+ # Clip to [0, 1]
+ img += 0.5
+ img = np.clip(img, 0, 1)
+
+ # Convert to RGB array
+ img *= 255
+ img = np.clip(img, 0, 255).astype("uint8")
+ return img
+
+
+"""
+Let's try it out with filter 0 in the target layer:
+"""
+
+from IPython.display import Image, display
+
+loss, img = visualize_filter(0)
+keras.utils.save_img("0.png", img)
+
+"""
+This is what an input that maximizes the response of filter 0 in the target layer would
+look like:
+"""
+
+display(Image("0.png"))
+
+"""
+## Visualize the first 64 filters in the target layer
+
+Now, let's make a 8x8 grid of the first 64 filters
+in the target layer to get of feel for the range
+of different visual patterns that the model has learned.
+"""
+
+# Compute image inputs that maximize per-filter activations
+# for the first 64 filters of our target layer
+all_imgs = []
+for filter_index in range(64):
+ print("Processing filter %d" % (filter_index,))
+ loss, img = visualize_filter(filter_index)
+ all_imgs.append(img)
+
+# Build a black picture with enough space for
+# our 8 x 8 filters of size 128 x 128, with a 5px margin in between
+margin = 5
+n = 8
+cropped_width = img_width - 25 * 2
+cropped_height = img_height - 25 * 2
+width = n * cropped_width + (n - 1) * margin
+height = n * cropped_height + (n - 1) * margin
+stitched_filters = np.zeros((width, height, 3))
+
+# Fill the picture with our saved filters
+for i in range(n):
+ for j in range(n):
+ img = all_imgs[i * n + j]
+ stitched_filters[
+ (cropped_width + margin) * i : (cropped_width + margin) * i + cropped_width,
+ (cropped_height + margin) * j : (cropped_height + margin) * j
+ + cropped_height,
+ :,
+ ] = img
+keras.utils.save_img("stiched_filters.png", stitched_filters)
+
+from IPython.display import Image, display
+
+display(Image("stiched_filters.png"))
+
+"""
+Image classification models see the world by decomposing their inputs over a "vector
+basis" of texture filters such as these.
+
+See also
+[this old blog post](https://blog.keras.io/how-convolutional-neural-networks-see-the-world.html)
+for analysis and interpretation.
+"""
diff --git a/knowledge_base/vision/vit_small_ds.py b/knowledge_base/vision/vit_small_ds.py
new file mode 100644
index 0000000000000000000000000000000000000000..525967fca619e09dc2867dd46ba437942917ec94
--- /dev/null
+++ b/knowledge_base/vision/vit_small_ds.py
@@ -0,0 +1,552 @@
+"""
+Title: Train a Vision Transformer on small datasets
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
+Date created: 2022/01/07
+Last modified: 2024/11/27
+Description: Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention.
+Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
+"""
+
+"""
+## Introduction
+
+In the academic paper
+[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929),
+the authors mention that Vision Transformers (ViT) are data-hungry. Therefore,
+pretraining a ViT on a large-sized dataset like JFT300M and fine-tuning
+it on medium-sized datasets (like ImageNet) is the only way to beat
+state-of-the-art Convolutional Neural Network models.
+
+The self-attention layer of ViT lacks **locality inductive bias** (the notion that
+image pixels are locally correlated and that their correlation maps are translation-invariant).
+This is the reason why ViTs need more data. On the other hand, CNNs look at images through
+spatial sliding windows, which helps them get better results with smaller datasets.
+
+In the academic paper
+[Vision Transformer for Small-Size Datasets](https://arxiv.org/abs/2112.13492v1),
+the authors set out to tackle the problem of locality inductive bias in ViTs.
+
+The main ideas are:
+
+- **Shifted Patch Tokenization**
+- **Locality Self Attention**
+
+This example implements the ideas of the paper. A large part of this
+example is inspired from
+[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
+
+_Note_: This example requires TensorFlow 2.6 or higher.
+```
+"""
+"""
+## Setup
+"""
+import math
+import numpy as np
+import keras
+from keras import ops
+from keras import layers
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+# Setting seed for reproducibiltiy
+SEED = 42
+keras.utils.set_random_seed(SEED)
+
+"""
+## Prepare the data
+"""
+
+NUM_CLASSES = 100
+INPUT_SHAPE = (32, 32, 3)
+
+(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
+
+print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
+print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
+
+"""
+## Configure the hyperparameters
+
+The hyperparameters are different from the paper. Feel free to tune
+the hyperparameters yourself.
+"""
+
+# DATA
+BUFFER_SIZE = 512
+BATCH_SIZE = 256
+
+# AUGMENTATION
+IMAGE_SIZE = 72
+PATCH_SIZE = 6
+NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
+
+# OPTIMIZER
+LEARNING_RATE = 0.001
+WEIGHT_DECAY = 0.0001
+
+# TRAINING
+EPOCHS = 50
+
+# ARCHITECTURE
+LAYER_NORM_EPS = 1e-6
+TRANSFORMER_LAYERS = 8
+PROJECTION_DIM = 64
+NUM_HEADS = 4
+TRANSFORMER_UNITS = [
+ PROJECTION_DIM * 2,
+ PROJECTION_DIM,
+]
+MLP_HEAD_UNITS = [2048, 1024]
+
+"""
+## Use data augmentation
+
+A snippet from the paper:
+
+*"According to DeiT, various techniques are required to effectively
+train ViTs. Thus, we applied data augmentations such as CutMix, Mixup,
+Auto Augment, Repeated Augment to all models."*
+
+In this example, we will focus solely on the novelty of the approach
+and not on reproducing the paper results. For this reason, we
+don't use the mentioned data augmentation schemes. Please feel
+free to add to or remove from the augmentation pipeline.
+"""
+
+data_augmentation = keras.Sequential(
+ [
+ layers.Normalization(),
+ layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
+ layers.RandomFlip("horizontal"),
+ layers.RandomRotation(factor=0.02),
+ layers.RandomZoom(height_factor=0.2, width_factor=0.2),
+ ],
+ name="data_augmentation",
+)
+# Compute the mean and the variance of the training data for normalization.
+data_augmentation.layers[0].adapt(x_train)
+
+"""
+## Implement Shifted Patch Tokenization
+
+In a ViT pipeline, the input images are divided into patches that are
+then linearly projected into tokens. Shifted patch tokenization (STP)
+is introduced to combat the low receptive field of ViTs. The steps
+for Shifted Patch Tokenization are as follows:
+
+- Start with an image.
+- Shift the image in diagonal directions.
+- Concat the diagonally shifted images with the original image.
+- Extract patches of the concatenated images.
+- Flatten the spatial dimension of all patches.
+- Layer normalize the flattened patches and then project it.
+
+|  |
+| :--: |
+| Shifted Patch Tokenization [Source](https://arxiv.org/abs/2112.13492v1) |
+"""
+
+
+class ShiftedPatchTokenization(layers.Layer):
+ def __init__(
+ self,
+ image_size=IMAGE_SIZE,
+ patch_size=PATCH_SIZE,
+ num_patches=NUM_PATCHES,
+ projection_dim=PROJECTION_DIM,
+ vanilla=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vanilla = vanilla # Flag to swtich to vanilla patch extractor
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.half_patch = patch_size // 2
+ self.flatten_patches = layers.Reshape((num_patches, -1))
+ self.projection = layers.Dense(units=projection_dim)
+ self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
+
+ def crop_shift_pad(self, images, mode):
+ # Build the diagonally shifted images
+ if mode == "left-up":
+ crop_height = self.half_patch
+ crop_width = self.half_patch
+ shift_height = 0
+ shift_width = 0
+ elif mode == "left-down":
+ crop_height = 0
+ crop_width = self.half_patch
+ shift_height = self.half_patch
+ shift_width = 0
+ elif mode == "right-up":
+ crop_height = self.half_patch
+ crop_width = 0
+ shift_height = 0
+ shift_width = self.half_patch
+ else:
+ crop_height = 0
+ crop_width = 0
+ shift_height = self.half_patch
+ shift_width = self.half_patch
+
+ # Crop the shifted images and pad them
+ crop = ops.image.crop_images(
+ images,
+ top_cropping=crop_height,
+ left_cropping=crop_width,
+ target_height=self.image_size - self.half_patch,
+ target_width=self.image_size - self.half_patch,
+ )
+ shift_pad = ops.image.pad_images(
+ crop,
+ top_padding=shift_height,
+ left_padding=shift_width,
+ target_height=self.image_size,
+ target_width=self.image_size,
+ )
+ return shift_pad
+
+ def call(self, images):
+ if not self.vanilla:
+ # Concat the shifted images with the original image
+ images = ops.concatenate(
+ [
+ images,
+ self.crop_shift_pad(images, mode="left-up"),
+ self.crop_shift_pad(images, mode="left-down"),
+ self.crop_shift_pad(images, mode="right-up"),
+ self.crop_shift_pad(images, mode="right-down"),
+ ],
+ axis=-1,
+ )
+ # Patchify the images and flatten it
+ patches = ops.image.extract_patches(
+ images=images,
+ size=(self.patch_size, self.patch_size),
+ strides=[1, self.patch_size, self.patch_size, 1],
+ dilation_rate=1,
+ padding="VALID",
+ )
+ flat_patches = self.flatten_patches(patches)
+ if not self.vanilla:
+ # Layer normalize the flat patches and linearly project it
+ tokens = self.layer_norm(flat_patches)
+ tokens = self.projection(tokens)
+ else:
+ # Linearly project the flat patches
+ tokens = self.projection(flat_patches)
+ return (tokens, patches)
+
+
+"""
+### Visualize the patches
+"""
+
+# Get a random image from the training dataset
+# and resize the image
+image = x_train[np.random.choice(range(x_train.shape[0]))]
+resized_image = ops.cast(
+ ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),
+ dtype="float32",
+)
+
+# Vanilla patch maker: This takes an image and divides into
+# patches as in the original ViT paper
+(token, patch) = ShiftedPatchTokenization(vanilla=True)(resized_image / 255.0)
+(token, patch) = (token[0], patch[0])
+n = patch.shape[0]
+count = 1
+plt.figure(figsize=(4, 4))
+for row in range(n):
+ for col in range(n):
+ plt.subplot(n, n, count)
+ count = count + 1
+ image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
+ plt.imshow(image)
+ plt.axis("off")
+plt.show()
+
+# Shifted Patch Tokenization: This layer takes the image, shifts it
+# diagonally and then extracts patches from the concatinated images
+(token, patch) = ShiftedPatchTokenization(vanilla=False)(resized_image / 255.0)
+(token, patch) = (token[0], patch[0])
+n = patch.shape[0]
+shifted_images = ["ORIGINAL", "LEFT-UP", "LEFT-DOWN", "RIGHT-UP", "RIGHT-DOWN"]
+for index, name in enumerate(shifted_images):
+ print(name)
+ count = 1
+ plt.figure(figsize=(4, 4))
+ for row in range(n):
+ for col in range(n):
+ plt.subplot(n, n, count)
+ count = count + 1
+ image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
+ plt.imshow(image[..., 3 * index : 3 * index + 3])
+ plt.axis("off")
+ plt.show()
+
+"""
+## Implement the patch encoding layer
+
+This layer accepts projected patches and then adds positional
+information to them.
+"""
+
+
+class PatchEncoder(layers.Layer):
+ def __init__(
+ self, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM, **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.num_patches = num_patches
+ self.position_embedding = layers.Embedding(
+ input_dim=num_patches, output_dim=projection_dim
+ )
+ self.positions = ops.arange(start=0, stop=self.num_patches, step=1)
+
+ def call(self, encoded_patches):
+ encoded_positions = self.position_embedding(self.positions)
+ encoded_patches = encoded_patches + encoded_positions
+ return encoded_patches
+
+
+"""
+## Implement Locality Self Attention
+
+The regular attention equation is stated below.
+
+|  |
+| :--: |
+| [Source](https://towardsdatascience.com/attention-is-all-you-need-discovering-the-transformer-paper-73e5ff5e0634) |
+
+The attention module takes a query, key, and value. First, we compute the
+similarity between the query and key via a dot product. Then, the result
+is scaled by the square root of the key dimension. The scaling prevents
+the softmax function from having an overly small gradient. Softmax is then
+applied to the scaled dot product to produce the attention weights.
+The value is then modulated via the attention weights.
+
+In self-attention, query, key and value come from the same input.
+The dot product would result in large self-token relations rather than
+inter-token relations. This also means that the softmax gives higher
+probabilities to self-token relations than the inter-token relations.
+To combat this, the authors propose masking the diagonal of the dot product.
+This way, we force the attention module to pay more attention to the
+inter-token relations.
+
+The scaling factor is a constant in the regular attention module.
+This acts like a temperature term that can modulate the softmax function.
+The authors suggest a learnable temperature term instead of a constant.
+
+|  |
+| :--: |
+| Locality Self Attention [Source](https://arxiv.org/abs/2112.13492v1) |
+
+The above two pointers make the Locality Self Attention. We have subclassed the
+[`layers.MultiHeadAttention`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention)
+and implemented the trainable temperature. The attention mask is built
+at a later stage.
+"""
+
+
+class MultiHeadAttentionLSA(layers.MultiHeadAttention):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ # The trainable temperature term. The initial value is
+ # the square root of the key dimension.
+ self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)
+
+ def _compute_attention(self, query, key, value, attention_mask=None, training=None):
+ query = ops.multiply(query, 1.0 / self.tau)
+ attention_scores = ops.einsum(self._dot_product_equation, key, query)
+ attention_scores = self._masked_softmax(attention_scores, attention_mask)
+ attention_scores_dropout = self._dropout_layer(
+ attention_scores, training=training
+ )
+ attention_output = ops.einsum(
+ self._combine_equation, attention_scores_dropout, value
+ )
+ return attention_output, attention_scores
+
+
+"""
+## Implement the MLP
+"""
+
+
+def mlp(x, hidden_units, dropout_rate):
+ for units in hidden_units:
+ x = layers.Dense(units, activation="gelu")(x)
+ x = layers.Dropout(dropout_rate)(x)
+ return x
+
+
+# Build the diagonal attention mask
+diag_attn_mask = 1 - ops.eye(NUM_PATCHES)
+diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8")
+
+"""
+## Build the ViT
+"""
+
+
+def create_vit_classifier(vanilla=False):
+ inputs = layers.Input(shape=INPUT_SHAPE)
+ # Augment data.
+ augmented = data_augmentation(inputs)
+ # Create patches.
+ (tokens, _) = ShiftedPatchTokenization(vanilla=vanilla)(augmented)
+ # Encode patches.
+ encoded_patches = PatchEncoder()(tokens)
+
+ # Create multiple layers of the Transformer block.
+ for _ in range(TRANSFORMER_LAYERS):
+ # Layer normalization 1.
+ x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ # Create a multi-head attention layer.
+ if not vanilla:
+ attention_output = MultiHeadAttentionLSA(
+ num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
+ )(x1, x1, attention_mask=diag_attn_mask)
+ else:
+ attention_output = layers.MultiHeadAttention(
+ num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
+ )(x1, x1)
+ # Skip connection 1.
+ x2 = layers.Add()([attention_output, encoded_patches])
+ # Layer normalization 2.
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ # MLP.
+ x3 = mlp(x3, hidden_units=TRANSFORMER_UNITS, dropout_rate=0.1)
+ # Skip connection 2.
+ encoded_patches = layers.Add()([x3, x2])
+
+ # Create a [batch_size, projection_dim] tensor.
+ representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ representation = layers.Flatten()(representation)
+ representation = layers.Dropout(0.5)(representation)
+ # Add MLP.
+ features = mlp(representation, hidden_units=MLP_HEAD_UNITS, dropout_rate=0.5)
+ # Classify outputs.
+ logits = layers.Dense(NUM_CLASSES)(features)
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=logits)
+ return model
+
+
+"""
+## Compile, train, and evaluate the mode
+"""
+
+
+# Some code is taken from:
+# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
+class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
+ def __init__(
+ self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
+ ):
+ super().__init__()
+
+ self.learning_rate_base = learning_rate_base
+ self.total_steps = total_steps
+ self.warmup_learning_rate = warmup_learning_rate
+ self.warmup_steps = warmup_steps
+ self.pi = ops.array(np.pi)
+
+ def __call__(self, step):
+ if self.total_steps < self.warmup_steps:
+ raise ValueError("Total_steps must be larger or equal to warmup_steps.")
+
+ cos_annealed_lr = ops.cos(
+ self.pi
+ * (ops.cast(step, dtype="float32") - self.warmup_steps)
+ / float(self.total_steps - self.warmup_steps)
+ )
+ learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
+
+ if self.warmup_steps > 0:
+ if self.learning_rate_base < self.warmup_learning_rate:
+ raise ValueError(
+ "Learning_rate_base must be larger or equal to "
+ "warmup_learning_rate."
+ )
+ slope = (
+ self.learning_rate_base - self.warmup_learning_rate
+ ) / self.warmup_steps
+ warmup_rate = (
+ slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate
+ )
+ learning_rate = ops.where(
+ step < self.warmup_steps, warmup_rate, learning_rate
+ )
+ return ops.where(
+ step > self.total_steps, 0.0, learning_rate, name="learning_rate"
+ )
+
+
+def run_experiment(model):
+ total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
+ warmup_epoch_percentage = 0.10
+ warmup_steps = int(total_steps * warmup_epoch_percentage)
+ scheduled_lrs = WarmUpCosine(
+ learning_rate_base=LEARNING_RATE,
+ total_steps=total_steps,
+ warmup_learning_rate=0.0,
+ warmup_steps=warmup_steps,
+ )
+
+ optimizer = keras.optimizers.AdamW(
+ learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
+ )
+
+ model.compile(
+ optimizer=optimizer,
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+ )
+
+ history = model.fit(
+ x=x_train,
+ y=y_train,
+ batch_size=BATCH_SIZE,
+ epochs=EPOCHS,
+ validation_split=0.1,
+ )
+ _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test, batch_size=BATCH_SIZE)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+ return history
+
+
+# Run experiments with the vanilla ViT
+vit = create_vit_classifier(vanilla=True)
+history = run_experiment(vit)
+
+# Run experiments with the Shifted Patch Tokenization and
+# Locality Self Attention modified ViT
+vit_sl = create_vit_classifier(vanilla=False)
+history = run_experiment(vit_sl)
+
+"""
+# Final Notes
+
+With the help of Shifted Patch Tokenization and Locality Self Attention,
+we were able to get ~**3-4%** top-1 accuracy gains on CIFAR100.
+
+The ideas on Shifted Patch Tokenization and Locality Self Attention
+are very intuitive and easy to implement. The authors also ablates of
+different shifting strategies for Shifted Patch Tokenization in the
+supplementary of the paper.
+
+I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for
+generously helping with GPU credits.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds).
+"""
diff --git a/knowledge_base/vision/vivit.py b/knowledge_base/vision/vivit.py
new file mode 100644
index 0000000000000000000000000000000000000000..116bddf14229548c31e897f1c18c408c3d81f868
--- /dev/null
+++ b/knowledge_base/vision/vivit.py
@@ -0,0 +1,447 @@
+"""
+Title: Video Vision Transformer
+Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ayush Thakur](https://twitter.com/ayushthakur0) (equal contribution)
+Date created: 2022/01/12
+Last modified: 2024/01/15
+Description: A Transformer-based architecture for video classification.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+Videos are sequences of images. Let's assume you have an image
+representation model (CNN, ViT, etc.) and a sequence model
+(RNN, LSTM, etc.) at hand. We ask you to tweak the model for video
+classification. The simplest approach would be to apply the image
+model to individual frames, use the sequence model to learn
+sequences of image features, then apply a classification head on
+the learned sequence representation.
+The Keras example
+[Video Classification with a CNN-RNN Architecture](https://keras.io/examples/vision/video_classification/)
+explains this approach in detail. Alernatively, you can also
+build a hybrid Transformer-based model for video classification as shown in the Keras example
+[Video Classification with Transformers](https://keras.io/examples/vision/video_transformers/).
+
+In this example, we minimally implement
+[ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691)
+by Arnab et al., a **pure Transformer-based** model
+for video classification. The authors propose a novel embedding scheme
+and a number of Transformer variants to model video clips. We implement
+the embedding scheme and one of the variants of the Transformer
+architecture, for simplicity.
+
+This example requires `medmnist` package, which can be installed
+by running the code cell below.
+"""
+
+"""shell
+pip install -qq medmnist
+"""
+
+"""
+## Imports
+"""
+
+import os
+import io
+import imageio
+import medmnist
+import ipywidgets
+import numpy as np
+import tensorflow as tf # for data preprocessing only
+import keras
+from keras import layers, ops
+
+# Setting seed for reproducibility
+SEED = 42
+os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
+keras.utils.set_random_seed(SEED)
+
+"""
+## Hyperparameters
+
+The hyperparameters are chosen via hyperparameter
+search. You can learn more about the process in the "conclusion" section.
+"""
+
+# DATA
+DATASET_NAME = "organmnist3d"
+BATCH_SIZE = 32
+AUTO = tf.data.AUTOTUNE
+INPUT_SHAPE = (28, 28, 28, 1)
+NUM_CLASSES = 11
+
+# OPTIMIZER
+LEARNING_RATE = 1e-4
+WEIGHT_DECAY = 1e-5
+
+# TRAINING
+EPOCHS = 60
+
+# TUBELET EMBEDDING
+PATCH_SIZE = (8, 8, 8)
+NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2
+
+# ViViT ARCHITECTURE
+LAYER_NORM_EPS = 1e-6
+PROJECTION_DIM = 128
+NUM_HEADS = 8
+NUM_LAYERS = 8
+
+"""
+## Dataset
+
+For our example we use the
+[MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification](https://medmnist.com/)
+dataset. The videos are lightweight and easy to train on.
+"""
+
+
+def download_and_prepare_dataset(data_info: dict):
+ """Utility function to download the dataset.
+
+ Arguments:
+ data_info (dict): Dataset metadata.
+ """
+ data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])
+
+ with np.load(data_path) as data:
+ # Get videos
+ train_videos = data["train_images"]
+ valid_videos = data["val_images"]
+ test_videos = data["test_images"]
+
+ # Get labels
+ train_labels = data["train_labels"].flatten()
+ valid_labels = data["val_labels"].flatten()
+ test_labels = data["test_labels"].flatten()
+
+ return (
+ (train_videos, train_labels),
+ (valid_videos, valid_labels),
+ (test_videos, test_labels),
+ )
+
+
+# Get the metadata of the dataset
+info = medmnist.INFO[DATASET_NAME]
+
+# Get the dataset
+prepared_dataset = download_and_prepare_dataset(info)
+(train_videos, train_labels) = prepared_dataset[0]
+(valid_videos, valid_labels) = prepared_dataset[1]
+(test_videos, test_labels) = prepared_dataset[2]
+
+"""
+### `tf.data` pipeline
+"""
+
+
+def preprocess(frames: tf.Tensor, label: tf.Tensor):
+ """Preprocess the frames tensors and parse the labels."""
+ # Preprocess images
+ frames = tf.image.convert_image_dtype(
+ frames[
+ ..., tf.newaxis
+ ], # The new axis is to help for further processing with Conv3D layers
+ tf.float32,
+ )
+ # Parse label
+ label = tf.cast(label, tf.float32)
+ return frames, label
+
+
+def prepare_dataloader(
+ videos: np.ndarray,
+ labels: np.ndarray,
+ loader_type: str = "train",
+ batch_size: int = BATCH_SIZE,
+):
+ """Utility function to prepare the dataloader."""
+ dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
+
+ if loader_type == "train":
+ dataset = dataset.shuffle(BATCH_SIZE * 2)
+
+ dataloader = (
+ dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
+ .batch(batch_size)
+ .prefetch(tf.data.AUTOTUNE)
+ )
+ return dataloader
+
+
+trainloader = prepare_dataloader(train_videos, train_labels, "train")
+validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
+testloader = prepare_dataloader(test_videos, test_labels, "test")
+
+"""
+## Tubelet Embedding
+
+In ViTs, an image is divided into patches, which are then spatially
+flattened, a process known as tokenization. For a video, one can
+repeat this process for individual frames. **Uniform frame sampling**
+as suggested by the authors is a tokenization scheme in which we
+sample frames from the video clip and perform simple ViT tokenization.
+
+|  |
+| :--: |
+| Uniform Frame Sampling [Source](https://arxiv.org/abs/2103.15691) |
+
+**Tubelet Embedding** is different in terms of capturing temporal
+information from the video.
+First, we extract volumes from the video -- these volumes contain
+patches of the frame and the temporal information as well. The volumes
+are then flattened to build video tokens.
+
+|  |
+| :--: |
+| Tubelet Embedding [Source](https://arxiv.org/abs/2103.15691) |
+"""
+
+
+class TubeletEmbedding(layers.Layer):
+ def __init__(self, embed_dim, patch_size, **kwargs):
+ super().__init__(**kwargs)
+ self.projection = layers.Conv3D(
+ filters=embed_dim,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="VALID",
+ )
+ self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
+
+ def call(self, videos):
+ projected_patches = self.projection(videos)
+ flattened_patches = self.flatten(projected_patches)
+ return flattened_patches
+
+
+"""
+## Positional Embedding
+
+This layer adds positional information to the encoded video tokens.
+"""
+
+
+class PositionalEncoder(layers.Layer):
+ def __init__(self, embed_dim, **kwargs):
+ super().__init__(**kwargs)
+ self.embed_dim = embed_dim
+
+ def build(self, input_shape):
+ _, num_tokens, _ = input_shape
+ self.position_embedding = layers.Embedding(
+ input_dim=num_tokens, output_dim=self.embed_dim
+ )
+ self.positions = ops.arange(0, num_tokens, 1)
+
+ def call(self, encoded_tokens):
+ # Encode the positions and add it to the encoded tokens
+ encoded_positions = self.position_embedding(self.positions)
+ encoded_tokens = encoded_tokens + encoded_positions
+ return encoded_tokens
+
+
+"""
+## Video Vision Transformer
+
+The authors suggest 4 variants of Vision Transformer:
+
+- Spatio-temporal attention
+- Factorized encoder
+- Factorized self-attention
+- Factorized dot-product attention
+
+In this example, we will implement the **Spatio-temporal attention**
+model for simplicity. The following code snippet is heavily inspired from
+[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
+One can also refer to the
+[official repository of ViViT](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit)
+which contains all the variants, implemented in JAX.
+"""
+
+
+def create_vivit_classifier(
+ tubelet_embedder,
+ positional_encoder,
+ input_shape=INPUT_SHAPE,
+ transformer_layers=NUM_LAYERS,
+ num_heads=NUM_HEADS,
+ embed_dim=PROJECTION_DIM,
+ layer_norm_eps=LAYER_NORM_EPS,
+ num_classes=NUM_CLASSES,
+):
+ # Get the input layer
+ inputs = layers.Input(shape=input_shape)
+ # Create patches.
+ patches = tubelet_embedder(inputs)
+ # Encode patches.
+ encoded_patches = positional_encoder(patches)
+
+ # Create multiple layers of the Transformer block.
+ for _ in range(transformer_layers):
+ # Layer normalization and MHSA
+ x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
+ attention_output = layers.MultiHeadAttention(
+ num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
+ )(x1, x1)
+
+ # Skip connection
+ x2 = layers.Add()([attention_output, encoded_patches])
+
+ # Layer Normalization and MLP
+ x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
+ x3 = keras.Sequential(
+ [
+ layers.Dense(units=embed_dim * 4, activation=ops.gelu),
+ layers.Dense(units=embed_dim, activation=ops.gelu),
+ ]
+ )(x3)
+
+ # Skip connection
+ encoded_patches = layers.Add()([x3, x2])
+
+ # Layer normalization and Global average pooling.
+ representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
+ representation = layers.GlobalAvgPool1D()(representation)
+
+ # Classify outputs.
+ outputs = layers.Dense(units=num_classes, activation="softmax")(representation)
+
+ # Create the Keras model.
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+"""
+## Train
+"""
+
+
+def run_experiment():
+ # Initialize model
+ model = create_vivit_classifier(
+ tubelet_embedder=TubeletEmbedding(
+ embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
+ ),
+ positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
+ )
+
+ # Compile the model with the optimizer, loss function
+ # and the metrics.
+ optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
+ model.compile(
+ optimizer=optimizer,
+ loss="sparse_categorical_crossentropy",
+ metrics=[
+ keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
+ keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
+ ],
+ )
+
+ # Train the model.
+ _ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)
+
+ _, accuracy, top_5_accuracy = model.evaluate(testloader)
+ print(f"Test accuracy: {round(accuracy * 100, 2)}%")
+ print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
+
+ return model
+
+
+model = run_experiment()
+
+"""
+## Inference
+"""
+
+NUM_SAMPLES_VIZ = 25
+testsamples, labels = next(iter(testloader))
+testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]
+
+ground_truths = []
+preds = []
+videos = []
+
+for i, (testsample, label) in enumerate(zip(testsamples, labels)):
+ # Generate gif
+ testsample = np.reshape(testsample.numpy(), (-1, 28, 28))
+ with io.BytesIO() as gif:
+ imageio.mimsave(gif, (testsample * 255).astype("uint8"), "GIF", fps=5)
+ videos.append(gif.getvalue())
+
+ # Get model prediction
+ output = model.predict(ops.expand_dims(testsample, axis=0))[0]
+ pred = np.argmax(output, axis=0)
+
+ ground_truths.append(label.numpy().astype("int"))
+ preds.append(pred)
+
+
+def make_box_for_grid(image_widget, fit):
+ """Make a VBox to hold caption/image for demonstrating option_fit values.
+
+ Source: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
+ """
+ # Make the caption
+ if fit is not None:
+ fit_str = "'{}'".format(fit)
+ else:
+ fit_str = str(fit)
+
+ h = ipywidgets.HTML(value="" + str(fit_str) + "")
+
+ # Make the green box with the image widget inside it
+ boxb = ipywidgets.widgets.Box()
+ boxb.children = [image_widget]
+
+ # Compose into a vertical box
+ vb = ipywidgets.widgets.VBox()
+ vb.layout.align_items = "center"
+ vb.children = [h, boxb]
+ return vb
+
+
+boxes = []
+for i in range(NUM_SAMPLES_VIZ):
+ ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
+ true_class = info["label"][str(ground_truths[i])]
+ pred_class = info["label"][str(preds[i])]
+ caption = f"T: {true_class} | P: {pred_class}"
+
+ boxes.append(make_box_for_grid(ib, caption))
+
+ipywidgets.widgets.GridBox(
+ boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
+)
+
+"""
+## Final thoughts
+
+With a vanilla implementation, we achieve ~79-80% Top-1 accuracy on the
+test dataset.
+
+The hyperparameters used in this tutorial were finalized by running a
+hyperparameter search using
+[W&B Sweeps](https://docs.wandb.ai/guides/sweeps).
+You can find out our sweeps result
+[here](https://wandb.ai/minimal-implementations/vivit/sweeps/66fp0lhz)
+and our quick analysis of the results
+[here](https://wandb.ai/minimal-implementations/vivit/reports/Hyperparameter-Tuning-Analysis--VmlldzoxNDEwNzcx).
+
+For further improvement, you could look into the following:
+
+- Using data augmentation for videos.
+- Using a better regularization scheme for training.
+- Apply different variants of the transformer model as in the paper.
+
+We would like to thank [Anurag Arnab](https://anuragarnab.github.io/)
+(first author of ViViT) for helpful discussion. We are grateful to
+[Weights and Biases](https://wandb.ai/site) program for helping with
+GPU credits.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/video-vision-transformer)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/video-vision-transformer-CT).
+"""
diff --git a/knowledge_base/vision/xray_classification_with_tpus.py b/knowledge_base/vision/xray_classification_with_tpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9f0d502142fefeca3fe3d43b7aa883dccc1dd68
--- /dev/null
+++ b/knowledge_base/vision/xray_classification_with_tpus.py
@@ -0,0 +1,441 @@
+"""
+Title: Pneumonia Classification on TPU
+Author: Amy MiHyun Jang
+Date created: 2020/07/28
+Last modified: 2020/08/24
+Description: Medical image classification on TPU.
+Accelerator: TPU
+"""
+
+"""
+## Introduction + Set-up
+
+This tutorial will explain how to build an X-ray image classification model
+to predict whether an X-ray scan shows presence of pneumonia.
+"""
+
+import re
+import os
+import random
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+try:
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
+ print("Device:", tpu.master())
+ strategy = tf.distribute.TPUStrategy(tpu)
+except:
+ strategy = tf.distribute.get_strategy()
+print("Number of replicas:", strategy.num_replicas_in_sync)
+
+"""
+We need a Google Cloud link to our data to load the data using a TPU.
+Below, we define key configuration parameters we'll use in this example.
+To run on TPU, this example must be on Colab with the TPU runtime selected.
+"""
+
+AUTOTUNE = tf.data.AUTOTUNE
+BATCH_SIZE = 25 * strategy.num_replicas_in_sync
+IMAGE_SIZE = [180, 180]
+CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
+
+"""
+## Load the data
+
+The Chest X-ray data we are using from
+[*Cell*](https://www.cell.com/cell/fulltext/S0092-8674(18)30154-5) divides the data into
+training and test files. Let's first load in the training TFRecords.
+"""
+
+train_images = tf.data.TFRecordDataset(
+ "gs://download.tensorflow.org/data/ChestXRay2017/train/images.tfrec"
+)
+train_paths = tf.data.TFRecordDataset(
+ "gs://download.tensorflow.org/data/ChestXRay2017/train/paths.tfrec"
+)
+
+ds = tf.data.Dataset.zip((train_images, train_paths))
+
+"""
+Let's count how many healthy/normal chest X-rays we have and how many
+pneumonia chest X-rays we have:
+"""
+
+COUNT_NORMAL = len(
+ [
+ filename
+ for filename in train_paths
+ if "NORMAL" in filename.numpy().decode("utf-8")
+ ]
+)
+print("Normal images count in training set: " + str(COUNT_NORMAL))
+
+COUNT_PNEUMONIA = len(
+ [
+ filename
+ for filename in train_paths
+ if "PNEUMONIA" in filename.numpy().decode("utf-8")
+ ]
+)
+print("Pneumonia images count in training set: " + str(COUNT_PNEUMONIA))
+
+"""
+Notice that there are way more images that are classified as pneumonia than normal. This
+shows that we have an imbalance in our data. We will correct for this imbalance later on
+in our notebook.
+"""
+
+"""
+We want to map each filename to the corresponding (image, label) pair. The following
+methods will help us do that.
+
+As we only have two labels, we will encode the label so that `1` or `True` indicates
+pneumonia and `0` or `False` indicates normal.
+"""
+
+
+def get_label(file_path):
+ # convert the path to a list of path components
+ parts = tf.strings.split(file_path, "/")
+ # The second to last is the class-directory
+ if parts[-2] == "PNEUMONIA":
+ return 1
+ else:
+ return 0
+
+
+def decode_img(img):
+ # convert the compressed string to a 3D uint8 tensor
+ img = tf.image.decode_jpeg(img, channels=3)
+ # resize the image to the desired size.
+ return tf.image.resize(img, IMAGE_SIZE)
+
+
+def process_path(image, path):
+ label = get_label(path)
+ # load the raw data from the file as a string
+ img = decode_img(image)
+ return img, label
+
+
+ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
+
+"""
+Let's split the data into a training and validation datasets.
+"""
+
+ds = ds.shuffle(10000)
+train_ds = ds.take(4200)
+val_ds = ds.skip(4200)
+
+"""
+Let's visualize the shape of an (image, label) pair.
+"""
+
+for image, label in train_ds.take(1):
+ print("Image shape: ", image.numpy().shape)
+ print("Label: ", label.numpy())
+
+"""
+Load and format the test data as well.
+"""
+
+test_images = tf.data.TFRecordDataset(
+ "gs://download.tensorflow.org/data/ChestXRay2017/test/images.tfrec"
+)
+test_paths = tf.data.TFRecordDataset(
+ "gs://download.tensorflow.org/data/ChestXRay2017/test/paths.tfrec"
+)
+test_ds = tf.data.Dataset.zip((test_images, test_paths))
+
+test_ds = test_ds.map(process_path, num_parallel_calls=AUTOTUNE)
+test_ds = test_ds.batch(BATCH_SIZE)
+
+"""
+## Visualize the dataset
+
+First, let's use buffered prefetching so we can yield data from disk without having I/O
+become blocking.
+
+Please note that large image datasets should not be cached in memory. We do it here
+because the dataset is not very large and we want to train on TPU.
+"""
+
+
+def prepare_for_training(ds, cache=True):
+ # This is a small dataset, only load it once, and keep it in memory.
+ # use `.cache(filename)` to cache preprocessing work for datasets that don't
+ # fit in memory.
+ if cache:
+ if isinstance(cache, str):
+ ds = ds.cache(cache)
+ else:
+ ds = ds.cache()
+
+ ds = ds.batch(BATCH_SIZE)
+
+ # `prefetch` lets the dataset fetch batches in the background while the model
+ # is training.
+ ds = ds.prefetch(buffer_size=AUTOTUNE)
+
+ return ds
+
+
+"""
+Call the next batch iteration of the training data.
+"""
+
+train_ds = prepare_for_training(train_ds)
+val_ds = prepare_for_training(val_ds)
+
+image_batch, label_batch = next(iter(train_ds))
+
+"""
+Define the method to show the images in the batch.
+"""
+
+
+def show_batch(image_batch, label_batch):
+ plt.figure(figsize=(10, 10))
+ for n in range(25):
+ ax = plt.subplot(5, 5, n + 1)
+ plt.imshow(image_batch[n] / 255)
+ if label_batch[n]:
+ plt.title("PNEUMONIA")
+ else:
+ plt.title("NORMAL")
+ plt.axis("off")
+
+
+"""
+As the method takes in NumPy arrays as its parameters, call the numpy function on the
+batches to return the tensor in NumPy array form.
+"""
+
+show_batch(image_batch.numpy(), label_batch.numpy())
+
+"""
+## Build the CNN
+
+To make our model more modular and easier to understand, let's define some blocks. As
+we're building a convolution neural network, we'll create a convolution block and a dense
+layer block.
+
+The architecture for this CNN has been inspired by this
+[article](https://towardsdatascience.com/deep-learning-for-detecting-pneumonia-from-x-ray-images-fc9a3d9fdba8).
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+
+
+def conv_block(filters, inputs):
+ x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(inputs)
+ x = layers.SeparableConv2D(filters, 3, activation="relu", padding="same")(x)
+ x = layers.BatchNormalization()(x)
+ outputs = layers.MaxPool2D()(x)
+
+ return outputs
+
+
+def dense_block(units, dropout_rate, inputs):
+ x = layers.Dense(units, activation="relu")(inputs)
+ x = layers.BatchNormalization()(x)
+ outputs = layers.Dropout(dropout_rate)(x)
+
+ return outputs
+
+
+"""
+The following method will define the function to build our model for us.
+
+The images originally have values that range from [0, 255]. CNNs work better with smaller
+numbers so we will scale this down for our input.
+
+The Dropout layers are important, as they
+reduce the likelikhood of the model overfitting. We want to end the model with a `Dense`
+layer with one node, as this will be the binary output that determines if an X-ray shows
+presence of pneumonia.
+"""
+
+
+def build_model():
+ inputs = keras.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
+ x = layers.Rescaling(1.0 / 255)(inputs)
+ x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)
+ x = layers.Conv2D(16, 3, activation="relu", padding="same")(x)
+ x = layers.MaxPool2D()(x)
+
+ x = conv_block(32, x)
+ x = conv_block(64, x)
+
+ x = conv_block(128, x)
+ x = layers.Dropout(0.2)(x)
+
+ x = conv_block(256, x)
+ x = layers.Dropout(0.2)(x)
+
+ x = layers.Flatten()(x)
+ x = dense_block(512, 0.7, x)
+ x = dense_block(128, 0.5, x)
+ x = dense_block(64, 0.3, x)
+
+ outputs = layers.Dense(1, activation="sigmoid")(x)
+
+ model = keras.Model(inputs=inputs, outputs=outputs)
+ return model
+
+
+"""
+## Correct for data imbalance
+
+We saw earlier in this example that the data was imbalanced, with more images classified
+as pneumonia than normal. We will correct for that by using class weighting:
+"""
+
+initial_bias = np.log([COUNT_PNEUMONIA / COUNT_NORMAL])
+print("Initial bias: {:.5f}".format(initial_bias[0]))
+
+TRAIN_IMG_COUNT = COUNT_NORMAL + COUNT_PNEUMONIA
+weight_for_0 = (1 / COUNT_NORMAL) * (TRAIN_IMG_COUNT) / 2.0
+weight_for_1 = (1 / COUNT_PNEUMONIA) * (TRAIN_IMG_COUNT) / 2.0
+
+class_weight = {0: weight_for_0, 1: weight_for_1}
+
+print("Weight for class 0: {:.2f}".format(weight_for_0))
+print("Weight for class 1: {:.2f}".format(weight_for_1))
+
+"""
+The weight for class `0` (Normal) is a lot higher than the weight for class `1`
+(Pneumonia). Because there are less normal images, each normal image will be weighted
+more to balance the data as the CNN works best when the training data is balanced.
+"""
+
+"""
+## Train the model
+"""
+
+"""
+### Defining callbacks
+
+The checkpoint callback saves the best weights of the model, so next time we want to use
+the model, we do not have to spend time training it. The early stopping callback stops
+the training process when the model starts becoming stagnant, or even worse, when the
+model starts overfitting.
+"""
+
+checkpoint_cb = keras.callbacks.ModelCheckpoint("xray_model.keras", save_best_only=True)
+
+early_stopping_cb = keras.callbacks.EarlyStopping(
+ patience=10, restore_best_weights=True
+)
+
+"""
+We also want to tune our learning rate. Too high of a learning rate will cause the model
+to diverge. Too small of a learning rate will cause the model to be too slow. We
+implement the exponential learning rate scheduling method below.
+"""
+
+initial_learning_rate = 0.015
+lr_schedule = keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
+)
+
+"""
+### Fit the model
+
+For our metrics, we want to include precision and recall as they will provide use with a
+more informed picture of how good our model is. Accuracy tells us what fraction of the
+labels is correct. Since our data is not balanced, accuracy might give a skewed sense of
+a good model (i.e. a model that always predicts PNEUMONIA will be 74% accurate but is not
+a good model).
+
+Precision is the number of true positives (TP) over the sum of TP and false positives
+(FP). It shows what fraction of labeled positives are actually correct.
+
+Recall is the number of TP over the sum of TP and false negatves (FN). It shows what
+fraction of actual positives are correct.
+
+Since there are only two possible labels for the image, we will be using the
+binary crossentropy loss. When we fit the model, remember to specify the class weights,
+which we defined earlier. Because we are using a TPU, training will be quick - less than
+2 minutes.
+"""
+
+with strategy.scope():
+ model = build_model()
+
+ METRICS = [
+ keras.metrics.BinaryAccuracy(),
+ keras.metrics.Precision(name="precision"),
+ keras.metrics.Recall(name="recall"),
+ ]
+ model.compile(
+ optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
+ loss="binary_crossentropy",
+ metrics=METRICS,
+ )
+
+history = model.fit(
+ train_ds,
+ epochs=100,
+ validation_data=val_ds,
+ class_weight=class_weight,
+ callbacks=[checkpoint_cb, early_stopping_cb],
+)
+
+"""
+## Visualizing model performance
+
+Let's plot the model accuracy and loss for the training and the validating set. Note that
+no random seed is specified for this notebook. For your notebook, there might be slight
+variance.
+"""
+
+fig, ax = plt.subplots(1, 4, figsize=(20, 3))
+ax = ax.ravel()
+
+for i, met in enumerate(["precision", "recall", "binary_accuracy", "loss"]):
+ ax[i].plot(history.history[met])
+ ax[i].plot(history.history["val_" + met])
+ ax[i].set_title("Model {}".format(met))
+ ax[i].set_xlabel("epochs")
+ ax[i].set_ylabel(met)
+ ax[i].legend(["train", "val"])
+
+"""
+We see that the accuracy for our model is around 95%.
+"""
+
+"""
+## Predict and evaluate results
+
+Let's evaluate the model on our test data!
+"""
+
+model.evaluate(test_ds, return_dict=True)
+
+"""
+We see that our accuracy on our test data is lower than the accuracy for our validating
+set. This may indicate overfitting.
+
+Our recall is greater than our precision, indicating that almost all pneumonia images are
+correctly identified but some normal images are falsely identified. We should aim to
+increase our precision.
+"""
+
+for image, label in test_ds.take(1):
+ plt.imshow(image[0] / 255.0)
+ plt.title(CLASS_NAMES[label[0].numpy()])
+
+prediction = model.predict(test_ds.take(1))[0]
+scores = [1 - prediction, prediction]
+
+for score, name in zip(scores, CLASS_NAMES):
+ print("This image is %.2f percent %s" % ((100 * score), name))
diff --git a/knowledge_base/vision/yolov8.py b/knowledge_base/vision/yolov8.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e70a929e6baedb447fc60ae137c0f8bdf2ddcee
--- /dev/null
+++ b/knowledge_base/vision/yolov8.py
@@ -0,0 +1,631 @@
+"""
+Title: Efficient Object Detection with YOLOV8 and KerasCV
+Author: [Gitesh Chawda](https://twitter.com/gitesh12_)
+Date created: 2023/06/26
+Last modified: 2023/06/26
+Description: Train custom YOLOV8 object detection model with KerasCV.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+"""
+
+"""
+KerasCV is an extension of Keras for computer vision tasks. In this example, we'll see
+how to train a YOLOV8 object detection model using KerasCV.
+
+KerasCV includes pre-trained models for popular computer vision datasets, such as
+ImageNet, COCO, and Pascal VOC, which can be used for transfer learning. KerasCV also
+provides a range of visualization tools for inspecting the intermediate representations
+learned by the model and for visualizing the results of object detection and segmentation
+tasks.
+"""
+
+"""
+If you're interested in learning about object detection using KerasCV, I highly suggest
+taking a look at the guide created by lukewood. This resource, available at
+[Object Detection With KerasCV](https://keras.io/guides/keras_cv/object_detection_keras_cv/#object-detection-introduction),
+provides a comprehensive overview of the fundamental concepts and techniques
+required for building object detection models with KerasCV.
+"""
+
+"""shell
+pip install --upgrade git+https://github.com/keras-team/keras-cv -q
+"""
+
+"""
+## Setup
+"""
+
+import os
+from tqdm.auto import tqdm
+import xml.etree.ElementTree as ET
+
+import tensorflow as tf
+from tensorflow import keras
+
+import keras_cv
+from keras_cv import bounding_box
+from keras_cv import visualization
+
+"""
+## Load Data
+"""
+
+"""
+For this guide, we will be utilizing the Self-Driving Car Dataset obtained from
+[roboflow](https://public.roboflow.com/object-detection/self-driving-car). In order to
+make the dataset more manageable, I have extracted a subset of the larger dataset, which
+originally consisted of 15,000 data samples. From this subset, I have chosen 7,316
+samples for model training.
+
+To simplify the task at hand and focus our efforts, we will be working with a reduced
+number of object classes. Specifically, we will be considering five primary classes for
+detection and classification: car, pedestrian, traffic light, biker, and truck. These
+classes represent some of the most common and significant objects encountered in the
+context of self-driving cars.
+
+By narrowing down the dataset to these specific classes, we can concentrate on building a
+robust object detection model that can accurately identify and classify these important
+objects.
+"""
+
+"""
+The TensorFlow Datasets library provides a convenient way to download and use various
+datasets, including the object detection dataset. This can be a great option for those
+who want to quickly start working with the data without having to manually download and
+preprocess it.
+
+You can view various object detection datasets here
+[TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview#object_detection)
+
+However, in this code example, we will demonstrate how to load the dataset from scratch
+using TensorFlow's `tf.data` pipeline. This approach provides more flexibility and allows
+you to customize the preprocessing steps as needed.
+
+Loading custom datasets that are not available in the TensorFlow Datasets library is one
+of the main advantages of using the `tf.data` pipeline. This approach allows you to
+create a custom data preprocessing pipeline tailored to the specific needs and
+requirements of your dataset.
+"""
+
+"""
+## Hyperparameters
+"""
+
+SPLIT_RATIO = 0.2
+BATCH_SIZE = 4
+LEARNING_RATE = 0.001
+EPOCH = 5
+GLOBAL_CLIPNORM = 10.0
+
+"""
+A dictionary is created to map each class name to a unique numerical identifier. This
+mapping is used to encode and decode the class labels during training and inference in
+object detection tasks.
+"""
+
+class_ids = [
+ "car",
+ "pedestrian",
+ "trafficLight",
+ "biker",
+ "truck",
+]
+class_mapping = dict(zip(range(len(class_ids)), class_ids))
+
+# Path to images and annotations
+path_images = "/kaggle/input/dataset/data/images/"
+path_annot = "/kaggle/input/dataset/data/annotations/"
+
+# Get all XML file paths in path_annot and sort them
+xml_files = sorted(
+ [
+ os.path.join(path_annot, file_name)
+ for file_name in os.listdir(path_annot)
+ if file_name.endswith(".xml")
+ ]
+)
+
+# Get all JPEG image file paths in path_images and sort them
+jpg_files = sorted(
+ [
+ os.path.join(path_images, file_name)
+ for file_name in os.listdir(path_images)
+ if file_name.endswith(".jpg")
+ ]
+)
+
+"""
+The function below reads the XML file and finds the image name and path, and then
+iterates over each object in the XML file to extract the bounding box coordinates and
+class labels for each object.
+
+The function returns three values: the image path, a list of bounding boxes (each
+represented as a list of four floats: xmin, ymin, xmax, ymax), and a list of class IDs
+(represented as integers) corresponding to each bounding box. The class IDs are obtained
+by mapping the class labels to integer values using a dictionary called `class_mapping`.
+"""
+
+
+def parse_annotation(xml_file):
+ tree = ET.parse(xml_file)
+ root = tree.getroot()
+
+ image_name = root.find("filename").text
+ image_path = os.path.join(path_images, image_name)
+
+ boxes = []
+ classes = []
+ for obj in root.iter("object"):
+ cls = obj.find("name").text
+ classes.append(cls)
+
+ bbox = obj.find("bndbox")
+ xmin = float(bbox.find("xmin").text)
+ ymin = float(bbox.find("ymin").text)
+ xmax = float(bbox.find("xmax").text)
+ ymax = float(bbox.find("ymax").text)
+ boxes.append([xmin, ymin, xmax, ymax])
+
+ class_ids = [
+ list(class_mapping.keys())[list(class_mapping.values()).index(cls)]
+ for cls in classes
+ ]
+ return image_path, boxes, class_ids
+
+
+image_paths = []
+bbox = []
+classes = []
+for xml_file in tqdm(xml_files):
+ image_path, boxes, class_ids = parse_annotation(xml_file)
+ image_paths.append(image_path)
+ bbox.append(boxes)
+ classes.append(class_ids)
+
+"""
+Here we are using `tf.ragged.constant` to create ragged tensors from the `bbox` and
+`classes` lists. A ragged tensor is a type of tensor that can handle varying lengths of
+data along one or more dimensions. This is useful when dealing with data that has
+variable-length sequences, such as text or time series data.
+
+```python
+classes = [
+ [8, 8, 8, 8, 8], # 5 classes
+ [12, 14, 14, 14], # 4 classes
+ [1], # 1 class
+ [7, 7], # 2 classes
+ ...]
+```
+
+```python
+bbox = [
+ [[199.0, 19.0, 390.0, 401.0],
+ [217.0, 15.0, 270.0, 157.0],
+ [393.0, 18.0, 432.0, 162.0],
+ [1.0, 15.0, 226.0, 276.0],
+ [19.0, 95.0, 458.0, 443.0]], #image 1 has 4 objects
+ [[52.0, 117.0, 109.0, 177.0]], #image 2 has 1 object
+ [[88.0, 87.0, 235.0, 322.0],
+ [113.0, 117.0, 218.0, 471.0]], #image 3 has 2 objects
+ ...]
+```
+
+In this case, the `bbox` and `classes` lists have different lengths for each image,
+depending on the number of objects in the image and the corresponding bounding boxes and
+classes. To handle this variability, ragged tensors are used instead of regular tensors.
+
+Later, these ragged tensors are used to create a `tf.data.Dataset` using the
+`from_tensor_slices` method. This method creates a dataset from the input tensors by
+slicing them along the first dimension. By using ragged tensors, the dataset can handle
+varying lengths of data for each image and provide a flexible input pipeline for further
+processing.
+"""
+
+bbox = tf.ragged.constant(bbox)
+classes = tf.ragged.constant(classes)
+image_paths = tf.ragged.constant(image_paths)
+
+data = tf.data.Dataset.from_tensor_slices((image_paths, classes, bbox))
+
+"""
+Splitting data in training and validation data
+"""
+
+# Determine the number of validation samples
+num_val = int(len(xml_files) * SPLIT_RATIO)
+
+# Split the dataset into train and validation sets
+val_data = data.take(num_val)
+train_data = data.skip(num_val)
+
+"""
+Let's see about data loading and bounding box formatting to get things going. Bounding
+boxes in KerasCV have a predetermined format. To do this, you must bundle your bounding
+boxes into a dictionary that complies with the requirements listed below:
+
+```python
+bounding_boxes = {
+ # num_boxes may be a Ragged dimension
+ 'boxes': Tensor(shape=[batch, num_boxes, 4]),
+ 'classes': Tensor(shape=[batch, num_boxes])
+}
+```
+
+The dictionary has two keys, `'boxes'` and `'classes'`, each of which maps to a
+TensorFlow RaggedTensor or Tensor object. The `'boxes'` Tensor has a shape of `[batch,
+num_boxes, 4]`, where batch is the number of images in the batch and num_boxes is the
+maximum number of bounding boxes in any image. The 4 represents the four values needed to
+define a bounding box: xmin, ymin, xmax, ymax.
+
+The `'classes'` Tensor has a shape of `[batch, num_boxes]`, where each element represents
+the class label for the corresponding bounding box in the `'boxes'` Tensor. The num_boxes
+dimension may be ragged, which means that the number of boxes may vary across images in
+the batch.
+
+Final dict should be:
+```python
+{"images": images, "bounding_boxes": bounding_boxes}
+```
+"""
+
+
+def load_image(image_path):
+ image = tf.io.read_file(image_path)
+ image = tf.image.decode_jpeg(image, channels=3)
+ return image
+
+
+def load_dataset(image_path, classes, bbox):
+ # Read Image
+ image = load_image(image_path)
+ bounding_boxes = {
+ "classes": tf.cast(classes, dtype=tf.float32),
+ "boxes": bbox,
+ }
+ return {"images": tf.cast(image, tf.float32), "bounding_boxes": bounding_boxes}
+
+
+"""
+Here we create a layer that resizes images to 640x640 pixels, while maintaining the
+original aspect ratio. The bounding boxes associated with the image are specified in the
+`xyxy` format. If necessary, the resized image will be padded with zeros to maintain the
+original aspect ratio.
+
+Bounding Box Formats supported by KerasCV:
+1. CENTER_XYWH
+2. XYWH
+3. XYXY
+4. REL_XYXY
+5. REL_XYWH
+6. YXYX
+7. REL_YXYX
+
+
+You can read more about KerasCV bounding box formats in
+[docs](https://keras.io/api/keras_cv/bounding_box/formats/).
+
+Furthermore, it is possible to perform format conversion between any two pairs:
+
+```python
+boxes = keras_cv.bounding_box.convert_format(
+ bounding_box,
+ images=image,
+ source="xyxy", # Original Format
+ target="xywh", # Target Format (to which we want to convert)
+ )
+```
+"""
+
+"""
+## Data Augmentation
+
+One of the most challenging tasks when constructing object detection pipelines is data
+augmentation. It involves applying various transformations to the input images to
+increase the diversity of the training data and improve the model's ability to
+generalize. However, when working with object detection tasks, it becomes even more
+complex as these transformations need to be aware of the underlying bounding boxes and
+update them accordingly.
+
+KerasCV provides native support for bounding box augmentation. KerasCV offers an
+extensive collection of data augmentation layers specifically designed to handle bounding
+boxes. These layers intelligently adjust the bounding box coordinates as the image is
+transformed, ensuring that the bounding boxes remain accurate and aligned with the
+augmented images.
+
+By leveraging KerasCV's capabilities, developers can conveniently integrate bounding
+box-friendly data augmentation into their object detection pipelines. By performing
+on-the-fly augmentation within a tf.data pipeline, the process becomes seamless and
+efficient, enabling better training and more accurate object detection results.
+"""
+
+augmenter = keras.Sequential(
+ layers=[
+ keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xyxy"),
+ keras_cv.layers.RandomShear(
+ x_factor=0.2, y_factor=0.2, bounding_box_format="xyxy"
+ ),
+ keras_cv.layers.JitteredResize(
+ target_size=(640, 640), scale_factor=(0.75, 1.3), bounding_box_format="xyxy"
+ ),
+ ]
+)
+
+"""
+## Creating Training Dataset
+"""
+
+train_ds = train_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
+train_ds = train_ds.shuffle(BATCH_SIZE * 4)
+train_ds = train_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
+train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)
+
+"""
+## Creating Validation Dataset
+"""
+
+resizing = keras_cv.layers.JitteredResize(
+ target_size=(640, 640),
+ scale_factor=(0.75, 1.3),
+ bounding_box_format="xyxy",
+)
+
+val_ds = val_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
+val_ds = val_ds.shuffle(BATCH_SIZE * 4)
+val_ds = val_ds.ragged_batch(BATCH_SIZE, drop_remainder=True)
+val_ds = val_ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
+
+"""
+## Visualization
+"""
+
+
+def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):
+ inputs = next(iter(inputs.take(1)))
+ images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
+ visualization.plot_bounding_box_gallery(
+ images,
+ value_range=value_range,
+ rows=rows,
+ cols=cols,
+ y_true=bounding_boxes,
+ scale=5,
+ font_scale=0.7,
+ bounding_box_format=bounding_box_format,
+ class_mapping=class_mapping,
+ )
+
+
+visualize_dataset(
+ train_ds, bounding_box_format="xyxy", value_range=(0, 255), rows=2, cols=2
+)
+
+visualize_dataset(
+ val_ds, bounding_box_format="xyxy", value_range=(0, 255), rows=2, cols=2
+)
+
+"""
+We need to extract the inputs from the preprocessing dictionary and get them ready to be
+fed into the model.
+"""
+
+
+def dict_to_tuple(inputs):
+ return inputs["images"], inputs["bounding_boxes"]
+
+
+train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
+train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
+
+val_ds = val_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
+val_ds = val_ds.prefetch(tf.data.AUTOTUNE)
+
+"""
+## Creating Model
+"""
+
+"""
+YOLOv8 is a cutting-edge YOLO model that is used for a variety of computer vision tasks,
+such as object detection, image classification, and instance segmentation. Ultralytics,
+the creators of YOLOv5, also developed YOLOv8, which incorporates many improvements and
+changes in architecture and developer experience compared to its predecessor. YOLOv8 is
+the latest state-of-the-art model that is highly regarded in the industry.
+"""
+
+"""
+Below table compares the performance metrics of five different YOLOv8 models with
+different sizes (measured in pixels): YOLOv8n, YOLOv8s, YOLOv8m, YOLOv8l, and YOLOv8x.
+The metrics include mean average precision (mAP) values at different
+intersection-over-union (IoU) thresholds for validation data, inference speed on CPU with
+ONNX format and A100 TensorRT, number of parameters, and number of floating-point
+operations (FLOPs) (both in millions and billions, respectively). As the size of the
+model increases, the mAP, parameters, and FLOPs generally increase while the speed
+decreases. YOLOv8x has the highest mAP, parameters, and FLOPs but also the slowest
+inference speed, while YOLOv8n has the smallest size, fastest inference speed, and lowest
+mAP, parameters, and FLOPs.
+
+| Model |
+size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) |
+Speed
A100 TensorRT
(ms) | params
(M) | FLOPs
(B) |
+| ------------------------------------------------------------------------------------ |
+--------------------- | -------------------- | ------------------------------ |
+----------------------------------- | ------------------ | ----------------- |
+| YOLOv8n | 640 | 37.3 | 80.4
+| 0.99 | 3.2 | 8.7 |
+| YOLOv8s | 640 | 44.9 | 128.4
+| 1.20 | 11.2 | 28.6 |
+| YOLOv8m | 640 | 50.2 | 234.7
+| 1.83 | 25.9 | 78.9 |
+| YOLOv8l | 640 | 52.9 | 375.2
+| 2.39 | 43.7 | 165.2 |
+| YOLOv8x | 640 | 53.9 | 479.1
+| 3.53 | 68.2 | 257.8 |
+"""
+
+"""
+You can read more about YOLOV8 and its architecture in this
+[RoboFlow Blog](https://blog.roboflow.com/whats-new-in-yolov8/)
+"""
+
+"""
+First we will create a instance of backbone which will be used by our yolov8 detector
+class.
+
+YOLOV8 Backbones available in KerasCV:
+
+1. Without Weights:
+
+ 1. yolo_v8_xs_backbone
+ 2. yolo_v8_s_backbone
+ 3. yolo_v8_m_backbone
+ 4. yolo_v8_l_backbone
+ 5. yolo_v8_xl_backbone
+
+2. With Pre-trained coco weight:
+
+ 1. yolo_v8_xs_backbone_coco
+ 2. yolo_v8_s_backbone_coco
+ 2. yolo_v8_m_backbone_coco
+ 2. yolo_v8_l_backbone_coco
+ 2. yolo_v8_xl_backbone_coco
+
+
+
+"""
+
+backbone = keras_cv.models.YOLOV8Backbone.from_preset(
+ "yolo_v8_s_backbone_coco" # We will use yolov8 small backbone with coco weights
+)
+
+"""
+Next, let's build a YOLOV8 model using the `YOLOV8Detector`, which accepts a feature
+extractor as the `backbone` argument, a `num_classes` argument that specifies the number
+of object classes to detect based on the size of the `class_mapping` list, a
+`bounding_box_format` argument that informs the model of the format of the bbox in the
+dataset, and a finally, the feature pyramid network (FPN) depth is specified by the
+`fpn_depth` argument.
+
+It is simple to build a YOLOV8 using any of the aforementioned backbones thanks to
+KerasCV.
+
+"""
+
+yolo = keras_cv.models.YOLOV8Detector(
+ num_classes=len(class_mapping),
+ bounding_box_format="xyxy",
+ backbone=backbone,
+ fpn_depth=1,
+)
+
+"""
+## Compile the Model
+"""
+
+"""
+Loss used for YOLOV8
+
+
+1. Classification Loss: This loss function calculates the discrepancy between anticipated
+class probabilities and actual class probabilities. In this instance,
+`binary_crossentropy`, a prominent solution for binary classification issues, is
+Utilized. We Utilized binary crossentropy since each thing that is identified is either
+classed as belonging to or not belonging to a certain object class (such as a person, a
+car, etc.).
+
+2. Box Loss: `box_loss` is the loss function used to measure the difference between the
+predicted bounding boxes and the ground truth. In this case, the Complete IoU (CIoU)
+metric is used, which not only measures the overlap between predicted and ground truth
+bounding boxes but also considers the difference in aspect ratio, center distance, and
+box size. Together, these loss functions help optimize the model for object detection by
+minimizing the difference between the predicted and ground truth class probabilities and
+bounding boxes.
+
+
+"""
+
+optimizer = tf.keras.optimizers.Adam(
+ learning_rate=LEARNING_RATE,
+ global_clipnorm=GLOBAL_CLIPNORM,
+)
+
+yolo.compile(
+ optimizer=optimizer, classification_loss="binary_crossentropy", box_loss="ciou"
+)
+
+"""
+## COCO Metric Callback
+
+We will be using `BoxCOCOMetrics` from KerasCV to evaluate the model and calculate the
+Map(Mean Average Precision) score, Recall and Precision. We also save our model when the
+mAP score improves.
+"""
+
+
+class EvaluateCOCOMetricsCallback(keras.callbacks.Callback):
+ def __init__(self, data, save_path):
+ super().__init__()
+ self.data = data
+ self.metrics = keras_cv.metrics.BoxCOCOMetrics(
+ bounding_box_format="xyxy",
+ evaluate_freq=1e9,
+ )
+
+ self.save_path = save_path
+ self.best_map = -1.0
+
+ def on_epoch_end(self, epoch, logs):
+ self.metrics.reset_state()
+ for batch in self.data:
+ images, y_true = batch[0], batch[1]
+ y_pred = self.model.predict(images, verbose=0)
+ self.metrics.update_state(y_true, y_pred)
+
+ metrics = self.metrics.result(force=True)
+ logs.update(metrics)
+
+ current_map = metrics["MaP"]
+ if current_map > self.best_map:
+ self.best_map = current_map
+ self.model.save(self.save_path) # Save the model when mAP improves
+
+ return logs
+
+
+"""
+## Train the Model
+"""
+
+yolo.fit(
+ train_ds,
+ validation_data=val_ds,
+ epochs=3,
+ callbacks=[EvaluateCOCOMetricsCallback(val_ds, "model.h5")],
+)
+
+"""
+## Visualize Predictions
+"""
+
+
+def visualize_detections(model, dataset, bounding_box_format):
+ images, y_true = next(iter(dataset.take(1)))
+ y_pred = model.predict(images)
+ y_pred = bounding_box.to_ragged(y_pred)
+ visualization.plot_bounding_box_gallery(
+ images,
+ value_range=(0, 255),
+ bounding_box_format=bounding_box_format,
+ y_true=y_true,
+ y_pred=y_pred,
+ scale=4,
+ rows=2,
+ cols=2,
+ show=True,
+ font_scale=0.7,
+ class_mapping=class_mapping,
+ )
+
+
+visualize_detections(yolo, dataset=val_ds, bounding_box_format="xyxy")
diff --git a/knowledge_base/vision/zero_dce.py b/knowledge_base/vision/zero_dce.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdc6cab1fedb5342d4c5ac8b1b1402f319b2ed1f
--- /dev/null
+++ b/knowledge_base/vision/zero_dce.py
@@ -0,0 +1,535 @@
+"""
+Title: Zero-DCE for low-light image enhancement
+Author: [Soumik Rakshit](http://github.com/soumik12345)
+Date created: 2021/09/18
+Last modified: 2023/07/15
+Description: Implementing Zero-Reference Deep Curve Estimation for low-light image enhancement.
+Accelerator: GPU
+Converted to Keras 3 by: [Soumik Rakshit](http://github.com/soumik12345)
+"""
+
+"""
+## Introduction
+
+**Zero-Reference Deep Curve Estimation** or **Zero-DCE** formulates low-light image
+enhancement as the task of estimating an image-specific
+[*tonal curve*](https://en.wikipedia.org/wiki/Curve_(tonality)) with a deep neural network.
+In this example, we train a lightweight deep network, **DCE-Net**, to estimate
+pixel-wise and high-order tonal curves for dynamic range adjustment of a given image.
+
+Zero-DCE takes a low-light image as input and produces high-order tonal curves as its output.
+These curves are then used for pixel-wise adjustment on the dynamic range of the input to
+obtain an enhanced image. The curve estimation process is done in such a way that it maintains
+the range of the enhanced image and preserves the contrast of neighboring pixels. This
+curve estimation is inspired by curves adjustment used in photo editing software such as
+Adobe Photoshop where users can adjust points throughout an imageโs tonal range.
+
+Zero-DCE is appealing because of its relaxed assumptions with regard to reference images:
+it does not require any input/output image pairs during training.
+This is achieved through a set of carefully formulated non-reference loss functions,
+which implicitly measure the enhancement quality and guide the training of the network.
+
+### References
+
+- [Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement](https://arxiv.org/abs/2001.06826)
+- [Curves adjustment in Adobe Photoshop](https://helpx.adobe.com/photoshop/using/curves-adjustment.html)
+"""
+
+"""
+## Downloading LOLDataset
+
+The **LoL Dataset** has been created for low-light image enhancement. It provides 485
+images for training and 15 for testing. Each image pair in the dataset consists of a
+low-light input image and its corresponding well-exposed reference image.
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import random
+import numpy as np
+from glob import glob
+from PIL import Image, ImageOps
+import matplotlib.pyplot as plt
+
+import keras
+from keras import layers
+
+import tensorflow as tf
+
+"""shell
+wget https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
+unzip -q lol_dataset.zip && rm lol_dataset.zip
+"""
+
+"""
+## Creating a TensorFlow Dataset
+
+We use 300 low-light images from the LoL Dataset training set for training, and we use
+the remaining 185 low-light images for validation. We resize the images to size `256 x
+256` to be used for both training and validation. Note that in order to train the DCE-Net,
+we will not require the corresponding enhanced images.
+"""
+
+IMAGE_SIZE = 256
+BATCH_SIZE = 16
+MAX_TRAIN_IMAGES = 400
+
+
+def load_data(image_path):
+ image = tf.io.read_file(image_path)
+ image = tf.image.decode_png(image, channels=3)
+ image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
+ image = image / 255.0
+ return image
+
+
+def data_generator(low_light_images):
+ dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
+ dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
+ dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
+ return dataset
+
+
+train_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
+val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
+test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))
+
+
+train_dataset = data_generator(train_low_light_images)
+val_dataset = data_generator(val_low_light_images)
+
+print("Train Dataset:", train_dataset)
+print("Validation Dataset:", val_dataset)
+
+"""
+## The Zero-DCE Framework
+
+The goal of DCE-Net is to estimate a set of best-fitting light-enhancement curves
+(LE-curves) given an input image. The framework then maps all pixels of the inputโs RGB
+channels by applying the curves iteratively to obtain the final enhanced image.
+
+### Understanding light-enhancement curves
+
+A ligh-enhancement curve is a kind of curve that can map a low-light image
+to its enhanced version automatically,
+where the self-adaptive curve parameters are solely dependent on the input image.
+When designing such a curve, three objectives should be taken into account:
+
+- Each pixel value of the enhanced image should be in the normalized range `[0,1]`, in order to
+avoid information loss induced by overflow truncation.
+- It should be monotonous, to preserve the contrast between neighboring pixels.
+- The shape of this curve should be as simple as possible,
+and the curve should be differentiable to allow backpropagation.
+
+The light-enhancement curve is separately applied to three RGB channels instead of solely on the
+illumination channel. The three-channel adjustment can better preserve the inherent color and reduce
+the risk of over-saturation.
+
+
+
+### DCE-Net
+
+The DCE-Net is a lightweight deep neural network that learns the mapping between an input
+image and its best-fitting curve parameter maps. The input to the DCE-Net is a low-light
+image while the outputs are a set of pixel-wise curve parameter maps for corresponding
+higher-order curves. It is a plain CNN of seven convolutional layers with symmetrical
+concatenation. Each layer consists of 32 convolutional kernels of size 3ร3 and stride 1
+followed by the ReLU activation function. The last convolutional layer is followed by the
+Tanh activation function, which produces 24 parameter maps for 8 iterations, where each
+iteration requires three curve parameter maps for the three channels.
+
+
+"""
+
+
+def build_dce_net():
+ input_img = keras.Input(shape=[None, None, 3])
+ conv1 = layers.Conv2D(
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
+ )(input_img)
+ conv2 = layers.Conv2D(
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
+ )(conv1)
+ conv3 = layers.Conv2D(
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
+ )(conv2)
+ conv4 = layers.Conv2D(
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
+ )(conv3)
+ int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
+ conv5 = layers.Conv2D(
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
+ )(int_con1)
+ int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
+ conv6 = layers.Conv2D(
+ 32, (3, 3), strides=(1, 1), activation="relu", padding="same"
+ )(int_con2)
+ int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
+ x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
+ int_con3
+ )
+ return keras.Model(inputs=input_img, outputs=x_r)
+
+
+"""
+## Loss functions
+
+To enable zero-reference learning in DCE-Net, we use a set of differentiable
+zero-reference losses that allow us to evaluate the quality of enhanced images.
+"""
+
+"""
+### Color constancy loss
+
+The *color constancy loss* is used to correct the potential color deviations in the
+enhanced image.
+"""
+
+
+def color_constancy_loss(x):
+ mean_rgb = tf.reduce_mean(x, axis=(1, 2), keepdims=True)
+ mr, mg, mb = (
+ mean_rgb[:, :, :, 0],
+ mean_rgb[:, :, :, 1],
+ mean_rgb[:, :, :, 2],
+ )
+ d_rg = tf.square(mr - mg)
+ d_rb = tf.square(mr - mb)
+ d_gb = tf.square(mb - mg)
+ return tf.sqrt(tf.square(d_rg) + tf.square(d_rb) + tf.square(d_gb))
+
+
+"""
+### Exposure loss
+
+To restrain under-/over-exposed regions, we use the *exposure control loss*.
+It measures the distance between the average intensity value of a local region
+and a preset well-exposedness level (set to `0.6`).
+"""
+
+
+def exposure_loss(x, mean_val=0.6):
+ x = tf.reduce_mean(x, axis=3, keepdims=True)
+ mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding="VALID")
+ return tf.reduce_mean(tf.square(mean - mean_val))
+
+
+"""
+### Illumination smoothness loss
+
+To preserve the monotonicity relations between neighboring pixels, the
+*illumination smoothness loss* is added to each curve parameter map.
+"""
+
+
+def illumination_smoothness_loss(x):
+ batch_size = tf.shape(x)[0]
+ h_x = tf.shape(x)[1]
+ w_x = tf.shape(x)[2]
+ count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
+ count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
+ h_tv = tf.reduce_sum(tf.square((x[:, 1:, :, :] - x[:, : h_x - 1, :, :])))
+ w_tv = tf.reduce_sum(tf.square((x[:, :, 1:, :] - x[:, :, : w_x - 1, :])))
+ batch_size = tf.cast(batch_size, dtype=tf.float32)
+ count_h = tf.cast(count_h, dtype=tf.float32)
+ count_w = tf.cast(count_w, dtype=tf.float32)
+ return 2 * (h_tv / count_h + w_tv / count_w) / batch_size
+
+
+"""
+### Spatial consistency loss
+
+The *spatial consistency loss* encourages spatial coherence of the enhanced image by
+preserving the contrast between neighboring regions across the input image and its enhanced version.
+"""
+
+
+class SpatialConsistencyLoss(keras.losses.Loss):
+ def __init__(self, **kwargs):
+ super().__init__(reduction="none")
+
+ self.left_kernel = tf.constant(
+ [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
+ )
+ self.right_kernel = tf.constant(
+ [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32
+ )
+ self.up_kernel = tf.constant(
+ [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32
+ )
+ self.down_kernel = tf.constant(
+ [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32
+ )
+
+ def call(self, y_true, y_pred):
+ original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
+ enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
+ original_pool = tf.nn.avg_pool2d(
+ original_mean, ksize=4, strides=4, padding="VALID"
+ )
+ enhanced_pool = tf.nn.avg_pool2d(
+ enhanced_mean, ksize=4, strides=4, padding="VALID"
+ )
+
+ d_original_left = tf.nn.conv2d(
+ original_pool,
+ self.left_kernel,
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ )
+ d_original_right = tf.nn.conv2d(
+ original_pool,
+ self.right_kernel,
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ )
+ d_original_up = tf.nn.conv2d(
+ original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
+ )
+ d_original_down = tf.nn.conv2d(
+ original_pool,
+ self.down_kernel,
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ )
+
+ d_enhanced_left = tf.nn.conv2d(
+ enhanced_pool,
+ self.left_kernel,
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ )
+ d_enhanced_right = tf.nn.conv2d(
+ enhanced_pool,
+ self.right_kernel,
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ )
+ d_enhanced_up = tf.nn.conv2d(
+ enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME"
+ )
+ d_enhanced_down = tf.nn.conv2d(
+ enhanced_pool,
+ self.down_kernel,
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ )
+
+ d_left = tf.square(d_original_left - d_enhanced_left)
+ d_right = tf.square(d_original_right - d_enhanced_right)
+ d_up = tf.square(d_original_up - d_enhanced_up)
+ d_down = tf.square(d_original_down - d_enhanced_down)
+ return d_left + d_right + d_up + d_down
+
+
+"""
+### Deep curve estimation model
+
+We implement the Zero-DCE framework as a Keras subclassed model.
+"""
+
+
+class ZeroDCE(keras.Model):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dce_model = build_dce_net()
+
+ def compile(self, learning_rate, **kwargs):
+ super().compile(**kwargs)
+ self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
+ self.spatial_constancy_loss = SpatialConsistencyLoss(reduction="none")
+ self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
+ self.illumination_smoothness_loss_tracker = keras.metrics.Mean(
+ name="illumination_smoothness_loss"
+ )
+ self.spatial_constancy_loss_tracker = keras.metrics.Mean(
+ name="spatial_constancy_loss"
+ )
+ self.color_constancy_loss_tracker = keras.metrics.Mean(
+ name="color_constancy_loss"
+ )
+ self.exposure_loss_tracker = keras.metrics.Mean(name="exposure_loss")
+
+ @property
+ def metrics(self):
+ return [
+ self.total_loss_tracker,
+ self.illumination_smoothness_loss_tracker,
+ self.spatial_constancy_loss_tracker,
+ self.color_constancy_loss_tracker,
+ self.exposure_loss_tracker,
+ ]
+
+ def get_enhanced_image(self, data, output):
+ r1 = output[:, :, :, :3]
+ r2 = output[:, :, :, 3:6]
+ r3 = output[:, :, :, 6:9]
+ r4 = output[:, :, :, 9:12]
+ r5 = output[:, :, :, 12:15]
+ r6 = output[:, :, :, 15:18]
+ r7 = output[:, :, :, 18:21]
+ r8 = output[:, :, :, 21:24]
+ x = data + r1 * (tf.square(data) - data)
+ x = x + r2 * (tf.square(x) - x)
+ x = x + r3 * (tf.square(x) - x)
+ enhanced_image = x + r4 * (tf.square(x) - x)
+ x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
+ x = x + r6 * (tf.square(x) - x)
+ x = x + r7 * (tf.square(x) - x)
+ enhanced_image = x + r8 * (tf.square(x) - x)
+ return enhanced_image
+
+ def call(self, data):
+ dce_net_output = self.dce_model(data)
+ return self.get_enhanced_image(data, dce_net_output)
+
+ def compute_losses(self, data, output):
+ enhanced_image = self.get_enhanced_image(data, output)
+ loss_illumination = 200 * illumination_smoothness_loss(output)
+ loss_spatial_constancy = tf.reduce_mean(
+ self.spatial_constancy_loss(enhanced_image, data)
+ )
+ loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(enhanced_image))
+ loss_exposure = 10 * tf.reduce_mean(exposure_loss(enhanced_image))
+ total_loss = (
+ loss_illumination
+ + loss_spatial_constancy
+ + loss_color_constancy
+ + loss_exposure
+ )
+
+ return {
+ "total_loss": total_loss,
+ "illumination_smoothness_loss": loss_illumination,
+ "spatial_constancy_loss": loss_spatial_constancy,
+ "color_constancy_loss": loss_color_constancy,
+ "exposure_loss": loss_exposure,
+ }
+
+ def train_step(self, data):
+ with tf.GradientTape() as tape:
+ output = self.dce_model(data)
+ losses = self.compute_losses(data, output)
+
+ gradients = tape.gradient(
+ losses["total_loss"], self.dce_model.trainable_weights
+ )
+ self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
+
+ self.total_loss_tracker.update_state(losses["total_loss"])
+ self.illumination_smoothness_loss_tracker.update_state(
+ losses["illumination_smoothness_loss"]
+ )
+ self.spatial_constancy_loss_tracker.update_state(
+ losses["spatial_constancy_loss"]
+ )
+ self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
+ self.exposure_loss_tracker.update_state(losses["exposure_loss"])
+
+ return {metric.name: metric.result() for metric in self.metrics}
+
+ def test_step(self, data):
+ output = self.dce_model(data)
+ losses = self.compute_losses(data, output)
+
+ self.total_loss_tracker.update_state(losses["total_loss"])
+ self.illumination_smoothness_loss_tracker.update_state(
+ losses["illumination_smoothness_loss"]
+ )
+ self.spatial_constancy_loss_tracker.update_state(
+ losses["spatial_constancy_loss"]
+ )
+ self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
+ self.exposure_loss_tracker.update_state(losses["exposure_loss"])
+
+ return {metric.name: metric.result() for metric in self.metrics}
+
+ def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
+ """While saving the weights, we simply save the weights of the DCE-Net"""
+ self.dce_model.save_weights(
+ filepath,
+ overwrite=overwrite,
+ save_format=save_format,
+ options=options,
+ )
+
+ def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
+ """While loading the weights, we simply load the weights of the DCE-Net"""
+ self.dce_model.load_weights(
+ filepath=filepath,
+ by_name=by_name,
+ skip_mismatch=skip_mismatch,
+ options=options,
+ )
+
+
+"""
+## Training
+"""
+
+zero_dce_model = ZeroDCE()
+zero_dce_model.compile(learning_rate=1e-4)
+history = zero_dce_model.fit(train_dataset, validation_data=val_dataset, epochs=100)
+
+
+def plot_result(item):
+ plt.plot(history.history[item], label=item)
+ plt.plot(history.history["val_" + item], label="val_" + item)
+ plt.xlabel("Epochs")
+ plt.ylabel(item)
+ plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
+ plt.legend()
+ plt.grid()
+ plt.show()
+
+
+plot_result("total_loss")
+plot_result("illumination_smoothness_loss")
+plot_result("spatial_constancy_loss")
+plot_result("color_constancy_loss")
+plot_result("exposure_loss")
+
+"""
+## Inference
+"""
+
+
+def plot_results(images, titles, figure_size=(12, 12)):
+ fig = plt.figure(figsize=figure_size)
+ for i in range(len(images)):
+ fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
+ _ = plt.imshow(images[i])
+ plt.axis("off")
+ plt.show()
+
+
+def infer(original_image):
+ image = keras.utils.img_to_array(original_image)
+ image = image.astype("float32") / 255.0
+ image = np.expand_dims(image, axis=0)
+ output_image = zero_dce_model(image)
+ output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
+ output_image = Image.fromarray(output_image.numpy())
+ return output_image
+
+
+"""
+### Inference on test images
+
+We compare the test images from LOLDataset enhanced by MIRNet with images enhanced via
+the `PIL.ImageOps.autocontrast()` function.
+
+You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/low-light-image-enhancement)
+and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/low-light-image-enhancement).
+"""
+
+for val_image_file in test_low_light_images:
+ original_image = Image.open(val_image_file)
+ enhanced_image = infer(original_image)
+ plot_results(
+ [original_image, ImageOps.autocontrast(original_image), enhanced_image],
+ ["Original", "PIL Autocontrast", "Enhanced"],
+ (20, 12),
+ )
diff --git a/mcp_server/__init__.py b/mcp_server/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..837133252d826e2a105c6088e4bd6bf530b45c80
--- /dev/null
+++ b/mcp_server/__init__.py
@@ -0,0 +1,2 @@
+__all__ = ["loader", "embeddings"]
+__version__ = "0.1.0"
\ No newline at end of file
diff --git a/mcp_server/embeddings.py b/mcp_server/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a9679576682eafda0600cea95de98a88d0a33e
--- /dev/null
+++ b/mcp_server/embeddings.py
@@ -0,0 +1,155 @@
+from __future__ import annotations
+
+import os
+import threading
+from dataclasses import dataclass
+from typing import List, Sequence, Tuple, Optional
+
+import numpy as np
+
+# Determinism knobs
+os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
+
+try:
+ import torch # type: ignore
+except Exception: # pragma: no cover
+ torch = None # type: ignore
+
+try:
+ from sentence_transformers import SentenceTransformer # type: ignore
+except Exception as e: # pragma: no cover
+ raise RuntimeError(
+ "sentence-transformers is required. Add it to requirements.txt and install."
+ ) from e
+
+
+DEFAULT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
+
+
+def _set_deterministic(seed: int = 42) -> None:
+ """Best-effort determinism across numpy and torch."""
+ np.random.seed(seed)
+ if torch is not None:
+ try:
+ torch.manual_seed(seed)
+ torch.use_deterministic_algorithms(True) # type: ignore[attr-defined]
+ except Exception:
+ # Older PyTorch or restricted envs
+ pass
+
+
+def _l2_normalize(arr: np.ndarray, eps: float = 1e-12) -> np.ndarray:
+ norms = np.linalg.norm(arr, axis=1, keepdims=True)
+ norms = np.maximum(norms, eps)
+ return arr / norms
+
+
+@dataclass(frozen=True)
+class IndexedItem:
+ """Holds the minimal info needed by the search index."""
+ id: str
+ category: str
+ filename: str
+ path: str # relative to project root
+ summary: str
+
+
+class Embedder:
+ """
+ Thin wrapper around SentenceTransformer to produce normalized embeddings.
+ Thread-safe, lazily initialized, deterministic.
+ """
+
+ def __init__(self, model_name: str = DEFAULT_MODEL, device: Optional[str] = None, seed: int = 42):
+ self.model_name = model_name
+ self.device = device
+ self.seed = seed
+ self._model: Optional[SentenceTransformer] = None
+ self._lock = threading.Lock()
+ _set_deterministic(self.seed)
+
+ def _ensure_model(self) -> SentenceTransformer:
+ if self._model is not None:
+ return self._model
+ with self._lock:
+ if self._model is None:
+ self._model = SentenceTransformer(self.model_name, device=self.device)
+ return self._model # type: ignore[return-value]
+
+ def embed(self, texts: Sequence[str], batch_size: int = 64) -> np.ndarray:
+ """
+ Embed a list of strings into a 2D numpy array (N, D), L2-normalized.
+ """
+ if not texts:
+ return np.zeros((0, 384), dtype=np.float32) # model dim for MiniLM
+ model = self._ensure_model()
+ vecs = model.encode(
+ list(texts),
+ batch_size=batch_size,
+ convert_to_numpy=True,
+ normalize_embeddings=False,
+ show_progress_bar=False,
+ )
+ if not isinstance(vecs, np.ndarray):
+ vecs = np.array(vecs)
+ vecs = vecs.astype(np.float32, copy=False)
+ return _l2_normalize(vecs)
+
+ def embed_one(self, text: str) -> np.ndarray:
+ return self.embed([text])[0:1]
+
+
+class EmbeddingIndex:
+ """
+ In-memory embedding index for KB items using cosine similarity.
+
+ - Stores L2-normalized item vectors in a (N, D) float32 matrix.
+ - search(query) computes normalized query vector and returns argmax cosine.
+ """
+
+ def __init__(self, embedder: Embedder):
+ self.embedder = embedder
+ self.items: List[IndexedItem] = []
+ self.matrix: Optional[np.ndarray] = None # (N, D), float32, normalized
+ self.dim: Optional[int] = None
+
+ def build(self, items: Sequence[IndexedItem], texts: Sequence[str]) -> None:
+ if len(items) != len(texts):
+ raise ValueError("items and texts must have the same length")
+ if not items:
+ # Empty index
+ self.items = []
+ self.matrix = np.zeros((0, 1), dtype=np.float32)
+ self.dim = 1
+ return
+ vecs = self.embedder.embed(texts)
+ self.items = list(items)
+ self.matrix = vecs # already L2-normalized
+ self.dim = int(vecs.shape[1])
+
+ def is_built(self) -> bool:
+ return self.matrix is not None and self.items is not None
+
+ def search_one(self, query_text: str) -> Tuple[IndexedItem, float]:
+ """
+ Return (best_item, best_score) where score is cosine similarity in [ -1, 1 ].
+ """
+ if self.matrix is None or len(self.items) == 0:
+ raise RuntimeError("Index is empty. Build the index before searching.")
+ q = self.embedder.embed_one(query_text) # (1, D), normalized
+ # Cosine for normalized vectors reduces to dot product
+ sims = (q @ self.matrix.T).astype(np.float32) # (1, N)
+ best_idx = int(np.argmax(sims, axis=1)[0])
+ best_score = float(sims[0, best_idx])
+ return self.items[best_idx], best_score
+
+ def search_topk(self, query_text: str, k: int = 5) -> List[Tuple[IndexedItem, float]]:
+ if self.matrix is None or len(self.items) == 0:
+ raise RuntimeError("Index is empty. Build the index before searching.")
+ k = max(1, min(k, len(self.items)))
+ q = self.embedder.embed_one(query_text)
+ sims = (q @ self.matrix.T).astype(np.float32).ravel()
+ topk_idx = np.argpartition(-sims, kth=k - 1)[:k]
+ # sort exact top-k
+ topk_sorted = sorted(((int(i), float(sims[int(i)])) for i in topk_idx), key=lambda t: -t[1])
+ return [(self.items[i], score) for i, score in topk_sorted]
\ No newline at end of file
diff --git a/mcp_server/loader.py b/mcp_server/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fa1a98dcf80e3b4b64010bb132d053a71ab2618
--- /dev/null
+++ b/mcp_server/loader.py
@@ -0,0 +1,129 @@
+from __future__ import annotations
+
+import ast
+from dataclasses import dataclass, asdict
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+KB_DIRNAME = "knowledge_base"
+
+@dataclass(frozen=True)
+class KBItem:
+ id: str
+ category: str
+ filename: str
+ path: str
+ summary: str
+
+def project_root() -> Path:
+ """Return project root (folder containing 'mcp_server' and 'knowledge_base')."""
+ return Path(__file__).resolve().parent.parent
+
+def kb_root() -> Path:
+ """Return absolute path to knowledge_base directory."""
+ return project_root() / KB_DIRNAME
+
+def extract_docstring(p: Path) -> str:
+ """Extract module docstring; fallback to first non-empty line if absent."""
+ try:
+ src = p.read_text(encoding="utf-8")
+ except Exception:
+ return ""
+ try:
+ module = ast.parse(src)
+ doc = ast.get_docstring(module) or ""
+ if doc:
+ return " ".join(doc.strip().split())
+ # fallback: first non-empty non-comment line
+ for ln in src.splitlines():
+ s = ln.strip()
+ if s:
+ return s[:300]
+ return ""
+ except Exception:
+ return ""
+
+def to_rel_path(p: Path) -> str:
+ """Return path relative to project root for stable IDs."""
+ try:
+ return str(p.relative_to(project_root()))
+ except Exception:
+ return str(p)
+
+def scan_knowledge_base() -> List[KBItem]:
+ """Scan knowledge_base for Python examples and return KBItem list."""
+ root = kb_root()
+ if not root.exists():
+ raise FileNotFoundError(f"Knowledge base folder not found: {root}")
+ items: List[KBItem] = []
+ for category_dir in sorted([d for d in root.iterdir() if d.is_dir()]):
+ category = category_dir.name
+ for py in sorted(category_dir.rglob("*.py")):
+ filename = py.name
+ rel_path = to_rel_path(py)
+ summary = extract_docstring(py) or filename
+ item_id = f"{category}/{filename}"
+ items.append(
+ KBItem(
+ id=item_id,
+ category=category,
+ filename=filename,
+ path=rel_path,
+ summary=summary,
+ )
+ )
+ return items
+
+# Simple in-process cache. Re-scan lazily on first access.
+_ITEMS_CACHE: Optional[List[KBItem]] = None
+
+def list_items_dict() -> List[dict]:
+ """Return KB items as plain dicts suitable for JSON serialization."""
+ global _ITEMS_CACHE
+ if _ITEMS_CACHE is None:
+ _ITEMS_CACHE = scan_knowledge_base()
+ return [asdict(item) for item in _ITEMS_CACHE]
+
+def get_embedding_text(item: KBItem) -> str:
+ """Canonical string used for embedding a KB item."""
+ head = (item.summary or "").strip()
+ return f"{item.category}/{item.filename}: {head}"
+
+def ensure_kb_path(path_str: str) -> Path:
+ """
+ Validate and resolve a path under knowledge_base.
+ Accepts either:
+ - 'knowledge_base//.py'
+ - '/.py'
+ Raises if outside KB or not a Python file.
+ """
+ base = kb_root().resolve()
+ # Normalize input
+ candidate = Path(path_str)
+ if not candidate.is_absolute():
+ # Allow either direct rel to project root or category/file
+ p1 = (project_root() / candidate).resolve()
+ p2 = (base / candidate).resolve()
+ # Prefer inside KB
+ p = p1 if str(p1).startswith(str(base)) and p1.exists() else p2
+ else:
+ p = candidate.resolve()
+ try:
+ p.relative_to(base)
+ except Exception:
+ raise ValueError(f"Path is outside knowledge base: {path_str}")
+ if not p.exists() or not p.is_file() or p.suffix != ".py":
+ raise FileNotFoundError(f"KB file not found: {p}")
+ return p
+
+def read_code(path_str: str) -> str:
+ """Read and return full source code of a KB file."""
+ p = ensure_kb_path(path_str)
+ return p.read_text(encoding="utf-8")
+
+def get_items_for_embedding() -> List[Tuple[KBItem, str]]:
+ """Return list of (KBItem, text) pairs for embedding."""
+ global _ITEMS_CACHE
+ if _ITEMS_CACHE is None:
+ _ITEMS_CACHE = scan_knowledge_base()
+ return [(it, get_embedding_text(it)) for it in _ITEMS_CACHE]
\ No newline at end of file
diff --git a/mcp_server/server.py b/mcp_server/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..322a0de0af3278e2c1a4a19b3a24726b15a9bb7e
--- /dev/null
+++ b/mcp_server/server.py
@@ -0,0 +1,273 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List
+
+import gradio as gr
+
+from mcp_server.tools.list_items import list_items as tool_list_items
+from mcp_server.tools.semantic_search import semantic_search as tool_semantic_search
+from mcp_server.tools.get_code import get_code as tool_get_code
+
+def create_gradio_blocks() -> gr.Blocks:
+ """
+ Build a Gradio UI that, when launched with mcp_server=True, exposes a remote MCP server
+ at: http://:/gradio_api/mcp/sse
+
+ Tools exposed (via function signatures and docstrings):
+ - list_items()
+ - semantic_search(problem_markdown: str)
+ - get_code(path: str)
+
+ Polished UI/UX:
+ - Themed interfaces, custom CSS, clear titles and descriptions
+ - Curated examples for Semantic Search and Get Code
+ - Helpful hero/guide text on the List Items tab
+ """
+
+ # Lightweight custom CSS for a more polished look
+ custom_css = """
+ :root {
+ --radius-md: 12px;
+ --shadow-md: 0 6px 24px rgba(0,0,0,.08);
+ --color-accent: #3B82F6; /* Blue 500 */
+ --color-accent-hover: #2563EB; /* Blue 600 */
+ --color-accent-soft: rgba(59,130,246,.15);
+ --link-text-color: #3B82F6;
+ }
+ .gradio-container { max-width: 1120px !important; margin: 0 auto; }
+
+ /* Buttons and controls -> blue accent */
+ .gr-button {
+ border-radius: 12px;
+ box-shadow: var(--shadow-md);
+ background: var(--color-accent) !important;
+ color: #fff !important;
+ border: 1px solid transparent !important;
+ }
+ .gr-button:hover { background: var(--color-accent-hover) !important; }
+ .gr-button:focus-visible { outline: 2px solid var(--color-accent); outline-offset: 2px; }
+
+ /* Tabs -> blue accent on active/hover */
+ .gr-tabs .tab-nav button[aria-selected="true"] {
+ border-bottom: 2px solid var(--color-accent) !important;
+ color: var(--color-accent) !important;
+ }
+ .gr-tabs .tab-nav button:hover { color: var(--color-accent) !important; }
+
+ /* Examples (chips/buttons) */
+ .gr-examples button, .examples button {
+ border-color: var(--color-accent) !important;
+ color: var(--color-accent) !important;
+ background: transparent !important;
+ }
+ .gr-examples button:hover, .examples button:hover {
+ background: var(--color-accent-soft) !important;
+ }
+
+ .gr-textbox textarea {
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
+ }
+ h1, .prose h1 {
+ background: linear-gradient(90deg, #60A5FA, #22D3EE, #1D4ED8);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent;
+ }
+ a, .prose a { color: var(--link-text-color) !important; }
+ .prose p, .prose li { font-size: 15px; line-height: 1.65; }
+ """
+
+ def list_items() -> List[Dict[str, Any]]:
+ """
+ List all knowledge base items.
+
+ Returns:
+ A JSON-serializable list of items, each with:
+ id (str): '/.py'
+ category (str)
+ filename (str)
+ path (str): 'knowledge_base//.py'
+ summary (str): Docstring or first non-empty line
+ """
+ return tool_list_items()
+
+ def semantic_search(problem_markdown: str) -> Dict[str, Any]:
+ """
+ Semantic search over the knowledge base.
+
+ Args:
+ problem_markdown: Markdown text describing the task/problem.
+
+ Returns:
+ {
+ "best_match": {
+ "id": str,
+ "category": str,
+ "filename": str,
+ "path": str,
+ "summary": str
+ },
+ "score": float # cosine similarity in [-1, 1]
+ }
+ """
+ return tool_semantic_search(problem_markdown)
+
+ def get_code(path: str) -> str:
+ """
+ Return the full Python source code for a KB file.
+
+ Args:
+ path: Either 'knowledge_base//.py' or '/.py'
+
+ Returns:
+ UTF-8 Python source as a string.
+ """
+ return tool_get_code(path)
+
+ # Curated examples for a smoother first-run UX
+ search_examples = [
+ "I want to fine-tune a transformer for sentiment classification.",
+ "Train a GNN on citation networks for node classification.",
+ "Image generation with GANs; how to stabilize training?",
+ ]
+ code_examples = [
+ "knowledge_base/nlp/text_classification_with_transformer.py",
+ "knowledge_base/graph/gnn_citations.py",
+ "knowledge_base/generative/dcgan_overriding_train_step.py",
+ ]
+
+ hero_md = """
+# โก๏ธ ML Starter: Your ML Launchpad
+
+## **Starting an ML project and overwhelmed by where to begin?**
+## **LLMs not specialized enough for your domain?**
+## **Need real, reusable code instead of vague suggestions?**
+
+### **Describe your problem โ get the top-ranked match โ pull the exact code file.**
+
+---
+
+### ๐ฅ Why you'll love it
+- ๐ฏ **Problem-to-code in one flow** โ search semantically, explore context, and download source.
+- ๐ง **Domain-tuned knowledge** โ embeddings built over curated ML projects across vision, NLP, audio, structured data, and more.
+- ๐ค **Automation ready** โ the same tools power IDEs/agents via MCP over SSE.
+
+### ๐ What you can do
+- ๐ **Browse Items** โ scan the entire library with instant summaries.
+- ๐ **Semantic Search** โ paste your challenge and get the closest-fit recipe plus similarity score.
+- ๐ป **Get Code** โ drop in the path and copy the full Python implementation.
+
+### ๐ Under the hood
+- Sentence-transformer embeddings + cosine similarity for precise retrieval.
+- Rich metadata (id, category, path, summary) for fast filtering.
+- Remote MCP endpoint at `/gradio_api/mcp/sse` exposing `list_items()`, `semantic_search()`, `get_code()`.
+
+### โฑ Quickstart
+1. Head to โ๐ Semantic Searchโ, describe what you're building, and submit.
+2. Copy the suggested path from the results.
+3. Open โ๐ป Get Codeโ, paste the path, and grab the exact source.
+4. Want the big picture first? Start with โ๐ Browse Itemsโ.
+
+### ๐ก Power tip
+Run locally or on Spaces, then connect any MCP-compatible client to orchestrate the same workflow programmatically.
+"""
+
+ search_article = """
+๐งญ How to use
+1) Describe your task with as much signal as possible (dataset, modality, constraints, target metric).
+2) Click Submit or pick an example. We compute embeddings and retrieve the closest KB match.
+3) Copy the 'path' value and open it in the โ๐ป Get Codeโ tab to view the full source.
+
+๐ง Notes
+- Markdown is supported. Bullet points and short snippets help a lot.
+- Similarity uses cosine distance on L2โnormalized sentence-transformer embeddings.
+ """
+ code_article = """
+Paste a valid knowledge base path to fetch the full Python source.
+
+๐ Examples
+- knowledge_base/nlp/text_classification_with_transformer.py
+- nlp/text_classification_with_transformer.py
+
+๐ก Tips
+- Accepts both absolute KB paths and '/.py'.
+- The code block is copy-friendly for quick reuse.
+ """
+
+ list_ui = gr.Interface(
+ fn=list_items,
+ inputs=None,
+ outputs=gr.JSON(label="๐ฆ Items (JSON)"),
+ title="๐ Browse Items",
+ description="Explore every ML Starter KB entry โ id, category, path, and summary.",
+ article="",
+ )
+
+ search_ui = gr.Interface(
+ fn=semantic_search,
+ inputs=gr.Textbox(
+ lines=10,
+ label="โ๏ธ Describe your problem (Markdown supported)",
+ placeholder="e.g., Fine-tune a transformer for sentiment classification on IMDB (dataset, goal, constraints)"
+ ),
+ outputs=gr.JSON(label="๐ Best match + similarity score"),
+ title="๐ Semantic Search",
+ description="Paste your task. We compute embeddings and return the closest KB recipe with a score.",
+ examples=search_examples,
+ article=search_article,
+ )
+
+ code_ui = gr.Interface(
+ fn=get_code,
+ inputs=gr.Textbox(
+ lines=1,
+ label="๐ KB file path",
+ placeholder="knowledge_base/nlp/text_classification_with_transformer.py"
+ ),
+ outputs=gr.Code(label="๐งฉ Python source", language="python"),
+ title="๐ป Get Code",
+ description="Paste a KB path and copy the exact source into your project.",
+ examples=code_examples,
+ article=code_article,
+ )
+
+ # Compose top-level layout: explanation on top, tabs below
+ with gr.Blocks() as blocks:
+ gr.HTML(f"")
+ gr.Markdown(hero_md)
+ with gr.Tabs():
+ with gr.Tab("๐ List Items"):
+ list_ui.render()
+ with gr.Tab("๐ Semantic Search"):
+ search_ui.render()
+ with gr.Tab("๐ป Get Code"):
+ code_ui.render()
+ return blocks
+
+
+def main() -> None:
+ """
+ Entry point: Launch Gradio UI and expose remote MCP over SSE at /gradio_api/mcp/sse
+ """
+ parser = argparse.ArgumentParser(description="ML Starter MCP Server (Gradio Remote Only)")
+ parser.add_argument("--host", default="127.0.0.1", help="Host for Gradio")
+ parser.add_argument("--port", type=int, default=7860, help="Port for Gradio")
+ args = parser.parse_args()
+
+ # Derive host/port from environment for Hugging Face Spaces and containers
+ env_host = os.getenv("GRADIO_SERVER_NAME") or os.getenv("HOST") or args.host
+ env_port_str = os.getenv("GRADIO_SERVER_PORT") or os.getenv("PORT")
+ env_port = int(env_port_str) if env_port_str and env_port_str.isdigit() else args.port
+
+ # If running on HF Spaces, bind to 0.0.0.0 unless explicitly overridden
+ if os.getenv("SPACE_ID") and env_host in ("127.0.0.1", "localhost"):
+ env_host = "0.0.0.0"
+
+ blocks = create_gradio_blocks()
+ blocks.launch(server_name=env_host, server_port=env_port, mcp_server=True)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/mcp_server/tools/__init__.py b/mcp_server/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ffc238b15cc0177fba96ab16b255fb8e4f7c55e
--- /dev/null
+++ b/mcp_server/tools/__init__.py
@@ -0,0 +1 @@
+__all__ = ["list_items", "semantic_search", "get_code"]
\ No newline at end of file
diff --git a/mcp_server/tools/get_code.py b/mcp_server/tools/get_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..d578e3ff447b3817862dadf810bc40b4325748e7
--- /dev/null
+++ b/mcp_server/tools/get_code.py
@@ -0,0 +1,24 @@
+from __future__ import annotations
+
+from mcp_server.loader import read_code
+
+
+__all__ = ["get_code"]
+
+
+def get_code(path: str) -> str:
+ """
+ Return full Python source code for a knowledge base file.
+
+ Args:
+ path: Either:
+ - "knowledge_base//.py"
+ - "/.py"
+
+ Raises:
+ ValueError, FileNotFoundError per validation rules.
+
+ Returns:
+ The full UTF-8 text of the Python file.
+ """
+ return read_code(path)
\ No newline at end of file
diff --git a/mcp_server/tools/list_items.py b/mcp_server/tools/list_items.py
new file mode 100644
index 0000000000000000000000000000000000000000..562aae48ba85a7e97a3d1f19dc4ccc61c4796df6
--- /dev/null
+++ b/mcp_server/tools/list_items.py
@@ -0,0 +1,25 @@
+from __future__ import annotations
+
+from typing import List, Dict
+
+from mcp_server.loader import list_items_dict
+
+
+__all__ = ["list_items"]
+
+
+def list_items() -> List[Dict]:
+ """
+ Return all KB items with minimal metadata:
+ [
+ {
+ "id": "nlp/text_classification_with_transformer.py",
+ "category": "nlp",
+ "filename": "text_classification_with_transformer.py",
+ "path": "knowledge_base/nlp/text_classification_with_transformer.py",
+ "summary": "...",
+ },
+ ...
+ ]
+ """
+ return list_items_dict()
\ No newline at end of file
diff --git a/mcp_server/tools/semantic_search.py b/mcp_server/tools/semantic_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a02b862d00518744d63b151de1c7e3d6d3d904e
--- /dev/null
+++ b/mcp_server/tools/semantic_search.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+from typing import Dict, Tuple, List
+
+from mcp_server.loader import get_items_for_embedding, get_embedding_text
+from mcp_server.embeddings import Embedder, EmbeddingIndex, IndexedItem
+
+
+__all__ = ["semantic_search"]
+
+# Lazy singletons
+_embedder: Embedder | None = None
+_index: EmbeddingIndex | None = None
+_built: bool = False
+
+
+def _ensure_index() -> EmbeddingIndex:
+ global _embedder, _index, _built
+ if _embedder is None:
+ _embedder = Embedder()
+ if _index is None:
+ _index = EmbeddingIndex(_embedder)
+ if not _built:
+ pairs = get_items_for_embedding() # List[ (KBItem, text) ]
+ items: List[IndexedItem] = [
+ IndexedItem(
+ id=it.id,
+ category=it.category,
+ filename=it.filename,
+ path=it.path,
+ summary=it.summary,
+ )
+ for it, _ in pairs
+ ]
+ texts: List[str] = [get_embedding_text(it) for it, _ in pairs]
+ _index.build(items, texts)
+ _built = True
+ return _index
+
+
+def semantic_search(problem_markdown: str) -> Dict:
+ """
+ Return only the best match and its score.
+ {
+ "best_match": "knowledge_base/nlp/text_classification_with_transformer.py",
+ "score": 0.89
+ }
+ """
+ if not isinstance(problem_markdown, str) or not problem_markdown.strip():
+ raise ValueError("problem_markdown must be a non-empty string")
+ index = _ensure_index()
+ best_item, score = index.search_one(problem_markdown)
+ return {
+ "best_match": best_item,
+ "score": round(float(score), 6),
+ }
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e0444f8330e48e1e50f49bfad1aa5ec06a367266
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+# Gradio MCP (remote server only)
+gradio[mcp]>=5.0.0
+
+# Retrieval
+sentence-transformers>=3.0.1,<4
+numpy>=1.26,<3
+# PyTorch backend for sentence-transformers; pick a CPU or CUDA build as appropriate for your system
+torch>=2.1.0
+huggingface_hub>=0.23.0
\ No newline at end of file