search iconsearch icon
Type something to search...

Concurrent Async Calls to OpenAI with Rate Limiting

Concurrent Async Calls to OpenAI with Rate Limiting

0. Motivation

In a previous post (Domain LogoProcess Calls with Open AI), we explored how to extract insights from audio recordings using OpenAI models. While extracting insights is critical, fully automating this process presents additional challenges, primarily in managing API quota limits and efficiently handling multiple concurrent API calls.

Consider a scenario where you need to process thousands of audio recordings daily. Without proper management, exceeding the API rate limits could lead to failed requests, delays, or even service disruptions. This post addresses these challenges by demonstrating how to implement a rate limiter that effectively manages concurrent asynchronous calls to the OpenAI API. By ensuring compliance with API quotas and dynamically handling limits, you can maximize the throughput of your requests while maintaining service reliability. Whether you’re using Whisper for transcription or ChatGPT for extracting insights, this guide will help you streamline operations, avoid common pitfalls, and optimize the use of OpenAI’s powerful capabilities.

1. Approach overview

To efficiently manage concurrent API calls to OpenAI while respecting rate limits, we use a two-pronged approach:

  1. Rate Limiting: We implement a custom RateLimiter class to monitor and control the number of requests and tokens used. This rate limiter dynamically adjusts based on API response headers, ensuring that the application remains within OpenAI’s usage policies. By doing so, it prevents overuse and potential service disruptions by pausing execution when limits are reached and resuming once quotas are reset.

  2. Asynchronous Execution: Leveraging Python’s asyncio and the aiohttp library, we can handle multiple API calls concurrently without blocking other tasks. This asynchronous approach is vital for maintaining high throughput and efficiency, as it allows each API call to be managed in a non-blocking manner. The system can handle other operations while waiting for API responses, maximizing resource utilization.

Combining these strategies enables the processing of a large volume of API requests both efficiently and reliably. By structuring outputs using Pydantic models, we ensure that data remains organized and easily manageable. This approach not only optimizes performance but also preserves the integrity and reliability of data processing tasks, making it well-suited for scaling up operations.

2. Quota limits

2.1. Understanding OpenAI Rate Limits

OpenAI enforces rate limits to manage API usage effectively, ensuring fair use and stable performance for all users. The primary metrics to consider are:

  • Requests Per Minute (RPM): Limits the number of API requests allowed per minute. This is crucial for high-frequency tasks like transcribing audio with whisper-1, where RPM can range significantly depending on the subscription tier.
  • Requests Per Day (RPD): Caps the total number of API requests that can be made in a day.
  • Tokens Per Minute (TPM): Limits the number of text tokens processed per minute, including both input and output. This is particularly relevant for ChatGPT models, where the number of tokens depends on the length of the conversation.
  • Tokens Per Day (TPD): Sets a daily limit on token usage across all operations.

The specific rate limits vary by subscription tier. For more details, refer to Domain LogoOpen AI | Rate Limits.

2.2. Managing Rate Limits

Understanding and managing these limits is essential, especially when scaling operations or integrating with automated workflows. Here’s how to handle them effectively:

  • Requests Per Minute (RPM): Critical when making frequent API calls. For example, using whisper-1 for transcriptions may involve a high number of requests, requiring careful monitoring of RPM to prevent exceeding the limit.

  • Tokens Per Minute (TPM): When using models like ChatGPT, it’s important to estimate token usage accurately before making a request. The following function provides a good approximation of token count:

import tiktoken
from pydantic import BaseModel
from pydantic import Field

MODEL_GPT_4 = "gpt-4"
MODEL_TOKENIZER_DEFAULT = MODEL_GPT_4


class GPTTokens(BaseModel):
    prompt_tokens: int = Field(
        0, description="The number of tokens used in the prompt."
    )
    completion_tokens: int = Field(
        0, description="The number of tokens used in the completion."
    )
    total_tokens: int = Field(
        0, description="The total number of tokens used (prompt + completion)."
    )


def count_tokens(messages, model=MODEL_TOKENIZER_DEFAULT):
    """Estimate the number of tokens used by the messages."""

    # I belive the tokenization does not change between sub models
    if model.startswith(MODEL_GPT_4):
        model = MODEL_GPT_4

    encoding = tiktoken.encoding_for_model(model)

    tokens_per_message = 3  # A rough estimate for the start of each message
    tokens_per_name = 1  # If names are used, they add an extra token

    total_tokens = 0
    for message in messages:
        total_tokens += tokens_per_message
        for key, value in message.items():
            total_tokens += len(encoding.encode(value))
            if key == "name":
                total_tokens += tokens_per_name

    # Adding 3 tokens for the overall start and end of the message
    total_tokens += 3

    return total_tokens

This function estimates token usage effectively, but note that actual usage may vary slightly.

The GPTTokens class will be used to track actual token usage during API calls, which will be detailed in the following sections.

3. Rate Limiter

We use a Rate Limiter to prevent exceeding API usage limits, ensuring compliance and avoiding throttling or service disruptions. This approach maximizes the number of valid requests per minute.

3.1. How It Works

The RateLimiter class is designed to monitor and control both request and token usage. It initializes with specified limits for maximum requests (max_requests) and tokens (max_tokens). Before each API call, it checks the remaining quota. If limits are reached, the limiter pauses execution until quotas reset, preventing the application from sending too many requests simultaneously.

After each API call, the limiter updates its counters using response headers. This dynamic adjustment keeps the rate limiter in sync with real-time API limits, ensuring that subsequent calls respect the current usage status. The global set_rate_limiter function simplifies setup by initializing a shared instance that can be accessed throughout the application.

Here you have the full code for the RateLimiter:

rate_limiter.py

import asyncio
import re
import time

from loguru import logger


# Define a regular expression to capture time components
REGEX_TIME = re.compile(r"(?P<value>\d+)(?P<unit>[smhms]+)")

RATE_LIMITER = None
MAX_SLEEP_TIME = 2 * 60  # 2 Minutes


def set_rate_limiter(max_requests: int = None, max_tokens: int = None):
    global RATE_LIMITER
    RATE_LIMITER = RateLimiter(max_requests, max_tokens)


class RateLimiter:
    def __init__(self, max_requests: int = None, max_tokens: int = None):
        self.max_requests = max_requests
        self.max_tokens = max_tokens

        self.remaining_requests = max_requests
        self.remaining_tokens = max_tokens

        # Assume reset in 60 seconds initially
        self.reset_time_requests = time.monotonic() + 60
        self.reset_time_tokens = time.monotonic() + 60

        logger.info(f"Setting {self}")

    def __repr__(self):
        max_reqs = self.max_requests
        max_tokens = self.max_tokens
        rem_reqs = self.remaining_requests
        rem_tokens = self.remaining_tokens
        reset_t_reqs = round(self.reset_time_requests)
        reset_t_tokens = round(self.reset_time_tokens)

        if max_tokens is None:
            return f"RateLimiter({max_reqs=}, {rem_reqs=}, {reset_t_reqs=} [no_tokens])"

        return (
            f"RateLimiter({max_reqs=}, {max_tokens=} "
            f"{rem_reqs=}, {rem_tokens=}, {reset_t_reqs=}, {reset_t_tokens=})"
        )

    def _get_seconds_to_sleep(self):
        if self.max_tokens is None:
            # Only consider requests if tokens are not being used
            sleep_time = self.reset_time_requests - time.monotonic()
        else:
            sleep_time = min(
                self.reset_time_requests - time.monotonic(),
                self.reset_time_tokens - time.monotonic(),
            )

        # Do not wait really high times. Better to try anyway
        sleep_time = min(sleep_time, MAX_SLEEP_TIME)

        # Ensure sleep_time is at least 1 second
        return max(sleep_time, 1)

    async def wait_for_availability(self, required_tokens=None):
        if self.max_tokens is None:
            required_tokens = 0  # Ignore tokens if max_tokens is None

        while self.remaining_requests <= 0 or (
            self.max_tokens is not None and self.remaining_tokens < required_tokens
        ):
            self.update_limits()

            seconds_to_sleep = self._get_seconds_to_sleep()
            logger.debug(f"Sleeping {seconds_to_sleep=}")
            await asyncio.sleep(seconds_to_sleep)

        self.remaining_requests -= 1
        if self.max_tokens is not None:
            self.remaining_tokens -= required_tokens

    def update_limits(self):
        current_time = time.monotonic()

        # If we've passed the reset time, reset the limits
        if current_time >= self.reset_time_requests:
            self.remaining_requests = self.max_requests
            self.reset_time_requests = current_time + 60  # Reset the time window

        if self.max_tokens is not None and current_time >= self.reset_time_tokens:
            self.remaining_tokens = self.max_tokens
            self.reset_time_tokens = current_time + 60  # Reset the time window

        logger.debug(f"Updating limits:\n{self}")

    def _parse_reset_time(self, reset_time_str):
        """Convert a time reset string like '1s', '6m0s', or '60ms' into seconds."""

        total_seconds = 0
        for match in REGEX_TIME.finditer(reset_time_str):
            value = int(match.group("value"))
            unit = match.group("unit")

            if unit == "s":
                total_seconds += value
            elif unit == "m":
                total_seconds += value * 60
            elif unit == "h":
                total_seconds += value * 3600
            elif unit == "ms":
                total_seconds += value / 1000.0

        # Default to 60 seconds if the format is unexpected
        return total_seconds if total_seconds > 0 else 60

    def update_from_headers(self, headers):
        """Update the rate limits based on headers from the API response."""

        self.remaining_requests = int(
            headers.get("x-ratelimit-remaining-requests", self.remaining_requests)
        )

        if self.max_tokens is not None:
            self.remaining_tokens = int(
                headers.get("x-ratelimit-remaining-tokens", self.remaining_tokens)
            )

            reset_tokens_seconds = self._parse_reset_time(
                headers.get("x-ratelimit-reset-tokens", "60s")
            )
            self.reset_time_tokens = time.monotonic() + reset_tokens_seconds

        reset_requests_seconds = self._parse_reset_time(
            headers.get("x-ratelimit-reset-requests", "60s")
        )

        self.reset_time_requests = time.monotonic() + reset_requests_seconds
        logger.debug(f"Updating limits from headers:\n{self}")

There is some overlap between the update_limits and update_from_headers methods. While update_limits resets limits based on a fixed time window, update_from_headers allows dynamic adjustments based on real-time data from API responses. This overlap is intentional, as using update_from_headers ensures more accurate limit management and optimizes the number of valid requests per minute. By relying on actual API feedback, we can better handle fluctuating limits and reduce the risk of hitting rate caps unexpectedly.

4. Using the Rate Limiter with Async calls

As seen in Domain LogoProcess Calls with Open AI we were using the Domain LogoOpenAI python package. This simplifies the process but prevents us from getting relevant data from the headers (as explained in Domain LogoRate limits in headers | Open AI). That data is really useful for knowning:

  • The remaining quota
  • The time until the quota is reset

Since this will help us getting more accurate information which in turn will help us increase the number of requests we can do, we will do manual requests to the API.

4.1. Calling the API

We will use Domain LogoAiohttp for making API requests. The general approach is similar to what is descrived in Domain LogoBest strategy on managing concurrent calls ? (Python/Asyncio).

The general idea is to have a function that waits for sufficient quota (managed by the RateLimiter) before calling the API. If a HTTP 429 error (rate limit exceeded) occurs, the function will recursively call itself to retry the request. After each API call, it updates the RateLimiter with quota information from the response headers.

A semaphore is also used to limit the number of concurrent calls, ensuring that we don’t overwhelm the system. More details on this will be covered later.

In the next section, you will see the full code for making calls to both Chat and Transcription endpoints.

4.2. Calling the Transcription endpoint

For transcriptions, we use whisper-1, which does not have specific RPM (Requests Per Minute) limits. This makes the code simpler compared to the one used for Chat.

Here is the code for it:

transcription.py

import io
import time
from datetime import datetime
from typing import Optional

import aiohttp
from mutagen.mp3 import HeaderNotFoundError
from mutagen.mp3 import MP3
from pydantic import BaseModel
from loguru import logger

import rate_limiter as rt
import utils as u

MIN_SECONDS = 5
DEFAULT_MODEL = "whisper-1"
ENDPOINT_TRANSCRIPTIONS = "https://api.openai.com/v1/audio/transcriptions"

MAX_REQUESTS_PER_MIN = 100

async def _call_openai_transcription(data):
    await rt.RATE_LIMITER.wait_for_availability()
    async with u.RATE_LIMITER_SEMAPHORE:  # Ensure no more than N tasks run concurrently
        async with aiohttp.ClientSession() as session:
            async with session.post(
                ENDPOINT_TRANSCRIPTIONS, headers=u.HEADERS, data=data
            ) as res:
                # Update the rate limiter with the response headers
                # We want this no matter if the request succeded or failed
                rt.RATE_LIMITER.update_from_headers(res.headers)

                try:
                    res.raise_for_status()

                except aiohttp.ClientResponseError as e:
                    if e.status == 429:
                        logger.warning("Rate limit exceeded, retrying")
                        await rt.RATE_LIMITER.wait_for_availability()
                        return await _call_openai_transcription(data)
                    else:
                        raise e

                result = await res.json()
                return result["text"]

class TranscriptionOutput(BaseModel):
    """
    Export relevant metadata
    """

    filename: str
    duration_seconds: float
    transcription: Optional[str]
    transcription_time: Optional[float] = None
    transcribed_at: datetime = datetime.now()
    gpt_model: str


async def call_openai_transcription(filename, temperature=0.2, gpt_model=DEFAULT_MODEL):
    logger.info(f"Processing {filename=}")

    t0 = time.monotonic()
    file = io.BytesIO(s3.read_file(filename))
    duration_seconds = get_duration(file)

    transcription = None
    trans_time = None

    if duration_seconds > MIN_SECONDS:
        data = aiohttp.FormData()
        data.add_field(
            "file",
            file,
            filename=filename.split("/")[-1],
            content_type="audio/mpeg",
        )
        data.add_field("model", gpt_model)
        data.add_field("response_format", "json")
        data.add_field("temperature", str(temperature))

        transcription = await _call_openai_transcription(data)
        trans_time = time.monotonic() - t0
        logger.info(f"Successfully transcribed {filename=} in {trans_time:.2f} seconds")
    else:
        logger.info(
            f"Skipping transcription for {filename=}, "
            f"{duration_seconds=:.2f} is less than {MIN_SECONDS=}"
        )

    return TranscriptionOutput(
        filename=filename,
        duration_seconds=duration_seconds,
        transcription=transcription,
        transcription_time=trans_time,
        gpt_model=gpt_model,
    )

The _call_openai_transcription function handles the rate limits, ensuring compliance by checking availability before making API calls. If a rate limit is exceeded, indicated by a HTTP 429 error, it retries the request after waiting.

The call_openai_transcription function manages the preparation of input data for the API call and formats the output as a Pydantic object for structured response handling (more details on this can be found in Domain LogoProcess Calls with Open AI).

4.3. Calling the Chat Endpoint

Calls to the Chat endpoint are quite similar to those for Transcriptions. The key differences are that we need to manage the number of tokens used in each request and ensure that the output is structured according to a Pydantic class. More details on this can be found in Domain LogoProcess Calls with Open AI.

chat.py

import json
from datetime import datetime
from typing import Optional

import aiohttp
from pydantic import BaseModel
from pydantic import Field
from loguru import logger

import rate_limiter as rt
import utils as u
from tokens import count_tokens
from tokens import GPTTokens

DEFAULT_MODEL = "gpt-4o-mini"
ENDPOINT_CHAT = "https://api.openai.com/v1/chat/completions"

MAX_REQUESTS_PER_MIN = 10_000
MAX_TOKENS_PER_MIN = 10_000_000


class BaseChatResponse(BaseModel):
    gpt_tokens_used: GPTTokens = Field(
        default_factory=GPTTokens,
        description="Information about the tokens used by the ChatGPT API, "
        "including prompt, completion, and total tokens.",
    )


class BadResponseException(Exception):
    def __init__(self, message="OpenAI API response is malformed or incomplete"):
        super().__init__(message)


def _construct_messages(prompt, text=None):
    messages = []
    if text is not None:
        messages += [{"role": "system", "content": prompt}]

    messages += [{"role": "user", "content": text or prompt}]
    return messages


def get_json_schema_from_pydantic(pydantic_model):
    """
    To force ChatGPT to output json data.
    More info: https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
    """
    schema = pydantic_model.model_json_schema()

    # Manually add "additionalProperties": false to the schema
    schema["additionalProperties"] = False

    json_schema = {
        "name": pydantic_model.__name__,
        "description": pydantic_model.Config.description,
        "schema": schema,
        "strict": True,
    }

    return {"type": "json_schema", "json_schema": json_schema}


async def _call_openai_chat(data, required_tokens):
    await rt.RATE_LIMITER.wait_for_availability(required_tokens)
    async with u.RATE_LIMITER_SEMAPHORE:  # Ensure no more than N tasks run concurrently
        async with aiohttp.ClientSession() as session:
            async with session.post(ENDPOINT_CHAT, headers=u.HEADERS, json=data) as res:
                try:
                    res.raise_for_status()

                except aiohttp.ClientResponseError as e:
                    if e.status == 429:
                        logger.warning("Rate limit exceeded, retrying")
                        await rt.RATE_LIMITER.wait_for_availability()
                        return await _call_openai_chat(data, required_tokens)
                    else:
                        raise e

                # Update the rate limiter with the response headers
                rt.RATE_LIMITER.update_from_headers(res.headers)

                # Parse the response
                return await res.json()


async def call_openai_chat(
    prompt, text=None, pydantic_model=None, gpt_model=DEFAULT_MODEL, id=None
):
    messages = _construct_messages(prompt, text)
    required_tokens = count_tokens(messages, model=gpt_model)

    data = {"model": gpt_model, "messages": messages}

    if pydantic_model is not None:
        data["response_format"] = get_json_schema_from_pydantic(pydantic_model)

    response = await _call_openai_chat(data, required_tokens)

    if "usage" not in response or "choices" not in response or not response["choices"]:
        raise BadResponseException(f"Missing expected fields in {response=}")

    usage = response["usage"]
    result = response["choices"][0]["message"]["content"]

    if pydantic_model is None:
        return result, usage

    result_json = json.loads(result)

    class ChatOutput(pydantic_model):
        """
        This class is an extension of the input `pydantic_model`
        It adds multiple metadata meant for debugging and issues resolution
        """

        gpt_called_at: datetime = datetime.now()
        gpt_tokens_used: GPTTokens
        gpt_raw_input: str = text or prompt
        gpt_model: str = data["model"]
        id: Optional[str]

    return ChatOutput(id=id, gpt_tokens_used=usage, **result_json)

4.4. Utils

I use a utils file for setting the headers needed for the API calls and to control the number of concurrent jobs with a Semaphore.

There is also a function that helps validating that the output is as expected by validating that is an instance of the pydantic class we expect.

utils.py

import asyncio

from loguru import logger

HEADERS = None
RATE_LIMITER_SEMAPHORE = asyncio.Semaphore(25)


def init_openai(secret, n_jobs=None, json=True):
    global HEADERS
    HEADERS = {"Authorization": f"Bearer {secret['api_key']}"}
    if json:
        HEADERS["Content-Type"] = "application/json"

    if n_jobs:
        global RATE_LIMITER_SEMAPHORE
        RATE_LIMITER_SEMAPHORE = asyncio.Semaphore(n_jobs)


def split_valid_and_invalid_records(records, pydantic_model):
    if invalid_results := [x for x in records if not isinstance(x, pydantic_model)]:
        logger.error(f"There are {len(invalid_results)} failed OpenAI calls")
    else:
        logger.info("All calls were successful")

    valid_results = [x for x in records if isinstance(x, pydantic_model)]
    return valid_results, invalid_results

5. Doing batches

The idea is to initiallize the RateLimiter with the limits for each endpoint and to then do all concurrent calls asyncronously with asyncio.gather. Finally since the outputs are always a pydantic class we can easily transform the results to a pandas dataframe so that we can export it as a table.

5.1. Transcriptions in batches

Here is the code for the transcriptions:

import asyncio
import time

import pandas as pd

async def async_call_open_ai_transcription(files, gpt_model=DEFAULT_MODEL):
    t0 = time.monotonic()
    rt.set_rate_limiter(MAX_REQUESTS_PER_MIN)

    msg_jobs = f"(n_jobs={u.RATE_LIMITER_SEMAPHORE._value})"
    logger.info(f"Processing {len(files)} calls to OpenAI asyncronously {msg_jobs}")
    tasks = [call_openai_transcription(x, gpt_model=gpt_model) for x in files]
    results = await asyncio.gather(*tasks, return_exceptions=True)
    logger.info(
        f"All calls done in {(time.monotonic() - t0)/ 60:.2f} mins {msg_jobs}"
    )

    output, errors = u.split_valid_and_invalid_records(results, TranscriptionOutput)
    return pd.DataFrame([x.dict() for x in output]), errors

5.2. Chats in batches

The code for chats is very similar. The only difference is that here we allow to return the text as is (if wanted) or to get structured data as a pydantic class.

import asyncio
import time

import pandas as pd

async def async_call_open_ai_chat(
    prompt, df, pydantic_model=None, gpt_model=DEFAULT_MODEL
):
    t0 = time.monotonic()
    rt.set_rate_limiter(MAX_REQUESTS_PER_MIN, MAX_TOKENS_PER_MIN)

    msg_jobs = f"(n_jobs={u.RATE_LIMITER_SEMAPHORE._value})"
    logger.info(f"Processing {df.shape[0]} calls to OpenAI asyncronously {msg_jobs}")
    tasks = [
        call_openai_chat(
            prompt, pydantic_model=pydantic_model, gpt_model=gpt_model, **row
        )
        for _, row in df.iterrows()
    ]
    results = await asyncio.gather(*tasks, return_exceptions=True)

    logger.info(
        f"All calls done in {(time.monotonic() - t0)/ 60:.2f} mins {msg_jobs}"
    )
    if pydantic_model is None:
        return results

    output, errors = u.split_valid_and_invalid_records(results, pydantic_model)

    # If we have a pydantic_model, we can extract tabulated data
    return pd.DataFrame([x.dict() for x in output]), errors

6. Further Reading and Resources