Skip to content

openai

OpenAIAPI

Bases: AsyncAPI

Class for querying the OpenAI API asynchronously.

Parameters:

Name Type Description Default
settings Settings

The settings for the pipeline/experiment

required
log_file str

The path to the log file

required
Source code in src/prompto/apis/openai/openai.py
class OpenAIAPI(AsyncAPI):
    """
    Class for querying the OpenAI API asynchronously.

    Parameters
    ----------
    settings : Settings
        The settings for the pipeline/experiment
    log_file : str
        The path to the log file
    """

    def __init__(
        self,
        settings: Settings,
        log_file: str,
        *args: Any,
        **kwargs: Any,
    ):
        super().__init__(settings=settings, log_file=log_file, *args, **kwargs)
        self.api_type = "openai"

    @staticmethod
    def check_environment_variables() -> list[Exception]:
        """
        For OpenAI, there are some optional environment:
        - OPENAI_API_KEY

        These are optional only if the model_name is passed
        in the prompt dictionary. If the model_name is not
        passed, then the default values are taken from these
        environment variables.

        These are checked in the check_prompt_dict method to ensure that
        the required environment variables are set.

        Returns
        -------
        list[Exception]
            A list of exceptions or warnings if the environment variables
            are not set
        """
        issues = []

        # check the optional environment variables are set and warn if not
        issues.extend(check_optional_env_variables_set([API_KEY_VAR_NAME]))

        return issues

    @staticmethod
    def check_prompt_dict(prompt_dict: dict) -> list[Exception]:
        """
        For OpenAI, we make the following model-specific checks:
        - "prompt" key must be of type str, list[str], or list[dict[str,str]]
        - model-specific environment variable (OPENAI_API_KEY_{identifier})
          (where identifier is the model name with invalid characters replaced
          by underscores obtained using get_model_name_identifier function)
          can be set or the default environment variable must be set
        - if "mode" is passed, it must be one of 'chat' or 'completion'

        Parameters
        ----------
        prompt_dict : dict
            The prompt dictionary to check

        Returns
        -------
        list[Exception]
            A list of exceptions or warnings if the prompt dictionary
            is not valid
        """
        issues = []

        # check prompt is of the right type
        if isinstance(prompt_dict["prompt"], str):
            pass
        elif isinstance(prompt_dict["prompt"], list):
            if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
                pass
            elif all(
                isinstance(message, dict) for message in prompt_dict["prompt"]
            ) and all(
                [
                    set(d.keys()) == {"role", "content"}
                    and d["role"] in openai_chat_roles
                    for d in prompt_dict["prompt"]
                ]
            ):
                pass
            else:
                issues.append(TYPE_ERROR)
        else:
            issues.append(TYPE_ERROR)

        # use the model specific environment variables if they exist
        model_name = prompt_dict["model_name"]
        # replace any invalid characters in the model name
        identifier = get_model_name_identifier(model_name)

        # check the required environment variables are set
        # must either have the model specific key or the default key set
        issues.extend(
            check_either_required_env_variables_set(
                [
                    [f"{API_KEY_VAR_NAME}_{identifier}", API_KEY_VAR_NAME],
                ]
            )
        )

        # if mode is passed, check it is a valid value
        if "mode" in prompt_dict and prompt_dict["mode"] not in ["chat", "completion"]:
            issues.append(
                ValueError(
                    f"Invalid mode value. Must be 'chat' or 'completion', not {prompt_dict['mode']}"
                )
            )

        # TODO: add checks for prompt_dict["parameters"] being
        # valid arguments for OpenAI API without hardcoding

        return issues

    async def _obtain_model_inputs(
        self, prompt_dict: dict
    ) -> tuple[str, str, AsyncOpenAI, dict, str]:
        """
        Async method to obtain the model inputs from the prompt dictionary.

        Parameters
        ----------
        prompt_dict : dict
            The prompt dictionary to use for querying the model

        Returns
        -------
        tuple[str, str, AsyncAzureOpenAI, dict, str]
            A tuple containing the prompt, model name, AzureOpenAI client object,
            the generation config, and mode to use for querying the model
        """
        # obtain the prompt from the prompt dictionary
        prompt = prompt_dict["prompt"]

        # obtain model name
        model_name = prompt_dict["model_name"]
        api_key = get_environment_variable(
            env_variable=API_KEY_VAR_NAME, model_name=model_name
        )

        openai.api_key = api_key
        openai.api_type = self.api_type
        client = AsyncOpenAI(api_key=api_key, max_retries=1)

        # get parameters dict (if any)
        generation_config = prompt_dict.get("parameters", None)
        if generation_config is None:
            generation_config = {}
        if type(generation_config) is not dict:
            raise TypeError(
                f"parameters must be a dictionary, not {type(generation_config)}"
            )

        # obtain mode (default is chat)
        mode = prompt_dict.get("mode", "chat")
        if mode not in ["chat", "completion"]:
            raise ValueError(f"mode must be one of 'chat' or 'completion', not {mode}")

        return prompt, model_name, client, generation_config, mode

    async def _query_string(self, prompt_dict: dict, index: int | str) -> dict:
        """
        Async method for querying the model with a string prompt
        (prompt_dict["prompt"] is a string),
        i.e. single-turn completion or chat.
        """
        prompt, model_name, client, generation_config, mode = (
            await self._obtain_model_inputs(prompt_dict)
        )

        try:
            if mode == "chat":
                response = await client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                    **generation_config,
                )
            elif mode == "completion":
                response = await client.completions.create(
                    model=model_name,
                    prompt=prompt,
                    **generation_config,
                )

            response_text = process_response(response)

            log_success_response_query(
                index=index,
                model=f"OpenAI ({model_name})",
                prompt=prompt,
                response_text=response_text,
                id=prompt_dict.get("id", "NA"),
            )

            prompt_dict["response"] = response_text
            return prompt_dict
        except Exception as err:
            error_as_string = f"{type(err).__name__} - {err}"
            log_message = log_error_response_query(
                index=index,
                model=f"OpenAI ({model_name})",
                prompt=prompt,
                error_as_string=error_as_string,
                id=prompt_dict.get("id", "NA"),
            )
            async with FILE_WRITE_LOCK:
                write_log_message(
                    log_file=self.log_file,
                    log_message=log_message,
                    log=True,
                )
            raise err

    async def _query_chat(self, prompt_dict: dict, index: int | str) -> dict:
        """
        Async method for querying the model with a chat prompt
        (prompt_dict["prompt"] is a list of strings to sequentially send to the model),
        i.e. multi-turn chat with history.
        """
        prompt, model_name, client, generation_config, _ = (
            await self._obtain_model_inputs(prompt_dict)
        )

        messages = []
        response_list = []
        try:
            for message_index, message in enumerate(prompt):
                # add the user message to the list of messages
                messages.append({"role": "user", "content": message})
                # obtain the response from the model
                response = await client.chat.completions.create(
                    model=model_name,
                    messages=messages,
                    **generation_config,
                )
                # parse the response to obtain the response text
                response_text = process_response(response)
                # add the response to the list of responses
                response_list.append(response_text)
                # add the response message to the list of messages
                messages.append({"role": "assistant", "content": response_text})

                log_success_response_chat(
                    index=index,
                    model=f"OpenAI ({model_name})",
                    message_index=message_index,
                    n_messages=len(prompt),
                    message=message,
                    response_text=response_text,
                    id=prompt_dict.get("id", "NA"),
                )

            logging.info(
                f"Chat completed (i={index}, id={prompt_dict.get('id', 'NA')})"
            )

            prompt_dict["response"] = response_list
            return prompt_dict
        except Exception as err:
            error_as_string = f"{type(err).__name__} - {err}"
            log_message = log_error_response_chat(
                index=index,
                model=f"OpenAI ({model_name})",
                message_index=message_index,
                n_messages=len(prompt),
                message=message,
                responses_so_far=response_list,
                error_as_string=error_as_string,
                id=prompt_dict.get("id", "NA"),
            )
            async with FILE_WRITE_LOCK:
                write_log_message(
                    log_file=self.log_file,
                    log_message=log_message,
                    log=True,
                )
            raise err

    async def _query_history(self, prompt_dict: dict, index: int | str) -> dict:
        """
        Async method for querying the model with a chat prompt with history
        (prompt_dict["prompt"] is a list of dictionaries with keys "role" and "content",
        where "role" is one of "user", "assistant", or "system" and "content" is the message),
        i.e. multi-turn chat with history.
        """
        prompt, model_name, client, generation_config, _ = (
            await self._obtain_model_inputs(prompt_dict)
        )

        try:
            response = await client.chat.completions.create(
                model=model_name,
                messages=[
                    convert_dict_to_input(
                        content_dict=x, media_folder=self.settings.media_folder
                    )
                    for x in prompt
                ],
                **generation_config,
            )

            response_text = process_response(response)

            log_success_response_query(
                index=index,
                model=f"OpenAI ({model_name})",
                prompt=prompt,
                response_text=response_text,
                id=prompt_dict.get("id", "NA"),
            )

            prompt_dict["response"] = response_text
            return prompt_dict
        except Exception as err:
            error_as_string = f"{type(err).__name__} - {err}"
            log_message = log_error_response_query(
                index=index,
                model=f"OpenAI ({model_name})",
                prompt=prompt,
                error_as_string=error_as_string,
                id=prompt_dict.get("id", "NA"),
            )
            async with FILE_WRITE_LOCK:
                write_log_message(
                    log_file=self.log_file,
                    log_message=log_message,
                    log=True,
                )
            raise err

    async def query(self, prompt_dict: dict, index: int | str = "NA") -> dict:
        """
        Async Method for querying the API/model asynchronously.

        Parameters
        ----------
        prompt_dict : dict
            The prompt dictionary to use for querying the model
        index : int | str
            The index of the prompt in the experiment

        Returns
        -------
        dict
            Completed prompt_dict with "response" key storing the response(s)
            from the LLM

        Raises
        ------
        Exception
            If an error occurs during the querying process
        """
        if isinstance(prompt_dict["prompt"], str):
            return await self._query_string(
                prompt_dict=prompt_dict,
                index=index,
            )
        elif isinstance(prompt_dict["prompt"], list):
            if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
                return await self._query_chat(
                    prompt_dict=prompt_dict,
                    index=index,
                )
            elif all(
                isinstance(message, dict) for message in prompt_dict["prompt"]
            ) and all(
                [
                    set(d.keys()) == {"role", "content"}
                    and d["role"] in openai_chat_roles
                    for d in prompt_dict["prompt"]
                ]
            ):
                return await self._query_history(
                    prompt_dict=prompt_dict,
                    index=index,
                )

        raise TYPE_ERROR

check_environment_variables staticmethod

check_environment_variables() -> list[Exception]

For OpenAI, there are some optional environment: - OPENAI_API_KEY

These are optional only if the model_name is passed in the prompt dictionary. If the model_name is not passed, then the default values are taken from these environment variables.

These are checked in the check_prompt_dict method to ensure that the required environment variables are set.

Returns:

Type Description
list[Exception]

A list of exceptions or warnings if the environment variables are not set

Source code in src/prompto/apis/openai/openai.py
@staticmethod
def check_environment_variables() -> list[Exception]:
    """
    For OpenAI, there are some optional environment:
    - OPENAI_API_KEY

    These are optional only if the model_name is passed
    in the prompt dictionary. If the model_name is not
    passed, then the default values are taken from these
    environment variables.

    These are checked in the check_prompt_dict method to ensure that
    the required environment variables are set.

    Returns
    -------
    list[Exception]
        A list of exceptions or warnings if the environment variables
        are not set
    """
    issues = []

    # check the optional environment variables are set and warn if not
    issues.extend(check_optional_env_variables_set([API_KEY_VAR_NAME]))

    return issues

check_prompt_dict staticmethod

check_prompt_dict(prompt_dict: dict) -> list[Exception]

For OpenAI, we make the following model-specific checks: - “prompt” key must be of type str, list[str], or list[dict[str,str]] - model-specific environment variable (OPENAI_API_KEY_{identifier}) (where identifier is the model name with invalid characters replaced by underscores obtained using get_model_name_identifier function) can be set or the default environment variable must be set - if “mode” is passed, it must be one of ‘chat’ or ‘completion’

Parameters:

Name Type Description Default
prompt_dict dict

The prompt dictionary to check

required

Returns:

Type Description
list[Exception]

A list of exceptions or warnings if the prompt dictionary is not valid

Source code in src/prompto/apis/openai/openai.py
@staticmethod
def check_prompt_dict(prompt_dict: dict) -> list[Exception]:
    """
    For OpenAI, we make the following model-specific checks:
    - "prompt" key must be of type str, list[str], or list[dict[str,str]]
    - model-specific environment variable (OPENAI_API_KEY_{identifier})
      (where identifier is the model name with invalid characters replaced
      by underscores obtained using get_model_name_identifier function)
      can be set or the default environment variable must be set
    - if "mode" is passed, it must be one of 'chat' or 'completion'

    Parameters
    ----------
    prompt_dict : dict
        The prompt dictionary to check

    Returns
    -------
    list[Exception]
        A list of exceptions or warnings if the prompt dictionary
        is not valid
    """
    issues = []

    # check prompt is of the right type
    if isinstance(prompt_dict["prompt"], str):
        pass
    elif isinstance(prompt_dict["prompt"], list):
        if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
            pass
        elif all(
            isinstance(message, dict) for message in prompt_dict["prompt"]
        ) and all(
            [
                set(d.keys()) == {"role", "content"}
                and d["role"] in openai_chat_roles
                for d in prompt_dict["prompt"]
            ]
        ):
            pass
        else:
            issues.append(TYPE_ERROR)
    else:
        issues.append(TYPE_ERROR)

    # use the model specific environment variables if they exist
    model_name = prompt_dict["model_name"]
    # replace any invalid characters in the model name
    identifier = get_model_name_identifier(model_name)

    # check the required environment variables are set
    # must either have the model specific key or the default key set
    issues.extend(
        check_either_required_env_variables_set(
            [
                [f"{API_KEY_VAR_NAME}_{identifier}", API_KEY_VAR_NAME],
            ]
        )
    )

    # if mode is passed, check it is a valid value
    if "mode" in prompt_dict and prompt_dict["mode"] not in ["chat", "completion"]:
        issues.append(
            ValueError(
                f"Invalid mode value. Must be 'chat' or 'completion', not {prompt_dict['mode']}"
            )
        )

    # TODO: add checks for prompt_dict["parameters"] being
    # valid arguments for OpenAI API without hardcoding

    return issues

query async

query(prompt_dict: dict, index: int | str = 'NA') -> dict

Async Method for querying the API/model asynchronously.

Parameters:

Name Type Description Default
prompt_dict dict

The prompt dictionary to use for querying the model

required
index int | str

The index of the prompt in the experiment

'NA'

Returns:

Type Description
dict

Completed prompt_dict with “response” key storing the response(s) from the LLM

Raises:

Type Description
Exception

If an error occurs during the querying process

Source code in src/prompto/apis/openai/openai.py
async def query(self, prompt_dict: dict, index: int | str = "NA") -> dict:
    """
    Async Method for querying the API/model asynchronously.

    Parameters
    ----------
    prompt_dict : dict
        The prompt dictionary to use for querying the model
    index : int | str
        The index of the prompt in the experiment

    Returns
    -------
    dict
        Completed prompt_dict with "response" key storing the response(s)
        from the LLM

    Raises
    ------
    Exception
        If an error occurs during the querying process
    """
    if isinstance(prompt_dict["prompt"], str):
        return await self._query_string(
            prompt_dict=prompt_dict,
            index=index,
        )
    elif isinstance(prompt_dict["prompt"], list):
        if all([isinstance(message, str) for message in prompt_dict["prompt"]]):
            return await self._query_chat(
                prompt_dict=prompt_dict,
                index=index,
            )
        elif all(
            isinstance(message, dict) for message in prompt_dict["prompt"]
        ) and all(
            [
                set(d.keys()) == {"role", "content"}
                and d["role"] in openai_chat_roles
                for d in prompt_dict["prompt"]
            ]
        ):
            return await self._query_history(
                prompt_dict=prompt_dict,
                index=index,
            )

    raise TYPE_ERROR