Source code for oumi.core.configs.params.grpo_params
# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importmathfromdataclassesimportdataclass,field,fieldsfromtypingimportAny,Optionalfromoumi.core.configs.params.base_paramsimportBaseParams
[docs]@dataclassclassGrpoParams(BaseParams):model_init_kwargs:dict[str,Any]=field(default_factory=dict)"""Keyword arguments for `AutoModelForCausalLM.from_pretrained(...)`"""max_prompt_length:Optional[int]=None"""Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. If unspecified (`None`), defaults to 512. """max_completion_length:Optional[int]=None"""Maximum length of the generated completion. If unspecified (`None`), defaults to 256. """num_generations:Optional[int]=None"""Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value. If unspecified (`None`), defaults to 8. """temperature:float=0.9"""Temperature for sampling. The higher the temperature, the more random the completions. """remove_unused_columns:bool=False"""Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that requires any column other than `"prompts"` and `"completions"`, you should set it to `False`. """repetition_penalty:Optional[float]=1.0"""Float that penalizes new tokens if they appear in the prompt/response so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model to repeat tokens. """use_vllm:bool=False"""Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for training, as vLLM will require one for generation. """vllm_mode:Optional[str]=None"""The mode to use for vLLM generation ("colocate" or "server"). If set to `None`, defaults to "server". Server mode means that vLLM is running on a separate server that the trainer will communicate with. It requires the server to be started with `trl vllm-serve` beforehand. Colocate mode means that vLLM will run in the same process as the trainer and share GPUs. While this is simpler as it doesn't require a separate server, vLLM will contend with the trainer for GPU resources. """vllm_gpu_memory_utilization:float=0.9"""Ratio (between 0 and 1) of GPU memory to reserve. Fraction of VRAM reserved for the model weights, activations, and KV cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during initialization. """epsilon:float=0.2"""Epsilon value for clipping the relative probability in the loss. For example, if epsilon is 0.2, then the new probability can only differ from the old probability by a factor of x0.8-1.2."""log_completions:bool=False"""Whether to log prompt and completion pairs every `logging_steps` steps."""
[docs]def__post_init__(self):"""Verifies params."""ifself.max_prompt_lengthisnotNoneandself.max_prompt_length<=0:raiseValueError("GrpoParams.max_prompt_length must be positive. "f"Actual: {self.max_prompt_length}")ifself.max_completion_lengthisnotNoneandself.max_completion_length<=0:raiseValueError("GrpoParams.max_completion_length must be positive. "f"Actual: {self.max_completion_length}")ifself.num_generationsisnotNoneandself.num_generations<=0:raiseValueError("GrpoParams.num_generations must be positive. "f"Actual: {self.num_generations}")ifnot(math.isfinite(self.temperature)andself.temperature>=0.0andself.temperature<=1.0):raiseValueError("GrpoParams.temperature must be within [0.0, 1.0] range. "f"Actual: {self.temperature}")
[docs]defto_hf_trainer_kwargs(self)->dict[str,Any]:"""Converts GrpoParams to TRL's GRPOConfig kwargs."""result={}iflen(self.model_init_kwargs)>0:result["model_init_kwargs"]=self.model_init_kwargsifself.max_prompt_lengthisnotNone:result["max_prompt_length"]=self.max_prompt_lengthifself.max_completion_lengthisnotNone:result["max_completion_length"]=self.max_completion_lengthifself.num_generationsisnotNone:result["num_generations"]=self.num_generationsalready_processed_keys:set[str]=set({"model_init_kwargs","max_prompt_length","max_completion_length","num_generations",})# Copy the majority of fields that aren't special-cased.forparaminfields(self):ifparam.name.startswith("vllm_")orparam.nameinalready_processed_keys:continueresult[param.name]=getattr(self,param.name)ifself.use_vllm:# Return vLLM params only if vLLM is enabled.ifself.vllm_modeisnotNone:result["vllm_mode"]=self.vllm_moderesult["vllm_gpu_memory_utilization"]=self.vllm_gpu_memory_utilizationreturnresult