Skip to content

Commit

Permalink
feat: improve model download/convert; copy model
Browse files Browse the repository at this point in the history
assets
  • Loading branch information
cdiddy77 committed Oct 5, 2024
1 parent 0636d71 commit 9aa0210
Show file tree
Hide file tree
Showing 11 changed files with 1,680 additions and 106 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,6 @@ ungitable/
# download/convert model files
models/download/
models/converted/

# environment variables
.env
6 changes: 5 additions & 1 deletion example/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
vendor/
Pods/
Pods/
falcon-*.bin
gemma-*.bin
phi-*.bin
stablelm-*.bin
16 changes: 16 additions & 0 deletions example/android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ android {
}
}

project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'

def sourceFile = file('../../../models/converted/gemma-2b-it-cpu-int4.bin')
def destinationFile = file("${project.ext.ASSET_DIR}/gemma-2b-it-cpu-int4.bin")

task copyModelFile(type: Copy) {
onlyIf { !destinationFile.exists() }
from sourceFile
into project.ext.ASSET_DIR
doFirst {
println "Copying ${sourceFile} to ${destinationFile}"
}
}

preBuild.dependsOn copyModelFile

dependencies {
// The version of react-native is set by the React Native Gradle Plugin
implementation("com.facebook.react:react-android")
Expand Down
20 changes: 20 additions & 0 deletions example/ios/LlmMediapipeExample.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
buildConfigurationList = 13B07F931A680F5B00A75B9A /* Build configuration list for PBXNativeTarget "LlmMediapipeExample" */;
buildPhases = (
C38B50BA6285516D6DCD4F65 /* [CP] Check Pods Manifest.lock */,
37D0E69A2C37215F007DB3A5 /* Copy Gemma Model */,
13B07F871A680F5B00A75B9A /* Sources */,
13B07F8C1A680F5B00A75B9A /* Frameworks */,
13B07F8E1A680F5B00A75B9A /* Resources */,
Expand Down Expand Up @@ -286,6 +287,25 @@
shellScript = "\"${PODS_ROOT}/Target Support Files/Pods-LlmMediapipeExample/Pods-LlmMediapipeExample-frameworks.sh\"\n";
showEnvVarsInLog = 0;
};
37D0E69A2C37215F007DB3A5 /* Copy Gemma Model */ = {
isa = PBXShellScriptBuildPhase;
alwaysOutOfDate = 1;
buildActionMask = 2147483647;
files = (
);
inputFileListPaths = (
);
inputPaths = (
);
name = "Copy Gemma Model";
outputFileListPaths = (
);
outputPaths = (
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "$SRCROOT/RunScripts/copy_model.sh\n";
};
A55EABD7B0C7F3A422A6CC61 /* [CP] Check Pods Manifest.lock */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
Expand Down
10 changes: 10 additions & 0 deletions example/ios/RunScripts/copy_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
echo "INFO: Copying model for iOS."

# Copy gemma-2b-it-cpu-int4.bin from the models built folder if it doesn't exist.
TFLITE_FILE=./gemma-2b-it-cpu-int4.bin
if test -f "$TFLITE_FILE"; then
echo "INFO: gemma-2b-it-cpu-int4.bin exists. Skip downloading and use the local model."
else
cp ../../models/converted/gemma-2b-it-cpu-int4.bin ${TFLITE_FILE}
echo "INFO: Copied gemma-2b-it-cpu-int4.bin to $TFLITE_FILE ."
fi
1 change: 1 addition & 0 deletions models/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HF_API_TOKEN="YOUR_HUGGINGFACE_API_TOKEN"
53 changes: 13 additions & 40 deletions models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,68 +4,41 @@ This README provides detailed instructions on how to build model files using a P

## Prerequisites

- Python 3.10 installed on your system (prefer `pyenv`)
- Python 3.10 installed on your system (tested with `pyenv`)
- Poetry installed
- Command line access (Terminal on macOS and Linux, CMD or PowerShell on Windows)

## Setup

1. **Create and Activate a Python Virtual Environment**:
Create a new virtual environment in the models directory by running:

```bash
python -m venv venv
```

Activate the virtual environment:

- On Windows:
```bash
.\venv\Scripts\activate
```
- On macOS and Linux:
```bash
source venv/bin/activate
```

2. **Navigate to the Models Directory**:
Open your command line interface and navigate to the models directory where the model files and scripts are located.

```bash
cd models
```

3. **Install Required Packages**:
Ensure that the virtual environment is activated. Install all required Python packages using the requirements file provided:

```bash
pip install -r requirements.txt
```
```bash
cd models
poetry install
```

## Building the Models

1. **Run the Build Script**:
Execute the `build.py` script located in the models directory to start the model building process:

```bash
python build.py
poetry run python build.py
```

This script will download and process model files as necessary in accordance with the [instructions on the MediaPipe website](https://developers.google.com/mediapipe/solutions/genai/llm_inference#models). Ensure that no errors are displayed in the command line output.

> NOTE: the gemma model requires that you accept the terms and conditions. Go to the
> [Model Page](https://huggingface.co/google/gemma-2b-it) and accept the agreement, and then grab your API key and put it in `models/.env` as `HF_API_TOKEN` (see `models/.env.example`)

> NOTE: as of 7/4/2024, I was not able to successfully convert the Gemma model. You can also
> download the model directly from Kaggle: ([cpu](https://www.kaggle.com/models/google/gemma/tfLite/gemma-2b-it-cpu-int4), [gpu](https://www.kaggle.com/models/google/gemma/tfLite/gemma-2b-it-gpu-int4))

2. **Verify Output**:
If the script runs successfully, the built or converted models will be located in the `models/converted` directory. Verify that the models exist in this directory.

```bash
ls models/converted
```

## Troubleshooting

If you encounter any errors during the setup or build process, please ensure that:

- Python 3.10 is properly installed and the path is correctly configured (we prefer to use `pyenv`)
- All commands are executed in the virtual environment.

## Conclusion

You have successfully set up your environment and run the build script to download and convert model files. These models are now ready to be copied to the appropriate asset directories in your target application.
78 changes: 52 additions & 26 deletions models/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,35 @@
import sys
import mediapipe as mp
from mediapipe.tasks.python.genai import converter
import dotenv

dotenv.load_dotenv()

def download_files(url_path_pairs, chunk_size=1024 * 1024): # Default chunk size: 1 MB

def download_files(
url_path_pairs, chunk_size=1024 * 1024, force=False
): # Default chunk size: 1 MB
for url, path in url_path_pairs:
# Ensure the directory exists
os.makedirs(os.path.dirname(path), exist_ok=True)

if not force and os.path.exists(path):
print(f"File '{path}' already exists. Skipping download.")
continue

headers = (
{"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"}
if os.getenv("HF_API_TOKEN")
else {}
)

try:
# Stream the download to handle large files
with requests.get(url, stream=True) as response:
with requests.get(url, stream=True, headers=headers) as response:
response.raise_for_status() # Raise an exception for HTTP errors

# Write the content of the response to a file in chunks
with open(path, 'wb') as file:
with open(path, "wb") as file:
for chunk in response.iter_content(chunk_size=chunk_size):
# filter out keep-alive new chunks
if chunk:
Expand All @@ -33,14 +48,14 @@ def convert_models(convert_jobs):
# Convert the model
print(f"Converting model: {job['title']}")
config = converter.ConversionConfig(
input_ckpt=job['input_ckpt'],
ckpt_format=job['ckpt_format'],
model_type=job['model_type'],
backend=job['backend'],
output_dir=job['output_dir'],
input_ckpt=job["input_ckpt"],
ckpt_format=job["ckpt_format"],
model_type=job["model_type"],
backend=job["backend"],
output_dir=job["output_dir"],
combine_file_only=False,
vocab_model_file=job['vocab_model_file'],
output_tflite_file=job['output_tflite_file'],
vocab_model_file=job["vocab_model_file"],
output_tflite_file=job["output_tflite_file"],
)

converter.convert_checkpoint(config)
Expand All @@ -66,7 +81,7 @@ def convert_models(convert_jobs):
"download/gemma-2b-it/model-00001-of-00002.safetensors",
),
(
"https://huggingface.co/google/gemma-2b-it/resolve/main/model-00002-of-00002.safetensors.safetensors?download=true",
"https://huggingface.co/google/gemma-2b-it/resolve/main/model-00002-of-00002.safetensors?download=true",
"download/gemma-2b-it/model-00002-of-00002.safetensors.safetensors",
),
(
Expand Down Expand Up @@ -132,12 +147,12 @@ def create_convert_params(name, backend, model_type, ckpt, ckpt_format):
create_convert_params(
"falcon-rw-1b", "cpu", "FALCON_RW_1B", "pytorch_model.bin", "pytorch"
),
create_convert_params(
"gemma-2b-it", "gpu", "GEMMA_2B", "model-*.safetensors", "safetensors"
),
create_convert_params(
"gemma-2b-it", "cpu", "GEMMA_2B", "model-*.safetensors", "safetensors"
),
# create_convert_params(
# "gemma-2b-it", "gpu", "GEMMA_2B", "model-*.safetensors", "safetensors"
# ),
# create_convert_params(
# "gemma-2b-it", "cpu", "GEMMA_2B", "model-*.safetensors", "safetensors"
# ),
create_convert_params(
"stablelm-3b-4e1t",
"gpu",
Expand All @@ -161,13 +176,24 @@ def create_convert_params(name, backend, model_type, ckpt, ckpt_format):
]

if __name__ == "__main__":
if len(sys.argv) == 1:
# No command-line arguments were provided, run both functions
download_files(url_path_pairs)
import argparse

parser = argparse.ArgumentParser(description="Download and convert models.")
parser.add_argument("--download", action="store_true", help="Download files")
parser.add_argument("--convert", action="store_true", help="Convert models")
parser.add_argument(
"--force",
action="store_true",
help="Force download even if files already exist",
)

args = parser.parse_args()

if args.download:
download_files(url_path_pairs, force=args.force)
if args.convert:
convert_models(convert_jobs)

if not args.download and not args.convert:
download_files(url_path_pairs, force=args.force)
convert_models(convert_jobs)
else:
# Process command-line arguments
if "--download" in sys.argv:
download_files(url_path_pairs)
if "--convert" in sys.argv:
convert_models(convert_jobs)
Loading

0 comments on commit 9aa0210

Please sign in to comment.