None dpo
In [1]:
# code for loading notebook's format
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)


%load_ext watermark
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import pandas as pd
import transformers
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from peft import LoraConfig
from trl import DPOTrainer, DPOConfig

%watermark -a 'Ethen' -d -v -u -p transformers,datasets,torch,trl,peft
Author: Ethen

Last updated: 2024-10-23

Python implementation: CPython
Python version       : 3.10.14
IPython version      : 8.26.0

transformers: 4.42.0
datasets    : 3.0.0
torch       : 2.1.2+cu121
trl         : 0.11.3
peft        : 0.10.0

Direct Preference Optimization (DPO)

A typical process of training modern LLM involves an un/self-supervised training stage, followed by an instruction tuning stage. The instruction tuning phase trains LLM on higher quality of instructions to completion datasets. This helps conform the model's outputs with desired behaviors or tasks, making it more reliable and effective for various applications. Despite the success of instruction tuning, relative judgements of response quality are often times easier to collect, thus subsequent LLM works have a so called alignment stage, which tunes LLMs with preference dataset and reinforcement learning based algorithms, a.k.a. RLHF (Reinforcement Learning with Human Feedback).

In Reinforcement Learning from Human Feedback (RLHF), our objective can be written as:

$$ \max_{\pi_{\theta}} \mathbb{E}{x \sim D, y \sim \pi_{\theta}(y|x)}[r(x, y)] - \beta \cdot \mathrm{KL}(\pi_{\theta}(y|x)||\pi_{\mathrm{ref}}(y|x)) $$
  • This objective function optimizes our LLM policy $\pi_{\theta}$ to maximize the expected reward $[r(x, y)]$ for the questions/prompts $x$ sampled from our dataset $D$ and the answers $y$ generated by LLM policy $\pi_{\theta}(y|x)$. The reason we can't just run gradient descent on this objective function and resort to RL algorithms is because the output $y$ are sampled from LLM using various decoding strategies (greedy, beam search, etc.). All of which are not differentiable.
  • At the same time, it seeks to minimize Kullback-Leibler (KL) divergence between the LLM policy $\pi_{\theta}(y|x)$ and the original reference policy $\pi_{\mathrm{ref}}(y|x)$, weighted by a factor $\beta$. Higher $\beta$ means less deviation from the reference model. This second term is added to prevent "reward hacking," a situation where the LLM generates sequences of tokens that achieve high reward scores but may be nonsensical.

Given this objective function, RLHF methods usually involves fitting a reward model to dataset of human preferences and then use RL to optimize a language model policy to produce responses assigned high reward without drifting excessively far from the original model.

Suppose we have a preference dataset at hand, we can convert these preferences into a score by modeling it via the Bradley-Terry model.

$$\Pr(y_1 \succ y_2) = \frac{\exp(r(x, y_1))}{\exp(r(x, y_1)) + \exp(r(x, y_2))}$$

where:

  • $P(i \succ j)$ represents probability that the first answer ($y_1$) is better than the second answer ($y_2$) in a paired comparison.
  • The numerator $\exp(r(x, y_1))$ is the (hidden) reward function that evaluates the first answer's quaility $y_1$ for an input prompt $x$. Similarly, for the second answer, we have $\exp(r(x, y_2))$.
  • The probability is computed by taking the ratio of first answer's reward over the sum of rewards for both answers. This normalization ensures we result in a valid probability distribution.

We can parameterize the reward function and estimate its parameter via maximum likelihood, which boils down to framing it as a binary classfication problem:

$$ L = -\mathbb{E}_{(x,y_1,y_2)\sim D}\left[\log \sigma\left(r(x,y_1) - r(x,y_2)\right)\right] $$

Where $\sigma$ is the logistic function $\frac{1}{1 + e^{-x}}$.

Note, we can show the two are connected. Assuming $r(x,y_1) = A$ and $r(x,y_2) = B$

$$ \frac{e^A}{e^A + e^B} = \frac{\frac{e^A}{e^A}}{\frac{e^A + e^B}{e^A}} = \frac{1}{1 + \left(\frac{e^A + e^B}{e^A} - 1\right)} = \frac{1}{1 + \left(\frac{e^A + e^B - e^A}{e^A}\right)} = \frac{1}{1 + \left(\frac{e^B}{e^A}\right)} = \frac{1}{1 + e^{B-A}} = \sigma(A - B) $$

While RLHF is capable of producing models with impressive capabilities (supposedly one of the secret sauces behind ChatGPT), its pipeline can be consideraby more complex than supervised learning. Involving training multiple LMs and sampling from a LLM policy as part of the training loop. One of the key insights in Direct Preference Optimization (DPO) [9] is replacing this complex process in RLHF with a supervised learning algorithm that implicitly optimizes the same objective as RLHF.

The DPO paper shows an optimal solution to this optimization problem is:

$$\pi_{\theta}(y|x) = \frac{1}{Z(x)} \pi_{\mathrm{ref}}(y|x) \exp\left(\frac{1}{\beta}r(x, y)\right)$$

where:

$$Z(x) = \sum_{y'} \pi_{\mathrm{ref}}(y'|x) \exp\left(\frac{1}{\beta}r(x, y')\right)$$

While an exact solution exists, it is hard to utilize in practice. $Z(x)$ entails we would have to compute all possible answers that can be generated by our LLM, making it computationally intractable.

The trick is to re-arrange the above term to be based on the reward function:

$$\log \pi_{\theta}(y|x) = \log \left[\frac{1}{Z(x)}\pi_\text{ref}(y|x) \exp \left(\frac{1}{\beta} r(x,y)\right)\right] = \log \pi_\text{ref}(y|x) - \log Z(x) + \log \exp \left(\frac{1}{\beta} r(x,y)\right) = \log \pi_\text{ref}(y|x) - \log Z(x) + \frac{1}{\beta} r(x,y)$$$$r(x, y) = \beta \log \frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x)$$

Given this reward function expression, we can now plug it back into Bradley-Terry model expression:

$$P(y_1 > y_2) = \sigma(r(x, y_1) - r(x, y_2)) = \sigma \left( \beta \log \frac{\pi_{\theta}(y_1|x)}{\pi_{ref}(y_1|x)} + \beta \log Z(x) - \beta \log \frac{\pi_{\theta}(y_2|x)}{\pi_{ref}(y_2|x)} - \beta \log Z(x) \right)$$

With the Bradley-Terry model depending only on reward differences beteen the two completions, the computationally expensive $Z(x)$ term canels out. With the final DPO loss being:

$$L_{DPO}(\pi_{\theta}; \pi_{ref}) = -\mathbb{E}{(x,y_1,y_2)\sim D} \left[ \log \sigma \left( \beta \log \frac{\pi_{\theta}(y_1|x)}{\pi_{ref}(y_1|x)} - \beta \log \frac{\pi_{\theta}(y_2|x)}{\pi_{ref}(y_2|x)} \right) \right]$$

And with all that, we now have the probability of human preference data in terms of our optimal policy rather than a reward model.

Similar to existing algorithms, Direct Preference Optimization (DPO) uses a preference model, such as the Bradley-Terry model, to evaluate how well a reward function aligns with empirical preference data. However, while RLHF methods trains a reward model based on the preference model and then optimize a policy to maximize that learned reward, DPO takes a different approach. It uses a change of variables to directly define a preference loss as a function of the policy itself. This is DPO's main contribution - an algorithm that can train language models from human preferences through binary cross-entropy objective, discarding the need for reinforcement learning methods.

Implementation

This implementation section is comprised of two parts. We'll roll out the DPO loss calculations ourselves, as well as showcase how to leverage trl library's DPOTrainer [3] [5]. We'll be leveraging Ultrafeedback [10] as our dataset. Ultrafeedback is a synthetic preference dataset collected via LLMs. At a high level, the authors compiled 60K diverse instructions, prompted a pool of distinct models at different capability levels for generating completions, and finally leveraged GPT-4 for annotating completion pairs.

In [3]:
dataset = load_dataset(
    "argilla/ultrafeedback-binarized-preferences-cleaned",
    split="train",
    verification_mode="no_checks",
    cache_dir="/data"
)
print(dataset)
dataset[0]
Dataset({
    features: ['source', 'prompt', 'chosen', 'chosen-rating', 'chosen-model', 'rejected', 'rejected-rating', 'rejected-model'],
    num_rows: 60917
})
Out[3]:
{'source': 'evol_instruct',
 'prompt': 'Can you write a C++ program that prompts the user to enter the name of a country and checks if it borders the Mediterranean Sea? Here\'s some starter code to help you out:\n#include <iostream>\n#include <string>\nusing namespace std;\nint main() {\n    string country;\n    // prompt user for input\n    cout << "Enter the name of a country: ";\n    cin >> country;\n    // check if country borders the Mediterranean Sea\n    // [C++ code]\n    return 0;\n}',
 'chosen': [{'content': 'Can you write a C++ program that prompts the user to enter the name of a country and checks if it borders the Mediterranean Sea? Here\'s some starter code to help you out:\n#include <iostream>\n#include <string>\nusing namespace std;\nint main() {\n    string country;\n    // prompt user for input\n    cout << "Enter the name of a country: ";\n    cin >> country;\n    // check if country borders the Mediterranean Sea\n    // [C++ code]\n    return 0;\n}',
   'role': 'user'},
  {'content': 'Here\'s a C++ program that prompts the user to enter the name of a country and checks if it borders the Mediterranean Sea:\n\n#include <iostream>\n#include <string>\n#include <set>\n#include <map>\n#include <algorithm>\n\nusing namespace std;\n\nint main() {\n    // store countries and their bordering seas in a map\n    map<string, set<string>> countries;\n    countries["Algeria"] = {"Mediterranean Sea", "North African Coast"};\n    countries["France"] = {"Mediterranean Sea", "English Channel"};\n    countries["Italy"] = {"Mediterranean Sea", "Adriatic Sea"};\n    countries["Malta"] = {"Mediterranean Sea", "North African Coast"};\n    countries["Spain"] = {"Mediterranean Sea", "Atlantic Ocean"};\n    countries["Tunisia"] = {"Mediterranean Sea", "North African Coast"};\n\n    string country;\n    cout << "Enter the name of a country: ";\n    cin >> country;\n\n    // check if country is in the map\n    if (countries.find(country)!= countries.end()) {\n        // check if country borders the Mediterranean Sea\n        set<string> borders = countries.at(country);\n        if (borders.find("Mediterranean Sea")!= borders.end()) {\n            cout << country << " borders the Mediterranean Sea." << endl;\n        } else {\n            cout << country << " does not border the Mediterranean Sea." << endl;\n        }\n    } else {\n        cout << country << " is not a recognized country." << endl;\n    }\n\n    return 0;\n}\n\nIn this program, we first define a map called countries that stores each country\'s name as a key and a set of its surrounding seas as the value. We use a set to store each sea name so that we can easily check if a country borders a certain sea.\n\nWe then present the user with a prompt to enter the name of a country. We use a while loop to ensure that the user enters a valid country name. We check if the country is in the map, and if so, we retrieve its bordering seas from the map.\n\nFinally, we check if the Mediterranean Sea is among the set of borders. If it is, we inform the user that the country borders the Mediterranean Sea. If not, we inform the user that the country does not border the Mediterranean Sea.',
   'role': 'assistant'}],
 'chosen-rating': 5.0,
 'chosen-model': 'starchat',
 'rejected': [{'content': 'Can you write a C++ program that prompts the user to enter the name of a country and checks if it borders the Mediterranean Sea? Here\'s some starter code to help you out:\n#include <iostream>\n#include <string>\nusing namespace std;\nint main() {\n    string country;\n    // prompt user for input\n    cout << "Enter the name of a country: ";\n    cin >> country;\n    // check if country borders the Mediterranean Sea\n    // [C++ code]\n    return 0;\n}',
   'role': 'user'},
  {'content': 'Sure, here is the program using the C++11 algorithm "cds::algorithm::GreaterEqual":\n#include <iostream>\n#include <string>\n#include <algorithm>\n#include <vector>\n#include <cctype>\n\nusing namespace std;\n\nint main() {\n    string country;\n    cout << "Enter the name of a country: ";\n    cin >> country;\n    std::vector<string> vec;\n    vec.push_back(country);\n    size_t index = std::find_if(vec.begin(), vec.end(), [](const string& s) {\n        return std::any_of(s.begin(), s.end(), [](const char& c) {\n            return c == \'}}\';\n    });\n    if (index != vec.end()) {\n        if ((*index)[0] == \'M\') {\n            cout << "Country is bordered by the Mediterranean Sea." << endl;\n        } else {\n            cout << "Country does not border the Mediterranean Sea." << endl;\n        }\n    } else {\n        cout << "Country is not found." << endl;\n    }\n    return 0;\n}',
   'role': 'assistant'}],
 'rejected-rating': 1.25,
 'rejected-model': 'pythia-12b'}
In [4]:
def create_preference_triplets(example):
    """
    Create preference triplets:
    
    - `prompt`: prompt that is given to a model for text generation.
    - `chosen`: preferred generated response for the corresponding prompt.
    - `rejected`: response that is not preferred.
    """
    chosen = extract_assistant_messages(example["chosen"], index=-1)
    rejected = extract_assistant_messages(example["rejected"], index=-1)
 
    return {
        "prompt": example["prompt"],
        "chosen": chosen,
        "rejected": rejected
    }


def extract_assistant_messages(messages, index=-1):
    """Recursively extract the last assistant messages from the end of the conversation."""
    if messages[index]["role"] == "assistant":
        return messages[index]["content"]
    else:
        extract_assistant_messages(messages, index - 1)
In [5]:
dataset_dict = dataset.train_test_split(test_size=0.01, seed=54321)
dataset_dict_preprocessed = dataset_dict.map(
    create_preference_triplets,
    num_proc=8
)
dataset_dict_preprocessed["train"][0]
Out[5]:
{'source': 'flan_v2_niv2',
 'prompt': "In this task, you are given two sentences in the English language (Sentence 1 and Sentence 2). Your task is to identify the connecting word between the two sentences.\n\nExample input: Sentence 1:A suicide car bomber detonated his bomb as he attempted to drive an SUV into a Somali military camp . Sentence 2:However the vehicle exploded at the entrance when police guards opened fire on the insurgent driving the vehicle .\nExample output: however\nExample explanation: The above sentence is connected by the word 'however'.\nQ: Sentence 1:She went on to earn 39 caps in the Silver Ferns by the end of 2006 . Sentence 2:but was not selected for the 2007 World Netball Championships squad , subsequently retiring from netball to focus on beach volleyball , partnering with Susan Blundell .\nA:",
 'chosen': 'but\nConfidence: 90%',
 'chosen-rating': 4.5,
 'chosen-model': 'vicuna-33b',
 'rejected': 'The connecting word is "by"',
 'rejected-rating': 1.75,
 'rejected-model': 'starchat'}

For DPO, we need two copies of the model, one serving as the model/policy we wish to optimize, and another for reference model, hence memory requirement will be higher than supervised fine-tuning.

In [6]:
model_name_or_path = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir="/data")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

reference_model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    cache_dir="/data"
)
reference_model.eval()

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    cache_dir="/data"
)
model.train()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Out[6]:
Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear(in_features=896, out_features=151936, bias=False)
)
In [7]:
def get_batch_logps(logits, labels, label_pad_token_id = -100):
    """
    Get log probabilities summed over all tokens for a given batch.
    
    Parameters
    ----------
    logits: [batch size, seq length, vocab size]
    labels: [batch size, seq length]

    Returns
    -------
    logps: [batch size]
    """

    # next token prediction: labels are inputs shifted by one
    targets = labels[:, 1:].clone()
    # truncate logits to match labels' number of tokens
    logits = logits[:, :-1, :]

    loss_mask = targets != label_pad_token_id
    
    # dummy token; we'll ignore loss on these tokens later
    targets[targets == label_pad_token_id] = 0
    
    log_probs = logits.log_softmax(dim=-1)
    # [batch size, seq length]
    per_token_logps = torch.gather(log_probs, dim=2, index=targets.unsqueeze(2)).squeeze(2)
    return (per_token_logps * loss_mask).sum(-1)
In [8]:
example = dataset_dict_preprocessed["train"][0]
prompt = example["prompt"]
chosen = example["chosen"]
rejected = example["rejected"]
In [9]:
chosen_tokenized_outputs = tokenizer([prompt + chosen], padding=True, return_tensors="pt")
policy_chosen_model_outputs = model(
    chosen_tokenized_outputs["input_ids"],
    chosen_tokenized_outputs["attention_mask"]
)
policy_chosen_logps = get_batch_logps(
    policy_chosen_model_outputs.logits,
    chosen_tokenized_outputs["input_ids"]
)

rejected_tokenized_outputs = tokenizer([prompt + rejected], padding=True, return_tensors="pt")
policy_rejected_model_outputs = model(
    rejected_tokenized_outputs["input_ids"],
    rejected_tokenized_outputs["attention_mask"]
)
policy_rejected_logps = get_batch_logps(
    policy_rejected_model_outputs.logits,
    rejected_tokenized_outputs["input_ids"]
)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
In [10]:
with torch.no_grad():
    reference_chosen_model_outputs = reference_model(
        chosen_tokenized_outputs["input_ids"],
        chosen_tokenized_outputs["attention_mask"]
    )
    reference_chosen_logps = get_batch_logps(
        reference_chosen_model_outputs.logits,
        chosen_tokenized_outputs["input_ids"]
    )
    
    reference_rejected_model_outputs = reference_model(
        rejected_tokenized_outputs["input_ids"],
        rejected_tokenized_outputs["attention_mask"]
    )
    reference_rejected_logps = get_batch_logps(
        reference_rejected_model_outputs.logits,
        rejected_tokenized_outputs["input_ids"]
    )
In [11]:
def dpo_loss(
    policy_chosen_logps,
    policy_rejected_logps,
    reference_chosen_logps,
    reference_rejected_logps,
    beta=0.1,
):
    """logps shape [batch size,]"""
    chosen_logratios = policy_chosen_logps - reference_chosen_logps
    rejected_logratios = policy_rejected_logps - reference_rejected_logps
    logits = chosen_logratios - rejected_logratios
    loss = -F.logsigmoid(beta * logits).mean(dim=-1)
    return loss
In [12]:
dpo_loss(
    policy_chosen_logps,
    policy_rejected_logps,
    reference_chosen_logps,
    reference_rejected_logps,
)
Out[12]:
tensor(0.6931, grad_fn=<NegBackward0>)

DPOTrainer

In [13]:
training_args = DPOConfig(
    output_dir="./dpo_sample_model",
    fp16=True,
    gradient_accumulation_steps=4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    max_steps=100,
    logging_steps=50,
    learning_rate=0.0001,
    beta=0.1,
    max_length=512,
    max_prompt_length=256,
    remove_unused_columns=False,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
)
peft_config = LoraConfig(
    r=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "down_proj",
        "up_proj",
        "gate_proj"
    ],
    modules_to_save=[
        "embed_tokens",
        "lm_head"
    ]
)
dpo_trainer = DPOTrainer(
    model,
    train_dataset=dataset_dict_preprocessed["train"],
    eval_dataset=dataset_dict_preprocessed["test"],
    tokenizer=tokenizer,
    args=training_args,
    peft_config=peft_config,
)
dpo_trainer.train()
Tokenizing train dataset: 100%|██████████| 60307/60307 [05:22<00:00, 187.11 examples/s]
Tokenizing eval dataset: 100%|██████████| 610/610 [00:03<00:00, 185.96 examples/s]
max_steps is given, it will override any value given in num_train_epochs
Could not estimate the number of tokens of the input, floating-point operations will not be computed
[100/100 06:22, Epoch 0/1]
Step Training Loss Validation Loss Rewards/chosen Rewards/rejected Rewards/accuracies Rewards/margins Logps/rejected Logps/chosen Logits/rejected Logits/chosen
50 0.884900 0.917661 -1.439814 -1.465053 0.511475 0.025239 -272.778931 -303.907043 -1.459422 -1.637277
100 0.854500 0.940485 -1.702263 -1.789815 0.522951 0.087551 -276.026550 -306.531555 -1.460726 -1.642910

Out[13]:
TrainOutput(global_step=100, training_loss=0.8696735000610352, metrics={'train_runtime': 384.2274, 'train_samples_per_second': 1.041, 'train_steps_per_second': 0.26, 'total_flos': 0.0, 'train_loss': 0.8696735000610352, 'epoch': 0.006632729202248495})

DPO trainer offers several metrics related to rewards for us to monitor as part of its training process:

  • rewards/chosen: for the chosen responses, it measures the mean difference between policy model and reference model's log probabilities, scaled by beta
  • rewards/rejected: Same as above but for rejected responses.
  • rewards/accuracies: how often rewards/chosen is greater than its corresponding rewards/rejected.
  • rewards/margins: the mean difference between the chosen and corresponding rejected rewards.

We can also compare final result using LLM as a parwise judge. In our case, this boils down to comparing the answer generated by the original base/instruct model with the model that was aligned using DPO.

llm_judge_responses.parquet shows pairwise comparison of a Qwen 2.5 3B instruct reference model versus the reference model being aligned with DPO on ultra feedback dataset.

  1. 3B model was trained via dpo_train.py (took approx. 1 hour on 16 A100)
  2. response generated via generate.py
  3. LLM based judgements curated via claude_judge.py. LLM as a judge script is largely following the concept introduced as part of a separate notebook, except we call claude through AWS bedrock's API for obtaining judge response.
In [16]:
df_llm_judge_responses = pd.read_parquet("llm_judge_responses_v7.parquet")
print(df_llm_judge_responses.shape)
df_llm_judge_responses.head()
(61, 6)
Out[16]:
prompts responses1 responses2 judge_responses judge_responses_swapped_position result
0 Is a raspberry pi powerful enough to run real ... I have a raspberry pi 3 and I want to run a c... For running real-time chess software on a Ras... 2 1 win
1 In this task, you will be presented with a que... B\nThe answer to this example can be: B\nHere... B\nThe answer to this example can be: B\nHere... 1 1 tie
2 Name one kind of technologie that stones use m... Stones are typically not used in daily life b... It's important to note that stones themselves... 2 1 win
3 Which clown is more famous in Thunderstorm? A. Red Aiken B. White Aiken C. Both Aiken clo... In the context of the play "Thunderstorm" by ... 1 1 tie
4 Did your mouth hear the music? Did your mouth taste the food? Did your mouth... Or did your ears hear the music?\n\nWhen we t... 2 1 win
In [18]:
# dpo model preferred 42%
# instruct model preferred 23%
index = 0
print(df_llm_judge_responses["prompts"].iloc[index])
print("\n instruct model response: \n")
print(df_llm_judge_responses["responses1"].iloc[index])
print("\n dpo model response: \n")
print(df_llm_judge_responses["responses2"].iloc[index])
Is a raspberry pi powerful enough to run real time chess software?

 instruct model response: 

 I have a raspberry pi 3 and I want to run a chess engine on it. I have a 16GB microSD card and 1GB RAM. I have a 1080p monitor and a keyboard and mouse. I have a 1.4GHz processor. I have a 100Mbps internet connection. I have a 100W power supply. I have a 100% clean environment. I have a 100% stable power supply. I have a 100% stable internet connection. I have a 100% stable power supply. I have a 100% stable internet connection. I have a 100% stable power supply. I have a 100% stable internet connection. I have a 100% stable power supply. I have a 100% stable internet connection. I have a 100% stable power supply. I have a 100% stable internet connection. I have a 100% stable power supply. I have a 100% stable internet connection. I have a 100% stable power supply. I

 dpo model response: 

 For running real-time chess software on a Raspberry Pi, it's important to consider both the hardware capabilities and the specific requirements of the chess software you plan to use.

Typically, chess software requires a significant amount of processing power, especially for high-level games and complex algorithms. The Raspberry Pi, while quite capable for many tasks, has some limitations that might affect its performance for chess software.

Here are some key factors to consider:

1. CPU: The Raspberry Pi 4 has a 1.5 GHz quad-core ARM Cortex-A72 CPU, which is quite powerful. However, chess software can still be quite demanding on CPU resources, especially for high-level games and complex algorithms.

2. Memory: The Raspberry Pi 4 has 4GB of RAM, which is sufficient for most tasks, but chess software might still require more memory for optimal performance.

3. Storage: The Raspberry Pi 4 has 64GB of eMMC storage, which is more than enough for chess software.

4. GPU: The Raspberry Pi 4 has a Broadcom VideoCore IV GPU, which is not as powerful as dedicated GPUs found in more powerful computers. This might affect chess software performance, especially for games with complex graphics or animations.

5.

Since DPO was introduced, subsequent works such as IPO, KPO have attempted to further improve upon it. Though, thorough comparisons on these approaches thus far seems to indicate that they all offer similar performance when hyperparameters scans are properly conducted [4].

Reference

  • [1] Youtube: Direct Preference Optimization (DPO) explained
  • [2] Youtube: CS 285: Eric Mitchell: Reinforcement Learning from Human Feedback: Algorithms & Applications
  • [3] Huggingface Blog: Fine-tune Llama 2 with DPO
  • [4] Huggingface Blog: Preference Tuning LLMs with Direct Preference Optimization Methods
  • [5] RLHF in 2024 with DPO & Hugging Face
  • [6] Unveiling the Hidden Reward System in Language Models: A Dive into DPO
  • [7] Github: DPO - Direct Preference Optimization
  • [8] Deriving DPO’s Loss
  • [9] Rafael Rafailov, Archit Sharma, Eric Mitchell et al. - Direct Preference Optimization: Your Language Model is Secretly a Reward Model (2023)
  • [10] Ganqu Cui, Lifan Yuan, et al. - UltraFeedback: Boosting Language Models with Scaled AI Feedback (2023)