Source code for oumi.core.configs.params.gkd_params
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass, field
from typing import Any, Optional
from oumi.core.configs.params.base_params import BaseParams
[docs]
@dataclass
class GkdParams(BaseParams):
"""Parameters for Generalized Knowledge Distillation (GKD) training.
GKD implements on-policy distillation where the student model generates outputs
and learns from teacher corrections in real-time during training.
Based on "On-Policy Distillation of Language Models: Learning from Self-Generated
Mistakes" (https://arxiv.org/abs/2306.13649).
Warning:
GKDTrainer is experimental and may be changed or removed in future versions.
"""
teacher_model_name_or_path: Optional[str] = None
"""Path or identifier of the teacher model.
This is required for GKD training. Can be a HuggingFace model ID or local path.
"""
teacher_model_init_kwargs: dict[str, Any] = field(
default_factory=lambda: {"dtype": "auto"}
)
"""Keyword arguments for loading the teacher model.
Passed to `AutoModelForCausalLM.from_pretrained(...)` when loading the teacher.
Common kwargs include `device_map`, `attn_implementation`, etc.
Defaults to {"dtype": "auto"} to allow the model to use the default dtype
of the teacher model.
"""
temperature: float = 0.9
"""Temperature for sampling during generation.
Higher values (e.g., 1.0) produce more diverse outputs, while lower values
(e.g., 0.5) produce more focused outputs. Must be in range (0.0, 1.0].
"""
lmbda: float = 0.5
"""Student data fraction (lambda parameter).
Controls the mix between on-policy (student-generated) and off-policy
(dataset) examples. Value of 0.5 means 50% on-policy, 50% off-policy.
Must be in range [0.0, 1.0].
"""
beta: float = 0.5
"""Jensen-Shannon Divergence interpolation coefficient.
Controls the balance in the JSD loss function:
- beta = 0.0: Uses KL divergence (teacher → student)
- beta = 0.5: Uses symmetric JSD
- beta = 1.0: Uses reverse KL divergence (student → teacher)
Must be in range [0.0, 1.0].
"""
max_new_tokens: int = 128
"""Maximum number of tokens to generate per prompt.
This controls how long the student model's completions can be during
on-policy generation.
"""
disable_dropout: bool = True
"""Whether to disable dropout in the student model during training.
Recommended to keep as `True` for more stable distillation.
"""
seq_kd: bool = False
"""Whether to use sequence-level knowledge distillation.
If `True`, uses sequence-level KD where the loss is computed at the sequence
level. If `False`, uses token-level KD (default and recommended).
"""
[docs]
def __post_init__(self):
"""Validates GKD parameters."""
if self.teacher_model_name_or_path is not None:
if not isinstance(self.teacher_model_name_or_path, str):
raise TypeError(
"GkdParams.teacher_model_name_or_path must be a string. "
f"Actual type: {type(self.teacher_model_name_or_path)}"
)
if not self.teacher_model_name_or_path.strip():
raise ValueError(
"GkdParams.teacher_model_name_or_path cannot be empty."
)
if not (
math.isfinite(self.temperature)
and self.temperature > 0.0
and self.temperature <= 1.0
):
raise ValueError(
"GkdParams.temperature must be in range (0.0, 1.0]. "
f"Actual: {self.temperature}"
)
if not (math.isfinite(self.lmbda) and 0.0 <= self.lmbda <= 1.0):
raise ValueError(
f"GkdParams.lmbda must be in range [0.0, 1.0]. Actual: {self.lmbda}"
)
if not (math.isfinite(self.beta) and 0.0 <= self.beta <= 1.0):
raise ValueError(
f"GkdParams.beta must be in range [0.0, 1.0]. Actual: {self.beta}"
)
if self.max_new_tokens <= 0:
raise ValueError(
"GkdParams.max_new_tokens must be positive. "
f"Actual: {self.max_new_tokens}"
)
[docs]
def to_hf_trainer_kwargs(self) -> dict[str, Any]:
"""Converts GkdParams to TRL's GKDConfig kwargs.
Note:
The teacher_model_name_or_path is NOT passed to GKDConfig.
Instead, it's passed to the GKDTrainer constructor via train.py.
The teacher_model_init_kwargs goes into GKDConfig for TRL to use when
loading the teacher model.
Returns:
Dictionary of kwargs to pass to TRL's GKDConfig.
"""
result = {
"temperature": self.temperature,
"lmbda": self.lmbda,
"beta": self.beta,
"max_new_tokens": self.max_new_tokens,
"disable_dropout": self.disable_dropout,
"seq_kd": self.seq_kd,
}
if len(self.teacher_model_init_kwargs) > 0:
result["teacher_model_init_kwargs"] = self.teacher_model_init_kwargs
else:
result["teacher_model_init_kwargs"] = {}
if "dtype" not in result["teacher_model_init_kwargs"]:
result["teacher_model_init_kwargs"]["dtype"] = "auto"
return result