2084: Deepstock - can you train deepseek to do stock trading?
A short guide to using deepseek to do stocktrading.
Recently I've been fascinated by DeepSeek's breakthrough in training language models without human annotation. What made it particularly interesting was their use of objective evaluation criteria - if the math is right, it's right, no humans needed to judge.
This got me thinking: Stock prices are also pretty objective - they either go up or down. Could we use similar reinforcement learning techniques to teach a model about stock trading? Let's find out by building DeepStock, a DeepSeek-inspired model for trading stocks.
Building the Dataset
The core idea is simple: Given information about a stock, predict whether to buy or sell. The reward? Simply whether we were right about the price movement. No human judgments needed - just pure market performance.
I decided to focus on the S&P 500 stocks, collecting the following data points for each trading day:
News headlines from the past week
Price movements over the past week
Opening and closing prices
Annual financials (via yfinance)
Company description and name
The relevant code is here(there’s also some other code scattered around the colab):
%%writefile create_deepstock_dataset.py
from pydantic import BaseModel
from datetime import date, timedelta
from typing import List, Dict, Any, Tuple, Optional
import abc
from datasets import load_dataset, Dataset
import yfinance as yf
from datetime import datetime, timedelta
import pandas as pd
from tqdm.auto import tqdm
import pickle
import json
from pprint import pprint
import os
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field, HttpUrl
from enum import Enum
import requests
from typing import Dict, Any
import time
from datetime import date
import holidays
import modal
import pickle
import os
# Select country
CACHE_VOLUME="/cache"
us_holidays = holidays.US()
"""
REMEMBER TO SET NEWSAPI_SECRET in your environment variables!
"""
OLDEST_POSSIBLE_DATE = "2021-06-30"
NEWEST_POSSIBLE_DATE = "2023-12-31"
def get_business_days(start_date="2021-06-30", end_date=None):
"""
Generate a list of business days between start_date and end_date (inclusive).
If end_date is not provided, uses current date.
Args:
start_date (str): Start date in 'YYYY-MM-DD' format
end_date (str, optional): End date in 'YYYY-MM-DD' format
Returns:
list: List of datetime objects representing business days
"""
if end_date is None:
end_date = datetime.now().strftime("%Y-%m-%d")
# Convert string dates to datetime objects
start = pd.to_datetime(start_date)
end = pd.to_datetime(end_date)
# Generate business days using pandas
business_days = pd.date_range(start=start, end=end, freq='B')
return business_days.tolist()
class FinancialsT(BaseModel):
year: date
financials: str
class PriceT(BaseModel):
open: float
close: float
price_date: date
open_previous: float
close_previous: float
previous_date: date
class CompanyInfoT(BaseModel):
name: str
description: str
class NewsT(BaseModel):
news_headlines: List[str]
news_date : date
class CompanyInfoAtDate(BaseModel):
ticker: str
current_date: date
company_info: CompanyInfoT
news: NewsT
financials: FinancialsT
price: PriceT
class AbstractCompanyInfoCreator:
@abc.abstractmethod
def fetch_company_info(self, ticker: str, current_date: date) -> CompanyInfoAtDate:
pass
def format_datetime(newsdate : date):
return newsdate.strftime("%Y-%m-%d")
class NewsDatabase():
def __init__(self, start_date: date, end_date: date):
self.ds = None
self.cache = {}
self.cache_file = os.path.join(CACHE_VOLUME,"news_cache_7.pkl")
self._load_cache(start_date, end_date)
def _load_cache(self, start_date: date, end_date: date):
"""Load cache from disk if it exists"""
try:
with open(self.cache_file, 'rb') as f:
self.cache = pickle.load(f)
except (FileNotFoundError):
self.cache = {}
self.preprocess_date_range(start_date, end_date)
def _save_cache(self):
"""Save cache to disk"""
with open(self.cache_file, 'wb') as f:
pickle.dump(self.cache, f)
def preprocess_date_range(self, start_date: date, end_date: date):
"""
Preprocess and cache news headlines for all stocks between start_date and end_date.
Args:
start_date (date): Start date for preprocessing
end_date (date): End date for preprocessing
"""
print(os.listdir(CACHE_VOLUME), os.listdir("/"))
self.ds = load_dataset(
"2084Collective/FNSPID_IMPROVED", split="train"
).to_pandas()
# Convert dates to string format for comparison with dataset
start_str = format_datetime(start_date)
end_str = format_datetime(end_date)
# Filter dataset for date range
date_filtered = self.ds[
(self.ds["date"] >= start_str) &
(self.ds["date"] <= end_str)
]
# Group by date and stock
grouped = date_filtered.groupby(["date", "stock"])["title"].apply(list).to_dict()
# Update cache
for (date_str, stock), headlines in tqdm(grouped.items()):
date_str = (date_str[:10])
# print(date_str, stock, headlines)
if date_str not in self.cache:
print("not in cache")
self.cache[date_str] = {}
self.cache[date_str][stock] = headlines
# Save cache to disk
self._save_cache()
def fetch_news_for_date(self, newsdate: date, stock: str, company_info: CompanyInfoT) -> NewsT:
"""
Fetch news for a given date and stock, using cache if available.
Args:
newsdate (date): Date to fetch news for
stock (str): Stock symbol
company_info (CompanyInfoT): Company information
Returns:
NewsT: News headlines and date
"""
seven_days_ago = newsdate - timedelta(days=7)
headlines = []
# Try to get headlines from cache for the past 7 days
current_date = seven_days_ago
while current_date < newsdate:
date_str = format_datetime(current_date)
if date_str in self.cache and stock in self.cache[date_str]:
headlines.extend(self.cache[date_str][stock])
else:
pass
# # If not in cache, fetch from dataset
# day_headlines = self.ds[
# (self.ds["date"] == date_str) &
# (self.ds["stock"] == stock)
# ]["title"].tolist()
# print(current_date, day_headlines)
# # Update cache for this date and stock
# if date_str not in self.cache:
# self.cache[date_str] = {}
# self.cache[date_str][stock] = day_headlines
# headlines.extend(day_headlines)
# self._save_cache()
current_date += timedelta(days=1)
# print(headlines, newsdate, seven_days_ago)
return NewsT(news_headlines=headlines, news_date=seven_days_ago)
def summary(self) -> dict:
return {
"max_date": self.ds["date"].max(),
"min_date": self.ds["date"].min(),
"stock_count": self.ds["stock"].nunique(),
}
class PriceOpenPriceCloseDatabase:
CACHE_FILE = os.path.join(CACHE_VOLUME,"price_cache.pkl")
def __init__(self):
self.ds = None
self.cache = self._load_or_create_cache()
def _load_or_create_cache(self):
if os.path.exists(self.CACHE_FILE):
# Load existing cache
with open(self.CACHE_FILE, 'rb') as f:
return pickle.load(f)
else:
# Create and save new cache
cache = self._preprocess_data()
with open(self.CACHE_FILE, 'wb') as f:
pickle.dump(cache, f)
return cache
def _preprocess_data(self):
from collections import defaultdict
self.ds = load_dataset(
"2084Collective/deepstock-stock-historical-prices-dataset-processed",
split="train",
).to_pandas()
print("Creating new cache...")
cache = defaultdict(dict)
for _, row in tqdm(self.ds.iterrows(), total=13900000):
date_str = row['date']
stock = row['stock']
cache[date_str][stock] = {
'open': row['open'],
'close': row['close']
}
print("Cache creation complete")
return cache
def fetch_open_close_for_date(self, price_date: date, stock: str) -> PriceT:
seven_days_ago = price_date - timedelta(days=7)
current_data = self.get_stock_price(stock, price_date)
assert current_data is not None, f"Could not fetch data for {stock} on {price_date}"
while (seven_days_ago_data := self.get_stock_price(stock, seven_days_ago)) is None:
seven_days_ago -= timedelta(days=1)
return PriceT(
open=current_data['open'],
close=current_data['close'],
price_date=price_date,
open_previous=seven_days_ago_data['open'],
close_previous=seven_days_ago_data['close'],
previous_date=seven_days_ago
)
def get_stock_price(self, stock: str, pricedate: date) -> Optional[Dict[str, float]]:
date_str = format_datetime(pricedate)
try:
return self.cache[date_str][stock]
except KeyError:
try:
stock_data = yf.Ticker(stock).history(start=pricedate, end=pricedate+timedelta(days=1))
print(f"Fetching {stock} data for {date_str}", stock_data)
open_price = stock_data['Open'].iloc[0]
close_price = stock_data['Close'].iloc[0]
self.cache[date_str][stock] = {
'open': open_price,
'close': close_price
}
return self.cache[date_str][stock]
except Exception as e:
print(f"Error fetching {stock} data for {date_str}: {e}")
return None
class FinancialsDatabase:
def __init__(self):
self.financials_cache = {}
def fetch_financials_for_date(self, stock_date: date, stock: str) -> FinancialsT:
if stock not in self.financials_cache:
self.financials_cache[stock] = yf.Ticker(stock).financials
dates = [date.date() for date in self.financials_cache[stock].columns]
sorted_dates = sorted(dates)
right_date = None
for i in range(len(sorted_dates) - 1):
ind = min(len(sorted_dates) - 1, i + 1)
if sorted_dates[ind] > stock_date:
right_date = sorted_dates[ind]
break
if right_date is None and stock_date > sorted_dates[-1]:
right_date = sorted_dates[-1]
return FinancialsT(
financials=json.dumps(self.financials_cache[stock][
right_date.strftime("%Y-%m-%d")
].to_dict()),
year=right_date,
)
class CompanyInfoDatabase:
def __init__(self):
self.company_info_cache = {}
def fetch_company_info(self, stock: str) -> CompanyInfoT:
if stock not in self.company_info_cache:
self.company_info_cache[stock] = yf.Ticker(stock).info
return CompanyInfoT(
name=self.company_info_cache[stock]["shortName"],
description=self.company_info_cache[stock]["longBusinessSummary"],
)
class CompanyInfoCreator(AbstractCompanyInfoCreator):
def __init__(self, earliest_date: date, latest_date: date):
self.news_db = NewsDatabase(earliest_date - timedelta(days=10), latest_date + timedelta(days=10))
self.price_db = PriceOpenPriceCloseDatabase()
self.financials_db = FinancialsDatabase()
self.company_info_db = CompanyInfoDatabase()
# print(self.news_db.summary())
def fetch_company_info(self, ticker: str, current_date: date) -> CompanyInfoAtDate:
# start_time = time.time()
company_info = self.company_info_db.fetch_company_info(ticker)
# print(f"Fetched company info in {time.time() - start_time} seconds")
# start_time = time.time()
news = self.news_db.fetch_news_for_date(current_date, ticker, company_info)
# print(f"Fetched news in {time.time() - start_time} seconds")
# start_time = time.time()
financials = self.financials_db.fetch_financials_for_date(current_date, ticker)
# print(f"Fetched financials in {time.time() - start_time} seconds")
# start_time = time.time()
price = self.price_db.fetch_open_close_for_date(current_date, ticker)
# print(f"Fetched price in {time.time() - start_time} seconds")
return CompanyInfoAtDate(
ticker=ticker,
current_date=current_date,
company_info=company_info,
news=news,
financials=financials,
price=price,
)
def get_sp500_tickers() -> List[str]:
return pd.read_html("https://en.wikipedia.org/wiki/List_of_S%26P_500_companies")[0][
"Symbol"
].tolist()
def dump_company(company_info: Optional[CompanyInfoAtDate]) -> dict:
if company_info is None:
return None
return company_info.model_dump()
def process_single_stock(data):
cic = CompanyInfoCreator(datetime.strptime(OLDEST_POSSIBLE_DATE, "%Y-%m-%d").date(), datetime.strptime(NEWEST_POSSIBLE_DATE, "%Y-%m-%d").date())
company_info : List[CompanyInfoAtDate] = []
for ticker, day in zip(data['ticker'], data['day']):
try:
ci = cic.fetch_company_info(ticker, day.date())
company_info.append(ci)
except Exception as e:
print(e)
print(ticker, day)
company_info.append(None)
pass
return {"company_info": [dump_company(ci) for ci in company_info]}
if __name__ == "__main__" and not (os.path.exists("price_cache.pkl") and os.path.exists("news_cache_7.pkl")):
NewsDatabase(datetime.strptime(OLDEST_POSSIBLE_DATE, "%Y-%m-%d").date(), datetime.strptime(NEWEST_POSSIBLE_DATE, "%Y-%m-%d").date())
PriceOpenPriceCloseDatabase()
image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("git")
.pip_install("torch==2.2.1")
.pip_install([
"datasets",
"yfinance",
"pandas",
"requests",
"pydantic",
"tqdm",
"holidays",
"modal",
"numpy",
"transformers",
"huggingface_hub",
])
.run_commands("mkdir /cache")
.add_local_file("price_cache.pkl", remote_path="/cache/price_cache.pkl")
.add_local_file("news_cache_7.pkl", remote_path="/cache/news_cache_7.pkl")
)
app = modal.App(name="deepstock", image=image)
@app.function(timeout=2000)
def get_company_info(ticker: str) -> Tuple[List[CompanyInfoAtDate], str]:
cic = CompanyInfoCreator(datetime.strptime(OLDEST_POSSIBLE_DATE, "%Y-%m-%d").date(), datetime.strptime(NEWEST_POSSIBLE_DATE, "%Y-%m-%d").date())
company_info : List[CompanyInfoAtDate] = []
for day in get_business_days(OLDEST_POSSIBLE_DATE, NEWEST_POSSIBLE_DATE):
try:
ci = cic.fetch_company_info(ticker, day.date())
company_info.append(ci)
except Exception as e:
print(e)
print(ticker, day.date())
company_info.append(None)
pass
return company_info, ticker
@app.local_entrypoint()
def main():
tickers = get_sp500_tickers()
tickers.remove("KVUE")
tickers.remove("CEG")
tickers.remove("VLTO")
tickers.remove("GEHC")
company_info_info = {}
for result in get_company_info.map(tickers):
company_info_dates, ticker = result
company_info_info[ticker] = company_info_dates
with open("company_info.pkl", "wb") as f:
pickle.dump(company_info_info, f)
dataset = []
count_none = 0
count_total = 0
for ticker, company_info_dates in company_info_info.items():
for company_info in company_info_dates:
count_total += 1
if company_info is None:
count_none += 1
continue
dataset.append({
"ticker": ticker,
"company_info": company_info.model_dump()
})
print(f"Total number of data points: {count_total}")
print(f"Number of missing data points: {count_none}")
dataset = Dataset.from_list(dataset)
dataset.push_to_hub("2084Collective/deepstock-sp500-companies-with-info")
Here's how I formatted this data into prompts:
You are a seasoned stock market analyst who is trying to predict whether the prices will go down or up over the day, 2021-06-30, for a specific stock, by offering a buy or sell rating.
[Company Name]
3M Company
[Company Description]
3M Company provides diversified technology services in the United States and internationally. The company's Safety and Industrial segment offers industrial abrasives and finishing for metalworking applications; autobody repair solutions; closure systems for personal hygiene products, masking, and packaging materials; electrical products and materials for construction and maintenance, power distribution, and electrical original equipment manufacturers; structural adhesives and tapes; respiratory, hearing, eye, and fall protection solutions; and natural and color-coated mineral granules for shingles. Its Transportation and Electronics segment provides ceramic solutions; attachment/bonding products, films, sound, and temperature management for transportation vehicles; premium large format graphic films for advertising and fleet signage; light management films and electronics assembly solutions; packaging and interconnection solutions; semiconductor production materials; data centers solutions; and reflective signage for highway, and vehicle safety. The company's Consumer segment provides consumer bandages, braces, supports, and consumer respirators; home cleaning products; retail abrasives, paint accessories, car care DIY products, picture hanging, and consumer air quality solutions; and stationery products. It offers its products through e-commerce and traditional wholesalers, retailers, jobbers, distributors, and dealers. 3M Company was founded in 1902 and is headquartered in Saint Paul, Minnesota.
[Price Movement]
It was 142.9438018798828 on 2021-06-23.
The price of the stock on 2021-06-30 started at 140.97916179854076.
[News since 2021-06-23]
Dow Movers: MMM, CVX
2 Stocks I'm Never Selling
Better Buy: GE vs. 3M
C3.ai Is Down More Than 60% From Its Peak. Here's What Happened
Which Industrial Stocks Are Better Bets Compared To Johnson Controls?
Have Insiders Been Selling 3M Company (NYSE:MMM) Shares?
[Financials]
Basic EPS: $10.23
Normalized EBITDA: $9.607000e+09
Net Income: $5.921000e+09
Your answer should look like the following
<think>reasoning about why the stock would go up or down here for example
- Recent news highlights insider selling, which could signal low confidence.
- EPS is strong, but EBITDA has dipped slightly.
- The stock has been trending downward for the past week.
</think><answer>down</answer>
Please reason about and provide several reasons for why you think the stock would go up or down in the <think></think> tags. Please provide your answer as a single rating, 'buy' or 'sell', in the <answer></answer> tags, with buy meaning that the stock price will go up,
and sell meaning that the stock price will go down.
I set the date range from June 2021 to December 2023, mainly due to data availability constraints. For the heavy lifting of data processing, I used Modal - seriously, it's amazing how easily you can parallelize data processing with it. Just write your Python function and watch the magic happen across 100 containers. The code is in the attached colab.
The Data Wrangling Nightmare
Let’s be real: 80% of this project was data cleanup.
News APIs are trash. Most have laughable rate limits (looking at you, 100 calls/day). I ended up using a pre-scraped Hugging Face dataset,FNSPID(Thanks Zihan1004), but even that was broken. After hacking the CSV and re-uploading it, I got ~28 million headlines—good enough.
yfinance rate limits. Fetching 500 stocks’ financials took hours of staggered requests.
Date alignment. News from the last week? Easy. But aligning quarterly financials to daily predictions required backtracking to the latest available report. And then there’s also dealing with the fact that the stock market follows business hours and takes holidays off(i.e. you need to check for and skip holidays when trying to get price movements).
The Magic: Reinforcement Learning
Now for the fun part - how does the reinforcement learning actually work? For deepstock I used huggingface’s open-r1’s GRPO trainer, which is pretty clever, and based on the original Deepseek R1 Zero paper.
Basically it works like the following: At each time step, it does the following:
1. Generate 20 Completions Per Prompt
For each input prompt (e.g., “Predict MMM’s price on 2021-06-30”), the model generates 20 different responses. Think of this as brainstorming 20 possible answers, each with its own reasoning.
Example output:
<think>
- Recent news highlights insider selling, which could signal low confidence.
- EPS is strong, but EBITDA has dipped slightly.
- The stock has been trending downward for the past week.
</think>
<answer>down</answer>
2. Calculate Rewards
Each completion gets two rewards:
Accuracy reward: 1.0 if the
<answer>
tag matches the actual price movement. i.e. did it say “buy” and did the price move up over the day?Format reward: 1.0 if the output follows the
<think>...</think><answer>buy|sell</answer>
structure.
Rewards are objective—no human judgment needed. Here’s the code:
def accuracy_reward(completions, company_info, **kwargs):
"""
Reward function that checks if the completion correctly predicted price movement.
Returns 1.0 if prediction matches actual movement, 0.0 otherwise.
Ignores whitespace in the answer.
"""
try:
rewards = []
contents = [completion[0]["content"] for completion in completions]
for completion_contents, company_info in zip(contents, company_info):
pattern = r"^<think>.*?</think><answer>\s*(buy|sell)\s*</answer>$"
# Extract price data from company_info
close_price = float(company_info["price"]["close"])
open_price = float(company_info["price"]["open"])
actual_movement = "buy" if close_price > open_price else "sell"
match = re.match(pattern, completion_contents.strip(), re.IGNORECASE)
if not match:
rewards.append(0.0)
continue
# Extract prediction and remove all whitespace
prediction = match.group(1).lower().strip()
# Compare prediction with actual movement
reward = 1.0 if prediction == actual_movement else 0.0
rewards.append(reward)
return rewards
except Exception as e:
print(company_info, completions)
raise e
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format and answer content."""
pattern = r"^<think>.*?</think><answer>\s*(buy|sell)\s*</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.IGNORECASE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
How the accuracy rewards work for the GRPO Trainer is it passes in the completions as well as any other columns in the dataset you feed it as an array, which you then need to zip together and iterate over to return a list of rewards for each completion.
3. Normalize Rewards into "Advantages"
To avoid skewed gradients, rewards are normalized across the 20 completions:
Subtract the mean reward of the group.
Divide by the standard deviation.
This ensures the model focuses on relative performance (e.g., “completion #3 was better than average”) rather than absolute rewards.
# Group rewards by prompt (20 completions per prompt)
grouped_rewards = rewards.view(-1, num_generations)
mean_rewards = grouped_rewards.mean(dim=1)
std_rewards = grouped_rewards.std(dim=1)
# Normalize
advantages = (rewards - mean_rewards) / (std_rewards + 1e-4)
data:image/s3,"s3://crabby-images/530e2/530e2553811354e2cac864b3f2cfcedc735828fd" alt=""
One quirky observation: The loss actually increases as the reward improves. This is expected behavior according to the GRPO trainer documentation, as laid out in the following image, so don't panic if you see this happening.
4. Update the Model with Gradients
Here’s where the RL magic happens. The model’s loss is calculated as:
# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) /completion_mask.sum(dim=1)).mean()
Reward-weighted log probabilities: Gradients are scaled by the advantages. High-reward completions get reinforced; low-reward ones are suppressed. Essentially what this does, is by dividing out the original model outputs(subtracting
per_token_logps.detach()
which is division in logarithmic math), it normalizes the gradients attached to the parameters of the model, while preserving the direction of the gradients, and then it multiplies it by the advantages, which has the effect of forcing the model in the direction of completions which resulted in high rewards and penalizing it for completions with low rewards by reducing the magnitude of the gradient update. This is where the RL exploration/exploitation comes in, and the GRPO used in the original paper(the big equations.)
This approach essentially clips the gradients for incorrect outputs while preserving them for correct ones, gradually pushing the model toward better predictions.KL divergence penalty: On top of the above, it also calculates the KL divergence of the model from the original model, to prevents the model from drifting too far from its original behavior (critical for stability) - you don’t want it to output gobbledygook.
5. Profit???
Letting the model run for a while, the GRPO trainer should move the model towards the higher rewards using the above techniques.
Full code
import re
from dataclasses import dataclass, field
import json
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
MODEL_ID="HuggingFaceTB/SmolLM2-1.7B-Instruct"
# MODEL_ID="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DATASET_ID="2084Collective/deepstock-sp500-companies-with-info-and-user-prompt_buy_sell"
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format'.
"""
reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
)
reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def main(script_args, training_args : GRPOConfig, model_args):
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example['user_prompt']},
]
}
dataset = dataset.map(make_conversation)
# Initialize the GRPO trainer
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset['train'],
eval_dataset=dataset['test'] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
config = GRPOConfig(
log_level="debug",
max_completion_length=256,
bf16=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
logging_steps=1,
max_prompt_length=256,
output_dir="DeepSeek-R1-Distill-Qwen-7B-GRPO",
run_name="deepstock-check",
num_train_epochs=1,
learning_rate=1e-4
)
print(config.device)
script_args = GRPOScriptArguments(dataset_name=DATASET_ID)
model_args = ModelConfig(model_name_or_path=MODEL_ID, use_peft=True)
main(script_args, config, model_args)
Mini note about GRPO Trainer
The trainer expects your dataset to have a “prompt” column, which should consist of a list of messages with “role” and “content” set. This will use the chat template attached to the tokenizer for the model used to format it into the final prompt fed into the model.
The gradient_accumulation_steps
in the arguments basically indicates how many times you should run the algorithm before actually using the gradients to update the model - its a way of getting smaller GPUs to processes larger batches essentially. In this case, I have a batch size of 16.
Training
So after all the above, I ran a mini version of the above using a smaller model “"HuggingFaceTB/SmolLM2-1.7B-Instruct”, with the dataset I created as laid out above. I also used PEFT to reduce the memory usage, training only a subset of parameters.
The exciting news? DeepStock trains! While I don't have the resources to train it to completion, I've made all the training code available in a Colab notebook. I'll let it run over the next week and see what happens.
Early results
The accuracy is going up over time! Soon it’ll take on all of Wall Street. r/WallStreetBets here I come.
The loss is also going up, although as mentioned that is to be expected.
Try It Yourself
All code is in this Colab notebook. You'll need:
* A Hugging Face API key (for the dataset)
* A Modal account (for scaling data processing)
* A *lot* of patience (or a big GPU budget)
As I said above, I'm letting this train for the next week. If it works, maybe we'll see DeepStock on Wall Street. If not—well, at least we tried. If anyone has some spare compute that they’d be willing to let me use to train a much bigger model on the dataset, I’d be all ears.
Dataset
What's Next?
I've got some ideas brewing for future improvements:
Add more data sources (earnings call transcripts, SEC filings)
Experiment with longer-term predictions (weekly/monthly) - having the model issue a “buy” or “sell” rating instead of an “up” or “down” one and then the objective function being “hold for a year” instead of one day.
Let the model size itself up ("I'm never selling!")
Let me know in the comments if you'd like a follow-up! Also please subscribe if you like my work and want it to continue!
So fascinating, cannot wait to see the results!
Any update on this? How did the model do after training for almost 2 weeks?