Source code for oumi.core.configs.params.synthesis_params
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importrefromdataclassesimportdataclass,fieldfromenumimportEnumfrompathlibimportPathfromtypingimportAny,Optionalfromoumi.core.configs.params.base_paramsimportBaseParamsfromoumi.core.types.conversationimportConversation,Message,Role_SUPPORTED_DATASET_FILE_TYPES={".jsonl",".json",".csv",".parquet",".tsv"}
[docs]@dataclassclassTextMessage:"""Text-only message to make it usable in omegaconf."""role:Rolecontent:str
[docs]defto_message(self)->Message:"""Convert to a Message."""returnMessage(role=self.role,content=self.content)
[docs]@dataclassclassTextConversation:"""Text-only conversation to make it usable in omegaconf."""messages:list[TextMessage]conversation_id:Optional[str]=Nonemetadata:dict[str,Any]=field(default_factory=dict)
[docs]defto_conversation(self)->Conversation:"""Convert to a Conversation."""returnConversation(messages=[message.to_message()formessageinself.messages],conversation_id=self.conversation_id,metadata=self.metadata,)
[docs]@dataclassclassDatasetSource:"""Dataset to be used in synthesis."""path:str"""Path to the dataset source."""hf_split:Optional[str]=None"""Split of the huggingface dataset to be used in synthesis."""hf_revision:Optional[str]=None"""Revision of the huggingface dataset to be used in synthesis."""attribute_map:Optional[dict[str,str]]=None"""Map of attributes to be used in synthesis. Will use the existing keys in the dataset if not specified."""
[docs]classSegmentationStrategy(str,Enum):"""Segmentation strategies."""TOKENS="tokens""""Segment the document via tokens."""
[docs]@dataclassclassDocumentSegmentationParams:"""Segmentation parameters to be used when segmenting the document."""id:str"""ID to be used when referencing the document segment during synthesis."""segmentation_strategy:SegmentationStrategy=SegmentationStrategy.TOKENS"""Type of segmentation to be used."""tokenizer:str="openai-community/gpt2""""Tokenizer to be used for segmentation. Tokenizers can be specified by their HuggingFace Hub ID or by direct file path. If not specified, will use the GPT-2 tokenizer from the HuggingFace Hub."""segment_length:int=2048"""Length of each segment, dependent on the segmentation strategy."""segment_overlap:int=0"""Overlap between segments. Must be less than segment_length."""keep_original_text:bool=False"""Whether to keep the original text of the document."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifself.segment_length<=0:raiseValueError("Segment length must be positive.")ifself.segment_overlap<0:raiseValueError("Segment overlap must be non-negative.")ifself.segment_overlap>=self.segment_length:raiseValueError("Segment overlap must be less than segment length.")ifself.segmentation_strategy==SegmentationStrategy.TOKENS:ifnotself.tokenizer:raiseValueError("DocumentSegmentationParams.tokenizer cannot be empty when ""segmentation_strategy is TOKENS.")
[docs]@dataclassclassDocumentSource:"""Documents to be used in synthesis."""path:str"""Path to the document source."""id:str"""ID to be used when referencing the document during synthesis."""segmentation_params:Optional[DocumentSegmentationParams]=None"""Segmentation parameters to be used when segmenting the document."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.path:raiseValueError("DocumentSource.path cannot be empty.")ifnotself.id:raiseValueError("DocumentSource.id cannot be empty.")
[docs]@dataclassclassExampleSource:"""In-line examples to be used in synthesis."""examples:list[dict[str,Any]]"""Examples to be used in synthesis."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.examples:raiseValueError("ExampleSource.examples cannot be empty.")keys=self.examples[0].keys()forexampleinself.examples:ifexample.keys()!=keys:raiseValueError("All examples must have the same keys.")
[docs]@dataclassclassPermutableAttributeValue:"""Value to be used for the attribute."""id:str"""ID to be used when referencing the attribute value during synthesis."""value:str"""Value to be used for the attribute. Referenced as {attribute_id.value}"""description:str"""Description of the attribute value. Referenced as {attribute_id.value.description}"""sample_rate:Optional[float]=None"""Sample rate for the attribute value. If not specified, will assume uniform sampling among possible values."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.id:raiseValueError("PermutableAttributeValue.id cannot be empty.")ifnotself.value:raiseValueError("PermutableAttributeValue.value cannot be empty.")ifnotself.description:raiseValueError("PermutableAttributeValue.description cannot be empty.")ifself.sample_rateisnotNoneand(self.sample_rate<0orself.sample_rate>1):raiseValueError("PermutableAttributeValue.sample_rate must be between 0 and 1.")
[docs]@dataclassclassPermutableAttribute:"""Attributes to be varied across the dataset."""id:str"""ID to be used when referencing the attribute during synthesis."""attribute:str"""Plaintext name of the attribute. Referenced as {attribute_id}"""description:str"""Description of the attribute. Referenced as {attribute_id.description}"""possible_values:list[PermutableAttributeValue]"""Type of the attribute."""
[docs]defget_value_distribution(self)->dict[str,float]:"""Get the distribution of attribute values."""value_distribution={}forvalueinself.possible_values:value_distribution[value.id]=value.sample_ratereturnvalue_distribution
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.id:raiseValueError("PermutableAttribute.id cannot be empty.")ifnotself.attribute:raiseValueError("PermutableAttribute.attribute cannot be empty.")ifnotself.description:raiseValueError("PermutableAttribute.description cannot be empty.")ifnotself.possible_values:raiseValueError("PermutableAttribute.possible_values cannot be empty.")value_ids=[]sample_rates=[]forvalueinself.possible_values:value_ids.append(value.id)sample_rates.append(value.sample_rate)value_ids_set=set(value_ids)iflen(value_ids)!=len(value_ids_set):raiseValueError("PermutableAttribute.possible_values must have unique IDs.")# Normalize sample ratesnormalized_sample_rates=[]undefined_sample_rate_count=0defined_sample_rate=0.0forsample_rateinsample_rates:ifsample_rateisnotNone:defined_sample_rate+=sample_rateelse:undefined_sample_rate_count+=1ifdefined_sample_rate>1.0:raiseValueError("PermutableAttribute.possible_values must sum to 1.0.")# Assign remaining sample rate to undefined sample ratesremaining_sample_rate=1.0-defined_sample_rateforsample_rateinsample_rates:ifsample_rateisNone:normalized_sample_rates.append(remaining_sample_rate/undefined_sample_rate_count)else:normalized_sample_rates.append(sample_rate)# Update sample ratesfori,sample_rateinenumerate(normalized_sample_rates):self.possible_values[i].sample_rate=sample_rate
[docs]@dataclassclassAttributeCombination:"""Sampling rates for combinations of attributes."""combination:dict[str,str]"""Combination of attribute values to be used."""sample_rate:float"""Sample rate for the combination."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifself.sample_rate<0orself.sample_rate>1:raiseValueError("AttributeCombination.sample_rate must be between 0 and 1.")ifnotself.combination:raiseValueError("AttributeCombination.combination cannot be empty.")forkey,valueinself.combination.items():ifnotkey:raiseValueError("AttributeCombination.combination key cannot be empty.")ifnotvalue:raiseValueError("AttributeCombination.combination value cannot be empty.")iflen(self.combination.keys())<=1:raiseValueError("AttributeCombination.combination must have at least two keys.")
[docs]@dataclassclassGeneratedAttributePostprocessingParams:"""Postprocessing parameters for generated attributes."""id:str"""ID to be used when referencing the postprocessing parameters during synthesis."""keep_original_text_attribute:bool=True"""Whether to keep the original text of the generated attribute. If True, the original text will be returned as an attribute. If False, the original text will be discarded."""cut_prefix:Optional[str]=None"""Cut off value before and including prefix."""cut_suffix:Optional[str]=None"""Cut off value after and including suffix."""regex:Optional[str]=None"""Regex to be used to pull out the value from the generated text."""strip_whitespace:bool=True"""Whether to strip whitespace from the value."""added_prefix:Optional[str]=None"""Prefix to be added to the value."""added_suffix:Optional[str]=None"""Suffix to be added to the value."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.id:raiseValueError("GeneratedAttributePostprocessingParams.id cannot be empty.")ifself.regex:try:re.compile(self.regex)exceptExceptionase:raiseValueError(f"Error compiling GeneratedAttributePostprocessingParams.regex: {e}")
[docs]@dataclassclassGeneratedAttribute:"""Attributes to be generated."""id:str"""ID to be used when referencing the attribute during synthesis."""instruction_messages:list[TextMessage]"""List of messages providing instructions for generating this attribute."""postprocessing_params:Optional[GeneratedAttributePostprocessingParams]=None"""Postprocessing parameters for the generated attribute."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.id:raiseValueError("GeneratedAttribute.id cannot be empty.")ifnotself.instruction_messages:raiseValueError("GeneratedAttribute.instruction_messages cannot be empty.")ifself.postprocessing_params:ifself.id==self.postprocessing_params.id:raiseValueError("GeneratedAttribute.id and ""GeneratedAttributePostprocessingParams.id ""cannot be the same.")
[docs]classTransformationType(str,Enum):"""Types of transformation strategies."""STRING="string"LIST="list"DICT="dict"CHAT="chat"
[docs]@dataclassclassTransformationStrategy:"""Discriminated union for transformation strategies that works with OmegaConf."""type:TransformationType"""The type of transformation strategy."""# For string transformationsstring_transform:Optional[str]=None"""String transformation template (used when type=STRING)."""# For list transformationslist_transform:Optional[list[str]]=None"""List of transforms for each element (used when type=LIST)."""# For dict transformationsdict_transform:Optional[dict[str,str]]=None"""Mapping of dictionary keys to their transforms (used when type=DICT)."""# For chat transformationschat_transform:Optional[TextConversation]=None"""Chat transform for chat messages (used when type=CHAT)."""
[docs]def__post_init__(self):"""Verifies/populates params based on the type."""ifself.type==TransformationType.STRING:ifself.string_transformisNoneorself.string_transform=="":raiseValueError("string_transform cannot be empty when type=STRING")# Clear other fieldsself.list_transform=Noneself.dict_transform=Noneself.chat_transform=Noneelifself.type==TransformationType.LIST:ifnotself.list_transformorlen(self.list_transform)==0:raiseValueError("list_transform cannot be empty when type=LIST")# Clear other fieldsself.string_transform=Noneself.dict_transform=Noneself.chat_transform=Noneelifself.type==TransformationType.DICT:ifnotself.dict_transformorlen(self.dict_transform)==0:raiseValueError("dict_transform cannot be empty when type=DICT")# Clear other fieldsself.string_transform=Noneself.list_transform=Noneself.chat_transform=Noneelifself.type==TransformationType.CHAT:ifnotself.chat_transformorlen(self.chat_transform.messages)==0:raiseValueError("chat_transform cannot be empty when type=CHAT")messages=self.chat_transform.messagesformessageinmessages:content=message.contentifnotisinstance(content,str):raiseValueError("chat_transform message content must be a string")ifnotcontent:raiseValueError("chat_transform message content cannot be empty")# Clear other fieldsself.string_transform=Noneself.list_transform=Noneself.dict_transform=None
[docs]@dataclassclassTransformedAttribute:"""Transformation of existing attributes."""id:str"""ID to be used when referencing the transformed attribute during synthesis."""transformation_strategy:TransformationStrategy"""Strategy to be used for the transformation."""
[docs]def__post_init__(self):"""Verifies/populates params."""ifnotself.id:raiseValueError("TransformedAttribute.id cannot be empty.")ifnotisinstance(self.transformation_strategy,TransformationStrategy):raiseValueError("TransformedAttribute.transformation_strategy must be a "f"TransformationStrategy, got {type(self.transformation_strategy)}")
[docs]defget_strategy(self)->TransformationStrategy:"""Get the strategy for the transformation."""returnself.transformation_strategy
[docs]@dataclassclassGeneralSynthesisParams(BaseParams):"""General synthesis parameters."""input_data:Optional[list[DatasetSource]]=None"""Datasets whose rows and columns will be used in synthesis. Rows will be enumerated during sampling, and columns can be referenced as attributes when generating new attributes."""input_documents:Optional[list[DocumentSource]]=None"""Documents to be used in synthesis. Documents will be enumerated during sampling, and both documents and document segments can be referenced as attributes when generating new attributes."""input_examples:Optional[list[ExampleSource]]=None"""In-line examples to be used in synthesis. Examples will be enumerated during sampling, and attributes can be referenced as attributes when generating new attributes."""permutable_attributes:Optional[list[PermutableAttribute]]=None"""Attributes to be varied across the dataset. Attributes each have a set of possible values which will be randomly sampled according to their sample rate. If no sample rate is specified, a uniform distribution is used. Sample rates must sum to <= 1.0. Any attributes that do not have a sample rate will be given a uniform sample rate equal to whatever remains. For example, if there are 3 attributes with sample rates of 0.5, 0.3, and 0.2, the total sample rate is 1.0. The first attribute will be sampled 50% of the time, the second attribute will be sampled 30% of the time, and the third attribute will be sampled 20% of the time. If the last two attributes have no sample rate, they will be sampled 25% of the time each as (1.0 - 0.5) / 2 = 0.25."""combination_sampling:Optional[list[AttributeCombination]]=None"""Sampling rates for combinations of attributes. Each combination is a dictionary of attribute IDs to their values. The sample rate is the probability of sampling this combination. The sample rate of all combinations must sum to <= 1.0."""generated_attributes:Optional[list[GeneratedAttribute]]=None"""Attributes to be generated. Generated attributes are created by running a chat with the model. The chat is specified by a list of messages. The messages will be populated with attribute values specific to that data point. The output of the chat is the generated attribute. For example, if one of the previous attributes is "name", and you use the following instruction messages:: [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "How do you pronounce the name {name}?"} ] Then assuming your data point has a value of "Oumi" for the "name" attribute, the chat will be run with the following messages:: [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "How do you pronounce the name Oumi?"} ] The model's response to these messages will be the value of the "name" attribute for that data point."""transformed_attributes:Optional[list[TransformedAttribute]]=None"""Transformation of existing attributes. Transformed attributes involve no model interaction and instead are for the convenience of transforming parts of your data into a new form. For example, if you have "prompt" and "response" attributes, you can create a "chat" attribute by transforming the "prompt" and "response" attributes into a chat message:: [ {"role": "user", "content": "{prompt}"}, {"role": "assistant", "content": "{response}"} ] """passthrough_attributes:Optional[list[str]]=None"""When specified, will ONLY pass through these attributes in final output. If left unspecified, all attributes are saved. If an attribute is specified in passthrough_attributes but doesn't exist, it will be ignored."""def_check_attribute_ids(self,attribute_ids:set[str],id:str):"""Check if the attribute ID is already in the set."""ifidinattribute_ids:raiseValueError(f"GeneralSynthesisParams contains duplicate attribute IDs: {id}")attribute_ids.add(id)def_check_dataset_source_attribute_ids(self,all_attribute_ids:set[str])->None:"""Check attribute IDs from dataset sources for uniqueness."""ifself.input_dataisNone:returniflen(self.input_data)==0:raiseValueError("GeneralSynthesisParams.input_data cannot be empty.")fordataset_sourceinself.input_data:ifdataset_source.attribute_map:fornew_keyindataset_source.attribute_map.values():self._check_attribute_ids(all_attribute_ids,new_key)def_check_document_source_attribute_ids(self,all_attribute_ids:set[str])->None:"""Check attribute IDs from document sources for uniqueness."""ifself.input_documentsisNone:returniflen(self.input_documents)==0:raiseValueError("GeneralSynthesisParams.input_documents cannot be empty.")fordocument_sourceinself.input_documents:ifnotdocument_source.segmentation_params:continueseg_key=document_source.segmentation_params.idself._check_attribute_ids(all_attribute_ids,seg_key)def_check_example_source_attribute_ids(self,all_attribute_ids:set[str])->None:"""Check attribute IDs from example sources for uniqueness."""ifself.input_examplesisNone:returniflen(self.input_examples)==0:raiseValueError("GeneralSynthesisParams.input_examples cannot be empty.")forexample_sourceinself.input_examples:example_keys=example_source.examples[0].keys()fornew_keyinexample_keys:self._check_attribute_ids(all_attribute_ids,new_key)def_check_permutable_attribute_ids(self,all_attribute_ids:set[str])->None:"""Check attribute IDs from permutable attributes for uniqueness."""ifself.permutable_attributesisNone:returniflen(self.permutable_attributes)==0:raiseValueError("GeneralSynthesisParams.permutable_attributes cannot be empty.")forpermutable_attributeinself.permutable_attributes:attribute_id=permutable_attribute.idself._check_attribute_ids(all_attribute_ids,attribute_id)def_check_generated_attribute_ids(self,all_attribute_ids:set[str])->None:"""Check attribute IDs from generated attributes for uniqueness."""ifself.generated_attributesisNone:returniflen(self.generated_attributes)==0:raiseValueError("GeneralSynthesisParams.generated_attributes cannot be empty.")forgenerated_attributeinself.generated_attributes:attribute_id=generated_attribute.idself._check_attribute_ids(all_attribute_ids,attribute_id)ifgenerated_attribute.postprocessing_params:postprocessing_id=generated_attribute.postprocessing_params.idself._check_attribute_ids(all_attribute_ids,postprocessing_id)def_check_transformed_attribute_ids(self,all_attribute_ids:set[str])->None:"""Check attribute IDs from transformed attributes for uniqueness."""ifself.transformed_attributesisNone:returniflen(self.transformed_attributes)==0:raiseValueError("GeneralSynthesisParams.transformed_attributes cannot be empty.")fortransformed_attributeinself.transformed_attributes:attribute_id=transformed_attribute.idself._check_attribute_ids(all_attribute_ids,attribute_id)def_check_combination_sampling_sample_rates(self)->None:"""Validate that the combination sample rates are <= 1.0."""ifself.combination_samplingisNone:returniflen(self.combination_sampling)==0:raiseValueError("GeneralSynthesisParams.combination_sampling cannot be empty.")sample_rates=[combination.sample_rateforcombinationinself.combination_sampling]ifsum(sample_rates)>1.0:raiseValueError("GeneralSynthesisParams.combination_sampling sample rates must be ""less than or equal to 1.0.")def_check_passthrough_attribute_ids(self)->None:"""Validate that passthrough attributes are non-empty when defined."""ifself.passthrough_attributesisNone:returniflen(self.passthrough_attributes)==0:raiseValueError("GeneralSynthesisParams.passthrough_attributes cannot be empty.")