Dan Flower commited on
Commit
53d10a6
·
1 Parent(s): d7e8c05

fix(model): Python 3.9 typing; move model_runner to model/; centralize flags in utils/

Browse files
Files changed (1) hide show
  1. model/model_runner.py +0 -101
model/model_runner.py DELETED
@@ -1,101 +0,0 @@
1
- # model_runner.py
2
- import os
3
- import sys
4
- from typing import List
5
- from llama_cpp import Llama
6
-
7
- # ---- Phase 2: flags (no behavior change) ------------------------------------
8
- # Reads LAB_* env toggles; all defaults preserve current behavior.
9
- try:
10
- from TemplateA.utils import flags # if your package path is different, adjust import
11
- except Exception:
12
- # Fallback inline flags if template.utils.flags isn't available in this lab
13
- def _as_bool(val: str | None, default: bool) -> bool:
14
- if val is None:
15
- return default
16
- return val.strip().lower() in {"1", "true", "yes", "on", "y", "t"}
17
- class _F:
18
- SANITIZE_ENABLED = _as_bool(os.getenv("LAB_SANITIZE_ENABLED"), False) # you don't sanitize today
19
- STOPSEQ_ENABLED = _as_bool(os.getenv("LAB_STOPSEQ_ENABLED"), False) # extra stops only; defaults off
20
- CRITIC_ENABLED = _as_bool(os.getenv("LAB_CRITIC_ENABLED"), False)
21
- JSON_MODE = _as_bool(os.getenv("LAB_JSON_MODE"), False)
22
- EVIDENCE_GATE = _as_bool(os.getenv("LAB_EVIDENCE_GATE"), False)
23
- @staticmethod
24
- def snapshot():
25
- return {
26
- "LAB_SANITIZE_ENABLED": _F.SANITIZE_ENABLED,
27
- "LAB_STOPSEQ_ENABLED": _F.STOPSEQ_ENABLED,
28
- "LAB_CRITIC_ENABLED": _F.CRITIC_ENABLED,
29
- "LAB_JSON_MODE": _F.JSON_MODE,
30
- "LAB_EVIDENCE_GATE": _F.EVIDENCE_GATE,
31
- }
32
- flags = _F()
33
-
34
- print("[flags] snapshot:", getattr(flags, "snapshot", lambda: {} )(), file=sys.stderr)
35
-
36
- # Optional sanitizer hook (kept no-op unless enabled later)
37
- def _sanitize(text: str) -> str:
38
- # Phase 2: default False -> no behavior change
39
- if getattr(flags, "SANITIZE_ENABLED", False):
40
- # TODO: wire your real sanitizer in Phase 3+
41
- return text.strip()
42
- return text
43
-
44
- # Stop sequences: keep today's defaults ALWAYS.
45
- # If LAB_STOPSEQ_ENABLED=true, add *extra* stops from STOP_SEQUENCES env (comma-separated).
46
- DEFAULT_STOPS: List[str] = ["\nUser:", "\nAssistant:"]
47
-
48
- def _extra_stops_from_env() -> List[str]:
49
- if not getattr(flags, "STOPSEQ_ENABLED", False):
50
- return []
51
- raw = os.getenv("STOP_SEQUENCES", "")
52
- toks = [t.strip() for t in raw.split(",") if t.strip()]
53
- return toks
54
-
55
- # ---- Model cache / load ------------------------------------------------------
56
- _model = None # module-level cache
57
-
58
- def load_model():
59
- global _model
60
- if _model is not None:
61
- return _model
62
-
63
- model_path = os.getenv("MODEL_PATH")
64
- if not model_path or not os.path.exists(model_path):
65
- raise ValueError(f"Model path does not exist or is not set: {model_path}")
66
-
67
- print(f"[INFO] Loading model from {model_path}")
68
-
69
- _model = Llama(
70
- model_path=model_path,
71
- n_ctx=1024, # short context to reduce memory use
72
- n_threads=4, # number of CPU threads
73
- n_gpu_layers=0 # CPU only (Hugging Face free tier)
74
- )
75
- return _model
76
-
77
- # ---- Inference ---------------------------------------------------------------
78
- def generate(prompt: str, max_tokens: int = 256) -> str:
79
- model = load_model()
80
-
81
- # Preserve existing default stops; optionally extend with extra ones if flag is on
82
- stops = DEFAULT_STOPS + _extra_stops_from_env()
83
-
84
- output = model(
85
- prompt,
86
- max_tokens=max_tokens,
87
- stop=stops, # unchanged defaults; may include extra stops if enabled
88
- echo=False,
89
- temperature=0.7,
90
- top_p=0.95,
91
- )
92
- raw_text = output["choices"][0]["text"]
93
-
94
- # Preserve current manual truncation by the same default stops (kept intentionally)
95
- # Extra stops are also applied here if enabled for consistency.
96
- for stop_token in stops:
97
- if stop_token and stop_token in raw_text:
98
- raw_text = raw_text.split(stop_token)[0]
99
-
100
- final = _sanitize(raw_text)
101
- return final.strip()