|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
from peft import BOFTConfig, PeftModel, get_peft_model |
|
|
from peft.utils import infer_device |
|
|
|
|
|
|
|
|
class TestBoft: |
|
|
device = infer_device() |
|
|
|
|
|
def test_boft_state_dict(self, tmp_path): |
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
inputs = torch.arange(10).view(-1, 1).to(self.device) |
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) |
|
|
model.eval() |
|
|
output_base = model(inputs).logits |
|
|
|
|
|
config = BOFTConfig(init_weights=False) |
|
|
model = get_peft_model(model, config) |
|
|
model.eval() |
|
|
output_peft = model(inputs).logits |
|
|
|
|
|
atol, rtol = 1e-5, 1e-8 |
|
|
|
|
|
assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol) |
|
|
|
|
|
model.save_pretrained(tmp_path) |
|
|
del model |
|
|
|
|
|
|
|
|
state_dict = load_file(tmp_path / "adapter_model.safetensors") |
|
|
assert not any("boft_P" in key for key in state_dict) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) |
|
|
model = PeftModel.from_pretrained(model, tmp_path) |
|
|
output_loaded = model(inputs).logits |
|
|
assert torch.allclose(output_peft, output_loaded, atol=atol, rtol=rtol) |
|
|
|
|
|
def test_boft_old_checkpoint_including_boft_P(self, tmp_path): |
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
inputs = torch.arange(10).view(-1, 1).to(self.device) |
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) |
|
|
|
|
|
|
|
|
config = BOFTConfig(init_weights=False) |
|
|
model = get_peft_model(model, config) |
|
|
model.eval() |
|
|
output_peft = model(inputs).logits |
|
|
del model |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) |
|
|
|
|
|
hub_id = "peft-internal-testing/boft-tiny-opt-peft-v0.12" |
|
|
model = PeftModel.from_pretrained(model, hub_id) |
|
|
output_old = model(inputs).logits |
|
|
|
|
|
atol, rtol = 1e-5, 1e-8 |
|
|
assert torch.allclose(output_peft, output_old, atol=atol, rtol=rtol) |
|
|
|