calvini commited on
Commit
2f4c052
·
verified ·
1 Parent(s): 7aac7aa

Upload 10 files

Browse files

Causal Language Classifier

all_thresholds.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "f1": {
3
+ "threshold": 0.3500000000000001,
4
+ "metrics": {
5
+ "accuracy": 0.8336666666666667,
6
+ "precision": 0.7284546805349182,
7
+ "recall": 0.88016157989228,
8
+ "f1": 0.7971544715447154
9
+ }
10
+ },
11
+ "balanced": {
12
+ "threshold": 0.5500000000000002,
13
+ "metrics": {
14
+ "accuracy": 0.847,
15
+ "precision": 0.7971869328493648,
16
+ "recall": 0.7885996409335727,
17
+ "f1": 0.7928700361010831
18
+ }
19
+ }
20
+ }
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "KM4STfulltext/SSCI-BERT-e4",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 0,
20
+ "position_embedding_type": "absolute",
21
+ "problem_type": "single_label_classification",
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.48.0",
24
+ "type_vocab_size": 2,
25
+ "use_cache": true,
26
+ "vocab_size": 28996
27
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d83ffc7765f5797aefd4327502bda64752563cd0ddd2c8ed025f2902e875548f
3
+ size 433270768
predict_ft_model_causal.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def predict_with_threshold(text, model_dir="causal-classifier", threshold=None, threshold_type=None):
2
+ """
3
+ Make a prediction using the saved model and custom threshold.
4
+
5
+ Args:
6
+ text: Input text to classify
7
+ model_dir: Directory where the model and threshold are saved
8
+ threshold: Custom threshold to use (if None, loads from saved config)
9
+ threshold_type: Type of threshold to use ('f1' or 'balanced') if loading from all_thresholds.json
10
+
11
+ Returns:
12
+ Dictionary with prediction results
13
+ """
14
+ # Load the model and tokenizer
15
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
+ import torch
17
+ import json
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
20
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
21
+
22
+ threshold_source = "custom"
23
+
24
+ # Load the threshold if not provided
25
+ if threshold is None:
26
+ # Try multiple methods to get the threshold
27
+ if threshold_type is not None:
28
+ # First try all_thresholds.json if threshold_type is specified
29
+ try:
30
+ with open(f"{model_dir}/all_thresholds.json", "r") as f:
31
+ all_thresholds = json.load(f)
32
+ if threshold_type in all_thresholds:
33
+ threshold = all_thresholds[threshold_type]["threshold"]
34
+ threshold_source = f"all_thresholds.json ({threshold_type})"
35
+ else:
36
+ print(f"Threshold type '{threshold_type}' not found. Available types: {list(all_thresholds.keys())}")
37
+ except FileNotFoundError:
38
+ pass
39
+
40
+ # If still no threshold, try threshold_config.json
41
+ if threshold is None:
42
+ try:
43
+ with open(f"{model_dir}/threshold_config.json", "r") as f:
44
+ config = json.load(f)
45
+ threshold = config["threshold"]
46
+ threshold_source = "threshold_config.json"
47
+ except FileNotFoundError:
48
+ # Default to 0.5 if no threshold config is found
49
+ threshold = 0.5
50
+ threshold_source = "default"
51
+ print("No threshold configuration found. Using default threshold of 0.5.")
52
+
53
+ # Tokenize the input text
54
+ inputs = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
55
+
56
+ # Get model prediction
57
+ model.eval()
58
+ with torch.no_grad():
59
+ outputs = model(**inputs)
60
+
61
+ # Convert logits to probabilities
62
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
63
+
64
+ # Apply threshold to get the final prediction
65
+ if isinstance(probs, list): # Handle batch size of 1
66
+ prediction = 1 if probs[1] > threshold else 0
67
+ class_probs = {
68
+ "Causal": probs[1],
69
+ "Descriptive": probs[0]
70
+ }
71
+ else: # Handle single prediction case
72
+ prediction = 1 if probs > threshold else 0
73
+ class_probs = {
74
+ "Causal": probs,
75
+ "Descriptive": 1 - probs
76
+ }
77
+
78
+ # Map prediction back to original label
79
+ label_names = {1: "Causal", 0: "Descriptive"}
80
+
81
+ return {
82
+ "prediction": label_names[prediction],
83
+ "probabilities": class_probs,
84
+ "threshold_used": threshold,
85
+ "threshold_source": threshold_source
86
+ }
87
+
88
+
89
+ sample_text = 'This is a causal study that aims to investigate the relationship between smoking and lung cancer.'
90
+ result = predict_with_threshold(sample_text)
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
threshold_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "threshold_type": "balanced",
3
+ "threshold": 0.5500000000000002,
4
+ "metrics": {
5
+ "accuracy": 0.847,
6
+ "precision": 0.7971869328493648,
7
+ "recall": 0.7885996409335727,
8
+ "f1": 0.7928700361010831
9
+ }
10
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 1000000000000000019884624838656,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4159057d9ab821b7ae87bf1469f0a43b53f84ca0c9b6a7a3b71f8376df8a4ea1
3
+ size 5304
vocab.txt ADDED
The diff for this file is too large to render. See raw diff