None grpo
In [1]:
import os
In [2]:
%load_ext watermark
%load_ext autoreload
%autoreload 2

import re
import vllm
import torch
import torch.nn.functional as F
from datasets import load_dataset
from math_verify import parse, verify
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

%watermark -a 'Ethen' -d -u -v -iv
/usr/local/lib/python3.12/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Author: Ethen

Last updated: 2026-06-05

Python implementation: CPython
Python version       : 3.12.13
IPython version      : 9.14.0

datasets    : 4.8.5
math_verify : 0.9.0
re          : 2.2.1
torch       : 2.10.0+cu129
transformers: 4.57.6
vllm        : 0.18.0

Introduction to GRPO (Group Relative Policy Optimization)

The standard recipe for building capable language models (LLM) involves: pre-training on massive corpora, followed by supervised/instruction fine-tuning stage (SFT) on curated instruction response pairs. While SFT is a potent method it's fundamental limitation is it trains the model to imitate demonstrations. On the other hand, in reinforcement learning (RL), the model learns to independently generate its own answers, explore new paths to optimize for a specific goal. Being "on-policy" allows the model to explore new pathways and solutions that might not be present in the demonstration training data. This is similar in real life, where we first bootstrap ourselves via imitation, either from school, studying how other people found success. But once we can acquire that foundation, it's always best to play to our strength, take our own actions and we learn through environment's feedback [7].

Using RL to improve LLM started with learning to summarize from human feedback [2], where they demonstrated training a reward model from human preferences, and using it as a reward signal to fine-tune a summarization policy with RL produce better summary than optimizing ROUGE as well as much larger models trained with supervised learning alone. InstructGPT [3] scaled this approach to general instruction following across arbitrary tasks not just summarization, establishing the reinforcement learning from human feedback (RLHF) recipe behind ChatGPT.

The standard RLHF pipeline, however, comes with significant complexity. The algorithm adopted, PPO, requires 4 models simultaneously (we'll dive into this a bit more in later sections).

┌─────────────┐     ┌──────────────┐     ┌──────────────┐     ┌─────────────┐
│  SFT Model  │ ──► │ Reward Model │ ──► │ PPO Training │ ──► │ Final Model │
│  (policy    │     │ (trained on  │     │ (policy +    │     │             │
│   init)     │     │  preferences)│     │  critic +    │     │             │
│             │     │              │     │  reward +    │     │             │
│             │     │              │     │  reference)  │     │             │
└─────────────┘     └──────────────┘     └──────────────┘     └─────────────┘
                                               │
                                    4 models in memory

Two subsequent work from Deepseek introducted simplifcations that make modern RL practical for folks that are just getting started.

DeepseekMath [4] introduced Group Relative Policy Optimization (GRPO), which eliminates the critic model entirely. Subsequent work Deepseek R1 [5] demonstrates for tasks with verifiable answers like math, we can also eliminate the learned reward model and replace it with a deterministic verifier. This motivates the concept of verification asymmetry [6], where some tasks should be easier to verify than to solve. Together these simplificactions reduced the RL pipeline from 4 models to 2.

Intuition: The Exam Analogy

Instead of diving directly into the math, let us build an intuition for why each component of GRPO exists. In reinforcement learning, simply optimizing for a high reward alone can lead to instability or shortcutting behaviors. To address these challenges, we often incorporate several mechanisms - critic, clip operation, reference.

Following the work understanding GRPO without any prior RL knowledge [1], We will frame RL training process as a school exam scenario. We (the model being trained) are like students trying to get higher grades, the teacher we grades our exam is like the reward model, and our father handing out pocket money based on our grades plays the rolw of the training algorithm.

Suppose my younger brother and I are in the same class, where I typically score above 80, while my brother often gets around 30. We take our test scores directly to our father for pocket money - whoever scores higher gets more. This however, becomes unfair, as if my brother improves from 30 to 60 through tremendous effort, his absolute reward still pales in comparison to my ususal 80+.

Critic: Recognizing this problem, our father sets a predicted score line for each of us, e.g. 80 for me and 40 for my brother. If we exceed each of our baseline, we get extra pocket money proportional to the surplus. Each person is now incentized to improve from their own baseline. And of course, our father must keep adjusting this baseline as we progress. This predicted baseline is so called the value function, and our training objective now becomes maximizing the advantage (reward - baseline), instead of raw reward.

Clip: Purely chasing for larger amount of pocket money might push us to adopt some extreme study patterns. Causing our grades to swing s between extremes, 95 one week, 60 the next. These outliers shouldn't dictate our overall study strategy. So our father caps how much extra pocket money any single exam can earn. This is called the clip mechanism, limiting how much the policy can shift in one update.

Reference Model: If I am solely fixated on high scores, I might resort to questionable tactics - cheating, or memorizing patterns without actually understanding. In LLM terms, this corresponds to reward hacking: producing degenerate outputs that score highly without being genuinely useful. Our father adds a rule: "You can't deviate too much from your original studying approach, else even with high scores, I will penalize you". In practice, this is calculated using the KL divergence against a reference model, it keeps the policy that we are training during RL stage from drifting too far from the initial model.

GRPO: One day, our father feels he is spending enormous amount of time maintaining our predicted score baseline, instead says "take a few practice tests yourself before the real exam, and use those average as your expected score". In LLM scenarios, this critic is often times a model that also needs to be trained, typically the same size as the policy model, effectively doubling the system's memory and compute requirements. GRPO's core insight is to replace the learned value function with the empirical mean of multiple samples from our model. This mechanism is self-calibrating, as we improve, our average rises automatically, fair given we are comparing only against ourselves, as well as offers computational advantage, as generating multiple samples is far cheaper than training a seprate value model.

Putting it Together

The complete GRPO training objective is to maximize:

$$ J_{\text{GRPO}}(\theta) = \mathbb{E}_{q, \{o_i\}_{i=1}^G} \left[\frac{1}{G}\sum_{i=1}^{G} \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} \left(\min\left(\rho_{i,t} \hat{A}_i, \; \text{clip}(\rho_{i,t}, 1\!-\!\epsilon, 1\!+\!\epsilon) \hat{A}_i\right) - \beta \cdot D_{\text{KL}}^{(i,t)}\right)\right] $$

GRPO's key insight is replacing the learned value function with a statistical baseline computed from a group of sampled completions. For each prompt $q$, GRPO:

  1. Samples a group of $G$ completions $\{o_1, o_2, \ldots, o_G\}$ from the current policy $\pi_\theta$.
  2. Evaluates each completion with a reward function to obtain rewards $\{r_1, r_2, \ldots, r_G\}$.
  3. Computes the group relative advantage for each completion using z-score normalization.
$$\hat{A}_i = \frac{r_i - \mu_G}{\sigma_G}$$

Completions with above average rewards gets positive advantges, and will get reinforced. Whereas completion with below average rewards gets negative advantage and suppressed.

One of the most critical failure modes in GRPO occurs when all completions within a group receives identical rewards. Since GRPO computes advantages via group level normalization, uniform rewards collapse the advantage to 0, which leads to zero gradients, and completely wasting that batch of compute.

When all rewards are identical ($r_1 = r_2 = \cdots = r_G$):

$$ \text{mean} = r_i \quad \Rightarrow \quad r_i - \text{mean} = 0 \quad \Rightarrow \quad \hat{A}_i = 0 \quad \forall i $$

This condition happens in two symmetric scenarios when the problem is too easy or too hard for the currency policy. If every sampled completion solves it correctly or no sampled completion succeeds, the reward lacks the variance/diversity needed to differentiate between completions and provide meaningful learning signals.

Importance sampling ratio: One fundamental challenge in RL is sample efficiency, every time we update our current policy's parameter $\pi_\theta$ by a single step, all previously collected samples become "stale". This can be wasteful in LLM setting, where generating completions requires expensive full autoregressive decoding. Importance sampling allows us to reuse samples from an old policy $\pi_{\theta_{\text{old}}}$ while still computing a valid gradient estimate for the current policy $\pi_\theta$.

In practice, GRPO computes the importance ratio at token level:

$$ \rho_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})} = \exp\left(\log \pi_\theta(o_{i,t} \mid q, o_{i,<t}) - \log \pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})\right) $$

The ratio $\rho = \pi_\theta(o) / \pi_{\theta_{\text{old}}}(o)$ is called the importance sampling ratio, where it reweights each sample to account for the mismatch between the distribution it was drawn from (old policy) and the distribution we want to estimate expectations under (current policy). Intuition: If the current policy assigns a higher probability to a particular completion than the old policy did, then that completion is under-represented in our old samples. Given it should have appeared more often, the ratio $\rho > 1$ upweights it to compensate.

As for carrying out the computation, what we want (in probability space):

$\text{ratio} = \text{new prob} / \text{old prob}$

but in practice, we typically have model's log probabilities, so in log space:

$\text{log(ratio)} = \text{log(new prob / old prob)} = \text{log(new prob) - log(old prob)} = \text{log prob - old log prob}$

To obtain actual ratio:

$\text{ratio} = \text{exp(log(new prob / old prob))} = \text{exp(log prob - old log prob)}$

Clipping is necessary as importance sampling can have a very high variance when the two distributions diverge significantly.

$$\text{clip}(\rho_{i,t}, \; 1-\epsilon, \; 1+\epsilon)$$

By bounding the ratio via a clipping parameter $\epsilon$, we limit how much any single token's reweighting can influence the model update. This creates a trusted region, where the optimization proceeds as if the current policy can't deviate too far from the old policy in a single step.

It's important to note that when the ratio gets clamped, this style of clipping mechanism introduces a flat region in the loss landscape where the gradient becomes 0 to those elements that are clipped, effectively stopping learning for them. Because any value exceeding the thresholds is capped to a constant value, mathematically, the derivative of a constant is zero.

In [3]:
# create an input tensor with values below, within, and above the clamp limits
x = torch.tensor([-2.0, 0.5, 3.0], requires_grad=True)

# apply torch.clamp between -1.0 and 1.0
y = torch.clamp(x, min=-1.0, max=1.0)

# compute a dummy loss (sum of elements) so we can backpropagate
loss = y.sum()
loss.backward()

# inspect the gradients
print("Original Input: ", x.tolist())
print("Clamped Output: ", y.tolist())
print("Gradients:      ", x.grad.tolist())
Original Input:  [-2.0, 0.5, 3.0]
Clamped Output:  [-1.0, 0.5, 1.0]
Gradients:       [0.0, 1.0, 0.0]

When num_grpo/ppo_epochs > 1 we perform multiple gradient steps on the same batch of rollouts. The importance sampling ratio drifts further from 1 with each step, and clipping prevents the later steps from making overly aggressive updates.

KL divergence regularization: In GRPO, KL divergence is added directly to the loss function (not to reward as in some other variants). The per-token KL divergence is computed against a reference policy $\pi_{\text{ref}}$ (typically the initial SFT model):

$$D_{\text{KL}}^{(i,t)} = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1$$

There're many different forms of KL divergence, this particular form is sometimes referred to as $k_3$ estimator. The coefficient $\beta$ controls the strength of regularization — larger $\beta$ keeps the policy closer to the reference model.

Note that advantage $\hat{A}_i$ is at sequence level (same value for all tokens in a completion), while importance ratio $\rho_{i,t}$ and KL divergence $D_{\text{KL}}^{(i,t)}$ are computed at the token level. Because of this, while KL divergence's primary role is to prevent the policy from diverging too far from the reference model, it effectively also serves as a per token process reward signal. Providing dense, token-level feedback regardless of whether the final sparse outcome reward.

To summarize:

Step Exam Analogy RL Component Problem Solved
1. Raw scores Pays based on absolute grade Reward $r(o)$ — (starting point)
2. Score line Sets a predicted baseline per child Value function $V_\psi(s)$ (Critic) High variance, unfair comparisons
3. Capped bonus Limits max bonus from any single exam Clip $[1-\epsilon, 1+\epsilon]$ Unstable over-updates
4. "Stay honest" rule Penalizes deviation from original habits KL penalty against $\pi_{\text{ref}}$ Reward hacking, policy collapse
5. Practice tests Use your own average from n tests as baseline Group mean $\mu_G$ replaces Critic Expensive Critic network eliminated

Implementation

This section provides an implementation of GRPO training loop consisting of 3 main phases per iteration:

  • Rollout generation
  • Reward/Advantage computation
  • Policy update.

We'll be using gsm8k grade school math dataset, and Qwen/Qwen3-0.6B model to illustrate the end to end flow.

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load policy model (trainable)
policy_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
).to(device)
policy_model.train()
# Load reference model, frozen copy of the initial policy
ref_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16
).to(device)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False
`torch_dtype` is deprecated! Use `dtype` instead!
In [5]:
train_data = load_dataset("openai/gsm8k", "main")["train"]
print(train_data)
train_data[0]
Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})
Out[5]:
{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'}
In [6]:
def extract_hash_answer(text: str) -> str:
    return text.split("####")[1].strip()


def collate_fn(batch):
    prompts = []
    answers = []
    for example in batch:
        question = example["question"]
        answer = example["answer"]
        answer = extract_hash_answer(answer)

        prompt = [
            {"role": "user", "content": question}
        ]
        prompts.append(prompt)
        answers.append(answer)

    batch = {
        "prompts": prompts,
        "answers": answers
    }
    return batch
In [7]:
batch_size = 2
train_loader = DataLoader(
    train_data, batch_size=batch_size, collate_fn=collate_fn, num_workers=0,    
)
batch = next(iter(train_loader))
batch
Out[7]:
{'prompts': [[{'role': 'user',
    'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?'}],
  [{'role': 'user',
    'content': 'Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?'}]],
 'answers': ['72', '10']}

Rollout Generation

In [8]:
prompts = tokenizer.apply_chat_template(
    batch["prompts"],
    tokenize=False,
    enable_thinking=False,
    add_generation_prompt=True
)
prompt_encoded = tokenizer(
    prompts,
    return_tensors="pt",
    padding=True,
    padding_side="left",
    truncation=True
)

# [batch size, sequence length]
input_ids = prompt_encoded["input_ids"].to(device)
attention_mask = prompt_encoded["attention_mask"].to(device)
input_ids.shape
Out[8]:
torch.Size([2, 51])
In [9]:
num_rollouts = 4

# roll out stage
policy_model.eval()
with torch.no_grad():
    # [batch size * number of rollouts, sequence length (prompt length + max new tokens)]
    # num_return_sequences internally repeats each prompt
    generated_ids = policy_model.generate(    
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=1024,
        do_sample=True,
        num_return_sequences=num_rollouts,
        top_p=0.95,
        top_k=20,
        temperature=0.6,
        eos_token_id=tokenizer.eos_token_id
    )
policy_model.train()
print(generated_ids.shape)
print(tokenizer.batch_decode(generated_ids)[0])
torch.Size([8, 165])
<|im_start|>user
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|im_end|>
<|im_start|>assistant
<think>

</think>

Natalia sold 48 clips in April.

In May, she sold **half as many clips** as she did in April. That means she sold:

$$
\frac{1}{2} \times 48 = 24 \text{ clips in May}
$$

Now, to find the total number of clips she sold in April and May:

$$
48 + 24 = 72
$$

So, Natalia sold **72 clips** in total in April and May.<|im_end|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
In [10]:
# decode completions for reward computation, since left padding
# ensures all prompts end at column prompt length, we simply slice
# from that position onwards, and batch decode handles variable length
# responses via skip special tokens
# [batch size * number of rollouts, response length]
prompt_len = input_ids.shape[1]
response_ids = generated_ids[:, prompt_len:]
completion_texts = tokenizer.batch_decode(response_ids, skip_special_tokens=True)

Reward Advantage Computation

In [11]:
def cleanup_reasoning(response: str) -> str:
    """Remove the <think> process."""
    response = response.strip()
    think_pattern = re.compile(r'(<think>)?(.+)(</think>)', re.DOTALL)
    matched = re.search(think_pattern, response)
    if matched:
        # remove think tags
        answer = response[matched.end(0):]
    else:
        answer = response

    answer = answer.strip()
    return answer


def compute_score(solution_str, ground_truth):
    """Binary reward: 1 if the parsed answer matches ground truth, 0 otherwise.

    We rely on math-verify, which offers more flexibility in terms
    of mathematical equivalence answers instead of strict answer match
    https://github.com/huggingface/Math-Verify
    """
    solution_str_cleaned = cleanup_reasoning(solution_str)
    gold = parse(ground_truth)
    answer = parse(solution_str_cleaned)
    score = int(verify(gold, answer))
    return score
In [12]:
def compute_group_advantages(
    rewards: torch.Tensor
) -> torch.Tensor:
    """
    Compute group-relative advantages rewards [batch_size, num_rollouts]
    """
    mean_rewards = rewards.mean(dim=-1, keepdim=True)
    std_rewards = rewards.std(dim=-1, keepdim=True)
    advantages = (rewards - mean_rewards) / (std_rewards + 1e-8)
    return advantages
In [13]:
answers = batch["answers"]
answers_repeated = [
    answer for answer in answers for _ in range(num_rollouts)
]
print(answers_repeated)
Out[13]:
['72', '72', '72', '72', '10', '10', '10', '10']
In [14]:
rewards = torch.tensor(
    [
        compute_score(solution_str, ground_truth)
        for solution_str, ground_truth in zip(completion_texts, answers_repeated)
    ],
    dtype=torch.float32,
    device=device,
)  # [batch_size * num_rollouts,]

# temporarily reshape to [batch_size, num_rollouts] for
# computing batch group-relative advantages
# the last batch in an epoch might be smaller than config's batch size
rewards_grouped = rewards.view(len(answers), num_rollouts)
advantages_grouped = compute_group_advantages(rewards_grouped)
# [batch_size * num_rollouts,]
advantages = advantages_grouped.view(-1)
advantages
Out[14]:
tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')

Policy Update

In [15]:
def compute_logprobs(model, input_ids, attention_mask):
    """Compute per-token log probabilities from model logits."""
    output = model(input_ids, attention_mask, use_cache=False)
    logits = output.logits

    # shift logits and labels
    labels = input_ids.clone()
    labels[attention_mask == 0] = -100
    logits_shifted = logits[:, :-1, :].contiguous()
    labels_shifted = labels[:, 1:].contiguous()
    # (batch_size, seq_len - 1)
    log_probs = -F.cross_entropy(
        logits_shifted.view(-1, logits_shifted.shape[-1]),
        labels_shifted.view(-1),
        reduction="none",
        ignore_index=-100
    ).view(labels_shifted.shape)
    return log_probs
In [16]:
gen_attention_mask = (generated_ids != tokenizer.pad_token_id).long()
gen_attention_mask[:, :prompt_len] = 0

# compute old policy's log prob, frozen used for importance sampling ratios
with torch.no_grad():
    log_probs_old = compute_logprobs(policy_model, generated_ids, gen_attention_mask)

# compute reference model's log prob, frozen used for KL divergence
with torch.no_grad():
    log_probs_ref = compute_logprobs(ref_model, generated_ids, gen_attention_mask)

# current policy log prob
policy_model.train()
log_probs_current = compute_logprobs(policy_model, generated_ids, gen_attention_mask)

# [batch_size * num_rollouts, seq length - 1]
print(log_probs_old.shape)
print(log_probs_current.shape)
torch.Size([8, 164])
torch.Size([8, 164])
In [20]:
def compute_kl_divergence(
    log_probs_current: torch.Tensor,
    log_probs_ref: torch.Tensor,
) -> torch.Tensor:
    log_ratio = log_probs_ref - log_probs_current
    kl_per_token = torch.exp(log_ratio) - log_ratio - 1
    return kl_per_token


def compute_grpo_loss(
    log_probs_current: torch.Tensor,
    log_probs_old: torch.Tensor,
    log_probs_ref: torch.Tensor,
    advantages: torch.Tensor,
    response_mask: torch.Tensor,
    clip_epsilon: float = 0.2,
    kl_coeff: float = 0.01
):
    # per token importance sampling ratio
    log_ratio = log_probs_current - log_probs_old
    ratio = torch.exp(log_ratio)
    
    # broadcast sequence level advantages to all token positions
    advantages_expanded = advantages.unsqueeze(-1)
    
    # clipped objective
    surrogate_unclipped = ratio * advantages_expanded
    surrogate_clipped = (
        torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
        * advantages_expanded
    )
    surrogate_loss = torch.min(surrogate_unclipped, surrogate_clipped)
    
    # loss - beta * KL, ensure it is masked to calculate for response token only
    kl_penalty = compute_kl_divergence(log_probs_current, log_probs_ref)
    per_token_objective = (surrogate_loss - kl_coeff * kl_penalty) * response_mask
    
    # average over valid response per completion, then over batch size
    response_lengths = response_mask.sum(dim=-1)
    per_completion_objective = per_token_objective.sum(dim=-1) / response_lengths
    loss = -per_completion_objective.mean()
    return loss, surrogate_loss, per_token_objective, kl_penalty


# slice to match shifted tensor
response_mask = gen_attention_mask[:, 1:]
loss, surrogate_loss, per_token_objective, kl_penalty = compute_grpo_loss(
    log_probs_current=log_probs_current,
    log_probs_old=log_probs_old,
    log_probs_ref=log_probs_ref,
    advantages=advantages,
    response_mask=response_mask,
    clip_epsilon=0.2,
    kl_coeff=0.01,
)
print(loss)
tensor(4.8195e-06, device='cuda:0', grad_fn=<NegBackward0>)

Evaluation

We run the full training and evaluation separately outside of notebook via grpo.py and evaluate.py script folder link.

Model: Qwen/Qwen3-0.6B Accuracy: 58.98% (778/1319)

Model: Qwen/Qwen3-0.6B, 8 A100 for 200 steps Accuracy: 66.87% (882/1319)

Anecdotal example:

{'idx': 2, 'prompt': '<|im_start|>user\nJosh decides to try flipping a house.  He buys a house for $80,000 and then puts in $50,000 in repairs.  This increased the value of the house by 150%.  How much profit did he make?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n', 'ground_truth': '70000', 'generated_text': 'Josh buys a house for **$80,000** and puts in **$50,000** in repairs. The value of the house **increased by 150%**.\n\n---\n\n### Step 1: Find the new value of the house\n\nThe value increased by 150%, so the new value is:\n\n$$\n\\text{New Value} = \\text{Original Value} + 150\\% \\times \\text{Original Value}\n$$\n\n$$\n\\text{New Value} = 80,000 + 0.15 \\times 80,000 = 80,000 + 12,000 = 92,000\n$$\n\n---\n\n### Step 2: Calculate the profit\n\nProfit = New Value - Cost\n\n$$\n\\text{Profit} = 92,000 - 50,000 = 42,000\n$$\n\n---\n\n### ✅ Final Answer:\n\nJosh made a **profit of $42,000**.', 'correct': 0}

Qwen3-0.6B made two errors:

Computed 150% as 0.15 × 80,000 = 12,000 (confused 150% with 15%). Only subtracted repair cost, not total investment: 92,000 - 50,000

Qwen3-0.6B-GRPO, on the other hand, got it right:

{'idx': 2, 'prompt': '<|im_start|>user\nJosh decides to try flipping a house.  He buys a house for $80,000 and then puts in $50,000 in repairs.  This increased the value of the house by 150%.  How much profit did he make?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n', 'ground_truth': '70000', 'generated_text': 'We are given the following information:\n\n- The original price of the house is **$80,000**.\n- Josh puts in **$50,000** in repairs.\n- The value of the house **increased by 150%** due to the repairs.\n\n---\n\n### Step 1: Calculate the new value of the house after the increase\n\nThe increase in value is **150%** of the original price:\n\n$$\n\\text{Increase} = 150\\% \\times 80,000 = 1.5 \\times 80,000 = 120,000\n$$\n\n$$\n\\text{New value} = \\text{Original value} + \\text{Increase} = 80,000 + 120,000 = 200,000\n$$\n\n---\n\n### Step 2: Calculate the profit\n\nProfit is the difference between the **new value** and the **cost** (original price of the house and repairs):\n\n$$\n\\text{Profit} = \\text{New value} - (\\text{Original price} + \\text{Cost in repairs})\n$$\n\n$$\n\\text{Profit} = 200,000 - (80,000 + 50,000) = 200,000 - 130,000 = 70,000\n$$\n\n---\n\n### ✅ Final Answer: $ \\boxed{70,000} $ profit.', 'correct': 1}

Correctly computed 1.5 × 80,000 = 120,000 increase. Subtracted total cost (purchase + repairs): 200,000 - 130,000 = 70,000

Reference

  • [1] Yihua Zhang - Blog: DeepSeek-R1 Dissection: Understanding PPO & GRPO Without Any Prior Reinforcement Learning Knowledge (2025)
  • [2] Nisan Stiennon, Long Ouyang, Jeff Wu, Daniel M. Ziegler, Ryan Lowe, Chelsea Voss, Paul Christiano, et al. - Learning to summarize from human feedback (2020)
  • [3] Long Ouyang, Jeff Wu, Xu Jiang, Diogo Almeida, Carroll L. Wainwright, Pamela Mishkin, Amanda Askell, Paul Christiano, Jan Leike, Ryan Lowe, et al. - Training language models to follow instructions with human feedback (2022)
  • [4] Zhihong Shao, Peiyi Wang, Qihao Zhu, Daya Guo, et al. - DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models (2024)
  • [5] DeepSeek-AI - DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning (2025)
  • [6] Jason Wei - Blog: Asymmetry of verification and verifier’s rule
  • [7] Jason Wei - Blog: Life lessons from reinforcement learning (2025)