Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						8f9f548
	
1
								Parent(s):
							
							52e8ef5
								
Deploy to HF Space
Browse files- __pycache__/configs.cpython-39.pyc +0 -0
 - __pycache__/fastapi_app.cpython-39.pyc +0 -0
 - __pycache__/models.cpython-39.pyc +0 -0
 - __pycache__/utils.cpython-39.pyc +0 -0
 - codereviewerapp.py +153 -0
 - configs.py +252 -0
 - models.py +208 -0
 - utils.py +823 -0
 
    	
        __pycache__/configs.cpython-39.pyc
    ADDED
    
    | 
         Binary file (5.38 kB). View file 
     | 
| 
         | 
    	
        __pycache__/fastapi_app.cpython-39.pyc
    ADDED
    
    | 
         Binary file (4.81 kB). View file 
     | 
| 
         | 
    	
        __pycache__/models.cpython-39.pyc
    ADDED
    
    | 
         Binary file (6.67 kB). View file 
     | 
| 
         | 
    	
        __pycache__/utils.cpython-39.pyc
    ADDED
    
    | 
         Binary file (27.8 kB). View file 
     | 
| 
         | 
    	
        codereviewerapp.py
    ADDED
    
    | 
         @@ -0,0 +1,153 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from fastapi import FastAPI, Request, Form
         
     | 
| 2 | 
         
            +
            from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
         
     | 
| 3 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 4 | 
         
            +
            from typing import List
         
     | 
| 5 | 
         
            +
            from clearml import Model
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from configs import add_args
         
     | 
| 8 | 
         
            +
            from models import build_or_load_gen_model
         
     | 
| 9 | 
         
            +
            import argparse
         
     | 
| 10 | 
         
            +
            from argparse import Namespace
         
     | 
| 11 | 
         
            +
            import os
         
     | 
| 12 | 
         
            +
            from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            MAX_SOURCE_LENGTH = 512
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def pad_assert(tokenizer, source_ids):
         
     | 
| 17 | 
         
            +
                source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
         
     | 
| 18 | 
         
            +
                source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
         
     | 
| 19 | 
         
            +
                pad_len = MAX_SOURCE_LENGTH - len(source_ids)
         
     | 
| 20 | 
         
            +
                source_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 21 | 
         
            +
                assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
         
     | 
| 22 | 
         
            +
                return source_ids
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            # Encode code content and comment into model input
         
     | 
| 25 | 
         
            +
            def encode_diff(tokenizer, code, comment):
         
     | 
| 26 | 
         
            +
                # Tokenize code file content
         
     | 
| 27 | 
         
            +
                code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
         
     | 
| 28 | 
         
            +
                # Tokenize comment
         
     | 
| 29 | 
         
            +
                comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
         
     | 
| 30 | 
         
            +
                # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
         
     | 
| 31 | 
         
            +
                source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
         
     | 
| 32 | 
         
            +
                source_ids += [tokenizer.msg_id] + comment_ids
         
     | 
| 33 | 
         
            +
                # Pad/truncate to fixed length
         
     | 
| 34 | 
         
            +
                source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
         
     | 
| 35 | 
         
            +
                source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
         
     | 
| 36 | 
         
            +
                pad_len = MAX_SOURCE_LENGTH - len(source_ids)
         
     | 
| 37 | 
         
            +
                source_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 38 | 
         
            +
                assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
         
     | 
| 39 | 
         
            +
                return source_ids
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # Load base model architecture and tokenizer from HuggingFace
         
     | 
| 42 | 
         
            +
            BASE_MODEL_NAME = "microsoft/codereviewer"
         
     | 
| 43 | 
         
            +
            args = Namespace(
         
     | 
| 44 | 
         
            +
                model_name_or_path=BASE_MODEL_NAME,
         
     | 
| 45 | 
         
            +
                load_model_path=None,
         
     | 
| 46 | 
         
            +
                # Add other necessary default arguments if build_or_load_gen_model requires them
         
     | 
| 47 | 
         
            +
            )
         
     | 
| 48 | 
         
            +
            print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
         
     | 
| 49 | 
         
            +
            config, base_model, tokenizer = build_or_load_gen_model(args)
         
     | 
| 50 | 
         
            +
            print("Base model architecture and tokenizer loaded.")
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            # Download the fine-tuned weights from ClearML
         
     | 
| 53 | 
         
            +
            CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e"
         
     | 
| 54 | 
         
            +
            model_obj = Model(model_id=CLEARML_MODEL_ID)
         
     | 
| 55 | 
         
            +
            finetuned_weights_path = model_obj.get_local_copy()
         
     | 
| 56 | 
         
            +
            adapter_dir = os.path.dirname(finetuned_weights_path)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}")
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            # Create LoRA configuration matching the fine-tuned checkpoint
         
     | 
| 61 | 
         
            +
            lora_cfg = LoraConfig(
         
     | 
| 62 | 
         
            +
                r=64,
         
     | 
| 63 | 
         
            +
                lora_alpha=128,
         
     | 
| 64 | 
         
            +
                target_modules=["q", "wo", "wi", "v", "o", "k"],
         
     | 
| 65 | 
         
            +
                lora_dropout=0.05,
         
     | 
| 66 | 
         
            +
                bias="none",
         
     | 
| 67 | 
         
            +
                task_type="SEQ_2_SEQ_LM"
         
     | 
| 68 | 
         
            +
            )
         
     | 
| 69 | 
         
            +
            # Wrap base model with PEFT LoRA
         
     | 
| 70 | 
         
            +
            peft_model = get_peft_model(base_model, lora_cfg)
         
     | 
| 71 | 
         
            +
            # Load adapter-only weights and merge into base
         
     | 
| 72 | 
         
            +
            adapter_state = torch.load(finetuned_weights_path, map_location="cpu")
         
     | 
| 73 | 
         
            +
            peft_model.load_state_dict(adapter_state, strict=False)
         
     | 
| 74 | 
         
            +
            model = peft_model.merge_and_unload()
         
     | 
| 75 | 
         
            +
            print("Merged base model with LoRA adapters.")
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            model.to("cpu")
         
     | 
| 78 | 
         
            +
            model.eval()
         
     | 
| 79 | 
         
            +
            print("Model ready for inference.")
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            app = FastAPI()
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            last_payload = {"comment": "", "files": []}
         
     | 
| 84 | 
         
            +
            last_infer_result = {"generated_code": ""}
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            class FileContent(BaseModel):
         
     | 
| 87 | 
         
            +
                filename: str
         
     | 
| 88 | 
         
            +
                content: str
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            class PRPayload(BaseModel):
         
     | 
| 91 | 
         
            +
                comment: str
         
     | 
| 92 | 
         
            +
                files: List[FileContent]
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            class InferenceRequest(BaseModel):
         
     | 
| 95 | 
         
            +
                comment: str
         
     | 
| 96 | 
         
            +
                files: List[FileContent]
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
            @app.get("/")
         
     | 
| 100 | 
         
            +
            def root():
         
     | 
| 101 | 
         
            +
                return {"message": "FastAPI PR comment service is running"}
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            @app.post("/pr-comments")
         
     | 
| 104 | 
         
            +
            async def receive_pr_comment(payload: PRPayload):
         
     | 
| 105 | 
         
            +
                global last_payload
         
     | 
| 106 | 
         
            +
                last_payload = payload.dict()
         
     | 
| 107 | 
         
            +
                # Return the received payload as JSON and also redirect to /show
         
     | 
| 108 | 
         
            +
                return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            @app.get("/show", response_class=HTMLResponse)
         
     | 
| 111 | 
         
            +
            def show_last_comment():
         
     | 
| 112 | 
         
            +
                html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
         
     | 
| 113 | 
         
            +
                for file in last_payload["files"]:
         
     | 
| 114 | 
         
            +
                    html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
         
     | 
| 115 | 
         
            +
                return html
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            @app.post("/infer")
         
     | 
| 118 | 
         
            +
            async def infer(request: InferenceRequest):
         
     | 
| 119 | 
         
            +
                global last_infer_result
         
     | 
| 120 | 
         
            +
                print("[DEBUG] Received /infer request with:", request.dict())
         
     | 
| 121 | 
         
            +
                
         
     | 
| 122 | 
         
            +
                code = request.files[0].content if request.files else ""
         
     | 
| 123 | 
         
            +
                source_ids = encode_diff(tokenizer, code, request.comment)
         
     | 
| 124 | 
         
            +
                # print("[DEBUG] source_ids:", source_ids)
         
     | 
| 125 | 
         
            +
                #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
         
     | 
| 126 | 
         
            +
                #print("[DEBUG] tokens:", tokens)
         
     | 
| 127 | 
         
            +
                inputs = torch.tensor([source_ids], dtype=torch.long)
         
     | 
| 128 | 
         
            +
                inputs_mask = inputs.ne(tokenizer.pad_id)
         
     | 
| 129 | 
         
            +
                
         
     | 
| 130 | 
         
            +
                preds = model.generate(
         
     | 
| 131 | 
         
            +
                    inputs,
         
     | 
| 132 | 
         
            +
                    attention_mask=inputs_mask,
         
     | 
| 133 | 
         
            +
                    use_cache=True,
         
     | 
| 134 | 
         
            +
                    num_beams=5,
         
     | 
| 135 | 
         
            +
                    early_stopping=True,
         
     | 
| 136 | 
         
            +
                    max_length=100,
         
     | 
| 137 | 
         
            +
                    num_return_sequences=1
         
     | 
| 138 | 
         
            +
                )
         
     | 
| 139 | 
         
            +
                
         
     | 
| 140 | 
         
            +
                pred = preds[0].cpu().numpy()
         
     | 
| 141 | 
         
            +
                pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
         
     | 
| 142 | 
         
            +
                last_infer_result = {"generated_code": pred_nl}
         
     | 
| 143 | 
         
            +
                return last_infer_result
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
            @app.get("/show-infer", response_class=HTMLResponse)
         
     | 
| 146 | 
         
            +
            def show_infer_result():
         
     | 
| 147 | 
         
            +
                html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
         
     | 
| 148 | 
         
            +
                return html
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 151 | 
         
            +
                # Place any CLI/training logic here if needed
         
     | 
| 152 | 
         
            +
                # This block is NOT executed when running with uvicorn
         
     | 
| 153 | 
         
            +
                pass
         
     | 
    	
        configs.py
    ADDED
    
    | 
         @@ -0,0 +1,252 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import random
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import logging
         
     | 
| 4 | 
         
            +
            import multiprocessing
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            def add_args(parser):
         
     | 
| 11 | 
         
            +
                parser.add_argument(
         
     | 
| 12 | 
         
            +
                    "--task",
         
     | 
| 13 | 
         
            +
                    type=str,
         
     | 
| 14 | 
         
            +
                    required=False,
         
     | 
| 15 | 
         
            +
                    choices=[
         
     | 
| 16 | 
         
            +
                        "review",
         
     | 
| 17 | 
         
            +
                    ],
         
     | 
| 18 | 
         
            +
                )
         
     | 
| 19 | 
         
            +
                parser.add_argument(
         
     | 
| 20 | 
         
            +
                    "--model_type",
         
     | 
| 21 | 
         
            +
                    default="codet5",
         
     | 
| 22 | 
         
            +
                    type=str,
         
     | 
| 23 | 
         
            +
                    choices=["roberta", "t5", "bart", "codet5", "scratch"],
         
     | 
| 24 | 
         
            +
                )
         
     | 
| 25 | 
         
            +
                parser.add_argument("--add_lang_ids", action="store_true")
         
     | 
| 26 | 
         
            +
                parser.add_argument("--from_scratch", action="store_true")
         
     | 
| 27 | 
         
            +
                parser.add_argument("--debug", action="store_true")
         
     | 
| 28 | 
         
            +
                parser.add_argument("--start_epoch", default=0, type=int)
         
     | 
| 29 | 
         
            +
                parser.add_argument("--train_epochs", default=10, type=int)
         
     | 
| 30 | 
         
            +
                parser.add_argument("--tokenizer_path", type=str, required=False)
         
     | 
| 31 | 
         
            +
                
         
     | 
| 32 | 
         
            +
                parser.add_argument(
         
     | 
| 33 | 
         
            +
                    "--output_dir",
         
     | 
| 34 | 
         
            +
                    default=None,
         
     | 
| 35 | 
         
            +
                    type=str,
         
     | 
| 36 | 
         
            +
                    required=False,
         
     | 
| 37 | 
         
            +
                    help="The output directory where the model predictions and checkpoints will be written.",
         
     | 
| 38 | 
         
            +
                )
         
     | 
| 39 | 
         
            +
                parser.add_argument(
         
     | 
| 40 | 
         
            +
                    "--load_model_path",
         
     | 
| 41 | 
         
            +
                    default=None,
         
     | 
| 42 | 
         
            +
                    type=str,
         
     | 
| 43 | 
         
            +
                    required=False
         
     | 
| 44 | 
         
            +
                )
         
     | 
| 45 | 
         
            +
                parser.add_argument(
         
     | 
| 46 | 
         
            +
                    "--model_name_or_path",
         
     | 
| 47 | 
         
            +
                    default=None,
         
     | 
| 48 | 
         
            +
                    type=str,
         
     | 
| 49 | 
         
            +
                    help="Path to trained model: Should contain the .bin files",
         
     | 
| 50 | 
         
            +
                )
         
     | 
| 51 | 
         
            +
                ## Other parameters
         
     | 
| 52 | 
         
            +
                parser.add_argument(
         
     | 
| 53 | 
         
            +
                    "--train_path",
         
     | 
| 54 | 
         
            +
                    default=None,
         
     | 
| 55 | 
         
            +
                    type=str,
         
     | 
| 56 | 
         
            +
                    help="The pretrain files path. Should contain the .jsonl files for this task.",
         
     | 
| 57 | 
         
            +
                )
         
     | 
| 58 | 
         
            +
                parser.add_argument(
         
     | 
| 59 | 
         
            +
                    "--eval_chunkname",
         
     | 
| 60 | 
         
            +
                    default=None,
         
     | 
| 61 | 
         
            +
                    type=str,
         
     | 
| 62 | 
         
            +
                    help="The eval file name.",
         
     | 
| 63 | 
         
            +
                )
         
     | 
| 64 | 
         
            +
                parser.add_argument(
         
     | 
| 65 | 
         
            +
                    "--train_filename",
         
     | 
| 66 | 
         
            +
                    default=None,
         
     | 
| 67 | 
         
            +
                    type=str,
         
     | 
| 68 | 
         
            +
                    help="The train filename. Should contain the .jsonl files for this task.",
         
     | 
| 69 | 
         
            +
                )
         
     | 
| 70 | 
         
            +
                parser.add_argument(
         
     | 
| 71 | 
         
            +
                    "--dev_filename",
         
     | 
| 72 | 
         
            +
                    default=None,
         
     | 
| 73 | 
         
            +
                    type=str,
         
     | 
| 74 | 
         
            +
                    help="The dev filename. Should contain the .jsonl files for this task.",
         
     | 
| 75 | 
         
            +
                )
         
     | 
| 76 | 
         
            +
                parser.add_argument(
         
     | 
| 77 | 
         
            +
                    "--test_filename",
         
     | 
| 78 | 
         
            +
                    default=None,
         
     | 
| 79 | 
         
            +
                    type=str,
         
     | 
| 80 | 
         
            +
                    help="The test filename. Should contain the .jsonl files for this task.",
         
     | 
| 81 | 
         
            +
                )
         
     | 
| 82 | 
         
            +
                parser.add_argument(
         
     | 
| 83 | 
         
            +
                    "--gold_filename",
         
     | 
| 84 | 
         
            +
                    default=None,
         
     | 
| 85 | 
         
            +
                    type=str,
         
     | 
| 86 | 
         
            +
                    help="The gold filename. Should contain the .jsonl files for this task.",
         
     | 
| 87 | 
         
            +
                )
         
     | 
| 88 | 
         
            +
                parser.add_argument(
         
     | 
| 89 | 
         
            +
                    "--config_name",
         
     | 
| 90 | 
         
            +
                    default="Salesforce/codet5-base",
         
     | 
| 91 | 
         
            +
                    type=str,
         
     | 
| 92 | 
         
            +
                    help="Pretrained config name or path if not the same as model_name",
         
     | 
| 93 | 
         
            +
                )
         
     | 
| 94 | 
         
            +
                parser.add_argument(
         
     | 
| 95 | 
         
            +
                    "--max_source_length",
         
     | 
| 96 | 
         
            +
                    default=64,
         
     | 
| 97 | 
         
            +
                    type=int,
         
     | 
| 98 | 
         
            +
                    help="The maximum total source sequence length after tokenization. Sequences longer "
         
     | 
| 99 | 
         
            +
                    "than this will be truncated, sequences shorter will be padded.",
         
     | 
| 100 | 
         
            +
                )
         
     | 
| 101 | 
         
            +
                parser.add_argument(
         
     | 
| 102 | 
         
            +
                    "--max_target_length",
         
     | 
| 103 | 
         
            +
                    default=32,
         
     | 
| 104 | 
         
            +
                    type=int,
         
     | 
| 105 | 
         
            +
                    help="The maximum total target sequence length after tokenization. Sequences longer "
         
     | 
| 106 | 
         
            +
                    "than this will be truncated, sequences shorter will be padded.",
         
     | 
| 107 | 
         
            +
                )
         
     | 
| 108 | 
         
            +
                parser.add_argument(
         
     | 
| 109 | 
         
            +
                    "--do_train", action="store_true", help="Whether to run eval on the train set."
         
     | 
| 110 | 
         
            +
                )
         
     | 
| 111 | 
         
            +
                parser.add_argument(
         
     | 
| 112 | 
         
            +
                    "--do_eval", action="store_true", help="Whether to run eval on the dev set."
         
     | 
| 113 | 
         
            +
                )
         
     | 
| 114 | 
         
            +
                parser.add_argument(
         
     | 
| 115 | 
         
            +
                    "--do_test", action="store_true", help="Whether to run eval on the dev set."
         
     | 
| 116 | 
         
            +
                )
         
     | 
| 117 | 
         
            +
                parser.add_argument(
         
     | 
| 118 | 
         
            +
                    "--raw_input", action="store_true", help="Whether to use simple input format (set for baselines)."
         
     | 
| 119 | 
         
            +
                )
         
     | 
| 120 | 
         
            +
                parser.add_argument(
         
     | 
| 121 | 
         
            +
                    "--do_lower_case",
         
     | 
| 122 | 
         
            +
                    action="store_true",
         
     | 
| 123 | 
         
            +
                    help="Set this flag if you are using an uncased model.",
         
     | 
| 124 | 
         
            +
                )
         
     | 
| 125 | 
         
            +
                parser.add_argument(
         
     | 
| 126 | 
         
            +
                    "--no_cuda", action="store_true", help="Avoid using CUDA when available"
         
     | 
| 127 | 
         
            +
                )
         
     | 
| 128 | 
         
            +
                parser.add_argument(
         
     | 
| 129 | 
         
            +
                    "--train_batch_size",
         
     | 
| 130 | 
         
            +
                    default=8,
         
     | 
| 131 | 
         
            +
                    type=int,
         
     | 
| 132 | 
         
            +
                    help="Batch size per GPU/CPU for training.",
         
     | 
| 133 | 
         
            +
                )
         
     | 
| 134 | 
         
            +
                parser.add_argument(
         
     | 
| 135 | 
         
            +
                    "--eval_batch_size",
         
     | 
| 136 | 
         
            +
                    default=8,
         
     | 
| 137 | 
         
            +
                    type=int,
         
     | 
| 138 | 
         
            +
                    help="Batch size per GPU/CPU for evaluation.",
         
     | 
| 139 | 
         
            +
                )
         
     | 
| 140 | 
         
            +
                parser.add_argument(
         
     | 
| 141 | 
         
            +
                    "--gradient_accumulation_steps",
         
     | 
| 142 | 
         
            +
                    type=int,
         
     | 
| 143 | 
         
            +
                    default=1,
         
     | 
| 144 | 
         
            +
                    help="Number of updates steps to accumulate before performing a backward/update pass.",
         
     | 
| 145 | 
         
            +
                )
         
     | 
| 146 | 
         
            +
                parser.add_argument(
         
     | 
| 147 | 
         
            +
                    "--learning_rate",
         
     | 
| 148 | 
         
            +
                    default=5e-5,
         
     | 
| 149 | 
         
            +
                    type=float,
         
     | 
| 150 | 
         
            +
                    help="The initial learning rate for Adam.",
         
     | 
| 151 | 
         
            +
                )
         
     | 
| 152 | 
         
            +
                parser.add_argument(
         
     | 
| 153 | 
         
            +
                    "--mask_rate", default=0.15, type=float, help="The masked rate of input lines.",
         
     | 
| 154 | 
         
            +
                )
         
     | 
| 155 | 
         
            +
                parser.add_argument(
         
     | 
| 156 | 
         
            +
                    "--beam_size", default=6, type=int, help="beam size for beam search"
         
     | 
| 157 | 
         
            +
                )
         
     | 
| 158 | 
         
            +
                parser.add_argument(
         
     | 
| 159 | 
         
            +
                    "--weight_decay", default=0.0, type=float, help="Weight deay if we apply some."
         
     | 
| 160 | 
         
            +
                )
         
     | 
| 161 | 
         
            +
                parser.add_argument(
         
     | 
| 162 | 
         
            +
                    "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
         
     | 
| 163 | 
         
            +
                )
         
     | 
| 164 | 
         
            +
                parser.add_argument(
         
     | 
| 165 | 
         
            +
                    "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
         
     | 
| 166 | 
         
            +
                )
         
     | 
| 167 | 
         
            +
                parser.add_argument(
         
     | 
| 168 | 
         
            +
                    "--save_steps", default=-1, type=int,
         
     | 
| 169 | 
         
            +
                )
         
     | 
| 170 | 
         
            +
                parser.add_argument(
         
     | 
| 171 | 
         
            +
                    "--log_steps", default=-1, type=int,
         
     | 
| 172 | 
         
            +
                )
         
     | 
| 173 | 
         
            +
                parser.add_argument("--eval_steps", default=-1, type=int, help="")
         
     | 
| 174 | 
         
            +
                parser.add_argument("--eval_file", default="", type=str)
         
     | 
| 175 | 
         
            +
                parser.add_argument("--out_file", default="", type=str)
         
     | 
| 176 | 
         
            +
                parser.add_argument("--break_cnt", default=-1, type=int)
         
     | 
| 177 | 
         
            +
                parser.add_argument("--train_steps", default=-1, type=int, help="")
         
     | 
| 178 | 
         
            +
                parser.add_argument(
         
     | 
| 179 | 
         
            +
                    "--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps."
         
     | 
| 180 | 
         
            +
                )
         
     | 
| 181 | 
         
            +
                parser.add_argument(
         
     | 
| 182 | 
         
            +
                    "--gpu_per_node",
         
     | 
| 183 | 
         
            +
                    type=int,
         
     | 
| 184 | 
         
            +
                    default=4,
         
     | 
| 185 | 
         
            +
                    help="gpus per node",
         
     | 
| 186 | 
         
            +
                )
         
     | 
| 187 | 
         
            +
                parser.add_argument(
         
     | 
| 188 | 
         
            +
                    "--node_index",
         
     | 
| 189 | 
         
            +
                    type=int,
         
     | 
| 190 | 
         
            +
                    default=0,
         
     | 
| 191 | 
         
            +
                    help="For distributed training: node_index",
         
     | 
| 192 | 
         
            +
                )
         
     | 
| 193 | 
         
            +
                parser.add_argument(
         
     | 
| 194 | 
         
            +
                    "--local_rank",
         
     | 
| 195 | 
         
            +
                    type=int,
         
     | 
| 196 | 
         
            +
                    default=-1,
         
     | 
| 197 | 
         
            +
                    help="For distributed training: local_rank",
         
     | 
| 198 | 
         
            +
                )
         
     | 
| 199 | 
         
            +
                parser.add_argument(
         
     | 
| 200 | 
         
            +
                    "--seed", type=int, default=2233, help="random seed for initialization"
         
     | 
| 201 | 
         
            +
                )  # previous one 42
         
     | 
| 202 | 
         
            +
                # Or in configs.py if add_args is defined there
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                parser.add_argument(
         
     | 
| 205 | 
         
            +
                    "--clearml_train_dataset_id",
         
     | 
| 206 | 
         
            +
                    type=str,
         
     | 
| 207 | 
         
            +
                    default=None,
         
     | 
| 208 | 
         
            +
                    help="ClearML Dataset ID to fetch training data from. Overrides train_filename if provided.",
         
     | 
| 209 | 
         
            +
                )
         
     | 
| 210 | 
         
            +
                parser.add_argument(
         
     | 
| 211 | 
         
            +
                    "--clearml_valid_dataset_id",
         
     | 
| 212 | 
         
            +
                    type=str,
         
     | 
| 213 | 
         
            +
                    default=None,
         
     | 
| 214 | 
         
            +
                    help="ClearML Dataset ID to fetch validation data from. Overrides dev_filename if provided.",
         
     | 
| 215 | 
         
            +
                )
         
     | 
| 216 | 
         
            +
                args = parser.parse_args()
         
     | 
| 217 | 
         
            +
                return args
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            def set_dist(args):
         
     | 
| 221 | 
         
            +
                # Setup CUDA, GPU & distributed training
         
     | 
| 222 | 
         
            +
                if args.local_rank == -1 or args.no_cuda:
         
     | 
| 223 | 
         
            +
                    device = torch.device(
         
     | 
| 224 | 
         
            +
                        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
         
     | 
| 225 | 
         
            +
                    )
         
     | 
| 226 | 
         
            +
                    args.n_gpu = torch.cuda.device_count()
         
     | 
| 227 | 
         
            +
                else:
         
     | 
| 228 | 
         
            +
                    # Setup for distributed data parallel
         
     | 
| 229 | 
         
            +
                    torch.cuda.set_device(args.local_rank)
         
     | 
| 230 | 
         
            +
                    device = torch.device("cuda", args.local_rank)
         
     | 
| 231 | 
         
            +
                    torch.distributed.init_process_group(backend="nccl")
         
     | 
| 232 | 
         
            +
                    args.n_gpu = 1
         
     | 
| 233 | 
         
            +
                cpu_count = multiprocessing.cpu_count()
         
     | 
| 234 | 
         
            +
                logger.warning(
         
     | 
| 235 | 
         
            +
                    "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, cpu count: %d",
         
     | 
| 236 | 
         
            +
                    args.local_rank,
         
     | 
| 237 | 
         
            +
                    device,
         
     | 
| 238 | 
         
            +
                    args.n_gpu,
         
     | 
| 239 | 
         
            +
                    bool(args.local_rank != -1),
         
     | 
| 240 | 
         
            +
                    cpu_count,
         
     | 
| 241 | 
         
            +
                )
         
     | 
| 242 | 
         
            +
                args.device = device
         
     | 
| 243 | 
         
            +
                args.cpu_count = cpu_count
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
            def set_seed(args):
         
     | 
| 247 | 
         
            +
                """set random seed."""
         
     | 
| 248 | 
         
            +
                random.seed(args.seed)
         
     | 
| 249 | 
         
            +
                np.random.seed(args.seed)
         
     | 
| 250 | 
         
            +
                torch.manual_seed(args.seed)
         
     | 
| 251 | 
         
            +
                # if args.n_gpu > 0:
         
     | 
| 252 | 
         
            +
                torch.cuda.manual_seed_all(args.seed)
         
     | 
    	
        models.py
    ADDED
    
    | 
         @@ -0,0 +1,208 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from utils import MyTokenizer
         
     | 
| 8 | 
         
            +
            from transformers import (
         
     | 
| 9 | 
         
            +
                RobertaConfig,
         
     | 
| 10 | 
         
            +
                RobertaModel,
         
     | 
| 11 | 
         
            +
                RobertaTokenizer,
         
     | 
| 12 | 
         
            +
                BartConfig,
         
     | 
| 13 | 
         
            +
                BartForConditionalGeneration,
         
     | 
| 14 | 
         
            +
                BartTokenizer,
         
     | 
| 15 | 
         
            +
                T5Config,
         
     | 
| 16 | 
         
            +
                T5ForConditionalGeneration,
         
     | 
| 17 | 
         
            +
                T5Tokenizer,
         
     | 
| 18 | 
         
            +
            )
         
     | 
| 19 | 
         
            +
            import logging
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class ReviewerModel(T5ForConditionalGeneration):
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def __init__(self, config):
         
     | 
| 27 | 
         
            +
                    super().__init__(config)
         
     | 
| 28 | 
         
            +
                    self.cls_head = nn.Linear(self.config.d_model, 2, bias=True)
         
     | 
| 29 | 
         
            +
                    self.init()
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def init(self):
         
     | 
| 32 | 
         
            +
                    nn.init.xavier_uniform_(self.lm_head.weight)
         
     | 
| 33 | 
         
            +
                    factor = self.config.initializer_factor
         
     | 
| 34 | 
         
            +
                    self.cls_head.weight.data.normal_(mean=0.0, \
         
     | 
| 35 | 
         
            +
                        std=factor * ((self.config.d_model) ** -0.5))
         
     | 
| 36 | 
         
            +
                    self.cls_head.bias.data.zero_()
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def forward(
         
     | 
| 39 | 
         
            +
                    self, *argv, **kwargs
         
     | 
| 40 | 
         
            +
                ):
         
     | 
| 41 | 
         
            +
                    r"""
         
     | 
| 42 | 
         
            +
                    Doc from Huggingface transformers:
         
     | 
| 43 | 
         
            +
                    labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
         
     | 
| 44 | 
         
            +
                        Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
         
     | 
| 45 | 
         
            +
                        config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
         
     | 
| 46 | 
         
            +
                        labels in ``[0, ..., config.vocab_size]``
         
     | 
| 47 | 
         
            +
                    Returns:
         
     | 
| 48 | 
         
            +
                    Examples::
         
     | 
| 49 | 
         
            +
                        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
         
     | 
| 50 | 
         
            +
                        >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
         
     | 
| 51 | 
         
            +
                        >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
         
     | 
| 52 | 
         
            +
                        >>> # training
         
     | 
| 53 | 
         
            +
                        >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
         
     | 
| 54 | 
         
            +
                        >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
         
     | 
| 55 | 
         
            +
                        >>> outputs = model(input_ids=input_ids, labels=labels)
         
     | 
| 56 | 
         
            +
                        >>> loss = outputs.loss
         
     | 
| 57 | 
         
            +
                        >>> logits = outputs.logits
         
     | 
| 58 | 
         
            +
                        >>> # inference
         
     | 
| 59 | 
         
            +
                        >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
         
     | 
| 60 | 
         
            +
                        >>> outputs = model.generate(input_ids)
         
     | 
| 61 | 
         
            +
                        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
         
     | 
| 62 | 
         
            +
                        >>> # studies have shown that owning a dog is good for you.
         
     | 
| 63 | 
         
            +
                    """
         
     | 
| 64 | 
         
            +
                    if "cls" in kwargs:
         
     | 
| 65 | 
         
            +
                        assert (
         
     | 
| 66 | 
         
            +
                            "input_ids" in kwargs and \
         
     | 
| 67 | 
         
            +
                            "labels" in kwargs and \
         
     | 
| 68 | 
         
            +
                            "attention_mask" in kwargs
         
     | 
| 69 | 
         
            +
                        )
         
     | 
| 70 | 
         
            +
                        return self.cls(
         
     | 
| 71 | 
         
            +
                            input_ids=kwargs["input_ids"],
         
     | 
| 72 | 
         
            +
                            labels=kwargs["labels"],
         
     | 
| 73 | 
         
            +
                            attention_mask=kwargs["attention_mask"],
         
     | 
| 74 | 
         
            +
                        )
         
     | 
| 75 | 
         
            +
                    if "input_labels" in kwargs:
         
     | 
| 76 | 
         
            +
                        assert (
         
     | 
| 77 | 
         
            +
                            "input_ids" in kwargs and \
         
     | 
| 78 | 
         
            +
                            "input_labels" in kwargs and \
         
     | 
| 79 | 
         
            +
                            "decoder_input_ids" in kwargs and \
         
     | 
| 80 | 
         
            +
                            "attention_mask" in kwargs and \
         
     | 
| 81 | 
         
            +
                            "decoder_attention_mask" in kwargs
         
     | 
| 82 | 
         
            +
                        ), "Please give these arg keys."
         
     | 
| 83 | 
         
            +
                        input_ids = kwargs["input_ids"]
         
     | 
| 84 | 
         
            +
                        input_labels = kwargs["input_labels"]
         
     | 
| 85 | 
         
            +
                        decoder_input_ids = kwargs["decoder_input_ids"]
         
     | 
| 86 | 
         
            +
                        attention_mask = kwargs["attention_mask"]
         
     | 
| 87 | 
         
            +
                        decoder_attention_mask = kwargs["decoder_attention_mask"]
         
     | 
| 88 | 
         
            +
                        if "encoder_loss" not in kwargs:
         
     | 
| 89 | 
         
            +
                            encoder_loss = True
         
     | 
| 90 | 
         
            +
                        else:
         
     | 
| 91 | 
         
            +
                            encoder_loss = kwargs["encoder_loss"]
         
     | 
| 92 | 
         
            +
                        return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss)
         
     | 
| 93 | 
         
            +
                    return super().forward(*argv, **kwargs)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def cls(
         
     | 
| 96 | 
         
            +
                    self,
         
     | 
| 97 | 
         
            +
                    input_ids,
         
     | 
| 98 | 
         
            +
                    labels,
         
     | 
| 99 | 
         
            +
                    attention_mask,
         
     | 
| 100 | 
         
            +
                ):
         
     | 
| 101 | 
         
            +
                    encoder_outputs = self.encoder( \
         
     | 
| 102 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 103 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 104 | 
         
            +
                        output_attentions=False,
         
     | 
| 105 | 
         
            +
                        return_dict=False
         
     | 
| 106 | 
         
            +
                    )
         
     | 
| 107 | 
         
            +
                    hidden_states = encoder_outputs[0]
         
     | 
| 108 | 
         
            +
                    first_hidden = hidden_states[:, 0, :]
         
     | 
| 109 | 
         
            +
                    first_hidden = nn.Dropout(0.3)(first_hidden)
         
     | 
| 110 | 
         
            +
                    logits = self.cls_head(first_hidden)
         
     | 
| 111 | 
         
            +
                    loss_fct = CrossEntropyLoss()
         
     | 
| 112 | 
         
            +
                    if labels != None:
         
     | 
| 113 | 
         
            +
                        loss = loss_fct(logits, labels)
         
     | 
| 114 | 
         
            +
                        return loss
         
     | 
| 115 | 
         
            +
                    return logits
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                def review_forward(
         
     | 
| 118 | 
         
            +
                    self,
         
     | 
| 119 | 
         
            +
                    input_ids,
         
     | 
| 120 | 
         
            +
                    input_labels,
         
     | 
| 121 | 
         
            +
                    decoder_input_ids,
         
     | 
| 122 | 
         
            +
                    attention_mask,
         
     | 
| 123 | 
         
            +
                    decoder_attention_mask,
         
     | 
| 124 | 
         
            +
                    encoder_loss=True
         
     | 
| 125 | 
         
            +
                ):
         
     | 
| 126 | 
         
            +
                    encoder_outputs = self.encoder( \
         
     | 
| 127 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 128 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 129 | 
         
            +
                        output_attentions=False,
         
     | 
| 130 | 
         
            +
                        return_dict=False
         
     | 
| 131 | 
         
            +
                    )
         
     | 
| 132 | 
         
            +
                    hidden_states = encoder_outputs[0]
         
     | 
| 133 | 
         
            +
                    decoder_inputs = self._shift_right(decoder_input_ids)
         
     | 
| 134 | 
         
            +
                    # Decode
         
     | 
| 135 | 
         
            +
                    decoder_outputs = self.decoder(
         
     | 
| 136 | 
         
            +
                        input_ids=decoder_inputs,
         
     | 
| 137 | 
         
            +
                        attention_mask=decoder_attention_mask,
         
     | 
| 138 | 
         
            +
                        encoder_hidden_states=hidden_states,
         
     | 
| 139 | 
         
            +
                        encoder_attention_mask=attention_mask,
         
     | 
| 140 | 
         
            +
                        output_attentions=False,
         
     | 
| 141 | 
         
            +
                        return_dict=False
         
     | 
| 142 | 
         
            +
                    )
         
     | 
| 143 | 
         
            +
                    sequence_output = decoder_outputs[0]
         
     | 
| 144 | 
         
            +
                    if self.config.tie_word_embeddings: # this is True default
         
     | 
| 145 | 
         
            +
                        sequence_output = sequence_output * (self.model_dim ** -0.5)
         
     | 
| 146 | 
         
            +
                    if encoder_loss:
         
     | 
| 147 | 
         
            +
                        # print(self.encoder.get_input_embeddings().weight.shape)
         
     | 
| 148 | 
         
            +
                        cls_logits = nn.functional.linear(hidden_states, self.encoder.get_input_embeddings().weight)
         
     | 
| 149 | 
         
            +
                        # cls_logits = self.cls_head(hidden_states)
         
     | 
| 150 | 
         
            +
                    lm_logits = self.lm_head(sequence_output)
         
     | 
| 151 | 
         
            +
                    if decoder_input_ids is not None:
         
     | 
| 152 | 
         
            +
                        lm_loss_fct = CrossEntropyLoss(ignore_index=0)      # Warning: PAD_ID should be 0
         
     | 
| 153 | 
         
            +
                        loss = lm_loss_fct(lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1))
         
     | 
| 154 | 
         
            +
                        if encoder_loss and input_labels is not None:
         
     | 
| 155 | 
         
            +
                            cls_loss_fct = CrossEntropyLoss(ignore_index=-100)
         
     | 
| 156 | 
         
            +
                            loss += cls_loss_fct(cls_logits.view(-1, cls_logits.size(-1)), input_labels.view(-1))
         
     | 
| 157 | 
         
            +
                        return loss
         
     | 
| 158 | 
         
            +
                    return cls_logits, lm_logits
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
            def get_model_size(model):
         
     | 
| 161 | 
         
            +
                model_parameters = filter(lambda p: p.requires_grad, model.parameters())
         
     | 
| 162 | 
         
            +
                model_size = sum([np.prod(p.size()) for p in model_parameters])
         
     | 
| 163 | 
         
            +
                return "{}M".format(round(model_size / 1e6))
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
            def build_or_load_gen_model(args):
         
     | 
| 167 | 
         
            +
                config_class, model_class, tokenizer_class = T5Config, ReviewerModel, RobertaTokenizer
         
     | 
| 168 | 
         
            +
                
         
     | 
| 169 | 
         
            +
                config = config_class.from_pretrained(args.model_name_or_path)
         
     | 
| 170 | 
         
            +
                tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
         
     | 
| 171 | 
         
            +
                model = model_class.from_pretrained(args.model_name_or_path, config=config)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                tokenizer.special_dict = {
         
     | 
| 174 | 
         
            +
                    f"<e{i}>" : tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1)
         
     | 
| 175 | 
         
            +
                }
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                tokenizer.mask_id = tokenizer.get_vocab()["<mask>"]
         
     | 
| 178 | 
         
            +
                tokenizer.bos_id = tokenizer.get_vocab()["<s>"]
         
     | 
| 179 | 
         
            +
                tokenizer.pad_id = tokenizer.get_vocab()["<pad>"]
         
     | 
| 180 | 
         
            +
                tokenizer.eos_id = tokenizer.get_vocab()["</s>"]
         
     | 
| 181 | 
         
            +
                tokenizer.msg_id = tokenizer.get_vocab()["<msg>"]
         
     | 
| 182 | 
         
            +
                tokenizer.keep_id = tokenizer.get_vocab()["<keep>"]
         
     | 
| 183 | 
         
            +
                tokenizer.add_id = tokenizer.get_vocab()["<add>"]
         
     | 
| 184 | 
         
            +
                tokenizer.del_id = tokenizer.get_vocab()["<del>"]
         
     | 
| 185 | 
         
            +
                tokenizer.start_id = tokenizer.get_vocab()["<start>"]
         
     | 
| 186 | 
         
            +
                tokenizer.end_id = tokenizer.get_vocab()["<end>"]
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                logger.info(
         
     | 
| 189 | 
         
            +
                    "Finish loading model [%s] from %s",
         
     | 
| 190 | 
         
            +
                    get_model_size(model),
         
     | 
| 191 | 
         
            +
                    args.model_name_or_path,
         
     | 
| 192 | 
         
            +
                )
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                if args.load_model_path is not None:
         
     | 
| 195 | 
         
            +
                    model_path = os.path.join(args.load_model_path, "pytorch_model.bin")
         
     | 
| 196 | 
         
            +
                    logger.info("Reload model from {}".format(model_path))
         
     | 
| 197 | 
         
            +
                    try:
         
     | 
| 198 | 
         
            +
                        model.load_state_dict(torch.load(model_path, map_location="cpu"))
         
     | 
| 199 | 
         
            +
                    except RuntimeError:
         
     | 
| 200 | 
         
            +
                        saved = model.cls_head
         
     | 
| 201 | 
         
            +
                        model.cls_head = None
         
     | 
| 202 | 
         
            +
                        model.load_state_dict(torch.load(model_path, map_location="cpu"))
         
     | 
| 203 | 
         
            +
                        model.cls_head = saved
         
     | 
| 204 | 
         
            +
                    model.to(args.local_rank)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                return config, model, tokenizer
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
             
     | 
    	
        utils.py
    ADDED
    
    | 
         @@ -0,0 +1,823 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import re, json
         
     | 
| 2 | 
         
            +
            import os, random
         
     | 
| 3 | 
         
            +
            import torch, logging
         
     | 
| 4 | 
         
            +
            from copy import deepcopy as cp
         
     | 
| 5 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 6 | 
         
            +
            from tokenizers import ByteLevelBPETokenizer
         
     | 
| 7 | 
         
            +
            from transformers import T5Tokenizer, RobertaTokenizer
         
     | 
| 8 | 
         
            +
            import nltk
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            logging.basicConfig(
         
     | 
| 11 | 
         
            +
                format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
         
     | 
| 12 | 
         
            +
                datefmt="%m/%d/%Y %H:%M:%S",
         
     | 
| 13 | 
         
            +
                level=logging.INFO,
         
     | 
| 14 | 
         
            +
            )
         
     | 
| 15 | 
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            class MyTokenizer(object):
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
                Wrapper for ByteLevelBPETokenizer
         
     | 
| 22 | 
         
            +
                """
         
     | 
| 23 | 
         
            +
                def __init__(self, vocab=None, merges=None, **kwargs):
         
     | 
| 24 | 
         
            +
                    self.tokenizer = ByteLevelBPETokenizer(vocab, merges, **kwargs)
         
     | 
| 25 | 
         
            +
                    self.update_id2token()
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                @staticmethod
         
     | 
| 28 | 
         
            +
                def from_pretrained(path):
         
     | 
| 29 | 
         
            +
                    vocabp = os.path.join(path, "vocab.json")
         
     | 
| 30 | 
         
            +
                    mergesp = os.path.join(path, "merges.txt")
         
     | 
| 31 | 
         
            +
                    mytoken = MyTokenizer(vocabp, mergesp)
         
     | 
| 32 | 
         
            +
                    return mytoken
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def update_id2token(self):
         
     | 
| 35 | 
         
            +
                    vocab = self.tokenizer.get_vocab()
         
     | 
| 36 | 
         
            +
                    self.id2token = {vocab[token]: token for token in vocab}
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def add_special_tokens(self, dic):
         
     | 
| 39 | 
         
            +
                    for values in dic.values():
         
     | 
| 40 | 
         
            +
                        self.tokenizer.add_special_tokens(values)
         
     | 
| 41 | 
         
            +
                    self.update_id2token()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def convert_ids_to_tokens(self, ids):
         
     | 
| 44 | 
         
            +
                    vocab = self.id2token
         
     | 
| 45 | 
         
            +
                    return [vocab[i] for i in ids]
         
     | 
| 46 | 
         
            +
                
         
     | 
| 47 | 
         
            +
                def decode(self, ids, **kwargs):    ##### to be update
         
     | 
| 48 | 
         
            +
                    tokens = self.convert_ids_to_tokens(ids)
         
     | 
| 49 | 
         
            +
                    return " ".join(tokens)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def encode(self, text, **kwargs):
         
     | 
| 52 | 
         
            +
                    text = text.encode("ascii", errors="ignore").decode("ascii")
         
     | 
| 53 | 
         
            +
                    return self.tokenizer.encode(text).ids
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def get_vocab(self):
         
     | 
| 56 | 
         
            +
                    return self.tokenizer.get_vocab()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def __len__(self):
         
     | 
| 59 | 
         
            +
                    return len(self.tokenizer.get_vocab())
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            class RefineFeatures(object):
         
     | 
| 63 | 
         
            +
                def __init__(self, example_id, source_ids, target_ids):
         
     | 
| 64 | 
         
            +
                    self.example_id = example_id
         
     | 
| 65 | 
         
            +
                    self.source_ids = source_ids
         
     | 
| 66 | 
         
            +
                    self.target_ids = target_ids
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            class RefineDataset(Dataset):
         
     | 
| 69 | 
         
            +
                def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
         
     | 
| 70 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 71 | 
         
            +
                    self.args = args
         
     | 
| 72 | 
         
            +
                    logger.info("Reading examples from {}".format(file_path))
         
     | 
| 73 | 
         
            +
                    examples = [json.loads(line) for line in open(file_path)]
         
     | 
| 74 | 
         
            +
                    for i in range(len(examples)):
         
     | 
| 75 | 
         
            +
                        if "id" not in examples[i]:
         
     | 
| 76 | 
         
            +
                            examples[i]["id"] = i
         
     | 
| 77 | 
         
            +
                    if samplenum > 0:
         
     | 
| 78 | 
         
            +
                        examples = examples[:samplenum]
         
     | 
| 79 | 
         
            +
                    logger.info(f"Tokenize examples: {file_path}")
         
     | 
| 80 | 
         
            +
                    self.feats = pool.map(self.tokenize, \
         
     | 
| 81 | 
         
            +
                        [(example, tokenizer, args) for example in examples])
         
     | 
| 82 | 
         
            +
                    
         
     | 
| 83 | 
         
            +
                def tokenize(self, item):
         
     | 
| 84 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 85 | 
         
            +
                    oldlines = example["old"].split("\n")
         
     | 
| 86 | 
         
            +
                    newlines = example["new"].split("\n")
         
     | 
| 87 | 
         
            +
                    oldlines = [line[1:].strip() for line in oldlines]
         
     | 
| 88 | 
         
            +
                    newlines = [line[1:].strip() for line in newlines]
         
     | 
| 89 | 
         
            +
                    oldlines = "\n".join(oldlines)
         
     | 
| 90 | 
         
            +
                    newlines = "\n".join(newlines)
         
     | 
| 91 | 
         
            +
                    oldlines = "<add>" + oldlines.replace("\n", "<add>")
         
     | 
| 92 | 
         
            +
                    newlines = "<add>" + newlines.replace("\n", "<add>")
         
     | 
| 93 | 
         
            +
                    comment = example["comment"]
         
     | 
| 94 | 
         
            +
                    srcids = self.encode_remove(tokenizer, oldlines, args)
         
     | 
| 95 | 
         
            +
                    srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args)
         
     | 
| 96 | 
         
            +
                    tgtids = self.encode_remove(tokenizer, newlines, args)
         
     | 
| 97 | 
         
            +
                    srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer)
         
     | 
| 98 | 
         
            +
                    return RefineFeatures(example["id"], srcids, tgtids)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                @staticmethod
         
     | 
| 101 | 
         
            +
                def process_pred_gold(pred, gold):
         
     | 
| 102 | 
         
            +
                    gold = gold.split("\n")
         
     | 
| 103 | 
         
            +
                    gold = [line[1:].strip() for line in gold]
         
     | 
| 104 | 
         
            +
                    gold = " ".join(gold)
         
     | 
| 105 | 
         
            +
                    pred = " ".join(pred.split())
         
     | 
| 106 | 
         
            +
                    pred = pred.replace("<add> ", "")
         
     | 
| 107 | 
         
            +
                    return pred, gold
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                def pad_assert(self, source_ids, target_ids, args, tokenizer):
         
     | 
| 110 | 
         
            +
                    source_ids = source_ids[:args.max_source_length - 2]
         
     | 
| 111 | 
         
            +
                    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
         
     | 
| 112 | 
         
            +
                    pad_len = args.max_source_length - len(source_ids)
         
     | 
| 113 | 
         
            +
                    source_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 114 | 
         
            +
                    target_ids = target_ids[:args.max_target_length - 2]
         
     | 
| 115 | 
         
            +
                    target_ids = [tokenizer.bos_id] + target_ids + [tokenizer.eos_id]
         
     | 
| 116 | 
         
            +
                    pad_len = args.max_target_length - len(target_ids)
         
     | 
| 117 | 
         
            +
                    target_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 118 | 
         
            +
                    assert len(source_ids) == args.max_source_length, "Not equal length."
         
     | 
| 119 | 
         
            +
                    assert len(target_ids) == args.max_target_length, "Not equal length."
         
     | 
| 120 | 
         
            +
                    return source_ids, target_ids
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def encode_remove(self, tokenizer, text, args):
         
     | 
| 123 | 
         
            +
                    text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True)
         
     | 
| 124 | 
         
            +
                    if type(tokenizer) == T5Tokenizer:
         
     | 
| 125 | 
         
            +
                        return text[:-1]
         
     | 
| 126 | 
         
            +
                    elif type(tokenizer) == RobertaTokenizer:
         
     | 
| 127 | 
         
            +
                        return text[1:-1]
         
     | 
| 128 | 
         
            +
                    elif type(tokenizer) == MyTokenizer:
         
     | 
| 129 | 
         
            +
                        return text
         
     | 
| 130 | 
         
            +
                    else:
         
     | 
| 131 | 
         
            +
                        raise NotImplementedError
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def __len__(self):
         
     | 
| 134 | 
         
            +
                    return len(self.feats)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def __getitem__(self, i):
         
     | 
| 137 | 
         
            +
                    return self.feats[i]
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
            class SimpleRefineDataset(RefineDataset):
         
     | 
| 140 | 
         
            +
                def tokenize(self, item):
         
     | 
| 141 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 142 | 
         
            +
                    oldlines = example["old"].split("\n")
         
     | 
| 143 | 
         
            +
                    newlines = example["new"].split("\n")
         
     | 
| 144 | 
         
            +
                    oldlines = [line[1:].strip() for line in oldlines]
         
     | 
| 145 | 
         
            +
                    newlines = [line[1:].strip() for line in newlines]
         
     | 
| 146 | 
         
            +
                    oldlines = " ".join(oldlines)
         
     | 
| 147 | 
         
            +
                    newlines = " ".join(newlines)
         
     | 
| 148 | 
         
            +
                    comment = example["comment"]
         
     | 
| 149 | 
         
            +
                    srcids = self.encode_remove(tokenizer, oldlines, args)
         
     | 
| 150 | 
         
            +
                    srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args)
         
     | 
| 151 | 
         
            +
                    tgtids = self.encode_remove(tokenizer, newlines, args)
         
     | 
| 152 | 
         
            +
                    srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer)
         
     | 
| 153 | 
         
            +
                    return RefineFeatures(example["id"], srcids, tgtids)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                @staticmethod
         
     | 
| 156 | 
         
            +
                def process_pred_gold(pred, gold):
         
     | 
| 157 | 
         
            +
                    gold = gold.split("\n")
         
     | 
| 158 | 
         
            +
                    gold = [line[1:].strip() for line in gold]
         
     | 
| 159 | 
         
            +
                    gold = " ".join(gold)
         
     | 
| 160 | 
         
            +
                    pred = " ".join(pred.split())
         
     | 
| 161 | 
         
            +
                    return pred, gold
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            class Seq2SeqDataset(RefineDataset):
         
     | 
| 165 | 
         
            +
                def tokenize(self, item):
         
     | 
| 166 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 167 | 
         
            +
                    inputs, outputs = example["old"], example["new"]
         
     | 
| 168 | 
         
            +
                    inputs = " ".join(inputs.split())
         
     | 
| 169 | 
         
            +
                    outputs = " ".join(outputs.split())
         
     | 
| 170 | 
         
            +
                    srcids = self.encode_remove(tokenizer, inputs, args)
         
     | 
| 171 | 
         
            +
                    tgtids = self.encode_remove(tokenizer, outputs, args)
         
     | 
| 172 | 
         
            +
                    srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer)
         
     | 
| 173 | 
         
            +
                    return RefineFeatures(example["id"], srcids, tgtids)
         
     | 
| 174 | 
         
            +
                
         
     | 
| 175 | 
         
            +
                @staticmethod
         
     | 
| 176 | 
         
            +
                def process_pred_gold(pred, gold):
         
     | 
| 177 | 
         
            +
                    gold = " ".join(gold.split())
         
     | 
| 178 | 
         
            +
                    pred = " ".join(pred.split())
         
     | 
| 179 | 
         
            +
                    return pred, gold
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            class TextDataset(Dataset):
         
     | 
| 183 | 
         
            +
                def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
         
     | 
| 184 | 
         
            +
                    self.cnt = 0
         
     | 
| 185 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 186 | 
         
            +
                    self.args = args
         
     | 
| 187 | 
         
            +
                    if isinstance(tokenizer, MyTokenizer):
         
     | 
| 188 | 
         
            +
                        tokenizer_type = "mytok"
         
     | 
| 189 | 
         
            +
                    elif isinstance(tokenizer, T5Tokenizer):
         
     | 
| 190 | 
         
            +
                        tokenizer_type = ""
         
     | 
| 191 | 
         
            +
                    elif isinstance(tokenizer, RobertaTokenizer):
         
     | 
| 192 | 
         
            +
                        tokenizer_type = "rb"
         
     | 
| 193 | 
         
            +
                    else:
         
     | 
| 194 | 
         
            +
                        tokenizer_type = "unk"
         
     | 
| 195 | 
         
            +
                    savep = file_path.replace(".jsonl", tokenizer_type + ".exps")
         
     | 
| 196 | 
         
            +
                    # savep = "/home/v-zhuoli1/lzzz/processed/chunk_25.exps"
         
     | 
| 197 | 
         
            +
                    if os.path.exists(savep):
         
     | 
| 198 | 
         
            +
                        logger.info("Loading examples from {}".format(savep))
         
     | 
| 199 | 
         
            +
                        examples = torch.load(savep)
         
     | 
| 200 | 
         
            +
                    else:
         
     | 
| 201 | 
         
            +
                        logger.info("Reading examples from {}".format(file_path))
         
     | 
| 202 | 
         
            +
                        examples = read_review_examples(file_path, samplenum, tokenizer)
         
     | 
| 203 | 
         
            +
                        logger.info(f"Tokenize examples: {file_path}")
         
     | 
| 204 | 
         
            +
                        examples = pool.map(self.tokenize, \
         
     | 
| 205 | 
         
            +
                            [(example, tokenizer, args) for example in examples])
         
     | 
| 206 | 
         
            +
                        torch.save(examples, savep)
         
     | 
| 207 | 
         
            +
                    logger.info("Convert examples to features...")
         
     | 
| 208 | 
         
            +
                    self.set_start_end_ids(examples)
         
     | 
| 209 | 
         
            +
                    self.featss = pool.map(self.convert_examples_to_features, \
         
     | 
| 210 | 
         
            +
                        [(example, tokenizer, args) for example in examples])
         
     | 
| 211 | 
         
            +
                    self.feats = [feat for feats in self.featss for feat in feats]  # expand the lists
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def __len__(self):
         
     | 
| 214 | 
         
            +
                    return len(self.feats)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                def __getitem__(self, i):
         
     | 
| 217 | 
         
            +
                    return self.feats[i]
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                def reset_len(self, data_len):
         
     | 
| 220 | 
         
            +
                    assert len(self.feats) >= data_len
         
     | 
| 221 | 
         
            +
                    self.feats = self.feats[:data_len]
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                def set_start_end_ids(self, examples):
         
     | 
| 224 | 
         
            +
                    for example in examples:
         
     | 
| 225 | 
         
            +
                        labels = example.labels
         
     | 
| 226 | 
         
            +
                        start_id = 0
         
     | 
| 227 | 
         
            +
                        end_id = len(labels) - 1
         
     | 
| 228 | 
         
            +
                        for i, label in enumerate(labels):
         
     | 
| 229 | 
         
            +
                            if label != -100:               # find the first label
         
     | 
| 230 | 
         
            +
                                start_id = i
         
     | 
| 231 | 
         
            +
                                break
         
     | 
| 232 | 
         
            +
                        for i in range(len(labels) - 1, -1, -1):
         
     | 
| 233 | 
         
            +
                            label = labels[i]
         
     | 
| 234 | 
         
            +
                            if label != -100:
         
     | 
| 235 | 
         
            +
                                end_id = i
         
     | 
| 236 | 
         
            +
                                break
         
     | 
| 237 | 
         
            +
                        example.start_id = start_id
         
     | 
| 238 | 
         
            +
                        example.end_id = end_id
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def tokenize(self, item):
         
     | 
| 241 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 242 | 
         
            +
                    example.input = self.encode_remove(tokenizer, example.input, args)
         
     | 
| 243 | 
         
            +
                    e0id = tokenizer.special_dict["<e0>"]
         
     | 
| 244 | 
         
            +
                    inputs = " ".join(str(id) for id in example.input)
         
     | 
| 245 | 
         
            +
                    lines = inputs.split(" " + str(e0id) + " ")
         
     | 
| 246 | 
         
            +
                    lines = [
         
     | 
| 247 | 
         
            +
                        [int(v) for v in line.split(" ") if len(v) > 0] for line in lines
         
     | 
| 248 | 
         
            +
                    ]
         
     | 
| 249 | 
         
            +
                    lens = [len(line) for line in lines]
         
     | 
| 250 | 
         
            +
                    # if 0 in lens:
         
     | 
| 251 | 
         
            +
                    #     logger.info("Warning: empty line in an example.")
         
     | 
| 252 | 
         
            +
                    lens = list(map(len, lines))
         
     | 
| 253 | 
         
            +
                    curlen = len(lens) + sum(lens)
         
     | 
| 254 | 
         
            +
                    left, right = 0, len(lines)
         
     | 
| 255 | 
         
            +
                    while curlen > args.max_source_length - 2:
         
     | 
| 256 | 
         
            +
                        if left % 2 == 0:
         
     | 
| 257 | 
         
            +
                            curlen -= 1 + len(lines[left])
         
     | 
| 258 | 
         
            +
                            left += 1
         
     | 
| 259 | 
         
            +
                        else:
         
     | 
| 260 | 
         
            +
                            right -= 1
         
     | 
| 261 | 
         
            +
                            curlen -= 1 + len(lines[right])
         
     | 
| 262 | 
         
            +
                    lines = lines[left:right]
         
     | 
| 263 | 
         
            +
                    labels = example.labels[left:right]
         
     | 
| 264 | 
         
            +
                    assert len(lines) + sum(map(len, lines)) <= args.max_source_length - 2, "Too long inputs in TextDataset.tokenize."
         
     | 
| 265 | 
         
            +
                    if len(lines) != len(labels):
         
     | 
| 266 | 
         
            +
                        logger.info("Not equal length in TextDataset.tokenize.")
         
     | 
| 267 | 
         
            +
                        lines = lines[:len(labels)]
         
     | 
| 268 | 
         
            +
                        labels = labels[:len(lines)]
         
     | 
| 269 | 
         
            +
                    example.lines = lines
         
     | 
| 270 | 
         
            +
                    example.labels = labels
         
     | 
| 271 | 
         
            +
                    example.msg = self.encode_remove(tokenizer, example.msg, args)
         
     | 
| 272 | 
         
            +
                    return example
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def convert_examples_to_features(self, item):
         
     | 
| 275 | 
         
            +
                    example, _, _ = item
         
     | 
| 276 | 
         
            +
                    if len(example.msg) > 0:
         
     | 
| 277 | 
         
            +
                        exs = []
         
     | 
| 278 | 
         
            +
                        for _ in range(3):  # up sampling
         
     | 
| 279 | 
         
            +
                            if random.random() < 0.5:
         
     | 
| 280 | 
         
            +
                                exs.append(self.genmsg_example(item))
         
     | 
| 281 | 
         
            +
                            else:
         
     | 
| 282 | 
         
            +
                                exs.append(self.daemsg_example(item))
         
     | 
| 283 | 
         
            +
                        return exs
         
     | 
| 284 | 
         
            +
                    if random.random() < 0.5:
         
     | 
| 285 | 
         
            +
                        return [self.encoder_example(item)]
         
     | 
| 286 | 
         
            +
                    return [self.decoder_example(item)]
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                def encoder_example(self, item):
         
     | 
| 289 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 290 | 
         
            +
                    lines = example.lines
         
     | 
| 291 | 
         
            +
                    labels = example.labels
         
     | 
| 292 | 
         
            +
                    target_ids = [tokenizer.pad_id] * args.max_target_length
         
     | 
| 293 | 
         
            +
                    source_ids, input_labels = [], []
         
     | 
| 294 | 
         
            +
                    for i, (line, label) in enumerate(zip(lines, labels)):
         
     | 
| 295 | 
         
            +
                        if i == example.start_id:
         
     | 
| 296 | 
         
            +
                            source_ids.append(tokenizer.start_id)
         
     | 
| 297 | 
         
            +
                            input_labels.append(-100)
         
     | 
| 298 | 
         
            +
                        if label != -100:       # only insert special tokens at diffs, not context
         
     | 
| 299 | 
         
            +
                            source_ids.append(tokenizer.mask_id)
         
     | 
| 300 | 
         
            +
                            input_labels.append(label)
         
     | 
| 301 | 
         
            +
                        source_ids.extend(line)
         
     | 
| 302 | 
         
            +
                        input_labels.extend([-100] * len(line))
         
     | 
| 303 | 
         
            +
                        if i == example.end_id:
         
     | 
| 304 | 
         
            +
                            source_ids.append(tokenizer.end_id)
         
     | 
| 305 | 
         
            +
                            input_labels.append(-100)
         
     | 
| 306 | 
         
            +
                    assert len(input_labels) == len(source_ids), "Not equal length."
         
     | 
| 307 | 
         
            +
                    assert len(input_labels) <= args.max_source_length, f"Too long inputs: {len(input_labels)}."
         
     | 
| 308 | 
         
            +
                    source_ids = source_ids[:args.max_source_length - 2]
         
     | 
| 309 | 
         
            +
                    input_labels = input_labels[:args.max_source_length - 2]
         
     | 
| 310 | 
         
            +
                    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
         
     | 
| 311 | 
         
            +
                    input_labels = [-100] + input_labels + [-100]
         
     | 
| 312 | 
         
            +
                    pad_len = args.max_source_length - len(source_ids)
         
     | 
| 313 | 
         
            +
                    source_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 314 | 
         
            +
                    input_labels += [-100] * pad_len
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    new_input_labels = []
         
     | 
| 317 | 
         
            +
                    map_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id}
         
     | 
| 318 | 
         
            +
                    for label in input_labels:
         
     | 
| 319 | 
         
            +
                        if label == -100:
         
     | 
| 320 | 
         
            +
                            new_input_labels.append(-100)
         
     | 
| 321 | 
         
            +
                        else:
         
     | 
| 322 | 
         
            +
                            new_input_labels.append(map_dict[label])
         
     | 
| 323 | 
         
            +
                    input_labels = new_input_labels
         
     | 
| 324 | 
         
            +
                    assert len(source_ids) == args.max_source_length, "Not equal length."
         
     | 
| 325 | 
         
            +
                    assert len(input_labels) == args.max_source_length, "Not equal length."
         
     | 
| 326 | 
         
            +
                    return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="label")
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def decoder_example(self, item):
         
     | 
| 329 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 330 | 
         
            +
                    lines = example.lines
         
     | 
| 331 | 
         
            +
                    labels = example.labels
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    input_labels = [-100] * args.max_source_length
         
     | 
| 334 | 
         
            +
                    source_ids, target_ids = [], []
         
     | 
| 335 | 
         
            +
                    SPECIAL_ID = 0
         
     | 
| 336 | 
         
            +
                    mask_idxs = random.choices(range(len(lines)), k=int(len(lines) * args.mask_rate))
         
     | 
| 337 | 
         
            +
                    id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id}
         
     | 
| 338 | 
         
            +
                    for i, (line, label) in enumerate(zip(lines, labels)):
         
     | 
| 339 | 
         
            +
                        if i == example.start_id:
         
     | 
| 340 | 
         
            +
                            source_ids.append(tokenizer.start_id)
         
     | 
| 341 | 
         
            +
                        if label in id_dict:
         
     | 
| 342 | 
         
            +
                            source_ids.append(id_dict[label])
         
     | 
| 343 | 
         
            +
                        if i in mask_idxs:
         
     | 
| 344 | 
         
            +
                            source_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
         
     | 
| 345 | 
         
            +
                            target_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
         
     | 
| 346 | 
         
            +
                            target_ids.extend(line)
         
     | 
| 347 | 
         
            +
                            if SPECIAL_ID < 99:     # only 0-99 ids in vocab
         
     | 
| 348 | 
         
            +
                                SPECIAL_ID += 1
         
     | 
| 349 | 
         
            +
                        else:
         
     | 
| 350 | 
         
            +
                            source_ids.extend(line)
         
     | 
| 351 | 
         
            +
                        if i == example.end_id:
         
     | 
| 352 | 
         
            +
                            source_ids.append(tokenizer.end_id)
         
     | 
| 353 | 
         
            +
                    source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
         
     | 
| 354 | 
         
            +
                    return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="line")
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                def genmsg_example(self, item):
         
     | 
| 357 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 358 | 
         
            +
                    lines = example.lines
         
     | 
| 359 | 
         
            +
                    labels = example.labels
         
     | 
| 360 | 
         
            +
                    input_labels = [-100] * args.max_source_length
         
     | 
| 361 | 
         
            +
                    source_ids, target_ids = [], []
         
     | 
| 362 | 
         
            +
                    id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id}
         
     | 
| 363 | 
         
            +
                    for i, (line, label) in enumerate(zip(lines, labels)):
         
     | 
| 364 | 
         
            +
                        if i == example.start_id:
         
     | 
| 365 | 
         
            +
                            source_ids.append(tokenizer.start_id)
         
     | 
| 366 | 
         
            +
                        if label != -100:
         
     | 
| 367 | 
         
            +
                            source_ids.append(id_dict[label])
         
     | 
| 368 | 
         
            +
                        source_ids.extend(line)
         
     | 
| 369 | 
         
            +
                        if i == example.end_id:
         
     | 
| 370 | 
         
            +
                            source_ids.append(tokenizer.end_id)
         
     | 
| 371 | 
         
            +
                    target_ids.append(tokenizer.msg_id)
         
     | 
| 372 | 
         
            +
                    target_ids.extend(example.msg)
         
     | 
| 373 | 
         
            +
                    assert len(source_ids) <= args.max_source_length, f"Too long inputs: {len(source_ids)}."
         
     | 
| 374 | 
         
            +
                    source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
         
     | 
| 375 | 
         
            +
                    return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="genmsg")
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                def daemsg_example(self, item):
         
     | 
| 378 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 379 | 
         
            +
                    input_labels = [-100] * args.max_source_length
         
     | 
| 380 | 
         
            +
                    source_ids, target_ids = [], []
         
     | 
| 381 | 
         
            +
                    msg_ids = cp(example.msg)
         
     | 
| 382 | 
         
            +
                    masks = [random.random() < 0.20 for _ in range(len(msg_ids))]
         
     | 
| 383 | 
         
            +
                    if sum(masks) == 0:
         
     | 
| 384 | 
         
            +
                        idx = random.choice(range(len(msg_ids)))
         
     | 
| 385 | 
         
            +
                        masks[idx] = True
         
     | 
| 386 | 
         
            +
                    source_ids, target_ids = [], []
         
     | 
| 387 | 
         
            +
                    i = 0
         
     | 
| 388 | 
         
            +
                    SPECIAL_ID = 0
         
     | 
| 389 | 
         
            +
                    while i < len(masks):
         
     | 
| 390 | 
         
            +
                        j = i
         
     | 
| 391 | 
         
            +
                        while j < len(masks) and not masks[j]:
         
     | 
| 392 | 
         
            +
                            source_ids.append(msg_ids[j])
         
     | 
| 393 | 
         
            +
                            j += 1
         
     | 
| 394 | 
         
            +
                        if j == len(masks):
         
     | 
| 395 | 
         
            +
                            break
         
     | 
| 396 | 
         
            +
                        source_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
         
     | 
| 397 | 
         
            +
                        target_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"])
         
     | 
| 398 | 
         
            +
                        while j < len(masks) and masks[j]:
         
     | 
| 399 | 
         
            +
                            target_ids.append(msg_ids[j])
         
     | 
| 400 | 
         
            +
                            j += 1
         
     | 
| 401 | 
         
            +
                        if SPECIAL_ID < 99:     # only 0-99 ids in vocab
         
     | 
| 402 | 
         
            +
                            SPECIAL_ID += 1
         
     | 
| 403 | 
         
            +
                        i = j
         
     | 
| 404 | 
         
            +
                    source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
         
     | 
| 405 | 
         
            +
                    return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="daemsg")
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                def pad_assert(self, source_ids, target_ids, args, tokenizer):
         
     | 
| 408 | 
         
            +
                    source_ids = source_ids[:args.max_source_length - 2]
         
     | 
| 409 | 
         
            +
                    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
         
     | 
| 410 | 
         
            +
                    pad_len = args.max_source_length - len(source_ids)
         
     | 
| 411 | 
         
            +
                    source_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 412 | 
         
            +
                    target_ids = target_ids[:args.max_target_length - 1]
         
     | 
| 413 | 
         
            +
                    target_ids = target_ids + [tokenizer.eos_id]
         
     | 
| 414 | 
         
            +
                    pad_len = args.max_target_length - len(target_ids)
         
     | 
| 415 | 
         
            +
                    target_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 416 | 
         
            +
                    assert len(source_ids) == args.max_source_length, "Not equal length."
         
     | 
| 417 | 
         
            +
                    assert len(target_ids) == args.max_target_length, "Not equal length."
         
     | 
| 418 | 
         
            +
                    return source_ids, target_ids
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                def encode_remove(self, tokenizer, text, args):
         
     | 
| 421 | 
         
            +
                    text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True)
         
     | 
| 422 | 
         
            +
                    if type(tokenizer) == T5Tokenizer:
         
     | 
| 423 | 
         
            +
                        return text[:-1]
         
     | 
| 424 | 
         
            +
                    elif type(tokenizer) == RobertaTokenizer:
         
     | 
| 425 | 
         
            +
                        return text[1:-1]
         
     | 
| 426 | 
         
            +
                    elif type(tokenizer) == MyTokenizer:
         
     | 
| 427 | 
         
            +
                        return text
         
     | 
| 428 | 
         
            +
                    else:
         
     | 
| 429 | 
         
            +
                        raise NotImplementedError
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
            class CommentGenDataset(TextDataset):
         
     | 
| 433 | 
         
            +
                def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
         
     | 
| 434 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 435 | 
         
            +
                    if isinstance(tokenizer, MyTokenizer):
         
     | 
| 436 | 
         
            +
                        tokenizer_type = "mytok"
         
     | 
| 437 | 
         
            +
                    elif isinstance(tokenizer, T5Tokenizer):
         
     | 
| 438 | 
         
            +
                        tokenizer_type = ""
         
     | 
| 439 | 
         
            +
                    elif isinstance(tokenizer, RobertaTokenizer):
         
     | 
| 440 | 
         
            +
                        tokenizer_type = "rb"
         
     | 
| 441 | 
         
            +
                    else:
         
     | 
| 442 | 
         
            +
                        tokenizer_type = "unk"
         
     | 
| 443 | 
         
            +
                    savep = file_path.replace(".jsonl", tokenizer_type + ".exps")
         
     | 
| 444 | 
         
            +
                    if os.path.exists(savep):
         
     | 
| 445 | 
         
            +
                        logger.info("Loading examples from {}".format(savep))
         
     | 
| 446 | 
         
            +
                        examples = torch.load(savep)
         
     | 
| 447 | 
         
            +
                    else:
         
     | 
| 448 | 
         
            +
                        logger.info("Reading examples from {}".format(file_path))
         
     | 
| 449 | 
         
            +
                        examples = read_review_examples(file_path, samplenum, tokenizer)
         
     | 
| 450 | 
         
            +
                        # for i in range(len(examples)):
         
     | 
| 451 | 
         
            +
                        #     examples[i].msg = " ".join(nltk.word_tokenize(examples[i].msg))
         
     | 
| 452 | 
         
            +
                        logger.info(f"Tokenize examples: {file_path}")
         
     | 
| 453 | 
         
            +
                        examples = pool.map(self.tokenize, \
         
     | 
| 454 | 
         
            +
                            [(example, tokenizer, args) for example in examples])
         
     | 
| 455 | 
         
            +
                        torch.save(examples, savep)
         
     | 
| 456 | 
         
            +
                    logger.info("Convert examples to features...")
         
     | 
| 457 | 
         
            +
                    self.set_start_end_ids(examples)
         
     | 
| 458 | 
         
            +
                    self.feats = pool.map(self.convert_examples_to_features, \
         
     | 
| 459 | 
         
            +
                        [(example, tokenizer, args) for example in examples])
         
     | 
| 460 | 
         
            +
                    self.feats = [feat for feat in self.feats if feat is not None]
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                def convert_examples_to_features(self, item):
         
     | 
| 463 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 464 | 
         
            +
                    if len(example.msg) == 0:
         
     | 
| 465 | 
         
            +
                        return None
         
     | 
| 466 | 
         
            +
                    return self.genmsg_example(item)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
            class CommentClsDataset(TextDataset):
         
     | 
| 470 | 
         
            +
                def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
         
     | 
| 471 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 472 | 
         
            +
                    if isinstance(tokenizer, MyTokenizer):
         
     | 
| 473 | 
         
            +
                        tokenizer_type = "mytok"
         
     | 
| 474 | 
         
            +
                    elif isinstance(tokenizer, T5Tokenizer):
         
     | 
| 475 | 
         
            +
                        tokenizer_type = ""
         
     | 
| 476 | 
         
            +
                    elif isinstance(tokenizer, RobertaTokenizer):
         
     | 
| 477 | 
         
            +
                        tokenizer_type = "rb"
         
     | 
| 478 | 
         
            +
                    else:
         
     | 
| 479 | 
         
            +
                        tokenizer_type = "unk"
         
     | 
| 480 | 
         
            +
                    savep = file_path.replace(".jsonl", tokenizer_type + ".exps")
         
     | 
| 481 | 
         
            +
                    if os.path.exists(savep):
         
     | 
| 482 | 
         
            +
                        logger.info("Loading examples from {}".format(savep))
         
     | 
| 483 | 
         
            +
                        examples = torch.load(savep)
         
     | 
| 484 | 
         
            +
                    else:
         
     | 
| 485 | 
         
            +
                        logger.info("Reading examples from {}".format(file_path))
         
     | 
| 486 | 
         
            +
                        examples = read_review_examples(file_path, samplenum, tokenizer)
         
     | 
| 487 | 
         
            +
                        logger.info(f"Tokenize examples: {file_path}")
         
     | 
| 488 | 
         
            +
                        examples = pool.map(self.tokenize, \
         
     | 
| 489 | 
         
            +
                            [(example, tokenizer, args) for example in examples])
         
     | 
| 490 | 
         
            +
                        torch.save(examples, savep)
         
     | 
| 491 | 
         
            +
                    logger.info("Convert examples to features...")
         
     | 
| 492 | 
         
            +
                    self.set_start_end_ids(examples)
         
     | 
| 493 | 
         
            +
                    self.feats = pool.map(self.convert_examples_to_features, \
         
     | 
| 494 | 
         
            +
                        [(example, tokenizer, args) for example in examples])
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                def convert_examples_to_features(self, item):
         
     | 
| 497 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 498 | 
         
            +
                    tmpfeature = self.genmsg_example(item)
         
     | 
| 499 | 
         
            +
                    return ClsFeatures(tmpfeature.example_id, tmpfeature.source_ids, example.y)
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
            class SimpleClsDataset(TextDataset):
         
     | 
| 503 | 
         
            +
                def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
         
     | 
| 504 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 505 | 
         
            +
                    if isinstance(tokenizer, MyTokenizer):
         
     | 
| 506 | 
         
            +
                        tokenizer_type = "mytok"
         
     | 
| 507 | 
         
            +
                    elif isinstance(tokenizer, T5Tokenizer):
         
     | 
| 508 | 
         
            +
                        tokenizer_type = ""
         
     | 
| 509 | 
         
            +
                    elif isinstance(tokenizer, RobertaTokenizer):
         
     | 
| 510 | 
         
            +
                        tokenizer_type = "rb"
         
     | 
| 511 | 
         
            +
                    else:
         
     | 
| 512 | 
         
            +
                        tokenizer_type = "unk"
         
     | 
| 513 | 
         
            +
                    savep = file_path.replace(".jsonl", tokenizer_type + ".simpexps")
         
     | 
| 514 | 
         
            +
                    if os.path.exists(savep):
         
     | 
| 515 | 
         
            +
                        logger.info("Loading examples from {}".format(savep))
         
     | 
| 516 | 
         
            +
                        self.feats = torch.load(savep)
         
     | 
| 517 | 
         
            +
                    else:
         
     | 
| 518 | 
         
            +
                        logger.info("Reading examples from {}".format(file_path))
         
     | 
| 519 | 
         
            +
                        examples = read_review_examples(file_path, samplenum, tokenizer)
         
     | 
| 520 | 
         
            +
                        logger.info(f"Tokenize examples: {file_path}")
         
     | 
| 521 | 
         
            +
                        self.feats = pool.map(self.convert_examples_to_features, \
         
     | 
| 522 | 
         
            +
                            [(example, tokenizer, args) for example in examples])
         
     | 
| 523 | 
         
            +
                        torch.save(self.feats, savep)
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                def convert_examples_to_features(self, item):
         
     | 
| 526 | 
         
            +
                    example, tokenizer, args = item
         
     | 
| 527 | 
         
            +
                    example.input_lines = example.input.split("<e0>")
         
     | 
| 528 | 
         
            +
                    labels_l = len(example.labels)
         
     | 
| 529 | 
         
            +
                    example.input_lines = example.input_lines[:labels_l]
         
     | 
| 530 | 
         
            +
                    for i in range(len(example.input_lines)):
         
     | 
| 531 | 
         
            +
                        if example.labels[i] == 1:
         
     | 
| 532 | 
         
            +
                            example.input_lines[i] = "+ " + example.input_lines[i]
         
     | 
| 533 | 
         
            +
                        elif example.labels[i] == 0:
         
     | 
| 534 | 
         
            +
                            example.input_lines[i] = "- " + example.input_lines[i]
         
     | 
| 535 | 
         
            +
                    example.input = " ".join(example.input_lines)
         
     | 
| 536 | 
         
            +
                    input_ids = self.encode_remove(tokenizer, example.input, args)
         
     | 
| 537 | 
         
            +
                    exceed_l = len(input_ids) - args.max_source_length + 2
         
     | 
| 538 | 
         
            +
                    if exceed_l > 0:
         
     | 
| 539 | 
         
            +
                        halfexl = (exceed_l + 1) // 2
         
     | 
| 540 | 
         
            +
                        input_ids = input_ids[halfexl:-halfexl]
         
     | 
| 541 | 
         
            +
                    source_ids = input_ids[:args.max_source_length - 2]
         
     | 
| 542 | 
         
            +
                    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
         
     | 
| 543 | 
         
            +
                    pad_len = args.max_source_length - len(source_ids)
         
     | 
| 544 | 
         
            +
                    source_ids += [tokenizer.pad_id] * pad_len
         
     | 
| 545 | 
         
            +
                    example_id = example.idx
         
     | 
| 546 | 
         
            +
                    y = example.y
         
     | 
| 547 | 
         
            +
                    return ClsFeatures(example_id, source_ids, y)
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
            class SimpleGenDataset(TextDataset):
         
     | 
| 551 | 
         
            +
                def __init__(self, tokenizer, pool, args, file_path, samplenum=-1):
         
     | 
| 552 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 553 | 
         
            +
                    if isinstance(tokenizer, MyTokenizer):
         
     | 
| 554 | 
         
            +
                        tokenizer_type = "mytok"
         
     | 
| 555 | 
         
            +
                    elif isinstance(tokenizer, T5Tokenizer):
         
     | 
| 556 | 
         
            +
                        tokenizer_type = ""
         
     | 
| 557 | 
         
            +
                    elif isinstance(tokenizer, RobertaTokenizer):
         
     | 
| 558 | 
         
            +
                        tokenizer_type = "rb"
         
     | 
| 559 | 
         
            +
                    else:
         
     | 
| 560 | 
         
            +
                        tokenizer_type = "unk"
         
     | 
| 561 | 
         
            +
                    savep = file_path.replace(".jsonl", tokenizer_type + ".simpgenexps")
         
     | 
| 562 | 
         
            +
                    if os.path.exists(savep):
         
     | 
| 563 | 
         
            +
                        logger.info("Loading examples from {}".format(savep))
         
     | 
| 564 | 
         
            +
                        self.feats = torch.load(savep)
         
     | 
| 565 | 
         
            +
                    else:
         
     | 
| 566 | 
         
            +
                        logger.info("Reading examples from {}".format(file_path))
         
     | 
| 567 | 
         
            +
                        data = read_jsonl(file_path)
         
     | 
| 568 | 
         
            +
                        # data = [dic for dic in data if len(dic["patch"].split("\n")) <= 20]
         
     | 
| 569 | 
         
            +
                        for i in range(len(data)):
         
     | 
| 570 | 
         
            +
                            data[i]["idx"] = i
         
     | 
| 571 | 
         
            +
                        logger.info(f"Tokenize examples: {file_path}")
         
     | 
| 572 | 
         
            +
                        # self.feats = pool.map(self.convert_examples_to_features, \
         
     | 
| 573 | 
         
            +
                        #     [(dic, tokenizer, args) for dic in data])
         
     | 
| 574 | 
         
            +
                        self.feats = [self.convert_examples_to_features((dic, tokenizer, args)) for dic in data]
         
     | 
| 575 | 
         
            +
                        torch.save(self.feats, savep)
         
     | 
| 576 | 
         
            +
             
     | 
| 577 | 
         
            +
                def convert_examples_to_features(self, item):
         
     | 
| 578 | 
         
            +
                    dic, tokenizer, args = item
         
     | 
| 579 | 
         
            +
                    diff, msg = dic["patch"], dic["msg"]
         
     | 
| 580 | 
         
            +
                    difflines = diff.split("\n")[1:]        # remove start @@
         
     | 
| 581 | 
         
            +
                    difflines = [line for line in difflines if len(line.strip()) > 0]
         
     | 
| 582 | 
         
            +
                    map_dic = {"-": 0, "+": 1, " ": 2}
         
     | 
| 583 | 
         
            +
                    def f(s):
         
     | 
| 584 | 
         
            +
                        if s in map_dic:
         
     | 
| 585 | 
         
            +
                            return map_dic[s]
         
     | 
| 586 | 
         
            +
                        else:
         
     | 
| 587 | 
         
            +
                            return 2
         
     | 
| 588 | 
         
            +
                    labels = [f(line[0]) for line in difflines]
         
     | 
| 589 | 
         
            +
                    difflines = [line[1:].strip() for line in difflines]
         
     | 
| 590 | 
         
            +
                    inputstr = ""
         
     | 
| 591 | 
         
            +
                    for label, line in zip(labels, difflines):
         
     | 
| 592 | 
         
            +
                        if label == 1:
         
     | 
| 593 | 
         
            +
                            inputstr += "<add>" + line
         
     | 
| 594 | 
         
            +
                        elif label == 0:
         
     | 
| 595 | 
         
            +
                            inputstr += "<del>" + line
         
     | 
| 596 | 
         
            +
                        else:
         
     | 
| 597 | 
         
            +
                            inputstr += "<keep>" + line
         
     | 
| 598 | 
         
            +
                    source_ids = self.encode_remove(tokenizer, inputstr, args)
         
     | 
| 599 | 
         
            +
                    target_ids = []
         
     | 
| 600 | 
         
            +
                    target_ids.append(tokenizer.msg_id)
         
     | 
| 601 | 
         
            +
                    msg = self.encode_remove(tokenizer, dic["msg"], args)
         
     | 
| 602 | 
         
            +
                    target_ids.extend(msg)
         
     | 
| 603 | 
         
            +
                    source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer)
         
     | 
| 604 | 
         
            +
                    input_labels = [-100] * len(source_ids)
         
     | 
| 605 | 
         
            +
                    return ReviewFeatures(dic["idx"], source_ids, input_labels, target_ids, type="genmsg")
         
     | 
| 606 | 
         
            +
             
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
            class InputFeatures(object):
         
     | 
| 609 | 
         
            +
                """A single training/test features for a example."""
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                def __init__(self, example_id, source_ids, target_ids, url=None):
         
     | 
| 612 | 
         
            +
                    self.example_id = example_id
         
     | 
| 613 | 
         
            +
                    self.source_ids = source_ids
         
     | 
| 614 | 
         
            +
                    self.target_ids = target_ids
         
     | 
| 615 | 
         
            +
                    self.url = url
         
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
            class ReviewFeatures(object):
         
     | 
| 619 | 
         
            +
                def __init__(self, example_id, source_ids, source_labels, target_ids, type):
         
     | 
| 620 | 
         
            +
                    self.example_id = example_id
         
     | 
| 621 | 
         
            +
                    self.source_ids = source_ids
         
     | 
| 622 | 
         
            +
                    self.source_labels = source_labels
         
     | 
| 623 | 
         
            +
                    self.target_ids = target_ids
         
     | 
| 624 | 
         
            +
                    assert type in ("label", "line", "genmsg", "daemsg")
         
     | 
| 625 | 
         
            +
                    self.type = type
         
     | 
| 626 | 
         
            +
             
     | 
| 627 | 
         
            +
            class ClsFeatures(object):
         
     | 
| 628 | 
         
            +
                def __init__(self, example_id, source_ids, y):
         
     | 
| 629 | 
         
            +
                    self.example_id = example_id
         
     | 
| 630 | 
         
            +
                    self.source_ids = source_ids
         
     | 
| 631 | 
         
            +
                    self.y = y
         
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
            class ReviewExample(object):
         
     | 
| 634 | 
         
            +
                """A single training/test example."""
         
     | 
| 635 | 
         
            +
             
     | 
| 636 | 
         
            +
                def __init__(
         
     | 
| 637 | 
         
            +
                    self, idx, oldf, diff, msg, cmtid, max_len, y
         
     | 
| 638 | 
         
            +
                ):
         
     | 
| 639 | 
         
            +
                    self.idx = idx      # idx is useless yet
         
     | 
| 640 | 
         
            +
                    self.oldf = oldf
         
     | 
| 641 | 
         
            +
                    self.diff = diff
         
     | 
| 642 | 
         
            +
                    self.msg = msg
         
     | 
| 643 | 
         
            +
                    self.cmtid = cmtid
         
     | 
| 644 | 
         
            +
                    self.max_len = max_len
         
     | 
| 645 | 
         
            +
                    self.y = y
         
     | 
| 646 | 
         
            +
                    self.prevlines = []
         
     | 
| 647 | 
         
            +
                    self.afterlines = []
         
     | 
| 648 | 
         
            +
                    self.lines = []
         
     | 
| 649 | 
         
            +
                    self.labels = []
         
     | 
| 650 | 
         
            +
                    self.avail = False
         
     | 
| 651 | 
         
            +
                    self.input = ""
         
     | 
| 652 | 
         
            +
                    self.align_and_clean()
         
     | 
| 653 | 
         
            +
                    self.postprocess()
         
     | 
| 654 | 
         
            +
             
     | 
| 655 | 
         
            +
                def postprocess(self):
         
     | 
| 656 | 
         
            +
                    if not self.avail:
         
     | 
| 657 | 
         
            +
                        return
         
     | 
| 658 | 
         
            +
                    # Warning: lines is not self.lines
         
     | 
| 659 | 
         
            +
                    # lines for rough length estimation
         
     | 
| 660 | 
         
            +
                    lines = [source_str.split() for source_str in self.lines]
         
     | 
| 661 | 
         
            +
                    inputl = len(lines) # line tag
         
     | 
| 662 | 
         
            +
                    inputl += sum(map(len, lines))
         
     | 
| 663 | 
         
            +
                    left, right = 0, len(lines)
         
     | 
| 664 | 
         
            +
                    while inputl > self.max_len:
         
     | 
| 665 | 
         
            +
                        if left % 2 == 0:
         
     | 
| 666 | 
         
            +
                            inputl -= len(lines[left]) + 1
         
     | 
| 667 | 
         
            +
                            left += 1
         
     | 
| 668 | 
         
            +
                        else:
         
     | 
| 669 | 
         
            +
                            right -= 1
         
     | 
| 670 | 
         
            +
                            inputl -= len(lines[right]) + 1
         
     | 
| 671 | 
         
            +
                    lines = lines[left:right]
         
     | 
| 672 | 
         
            +
                    self.lines = self.lines[left:right]
         
     | 
| 673 | 
         
            +
                    self.labels = self.labels[left:right]
         
     | 
| 674 | 
         
            +
                    prevlines = self.prevlines
         
     | 
| 675 | 
         
            +
                    afterlines = self.afterlines
         
     | 
| 676 | 
         
            +
                    prev_after_len = max(len(prevlines), len(afterlines))
         
     | 
| 677 | 
         
            +
                    i = 0
         
     | 
| 678 | 
         
            +
                    while inputl < self.max_len and i < prev_after_len:
         
     | 
| 679 | 
         
            +
                        if i < len(prevlines):
         
     | 
| 680 | 
         
            +
                            newl = inputl + len(prevlines[-1-i].split()) + 1
         
     | 
| 681 | 
         
            +
                            if newl > self.max_len:
         
     | 
| 682 | 
         
            +
                                break
         
     | 
| 683 | 
         
            +
                            self.lines.insert(0, prevlines[-1-i])
         
     | 
| 684 | 
         
            +
                            self.labels.insert(0, -100)
         
     | 
| 685 | 
         
            +
                            inputl = newl  # tag
         
     | 
| 686 | 
         
            +
                        if i < len(afterlines):
         
     | 
| 687 | 
         
            +
                            newl = inputl + len(afterlines[i].split()) + 1
         
     | 
| 688 | 
         
            +
                            if newl > self.max_len:
         
     | 
| 689 | 
         
            +
                                break
         
     | 
| 690 | 
         
            +
                            self.lines.append(afterlines[i])
         
     | 
| 691 | 
         
            +
                            self.labels.append(-100)
         
     | 
| 692 | 
         
            +
                            inputl = newl    # tag
         
     | 
| 693 | 
         
            +
                        i += 1
         
     | 
| 694 | 
         
            +
                    assert inputl <= self.max_len, "Too long inputs."
         
     | 
| 695 | 
         
            +
                    assert len(self.lines) == len(self.labels), "Not equal length."
         
     | 
| 696 | 
         
            +
                    self.input = "<e0>".join(self.lines)
         
     | 
| 697 | 
         
            +
                    self.prevlines, self.lines, self.afterlines = [], [], []
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                def remove_space_clean(self, line):
         
     | 
| 700 | 
         
            +
                    """
         
     | 
| 701 | 
         
            +
                        Remove start and end empty chars.
         
     | 
| 702 | 
         
            +
                    """
         
     | 
| 703 | 
         
            +
                    rep = " \t\r"
         
     | 
| 704 | 
         
            +
                    totallen = len(line)
         
     | 
| 705 | 
         
            +
                    i = 0
         
     | 
| 706 | 
         
            +
                    while i < totallen and line[i] in rep:
         
     | 
| 707 | 
         
            +
                        i += 1
         
     | 
| 708 | 
         
            +
                    j = totallen - 1
         
     | 
| 709 | 
         
            +
                    while j >= 0 and line[j] in rep:
         
     | 
| 710 | 
         
            +
                        j -= 1
         
     | 
| 711 | 
         
            +
                    line = line[i : j + 1]
         
     | 
| 712 | 
         
            +
                    return line
         
     | 
| 713 | 
         
            +
             
     | 
| 714 | 
         
            +
                def align_and_clean(self):
         
     | 
| 715 | 
         
            +
                    oldflines = self.oldf.split("\n")
         
     | 
| 716 | 
         
            +
                    difflines = self.diff.split("\n")
         
     | 
| 717 | 
         
            +
                    first_line = difflines[0]
         
     | 
| 718 | 
         
            +
                    difflines = difflines[1:]
         
     | 
| 719 | 
         
            +
                    difflines = [line for line in difflines if line != r""]
         
     | 
| 720 | 
         
            +
                    regex = r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@"
         
     | 
| 721 | 
         
            +
                    matchres = re.match(regex, first_line)
         
     | 
| 722 | 
         
            +
                    if matchres:
         
     | 
| 723 | 
         
            +
                        startline, rangelen, startpos, endpos = matchres.groups()
         
     | 
| 724 | 
         
            +
                        self.avail = True
         
     | 
| 725 | 
         
            +
                    else:
         
     | 
| 726 | 
         
            +
                        self.avail = False
         
     | 
| 727 | 
         
            +
                        return
         
     | 
| 728 | 
         
            +
                    startline, rangelen = int(startline) - 1, int(rangelen)
         
     | 
| 729 | 
         
            +
                    endline = startline + rangelen
         
     | 
| 730 | 
         
            +
                    self.prevlines = oldflines[:startline]
         
     | 
| 731 | 
         
            +
                    self.afterlines = oldflines[endline:]
         
     | 
| 732 | 
         
            +
                    for line in difflines:
         
     | 
| 733 | 
         
            +
                        if line.startswith("-"):
         
     | 
| 734 | 
         
            +
                            self.lines.append(line[1:])
         
     | 
| 735 | 
         
            +
                            self.labels.append(0)
         
     | 
| 736 | 
         
            +
                        elif line.startswith("+"):
         
     | 
| 737 | 
         
            +
                            self.lines.append(line[1:])
         
     | 
| 738 | 
         
            +
                            self.labels.append(1)
         
     | 
| 739 | 
         
            +
                        else:
         
     | 
| 740 | 
         
            +
                            self.lines.append(line)
         
     | 
| 741 | 
         
            +
                            self.labels.append(2)
         
     | 
| 742 | 
         
            +
                    self.prevlines = [self.remove_space_clean(line) for line in self.prevlines]
         
     | 
| 743 | 
         
            +
                    self.afterlines = [self.remove_space_clean(line) for line in self.afterlines]
         
     | 
| 744 | 
         
            +
                    self.lines = [self.remove_space_clean(line) for line in self.lines]
         
     | 
| 745 | 
         
            +
                    self.msg = self.remove_space_clean(self.msg)
         
     | 
| 746 | 
         
            +
                    self.prevlines = [line for line in self.prevlines if len(line) > 0]
         
     | 
| 747 | 
         
            +
                    self.afterlines = [line for line in self.afterlines if len(line) > 0]
         
     | 
| 748 | 
         
            +
                    # print("\n".join(self.prevlines))
         
     | 
| 749 | 
         
            +
                    # print("\n\n\n\n")
         
     | 
| 750 | 
         
            +
                    # print("\n".join(self.lines))
         
     | 
| 751 | 
         
            +
                    # print("\n\n\n\n")
         
     | 
| 752 | 
         
            +
                    # print("\n".join(self.afterlines))
         
     | 
| 753 | 
         
            +
                    # print("\n\n\n\n")
         
     | 
| 754 | 
         
            +
                    assert len(self.lines) == len(self.labels), "Not equal length in align."
         
     | 
| 755 | 
         
            +
                    topack = list(
         
     | 
| 756 | 
         
            +
                        zip(
         
     | 
| 757 | 
         
            +
                            *[
         
     | 
| 758 | 
         
            +
                                (line, label)
         
     | 
| 759 | 
         
            +
                                for line, label in zip(self.lines, self.labels)
         
     | 
| 760 | 
         
            +
                                if len(line) > 0
         
     | 
| 761 | 
         
            +
                            ]
         
     | 
| 762 | 
         
            +
                        )
         
     | 
| 763 | 
         
            +
                    )
         
     | 
| 764 | 
         
            +
                    if topack == []:
         
     | 
| 765 | 
         
            +
                        self.avail = False
         
     | 
| 766 | 
         
            +
                        return
         
     | 
| 767 | 
         
            +
                    else:
         
     | 
| 768 | 
         
            +
                        self.lines, self.labels = topack
         
     | 
| 769 | 
         
            +
                    # tuple->list, convenient for later operation
         
     | 
| 770 | 
         
            +
                    self.lines = list(self.lines)
         
     | 
| 771 | 
         
            +
                    self.labels = list(self.labels)
         
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
             
     | 
| 774 | 
         
            +
            def read_review_examples(filename, data_num=-1, tokenizer=None):
         
     | 
| 775 | 
         
            +
                """Read examples from filename."""
         
     | 
| 776 | 
         
            +
                examples = []
         
     | 
| 777 | 
         
            +
                idx = 0
         
     | 
| 778 | 
         
            +
                with open(filename) as f:
         
     | 
| 779 | 
         
            +
                    for line in f:
         
     | 
| 780 | 
         
            +
                        try:
         
     | 
| 781 | 
         
            +
                            js = json.loads(line.strip())
         
     | 
| 782 | 
         
            +
                        except:
         
     | 
| 783 | 
         
            +
                            print("Error during reading json data.")
         
     | 
| 784 | 
         
            +
                            continue
         
     | 
| 785 | 
         
            +
                        maxl = 200
         
     | 
| 786 | 
         
            +
                        if "y" not in js:
         
     | 
| 787 | 
         
            +
                            js["y"] = 0
         
     | 
| 788 | 
         
            +
                        if "msg" in js and len(js["msg"]) > 0:
         
     | 
| 789 | 
         
            +
                            js["y"] = 1
         
     | 
| 790 | 
         
            +
                        example = ReviewExample(
         
     | 
| 791 | 
         
            +
                                    idx=idx,
         
     | 
| 792 | 
         
            +
                                    oldf=js["oldf"],
         
     | 
| 793 | 
         
            +
                                    diff=js["patch"],
         
     | 
| 794 | 
         
            +
                                    msg=js["msg"] if "msg" in js else "",
         
     | 
| 795 | 
         
            +
                                    cmtid=js["cmtid"] if "cmtid" in js else "",
         
     | 
| 796 | 
         
            +
                                    max_len=maxl,
         
     | 
| 797 | 
         
            +
                                    y=js["y"]
         
     | 
| 798 | 
         
            +
                                )
         
     | 
| 799 | 
         
            +
                        if example.avail:
         
     | 
| 800 | 
         
            +
                            examples.append(example)
         
     | 
| 801 | 
         
            +
                            idx += 1
         
     | 
| 802 | 
         
            +
                            if idx == data_num:
         
     | 
| 803 | 
         
            +
                                break
         
     | 
| 804 | 
         
            +
                        else:
         
     | 
| 805 | 
         
            +
                            # print(f"Passing {idx} because of invalid diff.")
         
     | 
| 806 | 
         
            +
                            idx += 1 
         
     | 
| 807 | 
         
            +
                            if idx == data_num:
         
     | 
| 808 | 
         
            +
                                break
         
     | 
| 809 | 
         
            +
                            
         
     | 
| 810 | 
         
            +
                return examples
         
     | 
| 811 | 
         
            +
             
     | 
| 812 | 
         
            +
             
     | 
| 813 | 
         
            +
            def read_jsonl(path):
         
     | 
| 814 | 
         
            +
                data = []
         
     | 
| 815 | 
         
            +
                with open(path) as f:
         
     | 
| 816 | 
         
            +
                    for line in f:
         
     | 
| 817 | 
         
            +
                        try:
         
     | 
| 818 | 
         
            +
                            js = json.loads(line.strip())
         
     | 
| 819 | 
         
            +
                        except:
         
     | 
| 820 | 
         
            +
                            print("Error during reading json data.")
         
     | 
| 821 | 
         
            +
                            continue
         
     | 
| 822 | 
         
            +
                        data.append(js)
         
     | 
| 823 | 
         
            +
                return data
         
     |