Source code for oumi.datasets.sft.chatqa
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union, cast
import datasets
import pandas as pd
from typing_extensions import override
from oumi.core.datasets import BaseSftDataset
from oumi.core.registry import register_dataset
from oumi.core.types.conversation import Conversation, Message, Role
[docs]
@register_dataset("nvidia/ChatQA-Training-Data")
class ChatqaDataset(BaseSftDataset):
default_dataset = "nvidia/ChatQA-Training-Data"
default_subset = "sft"
def _get_system_message(self) -> Optional[str]:
if self.dataset_subset == "sft":
return None
if self.dataset_subset == "synthetic_convqa":
return "Please give a full and complete answer for the question."
if self.dataset_subset in ("tatqa-arithmetic", "tatqa"):
# Note: `tatqa` is a combination of `tatqa-arithmetic` and `tatqa-others``
# Here we use the same prompt as tqta-arithmetic for both as it's larger
# (8k vs 3k) Samples.
# Preferably, each dataset should be loaded separately instead of loading
# the combined `tatqa` subset.
return (
"Answer the following question with a number "
"from context or the math arithmetic"
)
if self.dataset_subset == "tatqa-others":
return (
"Answer the following question with a short span, "
"or a full and complete answer"
)
if self.dataset_subset in (
"drop",
"narrativeqa",
"quoref",
"ropes",
"squad1.1",
"squad2.0",
"newsqa",
):
return "Answer the following question with a short span."
raise ValueError(f"Unknown dataset subset: {self.dataset_subset}")
[docs]
@override
def transform_conversation(
self, raw_conversation: Union[dict, pd.Series]
) -> Conversation:
"""Preprocesses the inputs of the example and returns a dictionary.
ChatQA is a conversational question answering dataset.
It contains 10 subsets. Some subsets contain grounding documents.
See the dataset page for more information:
https://huggingface.co/datasets/nvidia/ChatQA-Training-Data
Args:
raw_conversation: The raw conversation example.
Returns:
dict: The preprocessed inputs as an Oumi conversation.
"""
messages = []
# Step 1. Add system message. Most subsets contain one.
system_message = self._get_system_message()
if system_message:
messages.append(Message(role=Role.SYSTEM, content=system_message))
# Step 2. Add grounding context and system instruction
has_context = raw_conversation.get("document") is not None
if has_context:
# Step 2.1. If the sample has a context, we add a system prompt
# to only use information from the context to answer the question
context_message = (
"Only use the information from the user "
"provided context to answer the question."
)
messages.append(Message(role=Role.SYSTEM, content=context_message))
# Step 2.2. Add context document, wrapped in <context> tags
# Note: This is not part of the original dataset
# but is added to make the context more explicit.
document = f"<context>{raw_conversation['document']}</document>"
messages.append(Message(role=Role.USER, content=document))
# Step 3. Add conversation history
# Can contain one or multiple user/assistant turns.
for message in raw_conversation["messages"]:
messages.append(Message(role=message["role"], content=message["content"]))
# Step 4. Add final assistant response, which is encoded differently
# depending on the subset.
if self.dataset_subset == "narrativeqa":
# `narrativeqa` contains an array of arrays of strings
# Note: All rows contain two answers.
# We arbitrarily use the first one answer.
response = raw_conversation["answers"][0][0]
elif self.dataset_subset in ("squad1.1", "squad2.0"):
# `squad1.1` and `squad2.0` contain a list of dicts
# All rows contain a single answer.
response = raw_conversation["answers"][0]["text"]
else:
# All other subsets contain a list of strings
# All rows contain a single answer.
response = raw_conversation["answers"][0]
messages.append({"role": Role.ASSISTANT, "content": response})
return Conversation(messages=messages)
[docs]
@register_dataset("nvidia/ChatQA-Training-Data", subset="tatqa-arithmetic")
@register_dataset("nvidia/ChatQA-Training-Data", subset="tatqa-others")
class ChatqaTatqaDataset(ChatqaDataset):
"""ChatQA Subclass to handle tatqa subsets.
The tatqa subsets require loading a specific file from
the dataset repository, thus requiring us to override the
default loading behavior.
"""
default_subset = "tatqa-arithmetic"
@override
def _load_hf_hub_dataset(self) -> pd.DataFrame:
if self.dataset_subset == "tatqa-arithmetic":
filename = "tatqa/train_arithmetic.json"
else:
filename = "tatqa/train_others.json"
if self.split is not None and self.split != "train":
raise ValueError("Only the `train` split is supported for this dataset.")
dataset = datasets.load_dataset(
self.dataset_name, data_files={"train": filename}
)
dataset = cast(datasets.DatasetDict, dataset)
return cast(pd.DataFrame, dataset["train"].to_pandas())