None llm_batch_inference_ray_vllm
In [2]:
%load_ext watermark
%load_ext autoreload
%autoreload 2

import os
import ray
import torch
import logging
import pandas as pd
from ray.util.placement_group import (
    placement_group,
    placement_group_table,
    remove_placement_group,
)
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm import LLM, SamplingParams

%watermark -a 'Ethen' -d -v -u -p transformers,torch,numpy,pandas,ray,vllm
Author: Ethen

Last updated: 2025-06-27

Python implementation: CPython
Python version       : 3.11.11
IPython version      : 8.32.0

transformers: 4.51.3
torch       : 2.6.0
numpy       : 1.26.4
pandas      : 2.2.3
ray         : 2.46.0
vllm        : 0.8.5

LLM Batch Inference with Ray and VLLM

This article builds towards a minimal LLM batch inference pipeline. We'll give a quick introduction to ray, VLLM, as well as tensor parallelism as part of this process before putting every piece of the building blocks together into the final solution.

Quick Introduction to Ray

In [3]:
# Create a single node Ray cluster pre-defined resources.
ray.init(
    num_cpus=8,
    num_gpus=4,
    # avoid polluting notebook with log info
    log_to_driver=False,
)
2025-06-17 00:36:27,157	INFO worker.py:1879 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265 
Out[3]:

Ray enables arbitrary functions to be executed asynchronously on separate Python workers. Such functions are called Ray remote functions and their asynchronous invocations are called Ray tasks.

In [4]:
# By adding the `@ray.remote` decorator, a regular Python function
# becomes a Ray remote function
@ray.remote
def my_function():
    return 1

# To invoke this remote function, use the `remote` method.
# This will immediately return an object ref (a future) and then create
# a task that will be executed on a worker process.
# This call is non-blocking
obj_ref = my_function.remote()

# The result can be retrieved with ``ray.get``.
assert ray.get(obj_ref) == 1

Actors extend Ray API from functions (tasks) to classes. An actor is essentially a stateful worker

In [5]:
# The ray.remote decorator indicates that instances of the Counter class are actors
# We can specify resource requirements as part remote
@ray.remote(num_cpus=1, num_gpus=1)
class Counter:
    def __init__(self):
        self.value = 0

    def increment(self):
        self.value += 1
        return self.value

    def get_counter(self):
        return self.value


# Create an actor from this class.
counter = Counter.remote()
# Call the actor.
obj_ref = counter.increment.remote()
assert ray.get(obj_ref) == 1

Ray's placement group can be used to reserve group of resouces across nodes. e.g. In distributed hyper parameter tuning, we must ensure all the resources needed for a given trial is made available at the same time, and packed resources required together, so node failures have minimal impact.

While creating these resource bundles, strategy allows us to specify whether these resources should be created to spread out on multiple nodes STRICT_SPREAD or have to be created within the same node STRICT_PACK.

We can schedule ray actors or tasks to a placement group once it has been created.

In [6]:
# a bundle is a collection of resources e.g. 1 CPU and 1 GPU,
# and placement group are represented by a list of bundles
pg = placement_group([{"CPU": 1, "GPU": 1}], strategy="STRICT_PACK")
ray.get(pg.ready())

# we can show placement group info through placement_group_table API
print(placement_group_table(pg))
{'placement_group_id': '3ebe34c338652490a05ba98f0ec501000000', 'name': '', 'bundles': {0: {'CPU': 1.0, 'GPU': 1.0}}, 'bundles_to_node_id': {0: 'ab8f683a5b2ed845357850df3db9de9f807e15ed5ebbdbbdb8cc5bba'}, 'strategy': 'STRICT_PACK', 'state': 'CREATED', 'stats': {'end_to_end_creation_latency_ms': 1.641, 'scheduling_latency_ms': 1.568, 'scheduling_attempt': 1, 'highest_retry_delay_ms': 0.0, 'scheduling_state': 'FINISHED'}}
In [7]:
# Create an actor to a placement group.
counter = Counter.options(
    scheduling_strategy=PlacementGroupSchedulingStrategy(
        placement_group=pg,
        placement_group_bundle_index=0
    )
).remote()
obj_ref = counter.increment.remote()
assert ray.get(obj_ref) == 1
In [8]:
remove_placement_group(pg)
print(placement_group_table(pg))
{'placement_group_id': '3ebe34c338652490a05ba98f0ec501000000', 'name': '', 'bundles': {0: {'CPU': 1.0, 'GPU': 1.0}}, 'bundles_to_node_id': {0: 'ab8f683a5b2ed845357850df3db9de9f807e15ed5ebbdbbdb8cc5bba'}, 'strategy': 'STRICT_PACK', 'state': 'REMOVED', 'stats': {'end_to_end_creation_latency_ms': 1.641, 'scheduling_latency_ms': 1.568, 'scheduling_attempt': 1, 'highest_retry_delay_ms': 0.0, 'scheduling_state': 'REMOVED'}}

Quick Introduction to vLLM

When interacting with vLLM, LLM is the main class for initiating vLLM engine, SamplingParams defines various parameters for the sampling process.

In [10]:
# we expect our model to be a modern instruct version that takes in chat message 
messages = [
    {"role": "user", "content": "Give me a short introduction to large language model."}
]
sampling_params = SamplingParams(n=1, max_tokens=512)
# we invoke chat method, compared to generate, it automatically applies
# the model's corresponding chat template
request_outputs = llm.chat(messages, sampling_params)
request_outputs
INFO 06-17 00:37:18 [chat_utils.py:397] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s, est. speed input: 42.93 toks/s, output: 151.91 toks/s]
Out[10]:
[RequestOutput(request_id=0, prompt=None, prompt_token_ids=[151644, 8948, 198, 2610, 525, 1207, 16948, 11, 3465, 553, 54364, 14817, 13, 1446, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 35127, 752, 264, 2805, 16800, 311, 3460, 4128, 1614, 13, 151645, 198, 151644, 77091, 198], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='A large language model, also known as a Language Generation Model, is a type of deep learning model that can generate text or language-like outputs. These models are designed to be able to generate coherent and contextually relevant text based on language rules and statistical patterns observed in language data.\nIn recent years, the field of large language models has grown rapidly, with the development of models such as BERT, GPT, and Transformers. These models have been used for a wide range of tasks, including machine translation, text generation, language modeling, and more. Today, large language models are being used in a variety of applications, including chatbots, virtual assistants, and content creation tools.', token_ids=[32, 3460, 4128, 1614, 11, 1083, 3881, 438, 264, 11434, 23470, 4903, 11, 374, 264, 943, 315, 5538, 6832, 1614, 429, 646, 6923, 1467, 476, 4128, 12681, 16275, 13, 4220, 4119, 525, 6188, 311, 387, 2952, 311, 6923, 55787, 323, 2266, 1832, 9760, 1467, 3118, 389, 4128, 5601, 323, 28464, 12624, 13166, 304, 4128, 821, 624, 641, 3213, 1635, 11, 279, 2070, 315, 3460, 4128, 4119, 702, 14700, 18512, 11, 448, 279, 4401, 315, 4119, 1741, 438, 425, 3399, 11, 479, 2828, 11, 323, 80532, 13, 4220, 4119, 614, 1012, 1483, 369, 264, 6884, 2088, 315, 9079, 11, 2670, 5662, 14468, 11, 1467, 9471, 11, 4128, 33479, 11, 323, 803, 13, 11201, 11, 3460, 4128, 4119, 525, 1660, 1483, 304, 264, 8045, 315, 8357, 11, 2670, 6236, 61905, 11, 4108, 56519, 11, 323, 2213, 9688, 7375, 13, 151645], cumulative_logprob=None, logprobs=None, finish_reason=stop, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=None, multi_modal_placeholders={})]

Given its output class, we will only parse out the generated text/response, as well as finish_reason. Finish reason can be useful for determining whether we are setting appropriate max generation token limit, i.e. prevent cropping model's unfinished response.

In [11]:
request_output = request_outputs[0]
finished_reasons = []
generated_texts = []
for request_output in request_outputs:
    # we assume only sampling 1 output
    output = request_output.outputs[0]
    generated_texts.append(output.text)
    finished_reasons.append(output.finish_reason)

predictions = {
    "generated_texts": generated_texts,
    "finished_reasons": finished_reasons
}
predictions
Out[11]:
{'generated_texts': ['A large language model, also known as a Language Generation Model, is a type of deep learning model that can generate text or language-like outputs. These models are designed to be able to generate coherent and contextually relevant text based on language rules and statistical patterns observed in language data.\nIn recent years, the field of large language models has grown rapidly, with the development of models such as BERT, GPT, and Transformers. These models have been used for a wide range of tasks, including machine translation, text generation, language modeling, and more. Today, large language models are being used in a variety of applications, including chatbots, virtual assistants, and content creation tools.'],
 'finished_reasons': ['stop']}
In [12]:
del llm

Quick Introduction to Tensor Parallelism

The following diagrams are directly copied from pytorch lightning's tensor parallelism illustration [3]

Data parallel is the most common form of parallelism due to its simplicity. Our model will be replicated across each device, and process data shards in parallel. Tensor parallelism is a form of model parallel technique used for large-scale models by distributing layers across multiple devices. This approach significantly reduces memory requirements per device, as each device only needs to store and process a portion of the weight matrix. There're two ways in which a linear layer can be distributed: column-wise or row-wise.

Column-wise Parallelism:

  • Weight matrix is divided evenly along the column dimension.
  • Each device receives identical input and performs a matrix multiplication with its allocated portion of the weight matrix.
  • Final output is formed by concatenating results from all devices.

Row-wise Parallelism:

  • Row-wise parallelism divides rows of the weight matrix evenly across devices. Given the weight matrix now has fewer rows, input also needs to be split along the inner dimension.
  • Each device then performs a matrix multiplication with its portion of the weight matrix and inputs.
  • Outputs from each device can be summed up element-wise (all-reduce) to form the final output.

Combined Column- and Row-wise Parallelism:

  • This hybrid approach is particularly effective for model architectures that have sequential linear layers, such as MLPs or Transformers.
  • The output of a column-wise parallel layer is maintained in its distributed form and directly fed into a subsequent row-wise parallel layer. This strategy minimizes inter-device data transfers, optimizing computational efficiency.

Batch Inference

We will now combine the knowledge we accumulated around ray core, VLLM and add in ray data for building a batch inference pipeline using 2D parallelism, data parallel plus tensor parallel [1] [2].

Our example demonstrates parallelization configured for multiple GPUs within a single machine. However, the primary application of 2D parallelism is in multi-node environments, where it often involves applying data parallelism for inter-node, and tensor parallelism for intra-node. Reason being tensor parallelism necessitates blocking collective calls, making rapid communication crucial for maintaining high throughput.

  Data Parallelism (across nodes)
  <----------------------------->

Node 0              Node 1         
+------------+     +------------+ 
|  GPU 0     |     |  GPU 0     | 
|    ↕       |     |    ↕       |    
|  GPU 1     |     |  GPU 1     |     
+------------+     +------------+     
     ↕                   ↕          
Tensor Parallel     Tensor Parallel

Legend:
↔ Data Parallelism: Horizontal scaling across nodes
↕ Tensor Parallelism: Vertical scaling across all GPUs within a node
In [ ]:
# we pick a smaller model to quickly showcase the concept, it's very likely
# a pure data parallel approach is faster for this model
pretrained_model_name_or_path = "Qwen/Qwen2.5-1.5B-Instruct"

# Set tensor parallelism per instance.
tensor_parallel_size = 2

# Set number of instances. Each instance will use tensor_parallel_size GPUs.
num_instances = 4
concurrency = num_instances // tensor_parallel_size

prediction_path = "vllm_prediction"

sampling_params = SamplingParams(n=1, temperature=0.6, max_tokens=512)
In [14]:
# Create a Ray Dataset from the list of dictionaries, so we can quickly mock some input data
# in real world scenarios, read from actual data ray.data.read_parquet / ray.data.read_text
data_dicts = [{"messages": messages}] * 32
ds = ray.data.from_items(data_dicts).repartition(concurrency)
2025-06-17 00:37:19,523	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
In [15]:
# Define a class for inference.
# Use class to initialize the model just once in `__init__`
# and re-use it for inference across multiple batches.
class LLMPredictor:

    def __init__(
        self,
        pretrained_model_name_or_path: str,
        sampling_params: SamplingParams,
        tensor_parallel_size: int = 1,
    ):
        self.llm = LLM(
            pretrained_model_name_or_path,
            tensor_parallel_size=tensor_parallel_size,
            dtype=torch.bfloat16,
        )
        self.sampling_params = sampling_params

    # Logic for inference on 1 batch of data.
    def __call__(self, batch):
        batch_messages = [text.tolist() for text in batch["messages"]]    
        request_outputs = self.llm.chat(batch_messages, self.sampling_params)

        finished_reasons = []
        generated_texts = []
        for request_output in request_outputs:
            # we assume only sampling 1 output
            output = request_output.outputs[0]
            generated_texts.append(output.text)
            finished_reasons.append(output.finish_reason)
    
        return {
            "generated_texts": generated_texts,
            "finished_reasons": finished_reasons
        }
In [16]:
def scheduling_strategy_fn():
    """For tensor_parallel_size > 1, we need to create one bundle per tensor parallel worker"""
    pg = placement_group(
        [{"CPU": 1, "GPU": 1}] * tensor_parallel_size,
        strategy="STRICT_PACK"
    )
    return dict(
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            pg, placement_group_capture_child_tasks=True)
    )


# define resources required for each actor
resources_kwarg = {}
if tensor_parallel_size == 1:
    # For tensor_parallel_size == 1, we simply set num_gpus=1.
    resources_kwarg["num_gpus"] = 1
else:
    # Otherwise, we have to set num_gpus=0 and provide
    # a function that will create a placement group for
    # each instance.
    resources_kwarg["num_gpus"] = 0
    resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
In [ ]:
ds_prediction = ds.map_batches(
    LLMPredictor,
    concurrency=concurrency,
    batch_size=4,
    fn_constructor_kwargs={
        "pretrained_model_name_or_path": pretrained_model_name_or_path,
        "tensor_parallel_size": tensor_parallel_size,
        "sampling_params": sampling_params,
    },
    **resources_kwarg,
)
ds_prediction.write_parquet(prediction_path)
In [22]:
# reading some sample output to showcase valid
# LLM inference result
pd.read_parquet(prediction_path).iloc[:3].values
Out[22]:
array([['A large language model is a type of artificial intelligence (AI) system designed to understand and generate human language. These models are trained on vast amounts of text data, allowing them to learn patterns and relationships within language. They can be used for a variety of tasks, such as language translation, text summarization, question answering, and even creative writing. The ability of large language models to generate text that is contextually appropriate and coherent has made them popular in recent years, particularly for their ability to improve the quality of human-generated language.',
        'stop'],
       ['Large language models, also known as AI language models, are artificial intelligence systems that are designed to understand, generate, and respond to human language. These models are based on advanced algorithms and machine learning techniques that allow them to analyze vast amounts of text data and learn patterns and relationships between words and phrases.\nOne of the key characteristics of large language models is their ability to generate human-like text that is coherent and contextually appropriate. This means that they can be used for a wide range of tasks, including natural language processing, language translation, text generation, and even chatbot development.\nIn recent years, large language models have become increasingly popular due to their ability to process large amounts of text data quickly and accurately. They are used in a variety of industries, including tech, finance, healthcare, and more.',
        'stop'],
       ['A large language model is a type of artificial intelligence that is designed to generate human-like text based on the input it receives. These models are trained on vast amounts of text data, such as books, articles, and other written material, and are capable of generating coherent and contextually relevant responses to a wide range of prompts or questions. Large language models are used in a variety of applications, including language translation, chatbots, virtual assistants, and content generation.',
        'stop']], dtype=object)

Reference

  • [1] VLLM offline batch inference example
  • [2] End-to-end: Offline Batch Inference
  • [3] PyTorch Lightning Tensor Parallelism