joaco7172 commited on
Commit
52f1cd7
Β·
verified Β·
1 Parent(s): 01fbb70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -20
app.py CHANGED
@@ -23,7 +23,7 @@ base_model = AutoModelForCausalLM.from_pretrained(
23
  token=access_token,
24
  trust_remote_code=True,
25
  device_map="auto",
26
- load_in_8bit=True,
27
  offload_folder="offload/"
28
  )
29
  model = PeftModel.from_pretrained(
@@ -48,6 +48,7 @@ SYSTEM_PROMPT = "You are a seasoned stock market analyst. Your task is to list t
48
 
49
 
50
  def print_gpu_utilization():
 
51
  nvmlInit()
52
  handle = nvmlDeviceGetHandleByIndex(0)
53
  info = nvmlDeviceGetMemoryInfo(handle)
@@ -55,31 +56,37 @@ def print_gpu_utilization():
55
 
56
 
57
  def get_curday():
 
58
  return date.today().strftime("%Y-%m-%d")
59
 
60
 
61
  def n_weeks_before(date_string, n):
 
62
  date = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
 
63
  return date.strftime("%Y-%m-%d")
64
 
65
 
66
  def get_stock_data(stock_symbol, steps):
 
67
  stock_data = yf.download(stock_symbol, steps[0], steps[-1])
68
  if len(stock_data) == 0:
69
  raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
70
 
 
 
71
  dates, prices = [], []
72
- available_dates = stock_data.index.astype(str).tolist()
73
 
74
  for date in steps[:-1]:
75
  for i in range(len(stock_data)):
76
  if available_dates[i] >= date:
77
- prices.append(stock_data['Close'].iloc[i])
78
  dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
79
  break
80
 
81
  dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
82
- prices.append(stock_data['Close'].iloc[-1])
83
 
84
  return pd.DataFrame({
85
  "Start Date": dates[:-1], "End Date": dates[1:],
@@ -88,12 +95,14 @@ def get_stock_data(stock_symbol, steps):
88
 
89
 
90
  def get_news(symbol, data):
 
91
  news_list = []
92
 
93
- for _, row in data.iterrows():
94
  start_date = row['Start Date'].strftime('%Y-%m-%d')
95
  end_date = row['End Date'].strftime('%Y-%m-%d')
96
- time.sleep(1) # control qpm
 
97
  weekly_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
98
  if len(weekly_news) == 0:
99
  raise gr.Error(f"No company news found for symbol {symbol} from finnhub!")
@@ -105,7 +114,7 @@ def get_news(symbol, data):
105
  } for n in weekly_news
106
  ]
107
  weekly_news.sort(key=lambda x: x['date'])
108
- news_list.append(weekly_news)
109
 
110
  data['News'] = news_list
111
 
@@ -113,6 +122,7 @@ def get_news(symbol, data):
113
 
114
 
115
  def get_company_prompt(symbol):
 
116
  profile = finnhub_client.company_profile2(symbol=symbol)
117
  if not profile:
118
  raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
@@ -126,14 +136,15 @@ def get_company_prompt(symbol):
126
 
127
 
128
  def get_prompt_by_row(symbol, row):
 
129
  start_date = row['Start Date'] if isinstance(row['Start Date'], str) else row['Start Date'].strftime('%Y-%m-%d')
130
  end_date = row['End Date'] if isinstance(row['End Date'], str) else row['End Date'].strftime('%Y-%m-%d')
131
  term = 'increased' if row['End Price'] > row['Start Price'] else 'decreased'
132
  head = "From {} to {}, {}'s stock price {} from {:.2f} to {:.2f}. Company news during this period are listed below:\n\n".format(
133
  start_date, end_date, symbol, term, row['Start Price'], row['End Price'])
134
 
135
- news = row["News"] if isinstance(row["News"], list) else json.loads(row["News"])
136
- news_formatted = ["[Headline]: {}\n[Summary]: {}\n".format(
137
  n['headline'], n['summary']) for n in news if n['date'][:8] <= end_date.replace('-', '') and \
138
  not n['summary'].startswith("Looking for stock market analysis and research with proves results?")]
139
 
@@ -144,21 +155,20 @@ def get_prompt_by_row(symbol, row):
144
  else:
145
  basics = "[Basic Financials]:\n\nNo basic financial reported."
146
 
147
- return head, news_formatted, basics
148
 
149
 
150
  def sample_news(news, k=5):
 
151
  return [news[i] for i in sorted(random.sample(range(len(news)), k))]
152
-
153
-
154
  def latest_news(news, k=5):
155
- if not isinstance(news, list) or not all(isinstance(item, dict) for item in news):
156
- raise ValueError("News must be a list of dictionaries.")
157
  sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
158
  return sorted_news[:k]
159
 
160
 
161
  def get_current_basics(symbol, curday):
 
162
  basic_financials = finnhub_client.company_basic_financials(symbol, 'all')
163
  if not basic_financials['series']:
164
  raise gr.Error(f"Failed to find basic financials for symbol {symbol} from finnhub!")
@@ -183,6 +193,7 @@ def get_current_basics(symbol, curday):
183
 
184
 
185
  def get_all_prompts_online(symbol, data, curday, with_basics=True):
 
186
  company_prompt = get_company_prompt(symbol)
187
 
188
  prev_rows = []
@@ -194,7 +205,10 @@ def get_all_prompts_online(symbol, data, curday, with_basics=True):
194
  prompt = ""
195
  for i in range(-len(prev_rows), 0):
196
  prompt += "\n" + prev_rows[i][0]
197
- latest_news_items = latest_news(prev_rows[i][1], min(5, len(prev_rows[i][1])))
 
 
 
198
  if latest_news_items:
199
  prompt += "\n".join(latest_news_items)
200
  else:
@@ -216,7 +230,9 @@ def get_all_prompts_online(symbol, data, curday, with_basics=True):
216
  return info, prompt
217
 
218
 
 
219
  def construct_prompt(ticker, curday, n_weeks, use_basics):
 
220
  try:
221
  steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
222
  except Exception:
@@ -225,26 +241,31 @@ def construct_prompt(ticker, curday, n_weeks, use_basics):
225
  data = get_stock_data(ticker, steps)
226
  data = get_news(ticker, data)
227
  data['Basics'] = [json.dumps({})] * len(data)
 
228
 
229
  info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
230
 
231
  prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
 
232
 
233
  return info, prompt
234
 
235
 
236
  def predict(ticker, date, n_weeks, use_basics):
 
237
  print_gpu_utilization()
238
 
239
  info, prompt = construct_prompt(ticker, date, n_weeks, use_basics)
240
 
241
- inputs = tokenizer(prompt, return_tensors='pt', padding=False)
 
 
242
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
243
 
244
  print("Inputs loaded onto devices.")
245
 
246
  res = model.generate(
247
- **inputs, max_length=4096, do_sample=False,
248
  eos_token_id=tokenizer.eos_token_id,
249
  use_cache=True, streamer=streamer
250
  )
@@ -291,9 +312,14 @@ demo = gr.Interface(
291
  label="Response"
292
  )
293
  ],
294
- title="Pro Capital",
295
- description="""Implementation**
 
 
 
 
 
296
  """
297
  )
298
 
299
- demo.launch()
 
23
  token=access_token,
24
  trust_remote_code=True,
25
  device_map="auto",
26
+ torch_dtype=torch.float16,
27
  offload_folder="offload/"
28
  )
29
  model = PeftModel.from_pretrained(
 
48
 
49
 
50
  def print_gpu_utilization():
51
+
52
  nvmlInit()
53
  handle = nvmlDeviceGetHandleByIndex(0)
54
  info = nvmlDeviceGetMemoryInfo(handle)
 
56
 
57
 
58
  def get_curday():
59
+
60
  return date.today().strftime("%Y-%m-%d")
61
 
62
 
63
  def n_weeks_before(date_string, n):
64
+
65
  date = datetime.strptime(date_string, "%Y-%m-%d") - timedelta(days=7*n)
66
+
67
  return date.strftime("%Y-%m-%d")
68
 
69
 
70
  def get_stock_data(stock_symbol, steps):
71
+
72
  stock_data = yf.download(stock_symbol, steps[0], steps[-1])
73
  if len(stock_data) == 0:
74
  raise gr.Error(f"Failed to download stock price data for symbol {stock_symbol} from yfinance!")
75
 
76
+ # print(stock_data)
77
+
78
  dates, prices = [], []
79
+ available_dates = stock_data.index.format()
80
 
81
  for date in steps[:-1]:
82
  for i in range(len(stock_data)):
83
  if available_dates[i] >= date:
84
+ prices.append(stock_data['Close'][i])
85
  dates.append(datetime.strptime(available_dates[i], "%Y-%m-%d"))
86
  break
87
 
88
  dates.append(datetime.strptime(available_dates[-1], "%Y-%m-%d"))
89
+ prices.append(stock_data['Close'][-1])
90
 
91
  return pd.DataFrame({
92
  "Start Date": dates[:-1], "End Date": dates[1:],
 
95
 
96
 
97
  def get_news(symbol, data):
98
+
99
  news_list = []
100
 
101
+ for end_date, row in data.iterrows():
102
  start_date = row['Start Date'].strftime('%Y-%m-%d')
103
  end_date = row['End Date'].strftime('%Y-%m-%d')
104
+ # print(symbol, ': ', start_date, ' - ', end_date)
105
+ time.sleep(1) # control qpm
106
  weekly_news = finnhub_client.company_news(symbol, _from=start_date, to=end_date)
107
  if len(weekly_news) == 0:
108
  raise gr.Error(f"No company news found for symbol {symbol} from finnhub!")
 
114
  } for n in weekly_news
115
  ]
116
  weekly_news.sort(key=lambda x: x['date'])
117
+ news_list.append(json.dumps(weekly_news))
118
 
119
  data['News'] = news_list
120
 
 
122
 
123
 
124
  def get_company_prompt(symbol):
125
+
126
  profile = finnhub_client.company_profile2(symbol=symbol)
127
  if not profile:
128
  raise gr.Error(f"Failed to find company profile for symbol {symbol} from finnhub!")
 
136
 
137
 
138
  def get_prompt_by_row(symbol, row):
139
+
140
  start_date = row['Start Date'] if isinstance(row['Start Date'], str) else row['Start Date'].strftime('%Y-%m-%d')
141
  end_date = row['End Date'] if isinstance(row['End Date'], str) else row['End Date'].strftime('%Y-%m-%d')
142
  term = 'increased' if row['End Price'] > row['Start Price'] else 'decreased'
143
  head = "From {} to {}, {}'s stock price {} from {:.2f} to {:.2f}. Company news during this period are listed below:\n\n".format(
144
  start_date, end_date, symbol, term, row['Start Price'], row['End Price'])
145
 
146
+ news = json.loads(row["News"])
147
+ news = ["[Headline]: {}\n[Summary]: {}\n".format(
148
  n['headline'], n['summary']) for n in news if n['date'][:8] <= end_date.replace('-', '') and \
149
  not n['summary'].startswith("Looking for stock market analysis and research with proves results?")]
150
 
 
155
  else:
156
  basics = "[Basic Financials]:\n\nNo basic financial reported."
157
 
158
+ return head, news, basics
159
 
160
 
161
  def sample_news(news, k=5):
162
+
163
  return [news[i] for i in sorted(random.sample(range(len(news)), k))]
164
+
 
165
  def latest_news(news, k=5):
 
 
166
  sorted_news = sorted(news, key=lambda x: x['date'], reverse=True)
167
  return sorted_news[:k]
168
 
169
 
170
  def get_current_basics(symbol, curday):
171
+
172
  basic_financials = finnhub_client.company_basic_financials(symbol, 'all')
173
  if not basic_financials['series']:
174
  raise gr.Error(f"Failed to find basic financials for symbol {symbol} from finnhub!")
 
193
 
194
 
195
  def get_all_prompts_online(symbol, data, curday, with_basics=True):
196
+
197
  company_prompt = get_company_prompt(symbol)
198
 
199
  prev_rows = []
 
205
  prompt = ""
206
  for i in range(-len(prev_rows), 0):
207
  prompt += "\n" + prev_rows[i][0]
208
+ latest_news_items = latest_news(
209
+ prev_rows[i][1],
210
+ min(5, len(prev_rows[i][1]))
211
+ )
212
  if latest_news_items:
213
  prompt += "\n".join(latest_news_items)
214
  else:
 
230
  return info, prompt
231
 
232
 
233
+
234
  def construct_prompt(ticker, curday, n_weeks, use_basics):
235
+
236
  try:
237
  steps = [n_weeks_before(curday, n) for n in range(n_weeks + 1)][::-1]
238
  except Exception:
 
241
  data = get_stock_data(ticker, steps)
242
  data = get_news(ticker, data)
243
  data['Basics'] = [json.dumps({})] * len(data)
244
+ # print(data)
245
 
246
  info, prompt = get_all_prompts_online(ticker, data, curday, use_basics)
247
 
248
  prompt = B_INST + B_SYS + SYSTEM_PROMPT + E_SYS + prompt + E_INST
249
+ # print(prompt)
250
 
251
  return info, prompt
252
 
253
 
254
  def predict(ticker, date, n_weeks, use_basics):
255
+
256
  print_gpu_utilization()
257
 
258
  info, prompt = construct_prompt(ticker, date, n_weeks, use_basics)
259
 
260
+ inputs = tokenizer(
261
+ prompt, return_tensors='pt', padding=False
262
+ )
263
  inputs = {key: value.to(model.device) for key, value in inputs.items()}
264
 
265
  print("Inputs loaded onto devices.")
266
 
267
  res = model.generate(
268
+ **inputs, max_length=4096, do_sample=True,
269
  eos_token_id=tokenizer.eos_token_id,
270
  use_cache=True, streamer=streamer
271
  )
 
312
  label="Response"
313
  )
314
  ],
315
+ title="FinGPT-Forecaster",
316
+ description="""FinGPT-Forecaster takes random market news and optional basic financials related to the specified company from the past few weeks as input and responds with the company's **positive developments** and **potential concerns**. Then it gives out a **prediction** of stock price movement for the coming week and its **analysis** summary.
317
+ This model is finetuned on Llama2-7b-chat-hf with LoRA on the past year's DOW30 market data. Inference in this demo uses fp16 and **welcomes any ticker symbol**.
318
+ Company profile & Market news & Basic financials & Stock prices are retrieved using **yfinance & finnhub**.
319
+ This is just a demo showing what this model is capable of. Results inferred from randomly chosen news can be strongly biased.
320
+ For more detailed and customized implementation, refer to our FinGPT project: <https://github.com/AI4Finance-Foundation/FinGPT>
321
+ **Disclaimer: Nothing herein is financial advice, and NOT a recommendation to trade real money. Please use common sense and always first consult a professional before trading or investing.**
322
  """
323
  )
324
 
325
+ demo.launch()