Source code for oumi.core.synthesis.synthesis_pipeline
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from pathlib import Path
from typing import Any
from oumi.core.configs.synthesis_config import SynthesisConfig
from oumi.core.synthesis.attribute_synthesizer import AttributeSynthesizer
from oumi.core.synthesis.attribute_transformation import AttributeTransformer
from oumi.core.synthesis.data_synthesizer import DataSynthesizer
from oumi.core.synthesis.dataset_planner import DatasetPlanner
from oumi.utils.io_utils import save_jsonlines
from oumi.utils.logging import logger
[docs]
class SynthesisPipeline:
"""Pipeline for synthesizing a dataset."""
def __init__(self, config: SynthesisConfig):
"""Initialize the synthesis pipeline."""
self._config = config
attribute_synthesizer = AttributeSynthesizer(
config.strategy_params, config.inference_config
)
self._attribute_transformer = AttributeTransformer(config.strategy_params)
self._dataset_planner = DatasetPlanner()
self._data_synthesizer = (
DataSynthesizer(
config.strategy_params.generated_attributes,
attribute_synthesizer,
)
if config.strategy_params.generated_attributes
else None
)
[docs]
def synthesize(self) -> list[dict[str, Any]]:
"""Synthesize a dataset."""
# Populate the dataset plan with column values for each non-generated attribute
logger.info(
f"Loading dependencies to synthesize dataset with "
f"{self._config.num_samples} samples"
)
dataset = self._dataset_planner.plan(
self._config.strategy_params,
self._config.num_samples,
)
# Synthesize the generated attributes
logger.info("Synthesizing generated attributes")
if self._data_synthesizer:
dataset = self._data_synthesizer.synthesize(dataset)
# Add the transformed attributes to the dataset
logger.info("Adding transformed attributes")
if self._config.strategy_params.transformed_attributes:
dataset = self._attribute_transformer.transform(dataset)
# If passthrough attributes are specified, keep only those attributes
logger.info("Keeping passthrough attributes")
if self._config.strategy_params.passthrough_attributes:
dataset = self._passthrough_attributes(dataset)
# Save the dataset to the output path
logger.info("Saving dataset")
if self._config.output_path:
self._save_dataset(dataset)
logger.info("Synthesis complete")
return dataset
def _passthrough_attributes(
self,
dataset: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Keep only the passthrough attributes in the dataset.
Supports both simple keys and bracket notation for nested access:
- Simple: "conversation" -> sample["conversation"]
- Bracket: "examples[0].field" -> sample["examples"][0]["field"]
"""
if not self._config.strategy_params.passthrough_attributes:
return dataset
passthrough_attributes = self._config.strategy_params.passthrough_attributes
# Separate simple keys from bracket notation paths
simple_keys = set()
bracket_paths = []
for attr in passthrough_attributes:
if "[" in attr and "]" in attr:
bracket_paths.append(attr)
else:
simple_keys.add(attr)
result = []
for sample in dataset:
filtered_sample = {}
# Add simple passthrough attributes
for key in simple_keys:
if key in sample:
filtered_sample[key] = sample[key]
# Add bracket notation attributes
for path in bracket_paths:
try:
value = self._extract_nested_value(sample, path)
# Store using the full path as the key
filtered_sample[path] = value
except (KeyError, IndexError, ValueError):
# Skip if path doesn't exist in sample
pass
result.append(filtered_sample)
return result
def _extract_nested_value(self, sample: dict[str, Any], path: str) -> Any:
"""Extract a value from a nested structure using bracket notation.
Args:
sample: The sample dictionary to extract from.
path: Path like "examples[0].field" or "data[1].nested.value"
Returns:
The extracted value.
Raises:
KeyError: If a key doesn't exist.
IndexError: If an index is out of range.
ValueError: If the path format is invalid.
"""
# Parse the path: "examples[0].field" -> ["examples", "[0]", "field"]
# Match: word, [index], or .word
pattern = r"([^\[\].]+|\[\d+\])"
parts = re.findall(pattern, path)
current: Any = sample
for part in parts:
if part.startswith("[") and part.endswith("]"):
# Array index access
index = int(part[1:-1])
if not isinstance(current, list):
raise ValueError(
f"Cannot index into non-list type {type(current).__name__}"
)
current = current[index]
else:
# Dictionary key access
if isinstance(current, dict):
current = current[part]
else:
raise ValueError(
f"Cannot access key '{part}' on non-dict type "
f"{type(current).__name__}"
)
return current
def _save_dataset(self, dataset: list[dict[str, Any]]):
"""Save the dataset to the output path."""
if not self._config.output_path:
raise ValueError("SynthesisConfig.output_path is not specified.")
path_str = self._config.output_path
path = Path(path_str)
parent = path.parent
if not parent.exists():
parent.mkdir(parents=True)
if path.suffix == ".jsonl":
save_jsonlines(path, dataset)
else:
raise ValueError(f"Unsupported output path: {path_str}")