Update readme with instructions on how to change the kernels. (#6)
Browse files- Update readme with instructions on how to change the kernels. (044aabbd6dbabb4c49e06f3fd961b656b9c8d734)
- fix import (8b0f5773ecf168d0eb6ad229bd6846148725e1d9)
Co-authored-by: Maximilian Beck <[email protected]>
    	
        README.md
    CHANGED
    
    | @@ -1,58 +1,73 @@ | |
| 1 | 
            -
            ---
         | 
| 2 | 
            -
            license: other
         | 
| 3 | 
            -
            ---
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            # xLSTM-7B
         | 
| 6 | 
            -
            This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.
         | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
            ## How to use it
         | 
| 10 | 
            -
            First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels:
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            ```bash
         | 
| 13 | 
            -
            pip install xlstm
         | 
| 14 | 
            -
            pip install mlstm_kernels
         | 
| 15 | 
            -
            ```
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            For now, install the transformers repositiory fork from NX-AI (until it is merged):
         | 
| 18 | 
            -
            ```bash
         | 
| 19 | 
            -
            pip install 'transformers @ git+ssh://[email protected]/NX-AI/transformers.git@integrate_xlstm'
         | 
| 20 | 
            -
            ```
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            Use this model as:
         | 
| 23 | 
            -
            ```python
         | 
| 24 | 
            -
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 25 | 
            -
             | 
| 26 | 
            -
            xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")
         | 
| 27 | 
            -
             | 
| 28 | 
            -
            # this is a fork of EleutherAI/gpt-neox-20b
         | 
| 29 | 
            -
            tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")
         | 
| 32 | 
            -
             | 
| 33 | 
            -
            out = xlstm.generate(tokens, max_new_tokens=20)
         | 
| 34 | 
            -
             | 
| 35 | 
            -
            print(tokenizer.decode(out[0]))
         | 
| 36 | 
            -
            ```
         | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
            ##  | 
| 58 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            license: other
         | 
| 3 | 
            +
            ---
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # xLSTM-7B
         | 
| 6 | 
            +
            This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            ## How to use it
         | 
| 10 | 
            +
            First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            ```bash
         | 
| 13 | 
            +
            pip install xlstm
         | 
| 14 | 
            +
            pip install mlstm_kernels
         | 
| 15 | 
            +
            ```
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            For now, install the transformers repositiory fork from NX-AI (until it is merged):
         | 
| 18 | 
            +
            ```bash
         | 
| 19 | 
            +
            pip install 'transformers @ git+ssh://[email protected]/NX-AI/transformers.git@integrate_xlstm'
         | 
| 20 | 
            +
            ```
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            Use this model as:
         | 
| 23 | 
            +
            ```python
         | 
| 24 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            # this is a fork of EleutherAI/gpt-neox-20b
         | 
| 29 | 
            +
            tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            out = xlstm.generate(tokens, max_new_tokens=20)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            print(tokenizer.decode(out[0]))
         | 
| 36 | 
            +
            ```
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            If you cannot or do not want to use the triton kernels, you can change them to native PyTorch implementations:
         | 
| 39 | 
            +
            ```python
         | 
| 40 | 
            +
            xlstm_config = AutoConfig.from_pretrained("NX-AI/xLSTM-7b")
         | 
| 41 | 
            +
            xlstm_config.step_kernel = "native"
         | 
| 42 | 
            +
            xlstm_config.chunkwise_kernel = "chunkwise--native_autograd"
         | 
| 43 | 
            +
            xlstm_config.sequence_kernel = "native_sequence__native"
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", config=xlstm_config, device_map="auto")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            # verify selected kernels
         | 
| 48 | 
            +
            from pprint import pprint
         | 
| 49 | 
            +
            pprint(xlstm.backbone.blocks[0].mlstm_layer.config)
         | 
| 50 | 
            +
            ```
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            ## Speed results
         | 
| 54 | 
            +
            Generation Speed using `torch.cuda.graph` and `torch.compile` optimizations on one NVIDIA H100:
         | 
| 55 | 
            +
            
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            ## Performance
         | 
| 58 | 
            +
            
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            Using HuggingFace's `lm_eval`:
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            | BBH   | MMLU-Pro | Math   | MUSR | GPQA | IfEval | 
         | 
| 63 | 
            +
            |-------|----------|--------|------|------|--------|
         | 
| 64 | 
            +
            | 0.381	| 0.242	   | 0.036	| 0.379|0.280 |	0.244  |
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            Using HuggingFace's `lighteval` in the Leaderboard-v1 settings:
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            |Arc-Challenge (25-shot) |MMLU (5-shot) |Hellaswag (10-shot)|Winogrande (5-shot) |TruthfulQA (0-shot) |GSM8k (5-shot) |OpenbookQA (5-shot) | PiQA (5-shot)|
         | 
| 69 | 
            +
            |------------------------|--------------|-------------------|--------------------|--------------------|---------------|--------------------|--------------|
         | 
| 70 | 
            +
            | 0.584	                 |0.589         |           0.710   |0.742               |          0.420     |         0.004 |         0.443      |        0.817 |
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            ## License 
         | 
| 73 | 
            +
            NXAI Community License (see `LICENSE` file)
         | 
