Spaces:
Runtime error
Runtime error
Commit
·
5e33295
1
Parent(s):
177344c
update: LlamaGuardFineTuner
Browse files
.gitignore
CHANGED
|
@@ -168,4 +168,5 @@ temp.txt
|
|
| 168 |
binary-classifier/
|
| 169 |
wandb/
|
| 170 |
artifacts/
|
| 171 |
-
evaluation_results/
|
|
|
|
|
|
| 168 |
binary-classifier/
|
| 169 |
wandb/
|
| 170 |
artifacts/
|
| 171 |
+
evaluation_results/
|
| 172 |
+
checkpoints/
|
application_pages/llama_guard_fine_tuning.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
|
| 4 |
|
| 5 |
|
| 6 |
def initialize_session_state():
|
| 7 |
-
st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
if "dataset_address" not in st.session_state:
|
| 9 |
st.session_state.dataset_address = ""
|
| 10 |
if "train_dataset_range" not in st.session_state:
|
|
@@ -25,6 +31,14 @@ def initialize_session_state():
|
|
| 25 |
st.session_state.evaluation_batch_size = None
|
| 26 |
if "evaluation_temperature" not in st.session_state:
|
| 27 |
st.session_state.evaluation_temperature = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
initialize_session_state()
|
|
@@ -43,18 +57,34 @@ if st.session_state.dataset_address != "":
|
|
| 43 |
st.session_state.train_dataset_range = train_dataset_range
|
| 44 |
st.session_state.test_dataset_range = test_dataset_range
|
| 45 |
|
| 46 |
-
model_name = st.sidebar.
|
| 47 |
-
"Model Name",
|
| 48 |
-
["meta-llama/Prompt-Guard-86M"],
|
| 49 |
)
|
| 50 |
st.session_state.model_name = model_name
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
preview_dataset = st.sidebar.toggle("Preview Dataset")
|
| 53 |
st.session_state.preview_dataset = preview_dataset
|
| 54 |
|
| 55 |
evaluate_model = st.sidebar.toggle("Evaluate Model")
|
| 56 |
st.session_state.evaluate_model = evaluate_model
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
|
| 59 |
st.session_state.load_fine_tuner_button = load_fine_tuner_button
|
| 60 |
|
|
@@ -68,13 +98,19 @@ if st.session_state.dataset_address != "":
|
|
| 68 |
)
|
| 69 |
)
|
| 70 |
st.session_state.llama_guard_fine_tuner.load_model(
|
| 71 |
-
model_name=st.session_state.model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
if st.session_state.preview_dataset:
|
| 74 |
st.session_state.llama_guard_fine_tuner.show_dataset_sample()
|
| 75 |
if st.session_state.evaluate_model:
|
| 76 |
st.session_state.llama_guard_fine_tuner.evaluate_model(
|
| 77 |
-
batch_size=
|
| 78 |
-
|
|
|
|
| 79 |
)
|
| 80 |
st.session_state.is_fine_tuner_loaded = True
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
|
| 5 |
from guardrails_genie.train.llama_guard import DatasetArgs, LlamaGuardFineTuner
|
| 6 |
|
| 7 |
|
| 8 |
def initialize_session_state():
|
| 9 |
+
st.session_state.llama_guard_fine_tuner = LlamaGuardFineTuner(
|
| 10 |
+
wandb_project=os.getenv("WANDB_PROJECT_NAME"),
|
| 11 |
+
wandb_entity=os.getenv("WANDB_ENTITY_NAME"),
|
| 12 |
+
streamlit_mode=True,
|
| 13 |
+
)
|
| 14 |
if "dataset_address" not in st.session_state:
|
| 15 |
st.session_state.dataset_address = ""
|
| 16 |
if "train_dataset_range" not in st.session_state:
|
|
|
|
| 31 |
st.session_state.evaluation_batch_size = None
|
| 32 |
if "evaluation_temperature" not in st.session_state:
|
| 33 |
st.session_state.evaluation_temperature = None
|
| 34 |
+
if "checkpoint" not in st.session_state:
|
| 35 |
+
st.session_state.checkpoint = None
|
| 36 |
+
if "eval_batch_size" not in st.session_state:
|
| 37 |
+
st.session_state.eval_batch_size = 32
|
| 38 |
+
if "eval_positive_label" not in st.session_state:
|
| 39 |
+
st.session_state.eval_positive_label = 2
|
| 40 |
+
if "eval_temperature" not in st.session_state:
|
| 41 |
+
st.session_state.eval_temperature = 1.0
|
| 42 |
|
| 43 |
|
| 44 |
initialize_session_state()
|
|
|
|
| 57 |
st.session_state.train_dataset_range = train_dataset_range
|
| 58 |
st.session_state.test_dataset_range = test_dataset_range
|
| 59 |
|
| 60 |
+
model_name = st.sidebar.text_input(
|
| 61 |
+
label="Model Name", value="meta-llama/Prompt-Guard-86M"
|
|
|
|
| 62 |
)
|
| 63 |
st.session_state.model_name = model_name
|
| 64 |
|
| 65 |
+
checkpoint = st.sidebar.text_input(label="Fine-tuned Model Checkpoint", value="")
|
| 66 |
+
st.session_state.checkpoint = checkpoint
|
| 67 |
+
|
| 68 |
preview_dataset = st.sidebar.toggle("Preview Dataset")
|
| 69 |
st.session_state.preview_dataset = preview_dataset
|
| 70 |
|
| 71 |
evaluate_model = st.sidebar.toggle("Evaluate Model")
|
| 72 |
st.session_state.evaluate_model = evaluate_model
|
| 73 |
|
| 74 |
+
if st.session_state.evaluate_model:
|
| 75 |
+
eval_batch_size = st.sidebar.slider(
|
| 76 |
+
label="Eval Batch Size", min_value=16, max_value=1024, value=32
|
| 77 |
+
)
|
| 78 |
+
st.session_state.eval_batch_size = eval_batch_size
|
| 79 |
+
|
| 80 |
+
eval_positive_label = st.sidebar.number_input("EVal Positive Label", value=2)
|
| 81 |
+
st.session_state.eval_positive_label = eval_positive_label
|
| 82 |
+
|
| 83 |
+
eval_temperature = st.sidebar.slider(
|
| 84 |
+
label="Eval Temperature", min_value=0.0, max_value=5.0, value=1.0
|
| 85 |
+
)
|
| 86 |
+
st.session_state.eval_temperature = eval_temperature
|
| 87 |
+
|
| 88 |
load_fine_tuner_button = st.sidebar.button("Load Fine-Tuner")
|
| 89 |
st.session_state.load_fine_tuner_button = load_fine_tuner_button
|
| 90 |
|
|
|
|
| 98 |
)
|
| 99 |
)
|
| 100 |
st.session_state.llama_guard_fine_tuner.load_model(
|
| 101 |
+
model_name=st.session_state.model_name,
|
| 102 |
+
checkpoint=(
|
| 103 |
+
None
|
| 104 |
+
if st.session_state.checkpoint == ""
|
| 105 |
+
else st.session_state.checkpoint
|
| 106 |
+
),
|
| 107 |
)
|
| 108 |
if st.session_state.preview_dataset:
|
| 109 |
st.session_state.llama_guard_fine_tuner.show_dataset_sample()
|
| 110 |
if st.session_state.evaluate_model:
|
| 111 |
st.session_state.llama_guard_fine_tuner.evaluate_model(
|
| 112 |
+
batch_size=st.session_state.eval_batch_size,
|
| 113 |
+
positive_label=st.session_state.eval_positive_label,
|
| 114 |
+
temperature=st.session_state.eval_temperature,
|
| 115 |
)
|
| 116 |
st.session_state.is_fine_tuner_loaded = True
|
guardrails_genie/train/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
|
|
| 1 |
from .train_classifier import train_binary_classifier
|
| 2 |
-
from .llama_guard import LlamaGuardFineTuner, DatasetArgs
|
| 3 |
|
| 4 |
-
__all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
|
|
|
|
| 1 |
+
from .llama_guard import DatasetArgs, LlamaGuardFineTuner
|
| 2 |
from .train_classifier import train_binary_classifier
|
|
|
|
| 3 |
|
| 4 |
+
__all__ = ["train_binary_classifier", "LlamaGuardFineTuner", "DatasetArgs"]
|
guardrails_genie/train/llama_guard.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import shutil
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
import streamlit as st
|
|
@@ -7,15 +9,16 @@ import torch
|
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
import torch.optim as optim
|
| 10 |
-
import wandb
|
| 11 |
from datasets import load_dataset
|
| 12 |
from pydantic import BaseModel
|
| 13 |
from rich.progress import track
|
| 14 |
-
from safetensors.torch import save_model
|
| 15 |
from sklearn.metrics import roc_auc_score, roc_curve
|
| 16 |
from torch.utils.data import DataLoader
|
| 17 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class DatasetArgs(BaseModel):
|
| 21 |
dataset_address: str
|
|
@@ -30,7 +33,7 @@ class LlamaGuardFineTuner:
|
|
| 30 |
classification tasks, specifically for detecting prompt injection attacks. It
|
| 31 |
integrates with Weights & Biases for experiment tracking and optionally
|
| 32 |
displays progress in a Streamlit app.
|
| 33 |
-
|
| 34 |
!!! example "Sample Usage"
|
| 35 |
```python
|
| 36 |
from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
|
|
@@ -98,7 +101,11 @@ class LlamaGuardFineTuner:
|
|
| 98 |
else dataset["test"].select(range(dataset_args.test_dataset_range))
|
| 99 |
)
|
| 100 |
|
| 101 |
-
def load_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"""
|
| 103 |
Loads the specified pre-trained model and tokenizer for sequence classification tasks.
|
| 104 |
|
|
@@ -118,9 +125,20 @@ class LlamaGuardFineTuner:
|
|
| 118 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 119 |
self.model_name = model_name
|
| 120 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 121 |
-
|
| 122 |
-
self.
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def show_dataset_sample(self):
|
| 126 |
"""
|
|
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
+
from glob import glob
|
| 4 |
+
from typing import Optional
|
| 5 |
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import streamlit as st
|
|
|
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
import torch.optim as optim
|
|
|
|
| 12 |
from datasets import load_dataset
|
| 13 |
from pydantic import BaseModel
|
| 14 |
from rich.progress import track
|
| 15 |
+
from safetensors.torch import load_model, save_model
|
| 16 |
from sklearn.metrics import roc_auc_score, roc_curve
|
| 17 |
from torch.utils.data import DataLoader
|
| 18 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 19 |
|
| 20 |
+
import wandb
|
| 21 |
+
|
| 22 |
|
| 23 |
class DatasetArgs(BaseModel):
|
| 24 |
dataset_address: str
|
|
|
|
| 33 |
classification tasks, specifically for detecting prompt injection attacks. It
|
| 34 |
integrates with Weights & Biases for experiment tracking and optionally
|
| 35 |
displays progress in a Streamlit app.
|
| 36 |
+
|
| 37 |
!!! example "Sample Usage"
|
| 38 |
```python
|
| 39 |
from guardrails_genie.train.llama_guard import LlamaGuardFineTuner, DatasetArgs
|
|
|
|
| 101 |
else dataset["test"].select(range(dataset_args.test_dataset_range))
|
| 102 |
)
|
| 103 |
|
| 104 |
+
def load_model(
|
| 105 |
+
self,
|
| 106 |
+
model_name: str = "meta-llama/Prompt-Guard-86M",
|
| 107 |
+
checkpoint: Optional[str] = None,
|
| 108 |
+
):
|
| 109 |
"""
|
| 110 |
Loads the specified pre-trained model and tokenizer for sequence classification tasks.
|
| 111 |
|
|
|
|
| 125 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 126 |
self.model_name = model_name
|
| 127 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 128 |
+
if checkpoint is None:
|
| 129 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 130 |
+
model_name
|
| 131 |
+
).to(self.device)
|
| 132 |
+
else:
|
| 133 |
+
api = wandb.Api()
|
| 134 |
+
artifact = api.artifact(checkpoint.removeprefix("wandb://"))
|
| 135 |
+
artifact_dir = artifact.download()
|
| 136 |
+
model_file_path = glob(os.path.join(artifact_dir, "model-*.safetensors"))[0]
|
| 137 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 138 |
+
self.model.classifier = nn.Linear(self.model.classifier.in_features, 2)
|
| 139 |
+
self.model.num_labels = 2
|
| 140 |
+
load_model(self.model, model_file_path)
|
| 141 |
+
self.model = self.model.to(self.device)
|
| 142 |
|
| 143 |
def show_dataset_sample(self):
|
| 144 |
"""
|
guardrails_genie/train/train_classifier.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import evaluate
|
| 2 |
import numpy as np
|
| 3 |
import streamlit as st
|
| 4 |
-
import wandb
|
| 5 |
from datasets import load_dataset
|
| 6 |
from transformers import (
|
| 7 |
AutoModelForSequenceClassification,
|
|
@@ -11,6 +10,7 @@ from transformers import (
|
|
| 11 |
TrainingArguments,
|
| 12 |
)
|
| 13 |
|
|
|
|
| 14 |
from guardrails_genie.utils import StreamlitProgressbarCallback
|
| 15 |
|
| 16 |
|
|
|
|
| 1 |
import evaluate
|
| 2 |
import numpy as np
|
| 3 |
import streamlit as st
|
|
|
|
| 4 |
from datasets import load_dataset
|
| 5 |
from transformers import (
|
| 6 |
AutoModelForSequenceClassification,
|
|
|
|
| 10 |
TrainingArguments,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
import wandb
|
| 14 |
from guardrails_genie.utils import StreamlitProgressbarCallback
|
| 15 |
|
| 16 |
|