| from modules import * | |
| class Guard(): | |
| def __init__(self, fn): | |
| self.fn = fn | |
| self.detector = Detector(binary=True) | |
| self.sanitizer = IterativeSanitizer() | |
| self.classifier = Classifier() | |
| def __call__(self, inp, classifier=False, sanitizer=False): | |
| output = { | |
| "safe": [], | |
| "class": [], | |
| "sanitized": [], | |
| } | |
| if type(inp) == str: | |
| inp = [inp] | |
| vuln = self.detector.forward(inp) | |
| v = vuln[0] | |
| # [0 1 1 1 0 0] | |
| output["safe"].append(v == 0) | |
| if v == 0: | |
| output["class"].append('safe input (no classification)') | |
| output["sanitized"].append('safe input (no sanitization)') | |
| response = self.fn.forward(inp[0]) | |
| else: # v == 1 -> unsafe case | |
| if classifier: | |
| classification = self.classifier.forward(inp) | |
| output["class"].append(classification) | |
| if sanitizer: | |
| sanitized = self.sanitizer.forward(inp) | |
| output["sanitized"].append(sanitized) | |
| response = self.fn.forward(sanitized) | |
| if not sanitizer: | |
| response = "Sorry, this is detected as a dangerous input." | |
| return response, output | |
| """ | |
| actual call: | |
| gpt = GPT() | |
| out = gpt(inp) | |
| llm = Guard(llm) | |
| print(llm("what is the meaning of life?")) | |
| """ |