# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importfunctoolsimportimportlib.utilimportosimportsysfromcollectionsimportnamedtuplefromenumimportEnum,autofrompathlibimportPathfromtypingimportAny,Callable,Optionalfromoumi.utils.loggingimportlogger
classRegistryKey(namedtuple("RegistryKey",["name","registry_type"])):def__new__(cls,name:str,registry_type:RegistryType):"""Create a new RegistryKey instance. Args: name: The name of the registry key. registry_type: The type of the registry. Returns: A new RegistryKey instance with lowercase name. """returnsuper().__new__(cls,name.lower(),registry_type)def_load_user_requirements(requirements_file:str):"""Loads user-defined requirements from a file."""logger.info(f"Loading user-defined registry from: {requirements_file}")logger.info("This value can be set using the OUMI_EXTRA_DEPS_FILE ""environment variable.")requirements_path=Path(requirements_file)ifnotrequirements_path.exists():logger.error(f"OUMI_EXTRA_DEPS_FILE file not found: {requirements_file}")raiseFileNotFoundError(f"OUMI_EXTRA_DEPS_FILE file not found: {requirements_file}")withopen(requirements_path)asf:import_count=0foridx,lineinenumerate(f):line=line.strip()ifnotlineorline.startswith("#"):continueimport_count+=1import_path=Path(line)logger.info(f"Loading user-defined registry module: {import_path}")mod_name=import_path.stemsys.path.append(str(import_path.parent))try:importlib.import_module(mod_name)exceptExceptionase:logger.error("Failed to load a user-defined module in "f"OUMI_EXTRA_DEPS_FILE: {line}")raiseImportError(f"Failed to load user-defined module: {line}")fromelogger.info(f"Loaded {import_count} user-defined registry modules.")def_register_dependencies(cls_function):"""Decorator to ensure core dependencies are added to the Registry."""@functools.wraps(cls_function)defwrapper(self,*args,**kwargs):ifnotself._initialized:# Immediately set the initialized flag to avoid infinite recursion.self._initialized=True# Import all core dependencies.importoumi.datasets# noqa: F401importoumi.judges# noqa: F401importoumi.launcher# noqa: F401importoumi.models# noqa: F401# Import user-defined dependencies.user_req_file=os.environ.get("OUMI_EXTRA_DEPS_FILE",None)ifuser_req_file:_load_user_requirements(user_req_file)returncls_function(self,*args,**kwargs)returnwrapper
[docs]classRegistry:_initialized:bool=Falsedef__init__(self):"""Initializes the class Registry."""self._registry=dict()## Public functions#
[docs]@_register_dependenciesdefcontains(self,name:str,type:RegistryType)->bool:"""Indicates whether a record exists in the registry."""returnself._contains(RegistryKey(name,type))
[docs]@_register_dependenciesdefclear(self)->None:"""Clears the registry."""self._registry=dict()
[docs]@_register_dependenciesdefregister(self,name:str,type:RegistryType,value:Any)->None:"""Registers a new record."""registry_key=RegistryKey(name,type)ifself._contains(registry_key):current_value=self.get(name=name,type=type)raiseValueError(f"Registry: `{name}` of `{type}` "f"is already registered as `{current_value}`.")self._registry[registry_key]=value
[docs]@_register_dependenciesdefget(self,name:str,type:RegistryType,)->Optional[Callable]:"""Gets a record by name and type."""registry_key=RegistryKey(name,type)returnself._registry.get(registry_key)
[docs]@_register_dependenciesdefget_all(self,type:RegistryType)->dict:"""Gets all records of a specific type."""return{key.name:valueforkey,valueinself._registry.items()ifkey.registry_type==type}
## Convenience public function wrappers.#
[docs]defget_model(self,name:str)->Optional[Callable]:"""Gets a record that corresponds to a registered model."""returnself.get(name,RegistryType.MODEL)
[docs]defget_model_config(self,name:str)->Optional[Callable]:"""Gets a record that corresponds to a registered config."""returnself.get(name,RegistryType.MODEL_CONFIG)
[docs]defget_metrics_function(self,name:str)->Optional[Callable]:"""Gets a record that corresponds to a registered metrics function."""returnself.get(name,RegistryType.METRICS_FUNCTION)
[docs]defget_judge_config(self,name:str)->Optional[Callable]:"""Gets a record that corresponds to a registered judge config."""returnself.get(name,RegistryType.JUDGE_CONFIG)
[docs]defget_dataset(self,name:str,subset:Optional[str]=None)->Optional[Callable]:"""Gets a record that corresponds to a registered dataset."""ifsubset:# If a subset is provided, first check for subset-specific dataset.# If not found, try to get the dataset directly.dataset_cls=self.get(f"{name}/{subset}",RegistryType.DATASET)ifdataset_clsisnotNone:returndataset_clsreturnself.get(name,RegistryType.DATASET)
## Private functions#def_contains(self,key:RegistryKey)->bool:"""Indicates whether a record already exists in the registry."""returnkeyinself._registry## Magic methods#
[docs]def__getitem__(self,args:tuple[str,RegistryType])->Callable:"""Gets a record by name and type."""ifnotisinstance(args,tuple)orlen(args)!=2:raiseValueError("Expected a tuple of length 2 with the first element being the name ""and the second element being the type.")name,type=argsregistry_key=RegistryKey(name,type)ifnotself._contains(registry_key):raiseKeyError(f"Registry: `{name}` of `{type}` does not exist.")else:returnself._registry[registry_key]
[docs]def__repr__(self)->str:"""Defines how this class is properly printed."""return"\n".join(f"{key}: {value}"forkey,valueinself._registry.items())
REGISTRY=Registry()
[docs]defregister(registry_name:str,registry_type:RegistryType)->Callable:"""Returns function to register decorated `obj` in the Oumi global registry. Args: registry_name: The name that the object should be registered with. registry_type: The type of object we are registering. Returns: Decorator function to register the target object. """defdecorator_register(obj):"""Decorator to register its target `obj`."""REGISTRY.register(name=registry_name,type=registry_type,value=obj)returnobjreturndecorator_register
[docs]defregister_dataset(registry_name:str,subset:Optional[str]=None)->Callable:"""Returns function to register decorated `obj` in the Oumi global registry. Args: registry_name: The name that the object should be registered with. subset: The type of object we are registering. Returns: Decorator function to register the target object. """defdecorator_register(obj):"""Decorator to register its target `obj`."""full_name=f"{registry_name}/{subset}"ifsubsetelseregistry_nameREGISTRY.register(name=full_name,type=RegistryType.DATASET,value=obj)returnobjreturndecorator_register
[docs]defregister_cloud_builder(registry_name:str)->Callable:"""Returns a function to register decorated builder in the Oumi global registry. Use this decorator to register cloud builder functions in the global registry. A cloud builder function is a function that accepts no arguments and returns an instance of a class that implements the `BaseCloud` interface. Args: registry_name: The name that the builder should be registered with. Returns: Decorator function to register the target builder. """defdecorator_register(obj):"""Decorator to register its target builder."""REGISTRY.register(name=registry_name,type=RegistryType.CLOUD,value=obj)returnobjreturndecorator_register
[docs]defregister_judge(registry_name:str)->Callable:"""Returns a function to register a judge configuration in the Oumi global registry. This decorator is used to register judge configuration in the global registry. A judge configuration function typically returns a JudgeConfig object that defines the parameters and attributes for a specific judge. Args: registry_name: The name under which the judge configuration should be registered. Returns: Callable: A decorator function that registers the target judge configuration. Example: .. code-block:: python @register_judge("my_custom_judge") def my_judge_config() -> JudgeConfig: return JudgeConfig(...) """defdecorator_register(obj):"""Decorator to register its target builder."""REGISTRY.register(name=registry_name,type=RegistryType.JUDGE_CONFIG,value=obj)returnobjreturndecorator_register